From 04d46526c7097e37b1edb988b0c38c159f013604 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 29 May 2022 19:22:05 +0200 Subject: [PATCH] refactor HTTP/3 stream handling to use a dedicated stream Reading from and writing onto this stream applies HTTP/3 DATA framing. --- http3/body.go | 91 +++++---------- http3/body_test.go | 207 ++++++----------------------------- http3/client.go | 50 ++++++++- http3/client_test.go | 1 + http3/http_stream.go | 71 ++++++++++++ http3/http_stream_test.go | 150 +++++++++++++++++++++++++ http3/request_writer.go | 54 +-------- http3/request_writer_test.go | 73 +----------- http3/server.go | 2 +- http3/server_test.go | 5 +- 10 files changed, 344 insertions(+), 360 deletions(-) create mode 100644 http3/http_stream.go create mode 100644 http3/http_stream_test.go diff --git a/http3/body.go b/http3/body.go index 23d4cf55..b2e2a933 100644 --- a/http3/body.go +++ b/http3/body.go @@ -2,7 +2,6 @@ package http3 import ( "context" - "fmt" "io" "net" @@ -29,42 +28,43 @@ type Hijacker interface { // The body of a http.Request or http.Response. type body struct { str quic.Stream +} + +var _ io.ReadCloser = &body{} + +func newRequestBody(str Stream) *body { + return &body{str: str} +} + +func (r *body) Read(b []byte) (int, error) { + return r.str.Read(b) +} + +func (r *body) Close() error { + r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + return nil +} + +type hijackableBody struct { + body + conn quic.Connection // only needed to implement Hijacker // only set for the http.Response // The channel is closed when the user is done with this response: // either when Read() errors, or when Close() is called. reqDone chan<- struct{} reqDoneClosed bool - - onFrameError func() - - bytesRemainingInFrame uint64 -} - -var _ io.ReadCloser = &body{} - -type hijackableBody struct { - body - conn quic.Connection // only needed to implement Hijacker } var _ Hijacker = &hijackableBody{} -func newRequestBody(str quic.Stream, onFrameError func()) *body { - return &body{ - str: str, - onFrameError: onFrameError, - } -} - -func newResponseBody(str quic.Stream, conn quic.Connection, done chan<- struct{}, onFrameError func()) *hijackableBody { +func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody { return &hijackableBody{ body: body{ - str: str, - onFrameError: onFrameError, - reqDone: done, + str: str, }, - conn: conn, + reqDone: done, + conn: conn, } } @@ -72,50 +72,15 @@ func (r *hijackableBody) StreamCreator() StreamCreator { return r.conn } -func (r *body) Read(b []byte) (int, error) { - n, err := r.readImpl(b) +func (r *hijackableBody) Read(b []byte) (int, error) { + n, err := r.str.Read(b) if err != nil { r.requestDone() } return n, err } -func (r *body) readImpl(b []byte) (int, error) { - if r.bytesRemainingInFrame == 0 { - parseLoop: - for { - frame, err := parseNextFrame(r.str, nil) - if err != nil { - return 0, err - } - switch f := frame.(type) { - case *headersFrame: - // skip HEADERS frames - continue - case *dataFrame: - r.bytesRemainingInFrame = f.Length - break parseLoop - default: - r.onFrameError() - // parseNextFrame skips over unknown frame types - // Therefore, this condition is only entered when we parsed another known frame type. - return 0, fmt.Errorf("peer sent an unexpected frame: %T", f) - } - } - } - - var n int - var err error - if r.bytesRemainingInFrame < uint64(len(b)) { - n, err = r.str.Read(b[:r.bytesRemainingInFrame]) - } else { - n, err = r.str.Read(b) - } - r.bytesRemainingInFrame -= uint64(n) - return n, err -} - -func (r *body) requestDone() { +func (r *hijackableBody) requestDone() { if r.reqDoneClosed || r.reqDone == nil { return } @@ -127,7 +92,7 @@ func (r *body) StreamID() quic.StreamID { return r.str.StreamID() } -func (r *body) Close() error { +func (r *hijackableBody) Close() error { r.requestDone() // If the EOF was read, CancelRead() is a no-op. r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) diff --git a/http3/body_test.go b/http3/body_test.go index f50004dc..4920357d 100644 --- a/http3/body_test.go +++ b/http3/body_test.go @@ -1,189 +1,54 @@ package http3 import ( - "bytes" - "fmt" - "io" + "errors" - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go" mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -type bodyType uint8 +var _ = Describe("Response Body", func() { + var reqDone chan struct{} -const ( - bodyTypeRequest bodyType = iota - bodyTypeResponse -) + BeforeEach(func() { reqDone = make(chan struct{}) }) -func (t bodyType) String() string { - if t == bodyTypeRequest { - return "request" - } - return "response" -} - -var _ = Describe("Body", func() { - var ( - rb io.ReadCloser - str *mockquic.MockStream - buf *bytes.Buffer - reqDone chan struct{} - errorCbCalled bool - ) - - errorCb := func() { errorCbCalled = true } - - getDataFrame := func(data []byte) []byte { - b := &bytes.Buffer{} - (&dataFrame{Length: uint64(len(data))}).Write(b) - b.Write(data) - return b.Bytes() - } - - BeforeEach(func() { - buf = &bytes.Buffer{} - errorCbCalled = false + It("closes the reqDone channel when Read errors", func() { + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error")) + rb := newResponseBody(str, nil, reqDone) + _, err := rb.Read([]byte{0}) + Expect(err).To(MatchError("test error")) + Expect(reqDone).To(BeClosed()) }) - for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} { - bodyType := bt + It("allows multiple calls to Read, when Read errors", func() { + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error")).Times(2) + rb := newResponseBody(str, nil, reqDone) + _, err := rb.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + Expect(reqDone).To(BeClosed()) + _, err = rb.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + }) - Context(fmt.Sprintf("using a %s body", bodyType), func() { - BeforeEach(func() { - str = mockquic.NewMockStream(mockCtrl) - str.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) { - return buf.Write(b) - }).AnyTimes() - str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) { - return buf.Read(b) - }).AnyTimes() + It("closes responses", func() { + str := mockquic.NewMockStream(mockCtrl) + rb := newResponseBody(str, nil, reqDone) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + Expect(rb.Close()).To(Succeed()) + }) - switch bodyType { - case bodyTypeRequest: - rb = newRequestBody(str, errorCb) - case bodyTypeResponse: - reqDone = make(chan struct{}) - rb = newResponseBody(str, nil, reqDone, errorCb) - } - }) - - It("reads DATA frames in a single run", func() { - buf.Write(getDataFrame([]byte("foobar"))) - b := make([]byte, 6) - n, err := rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b).To(Equal([]byte("foobar"))) - }) - - It("reads DATA frames in multiple runs", func() { - buf.Write(getDataFrame([]byte("foobar"))) - b := make([]byte, 3) - n, err := rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - Expect(b).To(Equal([]byte("foo"))) - n, err = rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - Expect(b).To(Equal([]byte("bar"))) - }) - - It("reads DATA frames into too large buffers", func() { - buf.Write(getDataFrame([]byte("foobar"))) - b := make([]byte, 10) - n, err := rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b[:n]).To(Equal([]byte("foobar"))) - }) - - It("reads DATA frames into too large buffers, in multiple runs", func() { - buf.Write(getDataFrame([]byte("foobar"))) - b := make([]byte, 4) - n, err := rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte("foob"))) - n, err = rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(2)) - Expect(b[:n]).To(Equal([]byte("ar"))) - }) - - It("reads multiple DATA frames", func() { - buf.Write(getDataFrame([]byte("foo"))) - buf.Write(getDataFrame([]byte("bar"))) - b := make([]byte, 6) - n, err := rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - Expect(b[:n]).To(Equal([]byte("foo"))) - n, err = rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - Expect(b[:n]).To(Equal([]byte("bar"))) - }) - - It("skips HEADERS frames", func() { - buf.Write(getDataFrame([]byte("foo"))) - (&headersFrame{Length: 10}).Write(buf) - buf.Write(make([]byte, 10)) - buf.Write(getDataFrame([]byte("bar"))) - b := make([]byte, 6) - n, err := io.ReadFull(rb, b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b).To(Equal([]byte("foobar"))) - }) - - It("errors when it can't parse the frame", func() { - buf.Write([]byte("invalid")) - _, err := rb.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - }) - - It("errors on unexpected frames, and calls the error callback", func() { - (&settingsFrame{}).Write(buf) - _, err := rb.Read([]byte{0}) - Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame")) - Expect(errorCbCalled).To(BeTrue()) - }) - - if bodyType == bodyTypeResponse { - It("closes the reqDone channel when Read errors", func() { - buf.Write([]byte("invalid")) - _, err := rb.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - Expect(reqDone).To(BeClosed()) - }) - - It("allows multiple calls to Read, when Read errors", func() { - buf.Write([]byte("invalid")) - _, err := rb.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - Expect(reqDone).To(BeClosed()) - _, err = rb.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - }) - - It("closes responses", func() { - str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)) - Expect(rb.Close()).To(Succeed()) - }) - - It("allows multiple calls to Close", func() { - str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2) - Expect(rb.Close()).To(Succeed()) - Expect(reqDone).To(BeClosed()) - Expect(rb.Close()).To(Succeed()) - }) - } - }) - } + It("allows multiple calls to Close", func() { + str := mockquic.NewMockStream(mockCtrl) + rb := newResponseBody(str, nil, reqDone) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2) + Expect(rb.Close()).To(Succeed()) + Expect(reqDone).To(BeClosed()) + Expect(rb.Close()).To(Succeed()) + }) }) diff --git a/http3/client.go b/http3/client.go index 325fd4d4..c56a8a35 100644 --- a/http3/client.go +++ b/http3/client.go @@ -298,15 +298,59 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon return rsp, rerr.err } +func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error { + defer body.Close() + b := make([]byte, bodyCopyBufferSize) + for { + n, rerr := body.Read(b) + if n == 0 { + if rerr == nil { + continue + } + if rerr == io.EOF { + break + } + } + if _, err := str.Write(b[:n]); err != nil { + return err + } + if rerr != nil { + if rerr == io.EOF { + break + } + str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) + return rerr + } + } + return nil +} + func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan struct{}) (*http.Response, requestError) { var requestGzip bool if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { requestGzip = true } - if err := c.requestWriter.WriteRequest(str, req, opt.DontCloseRequestStream, requestGzip); err != nil { + if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil { return nil, newStreamError(errorInternalError, err) } + if req.Body == nil && !opt.DontCloseRequestStream { + str.Close() + } + + hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") }) + if req.Body != nil { + // send the request body asynchronously + go func() { + if err := c.sendRequestBody(hstr, req.Body); err != nil { + c.logger.Errorf("Error writing request: %s", err) + } + if !opt.DontCloseRequestStream { + hstr.Close() + } + }() + } + frame, err := parseNextFrame(str, nil) if err != nil { return nil, newStreamError(errorFrameError, err) @@ -348,9 +392,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, res.Header.Add(hf.Name, hf.Value) } } - respBody := newResponseBody(str, c.conn, reqDone, func() { - c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") - }) + respBody := newResponseBody(hstr, c.conn, reqDone) // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. _, hasTransferEncoding := res.Header["Transfer-Encoding"] diff --git a/http3/client_test.go b/http3/client_test.go index 9be1c684..f512fd41 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -797,6 +797,7 @@ var _ = Describe("Client", func() { <-done return 0, errors.New("test done") }) + str.EXPECT().Close() _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) }) diff --git a/http3/http_stream.go b/http3/http_stream.go new file mode 100644 index 00000000..4c69068c --- /dev/null +++ b/http3/http_stream.go @@ -0,0 +1,71 @@ +package http3 + +import ( + "bytes" + "fmt" + + "github.com/lucas-clemente/quic-go" +) + +// A Stream is a HTTP/3 stream. +// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames. +type Stream quic.Stream + +// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly +// from the QUIC stream, it writes to and reads from the HTTP stream. +type stream struct { + quic.Stream + + onFrameError func() + bytesRemainingInFrame uint64 +} + +var _ Stream = &stream{} + +func newStream(str quic.Stream, onFrameError func()) *stream { + return &stream{Stream: str, onFrameError: onFrameError} +} + +func (s *stream) Read(b []byte) (int, error) { + if s.bytesRemainingInFrame == 0 { + parseLoop: + for { + frame, err := parseNextFrame(s.Stream, nil) + if err != nil { + return 0, err + } + switch f := frame.(type) { + case *headersFrame: + // skip HEADERS frames + continue + case *dataFrame: + s.bytesRemainingInFrame = f.Length + break parseLoop + default: + s.onFrameError() + // parseNextFrame skips over unknown frame types + // Therefore, this condition is only entered when we parsed another known frame type. + return 0, fmt.Errorf("peer sent an unexpected frame: %T", f) + } + } + } + + var n int + var err error + if s.bytesRemainingInFrame < uint64(len(b)) { + n, err = s.Stream.Read(b[:s.bytesRemainingInFrame]) + } else { + n, err = s.Stream.Read(b) + } + s.bytesRemainingInFrame -= uint64(n) + return n, err +} + +func (s *stream) Write(b []byte) (int, error) { + buf := &bytes.Buffer{} + (&dataFrame{Length: uint64(len(b))}).Write(buf) + if _, err := s.Stream.Write(buf.Bytes()); err != nil { + return 0, err + } + return s.Stream.Write(b) +} diff --git a/http3/http_stream_test.go b/http3/http_stream_test.go new file mode 100644 index 00000000..ad9833b9 --- /dev/null +++ b/http3/http_stream_test.go @@ -0,0 +1,150 @@ +package http3 + +import ( + "bytes" + "io" + + mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" + + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Stream", func() { + Context("reading", func() { + var ( + str Stream + qstr *mockquic.MockStream + buf *bytes.Buffer + errorCbCalled bool + ) + + errorCb := func() { errorCbCalled = true } + getDataFrame := func(data []byte) []byte { + b := &bytes.Buffer{} + (&dataFrame{Length: uint64(len(data))}).Write(b) + b.Write(data) + return b.Bytes() + } + + BeforeEach(func() { + buf = &bytes.Buffer{} + errorCbCalled = false + qstr = mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() + qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + str = newStream(qstr, errorCb) + }) + + It("reads DATA frames in a single run", func() { + buf.Write(getDataFrame([]byte("foobar"))) + b := make([]byte, 6) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + Expect(b).To(Equal([]byte("foobar"))) + }) + + It("reads DATA frames in multiple runs", func() { + buf.Write(getDataFrame([]byte("foobar"))) + b := make([]byte, 3) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(b).To(Equal([]byte("foo"))) + n, err = str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(b).To(Equal([]byte("bar"))) + }) + + It("reads DATA frames into too large buffers", func() { + buf.Write(getDataFrame([]byte("foobar"))) + b := make([]byte, 10) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + Expect(b[:n]).To(Equal([]byte("foobar"))) + }) + + It("reads DATA frames into too large buffers, in multiple runs", func() { + buf.Write(getDataFrame([]byte("foobar"))) + b := make([]byte, 4) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte("foob"))) + n, err = str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(2)) + Expect(b[:n]).To(Equal([]byte("ar"))) + }) + + It("reads multiple DATA frames", func() { + buf.Write(getDataFrame([]byte("foo"))) + buf.Write(getDataFrame([]byte("bar"))) + b := make([]byte, 6) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(b[:n]).To(Equal([]byte("foo"))) + n, err = str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(b[:n]).To(Equal([]byte("bar"))) + }) + + It("skips HEADERS frames", func() { + buf.Write(getDataFrame([]byte("foo"))) + (&headersFrame{Length: 10}).Write(buf) + buf.Write(make([]byte, 10)) + buf.Write(getDataFrame([]byte("bar"))) + b := make([]byte, 6) + n, err := io.ReadFull(str, b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + Expect(b).To(Equal([]byte("foobar"))) + }) + + It("errors when it can't parse the frame", func() { + buf.Write([]byte("invalid")) + _, err := str.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + }) + + It("errors on unexpected frames, and calls the error callback", func() { + (&settingsFrame{}).Write(buf) + _, err := str.Read([]byte{0}) + Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame")) + Expect(errorCbCalled).To(BeTrue()) + }) + }) + + Context("writing", func() { + It("writes data frames", func() { + buf := &bytes.Buffer{} + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() + str := newStream(qstr, nil) + str.Write([]byte("foo")) + str.Write([]byte("foobar")) + + f, err := parseNextFrame(buf, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(Equal(&dataFrame{Length: 3})) + b := make([]byte, 3) + _, err = io.ReadFull(buf, b) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(Equal([]byte("foo"))) + + f, err = parseNextFrame(buf, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(Equal(&dataFrame{Length: 6})) + b = make([]byte, 6) + _, err = io.ReadFull(buf, b) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(Equal([]byte("foobar"))) + }) + }) +}) diff --git a/http3/request_writer.go b/http3/request_writer.go index bd36c7dd..0a9c67ac 100644 --- a/http3/request_writer.go +++ b/http3/request_writer.go @@ -38,60 +38,14 @@ func newRequestWriter(logger utils.Logger) *requestWriter { } } -func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, dontCloseStr, gzip bool) error { +func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, gzip bool) error { + // TODO: figure out how to add support for trailers buf := &bytes.Buffer{} if err := w.writeHeaders(buf, req, gzip); err != nil { return err } - if _, err := str.Write(buf.Bytes()); err != nil { - return err - } - // TODO: add support for trailers - if req.Body == nil { - if !dontCloseStr { - str.Close() - } - return nil - } - - // send the request body asynchronously - go func() { - defer req.Body.Close() - b := make([]byte, bodyCopyBufferSize) - for { - n, rerr := req.Body.Read(b) - if n == 0 { - if rerr == nil { - continue - } else if rerr == io.EOF { - break - } - } - buf := &bytes.Buffer{} - (&dataFrame{Length: uint64(n)}).Write(buf) - if _, err := str.Write(buf.Bytes()); err != nil { - w.logger.Errorf("Error writing request: %s", err) - return - } - if _, err := str.Write(b[:n]); err != nil { - w.logger.Errorf("Error writing request: %s", err) - return - } - if rerr != nil { - if rerr == io.EOF { - break - } - str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) - w.logger.Errorf("Error writing request: %s", rerr) - return - } - } - if !dontCloseStr { - str.Close() - } - }() - - return nil + _, err := str.Write(buf.Bytes()) + return err } func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error { diff --git a/http3/request_writer_test.go b/http3/request_writer_test.go index e2c80cdc..1e5a1614 100644 --- a/http3/request_writer_test.go +++ b/http3/request_writer_test.go @@ -4,7 +4,6 @@ import ( "bytes" "io" "net/http" - "strconv" mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" "github.com/lucas-clemente/quic-go/internal/utils" @@ -16,12 +15,6 @@ import ( . "github.com/onsi/gomega" ) -type foobarReader struct{} - -func (r *foobarReader) Read(b []byte) (int, error) { - return copy(b, []byte("foobar")), io.EOF -} - var _ = Describe("Request Writer", func() { var ( rw *requestWriter @@ -51,16 +44,13 @@ var _ = Describe("Request Writer", func() { rw = newRequestWriter(utils.DefaultLogger) strBuf = &bytes.Buffer{} str = mockquic.NewMockStream(mockCtrl) - str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { - return strBuf.Write(p) - }).AnyTimes() + str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() }) It("writes a GET request", func() { - str.EXPECT().Close() req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) Expect(headerFields).To(HaveKeyWithValue(":method", "GET")) @@ -69,55 +59,7 @@ var _ = Describe("Request Writer", func() { Expect(headerFields).ToNot(HaveKey("accept-encoding")) }) - It("writes a GET request without closing the stream", func() { - req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io", nil) - Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, true, false)).To(Succeed()) - headerFields := decode(strBuf) - Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) - }) - - It("writes a POST request", func() { - closed := make(chan struct{}) - str.EXPECT().Close().Do(func() { close(closed) }) - postData := bytes.NewReader([]byte("foobar")) - req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", postData) - Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) - - Eventually(closed).Should(BeClosed()) - headerFields := decode(strBuf) - Expect(headerFields).To(HaveKeyWithValue(":method", "POST")) - Expect(headerFields).To(HaveKey("content-length")) - contentLength, err := strconv.Atoi(headerFields["content-length"]) - Expect(err).ToNot(HaveOccurred()) - Expect(contentLength).To(BeNumerically(">", 0)) - - frame, err := parseNextFrame(strBuf, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) - Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6)) - }) - - It("writes a POST request, if the Body returns an EOF immediately", func() { - closed := make(chan struct{}) - str.EXPECT().Close().Do(func() { close(closed) }) - req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", &foobarReader{}) - Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) - - Eventually(closed).Should(BeClosed()) - headerFields := decode(strBuf) - Expect(headerFields).To(HaveKeyWithValue(":method", "POST")) - - frame, err := parseNextFrame(strBuf, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) - Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6)) - }) - It("sends cookies", func() { - str.EXPECT().Close() req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil) Expect(err).ToNot(HaveOccurred()) cookie1 := &http.Cookie{ @@ -130,25 +72,23 @@ var _ = Describe("Request Writer", func() { } req.AddCookie(cookie1) req.AddCookie(cookie2) - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`)) }) It("adds the header for gzip support", func() { - str.EXPECT().Close() req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false, true)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, true)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip")) }) It("writes a CONNECT request", func() { - str.EXPECT().Close() req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) @@ -158,11 +98,10 @@ var _ = Describe("Request Writer", func() { }) It("writes an Extended CONNECT request", func() { - str.EXPECT().Close() req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil) Expect(err).ToNot(HaveOccurred()) req.Proto = "webtransport" - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) diff --git a/http3/server.go b/http3/server.go index 45ca3f4c..040770e0 100644 --- a/http3/server.go +++ b/http3/server.go @@ -549,7 +549,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q } req.RemoteAddr = conn.RemoteAddr().String() - req.Body = newRequestBody(str, onFrameError) + req.Body = newRequestBody(newStream(str, onFrameError)) if s.logger.Debug() { s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID()) diff --git a/http3/server_test.go b/http3/server_test.go index 064380c3..e7fc0f3e 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -135,11 +135,8 @@ var _ = Describe("Server", func() { buf := &bytes.Buffer{} str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() - closed := make(chan struct{}) - str.EXPECT().Close().Do(func() { close(closed) }) rw := newRequestWriter(utils.DefaultLogger) - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) - Eventually(closed).Should(BeClosed()) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) return buf.Bytes() }