http3: add a RoundTripOpt to check the server's SETTINGS frame (#4355)

For some requests, the client is required to check the server's HTTP/3
SETTINGS. For example, a client is only allowed to send HTTP/3 datagrams
if the server explicitly enabled support.

SETTINGS are sent asynchronously on a control stream (usually the first
unidirectional stream). This means that the SETTINGS might not be
available at the beginning of the connection. This is not expected to be
the common case, since the server can send the SETTINGS in 0.5-RTT data,
but we have to be able to deal with arbitrary delays.

For WebTransport, there are even more SETTINGS values that the client
needs to check. By making CheckSettings a callback on the RoundTripOpt,
this entire validation logic can live at the WebTransport layer.
This commit is contained in:
Marten Seemann 2024-03-12 17:33:00 +09:30 committed by GitHub
parent ca787d6f00
commit 497d3f58a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 145 additions and 22 deletions

View file

@ -58,6 +58,9 @@ type client struct {
dialer dialFunc dialer dialFunc
handshakeErr error handshakeErr error
receivedSettings chan struct{} // closed once the server's SETTINGS frame was processed
settings *Settings // set once receivedSettings is closed
requestWriter *requestWriter requestWriter *requestWriter
decoder *qpack.Decoder decoder *qpack.Decoder
@ -107,14 +110,15 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con
tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}
return &client{ return &client{
hostname: authorityAddr("https", hostname), hostname: authorityAddr("https", hostname),
tlsConf: tlsConf, tlsConf: tlsConf,
requestWriter: newRequestWriter(logger), requestWriter: newRequestWriter(logger),
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), receivedSettings: make(chan struct{}),
config: conf, decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
opts: opts, config: conf,
dialer: dialer, opts: opts,
logger: logger, dialer: dialer,
logger: logger,
}, nil }, nil
} }
@ -234,6 +238,12 @@ func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) {
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
return return
} }
c.settings = &Settings{
EnableDatagram: sf.Datagram,
EnableExtendedConnect: sf.ExtendedConnect,
Other: sf.Other,
}
close(c.receivedSettings)
if !sf.Datagram { if !sf.Datagram {
return return
} }
@ -299,6 +309,18 @@ func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
} }
} }
if opt.CheckSettings != nil {
// wait for the server's SETTINGS frame to arrive
select {
case <-c.receivedSettings:
case <-conn.Context().Done():
return nil, context.Cause(conn.Context())
}
if err := opt.CheckSettings(*c.settings); err != nil {
return nil, err
}
}
str, err := conn.OpenStreamSync(req.Context()) str, err := conn.OpenStreamSync(req.Context())
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -456,7 +456,6 @@ var _ = Describe("Client", func() {
conn = mockquic.NewMockEarlyConnection(mockCtrl) conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil return conn, nil
} }
@ -472,10 +471,15 @@ var _ = Describe("Client", func() {
It("parses the SETTINGS frame", func() { It("parses the SETTINGS frame", func() {
b := quicvarint.Append(nil, streamTypeControlStream) b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{}).Append(b) b = (&settingsFrame{
Datagram: true,
ExtendedConnect: true,
Other: map[uint64]uint64{1337: 42},
}).Append(b)
r := bytes.NewReader(b) r := bytes.NewReader(b)
controlStr := mockquic.NewMockStream(mockCtrl) controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil return controlStr, nil
}) })
@ -483,11 +487,42 @@ var _ = Describe("Client", func() {
<-testDone <-testDone
return nil, errors.New("test done") return nil, errors.New("test done")
}) })
_, err := cl.RoundTripOpt(req, RoundTripOpt{}) conn.EXPECT().Context().Return(context.Background())
_, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error {
defer GinkgoRecover()
Expect(settings.EnableDatagram).To(BeTrue())
Expect(settings.EnableExtendedConnect).To(BeTrue())
Expect(settings.Other).To(HaveLen(1))
Expect(settings.Other).To(HaveKeyWithValue(uint64(1337), uint64(42)))
return nil
}})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
}) })
It("allows the client to reject the SETTINGS using the CheckSettings RoundTripOpt", func() {
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{}).Append(b)
r := bytes.NewReader(b)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
// Don't EXPECT any call to OpenStreamSync.
// When the SETTINGS are rejected, we don't even open the request stream.
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
conn.EXPECT().Context().Return(context.Background())
_, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error {
return errors.New("wrong settings")
}})
Expect(err).To(MatchError("wrong settings"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
It("rejects duplicate control streams", func() { It("rejects duplicate control streams", func() {
b := quicvarint.Append(nil, streamTypeControlStream) b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{}).Append(b) b = (&settingsFrame{}).Append(b)
@ -498,6 +533,7 @@ var _ = Describe("Client", func() {
controlStr2 := mockquic.NewMockStream(mockCtrl) controlStr2 := mockquic.NewMockStream(mockCtrl)
controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes() controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes()
done := make(chan struct{}) done := make(chan struct{})
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error { conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error {
close(done) close(done)
return nil return nil
@ -529,6 +565,7 @@ var _ = Describe("Client", func() {
str := mockquic.NewMockStream(mockCtrl) str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return str, nil return str, nil
}) })
@ -542,13 +579,14 @@ var _ = Describe("Client", func() {
}) })
} }
It("resets streams Other than the control stream and the QPACK streams", func() { It("resets streams other than the control stream and the QPACK streams", func() {
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337)) buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337))
str := mockquic.NewMockStream(mockCtrl) str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
done := make(chan struct{}) done := make(chan struct{})
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) }) str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) })
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return str, nil return str, nil
}) })
@ -567,6 +605,8 @@ var _ = Describe("Client", func() {
r := bytes.NewReader(b) r := bytes.NewReader(b)
controlStr := mockquic.NewMockStream(mockCtrl) controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil return controlStr, nil
}) })
@ -584,12 +624,44 @@ var _ = Describe("Client", func() {
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("errors when the first frame on the control stream is not a SETTINGS frame, when checking SETTINGS", func() {
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&dataFrame{}).Append(b)
r := bytes.NewReader(b)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
// Don't EXPECT any calls to OpenStreamSync.
// We fail before we even get the chance to open the request stream.
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
doneCtx, doneCancel := context.WithCancelCause(context.Background())
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
doneCancel(errors.New("done"))
return nil
})
conn.EXPECT().Context().Return(doneCtx).Times(2)
var checked bool
_, err := cl.RoundTripOpt(req, RoundTripOpt{
CheckSettings: func(Settings) error { checked = true; return nil },
})
Expect(checked).To(BeFalse())
Expect(err).To(MatchError("done"))
Eventually(doneCtx.Done()).Should(BeClosed())
})
It("errors when parsing the frame on the control stream fails", func() { It("errors when parsing the frame on the control stream fails", func() {
b := quicvarint.Append(nil, streamTypeControlStream) b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{}).Append(b) b = (&settingsFrame{}).Append(b)
r := bytes.NewReader(b[:len(b)-1]) r := bytes.NewReader(b[:len(b)-1])
controlStr := mockquic.NewMockStream(mockCtrl) controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil return controlStr, nil
}) })
@ -611,6 +683,7 @@ var _ = Describe("Client", func() {
buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream)) buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream))
controlStr := mockquic.NewMockStream(mockCtrl) controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil return controlStr, nil
}) })
@ -635,6 +708,7 @@ var _ = Describe("Client", func() {
r := bytes.NewReader(b) r := bytes.NewReader(b)
controlStr := mockquic.NewMockStream(mockCtrl) controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil return controlStr, nil
}) })

