mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
remove ConnectionIDLength and ConnectionIDGenerator from the Config
This commit is contained in:
parent
b79b532b04
commit
ba942715db
17 changed files with 232 additions and 183 deletions
19
client.go
19
client.go
|
@ -25,6 +25,7 @@ type client struct {
|
|||
tlsConf *tls.Config
|
||||
config *Config
|
||||
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
|
||||
|
@ -133,6 +134,7 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo
|
|||
func dial(
|
||||
ctx context.Context,
|
||||
conn net.PacketConn,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
packetHandlers packetHandlerManager,
|
||||
addr net.Addr,
|
||||
tlsConf *tls.Config,
|
||||
|
@ -141,7 +143,7 @@ func dial(
|
|||
use0RTT bool,
|
||||
createdPacketConn bool,
|
||||
) (quicConn, error) {
|
||||
c, err := newClient(conn, addr, config, tlsConf, onClose, use0RTT, createdPacketConn)
|
||||
c, err := newClient(conn, addr, connIDGenerator, config, tlsConf, onClose, use0RTT, createdPacketConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -164,14 +166,23 @@ func dial(
|
|||
return c.conn, nil
|
||||
}
|
||||
|
||||
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, onClose func(), use0RTT, createdPacketConn bool) (*client, error) {
|
||||
func newClient(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
config *Config,
|
||||
tlsConf *tls.Config,
|
||||
onClose func(),
|
||||
use0RTT bool,
|
||||
createdPacketConn bool,
|
||||
) (*client, error) {
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
} else {
|
||||
tlsConf = tlsConf.Clone()
|
||||
}
|
||||
|
||||
srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
|
||||
srcConnID, err := connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -180,6 +191,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
|
|||
return nil, err
|
||||
}
|
||||
c := &client{
|
||||
connIDGenerator: connIDGenerator,
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
sconn: newSendPconn(pconn, remoteAddr),
|
||||
|
@ -203,6 +215,7 @@ func (c *client) dial(ctx context.Context) error {
|
|||
c.packetHandlers,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.connIDGenerator,
|
||||
c.config,
|
||||
c.tlsConf,
|
||||
c.initialPacketNumber,
|
||||
|
|
|
@ -39,6 +39,7 @@ var _ = Describe("Client", func() {
|
|||
runner connRunner,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
|
@ -114,6 +115,7 @@ var _ = Describe("Client", func() {
|
|||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
|
@ -132,7 +134,7 @@ var _ = Describe("Client", func() {
|
|||
conn.EXPECT().HandshakeComplete().Return(c)
|
||||
return conn
|
||||
}
|
||||
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, false, false)
|
||||
cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
|
@ -150,6 +152,7 @@ var _ = Describe("Client", func() {
|
|||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
|
@ -168,7 +171,7 @@ var _ = Describe("Client", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, true, false)
|
||||
cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
|
@ -186,6 +189,7 @@ var _ = Describe("Client", func() {
|
|||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
|
@ -203,7 +207,7 @@ var _ = Describe("Client", func() {
|
|||
return conn
|
||||
}
|
||||
var closed bool
|
||||
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, func() { closed = true }, true, false)
|
||||
cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
|
@ -219,16 +223,14 @@ var _ = Describe("Client", func() {
|
|||
MaxIdleTimeout: 42 * time.Hour,
|
||||
MaxIncomingStreams: 1234,
|
||||
MaxIncomingUniStreams: 4321,
|
||||
ConnectionIDLength: 13,
|
||||
TokenStore: tokenStore,
|
||||
EnableDatagrams: true,
|
||||
}
|
||||
c := populateClientConfig(config, false)
|
||||
c := populateConfig(config)
|
||||
Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute))
|
||||
Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour))
|
||||
Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
|
||||
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
|
||||
Expect(c.ConnectionIDLength).To(Equal(13))
|
||||
Expect(c.TokenStore).To(Equal(tokenStore))
|
||||
Expect(c.EnableDatagrams).To(BeTrue())
|
||||
})
|
||||
|
@ -238,7 +240,7 @@ var _ = Describe("Client", func() {
|
|||
MaxIncomingStreams: -1,
|
||||
MaxIncomingUniStreams: 4321,
|
||||
}
|
||||
c := populateClientConfig(config, false)
|
||||
c := populateConfig(config)
|
||||
Expect(c.MaxIncomingStreams).To(BeZero())
|
||||
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
|
||||
})
|
||||
|
@ -248,18 +250,13 @@ var _ = Describe("Client", func() {
|
|||
MaxIncomingStreams: 1234,
|
||||
MaxIncomingUniStreams: -1,
|
||||
}
|
||||
c := populateClientConfig(config, false)
|
||||
c := populateConfig(config)
|
||||
Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
|
||||
Expect(c.MaxIncomingUniStreams).To(BeZero())
|
||||
})
|
||||
|
||||
It("uses 0-byte connection IDs when dialing an address", func() {
|
||||
c := populateClientConfig(&Config{}, true)
|
||||
Expect(c.ConnectionIDLength).To(BeZero())
|
||||
})
|
||||
|
||||
It("fills in default values if options are not set in the Config", func() {
|
||||
c := populateClientConfig(&Config{}, false)
|
||||
c := populateConfig(&Config{})
|
||||
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
|
||||
Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
|
||||
|
@ -267,7 +264,7 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("creates new connections with the right parameters", func() {
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}}
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}}
|
||||
c := make(chan struct{})
|
||||
var cconn sendConn
|
||||
var version protocol.VersionNumber
|
||||
|
@ -278,6 +275,7 @@ var _ = Describe("Client", func() {
|
|||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
|
@ -320,6 +318,7 @@ var _ = Describe("Client", func() {
|
|||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
connID protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
pn protocol.PacketNumber,
|
||||
|
@ -352,7 +351,7 @@ var _ = Describe("Client", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}}
|
||||
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}}
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
_, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
26
config.go
26
config.go
|
@ -42,7 +42,7 @@ func validateConfig(config *Config) error {
|
|||
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateServerConfig(config *Config) *Config {
|
||||
config = populateConfig(config, protocol.DefaultConnectionIDLength)
|
||||
config = populateConfig(config)
|
||||
if config.MaxTokenAge == 0 {
|
||||
config.MaxTokenAge = protocol.TokenValidity
|
||||
}
|
||||
|
@ -55,19 +55,9 @@ func populateServerConfig(config *Config) *Config {
|
|||
return config
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// populateConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||
defaultConnIDLen := protocol.DefaultConnectionIDLength
|
||||
if createdPacketConn {
|
||||
defaultConnIDLen = 0
|
||||
}
|
||||
|
||||
config = populateConfig(config, defaultConnIDLen)
|
||||
return config
|
||||
}
|
||||
|
||||
func populateConfig(config *Config, defaultConnIDLen int) *Config {
|
||||
func populateConfig(config *Config) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
|
@ -75,10 +65,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
|
|||
if len(versions) == 0 {
|
||||
versions = protocol.SupportedVersions
|
||||
}
|
||||
conIDLen := config.ConnectionIDLength
|
||||
if config.ConnectionIDLength == 0 {
|
||||
conIDLen = defaultConnIDLen
|
||||
}
|
||||
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
|
||||
if config.HandshakeIdleTimeout != 0 {
|
||||
handshakeIdleTimeout = config.HandshakeIdleTimeout
|
||||
|
@ -115,10 +101,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
|
|||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
connIDGenerator := config.ConnectionIDGenerator
|
||||
if connIDGenerator == nil {
|
||||
connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen}
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Versions: versions,
|
||||
|
@ -135,8 +117,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
|
|||
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
|
||||
MaxIncomingStreams: maxIncomingStreams,
|
||||
MaxIncomingUniStreams: maxIncomingUniStreams,
|
||||
ConnectionIDLength: conIDLen,
|
||||
ConnectionIDGenerator: connIDGenerator,
|
||||
TokenStore: config.TokenStore,
|
||||
EnableDatagrams: config.EnableDatagrams,
|
||||
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
|
||||
|
|
|
@ -142,18 +142,18 @@ var _ = Describe("Config", func() {
|
|||
var calledAddrValidation bool
|
||||
c1 := &Config{}
|
||||
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
|
||||
c2 := populateConfig(c1, protocol.DefaultConnectionIDLength)
|
||||
c2 := populateConfig(c1)
|
||||
c2.RequireAddressValidation(&net.UDPAddr{})
|
||||
Expect(calledAddrValidation).To(BeTrue())
|
||||
})
|
||||
|
||||
It("copies non-function fields", func() {
|
||||
c := configWithNonZeroNonFunctionFields()
|
||||
Expect(populateConfig(c, protocol.DefaultConnectionIDLength)).To(Equal(c))
|
||||
Expect(populateConfig(c)).To(Equal(c))
|
||||
})
|
||||
|
||||
It("populates empty fields with default values", func() {
|
||||
c := populateConfig(&Config{}, protocol.DefaultConnectionIDLength)
|
||||
c := populateConfig(&Config{})
|
||||
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
|
||||
Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))
|
||||
|
@ -168,18 +168,7 @@ var _ = Describe("Config", func() {
|
|||
|
||||
It("populates empty fields with default values, for the server", func() {
|
||||
c := populateServerConfig(&Config{})
|
||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||
Expect(c.RequireAddressValidation).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("sets a default connection ID length if we didn't create the conn, for the client", func() {
|
||||
c := populateClientConfig(&Config{}, false)
|
||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||
})
|
||||
|
||||
It("doesn't set a default connection ID length if we created the conn, for the client", func() {
|
||||
c := populateClientConfig(&Config{}, true)
|
||||
Expect(c.ConnectionIDLength).To(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -240,6 +240,7 @@ var newConnection = func(
|
|||
clientDestConnID protocol.ConnectionID,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetToken protocol.StatelessResetToken,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
|
@ -283,7 +284,7 @@ var newConnection = func(
|
|||
runner.Retire,
|
||||
runner.ReplaceWithClosed,
|
||||
s.queueControlFrame,
|
||||
s.config.ConnectionIDGenerator,
|
||||
connIDGenerator,
|
||||
)
|
||||
s.preSetup()
|
||||
s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
|
||||
|
@ -363,6 +364,7 @@ var newClientConnection = func(
|
|||
runner connRunner,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
|
@ -402,7 +404,7 @@ var newClientConnection = func(
|
|||
runner.Retire,
|
||||
runner.ReplaceWithClosed,
|
||||
s.queueControlFrame,
|
||||
s.config.ConnectionIDGenerator,
|
||||
connIDGenerator,
|
||||
)
|
||||
s.preSetup()
|
||||
s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
|
||||
|
|
|
@ -113,6 +113,7 @@ var _ = Describe("Connection", func() {
|
|||
clientDestConnID,
|
||||
destConnID,
|
||||
srcConnID,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
protocol.StatelessResetToken{},
|
||||
populateServerConfig(&Config{DisablePathMTUDiscovery: true}),
|
||||
nil, // tls.Config
|
||||
|
@ -2015,8 +2016,6 @@ var _ = Describe("Connection", func() {
|
|||
packer.EXPECT().HandleTransportParameters(params)
|
||||
packer.EXPECT().PackCoalescedPacket(false, conn.version).MaxTimes(3)
|
||||
Expect(conn.earlyConnReady()).ToNot(BeClosed())
|
||||
connRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2)
|
||||
connRunner.EXPECT().Add(gomock.Any(), conn).Times(2)
|
||||
tracer.EXPECT().ReceivedTransportParameters(params)
|
||||
conn.handleTransportParameters(params)
|
||||
Expect(conn.earlyConnReady()).To(BeClosed())
|
||||
|
@ -2378,7 +2377,7 @@ var _ = Describe("Client Connection", func() {
|
|||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
quicConf = populateClientConfig(&Config{}, true)
|
||||
quicConf = populateConfig(&Config{})
|
||||
tlsConf = nil
|
||||
})
|
||||
|
||||
|
@ -2402,6 +2401,7 @@ var _ = Describe("Client Connection", func() {
|
|||
connRunner,
|
||||
destConnID,
|
||||
protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
quicConf,
|
||||
tlsConf,
|
||||
42, // initial packet number
|
||||
|
|
|
@ -34,13 +34,23 @@ func (c *connIDGenerator) ConnectionIDLen() int {
|
|||
var _ = Describe("Connection ID lengths tests", func() {
|
||||
randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) }
|
||||
|
||||
runServer := func(conf *quic.Config) *quic.Listener {
|
||||
if conf.ConnectionIDGenerator != nil {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", conf.ConnectionIDGenerator.ConnectionIDLen())))
|
||||
// connIDLen is ignored when connIDGenerator is set
|
||||
runServer := func(connIDLen int, connIDGenerator quic.ConnectionIDGenerator) (*quic.Listener, func()) {
|
||||
if connIDGenerator != nil {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", connIDGenerator.ConnectionIDLen())))
|
||||
} else {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", connIDLen)))
|
||||
}
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf)
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
ConnectionIDGenerator: connIDGenerator,
|
||||
}
|
||||
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -59,20 +69,35 @@ var _ = Describe("Connection ID lengths tests", func() {
|
|||
}()
|
||||
}
|
||||
}()
|
||||
return ln
|
||||
return ln, func() {
|
||||
ln.Close()
|
||||
tr.Close()
|
||||
}
|
||||
}
|
||||
|
||||
runClient := func(addr net.Addr, conf *quic.Config) {
|
||||
if conf.ConnectionIDGenerator != nil {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", conf.ConnectionIDGenerator.ConnectionIDLen())))
|
||||
// connIDLen is ignored when connIDGenerator is set
|
||||
runClient := func(addr net.Addr, connIDLen int, connIDGenerator quic.ConnectionIDGenerator) {
|
||||
if connIDGenerator != nil {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", connIDGenerator.ConnectionIDLen())))
|
||||
} else {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", connIDLen)))
|
||||
}
|
||||
cl, err := quic.DialAddr(
|
||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
ConnectionIDGenerator: connIDGenerator,
|
||||
}
|
||||
defer tr.Close()
|
||||
cl, err := tr.Dial(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port},
|
||||
getTLSClientConfig(),
|
||||
conf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer cl.CloseWithError(0, "")
|
||||
|
@ -84,32 +109,20 @@ var _ = Describe("Connection ID lengths tests", func() {
|
|||
}
|
||||
|
||||
It("downloads a file using a 0-byte connection ID for the client", func() {
|
||||
serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()})
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
|
||||
runClient(ln.Addr(), getQuicConfig(nil))
|
||||
ln, closeFn := runServer(randomConnIDLen(), nil)
|
||||
defer closeFn()
|
||||
runClient(ln.Addr(), 0, nil)
|
||||
})
|
||||
|
||||
It("downloads a file when both client and server use a random connection ID length", func() {
|
||||
serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()})
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
|
||||
runClient(ln.Addr(), getQuicConfig(nil))
|
||||
ln, closeFn := runServer(randomConnIDLen(), nil)
|
||||
defer closeFn()
|
||||
runClient(ln.Addr(), randomConnIDLen(), nil)
|
||||
})
|
||||
|
||||
It("downloads a file when both client and server use a custom connection ID generator", func() {
|
||||
serverConf := getQuicConfig(&quic.Config{
|
||||
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
|
||||
})
|
||||
clientConf := getQuicConfig(&quic.Config{
|
||||
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
|
||||
})
|
||||
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
|
||||
runClient(ln.Addr(), clientConf)
|
||||
ln, closeFn := runServer(0, &connIDGenerator{length: randomConnIDLen()})
|
||||
defer closeFn()
|
||||
runClient(ln.Addr(), 0, &connIDGenerator{length: randomConnIDLen()})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -244,7 +244,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
pconn, err = net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
dialer = &quic.Transport{Conn: pconn}
|
||||
dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
@ -303,7 +303,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
// This should free one spot in the queue.
|
||||
Expect(firstConn.CloseWithError(0, ""))
|
||||
Eventually(firstConn.Context().Done()).Should(BeClosed())
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond))
|
||||
time.Sleep(scaleDuration(200 * time.Millisecond))
|
||||
|
||||
// dial again, and expect that this dial succeeds
|
||||
_, err = dial()
|
||||
|
|
|
@ -35,7 +35,11 @@ var _ = Describe("MITM test", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
serverUDPConn, err = net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ln, err := quic.Listen(serverUDPConn, getTLSConfig(), serverConfig)
|
||||
tr := &quic.Transport{
|
||||
Conn: serverUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
ln, err := tr.Listen(getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -68,7 +72,7 @@ var _ = Describe("MITM test", func() {
|
|||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
serverConfig = getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen})
|
||||
serverConfig = getQuicConfig(nil)
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
clientUDPConn, err = net.ListenUDP("udp", addr)
|
||||
|
@ -146,12 +150,15 @@ var _ = Describe("MITM test", func() {
|
|||
defer closeFn()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
tr := &quic.Transport{
|
||||
Conn: clientUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
conn, err := tr.Dial(
|
||||
context.Background(),
|
||||
clientUDPConn,
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
|
@ -190,12 +197,15 @@ var _ = Describe("MITM test", func() {
|
|||
defer closeFn()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
tr := &quic.Transport{
|
||||
Conn: clientUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
conn, err := tr.Dial(
|
||||
context.Background(),
|
||||
clientUDPConn,
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
|
@ -302,20 +312,20 @@ var _ = Describe("MITM test", func() {
|
|||
const rtt = 20 * time.Millisecond
|
||||
|
||||
runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) {
|
||||
proxyPort, closeFn := startServerAndProxy(delayCb, nil)
|
||||
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = quic.Dial(
|
||||
tr := &quic.Transport{
|
||||
Conn: clientUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
_, err = tr.Dial(
|
||||
context.Background(),
|
||||
clientUDPConn,
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
ConnectionIDLength: connIDLen,
|
||||
HandshakeIdleTimeout: 2 * time.Second,
|
||||
}),
|
||||
getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}),
|
||||
)
|
||||
return closeFn, err
|
||||
return func() { tr.Close(); serverCloseFn() }, err
|
||||
}
|
||||
|
||||
// fails immediately because client connection closes when it can't find compatible version
|
||||
|
|
|
@ -137,7 +137,6 @@ var _ = Describe("Multiplexing", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
|
||||
server, err := tr.Listen(
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
|
|
|
@ -30,6 +30,7 @@ var _ = Describe("Stateless Resets", func() {
|
|||
tr := &quic.Transport{
|
||||
Conn: c,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
defer tr.Close()
|
||||
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
|
@ -61,14 +62,21 @@ var _ = Describe("Stateless Resets", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
cl := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
MaxIdleTimeout: 2 * time.Second,
|
||||
}),
|
||||
}
|
||||
defer cl.Close()
|
||||
conn, err := cl.Dial(
|
||||
context.Background(),
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
|
@ -87,6 +95,7 @@ var _ = Describe("Stateless Resets", func() {
|
|||
// CONNECTION_CLOSE packets for (recently) closed connections).
|
||||
tr2 := &quic.Transport{
|
||||
Conn: c,
|
||||
ConnectionIDLength: connIDLen,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
}
|
||||
defer tr2.Close()
|
||||
|
|
|
@ -101,6 +101,7 @@ var _ = Describe("0-RTT", func() {
|
|||
transfer0RTTData := func(
|
||||
ln *quic.EarlyListener,
|
||||
proxyPort int,
|
||||
connIDLen int,
|
||||
clientTLSConf *tls.Config,
|
||||
clientConf *quic.Config,
|
||||
testdata []byte, // data to transfer
|
||||
|
@ -125,13 +126,35 @@ var _ = Describe("0-RTT", func() {
|
|||
if clientConf == nil {
|
||||
clientConf = getQuicConfig(nil)
|
||||
}
|
||||
conn, err := quic.DialAddrEarly(
|
||||
var conn quic.EarlyConnection
|
||||
if connIDLen == 0 {
|
||||
var err error
|
||||
conn, err = quic.DialAddrEarly(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxyPort),
|
||||
clientTLSConf,
|
||||
clientConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
} else {
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
defer tr.Close()
|
||||
conn, err = tr.DialEarly(
|
||||
context.Background(),
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort},
|
||||
clientTLSConf,
|
||||
clientConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -212,8 +235,9 @@ var _ = Describe("0-RTT", func() {
|
|||
transfer0RTTData(
|
||||
ln,
|
||||
proxy.LocalPort(),
|
||||
connIDLen,
|
||||
clientTLSConf,
|
||||
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
|
||||
getQuicConfig(nil),
|
||||
PRData,
|
||||
)
|
||||
|
||||
|
@ -373,7 +397,7 @@ var _ = Describe("0-RTT", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData)
|
||||
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
|
||||
|
||||
num0RTT := atomic.LoadUint32(&num0RTTPackets)
|
||||
numDropped := atomic.LoadUint32(&num0RTTDropped)
|
||||
|
@ -448,7 +472,7 @@ var _ = Describe("0-RTT", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, GeneratePRData(5000)) // ~5 packets
|
||||
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
@ -768,7 +792,7 @@ var _ = Describe("0-RTT", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData)
|
||||
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
|
||||
|
||||
Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
|
|
12
interface.go
12
interface.go
|
@ -242,18 +242,6 @@ type Config struct {
|
|||
// The QUIC versions that can be negotiated.
|
||||
// If not set, it uses all versions available.
|
||||
Versions []VersionNumber
|
||||
// The length of the connection ID in bytes.
|
||||
// It can be 0, or any value between 4 and 18.
|
||||
// If not set, the interpretation depends on where the Config is used:
|
||||
// If used for dialing an address, a 0 byte connection ID will be used.
|
||||
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
|
||||
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
|
||||
ConnectionIDLength int
|
||||
// An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection.
|
||||
// The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers.
|
||||
// By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used.
|
||||
// Otherwise, if one is provided, then ConnectionIDLength is ignored.
|
||||
ConnectionIDGenerator ConnectionIDGenerator
|
||||
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
|
||||
// Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted.
|
||||
// If this value is zero, the timeout is set to 5 seconds.
|
||||
|
|
10
server.go
10
server.go
|
@ -68,6 +68,7 @@ type baseServer struct {
|
|||
|
||||
tokenGenerator *handshake.TokenGenerator
|
||||
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
connHandler packetHandlerManager
|
||||
onClose func()
|
||||
|
||||
|
@ -85,6 +86,7 @@ type baseServer struct {
|
|||
protocol.ConnectionID, /* client dest connection ID */
|
||||
protocol.ConnectionID, /* destination connection ID */
|
||||
protocol.ConnectionID, /* source connection ID */
|
||||
ConnectionIDGenerator,
|
||||
protocol.StatelessResetToken,
|
||||
*Config,
|
||||
*tls.Config,
|
||||
|
@ -210,7 +212,7 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear
|
|||
return tr.ListenEarly(tlsConf, config)
|
||||
}
|
||||
|
||||
func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) {
|
||||
func newServer(conn rawConn, connHandler packetHandlerManager, connIDGenerator ConnectionIDGenerator, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) {
|
||||
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -220,6 +222,7 @@ func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Conf
|
|||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
tokenGenerator: tokenGenerator,
|
||||
connIDGenerator: connIDGenerator,
|
||||
connHandler: connHandler,
|
||||
connQueue: make(chan quicConn),
|
||||
errorChan: make(chan struct{}),
|
||||
|
@ -574,7 +577,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
return nil
|
||||
}
|
||||
|
||||
connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
|
||||
connID, err := s.connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -603,6 +606,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
hdr.DestConnectionID,
|
||||
hdr.SrcConnectionID,
|
||||
connID,
|
||||
s.connIDGenerator,
|
||||
s.connHandler.GetStatelessResetToken(connID),
|
||||
s.config,
|
||||
s.tlsConf,
|
||||
|
@ -669,7 +673,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
|
|||
// Log the Initial packet now.
|
||||
// If no Retry is sent, the packet will be logged by the connection.
|
||||
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
|
||||
srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
|
||||
srcConnID, err := s.connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -286,6 +286,7 @@ var _ = Describe("Server", func() {
|
|||
clientDestConnID protocol.ConnectionID,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
tokenP protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
|
@ -488,6 +489,7 @@ var _ = Describe("Server", func() {
|
|||
clientDestConnID protocol.ConnectionID,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
tokenP protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
|
@ -547,6 +549,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
|
@ -600,6 +603,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
|
@ -631,6 +635,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
|
@ -702,6 +707,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
|
@ -1009,6 +1015,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
|
@ -1082,6 +1089,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
|
@ -1124,6 +1132,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
|
@ -1187,6 +1196,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
|
@ -1309,6 +1319,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
|
|
36
transport.go
36
transport.go
|
@ -59,6 +59,9 @@ type Transport struct {
|
|||
// Set in init.
|
||||
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
|
||||
connIDLen int
|
||||
// Set in init.
|
||||
// If no ConnectionIDGenerator is set, this is set to a default.
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
|
||||
server unknownPacketHandler
|
||||
|
||||
|
@ -92,10 +95,10 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error)
|
|||
return nil, errListenerAlreadySet
|
||||
}
|
||||
conf = populateServerConfig(conf)
|
||||
if err := t.init(conf); err != nil {
|
||||
if err := t.init(conf, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, false)
|
||||
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -121,10 +124,10 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen
|
|||
return nil, errListenerAlreadySet
|
||||
}
|
||||
conf = populateServerConfig(conf)
|
||||
if err := t.init(conf); err != nil {
|
||||
if err := t.init(conf, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, true)
|
||||
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -137,15 +140,15 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
|
|||
if err := validateConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf = populateClientConfig(conf, t.createdConn)
|
||||
if err := t.init(conf); err != nil {
|
||||
conf = populateConfig(conf)
|
||||
if err := t.init(conf, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var onClose func()
|
||||
if t.isSingleUse {
|
||||
onClose = func() { t.Close() }
|
||||
}
|
||||
return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn)
|
||||
return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn)
|
||||
}
|
||||
|
||||
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
|
||||
|
@ -153,15 +156,15 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
|
|||
if err := validateConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf = populateClientConfig(conf, t.createdConn)
|
||||
if err := t.init(conf); err != nil {
|
||||
conf = populateConfig(conf)
|
||||
if err := t.init(conf, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var onClose func()
|
||||
if t.isSingleUse {
|
||||
onClose = func() { t.Close() }
|
||||
}
|
||||
return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn)
|
||||
return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn)
|
||||
}
|
||||
|
||||
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
|
||||
|
@ -197,7 +200,7 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
|
|||
// only print warnings about the UDP receive buffer size once
|
||||
var receiveBufferWarningOnce sync.Once
|
||||
|
||||
func (t *Transport) init(conf *Config) error {
|
||||
func (t *Transport) init(conf *Config, isServer bool) error {
|
||||
t.initOnce.Do(func() {
|
||||
getMultiplexer().AddConn(t.Conn)
|
||||
|
||||
|
@ -208,9 +211,6 @@ func (t *Transport) init(conf *Config) error {
|
|||
}
|
||||
|
||||
t.Tracer = conf.Tracer
|
||||
t.ConnectionIDLength = conf.ConnectionIDLength
|
||||
t.ConnectionIDGenerator = conf.ConnectionIDGenerator
|
||||
|
||||
t.logger = utils.DefaultLogger // TODO: make this configurable
|
||||
t.conn = conn
|
||||
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
|
||||
|
@ -219,9 +219,15 @@ func (t *Transport) init(conf *Config) error {
|
|||
t.closeQueue = make(chan closePacket, 4)
|
||||
|
||||
if t.ConnectionIDGenerator != nil {
|
||||
t.connIDGenerator = t.ConnectionIDGenerator
|
||||
t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
|
||||
} else {
|
||||
t.connIDLen = t.ConnectionIDLength
|
||||
connIDLen := t.ConnectionIDLength
|
||||
if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) {
|
||||
connIDLen = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
t.connIDLen = connIDLen
|
||||
t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
|
||||
}
|
||||
|
||||
go t.listen(conn)
|
||||
|
|
|
@ -61,7 +61,7 @@ var _ = Describe("Transport", func() {
|
|||
It("handles packets for different packet handlers on the same packet conn", func() {
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := &Transport{Conn: newMockPacketConn(packetChan)}
|
||||
tr.init(&Config{})
|
||||
tr.init(&Config{}, true)
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
|
@ -127,8 +127,9 @@ var _ = Describe("Transport", func() {
|
|||
tracer := mocklogging.NewMockTracer(mockCtrl)
|
||||
tr := &Transport{
|
||||
Conn: newMockPacketConn(packetChan),
|
||||
ConnectionIDLength: 10,
|
||||
}
|
||||
tr.init(&Config{Tracer: tracer, ConnectionIDLength: 10})
|
||||
tr.init(&Config{Tracer: tracer}, true)
|
||||
dropped := make(chan struct{})
|
||||
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
|
||||
packetChan <- packetToRead{
|
||||
|
@ -147,7 +148,7 @@ var _ = Describe("Transport", func() {
|
|||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(&Config{})
|
||||
tr.init(&Config{}, true)
|
||||
tr.handlerMap = phm
|
||||
|
||||
done := make(chan struct{})
|
||||
|
@ -165,7 +166,7 @@ var _ = Describe("Transport", func() {
|
|||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(&Config{})
|
||||
tr.init(&Config{}, true)
|
||||
tr.handlerMap = phm
|
||||
|
||||
tempErr := deadlineError{}
|
||||
|
@ -183,11 +184,13 @@ var _ = Describe("Transport", func() {
|
|||
It("handles short header packets resets", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
tr.init(&Config{ConnectionIDLength: connID.Len()})
|
||||
tr := Transport{
|
||||
Conn: newMockPacketConn(packetChan),
|
||||
ConnectionIDLength: connID.Len(),
|
||||
}
|
||||
tr.init(&Config{}, true)
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(&Config{})
|
||||
tr.handlerMap = phm
|
||||
|
||||
var token protocol.StatelessResetToken
|
||||
|
@ -218,10 +221,9 @@ var _ = Describe("Transport", func() {
|
|||
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
tr.init(&Config{ConnectionIDLength: connID.Len()})
|
||||
tr.init(&Config{}, true)
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(&Config{})
|
||||
tr.handlerMap = phm
|
||||
|
||||
var token protocol.StatelessResetToken
|
||||
|
@ -253,11 +255,11 @@ var _ = Describe("Transport", func() {
|
|||
tr := Transport{
|
||||
Conn: conn,
|
||||
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
|
||||
ConnectionIDLength: connID.Len(),
|
||||
}
|
||||
tr.init(&Config{ConnectionIDLength: connID.Len()})
|
||||
tr.init(&Config{}, true)
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(&Config{})
|
||||
tr.handlerMap = phm
|
||||
|
||||
var b []byte
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue