move dialing logic from the client into the Transport (#4859)

This commit is contained in:
Marten Seemann 2025-01-14 00:40:20 -08:00 committed by GitHub
parent fbbc3c9e30
commit 62a94758e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 362 additions and 539 deletions

164
client.go
View file

@ -7,39 +7,8 @@ import (
"net"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
type client struct {
sendConn sendConn
use0RTT bool
packetHandlers packetHandlerManager
onClose func()
tlsConf *tls.Config
config *Config
connIDGenerator ConnectionIDGenerator
statelessResetter *statelessResetter
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
initialPacketNumber protocol.PacketNumber
hasNegotiatedVersion bool
version protocol.Version
handshakeChan chan struct{}
conn quicConn
tracer *logging.ConnectionTracer
tracingID ConnectionTracingID
logger utils.Logger
}
// make it possible to mock connection ID for initial generation in the tests
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
@ -133,136 +102,3 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo
isSingleUse: true,
}, nil
}
func dial(
ctx context.Context,
conn sendConn,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
packetHandlers packetHandlerManager,
tlsConf *tls.Config,
config *Config,
onClose func(),
use0RTT bool,
) (quicConn, error) {
c, err := newClient(conn, connIDGenerator, statelessResetter, config, tlsConf, onClose, use0RTT)
if err != nil {
return nil, err
}
c.packetHandlers = packetHandlers
c.tracingID = nextConnTracingID()
if c.config.Tracer != nil {
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
}
if c.tracer != nil && c.tracer.StartedConnection != nil {
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
}
if err := c.dial(ctx); err != nil {
return nil, err
}
return c.conn, nil
}
func newClient(
sendConn sendConn,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
config *Config,
tlsConf *tls.Config,
onClose func(),
use0RTT bool,
) (*client, error) {
srcConnID, err := connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return nil, err
}
c := &client{
connIDGenerator: connIDGenerator,
statelessResetter: statelessResetter,
srcConnID: srcConnID,
destConnID: destConnID,
sendConn: sendConn,
use0RTT: use0RTT,
onClose: onClose,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.conn = newClientConnection(
context.WithValue(context.WithoutCancel(ctx), ConnectionTracingKey, c.tracingID),
c.sendConn,
c.packetHandlers,
c.destConnID,
c.srcConnID,
c.connIDGenerator,
c.statelessResetter,
c.config,
c.tlsConf,
c.initialPacketNumber,
c.use0RTT,
c.hasNegotiatedVersion,
c.tracer,
c.logger,
c.version,
)
c.packetHandlers.Add(c.srcConnID, c.conn)
errorChan := make(chan error, 1)
recreateChan := make(chan errCloseForRecreating)
go func() {
err := c.conn.run()
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
recreateChan <- *recreateErr
return
}
if c.onClose != nil {
c.onClose()
}
errorChan <- err // returns as soon as the connection is closed
}()
// only set when we're using 0-RTT
// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
var earlyConnChan <-chan struct{}
if c.use0RTT {
earlyConnChan = c.conn.earlyConnReady()
}
select {
case <-ctx.Done():
c.conn.destroy(nil)
// wait until the Go routine that called Connection.run() returns
select {
case <-errorChan:
case <-recreateChan:
}
return context.Cause(ctx)
case err := <-errorChan:
return err
case recreateErr := <-recreateChan:
c.initialPacketNumber = recreateErr.nextPacketNumber
c.version = recreateErr.nextVersion
c.hasNegotiatedVersion = true
return c.dial(ctx)
case <-earlyConnChan:
// ready to send 0-RTT data
return nil
case <-c.conn.HandshakeComplete():
// handshake successfully completed
return nil
}
}

View file

