add support for providing a custom Connection ID generator via Config (#3452)

* Add support for providing a custom ConnectionID generator via Config

This work makes it possible for servers or clients to control how
ConnectionIDs are generated, which in turn will force peers in the
connection to use those ConnectionIDs as destination connection IDs  when sending packets.

This is useful for scenarios where we want to perform some kind
selection on the QUIC packets at the L4 level.

* add more doc

* refactor populate config to not use provided config

* add an integration test for custom connection ID generators

* fix linter warnings

Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
João Oliveirinha 2022-08-24 12:06:16 +01:00 committed by GitHub
parent 034fc4e09a
commit 66f6fe0b71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 124 additions and 38 deletions

View file

@ -42,11 +42,8 @@ type client struct {
logger utils.Logger
}
var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = protocol.GenerateConnectionID
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
)
// make it possible to mock connection ID for initial generation in the tests
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
// DialAddr establishes a new QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
@ -193,7 +190,7 @@ func dialContext(
return nil, err
}
config = populateClientConfig(config, createdPacketConn)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
if err != nil {
return nil, err
}
@ -256,7 +253,7 @@ func newClient(
}
}
srcConnID, err := generateConnectionID(config.ConnectionIDLength)
srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}

View file

@ -88,22 +88,16 @@ var _ = Describe("Client", func() {
})
Context("Dialing", func() {
var origGenerateConnectionID func(int) (protocol.ConnectionID, error)
var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error)
BeforeEach(func() {
origGenerateConnectionID = generateConnectionID
origGenerateConnectionIDForInitial = generateConnectionIDForInitial
generateConnectionID = func(int) (protocol.ConnectionID, error) {
return connID, nil
}
generateConnectionIDForInitial = func() (protocol.ConnectionID, error) {
return connID, nil
}
})
AfterEach(func() {
generateConnectionID = origGenerateConnectionID
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
})
@ -524,7 +518,7 @@ var _ = Describe("Client", func() {
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}}
c := make(chan struct{})
var cconn sendConn
var version protocol.VersionNumber
@ -602,6 +596,7 @@ var _ = Describe("Client", func() {
return conn
}
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := DialAddr("localhost:7890", tlsConf, config)
Expect(err).ToNot(HaveOccurred())
@ -609,3 +604,15 @@ var _ = Describe("Client", func() {
})
})
})
type mockedConnIDGenerator struct {
ConnID protocol.ConnectionID
}
func (m *mockedConnIDGenerator) GenerateConnectionID() ([]byte, error) {
return m.ConnID, nil
}
func (m *mockedConnIDGenerator) ConnectionIDLen() int {
return m.ConnID.Len()
}

View file

