diff --git a/http3/client.go b/http3/client.go index 1f981509..402fd0df 100644 --- a/http3/client.go +++ b/http3/client.go @@ -58,6 +58,9 @@ type client struct { dialer dialFunc handshakeErr error + receivedSettings chan struct{} // closed once the server's SETTINGS frame was processed + settings *Settings // set once receivedSettings is closed + requestWriter *requestWriter decoder *qpack.Decoder @@ -107,14 +110,15 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} return &client{ - hostname: authorityAddr("https", hostname), - tlsConf: tlsConf, - requestWriter: newRequestWriter(logger), - decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), - config: conf, - opts: opts, - dialer: dialer, - logger: logger, + hostname: authorityAddr("https", hostname), + tlsConf: tlsConf, + requestWriter: newRequestWriter(logger), + receivedSettings: make(chan struct{}), + decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), + config: conf, + opts: opts, + dialer: dialer, + logger: logger, }, nil } @@ -234,6 +238,12 @@ func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") return } + c.settings = &Settings{ + EnableDatagram: sf.Datagram, + EnableExtendedConnect: sf.ExtendedConnect, + Other: sf.Other, + } + close(c.receivedSettings) if !sf.Datagram { return } @@ -299,6 +309,18 @@ func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon } } + if opt.CheckSettings != nil { + // wait for the server's SETTINGS frame to arrive + select { + case <-c.receivedSettings: + case <-conn.Context().Done(): + return nil, context.Cause(conn.Context()) + } + if err := opt.CheckSettings(*c.settings); err != nil { + return nil, err + } + } + str, err := conn.OpenStreamSync(req.Context()) if err != nil { return nil, err diff --git a/http3/client_test.go b/http3/client_test.go index cb63bf34..5665aa41 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -456,7 +456,6 @@ var _ = Describe("Client", func() { conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().HandshakeComplete().Return(handshakeChan) - conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return conn, nil } @@ -472,10 +471,15 @@ var _ = Describe("Client", func() { It("parses the SETTINGS frame", func() { b := quicvarint.Append(nil, streamTypeControlStream) - b = (&settingsFrame{}).Append(b) + b = (&settingsFrame{ + Datagram: true, + ExtendedConnect: true, + Other: map[uint64]uint64{1337: 42}, + }).Append(b) r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) @@ -483,11 +487,42 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + conn.EXPECT().Context().Return(context.Background()) + _, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error { + defer GinkgoRecover() + Expect(settings.EnableDatagram).To(BeTrue()) + Expect(settings.EnableExtendedConnect).To(BeTrue()) + Expect(settings.Other).To(HaveLen(1)) + Expect(settings.Other).To(HaveKeyWithValue(uint64(1337), uint64(42))) + return nil + }}) Expect(err).To(MatchError("done")) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) + It("allows the client to reject the SETTINGS using the CheckSettings RoundTripOpt", func() { + b := quicvarint.Append(nil, streamTypeControlStream) + b = (&settingsFrame{}).Append(b) + r := bytes.NewReader(b) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + // Don't EXPECT any call to OpenStreamSync. + // When the SETTINGS are rejected, we don't even open the request stream. + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + conn.EXPECT().Context().Return(context.Background()) + _, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error { + return errors.New("wrong settings") + }}) + Expect(err).To(MatchError("wrong settings")) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + It("rejects duplicate control streams", func() { b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) @@ -498,6 +533,7 @@ var _ = Describe("Client", func() { controlStr2 := mockquic.NewMockStream(mockCtrl) controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes() done := make(chan struct{}) + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error { close(done) return nil @@ -529,6 +565,7 @@ var _ = Describe("Client", func() { str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return str, nil }) @@ -542,13 +579,14 @@ var _ = Describe("Client", func() { }) } - It("resets streams Other than the control stream and the QPACK streams", func() { + It("resets streams other than the control stream and the QPACK streams", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337)) str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() done := make(chan struct{}) str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) }) + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return str, nil }) @@ -567,6 +605,8 @@ var _ = Describe("Client", func() { r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) @@ -584,12 +624,44 @@ var _ = Describe("Client", func() { Eventually(done).Should(BeClosed()) }) + It("errors when the first frame on the control stream is not a SETTINGS frame, when checking SETTINGS", func() { + b := quicvarint.Append(nil, streamTypeControlStream) + b = (&dataFrame{}).Append(b) + r := bytes.NewReader(b) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + + // Don't EXPECT any calls to OpenStreamSync. + // We fail before we even get the chance to open the request stream. + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + doneCtx, doneCancel := context.WithCancelCause(context.Background()) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { + doneCancel(errors.New("done")) + return nil + }) + conn.EXPECT().Context().Return(doneCtx).Times(2) + var checked bool + _, err := cl.RoundTripOpt(req, RoundTripOpt{ + CheckSettings: func(Settings) error { checked = true; return nil }, + }) + Expect(checked).To(BeFalse()) + Expect(err).To(MatchError("done")) + Eventually(doneCtx.Done()).Should(BeClosed()) + }) + It("errors when parsing the frame on the control stream fails", func() { b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) r := bytes.NewReader(b[:len(b)-1]) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) @@ -611,6 +683,7 @@ var _ = Describe("Client", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream)) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) @@ -635,6 +708,7 @@ var _ = Describe("Client", func() { r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 86d3d3d0..d297531c 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -17,6 +17,30 @@ import ( "github.com/quic-go/quic-go" ) +// Settings are HTTP/3 settings that apply to the underlying connection. +type Settings struct { + // Support for HTTP/3 datagrams (RFC 9297) + EnableDatagram bool + // Extended CONNECT, RFC 9220 + EnableExtendedConnect bool + // Other settings, defined by the application + Other map[uint64]uint64 +} + +// RoundTripOpt are options for the Transport.RoundTripOpt method. +type RoundTripOpt struct { + // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. + // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. + OnlyCachedConn bool + // DontCloseRequestStream controls whether the request stream is closed after sending the request. + // If set, context cancellations have no effect after the response headers are received. + DontCloseRequestStream bool + // CheckSettings is run before the request is sent to the server. + // If not yet received, it blocks until the server's SETTINGS frame is received. + // If an error is returned, the request won't be sent to the server, and the error is returned. + CheckSettings func(Settings) error +} + type roundTripCloser interface { RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) HandshakeComplete() bool @@ -88,16 +112,6 @@ type RoundTripper struct { transport *quic.Transport } -// RoundTripOpt are options for the Transport.RoundTripOpt method. -type RoundTripOpt struct { - // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. - // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. - OnlyCachedConn bool - // DontCloseRequestStream controls whether the request stream is closed after sending the request. - // If set, context cancellations have no effect after the response headers are received. - DontCloseRequestStream bool -} - var ( _ http.RoundTripper = &RoundTripper{} _ io.Closer = &RoundTripper{} diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index cf9f683e..10564c17 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -559,4 +559,17 @@ var _ = Describe("HTTP tests", func() { Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) }) + + It("checks the server's settings", func() { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/hello", port), nil) + Expect(err).ToNot(HaveOccurred()) + testErr := errors.New("test error") + _, err = rt.RoundTripOpt(req, http3.RoundTripOpt{CheckSettings: func(settings http3.Settings) error { + Expect(settings.EnableExtendedConnect).To(BeTrue()) + Expect(settings.EnableDatagram).To(BeFalse()) + Expect(settings.Other).To(BeEmpty()) + return testErr + }}) + Expect(err).To(MatchError(err)) + }) })