use a buffered writer for the http3 request writer

This commit is contained in:
Marten Seemann 2020-03-27 09:53:16 +07:00
parent 683230372e
commit bcffb77ad4
3 changed files with 50 additions and 42 deletions

View file

@ -222,7 +222,7 @@ var _ = Describe("Client", func() {
sess.EXPECT().HandshakeComplete().Return(handshakeCtx), sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
) )
str.EXPECT().Write(gomock.Any()).AnyTimes() str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close() str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return rspBuf.Read(p) return rspBuf.Read(p)
@ -458,7 +458,7 @@ var _ = Describe("Client", func() {
gz.Write([]byte("gzipped response")) gz.Write([]byte("gzipped response"))
gz.Close() gz.Close()
rw.Flush() rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes() str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p) return buf.Read(p)
}).AnyTimes() }).AnyTimes()
@ -480,7 +480,7 @@ var _ = Describe("Client", func() {
rw := newResponseWriter(buf, utils.DefaultLogger) rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Write([]byte("not gzipped")) rw.Write([]byte("not gzipped"))
rw.Flush() rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes() str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p) return buf.Read(p)
}).AnyTimes() }).AnyTimes()

View file

@ -1,6 +1,7 @@
package http3 package http3
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
@ -18,6 +19,8 @@ import (
"golang.org/x/net/idna" "golang.org/x/net/idna"
) )
const bodyCopyBufferSize = 8 * 1024
type requestWriter struct { type requestWriter struct {
mutex sync.Mutex mutex sync.Mutex
encoder *qpack.Encoder encoder *qpack.Encoder
@ -37,22 +40,47 @@ func newRequestWriter(logger utils.Logger) *requestWriter {
} }
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error { func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error {
headers, err := w.getHeaders(req, gzip) wr := bufio.NewWriter(str)
if err != nil {
return err if err := w.writeHeaders(wr, req, gzip); err != nil {
}
if _, err := str.Write(headers); err != nil {
return err return err
} }
// TODO: add support for trailers // TODO: add support for trailers
if req.Body == nil { if req.Body == nil {
if err := wr.Flush(); err != nil {
return err
}
str.Close() str.Close()
return nil return nil
} }
// send the request body asynchronously // send the request body asynchronously
go func() { go func() {
if err := w.sendRequestBody(req.Body, str); err != nil { defer req.Body.Close()
b := make([]byte, bodyCopyBufferSize)
for {
n, err := req.Body.Read(b)
if err == io.EOF {
break
}
if err != nil {
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
w.logger.Errorf("Error writing request: %s", err)
return
}
buf := &bytes.Buffer{}
(&dataFrame{Length: uint64(n)}).Write(buf)
if _, err := wr.Write(buf.Bytes()); err != nil {
w.logger.Errorf("Error writing request: %s", err)
return
}
if _, err := wr.Write(b[:n]); err != nil {
w.logger.Errorf("Error writing request: %s", err)
return
}
}
if err := wr.Flush(); err != nil {
fmt.Println(err)
w.logger.Errorf("Error writing request: %s", err) w.logger.Errorf("Error writing request: %s", err)
return return
} }
@ -62,46 +90,25 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo
return nil return nil
} }
func (w *requestWriter) getHeaders(req *http.Request, gzip bool) ([]byte, error) { func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {
w.mutex.Lock() w.mutex.Lock()
defer w.mutex.Unlock() defer w.mutex.Unlock()
defer w.encoder.Close() defer w.encoder.Close()
if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil { if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil {
return nil, err return err
} }
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
hf := headersFrame{Length: uint64(w.headerBuf.Len())} hf := headersFrame{Length: uint64(w.headerBuf.Len())}
hf.Write(buf) hf.Write(buf)
if _, err := io.Copy(buf, w.headerBuf); err != nil { if _, err := wr.Write(buf.Bytes()); err != nil {
return nil, err return err
}
if _, err := wr.Write(w.headerBuf.Bytes()); err != nil {
return err
} }
w.headerBuf.Reset() w.headerBuf.Reset()
return buf.Bytes(), nil
}
func (w *requestWriter) sendRequestBody(req io.ReadCloser, str quic.Stream) error {
b := make([]byte, 8*1024)
for {
n, err := req.Read(b)
if err == io.EOF {
break
}
if err != nil {
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
return err
}
buf := &bytes.Buffer{}
(&dataFrame{Length: uint64(n)}).Write(buf)
if _, err := str.Write(buf.Bytes()); err != nil {
return err
}
if _, err := str.Write(b[:n]); err != nil {
return err
}
}
req.Close()
return nil return nil
} }

View file

@ -25,15 +25,15 @@ var _ = Describe("Request Writer", func() {
decode := func(str io.Reader) map[string]string { decode := func(str io.Reader) map[string]string {
frame, err := parseNextFrame(str) frame, err := parseNextFrame(str)
Expect(err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame) headersFrame := frame.(*headersFrame)
data := make([]byte, headersFrame.Length) data := make([]byte, headersFrame.Length)
_, err = io.ReadFull(str, data) _, err = io.ReadFull(str, data)
Expect(err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
decoder := qpack.NewDecoder(nil) decoder := qpack.NewDecoder(nil)
hfs, err := decoder.DecodeFull(data) hfs, err := decoder.DecodeFull(data)
Expect(err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
values := make(map[string]string) values := make(map[string]string)
for _, hf := range hfs { for _, hf := range hfs {
values[hf.Name] = hf.Value values[hf.Name] = hf.Value
@ -70,6 +70,8 @@ var _ = Describe("Request Writer", func() {
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", postData) req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", postData)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, false)).To(Succeed()) Expect(rw.WriteRequest(str, req, false)).To(Succeed())
Eventually(closed).Should(BeClosed())
headerFields := decode(strBuf) headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":method", "POST")) Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
Expect(headerFields).To(HaveKey("content-length")) Expect(headerFields).To(HaveKey("content-length"))
@ -77,7 +79,6 @@ var _ = Describe("Request Writer", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(contentLength).To(BeNumerically(">", 0)) Expect(contentLength).To(BeNumerically(">", 0))
Eventually(closed).Should(BeClosed())
frame, err := parseNextFrame(strBuf) frame, err := parseNextFrame(strBuf)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))