implement the Transport

This commit is contained in:
Marten Seemann 2023-04-06 18:02:51 +08:00
parent ae5a8bd35c
commit 8189e75be6
31 changed files with 1309 additions and 1250 deletions

158
client.go
View file

@ -20,6 +20,7 @@ type client struct {
use0RTT bool
packetHandlers packetHandlerManager
onClose func()
tlsConf *tls.Config
config *Config
@ -45,32 +46,58 @@ var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
// DialAddr establishes a new QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
// The hostname for SNI is taken from the given address.
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (Connection, error) {
return dialAddrContext(ctx, addr, tlsConf, config, false)
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
conn, err := dialAddrContext(ctx, addr, tlsConf, config, true)
if err != nil {
return nil, err
}
utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection")
return conn, nil
}
func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, config *Config, use0RTT bool) (quicConn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return dialContext(ctx, udpConn, udpAddr, tlsConf, config, use0RTT, true)
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
dl, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
return dl.Dial(ctx, udpAddr, tlsConf, conf)
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
dl, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
// See DialEarly for details.
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
return nil, err
}
conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If
@ -78,34 +105,43 @@ func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, conf
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
// packets.
// The same PacketConn can be used for multiple calls to Dial and Listen.
// QUIC connection IDs are used for demultiplexing the different connections.
// The tls.Config must define an application protocol (using NextProtos).
func Dial(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (Connection, error) {
return dialContext(ctx, pconn, addr, tlsConf, config, false, false)
}
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
// The same PacketConn can be used for multiple calls to Dial and Listen,
// QUIC connection IDs are used for demultiplexing the different connections.
// The tls.Config must define an application protocol (using NextProtos).
func DialEarly(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
return dialContext(ctx, pconn, addr, tlsConf, config, true, false)
}
func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config, use0RTT bool, createdPacketConn bool) (quicConn, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(config); err != nil {
return nil, err
}
config = populateClientConfig(config, createdPacketConn)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
return nil, err
}
c, err := newClient(pconn, addr, config, tlsConf, use0RTT, createdPacketConn)
conn, err := dl.Dial(ctx, addr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
return &Transport{
Conn: c,
createdConn: createdPacketConn,
isSingleUse: true,
}, nil
}
func dial(
ctx context.Context,
conn net.PacketConn,
packetHandlers packetHandlerManager,
addr net.Addr,
tlsConf *tls.Config,
config *Config,
onClose func(),
use0RTT bool,
createdPacketConn bool,
) (quicConn, error) {
c, err := newClient(conn, addr, config, tlsConf, onClose, use0RTT, createdPacketConn)
if err != nil {
return nil, err
}
@ -128,7 +164,7 @@ func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsCo
return c.conn, nil
}
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, use0RTT bool, createdPacketConn bool) (*client, error) {
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, onClose func(), use0RTT, createdPacketConn bool) (*client, error) {
if tlsConf == nil {
tlsConf = &tls.Config{}
} else {
@ -149,6 +185,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
sconn: newSendPconn(pconn, remoteAddr),
createdPacketConn: createdPacketConn,
use0RTT: use0RTT,
onClose: onClose,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
@ -179,13 +216,18 @@ func (c *client) dial(ctx context.Context) error {
c.packetHandlers.Add(c.srcConnID, c.conn)
errorChan := make(chan error, 1)
recreateChan := make(chan errCloseForRecreating)
go func() {
err := c.conn.run() // returns as soon as the connection is closed
if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn {
c.packetHandlers.Destroy()
err := c.conn.run()
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
recreateChan <- *recreateErr
return
}
errorChan <- err
if c.onClose != nil {
c.onClose()
}
errorChan <- err // returns as soon as the connection is closed
}()
// only set when we're using 0-RTT
@ -200,14 +242,12 @@ func (c *client) dial(ctx context.Context) error {
c.conn.shutdown()
return ctx.Err()
case err := <-errorChan:
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
c.initialPacketNumber = recreateErr.nextPacketNumber
c.version = recreateErr.nextVersion
c.hasNegotiatedVersion = true
return c.dial(ctx)
}
return err
case recreateErr := <-recreateChan:
c.initialPacketNumber = recreateErr.nextPacketNumber
c.version = recreateErr.nextVersion
c.hasNegotiatedVersion = true
return c.dial(ctx)
case <-earlyConnChan:
// ready to send 0-RTT data
return nil

View file

@ -18,13 +18,17 @@ import (
. "github.com/onsi/gomega"
)
type nullMultiplexer struct{}
func (n nullMultiplexer) AddConn(indexableConn) {}
func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil }
var _ = Describe("Client", func() {
var (
cl *client
packetConn *MockPacketConn
addr net.Addr
connID protocol.ConnectionID
mockMultiplexer *MockMultiplexer
origMultiplexer multiplexer
tlsConf *tls.Config
tracer *mocklogging.MockConnectionTracer
@ -53,6 +57,7 @@ var _ = Describe("Client", func() {
originalClientConnConstructor = newClientConnection
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
tr := mocklogging.NewMockTracer(mockCtrl)
tr.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}}
Eventually(areConnsRunning).Should(BeFalse())
@ -68,10 +73,9 @@ var _ = Describe("Client", func() {
logger: utils.DefaultLogger,
}
getMultiplexer() // make the sync.Once execute
// replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
mockMultiplexer = NewMockMultiplexer(mockCtrl)
// replace the clientMuxer. getMultiplexer will now return the nullMultiplexer
origMultiplexer = connMuxer
connMuxer = mockMultiplexer
connMuxer = &nullMultiplexer{}
})
AfterEach(func() {
@ -100,48 +104,14 @@ var _ = Describe("Client", func() {
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
})
It("resolves the address", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
manager.EXPECT().Destroy()
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
remoteAddrChan := make(chan string, 1)
newClientConnection = func(
sconn sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
remoteAddrChan <- sconn.RemoteAddr().String()
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run()
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
_, err := DialAddr(context.Background(), "localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond})
Expect(err).ToNot(HaveOccurred())
Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890")))
})
It("returns after the handshake is complete", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
run := make(chan struct{})
newClientConnection = func(
_ sendConn,
runner connRunner,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
@ -162,18 +132,17 @@ var _ = Describe("Client", func() {
conn.EXPECT().HandshakeComplete().Return(c)
return conn
}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
s, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, false, false)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
Expect(cl.dial(context.Background())).To(Succeed())
Eventually(run).Should(BeClosed())
})
It("returns early connections", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
readyChan := make(chan struct{})
done := make(chan struct{})
newClientConnection = func(
@ -193,29 +162,23 @@ var _ = Describe("Client", func() {
) quicConn {
Expect(enable0RTT).To(BeTrue())
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run().Do(func() { <-done })
conn.EXPECT().run().Do(func() { close(done) })
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().earlyConnReady().Return(readyChan)
return conn
}
go func() {
defer GinkgoRecover()
defer close(done)
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
s, err := DialEarly(context.Background(), packetConn, addr, tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
}()
Consistently(done).ShouldNot(BeClosed())
close(readyChan)
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, true, false)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
Expect(cl.dial(context.Background())).To(Succeed())
Eventually(done).Should(BeClosed())
})
It("returns an error that occurs while waiting for the handshake to complete", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
testErr := errors.New("early handshake error")
newClientConnection = func(
@ -236,108 +199,16 @@ var _ = Describe("Client", func() {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run().Return(testErr)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().earlyConnReady().Return(make(chan struct{}))
return conn
}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
Expect(err).To(MatchError(testErr))
})
It("closes the connection when the context is canceled", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
connRunning := make(chan struct{})
defer close(connRunning)
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run().Do(func() {
<-connRunning
})
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
newClientConnection = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
return conn
}
ctx, cancel := context.WithCancel(context.Background())
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := Dial(ctx, packetConn, addr, tlsConf, config)
Expect(err).To(MatchError(context.Canceled))
close(dialed)
}()
Consistently(dialed).ShouldNot(BeClosed())
conn.EXPECT().shutdown()
cancel()
Eventually(dialed).Should(BeClosed())
})
It("closes the connection when it was created by DialAddr", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
var sconn sendConn
run := make(chan struct{})
connCreated := make(chan struct{})
conn := NewMockQUICConn(mockCtrl)
newClientConnection = func(
connP sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
sconn = connP
close(connCreated)
return conn
}
conn.EXPECT().run().Do(func() {
<-run
})
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := DialAddr(context.Background(), "localhost:1337", tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(connCreated).Should(BeClosed())
// check that the connection is not closed
Expect(sconn.Write([]byte("foobar"))).To(Succeed())
manager.EXPECT().Destroy()
close(run)
time.Sleep(50 * time.Millisecond)
Eventually(done).Should(BeClosed())
var closed bool
cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, func() { closed = true }, true, false)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
Expect(cl.dial(context.Background())).To(MatchError(testErr))
Expect(closed).To(BeTrue())
})
Context("quic.Config", func() {
@ -365,12 +236,6 @@ var _ = Describe("Client", func() {
Expect(c.EnableDatagrams).To(BeTrue())
})
It("errors when the Config contains an invalid version", func() {
version := protocol.VersionNumber(0x1234)
_, err := Dial(context.Background(), packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
})
It("disables bidirectional streams", func() {
config := &Config{
MaxIncomingStreams: -1,
@ -405,15 +270,12 @@ var _ = Describe("Client", func() {
})
It("creates new connections with the right parameters", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}}
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}}
c := make(chan struct{})
var cconn sendConn
var version protocol.VersionNumber
var conf *Config
done := make(chan struct{})
newClientConnection = func(
connP sendConn,
_ connRunner,
@ -437,8 +299,15 @@ var _ = Describe("Client", func() {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run()
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().destroy(gomock.Any())
close(done)
return conn
}
packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) {
<-done
return 0, nil, errors.New("closed")
})
packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Eventually(c).Should(BeClosed())
@ -448,17 +317,12 @@ var _ = Describe("Client", func() {
})
It("creates a new connections after version negotiation", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any()).Times(2)
manager.EXPECT().Destroy()
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
var counter int
newClientConnection = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
runner connRunner,
_ protocol.ConnectionID,
connID protocol.ConnectionID,
configP *Config,
_ *tls.Config,
pn protocol.PacketNumber,
@ -474,20 +338,24 @@ var _ = Describe("Client", func() {
if counter == 0 {
Expect(pn).To(BeZero())
Expect(hasNegotiatedVersion).To(BeFalse())
conn.EXPECT().run().Return(&errCloseForRecreating{
nextPacketNumber: 109,
nextVersion: 789,
conn.EXPECT().run().DoAndReturn(func() error {
runner.Remove(connID)
return &errCloseForRecreating{
nextPacketNumber: 109,
nextVersion: 789,
}
})
} else {
Expect(pn).To(Equal(protocol.PacketNumber(109)))
Expect(hasNegotiatedVersion).To(BeTrue())
conn.EXPECT().run()
conn.EXPECT().destroy(gomock.Any())
}
counter++
return conn
}
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}}
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config)
Expect(err).ToNot(HaveOccurred())
@ -495,15 +363,3 @@ var _ = Describe("Client", func() {
})
})
})
type mockConnIDGenerator struct {
ConnID protocol.ConnectionID
}
func (m *mockConnIDGenerator) GenerateConnectionID() (protocol.ConnectionID, error) {
return m.ConnID, nil
}
func (m *mockConnIDGenerator) ConnectionIDLen() int {
return m.ConnID.Len()
}

View file

@ -660,6 +660,7 @@ var _ = Describe("Stream Cancellations", func() {
getQuicConfig(&quic.Config{MaxIncomingStreams: maxIncomingStreams, MaxIdleTimeout: 10 * time.Second}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
var wg sync.WaitGroup
wg.Add(2 * 4 * maxIncomingStreams)

View file

@ -24,6 +24,7 @@ var _ = Describe("Connection ID lengths tests", func() {
}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
var drop atomic.Bool
dropped := make(chan []byte, 100)
@ -50,6 +51,7 @@ var _ = Describe("Connection ID lengths tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
sconn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())

View file

@ -35,7 +35,11 @@ var _ = Describe("Connection ID lengths tests", func() {
randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) }
runServer := func(conf *quic.Config) *quic.Listener {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
if conf.ConnectionIDGenerator != nil {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", conf.ConnectionIDGenerator.ConnectionIDLen())))
} else {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
}
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf)
Expect(err).ToNot(HaveOccurred())
go func() {
@ -59,7 +63,11 @@ var _ = Describe("Connection ID lengths tests", func() {
}
runClient := func(addr net.Addr, conf *quic.Config) {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
if conf.ConnectionIDGenerator != nil {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", conf.ConnectionIDGenerator.ConnectionIDLen())))
} else {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
}
cl, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),

View file

@ -22,12 +22,11 @@ var _ = Describe("Datagram test", func() {
const num = 100
var (
proxy *quicproxy.QuicProxy
serverConn, clientConn *net.UDPConn
dropped, total int32
)
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) {
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) (port int, closeFn func()) {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
serverConn, err = net.ListenUDP("udp", addr)
@ -39,8 +38,10 @@ var _ = Describe("Datagram test", func() {
)
Expect(err).ToNot(HaveOccurred())
accepted := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(accepted)
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
@ -67,7 +68,7 @@ var _ = Describe("Datagram test", func() {
}()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
// drop 10% of Short Header packets sent from the server
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
@ -87,6 +88,11 @@ var _ = Describe("Datagram test", func() {
},
})
Expect(err).ToNot(HaveOccurred())
return proxy.LocalPort(), func() {
Eventually(accepted).Should(BeClosed())
proxy.Close()
ln.Close()
}
}
BeforeEach(func() {
@ -96,13 +102,10 @@ var _ = Describe("Datagram test", func() {
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
})
It("sends datagrams", func() {
startServerAndProxy(true, true)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
proxyPort, close := startServerAndProxy(true, true)
defer close()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
context.Background(),
@ -117,6 +120,7 @@ var _ = Describe("Datagram test", func() {
for {
// Close the connection if no message is received for 100 ms.
timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() {
fmt.Println("closing conn")
conn.CloseWithError(0, "")
})
if _, err := conn.ReceiveMessage(); err != nil {
@ -134,11 +138,12 @@ var _ = Describe("Datagram test", func() {
BeNumerically(">", expVal*9/10),
BeNumerically("<", num),
))
Eventually(conn.Context().Done).Should(BeClosed())
})
It("server can disable datagram", func() {
startServerAndProxy(false, true)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
proxyPort, close := startServerAndProxy(false, true)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
context.Background(),
@ -150,13 +155,13 @@ var _ = Describe("Datagram test", func() {
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
close()
conn.CloseWithError(0, "")
<-time.After(10 * time.Millisecond)
})
It("client can disable datagram", func() {
startServerAndProxy(false, true)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
proxyPort, close := startServerAndProxy(false, true)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
context.Background(),
@ -169,7 +174,8 @@ var _ = Describe("Datagram test", func() {
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
Expect(conn.SendMessage([]byte{0})).To(HaveOccurred())
close()
conn.CloseWithError(0, "")
<-time.After(10 * time.Millisecond)
})
})

View file

@ -24,6 +24,7 @@ var _ = Describe("early data", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()

View file

@ -8,10 +8,9 @@ import (
"time"
)
var (
go120 = false
errNotSupported = errors.New("not supported")
)
const go120 = false
var errNotSupported = errors.New("not supported")
func setReadDeadline(w http.ResponseWriter, deadline time.Time) error {
return errNotSupported

View file

@ -7,7 +7,7 @@ import (
"time"
)
var go120 = true
const go120 = true
func setReadDeadline(w http.ResponseWriter, deadline time.Time) error {
rc := http.NewResponseController(w)

View file

@ -62,13 +62,14 @@ var _ = Describe("Handshake RTT tests", func() {
runProxy(ln.Addr())
startTime := time.Now()
_, err = quic.DialAddr(
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
expectDurationInRTTs(startTime, 2)
})
@ -79,13 +80,14 @@ var _ = Describe("Handshake RTT tests", func() {
runProxy(ln.Addr())
startTime := time.Now()
_, err = quic.DialAddr(
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
expectDurationInRTTs(startTime, 1)
})
@ -97,13 +99,14 @@ var _ = Describe("Handshake RTT tests", func() {
runProxy(ln.Addr())
startTime := time.Now()
_, err = quic.DialAddr(
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
expectDurationInRTTs(startTime, 2)
})
@ -131,6 +134,7 @@ var _ = Describe("Handshake RTT tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
@ -166,6 +170,7 @@ var _ = Describe("Handshake RTT tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)

View file

@ -114,7 +114,7 @@ var _ = Describe("Handshake tests", func() {
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
nil,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
@ -223,13 +223,14 @@ var _ = Describe("Handshake tests", func() {
var (
server *quic.Listener
pconn net.PacketConn
dialer *quic.Transport
)
dial := func() (quic.Connection, error) {
remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
raddr, err := net.ResolveUDPAddr("udp", remoteAddr)
Expect(err).ToNot(HaveOccurred())
return quic.Dial(context.Background(), pconn, raddr, getTLSClientConfig(), nil)
return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil))
}
BeforeEach(func() {
@ -243,11 +244,13 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
pconn, err = net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
dialer = &quic.Transport{Conn: pconn}
})
AfterEach(func() {
Expect(server.Close()).To(Succeed())
Expect(pconn.Close()).To(Succeed())
Expect(dialer.Close()).To(Succeed())
})
It("rejects new connection attempts if connections don't get accepted", func() {
@ -366,6 +369,7 @@ var _ = Describe("Handshake tests", func() {
It("uses tokens provided in NEW_TOKEN frames", func() {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
// dial the first connection and receive the token
go func() {

View file

@ -382,6 +382,7 @@ var _ = Describe("HTTP tests", func() {
tlsConf.NextProtos = []string{"h3"}
ln, err := quic.ListenAddr("localhost:0", tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -398,57 +399,51 @@ var _ = Describe("HTTP tests", func() {
Eventually(done).Should(BeClosed())
})
It("supports read deadlines", func() {
if !go120 {
Skip("This test requires Go 1.20+")
}
if go120 {
It("supports read deadlines", func() {
mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
err := setReadDeadline(w, time.Now().Add(deadlineDelay))
Expect(err).ToNot(HaveOccurred())
mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
err := setReadDeadline(w, time.Now().Add(deadlineDelay))
body, err := io.ReadAll(r.Body)
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
Expect(body).To(ContainSubstring("aa"))
w.Write([]byte("ok"))
})
expectedEnd := time.Now().Add(deadlineDelay)
resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a'))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := io.ReadAll(r.Body)
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
Expect(body).To(ContainSubstring("aa"))
w.Write([]byte("ok"))
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
Expect(err).ToNot(HaveOccurred())
Expect(time.Now().After(expectedEnd)).To(BeTrue())
Expect(string(body)).To(Equal("ok"))
})
expectedEnd := time.Now().Add(deadlineDelay)
resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a'))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
It("supports write deadlines", func() {
mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
err := setWriteDeadline(w, time.Now().Add(deadlineDelay))
Expect(err).ToNot(HaveOccurred())
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
Expect(err).ToNot(HaveOccurred())
Expect(time.Now().After(expectedEnd)).To(BeTrue())
Expect(string(body)).To(Equal("ok"))
})
_, err = io.Copy(w, neverEnding('a'))
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
})
It("supports write deadlines", func() {
if !go120 {
Skip("This test requires Go 1.20+")
}
expectedEnd := time.Now().Add(deadlineDelay)
mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
err := setWriteDeadline(w, time.Now().Add(deadlineDelay))
resp, err := client.Get("https://localhost:" + port + "/write-deadline")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
_, err = io.Copy(w, neverEnding('a'))
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
Expect(err).ToNot(HaveOccurred())
Expect(time.Now().After(expectedEnd)).To(BeTrue())
Expect(string(body)).To(ContainSubstring("aa"))
})
expectedEnd := time.Now().Add(deadlineDelay)
resp, err := client.Get("https://localhost:" + port + "/write-deadline")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
Expect(err).ToNot(HaveOccurred())
Expect(time.Now().After(expectedEnd)).To(BeTrue())
Expect(string(body)).To(ContainSubstring("aa"))
})
}
})

View file

@ -34,10 +34,9 @@ var _ = Describe("Multiplexing", func() {
}()
}
dial := func(pconn net.PacketConn, addr net.Addr) {
conn, err := quic.Dial(
dial := func(tr *quic.Transport, addr net.Addr) {
conn, err := tr.Dial(
context.Background(),
pconn,
addr,
getTLSClientConfig(),
getQuicConfig(nil),
@ -72,17 +71,18 @@ var _ = Describe("Multiplexing", func() {
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn, server.Addr())
dial(tr, server.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(conn, server.Addr())
dial(tr, server.Addr())
close(done2)
}()
timeout := 30 * time.Second
@ -106,17 +106,18 @@ var _ = Describe("Multiplexing", func() {
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn, server1.Addr())
dial(tr, server1.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(conn, server2.Addr())
dial(tr, server2.Addr())
close(done2)
}()
timeout := 30 * time.Second
@ -135,9 +136,9 @@ var _ = Describe("Multiplexing", func() {
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
server, err := quic.Listen(
conn,
server, err := tr.Listen(
getTLSConfig(),
getQuicConfig(nil),
)
@ -146,7 +147,7 @@ var _ = Describe("Multiplexing", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn, server.Addr())
dial(tr, server.Addr())
close(done)
}()
timeout := 30 * time.Second
@ -165,15 +166,16 @@ var _ = Describe("Multiplexing", func() {
conn1, err := net.ListenUDP("udp", addr1)
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1}
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn2, err := net.ListenUDP("udp", addr2)
Expect(err).ToNot(HaveOccurred())
defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2}
server1, err := quic.Listen(
conn1,
server1, err := tr1.Listen(
getTLSConfig(),
getQuicConfig(nil),
)
@ -181,8 +183,7 @@ var _ = Describe("Multiplexing", func() {
runServer(server1)
defer server1.Close()
server2, err := quic.Listen(
conn2,
server2, err := tr2.Listen(
getTLSConfig(),
getQuicConfig(nil),
)
@ -194,12 +195,12 @@ var _ = Describe("Multiplexing", func() {
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn2, server1.Addr())
dial(tr2, server1.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(conn1, server2.Addr())
dial(tr1, server2.Addr())
close(done2)
}()
timeout := 30 * time.Second

View file

@ -31,8 +31,8 @@ var _ = Describe("Packetization", func() {
}),
)
Expect(err).ToNot(HaveOccurred())
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
defer server.Close()
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: serverAddr,
@ -54,6 +54,7 @@ var _ = Describe("Packetization", func() {
}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
go func() {
defer GinkgoRecover()

View file

@ -199,8 +199,16 @@ func areHandshakesRunning() bool {
return strings.Contains(b.String(), "RunHandshake")
}
func areTransportsRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
}
var _ = AfterEach(func() {
Expect(areHandshakesRunning()).To(BeFalse())
Eventually(areTransportsRunning).Should(BeFalse())
if debugLog() {
logFile, err := os.Create(logFileName)
Expect(err).ToNot(HaveOccurred())

View file

@ -2,7 +2,6 @@ package self_test
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
@ -27,7 +26,13 @@ var _ = Describe("Stateless Resets", func() {
rand.Read(statelessResetKey[:])
serverConfig := getQuicConfig(&quic.Config{StatelessResetKey: &statelessResetKey})
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
c, err := net.ListenUDP("udp", nil)
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: c,
}
defer tr.Close()
ln, err := tr.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port
@ -42,7 +47,8 @@ var _ = Describe("Stateless Resets", func() {
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
<-closeServer
ln.Close()
Expect(ln.Close()).To(Succeed())
Expect(tr.Close()).To(Succeed())
}()
var drop atomic.Bool
@ -77,11 +83,14 @@ var _ = Describe("Stateless Resets", func() {
close(closeServer)
time.Sleep(100 * time.Millisecond)
ln2, err := quic.ListenAddr(
fmt.Sprintf("localhost:%d", serverPort),
getTLSConfig(),
serverConfig,
)
// We need to create a new Transport here, since the old one is still sending out
// CONNECTION_CLOSE packets for (recently) closed connections).
tr2 := &quic.Transport{
Conn: c,
StatelessResetKey: &statelessResetKey,
}
defer tr2.Close()
ln2, err := tr2.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
drop.Store(false)
@ -100,8 +109,7 @@ var _ = Describe("Stateless Resets", func() {
_, serr = str.Read([]byte{0})
}
Expect(serr).To(HaveOccurred())
statelessResetErr := &quic.StatelessResetError{}
Expect(errors.As(serr, &statelessResetErr)).To(BeTrue())
Expect(serr).To(BeAssignableToTypeOf(&quic.StatelessResetError{}))
Expect(ln2.Close()).To(Succeed())
Eventually(acceptStopped).Should(BeClosed())
})

View file

@ -94,6 +94,8 @@ var _ = Describe("Bidirectional streams", func() {
)
Expect(err).ToNot(HaveOccurred())
runSendingPeer(client)
client.CloseWithError(0, "")
<-conn.Context().Done()
})
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
@ -149,5 +151,6 @@ var _ = Describe("Bidirectional streams", func() {
runReceivingPeer(client)
<-done1
<-done2
client.CloseWithError(0, "")
})
})

View file

@ -473,6 +473,7 @@ var _ = Describe("Timeout tests", func() {
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
serverErrChan := make(chan error, 1)
go func() {

View file

@ -88,11 +88,14 @@ var _ = Describe("Unidirectional Streams", func() {
})
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
runSendingPeer(conn)
<-conn.Context().Done()
}()
client, err := quic.DialAddr(
@ -103,6 +106,7 @@ var _ = Describe("Unidirectional Streams", func() {
)
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(client)
client.CloseWithError(0, "")
})
It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() {

View file

@ -1,65 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/quic-go/quic-go (interfaces: Multiplexer)
// Package quic is a generated GoMock package.
package quic
import (
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
logging "github.com/quic-go/quic-go/logging"
)
// MockMultiplexer is a mock of Multiplexer interface.
type MockMultiplexer struct {
ctrl *gomock.Controller
recorder *MockMultiplexerMockRecorder
}
// MockMultiplexerMockRecorder is the mock recorder for MockMultiplexer.
type MockMultiplexerMockRecorder struct {
mock *MockMultiplexer
}
// NewMockMultiplexer creates a new mock instance.
func NewMockMultiplexer(ctrl *gomock.Controller) *MockMultiplexer {
mock := &MockMultiplexer{ctrl: ctrl}
mock.recorder = &MockMultiplexerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder {
return m.recorder
}
// AddConn mocks base method.
func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 *StatelessResetKey, arg3 logging.Tracer) (packetHandlerManager, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(packetHandlerManager)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AddConn indicates an expected call of AddConn.
func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2, arg3)
}
// RemoveConn mocks base method.
func (m *MockMultiplexer) RemoveConn(arg0 indexableConn) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveConn", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveConn indicates an expected call of RemoveConn.
func (mr *MockMultiplexerMockRecorder) RemoveConn(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveConn", reflect.TypeOf((*MockMultiplexer)(nil).RemoveConn), arg0)
}

View file

@ -74,6 +74,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 i
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2)
}
// Close mocks base method.
func (m *MockPacketHandlerManager) Close(arg0 error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Close", arg0)
}
// Close indicates an expected call of Close.
func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0)
}
// CloseServer mocks base method.
func (m *MockPacketHandlerManager) CloseServer() {
m.ctrl.T.Helper()
@ -86,20 +98,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
}
// Destroy mocks base method.
func (m *MockPacketHandlerManager) Destroy() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Destroy")
ret0, _ := ret[0].(error)
return ret0
}
// Destroy indicates an expected call of Destroy.
func (mr *MockPacketHandlerManagerMockRecorder) Destroy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockPacketHandlerManager)(nil).Destroy))
}
// Get mocks base method.
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
m.ctrl.T.Helper()
@ -115,6 +113,21 @@ func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
}
// GetByResetToken mocks base method.
func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetToken) (packetHandler, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetByResetToken", arg0)
ret0, _ := ret[0].(packetHandler)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// GetByResetToken indicates an expected call of GetByResetToken.
func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0)
}
// GetStatelessResetToken mocks base method.
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken {
m.ctrl.T.Helper()
@ -176,15 +189,3 @@ func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 interface{}) *gomock
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0)
}
// SetServer mocks base method.
func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetServer", arg0)
}
// SetServer indicates an expected call of SetServer.
func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0)
}

View file

@ -65,9 +65,6 @@ type UnknownPacketHandler = unknownPacketHandler
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
type PacketHandlerManager = packetHandlerManager
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_multiplexer_test.go github.com/quic-go/quic-go Multiplexer"
type Multiplexer = multiplexer
// Need to use source mode for the batchConn, since reflect mode follows type aliases.
// See https://github.com/golang/mock/issues/244 for details.
//

View file

@ -6,7 +6,6 @@ import (
"sync"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
var (
@ -14,30 +13,19 @@ var (
connMuxer multiplexer
)
type indexableConn interface {
LocalAddr() net.Addr
}
type indexableConn interface{ LocalAddr() net.Addr }
type multiplexer interface {
AddConn(c net.PacketConn, connIDLen int, statelessResetKey *StatelessResetKey, tracer logging.Tracer) (packetHandlerManager, error)
AddConn(conn indexableConn)
RemoveConn(indexableConn) error
}
type connManager struct {
connIDLen int
statelessResetKey *StatelessResetKey
tracer logging.Tracer
manager packetHandlerManager
}
// The connMultiplexer listens on multiple net.PacketConns and dispatches
// incoming packets to the connection handler.
type connMultiplexer struct {
mutex sync.Mutex
conns map[string] /* LocalAddr().String() */ connManager
newPacketHandlerManager func(net.PacketConn, int, *StatelessResetKey, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests
conns map[string] /* LocalAddr().String() */ indexableConn
logger utils.Logger
}
@ -46,57 +34,38 @@ var _ multiplexer = &connMultiplexer{}
func getMultiplexer() multiplexer {
connMuxerOnce.Do(func() {
connMuxer = &connMultiplexer{
conns: make(map[string]connManager),
logger: utils.DefaultLogger.WithPrefix("muxer"),
newPacketHandlerManager: newPacketHandlerMap,
conns: make(map[string]indexableConn),
logger: utils.DefaultLogger.WithPrefix("muxer"),
}
})
return connMuxer
}
func (m *connMultiplexer) AddConn(
c net.PacketConn,
connIDLen int,
statelessResetKey *StatelessResetKey,
tracer logging.Tracer,
) (packetHandlerManager, error) {
func (m *connMultiplexer) index(addr net.Addr) string {
return addr.Network() + " " + addr.String()
}
func (m *connMultiplexer) AddConn(c indexableConn) {
m.mutex.Lock()
defer m.mutex.Unlock()
addr := c.LocalAddr()
connIndex := addr.Network() + " " + addr.String()
connIndex := m.index(c.LocalAddr())
p, ok := m.conns[connIndex]
if !ok {
manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger)
if err != nil {
return nil, err
}
p = connManager{
connIDLen: connIDLen,
statelessResetKey: statelessResetKey,
manager: manager,
tracer: tracer,
}
m.conns[connIndex] = p
} else {
if p.connIDLen != connIDLen {
return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen)
}
if statelessResetKey != nil && p.statelessResetKey != statelessResetKey {
return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn")
}
if tracer != p.tracer {
return nil, fmt.Errorf("cannot use different tracers on the same packet conn")
}
if ok {
// Panics if we're already listening on this connection.
// This is a safeguard because we're introducing a breaking API change, see
// https://github.com/quic-go/quic-go/issues/3727 for details.
// We'll remove this at a later time, when most users of the library have made the switch.
panic("connection already exists") // TODO: write a nice message
}
return p.manager, nil
m.conns[connIndex] = p
}
func (m *connMultiplexer) RemoveConn(c indexableConn) error {
m.mutex.Lock()
defer m.mutex.Unlock()
connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String()
connIndex := m.index(c.LocalAddr())
if _, ok := m.conns[connIndex]; !ok {
return fmt.Errorf("cannote remove connection, connection is unknown")
}

View file

@ -3,71 +3,24 @@ package quic
import (
"net"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
type testConn struct {
counter int
net.PacketConn
}
var _ = Describe("Multiplexer", func() {
It("adds a new packet conn ", func() {
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234})
_, err := getMultiplexer().AddConn(conn, 8, nil, nil)
Expect(err).ToNot(HaveOccurred())
It("adds new packet conns", func() {
conn1 := NewMockPacketConn(mockCtrl)
conn1.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234})
getMultiplexer().AddConn(conn1)
conn2 := NewMockPacketConn(mockCtrl)
conn2.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1235})
getMultiplexer().AddConn(conn2)
})
It("recognizes when the same connection is added twice", func() {
srk := &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'}
pconn := NewMockPacketConn(mockCtrl)
pconn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2)
pconn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn := testConn{PacketConn: pconn}
tracer := mocklogging.NewMockTracer(mockCtrl)
_, err := getMultiplexer().AddConn(conn, 8, srk, tracer)
Expect(err).ToNot(HaveOccurred())
conn.counter++
_, err = getMultiplexer().AddConn(conn, 8, srk, tracer)
Expect(err).ToNot(HaveOccurred())
Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1))
})
It("errors when adding an existing conn with a different connection ID length", func() {
It("panics when the same connection is added twice", func() {
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
_, err := getMultiplexer().AddConn(conn, 5, nil, nil)
Expect(err).ToNot(HaveOccurred())
_, err = getMultiplexer().AddConn(conn, 6, nil, nil)
Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs"))
})
It("errors when adding an existing conn with a different stateless rest key", func() {
srk1 := &StatelessResetKey{'f', 'o', 'o'}
srk2 := &StatelessResetKey{'b', 'a', 'r'}
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
_, err := getMultiplexer().AddConn(conn, 7, srk1, nil)
Expect(err).ToNot(HaveOccurred())
_, err = getMultiplexer().AddConn(conn, 7, srk2, nil)
Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn"))
})
It("errors when adding an existing conn with different tracers", func() {
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
_, err := getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl))
Expect(err).ToNot(HaveOccurred())
_, err = getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl))
Expect(err).To(MatchError("cannot use different tracers on the same packet conn"))
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2)
getMultiplexer().AddConn(conn)
Expect(func() { getMultiplexer().AddConn(conn) }).To(Panic())
})
})

View file

@ -5,28 +5,22 @@ import (
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"hash"
"io"
"log"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
// rawConn is a connection that allow reading of a receivedPacket.
// rawConn is a connection that allow reading of a receivedPackeh.
type rawConn interface {
ReadPacket() (*receivedPacket, error)
WritePacket(b []byte, addr net.Addr, oob []byte) (int, error)
LocalAddr() net.Addr
SetReadDeadline(time.Time) error
io.Closer
}
@ -36,113 +30,49 @@ type closePacket struct {
info *packetInfo
}
// The packetHandlerMap stores packetHandlers, identified by connection ID.
// It is used:
// * by the server to store connections
// * when multiplexing outgoing connections to store clients
type unknownPacketHandler interface {
handlePacket(*receivedPacket)
setCloseError(error)
}
var errListenerAlreadySet = errors.New("listener already set")
type packetHandlerMap struct {
mutex sync.Mutex
conn rawConn
connIDLen int
closeQueue chan closePacket
mutex sync.Mutex
handlers map[protocol.ConnectionID]packetHandler
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
server unknownPacketHandler
listening chan struct{} // is closed when listen returns
closed bool
closeChan chan struct{}
enqueueClosePacket func(closePacket)
deleteRetiredConnsAfter time.Duration
statelessResetEnabled bool
statelessResetMutex sync.Mutex
statelessResetHasher hash.Hash
statelessResetMutex sync.Mutex
statelessResetHasher hash.Hash
tracer logging.Tracer
logger utils.Logger
}
var _ packetHandlerManager = &packetHandlerMap{}
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
conn, ok := c.(interface{ SetReadBuffer(int) error })
if !ok {
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
}
size, err := inspectReadBuffer(c)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if size >= protocol.DesiredReceiveBufferSize {
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
return nil
}
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
return fmt.Errorf("failed to increase receive buffer size: %w", err)
}
newSize, err := inspectReadBuffer(c)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if newSize == size {
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
if newSize < protocol.DesiredReceiveBufferSize {
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
return nil
}
// only print warnings about the UDP receive buffer size once
var receiveBufferWarningOnce sync.Once
func newPacketHandlerMap(
c net.PacketConn,
connIDLen int,
statelessResetKey *StatelessResetKey,
tracer logging.Tracer,
logger utils.Logger,
) (packetHandlerManager, error) {
if err := setReceiveBuffer(c, logger); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
receiveBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
})
}
}
conn, err := wrapConn(c)
if err != nil {
return nil, err
}
m := &packetHandlerMap{
conn: conn,
connIDLen: connIDLen,
listening: make(chan struct{}),
func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
h := &packetHandlerMap{
closeChan: make(chan struct{}),
handlers: make(map[protocol.ConnectionID]packetHandler),
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
closeQueue: make(chan closePacket, 4),
statelessResetEnabled: statelessResetKey != nil,
tracer: tracer,
enqueueClosePacket: enqueueClosePacket,
logger: logger,
}
if m.statelessResetEnabled {
m.statelessResetHasher = hmac.New(sha256.New, statelessResetKey[:])
if key != nil {
h.statelessResetHasher = hmac.New(sha256.New, key[:])
}
go m.listen()
go m.runCloseQueue()
if logger.Debug() {
go m.logUsage()
if h.logger.Debug() {
go h.logUsage()
}
return m, nil
return h
}
func (h *packetHandlerMap) logUsage() {
@ -150,7 +80,7 @@ func (h *packetHandlerMap) logUsage() {
var printedZero bool
for {
select {
case <-h.listening:
case <-h.closeChan:
return
case <-ticker.C:
}
@ -233,12 +163,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
if connClosePacket != nil {
handler = newClosedLocalConn(
func(addr net.Addr, info *packetInfo) {
select {
case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}:
default:
// Oops, we're backlogged.
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
}
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
},
pers,
h.logger,
@ -265,17 +190,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
})
}
func (h *packetHandlerMap) runCloseQueue() {
for {
select {
case <-h.listening:
return
case p := <-h.closeQueue:
h.conn.WritePacket(p.payload, p.addr, p.info.OOB())
}
}
}
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
h.mutex.Lock()
h.resetTokens[token] = handler
@ -288,19 +202,16 @@ func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken)
h.mutex.Unlock()
}
func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) {
h.mutex.Lock()
h.server = s
h.mutex.Unlock()
defer h.mutex.Unlock()
handler, ok := h.resetTokens[token]
return handler, ok
}
func (h *packetHandlerMap) CloseServer() {
h.mutex.Lock()
if h.server == nil {
h.mutex.Unlock()
return
}
h.server = nil
var wg sync.WaitGroup
for _, handler := range h.handlers {
if handler.getPerspective() == protocol.PerspectiveServer {
@ -316,23 +227,16 @@ func (h *packetHandlerMap) CloseServer() {
wg.Wait()
}
// Destroy closes the underlying connection and waits until listen() has returned.
// It does not close active connections.
func (h *packetHandlerMap) Destroy() error {
if err := h.conn.Close(); err != nil {
return err
}
<-h.listening // wait until listening returns
return nil
}
func (h *packetHandlerMap) close(e error) error {
func (h *packetHandlerMap) Close(e error) {
h.mutex.Lock()
if h.closed {
h.mutex.Unlock()
return nil
return
}
close(h.closeChan)
var wg sync.WaitGroup
for _, handler := range h.handlers {
wg.Add(1)
@ -341,89 +245,14 @@ func (h *packetHandlerMap) close(e error) error {
wg.Done()
}(handler)
}
if h.server != nil {
h.server.setCloseError(e)
}
h.closed = true
h.mutex.Unlock()
wg.Wait()
return getMultiplexer().RemoveConn(h.conn)
}
func (h *packetHandlerMap) listen() {
defer close(h.listening)
for {
p, err := h.conn.ReadPacket()
//nolint:staticcheck // SA1019 ignore this!
// TODO: This code is used to ignore wsa errors on Windows.
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
// See https://github.com/quic-go/quic-go/issues/1737 for details.
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
h.logger.Debugf("Temporary error reading from conn: %w", err)
continue
}
if err != nil {
h.close(err)
return
}
h.handlePacket(p)
}
}
func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, h.connIDLen)
if err != nil {
h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
if h.tracer != nil {
h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
p.buffer.MaybeRelease()
return
}
h.mutex.Lock()
defer h.mutex.Unlock()
if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset {
return
}
if handler, ok := h.handlers[connID]; ok {
handler.handlePacket(p)
return
}
if !wire.IsLongHeaderPacket(p.data[0]) {
go h.maybeSendStatelessReset(p, connID)
return
}
if h.server == nil { // no server set
h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
return
}
h.server.handlePacket(p)
}
func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
// stateless resets are always short header packets
if wire.IsLongHeaderPacket(data[0]) {
return false
}
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
return false
}
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
go sess.destroy(&StatelessResetError{Token: token})
return true
}
return false
}
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
var token protocol.StatelessResetToken
if !h.statelessResetEnabled {
if h.statelessResetHasher == nil {
// Return a random stateless reset token.
// This token will be sent in the server's transport parameters.
// By using a random token, an off-path attacker won't be able to disrupt the connection.
@ -437,24 +266,3 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID)
h.statelessResetMutex.Unlock()
return token
}
func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
defer p.buffer.Release()
if !h.statelessResetEnabled {
return
}
// Don't send a stateless reset in response to very small packets.
// This includes packets that could be stateless resets.
if len(p.data) <= protocol.MinStatelessResetSize {
return
}
token := h.GetStatelessResetToken(connID)
h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...)
if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
h.logger.Debugf("Error sending Stateless Reset: %s", err)
}
}

View file

@ -6,405 +6,188 @@ import (
"net"
"time"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Packet Handler Map", func() {
type packetToRead struct {
addr net.Addr
data []byte
err error
}
var (
handler *packetHandlerMap
conn *MockPacketConn
tracer *mocklogging.MockTracer
packetChan chan packetToRead
connIDLen int
statelessResetKey *StatelessResetKey
)
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
b, err := (&wire.ExtendedHeader{
Header: wire.Header{
Type: t,
DestConnectionID: connID,
Length: length,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen2,
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
return b
}
getPacket := func(connID protocol.ConnectionID) []byte {
return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
}
BeforeEach(func() {
statelessResetKey = nil
connIDLen = 0
tracer = mocklogging.NewMockTracer(mockCtrl)
packetChan = make(chan packetToRead, 10)
It("adds and gets", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
handler := NewMockPacketHandler(mockCtrl)
Expect(m.Add(connID, handler)).To(BeTrue())
h, ok := m.Get(connID)
Expect(ok).To(BeTrue())
Expect(h).To(Equal(handler))
})
JustBeforeEach(func() {
conn = NewMockPacketConn(mockCtrl)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
p, ok := <-packetChan
if !ok {
return 0, nil, errors.New("closed")
It("refused to add duplicates", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
handler := NewMockPacketHandler(mockCtrl)
Expect(m.Add(connID, handler)).To(BeTrue())
Expect(m.Add(connID, handler)).To(BeFalse())
})
It("removes", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
handler := NewMockPacketHandler(mockCtrl)
Expect(m.Add(connID, handler)).To(BeTrue())
m.Remove(connID)
_, ok := m.Get(connID)
Expect(ok).To(BeFalse())
Expect(m.Add(connID, handler)).To(BeTrue())
})
It("retires", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
dur := scaleDuration(50 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
handler := NewMockPacketHandler(mockCtrl)
Expect(m.Add(connID, handler)).To(BeTrue())
m.Retire(connID)
_, ok := m.Get(connID)
Expect(ok).To(BeTrue())
time.Sleep(dur)
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})
It("adds newly to-be-constructed handlers", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
var called bool
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
Expect(m.AddWithConnID(connID1, connID2, func() packetHandler {
called = true
return NewMockPacketHandler(mockCtrl)
})).To(BeTrue())
Expect(called).To(BeTrue())
Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() packetHandler {
Fail("didn't expect the constructor to be executed")
return nil
})).To(BeFalse())
})
It("adds, gets and removes reset tokens", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
handler := NewMockPacketHandler(mockCtrl)
m.AddResetToken(token, handler)
h, ok := m.GetByResetToken(token)
Expect(ok).To(BeTrue())
Expect(h).To(Equal(h))
m.RemoveResetToken(token)
_, ok = m.GetByResetToken(token)
Expect(ok).To(BeFalse())
})
It("generates stateless reset token, if no key is set", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
b := make([]byte, 8)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
token := m.GetStatelessResetToken(connID)
for i := 0; i < 1000; i++ {
to := m.GetStatelessResetToken(connID)
Expect(to).ToNot(Equal(token))
token = to
}
})
It("generates stateless reset token, if a key is set", func() {
var key StatelessResetKey
rand.Read(key[:])
m := newPacketHandlerMap(&key, nil, utils.DefaultLogger)
b := make([]byte, 8)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
token := m.GetStatelessResetToken(connID)
Expect(token).ToNot(BeZero())
Expect(m.GetStatelessResetToken(connID)).To(Equal(token))
// generate a new connection ID
rand.Read(b)
connID2 := protocol.ParseConnectionID(b)
Expect(m.GetStatelessResetToken(connID2)).ToNot(Equal(token))
})
It("replaces locally closed connections", func() {
var closePackets []closePacket
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
dur := scaleDuration(50 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
handler := NewMockPacketHandler(mockCtrl)
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
Expect(m.Add(connID, handler)).To(BeTrue())
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, []byte("foobar"))
h, ok := m.Get(connID)
Expect(ok).To(BeTrue())
Expect(h).ToNot(Equal(handler))
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
h.handlePacket(&receivedPacket{remoteAddr: addr})
Expect(closePackets).To(HaveLen(1))
Expect(closePackets[0].addr).To(Equal(addr))
Expect(closePackets[0].payload).To(Equal([]byte("foobar")))
time.Sleep(dur)
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})
It("replaces remote closed connections", func() {
var closePackets []closePacket
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
dur := scaleDuration(50 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
handler := NewMockPacketHandler(mockCtrl)
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
Expect(m.Add(connID, handler)).To(BeTrue())
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, nil)
h, ok := m.Get(connID)
Expect(ok).To(BeTrue())
Expect(h).ToNot(Equal(handler))
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
h.handlePacket(&receivedPacket{remoteAddr: addr})
Expect(closePackets).To(BeEmpty())
time.Sleep(dur)
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})
It("closes the server", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
for i := 0; i < 10; i++ {
conn := NewMockPacketHandler(mockCtrl)
if i%2 == 0 {
conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
} else {
conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
conn.EXPECT().shutdown()
}
return copy(b, p.data), p.addr, p.err
}).AnyTimes()
phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger)
Expect(err).ToNot(HaveOccurred())
handler = phm.(*packetHandlerMap)
b := make([]byte, 12)
rand.Read(b)
m.Add(protocol.ParseConnectionID(b), conn)
}
m.CloseServer()
})
It("closes", func() {
getMultiplexer() // make the sync.Once execute
// replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
mockMultiplexer := NewMockMultiplexer(mockCtrl)
origMultiplexer := connMuxer
connMuxer = mockMultiplexer
defer func() {
connMuxer = origMultiplexer
}()
testErr := errors.New("test error ")
conn1 := NewMockPacketHandler(mockCtrl)
conn1.EXPECT().destroy(testErr)
conn2 := NewMockPacketHandler(mockCtrl)
conn2.EXPECT().destroy(testErr)
handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), conn1)
handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), conn2)
mockMultiplexer.EXPECT().RemoveConn(gomock.Any())
handler.close(testErr)
close(packetChan)
Eventually(handler.listening).Should(BeClosed())
})
Context("other operations", func() {
AfterEach(func() {
// delete connections and the server before closing
// They might be mock implementations, and we'd have to register the expected calls before otherwise.
handler.mutex.Lock()
for connID := range handler.handlers {
delete(handler.handlers, connID)
}
handler.server = nil
handler.mutex.Unlock()
conn.EXPECT().Close().MaxTimes(1)
close(packetChan)
handler.Destroy()
Eventually(handler.listening).Should(BeClosed())
})
Context("handling packets", func() {
BeforeEach(func() {
connIDLen = 5
})
It("handles packets for different packet handlers on the same packet conn", func() {
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
packetHandler1 := NewMockPacketHandler(mockCtrl)
packetHandler2 := NewMockPacketHandler(mockCtrl)
handledPacket1 := make(chan struct{})
handledPacket2 := make(chan struct{})
packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID1))
close(handledPacket1)
})
packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID2))
close(handledPacket2)
})
handler.Add(connID1, packetHandler1)
handler.Add(connID2, packetHandler2)
packetChan <- packetToRead{data: getPacket(connID1)}
packetChan <- packetToRead{data: getPacket(connID2)}
Eventually(handledPacket1).Should(BeClosed())
Eventually(handledPacket2).Should(BeClosed())
})
It("drops unparseable packets", func() {
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: []byte{0, 1, 2, 3},
})
})
It("deletes removed connections immediately", func() {
handler.deleteRetiredConnsAfter = time.Hour
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
handler.Add(connID, NewMockPacketHandler(mockCtrl))
handler.Remove(connID)
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
// don't EXPECT any calls to handlePacket of the MockPacketHandler
})
It("deletes retired connection entries after a wait time", func() {
handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
conn := NewMockPacketHandler(mockCtrl)
handler.Add(connID, conn)
handler.Retire(connID)
time.Sleep(scaleDuration(30 * time.Millisecond))
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
// don't EXPECT any calls to handlePacket of the MockPacketHandler
})
It("passes packets arriving late for closed connections to that connection", func() {
handler.deleteRetiredConnsAfter = time.Hour
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
packetHandler := NewMockPacketHandler(mockCtrl)
handled := make(chan struct{})
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
close(handled)
})
handler.Add(connID, packetHandler)
handler.Retire(connID)
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
Eventually(handled).Should(BeClosed())
})
It("drops packets for unknown receivers", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
})
It("closes the packet handlers when reading from the conn fails", func() {
done := make(chan struct{})
packetHandler := NewMockPacketHandler(mockCtrl)
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) {
Expect(e).To(HaveOccurred())
close(done)
})
handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler)
packetChan <- packetToRead{err: errors.New("read failed")}
Eventually(done).Should(BeClosed())
})
It("continues listening for temporary errors", func() {
packetHandler := NewMockPacketHandler(mockCtrl)
handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler)
err := deadlineError{}
Expect(err.Temporary()).To(BeTrue())
packetChan <- packetToRead{err: err}
// don't EXPECT any calls to packetHandler.destroy
time.Sleep(50 * time.Millisecond)
})
It("says if a connection ID is already taken", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
})
It("says if a connection ID is already taken, for AddWithConnID", func() {
clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
newConnID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
newConnID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue())
Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse())
})
})
Context("running a server", func() {
It("adds a server", func() {
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
cid, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(cid).To(Equal(connID))
})
handler.SetServer(server)
handler.handlePacket(&receivedPacket{data: p})
})
It("closes all server connections", func() {
handler.SetServer(NewMockUnknownPacketHandler(mockCtrl))
clientConn := NewMockPacketHandler(mockCtrl)
clientConn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
serverConn := NewMockPacketHandler(mockCtrl)
serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
serverConn.EXPECT().shutdown()
handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), clientConn)
handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), serverConn)
handler.CloseServer()
})
It("stops handling packets with unknown connection IDs after the server is closed", func() {
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
// don't EXPECT any calls to server.handlePacket
handler.SetServer(server)
handler.CloseServer()
handler.handlePacket(&receivedPacket{data: p})
})
})
Context("stateless resets", func() {
BeforeEach(func() {
connIDLen = 5
})
Context("handling", func() {
It("handles stateless resets", func() {
packetHandler := NewMockPacketHandler(mockCtrl)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, packetHandler)
destroyed := make(chan struct{})
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
defer close(destroyed)
Expect(err).To(HaveOccurred())
var resetErr *StatelessResetError
Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(resetErr.Token).To(Equal(token))
})
packetChan <- packetToRead{data: packet}
Eventually(destroyed).Should(BeClosed())
})
It("handles stateless resets for 0-length connection IDs", func() {
handler.connIDLen = 0
packetHandler := NewMockPacketHandler(mockCtrl)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, packetHandler)
destroyed := make(chan struct{})
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred())
var resetErr *StatelessResetError
Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(resetErr.Token).To(Equal(token))
close(destroyed)
})
packetChan <- packetToRead{data: packet}
Eventually(destroyed).Should(BeClosed())
})
It("removes reset tokens", func() {
connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42})
packetHandler := NewMockPacketHandler(mockCtrl)
handler.Add(connID, packetHandler)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, NewMockPacketHandler(mockCtrl))
handler.RemoveResetToken(token)
// don't EXPECT any call to packetHandler.destroy()
packetHandler.EXPECT().handlePacket(gomock.Any())
p := append([]byte{0x40} /* short header packet */, connID.Bytes()...)
p = append(p, make([]byte, 50)...)
p = append(p, token[:]...)
handler.handlePacket(&receivedPacket{data: p})
})
It("ignores packets too small to contain a stateless reset", func() {
handler.connIDLen = 0
packetHandler := NewMockPacketHandler(mockCtrl)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, packetHandler)
done := make(chan struct{})
// don't EXPECT any calls here, but register the closing of the done channel
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) {
close(done)
}).AnyTimes()
packetChan <- packetToRead{data: append([]byte{0x40} /* short header packet */, token[:15]...)}
Consistently(done).ShouldNot(BeClosed())
})
})
Context("generating", func() {
BeforeEach(func() {
var key StatelessResetKey
rand.Read(key[:])
statelessResetKey = &key
})
It("generates stateless reset tokens", func() {
connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
connID2 := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})
Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2)))
})
It("sends stateless resets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, 100)...)
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) {
defer close(done)
Expect(wire.IsLongHeaderPacket(b[0])).To(BeFalse()) // short header packet
Expect(b).To(HaveLen(protocol.MinStatelessResetSize))
})
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: p,
})
Eventually(done).Should(BeClosed())
})
It("doesn't send stateless resets for small packets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: p,
})
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
})
Context("if no key is configured", func() {
It("doesn't send stateless resets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, 100)...)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: p,
})
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
})
})
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
testErr := errors.New("shutdown")
for i := 0; i < 10; i++ {
conn := NewMockPacketHandler(mockCtrl)
conn.EXPECT().destroy(testErr)
b := make([]byte, 12)
rand.Read(b)
m.Add(protocol.ParseConnectionID(b), conn)
}
m.Close(testErr)
// check that Close can be called multiple times
m.Close(errors.New("close"))
})
})

