diff --git a/client.go b/client.go index 7a99a0f2..29a715cc 100644 --- a/client.go +++ b/client.go @@ -7,39 +7,8 @@ import ( "net" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" - "github.com/quic-go/quic-go/logging" ) -type client struct { - sendConn sendConn - - use0RTT bool - - packetHandlers packetHandlerManager - onClose func() - - tlsConf *tls.Config - config *Config - - connIDGenerator ConnectionIDGenerator - statelessResetter *statelessResetter - srcConnID protocol.ConnectionID - destConnID protocol.ConnectionID - - initialPacketNumber protocol.PacketNumber - hasNegotiatedVersion bool - version protocol.Version - - handshakeChan chan struct{} - - conn quicConn - - tracer *logging.ConnectionTracer - tracingID ConnectionTracingID - logger utils.Logger -} - // make it possible to mock connection ID for initial generation in the tests var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial @@ -133,136 +102,3 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo isSingleUse: true, }, nil } - -func dial( - ctx context.Context, - conn sendConn, - connIDGenerator ConnectionIDGenerator, - statelessResetter *statelessResetter, - packetHandlers packetHandlerManager, - tlsConf *tls.Config, - config *Config, - onClose func(), - use0RTT bool, -) (quicConn, error) { - c, err := newClient(conn, connIDGenerator, statelessResetter, config, tlsConf, onClose, use0RTT) - if err != nil { - return nil, err - } - c.packetHandlers = packetHandlers - - c.tracingID = nextConnTracingID() - if c.config.Tracer != nil { - c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) - } - if c.tracer != nil && c.tracer.StartedConnection != nil { - c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) - } - if err := c.dial(ctx); err != nil { - return nil, err - } - return c.conn, nil -} - -func newClient( - sendConn sendConn, - connIDGenerator ConnectionIDGenerator, - statelessResetter *statelessResetter, - config *Config, - tlsConf *tls.Config, - onClose func(), - use0RTT bool, -) (*client, error) { - srcConnID, err := connIDGenerator.GenerateConnectionID() - if err != nil { - return nil, err - } - destConnID, err := generateConnectionIDForInitial() - if err != nil { - return nil, err - } - c := &client{ - connIDGenerator: connIDGenerator, - statelessResetter: statelessResetter, - srcConnID: srcConnID, - destConnID: destConnID, - sendConn: sendConn, - use0RTT: use0RTT, - onClose: onClose, - tlsConf: tlsConf, - config: config, - version: config.Versions[0], - handshakeChan: make(chan struct{}), - logger: utils.DefaultLogger.WithPrefix("client"), - } - return c, nil -} - -func (c *client) dial(ctx context.Context) error { - c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) - - c.conn = newClientConnection( - context.WithValue(context.WithoutCancel(ctx), ConnectionTracingKey, c.tracingID), - c.sendConn, - c.packetHandlers, - c.destConnID, - c.srcConnID, - c.connIDGenerator, - c.statelessResetter, - c.config, - c.tlsConf, - c.initialPacketNumber, - c.use0RTT, - c.hasNegotiatedVersion, - c.tracer, - c.logger, - c.version, - ) - c.packetHandlers.Add(c.srcConnID, c.conn) - - errorChan := make(chan error, 1) - recreateChan := make(chan errCloseForRecreating) - go func() { - err := c.conn.run() - var recreateErr *errCloseForRecreating - if errors.As(err, &recreateErr) { - recreateChan <- *recreateErr - return - } - if c.onClose != nil { - c.onClose() - } - errorChan <- err // returns as soon as the connection is closed - }() - - // only set when we're using 0-RTT - // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. - var earlyConnChan <-chan struct{} - if c.use0RTT { - earlyConnChan = c.conn.earlyConnReady() - } - - select { - case <-ctx.Done(): - c.conn.destroy(nil) - // wait until the Go routine that called Connection.run() returns - select { - case <-errorChan: - case <-recreateChan: - } - return context.Cause(ctx) - case err := <-errorChan: - return err - case recreateErr := <-recreateChan: - c.initialPacketNumber = recreateErr.nextPacketNumber - c.version = recreateErr.nextVersion - c.hasNegotiatedVersion = true - return c.dial(ctx) - case <-earlyConnChan: - // ready to send 0-RTT data - return nil - case <-c.conn.HandshakeComplete(): - // handshake successfully completed - return nil - } -} diff --git a/client_test.go b/client_test.go index 8714647a..6e726783 100644 --- a/client_test.go +++ b/client_test.go @@ -3,376 +3,93 @@ package quic import ( "context" "crypto/tls" - "errors" "net" + "runtime" + "testing" "time" - mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" - "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" - "github.com/quic-go/quic-go/logging" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - "go.uber.org/mock/gomock" + "github.com/stretchr/testify/require" ) -var _ = Describe("Client", func() { - var ( - cl *client - packetConn *MockSendConn - connID protocol.ConnectionID - tlsConf *tls.Config - tracer *mocklogging.MockConnectionTracer - config *Config - - originalClientConnConstructor func( - ctx context.Context, - conn sendConn, - runner connRunner, - destConnID protocol.ConnectionID, - srcConnID protocol.ConnectionID, - connIDGenerator ConnectionIDGenerator, - statelessResetToken *statelessResetter, - conf *Config, - tlsConf *tls.Config, - initialPacketNumber protocol.PacketNumber, - enable0RTT bool, - hasNegotiatedVersion bool, - tracer *logging.ConnectionTracer, - logger utils.Logger, - v protocol.Version, - ) quicConn - ) - - BeforeEach(func() { - tlsConf = &tls.Config{NextProtos: []string{"proto1"}} - connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37}) - originalClientConnConstructor = newClientConnection - var tr *logging.ConnectionTracer - tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) - config = &Config{ - Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) *logging.ConnectionTracer { - return tr +func TestDial(t *testing.T) { + t.Run("Dial", func(t *testing.T) { + testDial(t, + func(ctx context.Context, addr net.Addr) error { + conn := newUPDConnLocalhost(t) + _, err := Dial(ctx, conn, addr, &tls.Config{}, nil) + return err }, - Versions: []protocol.Version{protocol.Version1}, + false, + ) + }) + + t.Run("DialEarly", func(t *testing.T) { + testDial(t, + func(ctx context.Context, addr net.Addr) error { + conn := newUPDConnLocalhost(t) + _, err := DialEarly(ctx, conn, addr, &tls.Config{}, nil) + return err + }, + false, + ) + }) + + t.Run("DialAddr", func(t *testing.T) { + testDial(t, + func(ctx context.Context, addr net.Addr) error { + _, err := DialAddr(ctx, addr.String(), &tls.Config{}, nil) + return err + }, + true, + ) + }) + + t.Run("DialAddrEarly", func(t *testing.T) { + testDial(t, + func(ctx context.Context, addr net.Addr) error { + _, err := DialAddrEarly(ctx, addr.String(), &tls.Config{}, nil) + return err + }, + true, + ) + }) +} + +func testDial(t *testing.T, + dialFn func(context.Context, net.Addr) error, + shouldCloseConn bool, +) { + server := newUPDConnLocalhost(t) + + ctx, cancel := context.WithCancel(context.Background()) + errChan := make(chan error, 1) + go func() { errChan <- dialFn(ctx, server.LocalAddr()) }() + + _, addr, err := server.ReadFrom(make([]byte, 1500)) + require.NoError(t, err) + require.True(t, areTransportsRunning()) + cancel() + select { + case err := <-errChan: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + // The socket that the client used for dialing should be closed now. + // Binding to the same address would error if the address was still in use. + conn, err := net.ListenUDP("udp", addr.(*net.UDPAddr)) + if shouldCloseConn { + require.NoError(t, err) + defer conn.Close() + } else { + require.Error(t, err) + if runtime.GOOS == "windows" { + require.ErrorContains(t, err, "bind: Only one usage of each socket address") + } else { + require.ErrorContains(t, err, "address already in use") } - Eventually(areConnsRunning).Should(BeFalse()) - packetConn = NewMockSendConn(mockCtrl) - packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() - packetConn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes() - cl = &client{ - srcConnID: connID, - destConnID: connID, - version: protocol.Version1, - sendConn: packetConn, - tracer: tr, - logger: utils.DefaultLogger, - } - }) + } - AfterEach(func() { - newClientConnection = originalClientConnConstructor - }) - - AfterEach(func() { - if s, ok := cl.conn.(*connection); ok { - s.destroy(nil) - } - Eventually(areConnsRunning).Should(BeFalse()) - }) - - Context("Dialing", func() { - var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error) - - BeforeEach(func() { - origGenerateConnectionIDForInitial = generateConnectionIDForInitial - generateConnectionIDForInitial = func() (protocol.ConnectionID, error) { - return connID, nil - } - }) - - AfterEach(func() { - generateConnectionIDForInitial = origGenerateConnectionIDForInitial - }) - - It("returns after the handshake is complete", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - - run := make(chan struct{}) - newClientConnection = func( - _ context.Context, - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - enable0RTT bool, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - Expect(enable0RTT).To(BeFalse()) - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run().Do(func() error { close(run); return nil }) - c := make(chan struct{}) - close(c) - conn.EXPECT().HandshakeComplete().Return(c) - return conn - } - cl, err := newClient( - packetConn, - &protocol.DefaultConnectionIDGenerator{}, - newStatelessResetter(nil), - populateConfig(config), - tlsConf, - nil, - false, - ) - Expect(err).ToNot(HaveOccurred()) - cl.packetHandlers = manager - Expect(cl).ToNot(BeNil()) - Expect(cl.dial(context.Background())).To(Succeed()) - Eventually(run).Should(BeClosed()) - }) - - It("returns early connections", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - readyChan := make(chan struct{}) - done := make(chan struct{}) - newClientConnection = func( - _ context.Context, - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - enable0RTT bool, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - Expect(enable0RTT).To(BeTrue()) - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run().Do(func() error { close(done); return nil }) - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - conn.EXPECT().earlyConnReady().Return(readyChan) - return conn - } - - cl, err := newClient( - packetConn, - &protocol.DefaultConnectionIDGenerator{}, - newStatelessResetter(nil), - populateConfig(config), - tlsConf, - nil, - true, - ) - Expect(err).ToNot(HaveOccurred()) - cl.packetHandlers = manager - Expect(cl).ToNot(BeNil()) - Expect(cl.dial(context.Background())).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("returns an error that occurs while waiting for the handshake to complete", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - - testErr := errors.New("early handshake error") - newClientConnection = func( - _ context.Context, - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run().Return(testErr) - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - conn.EXPECT().earlyConnReady().Return(make(chan struct{})) - return conn - } - var closed bool - cl, err := newClient( - packetConn, - &protocol.DefaultConnectionIDGenerator{}, - newStatelessResetter(nil), - populateConfig(config), tlsConf, func() { closed = true }, - true, - ) - Expect(err).ToNot(HaveOccurred()) - cl.packetHandlers = manager - Expect(cl).ToNot(BeNil()) - Expect(cl.dial(context.Background())).To(MatchError(testErr)) - Expect(closed).To(BeTrue()) - }) - - Context("quic.Config", func() { - It("setups with the right values", func() { - tokenStore := NewLRUTokenStore(10, 4) - config := &Config{ - HandshakeIdleTimeout: 1337 * time.Minute, - MaxIdleTimeout: 42 * time.Hour, - MaxIncomingStreams: 1234, - MaxIncomingUniStreams: 4321, - TokenStore: tokenStore, - EnableDatagrams: true, - } - c := populateConfig(config) - Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute)) - Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour)) - Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) - Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) - Expect(c.TokenStore).To(Equal(tokenStore)) - Expect(c.EnableDatagrams).To(BeTrue()) - }) - - It("disables bidirectional streams", func() { - config := &Config{ - MaxIncomingStreams: -1, - MaxIncomingUniStreams: 4321, - } - c := populateConfig(config) - Expect(c.MaxIncomingStreams).To(BeZero()) - Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) - }) - - It("disables unidirectional streams", func() { - config := &Config{ - MaxIncomingStreams: 1234, - MaxIncomingUniStreams: -1, - } - c := populateConfig(config) - Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) - Expect(c.MaxIncomingUniStreams).To(BeZero()) - }) - - It("fills in default values if options are not set in the Config", func() { - c := populateConfig(&Config{}) - Expect(c.Versions).To(Equal(protocol.SupportedVersions)) - Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) - Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) - }) - }) - - It("creates new connections with the right parameters", func() { - config := &Config{Versions: []protocol.Version{protocol.Version1}} - c := make(chan struct{}) - var version protocol.Version - var conf *Config - done := make(chan struct{}) - newClientConnection = func( - _ context.Context, - connP sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - configP *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - versionP protocol.Version, - ) quicConn { - version = versionP - conf = configP - close(c) - // TODO: check connection IDs? - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run() - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - conn.EXPECT().destroy(gomock.Any()).MaxTimes(1) - close(done) - return conn - } - packetConn := NewMockPacketConn(mockCtrl) - packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) { - <-done - return 0, nil, errors.New("closed") - }) - packetConn.EXPECT().LocalAddr() - packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes() - _, err := Dial(context.Background(), packetConn, &net.UDPAddr{}, tlsConf, config) - Expect(err).ToNot(HaveOccurred()) - Eventually(c).Should(BeClosed()) - Expect(version).To(Equal(config.Versions[0])) - Expect(conf.Versions).To(Equal(config.Versions)) - }) - - It("creates a new connections after version negotiation", func() { - var counter int - newClientConnection = func( - _ context.Context, - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - connID protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - configP *Config, - _ *tls.Config, - pn protocol.PacketNumber, - _ bool, - hasNegotiatedVersion bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - versionP protocol.Version, - ) quicConn { - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - if counter == 0 { - Expect(pn).To(BeZero()) - Expect(hasNegotiatedVersion).To(BeFalse()) - conn.EXPECT().run().DoAndReturn(func() error { - runner.Remove(connID) - return &errCloseForRecreating{ - nextPacketNumber: 109, - nextVersion: 789, - } - }) - } else { - Expect(pn).To(Equal(protocol.PacketNumber(109))) - Expect(hasNegotiatedVersion).To(BeTrue()) - conn.EXPECT().run() - conn.EXPECT().destroy(gomock.Any()) - } - counter++ - return conn - } - - config := &Config{Tracer: config.Tracer, Versions: []protocol.Version{protocol.Version1}} - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - _, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config) - Expect(err).ToNot(HaveOccurred()) - Expect(counter).To(Equal(2)) - }) - }) -}) + require.False(t, areTransportsRunning()) +} diff --git a/config_test.go b/config_test.go index ff4dd042..412300da 100644 --- a/config_test.go +++ b/config_test.go @@ -175,6 +175,7 @@ func TestConfigDefaultValues(t *testing.T) { c = populateConfig(&Config{}) require.Equal(t, protocol.SupportedVersions, c.Versions) require.Equal(t, protocol.DefaultHandshakeIdleTimeout, c.HandshakeIdleTimeout) + require.Equal(t, protocol.DefaultIdleTimeout, c.MaxIdleTimeout) require.EqualValues(t, protocol.DefaultInitialMaxStreamData, c.InitialStreamReceiveWindow) require.EqualValues(t, protocol.DefaultMaxReceiveStreamFlowControlWindow, c.MaxStreamReceiveWindow) require.EqualValues(t, protocol.DefaultInitialMaxData, c.InitialConnectionReceiveWindow) @@ -184,3 +185,13 @@ func TestConfigDefaultValues(t *testing.T) { require.False(t, c.DisablePathMTUDiscovery) require.Nil(t, c.GetConfigForClient) } + +func TestConfigZeroLimits(t *testing.T) { + config := &Config{ + MaxIncomingStreams: -1, + MaxIncomingUniStreams: -1, + } + c := populateConfig(config) + require.Zero(t, c.MaxIncomingStreams) + require.Zero(t, c.MaxIncomingUniStreams) +} diff --git a/transport.go b/transport.go index 32867550..ab7a77fd 100644 --- a/transport.go +++ b/transport.go @@ -218,25 +218,124 @@ func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsCon if err := t.init(t.isSingleUse); err != nil { return nil, err } - var onClose func() - if t.isSingleUse { - onClose = func() { t.Close() } - } tlsConf = tlsConf.Clone() setTLSConfigServerName(tlsConf, addr, host) - return dial( - ctx, + return t.doDial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), - t.connIDGenerator, - t.statelessResetter, - t.handlerMap, tlsConf, conf, - onClose, + 0, + false, use0RTT, + conf.Versions[0], ) } +func (t *Transport) doDial( + ctx context.Context, + sendConn sendConn, + tlsConf *tls.Config, + config *Config, + initialPacketNumber protocol.PacketNumber, + hasNegotiatedVersion bool, + use0RTT bool, + version protocol.Version, +) (quicConn, error) { + srcConnID, err := t.connIDGenerator.GenerateConnectionID() + if err != nil { + return nil, err + } + destConnID, err := generateConnectionIDForInitial() + if err != nil { + return nil, err + } + + tracingID := nextConnTracingID() + ctx = context.WithValue(ctx, ConnectionTracingKey, tracingID) + var tracer *logging.ConnectionTracer + if config.Tracer != nil { + tracer = config.Tracer(ctx, protocol.PerspectiveClient, destConnID) + } + if tracer != nil && tracer.StartedConnection != nil { + tracer.StartedConnection(sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID) + } + + logger := utils.DefaultLogger.WithPrefix("client") + logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", tlsConf.ServerName, sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID, version) + + conn := newClientConnection( + context.WithoutCancel(ctx), + sendConn, + t.handlerMap, + destConnID, + srcConnID, + t.connIDGenerator, + t.statelessResetter, + config, + tlsConf, + initialPacketNumber, + use0RTT, + hasNegotiatedVersion, + tracer, + logger, + version, + ) + t.handlerMap.Add(srcConnID, conn) + + // The error channel needs to be buffered, as the run loop will continue running + // after doDial returns (if the handshake is successful). + errChan := make(chan error, 1) + recreateChan := make(chan errCloseForRecreating) + go func() { + err := conn.run() + var recreateErr *errCloseForRecreating + if errors.As(err, &recreateErr) { + recreateChan <- *recreateErr + return + } + if t.isSingleUse { + t.Close() + } + errChan <- err + }() + + // Only set when we're using 0-RTT. + // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. + var earlyConnChan <-chan struct{} + if use0RTT { + earlyConnChan = conn.earlyConnReady() + } + + select { + case <-ctx.Done(): + conn.destroy(nil) + // wait until the Go routine that called Connection.run() returns + select { + case <-errChan: + case <-recreateChan: + } + return nil, context.Cause(ctx) + case params := <-recreateChan: + return t.doDial(ctx, + sendConn, + tlsConf, + config, + params.nextPacketNumber, + true, + use0RTT, + params.nextVersion, + ) + case err := <-errChan: + return nil, err + case <-earlyConnChan: + // ready to send 0-RTT data + return conn, nil + case <-conn.HandshakeComplete(): + // handshake successfully completed + return conn, nil + } +} + func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.initOnce.Do(func() { var conn rawConn diff --git a/transport_test.go b/transport_test.go index 4b4f54b4..37ee580f 100644 --- a/transport_test.go +++ b/transport_test.go @@ -13,6 +13,7 @@ import ( mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/logging" @@ -455,3 +456,162 @@ func TestTransportSetTLSConfigServerName(t *testing.T) { }) } } + +func TestTransportDial(t *testing.T) { + t.Run("regular", func(t *testing.T) { + testTransportDial(t, false) + }) + + t.Run("early", func(t *testing.T) { + testTransportDial(t, true) + }) +} + +func testTransportDial(t *testing.T, early bool) { + originalClientConnConstructor := newClientConnection + t.Cleanup(func() { newClientConnection = originalClientConnConstructor }) + + mockCtrl := gomock.NewController(t) + conn := NewMockQUICConn(mockCtrl) + handshakeChan := make(chan struct{}) + if early { + conn.EXPECT().earlyConnReady().Return(handshakeChan) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) + } else { + conn.EXPECT().HandshakeComplete().Return(handshakeChan) + } + blockRun := make(chan struct{}) + conn.EXPECT().run().DoAndReturn(func() error { + <-blockRun + return errors.New("done") + }) + defer close(blockRun) + + newClientConnection = func( + _ context.Context, + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ ConnectionIDGenerator, + _ *statelessResetter, + _ *Config, + _ *tls.Config, + _ protocol.PacketNumber, + _ bool, + _ bool, + _ *logging.ConnectionTracer, + _ utils.Logger, + _ protocol.Version, + ) quicConn { + return conn + } + + tr := &Transport{Conn: newUPDConnLocalhost(t)} + tr.init(true) + defer tr.Close() + + errChan := make(chan error, 1) + go func() { + var err error + if early { + _, err = tr.DialEarly(context.Background(), nil, &tls.Config{}, nil) + } else { + _, err = tr.Dial(context.Background(), nil, &tls.Config{}, nil) + } + errChan <- err + }() + + select { + case <-errChan: + t.Fatal("Dial shouldn't have returned") + case <-time.After(scaleDuration(10 * time.Millisecond)): + } + + close(handshakeChan) + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(time.Second): + } + + // for test tear-down + conn.EXPECT().destroy(gomock.Any()).AnyTimes() +} + +func TestTransportDialingVersionNegotiation(t *testing.T) { + originalClientConnConstructor := newClientConnection + t.Cleanup(func() { newClientConnection = originalClientConnConstructor }) + + // connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + mockCtrl := gomock.NewController(t) + // runner := NewMockConnRunner(mockCtrl) + conn := NewMockQUICConn(mockCtrl) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) + conn.EXPECT().run().Return(&errCloseForRecreating{nextPacketNumber: 109, nextVersion: 789}) + + conn2 := NewMockQUICConn(mockCtrl) + conn2.EXPECT().HandshakeComplete().Return(make(chan struct{})) + conn2.EXPECT().run().Return(errors.New("test done")) + + type connParams struct { + pn protocol.PacketNumber + hasNegotiatedVersion bool + version protocol.Version + } + + connChan := make(chan connParams, 2) + var counter int + newClientConnection = func( + _ context.Context, + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ ConnectionIDGenerator, + _ *statelessResetter, + _ *Config, + _ *tls.Config, + pn protocol.PacketNumber, + _ bool, + hasNegotiatedVersion bool, + _ *logging.ConnectionTracer, + _ utils.Logger, + v protocol.Version, + ) quicConn { + connChan <- connParams{pn: pn, hasNegotiatedVersion: hasNegotiatedVersion, version: v} + if counter == 0 { + counter++ + return conn + } + return conn2 + } + + tr := &Transport{Conn: newUPDConnLocalhost(t)} + tr.init(true) + defer tr.Close() + + _, err := tr.Dial(context.Background(), nil, &tls.Config{}, nil) + require.EqualError(t, err, "test done") + + select { + case params := <-connChan: + require.Zero(t, params.pn) + require.False(t, params.hasNegotiatedVersion) + require.Equal(t, protocol.Version1, params.version) + case <-time.After(time.Second): + t.Fatal("timeout") + } + select { + case params := <-connChan: + require.Equal(t, protocol.PacketNumber(109), params.pn) + require.True(t, params.hasNegotiatedVersion) + require.Equal(t, protocol.Version(789), params.version) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + // for test tear down + conn.EXPECT().destroy(gomock.Any()).AnyTimes() + conn2.EXPECT().destroy(gomock.Any()).AnyTimes() +}