diff --git a/http3/roundtrip.go b/http3/roundtrip.go index d9812abb..ff813536 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -16,6 +16,9 @@ import ( "github.com/quic-go/quic-go" ) +// declare this as a variable, such that we can it mock it in the tests +var quicDialer = quic.DialEarlyContext + type roundTripCloser interface { RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) HandshakeComplete() bool @@ -69,7 +72,8 @@ type RoundTripper struct { // Dial specifies an optional dial function for creating QUIC // connections for requests. - // If Dial is nil, quic.DialAddrEarlyContext will be used. + // If Dial is nil, a UDPConn will be created at the first request + // and will be reused for subsequent connections to other servers. Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) // MaxResponseHeaderBytes specifies a limit on how many response bytes are @@ -79,6 +83,7 @@ type RoundTripper struct { newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests clients map[string]roundTripCloser + udpConn *net.UDPConn } // RoundTripOpt are options for the Transport.RoundTripOpt method. @@ -173,6 +178,16 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri if r.newClient != nil { newCl = r.newClient } + dial := r.Dial + if dial == nil { + if r.udpConn == nil { + r.udpConn, err = net.ListenUDP("udp", nil) + if err != nil { + return nil, false, err + } + } + dial = r.makeDialer() + } client, err = newCl( hostname, r.TLSClientConfig, @@ -184,7 +199,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri UniStreamHijacker: r.UniStreamHijacker, }, r.QuicConfig, - r.Dial, + dial, ) if err != nil { return nil, false, err @@ -205,7 +220,8 @@ func (r *RoundTripper) removeClient(hostname string) { delete(r.clients, hostname) } -// Close closes the QUIC connections that this RoundTripper has used +// Close closes the QUIC connections that this RoundTripper has used. +// It also closes the underlying UDPConn if it is not nil. func (r *RoundTripper) Close() error { r.mutex.Lock() defer r.mutex.Unlock() @@ -215,6 +231,10 @@ func (r *RoundTripper) Close() error { } } r.clients = nil + if r.udpConn != nil { + r.udpConn.Close() + r.udpConn = nil + } return nil } @@ -245,3 +265,14 @@ func validMethod(method string) bool { func isNotToken(r rune) bool { return !httpguts.IsTokenRune(r) } + +// makeDialer makes a QUIC dialer using r.udpConn. +func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + return quicDialer(ctx, r.udpConn, udpAddr, addr, tlsCfg, cfg) + } +} diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index b03eef60..26c46ec2 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -6,10 +6,12 @@ import ( "crypto/tls" "errors" "io" + "net" "net/http" "time" "github.com/quic-go/quic-go" + mockquic "github.com/quic-go/quic-go/internal/mocks/quic" "github.com/quic-go/quic-go/internal/qerr" "github.com/golang/mock/gomock" @@ -75,7 +77,7 @@ var _ = Describe("RoundTripper", func() { It("uses the quic.Config, if provided", func() { config := &quic.Config{HandshakeIdleTimeout: time.Millisecond} var receivedConfig *quic.Config - dialAddr = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { + rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { receivedConfig = config return nil, errors.New("handshake error") } @@ -320,4 +322,66 @@ var _ = Describe("RoundTripper", func() { Expect(len(rt.clients)).To(BeZero()) }) }) + + Context("reusing udpconn", func() { + var originalDialer func(ctx context.Context, pconn net.PacketConn, remoteAddr net.Addr, host string, tlsConf *tls.Config, config *quic.Config) (quic.EarlyConnection, error) + var req1, req2 *http.Request + + BeforeEach(func() { + var err error + originalDialer = quicDialer + req1, err = http.NewRequest("GET", "https://site1.com", nil) + Expect(err).ToNot(HaveOccurred()) + req2, err = http.NewRequest("GET", "https://site2.com", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(req1.Host).ToNot(Equal(req2.Host)) + }) + + AfterEach(func() { + quicDialer = originalDialer + err := rt.Close() + Expect(err).ToNot(HaveOccurred()) + }) + + It("creates udpconn at first request", func() { + Expect(rt.udpConn).To(BeNil()) + rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { + cl := NewMockRoundTripCloser(mockCtrl) + cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()) + cl.EXPECT().Close() + return cl, nil + } + _, err := rt.RoundTrip(req1) + Expect(err).ToNot(HaveOccurred()) + Expect(rt.udpConn).ToNot(BeNil()) + }) + + It("reuses udpconn in different hosts", func() { + Expect(rt.udpConn).To(BeNil()) + quicDialer = func(_ context.Context, pconn net.PacketConn, _ net.Addr, _ string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { + conn := mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().LocalAddr().Return(pconn.LocalAddr()) + return conn, nil + } + rt.newClient = func(hostname string, tlsConf *tls.Config, _ *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { + cl := NewMockRoundTripCloser(mockCtrl) + cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *http.Request, _ RoundTripOpt) (*http.Response, error) { + header := make(http.Header) + quicConn, err := dialer(context.Background(), hostname, tlsConf, conf) + Expect(err).ToNot(HaveOccurred()) + header.Set("udpconn", quicConn.LocalAddr().String()) + return &http.Response{Header: header}, nil + }) + cl.EXPECT().Close() + return cl, nil + } + rsp1, err := rt.RoundTrip(req1) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp1.Header.Get("udpconn")).ToNot(Equal("")) + rsp2, err := rt.RoundTrip(req2) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp2.Header.Get("udpconn")).ToNot(Equal("")) + Expect(rsp1.Header.Get("udpconn")).To(Equal(rsp2.Header.Get("udpconn"))) + }) + }) })