use a buffered writer for the http3 request writer

This commit is contained in:
Marten Seemann 2020-03-27 09:53:16 +07:00
parent 683230372e
commit bcffb77ad4
3 changed files with 50 additions and 42 deletions

View file

@ -222,7 +222,7 @@ var _ = Describe("Client", func() {
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return rspBuf.Read(p)
@ -458,7 +458,7 @@ var _ = Describe("Client", func() {
gz.Write([]byte("gzipped response"))
gz.Close()
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p)
}).AnyTimes()
@ -480,7 +480,7 @@ var _ = Describe("Client", func() {
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p)
}).AnyTimes()

View file

@ -1,6 +1,7 @@
package http3
import (
"bufio"
"bytes"
"fmt"
"io"
@ -18,6 +19,8 @@ import (
"golang.org/x/net/idna"
)
const bodyCopyBufferSize = 8 * 1024
type requestWriter struct {
mutex sync.Mutex
encoder *qpack.Encoder
@ -37,22 +40,47 @@ func newRequestWriter(logger utils.Logger) *requestWriter {
}
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error {
headers, err := w.getHeaders(req, gzip)
if err != nil {
return err
}
if _, err := str.Write(headers); err != nil {
wr := bufio.NewWriter(str)
if err := w.writeHeaders(wr, req, gzip); err != nil {
return err
}
// TODO: add support for trailers
if req.Body == nil {
if err := wr.Flush(); err != nil {
return err
}
str.Close()
return nil
}
// send the request body asynchronously
go func() {
if err := w.sendRequestBody(req.Body, str); err != nil {
defer req.Body.Close()
b := make([]byte, bodyCopyBufferSize)
for {
n, err := req.Body.Read(b)
if err == io.EOF {
break
}
if err != nil {
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
w.logger.Errorf("Error writing request: %s", err)
return
}
buf := &bytes.Buffer{}
(&dataFrame{Length: uint64(n)}).Write(buf)
if _, err := wr.Write(buf.Bytes()); err != nil {
w.logger.Errorf("Error writing request: %s", err)
return
}
if _, err := wr.Write(b[:n]); err != nil {
w.logger.Errorf("Error writing request: %s", err)
return
}
}
if err := wr.Flush(); err != nil {
fmt.Println(err)
w.logger.Errorf("Error writing request: %s", err)
return
}
@ -62,46 +90,25 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo
return nil
}
func (w *requestWriter) getHeaders(req *http.Request, gzip bool) ([]byte, error) {
func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {
w.mutex.Lock()
defer w.mutex.Unlock()
defer w.encoder.Close()
if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil {
return nil, err
return err
}
buf := &bytes.Buffer{}
hf := headersFrame{Length: uint64(w.headerBuf.Len())}
hf.Write(buf)
if _, err := io.Copy(buf, w.headerBuf); err != nil {
return nil, err
if _, err := wr.Write(buf.Bytes()); err != nil {
return err
}
if _, err := wr.Write(w.headerBuf.Bytes()); err != nil {
return err
}
w.headerBuf.Reset()
return buf.Bytes(), nil
}
func (w *requestWriter) sendRequestBody(req io.ReadCloser, str quic.Stream) error {
b := make([]byte, 8*1024)
for {
n, err := req.Read(b)
if err == io.EOF {
break
}
if err != nil {
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
return err
}
buf := &bytes.Buffer{}
(&dataFrame{Length: uint64(n)}).Write(buf)
if _, err := str.Write(buf.Bytes()); err != nil {
return err
}
if _, err := str.Write(b[:n]); err != nil {
return err
}
}
req.Close()
return nil
}

View file

@ -25,15 +25,15 @@ var _ = Describe("Request Writer", func() {
decode := func(str io.Reader) map[string]string {
frame, err := parseNextFrame(str)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame)
data := make([]byte, headersFrame.Length)
_, err = io.ReadFull(str, data)
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, err).ToNot(HaveOccurred())
decoder := qpack.NewDecoder(nil)
hfs, err := decoder.DecodeFull(data)
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, err).ToNot(HaveOccurred())
values := make(map[string]string)
for _, hf := range hfs {
values[hf.Name] = hf.Value
@ -70,6 +70,8 @@ var _ = Describe("Request Writer", func() {
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", postData)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
Eventually(closed).Should(BeClosed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
Expect(headerFields).To(HaveKey("content-length"))
@ -77,7 +79,6 @@ var _ = Describe("Request Writer", func() {
Expect(err).ToNot(HaveOccurred())
Expect(contentLength).To(BeNumerically(">", 0))
Eventually(closed).Should(BeClosed())
frame, err := parseNextFrame(strBuf)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))