http3: correctly use the quic.Transport (#3869)

* use quic.Transport in http3

* add intergrationtests to dial server with different server names

* update test
This commit is contained in:
Glonee 2023-06-01 14:31:20 +08:00 committed by GitHub
parent 21549fcb4a
commit c96fbd2e4a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 57 additions and 90 deletions

View file

@ -17,9 +17,6 @@ import (
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
) )
// declare this as a variable, such that we can it mock it in the tests
var quicDialer = quic.DialEarly
type roundTripCloser interface { type roundTripCloser interface {
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
HandshakeComplete() bool HandshakeComplete() bool
@ -89,7 +86,7 @@ type RoundTripper struct {
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
clients map[string]*roundTripCloserWithCount clients map[string]*roundTripCloserWithCount
udpConn *net.UDPConn transport *quic.Transport
} }
// RoundTripOpt are options for the Transport.RoundTripOpt method. // RoundTripOpt are options for the Transport.RoundTripOpt method.
@ -187,11 +184,12 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTr
} }
dial := r.Dial dial := r.Dial
if dial == nil { if dial == nil {
if r.udpConn == nil { if r.transport == nil {
r.udpConn, err = net.ListenUDP("udp", nil) udpConn, err := net.ListenUDP("udp", nil)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
r.transport = &quic.Transport{Conn: udpConn}
} }
dial = r.makeDialer() dial = r.makeDialer()
} }
@ -240,9 +238,9 @@ func (r *RoundTripper) Close() error {
} }
} }
r.clients = nil r.clients = nil
if r.udpConn != nil { if r.transport != nil {
r.udpConn.Close() r.transport.Close()
r.udpConn = nil r.transport = nil
} }
return nil return nil
} }
@ -282,7 +280,7 @@ func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCf
if err != nil { if err != nil {
return nil, err return nil, err
} }
return quicDialer(ctx, r.udpConn, udpAddr, tlsCfg, cfg) return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
} }
} }

View file

@ -6,13 +6,11 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"io" "io"
"net"
"net/http" "net/http"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
"github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/qerr"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -368,66 +366,4 @@ var _ = Describe("RoundTripper", func() {
Expect(len(rt.clients)).To(Equal(0)) Expect(len(rt.clients)).To(Equal(0))
}) })
}) })
Context("reusing udpconn", func() {
var originalDialer func(ctx context.Context, pconn net.PacketConn, remoteAddr net.Addr, tlsConf *tls.Config, config *quic.Config) (quic.EarlyConnection, error)
var req1, req2 *http.Request
BeforeEach(func() {
var err error
originalDialer = quicDialer
req1, err = http.NewRequest("GET", "https://site1.com", nil)
Expect(err).ToNot(HaveOccurred())
req2, err = http.NewRequest("GET", "https://site2.com", nil)
Expect(err).ToNot(HaveOccurred())
Expect(req1.Host).ToNot(Equal(req2.Host))
})
AfterEach(func() {
quicDialer = originalDialer
err := rt.Close()
Expect(err).ToNot(HaveOccurred())
})
It("creates udpconn at first request", func() {
Expect(rt.udpConn).To(BeNil())
rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any())
cl.EXPECT().Close()
return cl, nil
}
_, err := rt.RoundTrip(req1)
Expect(err).ToNot(HaveOccurred())
Expect(rt.udpConn).ToNot(BeNil())
})
It("reuses udpconn in different hosts", func() {
Expect(rt.udpConn).To(BeNil())
quicDialer = func(_ context.Context, pconn net.PacketConn, _ net.Addr, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
conn := mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().LocalAddr().Return(pconn.LocalAddr())
return conn, nil
}
rt.newClient = func(hostname string, tlsConf *tls.Config, _ *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *http.Request, _ RoundTripOpt) (*http.Response, error) {
header := make(http.Header)
quicConn, err := dialer(context.Background(), hostname, tlsConf, conf)
Expect(err).ToNot(HaveOccurred())
header.Set("udpconn", quicConn.LocalAddr().String())
return &http.Response{Header: header}, nil
})
cl.EXPECT().Close()
return cl, nil
}
rsp1, err := rt.RoundTrip(req1)
Expect(err).ToNot(HaveOccurred())
Expect(rsp1.Header.Get("udpconn")).ToNot(Equal(""))
rsp2, err := rt.RoundTrip(req2)
Expect(err).ToNot(HaveOccurred())
Expect(rsp2.Header.Get("udpconn")).ToNot(Equal(""))
Expect(rsp1.Header.Get("udpconn")).To(Equal(rsp2.Header.Get("udpconn")))
})
})
}) })

View file

