remove ConnectionIDLength and ConnectionIDGenerator from the Config

This commit is contained in:
Marten Seemann 2023-04-20 11:37:56 +02:00
parent b79b532b04
commit ba942715db
17 changed files with 232 additions and 183 deletions

View file

@ -25,8 +25,9 @@ type client struct {
tlsConf *tls.Config
config *Config
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
connIDGenerator ConnectionIDGenerator
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
initialPacketNumber protocol.PacketNumber
hasNegotiatedVersion bool
@ -133,6 +134,7 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo
func dial(
ctx context.Context,
conn net.PacketConn,
connIDGenerator ConnectionIDGenerator,
packetHandlers packetHandlerManager,
addr net.Addr,
tlsConf *tls.Config,
@ -141,7 +143,7 @@ func dial(
use0RTT bool,
createdPacketConn bool,
) (quicConn, error) {
c, err := newClient(conn, addr, config, tlsConf, onClose, use0RTT, createdPacketConn)
c, err := newClient(conn, addr, connIDGenerator, config, tlsConf, onClose, use0RTT, createdPacketConn)
if err != nil {
return nil, err
}
@ -164,14 +166,23 @@ func dial(
return c.conn, nil
}
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, onClose func(), use0RTT, createdPacketConn bool) (*client, error) {
func newClient(
pconn net.PacketConn,
remoteAddr net.Addr,
connIDGenerator ConnectionIDGenerator,
config *Config,
tlsConf *tls.Config,
onClose func(),
use0RTT bool,
createdPacketConn bool,
) (*client, error) {
if tlsConf == nil {
tlsConf = &tls.Config{}
} else {
tlsConf = tlsConf.Clone()
}
srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
srcConnID, err := connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}
@ -180,6 +191,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
return nil, err
}
c := &client{
connIDGenerator: connIDGenerator,
srcConnID: srcConnID,
destConnID: destConnID,
sconn: newSendPconn(pconn, remoteAddr),
@ -203,6 +215,7 @@ func (c *client) dial(ctx context.Context) error {
c.packetHandlers,
c.destConnID,
c.srcConnID,
c.connIDGenerator,
c.config,
c.tlsConf,
c.initialPacketNumber,

View file

@ -39,6 +39,7 @@ var _ = Describe("Client", func() {
runner connRunner,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
conf *Config,
tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
@ -114,6 +115,7 @@ var _ = Describe("Client", func() {
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
@ -132,7 +134,7 @@ var _ = Describe("Client", func() {
conn.EXPECT().HandshakeComplete().Return(c)
return conn
}
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, false, false)
cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false, false)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
@ -150,6 +152,7 @@ var _ = Describe("Client", func() {
runner connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
@ -168,7 +171,7 @@ var _ = Describe("Client", func() {
return conn
}
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, true, false)
cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true, false)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
@ -186,6 +189,7 @@ var _ = Describe("Client", func() {
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
@ -203,7 +207,7 @@ var _ = Describe("Client", func() {
return conn
}
var closed bool
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, func() { closed = true }, true, false)
cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true, false)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
@ -219,16 +223,14 @@ var _ = Describe("Client", func() {
MaxIdleTimeout: 42 * time.Hour,
MaxIncomingStreams: 1234,
MaxIncomingUniStreams: 4321,
ConnectionIDLength: 13,
TokenStore: tokenStore,
EnableDatagrams: true,
}
c := populateClientConfig(config, false)
c := populateConfig(config)
Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute))
Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour))
Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
Expect(c.ConnectionIDLength).To(Equal(13))
Expect(c.TokenStore).To(Equal(tokenStore))
Expect(c.EnableDatagrams).To(BeTrue())
})
@ -238,7 +240,7 @@ var _ = Describe("Client", func() {
MaxIncomingStreams: -1,
MaxIncomingUniStreams: 4321,
}
c := populateClientConfig(config, false)
c := populateConfig(config)
Expect(c.MaxIncomingStreams).To(BeZero())
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
})
@ -248,18 +250,13 @@ var _ = Describe("Client", func() {
MaxIncomingStreams: 1234,
MaxIncomingUniStreams: -1,
}
c := populateClientConfig(config, false)
c := populateConfig(config)
Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
Expect(c.MaxIncomingUniStreams).To(BeZero())
})
It("uses 0-byte connection IDs when dialing an address", func() {
c := populateClientConfig(&Config{}, true)
Expect(c.ConnectionIDLength).To(BeZero())
})
It("fills in default values if options are not set in the Config", func() {
c := populateClientConfig(&Config{}, false)
c := populateConfig(&Config{})
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
@ -267,7 +264,7 @@ var _ = Describe("Client", func() {
})
It("creates new connections with the right parameters", func() {
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}}
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}}
c := make(chan struct{})
var cconn sendConn
var version protocol.VersionNumber
@ -278,6 +275,7 @@ var _ = Describe("Client", func() {
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
configP *Config,
_ *tls.Config,
_ protocol.PacketNumber,
@ -320,6 +318,7 @@ var _ = Describe("Client", func() {
runner connRunner,
_ protocol.ConnectionID,
connID protocol.ConnectionID,
_ ConnectionIDGenerator,
configP *Config,
_ *tls.Config,
pn protocol.PacketNumber,
@ -352,7 +351,7 @@ var _ = Describe("Client", func() {
return conn
}
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}}
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config)
Expect(err).ToNot(HaveOccurred())

View file

@ -42,7 +42,7 @@ func validateConfig(config *Config) error {
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config {
config = populateConfig(config, protocol.DefaultConnectionIDLength)
config = populateConfig(config)
if config.MaxTokenAge == 0 {
config.MaxTokenAge = protocol.TokenValidity
}
@ -55,19 +55,9 @@ func populateServerConfig(config *Config) *Config {
return config
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// populateConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
defaultConnIDLen := protocol.DefaultConnectionIDLength
if createdPacketConn {
defaultConnIDLen = 0
}
config = populateConfig(config, defaultConnIDLen)
return config
}
func populateConfig(config *Config, defaultConnIDLen int) *Config {
func populateConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
@ -75,10 +65,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
conIDLen := config.ConnectionIDLength
if config.ConnectionIDLength == 0 {
conIDLen = defaultConnIDLen
}
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
if config.HandshakeIdleTimeout != 0 {
handshakeIdleTimeout = config.HandshakeIdleTimeout
@ -115,10 +101,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
connIDGenerator := config.ConnectionIDGenerator
if connIDGenerator == nil {
connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen}
}
return &Config{
Versions: versions,
@ -135,8 +117,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
ConnectionIDLength: conIDLen,
ConnectionIDGenerator: connIDGenerator,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,

View file

@ -142,18 +142,18 @@ var _ = Describe("Config", func() {
var calledAddrValidation bool
c1 := &Config{}
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
c2 := populateConfig(c1, protocol.DefaultConnectionIDLength)
c2 := populateConfig(c1)
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
})
It("copies non-function fields", func() {
c := configWithNonZeroNonFunctionFields()
Expect(populateConfig(c, protocol.DefaultConnectionIDLength)).To(Equal(c))
Expect(populateConfig(c)).To(Equal(c))
})
It("populates empty fields with default values", func() {
c := populateConfig(&Config{}, protocol.DefaultConnectionIDLength)
c := populateConfig(&Config{})
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))
@ -168,18 +168,7 @@ var _ = Describe("Config", func() {
It("populates empty fields with default values, for the server", func() {
c := populateServerConfig(&Config{})
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
Expect(c.RequireAddressValidation).ToNot(BeNil())
})
It("sets a default connection ID length if we didn't create the conn, for the client", func() {
c := populateClientConfig(&Config{}, false)
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
})
It("doesn't set a default connection ID length if we created the conn, for the client", func() {
c := populateClientConfig(&Config{}, true)
Expect(c.ConnectionIDLength).To(BeZero())
})
})
})

View file

@ -240,6 +240,7 @@ var newConnection = func(
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
statelessResetToken protocol.StatelessResetToken,
conf *Config,
tlsConf *tls.Config,
@ -283,7 +284,7 @@ var newConnection = func(
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
s.config.ConnectionIDGenerator,
connIDGenerator,
)
s.preSetup()
s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
@ -363,6 +364,7 @@ var newClientConnection = func(
runner connRunner,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
conf *Config,
tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
@ -402,7 +404,7 @@ var newClientConnection = func(
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
s.config.ConnectionIDGenerator,
connIDGenerator,
)
s.preSetup()
s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))

View file

@ -113,6 +113,7 @@ var _ = Describe("Connection", func() {
clientDestConnID,
destConnID,
srcConnID,
&protocol.DefaultConnectionIDGenerator{},
protocol.StatelessResetToken{},
populateServerConfig(&Config{DisablePathMTUDiscovery: true}),
nil, // tls.Config
@ -2015,8 +2016,6 @@ var _ = Describe("Connection", func() {
packer.EXPECT().HandleTransportParameters(params)
packer.EXPECT().PackCoalescedPacket(false, conn.version).MaxTimes(3)
Expect(conn.earlyConnReady()).ToNot(BeClosed())
connRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2)
connRunner.EXPECT().Add(gomock.Any(), conn).Times(2)
tracer.EXPECT().ReceivedTransportParameters(params)
conn.handleTransportParameters(params)
Expect(conn.earlyConnReady()).To(BeClosed())
@ -2378,7 +2377,7 @@ var _ = Describe("Client Connection", func() {
}
BeforeEach(func() {
quicConf = populateClientConfig(&Config{}, true)
quicConf = populateConfig(&Config{})
tlsConf = nil
})
@ -2402,6 +2401,7 @@ var _ = Describe("Client Connection", func() {
connRunner,
destConnID,
protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
&protocol.DefaultConnectionIDGenerator{},
quicConf,
tlsConf,
42, // initial packet number

View file

@ -34,13 +34,23 @@ func (c *connIDGenerator) ConnectionIDLen() int {
var _ = Describe("Connection ID lengths tests", func() {
randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) }
runServer := func(conf *quic.Config) *quic.Listener {
if conf.ConnectionIDGenerator != nil {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", conf.ConnectionIDGenerator.ConnectionIDLen())))
// connIDLen is ignored when connIDGenerator is set
runServer := func(connIDLen int, connIDGenerator quic.ConnectionIDGenerator) (*quic.Listener, func()) {
if connIDGenerator != nil {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", connIDGenerator.ConnectionIDLen())))
} else {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", connIDLen)))
}
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf)
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: conn,
ConnectionIDLength: connIDLen,
ConnectionIDGenerator: connIDGenerator,
}
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
@ -59,20 +69,35 @@ var _ = Describe("Connection ID lengths tests", func() {
}()
}
}()
return ln
return ln, func() {
ln.Close()
tr.Close()
}
}
runClient := func(addr net.Addr, conf *quic.Config) {
if conf.ConnectionIDGenerator != nil {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", conf.ConnectionIDGenerator.ConnectionIDLen())))
// connIDLen is ignored when connIDGenerator is set
runClient := func(addr net.Addr, connIDLen int, connIDGenerator quic.ConnectionIDGenerator) {
if connIDGenerator != nil {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", connIDGenerator.ConnectionIDLen())))
} else {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", connIDLen)))
}
cl, err := quic.DialAddr(
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnectionIDLength: connIDLen,
ConnectionIDGenerator: connIDGenerator,
}
defer tr.Close()
cl, err := tr.Dial(
context.Background(),
fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port},
getTLSClientConfig(),
conf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer cl.CloseWithError(0, "")
@ -84,32 +109,20 @@ var _ = Describe("Connection ID lengths tests", func() {
}
It("downloads a file using a 0-byte connection ID for the client", func() {
serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()})
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), getQuicConfig(nil))
ln, closeFn := runServer(randomConnIDLen(), nil)
defer closeFn()
runClient(ln.Addr(), 0, nil)
})
It("downloads a file when both client and server use a random connection ID length", func() {
serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()})
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), getQuicConfig(nil))
ln, closeFn := runServer(randomConnIDLen(), nil)
defer closeFn()
runClient(ln.Addr(), randomConnIDLen(), nil)
})
It("downloads a file when both client and server use a custom connection ID generator", func() {
serverConf := getQuicConfig(&quic.Config{
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
})
clientConf := getQuicConfig(&quic.Config{
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
})
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), clientConf)
ln, closeFn := runServer(0, &connIDGenerator{length: randomConnIDLen()})
defer closeFn()
runClient(ln.Addr(), 0, &connIDGenerator{length: randomConnIDLen()})
})
})

View file

@ -244,7 +244,7 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
pconn, err = net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
dialer = &quic.Transport{Conn: pconn}
dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4}
})
AfterEach(func() {
@ -303,7 +303,7 @@ var _ = Describe("Handshake tests", func() {
// This should free one spot in the queue.
Expect(firstConn.CloseWithError(0, ""))
Eventually(firstConn.Context().Done()).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond))
time.Sleep(scaleDuration(200 * time.Millisecond))
// dial again, and expect that this dial succeeds
_, err = dial()

View file

@ -35,7 +35,11 @@ var _ = Describe("MITM test", func() {
Expect(err).ToNot(HaveOccurred())
serverUDPConn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
ln, err := quic.Listen(serverUDPConn, getTLSConfig(), serverConfig)
tr := &quic.Transport{
Conn: serverUDPConn,
ConnectionIDLength: connIDLen,
}
ln, err := tr.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@ -68,7 +72,7 @@ var _ = Describe("MITM test", func() {
}
BeforeEach(func() {
serverConfig = getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen})
serverConfig = getQuicConfig(nil)
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
clientUDPConn, err = net.ListenUDP("udp", addr)
@ -146,12 +150,15 @@ var _ = Describe("MITM test", func() {
defer closeFn()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
tr := &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
conn, err := tr.Dial(
context.Background(),
clientUDPConn,
raddr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())
@ -190,12 +197,15 @@ var _ = Describe("MITM test", func() {
defer closeFn()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
tr := &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
conn, err := tr.Dial(
context.Background(),
clientUDPConn,
raddr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())
@ -302,20 +312,20 @@ var _ = Describe("MITM test", func() {
const rtt = 20 * time.Millisecond
runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) {
proxyPort, closeFn := startServerAndProxy(delayCb, nil)
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
_, err = quic.Dial(
tr := &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
_, err = tr.Dial(
context.Background(),
clientUDPConn,
raddr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{
ConnectionIDLength: connIDLen,
HandshakeIdleTimeout: 2 * time.Second,
}),
getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}),
)
return closeFn, err
return func() { tr.Close(); serverCloseFn() }, err
}
// fails immediately because client connection closes when it can't find compatible version

View file

@ -137,7 +137,6 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
server, err := tr.Listen(
getTLSConfig(),
getQuicConfig(nil),

View file

@ -28,8 +28,9 @@ var _ = Describe("Stateless Resets", func() {
c, err := net.ListenUDP("udp", nil)
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: c,
StatelessResetKey: &statelessResetKey,
Conn: c,
StatelessResetKey: &statelessResetKey,
ConnectionIDLength: connIDLen,
}
defer tr.Close()
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
@ -61,14 +62,21 @@ var _ = Describe("Stateless Resets", func() {
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
conn, err := quic.DialAddr(
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
cl := &quic.Transport{
Conn: udpConn,
ConnectionIDLength: connIDLen,
}
defer cl.Close()
conn, err := cl.Dial(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
getTLSClientConfig(),
getQuicConfig(&quic.Config{
ConnectionIDLength: connIDLen,
MaxIdleTimeout: 2 * time.Second,
}),
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
@ -86,8 +94,9 @@ var _ = Describe("Stateless Resets", func() {
// We need to create a new Transport here, since the old one is still sending out
// CONNECTION_CLOSE packets for (recently) closed connections).
tr2 := &quic.Transport{
Conn: c,
StatelessResetKey: &statelessResetKey,
Conn: c,
ConnectionIDLength: connIDLen,
StatelessResetKey: &statelessResetKey,
}
defer tr2.Close()
ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))

View file

@ -101,6 +101,7 @@ var _ = Describe("0-RTT", func() {
transfer0RTTData := func(
ln *quic.EarlyListener,
proxyPort int,
connIDLen int,
clientTLSConf *tls.Config,
clientConf *quic.Config,
testdata []byte, // data to transfer
@ -125,13 +126,35 @@ var _ = Describe("0-RTT", func() {
if clientConf == nil {
clientConf = getQuicConfig(nil)
}
conn, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxyPort),
clientTLSConf,
clientConf,
)
Expect(err).ToNot(HaveOccurred())
var conn quic.EarlyConnection
if connIDLen == 0 {
var err error
conn, err = quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxyPort),
clientTLSConf,
clientConf,
)
Expect(err).ToNot(HaveOccurred())
} else {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
ConnectionIDLength: connIDLen,
}
defer tr.Close()
conn, err = tr.DialEarly(
context.Background(),
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort},
clientTLSConf,
clientConf,
)
Expect(err).ToNot(HaveOccurred())
}
defer conn.CloseWithError(0, "")
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
@ -212,8 +235,9 @@ var _ = Describe("0-RTT", func() {
transfer0RTTData(
ln,
proxy.LocalPort(),
connIDLen,
clientTLSConf,
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
getQuicConfig(nil),
PRData,
)
@ -373,7 +397,7 @@ var _ = Describe("0-RTT", func() {
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData)
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
num0RTT := atomic.LoadUint32(&num0RTTPackets)
numDropped := atomic.LoadUint32(&num0RTTDropped)
@ -448,7 +472,7 @@ var _ = Describe("0-RTT", func() {
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, GeneratePRData(5000)) // ~5 packets
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets
mutex.Lock()
defer mutex.Unlock()
@ -768,7 +792,7 @@ var _ = Describe("0-RTT", func() {
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData)
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())

View file

@ -242,18 +242,6 @@ type Config struct {
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
Versions []VersionNumber
// The length of the connection ID in bytes.
// It can be 0, or any value between 4 and 18.
// If not set, the interpretation depends on where the Config is used:
// If used for dialing an address, a 0 byte connection ID will be used.
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
ConnectionIDLength int
// An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection.
// The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers.
// By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used.
// Otherwise, if one is provided, then ConnectionIDLength is ignored.
ConnectionIDGenerator ConnectionIDGenerator
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
// Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted.
// If this value is zero, the timeout is set to 5 seconds.

View file

@ -68,8 +68,9 @@ type baseServer struct {
tokenGenerator *handshake.TokenGenerator
connHandler packetHandlerManager
onClose func()
connIDGenerator ConnectionIDGenerator
connHandler packetHandlerManager
onClose func()
receivedPackets chan *receivedPacket
@ -85,6 +86,7 @@ type baseServer struct {
protocol.ConnectionID, /* client dest connection ID */
protocol.ConnectionID, /* destination connection ID */
protocol.ConnectionID, /* source connection ID */
ConnectionIDGenerator,
protocol.StatelessResetToken,
*Config,
*tls.Config,
@ -210,7 +212,7 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear
return tr.ListenEarly(tlsConf, config)
}
func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) {
func newServer(conn rawConn, connHandler packetHandlerManager, connIDGenerator ConnectionIDGenerator, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) {
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
if err != nil {
return nil, err
@ -220,6 +222,7 @@ func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Conf
tlsConf: tlsConf,
config: config,
tokenGenerator: tokenGenerator,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn),
errorChan: make(chan struct{}),
@ -574,7 +577,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
return nil
}
connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
@ -603,6 +606,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
hdr.DestConnectionID,
hdr.SrcConnectionID,
connID,
s.connIDGenerator,
s.connHandler.GetStatelessResetToken(connID),
s.config,
s.tlsConf,
@ -669,7 +673,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
// Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the connection.
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
srcConnID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
return err
}

View file

@ -286,6 +286,7 @@ var _ = Describe("Server", func() {
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
_ ConnectionIDGenerator,
tokenP protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@ -488,6 +489,7 @@ var _ = Describe("Server", func() {
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
_ ConnectionIDGenerator,
tokenP protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@ -547,6 +549,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@ -600,6 +603,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@ -631,6 +635,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@ -702,6 +707,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@ -1009,6 +1015,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@ -1082,6 +1089,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@ -1124,6 +1132,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@ -1187,6 +1196,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@ -1309,6 +1319,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,

View file

@ -59,6 +59,9 @@ type Transport struct {
// Set in init.
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
connIDLen int
// Set in init.
// If no ConnectionIDGenerator is set, this is set to a default.
connIDGenerator ConnectionIDGenerator
server unknownPacketHandler
@ -92,10 +95,10 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error)
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(conf); err != nil {
if err := t.init(conf, true); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, false)
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, false)
if err != nil {
return nil, err
}
@ -121,10 +124,10 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(conf); err != nil {
if err := t.init(conf, true); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, true)
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, true)
if err != nil {
return nil, err
}
@ -137,15 +140,15 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateClientConfig(conf, t.createdConn)
if err := t.init(conf); err != nil {
conf = populateConfig(conf)
if err := t.init(conf, false); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn)
return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn)
}
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
@ -153,15 +156,15 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateClientConfig(conf, t.createdConn)
if err := t.init(conf); err != nil {
conf = populateConfig(conf)
if err := t.init(conf, false); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn)
return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn)
}
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
@ -197,7 +200,7 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
// only print warnings about the UDP receive buffer size once
var receiveBufferWarningOnce sync.Once
func (t *Transport) init(conf *Config) error {
func (t *Transport) init(conf *Config, isServer bool) error {
t.initOnce.Do(func() {
getMultiplexer().AddConn(t.Conn)
@ -208,9 +211,6 @@ func (t *Transport) init(conf *Config) error {
}
t.Tracer = conf.Tracer
t.ConnectionIDLength = conf.ConnectionIDLength
t.ConnectionIDGenerator = conf.ConnectionIDGenerator
t.logger = utils.DefaultLogger // TODO: make this configurable
t.conn = conn
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
@ -219,9 +219,15 @@ func (t *Transport) init(conf *Config) error {
t.closeQueue = make(chan closePacket, 4)
if t.ConnectionIDGenerator != nil {
t.connIDGenerator = t.ConnectionIDGenerator
t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
} else {
t.connIDLen = t.ConnectionIDLength
connIDLen := t.ConnectionIDLength
if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) {
connIDLen = protocol.DefaultConnectionIDLength
}
t.connIDLen = connIDLen
t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
}
go t.listen(conn)

View file

@ -61,7 +61,7 @@ var _ = Describe("Transport", func() {
It("handles packets for different packet handlers on the same packet conn", func() {
packetChan := make(chan packetToRead)
tr := &Transport{Conn: newMockPacketConn(packetChan)}
tr.init(&Config{})
tr.init(&Config{}, true)
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
@ -126,9 +126,10 @@ var _ = Describe("Transport", func() {
packetChan := make(chan packetToRead)
tracer := mocklogging.NewMockTracer(mockCtrl)
tr := &Transport{
Conn: newMockPacketConn(packetChan),
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: 10,
}
tr.init(&Config{Tracer: tracer, ConnectionIDLength: 10})
tr.init(&Config{Tracer: tracer}, true)
dropped := make(chan struct{})
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
packetChan <- packetToRead{
@ -147,7 +148,7 @@ var _ = Describe("Transport", func() {
tr := Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(&Config{})
tr.init(&Config{}, true)
tr.handlerMap = phm
done := make(chan struct{})
@ -165,7 +166,7 @@ var _ = Describe("Transport", func() {
tr := Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(&Config{})
tr.init(&Config{}, true)
tr.handlerMap = phm
tempErr := deadlineError{}
@ -183,11 +184,13 @@ var _ = Describe("Transport", func() {
It("handles short header packets resets", func() {
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
packetChan := make(chan packetToRead)
tr := Transport{Conn: newMockPacketConn(packetChan)}
tr.init(&Config{ConnectionIDLength: connID.Len()})
tr := Transport{
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: connID.Len(),
}
tr.init(&Config{}, true)
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(&Config{})
tr.handlerMap = phm
var token protocol.StatelessResetToken
@ -218,10 +221,9 @@ var _ = Describe("Transport", func() {
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
packetChan := make(chan packetToRead)
tr := Transport{Conn: newMockPacketConn(packetChan)}
tr.init(&Config{ConnectionIDLength: connID.Len()})
tr.init(&Config{}, true)
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(&Config{})
tr.handlerMap = phm
var token protocol.StatelessResetToken
@ -251,13 +253,13 @@ var _ = Describe("Transport", func() {
packetChan := make(chan packetToRead)
conn := newMockPacketConn(packetChan)
tr := Transport{
Conn: conn,
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
Conn: conn,
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
ConnectionIDLength: connID.Len(),
}
tr.init(&Config{ConnectionIDLength: connID.Len()})
tr.init(&Config{}, true)
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(&Config{})
tr.handlerMap = phm
var b []byte