@ -35,10 +35,7 @@ func validateConfig(config *Config) error {
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config {
config = populateConfig(config)
if config.ConnectionIDLength == 0 {
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
}
config = populateConfig(config, protocol.DefaultConnectionIDLength)
if config.MaxTokenAge == 0 {
config.MaxTokenAge = protocol.TokenValidity
}
@ -54,14 +51,16 @@ func populateServerConfig(config *Config) *Config {
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
config = populateConfig(config)
if config.ConnectionIDLength == 0 && !createdPacketConn {
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
defaultConnIDLen := protocol.DefaultConnectionIDLength
if createdPacketConn {
defaultConnIDLen = 0
}
config = populateConfig(config, defaultConnIDLen)
return config
}
func populateConfig(config *Config) *Config {
func populateConfig(config *Config, defaultConnIDLen int) *Config {
if config == nil {
config = &Config{}
}
@ -69,6 +68,10 @@ func populateConfig(config *Config) *Config {
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
conIDLen := config.ConnectionIDLength
if config.ConnectionIDLength == 0 {
conIDLen = defaultConnIDLen
}
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
if config.HandshakeIdleTimeout != 0 {
handshakeIdleTimeout = config.HandshakeIdleTimeout
@ -105,6 +108,10 @@ func populateConfig(config *Config) *Config {
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
connIDGenerator := config.ConnectionIDGenerator
if connIDGenerator == nil {
connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen}
}
return &Config{
Versions: versions,
@ -121,7 +128,8 @@ func populateConfig(config *Config) *Config {
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
ConnectionIDLength: config.ConnectionIDLength,
ConnectionIDLength: conIDLen,
ConnectionIDGenerator: connIDGenerator,
StatelessResetKey: config.StatelessResetKey,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,

View file

@ -51,6 +51,8 @@ var _ = Describe("Config", func() {
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
case "ConnectionIDLength":
f.Set(reflect.ValueOf(8))
case "ConnectionIDGenerator":
f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength}))
case "HandshakeIdleTimeout":
f.Set(reflect.ValueOf(time.Second))
case "MaxIdleTimeout":
@ -140,18 +142,18 @@ var _ = Describe("Config", func() {
var calledAddrValidation bool
c1 := &Config{}
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
c2 := populateConfig(c1)
c2 := populateConfig(c1, protocol.DefaultConnectionIDLength)
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
})
It("copies non-function fields", func() {
c := configWithNonZeroNonFunctionFields()
Expect(populateConfig(c)).To(Equal(c))
Expect(populateConfig(c, protocol.DefaultConnectionIDLength)).To(Equal(c))
})
It("populates empty fields with default values", func() {
c := populateConfig(&Config{})
c := populateConfig(&Config{}, protocol.DefaultConnectionIDLength)
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))

View file

@ -10,7 +10,7 @@ import (
)
type connIDGenerator struct {
connIDLen int
generator ConnectionIDGenerator
highestSeq uint64
activeSrcConnIDs map[uint64]protocol.ConnectionID
@ -35,10 +35,11 @@ func newConnIDGenerator(
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator,
version protocol.VersionNumber,
) *connIDGenerator {
m := &connIDGenerator{
connIDLen: initialConnectionID.Len(),
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
addConnectionID: addConnectionID,
getStatelessResetToken: getStatelessResetToken,
@ -54,7 +55,7 @@ func newConnIDGenerator(
}
func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
if m.connIDLen == 0 {
if m.generator.ConnectionIDLen() == 0 {
return nil
}
// The active_connection_id_limit transport parameter is the number of
@ -99,7 +100,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
}
func (m *connIDGenerator) issueNewConnID() error {
connID, err := protocol.GenerateConnectionID(m.connIDLen)
connID, err := m.generator.GenerateConnectionID()
if err != nil {
return err
}

View file

@ -44,6 +44,7 @@ var _ = Describe("Connection ID Generator", func() {
replacedWithClosed = append(replacedWithClosed, cs...)
},
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
&protocol.DefaultConnectionIDGenerator{ConnLen: initialConnID.Len()},
protocol.VersionDraft29,
)
})

View file

@ -281,6 +281,7 @@ var newConnection = func(
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
s.config.ConnectionIDGenerator,
s.version,
)
s.preSetup()
@ -411,6 +412,7 @@ var newClientConnection = func(
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
s.config.ConnectionIDGenerator,
s.version,
)
s.preSetup()

View file

@ -2,9 +2,10 @@ package self_test
import (
"context"
"crypto/rand"
"fmt"
"io"
"math/rand"
mrand "math/rand"
"net"
"github.com/lucas-clemente/quic-go"
@ -14,9 +15,26 @@ import (
. "github.com/onsi/gomega"
)
type connIDGenerator struct {
length int
}
func (c *connIDGenerator) GenerateConnectionID() ([]byte, error) {
b := make([]byte, c.length)
_, err := rand.Read(b)
if err != nil {
fmt.Fprintf(GinkgoWriter, "generating conn ID failed: %s", err)
}
return b, nil
}
func (c *connIDGenerator) ConnectionIDLen() int {
return c.length
}
var _ = Describe("Connection ID lengths tests", func() {
randomConnIDLen := func() int {
return 4 + int(rand.Int31n(15))
return 4 + int(mrand.Int31n(15))
}
runServer := func(conf *quic.Config) quic.Listener {
@ -87,4 +105,19 @@ var _ = Describe("Connection ID lengths tests", func() {
defer ln.Close()
runClient(ln.Addr(), clientConf)
})
It("downloads a file when both client and server use a custom connection ID generator", func() {
serverConf := getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
})
clientConf := getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
})
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), clientConf)
})
})

View file

@ -201,6 +201,24 @@ type EarlyConnection interface {
NextConnection() Connection
}
// A ConnectionIDGenerator is an interface that allows clients to implement their own format
// for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets.
//
// Connection IDs generated by an implementation should always produce IDs of constant size.
type ConnectionIDGenerator interface {
// GenerateConnectionID generates a new ConnectionID.
// Generated ConnectionIDs should be unique and observers should not be able to correlate two ConnectionIDs.
GenerateConnectionID() ([]byte, error)
// ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of
// this interface.
// Effectively, this means that implementations of ConnectionIDGenerator must always return constant-size
// connection IDs. Valid lengths are between 0 and 20 and calls to GenerateConnectionID.
// 0-length ConnectionsIDs can be used when an endpoint (server or client) does not require multiplexing connections
// in the presence of a connection migration environment.
ConnectionIDLen() int
}
// Config contains all configuration data needed for a QUIC server or client.
type Config struct {
// The QUIC versions that can be negotiated.
@ -213,6 +231,11 @@ type Config struct {
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
ConnectionIDLength int
// An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection.
// The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers.
// By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used.
// Otherwise, if one is provided, then ConnectionIDLength is ignored.
ConnectionIDGenerator ConnectionIDGenerator
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
// Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted.
// If this value is zero, the timeout is set to 5 seconds.

View file

@ -67,3 +67,15 @@ func (c ConnectionID) String() string {
}
return fmt.Sprintf("%x", c.Bytes())
}
type DefaultConnectionIDGenerator struct {
ConnLen int
}
func (d *DefaultConnectionIDGenerator) GenerateConnectionID() ([]byte, error) {
return GenerateConnectionID(d.ConnLen)
}
func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int {
return d.ConnLen
}

View file

@ -191,7 +191,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
}
}
connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer)
connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
if err != nil {
return nil, err
}
@ -322,7 +322,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
}
// If we're creating a new connection, the packet will be passed to the connection.
// The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength)
hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDGenerator.ConnectionIDLen())
if err != nil && err != wire.ErrUnsupportedVersion {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
@ -463,11 +463,11 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
return nil
}
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
s.logger.Debugf("Changing connection ID to %s.", connID)
s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(connID))
var conn quicConn
tracingID := nextConnTracingID()
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler {
@ -549,7 +549,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
// Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the connection.
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
@ -565,7 +565,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
replyHdr.DestConnectionID = hdr.SrcConnectionID
replyHdr.Token = token
if s.logger.Debug() {
s.logger.Debugf("Changing connection ID to %s.", srcConnID)
s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(srcConnID))
s.logger.Debugf("-> Sending Retry")
replyHdr.Log(s.logger)
}