From 5a22ac89701811545f13b662c8134e0bf24d6542 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 20 Jul 2023 09:47:01 -0700 Subject: [PATCH] http3: enforce that DATA frames don't exceed Content-Length (#3980) --- http3/client.go | 10 ++++++- http3/http_stream.go | 48 ++++++++++++++++++++++++++++++ http3/http_stream_test.go | 61 ++++++++++++++++++++++++++++++++++++--- http3/server.go | 11 ++++++- 4 files changed, 124 insertions(+), 6 deletions(-) diff --git a/http3/client.go b/http3/client.go index 849cc5d9..d53391f0 100644 --- a/http3/client.go +++ b/http3/client.go @@ -426,7 +426,15 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui connState := conn.ConnectionState().TLS res.TLS = &connState res.Request = req - respBody := newResponseBody(hstr, conn, reqDone) + // Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set). + // See section 4.1.2 of RFC 9114. + var httpStr Stream + if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 { + httpStr = newLengthLimitedStream(hstr, req.ContentLength) + } else { + httpStr = hstr + } + respBody := newResponseBody(httpStr, conn, reqDone) // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. _, hasTransferEncoding := res.Header["Transfer-Encoding"] diff --git a/http3/http_stream.go b/http3/http_stream.go index 2799e2b3..1c0ec4f1 100644 --- a/http3/http_stream.go +++ b/http3/http_stream.go @@ -1,9 +1,11 @@ package http3 import ( + "errors" "fmt" "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/internal/utils" ) // A Stream is a HTTP/3 stream. @@ -66,6 +68,10 @@ func (s *stream) Read(b []byte) (int, error) { return n, err } +func (s *stream) hasMoreData() bool { + return s.bytesRemainingInFrame > 0 +} + func (s *stream) Write(b []byte) (int, error) { s.buf = s.buf[:0] s.buf = (&dataFrame{Length: uint64(len(b))}).Append(s.buf) @@ -74,3 +80,45 @@ func (s *stream) Write(b []byte) (int, error) { } return s.Stream.Write(b) } + +var errTooMuchData = errors.New("peer sent too much data") + +type lengthLimitedStream struct { + *stream + contentLength int64 + read int64 + resetStream bool +} + +var _ Stream = &lengthLimitedStream{} + +func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream { + return &lengthLimitedStream{ + stream: str, + contentLength: contentLength, + } +} + +func (s *lengthLimitedStream) checkContentLengthViolation() error { + if s.read > s.contentLength || s.read == s.contentLength && s.hasMoreData() { + if !s.resetStream { + s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) + s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) + s.resetStream = true + } + return errTooMuchData + } + return nil +} + +func (s *lengthLimitedStream) Read(b []byte) (int, error) { + if err := s.checkContentLengthViolation(); err != nil { + return 0, err + } + n, err := s.stream.Read(b[:utils.Min(int64(len(b)), s.contentLength-s.read)]) + s.read += int64(n) + if err := s.checkContentLengthViolation(); err != nil { + return n, err + } + return n, err +} diff --git a/http3/http_stream_test.go b/http3/http_stream_test.go index 8a4ee26c..cff5476b 100644 --- a/http3/http_stream_test.go +++ b/http3/http_stream_test.go @@ -4,6 +4,7 @@ import ( "bytes" "io" + "github.com/quic-go/quic-go" mockquic "github.com/quic-go/quic-go/internal/mocks/quic" "github.com/golang/mock/gomock" @@ -11,6 +12,11 @@ import ( . "github.com/onsi/gomega" ) +func getDataFrame(data []byte) []byte { + b := (&dataFrame{Length: uint64(len(data))}).Append(nil) + return append(b, data...) +} + var _ = Describe("Stream", func() { Context("reading", func() { var ( @@ -21,10 +27,6 @@ var _ = Describe("Stream", func() { ) errorCb := func() { errorCbCalled = true } - getDataFrame := func(data []byte) []byte { - b := (&dataFrame{Length: uint64(len(data))}).Append(nil) - return append(b, data...) - } BeforeEach(func() { buf = &bytes.Buffer{} @@ -148,3 +150,54 @@ var _ = Describe("Stream", func() { }) }) }) + +var _ = Describe("length-limited streams", func() { + var ( + str *stream + qstr *mockquic.MockStream + buf *bytes.Buffer + ) + + BeforeEach(func() { + buf = &bytes.Buffer{} + qstr = mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() + qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + str = newStream(qstr, func() { Fail("didn't expect error callback to be called") }) + }) + + It("reads all frames", func() { + s := newLengthLimitedStream(str, 6) + buf.Write(getDataFrame([]byte("foo"))) + buf.Write(getDataFrame([]byte("bar"))) + data, err := io.ReadAll(s) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + }) + + It("errors if more data than the maximum length is sent, in the middle of a frame", func() { + s := newLengthLimitedStream(str, 4) + buf.Write(getDataFrame([]byte("foo"))) + buf.Write(getDataFrame([]byte("bar"))) + qstr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) + qstr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) + data, err := io.ReadAll(s) + Expect(err).To(MatchError(errTooMuchData)) + Expect(data).To(Equal([]byte("foob"))) + // check that repeated calls to Read also return the right error + n, err := s.Read([]byte{0}) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(errTooMuchData)) + }) + + It("errors if more data than the maximum length is sent, as an additional frame", func() { + s := newLengthLimitedStream(str, 3) + buf.Write(getDataFrame([]byte("foo"))) + buf.Write(getDataFrame([]byte("bar"))) + qstr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) + qstr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) + data, err := io.ReadAll(s) + Expect(err).To(MatchError(errTooMuchData)) + Expect(data).To(Equal([]byte("foo"))) + }) +}) diff --git a/http3/server.go b/http3/server.go index 3dd3d96b..b90c850c 100644 --- a/http3/server.go +++ b/http3/server.go @@ -579,7 +579,16 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q connState := conn.ConnectionState().TLS req.TLS = &connState req.RemoteAddr = conn.RemoteAddr().String() - body := newRequestBody(newStream(str, onFrameError)) + + // Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set). + // See section 4.1.2 of RFC 9114. + var httpStr Stream + if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 { + httpStr = newLengthLimitedStream(newStream(str, onFrameError), req.ContentLength) + } else { + httpStr = newStream(str, onFrameError) + } + body := newRequestBody(httpStr) req.Body = body if s.logger.Debug() {