From 66f6fe0b711bcfed0e66f4690fb45f854eee59ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Oliveirinha?= Date: Wed, 24 Aug 2022 12:06:16 +0100 Subject: [PATCH] add support for providing a custom Connection ID generator via Config (#3452) * Add support for providing a custom ConnectionID generator via Config This work makes it possible for servers or clients to control how ConnectionIDs are generated, which in turn will force peers in the connection to use those ConnectionIDs as destination connection IDs when sending packets. This is useful for scenarios where we want to perform some kind selection on the QUIC packets at the L4 level. * add more doc * refactor populate config to not use provided config * add an integration test for custom connection ID generators * fix linter warnings Co-authored-by: Marten Seemann --- client.go | 11 +++----- client_test.go | 21 ++++++++++----- config.go | 26 ++++++++++++------- config_test.go | 8 +++--- conn_id_generator.go | 9 ++++--- conn_id_generator_test.go | 1 + connection.go | 2 ++ integrationtests/self/conn_id_test.go | 37 +++++++++++++++++++++++++-- interface.go | 23 +++++++++++++++++ internal/protocol/connection_id.go | 12 +++++++++ server.go | 12 ++++----- 11 files changed, 124 insertions(+), 38 deletions(-) diff --git a/client.go b/client.go index be8390e6..3d7f1505 100644 --- a/client.go +++ b/client.go @@ -42,11 +42,8 @@ type client struct { logger utils.Logger } -var ( - // make it possible to mock connection ID generation in the tests - generateConnectionID = protocol.GenerateConnectionID - generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial -) +// make it possible to mock connection ID for initial generation in the tests +var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial // DialAddr establishes a new QUIC connection to a server. // It uses a new UDP connection and closes this connection when the QUIC connection is closed. @@ -193,7 +190,7 @@ func dialContext( return nil, err } config = populateClientConfig(config, createdPacketConn) - packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) + packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) if err != nil { return nil, err } @@ -256,7 +253,7 @@ func newClient( } } - srcConnID, err := generateConnectionID(config.ConnectionIDLength) + srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID() if err != nil { return nil, err } diff --git a/client_test.go b/client_test.go index c7fbc0d3..84b55d0f 100644 --- a/client_test.go +++ b/client_test.go @@ -88,22 +88,16 @@ var _ = Describe("Client", func() { }) Context("Dialing", func() { - var origGenerateConnectionID func(int) (protocol.ConnectionID, error) var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error) BeforeEach(func() { - origGenerateConnectionID = generateConnectionID origGenerateConnectionIDForInitial = generateConnectionIDForInitial - generateConnectionID = func(int) (protocol.ConnectionID, error) { - return connID, nil - } generateConnectionIDForInitial = func() (protocol.ConnectionID, error) { return connID, nil } }) AfterEach(func() { - generateConnectionID = origGenerateConnectionID generateConnectionIDForInitial = origGenerateConnectionIDForInitial }) @@ -524,7 +518,7 @@ var _ = Describe("Client", func() { manager.EXPECT().Add(connID, gomock.Any()) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} + config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}} c := make(chan struct{}) var cconn sendConn var version protocol.VersionNumber @@ -602,6 +596,7 @@ var _ = Describe("Client", func() { return conn } + config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}} tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) _, err := DialAddr("localhost:7890", tlsConf, config) Expect(err).ToNot(HaveOccurred()) @@ -609,3 +604,15 @@ var _ = Describe("Client", func() { }) }) }) + +type mockedConnIDGenerator struct { + ConnID protocol.ConnectionID +} + +func (m *mockedConnIDGenerator) GenerateConnectionID() ([]byte, error) { + return m.ConnID, nil +} + +func (m *mockedConnIDGenerator) ConnectionIDLen() int { + return m.ConnID.Len() +} diff --git a/config.go b/config.go index 93735e72..0e8cc98a 100644 --- a/config.go +++ b/config.go @@ -35,10 +35,7 @@ func validateConfig(config *Config) error { // populateServerConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateServerConfig(config *Config) *Config { - config = populateConfig(config) - if config.ConnectionIDLength == 0 { - config.ConnectionIDLength = protocol.DefaultConnectionIDLength - } + config = populateConfig(config, protocol.DefaultConnectionIDLength) if config.MaxTokenAge == 0 { config.MaxTokenAge = protocol.TokenValidity } @@ -54,14 +51,16 @@ func populateServerConfig(config *Config) *Config { // populateClientConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateClientConfig(config *Config, createdPacketConn bool) *Config { - config = populateConfig(config) - if config.ConnectionIDLength == 0 && !createdPacketConn { - config.ConnectionIDLength = protocol.DefaultConnectionIDLength + defaultConnIDLen := protocol.DefaultConnectionIDLength + if createdPacketConn { + defaultConnIDLen = 0 } + + config = populateConfig(config, defaultConnIDLen) return config } -func populateConfig(config *Config) *Config { +func populateConfig(config *Config, defaultConnIDLen int) *Config { if config == nil { config = &Config{} } @@ -69,6 +68,10 @@ func populateConfig(config *Config) *Config { if len(versions) == 0 { versions = protocol.SupportedVersions } + conIDLen := config.ConnectionIDLength + if config.ConnectionIDLength == 0 { + conIDLen = defaultConnIDLen + } handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout if config.HandshakeIdleTimeout != 0 { handshakeIdleTimeout = config.HandshakeIdleTimeout @@ -105,6 +108,10 @@ func populateConfig(config *Config) *Config { } else if maxIncomingUniStreams < 0 { maxIncomingUniStreams = 0 } + connIDGenerator := config.ConnectionIDGenerator + if connIDGenerator == nil { + connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen} + } return &Config{ Versions: versions, @@ -121,7 +128,8 @@ func populateConfig(config *Config) *Config { AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, - ConnectionIDLength: config.ConnectionIDLength, + ConnectionIDLength: conIDLen, + ConnectionIDGenerator: connIDGenerator, StatelessResetKey: config.StatelessResetKey, TokenStore: config.TokenStore, EnableDatagrams: config.EnableDatagrams, diff --git a/config_test.go b/config_test.go index f4cfe41d..0317253f 100644 --- a/config_test.go +++ b/config_test.go @@ -51,6 +51,8 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) case "ConnectionIDLength": f.Set(reflect.ValueOf(8)) + case "ConnectionIDGenerator": + f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength})) case "HandshakeIdleTimeout": f.Set(reflect.ValueOf(time.Second)) case "MaxIdleTimeout": @@ -140,18 +142,18 @@ var _ = Describe("Config", func() { var calledAddrValidation bool c1 := &Config{} c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true } - c2 := populateConfig(c1) + c2 := populateConfig(c1, protocol.DefaultConnectionIDLength) c2.RequireAddressValidation(&net.UDPAddr{}) Expect(calledAddrValidation).To(BeTrue()) }) It("copies non-function fields", func() { c := configWithNonZeroNonFunctionFields() - Expect(populateConfig(c)).To(Equal(c)) + Expect(populateConfig(c, protocol.DefaultConnectionIDLength)).To(Equal(c)) }) It("populates empty fields with default values", func() { - c := populateConfig(&Config{}) + c := populateConfig(&Config{}, protocol.DefaultConnectionIDLength) Expect(c.Versions).To(Equal(protocol.SupportedVersions)) Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData)) diff --git a/conn_id_generator.go b/conn_id_generator.go index 570045e6..0421d678 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -10,7 +10,7 @@ import ( ) type connIDGenerator struct { - connIDLen int + generator ConnectionIDGenerator highestSeq uint64 activeSrcConnIDs map[uint64]protocol.ConnectionID @@ -35,10 +35,11 @@ func newConnIDGenerator( retireConnectionID func(protocol.ConnectionID), replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte), queueControlFrame func(wire.Frame), + generator ConnectionIDGenerator, version protocol.VersionNumber, ) *connIDGenerator { m := &connIDGenerator{ - connIDLen: initialConnectionID.Len(), + generator: generator, activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), addConnectionID: addConnectionID, getStatelessResetToken: getStatelessResetToken, @@ -54,7 +55,7 @@ func newConnIDGenerator( } func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { - if m.connIDLen == 0 { + if m.generator.ConnectionIDLen() == 0 { return nil } // The active_connection_id_limit transport parameter is the number of @@ -99,7 +100,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect } func (m *connIDGenerator) issueNewConnID() error { - connID, err := protocol.GenerateConnectionID(m.connIDLen) + connID, err := m.generator.GenerateConnectionID() if err != nil { return err } diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 98f1eb7d..167a70d6 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -44,6 +44,7 @@ var _ = Describe("Connection ID Generator", func() { replacedWithClosed = append(replacedWithClosed, cs...) }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, + &protocol.DefaultConnectionIDGenerator{ConnLen: initialConnID.Len()}, protocol.VersionDraft29, ) }) diff --git a/connection.go b/connection.go index b8c4092d..015d98ad 100644 --- a/connection.go +++ b/connection.go @@ -281,6 +281,7 @@ var newConnection = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, + s.config.ConnectionIDGenerator, s.version, ) s.preSetup() @@ -411,6 +412,7 @@ var newClientConnection = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, + s.config.ConnectionIDGenerator, s.version, ) s.preSetup() diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index 6da758b0..dc47aa86 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -2,9 +2,10 @@ package self_test import ( "context" + "crypto/rand" "fmt" "io" - "math/rand" + mrand "math/rand" "net" "github.com/lucas-clemente/quic-go" @@ -14,9 +15,26 @@ import ( . "github.com/onsi/gomega" ) +type connIDGenerator struct { + length int +} + +func (c *connIDGenerator) GenerateConnectionID() ([]byte, error) { + b := make([]byte, c.length) + _, err := rand.Read(b) + if err != nil { + fmt.Fprintf(GinkgoWriter, "generating conn ID failed: %s", err) + } + return b, nil +} + +func (c *connIDGenerator) ConnectionIDLen() int { + return c.length +} + var _ = Describe("Connection ID lengths tests", func() { randomConnIDLen := func() int { - return 4 + int(rand.Int31n(15)) + return 4 + int(mrand.Int31n(15)) } runServer := func(conf *quic.Config) quic.Listener { @@ -87,4 +105,19 @@ var _ = Describe("Connection ID lengths tests", func() { defer ln.Close() runClient(ln.Addr(), clientConf) }) + + It("downloads a file when both client and server use a custom connection ID generator", func() { + serverConf := getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{protocol.VersionTLS}, + ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()}, + }) + clientConf := getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{protocol.VersionTLS}, + ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()}, + }) + + ln := runServer(serverConf) + defer ln.Close() + runClient(ln.Addr(), clientConf) + }) }) diff --git a/interface.go b/interface.go index c19a5c93..c400f986 100644 --- a/interface.go +++ b/interface.go @@ -201,6 +201,24 @@ type EarlyConnection interface { NextConnection() Connection } +// A ConnectionIDGenerator is an interface that allows clients to implement their own format +// for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets. +// +// Connection IDs generated by an implementation should always produce IDs of constant size. +type ConnectionIDGenerator interface { + // GenerateConnectionID generates a new ConnectionID. + // Generated ConnectionIDs should be unique and observers should not be able to correlate two ConnectionIDs. + GenerateConnectionID() ([]byte, error) + + // ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of + // this interface. + // Effectively, this means that implementations of ConnectionIDGenerator must always return constant-size + // connection IDs. Valid lengths are between 0 and 20 and calls to GenerateConnectionID. + // 0-length ConnectionsIDs can be used when an endpoint (server or client) does not require multiplexing connections + // in the presence of a connection migration environment. + ConnectionIDLen() int +} + // Config contains all configuration data needed for a QUIC server or client. type Config struct { // The QUIC versions that can be negotiated. @@ -213,6 +231,11 @@ type Config struct { // If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used. // When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call. ConnectionIDLength int + // An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection. + // The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers. + // By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used. + // Otherwise, if one is provided, then ConnectionIDLength is ignored. + ConnectionIDGenerator ConnectionIDGenerator // HandshakeIdleTimeout is the idle timeout before completion of the handshake. // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. // If this value is zero, the timeout is set to 5 seconds. diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go index 3aec2cd3..7ae7d9dc 100644 --- a/internal/protocol/connection_id.go +++ b/internal/protocol/connection_id.go @@ -67,3 +67,15 @@ func (c ConnectionID) String() string { } return fmt.Sprintf("%x", c.Bytes()) } + +type DefaultConnectionIDGenerator struct { + ConnLen int +} + +func (d *DefaultConnectionIDGenerator) GenerateConnectionID() ([]byte, error) { + return GenerateConnectionID(d.ConnLen) +} + +func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int { + return d.ConnLen +} diff --git a/server.go b/server.go index 726adcfa..ac68681c 100644 --- a/server.go +++ b/server.go @@ -191,7 +191,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl } } - connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) + connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) if err != nil { return nil, err } @@ -322,7 +322,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s } // If we're creating a new connection, the packet will be passed to the connection. // The header will then be parsed again. - hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength) + hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDGenerator.ConnectionIDLen()) if err != nil && err != wire.ErrUnsupportedVersion { if s.config.Tracer != nil { s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) @@ -463,11 +463,11 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro return nil } - connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) + connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() if err != nil { return err } - s.logger.Debugf("Changing connection ID to %s.", connID) + s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(connID)) var conn quicConn tracingID := nextConnTracingID() if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { @@ -549,7 +549,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the connection. (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) - srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) + srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() if err != nil { return err } @@ -565,7 +565,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack replyHdr.DestConnectionID = hdr.SrcConnectionID replyHdr.Token = token if s.logger.Debug() { - s.logger.Debugf("Changing connection ID to %s.", srcConnID) + s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(srcConnID)) s.logger.Debugf("-> Sending Retry") replyHdr.Log(s.logger) }