@ -3,376 +3,93 @@ package quic
import (
"context"
"crypto/tls"
"errors"
"net"
"runtime"
"testing"
"time"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"go.uber.org/mock/gomock"
"github.com/stretchr/testify/require"
)
var _ = Describe("Client", func() {
var (
cl *client
packetConn *MockSendConn
connID protocol.ConnectionID
tlsConf *tls.Config
tracer *mocklogging.MockConnectionTracer
config *Config
originalClientConnConstructor func(
ctx context.Context,
conn sendConn,
runner connRunner,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
statelessResetToken *statelessResetter,
conf *Config,
tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
enable0RTT bool,
hasNegotiatedVersion bool,
tracer *logging.ConnectionTracer,
logger utils.Logger,
v protocol.Version,
) quicConn
)
BeforeEach(func() {
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37})
originalClientConnConstructor = newClientConnection
var tr *logging.ConnectionTracer
tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
config = &Config{
Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) *logging.ConnectionTracer {
return tr
func TestDial(t *testing.T) {
t.Run("Dial", func(t *testing.T) {
testDial(t,
func(ctx context.Context, addr net.Addr) error {
conn := newUPDConnLocalhost(t)
_, err := Dial(ctx, conn, addr, &tls.Config{}, nil)
return err
},
Versions: []protocol.Version{protocol.Version1},
false,
)
})
t.Run("DialEarly", func(t *testing.T) {
testDial(t,
func(ctx context.Context, addr net.Addr) error {
conn := newUPDConnLocalhost(t)
_, err := DialEarly(ctx, conn, addr, &tls.Config{}, nil)
return err
},
false,
)
})
t.Run("DialAddr", func(t *testing.T) {
testDial(t,
func(ctx context.Context, addr net.Addr) error {
_, err := DialAddr(ctx, addr.String(), &tls.Config{}, nil)
return err
},
true,
)
})
t.Run("DialAddrEarly", func(t *testing.T) {
testDial(t,
func(ctx context.Context, addr net.Addr) error {
_, err := DialAddrEarly(ctx, addr.String(), &tls.Config{}, nil)
return err
},
true,
)
})
}
func testDial(t *testing.T,
dialFn func(context.Context, net.Addr) error,
shouldCloseConn bool,
) {
server := newUPDConnLocalhost(t)
ctx, cancel := context.WithCancel(context.Background())
errChan := make(chan error, 1)
go func() { errChan <- dialFn(ctx, server.LocalAddr()) }()
_, addr, err := server.ReadFrom(make([]byte, 1500))
require.NoError(t, err)
require.True(t, areTransportsRunning())
cancel()
select {
case err := <-errChan:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// The socket that the client used for dialing should be closed now.
// Binding to the same address would error if the address was still in use.
conn, err := net.ListenUDP("udp", addr.(*net.UDPAddr))
if shouldCloseConn {
require.NoError(t, err)
defer conn.Close()
} else {
require.Error(t, err)
if runtime.GOOS == "windows" {
require.ErrorContains(t, err, "bind: Only one usage of each socket address")
} else {
require.ErrorContains(t, err, "address already in use")
}
Eventually(areConnsRunning).Should(BeFalse())
packetConn = NewMockSendConn(mockCtrl)
packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
packetConn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes()
cl = &client{
srcConnID: connID,
destConnID: connID,
version: protocol.Version1,
sendConn: packetConn,
tracer: tr,
logger: utils.DefaultLogger,
}
})
}
AfterEach(func() {
newClientConnection = originalClientConnConstructor
})
AfterEach(func() {
if s, ok := cl.conn.(*connection); ok {
s.destroy(nil)
}
Eventually(areConnsRunning).Should(BeFalse())
})
Context("Dialing", func() {
var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error)
BeforeEach(func() {
origGenerateConnectionIDForInitial = generateConnectionIDForInitial
generateConnectionIDForInitial = func() (protocol.ConnectionID, error) {
return connID, nil
}
})
AfterEach(func() {
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
})
It("returns after the handshake is complete", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
run := make(chan struct{})
newClientConnection = func(
_ context.Context,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
enable0RTT bool,
_ bool,
_ *logging.ConnectionTracer,
_ utils.Logger,
_ protocol.Version,
) quicConn {
Expect(enable0RTT).To(BeFalse())
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run().Do(func() error { close(run); return nil })
c := make(chan struct{})
close(c)
conn.EXPECT().HandshakeComplete().Return(c)
return conn
}
cl, err := newClient(
packetConn,
&protocol.DefaultConnectionIDGenerator{},
newStatelessResetter(nil),
populateConfig(config),
tlsConf,
nil,
false,
)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
Expect(cl.dial(context.Background())).To(Succeed())
Eventually(run).Should(BeClosed())
})
It("returns early connections", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
readyChan := make(chan struct{})
done := make(chan struct{})
newClientConnection = func(
_ context.Context,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
enable0RTT bool,
_ bool,
_ *logging.ConnectionTracer,
_ utils.Logger,
_ protocol.Version,
) quicConn {
Expect(enable0RTT).To(BeTrue())
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run().Do(func() error { close(done); return nil })
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().earlyConnReady().Return(readyChan)
return conn
}
cl, err := newClient(
packetConn,
&protocol.DefaultConnectionIDGenerator{},
newStatelessResetter(nil),
populateConfig(config),
tlsConf,
nil,
true,
)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
Expect(cl.dial(context.Background())).To(Succeed())
Eventually(done).Should(BeClosed())
})
It("returns an error that occurs while waiting for the handshake to complete", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
testErr := errors.New("early handshake error")
newClientConnection = func(
_ context.Context,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ *logging.ConnectionTracer,
_ utils.Logger,
_ protocol.Version,
) quicConn {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run().Return(testErr)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().earlyConnReady().Return(make(chan struct{}))
return conn
}
var closed bool
cl, err := newClient(
packetConn,
&protocol.DefaultConnectionIDGenerator{},
newStatelessResetter(nil),
populateConfig(config), tlsConf, func() { closed = true },
true,
)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
Expect(cl.dial(context.Background())).To(MatchError(testErr))
Expect(closed).To(BeTrue())
})
Context("quic.Config", func() {
It("setups with the right values", func() {
tokenStore := NewLRUTokenStore(10, 4)
config := &Config{
HandshakeIdleTimeout: 1337 * time.Minute,
MaxIdleTimeout: 42 * time.Hour,
MaxIncomingStreams: 1234,
MaxIncomingUniStreams: 4321,
TokenStore: tokenStore,
EnableDatagrams: true,
}
c := populateConfig(config)
Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute))
Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour))
Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
Expect(c.TokenStore).To(Equal(tokenStore))
Expect(c.EnableDatagrams).To(BeTrue())
})
It("disables bidirectional streams", func() {
config := &Config{
MaxIncomingStreams: -1,
MaxIncomingUniStreams: 4321,
}
c := populateConfig(config)
Expect(c.MaxIncomingStreams).To(BeZero())
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
})
It("disables unidirectional streams", func() {
config := &Config{
MaxIncomingStreams: 1234,
MaxIncomingUniStreams: -1,
}
c := populateConfig(config)
Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
Expect(c.MaxIncomingUniStreams).To(BeZero())
})
It("fills in default values if options are not set in the Config", func() {
c := populateConfig(&Config{})
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
})
})
It("creates new connections with the right parameters", func() {
config := &Config{Versions: []protocol.Version{protocol.Version1}}
c := make(chan struct{})
var version protocol.Version
var conf *Config
done := make(chan struct{})
newClientConnection = func(
_ context.Context,
connP sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
configP *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ *logging.ConnectionTracer,
_ utils.Logger,
versionP protocol.Version,
) quicConn {
version = versionP
conf = configP
close(c)
// TODO: check connection IDs?
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run()
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
close(done)
return conn
}
packetConn := NewMockPacketConn(mockCtrl)
packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) {
<-done
return 0, nil, errors.New("closed")
})
packetConn.EXPECT().LocalAddr()
packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
_, err := Dial(context.Background(), packetConn, &net.UDPAddr{}, tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Eventually(c).Should(BeClosed())
Expect(version).To(Equal(config.Versions[0]))
Expect(conf.Versions).To(Equal(config.Versions))
})
It("creates a new connections after version negotiation", func() {
var counter int
newClientConnection = func(
_ context.Context,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
connID protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
configP *Config,
_ *tls.Config,
pn protocol.PacketNumber,
_ bool,
hasNegotiatedVersion bool,
_ *logging.ConnectionTracer,
_ utils.Logger,
versionP protocol.Version,
) quicConn {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
if counter == 0 {
Expect(pn).To(BeZero())
Expect(hasNegotiatedVersion).To(BeFalse())
conn.EXPECT().run().DoAndReturn(func() error {
runner.Remove(connID)
return &errCloseForRecreating{
nextPacketNumber: 109,
nextVersion: 789,
}
})
} else {
Expect(pn).To(Equal(protocol.PacketNumber(109)))
Expect(hasNegotiatedVersion).To(BeTrue())
conn.EXPECT().run()
conn.EXPECT().destroy(gomock.Any())
}
counter++
return conn
}
config := &Config{Tracer: config.Tracer, Versions: []protocol.Version{protocol.Version1}}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Expect(counter).To(Equal(2))
})
})
})
require.False(t, areTransportsRunning())
}

