mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
http3: don't send more than http.Request.ContentLength bytes (#3960)
This commit is contained in:
parent
de8d7a32b8
commit
3dea8f8a9b
2 changed files with 67 additions and 25 deletions
|
@ -327,31 +327,43 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
|
||||||
return rsp, rerr.err
|
return rsp, rerr.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error {
|
// cancelingReader reads from the io.Reader.
|
||||||
defer body.Close()
|
// It cancels writing on the stream if any error other than io.EOF occurs.
|
||||||
b := make([]byte, bodyCopyBufferSize)
|
type cancelingReader struct {
|
||||||
for {
|
r io.Reader
|
||||||
n, rerr := body.Read(b)
|
str Stream
|
||||||
if n == 0 {
|
}
|
||||||
if rerr == nil {
|
|
||||||
continue
|
func (r *cancelingReader) Read(b []byte) (int, error) {
|
||||||
}
|
n, err := r.r.Read(b)
|
||||||
if rerr == io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
break
|
r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||||
}
|
|
||||||
}
|
|
||||||
if _, err := str.Write(b[:n]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if rerr != nil {
|
|
||||||
if rerr == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
|
||||||
return rerr
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
|
||||||
|
defer body.Close()
|
||||||
|
buf := make([]byte, bodyCopyBufferSize)
|
||||||
|
sr := &cancelingReader{str: str, r: body}
|
||||||
|
if contentLength == -1 {
|
||||||
|
_, err := io.CopyBuffer(str, sr, buf)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure we don't send more bytes than the content length
|
||||||
|
n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var extra int64
|
||||||
|
extra, err = io.CopyBuffer(io.Discard, sr, buf)
|
||||||
|
n += extra
|
||||||
|
if n > contentLength {
|
||||||
|
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||||
|
return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
|
func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
|
||||||
|
@ -371,7 +383,13 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui
|
||||||
if req.Body != nil {
|
if req.Body != nil {
|
||||||
// send the request body asynchronously
|
// send the request body asynchronously
|
||||||
go func() {
|
go func() {
|
||||||
if err := c.sendRequestBody(hstr, req.Body); err != nil {
|
contentLength := int64(-1)
|
||||||
|
// According to the documentation for http.Request.ContentLength,
|
||||||
|
// a value of 0 with a non-nil Body is also treated as unknown content length.
|
||||||
|
if req.ContentLength > 0 {
|
||||||
|
contentLength = req.ContentLength
|
||||||
|
}
|
||||||
|
if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil {
|
||||||
c.logger.Errorf("Error writing request: %s", err)
|
c.logger.Errorf("Error writing request: %s", err)
|
||||||
}
|
}
|
||||||
if !opt.DontCloseRequestStream {
|
if !opt.DontCloseRequestStream {
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
|
@ -803,6 +804,29 @@ var _ = Describe("Client", func() {
|
||||||
Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
|
Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("doesn't send more bytes than allowed by http.Request.ContentLength", func() {
|
||||||
|
req.ContentLength = 7
|
||||||
|
var once sync.Once
|
||||||
|
done := make(chan struct{})
|
||||||
|
gomock.InOrder(
|
||||||
|
str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) {
|
||||||
|
once.Do(func() {
|
||||||
|
Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled)))
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
|
}).AnyTimes(),
|
||||||
|
str.EXPECT().Close().MaxTimes(1),
|
||||||
|
str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(),
|
||||||
|
)
|
||||||
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
||||||
|
<-done
|
||||||
|
return 0, errors.New("done")
|
||||||
|
})
|
||||||
|
cl.RoundTripOpt(req, RoundTripOpt{})
|
||||||
|
Expect(strBuf.String()).To(ContainSubstring("request"))
|
||||||
|
Expect(strBuf.String()).ToNot(ContainSubstring("request body"))
|
||||||
|
})
|
||||||
|
|
||||||
It("returns the error that occurred when reading the body", func() {
|
It("returns the error that occurred when reading the body", func() {
|
||||||
req.Body.(*mockBody).readErr = errors.New("testErr")
|
req.Body.(*mockBody).readErr = errors.New("testErr")
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue