From ba942715dbada3d26b80ac375836cbcefae4ec51 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 20 Apr 2023 11:37:56 +0200 Subject: [PATCH] remove ConnectionIDLength and ConnectionIDGenerator from the Config --- client.go | 23 ++++-- client_test.go | 31 ++++--- config.go | 26 +----- config_test.go | 17 +--- connection.go | 6 +- connection_test.go | 6 +- integrationtests/self/conn_id_test.go | 81 +++++++++++-------- integrationtests/self/handshake_test.go | 4 +- integrationtests/self/mitm_test.go | 42 ++++++---- integrationtests/self/multiplex_test.go | 1 - integrationtests/self/stateless_reset_test.go | 29 ++++--- integrationtests/self/zero_rtt_test.go | 46 ++++++++--- interface.go | 12 --- server.go | 14 ++-- server_test.go | 11 +++ transport.go | 36 +++++---- transport_test.go | 30 +++---- 17 files changed, 232 insertions(+), 183 deletions(-) diff --git a/client.go b/client.go index c8ea0641..ed6ccfb8 100644 --- a/client.go +++ b/client.go @@ -25,8 +25,9 @@ type client struct { tlsConf *tls.Config config *Config - srcConnID protocol.ConnectionID - destConnID protocol.ConnectionID + connIDGenerator ConnectionIDGenerator + srcConnID protocol.ConnectionID + destConnID protocol.ConnectionID initialPacketNumber protocol.PacketNumber hasNegotiatedVersion bool @@ -133,6 +134,7 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo func dial( ctx context.Context, conn net.PacketConn, + connIDGenerator ConnectionIDGenerator, packetHandlers packetHandlerManager, addr net.Addr, tlsConf *tls.Config, @@ -141,7 +143,7 @@ func dial( use0RTT bool, createdPacketConn bool, ) (quicConn, error) { - c, err := newClient(conn, addr, config, tlsConf, onClose, use0RTT, createdPacketConn) + c, err := newClient(conn, addr, connIDGenerator, config, tlsConf, onClose, use0RTT, createdPacketConn) if err != nil { return nil, err } @@ -164,14 +166,23 @@ func dial( return c.conn, nil } -func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, onClose func(), use0RTT, createdPacketConn bool) (*client, error) { +func newClient( + pconn net.PacketConn, + remoteAddr net.Addr, + connIDGenerator ConnectionIDGenerator, + config *Config, + tlsConf *tls.Config, + onClose func(), + use0RTT bool, + createdPacketConn bool, +) (*client, error) { if tlsConf == nil { tlsConf = &tls.Config{} } else { tlsConf = tlsConf.Clone() } - srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID() + srcConnID, err := connIDGenerator.GenerateConnectionID() if err != nil { return nil, err } @@ -180,6 +191,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon return nil, err } c := &client{ + connIDGenerator: connIDGenerator, srcConnID: srcConnID, destConnID: destConnID, sconn: newSendPconn(pconn, remoteAddr), @@ -203,6 +215,7 @@ func (c *client) dial(ctx context.Context) error { c.packetHandlers, c.destConnID, c.srcConnID, + c.connIDGenerator, c.config, c.tlsConf, c.initialPacketNumber, diff --git a/client_test.go b/client_test.go index 71e16f0b..ce53ef4b 100644 --- a/client_test.go +++ b/client_test.go @@ -39,6 +39,7 @@ var _ = Describe("Client", func() { runner connRunner, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + connIDGenerator ConnectionIDGenerator, conf *Config, tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber, @@ -114,6 +115,7 @@ var _ = Describe("Client", func() { _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ *Config, _ *tls.Config, _ protocol.PacketNumber, @@ -132,7 +134,7 @@ var _ = Describe("Client", func() { conn.EXPECT().HandshakeComplete().Return(c) return conn } - cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, false, false) + cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false, false) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -150,6 +152,7 @@ var _ = Describe("Client", func() { runner connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ *Config, _ *tls.Config, _ protocol.PacketNumber, @@ -168,7 +171,7 @@ var _ = Describe("Client", func() { return conn } - cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, true, false) + cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true, false) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -186,6 +189,7 @@ var _ = Describe("Client", func() { _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ *Config, _ *tls.Config, _ protocol.PacketNumber, @@ -203,7 +207,7 @@ var _ = Describe("Client", func() { return conn } var closed bool - cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, func() { closed = true }, true, false) + cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true, false) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -219,16 +223,14 @@ var _ = Describe("Client", func() { MaxIdleTimeout: 42 * time.Hour, MaxIncomingStreams: 1234, MaxIncomingUniStreams: 4321, - ConnectionIDLength: 13, TokenStore: tokenStore, EnableDatagrams: true, } - c := populateClientConfig(config, false) + c := populateConfig(config) Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute)) Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour)) Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) - Expect(c.ConnectionIDLength).To(Equal(13)) Expect(c.TokenStore).To(Equal(tokenStore)) Expect(c.EnableDatagrams).To(BeTrue()) }) @@ -238,7 +240,7 @@ var _ = Describe("Client", func() { MaxIncomingStreams: -1, MaxIncomingUniStreams: 4321, } - c := populateClientConfig(config, false) + c := populateConfig(config) Expect(c.MaxIncomingStreams).To(BeZero()) Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) }) @@ -248,18 +250,13 @@ var _ = Describe("Client", func() { MaxIncomingStreams: 1234, MaxIncomingUniStreams: -1, } - c := populateClientConfig(config, false) + c := populateConfig(config) Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) Expect(c.MaxIncomingUniStreams).To(BeZero()) }) - It("uses 0-byte connection IDs when dialing an address", func() { - c := populateClientConfig(&Config{}, true) - Expect(c.ConnectionIDLength).To(BeZero()) - }) - It("fills in default values if options are not set in the Config", func() { - c := populateClientConfig(&Config{}, false) + c := populateConfig(&Config{}) Expect(c.Versions).To(Equal(protocol.SupportedVersions)) Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) @@ -267,7 +264,7 @@ var _ = Describe("Client", func() { }) It("creates new connections with the right parameters", func() { - config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}} + config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}} c := make(chan struct{}) var cconn sendConn var version protocol.VersionNumber @@ -278,6 +275,7 @@ var _ = Describe("Client", func() { _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, configP *Config, _ *tls.Config, _ protocol.PacketNumber, @@ -320,6 +318,7 @@ var _ = Describe("Client", func() { runner connRunner, _ protocol.ConnectionID, connID protocol.ConnectionID, + _ ConnectionIDGenerator, configP *Config, _ *tls.Config, pn protocol.PacketNumber, @@ -352,7 +351,7 @@ var _ = Describe("Client", func() { return conn } - config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}} + config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}} tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) _, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config) Expect(err).ToNot(HaveOccurred()) diff --git a/config.go b/config.go index 7a48b407..ceb4fb68 100644 --- a/config.go +++ b/config.go @@ -42,7 +42,7 @@ func validateConfig(config *Config) error { // populateServerConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateServerConfig(config *Config) *Config { - config = populateConfig(config, protocol.DefaultConnectionIDLength) + config = populateConfig(config) if config.MaxTokenAge == 0 { config.MaxTokenAge = protocol.TokenValidity } @@ -55,19 +55,9 @@ func populateServerConfig(config *Config) *Config { return config } -// populateClientConfig populates fields in the quic.Config with their default values, if none are set +// populateConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil -func populateClientConfig(config *Config, createdPacketConn bool) *Config { - defaultConnIDLen := protocol.DefaultConnectionIDLength - if createdPacketConn { - defaultConnIDLen = 0 - } - - config = populateConfig(config, defaultConnIDLen) - return config -} - -func populateConfig(config *Config, defaultConnIDLen int) *Config { +func populateConfig(config *Config) *Config { if config == nil { config = &Config{} } @@ -75,10 +65,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { if len(versions) == 0 { versions = protocol.SupportedVersions } - conIDLen := config.ConnectionIDLength - if config.ConnectionIDLength == 0 { - conIDLen = defaultConnIDLen - } handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout if config.HandshakeIdleTimeout != 0 { handshakeIdleTimeout = config.HandshakeIdleTimeout @@ -115,10 +101,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { } else if maxIncomingUniStreams < 0 { maxIncomingUniStreams = 0 } - connIDGenerator := config.ConnectionIDGenerator - if connIDGenerator == nil { - connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen} - } return &Config{ Versions: versions, @@ -135,8 +117,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, - ConnectionIDLength: conIDLen, - ConnectionIDGenerator: connIDGenerator, TokenStore: config.TokenStore, EnableDatagrams: config.EnableDatagrams, DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, diff --git a/config_test.go b/config_test.go index b9ca6a5c..f319deb2 100644 --- a/config_test.go +++ b/config_test.go @@ -142,18 +142,18 @@ var _ = Describe("Config", func() { var calledAddrValidation bool c1 := &Config{} c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true } - c2 := populateConfig(c1, protocol.DefaultConnectionIDLength) + c2 := populateConfig(c1) c2.RequireAddressValidation(&net.UDPAddr{}) Expect(calledAddrValidation).To(BeTrue()) }) It("copies non-function fields", func() { c := configWithNonZeroNonFunctionFields() - Expect(populateConfig(c, protocol.DefaultConnectionIDLength)).To(Equal(c)) + Expect(populateConfig(c)).To(Equal(c)) }) It("populates empty fields with default values", func() { - c := populateConfig(&Config{}, protocol.DefaultConnectionIDLength) + c := populateConfig(&Config{}) Expect(c.Versions).To(Equal(protocol.SupportedVersions)) Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData)) @@ -168,18 +168,7 @@ var _ = Describe("Config", func() { It("populates empty fields with default values, for the server", func() { c := populateServerConfig(&Config{}) - Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) Expect(c.RequireAddressValidation).ToNot(BeNil()) }) - - It("sets a default connection ID length if we didn't create the conn, for the client", func() { - c := populateClientConfig(&Config{}, false) - Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) - }) - - It("doesn't set a default connection ID length if we created the conn, for the client", func() { - c := populateClientConfig(&Config{}, true) - Expect(c.ConnectionIDLength).To(BeZero()) - }) }) }) diff --git a/connection.go b/connection.go index eb16ece5..78969ed9 100644 --- a/connection.go +++ b/connection.go @@ -240,6 +240,7 @@ var newConnection = func( clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + connIDGenerator ConnectionIDGenerator, statelessResetToken protocol.StatelessResetToken, conf *Config, tlsConf *tls.Config, @@ -283,7 +284,7 @@ var newConnection = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, - s.config.ConnectionIDGenerator, + connIDGenerator, ) s.preSetup() s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) @@ -363,6 +364,7 @@ var newClientConnection = func( runner connRunner, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + connIDGenerator ConnectionIDGenerator, conf *Config, tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber, @@ -402,7 +404,7 @@ var newClientConnection = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, - s.config.ConnectionIDGenerator, + connIDGenerator, ) s.preSetup() s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) diff --git a/connection_test.go b/connection_test.go index a2829e85..10f0c202 100644 --- a/connection_test.go +++ b/connection_test.go @@ -113,6 +113,7 @@ var _ = Describe("Connection", func() { clientDestConnID, destConnID, srcConnID, + &protocol.DefaultConnectionIDGenerator{}, protocol.StatelessResetToken{}, populateServerConfig(&Config{DisablePathMTUDiscovery: true}), nil, // tls.Config @@ -2015,8 +2016,6 @@ var _ = Describe("Connection", func() { packer.EXPECT().HandleTransportParameters(params) packer.EXPECT().PackCoalescedPacket(false, conn.version).MaxTimes(3) Expect(conn.earlyConnReady()).ToNot(BeClosed()) - connRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2) - connRunner.EXPECT().Add(gomock.Any(), conn).Times(2) tracer.EXPECT().ReceivedTransportParameters(params) conn.handleTransportParameters(params) Expect(conn.earlyConnReady()).To(BeClosed()) @@ -2378,7 +2377,7 @@ var _ = Describe("Client Connection", func() { } BeforeEach(func() { - quicConf = populateClientConfig(&Config{}, true) + quicConf = populateConfig(&Config{}) tlsConf = nil }) @@ -2402,6 +2401,7 @@ var _ = Describe("Client Connection", func() { connRunner, destConnID, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + &protocol.DefaultConnectionIDGenerator{}, quicConf, tlsConf, 42, // initial packet number diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index 7cb1904d..0d8c4b44 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -34,13 +34,23 @@ func (c *connIDGenerator) ConnectionIDLen() int { var _ = Describe("Connection ID lengths tests", func() { randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) } - runServer := func(conf *quic.Config) *quic.Listener { - if conf.ConnectionIDGenerator != nil { - GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", conf.ConnectionIDGenerator.ConnectionIDLen()))) + // connIDLen is ignored when connIDGenerator is set + runServer := func(connIDLen int, connIDGenerator quic.ConnectionIDGenerator) (*quic.Listener, func()) { + if connIDGenerator != nil { + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", connIDGenerator.ConnectionIDLen()))) } else { - GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength))) + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", connIDLen))) } - ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf) + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + tr := &quic.Transport{ + Conn: conn, + ConnectionIDLength: connIDLen, + ConnectionIDGenerator: connIDGenerator, + } + ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() @@ -59,20 +69,35 @@ var _ = Describe("Connection ID lengths tests", func() { }() } }() - return ln + return ln, func() { + ln.Close() + tr.Close() + } } - runClient := func(addr net.Addr, conf *quic.Config) { - if conf.ConnectionIDGenerator != nil { - GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", conf.ConnectionIDGenerator.ConnectionIDLen()))) + // connIDLen is ignored when connIDGenerator is set + runClient := func(addr net.Addr, connIDLen int, connIDGenerator quic.ConnectionIDGenerator) { + if connIDGenerator != nil { + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", connIDGenerator.ConnectionIDLen()))) } else { - GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength))) + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", connIDLen))) } - cl, err := quic.DialAddr( + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + tr := &quic.Transport{ + Conn: conn, + ConnectionIDLength: connIDLen, + ConnectionIDGenerator: connIDGenerator, + } + defer tr.Close() + cl, err := tr.Dial( context.Background(), - fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port), + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port}, getTLSClientConfig(), - conf, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) defer cl.CloseWithError(0, "") @@ -84,32 +109,20 @@ var _ = Describe("Connection ID lengths tests", func() { } It("downloads a file using a 0-byte connection ID for the client", func() { - serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()}) - ln := runServer(serverConf) - defer ln.Close() - - runClient(ln.Addr(), getQuicConfig(nil)) + ln, closeFn := runServer(randomConnIDLen(), nil) + defer closeFn() + runClient(ln.Addr(), 0, nil) }) It("downloads a file when both client and server use a random connection ID length", func() { - serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()}) - ln := runServer(serverConf) - defer ln.Close() - - runClient(ln.Addr(), getQuicConfig(nil)) + ln, closeFn := runServer(randomConnIDLen(), nil) + defer closeFn() + runClient(ln.Addr(), randomConnIDLen(), nil) }) It("downloads a file when both client and server use a custom connection ID generator", func() { - serverConf := getQuicConfig(&quic.Config{ - ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()}, - }) - clientConf := getQuicConfig(&quic.Config{ - ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()}, - }) - - ln := runServer(serverConf) - defer ln.Close() - - runClient(ln.Addr(), clientConf) + ln, closeFn := runServer(0, &connIDGenerator{length: randomConnIDLen()}) + defer closeFn() + runClient(ln.Addr(), 0, &connIDGenerator{length: randomConnIDLen()}) }) }) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index b3a13e9d..fc77f424 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -244,7 +244,7 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) pconn, err = net.ListenUDP("udp", laddr) Expect(err).ToNot(HaveOccurred()) - dialer = &quic.Transport{Conn: pconn} + dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4} }) AfterEach(func() { @@ -303,7 +303,7 @@ var _ = Describe("Handshake tests", func() { // This should free one spot in the queue. Expect(firstConn.CloseWithError(0, "")) Eventually(firstConn.Context().Done()).Should(BeClosed()) - time.Sleep(scaleDuration(20 * time.Millisecond)) + time.Sleep(scaleDuration(200 * time.Millisecond)) // dial again, and expect that this dial succeeds _, err = dial() diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index 133c6f23..35e0af91 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -35,7 +35,11 @@ var _ = Describe("MITM test", func() { Expect(err).ToNot(HaveOccurred()) serverUDPConn, err = net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) - ln, err := quic.Listen(serverUDPConn, getTLSConfig(), serverConfig) + tr := &quic.Transport{ + Conn: serverUDPConn, + ConnectionIDLength: connIDLen, + } + ln, err := tr.Listen(getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) go func() { @@ -68,7 +72,7 @@ var _ = Describe("MITM test", func() { } BeforeEach(func() { - serverConfig = getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}) + serverConfig = getQuicConfig(nil) addr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) clientUDPConn, err = net.ListenUDP("udp", addr) @@ -146,12 +150,15 @@ var _ = Describe("MITM test", func() { defer closeFn() raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) - conn, err := quic.Dial( + tr := &quic.Transport{ + Conn: clientUDPConn, + ConnectionIDLength: connIDLen, + } + conn, err := tr.Dial( context.Background(), - clientUDPConn, raddr, getTLSClientConfig(), - getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}), + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptUniStream(context.Background()) @@ -190,12 +197,15 @@ var _ = Describe("MITM test", func() { defer closeFn() raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) - conn, err := quic.Dial( + tr := &quic.Transport{ + Conn: clientUDPConn, + ConnectionIDLength: connIDLen, + } + conn, err := tr.Dial( context.Background(), - clientUDPConn, raddr, getTLSClientConfig(), - getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}), + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptUniStream(context.Background()) @@ -302,20 +312,20 @@ var _ = Describe("MITM test", func() { const rtt = 20 * time.Millisecond runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) { - proxyPort, closeFn := startServerAndProxy(delayCb, nil) + proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil) raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) - _, err = quic.Dial( + tr := &quic.Transport{ + Conn: clientUDPConn, + ConnectionIDLength: connIDLen, + } + _, err = tr.Dial( context.Background(), - clientUDPConn, raddr, getTLSClientConfig(), - getQuicConfig(&quic.Config{ - ConnectionIDLength: connIDLen, - HandshakeIdleTimeout: 2 * time.Second, - }), + getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}), ) - return closeFn, err + return func() { tr.Close(); serverCloseFn() }, err } // fails immediately because client connection closes when it can't find compatible version diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index dcac1b46..2623e322 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -137,7 +137,6 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn.Close() tr := &quic.Transport{Conn: conn} - server, err := tr.Listen( getTLSConfig(), getQuicConfig(nil), diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 3c9a1703..4ceb8067 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -28,8 +28,9 @@ var _ = Describe("Stateless Resets", func() { c, err := net.ListenUDP("udp", nil) Expect(err).ToNot(HaveOccurred()) tr := &quic.Transport{ - Conn: c, - StatelessResetKey: &statelessResetKey, + Conn: c, + StatelessResetKey: &statelessResetKey, + ConnectionIDLength: connIDLen, } defer tr.Close() ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) @@ -61,14 +62,21 @@ var _ = Describe("Stateless Resets", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - conn, err := quic.DialAddr( + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + defer udpConn.Close() + cl := &quic.Transport{ + Conn: udpConn, + ConnectionIDLength: connIDLen, + } + defer cl.Close() + conn, err := cl.Dial( context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()}, getTLSClientConfig(), - getQuicConfig(&quic.Config{ - ConnectionIDLength: connIDLen, - MaxIdleTimeout: 2 * time.Second, - }), + getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptStream(context.Background()) @@ -86,8 +94,9 @@ var _ = Describe("Stateless Resets", func() { // We need to create a new Transport here, since the old one is still sending out // CONNECTION_CLOSE packets for (recently) closed connections). tr2 := &quic.Transport{ - Conn: c, - StatelessResetKey: &statelessResetKey, + Conn: c, + ConnectionIDLength: connIDLen, + StatelessResetKey: &statelessResetKey, } defer tr2.Close() ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil)) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 1559c176..e83a9f96 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -101,6 +101,7 @@ var _ = Describe("0-RTT", func() { transfer0RTTData := func( ln *quic.EarlyListener, proxyPort int, + connIDLen int, clientTLSConf *tls.Config, clientConf *quic.Config, testdata []byte, // data to transfer @@ -125,13 +126,35 @@ var _ = Describe("0-RTT", func() { if clientConf == nil { clientConf = getQuicConfig(nil) } - conn, err := quic.DialAddrEarly( - context.Background(), - fmt.Sprintf("localhost:%d", proxyPort), - clientTLSConf, - clientConf, - ) - Expect(err).ToNot(HaveOccurred()) + var conn quic.EarlyConnection + if connIDLen == 0 { + var err error + conn, err = quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxyPort), + clientTLSConf, + clientConf, + ) + Expect(err).ToNot(HaveOccurred()) + } else { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + defer udpConn.Close() + tr := &quic.Transport{ + Conn: udpConn, + ConnectionIDLength: connIDLen, + } + defer tr.Close() + conn, err = tr.DialEarly( + context.Background(), + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort}, + clientTLSConf, + clientConf, + ) + Expect(err).ToNot(HaveOccurred()) + } defer conn.CloseWithError(0, "") str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -212,8 +235,9 @@ var _ = Describe("0-RTT", func() { transfer0RTTData( ln, proxy.LocalPort(), + connIDLen, clientTLSConf, - getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}), + getQuicConfig(nil), PRData, ) @@ -373,7 +397,7 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData) + transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) num0RTT := atomic.LoadUint32(&num0RTTPackets) numDropped := atomic.LoadUint32(&num0RTTDropped) @@ -448,7 +472,7 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, GeneratePRData(5000)) // ~5 packets + transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets mutex.Lock() defer mutex.Unlock() @@ -768,7 +792,7 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData) + transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) diff --git a/interface.go b/interface.go index 33309fac..267c07f5 100644 --- a/interface.go +++ b/interface.go @@ -242,18 +242,6 @@ type Config struct { // The QUIC versions that can be negotiated. // If not set, it uses all versions available. Versions []VersionNumber - // The length of the connection ID in bytes. - // It can be 0, or any value between 4 and 18. - // If not set, the interpretation depends on where the Config is used: - // If used for dialing an address, a 0 byte connection ID will be used. - // If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used. - // When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call. - ConnectionIDLength int - // An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection. - // The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers. - // By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used. - // Otherwise, if one is provided, then ConnectionIDLength is ignored. - ConnectionIDGenerator ConnectionIDGenerator // HandshakeIdleTimeout is the idle timeout before completion of the handshake. // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. // If this value is zero, the timeout is set to 5 seconds. diff --git a/server.go b/server.go index f8c9b3cd..a80709ea 100644 --- a/server.go +++ b/server.go @@ -68,8 +68,9 @@ type baseServer struct { tokenGenerator *handshake.TokenGenerator - connHandler packetHandlerManager - onClose func() + connIDGenerator ConnectionIDGenerator + connHandler packetHandlerManager + onClose func() receivedPackets chan *receivedPacket @@ -85,6 +86,7 @@ type baseServer struct { protocol.ConnectionID, /* client dest connection ID */ protocol.ConnectionID, /* destination connection ID */ protocol.ConnectionID, /* source connection ID */ + ConnectionIDGenerator, protocol.StatelessResetToken, *Config, *tls.Config, @@ -210,7 +212,7 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear return tr.ListenEarly(tlsConf, config) } -func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) { +func newServer(conn rawConn, connHandler packetHandlerManager, connIDGenerator ConnectionIDGenerator, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) { tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) if err != nil { return nil, err @@ -220,6 +222,7 @@ func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Conf tlsConf: tlsConf, config: config, tokenGenerator: tokenGenerator, + connIDGenerator: connIDGenerator, connHandler: connHandler, connQueue: make(chan quicConn), errorChan: make(chan struct{}), @@ -574,7 +577,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro return nil } - connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() + connID, err := s.connIDGenerator.GenerateConnectionID() if err != nil { return err } @@ -603,6 +606,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro hdr.DestConnectionID, hdr.SrcConnectionID, connID, + s.connIDGenerator, s.connHandler.GetStatelessResetToken(connID), s.config, s.tlsConf, @@ -669,7 +673,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the connection. (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) - srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() + srcConnID, err := s.connIDGenerator.GenerateConnectionID() if err != nil { return err } diff --git a/server_test.go b/server_test.go index 658bb45b..7a17ffd5 100644 --- a/server_test.go +++ b/server_test.go @@ -286,6 +286,7 @@ var _ = Describe("Server", func() { clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + _ ConnectionIDGenerator, tokenP protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -488,6 +489,7 @@ var _ = Describe("Server", func() { clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + _ ConnectionIDGenerator, tokenP protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -547,6 +549,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -600,6 +603,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -631,6 +635,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -702,6 +707,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -1009,6 +1015,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -1082,6 +1089,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -1124,6 +1132,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -1187,6 +1196,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -1309,6 +1319,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, diff --git a/transport.go b/transport.go index bf902fbc..ebb144f4 100644 --- a/transport.go +++ b/transport.go @@ -59,6 +59,9 @@ type Transport struct { // Set in init. // If no ConnectionIDGenerator is set, this is the ConnectionIDLength. connIDLen int + // Set in init. + // If no ConnectionIDGenerator is set, this is set to a default. + connIDGenerator ConnectionIDGenerator server unknownPacketHandler @@ -92,10 +95,10 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) return nil, errListenerAlreadySet } conf = populateServerConfig(conf) - if err := t.init(conf); err != nil { + if err := t.init(conf, true); err != nil { return nil, err } - s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, false) + s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, false) if err != nil { return nil, err } @@ -121,10 +124,10 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen return nil, errListenerAlreadySet } conf = populateServerConfig(conf) - if err := t.init(conf); err != nil { + if err := t.init(conf, true); err != nil { return nil, err } - s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, true) + s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, true) if err != nil { return nil, err } @@ -137,15 +140,15 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config if err := validateConfig(conf); err != nil { return nil, err } - conf = populateClientConfig(conf, t.createdConn) - if err := t.init(conf); err != nil { + conf = populateConfig(conf) + if err := t.init(conf, false); err != nil { return nil, err } var onClose func() if t.isSingleUse { onClose = func() { t.Close() } } - return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn) + return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn) } // DialEarly dials a new connection, attempting to use 0-RTT if possible. @@ -153,15 +156,15 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C if err := validateConfig(conf); err != nil { return nil, err } - conf = populateClientConfig(conf, t.createdConn) - if err := t.init(conf); err != nil { + conf = populateConfig(conf) + if err := t.init(conf, false); err != nil { return nil, err } var onClose func() if t.isSingleUse { onClose = func() { t.Close() } } - return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn) + return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn) } func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { @@ -197,7 +200,7 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { // only print warnings about the UDP receive buffer size once var receiveBufferWarningOnce sync.Once -func (t *Transport) init(conf *Config) error { +func (t *Transport) init(conf *Config, isServer bool) error { t.initOnce.Do(func() { getMultiplexer().AddConn(t.Conn) @@ -208,9 +211,6 @@ func (t *Transport) init(conf *Config) error { } t.Tracer = conf.Tracer - t.ConnectionIDLength = conf.ConnectionIDLength - t.ConnectionIDGenerator = conf.ConnectionIDGenerator - t.logger = utils.DefaultLogger // TODO: make this configurable t.conn = conn t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) @@ -219,9 +219,15 @@ func (t *Transport) init(conf *Config) error { t.closeQueue = make(chan closePacket, 4) if t.ConnectionIDGenerator != nil { + t.connIDGenerator = t.ConnectionIDGenerator t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen() } else { - t.connIDLen = t.ConnectionIDLength + connIDLen := t.ConnectionIDLength + if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) { + connIDLen = protocol.DefaultConnectionIDLength + } + t.connIDLen = connIDLen + t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen} } go t.listen(conn) diff --git a/transport_test.go b/transport_test.go index c404a3b9..19ef54f2 100644 --- a/transport_test.go +++ b/transport_test.go @@ -61,7 +61,7 @@ var _ = Describe("Transport", func() { It("handles packets for different packet handlers on the same packet conn", func() { packetChan := make(chan packetToRead) tr := &Transport{Conn: newMockPacketConn(packetChan)} - tr.init(&Config{}) + tr.init(&Config{}, true) phm := NewMockPacketHandlerManager(mockCtrl) tr.handlerMap = phm connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) @@ -126,9 +126,10 @@ var _ = Describe("Transport", func() { packetChan := make(chan packetToRead) tracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ - Conn: newMockPacketConn(packetChan), + Conn: newMockPacketConn(packetChan), + ConnectionIDLength: 10, } - tr.init(&Config{Tracer: tracer, ConnectionIDLength: 10}) + tr.init(&Config{Tracer: tracer}, true) dropped := make(chan struct{}) tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) packetChan <- packetToRead{ @@ -147,7 +148,7 @@ var _ = Describe("Transport", func() { tr := Transport{Conn: newMockPacketConn(packetChan)} defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}) + tr.init(&Config{}, true) tr.handlerMap = phm done := make(chan struct{}) @@ -165,7 +166,7 @@ var _ = Describe("Transport", func() { tr := Transport{Conn: newMockPacketConn(packetChan)} defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}) + tr.init(&Config{}, true) tr.handlerMap = phm tempErr := deadlineError{} @@ -183,11 +184,13 @@ var _ = Describe("Transport", func() { It("handles short header packets resets", func() { connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) packetChan := make(chan packetToRead) - tr := Transport{Conn: newMockPacketConn(packetChan)} - tr.init(&Config{ConnectionIDLength: connID.Len()}) + tr := Transport{ + Conn: newMockPacketConn(packetChan), + ConnectionIDLength: connID.Len(), + } + tr.init(&Config{}, true) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}) tr.handlerMap = phm var token protocol.StatelessResetToken @@ -218,10 +221,9 @@ var _ = Describe("Transport", func() { connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) packetChan := make(chan packetToRead) tr := Transport{Conn: newMockPacketConn(packetChan)} - tr.init(&Config{ConnectionIDLength: connID.Len()}) + tr.init(&Config{}, true) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}) tr.handlerMap = phm var token protocol.StatelessResetToken @@ -251,13 +253,13 @@ var _ = Describe("Transport", func() { packetChan := make(chan packetToRead) conn := newMockPacketConn(packetChan) tr := Transport{ - Conn: conn, - StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, + Conn: conn, + StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, + ConnectionIDLength: connID.Len(), } - tr.init(&Config{ConnectionIDLength: connID.Len()}) + tr.init(&Config{}, true) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}) tr.handlerMap = phm var b []byte