diff --git a/http3/roundtrip.go b/http3/roundtrip.go index eef93c28..0e11242a 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -17,9 +17,6 @@ import ( "github.com/quic-go/quic-go" ) -// declare this as a variable, such that we can it mock it in the tests -var quicDialer = quic.DialEarly - type roundTripCloser interface { RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) HandshakeComplete() bool @@ -89,7 +86,7 @@ type RoundTripper struct { newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests clients map[string]*roundTripCloserWithCount - udpConn *net.UDPConn + transport *quic.Transport } // RoundTripOpt are options for the Transport.RoundTripOpt method. @@ -187,11 +184,12 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTr } dial := r.Dial if dial == nil { - if r.udpConn == nil { - r.udpConn, err = net.ListenUDP("udp", nil) + if r.transport == nil { + udpConn, err := net.ListenUDP("udp", nil) if err != nil { return nil, false, err } + r.transport = &quic.Transport{Conn: udpConn} } dial = r.makeDialer() } @@ -240,9 +238,9 @@ func (r *RoundTripper) Close() error { } } r.clients = nil - if r.udpConn != nil { - r.udpConn.Close() - r.udpConn = nil + if r.transport != nil { + r.transport.Close() + r.transport = nil } return nil } @@ -282,7 +280,7 @@ func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCf if err != nil { return nil, err } - return quicDialer(ctx, r.udpConn, udpAddr, tlsCfg, cfg) + return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) } } diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index 8e859de5..0c219db7 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -6,13 +6,11 @@ import ( "crypto/tls" "errors" "io" - "net" "net/http" "sync/atomic" "time" "github.com/quic-go/quic-go" - mockquic "github.com/quic-go/quic-go/internal/mocks/quic" "github.com/quic-go/quic-go/internal/qerr" "github.com/golang/mock/gomock" @@ -368,66 +366,4 @@ var _ = Describe("RoundTripper", func() { Expect(len(rt.clients)).To(Equal(0)) }) }) - - Context("reusing udpconn", func() { - var originalDialer func(ctx context.Context, pconn net.PacketConn, remoteAddr net.Addr, tlsConf *tls.Config, config *quic.Config) (quic.EarlyConnection, error) - var req1, req2 *http.Request - - BeforeEach(func() { - var err error - originalDialer = quicDialer - req1, err = http.NewRequest("GET", "https://site1.com", nil) - Expect(err).ToNot(HaveOccurred()) - req2, err = http.NewRequest("GET", "https://site2.com", nil) - Expect(err).ToNot(HaveOccurred()) - Expect(req1.Host).ToNot(Equal(req2.Host)) - }) - - AfterEach(func() { - quicDialer = originalDialer - err := rt.Close() - Expect(err).ToNot(HaveOccurred()) - }) - - It("creates udpconn at first request", func() { - Expect(rt.udpConn).To(BeNil()) - rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { - cl := NewMockRoundTripCloser(mockCtrl) - cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()) - cl.EXPECT().Close() - return cl, nil - } - _, err := rt.RoundTrip(req1) - Expect(err).ToNot(HaveOccurred()) - Expect(rt.udpConn).ToNot(BeNil()) - }) - - It("reuses udpconn in different hosts", func() { - Expect(rt.udpConn).To(BeNil()) - quicDialer = func(_ context.Context, pconn net.PacketConn, _ net.Addr, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { - conn := mockquic.NewMockEarlyConnection(mockCtrl) - conn.EXPECT().LocalAddr().Return(pconn.LocalAddr()) - return conn, nil - } - rt.newClient = func(hostname string, tlsConf *tls.Config, _ *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { - cl := NewMockRoundTripCloser(mockCtrl) - cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *http.Request, _ RoundTripOpt) (*http.Response, error) { - header := make(http.Header) - quicConn, err := dialer(context.Background(), hostname, tlsConf, conf) - Expect(err).ToNot(HaveOccurred()) - header.Set("udpconn", quicConn.LocalAddr().String()) - return &http.Response{Header: header}, nil - }) - cl.EXPECT().Close() - return cl, nil - } - rsp1, err := rt.RoundTrip(req1) - Expect(err).ToNot(HaveOccurred()) - Expect(rsp1.Header.Get("udpconn")).ToNot(Equal("")) - rsp2, err := rt.RoundTrip(req2) - Expect(err).ToNot(HaveOccurred()) - Expect(rsp2.Header.Get("udpconn")).ToNot(Equal("")) - Expect(rsp1.Header.Get("udpconn")).To(Equal(rsp2.Header.Get("udpconn"))) - }) - }) }) diff --git a/integrationtests/self/hotswap_test.go b/integrationtests/self/hotswap_test.go index 6cd73079..eeb16441 100644 --- a/integrationtests/self/hotswap_test.go +++ b/integrationtests/self/hotswap_test.go @@ -61,6 +61,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() { mux1 *http.ServeMux mux2 *http.ServeMux client *http.Client + rt *http3.RoundTripper server1 *http3.Server server2 *http3.Server ln *listenerWrapper @@ -97,17 +98,17 @@ var _ = Describe("HTTP3 Server hotswap test", func() { }) AfterEach(func() { + rt.Close() Expect(ln.Close()).NotTo(HaveOccurred()) }) BeforeEach(func() { - client = &http.Client{ - Transport: &http3.RoundTripper{ - TLSClientConfig: getTLSClientConfig(), - DisableCompression: true, - QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}), - }, + rt = &http3.RoundTripper{ + TLSClientConfig: getTLSClientConfig(), + DisableCompression: true, + QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}), } + client = &http.Client{Transport: rt} }) It("hotswap works", func() { diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 29385aef..334af45d 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -40,6 +40,7 @@ var _ = Describe("HTTP tests", func() { var ( mux *http.ServeMux client *http.Client + rt *http3.RoundTripper server *http3.Server stoppedServing chan struct{} port string @@ -77,6 +78,12 @@ var _ = Describe("HTTP tests", func() { w.Write(body) // don't check the error here. Stream may be reset. }) + mux.HandleFunc("/remoteAddr", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + w.Header().Set("X-RemoteAddr", r.RemoteAddr) + w.WriteHeader(http.StatusOK) + }) + server = &http3.Server{ Handler: mux, TLSConfig: getTLSConfig(), @@ -99,18 +106,18 @@ var _ = Describe("HTTP tests", func() { }) AfterEach(func() { + rt.Close() Expect(server.Close()).NotTo(HaveOccurred()) Eventually(stoppedServing).Should(BeClosed()) }) BeforeEach(func() { - client = &http.Client{ - Transport: &http3.RoundTripper{ - TLSClientConfig: getTLSClientConfig(), - DisableCompression: true, - QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}), - }, + rt = &http3.RoundTripper{ + TLSClientConfig: getTLSClientConfigWithoutServerName(), + DisableCompression: true, + QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}), } + client = &http.Client{Transport: rt} }) It("downloads a hello", func() { @@ -122,6 +129,20 @@ var _ = Describe("HTTP tests", func() { Expect(string(body)).To(Equal("Hello, World!\n")) }) + It("requests to different servers with the same udpconn", func() { + resp, err := client.Get("https://localhost:" + port + "/remoteAddr") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + addr1 := resp.Header.Get("X-RemoteAddr") + Expect(addr1).ToNot(Equal("")) + resp, err = client.Get("https://127.0.0.1:" + port + "/remoteAddr") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + addr2 := resp.Header.Get("X-RemoteAddr") + Expect(addr2).ToNot(Equal("")) + Expect(addr1).To(Equal(addr2)) + }) + It("downloads concurrently", func() { group, ctx := errgroup.WithContext(context.Background()) for i := 0; i < 2; i++ { diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 10adc856..8f9436e6 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -90,10 +90,11 @@ var ( qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer enableQlog bool - version quic.VersionNumber - tlsConfig *tls.Config - tlsConfigLongChain *tls.Config - tlsClientConfig *tls.Config + version quic.VersionNumber + tlsConfig *tls.Config + tlsConfigLongChain *tls.Config + tlsClientConfig *tls.Config + tlsClientConfigWithoutServerName *tls.Config ) // read the logfile command line flag @@ -131,6 +132,10 @@ func init() { RootCAs: root, NextProtos: []string{alpn}, } + tlsClientConfigWithoutServerName = &tls.Config{ + RootCAs: root, + NextProtos: []string{alpn}, + } } var _ = BeforeSuite(func() { @@ -165,6 +170,10 @@ func getTLSClientConfig() *tls.Config { return tlsClientConfig.Clone() } +func getTLSClientConfigWithoutServerName() *tls.Config { + return tlsClientConfigWithoutServerName.Clone() +} + func getQuicConfig(conf *quic.Config) *quic.Config { if conf == nil { conf = &quic.Config{} diff --git a/integrationtests/tools/crypto.go b/integrationtests/tools/crypto.go index d0bdfb41..6bd02ee0 100644 --- a/integrationtests/tools/crypto.go +++ b/integrationtests/tools/crypto.go @@ -9,6 +9,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "math/big" + "net" "time" ) @@ -44,6 +45,7 @@ func GenerateLeafCert(ca *x509.Certificate, caPriv crypto.PrivateKey) (*x509.Cer certTempl := &x509.Certificate{ SerialNumber: big.NewInt(1), DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, NotBefore: time.Now(), NotAfter: time.Now().Add(24 * time.Hour), ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},