mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 04:07:35 +03:00
http3: enforce that DATA frames don't exceed Content-Length (#3980)
This commit is contained in:
parent
56cd866840
commit
5a22ac8970
4 changed files with 124 additions and 6 deletions
|
@ -426,7 +426,15 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui
|
|||
connState := conn.ConnectionState().TLS
|
||||
res.TLS = &connState
|
||||
res.Request = req
|
||||
respBody := newResponseBody(hstr, conn, reqDone)
|
||||
// Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
|
||||
// See section 4.1.2 of RFC 9114.
|
||||
var httpStr Stream
|
||||
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
|
||||
httpStr = newLengthLimitedStream(hstr, req.ContentLength)
|
||||
} else {
|
||||
httpStr = hstr
|
||||
}
|
||||
respBody := newResponseBody(httpStr, conn, reqDone)
|
||||
|
||||
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
|
||||
_, hasTransferEncoding := res.Header["Transfer-Encoding"]
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A Stream is a HTTP/3 stream.
|
||||
|
@ -66,6 +68,10 @@ func (s *stream) Read(b []byte) (int, error) {
|
|||
return n, err
|
||||
}
|
||||
|
||||
func (s *stream) hasMoreData() bool {
|
||||
return s.bytesRemainingInFrame > 0
|
||||
}
|
||||
|
||||
func (s *stream) Write(b []byte) (int, error) {
|
||||
s.buf = s.buf[:0]
|
||||
s.buf = (&dataFrame{Length: uint64(len(b))}).Append(s.buf)
|
||||
|
@ -74,3 +80,45 @@ func (s *stream) Write(b []byte) (int, error) {
|
|||
}
|
||||
return s.Stream.Write(b)
|
||||
}
|
||||
|
||||
var errTooMuchData = errors.New("peer sent too much data")
|
||||
|
||||
type lengthLimitedStream struct {
|
||||
*stream
|
||||
contentLength int64
|
||||
read int64
|
||||
resetStream bool
|
||||
}
|
||||
|
||||
var _ Stream = &lengthLimitedStream{}
|
||||
|
||||
func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream {
|
||||
return &lengthLimitedStream{
|
||||
stream: str,
|
||||
contentLength: contentLength,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *lengthLimitedStream) checkContentLengthViolation() error {
|
||||
if s.read > s.contentLength || s.read == s.contentLength && s.hasMoreData() {
|
||||
if !s.resetStream {
|
||||
s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
s.resetStream = true
|
||||
}
|
||||
return errTooMuchData
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *lengthLimitedStream) Read(b []byte) (int, error) {
|
||||
if err := s.checkContentLengthViolation(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err := s.stream.Read(b[:utils.Min(int64(len(b)), s.contentLength-s.read)])
|
||||
s.read += int64(n)
|
||||
if err := s.checkContentLengthViolation(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
@ -11,6 +12,11 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func getDataFrame(data []byte) []byte {
|
||||
b := (&dataFrame{Length: uint64(len(data))}).Append(nil)
|
||||
return append(b, data...)
|
||||
}
|
||||
|
||||
var _ = Describe("Stream", func() {
|
||||
Context("reading", func() {
|
||||
var (
|
||||
|
@ -21,10 +27,6 @@ var _ = Describe("Stream", func() {
|
|||
)
|
||||
|
||||
errorCb := func() { errorCbCalled = true }
|
||||
getDataFrame := func(data []byte) []byte {
|
||||
b := (&dataFrame{Length: uint64(len(data))}).Append(nil)
|
||||
return append(b, data...)
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
buf = &bytes.Buffer{}
|
||||
|
@ -148,3 +150,54 @@ var _ = Describe("Stream", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("length-limited streams", func() {
|
||||
var (
|
||||
str *stream
|
||||
qstr *mockquic.MockStream
|
||||
buf *bytes.Buffer
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
buf = &bytes.Buffer{}
|
||||
qstr = mockquic.NewMockStream(mockCtrl)
|
||||
qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
|
||||
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
str = newStream(qstr, func() { Fail("didn't expect error callback to be called") })
|
||||
})
|
||||
|
||||
It("reads all frames", func() {
|
||||
s := newLengthLimitedStream(str, 6)
|
||||
buf.Write(getDataFrame([]byte("foo")))
|
||||
buf.Write(getDataFrame([]byte("bar")))
|
||||
data, err := io.ReadAll(s)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("errors if more data than the maximum length is sent, in the middle of a frame", func() {
|
||||
s := newLengthLimitedStream(str, 4)
|
||||
buf.Write(getDataFrame([]byte("foo")))
|
||||
buf.Write(getDataFrame([]byte("bar")))
|
||||
qstr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
qstr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
data, err := io.ReadAll(s)
|
||||
Expect(err).To(MatchError(errTooMuchData))
|
||||
Expect(data).To(Equal([]byte("foob")))
|
||||
// check that repeated calls to Read also return the right error
|
||||
n, err := s.Read([]byte{0})
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(errTooMuchData))
|
||||
})
|
||||
|
||||
It("errors if more data than the maximum length is sent, as an additional frame", func() {
|
||||
s := newLengthLimitedStream(str, 3)
|
||||
buf.Write(getDataFrame([]byte("foo")))
|
||||
buf.Write(getDataFrame([]byte("bar")))
|
||||
qstr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
qstr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
data, err := io.ReadAll(s)
|
||||
Expect(err).To(MatchError(errTooMuchData))
|
||||
Expect(data).To(Equal([]byte("foo")))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -579,7 +579,16 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
|
|||
connState := conn.ConnectionState().TLS
|
||||
req.TLS = &connState
|
||||
req.RemoteAddr = conn.RemoteAddr().String()
|
||||
body := newRequestBody(newStream(str, onFrameError))
|
||||
|
||||
// Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
|
||||
// See section 4.1.2 of RFC 9114.
|
||||
var httpStr Stream
|
||||
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
|
||||
httpStr = newLengthLimitedStream(newStream(str, onFrameError), req.ContentLength)
|
||||
} else {
|
||||
httpStr = newStream(str, onFrameError)
|
||||
}
|
||||
body := newRequestBody(httpStr)
|
||||
req.Body = body
|
||||
|
||||
if s.logger.Debug() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue