mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
add support for providing a custom Connection ID generator via Config (#3452)
* Add support for providing a custom ConnectionID generator via Config This work makes it possible for servers or clients to control how ConnectionIDs are generated, which in turn will force peers in the connection to use those ConnectionIDs as destination connection IDs when sending packets. This is useful for scenarios where we want to perform some kind selection on the QUIC packets at the L4 level. * add more doc * refactor populate config to not use provided config * add an integration test for custom connection ID generators * fix linter warnings Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
parent
034fc4e09a
commit
66f6fe0b71
11 changed files with 124 additions and 38 deletions
11
client.go
11
client.go
|
@ -42,11 +42,8 @@ type client struct {
|
|||
logger utils.Logger
|
||||
}
|
||||
|
||||
var (
|
||||
// make it possible to mock connection ID generation in the tests
|
||||
generateConnectionID = protocol.GenerateConnectionID
|
||||
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||
)
|
||||
// make it possible to mock connection ID for initial generation in the tests
|
||||
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||
|
||||
// DialAddr establishes a new QUIC connection to a server.
|
||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||
|
@ -193,7 +190,7 @@ func dialContext(
|
|||
return nil, err
|
||||
}
|
||||
config = populateClientConfig(config, createdPacketConn)
|
||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer)
|
||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -256,7 +253,7 @@ func newClient(
|
|||
}
|
||||
}
|
||||
|
||||
srcConnID, err := generateConnectionID(config.ConnectionIDLength)
|
||||
srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -88,22 +88,16 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
Context("Dialing", func() {
|
||||
var origGenerateConnectionID func(int) (protocol.ConnectionID, error)
|
||||
var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error)
|
||||
|
||||
BeforeEach(func() {
|
||||
origGenerateConnectionID = generateConnectionID
|
||||
origGenerateConnectionIDForInitial = generateConnectionIDForInitial
|
||||
generateConnectionID = func(int) (protocol.ConnectionID, error) {
|
||||
return connID, nil
|
||||
}
|
||||
generateConnectionIDForInitial = func() (protocol.ConnectionID, error) {
|
||||
return connID, nil
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
generateConnectionID = origGenerateConnectionID
|
||||
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
|
||||
})
|
||||
|
||||
|
@ -524,7 +518,7 @@ var _ = Describe("Client", func() {
|
|||
manager.EXPECT().Add(connID, gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}}
|
||||
c := make(chan struct{})
|
||||
var cconn sendConn
|
||||
var version protocol.VersionNumber
|
||||
|
@ -602,6 +596,7 @@ var _ = Describe("Client", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}}
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
_, err := DialAddr("localhost:7890", tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -609,3 +604,15 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
type mockedConnIDGenerator struct {
|
||||
ConnID protocol.ConnectionID
|
||||
}
|
||||
|
||||
func (m *mockedConnIDGenerator) GenerateConnectionID() ([]byte, error) {
|
||||
return m.ConnID, nil
|
||||
}
|
||||
|
||||
func (m *mockedConnIDGenerator) ConnectionIDLen() int {
|
||||
return m.ConnID.Len()
|
||||
}
|
||||
|
|
26
config.go
26
config.go
|
@ -35,10 +35,7 @@ func validateConfig(config *Config) error {
|
|||
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateServerConfig(config *Config) *Config {
|
||||
config = populateConfig(config)
|
||||
if config.ConnectionIDLength == 0 {
|
||||
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
config = populateConfig(config, protocol.DefaultConnectionIDLength)
|
||||
if config.MaxTokenAge == 0 {
|
||||
config.MaxTokenAge = protocol.TokenValidity
|
||||
}
|
||||
|
@ -54,14 +51,16 @@ func populateServerConfig(config *Config) *Config {
|
|||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||
config = populateConfig(config)
|
||||
if config.ConnectionIDLength == 0 && !createdPacketConn {
|
||||
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
|
||||
defaultConnIDLen := protocol.DefaultConnectionIDLength
|
||||
if createdPacketConn {
|
||||
defaultConnIDLen = 0
|
||||
}
|
||||
|
||||
config = populateConfig(config, defaultConnIDLen)
|
||||
return config
|
||||
}
|
||||
|
||||
func populateConfig(config *Config) *Config {
|
||||
func populateConfig(config *Config, defaultConnIDLen int) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
|
@ -69,6 +68,10 @@ func populateConfig(config *Config) *Config {
|
|||
if len(versions) == 0 {
|
||||
versions = protocol.SupportedVersions
|
||||
}
|
||||
conIDLen := config.ConnectionIDLength
|
||||
if config.ConnectionIDLength == 0 {
|
||||
conIDLen = defaultConnIDLen
|
||||
}
|
||||
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
|
||||
if config.HandshakeIdleTimeout != 0 {
|
||||
handshakeIdleTimeout = config.HandshakeIdleTimeout
|
||||
|
@ -105,6 +108,10 @@ func populateConfig(config *Config) *Config {
|
|||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
connIDGenerator := config.ConnectionIDGenerator
|
||||
if connIDGenerator == nil {
|
||||
connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen}
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Versions: versions,
|
||||
|
@ -121,7 +128,8 @@ func populateConfig(config *Config) *Config {
|
|||
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
|
||||
MaxIncomingStreams: maxIncomingStreams,
|
||||
MaxIncomingUniStreams: maxIncomingUniStreams,
|
||||
ConnectionIDLength: config.ConnectionIDLength,
|
||||
ConnectionIDLength: conIDLen,
|
||||
ConnectionIDGenerator: connIDGenerator,
|
||||
StatelessResetKey: config.StatelessResetKey,
|
||||
TokenStore: config.TokenStore,
|
||||
EnableDatagrams: config.EnableDatagrams,
|
||||
|
|
|
@ -51,6 +51,8 @@ var _ = Describe("Config", func() {
|
|||
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
|
||||
case "ConnectionIDLength":
|
||||
f.Set(reflect.ValueOf(8))
|
||||
case "ConnectionIDGenerator":
|
||||
f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength}))
|
||||
case "HandshakeIdleTimeout":
|
||||
f.Set(reflect.ValueOf(time.Second))
|
||||
case "MaxIdleTimeout":
|
||||
|
@ -140,18 +142,18 @@ var _ = Describe("Config", func() {
|
|||
var calledAddrValidation bool
|
||||
c1 := &Config{}
|
||||
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
|
||||
c2 := populateConfig(c1)
|
||||
c2 := populateConfig(c1, protocol.DefaultConnectionIDLength)
|
||||
c2.RequireAddressValidation(&net.UDPAddr{})
|
||||
Expect(calledAddrValidation).To(BeTrue())
|
||||
})
|
||||
|
||||
It("copies non-function fields", func() {
|
||||
c := configWithNonZeroNonFunctionFields()
|
||||
Expect(populateConfig(c)).To(Equal(c))
|
||||
Expect(populateConfig(c, protocol.DefaultConnectionIDLength)).To(Equal(c))
|
||||
})
|
||||
|
||||
It("populates empty fields with default values", func() {
|
||||
c := populateConfig(&Config{})
|
||||
c := populateConfig(&Config{}, protocol.DefaultConnectionIDLength)
|
||||
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
|
||||
Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
)
|
||||
|
||||
type connIDGenerator struct {
|
||||
connIDLen int
|
||||
generator ConnectionIDGenerator
|
||||
highestSeq uint64
|
||||
|
||||
activeSrcConnIDs map[uint64]protocol.ConnectionID
|
||||
|
@ -35,10 +35,11 @@ func newConnIDGenerator(
|
|||
retireConnectionID func(protocol.ConnectionID),
|
||||
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
|
||||
queueControlFrame func(wire.Frame),
|
||||
generator ConnectionIDGenerator,
|
||||
version protocol.VersionNumber,
|
||||
) *connIDGenerator {
|
||||
m := &connIDGenerator{
|
||||
connIDLen: initialConnectionID.Len(),
|
||||
generator: generator,
|
||||
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
|
||||
addConnectionID: addConnectionID,
|
||||
getStatelessResetToken: getStatelessResetToken,
|
||||
|
@ -54,7 +55,7 @@ func newConnIDGenerator(
|
|||
}
|
||||
|
||||
func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
|
||||
if m.connIDLen == 0 {
|
||||
if m.generator.ConnectionIDLen() == 0 {
|
||||
return nil
|
||||
}
|
||||
// The active_connection_id_limit transport parameter is the number of
|
||||
|
@ -99,7 +100,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
|
|||
}
|
||||
|
||||
func (m *connIDGenerator) issueNewConnID() error {
|
||||
connID, err := protocol.GenerateConnectionID(m.connIDLen)
|
||||
connID, err := m.generator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ var _ = Describe("Connection ID Generator", func() {
|
|||
replacedWithClosed = append(replacedWithClosed, cs...)
|
||||
},
|
||||
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
|
||||
&protocol.DefaultConnectionIDGenerator{ConnLen: initialConnID.Len()},
|
||||
protocol.VersionDraft29,
|
||||
)
|
||||
})
|
||||
|
|
|
@ -281,6 +281,7 @@ var newConnection = func(
|
|||
runner.Retire,
|
||||
runner.ReplaceWithClosed,
|
||||
s.queueControlFrame,
|
||||
s.config.ConnectionIDGenerator,
|
||||
s.version,
|
||||
)
|
||||
s.preSetup()
|
||||
|
@ -411,6 +412,7 @@ var newClientConnection = func(
|
|||
runner.Retire,
|
||||
runner.ReplaceWithClosed,
|
||||
s.queueControlFrame,
|
||||
s.config.ConnectionIDGenerator,
|
||||
s.version,
|
||||
)
|
||||
s.preSetup()
|
||||
|
|
|
@ -2,9 +2,10 @@ package self_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
|
@ -14,9 +15,26 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type connIDGenerator struct {
|
||||
length int
|
||||
}
|
||||
|
||||
func (c *connIDGenerator) GenerateConnectionID() ([]byte, error) {
|
||||
b := make([]byte, c.length)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
fmt.Fprintf(GinkgoWriter, "generating conn ID failed: %s", err)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (c *connIDGenerator) ConnectionIDLen() int {
|
||||
return c.length
|
||||
}
|
||||
|
||||
var _ = Describe("Connection ID lengths tests", func() {
|
||||
randomConnIDLen := func() int {
|
||||
return 4 + int(rand.Int31n(15))
|
||||
return 4 + int(mrand.Int31n(15))
|
||||
}
|
||||
|
||||
runServer := func(conf *quic.Config) quic.Listener {
|
||||
|
@ -87,4 +105,19 @@ var _ = Describe("Connection ID lengths tests", func() {
|
|||
defer ln.Close()
|
||||
runClient(ln.Addr(), clientConf)
|
||||
})
|
||||
|
||||
It("downloads a file when both client and server use a custom connection ID generator", func() {
|
||||
serverConf := getQuicConfig(&quic.Config{
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
|
||||
})
|
||||
clientConf := getQuicConfig(&quic.Config{
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
|
||||
})
|
||||
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
runClient(ln.Addr(), clientConf)
|
||||
})
|
||||
})
|
||||
|
|
23
interface.go
23
interface.go
|
@ -201,6 +201,24 @@ type EarlyConnection interface {
|
|||
NextConnection() Connection
|
||||
}
|
||||
|
||||
// A ConnectionIDGenerator is an interface that allows clients to implement their own format
|
||||
// for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets.
|
||||
//
|
||||
// Connection IDs generated by an implementation should always produce IDs of constant size.
|
||||
type ConnectionIDGenerator interface {
|
||||
// GenerateConnectionID generates a new ConnectionID.
|
||||
// Generated ConnectionIDs should be unique and observers should not be able to correlate two ConnectionIDs.
|
||||
GenerateConnectionID() ([]byte, error)
|
||||
|
||||
// ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of
|
||||
// this interface.
|
||||
// Effectively, this means that implementations of ConnectionIDGenerator must always return constant-size
|
||||
// connection IDs. Valid lengths are between 0 and 20 and calls to GenerateConnectionID.
|
||||
// 0-length ConnectionsIDs can be used when an endpoint (server or client) does not require multiplexing connections
|
||||
// in the presence of a connection migration environment.
|
||||
ConnectionIDLen() int
|
||||
}
|
||||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
type Config struct {
|
||||
// The QUIC versions that can be negotiated.
|
||||
|
@ -213,6 +231,11 @@ type Config struct {
|
|||
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
|
||||
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
|
||||
ConnectionIDLength int
|
||||
// An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection.
|
||||
// The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers.
|
||||
// By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used.
|
||||
// Otherwise, if one is provided, then ConnectionIDLength is ignored.
|
||||
ConnectionIDGenerator ConnectionIDGenerator
|
||||
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
|
||||
// Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted.
|
||||
// If this value is zero, the timeout is set to 5 seconds.
|
||||
|
|
|
@ -67,3 +67,15 @@ func (c ConnectionID) String() string {
|
|||
}
|
||||
return fmt.Sprintf("%x", c.Bytes())
|
||||
}
|
||||
|
||||
type DefaultConnectionIDGenerator struct {
|
||||
ConnLen int
|
||||
}
|
||||
|
||||
func (d *DefaultConnectionIDGenerator) GenerateConnectionID() ([]byte, error) {
|
||||
return GenerateConnectionID(d.ConnLen)
|
||||
}
|
||||
|
||||
func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int {
|
||||
return d.ConnLen
|
||||
}
|
||||
|
|
12
server.go
12
server.go
|
@ -191,7 +191,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
|
|||
}
|
||||
}
|
||||
|
||||
connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer)
|
||||
connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -322,7 +322,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
|||
}
|
||||
// If we're creating a new connection, the packet will be passed to the connection.
|
||||
// The header will then be parsed again.
|
||||
hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength)
|
||||
hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDGenerator.ConnectionIDLen())
|
||||
if err != nil && err != wire.ErrUnsupportedVersion {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||
|
@ -463,11 +463,11 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
return nil
|
||||
}
|
||||
|
||||
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
|
||||
connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debugf("Changing connection ID to %s.", connID)
|
||||
s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(connID))
|
||||
var conn quicConn
|
||||
tracingID := nextConnTracingID()
|
||||
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler {
|
||||
|
@ -549,7 +549,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
|
|||
// Log the Initial packet now.
|
||||
// If no Retry is sent, the packet will be logged by the connection.
|
||||
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
|
||||
srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
|
||||
srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -565,7 +565,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
|
|||
replyHdr.DestConnectionID = hdr.SrcConnectionID
|
||||
replyHdr.Token = token
|
||||
if s.logger.Debug() {
|
||||
s.logger.Debugf("Changing connection ID to %s.", srcConnID)
|
||||
s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(srcConnID))
|
||||
s.logger.Debugf("-> Sending Retry")
|
||||
replyHdr.Log(s.logger)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue