diff --git a/http3/client.go b/http3/client.go index 4979a942..8bb212e8 100644 --- a/http3/client.go +++ b/http3/client.go @@ -23,7 +23,7 @@ var defaultQuicConfig = &quic.Config{ KeepAlive: true, } -var dialAddr = quic.DialAddr +var dialAddr = quic.DialAddrEarly type roundTripperOpts struct { DisableCompression bool @@ -37,7 +37,7 @@ type client struct { opts *roundTripperOpts dialOnce sync.Once - dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error) + dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) handshakeErr error requestWriter *requestWriter @@ -45,7 +45,7 @@ type client struct { decoder *qpack.Decoder hostname string - session quic.Session + session quic.EarlySession logger utils.Logger } @@ -55,7 +55,7 @@ func newClient( tlsConf *tls.Config, opts *roundTripperOpts, quicConfig *quic.Config, - dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error), + dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error), ) *client { if tlsConf == nil { tlsConf = &tls.Config{} @@ -93,6 +93,7 @@ func (c *client) dial() error { return err } + // run the sesssion setup using 0-RTT data go func() { if err := c.setupSession(); err != nil { c.logger.Debugf("Setting up session failed: %s", err) @@ -100,6 +101,7 @@ func (c *client) dial() error { } }() + <-c.session.HandshakeComplete().Done() return nil } diff --git a/http3/client_test.go b/http3/client_test.go index 10f10504..3893a4af 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -26,6 +26,7 @@ var _ = Describe("Client", func() { client *client req *http.Request origDialAddr = dialAddr + handshakeCtx context.Context // an already canceled context ) BeforeEach(func() { @@ -37,6 +38,10 @@ var _ = Describe("Client", func() { var err error req, err = http.NewRequest("GET", "https://localhost:1337", nil) Expect(err).ToNot(HaveOccurred()) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + handshakeCtx = ctx }) AfterEach(func() { @@ -46,7 +51,7 @@ var _ = Describe("Client", func() { It("uses the default QUIC and TLS config if none is give", func() { client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) var dialAddrCalled bool - dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.Session, error) { + dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) { Expect(quicConf).To(Equal(defaultQuicConfig)) Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3})) dialAddrCalled = true @@ -59,7 +64,7 @@ var _ = Describe("Client", func() { It("adds the port to the hostname, if none is given", func() { client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) var dialAddrCalled bool - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { Expect(hostname).To(Equal("quic.clemente.io:443")) dialAddrCalled = true return nil, errors.New("test done") @@ -82,7 +87,7 @@ var _ = Describe("Client", func() { hostname string, tlsConfP *tls.Config, quicConfP *quic.Config, - ) (quic.Session, error) { + ) (quic.EarlySession, error) { Expect(hostname).To(Equal("localhost:1337")) Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3})) @@ -101,7 +106,7 @@ var _ = Describe("Client", func() { tlsConf := &tls.Config{ServerName: "foo.bar"} quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} var dialerCalled bool - dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.Session, error) { + dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) { Expect(network).To(Equal("udp")) Expect(address).To(Equal("localhost:1337")) Expect(tlsConfP.ServerName).To(Equal("foo.bar")) @@ -118,7 +123,7 @@ var _ = Describe("Client", func() { It("errors when dialing fails", func() { testErr := errors.New("handshake error") client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return nil, testErr } _, err := client.RoundTrip(req) @@ -130,9 +135,10 @@ var _ = Describe("Client", func() { client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) session := mockquic.NewMockEarlySession(mockCtrl) session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1) + session.EXPECT().HandshakeComplete().Return(handshakeCtx).MaxTimes(1) session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1) session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return session, nil } defer GinkgoRecover() @@ -173,7 +179,7 @@ var _ = Describe("Client", func() { str = mockquic.NewMockStream(mockCtrl) sess = mockquic.NewMockEarlySession(mockCtrl) sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1) - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return sess, nil } var err error @@ -186,7 +192,10 @@ var _ = Describe("Client", func() { rw := newResponseWriter(rspBuf, utils.DefaultLogger) rw.WriteHeader(418) - sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + gomock.InOrder( + sess.EXPECT().HandshakeComplete().Return(handshakeCtx), + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), + ) str.EXPECT().Write(gomock.Any()).AnyTimes() str.EXPECT().Close() str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { @@ -220,7 +229,10 @@ var _ = Describe("Client", func() { BeforeEach(func() { strBuf = &bytes.Buffer{} - sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + gomock.InOrder( + sess.EXPECT().HandshakeComplete().Return(handshakeCtx), + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), + ) body := &mockBody{} body.SetData([]byte("request body")) var err error @@ -301,6 +313,7 @@ var _ = Describe("Client", func() { It("cancels a request while the request is still in flight", func() { ctx, cancel := context.WithCancel(context.Background()) req := request.WithContext(ctx) + sess.EXPECT().HandshakeComplete().Return(handshakeCtx) sess.EXPECT().OpenStreamSync(ctx).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Close().MaxTimes(1) @@ -333,6 +346,7 @@ var _ = Describe("Client", func() { ctx, cancel := context.WithCancel(context.Background()) req := request.WithContext(ctx) + sess.EXPECT().HandshakeComplete().Return(handshakeCtx) sess.EXPECT().OpenStreamSync(ctx).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Close().MaxTimes(1) @@ -354,21 +368,8 @@ var _ = Describe("Client", func() { }) Context("gzip compression", func() { - var gzippedData []byte // a gzipped foobar - var response *http.Response - BeforeEach(func() { - var b bytes.Buffer - w := gzip.NewWriter(&b) - w.Write([]byte("foobar")) - w.Close() - gzippedData = b.Bytes() - response = &http.Response{ - StatusCode: 200, - Header: http.Header{"Content-Length": []string{"1000"}}, - } - _ = gzippedData - _ = response + sess.EXPECT().HandshakeComplete().Return(handshakeCtx) }) It("adds the gzip header to requests", func() { diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 003e17d5..cc307029 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -44,7 +44,7 @@ type RoundTripper struct { // Dial specifies an optional dial function for creating QUIC // connections for requests. // If Dial is nil, quic.DialAddr will be used. - Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error) + Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) // MaxResponseHeaderBytes specifies a limit on how many response bytes are // allowed in the server's response header. diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index c661d4c4..878a6109 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -58,9 +58,10 @@ func (m *mockBody) Close() error { var _ = Describe("RoundTripper", func() { var ( - rt *RoundTripper - req1 *http.Request - session *mockquic.MockEarlySession + rt *RoundTripper + req1 *http.Request + session *mockquic.MockEarlySession + handshakeCtx context.Context // an already canceled context ) BeforeEach(func() { @@ -68,6 +69,10 @@ var _ = Describe("RoundTripper", func() { var err error req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil) Expect(err).ToNot(HaveOccurred()) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + handshakeCtx = ctx }) Context("dialing hosts", func() { @@ -76,7 +81,7 @@ var _ = Describe("RoundTripper", func() { BeforeEach(func() { session = mockquic.NewMockEarlySession(mockCtrl) origDialAddr = dialAddr - dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) { + dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) { // return an error when trying to open a stream // we don't want to test all the dial logic here, just that dialing happens at all return session, nil @@ -93,6 +98,7 @@ var _ = Describe("RoundTripper", func() { req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) + session.EXPECT().HandshakeComplete().Return(handshakeCtx) session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) }) _, err = rt.RoundTrip(req) @@ -104,7 +110,7 @@ var _ = Describe("RoundTripper", func() { It("uses the quic.Config, if provided", func() { config := &quic.Config{HandshakeTimeout: time.Millisecond} var receivedConfig *quic.Config - dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) { + dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) { receivedConfig = config return nil, errors.New("handshake error") } @@ -116,7 +122,7 @@ var _ = Describe("RoundTripper", func() { It("uses the custom dialer, if provided", func() { var dialed bool - dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) { + dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) { dialed = true return nil, errors.New("handshake error") } @@ -129,6 +135,7 @@ var _ = Describe("RoundTripper", func() { It("reuses existing clients", func() { closed := make(chan struct{}) testErr := errors.New("test err") + session.EXPECT().HandshakeComplete().Return(handshakeCtx) session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2) session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })