mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 05:07:36 +03:00
http3: correctly use the quic.Transport (#3869)
* use quic.Transport in http3 * add intergrationtests to dial server with different server names * update test
This commit is contained in:
parent
21549fcb4a
commit
c96fbd2e4a
6 changed files with 57 additions and 90 deletions
|
@ -17,9 +17,6 @@ import (
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
// declare this as a variable, such that we can it mock it in the tests
|
|
||||||
var quicDialer = quic.DialEarly
|
|
||||||
|
|
||||||
type roundTripCloser interface {
|
type roundTripCloser interface {
|
||||||
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
|
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
|
||||||
HandshakeComplete() bool
|
HandshakeComplete() bool
|
||||||
|
@ -89,7 +86,7 @@ type RoundTripper struct {
|
||||||
|
|
||||||
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
|
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
|
||||||
clients map[string]*roundTripCloserWithCount
|
clients map[string]*roundTripCloserWithCount
|
||||||
udpConn *net.UDPConn
|
transport *quic.Transport
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
||||||
|
@ -187,11 +184,12 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTr
|
||||||
}
|
}
|
||||||
dial := r.Dial
|
dial := r.Dial
|
||||||
if dial == nil {
|
if dial == nil {
|
||||||
if r.udpConn == nil {
|
if r.transport == nil {
|
||||||
r.udpConn, err = net.ListenUDP("udp", nil)
|
udpConn, err := net.ListenUDP("udp", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
r.transport = &quic.Transport{Conn: udpConn}
|
||||||
}
|
}
|
||||||
dial = r.makeDialer()
|
dial = r.makeDialer()
|
||||||
}
|
}
|
||||||
|
@ -240,9 +238,9 @@ func (r *RoundTripper) Close() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.clients = nil
|
r.clients = nil
|
||||||
if r.udpConn != nil {
|
if r.transport != nil {
|
||||||
r.udpConn.Close()
|
r.transport.Close()
|
||||||
r.udpConn = nil
|
r.transport = nil
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -282,7 +280,7 @@ func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCf
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return quicDialer(ctx, r.udpConn, udpAddr, tlsCfg, cfg)
|
return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,13 +6,11 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
|
|
||||||
"github.com/quic-go/quic-go/internal/qerr"
|
"github.com/quic-go/quic-go/internal/qerr"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
@ -368,66 +366,4 @@ var _ = Describe("RoundTripper", func() {
|
||||||
Expect(len(rt.clients)).To(Equal(0))
|
Expect(len(rt.clients)).To(Equal(0))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("reusing udpconn", func() {
|
|
||||||
var originalDialer func(ctx context.Context, pconn net.PacketConn, remoteAddr net.Addr, tlsConf *tls.Config, config *quic.Config) (quic.EarlyConnection, error)
|
|
||||||
var req1, req2 *http.Request
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
originalDialer = quicDialer
|
|
||||||
req1, err = http.NewRequest("GET", "https://site1.com", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
req2, err = http.NewRequest("GET", "https://site2.com", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(req1.Host).ToNot(Equal(req2.Host))
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
quicDialer = originalDialer
|
|
||||||
err := rt.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("creates udpconn at first request", func() {
|
|
||||||
Expect(rt.udpConn).To(BeNil())
|
|
||||||
rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
|
|
||||||
cl := NewMockRoundTripCloser(mockCtrl)
|
|
||||||
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any())
|
|
||||||
cl.EXPECT().Close()
|
|
||||||
return cl, nil
|
|
||||||
}
|
|
||||||
_, err := rt.RoundTrip(req1)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rt.udpConn).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reuses udpconn in different hosts", func() {
|
|
||||||
Expect(rt.udpConn).To(BeNil())
|
|
||||||
quicDialer = func(_ context.Context, pconn net.PacketConn, _ net.Addr, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
|
|
||||||
conn := mockquic.NewMockEarlyConnection(mockCtrl)
|
|
||||||
conn.EXPECT().LocalAddr().Return(pconn.LocalAddr())
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
rt.newClient = func(hostname string, tlsConf *tls.Config, _ *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
|
|
||||||
cl := NewMockRoundTripCloser(mockCtrl)
|
|
||||||
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
|
||||||
header := make(http.Header)
|
|
||||||
quicConn, err := dialer(context.Background(), hostname, tlsConf, conf)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
header.Set("udpconn", quicConn.LocalAddr().String())
|
|
||||||
return &http.Response{Header: header}, nil
|
|
||||||
})
|
|
||||||
cl.EXPECT().Close()
|
|
||||||
return cl, nil
|
|
||||||
}
|
|
||||||
rsp1, err := rt.RoundTrip(req1)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp1.Header.Get("udpconn")).ToNot(Equal(""))
|
|
||||||
rsp2, err := rt.RoundTrip(req2)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp2.Header.Get("udpconn")).ToNot(Equal(""))
|
|
||||||
Expect(rsp1.Header.Get("udpconn")).To(Equal(rsp2.Header.Get("udpconn")))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
|
@ -61,6 +61,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
|
||||||
mux1 *http.ServeMux
|
mux1 *http.ServeMux
|
||||||
mux2 *http.ServeMux
|
mux2 *http.ServeMux
|
||||||
client *http.Client
|
client *http.Client
|
||||||
|
rt *http3.RoundTripper
|
||||||
server1 *http3.Server
|
server1 *http3.Server
|
||||||
server2 *http3.Server
|
server2 *http3.Server
|
||||||
ln *listenerWrapper
|
ln *listenerWrapper
|
||||||
|
@ -97,17 +98,17 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
|
rt.Close()
|
||||||
Expect(ln.Close()).NotTo(HaveOccurred())
|
Expect(ln.Close()).NotTo(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
client = &http.Client{
|
rt = &http3.RoundTripper{
|
||||||
Transport: &http3.RoundTripper{
|
TLSClientConfig: getTLSClientConfig(),
|
||||||
TLSClientConfig: getTLSClientConfig(),
|
DisableCompression: true,
|
||||||
DisableCompression: true,
|
QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
|
||||||
QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
client = &http.Client{Transport: rt}
|
||||||
})
|
})
|
||||||
|
|
||||||
It("hotswap works", func() {
|
It("hotswap works", func() {
|
||||||
|
|
|
@ -40,6 +40,7 @@ var _ = Describe("HTTP tests", func() {
|
||||||
var (
|
var (
|
||||||
mux *http.ServeMux
|
mux *http.ServeMux
|
||||||
client *http.Client
|
client *http.Client
|
||||||
|
rt *http3.RoundTripper
|
||||||
server *http3.Server
|
server *http3.Server
|
||||||
stoppedServing chan struct{}
|
stoppedServing chan struct{}
|
||||||
port string
|
port string
|
||||||
|
@ -77,6 +78,12 @@ var _ = Describe("HTTP tests", func() {
|
||||||
w.Write(body) // don't check the error here. Stream may be reset.
|
w.Write(body) // don't check the error here. Stream may be reset.
|
||||||
})
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/remoteAddr", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
w.Header().Set("X-RemoteAddr", r.RemoteAddr)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
server = &http3.Server{
|
server = &http3.Server{
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
TLSConfig: getTLSConfig(),
|
TLSConfig: getTLSConfig(),
|
||||||
|
@ -99,18 +106,18 @@ var _ = Describe("HTTP tests", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
|
rt.Close()
|
||||||
Expect(server.Close()).NotTo(HaveOccurred())
|
Expect(server.Close()).NotTo(HaveOccurred())
|
||||||
Eventually(stoppedServing).Should(BeClosed())
|
Eventually(stoppedServing).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
client = &http.Client{
|
rt = &http3.RoundTripper{
|
||||||
Transport: &http3.RoundTripper{
|
TLSClientConfig: getTLSClientConfigWithoutServerName(),
|
||||||
TLSClientConfig: getTLSClientConfig(),
|
DisableCompression: true,
|
||||||
DisableCompression: true,
|
QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
|
||||||
QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
client = &http.Client{Transport: rt}
|
||||||
})
|
})
|
||||||
|
|
||||||
It("downloads a hello", func() {
|
It("downloads a hello", func() {
|
||||||
|
@ -122,6 +129,20 @@ var _ = Describe("HTTP tests", func() {
|
||||||
Expect(string(body)).To(Equal("Hello, World!\n"))
|
Expect(string(body)).To(Equal("Hello, World!\n"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("requests to different servers with the same udpconn", func() {
|
||||||
|
resp, err := client.Get("https://localhost:" + port + "/remoteAddr")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.StatusCode).To(Equal(200))
|
||||||
|
addr1 := resp.Header.Get("X-RemoteAddr")
|
||||||
|
Expect(addr1).ToNot(Equal(""))
|
||||||
|
resp, err = client.Get("https://127.0.0.1:" + port + "/remoteAddr")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.StatusCode).To(Equal(200))
|
||||||
|
addr2 := resp.Header.Get("X-RemoteAddr")
|
||||||
|
Expect(addr2).ToNot(Equal(""))
|
||||||
|
Expect(addr1).To(Equal(addr2))
|
||||||
|
})
|
||||||
|
|
||||||
It("downloads concurrently", func() {
|
It("downloads concurrently", func() {
|
||||||
group, ctx := errgroup.WithContext(context.Background())
|
group, ctx := errgroup.WithContext(context.Background())
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
|
|
|
@ -90,10 +90,11 @@ var (
|
||||||
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
|
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
|
||||||
enableQlog bool
|
enableQlog bool
|
||||||
|
|
||||||
version quic.VersionNumber
|
version quic.VersionNumber
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
tlsConfigLongChain *tls.Config
|
tlsConfigLongChain *tls.Config
|
||||||
tlsClientConfig *tls.Config
|
tlsClientConfig *tls.Config
|
||||||
|
tlsClientConfigWithoutServerName *tls.Config
|
||||||
)
|
)
|
||||||
|
|
||||||
// read the logfile command line flag
|
// read the logfile command line flag
|
||||||
|
@ -131,6 +132,10 @@ func init() {
|
||||||
RootCAs: root,
|
RootCAs: root,
|
||||||
NextProtos: []string{alpn},
|
NextProtos: []string{alpn},
|
||||||
}
|
}
|
||||||
|
tlsClientConfigWithoutServerName = &tls.Config{
|
||||||
|
RootCAs: root,
|
||||||
|
NextProtos: []string{alpn},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ = BeforeSuite(func() {
|
var _ = BeforeSuite(func() {
|
||||||
|
@ -165,6 +170,10 @@ func getTLSClientConfig() *tls.Config {
|
||||||
return tlsClientConfig.Clone()
|
return tlsClientConfig.Clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTLSClientConfigWithoutServerName() *tls.Config {
|
||||||
|
return tlsClientConfigWithoutServerName.Clone()
|
||||||
|
}
|
||||||
|
|
||||||
func getQuicConfig(conf *quic.Config) *quic.Config {
|
func getQuicConfig(conf *quic.Config) *quic.Config {
|
||||||
if conf == nil {
|
if conf == nil {
|
||||||
conf = &quic.Config{}
|
conf = &quic.Config{}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,6 +45,7 @@ func GenerateLeafCert(ca *x509.Certificate, caPriv crypto.PrivateKey) (*x509.Cer
|
||||||
certTempl := &x509.Certificate{
|
certTempl := &x509.Certificate{
|
||||||
SerialNumber: big.NewInt(1),
|
SerialNumber: big.NewInt(1),
|
||||||
DNSNames: []string{"localhost"},
|
DNSNames: []string{"localhost"},
|
||||||
|
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)},
|
||||||
NotBefore: time.Now(),
|
NotBefore: time.Now(),
|
||||||
NotAfter: time.Now().Add(24 * time.Hour),
|
NotAfter: time.Now().Add(24 * time.Hour),
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue