mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 04:07:35 +03:00
http3: correctly use the quic.Transport (#3869)
* use quic.Transport in http3 * add intergrationtests to dial server with different server names * update test
This commit is contained in:
parent
21549fcb4a
commit
c96fbd2e4a
6 changed files with 57 additions and 90 deletions
|
@ -17,9 +17,6 @@ import (
|
|||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
// declare this as a variable, such that we can it mock it in the tests
|
||||
var quicDialer = quic.DialEarly
|
||||
|
||||
type roundTripCloser interface {
|
||||
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
|
||||
HandshakeComplete() bool
|
||||
|
@ -89,7 +86,7 @@ type RoundTripper struct {
|
|||
|
||||
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
|
||||
clients map[string]*roundTripCloserWithCount
|
||||
udpConn *net.UDPConn
|
||||
transport *quic.Transport
|
||||
}
|
||||
|
||||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
||||
|
@ -187,11 +184,12 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTr
|
|||
}
|
||||
dial := r.Dial
|
||||
if dial == nil {
|
||||
if r.udpConn == nil {
|
||||
r.udpConn, err = net.ListenUDP("udp", nil)
|
||||
if r.transport == nil {
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
r.transport = &quic.Transport{Conn: udpConn}
|
||||
}
|
||||
dial = r.makeDialer()
|
||||
}
|
||||
|
@ -240,9 +238,9 @@ func (r *RoundTripper) Close() error {
|
|||
}
|
||||
}
|
||||
r.clients = nil
|
||||
if r.udpConn != nil {
|
||||
r.udpConn.Close()
|
||||
r.udpConn = nil
|
||||
if r.transport != nil {
|
||||
r.transport.Close()
|
||||
r.transport = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -282,7 +280,7 @@ func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCf
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return quicDialer(ctx, r.udpConn, udpAddr, tlsCfg, cfg)
|
||||
return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,13 +6,11 @@ import (
|
|||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
@ -368,66 +366,4 @@ var _ = Describe("RoundTripper", func() {
|
|||
Expect(len(rt.clients)).To(Equal(0))
|
||||
})
|
||||
})
|
||||
|
||||
Context("reusing udpconn", func() {
|
||||
var originalDialer func(ctx context.Context, pconn net.PacketConn, remoteAddr net.Addr, tlsConf *tls.Config, config *quic.Config) (quic.EarlyConnection, error)
|
||||
var req1, req2 *http.Request
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
originalDialer = quicDialer
|
||||
req1, err = http.NewRequest("GET", "https://site1.com", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req2, err = http.NewRequest("GET", "https://site2.com", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(req1.Host).ToNot(Equal(req2.Host))
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
quicDialer = originalDialer
|
||||
err := rt.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("creates udpconn at first request", func() {
|
||||
Expect(rt.udpConn).To(BeNil())
|
||||
rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
|
||||
cl := NewMockRoundTripCloser(mockCtrl)
|
||||
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any())
|
||||
cl.EXPECT().Close()
|
||||
return cl, nil
|
||||
}
|
||||
_, err := rt.RoundTrip(req1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rt.udpConn).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("reuses udpconn in different hosts", func() {
|
||||
Expect(rt.udpConn).To(BeNil())
|
||||
quicDialer = func(_ context.Context, pconn net.PacketConn, _ net.Addr, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
|
||||
conn := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
conn.EXPECT().LocalAddr().Return(pconn.LocalAddr())
|
||||
return conn, nil
|
||||
}
|
||||
rt.newClient = func(hostname string, tlsConf *tls.Config, _ *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
|
||||
cl := NewMockRoundTripCloser(mockCtrl)
|
||||
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
||||
header := make(http.Header)
|
||||
quicConn, err := dialer(context.Background(), hostname, tlsConf, conf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
header.Set("udpconn", quicConn.LocalAddr().String())
|
||||
return &http.Response{Header: header}, nil
|
||||
})
|
||||
cl.EXPECT().Close()
|
||||
return cl, nil
|
||||
}
|
||||
rsp1, err := rt.RoundTrip(req1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp1.Header.Get("udpconn")).ToNot(Equal(""))
|
||||
rsp2, err := rt.RoundTrip(req2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp2.Header.Get("udpconn")).ToNot(Equal(""))
|
||||
Expect(rsp1.Header.Get("udpconn")).To(Equal(rsp2.Header.Get("udpconn")))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue