diff --git a/http3/client.go b/http3/client.go index b0e6546c..b2b7bbdc 100644 --- a/http3/client.go +++ b/http3/client.go @@ -327,31 +327,43 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon return rsp, rerr.err } -func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error { - defer body.Close() - b := make([]byte, bodyCopyBufferSize) - for { - n, rerr := body.Read(b) - if n == 0 { - if rerr == nil { - continue - } - if rerr == io.EOF { - break - } - } - if _, err := str.Write(b[:n]); err != nil { - return err - } - if rerr != nil { - if rerr == io.EOF { - break - } - str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) - return rerr - } +// cancelingReader reads from the io.Reader. +// It cancels writing on the stream if any error other than io.EOF occurs. +type cancelingReader struct { + r io.Reader + str Stream +} + +func (r *cancelingReader) Read(b []byte) (int, error) { + n, err := r.r.Read(b) + if err != nil && err != io.EOF { + r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) } - return nil + return n, err +} + +func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error { + defer body.Close() + buf := make([]byte, bodyCopyBufferSize) + sr := &cancelingReader{str: str, r: body} + if contentLength == -1 { + _, err := io.CopyBuffer(str, sr, buf) + return err + } + + // make sure we don't send more bytes than the content length + n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf) + if err != nil { + return err + } + var extra int64 + extra, err = io.CopyBuffer(io.Discard, sr, buf) + n += extra + if n > contentLength { + str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) + return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n) + } + return err } func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) { @@ -371,7 +383,13 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui if req.Body != nil { // send the request body asynchronously go func() { - if err := c.sendRequestBody(hstr, req.Body); err != nil { + contentLength := int64(-1) + // According to the documentation for http.Request.ContentLength, + // a value of 0 with a non-nil Body is also treated as unknown content length. + if req.ContentLength > 0 { + contentLength = req.ContentLength + } + if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil { c.logger.Errorf("Error writing request: %s", err) } if !opt.DontCloseRequestStream { diff --git a/http3/client_test.go b/http3/client_test.go index ebcfa3bc..1c5ac6f8 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "sync" "time" "github.com/quic-go/quic-go" @@ -803,6 +804,29 @@ var _ = Describe("Client", func() { Expect(hfs).To(HaveKeyWithValue(":path", "/upload")) }) + It("doesn't send more bytes than allowed by http.Request.ContentLength", func() { + req.ContentLength = 7 + var once sync.Once + done := make(chan struct{}) + gomock.InOrder( + str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) { + once.Do(func() { + Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled))) + close(done) + }) + }).AnyTimes(), + str.EXPECT().Close().MaxTimes(1), + str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(), + ) + str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { + <-done + return 0, errors.New("done") + }) + cl.RoundTripOpt(req, RoundTripOpt{}) + Expect(strBuf.String()).To(ContainSubstring("request")) + Expect(strBuf.String()).ToNot(ContainSubstring("request body")) + }) + It("returns the error that occurred when reading the body", func() { req.Body.(*mockBody).readErr = errors.New("testErr") done := make(chan struct{})