View file

@ -1,8 +1,11 @@
package quic
import (
"bytes"
"io"
"log"
"runtime/pprof"
"strings"
"sync"
"testing"
@ -29,6 +32,20 @@ var _ = BeforeSuite(func() {
log.SetOutput(io.Discard)
})
func areServersRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*baseServer).run")
}
func areTransportsRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
}
var _ = AfterEach(func() {
mockCtrl.Finish()
Eventually(areServersRunning).Should(BeFalse())
Eventually(areTransportsRunning()).Should(BeFalse())
})

View file

@ -20,7 +20,7 @@ import (
)
// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close.
var ErrServerClosed = errors.New("quic: Server closed")
var ErrServerClosed = errors.New("quic: server closed")
// packetHandler handles packets
type packetHandler interface {
@ -30,18 +30,13 @@ type packetHandler interface {
getPerspective() protocol.Perspective
}
type unknownPacketHandler interface {
handlePacket(*receivedPacket)
setCloseError(error)
}
type packetHandlerManager interface {
Get(protocol.ConnectionID) (packetHandler, bool)
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
Destroy() error
connRunner
SetServer(unknownPacketHandler)
Close(error)
CloseServer()
connRunner
}
type quicConn interface {
@ -70,13 +65,11 @@ type baseServer struct {
config *Config
conn rawConn
// If the server is started with ListenAddr, we create a packet conn.
// If it is started with Listen, we take a packet conn as a parameter.
createdPacketConn bool
tokenGenerator *handshake.TokenGenerator
connHandler packetHandlerManager
onClose func()
receivedPackets chan *receivedPacket
@ -114,8 +107,6 @@ type baseServer struct {
logger utils.Logger
}
var _ unknownPacketHandler = &baseServer{}
// A Listener listens for incoming QUIC connections.
// It returns connections once the handshake has completed.
type Listener struct {
@ -166,37 +157,36 @@ func (l *EarlyListener) Addr() net.Addr {
// The tls.Config must not be nil and must contain a certificate configuration.
// The quic.Config may be nil, in that case the default values will be used.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) {
s, err := listenAddr(addr, tlsConf, config, false)
conn, err := listenUDP(addr)
if err != nil {
return nil, err
}
return &Listener{baseServer: s}, nil
return (&Transport{
Conn: conn,
createdConn: true,
isSingleUse: true,
}).Listen(tlsConf, config)
}
// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes.
func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
s, err := listenAddr(addr, tlsConf, config, true)
conn, err := listenUDP(addr)
if err != nil {
return nil, err
}
return &EarlyListener{baseServer: s}, nil
return (&Transport{
Conn: conn,
createdConn: true,
isSingleUse: true,
}).ListenEarly(tlsConf, config)
}
func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) {
func listenUDP(addr string) (*net.UDPConn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, err
}
serv, err := listen(conn, tlsConf, config, acceptEarly)
if err != nil {
return nil, err
}
serv.createdPacketConn = true
return serv, nil
return net.ListenUDP("udp", udpAddr)
}
// Listen listens for QUIC connections on a given net.PacketConn. If the
@ -210,45 +200,23 @@ func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bo
// Furthermore, it must define an application control (using NextProtos).
// The quic.Config may be nil, in that case the default values will be used.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) {
s, err := listen(conn, tlsConf, config, false)
if err != nil {
return nil, err
}
return &Listener{baseServer: s}, nil
tr := &Transport{Conn: conn, isSingleUse: true}
return tr.Listen(tlsConf, config)
}
// ListenEarly works like Listen, but it returns connections before the handshake completes.
func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
s, err := listen(conn, tlsConf, config, true)
if err != nil {
return nil, err
}
return &EarlyListener{baseServer: s}, nil
tr := &Transport{Conn: conn, isSingleUse: true}
return tr.ListenEarly(tlsConf, config)
}
func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(config); err != nil {
return nil, err
}
config = populateServerConfig(config)
connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
if err != nil {
return nil, err
}
func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) {
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
if err != nil {
return nil, err
}
c, err := wrapConn(conn)
if err != nil {
return nil, err
}
s := &baseServer{
conn: c,
conn: conn,
tlsConf: tlsConf,
config: config,
tokenGenerator: tokenGenerator,
@ -260,12 +228,12 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
newConn: newConnection,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
onClose: onClose,
}
if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
}
go s.run()
connHandler.SetServer(s)
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s, nil
}
@ -317,18 +285,12 @@ func (s *baseServer) Close() error {
if s.serverError == nil {
s.serverError = ErrServerClosed
}
// If the server was started with ListenAddr, we created the packet conn.
// We need to close it in order to make the go routine reading from that conn return.
createdPacketConn := s.createdPacketConn
s.closed = true
close(s.errorChan)
s.mutex.Unlock()
<-s.running
s.connHandler.CloseServer()
if createdPacketConn {
return s.connHandler.Destroy()
}
s.onClose()
return nil
}

