mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-03-31 10:47:35 +03:00
implement the Transport
This commit is contained in:
parent
ae5a8bd35c
commit
8189e75be6
31 changed files with 1309 additions and 1250 deletions
158
client.go
158
client.go
|
@ -20,6 +20,7 @@ type client struct {
|
|||
use0RTT bool
|
||||
|
||||
packetHandlers packetHandlerManager
|
||||
onClose func()
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
|
@ -45,32 +46,58 @@ var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
|||
|
||||
// DialAddr establishes a new QUIC connection to a server.
|
||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (Connection, error) {
|
||||
return dialAddrContext(ctx, addr, tlsConf, config, false)
|
||||
}
|
||||
|
||||
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
|
||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
|
||||
conn, err := dialAddrContext(ctx, addr, tlsConf, config, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection")
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, config *Config, use0RTT bool) (quicConn, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dialContext(ctx, udpConn, udpAddr, tlsConf, config, use0RTT, true)
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl, err := setupTransport(udpConn, tlsConf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dl.Dial(ctx, udpAddr, tlsConf, conf)
|
||||
}
|
||||
|
||||
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
|
||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl, err := setupTransport(udpConn, tlsConf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf)
|
||||
if err != nil {
|
||||
dl.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
|
||||
// See DialEarly for details.
|
||||
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||
dl, err := setupTransport(c, tlsConf, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
|
||||
if err != nil {
|
||||
dl.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If
|
||||
|
@ -78,34 +105,43 @@ func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, conf
|
|||
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
|
||||
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
|
||||
// packets.
|
||||
// The same PacketConn can be used for multiple calls to Dial and Listen.
|
||||
// QUIC connection IDs are used for demultiplexing the different connections.
|
||||
// The tls.Config must define an application protocol (using NextProtos).
|
||||
func Dial(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (Connection, error) {
|
||||
return dialContext(ctx, pconn, addr, tlsConf, config, false, false)
|
||||
}
|
||||
|
||||
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
|
||||
// The same PacketConn can be used for multiple calls to Dial and Listen,
|
||||
// QUIC connection IDs are used for demultiplexing the different connections.
|
||||
// The tls.Config must define an application protocol (using NextProtos).
|
||||
func DialEarly(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
|
||||
return dialContext(ctx, pconn, addr, tlsConf, config, true, false)
|
||||
}
|
||||
|
||||
func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config, use0RTT bool, createdPacketConn bool) (quicConn, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
if err := validateConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config = populateClientConfig(config, createdPacketConn)
|
||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
|
||||
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||
dl, err := setupTransport(c, tlsConf, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := newClient(pconn, addr, config, tlsConf, use0RTT, createdPacketConn)
|
||||
conn, err := dl.Dial(ctx, addr, tlsConf, conf)
|
||||
if err != nil {
|
||||
dl.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
return &Transport{
|
||||
Conn: c,
|
||||
createdConn: createdPacketConn,
|
||||
isSingleUse: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func dial(
|
||||
ctx context.Context,
|
||||
conn net.PacketConn,
|
||||
packetHandlers packetHandlerManager,
|
||||
addr net.Addr,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
onClose func(),
|
||||
use0RTT bool,
|
||||
createdPacketConn bool,
|
||||
) (quicConn, error) {
|
||||
c, err := newClient(conn, addr, config, tlsConf, onClose, use0RTT, createdPacketConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -128,7 +164,7 @@ func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsCo
|
|||
return c.conn, nil
|
||||
}
|
||||
|
||||
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, use0RTT bool, createdPacketConn bool) (*client, error) {
|
||||
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, onClose func(), use0RTT, createdPacketConn bool) (*client, error) {
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
} else {
|
||||
|
@ -149,6 +185,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
|
|||
sconn: newSendPconn(pconn, remoteAddr),
|
||||
createdPacketConn: createdPacketConn,
|
||||
use0RTT: use0RTT,
|
||||
onClose: onClose,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
|
@ -179,13 +216,18 @@ func (c *client) dial(ctx context.Context) error {
|
|||
c.packetHandlers.Add(c.srcConnID, c.conn)
|
||||
|
||||
errorChan := make(chan error, 1)
|
||||
recreateChan := make(chan errCloseForRecreating)
|
||||
go func() {
|
||||
err := c.conn.run() // returns as soon as the connection is closed
|
||||
|
||||
if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn {
|
||||
c.packetHandlers.Destroy()
|
||||
err := c.conn.run()
|
||||
var recreateErr *errCloseForRecreating
|
||||
if errors.As(err, &recreateErr) {
|
||||
recreateChan <- *recreateErr
|
||||
return
|
||||
}
|
||||
errorChan <- err
|
||||
if c.onClose != nil {
|
||||
c.onClose()
|
||||
}
|
||||
errorChan <- err // returns as soon as the connection is closed
|
||||
}()
|
||||
|
||||
// only set when we're using 0-RTT
|
||||
|
@ -200,14 +242,12 @@ func (c *client) dial(ctx context.Context) error {
|
|||
c.conn.shutdown()
|
||||
return ctx.Err()
|
||||
case err := <-errorChan:
|
||||
var recreateErr *errCloseForRecreating
|
||||
if errors.As(err, &recreateErr) {
|
||||
c.initialPacketNumber = recreateErr.nextPacketNumber
|
||||
c.version = recreateErr.nextVersion
|
||||
c.hasNegotiatedVersion = true
|
||||
return c.dial(ctx)
|
||||
}
|
||||
return err
|
||||
case recreateErr := <-recreateChan:
|
||||
c.initialPacketNumber = recreateErr.nextPacketNumber
|
||||
c.version = recreateErr.nextVersion
|
||||
c.hasNegotiatedVersion = true
|
||||
return c.dial(ctx)
|
||||
case <-earlyConnChan:
|
||||
// ready to send 0-RTT data
|
||||
return nil
|
||||
|
|
236
client_test.go
236
client_test.go
|
@ -18,13 +18,17 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type nullMultiplexer struct{}
|
||||
|
||||
func (n nullMultiplexer) AddConn(indexableConn) {}
|
||||
func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil }
|
||||
|
||||
var _ = Describe("Client", func() {
|
||||
var (
|
||||
cl *client
|
||||
packetConn *MockPacketConn
|
||||
addr net.Addr
|
||||
connID protocol.ConnectionID
|
||||
mockMultiplexer *MockMultiplexer
|
||||
origMultiplexer multiplexer
|
||||
tlsConf *tls.Config
|
||||
tracer *mocklogging.MockConnectionTracer
|
||||
|
@ -53,6 +57,7 @@ var _ = Describe("Client", func() {
|
|||
originalClientConnConstructor = newClientConnection
|
||||
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
|
||||
tr := mocklogging.NewMockTracer(mockCtrl)
|
||||
tr.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
|
||||
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}}
|
||||
Eventually(areConnsRunning).Should(BeFalse())
|
||||
|
@ -68,10 +73,9 @@ var _ = Describe("Client", func() {
|
|||
logger: utils.DefaultLogger,
|
||||
}
|
||||
getMultiplexer() // make the sync.Once execute
|
||||
// replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
|
||||
mockMultiplexer = NewMockMultiplexer(mockCtrl)
|
||||
// replace the clientMuxer. getMultiplexer will now return the nullMultiplexer
|
||||
origMultiplexer = connMuxer
|
||||
connMuxer = mockMultiplexer
|
||||
connMuxer = &nullMultiplexer{}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
@ -100,48 +104,14 @@ var _ = Describe("Client", func() {
|
|||
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
|
||||
})
|
||||
|
||||
It("resolves the address", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
manager.EXPECT().Destroy()
|
||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
remoteAddrChan := make(chan string, 1)
|
||||
newClientConnection = func(
|
||||
sconn sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ logging.ConnectionTracer,
|
||||
_ uint64,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
remoteAddrChan <- sconn.RemoteAddr().String()
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run()
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
return conn
|
||||
}
|
||||
_, err := DialAddr(context.Background(), "localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890")))
|
||||
})
|
||||
|
||||
It("returns after the handshake is complete", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
run := make(chan struct{})
|
||||
newClientConnection = func(
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
|
@ -162,18 +132,17 @@ var _ = Describe("Client", func() {
|
|||
conn.EXPECT().HandshakeComplete().Return(c)
|
||||
return conn
|
||||
}
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
s, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
|
||||
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, false, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
Expect(cl.dial(context.Background())).To(Succeed())
|
||||
Eventually(run).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns early connections", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
readyChan := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
newClientConnection = func(
|
||||
|
@ -193,29 +162,23 @@ var _ = Describe("Client", func() {
|
|||
) quicConn {
|
||||
Expect(enable0RTT).To(BeTrue())
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run().Do(func() { <-done })
|
||||
conn.EXPECT().run().Do(func() { close(done) })
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().earlyConnReady().Return(readyChan)
|
||||
return conn
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
s, err := DialEarly(context.Background(), packetConn, addr, tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
}()
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
close(readyChan)
|
||||
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, true, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
Expect(cl.dial(context.Background())).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns an error that occurs while waiting for the handshake to complete", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
testErr := errors.New("early handshake error")
|
||||
newClientConnection = func(
|
||||
|
@ -236,108 +199,16 @@ var _ = Describe("Client", func() {
|
|||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run().Return(testErr)
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().earlyConnReady().Return(make(chan struct{}))
|
||||
return conn
|
||||
}
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("closes the connection when the context is canceled", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
connRunning := make(chan struct{})
|
||||
defer close(connRunning)
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run().Do(func() {
|
||||
<-connRunning
|
||||
})
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
newClientConnection = func(
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ logging.ConnectionTracer,
|
||||
_ uint64,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
return conn
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
_, err := Dial(ctx, packetConn, addr, tlsConf, config)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
close(dialed)
|
||||
}()
|
||||
Consistently(dialed).ShouldNot(BeClosed())
|
||||
conn.EXPECT().shutdown()
|
||||
cancel()
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("closes the connection when it was created by DialAddr", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
|
||||
var sconn sendConn
|
||||
run := make(chan struct{})
|
||||
connCreated := make(chan struct{})
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
newClientConnection = func(
|
||||
connP sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ logging.ConnectionTracer,
|
||||
_ uint64,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
sconn = connP
|
||||
close(connCreated)
|
||||
return conn
|
||||
}
|
||||
conn.EXPECT().run().Do(func() {
|
||||
<-run
|
||||
})
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := DialAddr(context.Background(), "localhost:1337", tlsConf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Eventually(connCreated).Should(BeClosed())
|
||||
|
||||
// check that the connection is not closed
|
||||
Expect(sconn.Write([]byte("foobar"))).To(Succeed())
|
||||
|
||||
manager.EXPECT().Destroy()
|
||||
close(run)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
Eventually(done).Should(BeClosed())
|
||||
var closed bool
|
||||
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, func() { closed = true }, true, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
Expect(cl.dial(context.Background())).To(MatchError(testErr))
|
||||
Expect(closed).To(BeTrue())
|
||||
})
|
||||
|
||||
Context("quic.Config", func() {
|
||||
|
@ -365,12 +236,6 @@ var _ = Describe("Client", func() {
|
|||
Expect(c.EnableDatagrams).To(BeTrue())
|
||||
})
|
||||
|
||||
It("errors when the Config contains an invalid version", func() {
|
||||
version := protocol.VersionNumber(0x1234)
|
||||
_, err := Dial(context.Background(), packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
|
||||
Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
|
||||
})
|
||||
|
||||
It("disables bidirectional streams", func() {
|
||||
config := &Config{
|
||||
MaxIncomingStreams: -1,
|
||||
|
@ -405,15 +270,12 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("creates new connections with the right parameters", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(connID, gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}}
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}}
|
||||
c := make(chan struct{})
|
||||
var cconn sendConn
|
||||
var version protocol.VersionNumber
|
||||
var conf *Config
|
||||
done := make(chan struct{})
|
||||
newClientConnection = func(
|
||||
connP sendConn,
|
||||
_ connRunner,
|
||||
|
@ -437,8 +299,15 @@ var _ = Describe("Client", func() {
|
|||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run()
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().destroy(gomock.Any())
|
||||
close(done)
|
||||
return conn
|
||||
}
|
||||
packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) {
|
||||
<-done
|
||||
return 0, nil, errors.New("closed")
|
||||
})
|
||||
packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
|
||||
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(c).Should(BeClosed())
|
||||
|
@ -448,17 +317,12 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("creates a new connections after version negotiation", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(connID, gomock.Any()).Times(2)
|
||||
manager.EXPECT().Destroy()
|
||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
var counter int
|
||||
newClientConnection = func(
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
connID protocol.ConnectionID,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
pn protocol.PacketNumber,
|
||||
|
@ -474,20 +338,24 @@ var _ = Describe("Client", func() {
|
|||
if counter == 0 {
|
||||
Expect(pn).To(BeZero())
|
||||
Expect(hasNegotiatedVersion).To(BeFalse())
|
||||
conn.EXPECT().run().Return(&errCloseForRecreating{
|
||||
nextPacketNumber: 109,
|
||||
nextVersion: 789,
|
||||
conn.EXPECT().run().DoAndReturn(func() error {
|
||||
runner.Remove(connID)
|
||||
return &errCloseForRecreating{
|
||||
nextPacketNumber: 109,
|
||||
nextVersion: 789,
|
||||
}
|
||||
})
|
||||
} else {
|
||||
Expect(pn).To(Equal(protocol.PacketNumber(109)))
|
||||
Expect(hasNegotiatedVersion).To(BeTrue())
|
||||
conn.EXPECT().run()
|
||||
conn.EXPECT().destroy(gomock.Any())
|
||||
}
|
||||
counter++
|
||||
return conn
|
||||
}
|
||||
|
||||
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}}
|
||||
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}}
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
_, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -495,15 +363,3 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
type mockConnIDGenerator struct {
|
||||
ConnID protocol.ConnectionID
|
||||
}
|
||||
|
||||
func (m *mockConnIDGenerator) GenerateConnectionID() (protocol.ConnectionID, error) {
|
||||
return m.ConnID, nil
|
||||
}
|
||||
|
||||
func (m *mockConnIDGenerator) ConnectionIDLen() int {
|
||||
return m.ConnID.Len()
|
||||
}
|
||||
|
|
|
@ -660,6 +660,7 @@ var _ = Describe("Stream Cancellations", func() {
|
|||
getQuicConfig(&quic.Config{MaxIncomingStreams: maxIncomingStreams, MaxIdleTimeout: 10 * time.Second}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2 * 4 * maxIncomingStreams)
|
||||
|
|
|
@ -24,6 +24,7 @@ var _ = Describe("Connection ID lengths tests", func() {
|
|||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
var drop atomic.Bool
|
||||
dropped := make(chan []byte, 100)
|
||||
|
@ -50,6 +51,7 @@ var _ = Describe("Connection ID lengths tests", func() {
|
|||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
sconn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -35,7 +35,11 @@ var _ = Describe("Connection ID lengths tests", func() {
|
|||
randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) }
|
||||
|
||||
runServer := func(conf *quic.Config) *quic.Listener {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
|
||||
if conf.ConnectionIDGenerator != nil {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", conf.ConnectionIDGenerator.ConnectionIDLen())))
|
||||
} else {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
|
||||
}
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
|
@ -59,7 +63,11 @@ var _ = Describe("Connection ID lengths tests", func() {
|
|||
}
|
||||
|
||||
runClient := func(addr net.Addr, conf *quic.Config) {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
|
||||
if conf.ConnectionIDGenerator != nil {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", conf.ConnectionIDGenerator.ConnectionIDLen())))
|
||||
} else {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
|
||||
}
|
||||
cl, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
|
||||
|
|
|
@ -22,12 +22,11 @@ var _ = Describe("Datagram test", func() {
|
|||
const num = 100
|
||||
|
||||
var (
|
||||
proxy *quicproxy.QuicProxy
|
||||
serverConn, clientConn *net.UDPConn
|
||||
dropped, total int32
|
||||
)
|
||||
|
||||
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) {
|
||||
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) (port int, closeFn func()) {
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConn, err = net.ListenUDP("udp", addr)
|
||||
|
@ -39,8 +38,10 @@ var _ = Describe("Datagram test", func() {
|
|||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
accepted := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(accepted)
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
|
@ -67,7 +68,7 @@ var _ = Describe("Datagram test", func() {
|
|||
}()
|
||||
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
// drop 10% of Short Header packets sent from the server
|
||||
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
|
||||
|
@ -87,6 +88,11 @@ var _ = Describe("Datagram test", func() {
|
|||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return proxy.LocalPort(), func() {
|
||||
Eventually(accepted).Should(BeClosed())
|
||||
proxy.Close()
|
||||
ln.Close()
|
||||
}
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
|
@ -96,13 +102,10 @@ var _ = Describe("Datagram test", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(proxy.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
It("sends datagrams", func() {
|
||||
startServerAndProxy(true, true)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
|
||||
proxyPort, close := startServerAndProxy(true, true)
|
||||
defer close()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
context.Background(),
|
||||
|
@ -117,6 +120,7 @@ var _ = Describe("Datagram test", func() {
|
|||
for {
|
||||
// Close the connection if no message is received for 100 ms.
|
||||
timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() {
|
||||
fmt.Println("closing conn")
|
||||
conn.CloseWithError(0, "")
|
||||
})
|
||||
if _, err := conn.ReceiveMessage(); err != nil {
|
||||
|
@ -134,11 +138,12 @@ var _ = Describe("Datagram test", func() {
|
|||
BeNumerically(">", expVal*9/10),
|
||||
BeNumerically("<", num),
|
||||
))
|
||||
Eventually(conn.Context().Done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("server can disable datagram", func() {
|
||||
startServerAndProxy(false, true)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
|
||||
proxyPort, close := startServerAndProxy(false, true)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
context.Background(),
|
||||
|
@ -150,13 +155,13 @@ var _ = Describe("Datagram test", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
|
||||
|
||||
close()
|
||||
conn.CloseWithError(0, "")
|
||||
<-time.After(10 * time.Millisecond)
|
||||
})
|
||||
|
||||
It("client can disable datagram", func() {
|
||||
startServerAndProxy(false, true)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
|
||||
proxyPort, close := startServerAndProxy(false, true)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
context.Background(),
|
||||
|
@ -169,7 +174,8 @@ var _ = Describe("Datagram test", func() {
|
|||
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
|
||||
|
||||
Expect(conn.SendMessage([]byte{0})).To(HaveOccurred())
|
||||
|
||||
close()
|
||||
conn.CloseWithError(0, "")
|
||||
<-time.After(10 * time.Millisecond)
|
||||
})
|
||||
})
|
||||
|
|
|
@ -24,6 +24,7 @@ var _ = Describe("early data", func() {
|
|||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
|
|
@ -8,10 +8,9 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
go120 = false
|
||||
errNotSupported = errors.New("not supported")
|
||||
)
|
||||
const go120 = false
|
||||
|
||||
var errNotSupported = errors.New("not supported")
|
||||
|
||||
func setReadDeadline(w http.ResponseWriter, deadline time.Time) error {
|
||||
return errNotSupported
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
var go120 = true
|
||||
const go120 = true
|
||||
|
||||
func setReadDeadline(w http.ResponseWriter, deadline time.Time) error {
|
||||
rc := http.NewResponseController(w)
|
||||
|
|
|
@ -62,13 +62,14 @@ var _ = Describe("Handshake RTT tests", func() {
|
|||
|
||||
runProxy(ln.Addr())
|
||||
startTime := time.Now()
|
||||
_, err = quic.DialAddr(
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
expectDurationInRTTs(startTime, 2)
|
||||
})
|
||||
|
||||
|
@ -79,13 +80,14 @@ var _ = Describe("Handshake RTT tests", func() {
|
|||
|
||||
runProxy(ln.Addr())
|
||||
startTime := time.Now()
|
||||
_, err = quic.DialAddr(
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
expectDurationInRTTs(startTime, 1)
|
||||
})
|
||||
|
||||
|
@ -97,13 +99,14 @@ var _ = Describe("Handshake RTT tests", func() {
|
|||
|
||||
runProxy(ln.Addr())
|
||||
startTime := time.Now()
|
||||
_, err = quic.DialAddr(
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
expectDurationInRTTs(startTime, 2)
|
||||
})
|
||||
|
||||
|
@ -131,6 +134,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
|||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
|
@ -166,6 +170,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
|||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
|
|
|
@ -114,7 +114,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
nil,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
|
@ -223,13 +223,14 @@ var _ = Describe("Handshake tests", func() {
|
|||
var (
|
||||
server *quic.Listener
|
||||
pconn net.PacketConn
|
||||
dialer *quic.Transport
|
||||
)
|
||||
|
||||
dial := func() (quic.Connection, error) {
|
||||
remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
||||
raddr, err := net.ResolveUDPAddr("udp", remoteAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return quic.Dial(context.Background(), pconn, raddr, getTLSClientConfig(), nil)
|
||||
return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil))
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
|
@ -243,11 +244,13 @@ var _ = Describe("Handshake tests", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
pconn, err = net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
dialer = &quic.Transport{Conn: pconn}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(server.Close()).To(Succeed())
|
||||
Expect(pconn.Close()).To(Succeed())
|
||||
Expect(dialer.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects new connection attempts if connections don't get accepted", func() {
|
||||
|
@ -366,6 +369,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
It("uses tokens provided in NEW_TOKEN frames", func() {
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
// dial the first connection and receive the token
|
||||
go func() {
|
||||
|
|
|
@ -382,6 +382,7 @@ var _ = Describe("HTTP tests", func() {
|
|||
tlsConf.NextProtos = []string{"h3"}
|
||||
ln, err := quic.ListenAddr("localhost:0", tlsConf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -398,57 +399,51 @@ var _ = Describe("HTTP tests", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("supports read deadlines", func() {
|
||||
if !go120 {
|
||||
Skip("This test requires Go 1.20+")
|
||||
}
|
||||
if go120 {
|
||||
It("supports read deadlines", func() {
|
||||
mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
err := setReadDeadline(w, time.Now().Add(deadlineDelay))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
err := setReadDeadline(w, time.Now().Add(deadlineDelay))
|
||||
body, err := io.ReadAll(r.Body)
|
||||
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
|
||||
Expect(body).To(ContainSubstring("aa"))
|
||||
|
||||
w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
expectedEnd := time.Now().Add(deadlineDelay)
|
||||
resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a'))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
|
||||
Expect(body).To(ContainSubstring("aa"))
|
||||
|
||||
w.Write([]byte("ok"))
|
||||
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(time.Now().After(expectedEnd)).To(BeTrue())
|
||||
Expect(string(body)).To(Equal("ok"))
|
||||
})
|
||||
|
||||
expectedEnd := time.Now().Add(deadlineDelay)
|
||||
resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a'))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
It("supports write deadlines", func() {
|
||||
mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
err := setWriteDeadline(w, time.Now().Add(deadlineDelay))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(time.Now().After(expectedEnd)).To(BeTrue())
|
||||
Expect(string(body)).To(Equal("ok"))
|
||||
})
|
||||
_, err = io.Copy(w, neverEnding('a'))
|
||||
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
|
||||
})
|
||||
|
||||
It("supports write deadlines", func() {
|
||||
if !go120 {
|
||||
Skip("This test requires Go 1.20+")
|
||||
}
|
||||
expectedEnd := time.Now().Add(deadlineDelay)
|
||||
|
||||
mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
err := setWriteDeadline(w, time.Now().Add(deadlineDelay))
|
||||
resp, err := client.Get("https://localhost:" + port + "/write-deadline")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
_, err = io.Copy(w, neverEnding('a'))
|
||||
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
|
||||
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(time.Now().After(expectedEnd)).To(BeTrue())
|
||||
Expect(string(body)).To(ContainSubstring("aa"))
|
||||
})
|
||||
|
||||
expectedEnd := time.Now().Add(deadlineDelay)
|
||||
|
||||
resp, err := client.Get("https://localhost:" + port + "/write-deadline")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(time.Now().After(expectedEnd)).To(BeTrue())
|
||||
Expect(string(body)).To(ContainSubstring("aa"))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
|
|
@ -34,10 +34,9 @@ var _ = Describe("Multiplexing", func() {
|
|||
}()
|
||||
}
|
||||
|
||||
dial := func(pconn net.PacketConn, addr net.Addr) {
|
||||
conn, err := quic.Dial(
|
||||
dial := func(tr *quic.Transport, addr net.Addr) {
|
||||
conn, err := tr.Dial(
|
||||
context.Background(),
|
||||
pconn,
|
||||
addr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
|
@ -72,17 +71,18 @@ var _ = Describe("Multiplexing", func() {
|
|||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server.Addr())
|
||||
dial(tr, server.Addr())
|
||||
close(done1)
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server.Addr())
|
||||
dial(tr, server.Addr())
|
||||
close(done2)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
|
@ -106,17 +106,18 @@ var _ = Describe("Multiplexing", func() {
|
|||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server1.Addr())
|
||||
dial(tr, server1.Addr())
|
||||
close(done1)
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server2.Addr())
|
||||
dial(tr, server2.Addr())
|
||||
close(done2)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
|
@ -135,9 +136,9 @@ var _ = Describe("Multiplexing", func() {
|
|||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
|
||||
server, err := quic.Listen(
|
||||
conn,
|
||||
server, err := tr.Listen(
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
|
@ -146,7 +147,7 @@ var _ = Describe("Multiplexing", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server.Addr())
|
||||
dial(tr, server.Addr())
|
||||
close(done)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
|
@ -165,15 +166,16 @@ var _ = Describe("Multiplexing", func() {
|
|||
conn1, err := net.ListenUDP("udp", addr1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn1.Close()
|
||||
tr1 := &quic.Transport{Conn: conn1}
|
||||
|
||||
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn2, err := net.ListenUDP("udp", addr2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn2.Close()
|
||||
tr2 := &quic.Transport{Conn: conn2}
|
||||
|
||||
server1, err := quic.Listen(
|
||||
conn1,
|
||||
server1, err := tr1.Listen(
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
|
@ -181,8 +183,7 @@ var _ = Describe("Multiplexing", func() {
|
|||
runServer(server1)
|
||||
defer server1.Close()
|
||||
|
||||
server2, err := quic.Listen(
|
||||
conn2,
|
||||
server2, err := tr2.Listen(
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
|
@ -194,12 +195,12 @@ var _ = Describe("Multiplexing", func() {
|
|||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn2, server1.Addr())
|
||||
dial(tr2, server1.Addr())
|
||||
close(done1)
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn1, server2.Addr())
|
||||
dial(tr1, server2.Addr())
|
||||
close(done2)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
|
|
|
@ -31,8 +31,8 @@ var _ = Describe("Packetization", func() {
|
|||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
||||
defer server.Close()
|
||||
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
||||
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: serverAddr,
|
||||
|
@ -54,6 +54,7 @@ var _ = Describe("Packetization", func() {
|
|||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
|
|
@ -199,8 +199,16 @@ func areHandshakesRunning() bool {
|
|||
return strings.Contains(b.String(), "RunHandshake")
|
||||
}
|
||||
|
||||
func areTransportsRunning() bool {
|
||||
var b bytes.Buffer
|
||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
|
||||
}
|
||||
|
||||
var _ = AfterEach(func() {
|
||||
Expect(areHandshakesRunning()).To(BeFalse())
|
||||
Eventually(areTransportsRunning).Should(BeFalse())
|
||||
|
||||
if debugLog() {
|
||||
logFile, err := os.Create(logFileName)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -2,7 +2,6 @@ package self_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
|
@ -27,7 +26,13 @@ var _ = Describe("Stateless Resets", func() {
|
|||
rand.Read(statelessResetKey[:])
|
||||
serverConfig := getQuicConfig(&quic.Config{StatelessResetKey: &statelessResetKey})
|
||||
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
c, err := net.ListenUDP("udp", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{
|
||||
Conn: c,
|
||||
}
|
||||
defer tr.Close()
|
||||
ln, err := tr.Listen(getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
|
||||
|
@ -42,7 +47,8 @@ var _ = Describe("Stateless Resets", func() {
|
|||
_, err = str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
<-closeServer
|
||||
ln.Close()
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
Expect(tr.Close()).To(Succeed())
|
||||
}()
|
||||
|
||||
var drop atomic.Bool
|
||||
|
@ -77,11 +83,14 @@ var _ = Describe("Stateless Resets", func() {
|
|||
close(closeServer)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ln2, err := quic.ListenAddr(
|
||||
fmt.Sprintf("localhost:%d", serverPort),
|
||||
getTLSConfig(),
|
||||
serverConfig,
|
||||
)
|
||||
// We need to create a new Transport here, since the old one is still sending out
|
||||
// CONNECTION_CLOSE packets for (recently) closed connections).
|
||||
tr2 := &quic.Transport{
|
||||
Conn: c,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
}
|
||||
defer tr2.Close()
|
||||
ln2, err := tr2.Listen(getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
drop.Store(false)
|
||||
|
||||
|
@ -100,8 +109,7 @@ var _ = Describe("Stateless Resets", func() {
|
|||
_, serr = str.Read([]byte{0})
|
||||
}
|
||||
Expect(serr).To(HaveOccurred())
|
||||
statelessResetErr := &quic.StatelessResetError{}
|
||||
Expect(errors.As(serr, &statelessResetErr)).To(BeTrue())
|
||||
Expect(serr).To(BeAssignableToTypeOf(&quic.StatelessResetError{}))
|
||||
Expect(ln2.Close()).To(Succeed())
|
||||
Eventually(acceptStopped).Should(BeClosed())
|
||||
})
|
||||
|
|
|
@ -94,6 +94,8 @@ var _ = Describe("Bidirectional streams", func() {
|
|||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(client)
|
||||
client.CloseWithError(0, "")
|
||||
<-conn.Context().Done()
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
||||
|
@ -149,5 +151,6 @@ var _ = Describe("Bidirectional streams", func() {
|
|||
runReceivingPeer(client)
|
||||
<-done1
|
||||
<-done2
|
||||
client.CloseWithError(0, "")
|
||||
})
|
||||
})
|
||||
|
|
|
@ -473,6 +473,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
|
|
|
@ -88,11 +88,14 @@ var _ = Describe("Unidirectional Streams", func() {
|
|||
})
|
||||
|
||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(conn)
|
||||
<-conn.Context().Done()
|
||||
}()
|
||||
|
||||
client, err := quic.DialAddr(
|
||||
|
@ -103,6 +106,7 @@ var _ = Describe("Unidirectional Streams", func() {
|
|||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runReceivingPeer(client)
|
||||
client.CloseWithError(0, "")
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() {
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/quic-go/quic-go (interfaces: Multiplexer)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
logging "github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
// MockMultiplexer is a mock of Multiplexer interface.
|
||||
type MockMultiplexer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMultiplexerMockRecorder
|
||||
}
|
||||
|
||||
// MockMultiplexerMockRecorder is the mock recorder for MockMultiplexer.
|
||||
type MockMultiplexerMockRecorder struct {
|
||||
mock *MockMultiplexer
|
||||
}
|
||||
|
||||
// NewMockMultiplexer creates a new mock instance.
|
||||
func NewMockMultiplexer(ctrl *gomock.Controller) *MockMultiplexer {
|
||||
mock := &MockMultiplexer{ctrl: ctrl}
|
||||
mock.recorder = &MockMultiplexerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddConn mocks base method.
|
||||
func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 *StatelessResetKey, arg3 logging.Tracer) (packetHandlerManager, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(packetHandlerManager)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AddConn indicates an expected call of AddConn.
|
||||
func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// RemoveConn mocks base method.
|
||||
func (m *MockMultiplexer) RemoveConn(arg0 indexableConn) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemoveConn", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemoveConn indicates an expected call of RemoveConn.
|
||||
func (mr *MockMultiplexerMockRecorder) RemoveConn(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveConn", reflect.TypeOf((*MockMultiplexer)(nil).RemoveConn), arg0)
|
||||
}
|
|
@ -74,6 +74,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 i
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockPacketHandlerManager) Close(arg0 error) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Close", arg0)
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0)
|
||||
}
|
||||
|
||||
// CloseServer mocks base method.
|
||||
func (m *MockPacketHandlerManager) CloseServer() {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -86,20 +98,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
|
||||
}
|
||||
|
||||
// Destroy mocks base method.
|
||||
func (m *MockPacketHandlerManager) Destroy() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Destroy")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Destroy indicates an expected call of Destroy.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) Destroy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockPacketHandlerManager)(nil).Destroy))
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -115,6 +113,21 @@ func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Ca
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
|
||||
}
|
||||
|
||||
// GetByResetToken mocks base method.
|
||||
func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetToken) (packetHandler, bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetByResetToken", arg0)
|
||||
ret0, _ := ret[0].(packetHandler)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetByResetToken indicates an expected call of GetByResetToken.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0)
|
||||
}
|
||||
|
||||
// GetStatelessResetToken mocks base method.
|
||||
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -176,15 +189,3 @@ func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 interface{}) *gomock
|
|||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0)
|
||||
}
|
||||
|
||||
// SetServer mocks base method.
|
||||
func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetServer", arg0)
|
||||
}
|
||||
|
||||
// SetServer indicates an expected call of SetServer.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0)
|
||||
}
|
||||
|
|
|
@ -65,9 +65,6 @@ type UnknownPacketHandler = unknownPacketHandler
|
|||
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
|
||||
type PacketHandlerManager = packetHandlerManager
|
||||
|
||||
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_multiplexer_test.go github.com/quic-go/quic-go Multiplexer"
|
||||
type Multiplexer = multiplexer
|
||||
|
||||
// Need to use source mode for the batchConn, since reflect mode follows type aliases.
|
||||
// See https://github.com/golang/mock/issues/244 for details.
|
||||
//
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -14,30 +13,19 @@ var (
|
|||
connMuxer multiplexer
|
||||
)
|
||||
|
||||
type indexableConn interface {
|
||||
LocalAddr() net.Addr
|
||||
}
|
||||
type indexableConn interface{ LocalAddr() net.Addr }
|
||||
|
||||
type multiplexer interface {
|
||||
AddConn(c net.PacketConn, connIDLen int, statelessResetKey *StatelessResetKey, tracer logging.Tracer) (packetHandlerManager, error)
|
||||
AddConn(conn indexableConn)
|
||||
RemoveConn(indexableConn) error
|
||||
}
|
||||
|
||||
type connManager struct {
|
||||
connIDLen int
|
||||
statelessResetKey *StatelessResetKey
|
||||
tracer logging.Tracer
|
||||
manager packetHandlerManager
|
||||
}
|
||||
|
||||
// The connMultiplexer listens on multiple net.PacketConns and dispatches
|
||||
// incoming packets to the connection handler.
|
||||
type connMultiplexer struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conns map[string] /* LocalAddr().String() */ connManager
|
||||
newPacketHandlerManager func(net.PacketConn, int, *StatelessResetKey, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests
|
||||
|
||||
conns map[string] /* LocalAddr().String() */ indexableConn
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
|
@ -46,57 +34,38 @@ var _ multiplexer = &connMultiplexer{}
|
|||
func getMultiplexer() multiplexer {
|
||||
connMuxerOnce.Do(func() {
|
||||
connMuxer = &connMultiplexer{
|
||||
conns: make(map[string]connManager),
|
||||
logger: utils.DefaultLogger.WithPrefix("muxer"),
|
||||
newPacketHandlerManager: newPacketHandlerMap,
|
||||
conns: make(map[string]indexableConn),
|
||||
logger: utils.DefaultLogger.WithPrefix("muxer"),
|
||||
}
|
||||
})
|
||||
return connMuxer
|
||||
}
|
||||
|
||||
func (m *connMultiplexer) AddConn(
|
||||
c net.PacketConn,
|
||||
connIDLen int,
|
||||
statelessResetKey *StatelessResetKey,
|
||||
tracer logging.Tracer,
|
||||
) (packetHandlerManager, error) {
|
||||
func (m *connMultiplexer) index(addr net.Addr) string {
|
||||
return addr.Network() + " " + addr.String()
|
||||
}
|
||||
|
||||
func (m *connMultiplexer) AddConn(c indexableConn) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
addr := c.LocalAddr()
|
||||
connIndex := addr.Network() + " " + addr.String()
|
||||
connIndex := m.index(c.LocalAddr())
|
||||
p, ok := m.conns[connIndex]
|
||||
if !ok {
|
||||
manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p = connManager{
|
||||
connIDLen: connIDLen,
|
||||
statelessResetKey: statelessResetKey,
|
||||
manager: manager,
|
||||
tracer: tracer,
|
||||
}
|
||||
m.conns[connIndex] = p
|
||||
} else {
|
||||
if p.connIDLen != connIDLen {
|
||||
return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen)
|
||||
}
|
||||
if statelessResetKey != nil && p.statelessResetKey != statelessResetKey {
|
||||
return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn")
|
||||
}
|
||||
if tracer != p.tracer {
|
||||
return nil, fmt.Errorf("cannot use different tracers on the same packet conn")
|
||||
}
|
||||
if ok {
|
||||
// Panics if we're already listening on this connection.
|
||||
// This is a safeguard because we're introducing a breaking API change, see
|
||||
// https://github.com/quic-go/quic-go/issues/3727 for details.
|
||||
// We'll remove this at a later time, when most users of the library have made the switch.
|
||||
panic("connection already exists") // TODO: write a nice message
|
||||
}
|
||||
return p.manager, nil
|
||||
m.conns[connIndex] = p
|
||||
}
|
||||
|
||||
func (m *connMultiplexer) RemoveConn(c indexableConn) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String()
|
||||
connIndex := m.index(c.LocalAddr())
|
||||
if _, ok := m.conns[connIndex]; !ok {
|
||||
return fmt.Errorf("cannote remove connection, connection is unknown")
|
||||
}
|
||||
|
|
|
@ -3,71 +3,24 @@ package quic
|
|||
import (
|
||||
"net"
|
||||
|
||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type testConn struct {
|
||||
counter int
|
||||
net.PacketConn
|
||||
}
|
||||
|
||||
var _ = Describe("Multiplexer", func() {
|
||||
It("adds a new packet conn ", func() {
|
||||
conn := NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234})
|
||||
_, err := getMultiplexer().AddConn(conn, 8, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
It("adds new packet conns", func() {
|
||||
conn1 := NewMockPacketConn(mockCtrl)
|
||||
conn1.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234})
|
||||
getMultiplexer().AddConn(conn1)
|
||||
conn2 := NewMockPacketConn(mockCtrl)
|
||||
conn2.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1235})
|
||||
getMultiplexer().AddConn(conn2)
|
||||
})
|
||||
|
||||
It("recognizes when the same connection is added twice", func() {
|
||||
srk := &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'}
|
||||
pconn := NewMockPacketConn(mockCtrl)
|
||||
pconn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2)
|
||||
pconn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
||||
conn := testConn{PacketConn: pconn}
|
||||
tracer := mocklogging.NewMockTracer(mockCtrl)
|
||||
_, err := getMultiplexer().AddConn(conn, 8, srk, tracer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn.counter++
|
||||
_, err = getMultiplexer().AddConn(conn, 8, srk, tracer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("errors when adding an existing conn with a different connection ID length", func() {
|
||||
It("panics when the same connection is added twice", func() {
|
||||
conn := NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
|
||||
_, err := getMultiplexer().AddConn(conn, 5, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = getMultiplexer().AddConn(conn, 6, nil, nil)
|
||||
Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs"))
|
||||
})
|
||||
|
||||
It("errors when adding an existing conn with a different stateless rest key", func() {
|
||||
srk1 := &StatelessResetKey{'f', 'o', 'o'}
|
||||
srk2 := &StatelessResetKey{'b', 'a', 'r'}
|
||||
conn := NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
|
||||
_, err := getMultiplexer().AddConn(conn, 7, srk1, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = getMultiplexer().AddConn(conn, 7, srk2, nil)
|
||||
Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn"))
|
||||
})
|
||||
|
||||
It("errors when adding an existing conn with different tracers", func() {
|
||||
conn := NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
|
||||
_, err := getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl))
|
||||
Expect(err).To(MatchError("cannot use different tracers on the same packet conn"))
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2)
|
||||
getMultiplexer().AddConn(conn)
|
||||
Expect(func() { getMultiplexer().AddConn(conn) }).To(Panic())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -5,28 +5,22 @@ import (
|
|||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
// rawConn is a connection that allow reading of a receivedPacket.
|
||||
// rawConn is a connection that allow reading of a receivedPackeh.
|
||||
type rawConn interface {
|
||||
ReadPacket() (*receivedPacket, error)
|
||||
WritePacket(b []byte, addr net.Addr, oob []byte) (int, error)
|
||||
LocalAddr() net.Addr
|
||||
SetReadDeadline(time.Time) error
|
||||
io.Closer
|
||||
}
|
||||
|
||||
|
@ -36,113 +30,49 @@ type closePacket struct {
|
|||
info *packetInfo
|
||||
}
|
||||
|
||||
// The packetHandlerMap stores packetHandlers, identified by connection ID.
|
||||
// It is used:
|
||||
// * by the server to store connections
|
||||
// * when multiplexing outgoing connections to store clients
|
||||
type unknownPacketHandler interface {
|
||||
handlePacket(*receivedPacket)
|
||||
setCloseError(error)
|
||||
}
|
||||
|
||||
var errListenerAlreadySet = errors.New("listener already set")
|
||||
|
||||
type packetHandlerMap struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conn rawConn
|
||||
connIDLen int
|
||||
|
||||
closeQueue chan closePacket
|
||||
|
||||
mutex sync.Mutex
|
||||
handlers map[protocol.ConnectionID]packetHandler
|
||||
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
|
||||
server unknownPacketHandler
|
||||
|
||||
listening chan struct{} // is closed when listen returns
|
||||
closed bool
|
||||
closeChan chan struct{}
|
||||
|
||||
enqueueClosePacket func(closePacket)
|
||||
|
||||
deleteRetiredConnsAfter time.Duration
|
||||
|
||||
statelessResetEnabled bool
|
||||
statelessResetMutex sync.Mutex
|
||||
statelessResetHasher hash.Hash
|
||||
statelessResetMutex sync.Mutex
|
||||
statelessResetHasher hash.Hash
|
||||
|
||||
tracer logging.Tracer
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ packetHandlerManager = &packetHandlerMap{}
|
||||
|
||||
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
|
||||
conn, ok := c.(interface{ SetReadBuffer(int) error })
|
||||
if !ok {
|
||||
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
|
||||
}
|
||||
size, err := inspectReadBuffer(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to determine receive buffer size: %w", err)
|
||||
}
|
||||
if size >= protocol.DesiredReceiveBufferSize {
|
||||
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
|
||||
return nil
|
||||
}
|
||||
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
|
||||
return fmt.Errorf("failed to increase receive buffer size: %w", err)
|
||||
}
|
||||
newSize, err := inspectReadBuffer(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to determine receive buffer size: %w", err)
|
||||
}
|
||||
if newSize == size {
|
||||
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
|
||||
}
|
||||
if newSize < protocol.DesiredReceiveBufferSize {
|
||||
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
|
||||
}
|
||||
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
|
||||
return nil
|
||||
}
|
||||
|
||||
// only print warnings about the UDP receive buffer size once
|
||||
var receiveBufferWarningOnce sync.Once
|
||||
|
||||
func newPacketHandlerMap(
|
||||
c net.PacketConn,
|
||||
connIDLen int,
|
||||
statelessResetKey *StatelessResetKey,
|
||||
tracer logging.Tracer,
|
||||
logger utils.Logger,
|
||||
) (packetHandlerManager, error) {
|
||||
if err := setReceiveBuffer(c, logger); err != nil {
|
||||
if !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
receiveBufferWarningOnce.Do(func() {
|
||||
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
|
||||
return
|
||||
}
|
||||
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
|
||||
})
|
||||
}
|
||||
}
|
||||
conn, err := wrapConn(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := &packetHandlerMap{
|
||||
conn: conn,
|
||||
connIDLen: connIDLen,
|
||||
listening: make(chan struct{}),
|
||||
func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
|
||||
h := &packetHandlerMap{
|
||||
closeChan: make(chan struct{}),
|
||||
handlers: make(map[protocol.ConnectionID]packetHandler),
|
||||
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
|
||||
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
|
||||
closeQueue: make(chan closePacket, 4),
|
||||
statelessResetEnabled: statelessResetKey != nil,
|
||||
tracer: tracer,
|
||||
enqueueClosePacket: enqueueClosePacket,
|
||||
logger: logger,
|
||||
}
|
||||
if m.statelessResetEnabled {
|
||||
m.statelessResetHasher = hmac.New(sha256.New, statelessResetKey[:])
|
||||
if key != nil {
|
||||
h.statelessResetHasher = hmac.New(sha256.New, key[:])
|
||||
}
|
||||
go m.listen()
|
||||
go m.runCloseQueue()
|
||||
|
||||
if logger.Debug() {
|
||||
go m.logUsage()
|
||||
if h.logger.Debug() {
|
||||
go h.logUsage()
|
||||
}
|
||||
return m, nil
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) logUsage() {
|
||||
|
@ -150,7 +80,7 @@ func (h *packetHandlerMap) logUsage() {
|
|||
var printedZero bool
|
||||
for {
|
||||
select {
|
||||
case <-h.listening:
|
||||
case <-h.closeChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
@ -233,12 +163,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
|
|||
if connClosePacket != nil {
|
||||
handler = newClosedLocalConn(
|
||||
func(addr net.Addr, info *packetInfo) {
|
||||
select {
|
||||
case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}:
|
||||
default:
|
||||
// Oops, we're backlogged.
|
||||
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
|
||||
}
|
||||
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
|
||||
},
|
||||
pers,
|
||||
h.logger,
|
||||
|
@ -265,17 +190,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
|
|||
})
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) runCloseQueue() {
|
||||
for {
|
||||
select {
|
||||
case <-h.listening:
|
||||
return
|
||||
case p := <-h.closeQueue:
|
||||
h.conn.WritePacket(p.payload, p.addr, p.info.OOB())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
|
||||
h.mutex.Lock()
|
||||
h.resetTokens[token] = handler
|
||||
|
@ -288,19 +202,16 @@ func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken)
|
|||
h.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
|
||||
func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) {
|
||||
h.mutex.Lock()
|
||||
h.server = s
|
||||
h.mutex.Unlock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
handler, ok := h.resetTokens[token]
|
||||
return handler, ok
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) CloseServer() {
|
||||
h.mutex.Lock()
|
||||
if h.server == nil {
|
||||
h.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
h.server = nil
|
||||
var wg sync.WaitGroup
|
||||
for _, handler := range h.handlers {
|
||||
if handler.getPerspective() == protocol.PerspectiveServer {
|
||||
|
@ -316,23 +227,16 @@ func (h *packetHandlerMap) CloseServer() {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
// Destroy closes the underlying connection and waits until listen() has returned.
|
||||
// It does not close active connections.
|
||||
func (h *packetHandlerMap) Destroy() error {
|
||||
if err := h.conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
<-h.listening // wait until listening returns
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) close(e error) error {
|
||||
func (h *packetHandlerMap) Close(e error) {
|
||||
h.mutex.Lock()
|
||||
|
||||
if h.closed {
|
||||
h.mutex.Unlock()
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
close(h.closeChan)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, handler := range h.handlers {
|
||||
wg.Add(1)
|
||||
|
@ -341,89 +245,14 @@ func (h *packetHandlerMap) close(e error) error {
|
|||
wg.Done()
|
||||
}(handler)
|
||||
}
|
||||
|
||||
if h.server != nil {
|
||||
h.server.setCloseError(e)
|
||||
}
|
||||
h.closed = true
|
||||
h.mutex.Unlock()
|
||||
wg.Wait()
|
||||
return getMultiplexer().RemoveConn(h.conn)
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) listen() {
|
||||
defer close(h.listening)
|
||||
for {
|
||||
p, err := h.conn.ReadPacket()
|
||||
//nolint:staticcheck // SA1019 ignore this!
|
||||
// TODO: This code is used to ignore wsa errors on Windows.
|
||||
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
|
||||
// See https://github.com/quic-go/quic-go/issues/1737 for details.
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
||||
h.logger.Debugf("Temporary error reading from conn: %w", err)
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
h.close(err)
|
||||
return
|
||||
}
|
||||
h.handlePacket(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
|
||||
connID, err := wire.ParseConnectionID(p.data, h.connIDLen)
|
||||
if err != nil {
|
||||
h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
|
||||
if h.tracer != nil {
|
||||
h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
p.buffer.MaybeRelease()
|
||||
return
|
||||
}
|
||||
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset {
|
||||
return
|
||||
}
|
||||
if handler, ok := h.handlers[connID]; ok {
|
||||
handler.handlePacket(p)
|
||||
return
|
||||
}
|
||||
if !wire.IsLongHeaderPacket(p.data[0]) {
|
||||
go h.maybeSendStatelessReset(p, connID)
|
||||
return
|
||||
}
|
||||
if h.server == nil { // no server set
|
||||
h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
|
||||
return
|
||||
}
|
||||
h.server.handlePacket(p)
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
|
||||
// stateless resets are always short header packets
|
||||
if wire.IsLongHeaderPacket(data[0]) {
|
||||
return false
|
||||
}
|
||||
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
|
||||
return false
|
||||
}
|
||||
|
||||
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
|
||||
if sess, ok := h.resetTokens[token]; ok {
|
||||
h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
|
||||
go sess.destroy(&StatelessResetError{Token: token})
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
var token protocol.StatelessResetToken
|
||||
if !h.statelessResetEnabled {
|
||||
if h.statelessResetHasher == nil {
|
||||
// Return a random stateless reset token.
|
||||
// This token will be sent in the server's transport parameters.
|
||||
// By using a random token, an off-path attacker won't be able to disrupt the connection.
|
||||
|
@ -437,24 +266,3 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID)
|
|||
h.statelessResetMutex.Unlock()
|
||||
return token
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
|
||||
defer p.buffer.Release()
|
||||
if !h.statelessResetEnabled {
|
||||
return
|
||||
}
|
||||
// Don't send a stateless reset in response to very small packets.
|
||||
// This includes packets that could be stateless resets.
|
||||
if len(p.data) <= protocol.MinStatelessResetSize {
|
||||
return
|
||||
}
|
||||
token := h.GetStatelessResetToken(connID)
|
||||
h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
|
||||
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
|
||||
rand.Read(data)
|
||||
data[0] = (data[0] & 0x7f) | 0x40
|
||||
data = append(data, token[:]...)
|
||||
if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
|
||||
h.logger.Debugf("Error sending Stateless Reset: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,405 +6,188 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Packet Handler Map", func() {
|
||||
type packetToRead struct {
|
||||
addr net.Addr
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
var (
|
||||
handler *packetHandlerMap
|
||||
conn *MockPacketConn
|
||||
tracer *mocklogging.MockTracer
|
||||
packetChan chan packetToRead
|
||||
|
||||
connIDLen int
|
||||
statelessResetKey *StatelessResetKey
|
||||
)
|
||||
|
||||
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
|
||||
b, err := (&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
Type: t,
|
||||
DestConnectionID: connID,
|
||||
Length: length,
|
||||
Version: protocol.Version1,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return b
|
||||
}
|
||||
|
||||
getPacket := func(connID protocol.ConnectionID) []byte {
|
||||
return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
statelessResetKey = nil
|
||||
connIDLen = 0
|
||||
tracer = mocklogging.NewMockTracer(mockCtrl)
|
||||
packetChan = make(chan packetToRead, 10)
|
||||
It("adds and gets", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
h, ok := m.Get(connID)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(h).To(Equal(handler))
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
conn = NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
|
||||
p, ok := <-packetChan
|
||||
if !ok {
|
||||
return 0, nil, errors.New("closed")
|
||||
It("refused to add duplicates", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
Expect(m.Add(connID, handler)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("removes", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
m.Remove(connID)
|
||||
_, ok := m.Get(connID)
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("retires", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
dur := scaleDuration(50 * time.Millisecond)
|
||||
m.deleteRetiredConnsAfter = dur
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
m.Retire(connID)
|
||||
_, ok := m.Get(connID)
|
||||
Expect(ok).To(BeTrue())
|
||||
time.Sleep(dur)
|
||||
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("adds newly to-be-constructed handlers", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
var called bool
|
||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
Expect(m.AddWithConnID(connID1, connID2, func() packetHandler {
|
||||
called = true
|
||||
return NewMockPacketHandler(mockCtrl)
|
||||
})).To(BeTrue())
|
||||
Expect(called).To(BeTrue())
|
||||
Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() packetHandler {
|
||||
Fail("didn't expect the constructor to be executed")
|
||||
return nil
|
||||
})).To(BeFalse())
|
||||
})
|
||||
|
||||
It("adds, gets and removes reset tokens", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
m.AddResetToken(token, handler)
|
||||
h, ok := m.GetByResetToken(token)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(h).To(Equal(h))
|
||||
m.RemoveResetToken(token)
|
||||
_, ok = m.GetByResetToken(token)
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
|
||||
It("generates stateless reset token, if no key is set", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
connID := protocol.ParseConnectionID(b)
|
||||
token := m.GetStatelessResetToken(connID)
|
||||
for i := 0; i < 1000; i++ {
|
||||
to := m.GetStatelessResetToken(connID)
|
||||
Expect(to).ToNot(Equal(token))
|
||||
token = to
|
||||
}
|
||||
})
|
||||
|
||||
It("generates stateless reset token, if a key is set", func() {
|
||||
var key StatelessResetKey
|
||||
rand.Read(key[:])
|
||||
m := newPacketHandlerMap(&key, nil, utils.DefaultLogger)
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
connID := protocol.ParseConnectionID(b)
|
||||
token := m.GetStatelessResetToken(connID)
|
||||
Expect(token).ToNot(BeZero())
|
||||
Expect(m.GetStatelessResetToken(connID)).To(Equal(token))
|
||||
// generate a new connection ID
|
||||
rand.Read(b)
|
||||
connID2 := protocol.ParseConnectionID(b)
|
||||
Expect(m.GetStatelessResetToken(connID2)).ToNot(Equal(token))
|
||||
})
|
||||
|
||||
It("replaces locally closed connections", func() {
|
||||
var closePackets []closePacket
|
||||
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
|
||||
dur := scaleDuration(50 * time.Millisecond)
|
||||
m.deleteRetiredConnsAfter = dur
|
||||
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, []byte("foobar"))
|
||||
h, ok := m.Get(connID)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(h).ToNot(Equal(handler))
|
||||
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
||||
h.handlePacket(&receivedPacket{remoteAddr: addr})
|
||||
Expect(closePackets).To(HaveLen(1))
|
||||
Expect(closePackets[0].addr).To(Equal(addr))
|
||||
Expect(closePackets[0].payload).To(Equal([]byte("foobar")))
|
||||
|
||||
time.Sleep(dur)
|
||||
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("replaces remote closed connections", func() {
|
||||
var closePackets []closePacket
|
||||
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
|
||||
dur := scaleDuration(50 * time.Millisecond)
|
||||
m.deleteRetiredConnsAfter = dur
|
||||
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, nil)
|
||||
h, ok := m.Get(connID)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(h).ToNot(Equal(handler))
|
||||
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
||||
h.handlePacket(&receivedPacket{remoteAddr: addr})
|
||||
Expect(closePackets).To(BeEmpty())
|
||||
|
||||
time.Sleep(dur)
|
||||
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("closes the server", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
for i := 0; i < 10; i++ {
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
if i%2 == 0 {
|
||||
conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
|
||||
} else {
|
||||
conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
|
||||
conn.EXPECT().shutdown()
|
||||
}
|
||||
return copy(b, p.data), p.addr, p.err
|
||||
}).AnyTimes()
|
||||
phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler = phm.(*packetHandlerMap)
|
||||
b := make([]byte, 12)
|
||||
rand.Read(b)
|
||||
m.Add(protocol.ParseConnectionID(b), conn)
|
||||
}
|
||||
m.CloseServer()
|
||||
})
|
||||
|
||||
It("closes", func() {
|
||||
getMultiplexer() // make the sync.Once execute
|
||||
// replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
|
||||
mockMultiplexer := NewMockMultiplexer(mockCtrl)
|
||||
origMultiplexer := connMuxer
|
||||
connMuxer = mockMultiplexer
|
||||
|
||||
defer func() {
|
||||
connMuxer = origMultiplexer
|
||||
}()
|
||||
|
||||
testErr := errors.New("test error ")
|
||||
conn1 := NewMockPacketHandler(mockCtrl)
|
||||
conn1.EXPECT().destroy(testErr)
|
||||
conn2 := NewMockPacketHandler(mockCtrl)
|
||||
conn2.EXPECT().destroy(testErr)
|
||||
handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), conn1)
|
||||
handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), conn2)
|
||||
mockMultiplexer.EXPECT().RemoveConn(gomock.Any())
|
||||
handler.close(testErr)
|
||||
close(packetChan)
|
||||
Eventually(handler.listening).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("other operations", func() {
|
||||
AfterEach(func() {
|
||||
// delete connections and the server before closing
|
||||
// They might be mock implementations, and we'd have to register the expected calls before otherwise.
|
||||
handler.mutex.Lock()
|
||||
for connID := range handler.handlers {
|
||||
delete(handler.handlers, connID)
|
||||
}
|
||||
handler.server = nil
|
||||
handler.mutex.Unlock()
|
||||
conn.EXPECT().Close().MaxTimes(1)
|
||||
close(packetChan)
|
||||
handler.Destroy()
|
||||
Eventually(handler.listening).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("handling packets", func() {
|
||||
BeforeEach(func() {
|
||||
connIDLen = 5
|
||||
})
|
||||
|
||||
It("handles packets for different packet handlers on the same packet conn", func() {
|
||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
|
||||
packetHandler1 := NewMockPacketHandler(mockCtrl)
|
||||
packetHandler2 := NewMockPacketHandler(mockCtrl)
|
||||
handledPacket1 := make(chan struct{})
|
||||
handledPacket2 := make(chan struct{})
|
||||
packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(connID).To(Equal(connID1))
|
||||
close(handledPacket1)
|
||||
})
|
||||
packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(connID).To(Equal(connID2))
|
||||
close(handledPacket2)
|
||||
})
|
||||
handler.Add(connID1, packetHandler1)
|
||||
handler.Add(connID2, packetHandler2)
|
||||
packetChan <- packetToRead{data: getPacket(connID1)}
|
||||
packetChan <- packetToRead{data: getPacket(connID2)}
|
||||
|
||||
Eventually(handledPacket1).Should(BeClosed())
|
||||
Eventually(handledPacket2).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("drops unparseable packets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
||||
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError)
|
||||
handler.handlePacket(&receivedPacket{
|
||||
buffer: getPacketBuffer(),
|
||||
remoteAddr: addr,
|
||||
data: []byte{0, 1, 2, 3},
|
||||
})
|
||||
})
|
||||
|
||||
It("deletes removed connections immediately", func() {
|
||||
handler.deleteRetiredConnsAfter = time.Hour
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
||||
handler.Remove(connID)
|
||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
||||
// don't EXPECT any calls to handlePacket of the MockPacketHandler
|
||||
})
|
||||
|
||||
It("deletes retired connection entries after a wait time", func() {
|
||||
handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
handler.Add(connID, conn)
|
||||
handler.Retire(connID)
|
||||
time.Sleep(scaleDuration(30 * time.Millisecond))
|
||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
||||
// don't EXPECT any calls to handlePacket of the MockPacketHandler
|
||||
})
|
||||
|
||||
It("passes packets arriving late for closed connections to that connection", func() {
|
||||
handler.deleteRetiredConnsAfter = time.Hour
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
handled := make(chan struct{})
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
close(handled)
|
||||
})
|
||||
handler.Add(connID, packetHandler)
|
||||
handler.Retire(connID)
|
||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
||||
Eventually(handled).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("drops packets for unknown receivers", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
||||
})
|
||||
|
||||
It("closes the packet handlers when reading from the conn fails", func() {
|
||||
done := make(chan struct{})
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) {
|
||||
Expect(e).To(HaveOccurred())
|
||||
close(done)
|
||||
})
|
||||
handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler)
|
||||
packetChan <- packetToRead{err: errors.New("read failed")}
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("continues listening for temporary errors", func() {
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler)
|
||||
err := deadlineError{}
|
||||
Expect(err.Temporary()).To(BeTrue())
|
||||
packetChan <- packetToRead{err: err}
|
||||
// don't EXPECT any calls to packetHandler.destroy
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
|
||||
It("says if a connection ID is already taken", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
|
||||
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
|
||||
})
|
||||
|
||||
It("says if a connection ID is already taken, for AddWithConnID", func() {
|
||||
clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
newConnID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
newConnID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue())
|
||||
Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("running a server", func() {
|
||||
It("adds a server", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
|
||||
p := getPacket(connID)
|
||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
||||
server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
cid, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cid).To(Equal(connID))
|
||||
})
|
||||
handler.SetServer(server)
|
||||
handler.handlePacket(&receivedPacket{data: p})
|
||||
})
|
||||
|
||||
It("closes all server connections", func() {
|
||||
handler.SetServer(NewMockUnknownPacketHandler(mockCtrl))
|
||||
clientConn := NewMockPacketHandler(mockCtrl)
|
||||
clientConn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
|
||||
serverConn := NewMockPacketHandler(mockCtrl)
|
||||
serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
|
||||
serverConn.EXPECT().shutdown()
|
||||
|
||||
handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), clientConn)
|
||||
handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), serverConn)
|
||||
handler.CloseServer()
|
||||
})
|
||||
|
||||
It("stops handling packets with unknown connection IDs after the server is closed", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
|
||||
p := getPacket(connID)
|
||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
||||
// don't EXPECT any calls to server.handlePacket
|
||||
handler.SetServer(server)
|
||||
handler.CloseServer()
|
||||
handler.handlePacket(&receivedPacket{data: p})
|
||||
})
|
||||
})
|
||||
|
||||
Context("stateless resets", func() {
|
||||
BeforeEach(func() {
|
||||
connIDLen = 5
|
||||
})
|
||||
|
||||
Context("handling", func() {
|
||||
It("handles stateless resets", func() {
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
handler.AddResetToken(token, packetHandler)
|
||||
destroyed := make(chan struct{})
|
||||
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
|
||||
packet = append(packet, token[:]...)
|
||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
||||
defer GinkgoRecover()
|
||||
defer close(destroyed)
|
||||
Expect(err).To(HaveOccurred())
|
||||
var resetErr *StatelessResetError
|
||||
Expect(errors.As(err, &resetErr)).To(BeTrue())
|
||||
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
|
||||
Expect(resetErr.Token).To(Equal(token))
|
||||
})
|
||||
packetChan <- packetToRead{data: packet}
|
||||
Eventually(destroyed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("handles stateless resets for 0-length connection IDs", func() {
|
||||
handler.connIDLen = 0
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
handler.AddResetToken(token, packetHandler)
|
||||
destroyed := make(chan struct{})
|
||||
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
|
||||
packet = append(packet, token[:]...)
|
||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
||||
defer GinkgoRecover()
|
||||
Expect(err).To(HaveOccurred())
|
||||
var resetErr *StatelessResetError
|
||||
Expect(errors.As(err, &resetErr)).To(BeTrue())
|
||||
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
|
||||
Expect(resetErr.Token).To(Equal(token))
|
||||
close(destroyed)
|
||||
})
|
||||
packetChan <- packetToRead{data: packet}
|
||||
Eventually(destroyed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("removes reset tokens", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42})
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
handler.Add(connID, packetHandler)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
handler.AddResetToken(token, NewMockPacketHandler(mockCtrl))
|
||||
handler.RemoveResetToken(token)
|
||||
// don't EXPECT any call to packetHandler.destroy()
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any())
|
||||
p := append([]byte{0x40} /* short header packet */, connID.Bytes()...)
|
||||
p = append(p, make([]byte, 50)...)
|
||||
p = append(p, token[:]...)
|
||||
|
||||
handler.handlePacket(&receivedPacket{data: p})
|
||||
})
|
||||
|
||||
It("ignores packets too small to contain a stateless reset", func() {
|
||||
handler.connIDLen = 0
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
handler.AddResetToken(token, packetHandler)
|
||||
done := make(chan struct{})
|
||||
// don't EXPECT any calls here, but register the closing of the done channel
|
||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) {
|
||||
close(done)
|
||||
}).AnyTimes()
|
||||
packetChan <- packetToRead{data: append([]byte{0x40} /* short header packet */, token[:15]...)}
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("generating", func() {
|
||||
BeforeEach(func() {
|
||||
var key StatelessResetKey
|
||||
rand.Read(key[:])
|
||||
statelessResetKey = &key
|
||||
})
|
||||
|
||||
It("generates stateless reset tokens", func() {
|
||||
connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
|
||||
connID2 := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})
|
||||
Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2)))
|
||||
})
|
||||
|
||||
It("sends stateless resets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
p := append([]byte{40}, make([]byte, 100)...)
|
||||
done := make(chan struct{})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) {
|
||||
defer close(done)
|
||||
Expect(wire.IsLongHeaderPacket(b[0])).To(BeFalse()) // short header packet
|
||||
Expect(b).To(HaveLen(protocol.MinStatelessResetSize))
|
||||
})
|
||||
handler.handlePacket(&receivedPacket{
|
||||
buffer: getPacketBuffer(),
|
||||
remoteAddr: addr,
|
||||
data: p,
|
||||
})
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't send stateless resets for small packets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...)
|
||||
handler.handlePacket(&receivedPacket{
|
||||
buffer: getPacketBuffer(),
|
||||
remoteAddr: addr,
|
||||
data: p,
|
||||
})
|
||||
// make sure there are no Write calls on the packet conn
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
})
|
||||
|
||||
Context("if no key is configured", func() {
|
||||
It("doesn't send stateless resets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
p := append([]byte{40}, make([]byte, 100)...)
|
||||
handler.handlePacket(&receivedPacket{
|
||||
buffer: getPacketBuffer(),
|
||||
remoteAddr: addr,
|
||||
data: p,
|
||||
})
|
||||
// make sure there are no Write calls on the packet conn
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
})
|
||||
})
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
testErr := errors.New("shutdown")
|
||||
for i := 0; i < 10; i++ {
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
conn.EXPECT().destroy(testErr)
|
||||
b := make([]byte, 12)
|
||||
rand.Read(b)
|
||||
m.Add(protocol.ParseConnectionID(b), conn)
|
||||
}
|
||||
m.Close(testErr)
|
||||
// check that Close can be called multiple times
|
||||
m.Close(errors.New("close"))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
|
@ -29,6 +32,20 @@ var _ = BeforeSuite(func() {
|
|||
log.SetOutput(io.Discard)
|
||||
})
|
||||
|
||||
func areServersRunning() bool {
|
||||
var b bytes.Buffer
|
||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||
return strings.Contains(b.String(), "quic-go.(*baseServer).run")
|
||||
}
|
||||
|
||||
func areTransportsRunning() bool {
|
||||
var b bytes.Buffer
|
||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
|
||||
}
|
||||
|
||||
var _ = AfterEach(func() {
|
||||
mockCtrl.Finish()
|
||||
Eventually(areServersRunning).Should(BeFalse())
|
||||
Eventually(areTransportsRunning()).Should(BeFalse())
|
||||
})
|
||||
|
|
92
server.go
92
server.go
|
@ -20,7 +20,7 @@ import (
|
|||
)
|
||||
|
||||
// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close.
|
||||
var ErrServerClosed = errors.New("quic: Server closed")
|
||||
var ErrServerClosed = errors.New("quic: server closed")
|
||||
|
||||
// packetHandler handles packets
|
||||
type packetHandler interface {
|
||||
|
@ -30,18 +30,13 @@ type packetHandler interface {
|
|||
getPerspective() protocol.Perspective
|
||||
}
|
||||
|
||||
type unknownPacketHandler interface {
|
||||
handlePacket(*receivedPacket)
|
||||
setCloseError(error)
|
||||
}
|
||||
|
||||
type packetHandlerManager interface {
|
||||
Get(protocol.ConnectionID) (packetHandler, bool)
|
||||
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
|
||||
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
|
||||
Destroy() error
|
||||
connRunner
|
||||
SetServer(unknownPacketHandler)
|
||||
Close(error)
|
||||
CloseServer()
|
||||
connRunner
|
||||
}
|
||||
|
||||
type quicConn interface {
|
||||
|
@ -70,13 +65,11 @@ type baseServer struct {
|
|||
config *Config
|
||||
|
||||
conn rawConn
|
||||
// If the server is started with ListenAddr, we create a packet conn.
|
||||
// If it is started with Listen, we take a packet conn as a parameter.
|
||||
createdPacketConn bool
|
||||
|
||||
tokenGenerator *handshake.TokenGenerator
|
||||
|
||||
connHandler packetHandlerManager
|
||||
onClose func()
|
||||
|
||||
receivedPackets chan *receivedPacket
|
||||
|
||||
|
@ -114,8 +107,6 @@ type baseServer struct {
|
|||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ unknownPacketHandler = &baseServer{}
|
||||
|
||||
// A Listener listens for incoming QUIC connections.
|
||||
// It returns connections once the handshake has completed.
|
||||
type Listener struct {
|
||||
|
@ -166,37 +157,36 @@ func (l *EarlyListener) Addr() net.Addr {
|
|||
// The tls.Config must not be nil and must contain a certificate configuration.
|
||||
// The quic.Config may be nil, in that case the default values will be used.
|
||||
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) {
|
||||
s, err := listenAddr(addr, tlsConf, config, false)
|
||||
conn, err := listenUDP(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{baseServer: s}, nil
|
||||
return (&Transport{
|
||||
Conn: conn,
|
||||
createdConn: true,
|
||||
isSingleUse: true,
|
||||
}).Listen(tlsConf, config)
|
||||
}
|
||||
|
||||
// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes.
|
||||
func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
|
||||
s, err := listenAddr(addr, tlsConf, config, true)
|
||||
conn, err := listenUDP(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &EarlyListener{baseServer: s}, nil
|
||||
return (&Transport{
|
||||
Conn: conn,
|
||||
createdConn: true,
|
||||
isSingleUse: true,
|
||||
}).ListenEarly(tlsConf, config)
|
||||
}
|
||||
|
||||
func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) {
|
||||
func listenUDP(addr string) (*net.UDPConn, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serv, err := listen(conn, tlsConf, config, acceptEarly)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serv.createdPacketConn = true
|
||||
return serv, nil
|
||||
return net.ListenUDP("udp", udpAddr)
|
||||
}
|
||||
|
||||
// Listen listens for QUIC connections on a given net.PacketConn. If the
|
||||
|
@ -210,45 +200,23 @@ func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bo
|
|||
// Furthermore, it must define an application control (using NextProtos).
|
||||
// The quic.Config may be nil, in that case the default values will be used.
|
||||
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) {
|
||||
s, err := listen(conn, tlsConf, config, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{baseServer: s}, nil
|
||||
tr := &Transport{Conn: conn, isSingleUse: true}
|
||||
return tr.Listen(tlsConf, config)
|
||||
}
|
||||
|
||||
// ListenEarly works like Listen, but it returns connections before the handshake completes.
|
||||
func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
|
||||
s, err := listen(conn, tlsConf, config, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &EarlyListener{baseServer: s}, nil
|
||||
tr := &Transport{Conn: conn, isSingleUse: true}
|
||||
return tr.ListenEarly(tlsConf, config)
|
||||
}
|
||||
|
||||
func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
if err := validateConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config = populateServerConfig(config)
|
||||
|
||||
connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) {
|
||||
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := wrapConn(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &baseServer{
|
||||
conn: c,
|
||||
conn: conn,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
tokenGenerator: tokenGenerator,
|
||||
|
@ -260,12 +228,12 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
|
|||
newConn: newConnection,
|
||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||
acceptEarlyConns: acceptEarly,
|
||||
onClose: onClose,
|
||||
}
|
||||
if acceptEarly {
|
||||
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
|
||||
}
|
||||
go s.run()
|
||||
connHandler.SetServer(s)
|
||||
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
|
||||
return s, nil
|
||||
}
|
||||
|
@ -317,18 +285,12 @@ func (s *baseServer) Close() error {
|
|||
if s.serverError == nil {
|
||||
s.serverError = ErrServerClosed
|
||||
}
|
||||
// If the server was started with ListenAddr, we created the packet conn.
|
||||
// We need to close it in order to make the go routine reading from that conn return.
|
||||
createdPacketConn := s.createdPacketConn
|
||||
s.closed = true
|
||||
close(s.errorChan)
|
||||
s.mutex.Unlock()
|
||||
|
||||
<-s.running
|
||||
s.connHandler.CloseServer()
|
||||
if createdPacketConn {
|
||||
return s.connHandler.Destroy()
|
||||
}
|
||||
s.onClose()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -24,17 +21,10 @@ import (
|
|||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func areServersRunning() bool {
|
||||
var b bytes.Buffer
|
||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||
return strings.Contains(b.String(), "quic-go.(*baseServer).run")
|
||||
}
|
||||
|
||||
var _ = Describe("Server", func() {
|
||||
var (
|
||||
conn *MockPacketConn
|
||||
|
@ -96,15 +86,19 @@ var _ = Describe("Server", func() {
|
|||
BeforeEach(func() {
|
||||
conn = NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
||||
wait := make(chan struct{})
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) {
|
||||
<-wait
|
||||
return 0, nil, errors.New("done")
|
||||
}).MaxTimes(1)
|
||||
conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) {
|
||||
close(wait)
|
||||
conn.EXPECT().SetReadDeadline(time.Time{})
|
||||
}).MaxTimes(1)
|
||||
tlsConf = testdata.GetTLSConfig()
|
||||
tlsConf.NextProtos = []string{"proto1"}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Eventually(areServersRunning).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("errors when no tls.Config is given", func() {
|
||||
_, err := ListenAddr("localhost:0", nil, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
|
@ -178,6 +172,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
Context("server accepting connections that completed the handshake", func() {
|
||||
var (
|
||||
ln *Listener
|
||||
serv *baseServer
|
||||
phm *MockPacketHandlerManager
|
||||
tracer *mocklogging.MockTracer
|
||||
|
@ -185,7 +180,8 @@ var _ = Describe("Server", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
tracer = mocklogging.NewMockTracer(mockCtrl)
|
||||
ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer})
|
||||
var err error
|
||||
ln, err = Listen(conn, tlsConf, &Config{Tracer: tracer})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serv = ln.baseServer
|
||||
phm = NewMockPacketHandlerManager(mockCtrl)
|
||||
|
@ -193,8 +189,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
AfterEach(func() {
|
||||
phm.EXPECT().CloseServer().MaxTimes(1)
|
||||
serv.Close()
|
||||
ln.Close()
|
||||
})
|
||||
|
||||
Context("handling packets", func() {
|
||||
|
@ -753,8 +748,7 @@ var _ = Describe("Server", func() {
|
|||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
phm.EXPECT().CloseServer()
|
||||
conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
|
||||
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -968,6 +962,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
serv.setCloseError(testErr)
|
||||
Eventually(done).Should(BeClosed())
|
||||
serv.onClose() // shutdown
|
||||
})
|
||||
|
||||
It("returns immediately, if an error occurred before", func() {
|
||||
|
@ -977,6 +972,7 @@ var _ = Describe("Server", func() {
|
|||
_, err := serv.Accept(context.Background())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
}
|
||||
serv.onClose() // shutdown
|
||||
})
|
||||
|
||||
It("returns when the context is canceled", func() {
|
||||
|
@ -1064,7 +1060,6 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
AfterEach(func() {
|
||||
phm.EXPECT().CloseServer().MaxTimes(1)
|
||||
serv.Close()
|
||||
})
|
||||
|
||||
|
@ -1234,8 +1229,7 @@ var _ = Describe("Server", func() {
|
|||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
phm.EXPECT().CloseServer()
|
||||
conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
|
||||
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
|
410
transport.go
Normal file
410
transport.go
Normal file
|
@ -0,0 +1,410 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
type Transport struct {
|
||||
// A single net.PacketConn can only be handled by one Transport.
|
||||
// Bad things will happen if passed to multiple Transports.
|
||||
//
|
||||
// If the connection satisfies the OOBCapablePacketConn interface
|
||||
// (as a net.UDPConn does), ECN and packet info support will be enabled.
|
||||
// In this case, optimized syscalls might be used, skipping the
|
||||
// ReadFrom and WriteTo calls to read / write packets.
|
||||
Conn net.PacketConn
|
||||
|
||||
// The length of the connection ID in bytes.
|
||||
// It can be 0, or any value between 4 and 18.
|
||||
// If unset, a 4 byte connection ID will be used.
|
||||
ConnectionIDLength int
|
||||
|
||||
// Use for generating new connection IDs.
|
||||
// This allows the application to control of the connection IDs used,
|
||||
// which allows routing / load balancing based on connection IDs.
|
||||
// All Connection IDs returned by the ConnectionIDGenerator MUST
|
||||
// have the same length.
|
||||
ConnectionIDGenerator ConnectionIDGenerator
|
||||
|
||||
// The StatelessResetKey is used to generate stateless reset tokens.
|
||||
// If no key is configured, sending of stateless resets is disabled.
|
||||
StatelessResetKey *StatelessResetKey
|
||||
|
||||
// A Tracer traces events that don't belong to a single QUIC connection.
|
||||
Tracer logging.Tracer
|
||||
|
||||
handlerMap packetHandlerManager
|
||||
|
||||
mutex sync.Mutex
|
||||
initOnce sync.Once
|
||||
initErr error
|
||||
|
||||
// Set in init.
|
||||
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
|
||||
connIDLen int
|
||||
|
||||
server unknownPacketHandler
|
||||
|
||||
conn rawConn
|
||||
|
||||
closeQueue chan closePacket
|
||||
|
||||
listening chan struct{} // is closed when listen returns
|
||||
closed bool
|
||||
createdConn bool
|
||||
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// Listen starts listening for incoming QUIC connections.
|
||||
// There can only be a single listener on any net.PacketConn.
|
||||
// Listen may only be called again after the current Listener was closed.
|
||||
func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
if err := validateConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
if t.server != nil {
|
||||
return nil, errListenerAlreadySet
|
||||
}
|
||||
conf = populateServerConfig(conf)
|
||||
if err := t.init(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.server = s
|
||||
return &Listener{baseServer: s}, nil
|
||||
}
|
||||
|
||||
// ListenEarly starts listening for incoming QUIC connections.
|
||||
// There can only be a single listener on any net.PacketConn.
|
||||
// Listen may only be called again after the current Listener was closed.
|
||||
func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
if err := validateConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
if t.server != nil {
|
||||
return nil, errListenerAlreadySet
|
||||
}
|
||||
conf = populateServerConfig(conf)
|
||||
if err := t.init(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.server = s
|
||||
return &EarlyListener{baseServer: s}, nil
|
||||
}
|
||||
|
||||
// Dial dials a new connection to a remote host (not using 0-RTT).
|
||||
func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||
if err := validateConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf = populateClientConfig(conf, t.createdConn)
|
||||
if err := t.init(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var onClose func()
|
||||
if t.isSingleUse {
|
||||
onClose = func() { t.Close() }
|
||||
}
|
||||
return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn)
|
||||
}
|
||||
|
||||
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
|
||||
func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||
if err := validateConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf = populateClientConfig(conf, t.createdConn)
|
||||
if err := t.init(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var onClose func()
|
||||
if t.isSingleUse {
|
||||
onClose = func() { t.Close() }
|
||||
}
|
||||
return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn)
|
||||
}
|
||||
|
||||
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
|
||||
conn, ok := c.(interface{ SetReadBuffer(int) error })
|
||||
if !ok {
|
||||
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
|
||||
}
|
||||
size, err := inspectReadBuffer(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to determine receive buffer size: %w", err)
|
||||
}
|
||||
if size >= protocol.DesiredReceiveBufferSize {
|
||||
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
|
||||
return nil
|
||||
}
|
||||
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
|
||||
return fmt.Errorf("failed to increase receive buffer size: %w", err)
|
||||
}
|
||||
newSize, err := inspectReadBuffer(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to determine receive buffer size: %w", err)
|
||||
}
|
||||
if newSize == size {
|
||||
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
|
||||
}
|
||||
if newSize < protocol.DesiredReceiveBufferSize {
|
||||
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
|
||||
}
|
||||
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
|
||||
return nil
|
||||
}
|
||||
|
||||
// only print warnings about the UDP receive buffer size once
|
||||
var receiveBufferWarningOnce sync.Once
|
||||
|
||||
func (t *Transport) init(conf *Config) error {
|
||||
t.initOnce.Do(func() {
|
||||
getMultiplexer().AddConn(t.Conn)
|
||||
|
||||
conn, err := wrapConn(t.Conn)
|
||||
if err != nil {
|
||||
t.initErr = err
|
||||
return
|
||||
}
|
||||
|
||||
t.StatelessResetKey = conf.StatelessResetKey
|
||||
t.Tracer = conf.Tracer
|
||||
t.ConnectionIDLength = conf.ConnectionIDLength
|
||||
t.ConnectionIDGenerator = conf.ConnectionIDGenerator
|
||||
|
||||
t.logger = utils.DefaultLogger // TODO: make this configurable
|
||||
t.conn = conn
|
||||
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
|
||||
t.listening = make(chan struct{})
|
||||
|
||||
t.closeQueue = make(chan closePacket, 4)
|
||||
|
||||
if t.ConnectionIDGenerator != nil {
|
||||
t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
|
||||
} else {
|
||||
t.connIDLen = t.ConnectionIDLength
|
||||
}
|
||||
|
||||
go t.listen(conn)
|
||||
go t.runCloseQueue()
|
||||
})
|
||||
return t.initErr
|
||||
}
|
||||
|
||||
func (t *Transport) enqueueClosePacket(p closePacket) {
|
||||
select {
|
||||
case t.closeQueue <- p:
|
||||
default:
|
||||
// Oops, we're backlogged.
|
||||
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) runCloseQueue() {
|
||||
for {
|
||||
select {
|
||||
case <-t.listening:
|
||||
return
|
||||
case p := <-t.closeQueue:
|
||||
t.conn.WritePacket(p.payload, p.addr, p.info.OOB())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the underlying connection and waits until listen has returned.
|
||||
// It is invalid to start new listeners or connections after that.
|
||||
func (t *Transport) Close() error {
|
||||
t.close(errors.New("closing"))
|
||||
if t.createdConn {
|
||||
if err := t.conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
t.conn.SetReadDeadline(time.Now())
|
||||
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
|
||||
}
|
||||
<-t.listening // wait until listening returns
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Transport) closeServer() {
|
||||
t.handlerMap.CloseServer()
|
||||
t.mutex.Lock()
|
||||
t.server = nil
|
||||
if t.isSingleUse {
|
||||
t.closed = true
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
if t.createdConn {
|
||||
t.Conn.Close()
|
||||
}
|
||||
if t.isSingleUse {
|
||||
t.conn.SetReadDeadline(time.Now())
|
||||
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
|
||||
<-t.listening // wait until listening returns
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) close(e error) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
if t.closed {
|
||||
return
|
||||
}
|
||||
|
||||
t.handlerMap.Close(e)
|
||||
if t.server != nil {
|
||||
t.server.setCloseError(e)
|
||||
}
|
||||
t.closed = true
|
||||
}
|
||||
|
||||
func (t *Transport) listen(conn rawConn) {
|
||||
defer close(t.listening)
|
||||
defer getMultiplexer().RemoveConn(t.Conn)
|
||||
|
||||
if err := setReceiveBuffer(t.Conn, t.logger); err != nil {
|
||||
if !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
receiveBufferWarningOnce.Do(func() {
|
||||
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
|
||||
return
|
||||
}
|
||||
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
p, err := conn.ReadPacket()
|
||||
//nolint:staticcheck // SA1019 ignore this!
|
||||
// TODO: This code is used to ignore wsa errors on Windows.
|
||||
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
|
||||
// See https://github.com/quic-go/quic-go/issues/1737 for details.
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
||||
t.mutex.Lock()
|
||||
closed := t.closed
|
||||
t.mutex.Unlock()
|
||||
if closed {
|
||||
return
|
||||
}
|
||||
t.logger.Debugf("Temporary error reading from conn: %w", err)
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.close(err)
|
||||
return
|
||||
}
|
||||
t.handlePacket(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) handlePacket(p *receivedPacket) {
|
||||
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
|
||||
if err != nil {
|
||||
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
|
||||
if t.Tracer != nil {
|
||||
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
p.buffer.MaybeRelease()
|
||||
return
|
||||
}
|
||||
|
||||
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
|
||||
return
|
||||
}
|
||||
if handler, ok := t.handlerMap.Get(connID); ok {
|
||||
handler.handlePacket(p)
|
||||
return
|
||||
}
|
||||
if !wire.IsLongHeaderPacket(p.data[0]) {
|
||||
go t.maybeSendStatelessReset(p, connID)
|
||||
return
|
||||
}
|
||||
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
if t.server == nil { // no server set
|
||||
t.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
|
||||
return
|
||||
}
|
||||
t.server.handlePacket(p)
|
||||
}
|
||||
|
||||
func (t *Transport) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
|
||||
defer p.buffer.Release()
|
||||
if t.StatelessResetKey == nil {
|
||||
return
|
||||
}
|
||||
// Don't send a stateless reset in response to very small packets.
|
||||
// This includes packets that could be stateless resets.
|
||||
if len(p.data) <= protocol.MinStatelessResetSize {
|
||||
return
|
||||
}
|
||||
token := t.handlerMap.GetStatelessResetToken(connID)
|
||||
t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
|
||||
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
|
||||
rand.Read(data)
|
||||
data[0] = (data[0] & 0x7f) | 0x40
|
||||
data = append(data, token[:]...)
|
||||
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
|
||||
t.logger.Debugf("Error sending Stateless Reset: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
|
||||
// stateless resets are always short header packets
|
||||
if wire.IsLongHeaderPacket(data[0]) {
|
||||
return false
|
||||
}
|
||||
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
|
||||
return false
|
||||
}
|
||||
|
||||
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
|
||||
if conn, ok := t.handlerMap.GetByResetToken(token); ok {
|
||||
t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
|
||||
go conn.destroy(&StatelessResetError{Token: token})
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
287
transport_test.go
Normal file
287
transport_test.go
Normal file
|
@ -0,0 +1,287 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Transport", func() {
|
||||
type packetToRead struct {
|
||||
addr net.Addr
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
|
||||
b, err := (&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
Type: t,
|
||||
DestConnectionID: connID,
|
||||
Length: length,
|
||||
Version: protocol.Version1,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return b
|
||||
}
|
||||
|
||||
getPacket := func(connID protocol.ConnectionID) []byte {
|
||||
return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
|
||||
}
|
||||
|
||||
newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn {
|
||||
conn := NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
|
||||
p, ok := <-packetChan
|
||||
if !ok {
|
||||
return 0, nil, errors.New("closed")
|
||||
}
|
||||
return copy(b, p.data), p.addr, p.err
|
||||
}).AnyTimes()
|
||||
// for shutdown
|
||||
conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
|
||||
return conn
|
||||
}
|
||||
|
||||
It("handles packets for different packet handlers on the same packet conn", func() {
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := &Transport{Conn: newMockPacketConn(packetChan)}
|
||||
tr.init(&Config{})
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
|
||||
|
||||
handled := make(chan struct{}, 2)
|
||||
phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
|
||||
h := NewMockPacketHandler(mockCtrl)
|
||||
h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
defer GinkgoRecover()
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(connID).To(Equal(connID1))
|
||||
handled <- struct{}{}
|
||||
})
|
||||
return h, true
|
||||
})
|
||||
phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
|
||||
h := NewMockPacketHandler(mockCtrl)
|
||||
h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
defer GinkgoRecover()
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(connID).To(Equal(connID2))
|
||||
handled <- struct{}{}
|
||||
})
|
||||
return h, true
|
||||
})
|
||||
|
||||
packetChan <- packetToRead{data: getPacket(connID1)}
|
||||
packetChan <- packetToRead{data: getPacket(connID2)}
|
||||
|
||||
Eventually(handled).Should(Receive())
|
||||
Eventually(handled).Should(Receive())
|
||||
|
||||
// shutdown
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("closes listeners", func() {
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := &Transport{Conn: newMockPacketConn(packetChan)}
|
||||
defer tr.Close()
|
||||
ln, err := tr.Listen(&tls.Config{}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
|
||||
phm.EXPECT().CloseServer()
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
|
||||
// shutdown
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("drops unparseable packets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
||||
packetChan := make(chan packetToRead)
|
||||
tracer := mocklogging.NewMockTracer(mockCtrl)
|
||||
tr := &Transport{
|
||||
Conn: newMockPacketConn(packetChan),
|
||||
}
|
||||
tr.init(&Config{Tracer: tracer, ConnectionIDLength: 10})
|
||||
dropped := make(chan struct{})
|
||||
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
|
||||
packetChan <- packetToRead{
|
||||
addr: addr,
|
||||
data: []byte{0, 1, 2, 3},
|
||||
}
|
||||
Eventually(dropped).Should(BeClosed())
|
||||
|
||||
// shutdown
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("closes when reading from the conn fails", func() {
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(&Config{})
|
||||
tr.handlerMap = phm
|
||||
|
||||
done := make(chan struct{})
|
||||
phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) })
|
||||
packetChan <- packetToRead{err: errors.New("read failed")}
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
||||
// shutdown
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("continues listening after temporary errors", func() {
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(&Config{})
|
||||
tr.handlerMap = phm
|
||||
|
||||
tempErr := deadlineError{}
|
||||
Expect(tempErr.Temporary()).To(BeTrue())
|
||||
packetChan <- packetToRead{err: tempErr}
|
||||
// don't expect any calls to phm.Close
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// shutdown
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("handles short header packets resets", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
tr.init(&Config{ConnectionIDLength: connID.Len()})
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(&Config{})
|
||||
tr.handlerMap = phm
|
||||
|
||||
var token protocol.StatelessResetToken
|
||||
rand.Read(token[:])
|
||||
|
||||
var b []byte
|
||||
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b = append(b, token[:]...)
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().GetByResetToken(token),
|
||||
phm.EXPECT().Get(connID).Return(conn, true),
|
||||
conn.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
Expect(p.data).To(Equal(b))
|
||||
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second))
|
||||
}),
|
||||
)
|
||||
packetChan <- packetToRead{data: b}
|
||||
|
||||
// shutdown
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("handles stateless resets", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
tr.init(&Config{ConnectionIDLength: connID.Len()})
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(&Config{})
|
||||
tr.handlerMap = phm
|
||||
|
||||
var token protocol.StatelessResetToken
|
||||
rand.Read(token[:])
|
||||
|
||||
var b []byte
|
||||
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b = append(b, token[:]...)
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().GetByResetToken(token).Return(conn, true),
|
||||
conn.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
||||
Expect(err).To(MatchError(&StatelessResetError{Token: token}))
|
||||
}),
|
||||
)
|
||||
packetChan <- packetToRead{data: b}
|
||||
|
||||
// shutdown
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("sends stateless resets", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
||||
packetChan := make(chan packetToRead)
|
||||
conn := newMockPacketConn(packetChan)
|
||||
tr := Transport{
|
||||
Conn: conn,
|
||||
}
|
||||
tr.init(&Config{ConnectionIDLength: connID.Len(), StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}})
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(&Config{})
|
||||
tr.handlerMap = phm
|
||||
|
||||
var b []byte
|
||||
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...)
|
||||
|
||||
var token protocol.StatelessResetToken
|
||||
rand.Read(token[:])
|
||||
written := make(chan struct{})
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().GetByResetToken(gomock.Any()),
|
||||
phm.EXPECT().Get(connID),
|
||||
phm.EXPECT().GetStatelessResetToken(connID).Return(token),
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) {
|
||||
defer close(written)
|
||||
Expect(bytes.Contains(b, token[:])).To(BeTrue())
|
||||
}),
|
||||
)
|
||||
packetChan <- packetToRead{data: b}
|
||||
Eventually(written).Should(BeClosed())
|
||||
|
||||
// shutdown
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
})
|
Loading…
Add table
Add a link
Reference in a new issue