http3: use a single UDPConn in RoundTripper (#3720)

* http3: use a single UDPConn in RoundTripper

* update

* add tests
This commit is contained in:
Glonee 2023-03-15 09:58:26 +08:00 committed by GitHub
parent a92238b73c
commit 6d7280b7dc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 4 deletions

View file

@ -16,6 +16,9 @@ import (
"github.com/quic-go/quic-go"
)
// declare this as a variable, such that we can it mock it in the tests
var quicDialer = quic.DialEarlyContext
type roundTripCloser interface {
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
HandshakeComplete() bool
@ -69,7 +72,8 @@ type RoundTripper struct {
// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddrEarlyContext will be used.
// If Dial is nil, a UDPConn will be created at the first request
// and will be reused for subsequent connections to other servers.
Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
@ -79,6 +83,7 @@ type RoundTripper struct {
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
clients map[string]roundTripCloser
udpConn *net.UDPConn
}
// RoundTripOpt are options for the Transport.RoundTripOpt method.
@ -173,6 +178,16 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri
if r.newClient != nil {
newCl = r.newClient
}
dial := r.Dial
if dial == nil {
if r.udpConn == nil {
r.udpConn, err = net.ListenUDP("udp", nil)
if err != nil {
return nil, false, err
}
}
dial = r.makeDialer()
}
client, err = newCl(
hostname,
r.TLSClientConfig,
@ -184,7 +199,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri
UniStreamHijacker: r.UniStreamHijacker,
},
r.QuicConfig,
r.Dial,
dial,
)
if err != nil {
return nil, false, err
@ -205,7 +220,8 @@ func (r *RoundTripper) removeClient(hostname string) {
delete(r.clients, hostname)
}
// Close closes the QUIC connections that this RoundTripper has used
// Close closes the QUIC connections that this RoundTripper has used.
// It also closes the underlying UDPConn if it is not nil.
func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
@ -215,6 +231,10 @@ func (r *RoundTripper) Close() error {
}
}
r.clients = nil
if r.udpConn != nil {
r.udpConn.Close()
r.udpConn = nil
}
return nil
}
@ -245,3 +265,14 @@ func validMethod(method string) bool {
func isNotToken(r rune) bool {
return !httpguts.IsTokenRune(r)
}
// makeDialer makes a QUIC dialer using r.udpConn.
func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
return quicDialer(ctx, r.udpConn, udpAddr, addr, tlsCfg, cfg)
}
}

View file

@ -6,10 +6,12 @@ import (
"crypto/tls"
"errors"
"io"
"net"
"net/http"
"time"
"github.com/quic-go/quic-go"
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/golang/mock/gomock"
@ -75,7 +77,7 @@ var _ = Describe("RoundTripper", func() {
It("uses the quic.Config, if provided", func() {
config := &quic.Config{HandshakeIdleTimeout: time.Millisecond}
var receivedConfig *quic.Config
dialAddr = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
receivedConfig = config
return nil, errors.New("handshake error")
}
@ -320,4 +322,66 @@ var _ = Describe("RoundTripper", func() {
Expect(len(rt.clients)).To(BeZero())
})
})
Context("reusing udpconn", func() {
var originalDialer func(ctx context.Context, pconn net.PacketConn, remoteAddr net.Addr, host string, tlsConf *tls.Config, config *quic.Config) (quic.EarlyConnection, error)
var req1, req2 *http.Request
BeforeEach(func() {
var err error
originalDialer = quicDialer
req1, err = http.NewRequest("GET", "https://site1.com", nil)
Expect(err).ToNot(HaveOccurred())
req2, err = http.NewRequest("GET", "https://site2.com", nil)
Expect(err).ToNot(HaveOccurred())
Expect(req1.Host).ToNot(Equal(req2.Host))
})
AfterEach(func() {
quicDialer = originalDialer
err := rt.Close()
Expect(err).ToNot(HaveOccurred())
})
It("creates udpconn at first request", func() {
Expect(rt.udpConn).To(BeNil())
rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any())
cl.EXPECT().Close()
return cl, nil
}
_, err := rt.RoundTrip(req1)
Expect(err).ToNot(HaveOccurred())
Expect(rt.udpConn).ToNot(BeNil())
})
It("reuses udpconn in different hosts", func() {
Expect(rt.udpConn).To(BeNil())
quicDialer = func(_ context.Context, pconn net.PacketConn, _ net.Addr, _ string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
conn := mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().LocalAddr().Return(pconn.LocalAddr())
return conn, nil
}
rt.newClient = func(hostname string, tlsConf *tls.Config, _ *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *http.Request, _ RoundTripOpt) (*http.Response, error) {
header := make(http.Header)
quicConn, err := dialer(context.Background(), hostname, tlsConf, conf)
Expect(err).ToNot(HaveOccurred())
header.Set("udpconn", quicConn.LocalAddr().String())
return &http.Response{Header: header}, nil
})
cl.EXPECT().Close()
return cl, nil
}
rsp1, err := rt.RoundTrip(req1)
Expect(err).ToNot(HaveOccurred())
Expect(rsp1.Header.Get("udpconn")).ToNot(Equal(""))
rsp2, err := rt.RoundTrip(req2)
Expect(err).ToNot(HaveOccurred())
Expect(rsp2.Header.Get("udpconn")).ToNot(Equal(""))
Expect(rsp1.Header.Get("udpconn")).To(Equal(rsp2.Header.Get("udpconn")))
})
})
})