mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
implement the Transport
This commit is contained in:
parent
ae5a8bd35c
commit
8189e75be6
31 changed files with 1309 additions and 1250 deletions
152
client.go
152
client.go
|
@ -20,6 +20,7 @@ type client struct {
|
||||||
use0RTT bool
|
use0RTT bool
|
||||||
|
|
||||||
packetHandlers packetHandlerManager
|
packetHandlers packetHandlerManager
|
||||||
|
onClose func()
|
||||||
|
|
||||||
tlsConf *tls.Config
|
tlsConf *tls.Config
|
||||||
config *Config
|
config *Config
|
||||||
|
@ -45,32 +46,58 @@ var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||||
|
|
||||||
// DialAddr establishes a new QUIC connection to a server.
|
// DialAddr establishes a new QUIC connection to a server.
|
||||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||||
// The hostname for SNI is taken from the given address.
|
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||||
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (Connection, error) {
|
|
||||||
return dialAddrContext(ctx, addr, tlsConf, config, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
|
|
||||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
|
||||||
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
|
|
||||||
conn, err := dialAddrContext(ctx, addr, tlsConf, config, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection")
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, config *Config, use0RTT bool) (quicConn, error) {
|
|
||||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return dialContext(ctx, udpConn, udpAddr, tlsConf, config, use0RTT, true)
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dl, err := setupTransport(udpConn, tlsConf, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dl.Dial(ctx, udpAddr, tlsConf, conf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
|
||||||
|
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||||
|
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dl, err := setupTransport(udpConn, tlsConf, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf)
|
||||||
|
if err != nil {
|
||||||
|
dl.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
|
||||||
|
// See DialEarly for details.
|
||||||
|
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||||
|
dl, err := setupTransport(c, tlsConf, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
|
||||||
|
if err != nil {
|
||||||
|
dl.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If
|
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If
|
||||||
|
@ -78,34 +105,43 @@ func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, conf
|
||||||
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
|
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
|
||||||
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
|
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
|
||||||
// packets.
|
// packets.
|
||||||
// The same PacketConn can be used for multiple calls to Dial and Listen.
|
|
||||||
// QUIC connection IDs are used for demultiplexing the different connections.
|
|
||||||
// The tls.Config must define an application protocol (using NextProtos).
|
// The tls.Config must define an application protocol (using NextProtos).
|
||||||
func Dial(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (Connection, error) {
|
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||||
return dialContext(ctx, pconn, addr, tlsConf, config, false, false)
|
dl, err := setupTransport(c, tlsConf, false)
|
||||||
}
|
|
||||||
|
|
||||||
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
|
|
||||||
// The same PacketConn can be used for multiple calls to Dial and Listen,
|
|
||||||
// QUIC connection IDs are used for demultiplexing the different connections.
|
|
||||||
// The tls.Config must define an application protocol (using NextProtos).
|
|
||||||
func DialEarly(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
|
|
||||||
return dialContext(ctx, pconn, addr, tlsConf, config, true, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config, use0RTT bool, createdPacketConn bool) (quicConn, error) {
|
|
||||||
if tlsConf == nil {
|
|
||||||
return nil, errors.New("quic: tls.Config not set")
|
|
||||||
}
|
|
||||||
if err := validateConfig(config); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
config = populateClientConfig(config, createdPacketConn)
|
|
||||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c, err := newClient(pconn, addr, config, tlsConf, use0RTT, createdPacketConn)
|
conn, err := dl.Dial(ctx, addr, tlsConf, conf)
|
||||||
|
if err != nil {
|
||||||
|
dl.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
|
||||||
|
if tlsConf == nil {
|
||||||
|
return nil, errors.New("quic: tls.Config not set")
|
||||||
|
}
|
||||||
|
return &Transport{
|
||||||
|
Conn: c,
|
||||||
|
createdConn: createdPacketConn,
|
||||||
|
isSingleUse: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dial(
|
||||||
|
ctx context.Context,
|
||||||
|
conn net.PacketConn,
|
||||||
|
packetHandlers packetHandlerManager,
|
||||||
|
addr net.Addr,
|
||||||
|
tlsConf *tls.Config,
|
||||||
|
config *Config,
|
||||||
|
onClose func(),
|
||||||
|
use0RTT bool,
|
||||||
|
createdPacketConn bool,
|
||||||
|
) (quicConn, error) {
|
||||||
|
c, err := newClient(conn, addr, config, tlsConf, onClose, use0RTT, createdPacketConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -128,7 +164,7 @@ func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsCo
|
||||||
return c.conn, nil
|
return c.conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, use0RTT bool, createdPacketConn bool) (*client, error) {
|
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, onClose func(), use0RTT, createdPacketConn bool) (*client, error) {
|
||||||
if tlsConf == nil {
|
if tlsConf == nil {
|
||||||
tlsConf = &tls.Config{}
|
tlsConf = &tls.Config{}
|
||||||
} else {
|
} else {
|
||||||
|
@ -149,6 +185,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
|
||||||
sconn: newSendPconn(pconn, remoteAddr),
|
sconn: newSendPconn(pconn, remoteAddr),
|
||||||
createdPacketConn: createdPacketConn,
|
createdPacketConn: createdPacketConn,
|
||||||
use0RTT: use0RTT,
|
use0RTT: use0RTT,
|
||||||
|
onClose: onClose,
|
||||||
tlsConf: tlsConf,
|
tlsConf: tlsConf,
|
||||||
config: config,
|
config: config,
|
||||||
version: config.Versions[0],
|
version: config.Versions[0],
|
||||||
|
@ -179,13 +216,18 @@ func (c *client) dial(ctx context.Context) error {
|
||||||
c.packetHandlers.Add(c.srcConnID, c.conn)
|
c.packetHandlers.Add(c.srcConnID, c.conn)
|
||||||
|
|
||||||
errorChan := make(chan error, 1)
|
errorChan := make(chan error, 1)
|
||||||
|
recreateChan := make(chan errCloseForRecreating)
|
||||||
go func() {
|
go func() {
|
||||||
err := c.conn.run() // returns as soon as the connection is closed
|
err := c.conn.run()
|
||||||
|
var recreateErr *errCloseForRecreating
|
||||||
if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn {
|
if errors.As(err, &recreateErr) {
|
||||||
c.packetHandlers.Destroy()
|
recreateChan <- *recreateErr
|
||||||
|
return
|
||||||
}
|
}
|
||||||
errorChan <- err
|
if c.onClose != nil {
|
||||||
|
c.onClose()
|
||||||
|
}
|
||||||
|
errorChan <- err // returns as soon as the connection is closed
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// only set when we're using 0-RTT
|
// only set when we're using 0-RTT
|
||||||
|
@ -200,14 +242,12 @@ func (c *client) dial(ctx context.Context) error {
|
||||||
c.conn.shutdown()
|
c.conn.shutdown()
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case err := <-errorChan:
|
case err := <-errorChan:
|
||||||
var recreateErr *errCloseForRecreating
|
return err
|
||||||
if errors.As(err, &recreateErr) {
|
case recreateErr := <-recreateChan:
|
||||||
c.initialPacketNumber = recreateErr.nextPacketNumber
|
c.initialPacketNumber = recreateErr.nextPacketNumber
|
||||||
c.version = recreateErr.nextVersion
|
c.version = recreateErr.nextVersion
|
||||||
c.hasNegotiatedVersion = true
|
c.hasNegotiatedVersion = true
|
||||||
return c.dial(ctx)
|
return c.dial(ctx)
|
||||||
}
|
|
||||||
return err
|
|
||||||
case <-earlyConnChan:
|
case <-earlyConnChan:
|
||||||
// ready to send 0-RTT data
|
// ready to send 0-RTT data
|
||||||
return nil
|
return nil
|
||||||
|
|
228
client_test.go
228
client_test.go
|
@ -18,13 +18,17 @@ import (
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type nullMultiplexer struct{}
|
||||||
|
|
||||||
|
func (n nullMultiplexer) AddConn(indexableConn) {}
|
||||||
|
func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil }
|
||||||
|
|
||||||
var _ = Describe("Client", func() {
|
var _ = Describe("Client", func() {
|
||||||
var (
|
var (
|
||||||
cl *client
|
cl *client
|
||||||
packetConn *MockPacketConn
|
packetConn *MockPacketConn
|
||||||
addr net.Addr
|
addr net.Addr
|
||||||
connID protocol.ConnectionID
|
connID protocol.ConnectionID
|
||||||
mockMultiplexer *MockMultiplexer
|
|
||||||
origMultiplexer multiplexer
|
origMultiplexer multiplexer
|
||||||
tlsConf *tls.Config
|
tlsConf *tls.Config
|
||||||
tracer *mocklogging.MockConnectionTracer
|
tracer *mocklogging.MockConnectionTracer
|
||||||
|
@ -53,6 +57,7 @@ var _ = Describe("Client", func() {
|
||||||
originalClientConnConstructor = newClientConnection
|
originalClientConnConstructor = newClientConnection
|
||||||
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
|
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
|
||||||
tr := mocklogging.NewMockTracer(mockCtrl)
|
tr := mocklogging.NewMockTracer(mockCtrl)
|
||||||
|
tr.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
|
tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
|
||||||
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}}
|
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}}
|
||||||
Eventually(areConnsRunning).Should(BeFalse())
|
Eventually(areConnsRunning).Should(BeFalse())
|
||||||
|
@ -68,10 +73,9 @@ var _ = Describe("Client", func() {
|
||||||
logger: utils.DefaultLogger,
|
logger: utils.DefaultLogger,
|
||||||
}
|
}
|
||||||
getMultiplexer() // make the sync.Once execute
|
getMultiplexer() // make the sync.Once execute
|
||||||
// replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
|
// replace the clientMuxer. getMultiplexer will now return the nullMultiplexer
|
||||||
mockMultiplexer = NewMockMultiplexer(mockCtrl)
|
|
||||||
origMultiplexer = connMuxer
|
origMultiplexer = connMuxer
|
||||||
connMuxer = mockMultiplexer
|
connMuxer = &nullMultiplexer{}
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
|
@ -100,48 +104,14 @@ var _ = Describe("Client", func() {
|
||||||
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
|
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
|
||||||
})
|
})
|
||||||
|
|
||||||
It("resolves the address", func() {
|
|
||||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
|
||||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
|
||||||
manager.EXPECT().Destroy()
|
|
||||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
|
||||||
|
|
||||||
remoteAddrChan := make(chan string, 1)
|
|
||||||
newClientConnection = func(
|
|
||||||
sconn sendConn,
|
|
||||||
_ connRunner,
|
|
||||||
_ protocol.ConnectionID,
|
|
||||||
_ protocol.ConnectionID,
|
|
||||||
_ *Config,
|
|
||||||
_ *tls.Config,
|
|
||||||
_ protocol.PacketNumber,
|
|
||||||
_ bool,
|
|
||||||
_ bool,
|
|
||||||
_ logging.ConnectionTracer,
|
|
||||||
_ uint64,
|
|
||||||
_ utils.Logger,
|
|
||||||
_ protocol.VersionNumber,
|
|
||||||
) quicConn {
|
|
||||||
remoteAddrChan <- sconn.RemoteAddr().String()
|
|
||||||
conn := NewMockQUICConn(mockCtrl)
|
|
||||||
conn.EXPECT().run()
|
|
||||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
|
||||||
return conn
|
|
||||||
}
|
|
||||||
_, err := DialAddr(context.Background(), "localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns after the handshake is complete", func() {
|
It("returns after the handshake is complete", func() {
|
||||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
|
||||||
|
|
||||||
run := make(chan struct{})
|
run := make(chan struct{})
|
||||||
newClientConnection = func(
|
newClientConnection = func(
|
||||||
_ sendConn,
|
_ sendConn,
|
||||||
runner connRunner,
|
_ connRunner,
|
||||||
_ protocol.ConnectionID,
|
_ protocol.ConnectionID,
|
||||||
_ protocol.ConnectionID,
|
_ protocol.ConnectionID,
|
||||||
_ *Config,
|
_ *Config,
|
||||||
|
@ -162,18 +132,17 @@ var _ = Describe("Client", func() {
|
||||||
conn.EXPECT().HandshakeComplete().Return(c)
|
conn.EXPECT().HandshakeComplete().Return(c)
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, false, false)
|
||||||
s, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(s).ToNot(BeNil())
|
cl.packetHandlers = manager
|
||||||
|
Expect(cl).ToNot(BeNil())
|
||||||
|
Expect(cl.dial(context.Background())).To(Succeed())
|
||||||
Eventually(run).Should(BeClosed())
|
Eventually(run).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns early connections", func() {
|
It("returns early connections", func() {
|
||||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
|
||||||
|
|
||||||
readyChan := make(chan struct{})
|
readyChan := make(chan struct{})
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
newClientConnection = func(
|
newClientConnection = func(
|
||||||
|
@ -193,29 +162,23 @@ var _ = Describe("Client", func() {
|
||||||
) quicConn {
|
) quicConn {
|
||||||
Expect(enable0RTT).To(BeTrue())
|
Expect(enable0RTT).To(BeTrue())
|
||||||
conn := NewMockQUICConn(mockCtrl)
|
conn := NewMockQUICConn(mockCtrl)
|
||||||
conn.EXPECT().run().Do(func() { <-done })
|
conn.EXPECT().run().Do(func() { close(done) })
|
||||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||||
conn.EXPECT().earlyConnReady().Return(readyChan)
|
conn.EXPECT().earlyConnReady().Return(readyChan)
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, true, false)
|
||||||
defer GinkgoRecover()
|
|
||||||
defer close(done)
|
|
||||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
||||||
s, err := DialEarly(context.Background(), packetConn, addr, tlsConf, config)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(s).ToNot(BeNil())
|
cl.packetHandlers = manager
|
||||||
}()
|
Expect(cl).ToNot(BeNil())
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
Expect(cl.dial(context.Background())).To(Succeed())
|
||||||
close(readyChan)
|
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns an error that occurs while waiting for the handshake to complete", func() {
|
It("returns an error that occurs while waiting for the handshake to complete", func() {
|
||||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
|
||||||
|
|
||||||
testErr := errors.New("early handshake error")
|
testErr := errors.New("early handshake error")
|
||||||
newClientConnection = func(
|
newClientConnection = func(
|
||||||
|
@ -236,108 +199,16 @@ var _ = Describe("Client", func() {
|
||||||
conn := NewMockQUICConn(mockCtrl)
|
conn := NewMockQUICConn(mockCtrl)
|
||||||
conn.EXPECT().run().Return(testErr)
|
conn.EXPECT().run().Return(testErr)
|
||||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||||
|
conn.EXPECT().earlyConnReady().Return(make(chan struct{}))
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
var closed bool
|
||||||
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
|
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, func() { closed = true }, true, false)
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes the connection when the context is canceled", func() {
|
|
||||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
|
||||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
|
||||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
|
||||||
|
|
||||||
connRunning := make(chan struct{})
|
|
||||||
defer close(connRunning)
|
|
||||||
conn := NewMockQUICConn(mockCtrl)
|
|
||||||
conn.EXPECT().run().Do(func() {
|
|
||||||
<-connRunning
|
|
||||||
})
|
|
||||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
|
||||||
newClientConnection = func(
|
|
||||||
_ sendConn,
|
|
||||||
_ connRunner,
|
|
||||||
_ protocol.ConnectionID,
|
|
||||||
_ protocol.ConnectionID,
|
|
||||||
_ *Config,
|
|
||||||
_ *tls.Config,
|
|
||||||
_ protocol.PacketNumber,
|
|
||||||
_ bool,
|
|
||||||
_ bool,
|
|
||||||
_ logging.ConnectionTracer,
|
|
||||||
_ uint64,
|
|
||||||
_ utils.Logger,
|
|
||||||
_ protocol.VersionNumber,
|
|
||||||
) quicConn {
|
|
||||||
return conn
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
dialed := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
||||||
_, err := Dial(ctx, packetConn, addr, tlsConf, config)
|
|
||||||
Expect(err).To(MatchError(context.Canceled))
|
|
||||||
close(dialed)
|
|
||||||
}()
|
|
||||||
Consistently(dialed).ShouldNot(BeClosed())
|
|
||||||
conn.EXPECT().shutdown()
|
|
||||||
cancel()
|
|
||||||
Eventually(dialed).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes the connection when it was created by DialAddr", func() {
|
|
||||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
|
||||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
|
||||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
|
||||||
|
|
||||||
var sconn sendConn
|
|
||||||
run := make(chan struct{})
|
|
||||||
connCreated := make(chan struct{})
|
|
||||||
conn := NewMockQUICConn(mockCtrl)
|
|
||||||
newClientConnection = func(
|
|
||||||
connP sendConn,
|
|
||||||
_ connRunner,
|
|
||||||
_ protocol.ConnectionID,
|
|
||||||
_ protocol.ConnectionID,
|
|
||||||
_ *Config,
|
|
||||||
_ *tls.Config,
|
|
||||||
_ protocol.PacketNumber,
|
|
||||||
_ bool,
|
|
||||||
_ bool,
|
|
||||||
_ logging.ConnectionTracer,
|
|
||||||
_ uint64,
|
|
||||||
_ utils.Logger,
|
|
||||||
_ protocol.VersionNumber,
|
|
||||||
) quicConn {
|
|
||||||
sconn = connP
|
|
||||||
close(connCreated)
|
|
||||||
return conn
|
|
||||||
}
|
|
||||||
conn.EXPECT().run().Do(func() {
|
|
||||||
<-run
|
|
||||||
})
|
|
||||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := DialAddr(context.Background(), "localhost:1337", tlsConf, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
close(done)
|
cl.packetHandlers = manager
|
||||||
}()
|
Expect(cl).ToNot(BeNil())
|
||||||
|
Expect(cl.dial(context.Background())).To(MatchError(testErr))
|
||||||
Eventually(connCreated).Should(BeClosed())
|
Expect(closed).To(BeTrue())
|
||||||
|
|
||||||
// check that the connection is not closed
|
|
||||||
Expect(sconn.Write([]byte("foobar"))).To(Succeed())
|
|
||||||
|
|
||||||
manager.EXPECT().Destroy()
|
|
||||||
close(run)
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("quic.Config", func() {
|
Context("quic.Config", func() {
|
||||||
|
@ -365,12 +236,6 @@ var _ = Describe("Client", func() {
|
||||||
Expect(c.EnableDatagrams).To(BeTrue())
|
Expect(c.EnableDatagrams).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors when the Config contains an invalid version", func() {
|
|
||||||
version := protocol.VersionNumber(0x1234)
|
|
||||||
_, err := Dial(context.Background(), packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
|
|
||||||
Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("disables bidirectional streams", func() {
|
It("disables bidirectional streams", func() {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
MaxIncomingStreams: -1,
|
MaxIncomingStreams: -1,
|
||||||
|
@ -405,15 +270,12 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("creates new connections with the right parameters", func() {
|
It("creates new connections with the right parameters", func() {
|
||||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}}
|
||||||
manager.EXPECT().Add(connID, gomock.Any())
|
|
||||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
|
||||||
|
|
||||||
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}}
|
|
||||||
c := make(chan struct{})
|
c := make(chan struct{})
|
||||||
var cconn sendConn
|
var cconn sendConn
|
||||||
var version protocol.VersionNumber
|
var version protocol.VersionNumber
|
||||||
var conf *Config
|
var conf *Config
|
||||||
|
done := make(chan struct{})
|
||||||
newClientConnection = func(
|
newClientConnection = func(
|
||||||
connP sendConn,
|
connP sendConn,
|
||||||
_ connRunner,
|
_ connRunner,
|
||||||
|
@ -437,8 +299,15 @@ var _ = Describe("Client", func() {
|
||||||
conn := NewMockQUICConn(mockCtrl)
|
conn := NewMockQUICConn(mockCtrl)
|
||||||
conn.EXPECT().run()
|
conn.EXPECT().run()
|
||||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||||
|
conn.EXPECT().destroy(gomock.Any())
|
||||||
|
close(done)
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) {
|
||||||
|
<-done
|
||||||
|
return 0, nil, errors.New("closed")
|
||||||
|
})
|
||||||
|
packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
|
||||||
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
|
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Eventually(c).Should(BeClosed())
|
Eventually(c).Should(BeClosed())
|
||||||
|
@ -448,17 +317,12 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("creates a new connections after version negotiation", func() {
|
It("creates a new connections after version negotiation", func() {
|
||||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
|
||||||
manager.EXPECT().Add(connID, gomock.Any()).Times(2)
|
|
||||||
manager.EXPECT().Destroy()
|
|
||||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
|
||||||
|
|
||||||
var counter int
|
var counter int
|
||||||
newClientConnection = func(
|
newClientConnection = func(
|
||||||
_ sendConn,
|
_ sendConn,
|
||||||
_ connRunner,
|
runner connRunner,
|
||||||
_ protocol.ConnectionID,
|
|
||||||
_ protocol.ConnectionID,
|
_ protocol.ConnectionID,
|
||||||
|
connID protocol.ConnectionID,
|
||||||
configP *Config,
|
configP *Config,
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
pn protocol.PacketNumber,
|
pn protocol.PacketNumber,
|
||||||
|
@ -474,20 +338,24 @@ var _ = Describe("Client", func() {
|
||||||
if counter == 0 {
|
if counter == 0 {
|
||||||
Expect(pn).To(BeZero())
|
Expect(pn).To(BeZero())
|
||||||
Expect(hasNegotiatedVersion).To(BeFalse())
|
Expect(hasNegotiatedVersion).To(BeFalse())
|
||||||
conn.EXPECT().run().Return(&errCloseForRecreating{
|
conn.EXPECT().run().DoAndReturn(func() error {
|
||||||
|
runner.Remove(connID)
|
||||||
|
return &errCloseForRecreating{
|
||||||
nextPacketNumber: 109,
|
nextPacketNumber: 109,
|
||||||
nextVersion: 789,
|
nextVersion: 789,
|
||||||
|
}
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
Expect(pn).To(Equal(protocol.PacketNumber(109)))
|
Expect(pn).To(Equal(protocol.PacketNumber(109)))
|
||||||
Expect(hasNegotiatedVersion).To(BeTrue())
|
Expect(hasNegotiatedVersion).To(BeTrue())
|
||||||
conn.EXPECT().run()
|
conn.EXPECT().run()
|
||||||
|
conn.EXPECT().destroy(gomock.Any())
|
||||||
}
|
}
|
||||||
counter++
|
counter++
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}}
|
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}}
|
||||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
_, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config)
|
_, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -495,15 +363,3 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
type mockConnIDGenerator struct {
|
|
||||||
ConnID protocol.ConnectionID
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockConnIDGenerator) GenerateConnectionID() (protocol.ConnectionID, error) {
|
|
||||||
return m.ConnID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockConnIDGenerator) ConnectionIDLen() int {
|
|
||||||
return m.ConnID.Len()
|
|
||||||
}
|
|
||||||
|
|
|
@ -660,6 +660,7 @@ var _ = Describe("Stream Cancellations", func() {
|
||||||
getQuicConfig(&quic.Config{MaxIncomingStreams: maxIncomingStreams, MaxIdleTimeout: 10 * time.Second}),
|
getQuicConfig(&quic.Config{MaxIncomingStreams: maxIncomingStreams, MaxIdleTimeout: 10 * time.Second}),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2 * 4 * maxIncomingStreams)
|
wg.Add(2 * 4 * maxIncomingStreams)
|
||||||
|
|
|
@ -24,6 +24,7 @@ var _ = Describe("Connection ID lengths tests", func() {
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
var drop atomic.Bool
|
var drop atomic.Bool
|
||||||
dropped := make(chan []byte, 100)
|
dropped := make(chan []byte, 100)
|
||||||
|
@ -50,6 +51,7 @@ var _ = Describe("Connection ID lengths tests", func() {
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer conn.CloseWithError(0, "")
|
||||||
|
|
||||||
sconn, err := server.Accept(context.Background())
|
sconn, err := server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
|
@ -35,7 +35,11 @@ var _ = Describe("Connection ID lengths tests", func() {
|
||||||
randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) }
|
randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) }
|
||||||
|
|
||||||
runServer := func(conf *quic.Config) *quic.Listener {
|
runServer := func(conf *quic.Config) *quic.Listener {
|
||||||
|
if conf.ConnectionIDGenerator != nil {
|
||||||
|
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", conf.ConnectionIDGenerator.ConnectionIDLen())))
|
||||||
|
} else {
|
||||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
|
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
|
||||||
|
}
|
||||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf)
|
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -59,7 +63,11 @@ var _ = Describe("Connection ID lengths tests", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
runClient := func(addr net.Addr, conf *quic.Config) {
|
runClient := func(addr net.Addr, conf *quic.Config) {
|
||||||
|
if conf.ConnectionIDGenerator != nil {
|
||||||
|
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", conf.ConnectionIDGenerator.ConnectionIDLen())))
|
||||||
|
} else {
|
||||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
|
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
|
||||||
|
}
|
||||||
cl, err := quic.DialAddr(
|
cl, err := quic.DialAddr(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
|
||||||
|
|
|
@ -22,12 +22,11 @@ var _ = Describe("Datagram test", func() {
|
||||||
const num = 100
|
const num = 100
|
||||||
|
|
||||||
var (
|
var (
|
||||||
proxy *quicproxy.QuicProxy
|
|
||||||
serverConn, clientConn *net.UDPConn
|
serverConn, clientConn *net.UDPConn
|
||||||
dropped, total int32
|
dropped, total int32
|
||||||
)
|
)
|
||||||
|
|
||||||
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) {
|
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) (port int, closeFn func()) {
|
||||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
serverConn, err = net.ListenUDP("udp", addr)
|
serverConn, err = net.ListenUDP("udp", addr)
|
||||||
|
@ -39,8 +38,10 @@ var _ = Describe("Datagram test", func() {
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
accepted := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
defer close(accepted)
|
||||||
conn, err := ln.Accept(context.Background())
|
conn, err := ln.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
@ -67,7 +68,7 @@ var _ = Describe("Datagram test", func() {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||||
// drop 10% of Short Header packets sent from the server
|
// drop 10% of Short Header packets sent from the server
|
||||||
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
|
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
|
||||||
|
@ -87,6 +88,11 @@ var _ = Describe("Datagram test", func() {
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
return proxy.LocalPort(), func() {
|
||||||
|
Eventually(accepted).Should(BeClosed())
|
||||||
|
proxy.Close()
|
||||||
|
ln.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -96,13 +102,10 @@ var _ = Describe("Datagram test", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Expect(proxy.Close()).To(Succeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sends datagrams", func() {
|
It("sends datagrams", func() {
|
||||||
startServerAndProxy(true, true)
|
proxyPort, close := startServerAndProxy(true, true)
|
||||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
|
defer close()
|
||||||
|
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
conn, err := quic.Dial(
|
conn, err := quic.Dial(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
|
@ -117,6 +120,7 @@ var _ = Describe("Datagram test", func() {
|
||||||
for {
|
for {
|
||||||
// Close the connection if no message is received for 100 ms.
|
// Close the connection if no message is received for 100 ms.
|
||||||
timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() {
|
timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() {
|
||||||
|
fmt.Println("closing conn")
|
||||||
conn.CloseWithError(0, "")
|
conn.CloseWithError(0, "")
|
||||||
})
|
})
|
||||||
if _, err := conn.ReceiveMessage(); err != nil {
|
if _, err := conn.ReceiveMessage(); err != nil {
|
||||||
|
@ -134,11 +138,12 @@ var _ = Describe("Datagram test", func() {
|
||||||
BeNumerically(">", expVal*9/10),
|
BeNumerically(">", expVal*9/10),
|
||||||
BeNumerically("<", num),
|
BeNumerically("<", num),
|
||||||
))
|
))
|
||||||
|
Eventually(conn.Context().Done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("server can disable datagram", func() {
|
It("server can disable datagram", func() {
|
||||||
startServerAndProxy(false, true)
|
proxyPort, close := startServerAndProxy(false, true)
|
||||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
|
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
conn, err := quic.Dial(
|
conn, err := quic.Dial(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
|
@ -150,13 +155,13 @@ var _ = Describe("Datagram test", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
|
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
|
||||||
|
|
||||||
|
close()
|
||||||
conn.CloseWithError(0, "")
|
conn.CloseWithError(0, "")
|
||||||
<-time.After(10 * time.Millisecond)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("client can disable datagram", func() {
|
It("client can disable datagram", func() {
|
||||||
startServerAndProxy(false, true)
|
proxyPort, close := startServerAndProxy(false, true)
|
||||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
|
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
conn, err := quic.Dial(
|
conn, err := quic.Dial(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
|
@ -169,7 +174,8 @@ var _ = Describe("Datagram test", func() {
|
||||||
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
|
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
|
||||||
|
|
||||||
Expect(conn.SendMessage([]byte{0})).To(HaveOccurred())
|
Expect(conn.SendMessage([]byte{0})).To(HaveOccurred())
|
||||||
|
|
||||||
|
close()
|
||||||
conn.CloseWithError(0, "")
|
conn.CloseWithError(0, "")
|
||||||
<-time.After(10 * time.Millisecond)
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -24,6 +24,7 @@ var _ = Describe("early data", func() {
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer ln.Close()
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
|
|
@ -8,10 +8,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
const go120 = false
|
||||||
go120 = false
|
|
||||||
errNotSupported = errors.New("not supported")
|
var errNotSupported = errors.New("not supported")
|
||||||
)
|
|
||||||
|
|
||||||
func setReadDeadline(w http.ResponseWriter, deadline time.Time) error {
|
func setReadDeadline(w http.ResponseWriter, deadline time.Time) error {
|
||||||
return errNotSupported
|
return errNotSupported
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var go120 = true
|
const go120 = true
|
||||||
|
|
||||||
func setReadDeadline(w http.ResponseWriter, deadline time.Time) error {
|
func setReadDeadline(w http.ResponseWriter, deadline time.Time) error {
|
||||||
rc := http.NewResponseController(w)
|
rc := http.NewResponseController(w)
|
||||||
|
|
|
@ -62,13 +62,14 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||||
|
|
||||||
runProxy(ln.Addr())
|
runProxy(ln.Addr())
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
_, err = quic.DialAddr(
|
conn, err := quic.DialAddr(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||||
getTLSClientConfig(),
|
getTLSClientConfig(),
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer conn.CloseWithError(0, "")
|
||||||
expectDurationInRTTs(startTime, 2)
|
expectDurationInRTTs(startTime, 2)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -79,13 +80,14 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||||
|
|
||||||
runProxy(ln.Addr())
|
runProxy(ln.Addr())
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
_, err = quic.DialAddr(
|
conn, err := quic.DialAddr(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||||
getTLSClientConfig(),
|
getTLSClientConfig(),
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer conn.CloseWithError(0, "")
|
||||||
expectDurationInRTTs(startTime, 1)
|
expectDurationInRTTs(startTime, 1)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -97,13 +99,14 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||||
|
|
||||||
runProxy(ln.Addr())
|
runProxy(ln.Addr())
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
_, err = quic.DialAddr(
|
conn, err := quic.DialAddr(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||||
getTLSClientConfig(),
|
getTLSClientConfig(),
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer conn.CloseWithError(0, "")
|
||||||
expectDurationInRTTs(startTime, 2)
|
expectDurationInRTTs(startTime, 2)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -131,6 +134,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer conn.CloseWithError(0, "")
|
||||||
str, err := conn.AcceptUniStream(context.Background())
|
str, err := conn.AcceptUniStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
data, err := io.ReadAll(str)
|
data, err := io.ReadAll(str)
|
||||||
|
@ -166,6 +170,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer conn.CloseWithError(0, "")
|
||||||
str, err := conn.AcceptUniStream(context.Background())
|
str, err := conn.AcceptUniStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
data, err := io.ReadAll(str)
|
data, err := io.ReadAll(str)
|
||||||
|
|
|
@ -114,7 +114,7 @@ var _ = Describe("Handshake tests", func() {
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||||
getTLSClientConfig(),
|
getTLSClientConfig(),
|
||||||
nil,
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str, err := conn.AcceptStream(context.Background())
|
str, err := conn.AcceptStream(context.Background())
|
||||||
|
@ -223,13 +223,14 @@ var _ = Describe("Handshake tests", func() {
|
||||||
var (
|
var (
|
||||||
server *quic.Listener
|
server *quic.Listener
|
||||||
pconn net.PacketConn
|
pconn net.PacketConn
|
||||||
|
dialer *quic.Transport
|
||||||
)
|
)
|
||||||
|
|
||||||
dial := func() (quic.Connection, error) {
|
dial := func() (quic.Connection, error) {
|
||||||
remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
||||||
raddr, err := net.ResolveUDPAddr("udp", remoteAddr)
|
raddr, err := net.ResolveUDPAddr("udp", remoteAddr)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
return quic.Dial(context.Background(), pconn, raddr, getTLSClientConfig(), nil)
|
return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -243,11 +244,13 @@ var _ = Describe("Handshake tests", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
pconn, err = net.ListenUDP("udp", laddr)
|
pconn, err = net.ListenUDP("udp", laddr)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
dialer = &quic.Transport{Conn: pconn}
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
Expect(server.Close()).To(Succeed())
|
Expect(server.Close()).To(Succeed())
|
||||||
Expect(pconn.Close()).To(Succeed())
|
Expect(pconn.Close()).To(Succeed())
|
||||||
|
Expect(dialer.Close()).To(Succeed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects new connection attempts if connections don't get accepted", func() {
|
It("rejects new connection attempts if connections don't get accepted", func() {
|
||||||
|
@ -366,6 +369,7 @@ var _ = Describe("Handshake tests", func() {
|
||||||
It("uses tokens provided in NEW_TOKEN frames", func() {
|
It("uses tokens provided in NEW_TOKEN frames", func() {
|
||||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
// dial the first connection and receive the token
|
// dial the first connection and receive the token
|
||||||
go func() {
|
go func() {
|
||||||
|
|
|
@ -382,6 +382,7 @@ var _ = Describe("HTTP tests", func() {
|
||||||
tlsConf.NextProtos = []string{"h3"}
|
tlsConf.NextProtos = []string{"h3"}
|
||||||
ln, err := quic.ListenAddr("localhost:0", tlsConf, nil)
|
ln, err := quic.ListenAddr("localhost:0", tlsConf, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer ln.Close()
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
@ -398,11 +399,8 @@ var _ = Describe("HTTP tests", func() {
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if go120 {
|
||||||
It("supports read deadlines", func() {
|
It("supports read deadlines", func() {
|
||||||
if !go120 {
|
|
||||||
Skip("This test requires Go 1.20+")
|
|
||||||
}
|
|
||||||
|
|
||||||
mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
err := setReadDeadline(w, time.Now().Add(deadlineDelay))
|
err := setReadDeadline(w, time.Now().Add(deadlineDelay))
|
||||||
|
@ -427,10 +425,6 @@ var _ = Describe("HTTP tests", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("supports write deadlines", func() {
|
It("supports write deadlines", func() {
|
||||||
if !go120 {
|
|
||||||
Skip("This test requires Go 1.20+")
|
|
||||||
}
|
|
||||||
|
|
||||||
mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
err := setWriteDeadline(w, time.Now().Add(deadlineDelay))
|
err := setWriteDeadline(w, time.Now().Add(deadlineDelay))
|
||||||
|
@ -451,4 +445,5 @@ var _ = Describe("HTTP tests", func() {
|
||||||
Expect(time.Now().After(expectedEnd)).To(BeTrue())
|
Expect(time.Now().After(expectedEnd)).To(BeTrue())
|
||||||
Expect(string(body)).To(ContainSubstring("aa"))
|
Expect(string(body)).To(ContainSubstring("aa"))
|
||||||
})
|
})
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -34,10 +34,9 @@ var _ = Describe("Multiplexing", func() {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
dial := func(pconn net.PacketConn, addr net.Addr) {
|
dial := func(tr *quic.Transport, addr net.Addr) {
|
||||||
conn, err := quic.Dial(
|
conn, err := tr.Dial(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
pconn,
|
|
||||||
addr,
|
addr,
|
||||||
getTLSClientConfig(),
|
getTLSClientConfig(),
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
|
@ -72,17 +71,18 @@ var _ = Describe("Multiplexing", func() {
|
||||||
conn, err := net.ListenUDP("udp", addr)
|
conn, err := net.ListenUDP("udp", addr)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
tr := &quic.Transport{Conn: conn}
|
||||||
|
|
||||||
done1 := make(chan struct{})
|
done1 := make(chan struct{})
|
||||||
done2 := make(chan struct{})
|
done2 := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
dial(conn, server.Addr())
|
dial(tr, server.Addr())
|
||||||
close(done1)
|
close(done1)
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
dial(conn, server.Addr())
|
dial(tr, server.Addr())
|
||||||
close(done2)
|
close(done2)
|
||||||
}()
|
}()
|
||||||
timeout := 30 * time.Second
|
timeout := 30 * time.Second
|
||||||
|
@ -106,17 +106,18 @@ var _ = Describe("Multiplexing", func() {
|
||||||
conn, err := net.ListenUDP("udp", addr)
|
conn, err := net.ListenUDP("udp", addr)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
tr := &quic.Transport{Conn: conn}
|
||||||
|
|
||||||
done1 := make(chan struct{})
|
done1 := make(chan struct{})
|
||||||
done2 := make(chan struct{})
|
done2 := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
dial(conn, server1.Addr())
|
dial(tr, server1.Addr())
|
||||||
close(done1)
|
close(done1)
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
dial(conn, server2.Addr())
|
dial(tr, server2.Addr())
|
||||||
close(done2)
|
close(done2)
|
||||||
}()
|
}()
|
||||||
timeout := 30 * time.Second
|
timeout := 30 * time.Second
|
||||||
|
@ -135,9 +136,9 @@ var _ = Describe("Multiplexing", func() {
|
||||||
conn, err := net.ListenUDP("udp", addr)
|
conn, err := net.ListenUDP("udp", addr)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
tr := &quic.Transport{Conn: conn}
|
||||||
|
|
||||||
server, err := quic.Listen(
|
server, err := tr.Listen(
|
||||||
conn,
|
|
||||||
getTLSConfig(),
|
getTLSConfig(),
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
|
@ -146,7 +147,7 @@ var _ = Describe("Multiplexing", func() {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
dial(conn, server.Addr())
|
dial(tr, server.Addr())
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
timeout := 30 * time.Second
|
timeout := 30 * time.Second
|
||||||
|
@ -165,15 +166,16 @@ var _ = Describe("Multiplexing", func() {
|
||||||
conn1, err := net.ListenUDP("udp", addr1)
|
conn1, err := net.ListenUDP("udp", addr1)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer conn1.Close()
|
defer conn1.Close()
|
||||||
|
tr1 := &quic.Transport{Conn: conn1}
|
||||||
|
|
||||||
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
|
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
conn2, err := net.ListenUDP("udp", addr2)
|
conn2, err := net.ListenUDP("udp", addr2)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer conn2.Close()
|
defer conn2.Close()
|
||||||
|
tr2 := &quic.Transport{Conn: conn2}
|
||||||
|
|
||||||
server1, err := quic.Listen(
|
server1, err := tr1.Listen(
|
||||||
conn1,
|
|
||||||
getTLSConfig(),
|
getTLSConfig(),
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
|
@ -181,8 +183,7 @@ var _ = Describe("Multiplexing", func() {
|
||||||
runServer(server1)
|
runServer(server1)
|
||||||
defer server1.Close()
|
defer server1.Close()
|
||||||
|
|
||||||
server2, err := quic.Listen(
|
server2, err := tr2.Listen(
|
||||||
conn2,
|
|
||||||
getTLSConfig(),
|
getTLSConfig(),
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
|
@ -194,12 +195,12 @@ var _ = Describe("Multiplexing", func() {
|
||||||
done2 := make(chan struct{})
|
done2 := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
dial(conn2, server1.Addr())
|
dial(tr2, server1.Addr())
|
||||||
close(done1)
|
close(done1)
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
dial(conn1, server2.Addr())
|
dial(tr1, server2.Addr())
|
||||||
close(done2)
|
close(done2)
|
||||||
}()
|
}()
|
||||||
timeout := 30 * time.Second
|
timeout := 30 * time.Second
|
||||||
|
|
|
@ -31,8 +31,8 @@ var _ = Describe("Packetization", func() {
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
||||||
|
|
||||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||||
RemoteAddr: serverAddr,
|
RemoteAddr: serverAddr,
|
||||||
|
@ -54,6 +54,7 @@ var _ = Describe("Packetization", func() {
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer conn.CloseWithError(0, "")
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
|
|
@ -199,8 +199,16 @@ func areHandshakesRunning() bool {
|
||||||
return strings.Contains(b.String(), "RunHandshake")
|
return strings.Contains(b.String(), "RunHandshake")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func areTransportsRunning() bool {
|
||||||
|
var b bytes.Buffer
|
||||||
|
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||||
|
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
|
||||||
|
}
|
||||||
|
|
||||||
var _ = AfterEach(func() {
|
var _ = AfterEach(func() {
|
||||||
Expect(areHandshakesRunning()).To(BeFalse())
|
Expect(areHandshakesRunning()).To(BeFalse())
|
||||||
|
Eventually(areTransportsRunning).Should(BeFalse())
|
||||||
|
|
||||||
if debugLog() {
|
if debugLog() {
|
||||||
logFile, err := os.Create(logFileName)
|
logFile, err := os.Create(logFileName)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
|
@ -2,7 +2,6 @@ package self_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
@ -27,7 +26,13 @@ var _ = Describe("Stateless Resets", func() {
|
||||||
rand.Read(statelessResetKey[:])
|
rand.Read(statelessResetKey[:])
|
||||||
serverConfig := getQuicConfig(&quic.Config{StatelessResetKey: &statelessResetKey})
|
serverConfig := getQuicConfig(&quic.Config{StatelessResetKey: &statelessResetKey})
|
||||||
|
|
||||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
c, err := net.ListenUDP("udp", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
tr := &quic.Transport{
|
||||||
|
Conn: c,
|
||||||
|
}
|
||||||
|
defer tr.Close()
|
||||||
|
ln, err := tr.Listen(getTLSConfig(), serverConfig)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||||
|
|
||||||
|
@ -42,7 +47,8 @@ var _ = Describe("Stateless Resets", func() {
|
||||||
_, err = str.Write([]byte("foobar"))
|
_, err = str.Write([]byte("foobar"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
<-closeServer
|
<-closeServer
|
||||||
ln.Close()
|
Expect(ln.Close()).To(Succeed())
|
||||||
|
Expect(tr.Close()).To(Succeed())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var drop atomic.Bool
|
var drop atomic.Bool
|
||||||
|
@ -77,11 +83,14 @@ var _ = Describe("Stateless Resets", func() {
|
||||||
close(closeServer)
|
close(closeServer)
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
ln2, err := quic.ListenAddr(
|
// We need to create a new Transport here, since the old one is still sending out
|
||||||
fmt.Sprintf("localhost:%d", serverPort),
|
// CONNECTION_CLOSE packets for (recently) closed connections).
|
||||||
getTLSConfig(),
|
tr2 := &quic.Transport{
|
||||||
serverConfig,
|
Conn: c,
|
||||||
)
|
StatelessResetKey: &statelessResetKey,
|
||||||
|
}
|
||||||
|
defer tr2.Close()
|
||||||
|
ln2, err := tr2.Listen(getTLSConfig(), serverConfig)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
drop.Store(false)
|
drop.Store(false)
|
||||||
|
|
||||||
|
@ -100,8 +109,7 @@ var _ = Describe("Stateless Resets", func() {
|
||||||
_, serr = str.Read([]byte{0})
|
_, serr = str.Read([]byte{0})
|
||||||
}
|
}
|
||||||
Expect(serr).To(HaveOccurred())
|
Expect(serr).To(HaveOccurred())
|
||||||
statelessResetErr := &quic.StatelessResetError{}
|
Expect(serr).To(BeAssignableToTypeOf(&quic.StatelessResetError{}))
|
||||||
Expect(errors.As(serr, &statelessResetErr)).To(BeTrue())
|
|
||||||
Expect(ln2.Close()).To(Succeed())
|
Expect(ln2.Close()).To(Succeed())
|
||||||
Eventually(acceptStopped).Should(BeClosed())
|
Eventually(acceptStopped).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
|
@ -94,6 +94,8 @@ var _ = Describe("Bidirectional streams", func() {
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
runSendingPeer(client)
|
runSendingPeer(client)
|
||||||
|
client.CloseWithError(0, "")
|
||||||
|
<-conn.Context().Done()
|
||||||
})
|
})
|
||||||
|
|
||||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
||||||
|
@ -149,5 +151,6 @@ var _ = Describe("Bidirectional streams", func() {
|
||||||
runReceivingPeer(client)
|
runReceivingPeer(client)
|
||||||
<-done1
|
<-done1
|
||||||
<-done2
|
<-done2
|
||||||
|
client.CloseWithError(0, "")
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -473,6 +473,7 @@ var _ = Describe("Timeout tests", func() {
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
serverErrChan := make(chan error, 1)
|
serverErrChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
|
|
|
@ -88,11 +88,14 @@ var _ = Describe("Unidirectional Streams", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
||||||
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
defer close(done)
|
||||||
conn, err := server.Accept(context.Background())
|
conn, err := server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
runSendingPeer(conn)
|
runSendingPeer(conn)
|
||||||
|
<-conn.Context().Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := quic.DialAddr(
|
client, err := quic.DialAddr(
|
||||||
|
@ -103,6 +106,7 @@ var _ = Describe("Unidirectional Streams", func() {
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
runReceivingPeer(client)
|
runReceivingPeer(client)
|
||||||
|
client.CloseWithError(0, "")
|
||||||
})
|
})
|
||||||
|
|
||||||
It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() {
|
It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() {
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
// Code generated by MockGen. DO NOT EDIT.
|
|
||||||
// Source: github.com/quic-go/quic-go (interfaces: Multiplexer)
|
|
||||||
|
|
||||||
// Package quic is a generated GoMock package.
|
|
||||||
package quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
net "net"
|
|
||||||
reflect "reflect"
|
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
logging "github.com/quic-go/quic-go/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockMultiplexer is a mock of Multiplexer interface.
|
|
||||||
type MockMultiplexer struct {
|
|
||||||
ctrl *gomock.Controller
|
|
||||||
recorder *MockMultiplexerMockRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockMultiplexerMockRecorder is the mock recorder for MockMultiplexer.
|
|
||||||
type MockMultiplexerMockRecorder struct {
|
|
||||||
mock *MockMultiplexer
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMockMultiplexer creates a new mock instance.
|
|
||||||
func NewMockMultiplexer(ctrl *gomock.Controller) *MockMultiplexer {
|
|
||||||
mock := &MockMultiplexer{ctrl: ctrl}
|
|
||||||
mock.recorder = &MockMultiplexerMockRecorder{mock}
|
|
||||||
return mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
|
||||||
func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder {
|
|
||||||
return m.recorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddConn mocks base method.
|
|
||||||
func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 *StatelessResetKey, arg3 logging.Tracer) (packetHandlerManager, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2, arg3)
|
|
||||||
ret0, _ := ret[0].(packetHandlerManager)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddConn indicates an expected call of AddConn.
|
|
||||||
func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2, arg3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveConn mocks base method.
|
|
||||||
func (m *MockMultiplexer) RemoveConn(arg0 indexableConn) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "RemoveConn", arg0)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveConn indicates an expected call of RemoveConn.
|
|
||||||
func (mr *MockMultiplexerMockRecorder) RemoveConn(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveConn", reflect.TypeOf((*MockMultiplexer)(nil).RemoveConn), arg0)
|
|
||||||
}
|
|
|
@ -74,6 +74,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 i
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close mocks base method.
|
||||||
|
func (m *MockPacketHandlerManager) Close(arg0 error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Close", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close indicates an expected call of Close.
|
||||||
|
func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
// CloseServer mocks base method.
|
// CloseServer mocks base method.
|
||||||
func (m *MockPacketHandlerManager) CloseServer() {
|
func (m *MockPacketHandlerManager) CloseServer() {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -86,20 +98,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destroy mocks base method.
|
|
||||||
func (m *MockPacketHandlerManager) Destroy() error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "Destroy")
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Destroy indicates an expected call of Destroy.
|
|
||||||
func (mr *MockPacketHandlerManagerMockRecorder) Destroy() *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockPacketHandlerManager)(nil).Destroy))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get mocks base method.
|
// Get mocks base method.
|
||||||
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
|
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -115,6 +113,21 @@ func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Ca
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetByResetToken mocks base method.
|
||||||
|
func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetToken) (packetHandler, bool) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetByResetToken", arg0)
|
||||||
|
ret0, _ := ret[0].(packetHandler)
|
||||||
|
ret1, _ := ret[1].(bool)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByResetToken indicates an expected call of GetByResetToken.
|
||||||
|
func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
// GetStatelessResetToken mocks base method.
|
// GetStatelessResetToken mocks base method.
|
||||||
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken {
|
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -176,15 +189,3 @@ func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 interface{}) *gomock
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetServer mocks base method.
|
|
||||||
func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "SetServer", arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetServer indicates an expected call of SetServer.
|
|
||||||
func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0)
|
|
||||||
}
|
|
||||||
|
|
|
@ -65,9 +65,6 @@ type UnknownPacketHandler = unknownPacketHandler
|
||||||
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
|
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
|
||||||
type PacketHandlerManager = packetHandlerManager
|
type PacketHandlerManager = packetHandlerManager
|
||||||
|
|
||||||
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_multiplexer_test.go github.com/quic-go/quic-go Multiplexer"
|
|
||||||
type Multiplexer = multiplexer
|
|
||||||
|
|
||||||
// Need to use source mode for the batchConn, since reflect mode follows type aliases.
|
// Need to use source mode for the batchConn, since reflect mode follows type aliases.
|
||||||
// See https://github.com/golang/mock/issues/244 for details.
|
// See https://github.com/golang/mock/issues/244 for details.
|
||||||
//
|
//
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
"github.com/quic-go/quic-go/internal/utils"
|
||||||
"github.com/quic-go/quic-go/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -14,30 +13,19 @@ var (
|
||||||
connMuxer multiplexer
|
connMuxer multiplexer
|
||||||
)
|
)
|
||||||
|
|
||||||
type indexableConn interface {
|
type indexableConn interface{ LocalAddr() net.Addr }
|
||||||
LocalAddr() net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
type multiplexer interface {
|
type multiplexer interface {
|
||||||
AddConn(c net.PacketConn, connIDLen int, statelessResetKey *StatelessResetKey, tracer logging.Tracer) (packetHandlerManager, error)
|
AddConn(conn indexableConn)
|
||||||
RemoveConn(indexableConn) error
|
RemoveConn(indexableConn) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type connManager struct {
|
|
||||||
connIDLen int
|
|
||||||
statelessResetKey *StatelessResetKey
|
|
||||||
tracer logging.Tracer
|
|
||||||
manager packetHandlerManager
|
|
||||||
}
|
|
||||||
|
|
||||||
// The connMultiplexer listens on multiple net.PacketConns and dispatches
|
// The connMultiplexer listens on multiple net.PacketConns and dispatches
|
||||||
// incoming packets to the connection handler.
|
// incoming packets to the connection handler.
|
||||||
type connMultiplexer struct {
|
type connMultiplexer struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
conns map[string] /* LocalAddr().String() */ connManager
|
conns map[string] /* LocalAddr().String() */ indexableConn
|
||||||
newPacketHandlerManager func(net.PacketConn, int, *StatelessResetKey, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests
|
|
||||||
|
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,57 +34,38 @@ var _ multiplexer = &connMultiplexer{}
|
||||||
func getMultiplexer() multiplexer {
|
func getMultiplexer() multiplexer {
|
||||||
connMuxerOnce.Do(func() {
|
connMuxerOnce.Do(func() {
|
||||||
connMuxer = &connMultiplexer{
|
connMuxer = &connMultiplexer{
|
||||||
conns: make(map[string]connManager),
|
conns: make(map[string]indexableConn),
|
||||||
logger: utils.DefaultLogger.WithPrefix("muxer"),
|
logger: utils.DefaultLogger.WithPrefix("muxer"),
|
||||||
newPacketHandlerManager: newPacketHandlerMap,
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return connMuxer
|
return connMuxer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *connMultiplexer) AddConn(
|
func (m *connMultiplexer) index(addr net.Addr) string {
|
||||||
c net.PacketConn,
|
return addr.Network() + " " + addr.String()
|
||||||
connIDLen int,
|
}
|
||||||
statelessResetKey *StatelessResetKey,
|
|
||||||
tracer logging.Tracer,
|
func (m *connMultiplexer) AddConn(c indexableConn) {
|
||||||
) (packetHandlerManager, error) {
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
addr := c.LocalAddr()
|
connIndex := m.index(c.LocalAddr())
|
||||||
connIndex := addr.Network() + " " + addr.String()
|
|
||||||
p, ok := m.conns[connIndex]
|
p, ok := m.conns[connIndex]
|
||||||
if !ok {
|
if ok {
|
||||||
manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger)
|
// Panics if we're already listening on this connection.
|
||||||
if err != nil {
|
// This is a safeguard because we're introducing a breaking API change, see
|
||||||
return nil, err
|
// https://github.com/quic-go/quic-go/issues/3727 for details.
|
||||||
}
|
// We'll remove this at a later time, when most users of the library have made the switch.
|
||||||
p = connManager{
|
panic("connection already exists") // TODO: write a nice message
|
||||||
connIDLen: connIDLen,
|
|
||||||
statelessResetKey: statelessResetKey,
|
|
||||||
manager: manager,
|
|
||||||
tracer: tracer,
|
|
||||||
}
|
}
|
||||||
m.conns[connIndex] = p
|
m.conns[connIndex] = p
|
||||||
} else {
|
|
||||||
if p.connIDLen != connIDLen {
|
|
||||||
return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen)
|
|
||||||
}
|
|
||||||
if statelessResetKey != nil && p.statelessResetKey != statelessResetKey {
|
|
||||||
return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn")
|
|
||||||
}
|
|
||||||
if tracer != p.tracer {
|
|
||||||
return nil, fmt.Errorf("cannot use different tracers on the same packet conn")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p.manager, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *connMultiplexer) RemoveConn(c indexableConn) error {
|
func (m *connMultiplexer) RemoveConn(c indexableConn) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String()
|
connIndex := m.index(c.LocalAddr())
|
||||||
if _, ok := m.conns[connIndex]; !ok {
|
if _, ok := m.conns[connIndex]; !ok {
|
||||||
return fmt.Errorf("cannote remove connection, connection is unknown")
|
return fmt.Errorf("cannote remove connection, connection is unknown")
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,71 +3,24 @@ package quic
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testConn struct {
|
|
||||||
counter int
|
|
||||||
net.PacketConn
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Multiplexer", func() {
|
var _ = Describe("Multiplexer", func() {
|
||||||
It("adds a new packet conn ", func() {
|
It("adds new packet conns", func() {
|
||||||
conn := NewMockPacketConn(mockCtrl)
|
conn1 := NewMockPacketConn(mockCtrl)
|
||||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
conn1.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234})
|
||||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234})
|
getMultiplexer().AddConn(conn1)
|
||||||
_, err := getMultiplexer().AddConn(conn, 8, nil, nil)
|
conn2 := NewMockPacketConn(mockCtrl)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
conn2.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1235})
|
||||||
|
getMultiplexer().AddConn(conn2)
|
||||||
})
|
})
|
||||||
|
|
||||||
It("recognizes when the same connection is added twice", func() {
|
It("panics when the same connection is added twice", func() {
|
||||||
srk := &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'}
|
|
||||||
pconn := NewMockPacketConn(mockCtrl)
|
|
||||||
pconn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2)
|
|
||||||
pconn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
|
||||||
conn := testConn{PacketConn: pconn}
|
|
||||||
tracer := mocklogging.NewMockTracer(mockCtrl)
|
|
||||||
_, err := getMultiplexer().AddConn(conn, 8, srk, tracer)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
conn.counter++
|
|
||||||
_, err = getMultiplexer().AddConn(conn, 8, srk, tracer)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when adding an existing conn with a different connection ID length", func() {
|
|
||||||
conn := NewMockPacketConn(mockCtrl)
|
conn := NewMockPacketConn(mockCtrl)
|
||||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2)
|
||||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
|
getMultiplexer().AddConn(conn)
|
||||||
_, err := getMultiplexer().AddConn(conn, 5, nil, nil)
|
Expect(func() { getMultiplexer().AddConn(conn) }).To(Panic())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = getMultiplexer().AddConn(conn, 6, nil, nil)
|
|
||||||
Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when adding an existing conn with a different stateless rest key", func() {
|
|
||||||
srk1 := &StatelessResetKey{'f', 'o', 'o'}
|
|
||||||
srk2 := &StatelessResetKey{'b', 'a', 'r'}
|
|
||||||
conn := NewMockPacketConn(mockCtrl)
|
|
||||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
|
||||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
|
|
||||||
_, err := getMultiplexer().AddConn(conn, 7, srk1, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = getMultiplexer().AddConn(conn, 7, srk2, nil)
|
|
||||||
Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when adding an existing conn with different tracers", func() {
|
|
||||||
conn := NewMockPacketConn(mockCtrl)
|
|
||||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
|
||||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
|
|
||||||
_, err := getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl))
|
|
||||||
Expect(err).To(MatchError("cannot use different tracers on the same packet conn"))
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -5,28 +5,22 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"hash"
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
"github.com/quic-go/quic-go/internal/utils"
|
||||||
"github.com/quic-go/quic-go/internal/wire"
|
|
||||||
"github.com/quic-go/quic-go/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// rawConn is a connection that allow reading of a receivedPacket.
|
// rawConn is a connection that allow reading of a receivedPackeh.
|
||||||
type rawConn interface {
|
type rawConn interface {
|
||||||
ReadPacket() (*receivedPacket, error)
|
ReadPacket() (*receivedPacket, error)
|
||||||
WritePacket(b []byte, addr net.Addr, oob []byte) (int, error)
|
WritePacket(b []byte, addr net.Addr, oob []byte) (int, error)
|
||||||
LocalAddr() net.Addr
|
LocalAddr() net.Addr
|
||||||
|
SetReadDeadline(time.Time) error
|
||||||
io.Closer
|
io.Closer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,113 +30,49 @@ type closePacket struct {
|
||||||
info *packetInfo
|
info *packetInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// The packetHandlerMap stores packetHandlers, identified by connection ID.
|
type unknownPacketHandler interface {
|
||||||
// It is used:
|
handlePacket(*receivedPacket)
|
||||||
// * by the server to store connections
|
setCloseError(error)
|
||||||
// * when multiplexing outgoing connections to store clients
|
}
|
||||||
|
|
||||||
|
var errListenerAlreadySet = errors.New("listener already set")
|
||||||
|
|
||||||
type packetHandlerMap struct {
|
type packetHandlerMap struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
conn rawConn
|
|
||||||
connIDLen int
|
|
||||||
|
|
||||||
closeQueue chan closePacket
|
|
||||||
|
|
||||||
handlers map[protocol.ConnectionID]packetHandler
|
handlers map[protocol.ConnectionID]packetHandler
|
||||||
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
|
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
|
||||||
server unknownPacketHandler
|
|
||||||
|
|
||||||
listening chan struct{} // is closed when listen returns
|
|
||||||
closed bool
|
closed bool
|
||||||
|
closeChan chan struct{}
|
||||||
|
|
||||||
|
enqueueClosePacket func(closePacket)
|
||||||
|
|
||||||
deleteRetiredConnsAfter time.Duration
|
deleteRetiredConnsAfter time.Duration
|
||||||
|
|
||||||
statelessResetEnabled bool
|
|
||||||
statelessResetMutex sync.Mutex
|
statelessResetMutex sync.Mutex
|
||||||
statelessResetHasher hash.Hash
|
statelessResetHasher hash.Hash
|
||||||
|
|
||||||
tracer logging.Tracer
|
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ packetHandlerManager = &packetHandlerMap{}
|
var _ packetHandlerManager = &packetHandlerMap{}
|
||||||
|
|
||||||
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
|
func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
|
||||||
conn, ok := c.(interface{ SetReadBuffer(int) error })
|
h := &packetHandlerMap{
|
||||||
if !ok {
|
closeChan: make(chan struct{}),
|
||||||
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
|
|
||||||
}
|
|
||||||
size, err := inspectReadBuffer(c)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to determine receive buffer size: %w", err)
|
|
||||||
}
|
|
||||||
if size >= protocol.DesiredReceiveBufferSize {
|
|
||||||
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
|
|
||||||
return fmt.Errorf("failed to increase receive buffer size: %w", err)
|
|
||||||
}
|
|
||||||
newSize, err := inspectReadBuffer(c)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to determine receive buffer size: %w", err)
|
|
||||||
}
|
|
||||||
if newSize == size {
|
|
||||||
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
|
|
||||||
}
|
|
||||||
if newSize < protocol.DesiredReceiveBufferSize {
|
|
||||||
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
|
|
||||||
}
|
|
||||||
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// only print warnings about the UDP receive buffer size once
|
|
||||||
var receiveBufferWarningOnce sync.Once
|
|
||||||
|
|
||||||
func newPacketHandlerMap(
|
|
||||||
c net.PacketConn,
|
|
||||||
connIDLen int,
|
|
||||||
statelessResetKey *StatelessResetKey,
|
|
||||||
tracer logging.Tracer,
|
|
||||||
logger utils.Logger,
|
|
||||||
) (packetHandlerManager, error) {
|
|
||||||
if err := setReceiveBuffer(c, logger); err != nil {
|
|
||||||
if !strings.Contains(err.Error(), "use of closed network connection") {
|
|
||||||
receiveBufferWarningOnce.Do(func() {
|
|
||||||
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn, err := wrapConn(c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m := &packetHandlerMap{
|
|
||||||
conn: conn,
|
|
||||||
connIDLen: connIDLen,
|
|
||||||
listening: make(chan struct{}),
|
|
||||||
handlers: make(map[protocol.ConnectionID]packetHandler),
|
handlers: make(map[protocol.ConnectionID]packetHandler),
|
||||||
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
|
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
|
||||||
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
|
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
|
||||||
closeQueue: make(chan closePacket, 4),
|
enqueueClosePacket: enqueueClosePacket,
|
||||||
statelessResetEnabled: statelessResetKey != nil,
|
|
||||||
tracer: tracer,
|
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
if m.statelessResetEnabled {
|
if key != nil {
|
||||||
m.statelessResetHasher = hmac.New(sha256.New, statelessResetKey[:])
|
h.statelessResetHasher = hmac.New(sha256.New, key[:])
|
||||||
}
|
}
|
||||||
go m.listen()
|
if h.logger.Debug() {
|
||||||
go m.runCloseQueue()
|
go h.logUsage()
|
||||||
|
|
||||||
if logger.Debug() {
|
|
||||||
go m.logUsage()
|
|
||||||
}
|
}
|
||||||
return m, nil
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) logUsage() {
|
func (h *packetHandlerMap) logUsage() {
|
||||||
|
@ -150,7 +80,7 @@ func (h *packetHandlerMap) logUsage() {
|
||||||
var printedZero bool
|
var printedZero bool
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-h.listening:
|
case <-h.closeChan:
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
|
@ -233,12 +163,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
|
||||||
if connClosePacket != nil {
|
if connClosePacket != nil {
|
||||||
handler = newClosedLocalConn(
|
handler = newClosedLocalConn(
|
||||||
func(addr net.Addr, info *packetInfo) {
|
func(addr net.Addr, info *packetInfo) {
|
||||||
select {
|
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
|
||||||
case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}:
|
|
||||||
default:
|
|
||||||
// Oops, we're backlogged.
|
|
||||||
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
pers,
|
pers,
|
||||||
h.logger,
|
h.logger,
|
||||||
|
@ -265,17 +190,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) runCloseQueue() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-h.listening:
|
|
||||||
return
|
|
||||||
case p := <-h.closeQueue:
|
|
||||||
h.conn.WritePacket(p.payload, p.addr, p.info.OOB())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
|
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
h.resetTokens[token] = handler
|
h.resetTokens[token] = handler
|
||||||
|
@ -288,19 +202,16 @@ func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken)
|
||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
|
func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
h.server = s
|
defer h.mutex.Unlock()
|
||||||
h.mutex.Unlock()
|
|
||||||
|
handler, ok := h.resetTokens[token]
|
||||||
|
return handler, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) CloseServer() {
|
func (h *packetHandlerMap) CloseServer() {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
if h.server == nil {
|
|
||||||
h.mutex.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.server = nil
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, handler := range h.handlers {
|
for _, handler := range h.handlers {
|
||||||
if handler.getPerspective() == protocol.PerspectiveServer {
|
if handler.getPerspective() == protocol.PerspectiveServer {
|
||||||
|
@ -316,23 +227,16 @@ func (h *packetHandlerMap) CloseServer() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destroy closes the underlying connection and waits until listen() has returned.
|
func (h *packetHandlerMap) Close(e error) {
|
||||||
// It does not close active connections.
|
|
||||||
func (h *packetHandlerMap) Destroy() error {
|
|
||||||
if err := h.conn.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
<-h.listening // wait until listening returns
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *packetHandlerMap) close(e error) error {
|
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
|
|
||||||
if h.closed {
|
if h.closed {
|
||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
close(h.closeChan)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, handler := range h.handlers {
|
for _, handler := range h.handlers {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
@ -341,89 +245,14 @@ func (h *packetHandlerMap) close(e error) error {
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}(handler)
|
}(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.server != nil {
|
|
||||||
h.server.setCloseError(e)
|
|
||||||
}
|
|
||||||
h.closed = true
|
h.closed = true
|
||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return getMultiplexer().RemoveConn(h.conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *packetHandlerMap) listen() {
|
|
||||||
defer close(h.listening)
|
|
||||||
for {
|
|
||||||
p, err := h.conn.ReadPacket()
|
|
||||||
//nolint:staticcheck // SA1019 ignore this!
|
|
||||||
// TODO: This code is used to ignore wsa errors on Windows.
|
|
||||||
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
|
|
||||||
// See https://github.com/quic-go/quic-go/issues/1737 for details.
|
|
||||||
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
|
||||||
h.logger.Debugf("Temporary error reading from conn: %w", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
h.close(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.handlePacket(p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
|
|
||||||
connID, err := wire.ParseConnectionID(p.data, h.connIDLen)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
|
|
||||||
if h.tracer != nil {
|
|
||||||
h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
|
||||||
}
|
|
||||||
p.buffer.MaybeRelease()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if handler, ok := h.handlers[connID]; ok {
|
|
||||||
handler.handlePacket(p)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !wire.IsLongHeaderPacket(p.data[0]) {
|
|
||||||
go h.maybeSendStatelessReset(p, connID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if h.server == nil { // no server set
|
|
||||||
h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.server.handlePacket(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
|
|
||||||
// stateless resets are always short header packets
|
|
||||||
if wire.IsLongHeaderPacket(data[0]) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
|
|
||||||
if sess, ok := h.resetTokens[token]; ok {
|
|
||||||
h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
|
|
||||||
go sess.destroy(&StatelessResetError{Token: token})
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
|
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
|
||||||
var token protocol.StatelessResetToken
|
var token protocol.StatelessResetToken
|
||||||
if !h.statelessResetEnabled {
|
if h.statelessResetHasher == nil {
|
||||||
// Return a random stateless reset token.
|
// Return a random stateless reset token.
|
||||||
// This token will be sent in the server's transport parameters.
|
// This token will be sent in the server's transport parameters.
|
||||||
// By using a random token, an off-path attacker won't be able to disrupt the connection.
|
// By using a random token, an off-path attacker won't be able to disrupt the connection.
|
||||||
|
@ -437,24 +266,3 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID)
|
||||||
h.statelessResetMutex.Unlock()
|
h.statelessResetMutex.Unlock()
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
|
|
||||||
defer p.buffer.Release()
|
|
||||||
if !h.statelessResetEnabled {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Don't send a stateless reset in response to very small packets.
|
|
||||||
// This includes packets that could be stateless resets.
|
|
||||||
if len(p.data) <= protocol.MinStatelessResetSize {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
token := h.GetStatelessResetToken(connID)
|
|
||||||
h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
|
|
||||||
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
|
|
||||||
rand.Read(data)
|
|
||||||
data[0] = (data[0] & 0x7f) | 0x40
|
|
||||||
data = append(data, token[:]...)
|
|
||||||
if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
|
|
||||||
h.logger.Debugf("Error sending Stateless Reset: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -6,405 +6,188 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
"github.com/quic-go/quic-go/internal/utils"
|
||||||
"github.com/quic-go/quic-go/internal/wire"
|
|
||||||
"github.com/quic-go/quic-go/logging"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Packet Handler Map", func() {
|
var _ = Describe("Packet Handler Map", func() {
|
||||||
type packetToRead struct {
|
It("adds and gets", func() {
|
||||||
addr net.Addr
|
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||||
data []byte
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||||
err error
|
handler := NewMockPacketHandler(mockCtrl)
|
||||||
}
|
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||||
|
h, ok := m.Get(connID)
|
||||||
var (
|
Expect(ok).To(BeTrue())
|
||||||
handler *packetHandlerMap
|
Expect(h).To(Equal(handler))
|
||||||
conn *MockPacketConn
|
|
||||||
tracer *mocklogging.MockTracer
|
|
||||||
packetChan chan packetToRead
|
|
||||||
|
|
||||||
connIDLen int
|
|
||||||
statelessResetKey *StatelessResetKey
|
|
||||||
)
|
|
||||||
|
|
||||||
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
|
|
||||||
b, err := (&wire.ExtendedHeader{
|
|
||||||
Header: wire.Header{
|
|
||||||
Type: t,
|
|
||||||
DestConnectionID: connID,
|
|
||||||
Length: length,
|
|
||||||
Version: protocol.Version1,
|
|
||||||
},
|
|
||||||
PacketNumberLen: protocol.PacketNumberLen2,
|
|
||||||
}).Append(nil, protocol.Version1)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
getPacket := func(connID protocol.ConnectionID) []byte {
|
|
||||||
return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
statelessResetKey = nil
|
|
||||||
connIDLen = 0
|
|
||||||
tracer = mocklogging.NewMockTracer(mockCtrl)
|
|
||||||
packetChan = make(chan packetToRead, 10)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
JustBeforeEach(func() {
|
It("refused to add duplicates", func() {
|
||||||
conn = NewMockPacketConn(mockCtrl)
|
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||||
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
|
handler := NewMockPacketHandler(mockCtrl)
|
||||||
p, ok := <-packetChan
|
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||||
if !ok {
|
Expect(m.Add(connID, handler)).To(BeFalse())
|
||||||
return 0, nil, errors.New("closed")
|
})
|
||||||
|
|
||||||
|
It("removes", func() {
|
||||||
|
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||||
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||||
|
handler := NewMockPacketHandler(mockCtrl)
|
||||||
|
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||||
|
m.Remove(connID)
|
||||||
|
_, ok := m.Get(connID)
|
||||||
|
Expect(ok).To(BeFalse())
|
||||||
|
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("retires", func() {
|
||||||
|
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||||
|
dur := scaleDuration(50 * time.Millisecond)
|
||||||
|
m.deleteRetiredConnsAfter = dur
|
||||||
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||||
|
handler := NewMockPacketHandler(mockCtrl)
|
||||||
|
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||||
|
m.Retire(connID)
|
||||||
|
_, ok := m.Get(connID)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
time.Sleep(dur)
|
||||||
|
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("adds newly to-be-constructed handlers", func() {
|
||||||
|
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||||
|
var called bool
|
||||||
|
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||||
|
connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||||
|
Expect(m.AddWithConnID(connID1, connID2, func() packetHandler {
|
||||||
|
called = true
|
||||||
|
return NewMockPacketHandler(mockCtrl)
|
||||||
|
})).To(BeTrue())
|
||||||
|
Expect(called).To(BeTrue())
|
||||||
|
Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() packetHandler {
|
||||||
|
Fail("didn't expect the constructor to be executed")
|
||||||
|
return nil
|
||||||
|
})).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("adds, gets and removes reset tokens", func() {
|
||||||
|
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||||
|
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
|
||||||
|
handler := NewMockPacketHandler(mockCtrl)
|
||||||
|
m.AddResetToken(token, handler)
|
||||||
|
h, ok := m.GetByResetToken(token)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(h).To(Equal(h))
|
||||||
|
m.RemoveResetToken(token)
|
||||||
|
_, ok = m.GetByResetToken(token)
|
||||||
|
Expect(ok).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("generates stateless reset token, if no key is set", func() {
|
||||||
|
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||||
|
b := make([]byte, 8)
|
||||||
|
rand.Read(b)
|
||||||
|
connID := protocol.ParseConnectionID(b)
|
||||||
|
token := m.GetStatelessResetToken(connID)
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
to := m.GetStatelessResetToken(connID)
|
||||||
|
Expect(to).ToNot(Equal(token))
|
||||||
|
token = to
|
||||||
}
|
}
|
||||||
return copy(b, p.data), p.addr, p.err
|
})
|
||||||
}).AnyTimes()
|
|
||||||
phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger)
|
It("generates stateless reset token, if a key is set", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
var key StatelessResetKey
|
||||||
handler = phm.(*packetHandlerMap)
|
rand.Read(key[:])
|
||||||
|
m := newPacketHandlerMap(&key, nil, utils.DefaultLogger)
|
||||||
|
b := make([]byte, 8)
|
||||||
|
rand.Read(b)
|
||||||
|
connID := protocol.ParseConnectionID(b)
|
||||||
|
token := m.GetStatelessResetToken(connID)
|
||||||
|
Expect(token).ToNot(BeZero())
|
||||||
|
Expect(m.GetStatelessResetToken(connID)).To(Equal(token))
|
||||||
|
// generate a new connection ID
|
||||||
|
rand.Read(b)
|
||||||
|
connID2 := protocol.ParseConnectionID(b)
|
||||||
|
Expect(m.GetStatelessResetToken(connID2)).ToNot(Equal(token))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("replaces locally closed connections", func() {
|
||||||
|
var closePackets []closePacket
|
||||||
|
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
|
||||||
|
dur := scaleDuration(50 * time.Millisecond)
|
||||||
|
m.deleteRetiredConnsAfter = dur
|
||||||
|
|
||||||
|
handler := NewMockPacketHandler(mockCtrl)
|
||||||
|
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||||
|
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||||
|
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, []byte("foobar"))
|
||||||
|
h, ok := m.Get(connID)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(h).ToNot(Equal(handler))
|
||||||
|
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
||||||
|
h.handlePacket(&receivedPacket{remoteAddr: addr})
|
||||||
|
Expect(closePackets).To(HaveLen(1))
|
||||||
|
Expect(closePackets[0].addr).To(Equal(addr))
|
||||||
|
Expect(closePackets[0].payload).To(Equal([]byte("foobar")))
|
||||||
|
|
||||||
|
time.Sleep(dur)
|
||||||
|
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("replaces remote closed connections", func() {
|
||||||
|
var closePackets []closePacket
|
||||||
|
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
|
||||||
|
dur := scaleDuration(50 * time.Millisecond)
|
||||||
|
m.deleteRetiredConnsAfter = dur
|
||||||
|
|
||||||
|
handler := NewMockPacketHandler(mockCtrl)
|
||||||
|
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||||
|
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||||
|
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, nil)
|
||||||
|
h, ok := m.Get(connID)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(h).ToNot(Equal(handler))
|
||||||
|
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
||||||
|
h.handlePacket(&receivedPacket{remoteAddr: addr})
|
||||||
|
Expect(closePackets).To(BeEmpty())
|
||||||
|
|
||||||
|
time.Sleep(dur)
|
||||||
|
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("closes the server", func() {
|
||||||
|
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
conn := NewMockPacketHandler(mockCtrl)
|
||||||
|
if i%2 == 0 {
|
||||||
|
conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
|
||||||
|
} else {
|
||||||
|
conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
|
||||||
|
conn.EXPECT().shutdown()
|
||||||
|
}
|
||||||
|
b := make([]byte, 12)
|
||||||
|
rand.Read(b)
|
||||||
|
m.Add(protocol.ParseConnectionID(b), conn)
|
||||||
|
}
|
||||||
|
m.CloseServer()
|
||||||
})
|
})
|
||||||
|
|
||||||
It("closes", func() {
|
It("closes", func() {
|
||||||
getMultiplexer() // make the sync.Once execute
|
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||||
// replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
|
testErr := errors.New("shutdown")
|
||||||
mockMultiplexer := NewMockMultiplexer(mockCtrl)
|
for i := 0; i < 10; i++ {
|
||||||
origMultiplexer := connMuxer
|
|
||||||
connMuxer = mockMultiplexer
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
connMuxer = origMultiplexer
|
|
||||||
}()
|
|
||||||
|
|
||||||
testErr := errors.New("test error ")
|
|
||||||
conn1 := NewMockPacketHandler(mockCtrl)
|
|
||||||
conn1.EXPECT().destroy(testErr)
|
|
||||||
conn2 := NewMockPacketHandler(mockCtrl)
|
|
||||||
conn2.EXPECT().destroy(testErr)
|
|
||||||
handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), conn1)
|
|
||||||
handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), conn2)
|
|
||||||
mockMultiplexer.EXPECT().RemoveConn(gomock.Any())
|
|
||||||
handler.close(testErr)
|
|
||||||
close(packetChan)
|
|
||||||
Eventually(handler.listening).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("other operations", func() {
|
|
||||||
AfterEach(func() {
|
|
||||||
// delete connections and the server before closing
|
|
||||||
// They might be mock implementations, and we'd have to register the expected calls before otherwise.
|
|
||||||
handler.mutex.Lock()
|
|
||||||
for connID := range handler.handlers {
|
|
||||||
delete(handler.handlers, connID)
|
|
||||||
}
|
|
||||||
handler.server = nil
|
|
||||||
handler.mutex.Unlock()
|
|
||||||
conn.EXPECT().Close().MaxTimes(1)
|
|
||||||
close(packetChan)
|
|
||||||
handler.Destroy()
|
|
||||||
Eventually(handler.listening).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("handling packets", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
connIDLen = 5
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles packets for different packet handlers on the same packet conn", func() {
|
|
||||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
||||||
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
|
|
||||||
packetHandler1 := NewMockPacketHandler(mockCtrl)
|
|
||||||
packetHandler2 := NewMockPacketHandler(mockCtrl)
|
|
||||||
handledPacket1 := make(chan struct{})
|
|
||||||
handledPacket2 := make(chan struct{})
|
|
||||||
packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
|
||||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(connID).To(Equal(connID1))
|
|
||||||
close(handledPacket1)
|
|
||||||
})
|
|
||||||
packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
|
||||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(connID).To(Equal(connID2))
|
|
||||||
close(handledPacket2)
|
|
||||||
})
|
|
||||||
handler.Add(connID1, packetHandler1)
|
|
||||||
handler.Add(connID2, packetHandler2)
|
|
||||||
packetChan <- packetToRead{data: getPacket(connID1)}
|
|
||||||
packetChan <- packetToRead{data: getPacket(connID2)}
|
|
||||||
|
|
||||||
Eventually(handledPacket1).Should(BeClosed())
|
|
||||||
Eventually(handledPacket2).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("drops unparseable packets", func() {
|
|
||||||
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
|
||||||
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError)
|
|
||||||
handler.handlePacket(&receivedPacket{
|
|
||||||
buffer: getPacketBuffer(),
|
|
||||||
remoteAddr: addr,
|
|
||||||
data: []byte{0, 1, 2, 3},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("deletes removed connections immediately", func() {
|
|
||||||
handler.deleteRetiredConnsAfter = time.Hour
|
|
||||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
||||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
|
||||||
handler.Remove(connID)
|
|
||||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
|
||||||
// don't EXPECT any calls to handlePacket of the MockPacketHandler
|
|
||||||
})
|
|
||||||
|
|
||||||
It("deletes retired connection entries after a wait time", func() {
|
|
||||||
handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond)
|
|
||||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
||||||
conn := NewMockPacketHandler(mockCtrl)
|
conn := NewMockPacketHandler(mockCtrl)
|
||||||
handler.Add(connID, conn)
|
conn.EXPECT().destroy(testErr)
|
||||||
handler.Retire(connID)
|
b := make([]byte, 12)
|
||||||
time.Sleep(scaleDuration(30 * time.Millisecond))
|
rand.Read(b)
|
||||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
m.Add(protocol.ParseConnectionID(b), conn)
|
||||||
// don't EXPECT any calls to handlePacket of the MockPacketHandler
|
}
|
||||||
})
|
m.Close(testErr)
|
||||||
|
// check that Close can be called multiple times
|
||||||
It("passes packets arriving late for closed connections to that connection", func() {
|
m.Close(errors.New("close"))
|
||||||
handler.deleteRetiredConnsAfter = time.Hour
|
|
||||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
||||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
|
||||||
handled := make(chan struct{})
|
|
||||||
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
|
||||||
close(handled)
|
|
||||||
})
|
|
||||||
handler.Add(connID, packetHandler)
|
|
||||||
handler.Retire(connID)
|
|
||||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
|
||||||
Eventually(handled).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("drops packets for unknown receivers", func() {
|
|
||||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
||||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes the packet handlers when reading from the conn fails", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
|
||||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) {
|
|
||||||
Expect(e).To(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
})
|
|
||||||
handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler)
|
|
||||||
packetChan <- packetToRead{err: errors.New("read failed")}
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("continues listening for temporary errors", func() {
|
|
||||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
|
||||||
handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler)
|
|
||||||
err := deadlineError{}
|
|
||||||
Expect(err.Temporary()).To(BeTrue())
|
|
||||||
packetChan <- packetToRead{err: err}
|
|
||||||
// don't EXPECT any calls to packetHandler.destroy
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("says if a connection ID is already taken", func() {
|
|
||||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
||||||
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
|
|
||||||
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("says if a connection ID is already taken, for AddWithConnID", func() {
|
|
||||||
clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
||||||
newConnID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
|
||||||
newConnID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
|
||||||
Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue())
|
|
||||||
Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("running a server", func() {
|
|
||||||
It("adds a server", func() {
|
|
||||||
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
|
|
||||||
p := getPacket(connID)
|
|
||||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
|
||||||
server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
|
||||||
cid, err := wire.ParseConnectionID(p.data, 0)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cid).To(Equal(connID))
|
|
||||||
})
|
|
||||||
handler.SetServer(server)
|
|
||||||
handler.handlePacket(&receivedPacket{data: p})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes all server connections", func() {
|
|
||||||
handler.SetServer(NewMockUnknownPacketHandler(mockCtrl))
|
|
||||||
clientConn := NewMockPacketHandler(mockCtrl)
|
|
||||||
clientConn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
|
|
||||||
serverConn := NewMockPacketHandler(mockCtrl)
|
|
||||||
serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
|
|
||||||
serverConn.EXPECT().shutdown()
|
|
||||||
|
|
||||||
handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), clientConn)
|
|
||||||
handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), serverConn)
|
|
||||||
handler.CloseServer()
|
|
||||||
})
|
|
||||||
|
|
||||||
It("stops handling packets with unknown connection IDs after the server is closed", func() {
|
|
||||||
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
|
|
||||||
p := getPacket(connID)
|
|
||||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
|
||||||
// don't EXPECT any calls to server.handlePacket
|
|
||||||
handler.SetServer(server)
|
|
||||||
handler.CloseServer()
|
|
||||||
handler.handlePacket(&receivedPacket{data: p})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("stateless resets", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
connIDLen = 5
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("handling", func() {
|
|
||||||
It("handles stateless resets", func() {
|
|
||||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
|
||||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
|
||||||
handler.AddResetToken(token, packetHandler)
|
|
||||||
destroyed := make(chan struct{})
|
|
||||||
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
|
|
||||||
packet = append(packet, token[:]...)
|
|
||||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
defer close(destroyed)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
var resetErr *StatelessResetError
|
|
||||||
Expect(errors.As(err, &resetErr)).To(BeTrue())
|
|
||||||
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
|
|
||||||
Expect(resetErr.Token).To(Equal(token))
|
|
||||||
})
|
|
||||||
packetChan <- packetToRead{data: packet}
|
|
||||||
Eventually(destroyed).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles stateless resets for 0-length connection IDs", func() {
|
|
||||||
handler.connIDLen = 0
|
|
||||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
|
||||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
|
||||||
handler.AddResetToken(token, packetHandler)
|
|
||||||
destroyed := make(chan struct{})
|
|
||||||
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
|
|
||||||
packet = append(packet, token[:]...)
|
|
||||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
var resetErr *StatelessResetError
|
|
||||||
Expect(errors.As(err, &resetErr)).To(BeTrue())
|
|
||||||
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
|
|
||||||
Expect(resetErr.Token).To(Equal(token))
|
|
||||||
close(destroyed)
|
|
||||||
})
|
|
||||||
packetChan <- packetToRead{data: packet}
|
|
||||||
Eventually(destroyed).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("removes reset tokens", func() {
|
|
||||||
connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42})
|
|
||||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
|
||||||
handler.Add(connID, packetHandler)
|
|
||||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
|
||||||
handler.AddResetToken(token, NewMockPacketHandler(mockCtrl))
|
|
||||||
handler.RemoveResetToken(token)
|
|
||||||
// don't EXPECT any call to packetHandler.destroy()
|
|
||||||
packetHandler.EXPECT().handlePacket(gomock.Any())
|
|
||||||
p := append([]byte{0x40} /* short header packet */, connID.Bytes()...)
|
|
||||||
p = append(p, make([]byte, 50)...)
|
|
||||||
p = append(p, token[:]...)
|
|
||||||
|
|
||||||
handler.handlePacket(&receivedPacket{data: p})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ignores packets too small to contain a stateless reset", func() {
|
|
||||||
handler.connIDLen = 0
|
|
||||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
|
||||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
|
||||||
handler.AddResetToken(token, packetHandler)
|
|
||||||
done := make(chan struct{})
|
|
||||||
// don't EXPECT any calls here, but register the closing of the done channel
|
|
||||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) {
|
|
||||||
close(done)
|
|
||||||
}).AnyTimes()
|
|
||||||
packetChan <- packetToRead{data: append([]byte{0x40} /* short header packet */, token[:15]...)}
|
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("generating", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
var key StatelessResetKey
|
|
||||||
rand.Read(key[:])
|
|
||||||
statelessResetKey = &key
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates stateless reset tokens", func() {
|
|
||||||
connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
|
|
||||||
connID2 := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})
|
|
||||||
Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sends stateless resets", func() {
|
|
||||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
|
||||||
p := append([]byte{40}, make([]byte, 100)...)
|
|
||||||
done := make(chan struct{})
|
|
||||||
conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) {
|
|
||||||
defer close(done)
|
|
||||||
Expect(wire.IsLongHeaderPacket(b[0])).To(BeFalse()) // short header packet
|
|
||||||
Expect(b).To(HaveLen(protocol.MinStatelessResetSize))
|
|
||||||
})
|
|
||||||
handler.handlePacket(&receivedPacket{
|
|
||||||
buffer: getPacketBuffer(),
|
|
||||||
remoteAddr: addr,
|
|
||||||
data: p,
|
|
||||||
})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't send stateless resets for small packets", func() {
|
|
||||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
|
||||||
p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...)
|
|
||||||
handler.handlePacket(&receivedPacket{
|
|
||||||
buffer: getPacketBuffer(),
|
|
||||||
remoteAddr: addr,
|
|
||||||
data: p,
|
|
||||||
})
|
|
||||||
// make sure there are no Write calls on the packet conn
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("if no key is configured", func() {
|
|
||||||
It("doesn't send stateless resets", func() {
|
|
||||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
|
||||||
p := append([]byte{40}, make([]byte, 100)...)
|
|
||||||
handler.handlePacket(&receivedPacket{
|
|
||||||
buffer: getPacketBuffer(),
|
|
||||||
remoteAddr: addr,
|
|
||||||
data: p,
|
|
||||||
})
|
|
||||||
// make sure there are no Write calls on the packet conn
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"runtime/pprof"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -29,6 +32,20 @@ var _ = BeforeSuite(func() {
|
||||||
log.SetOutput(io.Discard)
|
log.SetOutput(io.Discard)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
func areServersRunning() bool {
|
||||||
|
var b bytes.Buffer
|
||||||
|
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||||
|
return strings.Contains(b.String(), "quic-go.(*baseServer).run")
|
||||||
|
}
|
||||||
|
|
||||||
|
func areTransportsRunning() bool {
|
||||||
|
var b bytes.Buffer
|
||||||
|
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||||
|
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
|
||||||
|
}
|
||||||
|
|
||||||
var _ = AfterEach(func() {
|
var _ = AfterEach(func() {
|
||||||
mockCtrl.Finish()
|
mockCtrl.Finish()
|
||||||
|
Eventually(areServersRunning).Should(BeFalse())
|
||||||
|
Eventually(areTransportsRunning()).Should(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
92
server.go
92
server.go
|
@ -20,7 +20,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close.
|
// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close.
|
||||||
var ErrServerClosed = errors.New("quic: Server closed")
|
var ErrServerClosed = errors.New("quic: server closed")
|
||||||
|
|
||||||
// packetHandler handles packets
|
// packetHandler handles packets
|
||||||
type packetHandler interface {
|
type packetHandler interface {
|
||||||
|
@ -30,18 +30,13 @@ type packetHandler interface {
|
||||||
getPerspective() protocol.Perspective
|
getPerspective() protocol.Perspective
|
||||||
}
|
}
|
||||||
|
|
||||||
type unknownPacketHandler interface {
|
|
||||||
handlePacket(*receivedPacket)
|
|
||||||
setCloseError(error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type packetHandlerManager interface {
|
type packetHandlerManager interface {
|
||||||
Get(protocol.ConnectionID) (packetHandler, bool)
|
Get(protocol.ConnectionID) (packetHandler, bool)
|
||||||
|
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
|
||||||
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
|
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
|
||||||
Destroy() error
|
Close(error)
|
||||||
connRunner
|
|
||||||
SetServer(unknownPacketHandler)
|
|
||||||
CloseServer()
|
CloseServer()
|
||||||
|
connRunner
|
||||||
}
|
}
|
||||||
|
|
||||||
type quicConn interface {
|
type quicConn interface {
|
||||||
|
@ -70,13 +65,11 @@ type baseServer struct {
|
||||||
config *Config
|
config *Config
|
||||||
|
|
||||||
conn rawConn
|
conn rawConn
|
||||||
// If the server is started with ListenAddr, we create a packet conn.
|
|
||||||
// If it is started with Listen, we take a packet conn as a parameter.
|
|
||||||
createdPacketConn bool
|
|
||||||
|
|
||||||
tokenGenerator *handshake.TokenGenerator
|
tokenGenerator *handshake.TokenGenerator
|
||||||
|
|
||||||
connHandler packetHandlerManager
|
connHandler packetHandlerManager
|
||||||
|
onClose func()
|
||||||
|
|
||||||
receivedPackets chan *receivedPacket
|
receivedPackets chan *receivedPacket
|
||||||
|
|
||||||
|
@ -114,8 +107,6 @@ type baseServer struct {
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ unknownPacketHandler = &baseServer{}
|
|
||||||
|
|
||||||
// A Listener listens for incoming QUIC connections.
|
// A Listener listens for incoming QUIC connections.
|
||||||
// It returns connections once the handshake has completed.
|
// It returns connections once the handshake has completed.
|
||||||
type Listener struct {
|
type Listener struct {
|
||||||
|
@ -166,37 +157,36 @@ func (l *EarlyListener) Addr() net.Addr {
|
||||||
// The tls.Config must not be nil and must contain a certificate configuration.
|
// The tls.Config must not be nil and must contain a certificate configuration.
|
||||||
// The quic.Config may be nil, in that case the default values will be used.
|
// The quic.Config may be nil, in that case the default values will be used.
|
||||||
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) {
|
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) {
|
||||||
s, err := listenAddr(addr, tlsConf, config, false)
|
conn, err := listenUDP(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Listener{baseServer: s}, nil
|
return (&Transport{
|
||||||
|
Conn: conn,
|
||||||
|
createdConn: true,
|
||||||
|
isSingleUse: true,
|
||||||
|
}).Listen(tlsConf, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes.
|
// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes.
|
||||||
func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
|
func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
|
||||||
s, err := listenAddr(addr, tlsConf, config, true)
|
conn, err := listenUDP(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &EarlyListener{baseServer: s}, nil
|
return (&Transport{
|
||||||
|
Conn: conn,
|
||||||
|
createdConn: true,
|
||||||
|
isSingleUse: true,
|
||||||
|
}).ListenEarly(tlsConf, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) {
|
func listenUDP(addr string) (*net.UDPConn, error) {
|
||||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
conn, err := net.ListenUDP("udp", udpAddr)
|
return net.ListenUDP("udp", udpAddr)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
serv, err := listen(conn, tlsConf, config, acceptEarly)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
serv.createdPacketConn = true
|
|
||||||
return serv, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen listens for QUIC connections on a given net.PacketConn. If the
|
// Listen listens for QUIC connections on a given net.PacketConn. If the
|
||||||
|
@ -210,45 +200,23 @@ func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bo
|
||||||
// Furthermore, it must define an application control (using NextProtos).
|
// Furthermore, it must define an application control (using NextProtos).
|
||||||
// The quic.Config may be nil, in that case the default values will be used.
|
// The quic.Config may be nil, in that case the default values will be used.
|
||||||
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) {
|
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) {
|
||||||
s, err := listen(conn, tlsConf, config, false)
|
tr := &Transport{Conn: conn, isSingleUse: true}
|
||||||
if err != nil {
|
return tr.Listen(tlsConf, config)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Listener{baseServer: s}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenEarly works like Listen, but it returns connections before the handshake completes.
|
// ListenEarly works like Listen, but it returns connections before the handshake completes.
|
||||||
func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
|
func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
|
||||||
s, err := listen(conn, tlsConf, config, true)
|
tr := &Transport{Conn: conn, isSingleUse: true}
|
||||||
if err != nil {
|
return tr.ListenEarly(tlsConf, config)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &EarlyListener{baseServer: s}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) {
|
func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) {
|
||||||
if tlsConf == nil {
|
|
||||||
return nil, errors.New("quic: tls.Config not set")
|
|
||||||
}
|
|
||||||
if err := validateConfig(config); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
config = populateServerConfig(config)
|
|
||||||
|
|
||||||
connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
|
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c, err := wrapConn(conn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
s := &baseServer{
|
s := &baseServer{
|
||||||
conn: c,
|
conn: conn,
|
||||||
tlsConf: tlsConf,
|
tlsConf: tlsConf,
|
||||||
config: config,
|
config: config,
|
||||||
tokenGenerator: tokenGenerator,
|
tokenGenerator: tokenGenerator,
|
||||||
|
@ -260,12 +228,12 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
|
||||||
newConn: newConnection,
|
newConn: newConnection,
|
||||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||||
acceptEarlyConns: acceptEarly,
|
acceptEarlyConns: acceptEarly,
|
||||||
|
onClose: onClose,
|
||||||
}
|
}
|
||||||
if acceptEarly {
|
if acceptEarly {
|
||||||
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
|
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
|
||||||
}
|
}
|
||||||
go s.run()
|
go s.run()
|
||||||
connHandler.SetServer(s)
|
|
||||||
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
|
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
@ -317,18 +285,12 @@ func (s *baseServer) Close() error {
|
||||||
if s.serverError == nil {
|
if s.serverError == nil {
|
||||||
s.serverError = ErrServerClosed
|
s.serverError = ErrServerClosed
|
||||||
}
|
}
|
||||||
// If the server was started with ListenAddr, we created the packet conn.
|
|
||||||
// We need to close it in order to make the go routine reading from that conn return.
|
|
||||||
createdPacketConn := s.createdPacketConn
|
|
||||||
s.closed = true
|
s.closed = true
|
||||||
close(s.errorChan)
|
close(s.errorChan)
|
||||||
s.mutex.Unlock()
|
s.mutex.Unlock()
|
||||||
|
|
||||||
<-s.running
|
<-s.running
|
||||||
s.connHandler.CloseServer()
|
s.onClose()
|
||||||
if createdPacketConn {
|
|
||||||
return s.connHandler.Destroy()
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,12 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime/pprof"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
@ -24,17 +21,10 @@ import (
|
||||||
"github.com/quic-go/quic-go/logging"
|
"github.com/quic-go/quic-go/logging"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
func areServersRunning() bool {
|
|
||||||
var b bytes.Buffer
|
|
||||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
|
||||||
return strings.Contains(b.String(), "quic-go.(*baseServer).run")
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Server", func() {
|
var _ = Describe("Server", func() {
|
||||||
var (
|
var (
|
||||||
conn *MockPacketConn
|
conn *MockPacketConn
|
||||||
|
@ -96,15 +86,19 @@ var _ = Describe("Server", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
conn = NewMockPacketConn(mockCtrl)
|
conn = NewMockPacketConn(mockCtrl)
|
||||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
wait := make(chan struct{})
|
||||||
|
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) {
|
||||||
|
<-wait
|
||||||
|
return 0, nil, errors.New("done")
|
||||||
|
}).MaxTimes(1)
|
||||||
|
conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) {
|
||||||
|
close(wait)
|
||||||
|
conn.EXPECT().SetReadDeadline(time.Time{})
|
||||||
|
}).MaxTimes(1)
|
||||||
tlsConf = testdata.GetTLSConfig()
|
tlsConf = testdata.GetTLSConfig()
|
||||||
tlsConf.NextProtos = []string{"proto1"}
|
tlsConf.NextProtos = []string{"proto1"}
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Eventually(areServersRunning).Should(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when no tls.Config is given", func() {
|
It("errors when no tls.Config is given", func() {
|
||||||
_, err := ListenAddr("localhost:0", nil, nil)
|
_, err := ListenAddr("localhost:0", nil, nil)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
|
@ -178,6 +172,7 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
Context("server accepting connections that completed the handshake", func() {
|
Context("server accepting connections that completed the handshake", func() {
|
||||||
var (
|
var (
|
||||||
|
ln *Listener
|
||||||
serv *baseServer
|
serv *baseServer
|
||||||
phm *MockPacketHandlerManager
|
phm *MockPacketHandlerManager
|
||||||
tracer *mocklogging.MockTracer
|
tracer *mocklogging.MockTracer
|
||||||
|
@ -185,7 +180,8 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
tracer = mocklogging.NewMockTracer(mockCtrl)
|
tracer = mocklogging.NewMockTracer(mockCtrl)
|
||||||
ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer})
|
var err error
|
||||||
|
ln, err = Listen(conn, tlsConf, &Config{Tracer: tracer})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
serv = ln.baseServer
|
serv = ln.baseServer
|
||||||
phm = NewMockPacketHandlerManager(mockCtrl)
|
phm = NewMockPacketHandlerManager(mockCtrl)
|
||||||
|
@ -193,8 +189,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
phm.EXPECT().CloseServer().MaxTimes(1)
|
ln.Close()
|
||||||
serv.Close()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("handling packets", func() {
|
Context("handling packets", func() {
|
||||||
|
@ -753,8 +748,7 @@ var _ = Describe("Server", func() {
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
|
|
||||||
// make the go routine return
|
// make the go routine return
|
||||||
phm.EXPECT().CloseServer()
|
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
|
||||||
conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
|
|
||||||
Expect(serv.Close()).To(Succeed())
|
Expect(serv.Close()).To(Succeed())
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -968,6 +962,7 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
serv.setCloseError(testErr)
|
serv.setCloseError(testErr)
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
|
serv.onClose() // shutdown
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns immediately, if an error occurred before", func() {
|
It("returns immediately, if an error occurred before", func() {
|
||||||
|
@ -977,6 +972,7 @@ var _ = Describe("Server", func() {
|
||||||
_, err := serv.Accept(context.Background())
|
_, err := serv.Accept(context.Background())
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
}
|
}
|
||||||
|
serv.onClose() // shutdown
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns when the context is canceled", func() {
|
It("returns when the context is canceled", func() {
|
||||||
|
@ -1064,7 +1060,6 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
phm.EXPECT().CloseServer().MaxTimes(1)
|
|
||||||
serv.Close()
|
serv.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -1234,8 +1229,7 @@ var _ = Describe("Server", func() {
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
|
|
||||||
// make the go routine return
|
// make the go routine return
|
||||||
phm.EXPECT().CloseServer()
|
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
|
||||||
conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
|
|
||||||
Expect(serv.Close()).To(Succeed())
|
Expect(serv.Close()).To(Succeed())
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
410
transport.go
Normal file
410
transport.go
Normal file
|
@ -0,0 +1,410 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/quic-go/quic-go/internal/wire"
|
||||||
|
|
||||||
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
|
"github.com/quic-go/quic-go/internal/utils"
|
||||||
|
"github.com/quic-go/quic-go/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Transport struct {
|
||||||
|
// A single net.PacketConn can only be handled by one Transport.
|
||||||
|
// Bad things will happen if passed to multiple Transports.
|
||||||
|
//
|
||||||
|
// If the connection satisfies the OOBCapablePacketConn interface
|
||||||
|
// (as a net.UDPConn does), ECN and packet info support will be enabled.
|
||||||
|
// In this case, optimized syscalls might be used, skipping the
|
||||||
|
// ReadFrom and WriteTo calls to read / write packets.
|
||||||
|
Conn net.PacketConn
|
||||||
|
|
||||||
|
// The length of the connection ID in bytes.
|
||||||
|
// It can be 0, or any value between 4 and 18.
|
||||||
|
// If unset, a 4 byte connection ID will be used.
|
||||||
|
ConnectionIDLength int
|
||||||
|
|
||||||
|
// Use for generating new connection IDs.
|
||||||
|
// This allows the application to control of the connection IDs used,
|
||||||
|
// which allows routing / load balancing based on connection IDs.
|
||||||
|
// All Connection IDs returned by the ConnectionIDGenerator MUST
|
||||||
|
// have the same length.
|
||||||
|
ConnectionIDGenerator ConnectionIDGenerator
|
||||||
|
|
||||||
|
// The StatelessResetKey is used to generate stateless reset tokens.
|
||||||
|
// If no key is configured, sending of stateless resets is disabled.
|
||||||
|
StatelessResetKey *StatelessResetKey
|
||||||
|
|
||||||
|
// A Tracer traces events that don't belong to a single QUIC connection.
|
||||||
|
Tracer logging.Tracer
|
||||||
|
|
||||||
|
handlerMap packetHandlerManager
|
||||||
|
|
||||||
|
mutex sync.Mutex
|
||||||
|
initOnce sync.Once
|
||||||
|
initErr error
|
||||||
|
|
||||||
|
// Set in init.
|
||||||
|
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
|
||||||
|
connIDLen int
|
||||||
|
|
||||||
|
server unknownPacketHandler
|
||||||
|
|
||||||
|
conn rawConn
|
||||||
|
|
||||||
|
closeQueue chan closePacket
|
||||||
|
|
||||||
|
listening chan struct{} // is closed when listen returns
|
||||||
|
closed bool
|
||||||
|
createdConn bool
|
||||||
|
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
|
||||||
|
|
||||||
|
logger utils.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen starts listening for incoming QUIC connections.
|
||||||
|
// There can only be a single listener on any net.PacketConn.
|
||||||
|
// Listen may only be called again after the current Listener was closed.
|
||||||
|
func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) {
|
||||||
|
if tlsConf == nil {
|
||||||
|
return nil, errors.New("quic: tls.Config not set")
|
||||||
|
}
|
||||||
|
if err := validateConfig(conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
|
if t.server != nil {
|
||||||
|
return nil, errListenerAlreadySet
|
||||||
|
}
|
||||||
|
conf = populateServerConfig(conf)
|
||||||
|
if err := t.init(conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.server = s
|
||||||
|
return &Listener{baseServer: s}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenEarly starts listening for incoming QUIC connections.
|
||||||
|
// There can only be a single listener on any net.PacketConn.
|
||||||
|
// Listen may only be called again after the current Listener was closed.
|
||||||
|
func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) {
|
||||||
|
if tlsConf == nil {
|
||||||
|
return nil, errors.New("quic: tls.Config not set")
|
||||||
|
}
|
||||||
|
if err := validateConfig(conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
|
if t.server != nil {
|
||||||
|
return nil, errListenerAlreadySet
|
||||||
|
}
|
||||||
|
conf = populateServerConfig(conf)
|
||||||
|
if err := t.init(conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.server = s
|
||||||
|
return &EarlyListener{baseServer: s}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial dials a new connection to a remote host (not using 0-RTT).
|
||||||
|
func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||||
|
if err := validateConfig(conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conf = populateClientConfig(conf, t.createdConn)
|
||||||
|
if err := t.init(conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var onClose func()
|
||||||
|
if t.isSingleUse {
|
||||||
|
onClose = func() { t.Close() }
|
||||||
|
}
|
||||||
|
return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
|
||||||
|
func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||||
|
if err := validateConfig(conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conf = populateClientConfig(conf, t.createdConn)
|
||||||
|
if err := t.init(conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var onClose func()
|
||||||
|
if t.isSingleUse {
|
||||||
|
onClose = func() { t.Close() }
|
||||||
|
}
|
||||||
|
return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
|
||||||
|
conn, ok := c.(interface{ SetReadBuffer(int) error })
|
||||||
|
if !ok {
|
||||||
|
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
|
||||||
|
}
|
||||||
|
size, err := inspectReadBuffer(c)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to determine receive buffer size: %w", err)
|
||||||
|
}
|
||||||
|
if size >= protocol.DesiredReceiveBufferSize {
|
||||||
|
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
|
||||||
|
return fmt.Errorf("failed to increase receive buffer size: %w", err)
|
||||||
|
}
|
||||||
|
newSize, err := inspectReadBuffer(c)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to determine receive buffer size: %w", err)
|
||||||
|
}
|
||||||
|
if newSize == size {
|
||||||
|
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
|
||||||
|
}
|
||||||
|
if newSize < protocol.DesiredReceiveBufferSize {
|
||||||
|
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
|
||||||
|
}
|
||||||
|
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// only print warnings about the UDP receive buffer size once
|
||||||
|
var receiveBufferWarningOnce sync.Once
|
||||||
|
|
||||||
|
func (t *Transport) init(conf *Config) error {
|
||||||
|
t.initOnce.Do(func() {
|
||||||
|
getMultiplexer().AddConn(t.Conn)
|
||||||
|
|
||||||
|
conn, err := wrapConn(t.Conn)
|
||||||
|
if err != nil {
|
||||||
|
t.initErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.StatelessResetKey = conf.StatelessResetKey
|
||||||
|
t.Tracer = conf.Tracer
|
||||||
|
t.ConnectionIDLength = conf.ConnectionIDLength
|
||||||
|
t.ConnectionIDGenerator = conf.ConnectionIDGenerator
|
||||||
|
|
||||||
|
t.logger = utils.DefaultLogger // TODO: make this configurable
|
||||||
|
t.conn = conn
|
||||||
|
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
|
||||||
|
t.listening = make(chan struct{})
|
||||||
|
|
||||||
|
t.closeQueue = make(chan closePacket, 4)
|
||||||
|
|
||||||
|
if t.ConnectionIDGenerator != nil {
|
||||||
|
t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
|
||||||
|
} else {
|
||||||
|
t.connIDLen = t.ConnectionIDLength
|
||||||
|
}
|
||||||
|
|
||||||
|
go t.listen(conn)
|
||||||
|
go t.runCloseQueue()
|
||||||
|
})
|
||||||
|
return t.initErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) enqueueClosePacket(p closePacket) {
|
||||||
|
select {
|
||||||
|
case t.closeQueue <- p:
|
||||||
|
default:
|
||||||
|
// Oops, we're backlogged.
|
||||||
|
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) runCloseQueue() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.listening:
|
||||||
|
return
|
||||||
|
case p := <-t.closeQueue:
|
||||||
|
t.conn.WritePacket(p.payload, p.addr, p.info.OOB())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying connection and waits until listen has returned.
|
||||||
|
// It is invalid to start new listeners or connections after that.
|
||||||
|
func (t *Transport) Close() error {
|
||||||
|
t.close(errors.New("closing"))
|
||||||
|
if t.createdConn {
|
||||||
|
if err := t.conn.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.conn.SetReadDeadline(time.Now())
|
||||||
|
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
|
||||||
|
}
|
||||||
|
<-t.listening // wait until listening returns
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) closeServer() {
|
||||||
|
t.handlerMap.CloseServer()
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.server = nil
|
||||||
|
if t.isSingleUse {
|
||||||
|
t.closed = true
|
||||||
|
}
|
||||||
|
t.mutex.Unlock()
|
||||||
|
if t.createdConn {
|
||||||
|
t.Conn.Close()
|
||||||
|
}
|
||||||
|
if t.isSingleUse {
|
||||||
|
t.conn.SetReadDeadline(time.Now())
|
||||||
|
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
|
||||||
|
<-t.listening // wait until listening returns
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) close(e error) {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
if t.closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.handlerMap.Close(e)
|
||||||
|
if t.server != nil {
|
||||||
|
t.server.setCloseError(e)
|
||||||
|
}
|
||||||
|
t.closed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) listen(conn rawConn) {
|
||||||
|
defer close(t.listening)
|
||||||
|
defer getMultiplexer().RemoveConn(t.Conn)
|
||||||
|
|
||||||
|
if err := setReceiveBuffer(t.Conn, t.logger); err != nil {
|
||||||
|
if !strings.Contains(err.Error(), "use of closed network connection") {
|
||||||
|
receiveBufferWarningOnce.Do(func() {
|
||||||
|
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
p, err := conn.ReadPacket()
|
||||||
|
//nolint:staticcheck // SA1019 ignore this!
|
||||||
|
// TODO: This code is used to ignore wsa errors on Windows.
|
||||||
|
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
|
||||||
|
// See https://github.com/quic-go/quic-go/issues/1737 for details.
|
||||||
|
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
||||||
|
t.mutex.Lock()
|
||||||
|
closed := t.closed
|
||||||
|
t.mutex.Unlock()
|
||||||
|
if closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.logger.Debugf("Temporary error reading from conn: %w", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.close(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.handlePacket(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) handlePacket(p *receivedPacket) {
|
||||||
|
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
|
||||||
|
if err != nil {
|
||||||
|
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
|
||||||
|
if t.Tracer != nil {
|
||||||
|
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||||
|
}
|
||||||
|
p.buffer.MaybeRelease()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if handler, ok := t.handlerMap.Get(connID); ok {
|
||||||
|
handler.handlePacket(p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !wire.IsLongHeaderPacket(p.data[0]) {
|
||||||
|
go t.maybeSendStatelessReset(p, connID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
if t.server == nil { // no server set
|
||||||
|
t.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.server.handlePacket(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
|
||||||
|
defer p.buffer.Release()
|
||||||
|
if t.StatelessResetKey == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Don't send a stateless reset in response to very small packets.
|
||||||
|
// This includes packets that could be stateless resets.
|
||||||
|
if len(p.data) <= protocol.MinStatelessResetSize {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token := t.handlerMap.GetStatelessResetToken(connID)
|
||||||
|
t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
|
||||||
|
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
|
||||||
|
rand.Read(data)
|
||||||
|
data[0] = (data[0] & 0x7f) | 0x40
|
||||||
|
data = append(data, token[:]...)
|
||||||
|
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
|
||||||
|
t.logger.Debugf("Error sending Stateless Reset: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
|
||||||
|
// stateless resets are always short header packets
|
||||||
|
if wire.IsLongHeaderPacket(data[0]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
|
||||||
|
if conn, ok := t.handlerMap.GetByResetToken(token); ok {
|
||||||
|
t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
|
||||||
|
go conn.destroy(&StatelessResetError{Token: token})
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
287
transport_test.go
Normal file
287
transport_test.go
Normal file
|
@ -0,0 +1,287 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||||
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
|
"github.com/quic-go/quic-go/internal/wire"
|
||||||
|
"github.com/quic-go/quic-go/logging"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Transport", func() {
|
||||||
|
type packetToRead struct {
|
||||||
|
addr net.Addr
|
||||||
|
data []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
|
||||||
|
b, err := (&wire.ExtendedHeader{
|
||||||
|
Header: wire.Header{
|
||||||
|
Type: t,
|
||||||
|
DestConnectionID: connID,
|
||||||
|
Length: length,
|
||||||
|
Version: protocol.Version1,
|
||||||
|
},
|
||||||
|
PacketNumberLen: protocol.PacketNumberLen2,
|
||||||
|
}).Append(nil, protocol.Version1)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
getPacket := func(connID protocol.ConnectionID) []byte {
|
||||||
|
return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn {
|
||||||
|
conn := NewMockPacketConn(mockCtrl)
|
||||||
|
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||||
|
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
|
||||||
|
p, ok := <-packetChan
|
||||||
|
if !ok {
|
||||||
|
return 0, nil, errors.New("closed")
|
||||||
|
}
|
||||||
|
return copy(b, p.data), p.addr, p.err
|
||||||
|
}).AnyTimes()
|
||||||
|
// for shutdown
|
||||||
|
conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
It("handles packets for different packet handlers on the same packet conn", func() {
|
||||||
|
packetChan := make(chan packetToRead)
|
||||||
|
tr := &Transport{Conn: newMockPacketConn(packetChan)}
|
||||||
|
tr.init(&Config{})
|
||||||
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||||
|
tr.handlerMap = phm
|
||||||
|
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||||
|
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
|
||||||
|
|
||||||
|
handled := make(chan struct{}, 2)
|
||||||
|
phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
|
||||||
|
h := NewMockPacketHandler(mockCtrl)
|
||||||
|
h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(connID).To(Equal(connID1))
|
||||||
|
handled <- struct{}{}
|
||||||
|
})
|
||||||
|
return h, true
|
||||||
|
})
|
||||||
|
phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
|
||||||
|
h := NewMockPacketHandler(mockCtrl)
|
||||||
|
h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(connID).To(Equal(connID2))
|
||||||
|
handled <- struct{}{}
|
||||||
|
})
|
||||||
|
return h, true
|
||||||
|
})
|
||||||
|
|
||||||
|
packetChan <- packetToRead{data: getPacket(connID1)}
|
||||||
|
packetChan <- packetToRead{data: getPacket(connID2)}
|
||||||
|
|
||||||
|
Eventually(handled).Should(Receive())
|
||||||
|
Eventually(handled).Should(Receive())
|
||||||
|
|
||||||
|
// shutdown
|
||||||
|
phm.EXPECT().Close(gomock.Any())
|
||||||
|
close(packetChan)
|
||||||
|
tr.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("closes listeners", func() {
|
||||||
|
packetChan := make(chan packetToRead)
|
||||||
|
tr := &Transport{Conn: newMockPacketConn(packetChan)}
|
||||||
|
defer tr.Close()
|
||||||
|
ln, err := tr.Listen(&tls.Config{}, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||||
|
tr.handlerMap = phm
|
||||||
|
|
||||||
|
phm.EXPECT().CloseServer()
|
||||||
|
Expect(ln.Close()).To(Succeed())
|
||||||
|
|
||||||
|
// shutdown
|
||||||
|
phm.EXPECT().Close(gomock.Any())
|
||||||
|
close(packetChan)
|
||||||
|
tr.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("drops unparseable packets", func() {
|
||||||
|
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
||||||
|
packetChan := make(chan packetToRead)
|
||||||
|
tracer := mocklogging.NewMockTracer(mockCtrl)
|
||||||
|
tr := &Transport{
|
||||||
|
Conn: newMockPacketConn(packetChan),
|
||||||
|
}
|
||||||
|
tr.init(&Config{Tracer: tracer, ConnectionIDLength: 10})
|
||||||
|
dropped := make(chan struct{})
|
||||||
|
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
|
||||||
|
packetChan <- packetToRead{
|
||||||
|
addr: addr,
|
||||||
|
data: []byte{0, 1, 2, 3},
|
||||||
|
}
|
||||||
|
Eventually(dropped).Should(BeClosed())
|
||||||
|
|
||||||
|
// shutdown
|
||||||
|
close(packetChan)
|
||||||
|
tr.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("closes when reading from the conn fails", func() {
|
||||||
|
packetChan := make(chan packetToRead)
|
||||||
|
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||||
|
defer tr.Close()
|
||||||
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||||
|
tr.init(&Config{})
|
||||||
|
tr.handlerMap = phm
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) })
|
||||||
|
packetChan <- packetToRead{err: errors.New("read failed")}
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
|
||||||
|
// shutdown
|
||||||
|
close(packetChan)
|
||||||
|
tr.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("continues listening after temporary errors", func() {
|
||||||
|
packetChan := make(chan packetToRead)
|
||||||
|
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||||
|
defer tr.Close()
|
||||||
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||||
|
tr.init(&Config{})
|
||||||
|
tr.handlerMap = phm
|
||||||
|
|
||||||
|
tempErr := deadlineError{}
|
||||||
|
Expect(tempErr.Temporary()).To(BeTrue())
|
||||||
|
packetChan <- packetToRead{err: tempErr}
|
||||||
|
// don't expect any calls to phm.Close
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// shutdown
|
||||||
|
phm.EXPECT().Close(gomock.Any())
|
||||||
|
close(packetChan)
|
||||||
|
tr.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("handles short header packets resets", func() {
|
||||||
|
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
||||||
|
packetChan := make(chan packetToRead)
|
||||||
|
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||||
|
tr.init(&Config{ConnectionIDLength: connID.Len()})
|
||||||
|
defer tr.Close()
|
||||||
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||||
|
tr.init(&Config{})
|
||||||
|
tr.handlerMap = phm
|
||||||
|
|
||||||
|
var token protocol.StatelessResetToken
|
||||||
|
rand.Read(token[:])
|
||||||
|
|
||||||
|
var b []byte
|
||||||
|
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
b = append(b, token[:]...)
|
||||||
|
conn := NewMockPacketHandler(mockCtrl)
|
||||||
|
gomock.InOrder(
|
||||||
|
phm.EXPECT().GetByResetToken(token),
|
||||||
|
phm.EXPECT().Get(connID).Return(conn, true),
|
||||||
|
conn.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||||
|
Expect(p.data).To(Equal(b))
|
||||||
|
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
packetChan <- packetToRead{data: b}
|
||||||
|
|
||||||
|
// shutdown
|
||||||
|
phm.EXPECT().Close(gomock.Any())
|
||||||
|
close(packetChan)
|
||||||
|
tr.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("handles stateless resets", func() {
|
||||||
|
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
||||||
|
packetChan := make(chan packetToRead)
|
||||||
|
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||||
|
tr.init(&Config{ConnectionIDLength: connID.Len()})
|
||||||
|
defer tr.Close()
|
||||||
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||||
|
tr.init(&Config{})
|
||||||
|
tr.handlerMap = phm
|
||||||
|
|
||||||
|
var token protocol.StatelessResetToken
|
||||||
|
rand.Read(token[:])
|
||||||
|
|
||||||
|
var b []byte
|
||||||
|
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
b = append(b, token[:]...)
|
||||||
|
conn := NewMockPacketHandler(mockCtrl)
|
||||||
|
gomock.InOrder(
|
||||||
|
phm.EXPECT().GetByResetToken(token).Return(conn, true),
|
||||||
|
conn.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
||||||
|
Expect(err).To(MatchError(&StatelessResetError{Token: token}))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
packetChan <- packetToRead{data: b}
|
||||||
|
|
||||||
|
// shutdown
|
||||||
|
phm.EXPECT().Close(gomock.Any())
|
||||||
|
close(packetChan)
|
||||||
|
tr.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sends stateless resets", func() {
|
||||||
|
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
||||||
|
packetChan := make(chan packetToRead)
|
||||||
|
conn := newMockPacketConn(packetChan)
|
||||||
|
tr := Transport{
|
||||||
|
Conn: conn,
|
||||||
|
}
|
||||||
|
tr.init(&Config{ConnectionIDLength: connID.Len(), StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}})
|
||||||
|
defer tr.Close()
|
||||||
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||||
|
tr.init(&Config{})
|
||||||
|
tr.handlerMap = phm
|
||||||
|
|
||||||
|
var b []byte
|
||||||
|
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...)
|
||||||
|
|
||||||
|
var token protocol.StatelessResetToken
|
||||||
|
rand.Read(token[:])
|
||||||
|
written := make(chan struct{})
|
||||||
|
gomock.InOrder(
|
||||||
|
phm.EXPECT().GetByResetToken(gomock.Any()),
|
||||||
|
phm.EXPECT().Get(connID),
|
||||||
|
phm.EXPECT().GetStatelessResetToken(connID).Return(token),
|
||||||
|
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) {
|
||||||
|
defer close(written)
|
||||||
|
Expect(bytes.Contains(b, token[:])).To(BeTrue())
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
packetChan <- packetToRead{data: b}
|
||||||
|
Eventually(written).Should(BeClosed())
|
||||||
|
|
||||||
|
// shutdown
|
||||||
|
phm.EXPECT().Close(gomock.Any())
|
||||||
|
close(packetChan)
|
||||||
|
tr.Close()
|
||||||
|
})
|
||||||
|
})
|
Loading…
Add table
Add a link
Reference in a new issue