View file

@ -17,6 +17,30 @@ import (
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
) )
// Settings are HTTP/3 settings that apply to the underlying connection.
type Settings struct {
// Support for HTTP/3 datagrams (RFC 9297)
EnableDatagram bool
// Extended CONNECT, RFC 9220
EnableExtendedConnect bool
// Other settings, defined by the application
Other map[uint64]uint64
}
// RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
OnlyCachedConn bool
// DontCloseRequestStream controls whether the request stream is closed after sending the request.
// If set, context cancellations have no effect after the response headers are received.
DontCloseRequestStream bool
// CheckSettings is run before the request is sent to the server.
// If not yet received, it blocks until the server's SETTINGS frame is received.
// If an error is returned, the request won't be sent to the server, and the error is returned.
CheckSettings func(Settings) error
}
type roundTripCloser interface { type roundTripCloser interface {
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
HandshakeComplete() bool HandshakeComplete() bool
@ -88,16 +112,6 @@ type RoundTripper struct {
transport *quic.Transport transport *quic.Transport
} }
// RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
OnlyCachedConn bool
// DontCloseRequestStream controls whether the request stream is closed after sending the request.
// If set, context cancellations have no effect after the response headers are received.
DontCloseRequestStream bool
}
var ( var (
_ http.RoundTripper = &RoundTripper{} _ http.RoundTripper = &RoundTripper{}
_ io.Closer = &RoundTripper{} _ io.Closer = &RoundTripper{}

View file

@ -559,4 +559,17 @@ var _ = Describe("HTTP tests", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200)) Expect(resp.StatusCode).To(Equal(200))
}) })
It("checks the server's settings", func() {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/hello", port), nil)
Expect(err).ToNot(HaveOccurred())
testErr := errors.New("test error")
_, err = rt.RoundTripOpt(req, http3.RoundTripOpt{CheckSettings: func(settings http3.Settings) error {
Expect(settings.EnableExtendedConnect).To(BeTrue())
Expect(settings.EnableDatagram).To(BeFalse())
Expect(settings.Other).To(BeEmpty())
return testErr
}})
Expect(err).To(MatchError(err))
})
}) })