diff --git a/http3/client.go b/http3/client.go index 8744422e..b297b8b7 100644 --- a/http3/client.go +++ b/http3/client.go @@ -17,6 +17,7 @@ import ( ) const defaultUserAgent = "quic-go HTTP/3" +const defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB var defaultQuicConfig = &quic.Config{KeepAlive: true} @@ -24,6 +25,7 @@ var dialAddr = quic.DialAddr type roundTripperOpts struct { DisableCompression bool + MaxHeaderBytes int64 } // client is a HTTP3 client doing requests @@ -121,6 +123,13 @@ func (c *client) Close() error { return c.session.Close() } +func (c *client) maxHeaderBytes() uint64 { + if c.opts.MaxHeaderBytes <= 0 { + return defaultMaxResponseHeaderBytes + } + return uint64(c.opts.MaxHeaderBytes) +} + // Roundtrip executes a request and returns a response // TODO: handle request cancelations func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { @@ -160,7 +169,9 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { if !ok { return nil, errors.New("not a HEADERS frame") } - // TODO: check size + if hf.Length > c.maxHeaderBytes() { + return nil, fmt.Errorf("Headers frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()) + } headerBlock := make([]byte, hf.Length) if _, err := io.ReadFull(str, headerBlock); err != nil { return nil, err diff --git a/http3/client_test.go b/http3/client_test.go index f8ccc71c..1b65d253 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -31,7 +31,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { origDialAddr = dialAddr hostname := "quic.clemente.io:1337" - client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil) + client = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) Expect(client.hostname).To(Equal(hostname)) var err error @@ -275,6 +275,28 @@ var _ = Describe("Client", func() { _, err := client.RoundTrip(request) Expect(err).To(MatchError("test done")) }) + + It("errors when the first frame is not a HEADERS frame", func() { + buf := &bytes.Buffer{} + (&dataFrame{Length: 0x42}).Write(buf) + str.EXPECT().Close().MaxTimes(1) + str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) { + return buf.Read(b) + }).AnyTimes() + _, err := client.RoundTrip(request) + Expect(err).To(MatchError("not a HEADERS frame")) + }) + + It("errors when the first frame is not a HEADERS frame", func() { + buf := &bytes.Buffer{} + (&headersFrame{Length: 1338}).Write(buf) + str.EXPECT().Close().MaxTimes(1) + str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) { + return buf.Read(b) + }).AnyTimes() + _, err := client.RoundTrip(request) + Expect(err).To(MatchError("Headers frame too large: 1338 bytes (max: 1337)")) + }) }) Context("gzip compression", func() { diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 493e2fb3..003e17d5 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -46,6 +46,11 @@ type RoundTripper struct { // If Dial is nil, quic.DialAddr will be used. Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error) + // MaxResponseHeaderBytes specifies a limit on how many response bytes are + // allowed in the server's response header. + // Zero means to use a default limit. + MaxResponseHeaderBytes int64 + clients map[string]roundTripCloser } @@ -128,7 +133,10 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr client = newClient( hostname, r.TLSClientConfig, - &roundTripperOpts{DisableCompression: r.DisableCompression}, + &roundTripperOpts{ + DisableCompression: r.DisableCompression, + MaxHeaderBytes: r.MaxResponseHeaderBytes, + }, r.QuicConfig, r.Dial, )