mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
move dialing logic from the client into the Transport (#4859)
This commit is contained in:
parent
fbbc3c9e30
commit
62a94758e6
5 changed files with 362 additions and 539 deletions
164
client.go
164
client.go
|
@ -7,39 +7,8 @@ import (
|
|||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
sendConn sendConn
|
||||
|
||||
use0RTT bool
|
||||
|
||||
packetHandlers packetHandlerManager
|
||||
onClose func()
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
statelessResetter *statelessResetter
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
|
||||
initialPacketNumber protocol.PacketNumber
|
||||
hasNegotiatedVersion bool
|
||||
version protocol.Version
|
||||
|
||||
handshakeChan chan struct{}
|
||||
|
||||
conn quicConn
|
||||
|
||||
tracer *logging.ConnectionTracer
|
||||
tracingID ConnectionTracingID
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// make it possible to mock connection ID for initial generation in the tests
|
||||
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||
|
||||
|
@ -133,136 +102,3 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo
|
|||
isSingleUse: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func dial(
|
||||
ctx context.Context,
|
||||
conn sendConn,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetter *statelessResetter,
|
||||
packetHandlers packetHandlerManager,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
onClose func(),
|
||||
use0RTT bool,
|
||||
) (quicConn, error) {
|
||||
c, err := newClient(conn, connIDGenerator, statelessResetter, config, tlsConf, onClose, use0RTT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.packetHandlers = packetHandlers
|
||||
|
||||
c.tracingID = nextConnTracingID()
|
||||
if c.config.Tracer != nil {
|
||||
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
|
||||
}
|
||||
if c.tracer != nil && c.tracer.StartedConnection != nil {
|
||||
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
|
||||
}
|
||||
if err := c.dial(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.conn, nil
|
||||
}
|
||||
|
||||
func newClient(
|
||||
sendConn sendConn,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetter *statelessResetter,
|
||||
config *Config,
|
||||
tlsConf *tls.Config,
|
||||
onClose func(),
|
||||
use0RTT bool,
|
||||
) (*client, error) {
|
||||
srcConnID, err := connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
destConnID, err := generateConnectionIDForInitial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &client{
|
||||
connIDGenerator: connIDGenerator,
|
||||
statelessResetter: statelessResetter,
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
sendConn: sendConn,
|
||||
use0RTT: use0RTT,
|
||||
onClose: onClose,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
||||
c.conn = newClientConnection(
|
||||
context.WithValue(context.WithoutCancel(ctx), ConnectionTracingKey, c.tracingID),
|
||||
c.sendConn,
|
||||
c.packetHandlers,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.connIDGenerator,
|
||||
c.statelessResetter,
|
||||
c.config,
|
||||
c.tlsConf,
|
||||
c.initialPacketNumber,
|
||||
c.use0RTT,
|
||||
c.hasNegotiatedVersion,
|
||||
c.tracer,
|
||||
c.logger,
|
||||
c.version,
|
||||
)
|
||||
c.packetHandlers.Add(c.srcConnID, c.conn)
|
||||
|
||||
errorChan := make(chan error, 1)
|
||||
recreateChan := make(chan errCloseForRecreating)
|
||||
go func() {
|
||||
err := c.conn.run()
|
||||
var recreateErr *errCloseForRecreating
|
||||
if errors.As(err, &recreateErr) {
|
||||
recreateChan <- *recreateErr
|
||||
return
|
||||
}
|
||||
if c.onClose != nil {
|
||||
c.onClose()
|
||||
}
|
||||
errorChan <- err // returns as soon as the connection is closed
|
||||
}()
|
||||
|
||||
// only set when we're using 0-RTT
|
||||
// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
|
||||
var earlyConnChan <-chan struct{}
|
||||
if c.use0RTT {
|
||||
earlyConnChan = c.conn.earlyConnReady()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.conn.destroy(nil)
|
||||
// wait until the Go routine that called Connection.run() returns
|
||||
select {
|
||||
case <-errorChan:
|
||||
case <-recreateChan:
|
||||
}
|
||||
return context.Cause(ctx)
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
case recreateErr := <-recreateChan:
|
||||
c.initialPacketNumber = recreateErr.nextPacketNumber
|
||||
c.version = recreateErr.nextVersion
|
||||
c.hasNegotiatedVersion = true
|
||||
return c.dial(ctx)
|
||||
case <-earlyConnChan:
|
||||
// ready to send 0-RTT data
|
||||
return nil
|
||||
case <-c.conn.HandshakeComplete():
|
||||
// handshake successfully completed
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
447
client_test.go
447
client_test.go
|
@ -3,376 +3,93 @@ package quic
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"go.uber.org/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("Client", func() {
|
||||
var (
|
||||
cl *client
|
||||
packetConn *MockSendConn
|
||||
connID protocol.ConnectionID
|
||||
tlsConf *tls.Config
|
||||
tracer *mocklogging.MockConnectionTracer
|
||||
config *Config
|
||||
|
||||
originalClientConnConstructor func(
|
||||
ctx context.Context,
|
||||
conn sendConn,
|
||||
runner connRunner,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetToken *statelessResetter,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
enable0RTT bool,
|
||||
hasNegotiatedVersion bool,
|
||||
tracer *logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
v protocol.Version,
|
||||
) quicConn
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
|
||||
connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37})
|
||||
originalClientConnConstructor = newClientConnection
|
||||
var tr *logging.ConnectionTracer
|
||||
tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
|
||||
config = &Config{
|
||||
Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) *logging.ConnectionTracer {
|
||||
return tr
|
||||
func TestDial(t *testing.T) {
|
||||
t.Run("Dial", func(t *testing.T) {
|
||||
testDial(t,
|
||||
func(ctx context.Context, addr net.Addr) error {
|
||||
conn := newUPDConnLocalhost(t)
|
||||
_, err := Dial(ctx, conn, addr, &tls.Config{}, nil)
|
||||
return err
|
||||
},
|
||||
Versions: []protocol.Version{protocol.Version1},
|
||||
false,
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("DialEarly", func(t *testing.T) {
|
||||
testDial(t,
|
||||
func(ctx context.Context, addr net.Addr) error {
|
||||
conn := newUPDConnLocalhost(t)
|
||||
_, err := DialEarly(ctx, conn, addr, &tls.Config{}, nil)
|
||||
return err
|
||||
},
|
||||
false,
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("DialAddr", func(t *testing.T) {
|
||||
testDial(t,
|
||||
func(ctx context.Context, addr net.Addr) error {
|
||||
_, err := DialAddr(ctx, addr.String(), &tls.Config{}, nil)
|
||||
return err
|
||||
},
|
||||
true,
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("DialAddrEarly", func(t *testing.T) {
|
||||
testDial(t,
|
||||
func(ctx context.Context, addr net.Addr) error {
|
||||
_, err := DialAddrEarly(ctx, addr.String(), &tls.Config{}, nil)
|
||||
return err
|
||||
},
|
||||
true,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
func testDial(t *testing.T,
|
||||
dialFn func(context.Context, net.Addr) error,
|
||||
shouldCloseConn bool,
|
||||
) {
|
||||
server := newUPDConnLocalhost(t)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errChan := make(chan error, 1)
|
||||
go func() { errChan <- dialFn(ctx, server.LocalAddr()) }()
|
||||
|
||||
_, addr, err := server.ReadFrom(make([]byte, 1500))
|
||||
require.NoError(t, err)
|
||||
require.True(t, areTransportsRunning())
|
||||
cancel()
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
// The socket that the client used for dialing should be closed now.
|
||||
// Binding to the same address would error if the address was still in use.
|
||||
conn, err := net.ListenUDP("udp", addr.(*net.UDPAddr))
|
||||
if shouldCloseConn {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
if runtime.GOOS == "windows" {
|
||||
require.ErrorContains(t, err, "bind: Only one usage of each socket address")
|
||||
} else {
|
||||
require.ErrorContains(t, err, "address already in use")
|
||||
}
|
||||
Eventually(areConnsRunning).Should(BeFalse())
|
||||
packetConn = NewMockSendConn(mockCtrl)
|
||||
packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
packetConn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
cl = &client{
|
||||
srcConnID: connID,
|
||||
destConnID: connID,
|
||||
version: protocol.Version1,
|
||||
sendConn: packetConn,
|
||||
tracer: tr,
|
||||
logger: utils.DefaultLogger,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
AfterEach(func() {
|
||||
newClientConnection = originalClientConnConstructor
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
if s, ok := cl.conn.(*connection); ok {
|
||||
s.destroy(nil)
|
||||
}
|
||||
Eventually(areConnsRunning).Should(BeFalse())
|
||||
})
|
||||
|
||||
Context("Dialing", func() {
|
||||
var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error)
|
||||
|
||||
BeforeEach(func() {
|
||||
origGenerateConnectionIDForInitial = generateConnectionIDForInitial
|
||||
generateConnectionIDForInitial = func() (protocol.ConnectionID, error) {
|
||||
return connID, nil
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
|
||||
})
|
||||
|
||||
It("returns after the handshake is complete", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
|
||||
run := make(chan struct{})
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
enable0RTT bool,
|
||||
_ bool,
|
||||
_ *logging.ConnectionTracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.Version,
|
||||
) quicConn {
|
||||
Expect(enable0RTT).To(BeFalse())
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run().Do(func() error { close(run); return nil })
|
||||
c := make(chan struct{})
|
||||
close(c)
|
||||
conn.EXPECT().HandshakeComplete().Return(c)
|
||||
return conn
|
||||
}
|
||||
cl, err := newClient(
|
||||
packetConn,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
newStatelessResetter(nil),
|
||||
populateConfig(config),
|
||||
tlsConf,
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
Expect(cl.dial(context.Background())).To(Succeed())
|
||||
Eventually(run).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns early connections", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
readyChan := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
enable0RTT bool,
|
||||
_ bool,
|
||||
_ *logging.ConnectionTracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.Version,
|
||||
) quicConn {
|
||||
Expect(enable0RTT).To(BeTrue())
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run().Do(func() error { close(done); return nil })
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().earlyConnReady().Return(readyChan)
|
||||
return conn
|
||||
}
|
||||
|
||||
cl, err := newClient(
|
||||
packetConn,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
newStatelessResetter(nil),
|
||||
populateConfig(config),
|
||||
tlsConf,
|
||||
nil,
|
||||
true,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
Expect(cl.dial(context.Background())).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns an error that occurs while waiting for the handshake to complete", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
|
||||
testErr := errors.New("early handshake error")
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ *logging.ConnectionTracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.Version,
|
||||
) quicConn {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run().Return(testErr)
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().earlyConnReady().Return(make(chan struct{}))
|
||||
return conn
|
||||
}
|
||||
var closed bool
|
||||
cl, err := newClient(
|
||||
packetConn,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
newStatelessResetter(nil),
|
||||
populateConfig(config), tlsConf, func() { closed = true },
|
||||
true,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
Expect(cl.dial(context.Background())).To(MatchError(testErr))
|
||||
Expect(closed).To(BeTrue())
|
||||
})
|
||||
|
||||
Context("quic.Config", func() {
|
||||
It("setups with the right values", func() {
|
||||
tokenStore := NewLRUTokenStore(10, 4)
|
||||
config := &Config{
|
||||
HandshakeIdleTimeout: 1337 * time.Minute,
|
||||
MaxIdleTimeout: 42 * time.Hour,
|
||||
MaxIncomingStreams: 1234,
|
||||
MaxIncomingUniStreams: 4321,
|
||||
TokenStore: tokenStore,
|
||||
EnableDatagrams: true,
|
||||
}
|
||||
c := populateConfig(config)
|
||||
Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute))
|
||||
Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour))
|
||||
Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
|
||||
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
|
||||
Expect(c.TokenStore).To(Equal(tokenStore))
|
||||
Expect(c.EnableDatagrams).To(BeTrue())
|
||||
})
|
||||
|
||||
It("disables bidirectional streams", func() {
|
||||
config := &Config{
|
||||
MaxIncomingStreams: -1,
|
||||
MaxIncomingUniStreams: 4321,
|
||||
}
|
||||
c := populateConfig(config)
|
||||
Expect(c.MaxIncomingStreams).To(BeZero())
|
||||
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
|
||||
})
|
||||
|
||||
It("disables unidirectional streams", func() {
|
||||
config := &Config{
|
||||
MaxIncomingStreams: 1234,
|
||||
MaxIncomingUniStreams: -1,
|
||||
}
|
||||
c := populateConfig(config)
|
||||
Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
|
||||
Expect(c.MaxIncomingUniStreams).To(BeZero())
|
||||
})
|
||||
|
||||
It("fills in default values if options are not set in the Config", func() {
|
||||
c := populateConfig(&Config{})
|
||||
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
|
||||
Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
|
||||
})
|
||||
})
|
||||
|
||||
It("creates new connections with the right parameters", func() {
|
||||
config := &Config{Versions: []protocol.Version{protocol.Version1}}
|
||||
c := make(chan struct{})
|
||||
var version protocol.Version
|
||||
var conf *Config
|
||||
done := make(chan struct{})
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
connP sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ *logging.ConnectionTracer,
|
||||
_ utils.Logger,
|
||||
versionP protocol.Version,
|
||||
) quicConn {
|
||||
version = versionP
|
||||
conf = configP
|
||||
close(c)
|
||||
// TODO: check connection IDs?
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run()
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
|
||||
close(done)
|
||||
return conn
|
||||
}
|
||||
packetConn := NewMockPacketConn(mockCtrl)
|
||||
packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) {
|
||||
<-done
|
||||
return 0, nil, errors.New("closed")
|
||||
})
|
||||
packetConn.EXPECT().LocalAddr()
|
||||
packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
|
||||
_, err := Dial(context.Background(), packetConn, &net.UDPAddr{}, tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(c).Should(BeClosed())
|
||||
Expect(version).To(Equal(config.Versions[0]))
|
||||
Expect(conf.Versions).To(Equal(config.Versions))
|
||||
})
|
||||
|
||||
It("creates a new connections after version negotiation", func() {
|
||||
var counter int
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
connID protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
pn protocol.PacketNumber,
|
||||
_ bool,
|
||||
hasNegotiatedVersion bool,
|
||||
_ *logging.ConnectionTracer,
|
||||
_ utils.Logger,
|
||||
versionP protocol.Version,
|
||||
) quicConn {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
if counter == 0 {
|
||||
Expect(pn).To(BeZero())
|
||||
Expect(hasNegotiatedVersion).To(BeFalse())
|
||||
conn.EXPECT().run().DoAndReturn(func() error {
|
||||
runner.Remove(connID)
|
||||
return &errCloseForRecreating{
|
||||
nextPacketNumber: 109,
|
||||
nextVersion: 789,
|
||||
}
|
||||
})
|
||||
} else {
|
||||
Expect(pn).To(Equal(protocol.PacketNumber(109)))
|
||||
Expect(hasNegotiatedVersion).To(BeTrue())
|
||||
conn.EXPECT().run()
|
||||
conn.EXPECT().destroy(gomock.Any())
|
||||
}
|
||||
counter++
|
||||
return conn
|
||||
}
|
||||
|
||||
config := &Config{Tracer: config.Tracer, Versions: []protocol.Version{protocol.Version1}}
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
_, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(counter).To(Equal(2))
|
||||
})
|
||||
})
|
||||
})
|
||||
require.False(t, areTransportsRunning())
|
||||
}
|
||||
|
|
|
@ -175,6 +175,7 @@ func TestConfigDefaultValues(t *testing.T) {
|
|||
c = populateConfig(&Config{})
|
||||
require.Equal(t, protocol.SupportedVersions, c.Versions)
|
||||
require.Equal(t, protocol.DefaultHandshakeIdleTimeout, c.HandshakeIdleTimeout)
|
||||
require.Equal(t, protocol.DefaultIdleTimeout, c.MaxIdleTimeout)
|
||||
require.EqualValues(t, protocol.DefaultInitialMaxStreamData, c.InitialStreamReceiveWindow)
|
||||
require.EqualValues(t, protocol.DefaultMaxReceiveStreamFlowControlWindow, c.MaxStreamReceiveWindow)
|
||||
require.EqualValues(t, protocol.DefaultInitialMaxData, c.InitialConnectionReceiveWindow)
|
||||
|
@ -184,3 +185,13 @@ func TestConfigDefaultValues(t *testing.T) {
|
|||
require.False(t, c.DisablePathMTUDiscovery)
|
||||
require.Nil(t, c.GetConfigForClient)
|
||||
}
|
||||
|
||||
func TestConfigZeroLimits(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxIncomingStreams: -1,
|
||||
MaxIncomingUniStreams: -1,
|
||||
}
|
||||
c := populateConfig(config)
|
||||
require.Zero(t, c.MaxIncomingStreams)
|
||||
require.Zero(t, c.MaxIncomingUniStreams)
|
||||
}
|
||||
|
|
119
transport.go
119
transport.go
|
@ -218,25 +218,124 @@ func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsCon
|
|||
if err := t.init(t.isSingleUse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var onClose func()
|
||||
if t.isSingleUse {
|
||||
onClose = func() { t.Close() }
|
||||
}
|
||||
tlsConf = tlsConf.Clone()
|
||||
setTLSConfigServerName(tlsConf, addr, host)
|
||||
return dial(
|
||||
ctx,
|
||||
return t.doDial(ctx,
|
||||
newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger),
|
||||
t.connIDGenerator,
|
||||
t.statelessResetter,
|
||||
t.handlerMap,
|
||||
tlsConf,
|
||||
conf,
|
||||
onClose,
|
||||
0,
|
||||
false,
|
||||
use0RTT,
|
||||
conf.Versions[0],
|
||||
)
|
||||
}
|
||||
|
||||
func (t *Transport) doDial(
|
||||
ctx context.Context,
|
||||
sendConn sendConn,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
hasNegotiatedVersion bool,
|
||||
use0RTT bool,
|
||||
version protocol.Version,
|
||||
) (quicConn, error) {
|
||||
srcConnID, err := t.connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
destConnID, err := generateConnectionIDForInitial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tracingID := nextConnTracingID()
|
||||
ctx = context.WithValue(ctx, ConnectionTracingKey, tracingID)
|
||||
var tracer *logging.ConnectionTracer
|
||||
if config.Tracer != nil {
|
||||
tracer = config.Tracer(ctx, protocol.PerspectiveClient, destConnID)
|
||||
}
|
||||
if tracer != nil && tracer.StartedConnection != nil {
|
||||
tracer.StartedConnection(sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID)
|
||||
}
|
||||
|
||||
logger := utils.DefaultLogger.WithPrefix("client")
|
||||
logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", tlsConf.ServerName, sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID, version)
|
||||
|
||||
conn := newClientConnection(
|
||||
context.WithoutCancel(ctx),
|
||||
sendConn,
|
||||
t.handlerMap,
|
||||
destConnID,
|
||||
srcConnID,
|
||||
t.connIDGenerator,
|
||||
t.statelessResetter,
|
||||
config,
|
||||
tlsConf,
|
||||
initialPacketNumber,
|
||||
use0RTT,
|
||||
hasNegotiatedVersion,
|
||||
tracer,
|
||||
logger,
|
||||
version,
|
||||
)
|
||||
t.handlerMap.Add(srcConnID, conn)
|
||||
|
||||
// The error channel needs to be buffered, as the run loop will continue running
|
||||
// after doDial returns (if the handshake is successful).
|
||||
errChan := make(chan error, 1)
|
||||
recreateChan := make(chan errCloseForRecreating)
|
||||
go func() {
|
||||
err := conn.run()
|
||||
var recreateErr *errCloseForRecreating
|
||||
if errors.As(err, &recreateErr) {
|
||||
recreateChan <- *recreateErr
|
||||
return
|
||||
}
|
||||
if t.isSingleUse {
|
||||
t.Close()
|
||||
}
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
// Only set when we're using 0-RTT.
|
||||
// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
|
||||
var earlyConnChan <-chan struct{}
|
||||
if use0RTT {
|
||||
earlyConnChan = conn.earlyConnReady()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
conn.destroy(nil)
|
||||
// wait until the Go routine that called Connection.run() returns
|
||||
select {
|
||||
case <-errChan:
|
||||
case <-recreateChan:
|
||||
}
|
||||
return nil, context.Cause(ctx)
|
||||
case params := <-recreateChan:
|
||||
return t.doDial(ctx,
|
||||
sendConn,
|
||||
tlsConf,
|
||||
config,
|
||||
params.nextPacketNumber,
|
||||
true,
|
||||
use0RTT,
|
||||
params.nextVersion,
|
||||
)
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-earlyConnChan:
|
||||
// ready to send 0-RTT data
|
||||
return conn, nil
|
||||
case <-conn.HandshakeComplete():
|
||||
// handshake successfully completed
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
||||
t.initOnce.Do(func() {
|
||||
var conn rawConn
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
|
@ -455,3 +456,162 @@ func TestTransportSetTLSConfigServerName(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportDial(t *testing.T) {
|
||||
t.Run("regular", func(t *testing.T) {
|
||||
testTransportDial(t, false)
|
||||
})
|
||||
|
||||
t.Run("early", func(t *testing.T) {
|
||||
testTransportDial(t, true)
|
||||
})
|
||||
}
|
||||
|
||||
func testTransportDial(t *testing.T, early bool) {
|
||||
originalClientConnConstructor := newClientConnection
|
||||
t.Cleanup(func() { newClientConnection = originalClientConnConstructor })
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
handshakeChan := make(chan struct{})
|
||||
if early {
|
||||
conn.EXPECT().earlyConnReady().Return(handshakeChan)
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
} else {
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
||||
}
|
||||
blockRun := make(chan struct{})
|
||||
conn.EXPECT().run().DoAndReturn(func() error {
|
||||
<-blockRun
|
||||
return errors.New("done")
|
||||
})
|
||||
defer close(blockRun)
|
||||
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ *logging.ConnectionTracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.Version,
|
||||
) quicConn {
|
||||
return conn
|
||||
}
|
||||
|
||||
tr := &Transport{Conn: newUPDConnLocalhost(t)}
|
||||
tr.init(true)
|
||||
defer tr.Close()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
var err error
|
||||
if early {
|
||||
_, err = tr.DialEarly(context.Background(), nil, &tls.Config{}, nil)
|
||||
} else {
|
||||
_, err = tr.Dial(context.Background(), nil, &tls.Config{}, nil)
|
||||
}
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-errChan:
|
||||
t.Fatal("Dial shouldn't have returned")
|
||||
case <-time.After(scaleDuration(10 * time.Millisecond)):
|
||||
}
|
||||
|
||||
close(handshakeChan)
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
|
||||
// for test tear-down
|
||||
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
}
|
||||
|
||||
func TestTransportDialingVersionNegotiation(t *testing.T) {
|
||||
originalClientConnConstructor := newClientConnection
|
||||
t.Cleanup(func() { newClientConnection = originalClientConnConstructor })
|
||||
|
||||
// connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
mockCtrl := gomock.NewController(t)
|
||||
// runner := NewMockConnRunner(mockCtrl)
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().run().Return(&errCloseForRecreating{nextPacketNumber: 109, nextVersion: 789})
|
||||
|
||||
conn2 := NewMockQUICConn(mockCtrl)
|
||||
conn2.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn2.EXPECT().run().Return(errors.New("test done"))
|
||||
|
||||
type connParams struct {
|
||||
pn protocol.PacketNumber
|
||||
hasNegotiatedVersion bool
|
||||
version protocol.Version
|
||||
}
|
||||
|
||||
connChan := make(chan connParams, 2)
|
||||
var counter int
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
pn protocol.PacketNumber,
|
||||
_ bool,
|
||||
hasNegotiatedVersion bool,
|
||||
_ *logging.ConnectionTracer,
|
||||
_ utils.Logger,
|
||||
v protocol.Version,
|
||||
) quicConn {
|
||||
connChan <- connParams{pn: pn, hasNegotiatedVersion: hasNegotiatedVersion, version: v}
|
||||
if counter == 0 {
|
||||
counter++
|
||||
return conn
|
||||
}
|
||||
return conn2
|
||||
}
|
||||
|
||||
tr := &Transport{Conn: newUPDConnLocalhost(t)}
|
||||
tr.init(true)
|
||||
defer tr.Close()
|
||||
|
||||
_, err := tr.Dial(context.Background(), nil, &tls.Config{}, nil)
|
||||
require.EqualError(t, err, "test done")
|
||||
|
||||
select {
|
||||
case params := <-connChan:
|
||||
require.Zero(t, params.pn)
|
||||
require.False(t, params.hasNegotiatedVersion)
|
||||
require.Equal(t, protocol.Version1, params.version)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
select {
|
||||
case params := <-connChan:
|
||||
require.Equal(t, protocol.PacketNumber(109), params.pn)
|
||||
require.True(t, params.hasNegotiatedVersion)
|
||||
require.Equal(t, protocol.Version(789), params.version)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
// for test tear down
|
||||
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
conn2.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue