reject http3 responses that exceeded the header size limit

This commit is contained in:
Marten Seemann 2019-08-22 12:08:02 +07:00
parent 9294652ecc
commit 363de010ca
3 changed files with 44 additions and 3 deletions

View file

@ -17,6 +17,7 @@ import (
) )
const defaultUserAgent = "quic-go HTTP/3" const defaultUserAgent = "quic-go HTTP/3"
const defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
var defaultQuicConfig = &quic.Config{KeepAlive: true} var defaultQuicConfig = &quic.Config{KeepAlive: true}
@ -24,6 +25,7 @@ var dialAddr = quic.DialAddr
type roundTripperOpts struct { type roundTripperOpts struct {
DisableCompression bool DisableCompression bool
MaxHeaderBytes int64
} }
// client is a HTTP3 client doing requests // client is a HTTP3 client doing requests
@ -121,6 +123,13 @@ func (c *client) Close() error {
return c.session.Close() return c.session.Close()
} }
func (c *client) maxHeaderBytes() uint64 {
if c.opts.MaxHeaderBytes <= 0 {
return defaultMaxResponseHeaderBytes
}
return uint64(c.opts.MaxHeaderBytes)
}
// Roundtrip executes a request and returns a response // Roundtrip executes a request and returns a response
// TODO: handle request cancelations // TODO: handle request cancelations
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
@ -160,7 +169,9 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
if !ok { if !ok {
return nil, errors.New("not a HEADERS frame") return nil, errors.New("not a HEADERS frame")
} }
// TODO: check size if hf.Length > c.maxHeaderBytes() {
return nil, fmt.Errorf("Headers frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes())
}
headerBlock := make([]byte, hf.Length) headerBlock := make([]byte, hf.Length)
if _, err := io.ReadFull(str, headerBlock); err != nil { if _, err := io.ReadFull(str, headerBlock); err != nil {
return nil, err return nil, err

View file

@ -31,7 +31,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() { BeforeEach(func() {
origDialAddr = dialAddr origDialAddr = dialAddr
hostname := "quic.clemente.io:1337" hostname := "quic.clemente.io:1337"
client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil) client = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
Expect(client.hostname).To(Equal(hostname)) Expect(client.hostname).To(Equal(hostname))
var err error var err error
@ -275,6 +275,28 @@ var _ = Describe("Client", func() {
_, err := client.RoundTrip(request) _, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done")) Expect(err).To(MatchError("test done"))
}) })
It("errors when the first frame is not a HEADERS frame", func() {
buf := &bytes.Buffer{}
(&dataFrame{Length: 0x42}).Write(buf)
str.EXPECT().Close().MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return buf.Read(b)
}).AnyTimes()
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("not a HEADERS frame"))
})
It("errors when the first frame is not a HEADERS frame", func() {
buf := &bytes.Buffer{}
(&headersFrame{Length: 1338}).Write(buf)
str.EXPECT().Close().MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return buf.Read(b)
}).AnyTimes()
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("Headers frame too large: 1338 bytes (max: 1337)"))
})
}) })
Context("gzip compression", func() { Context("gzip compression", func() {

View file

@ -46,6 +46,11 @@ type RoundTripper struct {
// If Dial is nil, quic.DialAddr will be used. // If Dial is nil, quic.DialAddr will be used.
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error) Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
// allowed in the server's response header.
// Zero means to use a default limit.
MaxResponseHeaderBytes int64
clients map[string]roundTripCloser clients map[string]roundTripCloser
} }
@ -128,7 +133,10 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
client = newClient( client = newClient(
hostname, hostname,
r.TLSClientConfig, r.TLSClientConfig,
&roundTripperOpts{DisableCompression: r.DisableCompression}, &roundTripperOpts{
DisableCompression: r.DisableCompression,
MaxHeaderBytes: r.MaxResponseHeaderBytes,
},
r.QuicConfig, r.QuicConfig,
r.Dial, r.Dial,
) )