http3: don't send more than http.Request.ContentLength bytes (#3960)

This commit is contained in:
Marten Seemann 2023-07-16 19:16:52 -07:00 committed by GitHub
parent de8d7a32b8
commit 3dea8f8a9b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 25 deletions

View file

@ -327,31 +327,43 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
return rsp, rerr.err return rsp, rerr.err
} }
func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error { // cancelingReader reads from the io.Reader.
defer body.Close() // It cancels writing on the stream if any error other than io.EOF occurs.
b := make([]byte, bodyCopyBufferSize) type cancelingReader struct {
for { r io.Reader
n, rerr := body.Read(b) str Stream
if n == 0 { }
if rerr == nil {
continue func (r *cancelingReader) Read(b []byte) (int, error) {
} n, err := r.r.Read(b)
if rerr == io.EOF { if err != nil && err != io.EOF {
break r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
}
}
if _, err := str.Write(b[:n]); err != nil {
return err
}
if rerr != nil {
if rerr == io.EOF {
break
}
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
return rerr
}
} }
return nil return n, err
}
func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
defer body.Close()
buf := make([]byte, bodyCopyBufferSize)
sr := &cancelingReader{str: str, r: body}
if contentLength == -1 {
_, err := io.CopyBuffer(str, sr, buf)
return err
}
// make sure we don't send more bytes than the content length
n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
if err != nil {
return err
}
var extra int64
extra, err = io.CopyBuffer(io.Discard, sr, buf)
n += extra
if n > contentLength {
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
}
return err
} }
func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) { func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
@ -371,7 +383,13 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui
if req.Body != nil { if req.Body != nil {
// send the request body asynchronously // send the request body asynchronously
go func() { go func() {
if err := c.sendRequestBody(hstr, req.Body); err != nil { contentLength := int64(-1)
// According to the documentation for http.Request.ContentLength,
// a value of 0 with a non-nil Body is also treated as unknown content length.
if req.ContentLength > 0 {
contentLength = req.ContentLength
}
if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil {
c.logger.Errorf("Error writing request: %s", err) c.logger.Errorf("Error writing request: %s", err)
} }
if !opt.DontCloseRequestStream { if !opt.DontCloseRequestStream {

View file

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
@ -803,6 +804,29 @@ var _ = Describe("Client", func() {
Expect(hfs).To(HaveKeyWithValue(":path", "/upload")) Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
}) })
It("doesn't send more bytes than allowed by http.Request.ContentLength", func() {
req.ContentLength = 7
var once sync.Once
done := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) {
once.Do(func() {
Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled)))
close(done)
})
}).AnyTimes(),
str.EXPECT().Close().MaxTimes(1),
str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(),
)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
<-done
return 0, errors.New("done")
})
cl.RoundTripOpt(req, RoundTripOpt{})
Expect(strBuf.String()).To(ContainSubstring("request"))
Expect(strBuf.String()).ToNot(ContainSubstring("request body"))
})
It("returns the error that occurred when reading the body", func() { It("returns the error that occurred when reading the body", func() {
req.Body.(*mockBody).readErr = errors.New("testErr") req.Body.(*mockBody).readErr = errors.New("testErr")
done := make(chan struct{}) done := make(chan struct{})