View file

@ -1,15 +1,12 @@
package quic
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"errors"
"net"
"reflect"
"runtime/pprof"
"strings"
"sync"
"sync/atomic"
"time"
@ -24,17 +21,10 @@ import (
"github.com/quic-go/quic-go/logging"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func areServersRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*baseServer).run")
}
var _ = Describe("Server", func() {
var (
conn *MockPacketConn
@ -96,15 +86,19 @@ var _ = Describe("Server", func() {
BeforeEach(func() {
conn = NewMockPacketConn(mockCtrl)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1)
wait := make(chan struct{})
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) {
<-wait
return 0, nil, errors.New("done")
}).MaxTimes(1)
conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) {
close(wait)
conn.EXPECT().SetReadDeadline(time.Time{})
}).MaxTimes(1)
tlsConf = testdata.GetTLSConfig()
tlsConf.NextProtos = []string{"proto1"}
})
AfterEach(func() {
Eventually(areServersRunning).Should(BeFalse())
})
It("errors when no tls.Config is given", func() {
_, err := ListenAddr("localhost:0", nil, nil)
Expect(err).To(HaveOccurred())
@ -178,6 +172,7 @@ var _ = Describe("Server", func() {
Context("server accepting connections that completed the handshake", func() {
var (
ln *Listener
serv *baseServer
phm *MockPacketHandlerManager
tracer *mocklogging.MockTracer
@ -185,7 +180,8 @@ var _ = Describe("Server", func() {
BeforeEach(func() {
tracer = mocklogging.NewMockTracer(mockCtrl)
ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer})
var err error
ln, err = Listen(conn, tlsConf, &Config{Tracer: tracer})
Expect(err).ToNot(HaveOccurred())
serv = ln.baseServer
phm = NewMockPacketHandlerManager(mockCtrl)
@ -193,8 +189,7 @@ var _ = Describe("Server", func() {
})
AfterEach(func() {
phm.EXPECT().CloseServer().MaxTimes(1)
serv.Close()
ln.Close()
})
Context("handling packets", func() {
@ -753,8 +748,7 @@ var _ = Describe("Server", func() {
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
phm.EXPECT().CloseServer()
conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
@ -968,6 +962,7 @@ var _ = Describe("Server", func() {
serv.setCloseError(testErr)
Eventually(done).Should(BeClosed())
serv.onClose() // shutdown
})
It("returns immediately, if an error occurred before", func() {
@ -977,6 +972,7 @@ var _ = Describe("Server", func() {
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
}
serv.onClose() // shutdown
})
It("returns when the context is canceled", func() {
@ -1064,7 +1060,6 @@ var _ = Describe("Server", func() {
})
AfterEach(func() {
phm.EXPECT().CloseServer().MaxTimes(1)
serv.Close()
})
@ -1234,8 +1229,7 @@ var _ = Describe("Server", func() {
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
phm.EXPECT().CloseServer()
conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})

410
transport.go Normal file
View file

@ -0,0 +1,410 @@
package quic
import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
type Transport struct {
// A single net.PacketConn can only be handled by one Transport.
// Bad things will happen if passed to multiple Transports.
//
// If the connection satisfies the OOBCapablePacketConn interface
// (as a net.UDPConn does), ECN and packet info support will be enabled.
// In this case, optimized syscalls might be used, skipping the
// ReadFrom and WriteTo calls to read / write packets.
Conn net.PacketConn
// The length of the connection ID in bytes.
// It can be 0, or any value between 4 and 18.
// If unset, a 4 byte connection ID will be used.
ConnectionIDLength int
// Use for generating new connection IDs.
// This allows the application to control of the connection IDs used,
// which allows routing / load balancing based on connection IDs.
// All Connection IDs returned by the ConnectionIDGenerator MUST
// have the same length.
ConnectionIDGenerator ConnectionIDGenerator
// The StatelessResetKey is used to generate stateless reset tokens.
// If no key is configured, sending of stateless resets is disabled.
StatelessResetKey *StatelessResetKey
// A Tracer traces events that don't belong to a single QUIC connection.
Tracer logging.Tracer
handlerMap packetHandlerManager
mutex sync.Mutex
initOnce sync.Once
initErr error
// Set in init.
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
connIDLen int
server unknownPacketHandler
conn rawConn
closeQueue chan closePacket
listening chan struct{} // is closed when listen returns
closed bool
createdConn bool
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
logger utils.Logger
}
// Listen starts listening for incoming QUIC connections.
// There can only be a single listener on any net.PacketConn.
// Listen may only be called again after the current Listener was closed.
func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(conf); err != nil {
return nil, err
}
t.mutex.Lock()
defer t.mutex.Unlock()
if t.server != nil {
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(conf); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, false)
if err != nil {
return nil, err
}
t.server = s
return &Listener{baseServer: s}, nil
}
// ListenEarly starts listening for incoming QUIC connections.
// There can only be a single listener on any net.PacketConn.
// Listen may only be called again after the current Listener was closed.
func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(conf); err != nil {
return nil, err
}
t.mutex.Lock()
defer t.mutex.Unlock()
if t.server != nil {
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(conf); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, true)
if err != nil {
return nil, err
}
t.server = s
return &EarlyListener{baseServer: s}, nil
}
// Dial dials a new connection to a remote host (not using 0-RTT).
func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateClientConfig(conf, t.createdConn)
if err := t.init(conf); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn)
}
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateClientConfig(conf, t.createdConn)
if err := t.init(conf); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn)
}
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
conn, ok := c.(interface{ SetReadBuffer(int) error })
if !ok {
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
}
size, err := inspectReadBuffer(c)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if size >= protocol.DesiredReceiveBufferSize {
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
return nil
}
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
return fmt.Errorf("failed to increase receive buffer size: %w", err)
}
newSize, err := inspectReadBuffer(c)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if newSize == size {
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
if newSize < protocol.DesiredReceiveBufferSize {
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
return nil
}
// only print warnings about the UDP receive buffer size once
var receiveBufferWarningOnce sync.Once
func (t *Transport) init(conf *Config) error {
t.initOnce.Do(func() {
getMultiplexer().AddConn(t.Conn)
conn, err := wrapConn(t.Conn)
if err != nil {
t.initErr = err
return
}
t.StatelessResetKey = conf.StatelessResetKey
t.Tracer = conf.Tracer
t.ConnectionIDLength = conf.ConnectionIDLength
t.ConnectionIDGenerator = conf.ConnectionIDGenerator
t.logger = utils.DefaultLogger // TODO: make this configurable
t.conn = conn
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
t.listening = make(chan struct{})
t.closeQueue = make(chan closePacket, 4)
if t.ConnectionIDGenerator != nil {
t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
} else {
t.connIDLen = t.ConnectionIDLength
}
go t.listen(conn)
go t.runCloseQueue()
})
return t.initErr
}
func (t *Transport) enqueueClosePacket(p closePacket) {
select {
case t.closeQueue <- p:
default:
// Oops, we're backlogged.
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
}
}
func (t *Transport) runCloseQueue() {
for {
select {
case <-t.listening:
return
case p := <-t.closeQueue:
t.conn.WritePacket(p.payload, p.addr, p.info.OOB())
}
}
}
// Close closes the underlying connection and waits until listen has returned.
// It is invalid to start new listeners or connections after that.
func (t *Transport) Close() error {
t.close(errors.New("closing"))
if t.createdConn {
if err := t.conn.Close(); err != nil {
return err
}
} else {
t.conn.SetReadDeadline(time.Now())
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
}
<-t.listening // wait until listening returns
return nil
}
func (t *Transport) closeServer() {
t.handlerMap.CloseServer()
t.mutex.Lock()
t.server = nil
if t.isSingleUse {
t.closed = true
}
t.mutex.Unlock()
if t.createdConn {
t.Conn.Close()
}
if t.isSingleUse {
t.conn.SetReadDeadline(time.Now())
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
<-t.listening // wait until listening returns
}
}
func (t *Transport) close(e error) {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.closed {
return
}
t.handlerMap.Close(e)
if t.server != nil {
t.server.setCloseError(e)
}
t.closed = true
}
func (t *Transport) listen(conn rawConn) {
defer close(t.listening)
defer getMultiplexer().RemoveConn(t.Conn)
if err := setReceiveBuffer(t.Conn, t.logger); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
receiveBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
})
}
}
for {
p, err := conn.ReadPacket()
//nolint:staticcheck // SA1019 ignore this!
// TODO: This code is used to ignore wsa errors on Windows.
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
// See https://github.com/quic-go/quic-go/issues/1737 for details.
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
t.mutex.Lock()
closed := t.closed
t.mutex.Unlock()
if closed {
return
}
t.logger.Debugf("Temporary error reading from conn: %w", err)
continue
}
if err != nil {
t.close(err)
return
}
t.handlePacket(p)
}
}
func (t *Transport) handlePacket(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
if err != nil {
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
if t.Tracer != nil {
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
p.buffer.MaybeRelease()
return
}
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
return
}
if handler, ok := t.handlerMap.Get(connID); ok {
handler.handlePacket(p)
return
}
if !wire.IsLongHeaderPacket(p.data[0]) {
go t.maybeSendStatelessReset(p, connID)
return
}
t.mutex.Lock()
defer t.mutex.Unlock()
if t.server == nil { // no server set
t.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
return
}
t.server.handlePacket(p)
}
func (t *Transport) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
defer p.buffer.Release()
if t.StatelessResetKey == nil {
return
}
// Don't send a stateless reset in response to very small packets.
// This includes packets that could be stateless resets.
if len(p.data) <= protocol.MinStatelessResetSize {
return
}
token := t.handlerMap.GetStatelessResetToken(connID)
t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...)
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
t.logger.Debugf("Error sending Stateless Reset: %s", err)
}
}
func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
// stateless resets are always short header packets
if wire.IsLongHeaderPacket(data[0]) {
return false
}
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
return false
}
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
if conn, ok := t.handlerMap.GetByResetToken(token); ok {
t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
go conn.destroy(&StatelessResetError{Token: token})
return true
}
return false
}

