diff --git a/http3/client.go b/http3/client.go index 8ca78bb7..9bacd679 100644 --- a/http3/client.go +++ b/http3/client.go @@ -11,6 +11,7 @@ import ( "sync" "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qtls" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/marten-seemann/qpack" @@ -28,6 +29,7 @@ const ( var defaultQuicConfig = &quic.Config{ MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams KeepAlive: true, + Versions: []protocol.VersionNumber{protocol.VersionTLS}, } var dialAddr = quic.DialAddrEarly @@ -63,19 +65,26 @@ func newClient( opts *roundTripperOpts, quicConfig *quic.Config, dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error), -) *client { +) (*client, error) { + if quicConfig == nil { + quicConfig = defaultQuicConfig + } else if len(quicConfig.Versions) == 0 { + quicConfig = quicConfig.Clone() + quicConfig.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} + } + if len(quicConfig.Versions) != 1 { + return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") + } + quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams + logger := utils.DefaultLogger.WithPrefix("h3 client") + if tlsConf == nil { tlsConf = &tls.Config{} } else { tlsConf = tlsConf.Clone() } // Replace existing ALPNs by H3 - tlsConf.NextProtos = []string{nextProtoH3Draft29} - if quicConfig == nil { - quicConfig = defaultQuicConfig - } - quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams - logger := utils.DefaultLogger.WithPrefix("h3 client") + tlsConf.NextProtos = []string{versionToALPN(quicConfig.Versions[0])} return &client{ hostname: authorityAddr("https", hostname), @@ -86,7 +95,7 @@ func newClient( opts: opts, dialer: dialer, logger: logger, - } + }, nil } func (c *client) dial() error { diff --git a/http3/client_test.go b/http3/client_test.go index fd2d998a..c107a9cb 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -35,10 +35,11 @@ var _ = Describe("Client", func() { BeforeEach(func() { origDialAddr = dialAddr hostname := "quic.clemente.io:1337" - client = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) + var err error + client, err = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) + Expect(err).ToNot(HaveOccurred()) Expect(client.hostname).To(Equal(hostname)) - var err error req, err = http.NewRequest("GET", "https://localhost:1337", nil) Expect(err).ToNot(HaveOccurred()) @@ -51,12 +52,22 @@ var _ = Describe("Client", func() { dialAddr = origDialAddr }) + It("rejects quic.Configs that allow multiple QUIC versions", func() { + qconf := &quic.Config{ + Versions: []quic.VersionNumber{protocol.VersionDraft29, protocol.VersionDraft32}, + } + _, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil) + Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection")) + }) + It("uses the default QUIC and TLS config if none is give", func() { - client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) + client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) { Expect(quicConf).To(Equal(defaultQuicConfig)) Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3Draft29})) + Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.VersionTLS})) dialAddrCalled = true return nil, errors.New("test done") } @@ -65,7 +76,8 @@ var _ = Describe("Client", func() { }) It("adds the port to the hostname, if none is given", func() { - client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) + client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { Expect(hostname).To(Equal("quic.clemente.io:443")) @@ -84,7 +96,8 @@ var _ = Describe("Client", func() { NextProtos: []string{"proto foo", "proto bar"}, } quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond} - client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) + client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) + Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool dialAddr = func( hostname string, @@ -117,25 +130,28 @@ var _ = Describe("Client", func() { dialerCalled = true return nil, testErr } - client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) - _, err := client.RoundTrip(req) + client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) + Expect(err).ToNot(HaveOccurred()) + _, err = client.RoundTrip(req) Expect(err).To(MatchError(testErr)) Expect(dialerCalled).To(BeTrue()) }) It("errors when dialing fails", func() { testErr := errors.New("handshake error") - client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) + client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return nil, testErr } - _, err := client.RoundTrip(req) + _, err = client.RoundTrip(req) Expect(err).To(MatchError(testErr)) }) It("errors if it can't open a stream", func() { testErr := errors.New("stream open error") - client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) + client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) session := mockquic.NewMockEarlySession(mockCtrl) session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1) session.EXPECT().HandshakeComplete().Return(handshakeCtx).MaxTimes(1) @@ -145,14 +161,14 @@ var _ = Describe("Client", func() { return session, nil } defer GinkgoRecover() - _, err := client.RoundTrip(req) + _, err = client.RoundTrip(req) Expect(err).To(MatchError(testErr)) }) It("closes correctly if session was not created", func() { - client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) - err := client.Close() + client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) + Expect(client.Close()).To(Succeed()) }) Context("Doing requests", func() { @@ -437,7 +453,8 @@ var _ = Describe("Client", func() { }) It("doesn't add gzip if the header disable it", func() { - client = newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) + client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) + Expect(err).ToNot(HaveOccurred()) sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { @@ -448,7 +465,7 @@ var _ = Describe("Client", func() { str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors ) str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) - _, err := client.RoundTrip(request) + _, err = client.RoundTrip(request) Expect(err).To(MatchError("test done")) hfs := decodeHeader(buf) Expect(hfs).ToNot(HaveKey("accept-encoding")) diff --git a/http3/roundtrip.go b/http3/roundtrip.go index cc307029..3b074d32 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -130,7 +130,8 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr if onlyCached { return nil, ErrNoCachedConn } - client = newClient( + var err error + client, err = newClient( hostname, r.TLSClientConfig, &roundTripperOpts{ @@ -140,6 +141,9 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr r.QuicConfig, r.Dial, ) + if err != nil { + return nil, err + } r.clients[hostname] = client } return client, nil