View file

@ -175,6 +175,7 @@ func TestConfigDefaultValues(t *testing.T) {
c = populateConfig(&Config{})
require.Equal(t, protocol.SupportedVersions, c.Versions)
require.Equal(t, protocol.DefaultHandshakeIdleTimeout, c.HandshakeIdleTimeout)
require.Equal(t, protocol.DefaultIdleTimeout, c.MaxIdleTimeout)
require.EqualValues(t, protocol.DefaultInitialMaxStreamData, c.InitialStreamReceiveWindow)
require.EqualValues(t, protocol.DefaultMaxReceiveStreamFlowControlWindow, c.MaxStreamReceiveWindow)
require.EqualValues(t, protocol.DefaultInitialMaxData, c.InitialConnectionReceiveWindow)
@ -184,3 +185,13 @@ func TestConfigDefaultValues(t *testing.T) {
require.False(t, c.DisablePathMTUDiscovery)
require.Nil(t, c.GetConfigForClient)
}
func TestConfigZeroLimits(t *testing.T) {
config := &Config{
MaxIncomingStreams: -1,
MaxIncomingUniStreams: -1,
}
c := populateConfig(config)
require.Zero(t, c.MaxIncomingStreams)
require.Zero(t, c.MaxIncomingUniStreams)
}

View file

@ -218,25 +218,124 @@ func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsCon
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
tlsConf = tlsConf.Clone()
setTLSConfigServerName(tlsConf, addr, host)
return dial(
ctx,
return t.doDial(ctx,
newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger),
t.connIDGenerator,
t.statelessResetter,
t.handlerMap,
tlsConf,
conf,
onClose,
0,
false,
use0RTT,
conf.Versions[0],
)
}
func (t *Transport) doDial(
ctx context.Context,
sendConn sendConn,
tlsConf *tls.Config,
config *Config,
initialPacketNumber protocol.PacketNumber,
hasNegotiatedVersion bool,
use0RTT bool,
version protocol.Version,
) (quicConn, error) {
srcConnID, err := t.connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return nil, err
}
tracingID := nextConnTracingID()
ctx = context.WithValue(ctx, ConnectionTracingKey, tracingID)
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
tracer = config.Tracer(ctx, protocol.PerspectiveClient, destConnID)
}
if tracer != nil && tracer.StartedConnection != nil {
tracer.StartedConnection(sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID)
}
logger := utils.DefaultLogger.WithPrefix("client")
logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", tlsConf.ServerName, sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID, version)
conn := newClientConnection(
context.WithoutCancel(ctx),
sendConn,
t.handlerMap,
destConnID,
srcConnID,
t.connIDGenerator,
t.statelessResetter,
config,
tlsConf,
initialPacketNumber,
use0RTT,
hasNegotiatedVersion,
tracer,
logger,
version,
)
t.handlerMap.Add(srcConnID, conn)
// The error channel needs to be buffered, as the run loop will continue running
// after doDial returns (if the handshake is successful).
errChan := make(chan error, 1)
recreateChan := make(chan errCloseForRecreating)
go func() {
err := conn.run()
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
recreateChan <- *recreateErr
return
}
if t.isSingleUse {
t.Close()
}
errChan <- err
}()
// Only set when we're using 0-RTT.
// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
var earlyConnChan <-chan struct{}
if use0RTT {
earlyConnChan = conn.earlyConnReady()
}
select {
case <-ctx.Done():
conn.destroy(nil)
// wait until the Go routine that called Connection.run() returns
select {
case <-errChan:
case <-recreateChan:
}
return nil, context.Cause(ctx)
case params := <-recreateChan:
return t.doDial(ctx,
sendConn,
tlsConf,
config,
params.nextPacketNumber,
true,
use0RTT,
params.nextVersion,
)
case err := <-errChan:
return nil, err
case <-earlyConnChan:
// ready to send 0-RTT data
return conn, nil
case <-conn.HandshakeComplete():
// handshake successfully completed
return conn, nil
}
}
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.initOnce.Do(func() {
var conn rawConn

View file

@ -13,6 +13,7 @@ import (
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
@ -455,3 +456,162 @@ func TestTransportSetTLSConfigServerName(t *testing.T) {
})
}
}
func TestTransportDial(t *testing.T) {
t.Run("regular", func(t *testing.T) {
testTransportDial(t, false)
})
t.Run("early", func(t *testing.T) {
testTransportDial(t, true)
})
}
func testTransportDial(t *testing.T, early bool) {
originalClientConnConstructor := newClientConnection
t.Cleanup(func() { newClientConnection = originalClientConnConstructor })
mockCtrl := gomock.NewController(t)
conn := NewMockQUICConn(mockCtrl)
handshakeChan := make(chan struct{})
if early {
conn.EXPECT().earlyConnReady().Return(handshakeChan)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
} else {
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
}
blockRun := make(chan struct{})
conn.EXPECT().run().DoAndReturn(func() error {
<-blockRun
return errors.New("done")
})
defer close(blockRun)
newClientConnection = func(
_ context.Context,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ *logging.ConnectionTracer,
_ utils.Logger,
_ protocol.Version,
) quicConn {
return conn
}
tr := &Transport{Conn: newUPDConnLocalhost(t)}
tr.init(true)
defer tr.Close()
errChan := make(chan error, 1)
go func() {
var err error
if early {
_, err = tr.DialEarly(context.Background(), nil, &tls.Config{}, nil)
} else {
_, err = tr.Dial(context.Background(), nil, &tls.Config{}, nil)
}
errChan <- err
}()
select {
case <-errChan:
t.Fatal("Dial shouldn't have returned")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
close(handshakeChan)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
}
// for test tear-down
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
}
func TestTransportDialingVersionNegotiation(t *testing.T) {
originalClientConnConstructor := newClientConnection
t.Cleanup(func() { newClientConnection = originalClientConnConstructor })
// connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
mockCtrl := gomock.NewController(t)
// runner := NewMockConnRunner(mockCtrl)
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().run().Return(&errCloseForRecreating{nextPacketNumber: 109, nextVersion: 789})
conn2 := NewMockQUICConn(mockCtrl)
conn2.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn2.EXPECT().run().Return(errors.New("test done"))
type connParams struct {
pn protocol.PacketNumber
hasNegotiatedVersion bool
version protocol.Version
}
connChan := make(chan connParams, 2)
var counter int
newClientConnection = func(
_ context.Context,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
pn protocol.PacketNumber,
_ bool,
hasNegotiatedVersion bool,
_ *logging.ConnectionTracer,
_ utils.Logger,
v protocol.Version,
) quicConn {
connChan <- connParams{pn: pn, hasNegotiatedVersion: hasNegotiatedVersion, version: v}
if counter == 0 {
counter++
return conn
}
return conn2
}
tr := &Transport{Conn: newUPDConnLocalhost(t)}
tr.init(true)
defer tr.Close()
_, err := tr.Dial(context.Background(), nil, &tls.Config{}, nil)
require.EqualError(t, err, "test done")
select {
case params := <-connChan:
require.Zero(t, params.pn)
require.False(t, params.hasNegotiatedVersion)
require.Equal(t, protocol.Version1, params.version)
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case params := <-connChan:
require.Equal(t, protocol.PacketNumber(109), params.pn)
require.True(t, params.hasNegotiatedVersion)
require.Equal(t, protocol.Version(789), params.version)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// for test tear down
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
conn2.EXPECT().destroy(gomock.Any()).AnyTimes()
}