@ -61,6 +61,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
mux1 *http.ServeMux mux1 *http.ServeMux
mux2 *http.ServeMux mux2 *http.ServeMux
client *http.Client client *http.Client
rt *http3.RoundTripper
server1 *http3.Server server1 *http3.Server
server2 *http3.Server server2 *http3.Server
ln *listenerWrapper ln *listenerWrapper
@ -97,17 +98,17 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
}) })
AfterEach(func() { AfterEach(func() {
rt.Close()
Expect(ln.Close()).NotTo(HaveOccurred()) Expect(ln.Close()).NotTo(HaveOccurred())
}) })
BeforeEach(func() { BeforeEach(func() {
client = &http.Client{ rt = &http3.RoundTripper{
Transport: &http3.RoundTripper{ TLSClientConfig: getTLSClientConfig(),
TLSClientConfig: getTLSClientConfig(), DisableCompression: true,
DisableCompression: true, QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
},
} }
client = &http.Client{Transport: rt}
}) })
It("hotswap works", func() { It("hotswap works", func() {

View file

@ -40,6 +40,7 @@ var _ = Describe("HTTP tests", func() {
var ( var (
mux *http.ServeMux mux *http.ServeMux
client *http.Client client *http.Client
rt *http3.RoundTripper
server *http3.Server server *http3.Server
stoppedServing chan struct{} stoppedServing chan struct{}
port string port string
@ -77,6 +78,12 @@ var _ = Describe("HTTP tests", func() {
w.Write(body) // don't check the error here. Stream may be reset. w.Write(body) // don't check the error here. Stream may be reset.
}) })
mux.HandleFunc("/remoteAddr", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
w.Header().Set("X-RemoteAddr", r.RemoteAddr)
w.WriteHeader(http.StatusOK)
})
server = &http3.Server{ server = &http3.Server{
Handler: mux, Handler: mux,
TLSConfig: getTLSConfig(), TLSConfig: getTLSConfig(),
@ -99,18 +106,18 @@ var _ = Describe("HTTP tests", func() {
}) })
AfterEach(func() { AfterEach(func() {
rt.Close()
Expect(server.Close()).NotTo(HaveOccurred()) Expect(server.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing).Should(BeClosed()) Eventually(stoppedServing).Should(BeClosed())
}) })
BeforeEach(func() { BeforeEach(func() {
client = &http.Client{ rt = &http3.RoundTripper{
Transport: &http3.RoundTripper{ TLSClientConfig: getTLSClientConfigWithoutServerName(),
TLSClientConfig: getTLSClientConfig(), DisableCompression: true,
DisableCompression: true, QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
},
} }
client = &http.Client{Transport: rt}
}) })
It("downloads a hello", func() { It("downloads a hello", func() {
@ -122,6 +129,20 @@ var _ = Describe("HTTP tests", func() {
Expect(string(body)).To(Equal("Hello, World!\n")) Expect(string(body)).To(Equal("Hello, World!\n"))
}) })
It("requests to different servers with the same udpconn", func() {
resp, err := client.Get("https://localhost:" + port + "/remoteAddr")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
addr1 := resp.Header.Get("X-RemoteAddr")
Expect(addr1).ToNot(Equal(""))
resp, err = client.Get("https://127.0.0.1:" + port + "/remoteAddr")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
addr2 := resp.Header.Get("X-RemoteAddr")
Expect(addr2).ToNot(Equal(""))
Expect(addr1).To(Equal(addr2))
})
It("downloads concurrently", func() { It("downloads concurrently", func() {
group, ctx := errgroup.WithContext(context.Background()) group, ctx := errgroup.WithContext(context.Background())
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {

View file

@ -90,10 +90,11 @@ var (
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
enableQlog bool enableQlog bool
version quic.VersionNumber version quic.VersionNumber
tlsConfig *tls.Config tlsConfig *tls.Config
tlsConfigLongChain *tls.Config tlsConfigLongChain *tls.Config
tlsClientConfig *tls.Config tlsClientConfig *tls.Config
tlsClientConfigWithoutServerName *tls.Config
) )
// read the logfile command line flag // read the logfile command line flag
@ -131,6 +132,10 @@ func init() {
RootCAs: root, RootCAs: root,
NextProtos: []string{alpn}, NextProtos: []string{alpn},
} }
tlsClientConfigWithoutServerName = &tls.Config{
RootCAs: root,
NextProtos: []string{alpn},
}
} }
var _ = BeforeSuite(func() { var _ = BeforeSuite(func() {
@ -165,6 +170,10 @@ func getTLSClientConfig() *tls.Config {
return tlsClientConfig.Clone() return tlsClientConfig.Clone()
} }
func getTLSClientConfigWithoutServerName() *tls.Config {
return tlsClientConfigWithoutServerName.Clone()
}
func getQuicConfig(conf *quic.Config) *quic.Config { func getQuicConfig(conf *quic.Config) *quic.Config {
if conf == nil { if conf == nil {
conf = &quic.Config{} conf = &quic.Config{}

View file

@ -9,6 +9,7 @@ import (
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"math/big" "math/big"
"net"
"time" "time"
) )
@ -44,6 +45,7 @@ func GenerateLeafCert(ca *x509.Certificate, caPriv crypto.PrivateKey) (*x509.Cer
certTempl := &x509.Certificate{ certTempl := &x509.Certificate{
SerialNumber: big.NewInt(1), SerialNumber: big.NewInt(1),
DNSNames: []string{"localhost"}, DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)},
NotBefore: time.Now(), NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour), NotAfter: time.Now().Add(24 * time.Hour),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},