diff --git a/client.go b/client.go index be8390e6..3d7f1505 100644 --- a/client.go +++ b/client.go @@ -42,11 +42,8 @@ type client struct { logger utils.Logger } -var ( - // make it possible to mock connection ID generation in the tests - generateConnectionID = protocol.GenerateConnectionID - generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial -) +// make it possible to mock connection ID for initial generation in the tests +var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial // DialAddr establishes a new QUIC connection to a server. // It uses a new UDP connection and closes this connection when the QUIC connection is closed. @@ -193,7 +190,7 @@ func dialContext( return nil, err } config = populateClientConfig(config, createdPacketConn) - packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) + packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) if err != nil { return nil, err } @@ -256,7 +253,7 @@ func newClient( } } - srcConnID, err := generateConnectionID(config.ConnectionIDLength) + srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID() if err != nil { return nil, err } diff --git a/client_test.go b/client_test.go index c7fbc0d3..84b55d0f 100644 --- a/client_test.go +++ b/client_test.go @@ -88,22 +88,16 @@ var _ = Describe("Client", func() { }) Context("Dialing", func() { - var origGenerateConnectionID func(int) (protocol.ConnectionID, error) var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error) BeforeEach(func() { - origGenerateConnectionID = generateConnectionID origGenerateConnectionIDForInitial = generateConnectionIDForInitial - generateConnectionID = func(int) (protocol.ConnectionID, error) { - return connID, nil - } generateConnectionIDForInitial = func() (protocol.ConnectionID, error) { return connID, nil } }) AfterEach(func() { - generateConnectionID = origGenerateConnectionID generateConnectionIDForInitial = origGenerateConnectionIDForInitial }) @@ -524,7 +518,7 @@ var _ = Describe("Client", func() { manager.EXPECT().Add(connID, gomock.Any()) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} + config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}} c := make(chan struct{}) var cconn sendConn var version protocol.VersionNumber @@ -602,6 +596,7 @@ var _ = Describe("Client", func() { return conn } + config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}} tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) _, err := DialAddr("localhost:7890", tlsConf, config) Expect(err).ToNot(HaveOccurred()) @@ -609,3 +604,15 @@ var _ = Describe("Client", func() { }) }) }) + +type mockedConnIDGenerator struct { + ConnID protocol.ConnectionID +} + +func (m *mockedConnIDGenerator) GenerateConnectionID() ([]byte, error) { + return m.ConnID, nil +} + +func (m *mockedConnIDGenerator) ConnectionIDLen() int { + return m.ConnID.Len() +} diff --git a/config.go b/config.go index 93735e72..0e8cc98a 100644 --- a/config.go +++ b/config.go @@ -35,10 +35,7 @@ func validateConfig(config *Config) error { // populateServerConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateServerConfig(config *Config) *Config { - config = populateConfig(config) - if config.ConnectionIDLength == 0 { - config.ConnectionIDLength = protocol.DefaultConnectionIDLength - } + config = populateConfig(config, protocol.DefaultConnectionIDLength) if config.MaxTokenAge == 0 { config.MaxTokenAge = protocol.TokenValidity } @@ -54,14 +51,16 @@ func populateServerConfig(config *Config) *Config { // populateClientConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateClientConfig(config *Config, createdPacketConn bool) *Config { - config = populateConfig(config) - if config.ConnectionIDLength == 0 && !createdPacketConn { - config.ConnectionIDLength = protocol.DefaultConnectionIDLength + defaultConnIDLen := protocol.DefaultConnectionIDLength + if createdPacketConn { + defaultConnIDLen = 0 } + + config = populateConfig(config, defaultConnIDLen) return config } -func populateConfig(config *Config) *Config { +func populateConfig(config *Config, defaultConnIDLen int) *Config { if config == nil { config = &Config{} } @@ -69,6 +68,10 @@ func populateConfig(config *Config) *Config { if len(versions) == 0 { versions = protocol.SupportedVersions } + conIDLen := config.ConnectionIDLength + if config.ConnectionIDLength == 0 { + conIDLen = defaultConnIDLen + } handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout if config.HandshakeIdleTimeout != 0 { handshakeIdleTimeout = config.HandshakeIdleTimeout @@ -105,6 +108,10 @@ func populateConfig(config *Config) *Config { } else if maxIncomingUniStreams < 0 { maxIncomingUniStreams = 0 } + connIDGenerator := config.ConnectionIDGenerator + if connIDGenerator == nil { + connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen} + } return &Config{ Versions: versions, @@ -121,7 +128,8 @@ func populateConfig(config *Config) *Config { AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, - ConnectionIDLength: config.ConnectionIDLength, + ConnectionIDLength: conIDLen, + ConnectionIDGenerator: connIDGenerator, StatelessResetKey: config.StatelessResetKey, TokenStore: config.TokenStore, EnableDatagrams: config.EnableDatagrams, diff --git a/config_test.go b/config_test.go index f4cfe41d..0317253f 100644 --- a/config_test.go +++ b/config_test.go @@ -51,6 +51,8 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) case "ConnectionIDLength": f.Set(reflect.ValueOf(8)) + case "ConnectionIDGenerator": + f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength})) case "HandshakeIdleTimeout": f.Set(reflect.ValueOf(time.Second)) case "MaxIdleTimeout": @@ -140,18 +142,18 @@ var _ = Describe("Config", func() { var calledAddrValidation bool c1 := &Config{} c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true } - c2 := populateConfig(c1) + c2 := populateConfig(c1, protocol.DefaultConnectionIDLength) c2.RequireAddressValidation(&net.UDPAddr{}) Expect(calledAddrValidation).To(BeTrue()) }) It("copies non-function fields", func() { c := configWithNonZeroNonFunctionFields() - Expect(populateConfig(c)).To(Equal(c)) + Expect(populateConfig(c, protocol.DefaultConnectionIDLength)).To(Equal(c)) }) It("populates empty fields with default values", func() { - c := populateConfig(&Config{}) + c := populateConfig(&Config{}, protocol.DefaultConnectionIDLength) Expect(c.Versions).To(Equal(protocol.SupportedVersions)) Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData)) diff --git a/conn_id_generator.go b/conn_id_generator.go index 570045e6..0421d678 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -10,7 +10,7 @@ import ( ) type connIDGenerator struct { - connIDLen int + generator ConnectionIDGenerator highestSeq uint64 activeSrcConnIDs map[uint64]protocol.ConnectionID @@ -35,10 +35,11 @@ func newConnIDGenerator( retireConnectionID func(protocol.ConnectionID), replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte), queueControlFrame func(wire.Frame), + generator ConnectionIDGenerator, version protocol.VersionNumber, ) *connIDGenerator { m := &connIDGenerator{ - connIDLen: initialConnectionID.Len(), + generator: generator, activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), addConnectionID: addConnectionID, getStatelessResetToken: getStatelessResetToken, @@ -54,7 +55,7 @@ func newConnIDGenerator( } func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { - if m.connIDLen == 0 { + if m.generator.ConnectionIDLen() == 0 { return nil } // The active_connection_id_limit transport parameter is the number of @@ -99,7 +100,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect } func (m *connIDGenerator) issueNewConnID() error { - connID, err := protocol.GenerateConnectionID(m.connIDLen) + connID, err := m.generator.GenerateConnectionID() if err != nil { return err } diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 98f1eb7d..167a70d6 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -44,6 +44,7 @@ var _ = Describe("Connection ID Generator", func() { replacedWithClosed = append(replacedWithClosed, cs...) }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, + &protocol.DefaultConnectionIDGenerator{ConnLen: initialConnID.Len()}, protocol.VersionDraft29, ) }) diff --git a/connection.go b/connection.go index b8c4092d..015d98ad 100644 --- a/connection.go +++ b/connection.go @@ -281,6 +281,7 @@ var newConnection = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, + s.config.ConnectionIDGenerator, s.version, ) s.preSetup() @@ -411,6 +412,7 @@ var newClientConnection = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, + s.config.ConnectionIDGenerator, s.version, ) s.preSetup() diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index 6da758b0..dc47aa86 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -2,9 +2,10 @@ package self_test import ( "context" + "crypto/rand" "fmt" "io" - "math/rand" + mrand "math/rand" "net" "github.com/lucas-clemente/quic-go" @@ -14,9 +15,26 @@ import ( . "github.com/onsi/gomega" ) +type connIDGenerator struct { + length int +} + +func (c *connIDGenerator) GenerateConnectionID() ([]byte, error) { + b := make([]byte, c.length) + _, err := rand.Read(b) + if err != nil { + fmt.Fprintf(GinkgoWriter, "generating conn ID failed: %s", err) + } + return b, nil +} + +func (c *connIDGenerator) ConnectionIDLen() int { + return c.length +} + var _ = Describe("Connection ID lengths tests", func() { randomConnIDLen := func() int { - return 4 + int(rand.Int31n(15)) + return 4 + int(mrand.Int31n(15)) } runServer := func(conf *quic.Config) quic.Listener { @@ -87,4 +105,19 @@ var _ = Describe("Connection ID lengths tests", func() { defer ln.Close() runClient(ln.Addr(), clientConf) }) + + It("downloads a file when both client and server use a custom connection ID generator", func() { + serverConf := getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{protocol.VersionTLS}, + ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()}, + }) + clientConf := getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{protocol.VersionTLS}, + ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()}, + }) + + ln := runServer(serverConf) + defer ln.Close() + runClient(ln.Addr(), clientConf) + }) }) diff --git a/interface.go b/interface.go index c19a5c93..c400f986 100644 --- a/interface.go +++ b/interface.go @@ -201,6 +201,24 @@ type EarlyConnection interface { NextConnection() Connection } +// A ConnectionIDGenerator is an interface that allows clients to implement their own format +// for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets. +// +// Connection IDs generated by an implementation should always produce IDs of constant size. +type ConnectionIDGenerator interface { + // GenerateConnectionID generates a new ConnectionID. + // Generated ConnectionIDs should be unique and observers should not be able to correlate two ConnectionIDs. + GenerateConnectionID() ([]byte, error) + + // ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of + // this interface. + // Effectively, this means that implementations of ConnectionIDGenerator must always return constant-size + // connection IDs. Valid lengths are between 0 and 20 and calls to GenerateConnectionID. + // 0-length ConnectionsIDs can be used when an endpoint (server or client) does not require multiplexing connections + // in the presence of a connection migration environment. + ConnectionIDLen() int +} + // Config contains all configuration data needed for a QUIC server or client. type Config struct { // The QUIC versions that can be negotiated. @@ -213,6 +231,11 @@ type Config struct { // If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used. // When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call. ConnectionIDLength int + // An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection. + // The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers. + // By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used. + // Otherwise, if one is provided, then ConnectionIDLength is ignored. + ConnectionIDGenerator ConnectionIDGenerator // HandshakeIdleTimeout is the idle timeout before completion of the handshake. // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. // If this value is zero, the timeout is set to 5 seconds. diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go index 3aec2cd3..7ae7d9dc 100644 --- a/internal/protocol/connection_id.go +++ b/internal/protocol/connection_id.go @@ -67,3 +67,15 @@ func (c ConnectionID) String() string { } return fmt.Sprintf("%x", c.Bytes()) } + +type DefaultConnectionIDGenerator struct { + ConnLen int +} + +func (d *DefaultConnectionIDGenerator) GenerateConnectionID() ([]byte, error) { + return GenerateConnectionID(d.ConnLen) +} + +func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int { + return d.ConnLen +} diff --git a/server.go b/server.go index 726adcfa..ac68681c 100644 --- a/server.go +++ b/server.go @@ -191,7 +191,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl } } - connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) + connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) if err != nil { return nil, err } @@ -322,7 +322,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s } // If we're creating a new connection, the packet will be passed to the connection. // The header will then be parsed again. - hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength) + hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDGenerator.ConnectionIDLen()) if err != nil && err != wire.ErrUnsupportedVersion { if s.config.Tracer != nil { s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) @@ -463,11 +463,11 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro return nil } - connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) + connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() if err != nil { return err } - s.logger.Debugf("Changing connection ID to %s.", connID) + s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(connID)) var conn quicConn tracingID := nextConnTracingID() if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { @@ -549,7 +549,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the connection. (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) - srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) + srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() if err != nil { return err } @@ -565,7 +565,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack replyHdr.DestConnectionID = hdr.SrcConnectionID replyHdr.Token = token if s.logger.Debug() { - s.logger.Debugf("Changing connection ID to %s.", srcConnID) + s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(srcConnID)) s.logger.Debugf("-> Sending Retry") replyHdr.Log(s.logger) }