287
transport_test.go Normal file
View file

@ -0,0 +1,287 @@
package quic
import (
"bytes"
"crypto/rand"
"crypto/tls"
"errors"
"net"
"time"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Transport", func() {
type packetToRead struct {
addr net.Addr
data []byte
err error
}
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
b, err := (&wire.ExtendedHeader{
Header: wire.Header{
Type: t,
DestConnectionID: connID,
Length: length,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen2,
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
return b
}
getPacket := func(connID protocol.ConnectionID) []byte {
return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
}
newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn {
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
p, ok := <-packetChan
if !ok {
return 0, nil, errors.New("closed")
}
return copy(b, p.data), p.addr, p.err
}).AnyTimes()
// for shutdown
conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
return conn
}
It("handles packets for different packet handlers on the same packet conn", func() {
packetChan := make(chan packetToRead)
tr := &Transport{Conn: newMockPacketConn(packetChan)}
tr.init(&Config{})
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
handled := make(chan struct{}, 2)
phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
h := NewMockPacketHandler(mockCtrl)
h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
defer GinkgoRecover()
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID1))
handled <- struct{}{}
})
return h, true
})
phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
h := NewMockPacketHandler(mockCtrl)
h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
defer GinkgoRecover()
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID2))
handled <- struct{}{}
})
return h, true
})
packetChan <- packetToRead{data: getPacket(connID1)}
packetChan <- packetToRead{data: getPacket(connID2)}
Eventually(handled).Should(Receive())
Eventually(handled).Should(Receive())
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("closes listeners", func() {
packetChan := make(chan packetToRead)
tr := &Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
ln, err := tr.Listen(&tls.Config{}, nil)
Expect(err).ToNot(HaveOccurred())
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
phm.EXPECT().CloseServer()
Expect(ln.Close()).To(Succeed())
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("drops unparseable packets", func() {
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
packetChan := make(chan packetToRead)
tracer := mocklogging.NewMockTracer(mockCtrl)
tr := &Transport{
Conn: newMockPacketConn(packetChan),
}
tr.init(&Config{Tracer: tracer, ConnectionIDLength: 10})
dropped := make(chan struct{})
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
packetChan <- packetToRead{
addr: addr,
data: []byte{0, 1, 2, 3},
}
Eventually(dropped).Should(BeClosed())
// shutdown
close(packetChan)
tr.Close()
})
It("closes when reading from the conn fails", func() {
packetChan := make(chan packetToRead)
tr := Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(&Config{})
tr.handlerMap = phm
done := make(chan struct{})
phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) })
packetChan <- packetToRead{err: errors.New("read failed")}
Eventually(done).Should(BeClosed())
// shutdown
close(packetChan)
tr.Close()
})
It("continues listening after temporary errors", func() {
packetChan := make(chan packetToRead)
tr := Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(&Config{})
tr.handlerMap = phm
tempErr := deadlineError{}
Expect(tempErr.Temporary()).To(BeTrue())
packetChan <- packetToRead{err: tempErr}
// don't expect any calls to phm.Close
time.Sleep(50 * time.Millisecond)
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("handles short header packets resets", func() {
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
packetChan := make(chan packetToRead)
tr := Transport{Conn: newMockPacketConn(packetChan)}
tr.init(&Config{ConnectionIDLength: connID.Len()})
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(&Config{})
tr.handlerMap = phm
var token protocol.StatelessResetToken
rand.Read(token[:])
var b []byte
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
b = append(b, token[:]...)
conn := NewMockPacketHandler(mockCtrl)
gomock.InOrder(
phm.EXPECT().GetByResetToken(token),
phm.EXPECT().Get(connID).Return(conn, true),
conn.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.data).To(Equal(b))
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second))
}),
)
packetChan <- packetToRead{data: b}
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("handles stateless resets", func() {
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
packetChan := make(chan packetToRead)
tr := Transport{Conn: newMockPacketConn(packetChan)}
tr.init(&Config{ConnectionIDLength: connID.Len()})
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(&Config{})
tr.handlerMap = phm
var token protocol.StatelessResetToken
rand.Read(token[:])
var b []byte
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
b = append(b, token[:]...)
conn := NewMockPacketHandler(mockCtrl)
gomock.InOrder(
phm.EXPECT().GetByResetToken(token).Return(conn, true),
conn.EXPECT().destroy(gomock.Any()).Do(func(err error) {
Expect(err).To(MatchError(&StatelessResetError{Token: token}))
}),
)
packetChan <- packetToRead{data: b}
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("sends stateless resets", func() {
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
packetChan := make(chan packetToRead)
conn := newMockPacketConn(packetChan)
tr := Transport{
Conn: conn,
}
tr.init(&Config{ConnectionIDLength: connID.Len(), StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}})
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(&Config{})
tr.handlerMap = phm
var b []byte
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...)
var token protocol.StatelessResetToken
rand.Read(token[:])
written := make(chan struct{})
gomock.InOrder(
phm.EXPECT().GetByResetToken(gomock.Any()),
phm.EXPECT().Get(connID),
phm.EXPECT().GetStatelessResetToken(connID).Return(token),
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) {
defer close(written)
Expect(bytes.Contains(b, token[:])).To(BeTrue())
}),
)
packetChan <- packetToRead{data: b}
Eventually(written).Should(BeClosed())
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
})