From d7334c16e7d03fdf692c8e2a3d110687b90a1466 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 31 Aug 2023 13:33:40 +0700 Subject: [PATCH 01/43] move the DisableVersionNegotiationPackets flag to the Transport (#4047) * move the DisableVersionNegotiationPackets flag to the Transport * add an integration test for DisableVersionNegotiationPackets --- config.go | 41 ++++++++-------- config_test.go | 1 - .../versionnegotiation/handshake_test.go | 44 ++++++++++++++++- interface.go | 4 -- server.go | 41 ++++++++-------- server_test.go | 2 +- transport.go | 47 ++++++++++--------- 7 files changed, 112 insertions(+), 68 deletions(-) diff --git a/config.go b/config.go index 59df4cfd..59a4d922 100644 --- a/config.go +++ b/config.go @@ -110,26 +110,25 @@ func populateConfig(config *Config) *Config { } return &Config{ - GetConfigForClient: config.GetConfigForClient, - Versions: versions, - HandshakeIdleTimeout: handshakeIdleTimeout, - MaxIdleTimeout: idleTimeout, - MaxTokenAge: config.MaxTokenAge, - MaxRetryTokenAge: config.MaxRetryTokenAge, - RequireAddressValidation: config.RequireAddressValidation, - KeepAlivePeriod: config.KeepAlivePeriod, - InitialStreamReceiveWindow: initialStreamReceiveWindow, - MaxStreamReceiveWindow: maxStreamReceiveWindow, - InitialConnectionReceiveWindow: initialConnectionReceiveWindow, - MaxConnectionReceiveWindow: maxConnectionReceiveWindow, - AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, - MaxIncomingStreams: maxIncomingStreams, - MaxIncomingUniStreams: maxIncomingUniStreams, - TokenStore: config.TokenStore, - EnableDatagrams: config.EnableDatagrams, - DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, - DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets, - Allow0RTT: config.Allow0RTT, - Tracer: config.Tracer, + GetConfigForClient: config.GetConfigForClient, + Versions: versions, + HandshakeIdleTimeout: handshakeIdleTimeout, + MaxIdleTimeout: idleTimeout, + MaxTokenAge: config.MaxTokenAge, + MaxRetryTokenAge: config.MaxRetryTokenAge, + RequireAddressValidation: config.RequireAddressValidation, + KeepAlivePeriod: config.KeepAlivePeriod, + InitialStreamReceiveWindow: initialStreamReceiveWindow, + MaxStreamReceiveWindow: maxStreamReceiveWindow, + InitialConnectionReceiveWindow: initialConnectionReceiveWindow, + MaxConnectionReceiveWindow: maxConnectionReceiveWindow, + AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, + MaxIncomingStreams: maxIncomingStreams, + MaxIncomingUniStreams: maxIncomingUniStreams, + TokenStore: config.TokenStore, + EnableDatagrams: config.EnableDatagrams, + DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, + Allow0RTT: config.Allow0RTT, + Tracer: config.Tracer, } } diff --git a/config_test.go b/config_test.go index 1eca3d5d..7208b4ad 100644 --- a/config_test.go +++ b/config_test.go @@ -192,7 +192,6 @@ var _ = Describe("Config", func() { Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow)) Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams)) Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams)) - Expect(c.DisableVersionNegotiationPackets).To(BeFalse()) Expect(c.DisablePathMTUDiscovery).To(BeFalse()) Expect(c.GetConfigForClient).To(BeNil()) }) diff --git a/integrationtests/versionnegotiation/handshake_test.go b/integrationtests/versionnegotiation/handshake_test.go index 965700c1..a079f6e1 100644 --- a/integrationtests/versionnegotiation/handshake_test.go +++ b/integrationtests/versionnegotiation/handshake_test.go @@ -3,8 +3,10 @@ package versionnegotiation import ( "context" "crypto/tls" + "errors" "fmt" "net" + "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/integrationtests/tools/israce" @@ -113,7 +115,7 @@ var _ = Describe("Handshake tests", func() { It("when the client supports more versions than the server supports", func() { expectedVersion := protocol.SupportedVersions[0] - // the server doesn't support the highest supported version, which is the first one the client will try + // The server doesn't support the highest supported version, which is the first one the client will try, // but it supports a bunch of versions that the client doesn't speak serverTracer := &versionNegotiationTracer{} serverConfig := &quic.Config{} @@ -147,5 +149,45 @@ var _ = Describe("Handshake tests", func() { Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) Expect(serverTracer.clientVersions).To(BeEmpty()) }) + + It("fails if the server disables version negotiation", func() { + // The server doesn't support the highest supported version, which is the first one the client will try, + // but it supports a bunch of versions that the client doesn't speak + serverTracer := &versionNegotiationTracer{} + serverConfig := &quic.Config{} + serverConfig.Versions = supportedVersions + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return serverTracer + } + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + Expect(err).ToNot(HaveOccurred()) + tr := &quic.Transport{ + Conn: conn, + DisableVersionNegotiationPackets: true, + } + ln, err := tr.Listen(getTLSConfig(), serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} + clientTracer := &versionNegotiationTracer{} + _, err = quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port), + getTLSClientConfig(), + maybeAddQLOGTracer(&quic.Config{ + Versions: clientVersions, + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return clientTracer + }, + HandshakeIdleTimeout: 100 * time.Millisecond, + }), + ) + Expect(err).To(HaveOccurred()) + var nerr net.Error + Expect(errors.As(err, &nerr)).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(clientTracer.receivedVersionNegotiation).To(BeFalse()) + }) } }) diff --git a/interface.go b/interface.go index c3fb2b10..5d5ab5b0 100644 --- a/interface.go +++ b/interface.go @@ -322,10 +322,6 @@ type Config struct { // Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit. // If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. DisablePathMTUDiscovery bool - // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. - // This can be useful if version information is exchanged out-of-band. - // It has no effect for a client. - DisableVersionNegotiationPackets bool // Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted. // Only valid for the server. Allow0RTT bool diff --git a/server.go b/server.go index c06228c9..14cc6f82 100644 --- a/server.go +++ b/server.go @@ -59,7 +59,8 @@ type zeroRTTQueue struct { type baseServer struct { mutex sync.Mutex - acceptEarlyConns bool + disableVersionNegotiation bool + acceptEarlyConns bool tlsConf *tls.Config config *Config @@ -226,6 +227,7 @@ func newServer( config *Config, tracer logging.Tracer, onClose func(), + disableVersionNegotiation bool, acceptEarly bool, ) (*baseServer, error) { tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) @@ -233,23 +235,24 @@ func newServer( return nil, err } s := &baseServer{ - conn: conn, - tlsConf: tlsConf, - config: config, - tokenGenerator: tokenGenerator, - connIDGenerator: connIDGenerator, - connHandler: connHandler, - connQueue: make(chan quicConn), - errorChan: make(chan struct{}), - running: make(chan struct{}), - receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), - versionNegotiationQueue: make(chan receivedPacket, 4), - invalidTokenQueue: make(chan receivedPacket, 4), - newConn: newConnection, - tracer: tracer, - logger: utils.DefaultLogger.WithPrefix("server"), - acceptEarlyConns: acceptEarly, - onClose: onClose, + conn: conn, + tlsConf: tlsConf, + config: config, + tokenGenerator: tokenGenerator, + connIDGenerator: connIDGenerator, + connHandler: connHandler, + connQueue: make(chan quicConn), + errorChan: make(chan struct{}), + running: make(chan struct{}), + receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), + versionNegotiationQueue: make(chan receivedPacket, 4), + invalidTokenQueue: make(chan receivedPacket, 4), + newConn: newConnection, + tracer: tracer, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlyConns: acceptEarly, + disableVersionNegotiation: disableVersionNegotiation, + onClose: onClose, } if acceptEarly { s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} @@ -383,7 +386,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st } // send a Version Negotiation Packet if the client is speaking a different protocol version if !protocol.IsSupportedVersion(s.config.Versions, v) { - if s.config.DisableVersionNegotiationPackets { + if s.disableVersionNegotiation { return false } diff --git a/server_test.go b/server_test.go index 4705225a..e959aee5 100644 --- a/server_test.go +++ b/server_test.go @@ -357,7 +357,7 @@ var _ = Describe("Server", func() { }) It("doesn't send a Version Negotiation packets if sending them is disabled", func() { - serv.config.DisableVersionNegotiationPackets = true + serv.disableVersionNegotiation = true srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) packet := getPacket(&wire.Header{ diff --git a/transport.go b/transport.go index d8da9b1a..9f93cbb9 100644 --- a/transport.go +++ b/transport.go @@ -57,6 +57,11 @@ type Transport struct { // See section 10.3 of RFC 9000 for details. StatelessResetKey *StatelessResetKey + // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. + // This can be useful if version information is exchanged out-of-band. + // It has no effect for clients. + DisableVersionNegotiationPackets bool + // A Tracer traces events that don't belong to a single QUIC connection. Tracer logging.Tracer @@ -95,28 +100,10 @@ type Transport struct { // There can only be a single listener on any net.PacketConn. // Listen may only be called again after the current Listener was closed. func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(conf); err != nil { - return nil, err - } - - t.mutex.Lock() - defer t.mutex.Unlock() - - if t.server != nil { - return nil, errListenerAlreadySet - } - conf = populateServerConfig(conf) - if err := t.init(false); err != nil { - return nil, err - } - s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false) + s, err := t.createServer(tlsConf, conf, false) if err != nil { return nil, err } - t.server = s return &Listener{baseServer: s}, nil } @@ -124,6 +111,14 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) // There can only be a single listener on any net.PacketConn. // Listen may only be called again after the current Listener was closed. func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) { + s, err := t.createServer(tlsConf, conf, true) + if err != nil { + return nil, err + } + return &EarlyListener{baseServer: s}, nil +} + +func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bool) (*baseServer, error) { if tlsConf == nil { return nil, errors.New("quic: tls.Config not set") } @@ -141,12 +136,22 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen if err := t.init(false); err != nil { return nil, err } - s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true) + s, err := newServer( + t.conn, + t.handlerMap, + t.connIDGenerator, + tlsConf, + conf, + t.Tracer, + t.closeServer, + t.DisableVersionNegotiationPackets, + allow0RTT, + ) if err != nil { return nil, err } t.server = s - return &EarlyListener{baseServer: s}, nil + return s, nil } // Dial dials a new connection to a remote host (not using 0-RTT). From 090e505aa9072a4125765e59b0399fc7850aa3d9 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 31 Aug 2023 14:49:27 +0700 Subject: [PATCH 02/43] move GSO control message handling to the oobConn (#4056) * move GSO control message handling to the oobConn * disable OOB test on Windows * improve GSO tests * update ooConn.WritePacket comment --- connection.go | 12 ++++----- connection_test.go | 36 +++++++++++++------------- mock_raw_conn_test.go | 8 +++--- mock_send_conn_test.go | 3 +-- mock_sender_test.go | 3 +-- packet_handler_map.go | 5 ++-- send_conn.go | 40 +++++++++-------------------- send_conn_test.go | 58 ++++++++++++++++++++---------------------- send_queue.go | 15 +++++------ send_queue_test.go | 8 +++--- server.go | 6 ++--- sys_conn.go | 5 +++- sys_conn_oob.go | 10 ++++++-- sys_conn_oob_test.go | 34 +++++++++++++++++++++++++ transport.go | 6 ++--- 15 files changed, 134 insertions(+), 115 deletions(-) diff --git a/connection.go b/connection.go index 877f2d03..eb80cd13 100644 --- a/connection.go +++ b/connection.go @@ -1832,7 +1832,7 @@ func (s *connection) sendPackets(now time.Time) error { } s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) s.registerPackedShortHeaderPacket(p, now) - s.sendQueue.Send(buf, buf.Len()) + s.sendQueue.Send(buf, 0) // This is kind of a hack. We need to trigger sending again somehow. s.pacingDeadline = deadlineSendImmediately return nil @@ -1881,7 +1881,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error { return err } - s.sendQueue.Send(buf, buf.Len()) + s.sendQueue.Send(buf, 0) if s.sendQueue.WouldBlock() { return nil @@ -1938,7 +1938,7 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { continue } - s.sendQueue.Send(buf, maxSize) + s.sendQueue.Send(buf, uint16(maxSize)) if dontSendMore { return nil @@ -1986,7 +1986,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { } s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) s.registerPackedShortHeaderPacket(p, now) - s.sendQueue.Send(buf, buf.Len()) + s.sendQueue.Send(buf, 0) return nil } @@ -2078,7 +2078,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) } s.connIDManager.SentPacket() - s.sendQueue.Send(packet.buffer, packet.buffer.Len()) + s.sendQueue.Send(packet.buffer, 0) return nil } @@ -2101,7 +2101,7 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { return nil, err } s.logCoalescedPacket(packet) - return packet.buffer.Data, s.conn.Write(packet.buffer.Data, packet.buffer.Len()) + return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0) } func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { diff --git a/connection_test.go b/connection_test.go index 05a0f20d..eb512513 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1243,7 +1243,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() sent := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) tracer.EXPECT().SentShortHeaderPacket(&logging.ShortHeader{ DestConnectionID: p.DestConnID, PacketNumber: p.PacketNumber, @@ -1291,7 +1291,7 @@ var _ = Describe("Connection", func() { conn.connFlowController = fc runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) @@ -1347,7 +1347,7 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) if enc == protocol.Encryption1RTT { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) } else { @@ -1372,7 +1372,7 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) if enc == protocol.Encryption1RTT { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) } else { @@ -1428,10 +1428,10 @@ var _ = Describe("Connection", func() { expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, []byte("packet11")) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ protocol.ByteCount) { + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16) { Expect(b.Data).To(Equal([]byte("packet10"))) }) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ protocol.ByteCount) { + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16) { Expect(b.Data).To(Equal([]byte("packet11"))) }) go func() { @@ -1456,7 +1456,7 @@ var _ = Describe("Connection", func() { expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), conn.mtuDiscoverer.CurrentSize()).Do(func(b *packetBuffer, l protocol.ByteCount) { + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize())).Do(func(b *packetBuffer, l uint16) { Expect(b.Data).To(Equal(append(payload1, payload2...))) }) go func() { @@ -1481,7 +1481,7 @@ var _ = Describe("Connection", func() { expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), conn.mtuDiscoverer.CurrentSize()).Do(func(b *packetBuffer, l protocol.ByteCount) { + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize())).Do(func(b *packetBuffer, l uint16) { Expect(b.Data).To(Equal(append(payload1, payload2...))) }) go func() { @@ -1564,7 +1564,7 @@ var _ = Describe("Connection", func() { ) written := make(chan struct{}, 2) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(2) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }).Times(2) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1587,7 +1587,7 @@ var _ = Describe("Connection", func() { } written := make(chan struct{}, 3) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(3) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }).Times(3) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1622,7 +1622,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { close(written) }) available <- struct{}{} Eventually(written).Should(BeClosed()) }) @@ -1646,7 +1646,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { close(written) }) conn.scheduleSending() time.Sleep(scaleDuration(50 * time.Millisecond)) @@ -1661,7 +1661,7 @@ var _ = Describe("Connection", func() { written := make(chan struct{}, 1) sender.EXPECT().WouldBlock() sender.EXPECT().WouldBlock().Return(true).Times(2) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1680,7 +1680,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().WouldBlock().AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1001}, []byte("packet1001")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) available <- struct{}{} Eventually(written).Should(Receive()) @@ -1713,7 +1713,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true) ping := ackhandler.Frame{Frame: &wire.PingFrame{}} mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) @@ -1776,7 +1776,7 @@ var _ = Describe("Connection", func() { time.Sleep(50 * time.Millisecond) // only EXPECT calls after scheduleSending is called written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(written) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() conn.scheduleSending() Eventually(written).Should(BeClosed()) @@ -1799,7 +1799,7 @@ var _ = Describe("Connection", func() { conn.receivedPacketHandler = rph written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(written) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() go func() { defer GinkgoRecover() @@ -1864,7 +1864,7 @@ var _ = Describe("Connection", func() { ) sent := make(chan struct{}) - mconn.EXPECT().Write([]byte("foobar"), protocol.ByteCount(6)).Do(func([]byte, protocol.ByteCount) { close(sent) }) + mconn.EXPECT().Write([]byte("foobar"), uint16(0)).Do(func([]byte, uint16) { close(sent) }) go func() { defer GinkgoRecover() diff --git a/mock_raw_conn_test.go b/mock_raw_conn_test.go index 0a1a0f3a..84d8f276 100644 --- a/mock_raw_conn_test.go +++ b/mock_raw_conn_test.go @@ -93,18 +93,18 @@ func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Cal } // WritePacket mocks base method. -func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte) (int, error) { +func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 uint16) (int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // WritePacket indicates an expected call of WritePacket. -func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3) } // capabilities mocks base method. diff --git a/mock_send_conn_test.go b/mock_send_conn_test.go index 04df8763..529b9c58 100644 --- a/mock_send_conn_test.go +++ b/mock_send_conn_test.go @@ -8,7 +8,6 @@ import ( net "net" reflect "reflect" - protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) @@ -78,7 +77,7 @@ func (mr *MockSendConnMockRecorder) RemoteAddr() *gomock.Call { } // Write mocks base method. -func (m *MockSendConn) Write(arg0 []byte, arg1 protocol.ByteCount) error { +func (m *MockSendConn) Write(arg0 []byte, arg1 uint16) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Write", arg0, arg1) ret0, _ := ret[0].(error) diff --git a/mock_sender_test.go b/mock_sender_test.go index c2a0fa8f..40671562 100644 --- a/mock_sender_test.go +++ b/mock_sender_test.go @@ -7,7 +7,6 @@ package quic import ( reflect "reflect" - protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) @@ -75,7 +74,7 @@ func (mr *MockSenderMockRecorder) Run() *gomock.Call { } // Send mocks base method. -func (m *MockSender) Send(arg0 *packetBuffer, arg1 protocol.ByteCount) { +func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16) { m.ctrl.T.Helper() m.ctrl.Call(m, "Send", arg0, arg1) } diff --git a/packet_handler_map.go b/packet_handler_map.go index e0f0567d..60b7cef9 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -27,8 +27,9 @@ type connCapabilities struct { type rawConn interface { ReadPacket() (receivedPacket, error) // WritePacket writes a packet on the wire. - // If GSO is enabled, it's the caller's responsibility to set the correct control message. - WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) + // gsoSize is the size of a single packet, or 0 to disable GSO. + // It is invalid to set gsoSize if capabilities.GSO is not set. + WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16) (int, error) LocalAddr() net.Addr SetReadDeadline(time.Time) error io.Closer diff --git a/send_conn.go b/send_conn.go index d8ddbc87..272c61b1 100644 --- a/send_conn.go +++ b/send_conn.go @@ -1,17 +1,14 @@ package quic import ( - "fmt" - "math" "net" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) // A sendConn allows sending using a simple Write() on a non-connected packet conn. type sendConn interface { - Write(b []byte, size protocol.ByteCount) error + Write(b []byte, gsoSize uint16) error Close() error LocalAddr() net.Addr RemoteAddr() net.Addr @@ -27,8 +24,7 @@ type sconn struct { logger utils.Logger - info packetInfo - oob []byte + packetInfoOOB []byte // If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled. gotGSOError bool } @@ -51,28 +47,16 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge oob = append(oob, make([]byte, 32)...) oob = oob[:l] return &sconn{ - rawConn: c, - localAddr: localAddr, - remoteAddr: remote, - info: info, - oob: oob, - logger: logger, + rawConn: c, + localAddr: localAddr, + remoteAddr: remote, + packetInfoOOB: oob, + logger: logger, } } -func (c *sconn) Write(p []byte, size protocol.ByteCount) error { - if !c.capabilities().GSO { - if protocol.ByteCount(len(p)) != size { - panic(fmt.Sprintf("inconsistent packet size (%d vs %d)", len(p), size)) - } - _, err := c.WritePacket(p, c.remoteAddr, c.oob) - return err - } - // GSO is supported. Append the control message and send. - if size > math.MaxUint16 { - panic("size overflow") - } - _, err := c.WritePacket(p, c.remoteAddr, appendUDPSegmentSizeMsg(c.oob, uint16(size))) +func (c *sconn) Write(p []byte, gsoSize uint16) error { + _, err := c.WritePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize) if err != nil && isGSOError(err) { // disable GSO for future calls c.gotGSOError = true @@ -82,10 +66,10 @@ func (c *sconn) Write(p []byte, size protocol.ByteCount) error { // send out the packets one by one for len(p) > 0 { l := len(p) - if l > int(size) { - l = int(size) + if l > int(gsoSize) { + l = int(gsoSize) } - if _, err := c.WritePacket(p[:l], c.remoteAddr, c.oob); err != nil { + if _, err := c.WritePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0); err != nil { return err } p = p[l:] diff --git a/send_conn_test.go b/send_conn_test.go index 7f072430..963f2482 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -3,6 +3,7 @@ package quic import ( "net" "net/netip" + "runtime" "github.com/quic-go/quic-go/internal/utils" @@ -35,48 +36,43 @@ var _ = Describe("Connection (for sending packets)", func() { Expect(c.LocalAddr().String()).To(Equal("127.0.0.42:1234")) }) - if platformSupportsGSO { - It("writes with GSO", func() { + // We're not using an OOB conn on windows, and packetInfo.OOB() always returns an empty slice. + if runtime.GOOS != "windows" { + It("sets the OOB", func() { rawConn := NewMockRawConn(mockCtrl) rawConn.EXPECT().LocalAddr() - rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).AnyTimes() - c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any()).Do(func(_ []byte, _ net.Addr, oob []byte) { - msg := appendUDPSegmentSizeMsg([]byte{}, 3) - Expect(oob).To(Equal(msg)) - }) - Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + rawConn.EXPECT().capabilities().AnyTimes() + pi := packetInfo{addr: netip.IPv6Loopback()} + Expect(pi.OOB()).ToNot(BeEmpty()) + c := newSendConn(rawConn, remoteAddr, pi, utils.DefaultLogger) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, pi.OOB(), uint16(0)) + Expect(c.Write([]byte("foobar"), 0)).To(Succeed()) }) + } - It("disables GSO if writing fails", func() { + It("writes", func() { + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr() + rawConn.EXPECT().capabilities().AnyTimes() + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(3)) + Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + }) + + if platformSupportsGSO { + It("disables GSO if sending fails", func() { rawConn := NewMockRawConn(mockCtrl) rawConn.EXPECT().LocalAddr() rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).AnyTimes() c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) Expect(c.capabilities().GSO).To(BeTrue()) gomock.InOrder( - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any()).DoAndReturn(func(_ []byte, _ net.Addr, oob []byte) (int, error) { - msg := appendUDPSegmentSizeMsg([]byte{}, 3) - Expect(oob).To(Equal(msg)) - return 0, errGSO - }), - rawConn.EXPECT().WritePacket([]byte("foo"), remoteAddr, gomock.Len(0)).Return(3, nil), - rawConn.EXPECT().WritePacket([]byte("bar"), remoteAddr, gomock.Len(0)).Return(3, nil), + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(4)).Return(0, errGSO), + rawConn.EXPECT().WritePacket([]byte("foob"), remoteAddr, gomock.Any(), uint16(0)).Return(4, nil), + rawConn.EXPECT().WritePacket([]byte("ar"), remoteAddr, gomock.Any(), uint16(0)).Return(2, nil), ) - Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) - Expect(c.capabilities().GSO).To(BeFalse()) // GSO support is now disabled - // make sure we actually enforce that - Expect(func() { c.Write([]byte("foobar"), 3) }).To(PanicWith("inconsistent packet size (6 vs 3)")) - }) - } else { - It("writes without GSO", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - rawConn := NewMockRawConn(mockCtrl) - rawConn.EXPECT().LocalAddr() - rawConn.EXPECT().capabilities() - c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Len(0)) - Expect(c.Write([]byte("foobar"), 6)).To(Succeed()) + Expect(c.Write([]byte("foobar"), 4)).To(Succeed()) + Expect(c.capabilities().GSO).To(BeFalse()) }) } }) diff --git a/send_queue.go b/send_queue.go index a9f7ca1a..2da546e5 100644 --- a/send_queue.go +++ b/send_queue.go @@ -1,9 +1,8 @@ package quic -import "github.com/quic-go/quic-go/internal/protocol" - type sender interface { - Send(p *packetBuffer, packetSize protocol.ByteCount) + // Send sends a packet. GSO is only used if gsoSize > 0. + Send(p *packetBuffer, gsoSize uint16) Run() error WouldBlock() bool Available() <-chan struct{} @@ -11,8 +10,8 @@ type sender interface { } type queueEntry struct { - buf *packetBuffer - size protocol.ByteCount + buf *packetBuffer + gsoSize uint16 } type sendQueue struct { @@ -40,9 +39,9 @@ func newSendQueue(conn sendConn) sender { // Send sends out a packet. It's guaranteed to not block. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Otherwise Send will panic. -func (h *sendQueue) Send(p *packetBuffer, size protocol.ByteCount) { +func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16) { select { - case h.queue <- queueEntry{buf: p, size: size}: + case h.queue <- queueEntry{buf: p, gsoSize: gsoSize}: // clear available channel if we've reached capacity if len(h.queue) == sendQueueCapacity { select { @@ -77,7 +76,7 @@ func (h *sendQueue) Run() error { // make sure that all queued packets are actually sent out shouldClose = true case e := <-h.queue: - if err := h.conn.Write(e.buf.Data, e.size); err != nil { + if err := h.conn.Write(e.buf.Data, e.gsoSize); err != nil { // This additional check enables: // 1. Checking for "datagram too large" message from the kernel, as such, // 2. Path MTU discovery,and diff --git a/send_queue_test.go b/send_queue_test.go index 69562c58..0ed7bea5 100644 --- a/send_queue_test.go +++ b/send_queue_test.go @@ -3,8 +3,6 @@ package quic import ( "errors" - "github.com/quic-go/quic-go/internal/protocol" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "go.uber.org/mock/gomock" @@ -31,7 +29,7 @@ var _ = Describe("Send Queue", func() { q.Send(p, 10) // make sure the packet size is passed through to the conn written := make(chan struct{}) - c.EXPECT().Write([]byte("foobar"), protocol.ByteCount(10)).Do(func([]byte, protocol.ByteCount) { close(written) }) + c.EXPECT().Write([]byte("foobar"), uint16(10)).Do(func([]byte, uint16) { close(written) }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -79,7 +77,7 @@ var _ = Describe("Send Queue", func() { write := make(chan struct{}, 1) written := make(chan struct{}, 100) // now start sending out packets. This should free up queue space. - c.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, protocol.ByteCount) error { + c.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, uint16) error { written <- struct{}{} <-write return nil @@ -149,7 +147,7 @@ var _ = Describe("Send Queue", func() { It("blocks Close() until the packet has been sent out", func() { written := make(chan []byte) - c.EXPECT().Write(gomock.Any(), gomock.Any()).Do(func(p []byte, _ protocol.ByteCount) { written <- p }) + c.EXPECT().Write(gomock.Any(), gomock.Any()).Do(func(p []byte, _ uint16) { written <- p }) done := make(chan struct{}) go func() { defer GinkgoRecover() diff --git a/server.go b/server.go index 14cc6f82..495ca65a 100644 --- a/server.go +++ b/server.go @@ -745,7 +745,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0) return err } @@ -844,7 +844,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0) return err } @@ -882,7 +882,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { if s.tracer != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/sys_conn.go b/sys_conn.go index f2224e4c..a72aead5 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -104,7 +104,10 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16) (n int, err error) { + if gsoSize != 0 { + panic("cannot use GSO with a basicConn") + } return c.PacketConn.WriteTo(b, addr) } diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 4026a7b3..24b73e9c 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -228,8 +228,14 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { } // WritePacket writes a new packet. -// If the connection supports GSO, it's the caller's responsibility to append the right control mesage. -func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) { +func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16) (int, error) { + oob := packetInfoOOB + if gsoSize > 0 { + if !c.capabilities().GSO { + panic("GSO disabled") + } + oob = appendUDPSegmentSizeMsg(oob, gsoSize) + } n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err } diff --git a/sys_conn_oob_test.go b/sys_conn_oob_test.go index 3ae97ed9..54dac82d 100644 --- a/sys_conn_oob_test.go +++ b/sys_conn_oob_test.go @@ -18,6 +18,16 @@ import ( "go.uber.org/mock/gomock" ) +type oobRecordingConn struct { + *net.UDPConn + oobs [][]byte +} + +func (c *oobRecordingConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { + c.oobs = append(c.oobs, oob) + return c.UDPConn.WriteMsgUDP(b, oob, addr) +} + var _ = Describe("OOB Conn Test", func() { runServer := func(network, address string) (*net.UDPConn, <-chan receivedPacket) { addr, err := net.ResolveUDPAddr(network, address) @@ -242,4 +252,28 @@ var _ = Describe("OOB Conn Test", func() { } }) }) + + if platformSupportsGSO { + Context("GSO", func() { + It("appends the GSO control message", func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + c := &oobRecordingConn{UDPConn: udpConn} + oobConn, err := newConn(c, true) + Expect(err).ToNot(HaveOccurred()) + Expect(oobConn.capabilities().GSO).To(BeTrue()) + + oob := make([]byte, 0, 42) + oobConn.WritePacket([]byte("foobar"), addr, oob, 3) + Expect(c.oobs).To(HaveLen(1)) + oobMsg := c.oobs[0] + Expect(oobMsg).ToNot(BeEmpty()) + Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob + expected := appendUDPSegmentSizeMsg([]byte{}, 3) + Expect(oobMsg).To(Equal(expected)) + }) + }) + } }) diff --git a/transport.go b/transport.go index 9f93cbb9..41c97347 100644 --- a/transport.go +++ b/transport.go @@ -228,7 +228,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { return 0, err } - return t.conn.WritePacket(b, addr, nil) + return t.conn.WritePacket(b, addr, nil, 0) } func (t *Transport) enqueueClosePacket(p closePacket) { @@ -246,7 +246,7 @@ func (t *Transport) runSendQueue() { case <-t.listening: return case p := <-t.closeQueue: - t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) + t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } @@ -414,7 +414,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } } From 6cde43785f033f314e9c35fef7e0f56ec628e709 Mon Sep 17 00:00:00 2001 From: Ameagari <47713057+tanghaowillow@users.noreply.github.com> Date: Sat, 2 Sep 2023 10:40:35 +0800 Subject: [PATCH 03/43] integration tests: fix connection timeout in 0-RTT test (#4060) --- integrationtests/self/zero_rtt_oldgo_test.go | 4 ++-- integrationtests/self/zero_rtt_test.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/integrationtests/self/zero_rtt_oldgo_test.go b/integrationtests/self/zero_rtt_oldgo_test.go index e5809286..3e9277b9 100644 --- a/integrationtests/self/zero_rtt_oldgo_test.go +++ b/integrationtests/self/zero_rtt_oldgo_test.go @@ -852,12 +852,12 @@ var _ = Describe("0-RTT", func() { Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) Expect(receivedMessage).To(Equal(sentMessage)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) Expect(zeroRTTPackets).To(HaveLen(1)) - Expect(conn.CloseWithError(0, "")).To(Succeed()) }) It("rejects 0-RTT datagrams when the server doesn't support datagrams anymore", func() { @@ -907,10 +907,10 @@ var _ = Describe("0-RTT", func() { Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) - Expect(conn.CloseWithError(0, "")).To(Succeed()) }) }) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 011687ae..c9bdeff6 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -989,13 +989,13 @@ var _ = Describe("0-RTT", func() { <-received Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(receivedMessage).To(Equal(sentMessage)) num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) Expect(zeroRTTPackets).To(HaveLen(1)) - Expect(conn.CloseWithError(0, "")).To(Succeed()) }) It("rejects 0-RTT datagrams when the server doesn't support datagrams anymore", func() { @@ -1047,10 +1047,10 @@ var _ = Describe("0-RTT", func() { Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) - Expect(conn.CloseWithError(0, "")).To(Succeed()) }) }) From 96b1943cf58d228891e5298c429de85bdefd1ef4 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 4 Sep 2023 11:45:41 +0700 Subject: [PATCH 04/43] ackhandler: rename variables to follow RFC 9002 terminology (#4062) --- internal/ackhandler/interfaces.go | 2 +- .../ackhandler/received_packet_handler.go | 10 ++--- .../ackhandler/received_packet_tracker.go | 44 +++++++++---------- .../received_packet_tracker_test.go | 14 +++--- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index ba95f8a9..9c4a5b80 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -44,7 +44,7 @@ type sentPacketTracker interface { // ReceivedPacketHandler handles ACKs needed to send for incoming packets type ReceivedPacketHandler interface { IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool - ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error + ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, ackEliciting bool) error DropPackets(protocol.EncryptionLevel) GetAlarmTimeout() time.Time diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index e11ee1ee..b37f91c5 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -40,29 +40,29 @@ func (h *receivedPacketHandler) ReceivedPacket( ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, - shouldInstigateAck bool, + ackEliciting bool, ) error { h.sentPackets.ReceivedPacket(encLevel) switch encLevel { case protocol.EncryptionInitial: - return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) case protocol.EncryptionHandshake: // The Handshake packet number space might already have been dropped as a result // of processing the CRYPTO frame that was contained in this packet. if h.handshakePackets == nil { return nil } - return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) case protocol.Encryption0RTT: if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket { return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket) } - return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) case protocol.Encryption1RTT: if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket { h.lowest1RTTPacket = pn } - if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck); err != nil { + if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil { return err } h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked()) diff --git a/internal/ackhandler/received_packet_tracker.go b/internal/ackhandler/received_packet_tracker.go index b1883866..0ed929fa 100644 --- a/internal/ackhandler/received_packet_tracker.go +++ b/internal/ackhandler/received_packet_tracker.go @@ -13,10 +13,10 @@ import ( const packetsBeforeAck = 2 type receivedPacketTracker struct { - largestObserved protocol.PacketNumber - ignoreBelow protocol.PacketNumber - largestObservedReceivedTime time.Time - ect0, ect1, ecnce uint64 + largestObserved protocol.PacketNumber + ignoreBelow protocol.PacketNumber + largestObservedRcvdTime time.Time + ect0, ect1, ecnce uint64 packetHistory *receivedPacketHistory @@ -45,22 +45,22 @@ func newReceivedPacketTracker( } } -func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, shouldInstigateAck bool) error { - if isNew := h.packetHistory.ReceivedPacket(packetNumber); !isNew { - return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", packetNumber) +func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error { + if isNew := h.packetHistory.ReceivedPacket(pn); !isNew { + return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn) } - isMissing := h.isMissing(packetNumber) - if packetNumber >= h.largestObserved { - h.largestObserved = packetNumber - h.largestObservedReceivedTime = rcvTime + isMissing := h.isMissing(pn) + if pn >= h.largestObserved { + h.largestObserved = pn + h.largestObservedRcvdTime = rcvTime } - if shouldInstigateAck { + if ackEliciting { h.hasNewAck = true } - if shouldInstigateAck { - h.maybeQueueAck(packetNumber, rcvTime, isMissing) + if ackEliciting { + h.maybeQueueACK(pn, rcvTime, isMissing) } switch ecn { case protocol.ECNNon: @@ -76,14 +76,14 @@ func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumbe // IgnoreBelow sets a lower limit for acknowledging packets. // Packets with packet numbers smaller than p will not be acked. -func (h *receivedPacketTracker) IgnoreBelow(p protocol.PacketNumber) { - if p <= h.ignoreBelow { +func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) { + if pn <= h.ignoreBelow { return } - h.ignoreBelow = p - h.packetHistory.DeleteBelow(p) + h.ignoreBelow = pn + h.packetHistory.DeleteBelow(pn) if h.logger.Debug() { - h.logger.Debugf("\tIgnoring all packets below %d.", p) + h.logger.Debugf("\tIgnoring all packets below %d.", pn) } } @@ -103,8 +103,8 @@ func (h *receivedPacketTracker) hasNewMissingPackets() bool { return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1 } -// maybeQueueAck queues an ACK, if necessary. -func (h *receivedPacketTracker) maybeQueueAck(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) { +// maybeQueueACK queues an ACK, if necessary. +func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) { // always acknowledge the first packet if h.lastAck == nil { if !h.ackQueued { @@ -175,7 +175,7 @@ func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { ack = &wire.AckFrame{} } ack.Reset() - ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedReceivedTime)) + ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedRcvdTime)) ack.ECT0 = h.ect0 ack.ECT1 = h.ect1 ack.ECNCE = h.ecnce diff --git a/internal/ackhandler/received_packet_tracker_test.go b/internal/ackhandler/received_packet_tracker_test.go index 4818bd09..8c76f207 100644 --- a/internal/ackhandler/received_packet_tracker_test.go +++ b/internal/ackhandler/received_packet_tracker_test.go @@ -25,26 +25,26 @@ var _ = Describe("Received Packet Tracker", func() { Context("accepting packets", func() { It("saves the time when each packet arrived", func() { Expect(tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, time.Now(), true)).To(Succeed()) - Expect(tracker.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) + Expect(tracker.largestObservedRcvdTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) }) - It("updates the largestObserved and the largestObservedReceivedTime", func() { + It("updates the largestObserved and the largestObservedRcvdTime", func() { now := time.Now() tracker.largestObserved = 3 - tracker.largestObservedReceivedTime = now.Add(-1 * time.Second) + tracker.largestObservedRcvdTime = now.Add(-1 * time.Second) Expect(tracker.ReceivedPacket(5, protocol.ECNNon, now, true)).To(Succeed()) Expect(tracker.largestObserved).To(Equal(protocol.PacketNumber(5))) - Expect(tracker.largestObservedReceivedTime).To(Equal(now)) + Expect(tracker.largestObservedRcvdTime).To(Equal(now)) }) - It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() { + It("doesn't update the largestObserved and the largestObservedRcvdTime for a belated packet", func() { now := time.Now() timestamp := now.Add(-1 * time.Second) tracker.largestObserved = 5 - tracker.largestObservedReceivedTime = timestamp + tracker.largestObservedRcvdTime = timestamp Expect(tracker.ReceivedPacket(4, protocol.ECNNon, now, true)).To(Succeed()) Expect(tracker.largestObserved).To(Equal(protocol.PacketNumber(5))) - Expect(tracker.largestObservedReceivedTime).To(Equal(timestamp)) + Expect(tracker.largestObservedRcvdTime).To(Equal(timestamp)) }) }) From 591d864e5ee1953b181644f15abc5fecda2b9fdc Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 5 Sep 2023 17:47:05 +0700 Subject: [PATCH 05/43] ci: update GitHub checkout and setup-go actions to v4 (#4067) --- .github/workflows/build-interop-docker.yml | 2 +- .github/workflows/cross-compile.yml | 4 ++-- .github/workflows/go-generate.yml | 4 ++-- .github/workflows/integration.yml | 4 ++-- .github/workflows/lint.yml | 8 ++++---- .github/workflows/unit.yml | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/workflows/build-interop-docker.yml b/.github/workflows/build-interop-docker.yml index bef6f205..8aaac826 100644 --- a/.github/workflows/build-interop-docker.yml +++ b/.github/workflows/build-interop-docker.yml @@ -10,7 +10,7 @@ jobs: interop: runs-on: ${{ fromJSON(vars['DOCKER_RUNNER_UBUNTU'] || '"ubuntu-latest"') }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up QEMU uses: docker/setup-qemu-action@v2 - name: Set up Docker Buildx diff --git a/.github/workflows/cross-compile.yml b/.github/workflows/cross-compile.yml index 46309b0a..1e8d16b9 100644 --- a/.github/workflows/cross-compile.yml +++ b/.github/workflows/cross-compile.yml @@ -8,8 +8,8 @@ jobs: runs-on: ${{ fromJSON(vars['CROSS_COMPILE_RUNNER_UBUNTU'] || '"ubuntu-latest"') }} name: "Cross Compilation (Go ${{matrix.go}})" steps: - - uses: actions/checkout@v3 - - uses: actions/setup-go@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Install build utils diff --git a/.github/workflows/go-generate.yml b/.github/workflows/go-generate.yml index 441f442f..4add84e3 100644 --- a/.github/workflows/go-generate.yml +++ b/.github/workflows/go-generate.yml @@ -3,8 +3,8 @@ jobs: gogenerate: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-go@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 with: go-version: "1.20.x" - name: Install dependencies diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 0e76cba8..90ed611b 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -18,8 +18,8 @@ jobs: TIMESCALE_FACTOR: 3 name: Integration Tests (${{ matrix.os }}, Go ${{ matrix.go }}) steps: - - uses: actions/checkout@v3 - - uses: actions/setup-go@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - run: go version diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index dfde7636..1c74c53b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -4,8 +4,8 @@ jobs: check: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-go@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 with: skip-pkg-cache: true go-version: "1.20.x" @@ -25,8 +25,8 @@ jobs: golangci-lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-go@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 with: go-version: "1.20.x" - name: golangci-lint (Linux) diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index c9eff032..6dd625d2 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -11,8 +11,8 @@ jobs: runs-on: ${{ fromJSON(vars[format('UNIT_RUNNER_{0}', matrix.os)] || format('"{0}-latest"', matrix.os)) }} name: Unit tests (${{ matrix.os}}, Go ${{ matrix.go }}) steps: - - uses: actions/checkout@v3 - - uses: actions/setup-go@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - run: go version From 6cac231f6a4d7e3241f3409962bd1a7819b4e1da Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 6 Sep 2023 23:02:33 +0700 Subject: [PATCH 06/43] update qtls-go1-20 to v0.3.4 (#4068) --- go.mod | 2 +- go.sum | 4 ++-- integrationtests/gomodvendor/go.sum | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index fac9c9fb..936a3786 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/onsi/ginkgo/v2 v2.9.5 github.com/onsi/gomega v1.27.6 github.com/quic-go/qpack v0.4.0 - github.com/quic-go/qtls-go1-20 v0.3.3 + github.com/quic-go/qtls-go1-20 v0.3.4 go.uber.org/mock v0.2.0 golang.org/x/crypto v0.4.0 golang.org/x/exp v0.0.0-20221205204356-47842c84f3db diff --git a/go.sum b/go.sum index 28cab718..8a3bc655 100644 --- a/go.sum +++ b/go.sum @@ -88,8 +88,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM= -github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= +github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg= +github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index e24232c0..5e207944 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -134,8 +134,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM= -github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= +github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg= +github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= From dc0369cad423db07420a4e54bcf35d8b4e399bba Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 7 Sep 2023 11:27:03 +0700 Subject: [PATCH 07/43] remove TLS post-handshake message reassembly logic (#4073) Go 1.21.1 was released, which fixed the bug that made this workaround necessary. --- connection.go | 8 ++++---- crypto_stream.go | 25 ++----------------------- crypto_stream_test.go | 22 +--------------------- 3 files changed, 7 insertions(+), 48 deletions(-) diff --git a/connection.go b/connection.go index eb80cd13..42b6c122 100644 --- a/connection.go +++ b/connection.go @@ -243,7 +243,7 @@ var newConnection = func( handshakeDestConnID: destConnID, srcConnIDLen: srcConnID.Len(), tokenGenerator: tokenGenerator, - oneRTTStream: newCryptoStream(true), + oneRTTStream: newCryptoStream(), perspective: protocol.PerspectiveServer, tracer: tracer, logger: logger, @@ -391,7 +391,7 @@ var newClientConnection = func( s.logger, ) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) - oneRTTStream := newCryptoStream(true) + oneRTTStream := newCryptoStream() params := &wire.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -447,8 +447,8 @@ var newClientConnection = func( } func (s *connection) preSetup() { - s.initialStream = newCryptoStream(false) - s.handshakeStream = newCryptoStream(false) + s.initialStream = newCryptoStream() + s.handshakeStream = newCryptoStream() s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue() s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) diff --git a/crypto_stream.go b/crypto_stream.go index 5ce2125d..4be2a07a 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -30,17 +30,10 @@ type cryptoStreamImpl struct { writeOffset protocol.ByteCount writeBuf []byte - - // Reassemble TLS handshake messages before returning them from GetCryptoData. - // This is only needed because crypto/tls doesn't correctly handle post-handshake messages. - onlyCompleteMsg bool } -func newCryptoStream(onlyCompleteMsg bool) cryptoStream { - return &cryptoStreamImpl{ - queue: newFrameSorter(), - onlyCompleteMsg: onlyCompleteMsg, - } +func newCryptoStream() cryptoStream { + return &cryptoStreamImpl{queue: newFrameSorter()} } func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { @@ -78,20 +71,6 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { // GetCryptoData retrieves data that was received in CRYPTO frames func (s *cryptoStreamImpl) GetCryptoData() []byte { - if s.onlyCompleteMsg { - if len(s.msgBuf) < 4 { - return nil - } - msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3]) - if len(s.msgBuf) < msgLen { - return nil - } - msg := make([]byte, msgLen) - copy(msg, s.msgBuf[:msgLen]) - s.msgBuf = s.msgBuf[msgLen:] - return msg - } - b := s.msgBuf s.msgBuf = nil return b diff --git a/crypto_stream_test.go b/crypto_stream_test.go index 67de9149..9a4a2ee5 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -1,7 +1,6 @@ package quic import ( - "crypto/rand" "fmt" "github.com/quic-go/quic-go/internal/protocol" @@ -16,7 +15,7 @@ var _ = Describe("Crypto Stream", func() { var str cryptoStream BeforeEach(func() { - str = newCryptoStream(false) + str = newCryptoStream() }) Context("handling incoming data", func() { @@ -138,23 +137,4 @@ var _ = Describe("Crypto Stream", func() { Expect(f.Data).To(Equal([]byte("bar"))) }) }) - - It("reassembles data", func() { - str = newCryptoStream(true) - data := make([]byte, 1337) - l := len(data) - 4 - data[1] = uint8(l >> 16) - data[2] = uint8(l >> 8) - data[3] = uint8(l) - rand.Read(data[4:]) - - for i, b := range data { - Expect(str.GetCryptoData()).To(BeEmpty()) - Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ - Offset: protocol.ByteCount(i), - Data: []byte{b}, - })).To(Succeed()) - } - Expect(str.GetCryptoData()).To(Equal(data)) - }) }) From 54b76ceb3e91bd1f153c47d30e90891c75d83cb8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Sep 2023 20:12:19 +0700 Subject: [PATCH 08/43] ackhandler: use the receive time of the Retry packet for RTT estimation (#4070) --- connection.go | 6 ++-- connection_test.go | 9 ++++-- internal/ackhandler/interfaces.go | 4 +-- internal/ackhandler/sent_packet_handler.go | 3 +- .../ackhandler/sent_packet_handler_test.go | 28 +++++++++++-------- .../mocks/ackhandler/sent_packet_handler.go | 8 +++--- 6 files changed, 32 insertions(+), 26 deletions(-) diff --git a/connection.go b/connection.go index 42b6c122..1f3dadfd 100644 --- a/connection.go +++ b/connection.go @@ -927,7 +927,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) }() if hdr.Type == protocol.PacketTypeRetry { - return s.handleRetryPacket(hdr, p.data) + return s.handleRetryPacket(hdr, p.data, p.rcvTime) } // The server can change the source connection ID with the first Handshake packet. @@ -1013,7 +1013,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P return false } -func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ { +func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime time.Time) bool /* was this a valid Retry */ { if s.perspective == protocol.PerspectiveServer { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) @@ -1062,7 +1062,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa } newDestConnID := hdr.SrcConnectionID s.receivedRetry = true - if err := s.sentPacketHandler.ResetForRetry(); err != nil { + if err := s.sentPacketHandler.ResetForRetry(rcvTime); err != nil { s.closeLocal(err) return false } diff --git a/connection_test.go b/connection_test.go index eb512513..8b7c763f 100644 --- a/connection_test.go +++ b/connection_test.go @@ -2767,9 +2767,10 @@ var _ = Describe("Client Connection", func() { } It("handles Retry packets", func() { + now := time.Now() sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) conn.sentPacketHandler = sph - sph.EXPECT().ResetForRetry() + sph.EXPECT().ResetForRetry(now) sph.EXPECT().ReceivedBytes(gomock.Any()) cryptoSetup.EXPECT().ChangeConnectionID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})) packer.EXPECT().SetToken([]byte("foobar")) @@ -2778,7 +2779,9 @@ var _ = Describe("Client Connection", func() { Expect(hdr.SrcConnectionID).To(Equal(retryHdr.SrcConnectionID)) Expect(hdr.Token).To(Equal(retryHdr.Token)) }) - Expect(conn.handlePacketImpl(getPacket(retryHdr, getRetryTag(retryHdr)))).To(BeTrue()) + p := getPacket(retryHdr, getRetryTag(retryHdr)) + p.rcvTime = now + Expect(conn.handlePacketImpl(p)).To(BeTrue()) }) It("ignores Retry packets after receiving a regular packet", func() { @@ -3143,7 +3146,7 @@ var _ = Describe("Client Connection", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) conn.sentPacketHandler = sph sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2) - sph.EXPECT().ResetForRetry() + sph.EXPECT().ResetForRetry(gomock.Any()) newSrcConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID) packer.EXPECT().SetToken([]byte("foobar")) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 9c4a5b80..7b4eaa31 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -13,10 +13,10 @@ type SentPacketHandler interface { SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, isPathMTUProbePacket bool) // ReceivedAck processes an ACK frame. // It does not store a copy of the frame. - ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) + ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error) ReceivedBytes(protocol.ByteCount) DropPackets(protocol.EncryptionLevel) - ResetForRetry() error + ResetForRetry(rcvTime time.Time) error SetHandshakeConfirmed() // The SendMode determines if and what kind of packets can be sent. diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index c955da6e..a03d9f53 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -824,7 +824,7 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) { p.Frames = nil } -func (h *sentPacketHandler) ResetForRetry() error { +func (h *sentPacketHandler) ResetForRetry(now time.Time) error { h.bytesInFlight = 0 var firstPacketSendTime time.Time h.initialPackets.history.Iterate(func(p *packet) (bool, error) { @@ -850,7 +850,6 @@ func (h *sentPacketHandler) ResetForRetry() error { // Otherwise, we don't know which Initial the Retry was sent in response to. if h.ptoCount == 0 { // Don't set the RTT to a value lower than 5ms here. - now := time.Now() h.rttStats.UpdateRTT(utils.Max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) if h.logger.Debug() { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index fc120171..396014d7 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -1334,7 +1334,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.bytesInFlight).ToNot(BeZero()) Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) // now receive a Retry - Expect(handler.ResetForRetry()).To(Succeed()) + Expect(handler.ResetForRetry(time.Now())).To(Succeed()) Expect(lostPackets).To(Equal([]protocol.PacketNumber{42})) Expect(handler.bytesInFlight).To(BeZero()) Expect(handler.GetLossDetectionTimeout()).To(BeZero()) @@ -1369,7 +1369,7 @@ var _ = Describe("SentPacketHandler", func() { }) Expect(handler.bytesInFlight).ToNot(BeZero()) // now receive a Retry - Expect(handler.ResetForRetry()).To(Succeed()) + Expect(handler.ResetForRetry(time.Now())).To(Succeed()) Expect(handler.bytesInFlight).To(BeZero()) Expect(lostInitial).To(BeTrue()) Expect(lost0RTT).To(BeTrue()) @@ -1379,49 +1379,53 @@ var _ = Describe("SentPacketHandler", func() { }) It("uses a Retry for an RTT estimate, if it was not retransmitted", func() { + now := time.Now() sentPacket(ackElicitingPacket(&packet{ PacketNumber: 42, EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-500 * time.Millisecond), + SendTime: now, })) sentPacket(ackElicitingPacket(&packet{ PacketNumber: 43, EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-10 * time.Millisecond), + SendTime: now.Add(500 * time.Millisecond), })) - Expect(handler.ResetForRetry()).To(Succeed()) - Expect(handler.rttStats.SmoothedRTT()).To(BeNumerically("~", 500*time.Millisecond, 100*time.Millisecond)) + Expect(handler.ResetForRetry(now.Add(time.Second))).To(Succeed()) + Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Second)) }) It("uses a Retry for an RTT estimate, but doesn't set the RTT to a value lower than 5ms", func() { + now := time.Now() sentPacket(ackElicitingPacket(&packet{ PacketNumber: 42, EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-500 * time.Microsecond), + SendTime: now, })) sentPacket(ackElicitingPacket(&packet{ PacketNumber: 43, EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-10 * time.Microsecond), + SendTime: now.Add(2 * time.Millisecond), })) - Expect(handler.ResetForRetry()).To(Succeed()) + Expect(handler.ResetForRetry(now.Add(4 * time.Millisecond))).To(Succeed()) + Expect(minRTTAfterRetry).To(BeNumerically(">", 4*time.Millisecond)) Expect(handler.rttStats.SmoothedRTT()).To(Equal(minRTTAfterRetry)) }) It("doesn't use a Retry for an RTT estimate, if it was not retransmitted", func() { + now := time.Now() sentPacket(ackElicitingPacket(&packet{ PacketNumber: 42, EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-800 * time.Millisecond), + SendTime: now, })) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOInitial)) sentPacket(ackElicitingPacket(&packet{ PacketNumber: 43, EncryptionLevel: protocol.EncryptionInitial, - SendTime: time.Now().Add(-100 * time.Millisecond), + SendTime: now.Add(500 * time.Millisecond), })) - Expect(handler.ResetForRetry()).To(Succeed()) + Expect(handler.ResetForRetry(now.Add(time.Second))).To(Succeed()) Expect(handler.rttStats.SmoothedRTT()).To(BeZero()) }) }) diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 80aaae0f..24f5a157 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -148,17 +148,17 @@ func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 interface{}) *go } // ResetForRetry mocks base method. -func (m *MockSentPacketHandler) ResetForRetry() error { +func (m *MockSentPacketHandler) ResetForRetry(arg0 time.Time) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ResetForRetry") + ret := m.ctrl.Call(m, "ResetForRetry", arg0) ret0, _ := ret[0].(error) return ret0 } // ResetForRetry indicates an expected call of ResetForRetry. -func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry() *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetForRetry", reflect.TypeOf((*MockSentPacketHandler)(nil).ResetForRetry)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetForRetry", reflect.TypeOf((*MockSentPacketHandler)(nil).ResetForRetry), arg0) } // SendMode mocks base method. From e1fcac3e4682b27880b1d0d312cc947a59bd0086 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Sep 2023 20:12:37 +0700 Subject: [PATCH 09/43] set the handshake timeout to twice the handshake idle timeout (#4063) --- config.go | 3 +-- config_test.go | 7 +------ connection_test.go | 4 ++-- integrationtests/self/timeout_test.go | 2 +- interface.go | 3 ++- internal/protocol/params.go | 3 --- 6 files changed, 7 insertions(+), 15 deletions(-) diff --git a/config.go b/config.go index 59a4d922..35b7c897 100644 --- a/config.go +++ b/config.go @@ -6,7 +6,6 @@ import ( "time" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/quicvarint" ) @@ -17,7 +16,7 @@ func (c *Config) Clone() *Config { } func (c *Config) handshakeTimeout() time.Duration { - return utils.Max(protocol.DefaultHandshakeTimeout, 2*c.HandshakeIdleTimeout) + return 2 * c.HandshakeIdleTimeout } func validateConfig(config *Config) error { diff --git a/config_test.go b/config_test.go index 7208b4ad..87386122 100644 --- a/config_test.go +++ b/config_test.go @@ -115,12 +115,7 @@ var _ = Describe("Config", func() { return c } - It("uses 10s handshake timeout for short handshake idle timeouts", func() { - c := &Config{HandshakeIdleTimeout: time.Second} - Expect(c.handshakeTimeout()).To(Equal(protocol.DefaultHandshakeTimeout)) - }) - - It("uses twice the handshake idle timeouts for the handshake timeout, for long handshake idle timeouts", func() { + It("uses twice the handshake idle timeouts for the handshake timeout", func() { c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2} Expect(c.handshakeTimeout()).To(Equal(11 * time.Second)) }) diff --git a/connection_test.go b/connection_test.go index 8b7c763f..4c02c492 100644 --- a/connection_test.go +++ b/connection_test.go @@ -2201,7 +2201,7 @@ var _ = Describe("Connection", func() { It("times out due to non-completed handshake", func() { conn.handshakeComplete = false - conn.creationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) + conn.creationTime = time.Now().Add(-2 * protocol.DefaultHandshakeIdleTimeout).Add(-time.Second) connRunner.EXPECT().Remove(gomock.Any()).Times(2) cryptoSetup.EXPECT().Close() gomock.InOrder( @@ -2261,7 +2261,7 @@ var _ = Describe("Connection", func() { }) It("closes the connection due to the idle timeout before handshake", func() { - conn.config.HandshakeIdleTimeout = 0 + conn.config.HandshakeIdleTimeout = scaleDuration(25 * time.Millisecond) packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index c5ec46d4..47569b77 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -57,7 +57,7 @@ var _ = Describe("Timeout tests", func() { context.Background(), "localhost:12345", getTLSClientConfig(), - getQuicConfig(&quic.Config{HandshakeIdleTimeout: 10 * time.Millisecond}), + getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(50 * time.Millisecond)}), ) errChan <- err }() diff --git a/interface.go b/interface.go index 5d5ab5b0..fadd868d 100644 --- a/interface.go +++ b/interface.go @@ -250,7 +250,8 @@ type Config struct { // If not set, it uses all versions available. Versions []VersionNumber // HandshakeIdleTimeout is the idle timeout before completion of the handshake. - // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. + // If we don't receive any packet from the peer within this time, the connection attempt is aborted. + // Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted. // If this value is zero, the timeout is set to 5 seconds. HandshakeIdleTimeout time.Duration // MaxIdleTimeout is the maximum duration that may pass without any incoming network activity. diff --git a/internal/protocol/params.go b/internal/protocol/params.go index fe3a7562..0f118c49 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -108,9 +108,6 @@ const DefaultIdleTimeout = 30 * time.Second // DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion. const DefaultHandshakeIdleTimeout = 5 * time.Second -// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds. -const DefaultHandshakeTimeout = 10 * time.Second - // MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive. // It should be shorter than the time that NATs clear their mapping. const MaxKeepAliveInterval = 20 * time.Second From abfe1ef548cc407102f0e3a201847c5043abd5d4 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 10 Sep 2023 13:53:12 +0700 Subject: [PATCH 10/43] remove Config.MaxRetryTokenAge, set it to the handshake timeout (#4064) There is no good reason to manually set the validity period for Retry tokens. Retry tokens are only valid on a single connection during the handshake, so it makes sense to limit their validity to the configured handshake timeout. --- config.go | 8 ++++---- config_test.go | 2 -- integrationtests/self/handshake_test.go | 20 ++++++++++++++++++-- interface.go | 3 --- internal/protocol/params.go | 3 --- server.go | 2 +- server_test.go | 3 ++- 7 files changed, 25 insertions(+), 16 deletions(-) diff --git a/config.go b/config.go index 35b7c897..dea118b3 100644 --- a/config.go +++ b/config.go @@ -19,6 +19,10 @@ func (c *Config) handshakeTimeout() time.Duration { return 2 * c.HandshakeIdleTimeout } +func (c *Config) maxRetryTokenAge() time.Duration { + return c.handshakeTimeout() +} + func validateConfig(config *Config) error { if config == nil { return nil @@ -52,9 +56,6 @@ func populateServerConfig(config *Config) *Config { if config.MaxTokenAge == 0 { config.MaxTokenAge = protocol.TokenValidity } - if config.MaxRetryTokenAge == 0 { - config.MaxRetryTokenAge = protocol.RetryTokenValidity - } if config.RequireAddressValidation == nil { config.RequireAddressValidation = func(net.Addr) bool { return false } } @@ -114,7 +115,6 @@ func populateConfig(config *Config) *Config { HandshakeIdleTimeout: handshakeIdleTimeout, MaxIdleTimeout: idleTimeout, MaxTokenAge: config.MaxTokenAge, - MaxRetryTokenAge: config.MaxRetryTokenAge, RequireAddressValidation: config.RequireAddressValidation, KeepAlivePeriod: config.KeepAlivePeriod, InitialStreamReceiveWindow: initialStreamReceiveWindow, diff --git a/config_test.go b/config_test.go index 87386122..9cb665a3 100644 --- a/config_test.go +++ b/config_test.go @@ -80,8 +80,6 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf(time.Hour)) case "MaxTokenAge": f.Set(reflect.ValueOf(2 * time.Hour)) - case "MaxRetryTokenAge": - f.Set(reflect.ValueOf(2 * time.Minute)) case "TokenStore": f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) case "InitialStreamReceiveWindow": diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 8123c8fe..f90ac2fd 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/quic-go/quic-go" + quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/qtls" @@ -454,16 +455,31 @@ var _ = Describe("Handshake tests", func() { }) It("rejects invalid Retry token with the INVALID_TOKEN error", func() { + const rtt = 10 * time.Millisecond serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } - serverConfig.MaxRetryTokenAge = -time.Second + // The validity period of the retry token is the handshake timeout, + // which is twice the handshake idle timeout. + // By setting the handshake timeout shorter than the RTT, the token will have expired by the time + // it reaches the server. + serverConfig.HandshakeIdleTimeout = rtt / 5 server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) defer server.Close() + serverPort := server.Addr().(*net.UDPAddr).Port + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + _, err = quic.DialAddr( context.Background(), - fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), nil, ) diff --git a/interface.go b/interface.go index fadd868d..8faf0eb8 100644 --- a/interface.go +++ b/interface.go @@ -265,9 +265,6 @@ type Config struct { // See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details. // If not set, every client is forced to prove its remote address. RequireAddressValidation func(net.Addr) bool - // MaxRetryTokenAge is the maximum age of a Retry token. - // If not set, it defaults to 5 seconds. Only valid for a server. - MaxRetryTokenAge time.Duration // MaxTokenAge is the maximum age of the token presented during the handshake, // for tokens that were issued on a previous connection. // If not set, it defaults to 24 hours. Only valid for a server. diff --git a/internal/protocol/params.go b/internal/protocol/params.go index 0f118c49..3ca68bf8 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -65,9 +65,6 @@ const MaxAcceptQueueSize = 32 // TokenValidity is the duration that a (non-retry) token is considered valid const TokenValidity = 24 * time.Hour -// RetryTokenValidity is the duration that a retry token is considered valid -const RetryTokenValidity = 10 * time.Second - // MaxOutstandingSentPackets is maximum number of packets saved for retransmission. // When reached, it imposes a soft limit on sending new packets: // Sending ACKs and retransmission is still allowed, but now new regular packets can be sent. diff --git a/server.go b/server.go index 495ca65a..92b5e91a 100644 --- a/server.go +++ b/server.go @@ -531,7 +531,7 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { if !token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxTokenAge { return false } - if token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxRetryTokenAge { + if token.IsRetryToken && time.Since(token.SentTime) > s.config.maxRetryTokenAge() { return false } return true diff --git a/server_test.go b/server_test.go index e959aee5..31a97764 100644 --- a/server_test.go +++ b/server_test.go @@ -833,7 +833,8 @@ var _ = Describe("Server", func() { It("sends an INVALID_TOKEN error, if an expired retry token is received", func() { serv.config.RequireAddressValidation = func(net.Addr) bool { return true } - serv.config.MaxRetryTokenAge = time.Millisecond + serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout + Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond)) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) From a7f807856cf74836e6734ea04620679a21fcaf95 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 11 Sep 2023 11:49:29 +0700 Subject: [PATCH 11/43] randomize the serialization order of control frames (#4069) * randomize the serialization order of control frames * add comment for packetPacker.appendPacketPayload --- packet_packer.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/packet_packer.go b/packet_packer.go index 577f3b04..0483f35b 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -1,9 +1,13 @@ package quic import ( + crand "crypto/rand" + "encoding/binary" "errors" "fmt" + "golang.org/x/exp/rand" + "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" @@ -122,6 +126,7 @@ type packetPacker struct { acks ackFrameSource datagramQueue *datagramQueue retransmissionQueue *retransmissionQueue + rand rand.Rand numNonAckElicitingAcks int } @@ -140,6 +145,9 @@ func newPacketPacker( datagramQueue *datagramQueue, perspective protocol.Perspective, ) *packetPacker { + var b [8]byte + _, _ = crand.Read(b[:]) + return &packetPacker{ cryptoSetup: cryptoSetup, getDestConnID: getDestConnID, @@ -151,6 +159,7 @@ func newPacketPacker( perspective: perspective, framer: framer, acks: acks, + rand: *rand.New(rand.NewSource(binary.BigEndian.Uint64(b[:]))), pnManager: packetNumberManager, } } @@ -832,6 +841,8 @@ func (p *packetPacker) appendShortHeaderPacket( }, nil } +// appendPacketPayload serializes the payload of a packet into the raw byte slice. +// It modifies the order of payload.frames. func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { payloadOffset := len(raw) if pl.ack != nil { @@ -844,6 +855,11 @@ func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen pr if paddingLen > 0 { raw = append(raw, make([]byte, paddingLen)...) } + // Randomize the order of the control frames. + // This makes sure that the receiver doesn't rely on the order in which frames are packed. + if len(pl.frames) > 1 { + p.rand.Shuffle(len(pl.frames), func(i, j int) { pl.frames[i], pl.frames[j] = pl.frames[j], pl.frames[i] }) + } for _, f := range pl.frames { var err error raw, err = f.Frame.Append(raw, v) From f919473598e73a57dc7dde987f33b9d2822e8893 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 30 Jul 2023 13:42:15 -0400 Subject: [PATCH 12/43] add support for writing the ECN control message (Linux, macOS) --- send_conn.go | 13 ++++++++++--- sys_conn_df_windows.go | 4 ++++ sys_conn_helper_darwin.go | 2 ++ sys_conn_helper_freebsd.go | 2 ++ sys_conn_helper_linux.go | 2 ++ sys_conn_no_oob.go | 5 +++++ sys_conn_oob.go | 30 ++++++++++++++++++++++++++++++ sys_conn_oob_test.go | 36 ++++++++++++++++++++++++++++++++++++ 8 files changed, 91 insertions(+), 3 deletions(-) diff --git a/send_conn.go b/send_conn.go index 272c61b1..030e0fba 100644 --- a/send_conn.go +++ b/send_conn.go @@ -3,6 +3,7 @@ package quic import ( "net" + "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) @@ -42,10 +43,16 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge } oob := info.OOB() - // add 32 bytes, so we can add the UDP_SEGMENT msg + if remoteUDPAddr, ok := remote.(*net.UDPAddr); ok { + if remoteUDPAddr.IP.To4() != nil { + oob = appendIPv4ECNMsg(oob, protocol.ECT1) + } else { + oob = appendIPv6ECNMsg(oob, protocol.ECT1) + } + } + // increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating l := len(oob) - oob = append(oob, make([]byte, 32)...) - oob = oob[:l] + oob = append(oob, make([]byte, 64)...)[:l] return &sconn{ rawConn: c, localAddr: localAddr, diff --git a/sys_conn_df_windows.go b/sys_conn_df_windows.go index e27635ec..e56b7460 100644 --- a/sys_conn_df_windows.go +++ b/sys_conn_df_windows.go @@ -8,6 +8,7 @@ import ( "golang.org/x/sys/windows" + "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) @@ -52,3 +53,6 @@ func isRecvMsgSizeErr(err error) bool { // https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2 return errors.Is(err, windows.WSAEMSGSIZE) } + +func appendIPv4ECNMsg([]byte, protocol.ECN) []byte { return nil } +func appendIPv6ECNMsg([]byte, protocol.ECN) []byte { return nil } diff --git a/sys_conn_helper_darwin.go b/sys_conn_helper_darwin.go index 758cf778..d761072f 100644 --- a/sys_conn_helper_darwin.go +++ b/sys_conn_helper_darwin.go @@ -15,6 +15,8 @@ const ( ipv4PKTINFO = unix.IP_RECVPKTINFO ) +const ecnIPv4DataLen = 4 + // ReadBatch only returns a single packet on OSX, // see https://godoc.org/golang.org/x/net/ipv4#PacketConn.ReadBatch. const batchSize = 1 diff --git a/sys_conn_helper_freebsd.go b/sys_conn_helper_freebsd.go index a2baae3b..5e635b18 100644 --- a/sys_conn_helper_freebsd.go +++ b/sys_conn_helper_freebsd.go @@ -14,6 +14,8 @@ const ( ipv4PKTINFO = 0x7 ) +const ecnIPv4DataLen = 4 + const batchSize = 8 func parseIPv4PktInfo(body []byte) (ip netip.Addr, _ uint32, ok bool) { diff --git a/sys_conn_helper_linux.go b/sys_conn_helper_linux.go index 6a049241..622f4e6f 100644 --- a/sys_conn_helper_linux.go +++ b/sys_conn_helper_linux.go @@ -19,6 +19,8 @@ const ( ipv4PKTINFO = unix.IP_PKTINFO ) +const ecnIPv4DataLen = 4 + const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed) func forceSetReceiveBuffer(c syscall.RawConn, bytes int) error { diff --git a/sys_conn_no_oob.go b/sys_conn_no_oob.go index 2a1f807e..c1fb6fb1 100644 --- a/sys_conn_no_oob.go +++ b/sys_conn_no_oob.go @@ -5,6 +5,8 @@ package quic import ( "net" "net/netip" + + "github.com/quic-go/quic-go/internal/protocol" ) func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) { @@ -14,6 +16,9 @@ func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) { func inspectReadBuffer(any) (int, error) { return 0, nil } func inspectWriteBuffer(any) (int, error) { return 0, nil } +func appendIPv4ECNMsg([]byte, protocol.ECN) []byte { return nil } +func appendIPv6ECNMsg([]byte, protocol.ECN) []byte { return nil } + type packetInfo struct { addr netip.Addr } diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 24b73e9c..6446d893 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -11,6 +11,7 @@ import ( "sync" "syscall" "time" + "unsafe" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -279,3 +280,32 @@ func (info *packetInfo) OOB() []byte { } return nil } + +func appendIPv4ECNMsg(b []byte, val protocol.ECN) []byte { + startLen := len(b) + b = append(b, make([]byte, unix.CmsgSpace(ecnIPv4DataLen))...) + h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) + h.Level = syscall.IPPROTO_IP + h.Type = unix.IP_TOS + h.SetLen(unix.CmsgLen(ecnIPv4DataLen)) + + // UnixRights uses the private `data` method, but I *think* this achieves the same goal. + offset := startLen + unix.CmsgSpace(0) + b[offset] = uint8(val) + return b +} + +func appendIPv6ECNMsg(b []byte, val protocol.ECN) []byte { + startLen := len(b) + const dataLen = 4 + b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) + h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) + h.Level = syscall.IPPROTO_IPV6 + h.Type = unix.IPV6_TCLASS + h.SetLen(unix.CmsgLen(dataLen)) + + // UnixRights uses the private `data` method, but I *think* this achieves the same goal. + offset := startLen + unix.CmsgSpace(0) + b[offset] = uint8(val) + return b +} diff --git a/sys_conn_oob_test.go b/sys_conn_oob_test.go index 54dac82d..9d0f5e9d 100644 --- a/sys_conn_oob_test.go +++ b/sys_conn_oob_test.go @@ -139,6 +139,42 @@ var _ = Describe("OOB Conn Test", func() { Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse()) Expect(p.ecn).To(Equal(protocol.ECT1)) }) + + It("sends packets with ECN on IPv4", func() { + conn, packetChan := runServer("udp4", "localhost:0") + defer conn.Close() + + c, err := net.ListenUDP("udp4", nil) + Expect(err).ToNot(HaveOccurred()) + defer c.Close() + + for _, val := range []protocol.ECN{protocol.ECNNon, protocol.ECT1, protocol.ECT0, protocol.ECNCE} { + _, _, err = c.WriteMsgUDP([]byte("foobar"), appendIPv4ECNMsg([]byte{}, val), conn.LocalAddr().(*net.UDPAddr)) + Expect(err).ToNot(HaveOccurred()) + var p receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(p.data).To(Equal([]byte("foobar"))) + Expect(p.ecn).To(Equal(val)) + } + }) + + It("sends packets with ECN on IPv6", func() { + conn, packetChan := runServer("udp6", "[::1]:0") + defer conn.Close() + + c, err := net.ListenUDP("udp6", nil) + Expect(err).ToNot(HaveOccurred()) + defer c.Close() + + for _, val := range []protocol.ECN{protocol.ECNNon, protocol.ECT1, protocol.ECT0, protocol.ECNCE} { + _, _, err = c.WriteMsgUDP([]byte("foobar"), appendIPv6ECNMsg([]byte{}, val), conn.LocalAddr().(*net.UDPAddr)) + Expect(err).ToNot(HaveOccurred()) + var p receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(p.data).To(Equal([]byte("foobar"))) + Expect(p.ecn).To(Equal(val)) + } + }) }) Context("Packet Info conn", func() { From 5dd6d91c11eb6ecf871c1a134ececb71c2e73a9e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 12 Aug 2023 10:08:40 +0800 Subject: [PATCH 13/43] send and track packets with ECN markings --- connection.go | 65 +++--- connection_test.go | 200 +++++++++--------- internal/ackhandler/interfaces.go | 4 +- internal/ackhandler/sent_packet_handler.go | 6 + .../ackhandler/sent_packet_handler_test.go | 2 +- .../mocks/ackhandler/sent_packet_handler.go | 22 +- mock_raw_conn_test.go | 9 +- mock_send_conn_test.go | 9 +- mock_sender_test.go | 9 +- packet_handler_map.go | 2 +- send_conn.go | 15 +- send_conn_test.go | 17 +- send_queue.go | 12 +- send_queue_test.go | 36 ++-- server.go | 6 +- sys_conn.go | 2 +- sys_conn_df_windows.go | 4 - sys_conn_no_oob.go | 5 - sys_conn_oob.go | 9 +- sys_conn_oob_test.go | 30 ++- transport.go | 6 +- 21 files changed, 264 insertions(+), 206 deletions(-) diff --git a/connection.go b/connection.go index 1f3dadfd..f234ea3f 100644 --- a/connection.go +++ b/connection.go @@ -1830,9 +1830,10 @@ func (s *connection) sendPackets(now time.Time) error { if err != nil { return err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) - s.registerPackedShortHeaderPacket(p, now) - s.sendQueue.Send(buf, 0) + ecn := s.sentPacketHandler.ECNMode() + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false) + s.registerPackedShortHeaderPacket(p, ecn, now) + s.sendQueue.Send(buf, 0, ecn) // This is kind of a hack. We need to trigger sending again somehow. s.pacingDeadline = deadlineSendImmediately return nil @@ -1852,7 +1853,7 @@ func (s *connection) sendPackets(now time.Time) error { return err } s.sentFirstPacket = true - if err := s.sendPackedCoalescedPacket(packet, now); err != nil { + if err := s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(), now); err != nil { return err } sendMode := s.sentPacketHandler.SendMode(now) @@ -1873,7 +1874,8 @@ func (s *connection) sendPackets(now time.Time) error { func (s *connection) sendPacketsWithoutGSO(now time.Time) error { for { buf := getPacketBuffer() - if _, err := s.appendPacket(buf, s.mtuDiscoverer.CurrentSize(), now); err != nil { + ecn := s.sentPacketHandler.ECNMode() + if _, err := s.appendOnePacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil { if err == errNothingToPack { buf.Release() return nil @@ -1881,7 +1883,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error { return err } - s.sendQueue.Send(buf, 0) + s.sendQueue.Send(buf, 0, ecn) if s.sendQueue.WouldBlock() { return nil @@ -1908,7 +1910,8 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { for { var dontSendMore bool - size, err := s.appendPacket(buf, maxSize, now) + ecn := s.sentPacketHandler.ECNMode() + size, err := s.appendOnePacket(buf, maxSize, ecn, now) if err != nil { if err != errNothingToPack { return err @@ -1938,7 +1941,7 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { continue } - s.sendQueue.Send(buf, uint16(maxSize)) + s.sendQueue.Send(buf, uint16(maxSize), ecn) if dontSendMore { return nil @@ -1966,6 +1969,7 @@ func (s *connection) resetPacingDeadline() { } func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { + ecn := s.sentPacketHandler.ECNMode() if !s.handshakeConfirmed { packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { @@ -1974,7 +1978,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { if packet == nil { return nil } - return s.sendPackedCoalescedPacket(packet, time.Now()) + return s.sendPackedCoalescedPacket(packet, ecn, time.Now()) } p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version) @@ -1984,9 +1988,9 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { } return err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) - s.registerPackedShortHeaderPacket(p, now) - s.sendQueue.Send(buf, 0) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false) + s.registerPackedShortHeaderPacket(p, ecn, now) + s.sendQueue.Send(buf, 0, ecn) return nil } @@ -2018,24 +2022,24 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) } - return s.sendPackedCoalescedPacket(packet, now) + return s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(), now) } -// appendPacket appends a new packet to the given packetBuffer. +// appendOnePacket appends a new packet to the given packetBuffer. // If there was nothing to pack, the returned size is 0. -func (s *connection) appendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time) (protocol.ByteCount, error) { +func (s *connection) appendOnePacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) { startLen := buf.Len() p, err := s.packer.AppendPacket(buf, maxSize, s.version) if err != nil { return 0, err } size := buf.Len() - startLen - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, size, false) - s.registerPackedShortHeaderPacket(p, now) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, size, false) + s.registerPackedShortHeaderPacket(p, ecn, now) return size, nil } -func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now time.Time) { +func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, now time.Time) { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) { s.firstAckElicitingPacketAfterIdleSentTime = now } @@ -2044,12 +2048,12 @@ func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now ti if p.Ack != nil { largestAcked = p.Ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) + s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket) s.connIDManager.SentPacket() } -func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) error { - s.logCoalescedPacket(packet) +func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN, now time.Time) error { + s.logCoalescedPacket(packet, ecn) for _, p := range packet.longHdrPackets { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { s.firstAckElicitingPacketAfterIdleSentTime = now @@ -2058,7 +2062,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time if p.ack != nil { largestAcked = p.ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false) + s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false) if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake { // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. // See Section 4.9.1 of RFC 9001. @@ -2075,10 +2079,10 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time if p.Ack != nil { largestAcked = p.Ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) + s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket) } s.connIDManager.SentPacket() - s.sendQueue.Send(packet.buffer, 0) + s.sendQueue.Send(packet.buffer, 0, ecn) return nil } @@ -2100,8 +2104,9 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { if err != nil { return nil, err } - s.logCoalescedPacket(packet) - return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0) + ecn := s.sentPacketHandler.ECNMode() + s.logCoalescedPacket(packet, ecn) + return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn) } func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { @@ -2144,11 +2149,12 @@ func (s *connection) logShortHeaderPacket( pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit, + ecn protocol.ECN, size protocol.ByteCount, isCoalesced bool, ) { if s.logger.Debug() && !isCoalesced { - s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT", pn, size, s.logID) + s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, s.logID, ecn) } // quic-go logging if s.logger.Debug() { @@ -2191,7 +2197,7 @@ func (s *connection) logShortHeaderPacket( } } -func (s *connection) logCoalescedPacket(packet *coalescedPacket) { +func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) { if s.logger.Debug() { // There's a short period between dropping both Initial and Handshake keys and completion of the handshake, // during which we might call PackCoalescedPacket but just pack a short header packet. @@ -2204,6 +2210,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) { packet.shortHdrPacket.PacketNumber, packet.shortHdrPacket.PacketNumberLen, packet.shortHdrPacket.KeyPhase, + ecn, packet.shortHdrPacket.Length, false, ) @@ -2219,7 +2226,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) { s.logLongHeaderPacket(p) } if p := packet.shortHdrPacket; p != nil { - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true) } } diff --git a/connection_test.go b/connection_test.go index 4c02c492..828f4ced 100644 --- a/connection_test.go +++ b/connection_test.go @@ -454,7 +454,7 @@ var _ = Describe("Connection", func() { Expect(e.ErrorMessage).To(BeEmpty()) return &coalescedPacket{buffer: buffer}, nil }) - mconn.EXPECT().Write([]byte("connection close"), gomock.Any()) + mconn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { var appErr *ApplicationError @@ -475,7 +475,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -494,7 +494,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(expectedErr), tracer.EXPECT().Close(), @@ -516,7 +516,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(expectedErr), tracer.EXPECT().Close(), @@ -565,7 +565,7 @@ var _ = Describe("Connection", func() { close(returned) }() Consistently(returned).ShouldNot(BeClosed()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -609,13 +609,14 @@ var _ = Describe("Connection", func() { conn.handshakeConfirmed = true sconn := NewMockSendConn(mockCtrl) sconn.EXPECT().capabilities().AnyTimes() - sconn.EXPECT().Write(gomock.Any(), gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() + sconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() conn.sendQueue = newSendQueue(sconn) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes() + sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() // only expect a single SentPacket() call - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -837,7 +838,7 @@ var _ = Describe("Connection", func() { // make the go routine return tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("close")) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -872,7 +873,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("close")) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -908,7 +909,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("close")) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -930,7 +931,7 @@ var _ = Describe("Connection", func() { close(done) }() expectReplaceWithClosed() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) packet := getShortHeaderPacket(srcConnID, 0x42, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -958,7 +959,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.shutdown() Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -980,7 +981,7 @@ var _ = Describe("Connection", func() { close(done) }() expectReplaceWithClosed() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil)) @@ -1190,6 +1191,7 @@ var _ = Describe("Connection", func() { var ( connDone chan struct{} sender *MockSender + sph *mockackhandler.MockSentPacketHandler ) BeforeEach(func() { @@ -1198,14 +1200,17 @@ var _ = Describe("Connection", func() { sender.EXPECT().WouldBlock().AnyTimes() conn.sendQueue = sender connDone = make(chan struct{}) + sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) + conn.sentPacketHandler = sph }) AfterEach(func() { streamManager.EXPECT().CloseWithError(gomock.Any()) + sph.EXPECT().ECNMode().Return(protocol.ECNCE).MaxTimes(1) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() @@ -1226,12 +1231,11 @@ var _ = Describe("Connection", func() { It("sends packets", func() { conn.handshakeConfirmed = true - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - conn.sentPacketHandler = sph + sph.EXPECT().ECNMode().Return(protocol.ECNNon).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) runConn() p := shortHeaderPacket{ DestConnID: protocol.ParseConnectionID([]byte{1, 2, 3}), @@ -1243,7 +1247,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() sent := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) tracer.EXPECT().SentShortHeaderPacket(&logging.ShortHeader{ DestConnectionID: p.DestConnID, PacketNumber: p.PacketNumber, @@ -1256,6 +1260,9 @@ var _ = Describe("Connection", func() { It("doesn't send packets if there's nothing to send", func() { conn.handshakeConfirmed = true + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() runConn() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) @@ -1264,13 +1271,12 @@ var _ = Describe("Connection", func() { }) It("sends ACK only packets", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) + sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes() done := make(chan struct{}) packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) { close(done) }) - conn.sentPacketHandler = sph runConn() conn.scheduleSending() Eventually(done).Should(BeClosed()) @@ -1278,12 +1284,11 @@ var _ = Describe("Connection", func() { It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { conn.handshakeConfirmed = true - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - conn.sentPacketHandler = sph + sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 13}, []byte("foobar")) @@ -1291,7 +1296,7 @@ var _ = Describe("Connection", func() { conn.connFlowController = fc runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) @@ -1300,11 +1305,9 @@ var _ = Describe("Connection", func() { }) It("doesn't send when the SentPacketHandler doesn't allow it", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone).AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() - conn.sentPacketHandler = sph runConn() conn.scheduleSending() time.Sleep(50 * time.Millisecond) @@ -1333,21 +1336,19 @@ var _ = Describe("Connection", func() { }) It("sends a probe packet", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) sph.EXPECT().QueueProbePacket(encLevel) + sph.EXPECT().ECNMode() p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) { - Expect(pn).To(Equal(protocol.PacketNumber(123))) - }) + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) if enc == protocol.Encryption1RTT { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) } else { @@ -1358,21 +1359,18 @@ var _ = Describe("Connection", func() { }) It("sends a PING as a probe packet", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) + sph.EXPECT().ECNMode() sph.EXPECT().QueueProbePacket(encLevel).Return(false) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) { - Expect(pn).To(Equal(protocol.PacketNumber(123))) - }) - conn.sentPacketHandler = sph + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) if enc == protocol.Encryption1RTT { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) } else { @@ -1409,10 +1407,11 @@ var _ = Describe("Connection", func() { AfterEach(func() { // make the go routine return + sph.EXPECT().ECNMode().MaxTimes(1) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() @@ -1421,17 +1420,18 @@ var _ = Describe("Connection", func() { }) It("sends multiple packets one by one immediately", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) + sph.EXPECT().ECNMode().Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, []byte("packet11")) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16) { + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal([]byte("packet10"))) }) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16) { + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal([]byte("packet11"))) }) go func() { @@ -1446,7 +1446,8 @@ var _ = Describe("Connection", func() { It("sends multiple packets one by one immediately, with GSO", func() { enableGSO() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + sph.EXPECT().ECNMode().Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) rand.Read(payload1) @@ -1456,7 +1457,7 @@ var _ = Describe("Connection", func() { expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize())).Do(func(b *packetBuffer, l uint16) { + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal(append(payload1, payload2...))) }) go func() { @@ -1471,8 +1472,9 @@ var _ = Describe("Connection", func() { It("stops appending packets when a smaller packet is packed, with GSO", func() { enableGSO() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) + sph.EXPECT().ECNMode().Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) rand.Read(payload1) @@ -1481,7 +1483,7 @@ var _ = Describe("Connection", func() { expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize())).Do(func(b *packetBuffer, l uint16) { + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal(append(payload1, payload2...))) }) go func() { @@ -1495,12 +1497,13 @@ var _ = Describe("Connection", func() { }) It("sends multiple packets, when the pacer allows immediate sending", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) + sph.EXPECT().ECNMode().Times(2) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1512,13 +1515,14 @@ var _ = Describe("Connection", func() { }) It("allows an ACK to be sent when pacing limited", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) + sph.EXPECT().ECNMode() packer.EXPECT().PackAckOnlyPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{PacketNumber: 123}, getPacketBuffer(), nil) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1532,12 +1536,13 @@ var _ = Describe("Connection", func() { // when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck // we shouldn't send the ACK in the same run It("doesn't send an ACK right after becoming congestion limited", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) + sph.EXPECT().ECNMode().Times(2) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1552,19 +1557,21 @@ var _ = Describe("Connection", func() { pacingDelay := scaleDuration(100 * time.Millisecond) gomock.InOrder( sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), + sph.EXPECT().ECNMode(), expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), + sph.EXPECT().ECNMode(), expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 101}, []byte("packet101")), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)), ) written := make(chan struct{}, 2) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }).Times(2) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }).Times(2) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1578,8 +1585,9 @@ var _ = Describe("Connection", func() { }) It("sends multiple packets at once", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) + sph.EXPECT().ECNMode().Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) for pn := protocol.PacketNumber(1000); pn < 1003; pn++ { @@ -1587,7 +1595,7 @@ var _ = Describe("Connection", func() { } written := make(chan struct{}, 3) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }).Times(3) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }).Times(3) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1618,11 +1626,12 @@ var _ = Describe("Connection", func() { written := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) available <- struct{}{} Eventually(written).Should(BeClosed()) }) @@ -1639,14 +1648,15 @@ var _ = Describe("Connection", func() { written := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(time.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ByteCount, bool) { + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(time.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ECN, protocol.ByteCount, bool) { sph.EXPECT().ReceivedBytes(gomock.Any()) conn.handlePacket(receivedPacket{buffer: getPacketBuffer()}) }) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) conn.scheduleSending() time.Sleep(scaleDuration(50 * time.Millisecond)) @@ -1655,13 +1665,14 @@ var _ = Describe("Connection", func() { }) It("stops sending when the send queue is full", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) + sph.EXPECT().ECNMode() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock() sender.EXPECT().WouldBlock().Return(true).Times(2) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1675,12 +1686,13 @@ var _ = Describe("Connection", func() { time.Sleep(scaleDuration(50 * time.Millisecond)) // now make room in the send queue - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1001}, []byte("packet1001")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }) available <- struct{}{} Eventually(written).Should(Receive()) @@ -1691,6 +1703,7 @@ var _ = Describe("Connection", func() { It("doesn't set a pacing timer when there is no data to send", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) // don't EXPECT any calls to mconn.Write() @@ -1708,12 +1721,13 @@ var _ = Describe("Connection", func() { mtuDiscoverer := NewMockMTUDiscoverer(mockCtrl) conn.mtuDiscoverer = mtuDiscoverer conn.config.DisablePathMTUDiscovery = false - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) + sph.EXPECT().ECNMode() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }) mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true) ping := ackhandler.Frame{Frame: &wire.PingFrame{}} mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) @@ -1747,7 +1761,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) sender.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -1760,8 +1774,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1}, []byte("packet1")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) @@ -1776,7 +1791,7 @@ var _ = Describe("Connection", func() { time.Sleep(50 * time.Millisecond) // only EXPECT calls after scheduleSending is called written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() conn.scheduleSending() Eventually(written).Should(BeClosed()) @@ -1788,9 +1803,8 @@ var _ = Describe("Connection", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) { - Expect(pn).To(Equal(protocol.PacketNumber(1234))) - }) + sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(1234), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) @@ -1799,7 +1813,7 @@ var _ = Describe("Connection", func() { conn.receivedPacketHandler = rph written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() go func() { defer GinkgoRecover() @@ -1841,18 +1855,11 @@ var _ = Describe("Connection", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes() sph.EXPECT().TimeUntilSend().Return(time.Now()).AnyTimes() gomock.InOrder( - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, _ bool) { - Expect(encLevel).To(Equal(protocol.EncryptionInitial)) - Expect(pn).To(Equal(protocol.PacketNumber(13))) - Expect(size).To(BeEquivalentTo(123)) - }), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, _ bool) { - Expect(encLevel).To(Equal(protocol.EncryptionHandshake)) - Expect(pn).To(Equal(protocol.PacketNumber(37))) - Expect(size).To(BeEquivalentTo(1234)) - }), + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(13), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionInitial, protocol.ECT1, protocol.ByteCount(123), gomock.Any()), + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(37), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionHandshake, protocol.ECT1, protocol.ByteCount(1234), gomock.Any()), ) gomock.InOrder( tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) { @@ -1864,7 +1871,7 @@ var _ = Describe("Connection", func() { ) sent := make(chan struct{}) - mconn.EXPECT().Write([]byte("foobar"), uint16(0)).Do(func([]byte, uint16) { close(sent) }) + mconn.EXPECT().Write([]byte("foobar"), uint16(0), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) { close(sent) }) go func() { defer GinkgoRecover() @@ -1881,7 +1888,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -1952,7 +1959,7 @@ var _ = Describe("Connection", func() { }() handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx).ShouldNot(BeClosed()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("handshake error")) Consistently(handshakeCtx).ShouldNot(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed()) @@ -1961,11 +1968,12 @@ var _ = Describe("Connection", func() { It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SetHandshakeConfirmed() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph done := make(chan struct{}) @@ -1987,7 +1995,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handleHandshakeComplete()).To(Succeed()) conn.run() }() @@ -2016,7 +2024,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2043,7 +2051,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() Expect(conn.CloseWithError(0x1337, testErr.Error())).To(Succeed()) @@ -2102,7 +2110,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2255,7 +2263,7 @@ var _ = Describe("Connection", func() { // make the go routine return expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.shutdown() Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -2338,7 +2346,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2554,7 +2562,7 @@ var _ = Describe("Client Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()).MaxTimes(1) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2851,7 +2859,7 @@ var _ = Describe("Client Connection", func() { packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) } cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()), tracer.EXPECT().Close(), diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 7b4eaa31..ba224228 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -10,7 +10,7 @@ import ( // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet - SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, isPathMTUProbePacket bool) + SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool) // ReceivedAck processes an ACK frame. // It does not store a copy of the frame. ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error) @@ -29,6 +29,8 @@ type SentPacketHandler interface { // only to be called once the handshake is complete QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ + ECNMode() protocol.ECN + PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index a03d9f53..3b972b66 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -228,6 +228,7 @@ func (h *sentPacketHandler) SentPacket( streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, + _ protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool, ) { @@ -712,6 +713,11 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { return h.alarm } +func (h *sentPacketHandler) ECNMode() protocol.ECN { + // TODO: implement ECN logic + return protocol.ECNNon +} + func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { pnSpace := h.getPacketNumberSpace(encLevel) pn := pnSpace.pns.Peek() diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 396014d7..614b7c85 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -106,7 +106,7 @@ var _ = Describe("SentPacketHandler", func() { } sentPacket := func(p *packet) { - handler.SentPacket(p.SendTime, p.PacketNumber, p.LargestAcked, p.StreamFrames, p.Frames, p.EncryptionLevel, p.Length, p.IsPathMTUProbePacket) + handler.SentPacket(p.SendTime, p.PacketNumber, p.LargestAcked, p.StreamFrames, p.Frames, p.EncryptionLevel, protocol.ECNNon, p.Length, p.IsPathMTUProbePacket) } expectInPacketHistory := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) { diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 24f5a157..3479d9e6 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -49,6 +49,20 @@ func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) } +// ECNMode mocks base method. +func (m *MockSentPacketHandler) ECNMode() protocol.ECN { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ECNMode") + ret0, _ := ret[0].(protocol.ECN) + return ret0 +} + +// ECNMode indicates an expected call of ECNMode. +func (mr *MockSentPacketHandlerMockRecorder) ECNMode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode)) +} + // GetLossDetectionTimeout mocks base method. func (m *MockSentPacketHandler) GetLossDetectionTimeout() time.Time { m.ctrl.T.Helper() @@ -176,15 +190,15 @@ func (mr *MockSentPacketHandlerMockRecorder) SendMode(arg0 interface{}) *gomock. } // SentPacket mocks base method. -func (m *MockSentPacketHandler) SentPacket(arg0 time.Time, arg1, arg2 protocol.PacketNumber, arg3 []ackhandler.StreamFrame, arg4 []ackhandler.Frame, arg5 protocol.EncryptionLevel, arg6 protocol.ByteCount, arg7 bool) { +func (m *MockSentPacketHandler) SentPacket(arg0 time.Time, arg1, arg2 protocol.PacketNumber, arg3 []ackhandler.StreamFrame, arg4 []ackhandler.Frame, arg5 protocol.EncryptionLevel, arg6 protocol.ECN, arg7 protocol.ByteCount, arg8 bool) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) } // SentPacket indicates an expected call of SentPacket. -func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) } // SetHandshakeConfirmed mocks base method. diff --git a/mock_raw_conn_test.go b/mock_raw_conn_test.go index 84d8f276..d462fc87 100644 --- a/mock_raw_conn_test.go +++ b/mock_raw_conn_test.go @@ -9,6 +9,7 @@ import ( reflect "reflect" time "time" + protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) @@ -93,18 +94,18 @@ func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Cal } // WritePacket mocks base method. -func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 uint16) (int, error) { +func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 uint16, arg4 protocol.ECN) (int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // WritePacket indicates an expected call of WritePacket. -func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3, arg4) } // capabilities mocks base method. diff --git a/mock_send_conn_test.go b/mock_send_conn_test.go index 529b9c58..6568f9e5 100644 --- a/mock_send_conn_test.go +++ b/mock_send_conn_test.go @@ -8,6 +8,7 @@ import ( net "net" reflect "reflect" + protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) @@ -77,17 +78,17 @@ func (mr *MockSendConnMockRecorder) RemoteAddr() *gomock.Call { } // Write mocks base method. -func (m *MockSendConn) Write(arg0 []byte, arg1 uint16) error { +func (m *MockSendConn) Write(arg0 []byte, arg1 uint16, arg2 protocol.ECN) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", arg0, arg1) + ret := m.ctrl.Call(m, "Write", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // Write indicates an expected call of Write. -func (mr *MockSendConnMockRecorder) Write(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockSendConnMockRecorder) Write(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1, arg2) } // capabilities mocks base method. diff --git a/mock_sender_test.go b/mock_sender_test.go index 40671562..67bb9d09 100644 --- a/mock_sender_test.go +++ b/mock_sender_test.go @@ -7,6 +7,7 @@ package quic import ( reflect "reflect" + protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) @@ -74,15 +75,15 @@ func (mr *MockSenderMockRecorder) Run() *gomock.Call { } // Send mocks base method. -func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16) { +func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16, arg2 protocol.ECN) { m.ctrl.T.Helper() - m.ctrl.Call(m, "Send", arg0, arg1) + m.ctrl.Call(m, "Send", arg0, arg1, arg2) } // Send indicates an expected call of Send. -func (mr *MockSenderMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockSenderMockRecorder) Send(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1, arg2) } // WouldBlock mocks base method. diff --git a/packet_handler_map.go b/packet_handler_map.go index 60b7cef9..d7d92377 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -29,7 +29,7 @@ type rawConn interface { // WritePacket writes a packet on the wire. // gsoSize is the size of a single packet, or 0 to disable GSO. // It is invalid to set gsoSize if capabilities.GSO is not set. - WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16) (int, error) + WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) LocalAddr() net.Addr SetReadDeadline(time.Time) error io.Closer diff --git a/send_conn.go b/send_conn.go index 030e0fba..4fda1469 100644 --- a/send_conn.go +++ b/send_conn.go @@ -9,7 +9,7 @@ import ( // A sendConn allows sending using a simple Write() on a non-connected packet conn. type sendConn interface { - Write(b []byte, gsoSize uint16) error + Write(b []byte, gsoSize uint16, ecn protocol.ECN) error Close() error LocalAddr() net.Addr RemoteAddr() net.Addr @@ -43,13 +43,6 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge } oob := info.OOB() - if remoteUDPAddr, ok := remote.(*net.UDPAddr); ok { - if remoteUDPAddr.IP.To4() != nil { - oob = appendIPv4ECNMsg(oob, protocol.ECT1) - } else { - oob = appendIPv6ECNMsg(oob, protocol.ECT1) - } - } // increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating l := len(oob) oob = append(oob, make([]byte, 64)...)[:l] @@ -62,8 +55,8 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge } } -func (c *sconn) Write(p []byte, gsoSize uint16) error { - _, err := c.WritePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize) +func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error { + _, err := c.WritePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize, ecn) if err != nil && isGSOError(err) { // disable GSO for future calls c.gotGSOError = true @@ -76,7 +69,7 @@ func (c *sconn) Write(p []byte, gsoSize uint16) error { if l > int(gsoSize) { l = int(gsoSize) } - if _, err := c.WritePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0); err != nil { + if _, err := c.WritePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0, ecn); err != nil { return err } p = p[l:] diff --git a/send_conn_test.go b/send_conn_test.go index 963f2482..bbac8fe7 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -5,6 +5,7 @@ import ( "net/netip" "runtime" + "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" . "github.com/onsi/ginkgo/v2" @@ -45,8 +46,8 @@ var _ = Describe("Connection (for sending packets)", func() { pi := packetInfo{addr: netip.IPv6Loopback()} Expect(pi.OOB()).ToNot(BeEmpty()) c := newSendConn(rawConn, remoteAddr, pi, utils.DefaultLogger) - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, pi.OOB(), uint16(0)) - Expect(c.Write([]byte("foobar"), 0)).To(Succeed()) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, pi.OOB(), uint16(0), protocol.ECT1) + Expect(c.Write([]byte("foobar"), 0, protocol.ECT1)).To(Succeed()) }) } @@ -55,8 +56,8 @@ var _ = Describe("Connection (for sending packets)", func() { rawConn.EXPECT().LocalAddr() rawConn.EXPECT().capabilities().AnyTimes() c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(3)) - Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(3), protocol.ECNCE) + Expect(c.Write([]byte("foobar"), 3, protocol.ECNCE)).To(Succeed()) }) if platformSupportsGSO { @@ -67,11 +68,11 @@ var _ = Describe("Connection (for sending packets)", func() { c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) Expect(c.capabilities().GSO).To(BeTrue()) gomock.InOrder( - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(4)).Return(0, errGSO), - rawConn.EXPECT().WritePacket([]byte("foob"), remoteAddr, gomock.Any(), uint16(0)).Return(4, nil), - rawConn.EXPECT().WritePacket([]byte("ar"), remoteAddr, gomock.Any(), uint16(0)).Return(2, nil), + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(4), protocol.ECNCE).Return(0, errGSO), + rawConn.EXPECT().WritePacket([]byte("foob"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(4, nil), + rawConn.EXPECT().WritePacket([]byte("ar"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(2, nil), ) - Expect(c.Write([]byte("foobar"), 4)).To(Succeed()) + Expect(c.Write([]byte("foobar"), 4, protocol.ECNCE)).To(Succeed()) Expect(c.capabilities().GSO).To(BeFalse()) }) } diff --git a/send_queue.go b/send_queue.go index 2da546e5..bde02334 100644 --- a/send_queue.go +++ b/send_queue.go @@ -1,8 +1,9 @@ package quic +import "github.com/quic-go/quic-go/internal/protocol" + type sender interface { - // Send sends a packet. GSO is only used if gsoSize > 0. - Send(p *packetBuffer, gsoSize uint16) + Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) Run() error WouldBlock() bool Available() <-chan struct{} @@ -12,6 +13,7 @@ type sender interface { type queueEntry struct { buf *packetBuffer gsoSize uint16 + ecn protocol.ECN } type sendQueue struct { @@ -39,9 +41,9 @@ func newSendQueue(conn sendConn) sender { // Send sends out a packet. It's guaranteed to not block. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Otherwise Send will panic. -func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16) { +func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) { select { - case h.queue <- queueEntry{buf: p, gsoSize: gsoSize}: + case h.queue <- queueEntry{buf: p, gsoSize: gsoSize, ecn: ecn}: // clear available channel if we've reached capacity if len(h.queue) == sendQueueCapacity { select { @@ -76,7 +78,7 @@ func (h *sendQueue) Run() error { // make sure that all queued packets are actually sent out shouldClose = true case e := <-h.queue: - if err := h.conn.Write(e.buf.Data, e.gsoSize); err != nil { + if err := h.conn.Write(e.buf.Data, e.gsoSize, e.ecn); err != nil { // This additional check enables: // 1. Checking for "datagram too large" message from the kernel, as such, // 2. Path MTU discovery,and diff --git a/send_queue_test.go b/send_queue_test.go index 0ed7bea5..e8cb8bdc 100644 --- a/send_queue_test.go +++ b/send_queue_test.go @@ -3,6 +3,8 @@ package quic import ( "errors" + "github.com/quic-go/quic-go/internal/protocol" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "go.uber.org/mock/gomock" @@ -26,10 +28,10 @@ var _ = Describe("Send Queue", func() { It("sends a packet", func() { p := getPacket([]byte("foobar")) - q.Send(p, 10) // make sure the packet size is passed through to the conn + q.Send(p, 10, protocol.ECT1) // make sure the packet size is passed through to the conn written := make(chan struct{}) - c.EXPECT().Write([]byte("foobar"), uint16(10)).Do(func([]byte, uint16) { close(written) }) + c.EXPECT().Write([]byte("foobar"), uint16(10), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) { close(written) }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -45,19 +47,19 @@ var _ = Describe("Send Queue", func() { It("panics when Send() is called although there's no space in the queue", func() { for i := 0; i < sendQueueCapacity; i++ { Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) } Expect(q.WouldBlock()).To(BeTrue()) - Expect(func() { q.Send(getPacket([]byte("raboof")), 6) }).To(Panic()) + Expect(func() { q.Send(getPacket([]byte("raboof")), 6, protocol.ECNNon) }).To(Panic()) }) It("signals when sending is possible again", func() { Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar1")), 6) + q.Send(getPacket([]byte("foobar1")), 6, protocol.ECNNon) Consistently(q.Available()).ShouldNot(Receive()) // now start sending out packets. This should free up queue space. - c.EXPECT().Write(gomock.Any(), gomock.Any()).MinTimes(1).MaxTimes(2) + c.EXPECT().Write(gomock.Any(), gomock.Any(), protocol.ECNNon).MinTimes(1).MaxTimes(2) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -67,7 +69,7 @@ var _ = Describe("Send Queue", func() { Eventually(q.Available()).Should(Receive()) Expect(q.WouldBlock()).To(BeFalse()) - Expect(func() { q.Send(getPacket([]byte("foobar2")), 7) }).ToNot(Panic()) + Expect(func() { q.Send(getPacket([]byte("foobar2")), 7, protocol.ECNNon) }).ToNot(Panic()) q.Close() Eventually(done).Should(BeClosed()) @@ -77,7 +79,7 @@ var _ = Describe("Send Queue", func() { write := make(chan struct{}, 1) written := make(chan struct{}, 100) // now start sending out packets. This should free up queue space. - c.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, uint16) error { + c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, uint16, protocol.ECN) error { written <- struct{}{} <-write return nil @@ -92,19 +94,19 @@ var _ = Describe("Send Queue", func() { close(done) }() - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) <-written // now fill up the send queue for i := 0; i < sendQueueCapacity; i++ { Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) } // One more packet is queued when it's picked up by Run and written to the connection. // In this test, it's blocked on write channel in the mocked Write call. <-written Eventually(q.WouldBlock()).Should(BeFalse()) - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) Expect(q.WouldBlock()).To(BeTrue()) Consistently(q.Available()).ShouldNot(Receive()) @@ -130,15 +132,15 @@ var _ = Describe("Send Queue", func() { // the run loop exits if there is a write error testErr := errors.New("test error") - c.EXPECT().Write(gomock.Any(), gomock.Any()).Return(testErr) - q.Send(getPacket([]byte("foobar")), 6) + c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Return(testErr) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) Eventually(done).Should(BeClosed()) sent := make(chan struct{}) go func() { defer GinkgoRecover() - q.Send(getPacket([]byte("raboof")), 6) - q.Send(getPacket([]byte("quux")), 4) + q.Send(getPacket([]byte("raboof")), 6, protocol.ECNNon) + q.Send(getPacket([]byte("quux")), 4, protocol.ECNNon) close(sent) }() @@ -147,7 +149,7 @@ var _ = Describe("Send Queue", func() { It("blocks Close() until the packet has been sent out", func() { written := make(chan []byte) - c.EXPECT().Write(gomock.Any(), gomock.Any()).Do(func(p []byte, _ uint16) { written <- p }) + c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(p []byte, _ uint16, _ protocol.ECN) { written <- p }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -155,7 +157,7 @@ var _ = Describe("Send Queue", func() { close(done) }() - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) closed := make(chan struct{}) go func() { diff --git a/server.go b/server.go index 92b5e91a..681e8956 100644 --- a/server.go +++ b/server.go @@ -745,7 +745,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0) + _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNNon) return err } @@ -844,7 +844,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0) + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNNon) return err } @@ -882,7 +882,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { if s.tracer != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNNon); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/sys_conn.go b/sys_conn.go index a72aead5..4d7a6577 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -104,7 +104,7 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16) (n int, err error) { +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, _ protocol.ECN) (n int, err error) { if gsoSize != 0 { panic("cannot use GSO with a basicConn") } diff --git a/sys_conn_df_windows.go b/sys_conn_df_windows.go index e56b7460..e27635ec 100644 --- a/sys_conn_df_windows.go +++ b/sys_conn_df_windows.go @@ -8,7 +8,6 @@ import ( "golang.org/x/sys/windows" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) @@ -53,6 +52,3 @@ func isRecvMsgSizeErr(err error) bool { // https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2 return errors.Is(err, windows.WSAEMSGSIZE) } - -func appendIPv4ECNMsg([]byte, protocol.ECN) []byte { return nil } -func appendIPv6ECNMsg([]byte, protocol.ECN) []byte { return nil } diff --git a/sys_conn_no_oob.go b/sys_conn_no_oob.go index c1fb6fb1..2a1f807e 100644 --- a/sys_conn_no_oob.go +++ b/sys_conn_no_oob.go @@ -5,8 +5,6 @@ package quic import ( "net" "net/netip" - - "github.com/quic-go/quic-go/internal/protocol" ) func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) { @@ -16,9 +14,6 @@ func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) { func inspectReadBuffer(any) (int, error) { return 0, nil } func inspectWriteBuffer(any) (int, error) { return 0, nil } -func appendIPv4ECNMsg([]byte, protocol.ECN) []byte { return nil } -func appendIPv6ECNMsg([]byte, protocol.ECN) []byte { return nil } - type packetInfo struct { addr netip.Addr } diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 6446d893..7b6052ba 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -229,7 +229,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { } // WritePacket writes a new packet. -func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16) (int, error) { +func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) { oob := packetInfoOOB if gsoSize > 0 { if !c.capabilities().GSO { @@ -237,6 +237,13 @@ func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gso } oob = appendUDPSegmentSizeMsg(oob, gsoSize) } + if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok { + if remoteUDPAddr.IP.To4() != nil { + oob = appendIPv4ECNMsg(oob, ecn) + } else { + oob = appendIPv6ECNMsg(oob, ecn) + } + } n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err } diff --git a/sys_conn_oob_test.go b/sys_conn_oob_test.go index 9d0f5e9d..40596421 100644 --- a/sys_conn_oob_test.go +++ b/sys_conn_oob_test.go @@ -53,7 +53,7 @@ var _ = Describe("OOB Conn Test", func() { return udpConn, packetChan } - Context("ECN conn", func() { + Context("reading ECN-marked packets", func() { sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr { conn, err := net.DialUDP(network, nil, addr) ExpectWithOffset(1, err).ToNot(HaveOccurred()) @@ -289,6 +289,27 @@ var _ = Describe("OOB Conn Test", func() { }) }) + Context("sending ECN-marked packets", func() { + It("sets the ECN control message", func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + c := &oobRecordingConn{UDPConn: udpConn} + oobConn, err := newConn(c, true) + Expect(err).ToNot(HaveOccurred()) + + oob := make([]byte, 0, 123) + oobConn.WritePacket([]byte("foobar"), addr, oob, 0, protocol.ECNCE) + Expect(c.oobs).To(HaveLen(1)) + oobMsg := c.oobs[0] + Expect(oobMsg).ToNot(BeEmpty()) + Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob + expected := appendIPv4ECNMsg([]byte{}, protocol.ECNCE) + Expect(oobMsg).To(Equal(expected)) + }) + }) + if platformSupportsGSO { Context("GSO", func() { It("appends the GSO control message", func() { @@ -301,14 +322,15 @@ var _ = Describe("OOB Conn Test", func() { Expect(err).ToNot(HaveOccurred()) Expect(oobConn.capabilities().GSO).To(BeTrue()) - oob := make([]byte, 0, 42) - oobConn.WritePacket([]byte("foobar"), addr, oob, 3) + oob := make([]byte, 0, 123) + oobConn.WritePacket([]byte("foobar"), addr, oob, 3, protocol.ECNCE) Expect(c.oobs).To(HaveLen(1)) oobMsg := c.oobs[0] Expect(oobMsg).ToNot(BeEmpty()) Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob expected := appendUDPSegmentSizeMsg([]byte{}, 3) - Expect(oobMsg).To(Equal(expected)) + // Check that the first control message is the OOB control message. + Expect(oobMsg[:len(expected)]).To(Equal(expected)) }) }) } diff --git a/transport.go b/transport.go index 41c97347..b841222e 100644 --- a/transport.go +++ b/transport.go @@ -228,7 +228,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { return 0, err } - return t.conn.WritePacket(b, addr, nil, 0) + return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNNon) } func (t *Transport) enqueueClosePacket(p closePacket) { @@ -246,7 +246,7 @@ func (t *Transport) runSendQueue() { case <-t.listening: return case p := <-t.closeQueue: - t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0) + t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNNon) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } @@ -414,7 +414,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0); err != nil { + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNNon); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } } From b73a4de7eaa87773d9da570442752470b051ace6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 31 Aug 2023 13:06:37 +0700 Subject: [PATCH 14/43] only add an ECN control message if ECN is supported --- connection.go | 4 +- internal/ackhandler/ackhandler.go | 3 +- .../ackhandler/received_packet_tracker.go | 2 +- internal/ackhandler/sent_packet_handler.go | 7 ++++ .../ackhandler/sent_packet_handler_test.go | 4 +- internal/protocol/protocol.go | 42 +++++++++++++++++-- internal/protocol/protocol_test.go | 17 ++++++-- packet_handler_map.go | 2 + server.go | 6 +-- sys_conn.go | 5 ++- sys_conn_oob.go | 21 ++++++---- transport.go | 6 +-- 12 files changed, 90 insertions(+), 29 deletions(-) diff --git a/connection.go b/connection.go index f234ea3f..1920bb2a 100644 --- a/connection.go +++ b/connection.go @@ -278,6 +278,7 @@ var newConnection = func( getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, clientAddressValidated, + s.conn.capabilities().ECN, s.perspective, s.tracer, s.logger, @@ -385,7 +386,8 @@ var newClientConnection = func( initialPacketNumber, getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, - false, /* has no effect */ + false, // has no effect + s.conn.capabilities().ECN, s.perspective, s.tracer, s.logger, diff --git a/internal/ackhandler/ackhandler.go b/internal/ackhandler/ackhandler.go index 2c7cc4fc..036ed5ea 100644 --- a/internal/ackhandler/ackhandler.go +++ b/internal/ackhandler/ackhandler.go @@ -14,10 +14,11 @@ func NewAckHandler( initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, clientAddressValidated bool, + enableECN bool, pers protocol.Perspective, tracer logging.ConnectionTracer, logger utils.Logger, ) (SentPacketHandler, ReceivedPacketHandler) { - sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger) + sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger) return sph, newReceivedPacketHandler(sph, rttStats, logger) } diff --git a/internal/ackhandler/received_packet_tracker.go b/internal/ackhandler/received_packet_tracker.go index 0ed929fa..7fd071e6 100644 --- a/internal/ackhandler/received_packet_tracker.go +++ b/internal/ackhandler/received_packet_tracker.go @@ -62,8 +62,8 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro if ackEliciting { h.maybeQueueACK(pn, rcvTime, isMissing) } + //nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECNCE. switch ecn { - case protocol.ECNNon: case protocol.ECT0: h.ect0++ case protocol.ECT1: diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 3b972b66..da5d0718 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -92,6 +92,8 @@ type sentPacketHandler struct { // The alarm timeout alarm time.Time + enableECN bool + perspective protocol.Perspective tracer logging.ConnectionTracer @@ -110,6 +112,7 @@ func newSentPacketHandler( initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, clientAddressValidated bool, + enableECN bool, pers protocol.Perspective, tracer logging.ConnectionTracer, logger utils.Logger, @@ -130,6 +133,7 @@ func newSentPacketHandler( appDataPackets: newPacketNumberSpace(0, true), rttStats: rttStats, congestion: congestion, + enableECN: enableECN, perspective: pers, tracer: tracer, logger: logger, @@ -714,6 +718,9 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { } func (h *sentPacketHandler) ECNMode() protocol.ECN { + if !h.enableECN { + return protocol.ECNUnsupported + } // TODO: implement ECN logic return protocol.ECNNon } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 614b7c85..a1e1f0d9 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -44,7 +44,7 @@ var _ = Describe("SentPacketHandler", func() { JustBeforeEach(func() { lostPackets = nil rttStats := utils.NewRTTStats() - handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, perspective, nil, utils.DefaultLogger) + handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, false, perspective, nil, utils.DefaultLogger) streamFrame = wire.StreamFrame{ StreamID: 5, Data: []byte{0x13, 0x37}, @@ -984,7 +984,7 @@ var _ = Describe("SentPacketHandler", func() { Context("amplification limit, for the server, with validated address", func() { JustBeforeEach(func() { rttStats := utils.NewRTTStats() - handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, perspective, nil, utils.DefaultLogger) + handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, false, perspective, nil, utils.DefaultLogger) }) It("do not limits the window", func() { diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index b8104882..d056cb9d 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -37,14 +37,48 @@ func (t PacketType) String() string { type ECN uint8 const ( - ECNNon ECN = iota // 00 - ECT1 // 01 - ECT0 // 10 - ECNCE // 11 + ECNUnsupported ECN = iota + ECNNon // 00 + ECT1 // 01 + ECT0 // 10 + ECNCE // 11 ) +func ParseECNHeaderBits(bits byte) ECN { + switch bits { + case 0: + return ECNNon + case 0b00000010: + return ECT0 + case 0b00000001: + return ECT1 + case 0b00000011: + return ECNCE + default: + panic("invalid ECN bits") + } +} + +func (e ECN) ToHeaderBits() byte { + //nolint:exhaustive // There are only 4 values. + switch e { + case ECNNon: + return 0 + case ECT0: + return 0b00000010 + case ECT1: + return 0b00000001 + case ECNCE: + return 0b00000011 + default: + panic("ECN unsupported") + } +} + func (e ECN) String() string { switch e { + case ECNUnsupported: + return "ECN unsupported" case ECNNon: return "Not-ECT" case ECT1: diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go index e672d31e..22359e6a 100644 --- a/internal/protocol/protocol_test.go +++ b/internal/protocol/protocol_test.go @@ -17,13 +17,22 @@ var _ = Describe("Protocol", func() { }) It("converts ECN bits from the IP header wire to the correct types", func() { - Expect(ECN(0)).To(Equal(ECNNon)) - Expect(ECN(0b00000010)).To(Equal(ECT0)) - Expect(ECN(0b00000001)).To(Equal(ECT1)) - Expect(ECN(0b00000011)).To(Equal(ECNCE)) + Expect(ParseECNHeaderBits(0)).To(Equal(ECNNon)) + Expect(ParseECNHeaderBits(0b00000010)).To(Equal(ECT0)) + Expect(ParseECNHeaderBits(0b00000001)).To(Equal(ECT1)) + Expect(ParseECNHeaderBits(0b00000011)).To(Equal(ECNCE)) + Expect(func() { ParseECNHeaderBits(0b1010101) }).To(Panic()) + }) + + It("converts to IP header bits", func() { + for _, v := range [...]ECN{ECNNon, ECT0, ECT1, ECNCE} { + Expect(ParseECNHeaderBits(v.ToHeaderBits())).To(Equal(v)) + } + Expect(func() { ECN(42).ToHeaderBits() }).To(Panic()) }) It("has a string representation for ECN", func() { + Expect(ECNUnsupported.String()).To(Equal("ECN unsupported")) Expect(ECNNon.String()).To(Equal("Not-ECT")) Expect(ECT0.String()).To(Equal("ECT(0)")) Expect(ECT1.String()).To(Equal("ECT(1)")) diff --git a/packet_handler_map.go b/packet_handler_map.go index d7d92377..ba30fb23 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -21,6 +21,8 @@ type connCapabilities struct { DF bool // GSO (Generic Segmentation Offload) supported GSO bool + // ECN (Explicit Congestion Notifications) supported + ECN bool } // rawConn is a connection that allow reading of a receivedPackeh. diff --git a/server.go b/server.go index 681e8956..d41c6465 100644 --- a/server.go +++ b/server.go @@ -745,7 +745,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNNon) + _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported) return err } @@ -844,7 +844,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNNon) + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported) return err } @@ -882,7 +882,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { if s.tracer != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNNon); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/sys_conn.go b/sys_conn.go index 4d7a6577..71cc4607 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -104,10 +104,13 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, _ protocol.ECN) (n int, err error) { +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, ecn protocol.ECN) (n int, err error) { if gsoSize != 0 { panic("cannot use GSO with a basicConn") } + if ecn != protocol.ECNUnsupported { + panic("cannot use ECN with a basicConn") + } return c.PacketConn.WriteTo(b, addr) } diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 7b6052ba..b9ddf3c1 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -141,6 +141,7 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { cap: connCapabilities{ DF: supportsDF, GSO: isGSOSupported(rawConn), + ECN: true, }, } for i := 0; i < batchSize; i++ { @@ -189,7 +190,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { if hdr.Level == unix.IPPROTO_IP { switch hdr.Type { case msgTypeIPTOS: - p.ecn = protocol.ECN(body[0] & ecnMask) + p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask) case ipv4PKTINFO: ip, ifIndex, ok := parseIPv4PktInfo(body) if ok { @@ -206,7 +207,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { if hdr.Level == unix.IPPROTO_IPV6 { switch hdr.Type { case unix.IPV6_TCLASS: - p.ecn = protocol.ECN(body[0] & ecnMask) + p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask) case unix.IPV6_PKTINFO: // struct in6_pktinfo { // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ @@ -237,11 +238,13 @@ func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gso } oob = appendUDPSegmentSizeMsg(oob, gsoSize) } - if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok { - if remoteUDPAddr.IP.To4() != nil { - oob = appendIPv4ECNMsg(oob, ecn) - } else { - oob = appendIPv6ECNMsg(oob, ecn) + if ecn != protocol.ECNUnsupported { + if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok { + if remoteUDPAddr.IP.To4() != nil { + oob = appendIPv4ECNMsg(oob, ecn) + } else { + oob = appendIPv6ECNMsg(oob, ecn) + } } } n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) @@ -298,7 +301,7 @@ func appendIPv4ECNMsg(b []byte, val protocol.ECN) []byte { // UnixRights uses the private `data` method, but I *think* this achieves the same goal. offset := startLen + unix.CmsgSpace(0) - b[offset] = uint8(val) + b[offset] = val.ToHeaderBits() return b } @@ -313,6 +316,6 @@ func appendIPv6ECNMsg(b []byte, val protocol.ECN) []byte { // UnixRights uses the private `data` method, but I *think* this achieves the same goal. offset := startLen + unix.CmsgSpace(0) - b[offset] = uint8(val) + b[offset] = val.ToHeaderBits() return b } diff --git a/transport.go b/transport.go index b841222e..c021be77 100644 --- a/transport.go +++ b/transport.go @@ -228,7 +228,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { return 0, err } - return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNNon) + return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported) } func (t *Transport) enqueueClosePacket(p closePacket) { @@ -246,7 +246,7 @@ func (t *Transport) runSendQueue() { case <-t.listening: return case p := <-t.closeQueue: - t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNNon) + t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } @@ -414,7 +414,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNNon); err != nil { + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } } From 8df7624c079106f0fb0e238ccf4a91a007d47504 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 1 Sep 2023 09:18:14 +0700 Subject: [PATCH 15/43] add a QUIC_GO_DISABLE_ECN env to disable ECN support --- .github/workflows/integration.yml | 5 +++++ sys_conn_oob.go | 12 +++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 90ed611b..c661e09c 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -41,6 +41,11 @@ jobs: env: QUIC_GO_DISABLE_GSO: true run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self -- -version=1 ${{ env.QLOGFLAG }} + - name: Run self tests, with ECN disabled + if: ${{ matrix.os == 'ubuntu' && (success() || failure()) }} # run this step even if the previous one failed + env: + QUIC_GO_DISABLE_ECN: true + run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self -- -version=1 ${{ env.QLOGFLAG }} - name: Run tests (32 bit) if: ${{ matrix.os != 'macos' && (success() || failure()) }} # run this step even if the previous one failed env: diff --git a/sys_conn_oob.go b/sys_conn_oob.go index b9ddf3c1..71a037d5 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -8,6 +8,8 @@ import ( "log" "net" "net/netip" + "os" + "strconv" "sync" "syscall" "time" @@ -57,6 +59,11 @@ func inspectWriteBuffer(c syscall.RawConn) (int, error) { return size, serr } +func isECNDisabled() bool { + disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_ECN")) + return err == nil && disabled +} + type oobConn struct { OOBCapablePacketConn batchConn batchConn @@ -141,7 +148,7 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { cap: connCapabilities{ DF: supportsDF, GSO: isGSOSupported(rawConn), - ECN: true, + ECN: !isECNDisabled(), }, } for i := 0; i < batchSize; i++ { @@ -239,6 +246,9 @@ func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gso oob = appendUDPSegmentSizeMsg(oob, gsoSize) } if ecn != protocol.ECNUnsupported { + if !c.capabilities().ECN { + panic("tried to send a ECN-marked packet although ECN is disabled") + } if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok { if remoteUDPAddr.IP.To4() != nil { oob = appendIPv4ECNMsg(oob, ecn) From bed8ebbd4ca413d2a16d3a89439894813ee02a4c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 1 Sep 2023 09:50:52 +0700 Subject: [PATCH 16/43] distinguish coalesced and 1-RTT packets when determining ECN mode --- connection.go | 23 ++++---- connection_test.go | 56 +++++++++---------- internal/ackhandler/interfaces.go | 3 +- internal/ackhandler/sent_packet_handler.go | 5 +- .../mocks/ackhandler/sent_packet_handler.go | 8 +-- packet_packer.go | 5 ++ packet_packer_test.go | 5 ++ 7 files changed, 59 insertions(+), 46 deletions(-) diff --git a/connection.go b/connection.go index 1920bb2a..d4e44be6 100644 --- a/connection.go +++ b/connection.go @@ -1832,7 +1832,7 @@ func (s *connection) sendPackets(now time.Time) error { if err != nil { return err } - ecn := s.sentPacketHandler.ECNMode() + ecn := s.sentPacketHandler.ECNMode(true) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false) s.registerPackedShortHeaderPacket(p, ecn, now) s.sendQueue.Send(buf, 0, ecn) @@ -1855,7 +1855,7 @@ func (s *connection) sendPackets(now time.Time) error { return err } s.sentFirstPacket = true - if err := s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(), now); err != nil { + if err := s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now); err != nil { return err } sendMode := s.sentPacketHandler.SendMode(now) @@ -1876,8 +1876,8 @@ func (s *connection) sendPackets(now time.Time) error { func (s *connection) sendPacketsWithoutGSO(now time.Time) error { for { buf := getPacketBuffer() - ecn := s.sentPacketHandler.ECNMode() - if _, err := s.appendOnePacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil { + ecn := s.sentPacketHandler.ECNMode(true) + if _, err := s.appendOneShortHeaderPacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil { if err == errNothingToPack { buf.Release() return nil @@ -1912,8 +1912,8 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { for { var dontSendMore bool - ecn := s.sentPacketHandler.ECNMode() - size, err := s.appendOnePacket(buf, maxSize, ecn, now) + ecn := s.sentPacketHandler.ECNMode(true) + size, err := s.appendOneShortHeaderPacket(buf, maxSize, ecn, now) if err != nil { if err != errNothingToPack { return err @@ -1971,8 +1971,8 @@ func (s *connection) resetPacingDeadline() { } func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { - ecn := s.sentPacketHandler.ECNMode() if !s.handshakeConfirmed { + ecn := s.sentPacketHandler.ECNMode(false) packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { return err @@ -1983,6 +1983,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { return s.sendPackedCoalescedPacket(packet, ecn, time.Now()) } + ecn := s.sentPacketHandler.ECNMode(true) p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { if err == errNothingToPack { @@ -2024,12 +2025,12 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) } - return s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(), now) + return s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now) } -// appendOnePacket appends a new packet to the given packetBuffer. +// appendOneShortHeaderPacket appends a new packet to the given packetBuffer. // If there was nothing to pack, the returned size is 0. -func (s *connection) appendOnePacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) { +func (s *connection) appendOneShortHeaderPacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) { startLen := buf.Len() p, err := s.packer.AppendPacket(buf, maxSize, s.version) if err != nil { @@ -2106,7 +2107,7 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { if err != nil { return nil, err } - ecn := s.sentPacketHandler.ECNMode() + ecn := s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()) s.logCoalescedPacket(packet, ecn) return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn) } diff --git a/connection_test.go b/connection_test.go index 828f4ced..c4e067af 100644 --- a/connection_test.go +++ b/connection_test.go @@ -613,7 +613,7 @@ var _ = Describe("Connection", func() { conn.sendQueue = newSendQueue(sconn) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes() - sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes() + sph.EXPECT().ECNMode(true).Return(protocol.ECT1).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() // only expect a single SentPacket() call sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) @@ -1206,7 +1206,7 @@ var _ = Describe("Connection", func() { AfterEach(func() { streamManager.EXPECT().CloseWithError(gomock.Any()) - sph.EXPECT().ECNMode().Return(protocol.ECNCE).MaxTimes(1) + sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECNCE).MaxTimes(1) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() @@ -1234,7 +1234,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().Return(protocol.ECNNon).AnyTimes() + sph.EXPECT().ECNMode(true).Return(protocol.ECNNon).AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) runConn() p := shortHeaderPacket{ @@ -1262,7 +1262,7 @@ var _ = Describe("Connection", func() { conn.handshakeConfirmed = true sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(true).AnyTimes() runConn() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) @@ -1274,7 +1274,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) - sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes() done := make(chan struct{}) packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) { close(done) }) runConn() @@ -1287,7 +1287,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) @@ -1341,7 +1341,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) sph.EXPECT().QueueProbePacket(encLevel) - sph.EXPECT().ECNMode() + sph.EXPECT().ECNMode(gomock.Any()) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) @@ -1363,7 +1363,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) - sph.EXPECT().ECNMode() + sph.EXPECT().ECNMode(gomock.Any()) sph.EXPECT().QueueProbePacket(encLevel).Return(false) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) @@ -1407,7 +1407,7 @@ var _ = Describe("Connection", func() { AfterEach(func() { // make the go routine return - sph.EXPECT().ECNMode().MaxTimes(1) + sph.EXPECT().ECNMode(gomock.Any()).MaxTimes(1) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() @@ -1422,7 +1422,7 @@ var _ = Describe("Connection", func() { It("sends multiple packets one by one immediately", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) - sph.EXPECT().ECNMode().Times(2) + sph.EXPECT().ECNMode(gomock.Any()).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) @@ -1447,7 +1447,7 @@ var _ = Describe("Connection", func() { It("sends multiple packets one by one immediately, with GSO", func() { enableGSO() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) - sph.EXPECT().ECNMode().Times(3) + sph.EXPECT().ECNMode(true).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) rand.Read(payload1) @@ -1474,7 +1474,7 @@ var _ = Describe("Connection", func() { enableGSO() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) - sph.EXPECT().ECNMode().Times(2) + sph.EXPECT().ECNMode(true).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) rand.Read(payload1) @@ -1499,7 +1499,7 @@ var _ = Describe("Connection", func() { It("sends multiple packets, when the pacer allows immediate sending", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) - sph.EXPECT().ECNMode().Times(2) + sph.EXPECT().ECNMode(gomock.Any()).Times(2) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() @@ -1518,7 +1518,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) - sph.EXPECT().ECNMode() + sph.EXPECT().ECNMode(gomock.Any()) packer.EXPECT().PackAckOnlyPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{PacketNumber: 123}, getPacketBuffer(), nil) sender.EXPECT().WouldBlock().AnyTimes() @@ -1539,7 +1539,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) - sph.EXPECT().ECNMode().Times(2) + sph.EXPECT().ECNMode(gomock.Any()).Times(2) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) @@ -1557,13 +1557,13 @@ var _ = Describe("Connection", func() { pacingDelay := scaleDuration(100 * time.Millisecond) gomock.InOrder( sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), - sph.EXPECT().ECNMode(), + sph.EXPECT().ECNMode(gomock.Any()), expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")), sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), - sph.EXPECT().ECNMode(), + sph.EXPECT().ECNMode(gomock.Any()), expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 101}, []byte("packet101")), sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), @@ -1587,7 +1587,7 @@ var _ = Describe("Connection", func() { It("sends multiple packets at once", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) - sph.EXPECT().ECNMode().Times(3) + sph.EXPECT().ECNMode(gomock.Any()).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) for pn := protocol.PacketNumber(1000); pn < 1003; pn++ { @@ -1628,7 +1628,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().WouldBlock().AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) @@ -1653,7 +1653,7 @@ var _ = Describe("Connection", func() { conn.handlePacket(receivedPacket{buffer: getPacketBuffer()}) }) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) @@ -1667,7 +1667,7 @@ var _ = Describe("Connection", func() { It("stops sending when the send queue is full", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) - sph.EXPECT().ECNMode() + sph.EXPECT().ECNMode(gomock.Any()) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock() @@ -1688,7 +1688,7 @@ var _ = Describe("Connection", func() { // now make room in the send queue sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1001}, []byte("packet1001")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) @@ -1703,7 +1703,7 @@ var _ = Describe("Connection", func() { It("doesn't set a pacing timer when there is no data to send", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) // don't EXPECT any calls to mconn.Write() @@ -1723,7 +1723,7 @@ var _ = Describe("Connection", func() { conn.config.DisablePathMTUDiscovery = false sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) - sph.EXPECT().ECNMode() + sph.EXPECT().ECNMode(true) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock().AnyTimes() @@ -1774,7 +1774,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph @@ -1803,7 +1803,7 @@ var _ = Describe("Connection", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(1234), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) @@ -1855,7 +1855,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes() + sph.EXPECT().ECNMode(false).Return(protocol.ECT1).AnyTimes() sph.EXPECT().TimeUntilSend().Return(time.Now()).AnyTimes() gomock.InOrder( sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(13), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionInitial, protocol.ECT1, protocol.ByteCount(123), gomock.Any()), @@ -1968,7 +1968,7 @@ var _ = Describe("Connection", func() { It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SetHandshakeConfirmed() diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index ba224228..ba8cbbda 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -29,8 +29,7 @@ type SentPacketHandler interface { // only to be called once the handshake is complete QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ - ECNMode() protocol.ECN - + ECNMode(isShortHeaderPacket bool) protocol.ECN // isShortHeaderPacket should only be true for non-coalesced 1-RTT packets PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index da5d0718..53aba0dd 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -717,10 +717,13 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { return h.alarm } -func (h *sentPacketHandler) ECNMode() protocol.ECN { +func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN { if !h.enableECN { return protocol.ECNUnsupported } + if !isShortHeaderPacket { + return protocol.ECNNon + } // TODO: implement ECN logic return protocol.ECNNon } diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 3479d9e6..137ac2ef 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -50,17 +50,17 @@ func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomo } // ECNMode mocks base method. -func (m *MockSentPacketHandler) ECNMode() protocol.ECN { +func (m *MockSentPacketHandler) ECNMode(arg0 bool) protocol.ECN { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ECNMode") + ret := m.ctrl.Call(m, "ECNMode", arg0) ret0, _ := ret[0].(protocol.ECN) return ret0 } // ECNMode indicates an expected call of ECNMode. -func (mr *MockSentPacketHandlerMockRecorder) ECNMode() *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ECNMode(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode), arg0) } // GetLossDetectionTimeout mocks base method. diff --git a/packet_packer.go b/packet_packer.go index 0483f35b..64081c68 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -71,6 +71,11 @@ type coalescedPacket struct { shortHdrPacket *shortHeaderPacket } +// IsOnlyShortHeaderPacket says if this packet only contains a short header packet (and no long header packets). +func (p *coalescedPacket) IsOnlyShortHeaderPacket() bool { + return len(p.longHdrPackets) == 0 && p.shortHdrPacket != nil +} + func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel { //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data). switch p.header.Type { diff --git a/packet_packer_test.go b/packet_packer_test.go index 6dba31f9..e68906f3 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -655,6 +655,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(packet).ToNot(BeNil()) Expect(packet.longHdrPackets).To(HaveLen(1)) + Expect(packet.IsOnlyShortHeaderPacket()).To(BeFalse()) // cut off the tag that the mock sealer added // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] hdr, _, _, err := wire.ParsePacket(packet.buffer.Data) @@ -874,6 +875,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) + Expect(p.IsOnlyShortHeaderPacket()).To(BeFalse()) parsePacket(p.buffer.Data) }) @@ -1047,6 +1049,7 @@ var _ = Describe("Packet packer", func() { packer.retransmissionQueue.addAppData(&wire.PingFrame{}) p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) + Expect(p.IsOnlyShortHeaderPacket()).To(BeFalse()) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -1422,6 +1425,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) + Expect(p.IsOnlyShortHeaderPacket()).To(BeTrue()) Expect(p.longHdrPackets).To(BeEmpty()) Expect(p.shortHdrPacket).ToNot(BeNil()) packet := p.shortHdrPacket @@ -1448,6 +1452,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) + Expect(p.IsOnlyShortHeaderPacket()).To(BeTrue()) Expect(p.longHdrPackets).To(BeEmpty()) Expect(p.shortHdrPacket).ToNot(BeNil()) packet := p.shortHdrPacket From ad63e2a40af01c377cb02c5142111d99141a30bf Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 1 Sep 2023 11:47:05 +0700 Subject: [PATCH 17/43] trace and qlog the ECN marking on sent and received packets --- connection.go | 10 +-- connection_test.go | 70 +++++++++++---------- integrationtests/self/key_update_test.go | 6 +- integrationtests/self/self_suite_test.go | 8 ++- internal/mocks/logging/connection_tracer.go | 32 +++++----- logging/interface.go | 23 +++++-- logging/mock_connection_tracer_test.go | 32 +++++----- logging/multiplex.go | 16 ++--- logging/multiplex_test.go | 24 +++---- logging/null_tracer.go | 26 ++++---- qlog/event.go | 8 +++ qlog/qlog.go | 43 +++++++++---- qlog/qlog_test.go | 8 +++ qlog/types.go | 18 ++++++ qlog/types_test.go | 8 +++ 15 files changed, 212 insertions(+), 120 deletions(-) diff --git a/connection.go b/connection.go index d4e44be6..fdd056f1 100644 --- a/connection.go +++ b/connection.go @@ -907,6 +907,7 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc KeyPhase: keyPhase, }, p.Size(), + p.ecn, frames, ) } @@ -1192,7 +1193,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( var log func([]logging.Frame) if s.tracer != nil { log = func(frames []logging.Frame) { - s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, frames) + s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames) } } isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log) @@ -2112,7 +2113,7 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn) } -func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { +func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) { // quic-go logging if s.logger.Debug() { p.header.Log(s.logger) @@ -2140,7 +2141,7 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { if p.ack != nil { ack = logutils.ConvertAckFrame(p.ack) } - s.tracer.SentLongHeaderPacket(p.header, p.length, ack, frames) + s.tracer.SentLongHeaderPacket(p.header, p.length, ecn, ack, frames) } } @@ -2194,6 +2195,7 @@ func (s *connection) logShortHeaderPacket( KeyPhase: kp, }, size, + ecn, ack, fs, ) @@ -2226,7 +2228,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.EC } } for _, p := range packet.longHdrPackets { - s.logLongHeaderPacket(p) + s.logLongHeaderPacket(p, ecn) } if p := packet.shortHdrPacket; p != nil { s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true) diff --git a/connection_test.go b/connection_test.go index c4e067af..4fe53910 100644 --- a/connection_test.go +++ b/connection_test.go @@ -590,7 +590,7 @@ var _ = Describe("Connection", func() { return 3, protocol.PacketNumberLen2, protocol.KeyPhaseOne, b, nil }) gomock.InOrder( - tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()), + tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), tracer.EXPECT().ClosedConnection(gomock.Any()), tracer.EXPECT().Close(), ) @@ -617,7 +617,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() // only expect a single SentPacket() call sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() streamManager.EXPECT().CloseWithError(gomock.Any()) @@ -778,7 +778,7 @@ var _ = Describe("Connection", func() { conn.receivedPacketHandler = rph packet.rcvTime = rcvTime tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), []logging.Frame{}) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECNCE, []logging.Frame{}) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) @@ -796,7 +796,12 @@ var _ = Describe("Connection", func() { ) conn.receivedPacketHandler = rph packet.rcvTime = rcvTime - tracer.EXPECT().ReceivedShortHeaderPacket(&logging.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 2, KeyPhase: protocol.KeyPhaseZero}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}}) + tracer.EXPECT().ReceivedShortHeaderPacket( + &logging.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 2, KeyPhase: protocol.KeyPhaseZero}, + protocol.ByteCount(len(packet.data)), + logging.ECT1, + []logging.Frame{&logging.PingFrame{}}, + ) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) @@ -850,8 +855,7 @@ var _ = Describe("Connection", func() { pn++ return pn, protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil }).Times(3) - tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { - }).Times(3) + tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version) // only expect a single call for i := 0; i < 3; i++ { @@ -886,8 +890,7 @@ var _ = Describe("Connection", func() { pn++ return pn, protocol.PacketNumberLen4, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil }).Times(3) - tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { - }).Times(3) + tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Times(3) for i := 0; i < 3; i++ { @@ -1021,7 +1024,7 @@ var _ = Describe("Connection", func() { }, nil) p1 := getLongHeaderPacket(hdr1, nil) tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any()) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(p1)).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. p2 := getLongHeaderPacket(hdr2, nil) @@ -1054,7 +1057,7 @@ var _ = Describe("Connection", func() { unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* one PADDING frame */, nil) packet := getShortHeaderPacket(srcConnID, 0x42, nil) packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} - tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) + tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) }) @@ -1094,12 +1097,13 @@ var _ = Describe("Connection", func() { }) cryptoSetup.EXPECT().DiscardInitialKeys() tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) It("handles coalesced packets", func() { hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) + packet1.ecn = protocol.ECT1 unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) return &unpackedPacket{ @@ -1126,8 +1130,8 @@ var _ = Describe("Connection", func() { tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial).AnyTimes() cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes() gomock.InOrder( - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()), - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()), + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), logging.ECT1, gomock.Any()), + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.ECT1, gomock.Any()), ) packet1.data = append(packet1.data, packet2.data...) Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) @@ -1152,7 +1156,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes() gomock.InOrder( tracer.EXPECT().BufferedPacket(gomock.Any(), protocol.ByteCount(len(packet1.data))), - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()), + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any(), gomock.Any()), ) packet1.data = append(packet1.data, packet2.data...) Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) @@ -1178,7 +1182,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes() // don't EXPECT any more calls to unpacker.UnpackLongHeader() gomock.InOrder( - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()), + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any(), gomock.Any()), tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID), ) packet1.data = append(packet1.data, packet2.data...) @@ -1253,7 +1257,7 @@ var _ = Describe("Connection", func() { PacketNumber: p.PacketNumber, PacketNumberLen: p.PacketNumberLen, KeyPhase: p.KeyPhase, - }, gomock.Any(), nil, []logging.Frame{}) + }, gomock.Any(), gomock.Any(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) }) @@ -1297,7 +1301,7 @@ var _ = Describe("Connection", func() { runConn() sent := make(chan struct{}) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), nil, []logging.Frame{}) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) frames, _ := conn.framer.AppendControlFrames(nil, 1000, protocol.Version1) @@ -1350,9 +1354,9 @@ var _ = Describe("Connection", func() { sent := make(chan struct{}) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) if enc == protocol.Encryption1RTT { - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any(), gomock.Any()) } else { - tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any()) + tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any(), gomock.Any()) } conn.scheduleSending() Eventually(sent).Should(BeClosed()) @@ -1363,7 +1367,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) - sph.EXPECT().ECNMode(gomock.Any()) + sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT0) sph.EXPECT().QueueProbePacket(encLevel).Return(false) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) @@ -1372,9 +1376,9 @@ var _ = Describe("Connection", func() { sent := make(chan struct{}) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) if enc == protocol.Encryption1RTT { - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, logging.ECT0, gomock.Any(), gomock.Any()) } else { - tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any()) + tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, logging.ECT0, gomock.Any(), gomock.Any()) } conn.scheduleSending() Eventually(sent).Should(BeClosed()) @@ -1393,7 +1397,7 @@ var _ = Describe("Connection", func() { ) BeforeEach(func() { - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() conn.handshakeConfirmed = true @@ -1792,7 +1796,7 @@ var _ = Describe("Connection", func() { // only EXPECT calls after scheduleSending is called written := make(chan struct{}) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() conn.scheduleSending() Eventually(written).Should(BeClosed()) }) @@ -1814,7 +1818,7 @@ var _ = Describe("Connection", func() { written := make(chan struct{}) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1862,10 +1866,10 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(37), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionHandshake, protocol.ECT1, protocol.ByteCount(1234), gomock.Any()), ) gomock.InOrder( - tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) { + tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ logging.ECN, _ *wire.AckFrame, _ []logging.Frame) { Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) }), - tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) { + tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ logging.ECN, _ *wire.AckFrame, _ []logging.Frame) { Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) }), ) @@ -1974,7 +1978,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SetHandshakeConfirmed() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph done := make(chan struct{}) connRunner.EXPECT().Retire(clientDestConnID) @@ -2549,7 +2553,7 @@ var _ = Describe("Client Connection", func() { }, PacketNumberLen: protocol.PacketNumberLen2, }, []byte("foobar")) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), p.Size(), []logging.Frame{}) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), p.Size(), gomock.Any(), []logging.Frame{}) Expect(conn.handlePacketImpl(p)).To(BeTrue()) go func() { defer GinkgoRecover() @@ -2593,7 +2597,7 @@ var _ = Describe("Client Connection", func() { DestConnectionID: srcConnID, SrcConnectionID: destConnID, } - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handleLongHeaderPacket(receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) }) @@ -3105,7 +3109,7 @@ var _ = Describe("Client Connection", func() { hdr: hdr1, data: []byte{0}, // one PADDING frame }, nil) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any()) @@ -3132,7 +3136,7 @@ var _ = Describe("Client Connection", func() { It("fails on Initial-level ACK for unsent packet", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{ack}) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) }) @@ -3144,7 +3148,7 @@ var _ = Describe("Client Connection", func() { ReasonPhrase: "mitm attacker", } initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{connCloseFrame}) - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue()) }) diff --git a/integrationtests/self/key_update_test.go b/integrationtests/self/key_update_test.go index c24bef27..7975a77b 100644 --- a/integrationtests/self/key_update_test.go +++ b/integrationtests/self/key_update_test.go @@ -42,11 +42,13 @@ type keyUpdateConnTracer struct { logging.NullConnectionTracer } -func (t *keyUpdateConnTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ *logging.AckFrame, _ []logging.Frame) { +var _ logging.ConnectionTracer = &keyUpdateConnTracer{} + +func (t *keyUpdateConnTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) { sentHeaders = append(sentHeaders, hdr) } -func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, frames []logging.Frame) { +func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) { receivedHeaders = append(receivedHeaders, hdr) } diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 4b7ee5ef..6396d3a7 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -265,19 +265,21 @@ type packetTracer struct { rcvdLongHdr []packet } +var _ logging.ConnectionTracer = &packetTracer{} + func newPacketTracer() *packetTracer { return &packetTracer{closed: make(chan struct{})} } -func (t *packetTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, frames []logging.Frame) { +func (t *packetTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { t.rcvdLongHdr = append(t.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames}) } -func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, frames []logging.Frame) { +func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { t.rcvdShortHdr = append(t.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) } -func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, ack *wire.AckFrame, frames []logging.Frame) { +func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) { if ack != nil { frames = append(frames, ack) } diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index 1ca0895d..ff180062 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -184,15 +184,15 @@ func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 i } // ReceivedLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) { +func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2) + m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2, arg3) } // ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3) } // ReceivedRetry mocks base method. @@ -208,15 +208,15 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gom } // ReceivedShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) { +func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2) + m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2, arg3) } // ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3) } // ReceivedTransportParameters mocks base method. @@ -256,27 +256,27 @@ func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 int } // SentLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) { +func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3) + m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3, arg4) } // SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4) } // SentShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) { +func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3) + m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3, arg4) } // SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4) } // SentTransportParameters mocks base method. diff --git a/logging/interface.go b/logging/interface.go index 2ce8582e..a83fb507 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -15,6 +15,8 @@ import ( type ( // A ByteCount is used to count bytes. ByteCount = protocol.ByteCount + // ECN is the ECN value + ECN = protocol.ECN // A ConnectionID is a QUIC Connection ID. ConnectionID = protocol.ConnectionID // An ArbitraryLenConnectionID is a QUIC Connection ID that can be up to 255 bytes long. @@ -58,6 +60,19 @@ type ( RTTStats = utils.RTTStats ) +const ( + // ECNUnsupported means that no ECN value was set / received + ECNUnsupported = protocol.ECNUnsupported + // ECTNot is Not-ECT + ECTNot = protocol.ECNNon + // ECT0 is ECT(0) + ECT0 = protocol.ECT0 + // ECT1 is ECT(1) + ECT1 = protocol.ECT1 + // ECNCE is CE + ECNCE = protocol.ECNCE +) + const ( // KeyPhaseZero is key phase bit 0 KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero @@ -113,12 +128,12 @@ type ConnectionTracer interface { SentTransportParameters(*TransportParameters) ReceivedTransportParameters(*TransportParameters) RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT - SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) - SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame) + SentLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) + SentShortHeaderPacket(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) ReceivedRetry(*Header) - ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) - ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, frames []Frame) + ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, []Frame) + ReceivedShortHeaderPacket(*ShortHeader, ByteCount, ECN, []Frame) BufferedPacket(PacketType, ByteCount) DroppedPacket(PacketType, ByteCount, PacketDropReason) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) diff --git a/logging/mock_connection_tracer_test.go b/logging/mock_connection_tracer_test.go index f4f7648a..f2b8cc8a 100644 --- a/logging/mock_connection_tracer_test.go +++ b/logging/mock_connection_tracer_test.go @@ -183,15 +183,15 @@ func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 i } // ReceivedLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []Frame) { +func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []Frame) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2) + m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2, arg3) } // ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3) } // ReceivedRetry mocks base method. @@ -207,15 +207,15 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gom } // ReceivedShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 []Frame) { +func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []Frame) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2) + m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2, arg3) } // ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3) } // ReceivedTransportParameters mocks base method. @@ -255,27 +255,27 @@ func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 int } // SentLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) { +func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []Frame) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3) + m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3, arg4) } // SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4) } // SentShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) { +func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []Frame) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3) + m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3, arg4) } // SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4) } // SentTransportParameters mocks base method. diff --git a/logging/multiplex.go b/logging/multiplex.go index 672a5cdb..0aecc9c2 100644 --- a/logging/multiplex.go +++ b/logging/multiplex.go @@ -93,15 +93,15 @@ func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParamet } } -func (m *connTracerMultiplexer) SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) { +func (m *connTracerMultiplexer) SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { for _, t := range m.tracers { - t.SentLongHeaderPacket(hdr, size, ack, frames) + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) } } -func (m *connTracerMultiplexer) SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame) { +func (m *connTracerMultiplexer) SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { for _, t := range m.tracers { - t.SentShortHeaderPacket(hdr, size, ack, frames) + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) } } @@ -117,15 +117,15 @@ func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) { } } -func (m *connTracerMultiplexer) ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) { +func (m *connTracerMultiplexer) ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) { for _, t := range m.tracers { - t.ReceivedLongHeaderPacket(hdr, size, frames) + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) } } -func (m *connTracerMultiplexer) ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, frames []Frame) { +func (m *connTracerMultiplexer) ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) { for _, t := range m.tracers { - t.ReceivedShortHeaderPacket(hdr, size, frames) + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) } } diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index d2220499..572ced6a 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -119,18 +119,18 @@ var _ = Describe("Tracing", func() { hdr := &ExtendedHeader{Header: Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}} ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}} ping := &PingFrame{} - tr1.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping}) - tr2.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping}) - tracer.SentLongHeaderPacket(hdr, 1337, ack, []Frame{ping}) + tr1.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ECTNot, ack, []Frame{ping}) + tr2.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ECTNot, ack, []Frame{ping}) + tracer.SentLongHeaderPacket(hdr, 1337, ECTNot, ack, []Frame{ping}) }) It("traces the SentShortHeaderPacket event", func() { hdr := &ShortHeader{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}} ping := &PingFrame{} - tr1.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping}) - tr2.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping}) - tracer.SentShortHeaderPacket(hdr, 1337, ack, []Frame{ping}) + tr1.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ECNCE, ack, []Frame{ping}) + tr2.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ECNCE, ack, []Frame{ping}) + tracer.SentShortHeaderPacket(hdr, 1337, ECNCE, ack, []Frame{ping}) }) It("traces the ReceivedVersionNegotiationPacket event", func() { @@ -151,17 +151,17 @@ var _ = Describe("Tracing", func() { It("traces the ReceivedLongHeaderPacket event", func() { hdr := &ExtendedHeader{Header: Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}} ping := &PingFrame{} - tr1.EXPECT().ReceivedLongHeaderPacket(hdr, ByteCount(1337), []Frame{ping}) - tr2.EXPECT().ReceivedLongHeaderPacket(hdr, ByteCount(1337), []Frame{ping}) - tracer.ReceivedLongHeaderPacket(hdr, 1337, []Frame{ping}) + tr1.EXPECT().ReceivedLongHeaderPacket(hdr, ByteCount(1337), ECT1, []Frame{ping}) + tr2.EXPECT().ReceivedLongHeaderPacket(hdr, ByteCount(1337), ECT1, []Frame{ping}) + tracer.ReceivedLongHeaderPacket(hdr, 1337, ECT1, []Frame{ping}) }) It("traces the ReceivedShortHeaderPacket event", func() { hdr := &ShortHeader{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} ping := &PingFrame{} - tr1.EXPECT().ReceivedShortHeaderPacket(hdr, ByteCount(1337), []Frame{ping}) - tr2.EXPECT().ReceivedShortHeaderPacket(hdr, ByteCount(1337), []Frame{ping}) - tracer.ReceivedShortHeaderPacket(hdr, 1337, []Frame{ping}) + tr1.EXPECT().ReceivedShortHeaderPacket(hdr, ByteCount(1337), ECT0, []Frame{ping}) + tr2.EXPECT().ReceivedShortHeaderPacket(hdr, ByteCount(1337), ECT0, []Frame{ping}) + tracer.ReceivedShortHeaderPacket(hdr, 1337, ECT0, []Frame{ping}) }) It("traces the BufferedPacket event", func() { diff --git a/logging/null_tracer.go b/logging/null_tracer.go index de970385..731039e3 100644 --- a/logging/null_tracer.go +++ b/logging/null_tracer.go @@ -27,19 +27,23 @@ func (n NullConnectionTracer) StartedConnection(local, remote net.Addr, srcConnI func (n NullConnectionTracer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { } -func (n NullConnectionTracer) ClosedConnection(err error) {} -func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) SentLongHeaderPacket(*ExtendedHeader, ByteCount, *AckFrame, []Frame) {} -func (n NullConnectionTracer) SentShortHeaderPacket(*ShortHeader, ByteCount, *AckFrame, []Frame) {} +func (n NullConnectionTracer) ClosedConnection(err error) {} +func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {} +func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {} +func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {} +func (n NullConnectionTracer) SentLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) { +} + +func (n NullConnectionTracer) SentShortHeaderPacket(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) { +} + func (n NullConnectionTracer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) { } -func (n NullConnectionTracer) ReceivedRetry(*Header) {} -func (n NullConnectionTracer) ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, []Frame) {} -func (n NullConnectionTracer) ReceivedShortHeaderPacket(*ShortHeader, ByteCount, []Frame) {} -func (n NullConnectionTracer) BufferedPacket(PacketType, ByteCount) {} -func (n NullConnectionTracer) DroppedPacket(PacketType, ByteCount, PacketDropReason) {} +func (n NullConnectionTracer) ReceivedRetry(*Header) {} +func (n NullConnectionTracer) ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, []Frame) {} +func (n NullConnectionTracer) ReceivedShortHeaderPacket(*ShortHeader, ByteCount, ECN, []Frame) {} +func (n NullConnectionTracer) BufferedPacket(PacketType, ByteCount) {} +func (n NullConnectionTracer) DroppedPacket(PacketType, ByteCount, PacketDropReason) {} func (n NullConnectionTracer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { } diff --git a/qlog/event.go b/qlog/event.go index 9dae7444..1a76df95 100644 --- a/qlog/event.go +++ b/qlog/event.go @@ -158,6 +158,7 @@ type eventPacketSent struct { PayloadLength logging.ByteCount Frames frames IsCoalesced bool + ECN logging.ECN Trigger string } @@ -172,6 +173,9 @@ func (e eventPacketSent) MarshalJSONObject(enc *gojay.Encoder) { enc.ObjectKey("raw", rawInfo{Length: e.Length, PayloadLength: e.PayloadLength}) enc.ArrayKeyOmitEmpty("frames", e.Frames) enc.BoolKeyOmitEmpty("is_coalesced", e.IsCoalesced) + if e.ECN != logging.ECNUnsupported { + enc.StringKey("ecn", ecn(e.ECN).String()) + } enc.StringKeyOmitEmpty("trigger", e.Trigger) } @@ -180,6 +184,7 @@ type eventPacketReceived struct { Length logging.ByteCount PayloadLength logging.ByteCount Frames frames + ECN logging.ECN IsCoalesced bool Trigger string } @@ -195,6 +200,9 @@ func (e eventPacketReceived) MarshalJSONObject(enc *gojay.Encoder) { enc.ObjectKey("raw", rawInfo{Length: e.Length, PayloadLength: e.PayloadLength}) enc.ArrayKeyOmitEmpty("frames", e.Frames) enc.BoolKeyOmitEmpty("is_coalesced", e.IsCoalesced) + if e.ECN != logging.ECNUnsupported { + enc.StringKey("ecn", ecn(e.ECN).String()) + } enc.StringKeyOmitEmpty("trigger", e.Trigger) } diff --git a/qlog/qlog.go b/qlog/qlog.go index 4c480e26..6d2ece07 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -253,15 +253,33 @@ func (t *connectionTracer) toTransportParameters(tp *wire.TransportParameters) * } } -func (t *connectionTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) { - t.sentPacket(*transformLongHeader(hdr), packetSize, hdr.Length, ack, frames) +func (t *connectionTracer) SentLongHeaderPacket( + hdr *logging.ExtendedHeader, + size logging.ByteCount, + ecn logging.ECN, + ack *logging.AckFrame, + frames []logging.Frame, +) { + t.sentPacket(*transformLongHeader(hdr), size, hdr.Length, ecn, ack, frames) } -func (t *connectionTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) { - t.sentPacket(*transformShortHeader(hdr), packetSize, 0, ack, frames) +func (t *connectionTracer) SentShortHeaderPacket( + hdr *logging.ShortHeader, + size logging.ByteCount, + ecn logging.ECN, + ack *logging.AckFrame, + frames []logging.Frame, +) { + t.sentPacket(*transformShortHeader(hdr), size, 0, ecn, ack, frames) } -func (t *connectionTracer) sentPacket(hdr gojay.MarshalerJSONObject, packetSize, payloadLen logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) { +func (t *connectionTracer) sentPacket( + hdr gojay.MarshalerJSONObject, + size, payloadLen logging.ByteCount, + ecn logging.ECN, + ack *logging.AckFrame, + frames []logging.Frame, +) { numFrames := len(frames) if ack != nil { numFrames++ @@ -276,14 +294,15 @@ func (t *connectionTracer) sentPacket(hdr gojay.MarshalerJSONObject, packetSize, t.mutex.Lock() t.recordEvent(time.Now(), &eventPacketSent{ Header: hdr, - Length: packetSize, + Length: size, PayloadLength: payloadLen, + ECN: ecn, Frames: fs, }) t.mutex.Unlock() } -func (t *connectionTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, packetSize logging.ByteCount, frames []logging.Frame) { +func (t *connectionTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { fs := make([]frame, len(frames)) for i, f := range frames { fs[i] = frame{Frame: f} @@ -292,14 +311,15 @@ func (t *connectionTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, t.mutex.Lock() t.recordEvent(time.Now(), &eventPacketReceived{ Header: header, - Length: packetSize, + Length: size, PayloadLength: hdr.Length, + ECN: ecn, Frames: fs, }) t.mutex.Unlock() } -func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, packetSize logging.ByteCount, frames []logging.Frame) { +func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { fs := make([]frame, len(frames)) for i, f := range frames { fs[i] = frame{Frame: f} @@ -308,8 +328,9 @@ func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, p t.mutex.Lock() t.recordEvent(time.Now(), &eventPacketReceived{ Header: header, - Length: packetSize, - PayloadLength: packetSize - wire.ShortHeaderLen(hdr.DestConnectionID, hdr.PacketNumberLen), + Length: size, + PayloadLength: size - wire.ShortHeaderLen(hdr.DestConnectionID, hdr.PacketNumberLen), + ECN: ecn, Frames: fs, }) t.mutex.Unlock() diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index dc0d2dc1..d7634464 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -419,6 +419,7 @@ var _ = Describe("Tracing", func() { PacketNumber: 1337, }, 987, + logging.ECNCE, nil, []logging.Frame{ &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, @@ -439,6 +440,7 @@ var _ = Describe("Tracing", func() { Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) Expect(hdr).To(HaveKeyWithValue("scid", "04030201")) Expect(ev).To(HaveKey("frames")) + Expect(ev).To(HaveKeyWithValue("ecn", "CE")) frames := ev["frames"].([]interface{}) Expect(frames).To(HaveLen(2)) Expect(frames[0].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "max_stream_data")) @@ -452,6 +454,7 @@ var _ = Describe("Tracing", func() { PacketNumber: 1337, }, 123, + logging.ECNUnsupported, &logging.AckFrame{AckRanges: []logging.AckRange{{Smallest: 1, Largest: 10}}}, []logging.Frame{&logging.MaxDataFrame{MaximumData: 987}}, ) @@ -461,6 +464,7 @@ var _ = Describe("Tracing", func() { Expect(raw).To(HaveKeyWithValue("length", float64(123))) Expect(raw).ToNot(HaveKey("payload_length")) Expect(ev).To(HaveKey("header")) + Expect(ev).ToNot(HaveKey("ecn")) hdr := ev["header"].(map[string]interface{}) Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT")) Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) @@ -485,6 +489,7 @@ var _ = Describe("Tracing", func() { PacketNumber: 1337, }, 789, + logging.ECT0, []logging.Frame{ &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, &logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}, @@ -498,6 +503,7 @@ var _ = Describe("Tracing", func() { raw := ev["raw"].(map[string]interface{}) Expect(raw).To(HaveKeyWithValue("length", float64(789))) Expect(raw).To(HaveKeyWithValue("payload_length", float64(1234))) + Expect(ev).To(HaveKeyWithValue("ecn", "ECT(0)")) Expect(ev).To(HaveKey("header")) hdr := ev["header"].(map[string]interface{}) Expect(hdr).To(HaveKeyWithValue("packet_type", "initial")) @@ -520,6 +526,7 @@ var _ = Describe("Tracing", func() { tracer.ReceivedShortHeaderPacket( shdr, 789, + logging.ECT1, []logging.Frame{ &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, &logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}, @@ -533,6 +540,7 @@ var _ = Describe("Tracing", func() { raw := ev["raw"].(map[string]interface{}) Expect(raw).To(HaveKeyWithValue("length", float64(789))) Expect(raw).To(HaveKeyWithValue("payload_length", float64(789-(1+8+3)))) + Expect(ev).To(HaveKeyWithValue("ecn", "ECT(1)")) Expect(ev).To(HaveKey("header")) hdr := ev["header"].(map[string]interface{}) Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT")) diff --git a/qlog/types.go b/qlog/types.go index c47ad481..34104e61 100644 --- a/qlog/types.go +++ b/qlog/types.go @@ -312,3 +312,21 @@ func (s congestionState) String() string { return "unknown congestion state" } } + +type ecn logging.ECN + +func (e ecn) String() string { + //nolint:exhaustive // The unsupported value is never logged. + switch logging.ECN(e) { + case logging.ECTNot: + return "Not-ECT" + case logging.ECT0: + return "ECT(0)" + case logging.ECT1: + return "ECT(1)" + case logging.ECNCE: + return "CE" + default: + return "unknown ECN" + } +} diff --git a/qlog/types_test.go b/qlog/types_test.go index 9213ad31..6918b082 100644 --- a/qlog/types_test.go +++ b/qlog/types_test.go @@ -127,4 +127,12 @@ var _ = Describe("Types", func() { Expect(congestionState(logging.CongestionStateApplicationLimited).String()).To(Equal("application_limited")) Expect(congestionState(logging.CongestionStateRecovery).String()).To(Equal("recovery")) }) + + It("has a string representation for the ECN bits", func() { + Expect(ecn(logging.ECT0).String()).To(Equal("ECT(0)")) + Expect(ecn(logging.ECT1).String()).To(Equal("ECT(1)")) + Expect(ecn(logging.ECNCE).String()).To(Equal("CE")) + Expect(ecn(logging.ECTNot).String()).To(Equal("Not-ECT")) + Expect(ecn(42).String()).To(Equal("unknown ECN")) + }) }) From ffe6546833341205076b6c3631df129c4043d2b0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 1 Sep 2023 14:17:59 +0700 Subject: [PATCH 18/43] add tracing and qlogging of state transitions for ECN validation --- internal/mocks/logging/connection_tracer.go | 12 +++++++ logging/interface.go | 1 + logging/mock_connection_tracer_test.go | 12 +++++++ logging/multiplex.go | 6 ++++ logging/null_tracer.go | 27 ++++++++------- logging/types.go | 32 +++++++++++++++++ qlog/event.go | 14 ++++++++ qlog/qlog.go | 6 ++++ qlog/qlog_test.go | 21 ++++++++++++ qlog/types.go | 38 +++++++++++++++++++++ qlog/types_test.go | 18 ++++++++++ 11 files changed, 174 insertions(+), 13 deletions(-) diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index ff180062..811d1e99 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -135,6 +135,18 @@ func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 inter return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) } +// ECNStateUpdated mocks base method. +func (m *MockConnectionTracer) ECNStateUpdated(arg0 logging.ECNState, arg1 logging.ECNStateTrigger) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ECNStateUpdated", arg0, arg1) +} + +// ECNStateUpdated indicates an expected call of ECNStateUpdated. +func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1) +} + // LossTimerCanceled mocks base method. func (m *MockConnectionTracer) LossTimerCanceled() { m.ctrl.T.Helper() diff --git a/logging/interface.go b/logging/interface.go index a83fb507..62028ebc 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -148,6 +148,7 @@ type ConnectionTracer interface { SetLossTimer(TimerType, EncryptionLevel, time.Time) LossTimerExpired(TimerType, EncryptionLevel) LossTimerCanceled() + ECNStateUpdated(state ECNState, trigger ECNStateTrigger) // Close is called when the connection is closed. Close() Debug(name, msg string) diff --git a/logging/mock_connection_tracer_test.go b/logging/mock_connection_tracer_test.go index f2b8cc8a..94722509 100644 --- a/logging/mock_connection_tracer_test.go +++ b/logging/mock_connection_tracer_test.go @@ -134,6 +134,18 @@ func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 inter return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) } +// ECNStateUpdated mocks base method. +func (m *MockConnectionTracer) ECNStateUpdated(arg0 ECNState, arg1 ECNStateTrigger) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ECNStateUpdated", arg0, arg1) +} + +// ECNStateUpdated indicates an expected call of ECNStateUpdated. +func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1) +} + // LossTimerCanceled mocks base method. func (m *MockConnectionTracer) LossTimerCanceled() { m.ctrl.T.Helper() diff --git a/logging/multiplex.go b/logging/multiplex.go index 0aecc9c2..ef91f3f2 100644 --- a/logging/multiplex.go +++ b/logging/multiplex.go @@ -213,6 +213,12 @@ func (m *connTracerMultiplexer) LossTimerCanceled() { } } +func (m *connTracerMultiplexer) ECNStateUpdated(state ECNState, trigger ECNStateTrigger) { + for _, t := range m.tracers { + t.ECNStateUpdated(state, trigger) + } +} + func (m *connTracerMultiplexer) Debug(name, msg string) { for _, t := range m.tracers { t.Debug(name, msg) diff --git a/logging/null_tracer.go b/logging/null_tracer.go index 731039e3..3e3458f7 100644 --- a/logging/null_tracer.go +++ b/logging/null_tracer.go @@ -47,16 +47,17 @@ func (n NullConnectionTracer) DroppedPacket(PacketType, ByteCount, PacketDropRea func (n NullConnectionTracer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { } -func (n NullConnectionTracer) AcknowledgedPacket(EncryptionLevel, PacketNumber) {} -func (n NullConnectionTracer) LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) {} -func (n NullConnectionTracer) UpdatedCongestionState(CongestionState) {} -func (n NullConnectionTracer) UpdatedPTOCount(uint32) {} -func (n NullConnectionTracer) UpdatedKeyFromTLS(EncryptionLevel, Perspective) {} -func (n NullConnectionTracer) UpdatedKey(keyPhase KeyPhase, remote bool) {} -func (n NullConnectionTracer) DroppedEncryptionLevel(EncryptionLevel) {} -func (n NullConnectionTracer) DroppedKey(KeyPhase) {} -func (n NullConnectionTracer) SetLossTimer(TimerType, EncryptionLevel, time.Time) {} -func (n NullConnectionTracer) LossTimerExpired(timerType TimerType, level EncryptionLevel) {} -func (n NullConnectionTracer) LossTimerCanceled() {} -func (n NullConnectionTracer) Close() {} -func (n NullConnectionTracer) Debug(name, msg string) {} +func (n NullConnectionTracer) AcknowledgedPacket(EncryptionLevel, PacketNumber) {} +func (n NullConnectionTracer) LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) {} +func (n NullConnectionTracer) UpdatedCongestionState(CongestionState) {} +func (n NullConnectionTracer) UpdatedPTOCount(uint32) {} +func (n NullConnectionTracer) UpdatedKeyFromTLS(EncryptionLevel, Perspective) {} +func (n NullConnectionTracer) UpdatedKey(keyPhase KeyPhase, remote bool) {} +func (n NullConnectionTracer) DroppedEncryptionLevel(EncryptionLevel) {} +func (n NullConnectionTracer) DroppedKey(KeyPhase) {} +func (n NullConnectionTracer) SetLossTimer(TimerType, EncryptionLevel, time.Time) {} +func (n NullConnectionTracer) LossTimerExpired(TimerType, EncryptionLevel) {} +func (n NullConnectionTracer) LossTimerCanceled() {} +func (n NullConnectionTracer) ECNStateUpdated(ECNState, ECNStateTrigger) {} +func (n NullConnectionTracer) Close() {} +func (n NullConnectionTracer) Debug(name, msg string) {} diff --git a/logging/types.go b/logging/types.go index ad800692..bfee91c3 100644 --- a/logging/types.go +++ b/logging/types.go @@ -92,3 +92,35 @@ const ( // CongestionStateApplicationLimited means that the congestion controller is application limited CongestionStateApplicationLimited ) + +// ECNState is the state of the ECN state machine (see Appendix A.4 of RFC 9000) +type ECNState uint8 + +const ( + // ECNStateTesting is the testing state + ECNStateTesting ECNState = 1 + iota + // ECNStateUnknown is the unknown state + ECNStateUnknown + // ECNStateFailed is the failed state + ECNStateFailed + // ECNStateCapable is the capable state + ECNStateCapable +) + +// ECNStateTrigger is a trigger for an ECN state transition. +type ECNStateTrigger uint8 + +const ( + ECNTriggerNoTrigger ECNStateTrigger = iota + // ECNFailedNoECNCounts is emitted when an ACK acknowledges ECN-marked packets, + // but doesn't contain any ECN counts + ECNFailedNoECNCounts + // ECNFailedDecreasedECNCounts is emitted when an ACK frame decreases ECN counts + ECNFailedDecreasedECNCounts + // ECNFailedLostAllTestingPackets is emitted when all ECN testing packets are declared lost + ECNFailedLostAllTestingPackets + // ECNFailedMoreECNCountsThanSent is emitted when an ACK contains more ECN counts than ECN-marked packets were sent + ECNFailedMoreECNCountsThanSent + // ECNFailedTooFewECNCounts is emitted when contains fewer ECT(0) / ECT(1) counts than it acknowledges packets + ECNFailedTooFewECNCounts +) diff --git a/qlog/event.go b/qlog/event.go index 1a76df95..740d3c2e 100644 --- a/qlog/event.go +++ b/qlog/event.go @@ -524,6 +524,20 @@ func (e eventCongestionStateUpdated) MarshalJSONObject(enc *gojay.Encoder) { enc.StringKey("new", e.state.String()) } +type eventECNStateUpdated struct { + state logging.ECNState + trigger logging.ECNStateTrigger +} + +func (e eventECNStateUpdated) Category() category { return categoryRecovery } +func (e eventECNStateUpdated) Name() string { return "ecn_state_updated" } +func (e eventECNStateUpdated) IsNil() bool { return false } + +func (e eventECNStateUpdated) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("new", ecnState(e.state).String()) + enc.StringKeyOmitEmpty("trigger", ecnStateTrigger(e.trigger).String()) +} + type eventGeneric struct { name string msg string diff --git a/qlog/qlog.go b/qlog/qlog.go index 6d2ece07..943bbc3d 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -503,6 +503,12 @@ func (t *connectionTracer) LossTimerCanceled() { t.mutex.Unlock() } +func (t *connectionTracer) ECNStateUpdated(state logging.ECNState, trigger logging.ECNStateTrigger) { + t.mutex.Lock() + t.recordEvent(time.Now(), &eventECNStateUpdated{state: state, trigger: trigger}) + t.mutex.Unlock() +} + func (t *connectionTracer) Debug(name, msg string) { t.mutex.Lock() t.recordEvent(time.Now(), &eventGeneric{ diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index d7634464..c19a0d3e 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -865,6 +865,27 @@ var _ = Describe("Tracing", func() { Expect(ev).To(HaveKeyWithValue("event_type", "cancelled")) }) + It("records an ECN state transition, without a trigger", func() { + tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:ecn_state_updated")) + ev := entry.Event + Expect(ev).To(HaveLen(1)) + Expect(ev).To(HaveKeyWithValue("new", "unknown")) + }) + + It("records an ECN state transition, with a trigger", func() { + tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:ecn_state_updated")) + ev := entry.Event + Expect(ev).To(HaveLen(2)) + Expect(ev).To(HaveKeyWithValue("new", "failed")) + Expect(ev).To(HaveKeyWithValue("trigger", "ACK doesn't contain ECN marks")) + }) + It("records a generic event", func() { tracer.Debug("foo", "bar") entry := exportAndParseSingle() diff --git a/qlog/types.go b/qlog/types.go index 34104e61..49b4853a 100644 --- a/qlog/types.go +++ b/qlog/types.go @@ -330,3 +330,41 @@ func (e ecn) String() string { return "unknown ECN" } } + +type ecnState logging.ECNState + +func (e ecnState) String() string { + switch logging.ECNState(e) { + case logging.ECNStateTesting: + return "testing" + case logging.ECNStateUnknown: + return "unknown" + case logging.ECNStateCapable: + return "capable" + case logging.ECNStateFailed: + return "failed" + default: + return "unknown ECN state" + } +} + +type ecnStateTrigger logging.ECNStateTrigger + +func (e ecnStateTrigger) String() string { + switch logging.ECNStateTrigger(e) { + case logging.ECNTriggerNoTrigger: + return "" + case logging.ECNFailedNoECNCounts: + return "ACK doesn't contain ECN marks" + case logging.ECNFailedDecreasedECNCounts: + return "ACK decreases ECN counts" + case logging.ECNFailedLostAllTestingPackets: + return "all ECN testing packets declared lost" + case logging.ECNFailedMoreECNCountsThanSent: + return "ACK contains more ECN counts than ECN-marked packets sent" + case logging.ECNFailedTooFewECNCounts: + return "ACK contains fewer new ECN counts than acknowledged ECN-marked packets" + default: + return "unknown ECN state trigger" + } +} diff --git a/qlog/types_test.go b/qlog/types_test.go index 6918b082..d08eb739 100644 --- a/qlog/types_test.go +++ b/qlog/types_test.go @@ -135,4 +135,22 @@ var _ = Describe("Types", func() { Expect(ecn(logging.ECTNot).String()).To(Equal("Not-ECT")) Expect(ecn(42).String()).To(Equal("unknown ECN")) }) + + It("has a string representation for the ECN state", func() { + Expect(ecnState(logging.ECNStateTesting).String()).To(Equal("testing")) + Expect(ecnState(logging.ECNStateUnknown).String()).To(Equal("unknown")) + Expect(ecnState(logging.ECNStateFailed).String()).To(Equal("failed")) + Expect(ecnState(logging.ECNStateCapable).String()).To(Equal("capable")) + Expect(ecnState(42).String()).To(Equal("unknown ECN state")) + }) + + It("has a string representation for the ECN state trigger", func() { + Expect(ecnStateTrigger(logging.ECNTriggerNoTrigger).String()).To(Equal("")) + Expect(ecnStateTrigger(logging.ECNFailedNoECNCounts).String()).To(Equal("ACK doesn't contain ECN marks")) + Expect(ecnStateTrigger(logging.ECNFailedDecreasedECNCounts).String()).To(Equal("ACK decreases ECN counts")) + Expect(ecnStateTrigger(logging.ECNFailedLostAllTestingPackets).String()).To(Equal("all ECN testing packets declared lost")) + Expect(ecnStateTrigger(logging.ECNFailedMoreECNCountsThanSent).String()).To(Equal("ACK contains more ECN counts than ECN-marked packets sent")) + Expect(ecnStateTrigger(logging.ECNFailedTooFewECNCounts).String()).To(Equal("ACK contains fewer new ECN counts than acknowledged ECN-marked packets")) + Expect(ecnStateTrigger(42).String()).To(Equal("unknown ECN state trigger")) + }) }) From f9cfa248da6ceadf6c999111e1be214ead7698ac Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 1 Sep 2023 14:43:51 +0700 Subject: [PATCH 19/43] implement ECN path validation logic, send ECN-marked 1-RTT packets --- internal/ackhandler/ecn.go | 258 +++++++++++++++++++++ internal/ackhandler/ecn_test.go | 192 +++++++++++++++ internal/ackhandler/sent_packet_handler.go | 37 ++- 3 files changed, 479 insertions(+), 8 deletions(-) create mode 100644 internal/ackhandler/ecn.go create mode 100644 internal/ackhandler/ecn_test.go diff --git a/internal/ackhandler/ecn.go b/internal/ackhandler/ecn.go new file mode 100644 index 00000000..0fd39098 --- /dev/null +++ b/internal/ackhandler/ecn.go @@ -0,0 +1,258 @@ +package ackhandler + +import ( + "fmt" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" +) + +type ecnState uint8 + +const ( + ecnStateInitial ecnState = iota + ecnStateTesting + ecnStateUnknown + ecnStateCapable + ecnStateFailed +) + +// must fit into an uint8, otherwise numSentTesting and numLostTesting must have a larger type +const numECNTestingPackets = 10 + +// The ecnTracker performs ECN validation of a path. +// Once failed, it doesn't do any re-validation of the path. +// It is designed only work for 1-RTT packets, it doesn't handle multiple packet number spaces. +// In order to avoid revealing any internal state to on-path observers, +// callers should make sure to start using ECN (i.e. calling Mode) for the very first 1-RTT packet sent. +// The validation logic implemented here strictly follows the algorithm described in RFC 9000 section 13.4.2 and A.4. +type ecnTracker struct { + state ecnState + numSentTesting, numLostTesting uint8 + + firstTestingPacket protocol.PacketNumber + lastTestingPacket protocol.PacketNumber + firstCapablePacket protocol.PacketNumber + + numSentECT0, numSentECT1 int64 + numAckedECT0, numAckedECT1, numAckedECNCE int64 + + tracer logging.ConnectionTracer + logger utils.Logger +} + +func newECNTracker(logger utils.Logger, tracer logging.ConnectionTracer) *ecnTracker { + return &ecnTracker{ + firstTestingPacket: protocol.InvalidPacketNumber, + lastTestingPacket: protocol.InvalidPacketNumber, + firstCapablePacket: protocol.InvalidPacketNumber, + state: ecnStateInitial, + logger: logger, + tracer: tracer, + } +} + +func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) { + //nolint:exhaustive // These are the only ones we need to take care of. + switch ecn { + case protocol.ECNNon: + return + case protocol.ECT0: + e.numSentECT0++ + case protocol.ECT1: + e.numSentECT1++ + case protocol.ECNUnsupported: + if e.state != ecnStateFailed { + panic("didn't expect ECN to be unsupported") + } + default: + panic(fmt.Sprintf("sent packet with unexpected ECN marking: %s", ecn)) + } + + if e.state == ecnStateCapable && e.firstCapablePacket == protocol.InvalidPacketNumber { + e.firstCapablePacket = pn + } + + if e.state != ecnStateTesting { + return + } + + e.numSentTesting++ + if e.firstTestingPacket == protocol.InvalidPacketNumber { + e.firstTestingPacket = pn + } + if e.numSentECT0+e.numSentECT1 >= numECNTestingPackets { + if e.tracer != nil { + e.tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + } + e.state = ecnStateUnknown + e.lastTestingPacket = pn + } +} + +func (e *ecnTracker) Mode() protocol.ECN { + switch e.state { + case ecnStateInitial: + if e.tracer != nil { + e.tracer.ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + } + e.state = ecnStateTesting + return e.Mode() + case ecnStateTesting, ecnStateCapable: + return protocol.ECT0 + case ecnStateUnknown, ecnStateFailed: + return protocol.ECNNon + default: + panic(fmt.Sprintf("unknown ECN state: %d", e.state)) + } +} + +func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) { + if e.state != ecnStateTesting && e.state != ecnStateUnknown { + return + } + if !e.isTestingPacket(pn) { + return + } + e.numLostTesting++ + if e.numLostTesting >= e.numSentTesting { + e.logger.Debugf("Disabling ECN. All testing packets were lost.") + if e.tracer != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets) + } + e.state = ecnStateFailed + } +} + +// HandleNewlyAcked handles the ECN counts on an ACK frame. +// It must only be called for ACK frames that increase the largest acknowledged packet number, +// see section 13.4.2.1 of RFC 9000. +func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) { + if e.state == ecnStateFailed { + return false + } + + // ECN validation can fail if the received total count for either ECT(0) or ECT(1) exceeds + // the total number of packets sent with each corresponding ECT codepoint. + if ect0 > e.numSentECT0 || ect1 > e.numSentECT1 { + e.logger.Debugf("Disabling ECN. Received more ECT(0) / ECT(1) acknowledgements than packets sent.") + if e.tracer != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent) + } + e.state = ecnStateFailed + return false + } + + // Count ECT0 and ECT1 marks that we used when sending the packets that are now being acknowledged. + var ackedECT0, ackedECT1 int64 + for _, p := range packets { + //nolint:exhaustive // We only ever send ECT(0) and ECT(1). + switch e.ecnMarking(p.PacketNumber) { + case protocol.ECT0: + ackedECT0++ + case protocol.ECT1: + ackedECT1++ + } + } + + // If an ACK frame newly acknowledges a packet that the endpoint sent with either the ECT(0) or ECT(1) + // codepoint set, ECN validation fails if the corresponding ECN counts are not present in the ACK frame. + // This check detects: + // * paths that bleach all ECN marks, and + // * peers that don't report any ECN counts + if (ackedECT0 > 0 || ackedECT1 > 0) && ect0 == 0 && ect1 == 0 && ecnce == 0 { + e.logger.Debugf("Disabling ECN. ECN-marked packet acknowledged, but no ECN counts on ACK frame.") + if e.tracer != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts) + } + e.state = ecnStateFailed + return false + } + + // Determine the increase in ECT0, ECT1 and ECNCE marks + newECT0 := ect0 - e.numAckedECT0 + newECT1 := ect1 - e.numAckedECT1 + newECNCE := ecnce - e.numAckedECNCE + + // We're only processing ACKs that increase the Largest Acked. + // Therefore, the ECN counters should only ever increase. + // Any decrease means that the peer's counting logic is broken. + if newECT0 < 0 || newECT1 < 0 || newECNCE < 0 { + e.logger.Debugf("Disabling ECN. ECN counts decreased unexpectedly.") + if e.tracer != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts) + } + e.state = ecnStateFailed + return false + } + + // ECN validation also fails if the sum of the increase in ECT(0) and ECN-CE counts is less than the number + // of newly acknowledged packets that were originally sent with an ECT(0) marking. + // This could be the result of (partial) bleaching. + if newECT0+newECNCE < ackedECT0 { + e.logger.Debugf("Disabling ECN. Received less ECT(0) + ECN-CE than packets sent with ECT(0).") + if e.tracer != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts) + } + e.state = ecnStateFailed + return false + } + // Similarly, ECN validation fails if the sum of the increases to ECT(1) and ECN-CE counts is less than + // the number of newly acknowledged packets sent with an ECT(1) marking. + if newECT1+newECNCE < ackedECT1 { + e.logger.Debugf("Disabling ECN. Received less ECT(1) + ECN-CE than packets sent with ECT(1).") + if e.tracer != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts) + } + e.state = ecnStateFailed + return false + } + + if e.state == ecnStateTesting || e.state == ecnStateUnknown { + var ackedTestingPacket bool + for _, p := range packets { + if e.isTestingPacket(p.PacketNumber) { + ackedTestingPacket = true + break + } + } + // This check won't succeed if the path is mangling ECN-marks (i.e. rewrites all ECN-marked packets to CE). + if ackedTestingPacket && (newECT0 > 0 || newECT1 > 0) { + e.logger.Debugf("ECN capability confirmed.") + if e.tracer != nil { + e.tracer.ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + } + e.state = ecnStateCapable + } + } + + // update our counters + e.numAckedECT0 = ect0 + e.numAckedECT1 = ect1 + e.numAckedECNCE = ecnce + + return newECNCE > 0 +} + +func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN { + if pn < e.firstTestingPacket || e.firstTestingPacket == protocol.InvalidPacketNumber { + return protocol.ECNNon + } + if pn < e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber { + return protocol.ECT0 + } + if pn < e.firstCapablePacket || e.firstCapablePacket == protocol.InvalidPacketNumber { + return protocol.ECNNon + } + // We don't need to deal with the case when ECN validation fails, + // since we're ignoring any ECN counts reported in ACK frames in that case. + return protocol.ECT0 +} + +func (e *ecnTracker) isTestingPacket(pn protocol.PacketNumber) bool { + if e.firstTestingPacket == protocol.InvalidPacketNumber { + return false + } + return pn >= e.firstTestingPacket && (pn <= e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber) +} diff --git a/internal/ackhandler/ecn_test.go b/internal/ackhandler/ecn_test.go new file mode 100644 index 00000000..74fff094 --- /dev/null +++ b/internal/ackhandler/ecn_test.go @@ -0,0 +1,192 @@ +package ackhandler + +import ( + mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ECN tracker", func() { + var ecnTracker *ecnTracker + var tracer *mocklogging.MockConnectionTracer + + getAckedPackets := func(pns ...protocol.PacketNumber) []*packet { + var packets []*packet + for _, p := range pns { + packets = append(packets, &packet{PacketNumber: p}) + } + return packets + } + + BeforeEach(func() { + tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + ecnTracker = newECNTracker(utils.DefaultLogger, tracer) + }) + + It("sends exactly 10 testing packets", func() { + tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + for i := 0; i < 9; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + // Do this twice to make sure only sent packets are counted + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(protocol.PacketNumber(10+i), protocol.ECT0) + } + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + tracer.EXPECT().ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + ecnTracker.SentPacket(20, protocol.ECT0) + // In unknown state, packets shouldn't be ECN-marked. + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + }) + + sendAllTestingPackets := func() { + tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + tracer.EXPECT().ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + for i := 0; i < 10; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0) + } + } + + It("fails ECN validation if all ECN testing packets are lost", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + for i := 0; i < 9; i++ { + ecnTracker.LostPacket(protocol.PacketNumber(i)) + } + // We don't care about the loss of non-testing packets + ecnTracker.LostPacket(15) + // Now lose the last testing packet. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets) + ecnTracker.LostPacket(9) + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + // We still don't care about more non-testing packets being lost + ecnTracker.LostPacket(16) + }) + + It("passes ECN validation when a testing packet is acknowledged, while still in testing state", func() { + tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + for i := 0; i < 5; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0) + } + tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(3), 1, 0, 0)).To(BeFalse()) + // make sure we continue sending ECT(0) packets + for i := 5; i < 100; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0) + } + }) + + It("passes ECN validation when a testing packet is acknowledged, while in unknown state", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // Lose some packets to make sure this doesn't influence the outcome. + for i := 0; i < 5; i++ { + ecnTracker.LostPacket(protocol.PacketNumber(i)) + } + tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.HandleNewlyAcked([]*packet{{PacketNumber: 7}}, 1, 0, 0)).To(BeFalse()) + }) + + It("fails ECN validation when the ACK contains more ECN counts than we sent packets", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // only 10 ECT(0) packets were sent, but the ACK claims to have received 12 of them + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), 12, 0, 0)).To(BeFalse()) + }) + + It("fails ECN validation when the ACK contains ECN counts for the wrong code point", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // We sent ECT(0), but this ACK acknowledges ECT(1). + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), 0, 1, 0)).To(BeFalse()) + }) + + It("fails ECN validation when the ACK doesn't contain ECN counts", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // First only acknowledge packets sent without ECN marks. + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(12, 13, 14), 0, 0, 0)).To(BeFalse()) + // Now acknowledge some packets sent with ECN marks. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 15), 0, 0, 0)).To(BeFalse()) + }) + + It("fails ECN validation when an ACK decreases ECN counts", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 3, 0, 0)).To(BeFalse()) + // Now acknowledge some more packets, but decrease the ECN counts. Obviously, this doesn't make any sense. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 13), 2, 0, 0)).To(BeFalse()) + // make sure that new ACKs are ignored + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(7, 8, 9, 14), 5, 0, 0)).To(BeFalse()) + }) + + // This can happen if ACK are lost / reordered. + It("doesn't fail validation if the ACK contains more ECN counts than it acknowledges packets", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 8, 0, 0)).To(BeFalse()) + }) + + It("fails ECN validation when the ACK doesn't contain enough ECN counts", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // First only acknowledge some packets sent with ECN marks. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 2, 0, 1)).To(BeTrue()) + // Now acknowledge some more packets sent with ECN marks, but don't increase the counters enough. + // This ACK acknowledges 3 more ECN-marked packets, but the counters only increase by 2. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 15), 3, 0, 2)).To(BeFalse()) + }) + + It("declares congestion", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // Receive one CE count. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 2, 0, 1)).To(BeTrue()) + // No increase in CE. No congestion. + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 13), 5, 0, 1)).To(BeFalse()) + // Increase in CE. More congestion. + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(7, 8, 9, 14), 7, 0, 2)).To(BeTrue()) + }) +}) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 53aba0dd..7d3d4661 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -92,7 +92,8 @@ type sentPacketHandler struct { // The alarm timeout alarm time.Time - enableECN bool + enableECN bool + ecnTracker *ecnTracker perspective protocol.Perspective @@ -125,7 +126,7 @@ func newSentPacketHandler( tracer, ) - return &sentPacketHandler{ + h := &sentPacketHandler{ peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated, initialPackets: newPacketNumberSpace(initialPN, false), @@ -133,11 +134,15 @@ func newSentPacketHandler( appDataPackets: newPacketNumberSpace(0, true), rttStats: rttStats, congestion: congestion, - enableECN: enableECN, perspective: pers, tracer: tracer, logger: logger, } + if enableECN { + h.enableECN = true + h.ecnTracker = newECNTracker(logger, tracer) + } + return h } func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { @@ -232,7 +237,7 @@ func (h *sentPacketHandler) SentPacket( streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, - _ protocol.ECN, + ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool, ) { @@ -257,6 +262,10 @@ func (h *sentPacketHandler) SentPacket( } h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting) + if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil { + h.ecnTracker.SentPacket(pn, ecn) + } + if !isAckEliciting { pnSpace.history.SentNonAckElicitingPacket(pn) if !h.peerCompletedAddressValidation { @@ -307,8 +316,6 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } - pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked) - // Servers complete address validation when a protected packet is received. if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation && (encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) { @@ -338,6 +345,18 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.congestion.MaybeExitSlowStart() } } + + var ecnCongestionDetected bool + // Only inform the ECN tracker about new 1-RTT ACKs if the ACK increases the largest acked. + if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked { + ecnCongestionDetected = h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE)) + } + + pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked) + + // TODO: inform the congestion controller + _ = ecnCongestionDetected + if err := h.detectLostPackets(rcvTime, encLevel); err != nil { return false, err } @@ -642,6 +661,9 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E if !p.IsPathMTUProbePacket { h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) } + if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil { + h.ecnTracker.LostPacket(p.PacketNumber) + } } } return true, nil @@ -724,8 +746,7 @@ func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN { if !isShortHeaderPacket { return protocol.ECNNon } - // TODO: implement ECN logic - return protocol.ECNNon + return h.ecnTracker.Mode() } func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { From b6ce91bfe7ff63b353f555bc29b2c9f66efea4c8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 1 Sep 2023 18:12:22 +0700 Subject: [PATCH 20/43] stop appending to a GSO batch when the ECN marking changes --- connection.go | 10 +++++++--- connection_test.go | 47 ++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/connection.go b/connection.go index fdd056f1..10a36ba0 100644 --- a/connection.go +++ b/connection.go @@ -1911,9 +1911,9 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { buf := getLargePacketBuffer() maxSize := s.mtuDiscoverer.CurrentSize() + ecn := s.sentPacketHandler.ECNMode(true) for { var dontSendMore bool - ecn := s.sentPacketHandler.ECNMode(true) size, err := s.appendOneShortHeaderPacket(buf, maxSize, ecn, now) if err != nil { if err != errNothingToPack { @@ -1936,11 +1936,15 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { } } + // Don't send more packets in this batch if they require a different ECN marking than the previous ones. + nextECN := s.sentPacketHandler.ECNMode(true) + // Append another packet if // 1. The congestion controller and pacer allow sending more // 2. The last packet appended was a full-size packet - // 3. We still have enough space for another full-size packet in the buffer - if !dontSendMore && size == maxSize && buf.Len()+maxSize <= buf.Cap() { + // 3. The next packet will have the same ECN marking + // 4. We still have enough space for another full-size packet in the buffer + if !dontSendMore && size == maxSize && nextECN == ecn && buf.Len()+maxSize <= buf.Cap() { continue } diff --git a/connection_test.go b/connection_test.go index 4fe53910..67b02443 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1451,7 +1451,7 @@ var _ = Describe("Connection", func() { It("sends multiple packets one by one immediately, with GSO", func() { enableGSO() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) - sph.EXPECT().ECNMode(true).Times(3) + sph.EXPECT().ECNMode(true).Return(protocol.ECT1).Times(4) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) rand.Read(payload1) @@ -1476,20 +1476,59 @@ var _ = Describe("Connection", func() { It("stops appending packets when a smaller packet is packed, with GSO", func() { enableGSO() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) - sph.EXPECT().ECNMode(true).Times(2) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) + sph.EXPECT().ECNMode(true).Times(4) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) rand.Read(payload1) payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize()-1) rand.Read(payload2) + payload3 := make([]byte, conn.mtuDiscoverer.CurrentSize()) + rand.Read(payload3) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) + expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 12}, payload3) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal(append(payload1, payload2...))) }) + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { + Expect(b.Data).To(Equal(payload3)) + }) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) + conn.run() + }() + conn.scheduleSending() + time.Sleep(50 * time.Millisecond) // make sure that only 2 packets are sent + }) + + It("stops appending packets when the ECN marking changes, with GSO", func() { + enableGSO() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) + sph.EXPECT().ECNMode(true).Return(protocol.ECT1).Times(2) + sph.EXPECT().ECNMode(true).Return(protocol.ECT0).Times(2) + payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) + rand.Read(payload1) + payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize()) + rand.Read(payload2) + payload3 := make([]byte, conn.mtuDiscoverer.CurrentSize()) + rand.Read(payload3) + expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1) + expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) + expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload3) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { + Expect(b.Data).To(Equal(append(payload1, payload2...))) + }) + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { + Expect(b.Data).To(Equal(payload3)) + }) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) From 797e275293e27a0c12f0bb79633086dacf1f0bd0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 1 Sep 2023 22:07:12 +0700 Subject: [PATCH 21/43] congestion: rename OnPacketLost to OnCongestionEvent --- internal/ackhandler/sent_packet_handler.go | 2 +- .../ackhandler/sent_packet_handler_test.go | 12 +++++----- internal/congestion/cubic_sender.go | 2 +- internal/congestion/cubic_sender_test.go | 4 ++-- internal/congestion/interface.go | 2 +- internal/mocks/congestion.go | 24 +++++++++---------- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 7d3d4661..90f80422 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -659,7 +659,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E h.removeFromBytesInFlight(p) h.queueFramesForRetransmission(p) if !p.IsPathMTUProbePacket { - h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + h.congestion.OnCongestionEvent(p.PacketNumber, p.Length, priorInFlight) } if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil { h.ecnTracker.LostPacket(p.PacketNumber) diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index a1e1f0d9..acadcf37 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -563,7 +563,7 @@ var _ = Describe("SentPacketHandler", func() { // lose packet 1 gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)), + cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)), cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), ) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} @@ -575,7 +575,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).ToNot(HaveOccurred()) }) - It("doesn't call OnPacketLost when a Path MTU probe packet is lost", func() { + It("doesn't call OnCongestionEvent when a Path MTU probe packet is lost", func() { cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) var mtuPacketDeclaredLost bool sentPacket(ackElicitingPacket(&packet{ @@ -590,7 +590,7 @@ var _ = Describe("SentPacketHandler", func() { }, })) sentPacket(ackElicitingPacket(&packet{PacketNumber: 2})) - // lose packet 1, but don't EXPECT any calls to OnPacketLost() + // lose packet 1, but don't EXPECT any calls to OnCongestionEvent() gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), @@ -602,7 +602,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.bytesInFlight).To(BeZero()) }) - It("calls OnPacketAcked and OnPacketLost with the right bytes_in_flight value", func() { + It("calls OnPacketAcked and OnCongestionEvent with the right bytes_in_flight value", func() { cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(4) sentPacket(ackElicitingPacket(&packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) sentPacket(ackElicitingPacket(&packet{PacketNumber: 2, SendTime: time.Now().Add(-30 * time.Minute)})) @@ -611,7 +611,7 @@ var _ = Describe("SentPacketHandler", func() { // receive the first ACK gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)), + cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)), cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(4), gomock.Any()), ) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} @@ -620,7 +620,7 @@ var _ = Describe("SentPacketHandler", func() { // receive the second ACK gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)), + cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)), cong.EXPECT().OnPacketAcked(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), ) ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 4, Largest: 4}}} diff --git a/internal/congestion/cubic_sender.go b/internal/congestion/cubic_sender.go index 2e508475..10eb4667 100644 --- a/internal/congestion/cubic_sender.go +++ b/internal/congestion/cubic_sender.go @@ -188,7 +188,7 @@ func (c *cubicSender) OnPacketAcked( } } -func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { +func (c *cubicSender) OnCongestionEvent(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets // already sent should be treated as a single loss event, since it's expected. if packetNumber <= c.largestSentAtLastCutback { diff --git a/internal/congestion/cubic_sender_test.go b/internal/congestion/cubic_sender_test.go index 85d81167..d02ec3e8 100644 --- a/internal/congestion/cubic_sender_test.go +++ b/internal/congestion/cubic_sender_test.go @@ -80,14 +80,14 @@ var _ = Describe("Cubic Sender", func() { LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) { for i := 0; i < n; i++ { ackedPacketNumber++ - sender.OnPacketLost(ackedPacketNumber, packetLength, bytesInFlight) + sender.OnCongestionEvent(ackedPacketNumber, packetLength, bytesInFlight) } bytesInFlight -= protocol.ByteCount(n) * packetLength } // Does not increment acked_packet_number_. LosePacket := func(number protocol.PacketNumber) { - sender.OnPacketLost(number, maxDatagramSize, bytesInFlight) + sender.OnCongestionEvent(number, maxDatagramSize, bytesInFlight) bytesInFlight -= maxDatagramSize } diff --git a/internal/congestion/interface.go b/internal/congestion/interface.go index 484bd5f8..881f453b 100644 --- a/internal/congestion/interface.go +++ b/internal/congestion/interface.go @@ -14,7 +14,7 @@ type SendAlgorithm interface { CanSend(bytesInFlight protocol.ByteCount) bool MaybeExitSlowStart() OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time) - OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) + OnCongestionEvent(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) OnRetransmissionTimeout(packetsRetransmitted bool) SetMaxDatagramSize(protocol.ByteCount) } diff --git a/internal/mocks/congestion.go b/internal/mocks/congestion.go index d27113a7..4d38d1db 100644 --- a/internal/mocks/congestion.go +++ b/internal/mocks/congestion.go @@ -117,6 +117,18 @@ func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) MaybeExitSlowStart() *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeExitSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).MaybeExitSlowStart)) } +// OnCongestionEvent mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) OnCongestionEvent(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnCongestionEvent", arg0, arg1, arg2) +} + +// OnCongestionEvent indicates an expected call of OnCongestionEvent. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnCongestionEvent(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnCongestionEvent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnCongestionEvent), arg0, arg1, arg2) +} + // OnPacketAcked mocks base method. func (m *MockSendAlgorithmWithDebugInfos) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount, arg3 time.Time) { m.ctrl.T.Helper() @@ -129,18 +141,6 @@ func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(arg0, arg1, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketAcked), arg0, arg1, arg2, arg3) } -// OnPacketLost mocks base method. -func (m *MockSendAlgorithmWithDebugInfos) OnPacketLost(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnPacketLost", arg0, arg1, arg2) -} - -// OnPacketLost indicates an expected call of OnPacketLost. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketLost(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketLost", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketLost), arg0, arg1, arg2) -} - // OnPacketSent mocks base method. func (m *MockSendAlgorithmWithDebugInfos) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) { m.ctrl.T.Helper() From d6ac6300a443a636804399a1cd690ebafddc603d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 2 Sep 2023 09:26:49 +0700 Subject: [PATCH 22/43] feed ECN feedback into the congestion controller --- internal/ackhandler/ecn.go | 9 ++ internal/ackhandler/mock_ecn_handler_test.go | 87 +++++++++++ internal/ackhandler/mockgen.go | 3 + internal/ackhandler/sent_packet_handler.go | 11 +- .../ackhandler/sent_packet_handler_test.go | 147 ++++++++++++++++++ 5 files changed, 251 insertions(+), 6 deletions(-) create mode 100644 internal/ackhandler/mock_ecn_handler_test.go diff --git a/internal/ackhandler/ecn.go b/internal/ackhandler/ecn.go index 0fd39098..ff0dfe65 100644 --- a/internal/ackhandler/ecn.go +++ b/internal/ackhandler/ecn.go @@ -21,6 +21,13 @@ const ( // must fit into an uint8, otherwise numSentTesting and numLostTesting must have a larger type const numECNTestingPackets = 10 +type ecnHandler interface { + SentPacket(protocol.PacketNumber, protocol.ECN) + Mode() protocol.ECN + HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) + LostPacket(protocol.PacketNumber) +} + // The ecnTracker performs ECN validation of a path. // Once failed, it doesn't do any re-validation of the path. // It is designed only work for 1-RTT packets, it doesn't handle multiple packet number spaces. @@ -42,6 +49,8 @@ type ecnTracker struct { logger utils.Logger } +var _ ecnHandler = &ecnTracker{} + func newECNTracker(logger utils.Logger, tracer logging.ConnectionTracer) *ecnTracker { return &ecnTracker{ firstTestingPacket: protocol.InvalidPacketNumber, diff --git a/internal/ackhandler/mock_ecn_handler_test.go b/internal/ackhandler/mock_ecn_handler_test.go new file mode 100644 index 00000000..28f350c1 --- /dev/null +++ b/internal/ackhandler/mock_ecn_handler_test.go @@ -0,0 +1,87 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go/internal/ackhandler (interfaces: ECNHandler) + +// Package ackhandler is a generated GoMock package. +package ackhandler + +import ( + reflect "reflect" + + protocol "github.com/quic-go/quic-go/internal/protocol" + gomock "go.uber.org/mock/gomock" +) + +// MockECNHandler is a mock of ECNHandler interface. +type MockECNHandler struct { + ctrl *gomock.Controller + recorder *MockECNHandlerMockRecorder +} + +// MockECNHandlerMockRecorder is the mock recorder for MockECNHandler. +type MockECNHandlerMockRecorder struct { + mock *MockECNHandler +} + +// NewMockECNHandler creates a new mock instance. +func NewMockECNHandler(ctrl *gomock.Controller) *MockECNHandler { + mock := &MockECNHandler{ctrl: ctrl} + mock.recorder = &MockECNHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockECNHandler) EXPECT() *MockECNHandlerMockRecorder { + return m.recorder +} + +// HandleNewlyAcked mocks base method. +func (m *MockECNHandler) HandleNewlyAcked(arg0 []*packet, arg1, arg2, arg3 int64) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleNewlyAcked", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(bool) + return ret0 +} + +// HandleNewlyAcked indicates an expected call of HandleNewlyAcked. +func (mr *MockECNHandlerMockRecorder) HandleNewlyAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleNewlyAcked", reflect.TypeOf((*MockECNHandler)(nil).HandleNewlyAcked), arg0, arg1, arg2, arg3) +} + +// LostPacket mocks base method. +func (m *MockECNHandler) LostPacket(arg0 protocol.PacketNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LostPacket", arg0) +} + +// LostPacket indicates an expected call of LostPacket. +func (mr *MockECNHandlerMockRecorder) LostPacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockECNHandler)(nil).LostPacket), arg0) +} + +// Mode mocks base method. +func (m *MockECNHandler) Mode() protocol.ECN { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Mode") + ret0, _ := ret[0].(protocol.ECN) + return ret0 +} + +// Mode indicates an expected call of Mode. +func (mr *MockECNHandlerMockRecorder) Mode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mode", reflect.TypeOf((*MockECNHandler)(nil).Mode)) +} + +// SentPacket mocks base method. +func (m *MockECNHandler) SentPacket(arg0 protocol.PacketNumber, arg1 protocol.ECN) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockECNHandlerMockRecorder) SentPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockECNHandler)(nil).SentPacket), arg0, arg1) +} diff --git a/internal/ackhandler/mockgen.go b/internal/ackhandler/mockgen.go index b9eb8a88..dbf6ee2d 100644 --- a/internal/ackhandler/mockgen.go +++ b/internal/ackhandler/mockgen.go @@ -4,3 +4,6 @@ package ackhandler //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker" type SentPacketTracker = sentPacketTracker + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_ecn_handler_test.go github.com/quic-go/quic-go/internal/ackhandler ECNHandler" +type ECNHandler = ecnHandler diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 90f80422..b498ce4b 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -93,7 +93,7 @@ type sentPacketHandler struct { alarm time.Time enableECN bool - ecnTracker *ecnTracker + ecnTracker ecnHandler perspective protocol.Perspective @@ -346,17 +346,16 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } - var ecnCongestionDetected bool // Only inform the ECN tracker about new 1-RTT ACKs if the ACK increases the largest acked. if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked { - ecnCongestionDetected = h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE)) + congested := h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE)) + if congested { + h.congestion.OnCongestionEvent(largestAcked, 0, priorInFlight) + } } pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked) - // TODO: inform the congestion controller - _ = ecnCongestionDetected - if err := h.detectLostPackets(rcvTime, encLevel); err != nil { return false, err } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index acadcf37..63f1b834 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -1429,4 +1429,151 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.rttStats.SmoothedRTT()).To(BeZero()) }) }) + + Context("ECN handling", func() { + var ecnHandler *MockECNHandler + var cong *mocks.MockSendAlgorithmWithDebugInfos + + JustBeforeEach(func() { + cong = mocks.NewMockSendAlgorithmWithDebugInfos(mockCtrl) + cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + cong.EXPECT().OnPacketAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + cong.EXPECT().MaybeExitSlowStart().AnyTimes() + ecnHandler = NewMockECNHandler(mockCtrl) + lostPackets = nil + rttStats := utils.NewRTTStats() + rttStats.UpdateRTT(time.Hour, 0, time.Now()) + handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, false, perspective, nil, utils.DefaultLogger) + handler.ecnTracker = ecnHandler + handler.congestion = cong + }) + + It("informs about sent packets", func() { + // Check that only 1-RTT packets are reported + handler.SentPacket(time.Now(), 100, -1, nil, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false) + handler.SentPacket(time.Now(), 101, -1, nil, nil, protocol.EncryptionHandshake, protocol.ECT0, 1200, false) + handler.SentPacket(time.Now(), 102, -1, nil, nil, protocol.Encryption0RTT, protocol.ECNCE, 1200, false) + + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(103), protocol.ECT1) + handler.SentPacket(time.Now(), 103, -1, nil, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + }) + + It("informs about sent packets", func() { + // Check that only 1-RTT packets are reported + handler.SentPacket(time.Now(), 100, -1, nil, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false) + handler.SentPacket(time.Now(), 101, -1, nil, nil, protocol.EncryptionHandshake, protocol.ECT0, 1200, false) + handler.SentPacket(time.Now(), 102, -1, nil, nil, protocol.Encryption0RTT, protocol.ECNCE, 1200, false) + + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(103), protocol.ECT1) + handler.SentPacket(time.Now(), 103, -1, nil, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + }) + + It("informs about lost packets", func() { + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + } + cong.EXPECT().OnCongestionEvent(gomock.Any(), gomock.Any(), gomock.Any()).Times(3) + ecnHandler.EXPECT().LostPacket(protocol.PacketNumber(10)) + ecnHandler.EXPECT().LostPacket(protocol.PacketNumber(11)) + ecnHandler.EXPECT().LostPacket(protocol.PacketNumber(12)) + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 16, Smallest: 13}}}, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("processes ACKs", func() { + // Check that we only care about 1-RTT packets. + handler.SentPacket(time.Now(), 100, -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100, Smallest: 100}}}, protocol.EncryptionInitial, time.Now()) + Expect(err).ToNot(HaveOccurred()) + + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + } + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(1), int64(2), int64(3)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool { + Expect(packets).To(HaveLen(5)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(10))) + Expect(packets[1].PacketNumber).To(Equal(protocol.PacketNumber(11))) + Expect(packets[2].PacketNumber).To(Equal(protocol.PacketNumber(12))) + Expect(packets[3].PacketNumber).To(Equal(protocol.PacketNumber(14))) + Expect(packets[4].PacketNumber).To(Equal(protocol.PacketNumber(15))) + return false + }) + _, err = handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{ + {Largest: 15, Smallest: 14}, + {Largest: 12, Smallest: 10}, + }, + ECT0: 1, + ECT1: 2, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("ignores reordered ACKs", func() { + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + } + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(1), int64(2), int64(3)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool { + Expect(packets).To(HaveLen(2)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(11))) + Expect(packets[1].PacketNumber).To(Equal(protocol.PacketNumber(12))) + return false + }) + _, err := handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{{Largest: 12, Smallest: 11}}, + ECT0: 1, + ECT1: 2, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + // acknowledge packet 10 now, but don't increase the largest acked + _, err = handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{{Largest: 12, Smallest: 10}}, + ECT0: 1, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("ignores ACKs that don't increase the largest acked", func() { + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + } + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(1), int64(2), int64(3)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool { + Expect(packets).To(HaveLen(1)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(11))) + return false + }) + _, err := handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{{Largest: 11, Smallest: 11}}, + ECT0: 1, + ECT1: 2, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + _, err = handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{{Largest: 11, Smallest: 10}}, + ECT0: 1, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("informs the congestion controller about CE events", func() { + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT0) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT0, 1200, false) + } + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(0), int64(0), int64(0)).Return(true) + cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(15), gomock.Any(), gomock.Any()) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 15, Smallest: 10}}}, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + }) }) From d1f6ea997cba051841632758ea22679817da90ba Mon Sep 17 00:00:00 2001 From: Ameagari <47713057+tanghaowillow@users.noreply.github.com> Date: Mon, 11 Sep 2023 23:05:31 +0800 Subject: [PATCH 23/43] save the RTT in non-0-RTT session tickets (#4042) * also send session ticket when 0-RTT is disabled for go1.21 * allow session ticket without transport parameters * do not include transport parameters for non-0RTT session ticket * remove the test assertion because it is not supported for go1.20 * Update internal/handshake/session_ticket.go Co-authored-by: Marten Seemann * add a 0-RTT argument to unmarshaling session tickets * bump sessionTicketRevision to 4 * check if non-0-RTT session ticket has expected length * change parameter order * add test checks --------- Co-authored-by: Marten Seemann --- internal/handshake/crypto_setup.go | 28 ++++++++------ internal/handshake/crypto_setup_test.go | 12 +++++- internal/handshake/session_ticket.go | 19 +++++++--- internal/handshake/session_ticket_test.go | 46 +++++++++++++++++++---- internal/qtls/go120.go | 8 ++-- internal/qtls/go121.go | 24 ++++++------ 6 files changed, 95 insertions(+), 42 deletions(-) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 332dbd7e..b5fb404b 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -127,7 +127,7 @@ func NewCryptoSetupServer( cs.allow0RTT = allow0RTT quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} - qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT) + qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket) addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr) cs.tlsConf = quicConf.TLSConfig @@ -347,10 +347,13 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo } func (h *cryptoSetup) getDataForSessionTicket() []byte { - return (&sessionTicket{ - Parameters: h.ourParams, - RTT: h.rttStats.SmoothedRTT(), - }).Marshal() + ticket := &sessionTicket{ + RTT: h.rttStats.SmoothedRTT(), + } + if h.allow0RTT { + ticket.Parameters = h.ourParams + } + return ticket.Marshal() } // GetSessionTicket generates a new session ticket. @@ -379,12 +382,16 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { return ticket, nil } -// accept0RTT is called for the server when receiving the client's session ticket. -// It decides whether to accept 0-RTT. -func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { +// handleSessionTicket is called for the server when receiving the client's session ticket. +// It reads parameters from the session ticket and decides whether to accept 0-RTT when the session ticket is used for 0-RTT. +func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool { var t sessionTicket - if err := t.Unmarshal(sessionTicketData); err != nil { - h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error()) + if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil { + h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error()) + return false + } + h.rttStats.SetInitialRTT(t.RTT) + if !using0RTT { return false } valid := h.ourParams.ValidFor0RTT(t.Parameters) @@ -397,7 +404,6 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { return false } h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) - h.rttStats.SetInitialRTT(t.RTT) return true } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index e40db500..9b2844f7 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -8,6 +8,8 @@ import ( "crypto/x509/pkix" "math/big" "net" + "runtime" + "strings" "time" mocktls "github.com/quic-go/quic-go/internal/mocks/tls" @@ -417,11 +419,13 @@ var _ = Describe("Crypto Setup TLS", func() { close(receivedSessionTicket) }) clientConf.ClientSessionCache = csc + const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. + serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( clientConf, serverConf, - clientOrigRTTStats, &utils.RTTStats{}, + clientOrigRTTStats, serverOrigRTTStats, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) @@ -434,9 +438,10 @@ var _ = Describe("Crypto Setup TLS", func() { csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) clientRTTStats := &utils.RTTStats{} + serverRTTStats := &utils.RTTStats{} client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( clientConf, serverConf, - clientRTTStats, &utils.RTTStats{}, + clientRTTStats, serverRTTStats, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) @@ -446,6 +451,9 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(server.ConnectionState().DidResume).To(BeTrue()) Expect(client.ConnectionState().DidResume).To(BeTrue()) Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) + if !strings.Contains(runtime.Version(), "go1.20") { + Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) + } }) It("doesn't use session resumption if the server disabled it", func() { diff --git a/internal/handshake/session_ticket.go b/internal/handshake/session_ticket.go index d3efeb29..9481af56 100644 --- a/internal/handshake/session_ticket.go +++ b/internal/handshake/session_ticket.go @@ -10,7 +10,7 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) -const sessionTicketRevision = 3 +const sessionTicketRevision = 4 type sessionTicket struct { Parameters *wire.TransportParameters @@ -21,10 +21,13 @@ func (t *sessionTicket) Marshal() []byte { b := make([]byte, 0, 256) b = quicvarint.Append(b, sessionTicketRevision) b = quicvarint.Append(b, uint64(t.RTT.Microseconds())) + if t.Parameters == nil { + return b + } return t.Parameters.MarshalForSessionTicket(b) } -func (t *sessionTicket) Unmarshal(b []byte) error { +func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error { r := bytes.NewReader(b) rev, err := quicvarint.Read(r) if err != nil { @@ -37,11 +40,15 @@ func (t *sessionTicket) Unmarshal(b []byte) error { if err != nil { return errors.New("failed to read RTT") } - var tp wire.TransportParameters - if err := tp.UnmarshalFromSessionTicket(r); err != nil { - return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) + if using0RTT { + var tp wire.TransportParameters + if err := tp.UnmarshalFromSessionTicket(r); err != nil { + return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) + } + t.Parameters = &tp + } else if r.Len() > 0 { + return fmt.Errorf("the session ticket has more bytes than expected") } - t.Parameters = &tp t.RTT = time.Duration(rtt) * time.Microsecond return nil } diff --git a/internal/handshake/session_ticket_test.go b/internal/handshake/session_ticket_test.go index 6f004de9..8e209d3a 100644 --- a/internal/handshake/session_ticket_test.go +++ b/internal/handshake/session_ticket_test.go @@ -11,7 +11,7 @@ import ( ) var _ = Describe("Session Ticket", func() { - It("marshals and unmarshals a session ticket", func() { + It("marshals and unmarshals a 0-RTT session ticket", func() { ticket := &sessionTicket{ Parameters: &wire.TransportParameters{ InitialMaxStreamDataBidiLocal: 1, @@ -22,33 +22,65 @@ var _ = Describe("Session Ticket", func() { RTT: 1337 * time.Microsecond, } var t sessionTicket - Expect(t.Unmarshal(ticket.Marshal())).To(Succeed()) + Expect(t.Unmarshal(ticket.Marshal(), true)).To(Succeed()) Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1)) Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2)) Expect(t.Parameters.ActiveConnectionIDLimit).To(BeEquivalentTo(10)) Expect(t.Parameters.MaxDatagramFrameSize).To(BeEquivalentTo(20)) Expect(t.RTT).To(Equal(1337 * time.Microsecond)) + // fails to unmarshal the ticket as a non-0-RTT ticket + Expect(t.Unmarshal(ticket.Marshal(), false)).To(MatchError("the session ticket has more bytes than expected")) + }) + + It("marshals and unmarshals a non-0-RTT session ticket", func() { + ticket := &sessionTicket{ + RTT: 1337 * time.Microsecond, + } + var t sessionTicket + Expect(t.Unmarshal(ticket.Marshal(), false)).To(Succeed()) + Expect(t.Parameters).To(BeNil()) + Expect(t.RTT).To(Equal(1337 * time.Microsecond)) + // fails to unmarshal the ticket as a 0-RTT ticket + Expect(t.Unmarshal(ticket.Marshal(), true)).To(MatchError(ContainSubstring("unmarshaling transport parameters from session ticket failed"))) }) It("refuses to unmarshal if the ticket is too short for the revision", func() { - Expect((&sessionTicket{}).Unmarshal([]byte{})).To(MatchError("failed to read session ticket revision")) + Expect((&sessionTicket{}).Unmarshal([]byte{}, true)).To(MatchError("failed to read session ticket revision")) + Expect((&sessionTicket{}).Unmarshal([]byte{}, false)).To(MatchError("failed to read session ticket revision")) }) It("refuses to unmarshal if the revision doesn't match", func() { b := quicvarint.Append(nil, 1337) - Expect((&sessionTicket{}).Unmarshal(b)).To(MatchError("unknown session ticket revision: 1337")) + Expect((&sessionTicket{}).Unmarshal(b, true)).To(MatchError("unknown session ticket revision: 1337")) + Expect((&sessionTicket{}).Unmarshal(b, false)).To(MatchError("unknown session ticket revision: 1337")) }) It("refuses to unmarshal if the RTT cannot be read", func() { b := quicvarint.Append(nil, sessionTicketRevision) - Expect((&sessionTicket{}).Unmarshal(b)).To(MatchError("failed to read RTT")) + Expect((&sessionTicket{}).Unmarshal(b, true)).To(MatchError("failed to read RTT")) + Expect((&sessionTicket{}).Unmarshal(b, false)).To(MatchError("failed to read RTT")) }) - It("refuses to unmarshal if unmarshaling the transport parameters fails", func() { + It("refuses to unmarshal a 0-RTT session ticket if unmarshaling the transport parameters fails", func() { b := quicvarint.Append(nil, sessionTicketRevision) b = append(b, []byte("foobar")...) - err := (&sessionTicket{}).Unmarshal(b) + err := (&sessionTicket{}).Unmarshal(b, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("unmarshaling transport parameters from session ticket failed")) }) + + It("refuses to unmarshal if the non-0-RTT session ticket has more bytes than expected", func() { + ticket := &sessionTicket{ + Parameters: &wire.TransportParameters{ + InitialMaxStreamDataBidiLocal: 1, + InitialMaxStreamDataBidiRemote: 2, + ActiveConnectionIDLimit: 10, + MaxDatagramFrameSize: 20, + }, + RTT: 1234 * time.Microsecond, + } + err := (&sessionTicket{}).Unmarshal(ticket.Marshal(), false) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("the session ticket has more bytes than expected")) + }) }) diff --git a/internal/qtls/go120.go b/internal/qtls/go120.go index 3b50441c..898f0352 100644 --- a/internal/qtls/go120.go +++ b/internal/qtls/go120.go @@ -39,13 +39,15 @@ const ( QUICHandshakeDone = qtls.QUICHandshakeDone ) -func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, accept0RTT func([]byte) bool) { +func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, handleSessionTicket func([]byte, bool) bool) { qtls.InitSessionTicketKeys(conf.TLSConfig) conf.TLSConfig = conf.TLSConfig.Clone() conf.TLSConfig.MinVersion = tls.VersionTLS13 conf.ExtraConfig = &qtls.ExtraConfig{ - Enable0RTT: enable0RTT, - Accept0RTT: accept0RTT, + Enable0RTT: enable0RTT, + Accept0RTT: func(data []byte) bool { + return handleSessionTicket(data, true) + }, GetAppDataForSessionTicket: getDataForSessionTicket, } } diff --git a/internal/qtls/go121.go b/internal/qtls/go121.go index 4aebc4a1..65a274ac 100644 --- a/internal/qtls/go121.go +++ b/internal/qtls/go121.go @@ -41,7 +41,7 @@ const ( func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) } func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) } -func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, accept0RTT func([]byte) bool) { +func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) { conf := qconf.TLSConfig // Workaround for https://github.com/golang/go/issues/60506. @@ -55,11 +55,9 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce // add callbacks to save transport parameters into the session ticket origWrapSession := conf.WrapSession conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) { - // Add QUIC transport parameters if this is a 0-RTT packet. - // TODO(#3853): also save the RTT for non-0-RTT tickets - if state.EarlyData { - state.Extra = append(state.Extra, addExtraPrefix(getData())) - } + // Add QUIC session ticket + state.Extra = append(state.Extra, addExtraPrefix(getData())) + if origWrapSession != nil { return origWrapSession(cs, state) } @@ -83,14 +81,14 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce if err != nil || state == nil { return nil, err } - if state.EarlyData { - extra := findExtraData(state.Extra) - if unwrapCount == 1 && extra != nil { // first session ticket - state.EarlyData = accept0RTT(extra) - } else { // subsequent session ticket, can't be used for 0-RTT - state.EarlyData = false - } + + extra := findExtraData(state.Extra) + if extra != nil { + state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1) + } else { + state.EarlyData = false } + return state, nil } } From 2a8dc12a530bfbe4ee9e4d6d54d26ececbfeed79 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 12 Sep 2023 13:12:13 +0700 Subject: [PATCH 24/43] remove duplicate mocks for the Tracer and the ConnectionTracer (#4076) They were introduced to avoid an import loop in the tests in the logging package, but the same can be achieved by having a dedicated package for the tests (logging_test). --- logging/logging_suite_test.go | 2 +- logging/mock_connection_tracer_test.go | 387 ------------------------- logging/mock_tracer_test.go | 73 ----- logging/mockgen.go | 4 - logging/multiplex_test.go | 22 +- logging/packet_header_test.go | 3 +- 6 files changed, 15 insertions(+), 476 deletions(-) delete mode 100644 logging/mock_connection_tracer_test.go delete mode 100644 logging/mock_tracer_test.go delete mode 100644 logging/mockgen.go diff --git a/logging/logging_suite_test.go b/logging/logging_suite_test.go index d808adfe..e595313d 100644 --- a/logging/logging_suite_test.go +++ b/logging/logging_suite_test.go @@ -1,4 +1,4 @@ -package logging +package logging_test import ( "testing" diff --git a/logging/mock_connection_tracer_test.go b/logging/mock_connection_tracer_test.go deleted file mode 100644 index 94722509..00000000 --- a/logging/mock_connection_tracer_test.go +++ /dev/null @@ -1,387 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go/logging (interfaces: ConnectionTracer) - -// Package logging is a generated GoMock package. -package logging - -import ( - net "net" - reflect "reflect" - time "time" - - protocol "github.com/quic-go/quic-go/internal/protocol" - utils "github.com/quic-go/quic-go/internal/utils" - wire "github.com/quic-go/quic-go/internal/wire" - gomock "go.uber.org/mock/gomock" -) - -// MockConnectionTracer is a mock of ConnectionTracer interface. -type MockConnectionTracer struct { - ctrl *gomock.Controller - recorder *MockConnectionTracerMockRecorder -} - -// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. -type MockConnectionTracerMockRecorder struct { - mock *MockConnectionTracer -} - -// NewMockConnectionTracer creates a new mock instance. -func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { - mock := &MockConnectionTracer{ctrl: ctrl} - mock.recorder = &MockConnectionTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { - return m.recorder -} - -// AcknowledgedPacket mocks base method. -func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) -} - -// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. -func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) -} - -// BufferedPacket mocks base method. -func (m *MockConnectionTracer) BufferedPacket(arg0 PacketType, arg1 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "BufferedPacket", arg0, arg1) -} - -// BufferedPacket indicates an expected call of BufferedPacket. -func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0, arg1) -} - -// Close mocks base method. -func (m *MockConnectionTracer) Close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Close") -} - -// Close indicates an expected call of Close. -func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) -} - -// ClosedConnection mocks base method. -func (m *MockConnectionTracer) ClosedConnection(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ClosedConnection", arg0) -} - -// ClosedConnection indicates an expected call of ClosedConnection. -func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) -} - -// Debug mocks base method. -func (m *MockConnectionTracer) Debug(arg0, arg1 string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Debug", arg0, arg1) -} - -// Debug indicates an expected call of Debug. -func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) -} - -// DroppedEncryptionLevel mocks base method. -func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) -} - -// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. -func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) -} - -// DroppedKey mocks base method. -func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedKey", arg0) -} - -// DroppedKey indicates an expected call of DroppedKey. -func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) -} - -// DroppedPacket mocks base method. -func (m *MockConnectionTracer) DroppedPacket(arg0 PacketType, arg1 protocol.ByteCount, arg2 PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) -} - -// ECNStateUpdated mocks base method. -func (m *MockConnectionTracer) ECNStateUpdated(arg0 ECNState, arg1 ECNStateTrigger) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ECNStateUpdated", arg0, arg1) -} - -// ECNStateUpdated indicates an expected call of ECNStateUpdated. -func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1) -} - -// LossTimerCanceled mocks base method. -func (m *MockConnectionTracer) LossTimerCanceled() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerCanceled") -} - -// LossTimerCanceled indicates an expected call of LossTimerCanceled. -func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) -} - -// LossTimerExpired mocks base method. -func (m *MockConnectionTracer) LossTimerExpired(arg0 TimerType, arg1 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) -} - -// LossTimerExpired indicates an expected call of LossTimerExpired. -func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) -} - -// LostPacket mocks base method. -func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 PacketLossReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) -} - -// LostPacket indicates an expected call of LostPacket. -func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) -} - -// NegotiatedVersion mocks base method. -func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) -} - -// NegotiatedVersion indicates an expected call of NegotiatedVersion. -func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) -} - -// ReceivedLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2, arg3) -} - -// ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3) -} - -// ReceivedRetry mocks base method. -func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedRetry", arg0) -} - -// ReceivedRetry indicates an expected call of ReceivedRetry. -func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) -} - -// ReceivedShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2, arg3) -} - -// ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3) -} - -// ReceivedTransportParameters mocks base method. -func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedTransportParameters", arg0) -} - -// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. -func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) -} - -// ReceivedVersionNegotiationPacket mocks base method. -func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0, arg1 protocol.ArbitraryLenConnectionID, arg2 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1, arg2) -} - -// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1, arg2) -} - -// RestoredTransportParameters mocks base method. -func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RestoredTransportParameters", arg0) -} - -// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. -func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) -} - -// SentLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3, arg4) -} - -// SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4) -} - -// SentShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3, arg4) -} - -// SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4) -} - -// SentTransportParameters mocks base method. -func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentTransportParameters", arg0) -} - -// SentTransportParameters indicates an expected call of SentTransportParameters. -func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) -} - -// SetLossTimer mocks base method. -func (m *MockConnectionTracer) SetLossTimer(arg0 TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) -} - -// SetLossTimer indicates an expected call of SetLossTimer. -func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) -} - -// StartedConnection mocks base method. -func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) -} - -// StartedConnection indicates an expected call of StartedConnection. -func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) -} - -// UpdatedCongestionState mocks base method. -func (m *MockConnectionTracer) UpdatedCongestionState(arg0 CongestionState) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedCongestionState", arg0) -} - -// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. -func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) -} - -// UpdatedKey mocks base method. -func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKey", arg0, arg1) -} - -// UpdatedKey indicates an expected call of UpdatedKey. -func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) -} - -// UpdatedKeyFromTLS mocks base method. -func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) -} - -// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. -func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) -} - -// UpdatedMetrics mocks base method. -func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) -} - -// UpdatedMetrics indicates an expected call of UpdatedMetrics. -func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) -} - -// UpdatedPTOCount mocks base method. -func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedPTOCount", arg0) -} - -// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. -func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) -} diff --git a/logging/mock_tracer_test.go b/logging/mock_tracer_test.go deleted file mode 100644 index 61eb47ea..00000000 --- a/logging/mock_tracer_test.go +++ /dev/null @@ -1,73 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go/logging (interfaces: Tracer) - -// Package logging is a generated GoMock package. -package logging - -import ( - net "net" - reflect "reflect" - - protocol "github.com/quic-go/quic-go/internal/protocol" - wire "github.com/quic-go/quic-go/internal/wire" - gomock "go.uber.org/mock/gomock" -) - -// MockTracer is a mock of Tracer interface. -type MockTracer struct { - ctrl *gomock.Controller - recorder *MockTracerMockRecorder -} - -// MockTracerMockRecorder is the mock recorder for MockTracer. -type MockTracerMockRecorder struct { - mock *MockTracer -} - -// NewMockTracer creates a new mock instance. -func NewMockTracer(ctrl *gomock.Controller) *MockTracer { - mock := &MockTracer{ctrl: ctrl} - mock.recorder = &MockTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTracer) EXPECT() *MockTracerMockRecorder { - return m.recorder -} - -// DroppedPacket mocks base method. -func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 PacketType, arg2 protocol.ByteCount, arg3 PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) -} - -// SentPacket mocks base method. -func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) -} - -// SentVersionNegotiationPacket mocks base method. -func (m *MockTracer) SentVersionNegotiationPacket(arg0 net.Addr, arg1, arg2 protocol.ArbitraryLenConnectionID, arg3 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentVersionNegotiationPacket", arg0, arg1, arg2, arg3) -} - -// SentVersionNegotiationPacket indicates an expected call of SentVersionNegotiationPacket. -func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) -} diff --git a/logging/mockgen.go b/logging/mockgen.go deleted file mode 100644 index e44c15e1..00000000 --- a/logging/mockgen.go +++ /dev/null @@ -1,4 +0,0 @@ -package logging - -//go:generate sh -c "go run go.uber.org/mock/mockgen -package logging -self_package github.com/quic-go/quic-go/logging -destination mock_connection_tracer_test.go github.com/quic-go/quic-go/logging ConnectionTracer" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package logging -self_package github.com/quic-go/quic-go/logging -destination mock_tracer_test.go github.com/quic-go/quic-go/logging Tracer" diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index 572ced6a..c0f784ec 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -1,12 +1,14 @@ -package logging +package logging_test import ( "errors" "net" "time" + mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" + . "github.com/quic-go/quic-go/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -19,20 +21,20 @@ var _ = Describe("Tracing", func() { }) It("returns the raw tracer if only one tracer is passed in", func() { - tr := NewMockTracer(mockCtrl) + tr := mocklogging.NewMockTracer(mockCtrl) tracer := NewMultiplexedTracer(tr) - Expect(tracer).To(BeAssignableToTypeOf(&MockTracer{})) + Expect(tracer).To(BeAssignableToTypeOf(&mocklogging.MockTracer{})) }) Context("tracing events", func() { var ( tracer Tracer - tr1, tr2 *MockTracer + tr1, tr2 *mocklogging.MockTracer ) BeforeEach(func() { - tr1 = NewMockTracer(mockCtrl) - tr2 = NewMockTracer(mockCtrl) + tr1 = mocklogging.NewMockTracer(mockCtrl) + tr2 = mocklogging.NewMockTracer(mockCtrl) tracer = NewMultiplexedTracer(tr1, tr2) }) @@ -67,13 +69,13 @@ var _ = Describe("Tracing", func() { Context("Connection Tracer", func() { var ( tracer ConnectionTracer - tr1 *MockConnectionTracer - tr2 *MockConnectionTracer + tr1 *mocklogging.MockConnectionTracer + tr2 *mocklogging.MockConnectionTracer ) BeforeEach(func() { - tr1 = NewMockConnectionTracer(mockCtrl) - tr2 = NewMockConnectionTracer(mockCtrl) + tr1 = mocklogging.NewMockConnectionTracer(mockCtrl) + tr2 = mocklogging.NewMockConnectionTracer(mockCtrl) tracer = NewMultiplexedConnectionTracer(tr1, tr2) }) diff --git a/logging/packet_header_test.go b/logging/packet_header_test.go index 6df41925..66810a9c 100644 --- a/logging/packet_header_test.go +++ b/logging/packet_header_test.go @@ -1,8 +1,9 @@ -package logging +package logging_test import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" + . "github.com/quic-go/quic-go/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" From d52e9f35bc58864196226c8d9f7129f14315ee4d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 12 Sep 2023 13:18:33 +0700 Subject: [PATCH 25/43] ackhandler: detect ECN mangling (#4080) * ackhandler: detect ECN mangling Mangling means that a path is re-marking all ECN-marked packets as CE. * ackhandler: only detect ECN mangling once all testing packets were sent --- internal/ackhandler/ecn.go | 23 +++++++++++++++++------ internal/ackhandler/ecn_test.go | 30 ++++++++++++++++++++++++++++++ logging/types.go | 4 +++- qlog/types.go | 2 ++ qlog/types_test.go | 1 + 5 files changed, 53 insertions(+), 7 deletions(-) diff --git a/internal/ackhandler/ecn.go b/internal/ackhandler/ecn.go index ff0dfe65..bb8b37f8 100644 --- a/internal/ackhandler/ecn.go +++ b/internal/ackhandler/ecn.go @@ -218,7 +218,21 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 return false } + // update our counters + e.numAckedECT0 = ect0 + e.numAckedECT1 = ect1 + e.numAckedECNCE = ecnce + if e.state == ecnStateTesting || e.state == ecnStateUnknown { + // Detect mangling (a path remarking all ECN-marked testing packets as CE). + if e.numSentECT0+e.numSentECT1 == e.numAckedECNCE && e.numAckedECNCE >= numECNTestingPackets { + if e.tracer != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + } + e.state = ecnStateFailed + return false + } + var ackedTestingPacket bool for _, p := range packets { if e.isTestingPacket(p.PacketNumber) { @@ -236,12 +250,9 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 } } - // update our counters - e.numAckedECT0 = ect0 - e.numAckedECT1 = ect1 - e.numAckedECNCE = ecnce - - return newECNCE > 0 + // Don't trust CE marks before having confirmed ECN capability of the path. + // Otherwise, mangling would be misinterpreted as actual congestion. + return e.state == ecnStateCapable && newECNCE > 0 } func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN { diff --git a/internal/ackhandler/ecn_test.go b/internal/ackhandler/ecn_test.go index 74fff094..984a97a9 100644 --- a/internal/ackhandler/ecn_test.go +++ b/internal/ackhandler/ecn_test.go @@ -175,6 +175,36 @@ var _ = Describe("ECN tracker", func() { Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 15), 3, 0, 2)).To(BeFalse()) }) + It("detects ECN mangling", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // ECN capability not confirmed yet, therefore CE marks are not regarded as congestion events + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(0, 1, 2, 3), 0, 0, 4)).To(BeFalse()) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 10, 11, 12), 0, 0, 7)).To(BeFalse()) + // With the next ACK, all testing packets will now have been marked CE. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(7, 8, 9, 13), 0, 0, 10)).To(BeFalse()) + }) + + It("only detects ECN mangling after sending all testing packets", func() { + tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + for i := 0; i < 9; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(protocol.PacketNumber(i)), 0, 0, int64(i+1))).To(BeFalse()) + } + // Send the last testing packet, and receive a + tracer.EXPECT().ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(9, protocol.ECT0) + // This ACK now reports the last testing packets as CE as well. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(9), 0, 0, 10)).To(BeFalse()) + }) + It("declares congestion", func() { sendAllTestingPackets() for i := 10; i < 20; i++ { diff --git a/logging/types.go b/logging/types.go index bfee91c3..0d79b0a9 100644 --- a/logging/types.go +++ b/logging/types.go @@ -121,6 +121,8 @@ const ( ECNFailedLostAllTestingPackets // ECNFailedMoreECNCountsThanSent is emitted when an ACK contains more ECN counts than ECN-marked packets were sent ECNFailedMoreECNCountsThanSent - // ECNFailedTooFewECNCounts is emitted when contains fewer ECT(0) / ECT(1) counts than it acknowledges packets + // ECNFailedTooFewECNCounts is emitted when an ACK contains fewer ECN counts than it acknowledges packets ECNFailedTooFewECNCounts + // ECNFailedManglingDetected is emitted when the path marks all ECN-marked packets as CE + ECNFailedManglingDetected ) diff --git a/qlog/types.go b/qlog/types.go index 49b4853a..e0c0b243 100644 --- a/qlog/types.go +++ b/qlog/types.go @@ -364,6 +364,8 @@ func (e ecnStateTrigger) String() string { return "ACK contains more ECN counts than ECN-marked packets sent" case logging.ECNFailedTooFewECNCounts: return "ACK contains fewer new ECN counts than acknowledged ECN-marked packets" + case logging.ECNFailedManglingDetected: + return "ECN mangling detected" default: return "unknown ECN state trigger" } diff --git a/qlog/types_test.go b/qlog/types_test.go index d08eb739..d6be3e68 100644 --- a/qlog/types_test.go +++ b/qlog/types_test.go @@ -151,6 +151,7 @@ var _ = Describe("Types", func() { Expect(ecnStateTrigger(logging.ECNFailedLostAllTestingPackets).String()).To(Equal("all ECN testing packets declared lost")) Expect(ecnStateTrigger(logging.ECNFailedMoreECNCountsThanSent).String()).To(Equal("ACK contains more ECN counts than ECN-marked packets sent")) Expect(ecnStateTrigger(logging.ECNFailedTooFewECNCounts).String()).To(Equal("ACK contains fewer new ECN counts than acknowledged ECN-marked packets")) + Expect(ecnStateTrigger(logging.ECNFailedManglingDetected).String()).To(Equal("ECN mangling detected")) Expect(ecnStateTrigger(42).String()).To(Equal("unknown ECN state trigger")) }) }) From 7599f81faf622e3cd72560d425b593c2411ff0a8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 12 Sep 2023 18:37:12 +0700 Subject: [PATCH 26/43] ci: clean up Codecov ignore list (#4081) --- codecov.yml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/codecov.yml b/codecov.yml index 69435150..0b95a103 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,17 +1,9 @@ coverage: round: nearest ignore: - - streams_map_incoming_bidi.go - - streams_map_incoming_uni.go - - streams_map_outgoing_bidi.go - - streams_map_outgoing_uni.go - http3/gzip_reader.go - interop/ - - internal/ackhandler/packet_linkedlist.go - internal/handshake/cipher_suite.go - - internal/utils/byteinterval_linkedlist.go - - internal/utils/newconnectionid_linkedlist.go - - internal/utils/packetinterval_linkedlist.go - internal/utils/linkedlist/linkedlist.go - logging/null_tracer.go - fuzzing/ From 37a3c417a7c56167f9eeff462ddb1ed4258025b8 Mon Sep 17 00:00:00 2001 From: Benedikt Spies Date: Thu, 14 Sep 2023 08:55:49 +0200 Subject: [PATCH 27/43] expose GSO usage through ConnectionState (#4083) --- connection.go | 1 + interface.go | 2 ++ 2 files changed, 3 insertions(+) diff --git a/connection.go b/connection.go index 10a36ba0..d6cc886d 100644 --- a/connection.go +++ b/connection.go @@ -673,6 +673,7 @@ func (s *connection) ConnectionState() ConnectionState { cs := s.cryptoStreamHandler.ConnectionState() s.connState.TLS = cs.ConnectionState s.connState.Used0RTT = cs.Used0RTT + s.connState.GSO = s.conn.capabilities().GSO return s.connState } diff --git a/interface.go b/interface.go index 8faf0eb8..746ad12e 100644 --- a/interface.go +++ b/interface.go @@ -345,4 +345,6 @@ type ConnectionState struct { Used0RTT bool // Version is the QUIC version of the QUIC connection. Version VersionNumber + // GSO says if generic segmentation offload is used + GSO bool } From 862e64c7b928cec7a37f7c383e2cfeb10efd86fa Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 15 Sep 2023 18:33:57 +0700 Subject: [PATCH 28/43] add a Transport config option for the key used to encrypt tokens (#4066) * add a Transport config option for the key used to encrypt tokens * handshake: remove unused error return values --- connection_test.go | 3 +- fuzzing/tokens/fuzz.go | 14 +++-- interface.go | 3 ++ internal/handshake/token_generator.go | 13 ++--- internal/handshake/token_generator_test.go | 6 +-- internal/handshake/token_protector.go | 33 +++++------- internal/handshake/token_protector_test.go | 59 +++++++++++----------- server.go | 12 ++--- transport.go | 20 ++++++-- 9 files changed, 78 insertions(+), 85 deletions(-) diff --git a/connection_test.go b/connection_test.go index 67b02443..44564f84 100644 --- a/connection_test.go +++ b/connection_test.go @@ -100,8 +100,7 @@ var _ = Describe("Connection", func() { mconn.EXPECT().capabilities().DoAndReturn(func() connCapabilities { return capabilities }).AnyTimes() mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() - tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) - Expect(err).ToNot(HaveOccurred()) + tokenGenerator := handshake.NewTokenGenerator([32]byte{0xa, 0xb, 0xc}) tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentTransportParameters(gomock.Any()) diff --git a/fuzzing/tokens/fuzz.go b/fuzzing/tokens/fuzz.go index ea261fb6..d0043ead 100644 --- a/fuzzing/tokens/fuzz.go +++ b/fuzzing/tokens/fuzz.go @@ -2,24 +2,22 @@ package tokens import ( "encoding/binary" - "math/rand" "net" "time" + "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" ) func Fuzz(data []byte) int { - if len(data) < 8 { + if len(data) < 32 { return -1 } - seed := binary.BigEndian.Uint64(data[:8]) - data = data[8:] - tg, err := handshake.NewTokenGenerator(rand.New(rand.NewSource(int64(seed)))) - if err != nil { - panic(err) - } + var key quic.TokenGeneratorKey + copy(key[:], data[:32]) + data = data[32:] + tg := handshake.NewTokenGenerator(key) if len(data) < 1 { return -1 } diff --git a/interface.go b/interface.go index 746ad12e..f5ee28d8 100644 --- a/interface.go +++ b/interface.go @@ -212,6 +212,9 @@ type EarlyConnection interface { // StatelessResetKey is a key used to derive stateless reset tokens. type StatelessResetKey [32]byte +// TokenGeneratorKey is a key used to encrypt session resumption tokens. +type TokenGeneratorKey = handshake.TokenProtectorKey + // A ConnectionID is a QUIC Connection ID, as defined in RFC 9000. // It is not able to handle QUIC Connection IDs longer than 20 bytes, // as they are allowed by RFC 8999. diff --git a/internal/handshake/token_generator.go b/internal/handshake/token_generator.go index e5e90bb3..2d91e6b2 100644 --- a/internal/handshake/token_generator.go +++ b/internal/handshake/token_generator.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/asn1" "fmt" - "io" "net" "time" @@ -45,15 +44,9 @@ type TokenGenerator struct { tokenProtector tokenProtector } -// NewTokenGenerator initializes a new TookenGenerator -func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) { - tokenProtector, err := newTokenProtector(rand) - if err != nil { - return nil, err - } - return &TokenGenerator{ - tokenProtector: tokenProtector, - }, nil +// NewTokenGenerator initializes a new TokenGenerator +func NewTokenGenerator(key TokenProtectorKey) *TokenGenerator { + return &TokenGenerator{tokenProtector: newTokenProtector(key)} } // NewRetryToken generates a new token for a Retry for a given source address diff --git a/internal/handshake/token_generator_test.go b/internal/handshake/token_generator_test.go index 870ba3ed..7abc53b4 100644 --- a/internal/handshake/token_generator_test.go +++ b/internal/handshake/token_generator_test.go @@ -16,9 +16,9 @@ var _ = Describe("Token Generator", func() { var tokenGen *TokenGenerator BeforeEach(func() { - var err error - tokenGen, err = NewTokenGenerator(rand.Reader) - Expect(err).ToNot(HaveOccurred()) + var key TokenProtectorKey + rand.Read(key[:]) + tokenGen = NewTokenGenerator(key) }) It("generates a token", func() { diff --git a/internal/handshake/token_protector.go b/internal/handshake/token_protector.go index 650f230b..6dcf7f77 100644 --- a/internal/handshake/token_protector.go +++ b/internal/handshake/token_protector.go @@ -3,6 +3,7 @@ package handshake import ( "crypto/aes" "crypto/cipher" + "crypto/rand" "crypto/sha256" "fmt" "io" @@ -10,6 +11,9 @@ import ( "golang.org/x/crypto/hkdf" ) +// TokenProtectorKey is the key used to encrypt both Retry and session resumption tokens. +type TokenProtectorKey [32]byte + // TokenProtector is used to create and verify a token type tokenProtector interface { // NewToken creates a new token @@ -18,40 +22,29 @@ type tokenProtector interface { DecodeToken([]byte) ([]byte, error) } -const ( - tokenSecretSize = 32 - tokenNonceSize = 32 -) +const tokenNonceSize = 32 // tokenProtector is used to create and verify a token type tokenProtectorImpl struct { - rand io.Reader - secret []byte + key TokenProtectorKey } // newTokenProtector creates a source for source address tokens -func newTokenProtector(rand io.Reader) (tokenProtector, error) { - secret := make([]byte, tokenSecretSize) - if _, err := rand.Read(secret); err != nil { - return nil, err - } - return &tokenProtectorImpl{ - rand: rand, - secret: secret, - }, nil +func newTokenProtector(key TokenProtectorKey) tokenProtector { + return &tokenProtectorImpl{key: key} } // NewToken encodes data into a new token. func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) { - nonce := make([]byte, tokenNonceSize) - if _, err := s.rand.Read(nonce); err != nil { + var nonce [tokenNonceSize]byte + if _, err := rand.Read(nonce[:]); err != nil { return nil, err } - aead, aeadNonce, err := s.createAEAD(nonce) + aead, aeadNonce, err := s.createAEAD(nonce[:]) if err != nil { return nil, err } - return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil + return append(nonce[:], aead.Seal(nil, aeadNonce, data, nil)...), nil } // DecodeToken decodes a token. @@ -68,7 +61,7 @@ func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) { } func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { - h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source")) + h := hkdf.New(sha256.New, s.key[:], nonce[:], []byte("quic-go token source")) key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 if _, err := io.ReadFull(h, key); err != nil { return nil, nil, err diff --git a/internal/handshake/token_protector_test.go b/internal/handshake/token_protector_test.go index 03cd5320..74eb1f0c 100644 --- a/internal/handshake/token_protector_test.go +++ b/internal/handshake/token_protector_test.go @@ -7,41 +7,17 @@ import ( . "github.com/onsi/gomega" ) -type zeroReader struct{} - -func (r *zeroReader) Read(b []byte) (int, error) { - for i := range b { - b[i] = 0 - } - return len(b), nil -} - var _ = Describe("Token Protector", func() { var tp tokenProtector BeforeEach(func() { + var key TokenProtectorKey + rand.Read(key[:]) var err error - tp, err = newTokenProtector(rand.Reader) + tp = newTokenProtector(key) Expect(err).ToNot(HaveOccurred()) }) - It("uses the random source", func() { - tp1, err := newTokenProtector(&zeroReader{}) - Expect(err).ToNot(HaveOccurred()) - tp2, err := newTokenProtector(&zeroReader{}) - Expect(err).ToNot(HaveOccurred()) - t1, err := tp1.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - t2, err := tp2.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(t1).To(Equal(t2)) - tp3, err := newTokenProtector(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - t3, err := tp3.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(t3).ToNot(Equal(t1)) - }) - It("encodes and decodes tokens", func() { token, err := tp.NewToken([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) @@ -51,11 +27,34 @@ var _ = Describe("Token Protector", func() { Expect(decoded).To(Equal([]byte("foobar"))) }) - It("fails deconding invalid tokens", func() { + It("uses the different keys", func() { + var key1, key2 TokenProtectorKey + rand.Read(key1[:]) + rand.Read(key2[:]) + tp1 := newTokenProtector(key1) + tp2 := newTokenProtector(key2) + t1, err := tp1.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + t2, err := tp2.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + + _, err = tp1.DecodeToken(t1) + Expect(err).ToNot(HaveOccurred()) + _, err = tp1.DecodeToken(t2) + Expect(err).To(HaveOccurred()) + + // now create another token protector, reusing key1 + tp3 := newTokenProtector(key1) + _, err = tp3.DecodeToken(t1) + Expect(err).ToNot(HaveOccurred()) + _, err = tp3.DecodeToken(t2) + Expect(err).To(HaveOccurred()) + }) + + It("doesn't decode invalid tokens", func() { token, err := tp.NewToken([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) - token = token[1:] // remove the first byte - _, err = tp.DecodeToken(token) + _, err = tp.DecodeToken(token[1:]) // the token is invalid without the first byte Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("message authentication failed")) }) diff --git a/server.go b/server.go index d41c6465..671ae971 100644 --- a/server.go +++ b/server.go @@ -2,7 +2,6 @@ package quic import ( "context" - "crypto/rand" "crypto/tls" "errors" "fmt" @@ -227,18 +226,15 @@ func newServer( config *Config, tracer logging.Tracer, onClose func(), + tokenGeneratorKey TokenGeneratorKey, disableVersionNegotiation bool, acceptEarly bool, -) (*baseServer, error) { - tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) - if err != nil { - return nil, err - } +) *baseServer { s := &baseServer{ conn: conn, tlsConf: tlsConf, config: config, - tokenGenerator: tokenGenerator, + tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey), connIDGenerator: connIDGenerator, connHandler: connHandler, connQueue: make(chan quicConn), @@ -260,7 +256,7 @@ func newServer( go s.run() go s.runSendQueue() s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) - return s, nil + return s } func (s *baseServer) run() { diff --git a/transport.go b/transport.go index c021be77..b5855329 100644 --- a/transport.go +++ b/transport.go @@ -57,6 +57,12 @@ type Transport struct { // See section 10.3 of RFC 9000 for details. StatelessResetKey *StatelessResetKey + // The TokenGeneratorKey is used to encrypt session resumption tokens. + // If no key is configured, a random key will be generated. + // If multiple servers are authoritative for the same domain, they should use the same key, + // see section 8.1.3 of RFC 9000 for details. + TokenGeneratorKey *TokenGeneratorKey + // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. // This can be useful if version information is exchanged out-of-band. // It has no effect for clients. @@ -136,7 +142,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo if err := t.init(false); err != nil { return nil, err } - s, err := newServer( + s := newServer( t.conn, t.handlerMap, t.connIDGenerator, @@ -144,12 +150,10 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo conf, t.Tracer, t.closeServer, + *t.TokenGeneratorKey, t.DisableVersionNegotiationPackets, allow0RTT, ) - if err != nil { - return nil, err - } t.server = s return s, nil } @@ -203,6 +207,14 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.closeQueue = make(chan closePacket, 4) t.statelessResetQueue = make(chan receivedPacket, 4) + if t.TokenGeneratorKey == nil { + var key TokenGeneratorKey + if _, err := rand.Read(key[:]); err != nil { + t.initErr = err + return + } + t.TokenGeneratorKey = &key + } if t.ConnectionIDGenerator != nil { t.connIDGenerator = t.ConnectionIDGenerator From c1ce4a8e9209fca2a16ab362ab93a2e94eec1c27 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 15 Sep 2023 18:35:09 +0700 Subject: [PATCH 29/43] http09: increase the startup timeout in tests (#4071) --- interop/http09/http_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/interop/http09/http_test.go b/interop/http09/http_test.go index f2d48994..03f39263 100644 --- a/interop/http09/http_test.go +++ b/interop/http09/http_test.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "net/http/httptest" + "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/testdata" @@ -42,7 +43,7 @@ var _ = Describe("HTTP 0.9 integration tests", func() { defer server.mutex.Unlock() ln = server.listener return server.listener - }).ShouldNot(BeNil()) + }, 5*time.Second).ShouldNot(BeNil()) saddr = ln.Addr() saddr.(*net.UDPAddr).IP = net.IP{127, 0, 0, 1} }) From 5b25d8b5be05f6e81c7db4934f0276ba924e6c05 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 15 Sep 2023 18:35:53 +0700 Subject: [PATCH 30/43] ci: fail if any Go files contain an ignore directive (#4055) --- .github/workflows/lint.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1c74c53b..ee56f6e4 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -11,6 +11,13 @@ jobs: go-version: "1.20.x" - name: Check that no non-test files import Ginkgo or Gomega run: .github/workflows/no_ginkgo.sh + - name: Check for //go:build ignore in .go files + run: | + IGNORED_FILES=$(grep -rl '//go:build ignore' . --include='*.go') || true + if [ -n "$IGNORED_FILES" ]; then + echo "::error::Found ignored Go files: $IGNORED_FILES" + exit 1 + fi - name: Check that go.mod is tidied run: | cp go.mod go.mod.orig From 22eac50276dd6e2d794e1dff72b3f0e25b1f43ab Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 15 Sep 2023 23:56:20 +0700 Subject: [PATCH 31/43] ci: combine the go generate workflow with the linting workflow (#4053) * ci: combine the go generate workflow with the linting workflow * reorder --- .github/workflows/go-generate.yml | 13 ------------- .github/workflows/lint.yml | 5 +++++ 2 files changed, 5 insertions(+), 13 deletions(-) delete mode 100644 .github/workflows/go-generate.yml diff --git a/.github/workflows/go-generate.yml b/.github/workflows/go-generate.yml deleted file mode 100644 index 4add84e3..00000000 --- a/.github/workflows/go-generate.yml +++ /dev/null @@ -1,13 +0,0 @@ -on: [push, pull_request] -jobs: - gogenerate: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 - with: - go-version: "1.20.x" - - name: Install dependencies - run: go build - - name: Run code generators - run: .github/workflows/go-generate.sh diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index ee56f6e4..1d564861 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -19,13 +19,18 @@ jobs: exit 1 fi - name: Check that go.mod is tidied + if: success() || failure() # run this step even if the previous one failed run: | cp go.mod go.mod.orig cp go.sum go.sum.orig go mod tidy diff go.mod go.mod.orig diff go.sum go.sum.orig + - name: Run code generators + if: success() || failure() # run this step even if the previous one failed + run: .github/workflows/go-generate.sh - name: Check that go mod vendor works + if: success() || failure() # run this step even if the previous one failed run: | cd integrationtests/gomodvendor go mod vendor From ab1c1be9a9d07920b55418fbf1f37ffa2b3ff9a2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 15 Sep 2023 23:57:13 +0700 Subject: [PATCH 32/43] basic ClusterFuzzLite integration (#4034) --- .clusterfuzzlite/Dockerfile | 21 ++++++++++ .clusterfuzzlite/build.sh | 9 +++++ .clusterfuzzlite/project.yaml | 1 + .github/workflows/clusterfuzz-lite-pr.yml | 48 +++++++++++++++++++++++ 4 files changed, 79 insertions(+) create mode 100644 .clusterfuzzlite/Dockerfile create mode 100755 .clusterfuzzlite/build.sh create mode 100644 .clusterfuzzlite/project.yaml create mode 100644 .github/workflows/clusterfuzz-lite-pr.yml diff --git a/.clusterfuzzlite/Dockerfile b/.clusterfuzzlite/Dockerfile new file mode 100644 index 00000000..b57da8ca --- /dev/null +++ b/.clusterfuzzlite/Dockerfile @@ -0,0 +1,21 @@ +FROM gcr.io/oss-fuzz-base/base-builder-go:v1 + +ARG TARGETPLATFORM +RUN echo "TARGETPLATFORM: ${TARGETPLATFORM}" + +ENV GOVERSION=1.20.7 + +RUN platform=$(echo ${TARGETPLATFORM} | tr '/' '-') && \ + filename="go${GOVERSION}.${platform}.tar.gz" && \ + wget https://dl.google.com/go/${filename} && \ + mkdir temp-go && \ + rm -rf /root/.go/* && \ + tar -C temp-go/ -xzf ${filename} && \ + mv temp-go/go/* /root/.go/ && \ + rm -r ${filename} temp-go + +RUN apt-get update && apt-get install -y make autoconf automake libtool + +COPY . $SRC/quic-go +WORKDIR quic-go +COPY .clusterfuzzlite/build.sh $SRC/ diff --git a/.clusterfuzzlite/build.sh b/.clusterfuzzlite/build.sh new file mode 100755 index 00000000..92a24169 --- /dev/null +++ b/.clusterfuzzlite/build.sh @@ -0,0 +1,9 @@ +#!/bin/bash -eu + +export CXX="${CXX} -lresolv" # required by Go 1.20 + +compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/frames Fuzz frame_fuzzer +compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/header Fuzz header_fuzzer +compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/transportparameters Fuzz transportparameter_fuzzer +compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/tokens Fuzz token_fuzzer +compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/handshake Fuzz handshake_fuzzer diff --git a/.clusterfuzzlite/project.yaml b/.clusterfuzzlite/project.yaml new file mode 100644 index 00000000..4f2ee4d9 --- /dev/null +++ b/.clusterfuzzlite/project.yaml @@ -0,0 +1 @@ +language: go diff --git a/.github/workflows/clusterfuzz-lite-pr.yml b/.github/workflows/clusterfuzz-lite-pr.yml new file mode 100644 index 00000000..c902db19 --- /dev/null +++ b/.github/workflows/clusterfuzz-lite-pr.yml @@ -0,0 +1,48 @@ +name: ClusterFuzzLite PR fuzzing +on: + pull_request: + paths: + - '**' + +permissions: read-all +jobs: + PR: + runs-on: ${{ fromJSON(vars['CLUSTERFUZZ_LITE_RUNNER_UBUNTU'] || '"ubuntu-latest"') }} + concurrency: + group: ${{ github.workflow }}-${{ matrix.sanitizer }}-${{ github.ref }} + cancel-in-progress: true + strategy: + fail-fast: false + matrix: + sanitizer: + - address + steps: + - name: Build Fuzzers (${{ matrix.sanitizer }}) + id: build + uses: google/clusterfuzzlite/actions/build_fuzzers@v1 + with: + language: go + github-token: ${{ secrets.GITHUB_TOKEN }} + sanitizer: ${{ matrix.sanitizer }} + # Optional but recommended: used to only run fuzzers that are affected + # by the PR. + # See later section on "Git repo for storage". + # storage-repo: https://${{ secrets.PERSONAL_ACCESS_TOKEN }}@github.com/OWNER/STORAGE-REPO-NAME.git + # storage-repo-branch: main # Optional. Defaults to "main" + # storage-repo-branch-coverage: gh-pages # Optional. Defaults to "gh-pages". + - name: Run Fuzzers (${{ matrix.sanitizer }}) + id: run + uses: google/clusterfuzzlite/actions/run_fuzzers@v1 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + fuzz-seconds: 480 + mode: 'code-change' + sanitizer: ${{ matrix.sanitizer }} + output-sarif: true + parallel-fuzzing: true + # Optional but recommended: used to download the corpus produced by + # batch fuzzing. + # See later section on "Git repo for storage". + # storage-repo: https://${{ secrets.PERSONAL_ACCESS_TOKEN }}@github.com/OWNER/STORAGE-REPO-NAME.git + # storage-repo-branch: main # Optional. Defaults to "main" + # storage-repo-branch-coverage: gh-pages # Optional. Defaults to "gh-pages". From d8cc4cb3ef5bff86cdde8c8c5b0a24c8e9180e91 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 16 Sep 2023 18:57:50 +0700 Subject: [PATCH 33/43] http3: introduce an HTTP/3 error type (#4039) * http3: introduce an HTTP/3 error type * http3: use a pointer receiver for the Error --- http3/body.go | 5 +-- http3/client.go | 4 +-- http3/error.go | 58 ++++++++++++++++++++++++++++++ http3/error_codes.go | 10 +++++- http3/error_test.go | 41 +++++++++++++++++++++ http3/response_writer.go | 7 ++-- integrationtests/self/http_test.go | 12 ++++--- 7 files changed, 125 insertions(+), 12 deletions(-) create mode 100644 http3/error.go create mode 100644 http3/error_test.go diff --git a/http3/body.go b/http3/body.go index 63ff4366..dc168e9e 100644 --- a/http3/body.go +++ b/http3/body.go @@ -63,7 +63,8 @@ func (r *body) wasStreamHijacked() bool { } func (r *body) Read(b []byte) (int, error) { - return r.str.Read(b) + n, err := r.str.Read(b) + return n, maybeReplaceError(err) } func (r *body) Close() error { @@ -106,7 +107,7 @@ func (r *hijackableBody) Read(b []byte) (int, error) { if err != nil { r.requestDone() } - return n, err + return n, maybeReplaceError(err) } func (r *hijackableBody) requestDone() { diff --git a/http3/client.go b/http3/client.go index cc2400e2..8aca4807 100644 --- a/http3/client.go +++ b/http3/client.go @@ -318,13 +318,13 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon } conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) } - return nil, rerr.err + return nil, maybeReplaceError(rerr.err) } if opt.DontCloseRequestStream { close(reqDone) <-done } - return rsp, rerr.err + return rsp, maybeReplaceError(rerr.err) } // cancelingReader reads from the io.Reader. diff --git a/http3/error.go b/http3/error.go new file mode 100644 index 00000000..b96ebeec --- /dev/null +++ b/http3/error.go @@ -0,0 +1,58 @@ +package http3 + +import ( + "errors" + "fmt" + + "github.com/quic-go/quic-go" +) + +// Error is returned from the round tripper (for HTTP clients) +// and inside the HTTP handler (for HTTP servers) if an HTTP/3 error occurs. +// See section 8 of RFC 9114. +type Error struct { + Remote bool + ErrorCode ErrCode + ErrorMessage string +} + +var _ error = &Error{} + +func (e *Error) Error() string { + s := e.ErrorCode.string() + if s == "" { + s = fmt.Sprintf("H3 error (%#x)", uint64(e.ErrorCode)) + } + // Usually errors are remote. Only make it explicit for local errors. + if !e.Remote { + s += " (local)" + } + if e.ErrorMessage != "" { + s += ": " + e.ErrorMessage + } + return s +} + +func maybeReplaceError(err error) error { + if err == nil { + return nil + } + + var ( + e Error + strErr *quic.StreamError + appErr *quic.ApplicationError + ) + switch { + default: + return err + case errors.As(err, &strErr): + e.Remote = strErr.Remote + e.ErrorCode = ErrCode(strErr.ErrorCode) + case errors.As(err, &appErr): + e.Remote = appErr.Remote + e.ErrorCode = ErrCode(appErr.ErrorCode) + e.ErrorMessage = appErr.ErrorMessage + } + return &e +} diff --git a/http3/error_codes.go b/http3/error_codes.go index e8428d0a..ae646586 100644 --- a/http3/error_codes.go +++ b/http3/error_codes.go @@ -30,6 +30,14 @@ const ( ) func (e ErrCode) String() string { + s := e.string() + if s != "" { + return s + } + return fmt.Sprintf("unknown error code: %#x", uint16(e)) +} + +func (e ErrCode) string() string { switch e { case ErrCodeNoError: return "H3_NO_ERROR" @@ -68,6 +76,6 @@ func (e ErrCode) String() string { case ErrCodeDatagramError: return "H3_DATAGRAM_ERROR" default: - return fmt.Sprintf("unknown error code: %#x", uint16(e)) + return "" } } diff --git a/http3/error_test.go b/http3/error_test.go new file mode 100644 index 00000000..71567128 --- /dev/null +++ b/http3/error_test.go @@ -0,0 +1,41 @@ +package http3 + +import ( + "errors" + + "github.com/quic-go/quic-go" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("HTTP/3 errors", func() { + It("converts", func() { + Expect(maybeReplaceError(nil)).To(BeNil()) + Expect(maybeReplaceError(errors.New("foobar"))).To(MatchError("foobar")) + Expect(maybeReplaceError(&quic.StreamError{ + ErrorCode: 1337, + Remote: true, + })).To(Equal(&Error{ + Remote: true, + ErrorCode: 1337, + })) + Expect(maybeReplaceError(&quic.ApplicationError{ + ErrorCode: 42, + Remote: true, + ErrorMessage: "foobar", + })).To(Equal(&Error{ + Remote: true, + ErrorCode: 42, + ErrorMessage: "foobar", + })) + }) + + It("has a string representation", func() { + Expect((&Error{ErrorCode: 0x10c, Remote: true}).Error()).To(Equal("H3_REQUEST_CANCELLED")) + Expect((&Error{ErrorCode: 0x10c, Remote: true, ErrorMessage: "foobar"}).Error()).To(Equal("H3_REQUEST_CANCELLED: foobar")) + Expect((&Error{ErrorCode: 0x10c, Remote: false}).Error()).To(Equal("H3_REQUEST_CANCELLED (local)")) + Expect((&Error{ErrorCode: 0x10c, Remote: false, ErrorMessage: "foobar"}).Error()).To(Equal("H3_REQUEST_CANCELLED (local): foobar")) + Expect((&Error{ErrorCode: 0x1337, Remote: true}).Error()).To(Equal("H3 error (0x1337)")) + }) +}) diff --git a/http3/response_writer.go b/http3/response_writer.go index 90a30497..0d931478 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -166,9 +166,10 @@ func (w *responseWriter) Write(p []byte) (int, error) { w.buf = w.buf[:0] w.buf = df.Append(w.buf) if _, err := w.bufferedStr.Write(w.buf); err != nil { - return 0, err + return 0, maybeReplaceError(err) } - return w.bufferedStr.Write(p) + n, err := w.bufferedStr.Write(p) + return n, maybeReplaceError(err) } func (w *responseWriter) FlushError() error { @@ -177,7 +178,7 @@ func (w *responseWriter) FlushError() error { } if !w.written { if err := w.writeHeader(); err != nil { - return err + return maybeReplaceError(err) } w.written = true } diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index a1067070..ac16ffa5 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -307,9 +307,10 @@ var _ = Describe("HTTP tests", func() { for { if _, err := w.Write([]byte("foobar")); err != nil { Expect(r.Context().Done()).To(BeClosed()) - var strErr *quic.StreamError - Expect(errors.As(err, &strErr)).To(BeTrue()) - Expect(strErr.ErrorCode).To(Equal(quic.StreamErrorCode(0x10c))) + var http3Err *http3.Error + Expect(errors.As(err, &http3Err)).To(BeTrue()) + Expect(http3Err.ErrorCode).To(Equal(http3.ErrCode(0x10c))) + Expect(http3Err.Error()).To(Equal("H3_REQUEST_CANCELLED")) return } } @@ -325,7 +326,10 @@ var _ = Describe("HTTP tests", func() { cancel() Eventually(handlerCalled).Should(BeClosed()) _, err = resp.Body.Read([]byte{0}) - Expect(err).To(HaveOccurred()) + var http3Err *http3.Error + Expect(errors.As(err, &http3Err)).To(BeTrue()) + Expect(http3Err.ErrorCode).To(Equal(http3.ErrCode(0x10c))) + Expect(http3Err.Error()).To(Equal("H3_REQUEST_CANCELLED (local)")) }) It("allows streamed HTTP requests", func() { From 9b8219657899fd188678f0ca6c5751aa701faa43 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 16 Sep 2023 18:58:51 +0700 Subject: [PATCH 34/43] make the logging.Tracer and logging.ConnectionTracer a struct (#4082) --- client.go | 4 +- client_test.go | 21 +- codecov.yml | 1 - config_test.go | 2 +- connection.go | 82 +-- connection_test.go | 10 +- example/client/main.go | 2 +- example/main.go | 2 +- integrationtests/self/key_update_test.go | 23 +- integrationtests/self/packetization_test.go | 8 +- integrationtests/self/self_suite_test.go | 67 +-- integrationtests/self/timeout_test.go | 6 +- integrationtests/self/tracer_test.go | 12 +- integrationtests/self/zero_rtt_oldgo_test.go | 56 +- integrationtests/self/zero_rtt_test.go | 68 ++- integrationtests/tools/qlog.go | 4 +- .../versionnegotiation/handshake_test.go | 86 ++-- .../versionnegotiation_suite_test.go | 2 +- interface.go | 2 +- internal/ackhandler/ackhandler.go | 2 +- internal/ackhandler/ecn.go | 24 +- internal/ackhandler/ecn_test.go | 5 +- internal/ackhandler/sent_packet_handler.go | 46 +- internal/congestion/cubic_sender.go | 10 +- internal/handshake/crypto_setup.go | 22 +- internal/handshake/updatable_aead.go | 12 +- internal/handshake/updatable_aead_test.go | 6 +- internal/mocks/logging/connection_tracer.go | 478 ++++-------------- .../logging/internal/connection_tracer.go | 388 ++++++++++++++ internal/mocks/logging/internal/tracer.go | 74 +++ internal/mocks/logging/mockgen.go | 51 ++ internal/mocks/logging/tracer.go | 85 +--- internal/mocks/mockgen.go | 28 +- interop/utils/logging.go | 2 +- logging/connection_tracer.go | 255 ++++++++++ logging/interface.go | 44 -- logging/multiplex.go | 232 --------- logging/multiplex_test.go | 33 +- logging/null_tracer.go | 63 --- logging/tracer.go | 43 ++ qlog/qlog.go | 85 +++- qlog/qlog_test.go | 2 +- server.go | 48 +- server_test.go | 34 +- transport.go | 6 +- transport_test.go | 10 +- 46 files changed, 1388 insertions(+), 1158 deletions(-) create mode 100644 internal/mocks/logging/internal/connection_tracer.go create mode 100644 internal/mocks/logging/internal/tracer.go create mode 100644 internal/mocks/logging/mockgen.go create mode 100644 logging/connection_tracer.go delete mode 100644 logging/multiplex.go delete mode 100644 logging/null_tracer.go create mode 100644 logging/tracer.go diff --git a/client.go b/client.go index 61cd7526..670e7a7c 100644 --- a/client.go +++ b/client.go @@ -34,7 +34,7 @@ type client struct { conn quicConn - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer tracingID uint64 logger utils.Logger } @@ -153,7 +153,7 @@ func dial( if c.config.Tracer != nil { c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) } - if c.tracer != nil { + if c.tracer != nil && c.tracer.StartedConnection != nil { c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) } if err := c.dial(ctx); err != nil { diff --git a/client_test.go b/client_test.go index e908b82f..3fa5fb5a 100644 --- a/client_test.go +++ b/client_test.go @@ -43,7 +43,7 @@ var _ = Describe("Client", func() { initialPacketNumber protocol.PacketNumber, enable0RTT bool, hasNegotiatedVersion bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, v protocol.VersionNumber, @@ -54,10 +54,11 @@ var _ = Describe("Client", func() { tlsConf = &tls.Config{NextProtos: []string{"proto1"}} connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37}) originalClientConnConstructor = newClientConnection - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + var tr *logging.ConnectionTracer + tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) config = &Config{ - Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) logging.ConnectionTracer { - return tracer + Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) *logging.ConnectionTracer { + return tr }, Versions: []protocol.VersionNumber{protocol.Version1}, } @@ -70,7 +71,7 @@ var _ = Describe("Client", func() { destConnID: connID, version: protocol.Version1, sendConn: packetConn, - tracer: tracer, + tracer: tr, logger: utils.DefaultLogger, } getMultiplexer() // make the sync.Once execute @@ -121,7 +122,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, enable0RTT bool, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -158,7 +159,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, enable0RTT bool, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -195,7 +196,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ bool, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -280,7 +281,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ bool, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, versionP protocol.VersionNumber, @@ -323,7 +324,7 @@ var _ = Describe("Client", func() { pn protocol.PacketNumber, _ bool, hasNegotiatedVersion bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, versionP protocol.VersionNumber, diff --git a/codecov.yml b/codecov.yml index 0b95a103..a24c7a15 100644 --- a/codecov.yml +++ b/codecov.yml @@ -5,7 +5,6 @@ coverage: - interop/ - internal/handshake/cipher_suite.go - internal/utils/linkedlist/linkedlist.go - - logging/null_tracer.go - fuzzing/ - metrics/ status: diff --git a/config_test.go b/config_test.go index 9cb665a3..c5701608 100644 --- a/config_test.go +++ b/config_test.go @@ -125,7 +125,7 @@ var _ = Describe("Config", func() { GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, - Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer { + Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer { calledTracer = true return nil }, diff --git a/connection.go b/connection.go index d6cc886d..09b522a9 100644 --- a/connection.go +++ b/connection.go @@ -208,7 +208,7 @@ type connection struct { connState ConnectionState logID string - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger } @@ -232,7 +232,7 @@ var newConnection = func( tlsConf *tls.Config, tokenGenerator *handshake.TokenGenerator, clientAddressValidated bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, v protocol.VersionNumber, @@ -311,7 +311,7 @@ var newConnection = func( } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentTransportParameters != nil { s.tracer.SentTransportParameters(params) } cs := handshake.NewCryptoSetupServer( @@ -345,7 +345,7 @@ var newClientConnection = func( initialPacketNumber protocol.PacketNumber, enable0RTT bool, hasNegotiatedVersion bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, v protocol.VersionNumber, @@ -418,7 +418,7 @@ var newClientConnection = func( } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentTransportParameters != nil { s.tracer.SentTransportParameters(params) } cs := handshake.NewCryptoSetupClient( @@ -642,8 +642,10 @@ runLoop: s.cryptoStreamHandler.Close() s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE s.handleCloseError(&closeErr) - if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { - s.tracer.Close() + if s.tracer != nil && s.tracer.Close != nil { + if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) { + s.tracer.Close() + } } s.logger.Infof("Connection %s closed.", s.logID) s.timer.Stop() @@ -801,14 +803,14 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { var err error destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) } s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err) break } if destConnID != lastConnID { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID) @@ -819,7 +821,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { if wire.IsLongHeaderPacket(p.data[0]) { hdr, packetData, rest, err := wire.ParsePacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { dropReason := logging.PacketDropHeaderParseError if err == wire.ErrUnsupportedVersion { dropReason = logging.PacketDropUnsupportedVersion @@ -832,7 +834,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { lastConnID = hdr.DestConnectionID if hdr.Version != s.version { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) } s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) @@ -891,14 +893,14 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) { s.logger.Debugf("Dropping (potentially) duplicate packet.") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate) } return false } var log func([]logging.Frame) - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedShortHeaderPacket != nil { log = func(frames []logging.Frame) { s.tracer.ReceivedShortHeaderPacket( &logging.ShortHeader{ @@ -937,7 +939,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID) @@ -945,7 +947,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) } // drop 0-RTT packets, if we are a client if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketType0RTT, p.Size(), logging.PacketDropKeyUnavailable) } return false @@ -964,7 +966,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.hdr.PacketNumber, packet.encryptionLevel) { s.logger.Debugf("Dropping (potentially) duplicate packet.") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDuplicate) } return false @@ -980,7 +982,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) { switch err { case handshake.ErrKeysDropped: - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable) } s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size()) @@ -996,7 +998,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P }) case handshake.ErrDecryptionFailed: // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err) @@ -1004,7 +1006,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P var headerErr *headerParseError if errors.As(err, &headerErr) { // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err) @@ -1019,14 +1021,14 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime time.Time) bool /* was this a valid Retry */ { if s.perspective == protocol.PerspectiveServer { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry.") return false } if s.receivedFirstPacket { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry, since we already received a packet.") @@ -1034,7 +1036,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti } destConnID := s.connIDManager.Get() if hdr.SrcConnectionID == destConnID { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") @@ -1049,7 +1051,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version) if !bytes.Equal(data[len(data)-16:], tag[:]) { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.") @@ -1061,7 +1063,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID) } - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedRetry != nil { s.tracer.ReceivedRetry(hdr) } newDestConnID := hdr.SrcConnectionID @@ -1082,7 +1084,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) } return @@ -1090,7 +1092,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing Version Negotiation packet: %s", err) @@ -1099,7 +1101,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { for _, v := range supportedVersions { if v == s.version { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) } // The Version Negotiation packet contains the version that we offered. @@ -1109,7 +1111,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { } s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedVersionNegotiationPacket != nil { s.tracer.ReceivedVersionNegotiationPacket(dest, src, supportedVersions) } newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) @@ -1121,7 +1123,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { s.logger.Infof("No compatible QUIC version found.") return } - if s.tracer != nil { + if s.tracer != nil && s.tracer.NegotiatedVersion != nil { s.tracer.NegotiatedVersion(newVersion, s.config.Versions, supportedVersions) } @@ -1141,7 +1143,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( ) error { if !s.receivedFirstPacket { s.receivedFirstPacket = true - if !s.versionNegotiated && s.tracer != nil { + if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil { var clientVersions, serverVersions []protocol.VersionNumber switch s.perspective { case protocol.PerspectiveClient: @@ -1168,7 +1170,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( s.handshakeDestConnID = packet.hdr.SrcConnectionID s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) } - if s.tracer != nil { + if s.tracer != nil && s.tracer.StartedConnection != nil { s.tracer.StartedConnection( s.conn.LocalAddr(), s.conn.RemoteAddr(), @@ -1192,7 +1194,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( s.keepAlivePingSent = false var log func([]logging.Frame) - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedLongHeaderPacket != nil { log = func(frames []logging.Frame) { s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames) } @@ -1340,7 +1342,7 @@ func (s *connection) handlePacket(p receivedPacket) { select { case s.receivedPackets <- p: default: - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } @@ -1620,7 +1622,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { s.datagramQueue.CloseWithError(e) } - if s.tracer != nil && !errors.As(e, &recreateErr) { + if s.tracer != nil && s.tracer.ClosedConnection != nil && !errors.As(e, &recreateErr) { s.tracer.ClosedConnection(e) } @@ -1647,7 +1649,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { } func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedEncryptionLevel != nil { s.tracer.DroppedEncryptionLevel(encLevel) } s.sentPacketHandler.DropPackets(encLevel) @@ -1682,7 +1684,7 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters } func (s *connection) handleTransportParameters(params *wire.TransportParameters) error { - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedTransportParameters != nil { s.tracer.ReceivedTransportParameters(params) } if err := s.checkTransportParameters(params); err != nil { @@ -2134,7 +2136,7 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) } // tracing - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentLongHeaderPacket != nil { frames := make([]logging.Frame, 0, len(p.frames)) for _, f := range p.frames { frames = append(frames, logutils.ConvertFrame(f.Frame)) @@ -2180,7 +2182,7 @@ func (s *connection) logShortHeaderPacket( } // tracing - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentShortHeaderPacket != nil { fs := make([]logging.Frame, 0, len(frames)+len(streamFrames)) for _, f := range frames { fs = append(fs, logutils.ConvertFrame(f.Frame)) @@ -2302,14 +2304,14 @@ func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging panic("shouldn't queue undecryptable packets after handshake completion") } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropDOSPrevention) } s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size()) return } s.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.BufferedPacket != nil { s.tracer.BufferedPacket(pt, p.Size()) } s.undecryptablePackets = append(s.undecryptablePackets, p) diff --git a/connection_test.go b/connection_test.go index 44564f84..e2cca32e 100644 --- a/connection_test.go +++ b/connection_test.go @@ -101,7 +101,8 @@ var _ = Describe("Connection", func() { mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() tokenGenerator := handshake.NewTokenGenerator([32]byte{0xa, 0xb, 0xc}) - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + var tr *logging.ConnectionTracer + tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() @@ -120,7 +121,7 @@ var _ = Describe("Connection", func() { &tls.Config{}, tokenGenerator, false, - tracer, + tr, 1234, utils.DefaultLogger, protocol.Version1, @@ -2540,7 +2541,8 @@ var _ = Describe("Client Connection", func() { tlsConf = &tls.Config{} } connRunner = NewMockConnRunner(mockCtrl) - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + var tr *logging.ConnectionTracer + tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() @@ -2556,7 +2558,7 @@ var _ = Describe("Client Connection", func() { 42, // initial packet number false, false, - tracer, + tr, 1234, utils.DefaultLogger, protocol.Version1, diff --git a/example/client/main.go b/example/client/main.go index 83f810fd..5b86faa3 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -58,7 +58,7 @@ func main() { var qconf quic.Config if *enableQlog { - qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { filename := fmt.Sprintf("client_%x.qlog", connID) f, err := os.Create(filename) if err != nil { diff --git a/example/main.go b/example/main.go index 05814405..c2fb574d 100644 --- a/example/main.go +++ b/example/main.go @@ -163,7 +163,7 @@ func main() { handler := setupHandler(*www) quicConf := &quic.Config{} if *enableQlog { - quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { filename := fmt.Sprintf("server_%x.qlog", connID) f, err := os.Create(filename) if err != nil { diff --git a/integrationtests/self/key_update_test.go b/integrationtests/self/key_update_test.go index 7975a77b..3704f96e 100644 --- a/integrationtests/self/key_update_test.go +++ b/integrationtests/self/key_update_test.go @@ -38,18 +38,13 @@ func countKeyPhases() (sent, received int) { return } -type keyUpdateConnTracer struct { - logging.NullConnectionTracer -} - -var _ logging.ConnectionTracer = &keyUpdateConnTracer{} - -func (t *keyUpdateConnTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) { - sentHeaders = append(sentHeaders, hdr) -} - -func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) { - receivedHeaders = append(receivedHeaders, hdr) +var keyUpdateConnTracer = &logging.ConnectionTracer{ + SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) { + sentHeaders = append(sentHeaders, hdr) + }, + ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) { + receivedHeaders = append(receivedHeaders, hdr) + }, } var _ = Describe("Key Update tests", func() { @@ -77,8 +72,8 @@ var _ = Describe("Key Update tests", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { - return &keyUpdateConnTracer{} + getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return keyUpdateConnTracer }}), ) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index 740956c5..8c59bc0b 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -21,7 +21,7 @@ var _ = Describe("Packetization", func() { It("bundles ACKs", func() { const numMsg = 100 - serverTracer := newPacketTracer() + serverCounter, serverTracer := newPacketTracer() server, err := quic.ListenAddr( "localhost:0", getTLSConfig(), @@ -43,7 +43,7 @@ var _ = Describe("Packetization", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - clientTracer := newPacketTracer() + clientCounter, clientTracer := newPacketTracer() conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), @@ -104,8 +104,8 @@ var _ = Describe("Packetization", func() { return } - numBundledIncoming := countBundledPackets(clientTracer.getRcvdShortHeaderPackets()) - numBundledOutgoing := countBundledPackets(serverTracer.getRcvdShortHeaderPackets()) + numBundledIncoming := countBundledPackets(clientCounter.getRcvdShortHeaderPackets()) + numBundledOutgoing := countBundledPackets(serverCounter.getRcvdShortHeaderPackets()) fmt.Fprintf(GinkgoWriter, "bundled incoming packets: %d / %d\n", numBundledIncoming, numMsg) fmt.Fprintf(GinkgoWriter, "bundled outgoing packets: %d / %d\n", numBundledOutgoing, numMsg) Expect(numBundledIncoming).To(And( diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 6396d3a7..582575f5 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -86,7 +86,7 @@ var ( logBuf *syncedBuffer versionParam string - qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer + qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer enableQlog bool version quic.VersionNumber @@ -177,10 +177,16 @@ func getQuicConfig(conf *quic.Config) *quic.Config { } if enableQlog { if conf.Tracer == nil { - conf.Tracer = qlogTracer + conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { + return logging.NewMultiplexedConnectionTracer( + qlogTracer(ctx, p, connID), + // multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere + &logging.ConnectionTracer{}, + ) + } } else if qlogTracer != nil { origTracer := conf.Tracer - conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { return logging.NewMultiplexedConnectionTracer( qlogTracer(ctx, p, connID), origTracer(ctx, p, connID), @@ -242,8 +248,8 @@ func scaleDuration(d time.Duration) time.Duration { return time.Duration(scaleFactor) * d } -func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { - return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer } +func newTracer(tracer *logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return tracer } } type packet struct { @@ -258,51 +264,46 @@ type shortHeaderPacket struct { frames []logging.Frame } -type packetTracer struct { - logging.NullConnectionTracer +type packetCounter struct { closed chan struct{} sentShortHdr, rcvdShortHdr []shortHeaderPacket rcvdLongHdr []packet } -var _ logging.ConnectionTracer = &packetTracer{} - -func newPacketTracer() *packetTracer { - return &packetTracer{closed: make(chan struct{})} -} - -func (t *packetTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { - t.rcvdLongHdr = append(t.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames}) -} - -func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { - t.rcvdShortHdr = append(t.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) -} - -func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) { - if ack != nil { - frames = append(frames, ack) - } - t.sentShortHdr = append(t.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) -} - -func (t *packetTracer) Close() { close(t.closed) } - -func (t *packetTracer) getSentShortHeaderPackets() []shortHeaderPacket { +func (t *packetCounter) getSentShortHeaderPackets() []shortHeaderPacket { <-t.closed return t.sentShortHdr } -func (t *packetTracer) getRcvdLongHeaderPackets() []packet { +func (t *packetCounter) getRcvdLongHeaderPackets() []packet { <-t.closed return t.rcvdLongHdr } -func (t *packetTracer) getRcvdShortHeaderPackets() []shortHeaderPacket { +func (t *packetCounter) getRcvdShortHeaderPackets() []shortHeaderPacket { <-t.closed return t.rcvdShortHdr } +func newPacketTracer() (*packetCounter, *logging.ConnectionTracer) { + c := &packetCounter{closed: make(chan struct{})} + return c, &logging.ConnectionTracer{ + ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { + c.rcvdLongHdr = append(c.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames}) + }, + ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { + c.rcvdShortHdr = append(c.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) + }, + SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) { + if ack != nil { + frames = append(frames, ack) + } + c.sentShortHdr = append(c.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) + }, + Close: func() { close(c.closed) }, + } +} + func TestSelf(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Self integration tests") diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 47569b77..ddd0b973 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -194,7 +194,7 @@ var _ = Describe("Timeout tests", func() { close(serverConnClosed) }() - tr := newPacketTracer() + counter, tr := newPacketTracer() conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), @@ -215,7 +215,7 @@ var _ = Describe("Timeout tests", func() { }() Eventually(done, 2*idleTimeout).Should(BeClosed()) var lastAckElicitingPacketSentAt time.Time - for _, p := range tr.getSentShortHeaderPackets() { + for _, p := range counter.getSentShortHeaderPackets() { var hasAckElicitingFrame bool for _, f := range p.frames { if _, ok := f.(*logging.AckFrame); ok { @@ -228,7 +228,7 @@ var _ = Describe("Timeout tests", func() { lastAckElicitingPacketSentAt = p.time } } - rcvdPackets := tr.getRcvdShortHeaderPackets() + rcvdPackets := counter.getRcvdShortHeaderPackets() lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time // We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout. // This is ok since we're dealing with a lossless connection here, diff --git a/integrationtests/self/tracer_test.go b/integrationtests/self/tracer_test.go index 3bfae3c6..5179646e 100644 --- a/integrationtests/self/tracer_test.go +++ b/integrationtests/self/tracer_test.go @@ -26,9 +26,9 @@ var _ = Describe("Handshake tests", func() { fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, custom: %t\n", pers, enableQlog, enableCustomTracer) - var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer + var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer if enableQlog { - tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connID) return nil @@ -38,13 +38,13 @@ var _ = Describe("Handshake tests", func() { }) } if enableCustomTracer { - tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { - return logging.NullConnectionTracer{} + tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return &logging.ConnectionTracer{} }) } c := conf.Clone() - c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { - tracers := make([]logging.ConnectionTracer, 0, len(tracerConstructors)) + c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { + tracers := make([]*logging.ConnectionTracer, 0, len(tracerConstructors)) for _, c := range tracerConstructors { if tr := c(ctx, p, connID); tr != nil { tracers = append(tracers, tr) diff --git a/integrationtests/self/zero_rtt_oldgo_test.go b/integrationtests/self/zero_rtt_oldgo_test.go index 3e9277b9..eb2302d3 100644 --- a/integrationtests/self/zero_rtt_oldgo_test.go +++ b/integrationtests/self/zero_rtt_oldgo_test.go @@ -202,7 +202,7 @@ var _ = Describe("0-RTT", func() { Eventually(conn.Context().Done()).Should(BeClosed()) } - // can be used to extract 0-RTT from a packetTracer + // can be used to extract 0-RTT from a packetCounter get0RTTPackets := func(packets []packet) []protocol.PacketNumber { var zeroRTTPackets []protocol.PacketNumber for _, p := range packets { @@ -219,7 +219,7 @@ var _ = Describe("0-RTT", func() { It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() { tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -244,7 +244,7 @@ var _ = Describe("0-RTT", func() { ) var numNewConnIDs int - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if _, ok := f.(*logging.NewConnectionIDFrame); ok { numNewConnIDs++ @@ -260,7 +260,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0))) }) @@ -273,7 +273,7 @@ var _ = Describe("0-RTT", func() { zeroRTTData := GeneratePRData(5 << 10) oneRTTData := PRData - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -330,7 +330,7 @@ var _ = Describe("0-RTT", func() { // check that 0-RTT packets only contain STREAM frames for the first stream var num0RTT int - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { if p.hdr.Header.Type != protocol.PacketType0RTT { continue } @@ -355,7 +355,7 @@ var _ = Describe("0-RTT", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -406,7 +406,7 @@ var _ = Describe("0-RTT", func() { fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) Expect(numDropped).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) }) It("retransmits all 0-RTT data when the server performs a Retry", func() { @@ -430,7 +430,7 @@ var _ = Describe("0-RTT", func() { return } - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -480,7 +480,7 @@ var _ = Describe("0-RTT", func() { defer mutex.Unlock() Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra Expect(secondCounter).To(BeNumerically("~", firstCounter, 20)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5)) Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) }) @@ -491,14 +491,12 @@ var _ = Describe("0-RTT", func() { MaxIncomingUniStreams: maxStreams, })) - tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, getQuicConfig(&quic.Config{ MaxIncomingUniStreams: maxStreams + 1, Allow0RTT: true, - Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -536,7 +534,7 @@ var _ = Describe("0-RTT", func() { MaxIncomingStreams: maxStreams, })) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -556,7 +554,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) It("rejects 0-RTT when the ALPN changed", func() { @@ -565,7 +563,7 @@ var _ = Describe("0-RTT", func() { // now close the listener and dial new connection with a different ALPN clientConf.NextProtos = []string{"new-alpn"} tlsConf.NextProtos = []string{"new-alpn"} - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -585,14 +583,14 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) It("rejects 0-RTT when the application doesn't allow it", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) // now close the listener and dial new connection with a different ALPN - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -612,12 +610,12 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) DescribeTable("flow control limits", func(addFlowControlLimit func(*quic.Config, uint64)) { - tracer := newPacketTracer() + counter, tracer := newPacketTracer() firstConf := getQuicConfig(&quic.Config{Allow0RTT: true}) addFlowControlLimit(firstConf, 3) tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) @@ -669,7 +667,7 @@ var _ = Describe("0-RTT", func() { Eventually(conn.Context().Done()).Should(BeClosed()) var processedFirst bool - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if sf, ok := f.(*logging.StreamFrame); ok { if !processedFirst { @@ -695,7 +693,7 @@ var _ = Describe("0-RTT", func() { It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) // now dial new connection with different transport parameters - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -764,14 +762,14 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) } It("queues 0-RTT packets, if the Initial is delayed", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -796,8 +794,8 @@ var _ = Describe("0-RTT", func() { transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) - Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(counter.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) }) @@ -807,7 +805,7 @@ var _ = Describe("0-RTT", func() { EnableDatagrams: true, })) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -856,7 +854,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(zeroRTTPackets).To(HaveLen(1)) }) @@ -865,7 +863,7 @@ var _ = Describe("0-RTT", func() { EnableDatagrams: true, })) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -911,6 +909,6 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) }) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index c9bdeff6..1f750d3a 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -232,7 +232,7 @@ var _ = Describe("0-RTT", func() { Eventually(conn.Context().Done()).Should(BeClosed()) } - // can be used to extract 0-RTT from a packetTracer + // can be used to extract 0-RTT from a packetCounter get0RTTPackets := func(packets []packet) []protocol.PacketNumber { var zeroRTTPackets []protocol.PacketNumber for _, p := range packets { @@ -251,7 +251,7 @@ var _ = Describe("0-RTT", func() { clientTLSConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -276,7 +276,7 @@ var _ = Describe("0-RTT", func() { ) var numNewConnIDs int - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if _, ok := f.(*logging.NewConnectionIDFrame); ok { numNewConnIDs++ @@ -292,7 +292,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0))) }) @@ -307,7 +307,7 @@ var _ = Describe("0-RTT", func() { zeroRTTData := GeneratePRData(5 << 10) oneRTTData := PRData - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -364,7 +364,7 @@ var _ = Describe("0-RTT", func() { // check that 0-RTT packets only contain STREAM frames for the first stream var num0RTT int - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { if p.hdr.Header.Type != protocol.PacketType0RTT { continue } @@ -391,7 +391,7 @@ var _ = Describe("0-RTT", func() { clientConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -442,7 +442,7 @@ var _ = Describe("0-RTT", func() { fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) Expect(numDropped).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) }) It("retransmits all 0-RTT data when the server performs a Retry", func() { @@ -468,7 +468,7 @@ var _ = Describe("0-RTT", func() { return } - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -518,7 +518,7 @@ var _ = Describe("0-RTT", func() { defer mutex.Unlock() Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra Expect(secondCounter).To(BeNumerically("~", firstCounter, 20)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5)) Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) }) @@ -531,14 +531,12 @@ var _ = Describe("0-RTT", func() { MaxIncomingUniStreams: maxStreams, }), clientConf) - tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, getQuicConfig(&quic.Config{ MaxIncomingUniStreams: maxStreams + 1, Allow0RTT: true, - Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -578,7 +576,7 @@ var _ = Describe("0-RTT", func() { MaxIncomingStreams: maxStreams, }), clientConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -599,7 +597,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) It("rejects 0-RTT when the ALPN changed", func() { @@ -612,7 +610,7 @@ var _ = Describe("0-RTT", func() { // Append to the client's ALPN. // crypto/tls will attempt to resume with the ALPN from the original connection clientConf.NextProtos = append(clientConf.NextProtos, "new-alpn") - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -632,7 +630,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) It("rejects 0-RTT when the application doesn't allow it", func() { @@ -641,7 +639,7 @@ var _ = Describe("0-RTT", func() { dialAndReceiveSessionTicket(tlsConf, nil, clientConf) // now close the listener and dial new connection with a different ALPN - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -661,12 +659,12 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) DescribeTable("flow control limits", func(addFlowControlLimit func(*quic.Config, uint64)) { - tracer := newPacketTracer() + counter, tracer := newPacketTracer() firstConf := getQuicConfig(&quic.Config{Allow0RTT: true}) addFlowControlLimit(firstConf, 3) tlsConf := getTLSConfig() @@ -720,7 +718,7 @@ var _ = Describe("0-RTT", func() { Eventually(conn.Context().Done()).Should(BeClosed()) var processedFirst bool - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if sf, ok := f.(*logging.StreamFrame); ok { if !processedFirst { @@ -748,7 +746,7 @@ var _ = Describe("0-RTT", func() { clientConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientConf) // now dial new connection with different transport parameters - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -817,7 +815,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) } @@ -826,7 +824,7 @@ var _ = Describe("0-RTT", func() { clientConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -851,8 +849,8 @@ var _ = Describe("0-RTT", func() { transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) - Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(counter.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) }) @@ -878,14 +876,10 @@ var _ = Describe("0-RTT", func() { clientTLSConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf) - tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - Tracer: newTracer(tracer), - }), + getQuicConfig(&quic.Config{Allow0RTT: true}), ) Expect(err).ToNot(HaveOccurred()) defer ln.Close() @@ -916,14 +910,10 @@ var _ = Describe("0-RTT", func() { } dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf) - tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - Tracer: newTracer(tracer), - }), + getQuicConfig(&quic.Config{Allow0RTT: true}), ) Expect(err).ToNot(HaveOccurred()) defer ln.Close() @@ -946,7 +936,7 @@ var _ = Describe("0-RTT", func() { EnableDatagrams: true, }), clientTLSConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -994,7 +984,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(zeroRTTPackets).To(HaveLen(1)) }) @@ -1005,7 +995,7 @@ var _ = Describe("0-RTT", func() { EnableDatagrams: true, }), clientTLSConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -1051,6 +1041,6 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) }) diff --git a/integrationtests/tools/qlog.go b/integrationtests/tools/qlog.go index 352e0a61..ea37456e 100644 --- a/integrationtests/tools/qlog.go +++ b/integrationtests/tools/qlog.go @@ -14,8 +14,8 @@ import ( "github.com/quic-go/quic-go/qlog" ) -func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { - return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { +func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { role := "server" if p == logging.PerspectiveClient { role = "client" diff --git a/integrationtests/versionnegotiation/handshake_test.go b/integrationtests/versionnegotiation/handshake_test.go index a079f6e1..eefcd8e7 100644 --- a/integrationtests/versionnegotiation/handshake_test.go +++ b/integrationtests/versionnegotiation/handshake_test.go @@ -21,29 +21,29 @@ type versioner interface { GetVersion() protocol.VersionNumber } -type versionNegotiationTracer struct { - logging.NullConnectionTracer - +type result struct { loggedVersions bool receivedVersionNegotiation bool chosen logging.VersionNumber clientVersions, serverVersions []logging.VersionNumber } -var _ logging.ConnectionTracer = &versionNegotiationTracer{} - -func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { - if t.loggedVersions { - Fail("only expected one call to NegotiatedVersions") +func newVersionNegotiationTracer() (*result, *logging.ConnectionTracer) { + r := &result{} + return r, &logging.ConnectionTracer{ + NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { + if r.loggedVersions { + Fail("only expected one call to NegotiatedVersions") + } + r.loggedVersions = true + r.chosen = chosen + r.clientVersions = clientVersions + r.serverVersions = serverVersions + }, + ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { + r.receivedVersionNegotiation = true + }, } - t.loggedVersions = true - t.chosen = chosen - t.clientVersions = clientVersions - t.serverVersions = serverVersions -} - -func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { - t.receivedVersionNegotiation = true } var _ = Describe("Handshake tests", func() { @@ -86,54 +86,54 @@ var _ = Describe("Handshake tests", func() { // but it supports a bunch of versions that the client doesn't speak serverConfig := &quic.Config{} serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9} - serverTracer := &versionNegotiationTracer{} - serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + serverResult, serverTracer := newVersionNegotiationTracer() + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return serverTracer } server, cl := startServer(getTLSConfig(), serverConfig) defer cl() - clientTracer := &versionNegotiationTracer{} + clientResult, clientTracer := newVersionNegotiationTracer() conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer { + maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) *logging.ConnectionTracer { return clientTracer }}), ) Expect(err).ToNot(HaveOccurred()) Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion)) Expect(conn.CloseWithError(0, "")).To(Succeed()) - Expect(clientTracer.chosen).To(Equal(expectedVersion)) - Expect(clientTracer.receivedVersionNegotiation).To(BeFalse()) - Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions)) - Expect(clientTracer.serverVersions).To(BeEmpty()) - Expect(serverTracer.chosen).To(Equal(expectedVersion)) - Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) - Expect(serverTracer.clientVersions).To(BeEmpty()) + Expect(clientResult.chosen).To(Equal(expectedVersion)) + Expect(clientResult.receivedVersionNegotiation).To(BeFalse()) + Expect(clientResult.clientVersions).To(Equal(protocol.SupportedVersions)) + Expect(clientResult.serverVersions).To(BeEmpty()) + Expect(serverResult.chosen).To(Equal(expectedVersion)) + Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions)) + Expect(serverResult.clientVersions).To(BeEmpty()) }) It("when the client supports more versions than the server supports", func() { expectedVersion := protocol.SupportedVersions[0] // The server doesn't support the highest supported version, which is the first one the client will try, // but it supports a bunch of versions that the client doesn't speak - serverTracer := &versionNegotiationTracer{} + serverResult, serverTracer := newVersionNegotiationTracer() serverConfig := &quic.Config{} serverConfig.Versions = supportedVersions - serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return serverTracer } server, cl := startServer(getTLSConfig(), serverConfig) defer cl() clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} - clientTracer := &versionNegotiationTracer{} + clientResult, clientTracer := newVersionNegotiationTracer() conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), maybeAddQLOGTracer(&quic.Config{ Versions: clientVersions, - Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return clientTracer }, }), @@ -141,22 +141,22 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0])) Expect(conn.CloseWithError(0, "")).To(Succeed()) - Expect(clientTracer.chosen).To(Equal(expectedVersion)) - Expect(clientTracer.receivedVersionNegotiation).To(BeTrue()) - Expect(clientTracer.clientVersions).To(Equal(clientVersions)) - Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions - Expect(serverTracer.chosen).To(Equal(expectedVersion)) - Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) - Expect(serverTracer.clientVersions).To(BeEmpty()) + Expect(clientResult.chosen).To(Equal(expectedVersion)) + Expect(clientResult.receivedVersionNegotiation).To(BeTrue()) + Expect(clientResult.clientVersions).To(Equal(clientVersions)) + Expect(clientResult.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions + Expect(serverResult.chosen).To(Equal(expectedVersion)) + Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions)) + Expect(serverResult.clientVersions).To(BeEmpty()) }) It("fails if the server disables version negotiation", func() { // The server doesn't support the highest supported version, which is the first one the client will try, // but it supports a bunch of versions that the client doesn't speak - serverTracer := &versionNegotiationTracer{} + _, serverTracer := newVersionNegotiationTracer() serverConfig := &quic.Config{} serverConfig.Versions = supportedVersions - serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return serverTracer } conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) @@ -170,14 +170,14 @@ var _ = Describe("Handshake tests", func() { defer ln.Close() clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} - clientTracer := &versionNegotiationTracer{} + clientResult, clientTracer := newVersionNegotiationTracer() _, err = quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), maybeAddQLOGTracer(&quic.Config{ Versions: clientVersions, - Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return clientTracer }, HandshakeIdleTimeout: 100 * time.Millisecond, @@ -187,7 +187,7 @@ var _ = Describe("Handshake tests", func() { var nerr net.Error Expect(errors.As(err, &nerr)).To(BeTrue()) Expect(nerr.Timeout()).To(BeTrue()) - Expect(clientTracer.receivedVersionNegotiation).To(BeFalse()) + Expect(clientResult.receivedVersionNegotiation).To(BeFalse()) }) } }) diff --git a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go index a01ac1f8..150181f2 100644 --- a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go +++ b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go @@ -70,7 +70,7 @@ func maybeAddQLOGTracer(c *quic.Config) *quic.Config { c.Tracer = qlogger } else if qlogger != nil { origTracer := c.Tracer - c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { return logging.NewMultiplexedConnectionTracer( qlogger(ctx, p, connID), origTracer(ctx, p, connID), diff --git a/interface.go b/interface.go index f5ee28d8..c7e53bdc 100644 --- a/interface.go +++ b/interface.go @@ -328,7 +328,7 @@ type Config struct { Allow0RTT bool // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool - Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer + Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer } type ClientHelloInfo struct { diff --git a/internal/ackhandler/ackhandler.go b/internal/ackhandler/ackhandler.go index 036ed5ea..cb28582a 100644 --- a/internal/ackhandler/ackhandler.go +++ b/internal/ackhandler/ackhandler.go @@ -16,7 +16,7 @@ func NewAckHandler( clientAddressValidated bool, enableECN bool, pers protocol.Perspective, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, ) (SentPacketHandler, ReceivedPacketHandler) { sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger) diff --git a/internal/ackhandler/ecn.go b/internal/ackhandler/ecn.go index bb8b37f8..43cc3dd6 100644 --- a/internal/ackhandler/ecn.go +++ b/internal/ackhandler/ecn.go @@ -45,13 +45,13 @@ type ecnTracker struct { numSentECT0, numSentECT1 int64 numAckedECT0, numAckedECT1, numAckedECNCE int64 - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger } var _ ecnHandler = &ecnTracker{} -func newECNTracker(logger utils.Logger, tracer logging.ConnectionTracer) *ecnTracker { +func newECNTracker(logger utils.Logger, tracer *logging.ConnectionTracer) *ecnTracker { return &ecnTracker{ firstTestingPacket: protocol.InvalidPacketNumber, lastTestingPacket: protocol.InvalidPacketNumber, @@ -92,7 +92,7 @@ func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) { e.firstTestingPacket = pn } if e.numSentECT0+e.numSentECT1 >= numECNTestingPackets { - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) } e.state = ecnStateUnknown @@ -103,7 +103,7 @@ func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) { func (e *ecnTracker) Mode() protocol.ECN { switch e.state { case ecnStateInitial: - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) } e.state = ecnStateTesting @@ -127,7 +127,7 @@ func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) { e.numLostTesting++ if e.numLostTesting >= e.numSentTesting { e.logger.Debugf("Disabling ECN. All testing packets were lost.") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets) } e.state = ecnStateFailed @@ -146,7 +146,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // the total number of packets sent with each corresponding ECT codepoint. if ect0 > e.numSentECT0 || ect1 > e.numSentECT1 { e.logger.Debugf("Disabling ECN. Received more ECT(0) / ECT(1) acknowledgements than packets sent.") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent) } e.state = ecnStateFailed @@ -172,7 +172,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // * peers that don't report any ECN counts if (ackedECT0 > 0 || ackedECT1 > 0) && ect0 == 0 && ect1 == 0 && ecnce == 0 { e.logger.Debugf("Disabling ECN. ECN-marked packet acknowledged, but no ECN counts on ACK frame.") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts) } e.state = ecnStateFailed @@ -189,7 +189,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // Any decrease means that the peer's counting logic is broken. if newECT0 < 0 || newECT1 < 0 || newECNCE < 0 { e.logger.Debugf("Disabling ECN. ECN counts decreased unexpectedly.") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts) } e.state = ecnStateFailed @@ -201,7 +201,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // This could be the result of (partial) bleaching. if newECT0+newECNCE < ackedECT0 { e.logger.Debugf("Disabling ECN. Received less ECT(0) + ECN-CE than packets sent with ECT(0).") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts) } e.state = ecnStateFailed @@ -211,7 +211,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // the number of newly acknowledged packets sent with an ECT(1) marking. if newECT1+newECNCE < ackedECT1 { e.logger.Debugf("Disabling ECN. Received less ECT(1) + ECN-CE than packets sent with ECT(1).") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts) } e.state = ecnStateFailed @@ -226,7 +226,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 if e.state == ecnStateTesting || e.state == ecnStateUnknown { // Detect mangling (a path remarking all ECN-marked testing packets as CE). if e.numSentECT0+e.numSentECT1 == e.numAckedECNCE && e.numAckedECNCE >= numECNTestingPackets { - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) } e.state = ecnStateFailed @@ -243,7 +243,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // This check won't succeed if the path is mangling ECN-marks (i.e. rewrites all ECN-marked packets to CE). if ackedTestingPacket && (newECT0 > 0 || newECT1 > 0) { e.logger.Debugf("ECN capability confirmed.") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) } e.state = ecnStateCapable diff --git a/internal/ackhandler/ecn_test.go b/internal/ackhandler/ecn_test.go index 984a97a9..26d54eed 100644 --- a/internal/ackhandler/ecn_test.go +++ b/internal/ackhandler/ecn_test.go @@ -23,8 +23,9 @@ var _ = Describe("ECN tracker", func() { } BeforeEach(func() { - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) - ecnTracker = newECNTracker(utils.DefaultLogger, tracer) + var tr *logging.ConnectionTracer + tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + ecnTracker = newECNTracker(utils.DefaultLogger, tr) }) It("sends exactly 10 testing packets", func() { diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index b498ce4b..c8265a78 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -97,7 +97,7 @@ type sentPacketHandler struct { perspective protocol.Perspective - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger } @@ -115,7 +115,7 @@ func newSentPacketHandler( clientAddressValidated bool, enableECN bool, pers protocol.Perspective, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, ) *sentPacketHandler { congestion := congestion.NewCubicSender( @@ -196,7 +196,7 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { default: panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) } - if h.tracer != nil && h.ptoCount != 0 { + if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 { h.tracer.UpdatedPTOCount(0) } h.ptoCount = 0 @@ -286,7 +286,7 @@ func (h *sentPacketHandler) SentPacket( p.includedInBytesInFlight = true pnSpace.history.SentAckElicitingPacket(p) - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } h.setLossDetectionTimer() @@ -376,14 +376,14 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En // Reset the pto_count unless the client is unsure if the server has validated the client's address. if h.peerCompletedAddressValidation { - if h.tracer != nil && h.ptoCount != 0 { + if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 { h.tracer.UpdatedPTOCount(0) } h.ptoCount = 0 } h.numProbesToSend = 0 - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } @@ -462,7 +462,7 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL if err := pnSpace.history.Remove(p.PacketNumber); err != nil { return nil, err } - if h.tracer != nil { + if h.tracer != nil && h.tracer.AcknowledgedPacket != nil { h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber) } } @@ -555,7 +555,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { if !lossTime.IsZero() { // Early retransmit timer or time loss detection. h.alarm = lossTime - if h.tracer != nil && h.alarm != oldAlarm { + if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm { h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm) } return @@ -566,7 +566,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { h.alarm = time.Time{} if !oldAlarm.IsZero() { h.logger.Debugf("Canceling loss detection timer. Amplification limited.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } @@ -578,7 +578,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { h.alarm = time.Time{} if !oldAlarm.IsZero() { h.logger.Debugf("Canceling loss detection timer. No packets in flight.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } @@ -591,14 +591,14 @@ func (h *sentPacketHandler) setLossDetectionTimer() { if !oldAlarm.IsZero() { h.alarm = time.Time{} h.logger.Debugf("Canceling loss detection timer. No PTO needed..") - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } return } h.alarm = ptoTime - if h.tracer != nil && h.alarm != oldAlarm { + if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm { h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm) } } @@ -629,7 +629,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E if h.logger.Debug() { h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.LostPacket != nil { h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) } } @@ -639,7 +639,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E if h.logger.Debug() { h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.LostPacket != nil { h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) } } @@ -676,7 +676,7 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { if h.logger.Debug() { h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerExpired != nil { h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) } // Early retransmit or time loss detection @@ -713,8 +713,12 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) } if h.tracer != nil { - h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) - h.tracer.UpdatedPTOCount(h.ptoCount) + if h.tracer.LossTimerExpired != nil { + h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) + } + if h.tracer.UpdatedPTOCount != nil { + h.tracer.UpdatedPTOCount(h.ptoCount) + } } h.numProbesToSend += 2 //nolint:exhaustive // We never arm a PTO timer for 0-RTT packets. @@ -890,7 +894,7 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) error { if h.logger.Debug() { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } } @@ -899,8 +903,10 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) error { oldAlarm := h.alarm h.alarm = time.Time{} if h.tracer != nil { - h.tracer.UpdatedPTOCount(0) - if !oldAlarm.IsZero() { + if h.tracer.UpdatedPTOCount != nil { + h.tracer.UpdatedPTOCount(0) + } + if !oldAlarm.IsZero() && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } diff --git a/internal/congestion/cubic_sender.go b/internal/congestion/cubic_sender.go index 10eb4667..ee558f2d 100644 --- a/internal/congestion/cubic_sender.go +++ b/internal/congestion/cubic_sender.go @@ -56,7 +56,7 @@ type cubicSender struct { maxDatagramSize protocol.ByteCount lastState logging.CongestionState - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer } var ( @@ -70,7 +70,7 @@ func NewCubicSender( rttStats *utils.RTTStats, initialMaxDatagramSize protocol.ByteCount, reno bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, ) *cubicSender { return newCubicSender( clock, @@ -90,7 +90,7 @@ func newCubicSender( initialMaxDatagramSize, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, ) *cubicSender { c := &cubicSender{ rttStats: rttStats, @@ -108,7 +108,7 @@ func newCubicSender( maxDatagramSize: initialMaxDatagramSize, } c.pacer = newPacer(c.BandwidthEstimate) - if c.tracer != nil { + if c.tracer != nil && c.tracer.UpdatedCongestionState != nil { c.lastState = logging.CongestionStateSlowStart c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart) } @@ -296,7 +296,7 @@ func (c *cubicSender) OnConnectionMigration() { } func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { - if c.tracer == nil || new == c.lastState { + if c.tracer == nil || c.tracer.UpdatedCongestionState == nil || new == c.lastState { return } c.tracer.UpdatedCongestionState(new) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index b5fb404b..cae02873 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -43,7 +43,7 @@ type cryptoSetup struct { rttStats *utils.RTTStats - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger perspective protocol.Perspective @@ -77,7 +77,7 @@ func NewCryptoSetupClient( tlsConf *tls.Config, enable0RTT bool, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber, ) CryptoSetup { @@ -111,7 +111,7 @@ func NewCryptoSetupServer( tlsConf *tls.Config, allow0RTT bool, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber, ) CryptoSetup { @@ -165,13 +165,13 @@ func newCryptoSetup( connID protocol.ConnectionID, tp *wire.TransportParameters, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, perspective protocol.Perspective, version protocol.VersionNumber, ) *cryptoSetup { initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) - if tracer != nil { + if tracer != nil && tracer.UpdatedKeyFromTLS != nil { tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) } @@ -193,7 +193,7 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version) h.initialSealer = initialSealer h.initialOpener = initialOpener - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) } @@ -457,7 +457,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr } h.mutex.Unlock() h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) } } @@ -479,7 +479,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) } // don't set used0RTT here. 0-RTT might still get rejected. @@ -503,7 +503,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t h.used0RTT.Store(true) h.zeroRTTSealer = nil h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil { h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) } } @@ -511,7 +511,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t panic("unexpected write encryption level") } h.mutex.Unlock() - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) } } @@ -651,7 +651,7 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { h.zeroRTTOpener = nil h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil { h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) } } diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 919b8a5b..a583f277 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -57,7 +57,7 @@ type updatableAEAD struct { rttStats *utils.RTTStats - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger version protocol.VersionNumber @@ -70,7 +70,7 @@ var ( _ ShortHeaderSealer = &updatableAEAD{} ) -func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { +func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { return &updatableAEAD{ firstPacketNumber: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, @@ -86,7 +86,7 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, func (a *updatableAEAD) rollKeys() { if a.prevRcvAEAD != nil { a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) - if a.tracer != nil { + if a.tracer != nil && a.tracer.DroppedKey != nil { a.tracer.DroppedKey(a.keyPhase - 1) } a.prevRcvAEADExpiry = time.Time{} @@ -182,7 +182,7 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac a.prevRcvAEAD = nil a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) a.prevRcvAEADExpiry = time.Time{} - if a.tracer != nil { + if a.tracer != nil && a.tracer.DroppedKey != nil { a.tracer.DroppedKey(a.keyPhase - 1) } } @@ -216,7 +216,7 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac // The peer initiated this key update. It's safe to drop the keys for the previous generation now. // Start a timer to drop the previous key generation. a.startKeyDropTimer(rcvTime) - if a.tracer != nil { + if a.tracer != nil && a.tracer.UpdatedKey != nil { a.tracer.UpdatedKey(a.keyPhase, true) } a.firstRcvdWithCurrentKey = pn @@ -308,7 +308,7 @@ func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { if a.shouldInitiateKeyUpdate() { a.rollKeys() a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) - if a.tracer != nil { + if a.tracer != nil && a.tracer.UpdatedKey != nil { a.tracer.UpdatedKey(a.keyPhase, false) } } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index a4a91f01..a57030bf 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -11,6 +11,7 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -62,7 +63,8 @@ var _ = Describe("Updatable AEAD", func() { ) BeforeEach(func() { - serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) + var tr *logging.ConnectionTracer + tr, serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) trafficSecret1 := make([]byte, 16) trafficSecret2 := make([]byte, 16) rand.Read(trafficSecret1) @@ -70,7 +72,7 @@ var _ = Describe("Updatable AEAD", func() { rttStats = utils.NewRTTStats() client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v) - server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger, v) + server = newUpdatableAEAD(rttStats, tr, utils.DefaultLogger, v) client.SetReadKey(cs, trafficSecret2) client.SetWriteKey(cs, trafficSecret1) server.SetReadKey(cs, trafficSecret1) diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index 811d1e99..a2c74b1e 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -1,388 +1,108 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go/logging (interfaces: ConnectionTracer) +//go:build !gomock && !generate -// Package mocklogging is a generated GoMock package. package mocklogging import ( - net "net" - reflect "reflect" - time "time" + "net" + "time" - protocol "github.com/quic-go/quic-go/internal/protocol" - utils "github.com/quic-go/quic-go/internal/utils" - wire "github.com/quic-go/quic-go/internal/wire" - logging "github.com/quic-go/quic-go/logging" - gomock "go.uber.org/mock/gomock" + "github.com/quic-go/quic-go/internal/mocks/logging/internal" + "github.com/quic-go/quic-go/logging" + + "go.uber.org/mock/gomock" ) -// MockConnectionTracer is a mock of ConnectionTracer interface. -type MockConnectionTracer struct { - ctrl *gomock.Controller - recorder *MockConnectionTracerMockRecorder -} +type MockConnectionTracer = internal.MockConnectionTracer -// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. -type MockConnectionTracerMockRecorder struct { - mock *MockConnectionTracer -} - -// NewMockConnectionTracer creates a new mock instance. -func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { - mock := &MockConnectionTracer{ctrl: ctrl} - mock.recorder = &MockConnectionTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { - return m.recorder -} - -// AcknowledgedPacket mocks base method. -func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) -} - -// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. -func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) -} - -// BufferedPacket mocks base method. -func (m *MockConnectionTracer) BufferedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "BufferedPacket", arg0, arg1) -} - -// BufferedPacket indicates an expected call of BufferedPacket. -func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0, arg1) -} - -// Close mocks base method. -func (m *MockConnectionTracer) Close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Close") -} - -// Close indicates an expected call of Close. -func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) -} - -// ClosedConnection mocks base method. -func (m *MockConnectionTracer) ClosedConnection(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ClosedConnection", arg0) -} - -// ClosedConnection indicates an expected call of ClosedConnection. -func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) -} - -// Debug mocks base method. -func (m *MockConnectionTracer) Debug(arg0, arg1 string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Debug", arg0, arg1) -} - -// Debug indicates an expected call of Debug. -func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) -} - -// DroppedEncryptionLevel mocks base method. -func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) -} - -// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. -func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) -} - -// DroppedKey mocks base method. -func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedKey", arg0) -} - -// DroppedKey indicates an expected call of DroppedKey. -func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) -} - -// DroppedPacket mocks base method. -func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount, arg2 logging.PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) -} - -// ECNStateUpdated mocks base method. -func (m *MockConnectionTracer) ECNStateUpdated(arg0 logging.ECNState, arg1 logging.ECNStateTrigger) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ECNStateUpdated", arg0, arg1) -} - -// ECNStateUpdated indicates an expected call of ECNStateUpdated. -func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1) -} - -// LossTimerCanceled mocks base method. -func (m *MockConnectionTracer) LossTimerCanceled() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerCanceled") -} - -// LossTimerCanceled indicates an expected call of LossTimerCanceled. -func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) -} - -// LossTimerExpired mocks base method. -func (m *MockConnectionTracer) LossTimerExpired(arg0 logging.TimerType, arg1 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) -} - -// LossTimerExpired indicates an expected call of LossTimerExpired. -func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) -} - -// LostPacket mocks base method. -func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 logging.PacketLossReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) -} - -// LostPacket indicates an expected call of LostPacket. -func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) -} - -// NegotiatedVersion mocks base method. -func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) -} - -// NegotiatedVersion indicates an expected call of NegotiatedVersion. -func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) -} - -// ReceivedLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2, arg3) -} - -// ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3) -} - -// ReceivedRetry mocks base method. -func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedRetry", arg0) -} - -// ReceivedRetry indicates an expected call of ReceivedRetry. -func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) -} - -// ReceivedShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2, arg3) -} - -// ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3) -} - -// ReceivedTransportParameters mocks base method. -func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedTransportParameters", arg0) -} - -// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. -func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) -} - -// ReceivedVersionNegotiationPacket mocks base method. -func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0, arg1 protocol.ArbitraryLenConnectionID, arg2 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1, arg2) -} - -// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1, arg2) -} - -// RestoredTransportParameters mocks base method. -func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RestoredTransportParameters", arg0) -} - -// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. -func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) -} - -// SentLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3, arg4) -} - -// SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4) -} - -// SentShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3, arg4) -} - -// SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4) -} - -// SentTransportParameters mocks base method. -func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentTransportParameters", arg0) -} - -// SentTransportParameters indicates an expected call of SentTransportParameters. -func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) -} - -// SetLossTimer mocks base method. -func (m *MockConnectionTracer) SetLossTimer(arg0 logging.TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) -} - -// SetLossTimer indicates an expected call of SetLossTimer. -func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) -} - -// StartedConnection mocks base method. -func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) -} - -// StartedConnection indicates an expected call of StartedConnection. -func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) -} - -// UpdatedCongestionState mocks base method. -func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionState) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedCongestionState", arg0) -} - -// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. -func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) -} - -// UpdatedKey mocks base method. -func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKey", arg0, arg1) -} - -// UpdatedKey indicates an expected call of UpdatedKey. -func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) -} - -// UpdatedKeyFromTLS mocks base method. -func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) -} - -// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. -func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) -} - -// UpdatedMetrics mocks base method. -func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) -} - -// UpdatedMetrics indicates an expected call of UpdatedMetrics. -func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) -} - -// UpdatedPTOCount mocks base method. -func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedPTOCount", arg0) -} - -// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. -func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) +func NewMockConnectionTracer(ctrl *gomock.Controller) (*logging.ConnectionTracer, *MockConnectionTracer) { + t := internal.NewMockConnectionTracer(ctrl) + return &logging.ConnectionTracer{ + StartedConnection: func(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) { + t.StartedConnection(local, remote, srcConnID, destConnID) + }, + NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + }, + ClosedConnection: func(e error) { + t.ClosedConnection(e) + }, + SentTransportParameters: func(tp *logging.TransportParameters) { + t.SentTransportParameters(tp) + }, + ReceivedTransportParameters: func(tp *logging.TransportParameters) { + t.ReceivedTransportParameters(tp) + }, + RestoredTransportParameters: func(tp *logging.TransportParameters) { + t.RestoredTransportParameters(tp) + }, + SentLongHeaderPacket: func(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) + }, + SentShortHeaderPacket: func(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) + }, + ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, versions []logging.VersionNumber) { + t.ReceivedVersionNegotiationPacket(dest, src, versions) + }, + ReceivedRetry: func(hdr *logging.Header) { + t.ReceivedRetry(hdr) + }, + ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) + }, + ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) + }, + BufferedPacket: func(typ logging.PacketType, size logging.ByteCount) { + t.BufferedPacket(typ, size) + }, + DroppedPacket: func(typ logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) { + t.DroppedPacket(typ, size, reason) + }, + UpdatedMetrics: func(rttStats *logging.RTTStats, cwnd, bytesInFlight logging.ByteCount, packetsInFlight int) { + t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) + }, + AcknowledgedPacket: func(encLevel logging.EncryptionLevel, pn logging.PacketNumber) { + t.AcknowledgedPacket(encLevel, pn) + }, + LostPacket: func(encLevel logging.EncryptionLevel, pn logging.PacketNumber, reason logging.PacketLossReason) { + t.LostPacket(encLevel, pn, reason) + }, + UpdatedCongestionState: func(state logging.CongestionState) { + t.UpdatedCongestionState(state) + }, + UpdatedPTOCount: func(value uint32) { + t.UpdatedPTOCount(value) + }, + UpdatedKeyFromTLS: func(encLevel logging.EncryptionLevel, perspective logging.Perspective) { + t.UpdatedKeyFromTLS(encLevel, perspective) + }, + UpdatedKey: func(generation logging.KeyPhase, remote bool) { + t.UpdatedKey(generation, remote) + }, + DroppedEncryptionLevel: func(encLevel logging.EncryptionLevel) { + t.DroppedEncryptionLevel(encLevel) + }, + DroppedKey: func(generation logging.KeyPhase) { + t.DroppedKey(generation) + }, + SetLossTimer: func(typ logging.TimerType, encLevel logging.EncryptionLevel, exp time.Time) { + t.SetLossTimer(typ, encLevel, exp) + }, + LossTimerExpired: func(typ logging.TimerType, encLevel logging.EncryptionLevel) { + t.LossTimerExpired(typ, encLevel) + }, + LossTimerCanceled: func() { + t.LossTimerCanceled() + }, + ECNStateUpdated: func(state logging.ECNState, trigger logging.ECNStateTrigger) { + t.ECNStateUpdated(state, trigger) + }, + Close: func() { + t.Close() + }, + Debug: func(name, msg string) { + t.Debug(name, msg) + }, + }, t } diff --git a/internal/mocks/logging/internal/connection_tracer.go b/internal/mocks/logging/internal/connection_tracer.go new file mode 100644 index 00000000..1cc0bd2e --- /dev/null +++ b/internal/mocks/logging/internal/connection_tracer.go @@ -0,0 +1,388 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go/internal/mocks/logging (interfaces: ConnectionTracer) + +// Package internal is a generated GoMock package. +package internal + +import ( + net "net" + reflect "reflect" + time "time" + + protocol "github.com/quic-go/quic-go/internal/protocol" + utils "github.com/quic-go/quic-go/internal/utils" + wire "github.com/quic-go/quic-go/internal/wire" + logging "github.com/quic-go/quic-go/logging" + gomock "go.uber.org/mock/gomock" +) + +// MockConnectionTracer is a mock of ConnectionTracer interface. +type MockConnectionTracer struct { + ctrl *gomock.Controller + recorder *MockConnectionTracerMockRecorder +} + +// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. +type MockConnectionTracerMockRecorder struct { + mock *MockConnectionTracer +} + +// NewMockConnectionTracer creates a new mock instance. +func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { + mock := &MockConnectionTracer{ctrl: ctrl} + mock.recorder = &MockConnectionTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { + return m.recorder +} + +// AcknowledgedPacket mocks base method. +func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) +} + +// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. +func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) +} + +// BufferedPacket mocks base method. +func (m *MockConnectionTracer) BufferedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "BufferedPacket", arg0, arg1) +} + +// BufferedPacket indicates an expected call of BufferedPacket. +func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0, arg1) +} + +// Close mocks base method. +func (m *MockConnectionTracer) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) +} + +// ClosedConnection mocks base method. +func (m *MockConnectionTracer) ClosedConnection(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ClosedConnection", arg0) +} + +// ClosedConnection indicates an expected call of ClosedConnection. +func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) +} + +// Debug mocks base method. +func (m *MockConnectionTracer) Debug(arg0, arg1 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", arg0, arg1) +} + +// Debug indicates an expected call of Debug. +func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) +} + +// DroppedEncryptionLevel mocks base method. +func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) +} + +// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. +func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) +} + +// DroppedKey mocks base method. +func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedKey", arg0) +} + +// DroppedKey indicates an expected call of DroppedKey. +func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) +} + +// DroppedPacket mocks base method. +func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount, arg2 logging.PacketDropReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) +} + +// DroppedPacket indicates an expected call of DroppedPacket. +func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) +} + +// ECNStateUpdated mocks base method. +func (m *MockConnectionTracer) ECNStateUpdated(arg0 logging.ECNState, arg1 logging.ECNStateTrigger) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ECNStateUpdated", arg0, arg1) +} + +// ECNStateUpdated indicates an expected call of ECNStateUpdated. +func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1) +} + +// LossTimerCanceled mocks base method. +func (m *MockConnectionTracer) LossTimerCanceled() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LossTimerCanceled") +} + +// LossTimerCanceled indicates an expected call of LossTimerCanceled. +func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) +} + +// LossTimerExpired mocks base method. +func (m *MockConnectionTracer) LossTimerExpired(arg0 logging.TimerType, arg1 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) +} + +// LossTimerExpired indicates an expected call of LossTimerExpired. +func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) +} + +// LostPacket mocks base method. +func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 logging.PacketLossReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) +} + +// LostPacket indicates an expected call of LostPacket. +func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) +} + +// NegotiatedVersion mocks base method. +func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) +} + +// NegotiatedVersion indicates an expected call of NegotiatedVersion. +func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) +} + +// ReceivedLongHeaderPacket mocks base method. +func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2, arg3) +} + +// ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3) +} + +// ReceivedRetry mocks base method. +func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedRetry", arg0) +} + +// ReceivedRetry indicates an expected call of ReceivedRetry. +func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) +} + +// ReceivedShortHeaderPacket mocks base method. +func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2, arg3) +} + +// ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3) +} + +// ReceivedTransportParameters mocks base method. +func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedTransportParameters", arg0) +} + +// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. +func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) +} + +// ReceivedVersionNegotiationPacket mocks base method. +func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0, arg1 protocol.ArbitraryLenConnectionID, arg2 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1, arg2) +} + +// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1, arg2) +} + +// RestoredTransportParameters mocks base method. +func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RestoredTransportParameters", arg0) +} + +// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. +func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) +} + +// SentLongHeaderPacket mocks base method. +func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3, arg4) +} + +// SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket. +func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4) +} + +// SentShortHeaderPacket mocks base method. +func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3, arg4) +} + +// SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket. +func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4) +} + +// SentTransportParameters mocks base method. +func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentTransportParameters", arg0) +} + +// SentTransportParameters indicates an expected call of SentTransportParameters. +func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) +} + +// SetLossTimer mocks base method. +func (m *MockConnectionTracer) SetLossTimer(arg0 logging.TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) +} + +// SetLossTimer indicates an expected call of SetLossTimer. +func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) +} + +// StartedConnection mocks base method. +func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) +} + +// StartedConnection indicates an expected call of StartedConnection. +func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) +} + +// UpdatedCongestionState mocks base method. +func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionState) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedCongestionState", arg0) +} + +// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. +func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) +} + +// UpdatedKey mocks base method. +func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedKey", arg0, arg1) +} + +// UpdatedKey indicates an expected call of UpdatedKey. +func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) +} + +// UpdatedKeyFromTLS mocks base method. +func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) +} + +// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. +func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) +} + +// UpdatedMetrics mocks base method. +func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) +} + +// UpdatedMetrics indicates an expected call of UpdatedMetrics. +func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) +} + +// UpdatedPTOCount mocks base method. +func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedPTOCount", arg0) +} + +// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. +func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) +} diff --git a/internal/mocks/logging/internal/tracer.go b/internal/mocks/logging/internal/tracer.go new file mode 100644 index 00000000..15dbcaed --- /dev/null +++ b/internal/mocks/logging/internal/tracer.go @@ -0,0 +1,74 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go/internal/mocks/logging (interfaces: Tracer) + +// Package internal is a generated GoMock package. +package internal + +import ( + net "net" + reflect "reflect" + + protocol "github.com/quic-go/quic-go/internal/protocol" + wire "github.com/quic-go/quic-go/internal/wire" + logging "github.com/quic-go/quic-go/logging" + gomock "go.uber.org/mock/gomock" +) + +// MockTracer is a mock of Tracer interface. +type MockTracer struct { + ctrl *gomock.Controller + recorder *MockTracerMockRecorder +} + +// MockTracerMockRecorder is the mock recorder for MockTracer. +type MockTracerMockRecorder struct { + mock *MockTracer +} + +// NewMockTracer creates a new mock instance. +func NewMockTracer(ctrl *gomock.Controller) *MockTracer { + mock := &MockTracer{ctrl: ctrl} + mock.recorder = &MockTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTracer) EXPECT() *MockTracerMockRecorder { + return m.recorder +} + +// DroppedPacket mocks base method. +func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 protocol.ByteCount, arg3 logging.PacketDropReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) +} + +// DroppedPacket indicates an expected call of DroppedPacket. +func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) +} + +// SentPacket mocks base method. +func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) +} + +// SentVersionNegotiationPacket mocks base method. +func (m *MockTracer) SentVersionNegotiationPacket(arg0 net.Addr, arg1, arg2 protocol.ArbitraryLenConnectionID, arg3 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentVersionNegotiationPacket", arg0, arg1, arg2, arg3) +} + +// SentVersionNegotiationPacket indicates an expected call of SentVersionNegotiationPacket. +func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) +} diff --git a/internal/mocks/logging/mockgen.go b/internal/mocks/logging/mockgen.go new file mode 100644 index 00000000..8bf08dc8 --- /dev/null +++ b/internal/mocks/logging/mockgen.go @@ -0,0 +1,51 @@ +//go:build gomock || generate + +package mocklogging + +import ( + "net" + "time" + + "github.com/quic-go/quic-go/logging" +) + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package internal -destination internal/tracer.go github.com/quic-go/quic-go/internal/mocks/logging Tracer" +type Tracer interface { + SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) + SentVersionNegotiationPacket(_ net.Addr, dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) + DroppedPacket(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) +} + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package internal -destination internal/connection_tracer.go github.com/quic-go/quic-go/internal/mocks/logging ConnectionTracer" +type ConnectionTracer interface { + StartedConnection(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) + NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) + ClosedConnection(error) + SentTransportParameters(*logging.TransportParameters) + ReceivedTransportParameters(*logging.TransportParameters) + RestoredTransportParameters(parameters *logging.TransportParameters) // for 0-RTT + SentLongHeaderPacket(*logging.ExtendedHeader, logging.ByteCount, logging.ECN, *logging.AckFrame, []logging.Frame) + SentShortHeaderPacket(*logging.ShortHeader, logging.ByteCount, logging.ECN, *logging.AckFrame, []logging.Frame) + ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) + ReceivedRetry(*logging.Header) + ReceivedLongHeaderPacket(*logging.ExtendedHeader, logging.ByteCount, logging.ECN, []logging.Frame) + ReceivedShortHeaderPacket(*logging.ShortHeader, logging.ByteCount, logging.ECN, []logging.Frame) + BufferedPacket(logging.PacketType, logging.ByteCount) + DroppedPacket(logging.PacketType, logging.ByteCount, logging.PacketDropReason) + UpdatedMetrics(rttStats *logging.RTTStats, cwnd, bytesInFlight logging.ByteCount, packetsInFlight int) + AcknowledgedPacket(logging.EncryptionLevel, logging.PacketNumber) + LostPacket(logging.EncryptionLevel, logging.PacketNumber, logging.PacketLossReason) + UpdatedCongestionState(logging.CongestionState) + UpdatedPTOCount(value uint32) + UpdatedKeyFromTLS(logging.EncryptionLevel, logging.Perspective) + UpdatedKey(generation logging.KeyPhase, remote bool) + DroppedEncryptionLevel(logging.EncryptionLevel) + DroppedKey(generation logging.KeyPhase) + SetLossTimer(logging.TimerType, logging.EncryptionLevel, time.Time) + LossTimerExpired(logging.TimerType, logging.EncryptionLevel) + LossTimerCanceled() + ECNStateUpdated(state logging.ECNState, trigger logging.ECNStateTrigger) + // Close is called when the connection is closed. + Close() + Debug(name, msg string) +} diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go index 2cdd2335..115f578a 100644 --- a/internal/mocks/logging/tracer.go +++ b/internal/mocks/logging/tracer.go @@ -1,74 +1,29 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go/logging (interfaces: Tracer) +//go:build !gomock && !generate -// Package mocklogging is a generated GoMock package. package mocklogging import ( - net "net" - reflect "reflect" + "net" - protocol "github.com/quic-go/quic-go/internal/protocol" - wire "github.com/quic-go/quic-go/internal/wire" - logging "github.com/quic-go/quic-go/logging" - gomock "go.uber.org/mock/gomock" + "github.com/quic-go/quic-go/internal/mocks/logging/internal" + "github.com/quic-go/quic-go/logging" + + "go.uber.org/mock/gomock" ) -// MockTracer is a mock of Tracer interface. -type MockTracer struct { - ctrl *gomock.Controller - recorder *MockTracerMockRecorder -} +type MockTracer = internal.MockTracer -// MockTracerMockRecorder is the mock recorder for MockTracer. -type MockTracerMockRecorder struct { - mock *MockTracer -} - -// NewMockTracer creates a new mock instance. -func NewMockTracer(ctrl *gomock.Controller) *MockTracer { - mock := &MockTracer{ctrl: ctrl} - mock.recorder = &MockTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTracer) EXPECT() *MockTracerMockRecorder { - return m.recorder -} - -// DroppedPacket mocks base method. -func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 protocol.ByteCount, arg3 logging.PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) -} - -// SentPacket mocks base method. -func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) -} - -// SentVersionNegotiationPacket mocks base method. -func (m *MockTracer) SentVersionNegotiationPacket(arg0 net.Addr, arg1, arg2 protocol.ArbitraryLenConnectionID, arg3 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentVersionNegotiationPacket", arg0, arg1, arg2, arg3) -} - -// SentVersionNegotiationPacket indicates an expected call of SentVersionNegotiationPacket. -func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) +func NewMockTracer(ctrl *gomock.Controller) (*logging.Tracer, *MockTracer) { + t := internal.NewMockTracer(ctrl) + return &logging.Tracer{ + SentPacket: func(remote net.Addr, hdr *logging.Header, size logging.ByteCount, frames []logging.Frame) { + t.SentPacket(remote, hdr, size, frames) + }, + SentVersionNegotiationPacket: func(remote net.Addr, dest, src logging.ArbitraryLenConnectionID, versions []logging.VersionNumber) { + t.SentVersionNegotiationPacket(remote, dest, src, versions) + }, + DroppedPacket: func(remote net.Addr, typ logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) { + t.DroppedPacket(remote, typ, size, reason) + }, + }, t } diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 9bee4270..23bcda00 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -1,19 +1,19 @@ +//go:build gomock || generate + package mocks -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockquic -destination quic/stream.go github.com/quic-go/quic-go Stream" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/quic-go/quic-go EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocklogging -destination logging/tracer.go github.com/quic-go/quic-go/logging Tracer" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocklogging -destination logging/connection_tracer.go github.com/quic-go/quic-go/logging ConnectionTracer" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination short_header_sealer.go github.com/quic-go/quic-go/internal/handshake ShortHeaderSealer" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination short_header_opener.go github.com/quic-go/quic-go/internal/handshake ShortHeaderOpener" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination long_header_opener.go github.com/quic-go/quic-go/internal/handshake LongHeaderOpener" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination crypto_setup_tmp.go github.com/quic-go/quic-go/internal/handshake CryptoSetup && sed -E 's~github.com/quic-go/qtls[[:alnum:]_-]*~github.com/quic-go/quic-go/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && go run golang.org/x/tools/cmd/goimports -w crypto_setup.go" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination stream_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol StreamFlowController" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination congestion.go github.com/quic-go/quic-go/internal/congestion SendAlgorithmWithDebugInfos" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination connection_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol ConnectionFlowController" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler SentPacketHandler" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler ReceivedPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockquic -destination quic/stream.go github.com/quic-go/quic-go Stream" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockquic -destination quic/early_conn_tmp.go github.com/quic-go/quic-go EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination short_header_sealer.go github.com/quic-go/quic-go/internal/handshake ShortHeaderSealer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination short_header_opener.go github.com/quic-go/quic-go/internal/handshake ShortHeaderOpener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination long_header_opener.go github.com/quic-go/quic-go/internal/handshake LongHeaderOpener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination crypto_setup_tmp.go github.com/quic-go/quic-go/internal/handshake CryptoSetup && sed -E 's~github.com/quic-go/qtls[[:alnum:]_-]*~github.com/quic-go/quic-go/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && go run golang.org/x/tools/cmd/goimports -w crypto_setup.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination stream_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol StreamFlowController" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination congestion.go github.com/quic-go/quic-go/internal/congestion SendAlgorithmWithDebugInfos" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination connection_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol ConnectionFlowController" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler SentPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler ReceivedPacketHandler" // The following command produces a warning message on OSX, however, it still generates the correct mock file. // See https://github.com/golang/mock/issues/339 for details. -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" diff --git a/interop/utils/logging.go b/interop/utils/logging.go index 30e3f663..3cb42244 100644 --- a/interop/utils/logging.go +++ b/interop/utils/logging.go @@ -29,7 +29,7 @@ func GetSSLKeyLog() (io.WriteCloser, error) { } // NewQLOGConnectionTracer create a qlog file in QLOGDIR -func NewQLOGConnectionTracer(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { +func NewQLOGConnectionTracer(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { qlogDir := os.Getenv("QLOGDIR") if len(qlogDir) == 0 { return nil diff --git a/logging/connection_tracer.go b/logging/connection_tracer.go new file mode 100644 index 00000000..e3f322d9 --- /dev/null +++ b/logging/connection_tracer.go @@ -0,0 +1,255 @@ +package logging + +import ( + "net" + "time" +) + +// A ConnectionTracer records events. +type ConnectionTracer struct { + StartedConnection func(local, remote net.Addr, srcConnID, destConnID ConnectionID) + NegotiatedVersion func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) + ClosedConnection func(error) + SentTransportParameters func(*TransportParameters) + ReceivedTransportParameters func(*TransportParameters) + RestoredTransportParameters func(parameters *TransportParameters) // for 0-RTT + SentLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) + SentShortHeaderPacket func(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) + ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, _ []VersionNumber) + ReceivedRetry func(*Header) + ReceivedLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, []Frame) + ReceivedShortHeaderPacket func(*ShortHeader, ByteCount, ECN, []Frame) + BufferedPacket func(PacketType, ByteCount) + DroppedPacket func(PacketType, ByteCount, PacketDropReason) + UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) + AcknowledgedPacket func(EncryptionLevel, PacketNumber) + LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason) + UpdatedCongestionState func(CongestionState) + UpdatedPTOCount func(value uint32) + UpdatedKeyFromTLS func(EncryptionLevel, Perspective) + UpdatedKey func(generation KeyPhase, remote bool) + DroppedEncryptionLevel func(EncryptionLevel) + DroppedKey func(generation KeyPhase) + SetLossTimer func(TimerType, EncryptionLevel, time.Time) + LossTimerExpired func(TimerType, EncryptionLevel) + LossTimerCanceled func() + ECNStateUpdated func(state ECNState, trigger ECNStateTrigger) + // Close is called when the connection is closed. + Close func() + Debug func(name, msg string) +} + +// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. +func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &ConnectionTracer{ + StartedConnection: func(local, remote net.Addr, srcConnID, destConnID ConnectionID) { + for _, t := range tracers { + if t.StartedConnection != nil { + t.StartedConnection(local, remote, srcConnID, destConnID) + } + } + }, + NegotiatedVersion: func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { + for _, t := range tracers { + if t.NegotiatedVersion != nil { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + } + } + }, + ClosedConnection: func(e error) { + for _, t := range tracers { + if t.ClosedConnection != nil { + t.ClosedConnection(e) + } + } + }, + SentTransportParameters: func(tp *TransportParameters) { + for _, t := range tracers { + if t.SentTransportParameters != nil { + t.SentTransportParameters(tp) + } + } + }, + ReceivedTransportParameters: func(tp *TransportParameters) { + for _, t := range tracers { + if t.ReceivedTransportParameters != nil { + t.ReceivedTransportParameters(tp) + } + } + }, + RestoredTransportParameters: func(tp *TransportParameters) { + for _, t := range tracers { + if t.RestoredTransportParameters != nil { + t.RestoredTransportParameters(tp) + } + } + }, + SentLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { + for _, t := range tracers { + if t.SentLongHeaderPacket != nil { + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) + } + } + }, + SentShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { + for _, t := range tracers { + if t.SentShortHeaderPacket != nil { + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) + } + } + }, + ReceivedVersionNegotiationPacket: func(dest, src ArbitraryLenConnectionID, versions []VersionNumber) { + for _, t := range tracers { + if t.ReceivedVersionNegotiationPacket != nil { + t.ReceivedVersionNegotiationPacket(dest, src, versions) + } + } + }, + ReceivedRetry: func(hdr *Header) { + for _, t := range tracers { + if t.ReceivedRetry != nil { + t.ReceivedRetry(hdr) + } + } + }, + ReceivedLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) { + for _, t := range tracers { + if t.ReceivedLongHeaderPacket != nil { + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) + } + } + }, + ReceivedShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) { + for _, t := range tracers { + if t.ReceivedShortHeaderPacket != nil { + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) + } + } + }, + BufferedPacket: func(typ PacketType, size ByteCount) { + for _, t := range tracers { + if t.BufferedPacket != nil { + t.BufferedPacket(typ, size) + } + } + }, + DroppedPacket: func(typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range tracers { + if t.DroppedPacket != nil { + t.DroppedPacket(typ, size, reason) + } + } + }, + UpdatedMetrics: func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { + for _, t := range tracers { + if t.UpdatedMetrics != nil { + t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) + } + } + }, + AcknowledgedPacket: func(encLevel EncryptionLevel, pn PacketNumber) { + for _, t := range tracers { + if t.AcknowledgedPacket != nil { + t.AcknowledgedPacket(encLevel, pn) + } + } + }, + LostPacket: func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { + for _, t := range tracers { + if t.LostPacket != nil { + t.LostPacket(encLevel, pn, reason) + } + } + }, + UpdatedCongestionState: func(state CongestionState) { + for _, t := range tracers { + if t.UpdatedCongestionState != nil { + t.UpdatedCongestionState(state) + } + } + }, + UpdatedPTOCount: func(value uint32) { + for _, t := range tracers { + if t.UpdatedPTOCount != nil { + t.UpdatedPTOCount(value) + } + } + }, + UpdatedKeyFromTLS: func(encLevel EncryptionLevel, perspective Perspective) { + for _, t := range tracers { + if t.UpdatedKeyFromTLS != nil { + t.UpdatedKeyFromTLS(encLevel, perspective) + } + } + }, + UpdatedKey: func(generation KeyPhase, remote bool) { + for _, t := range tracers { + if t.UpdatedKey != nil { + t.UpdatedKey(generation, remote) + } + } + }, + DroppedEncryptionLevel: func(encLevel EncryptionLevel) { + for _, t := range tracers { + if t.DroppedEncryptionLevel != nil { + t.DroppedEncryptionLevel(encLevel) + } + } + }, + DroppedKey: func(generation KeyPhase) { + for _, t := range tracers { + if t.DroppedKey != nil { + t.DroppedKey(generation) + } + } + }, + SetLossTimer: func(typ TimerType, encLevel EncryptionLevel, exp time.Time) { + for _, t := range tracers { + if t.SetLossTimer != nil { + t.SetLossTimer(typ, encLevel, exp) + } + } + }, + LossTimerExpired: func(typ TimerType, encLevel EncryptionLevel) { + for _, t := range tracers { + if t.LossTimerExpired != nil { + t.LossTimerExpired(typ, encLevel) + } + } + }, + LossTimerCanceled: func() { + for _, t := range tracers { + if t.LossTimerCanceled != nil { + t.LossTimerCanceled() + } + } + }, + ECNStateUpdated: func(state ECNState, trigger ECNStateTrigger) { + for _, t := range tracers { + if t.ECNStateUpdated != nil { + t.ECNStateUpdated(state, trigger) + } + } + }, + Close: func() { + for _, t := range tracers { + if t.Close != nil { + t.Close() + } + } + }, + Debug: func(name, msg string) { + for _, t := range tracers { + if t.Debug != nil { + t.Debug(name, msg) + } + } + }, + } +} diff --git a/logging/interface.go b/logging/interface.go index 62028ebc..10ac038f 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -3,9 +3,6 @@ package logging import ( - "net" - "time" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" @@ -112,44 +109,3 @@ type ShortHeader struct { PacketNumberLen protocol.PacketNumberLen KeyPhase KeyPhaseBit } - -// A Tracer traces events. -type Tracer interface { - SentPacket(net.Addr, *Header, ByteCount, []Frame) - SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) - DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) -} - -// A ConnectionTracer records events. -type ConnectionTracer interface { - StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) - NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) - ClosedConnection(error) - SentTransportParameters(*TransportParameters) - ReceivedTransportParameters(*TransportParameters) - RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT - SentLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) - SentShortHeaderPacket(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) - ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) - ReceivedRetry(*Header) - ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, []Frame) - ReceivedShortHeaderPacket(*ShortHeader, ByteCount, ECN, []Frame) - BufferedPacket(PacketType, ByteCount) - DroppedPacket(PacketType, ByteCount, PacketDropReason) - UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) - AcknowledgedPacket(EncryptionLevel, PacketNumber) - LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) - UpdatedCongestionState(CongestionState) - UpdatedPTOCount(value uint32) - UpdatedKeyFromTLS(EncryptionLevel, Perspective) - UpdatedKey(generation KeyPhase, remote bool) - DroppedEncryptionLevel(EncryptionLevel) - DroppedKey(generation KeyPhase) - SetLossTimer(TimerType, EncryptionLevel, time.Time) - LossTimerExpired(TimerType, EncryptionLevel) - LossTimerCanceled() - ECNStateUpdated(state ECNState, trigger ECNStateTrigger) - // Close is called when the connection is closed. - Close() - Debug(name, msg string) -} diff --git a/logging/multiplex.go b/logging/multiplex.go deleted file mode 100644 index ef91f3f2..00000000 --- a/logging/multiplex.go +++ /dev/null @@ -1,232 +0,0 @@ -package logging - -import ( - "net" - "time" -) - -type tracerMultiplexer struct { - tracers []Tracer -} - -var _ Tracer = &tracerMultiplexer{} - -// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. -func NewMultiplexedTracer(tracers ...Tracer) Tracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &tracerMultiplexer{tracers} -} - -func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { - for _, t := range m.tracers { - t.SentPacket(remote, hdr, size, frames) - } -} - -func (m *tracerMultiplexer) SentVersionNegotiationPacket(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) { - for _, t := range m.tracers { - t.SentVersionNegotiationPacket(remote, dest, src, versions) - } -} - -func (m *tracerMultiplexer) DroppedPacket(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range m.tracers { - t.DroppedPacket(remote, typ, size, reason) - } -} - -type connTracerMultiplexer struct { - tracers []ConnectionTracer -} - -var _ ConnectionTracer = &connTracerMultiplexer{} - -// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. -func NewMultiplexedConnectionTracer(tracers ...ConnectionTracer) ConnectionTracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &connTracerMultiplexer{tracers: tracers} -} - -func (m *connTracerMultiplexer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { - for _, t := range m.tracers { - t.StartedConnection(local, remote, srcConnID, destConnID) - } -} - -func (m *connTracerMultiplexer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { - for _, t := range m.tracers { - t.NegotiatedVersion(chosen, clientVersions, serverVersions) - } -} - -func (m *connTracerMultiplexer) ClosedConnection(e error) { - for _, t := range m.tracers { - t.ClosedConnection(e) - } -} - -func (m *connTracerMultiplexer) SentTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.SentTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) ReceivedTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.ReceivedTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.RestoredTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { - for _, t := range m.tracers { - t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) - } -} - -func (m *connTracerMultiplexer) SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { - for _, t := range m.tracers { - t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) - } -} - -func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, versions []VersionNumber) { - for _, t := range m.tracers { - t.ReceivedVersionNegotiationPacket(dest, src, versions) - } -} - -func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) { - for _, t := range m.tracers { - t.ReceivedRetry(hdr) - } -} - -func (m *connTracerMultiplexer) ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) { - for _, t := range m.tracers { - t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) - } -} - -func (m *connTracerMultiplexer) ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) { - for _, t := range m.tracers { - t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) - } -} - -func (m *connTracerMultiplexer) BufferedPacket(typ PacketType, size ByteCount) { - for _, t := range m.tracers { - t.BufferedPacket(typ, size) - } -} - -func (m *connTracerMultiplexer) DroppedPacket(typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range m.tracers { - t.DroppedPacket(typ, size, reason) - } -} - -func (m *connTracerMultiplexer) UpdatedCongestionState(state CongestionState) { - for _, t := range m.tracers { - t.UpdatedCongestionState(state) - } -} - -func (m *connTracerMultiplexer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFLight ByteCount, packetsInFlight int) { - for _, t := range m.tracers { - t.UpdatedMetrics(rttStats, cwnd, bytesInFLight, packetsInFlight) - } -} - -func (m *connTracerMultiplexer) AcknowledgedPacket(encLevel EncryptionLevel, pn PacketNumber) { - for _, t := range m.tracers { - t.AcknowledgedPacket(encLevel, pn) - } -} - -func (m *connTracerMultiplexer) LostPacket(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { - for _, t := range m.tracers { - t.LostPacket(encLevel, pn, reason) - } -} - -func (m *connTracerMultiplexer) UpdatedPTOCount(value uint32) { - for _, t := range m.tracers { - t.UpdatedPTOCount(value) - } -} - -func (m *connTracerMultiplexer) UpdatedKeyFromTLS(encLevel EncryptionLevel, perspective Perspective) { - for _, t := range m.tracers { - t.UpdatedKeyFromTLS(encLevel, perspective) - } -} - -func (m *connTracerMultiplexer) UpdatedKey(generation KeyPhase, remote bool) { - for _, t := range m.tracers { - t.UpdatedKey(generation, remote) - } -} - -func (m *connTracerMultiplexer) DroppedEncryptionLevel(encLevel EncryptionLevel) { - for _, t := range m.tracers { - t.DroppedEncryptionLevel(encLevel) - } -} - -func (m *connTracerMultiplexer) DroppedKey(generation KeyPhase) { - for _, t := range m.tracers { - t.DroppedKey(generation) - } -} - -func (m *connTracerMultiplexer) SetLossTimer(typ TimerType, encLevel EncryptionLevel, exp time.Time) { - for _, t := range m.tracers { - t.SetLossTimer(typ, encLevel, exp) - } -} - -func (m *connTracerMultiplexer) LossTimerExpired(typ TimerType, encLevel EncryptionLevel) { - for _, t := range m.tracers { - t.LossTimerExpired(typ, encLevel) - } -} - -func (m *connTracerMultiplexer) LossTimerCanceled() { - for _, t := range m.tracers { - t.LossTimerCanceled() - } -} - -func (m *connTracerMultiplexer) ECNStateUpdated(state ECNState, trigger ECNStateTrigger) { - for _, t := range m.tracers { - t.ECNStateUpdated(state, trigger) - } -} - -func (m *connTracerMultiplexer) Debug(name, msg string) { - for _, t := range m.tracers { - t.Debug(name, msg) - } -} - -func (m *connTracerMultiplexer) Close() { - for _, t := range m.tracers { - t.Close() - } -} diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index c0f784ec..69606346 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -21,21 +21,22 @@ var _ = Describe("Tracing", func() { }) It("returns the raw tracer if only one tracer is passed in", func() { - tr := mocklogging.NewMockTracer(mockCtrl) + tr := &Tracer{} tracer := NewMultiplexedTracer(tr) - Expect(tracer).To(BeAssignableToTypeOf(&mocklogging.MockTracer{})) + Expect(tracer).To(Equal(tr)) }) Context("tracing events", func() { var ( - tracer Tracer + tracer *Tracer tr1, tr2 *mocklogging.MockTracer ) BeforeEach(func() { - tr1 = mocklogging.NewMockTracer(mockCtrl) - tr2 = mocklogging.NewMockTracer(mockCtrl) - tracer = NewMultiplexedTracer(tr1, tr2) + var t1, t2 *Tracer + t1, tr1 = mocklogging.NewMockTracer(mockCtrl) + t2, tr2 = mocklogging.NewMockTracer(mockCtrl) + tracer = NewMultiplexedTracer(t1, t2, &Tracer{}) }) It("traces the PacketSent event", func() { @@ -68,18 +69,19 @@ var _ = Describe("Tracing", func() { Context("Connection Tracer", func() { var ( - tracer ConnectionTracer + tracer *ConnectionTracer tr1 *mocklogging.MockConnectionTracer tr2 *mocklogging.MockConnectionTracer ) BeforeEach(func() { - tr1 = mocklogging.NewMockConnectionTracer(mockCtrl) - tr2 = mocklogging.NewMockConnectionTracer(mockCtrl) - tracer = NewMultiplexedConnectionTracer(tr1, tr2) + var t1, t2 *ConnectionTracer + t1, tr1 = mocklogging.NewMockConnectionTracer(mockCtrl) + t2, tr2 = mocklogging.NewMockConnectionTracer(mockCtrl) + tracer = NewMultiplexedConnectionTracer(t1, t2) }) - It("trace the ConnectionStarted event", func() { + It("traces the StartedConnection event", func() { local := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4)} remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} dest := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) @@ -89,6 +91,15 @@ var _ = Describe("Tracing", func() { tracer.StartedConnection(local, remote, src, dest) }) + It("traces the NegotiatedVersion event", func() { + chosen := protocol.Version2 + client := []protocol.VersionNumber{protocol.Version1} + server := []protocol.VersionNumber{13, 37} + tr1.EXPECT().NegotiatedVersion(chosen, client, server) + tr2.EXPECT().NegotiatedVersion(chosen, client, server) + tracer.NegotiatedVersion(chosen, client, server) + }) + It("traces the ClosedConnection event", func() { e := errors.New("test err") tr1.EXPECT().ClosedConnection(e) diff --git a/logging/null_tracer.go b/logging/null_tracer.go deleted file mode 100644 index 3e3458f7..00000000 --- a/logging/null_tracer.go +++ /dev/null @@ -1,63 +0,0 @@ -package logging - -import ( - "net" - "time" -) - -// The NullTracer is a Tracer that does nothing. -// It is useful for embedding. -type NullTracer struct{} - -var _ Tracer = &NullTracer{} - -func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {} -func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) { -} -func (n NullTracer) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) {} - -// The NullConnectionTracer is a ConnectionTracer that does nothing. -// It is useful for embedding. -type NullConnectionTracer struct{} - -var _ ConnectionTracer = &NullConnectionTracer{} - -func (n NullConnectionTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { -} - -func (n NullConnectionTracer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { -} -func (n NullConnectionTracer) ClosedConnection(err error) {} -func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) SentLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) { -} - -func (n NullConnectionTracer) SentShortHeaderPacket(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) { -} - -func (n NullConnectionTracer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) { -} -func (n NullConnectionTracer) ReceivedRetry(*Header) {} -func (n NullConnectionTracer) ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, []Frame) {} -func (n NullConnectionTracer) ReceivedShortHeaderPacket(*ShortHeader, ByteCount, ECN, []Frame) {} -func (n NullConnectionTracer) BufferedPacket(PacketType, ByteCount) {} -func (n NullConnectionTracer) DroppedPacket(PacketType, ByteCount, PacketDropReason) {} - -func (n NullConnectionTracer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { -} -func (n NullConnectionTracer) AcknowledgedPacket(EncryptionLevel, PacketNumber) {} -func (n NullConnectionTracer) LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) {} -func (n NullConnectionTracer) UpdatedCongestionState(CongestionState) {} -func (n NullConnectionTracer) UpdatedPTOCount(uint32) {} -func (n NullConnectionTracer) UpdatedKeyFromTLS(EncryptionLevel, Perspective) {} -func (n NullConnectionTracer) UpdatedKey(keyPhase KeyPhase, remote bool) {} -func (n NullConnectionTracer) DroppedEncryptionLevel(EncryptionLevel) {} -func (n NullConnectionTracer) DroppedKey(KeyPhase) {} -func (n NullConnectionTracer) SetLossTimer(TimerType, EncryptionLevel, time.Time) {} -func (n NullConnectionTracer) LossTimerExpired(TimerType, EncryptionLevel) {} -func (n NullConnectionTracer) LossTimerCanceled() {} -func (n NullConnectionTracer) ECNStateUpdated(ECNState, ECNStateTrigger) {} -func (n NullConnectionTracer) Close() {} -func (n NullConnectionTracer) Debug(name, msg string) {} diff --git a/logging/tracer.go b/logging/tracer.go new file mode 100644 index 00000000..5918f30f --- /dev/null +++ b/logging/tracer.go @@ -0,0 +1,43 @@ +package logging + +import "net" + +// A Tracer traces events. +type Tracer struct { + SentPacket func(net.Addr, *Header, ByteCount, []Frame) + SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) + DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason) +} + +// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. +func NewMultiplexedTracer(tracers ...*Tracer) *Tracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &Tracer{ + SentPacket: func(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { + for _, t := range tracers { + if t.SentPacket != nil { + t.SentPacket(remote, hdr, size, frames) + } + } + }, + SentVersionNegotiationPacket: func(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) { + for _, t := range tracers { + if t.SentVersionNegotiationPacket != nil { + t.SentVersionNegotiationPacket(remote, dest, src, versions) + } + } + }, + DroppedPacket: func(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range tracers { + if t.DroppedPacket != nil { + t.DroppedPacket(remote, typ, size, reason) + } + } + }, + } +} diff --git a/qlog/qlog.go b/qlog/qlog.go index 943bbc3d..7801b646 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -63,11 +63,9 @@ type connectionTracer struct { lastMetrics *metrics } -var _ logging.ConnectionTracer = &connectionTracer{} - // NewConnectionTracer creates a new tracer to record a qlog for a connection. -func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { - t := &connectionTracer{ +func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protocol.ConnectionID) *logging.ConnectionTracer { + t := connectionTracer{ w: w, perspective: p, odcid: odcid, @@ -76,7 +74,84 @@ func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protoco referenceTime: time.Now(), } go t.run() - return t + return &logging.ConnectionTracer{ + StartedConnection: func(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) { + t.StartedConnection(local, remote, srcConnID, destConnID) + }, + NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + }, + ClosedConnection: func(e error) { t.ClosedConnection(e) }, + SentTransportParameters: func(tp *wire.TransportParameters) { t.SentTransportParameters(tp) }, + ReceivedTransportParameters: func(tp *wire.TransportParameters) { t.ReceivedTransportParameters(tp) }, + RestoredTransportParameters: func(tp *wire.TransportParameters) { t.RestoredTransportParameters(tp) }, + SentLongHeaderPacket: func(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) + }, + SentShortHeaderPacket: func(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) + }, + ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) + }, + ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) + }, + ReceivedRetry: func(hdr *wire.Header) { + t.ReceivedRetry(hdr) + }, + ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, versions []logging.VersionNumber) { + t.ReceivedVersionNegotiationPacket(dest, src, versions) + }, + BufferedPacket: func(pt logging.PacketType, size protocol.ByteCount) { + t.BufferedPacket(pt, size) + }, + DroppedPacket: func(pt logging.PacketType, size protocol.ByteCount, reason logging.PacketDropReason) { + t.DroppedPacket(pt, size, reason) + }, + UpdatedMetrics: func(rttStats *utils.RTTStats, cwnd, bytesInFlight protocol.ByteCount, packetsInFlight int) { + t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) + }, + LostPacket: func(encLevel protocol.EncryptionLevel, pn protocol.PacketNumber, lossReason logging.PacketLossReason) { + t.LostPacket(encLevel, pn, lossReason) + }, + UpdatedCongestionState: func(state logging.CongestionState) { + t.UpdatedCongestionState(state) + }, + UpdatedPTOCount: func(value uint32) { + t.UpdatedPTOCount(value) + }, + UpdatedKeyFromTLS: func(encLevel protocol.EncryptionLevel, pers protocol.Perspective) { + t.UpdatedKeyFromTLS(encLevel, pers) + }, + UpdatedKey: func(generation protocol.KeyPhase, remote bool) { + t.UpdatedKey(generation, remote) + }, + DroppedEncryptionLevel: func(encLevel protocol.EncryptionLevel) { + t.DroppedEncryptionLevel(encLevel) + }, + DroppedKey: func(generation protocol.KeyPhase) { + t.DroppedKey(generation) + }, + SetLossTimer: func(tt logging.TimerType, encLevel protocol.EncryptionLevel, timeout time.Time) { + t.SetLossTimer(tt, encLevel, timeout) + }, + LossTimerExpired: func(tt logging.TimerType, encLevel protocol.EncryptionLevel) { + t.LossTimerExpired(tt, encLevel) + }, + LossTimerCanceled: func() { + t.LossTimerCanceled() + }, + ECNStateUpdated: func(state logging.ECNState, trigger logging.ECNStateTrigger) { + t.ECNStateUpdated(state, trigger) + }, + Debug: func(name, msg string) { + t.Debug(name, msg) + }, + Close: func() { + t.Close() + }, + } } func (t *connectionTracer) run() { diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index c19a0d3e..a58f0278 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -70,7 +70,7 @@ var _ = Describe("Tracing", func() { Context("connection tracer", func() { var ( - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer buf *bytes.Buffer ) diff --git a/server.go b/server.go index 671ae971..ebfb6afe 100644 --- a/server.go +++ b/server.go @@ -92,7 +92,7 @@ type baseServer struct { *tls.Config, *handshake.TokenGenerator, bool, /* client address validated by an address validation token */ - logging.ConnectionTracer, + *logging.ConnectionTracer, uint64, utils.Logger, protocol.VersionNumber, @@ -108,7 +108,7 @@ type baseServer struct { connQueue chan quicConn connQueueLen int32 // to be used as an atomic - tracer logging.Tracer + tracer *logging.Tracer logger utils.Logger } @@ -224,7 +224,7 @@ func newServer( connIDGenerator ConnectionIDGenerator, tlsConf *tls.Config, config *Config, - tracer logging.Tracer, + tracer *logging.Tracer, onClose func(), tokenGeneratorKey TokenGeneratorKey, disableVersionNegotiation bool, @@ -349,7 +349,7 @@ func (s *baseServer) handlePacket(p receivedPacket) { case s.receivedPackets <- p: default: s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } @@ -362,7 +362,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st if wire.IsVersionNegotiationPacket(p.data) { s.logger.Debugf("Dropping Version Negotiation packet.") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -375,7 +375,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st // drop the packet if we failed to parse the protocol version if err != nil { s.logger.Debugf("Dropping a packet with an unknown version") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -388,7 +388,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st if p.Size() < protocol.MinUnknownVersionPacketSize { s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -398,7 +398,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st if wire.Is0RTTPacket(p.data) { if !s.acceptEarlyConns { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -410,7 +410,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing packet: %s", err) @@ -418,7 +418,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st } if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -429,7 +429,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st // There's little point in sending a Stateless Reset, since the client // might not have received the token yet. s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -448,7 +448,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { connID, err := wire.ParseConnectionID(p.data, 0) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError) } return false @@ -462,7 +462,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { if q, ok := s.zeroRTTQueues[connID]; ok { if len(q.packets) >= protocol.Max0RTTQueueLen { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } return false @@ -472,7 +472,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { } if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } return false @@ -500,7 +500,7 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) { continue } for _, p := range q.packets { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } p.buffer.Release() @@ -536,7 +536,7 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { p.buffer.Release() - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return errors.New("too short connection ID") @@ -621,7 +621,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error } config = populateConfig(conf) } - var tracer logging.ConnectionTracer + var tracer *logging.ConnectionTracer if config.Tracer != nil { // Use the same connection ID that is passed to the client's GetLogWriter callback. connID := hdr.DestConnectionID @@ -738,7 +738,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe // append the Retry integrity tag tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version) buf.Data = append(buf.Data, tag[:]...) - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentPacket != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported) @@ -759,7 +759,7 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) { hdr, _, _, err := wire.ParsePacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing packet: %s", err) @@ -774,14 +774,14 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) { // Only send INVALID_TOKEN if we can unprotect the packet. // This makes sure that we won't send it for packets that were corrupted. if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) } return } hdrLen := extHdr.ParsedLen() if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) } return @@ -837,7 +837,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han replyHdr.Log(s.logger) wire.LogFrame(s.logger, ccf, true) - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentPacket != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported) @@ -866,7 +866,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) if err != nil { // should never happen s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return @@ -875,7 +875,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { s.logger.Debugf("Client offered version %s, sending Version Negotiation", v) data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentVersionNegotiationPacket != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { diff --git a/server_test.go b/server_test.go index 31a97764..71c968bd 100644 --- a/server_test.go +++ b/server_test.go @@ -177,8 +177,9 @@ var _ = Describe("Server", func() { ) BeforeEach(func() { - tracer = mocklogging.NewMockTracer(mockCtrl) - tr = &Transport{Conn: conn, Tracer: tracer} + var t *logging.Tracer + t, tracer = mocklogging.NewMockTracer(mockCtrl) + tr = &Transport{Conn: conn, Tracer: t} ln, err := tr.Listen(tlsConf, nil) Expect(err).ToNot(HaveOccurred()) serv = ln.baseServer @@ -291,7 +292,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -493,7 +494,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -552,7 +553,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -605,7 +606,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -641,7 +642,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -712,7 +713,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1022,7 +1023,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1099,7 +1100,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1172,7 +1173,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1215,7 +1216,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1279,7 +1280,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1329,8 +1330,9 @@ var _ = Describe("Server", func() { ) BeforeEach(func() { - tracer = mocklogging.NewMockTracer(mockCtrl) - tr = &Transport{Conn: conn, Tracer: tracer} + var t *logging.Tracer + t, tracer = mocklogging.NewMockTracer(mockCtrl) + tr = &Transport{Conn: conn, Tracer: t} ln, err := tr.ListenEarly(tlsConf, nil) Expect(err).ToNot(HaveOccurred()) phm = NewMockPacketHandlerManager(mockCtrl) @@ -1404,7 +1406,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, diff --git a/transport.go b/transport.go index b5855329..0119d629 100644 --- a/transport.go +++ b/transport.go @@ -69,7 +69,7 @@ type Transport struct { DisableVersionNegotiationPackets bool // A Tracer traces events that don't belong to a single QUIC connection. - Tracer logging.Tracer + Tracer *logging.Tracer handlerMap packetHandlerManager @@ -363,7 +363,7 @@ func (t *Transport) handlePacket(p receivedPacket) { connID, err := wire.ParseConnectionID(p.data, t.connIDLen) if err != nil { t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) - if t.Tracer != nil { + if t.Tracer != nil && t.Tracer.DroppedPacket != nil { t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } p.buffer.MaybeRelease() @@ -458,7 +458,7 @@ func (t *Transport) handleNonQUICPacket(p receivedPacket) { select { case t.nonQUICPackets <- p: default: - if t.Tracer != nil { + if t.Tracer != nil && t.Tracer.DroppedPacket != nil { t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } diff --git a/transport_test.go b/transport_test.go index 14c6fdbb..25019719 100644 --- a/transport_test.go +++ b/transport_test.go @@ -126,11 +126,11 @@ var _ = Describe("Transport", func() { It("drops unparseable QUIC packets", func() { addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} packetChan := make(chan packetToRead) - tracer := mocklogging.NewMockTracer(mockCtrl) + t, tracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ Conn: newMockPacketConn(packetChan), ConnectionIDLength: 10, - Tracer: tracer, + Tracer: t, } tr.init(true) dropped := make(chan struct{}) @@ -328,11 +328,9 @@ var _ = Describe("Transport", func() { It("allows receiving non-QUIC packets", func() { remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} packetChan := make(chan packetToRead) - tracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ Conn: newMockPacketConn(packetChan), ConnectionIDLength: 10, - Tracer: tracer, } tr.init(true) receivedPacketChan := make(chan []byte) @@ -362,11 +360,11 @@ var _ = Describe("Transport", func() { It("drops non-QUIC packet if the application doesn't process them quickly enough", func() { remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} packetChan := make(chan packetToRead) - tracer := mocklogging.NewMockTracer(mockCtrl) + t, tracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ Conn: newMockPacketConn(packetChan), ConnectionIDLength: 10, - Tracer: tracer, + Tracer: t, } tr.init(true) From 1affe38703fdb78ccaa3716d6f44fbec0063c7ce Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 16 Sep 2023 19:10:20 +0700 Subject: [PATCH 35/43] move MaxTokenAge configuration option to the Transport (#4084) --- config.go | 4 ---- config_test.go | 2 -- interface.go | 4 ---- internal/handshake/token_protector.go | 2 +- server.go | 5 ++++- server_test.go | 2 +- transport.go | 8 ++++++++ 7 files changed, 14 insertions(+), 13 deletions(-) diff --git a/config.go b/config.go index dea118b3..49b9fc3f 100644 --- a/config.go +++ b/config.go @@ -53,9 +53,6 @@ func validateConfig(config *Config) error { // it may be called with nil func populateServerConfig(config *Config) *Config { config = populateConfig(config) - if config.MaxTokenAge == 0 { - config.MaxTokenAge = protocol.TokenValidity - } if config.RequireAddressValidation == nil { config.RequireAddressValidation = func(net.Addr) bool { return false } } @@ -114,7 +111,6 @@ func populateConfig(config *Config) *Config { Versions: versions, HandshakeIdleTimeout: handshakeIdleTimeout, MaxIdleTimeout: idleTimeout, - MaxTokenAge: config.MaxTokenAge, RequireAddressValidation: config.RequireAddressValidation, KeepAlivePeriod: config.KeepAlivePeriod, InitialStreamReceiveWindow: initialStreamReceiveWindow, diff --git a/config_test.go b/config_test.go index c5701608..e40c1cfc 100644 --- a/config_test.go +++ b/config_test.go @@ -78,8 +78,6 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf(time.Second)) case "MaxIdleTimeout": f.Set(reflect.ValueOf(time.Hour)) - case "MaxTokenAge": - f.Set(reflect.ValueOf(2 * time.Hour)) case "TokenStore": f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) case "InitialStreamReceiveWindow": diff --git a/interface.go b/interface.go index c7e53bdc..6eac385d 100644 --- a/interface.go +++ b/interface.go @@ -268,10 +268,6 @@ type Config struct { // See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details. // If not set, every client is forced to prove its remote address. RequireAddressValidation func(net.Addr) bool - // MaxTokenAge is the maximum age of the token presented during the handshake, - // for tokens that were issued on a previous connection. - // If not set, it defaults to 24 hours. Only valid for a server. - MaxTokenAge time.Duration // The TokenStore stores tokens received from the server. // Tokens are used to skip address validation on future connection attempts. // The key used to store tokens is the ServerName from the tls.Config, if set diff --git a/internal/handshake/token_protector.go b/internal/handshake/token_protector.go index 6dcf7f77..f3a99e41 100644 --- a/internal/handshake/token_protector.go +++ b/internal/handshake/token_protector.go @@ -61,7 +61,7 @@ func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) { } func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { - h := hkdf.New(sha256.New, s.key[:], nonce[:], []byte("quic-go token source")) + h := hkdf.New(sha256.New, s.key[:], nonce, []byte("quic-go token source")) key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 if _, err := io.ReadFull(h, key); err != nil { return nil, nil, err diff --git a/server.go b/server.go index ebfb6afe..119289b3 100644 --- a/server.go +++ b/server.go @@ -67,6 +67,7 @@ type baseServer struct { conn rawConn tokenGenerator *handshake.TokenGenerator + maxTokenAge time.Duration connIDGenerator ConnectionIDGenerator connHandler packetHandlerManager @@ -227,6 +228,7 @@ func newServer( tracer *logging.Tracer, onClose func(), tokenGeneratorKey TokenGeneratorKey, + maxTokenAge time.Duration, disableVersionNegotiation bool, acceptEarly bool, ) *baseServer { @@ -235,6 +237,7 @@ func newServer( tlsConf: tlsConf, config: config, tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey), + maxTokenAge: maxTokenAge, connIDGenerator: connIDGenerator, connHandler: connHandler, connQueue: make(chan quicConn), @@ -524,7 +527,7 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { if !token.ValidateRemoteAddr(addr) { return false } - if !token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxTokenAge { + if !token.IsRetryToken && time.Since(token.SentTime) > s.maxTokenAge { return false } if token.IsRetryToken && time.Since(token.SentTime) > s.config.maxRetryTokenAge() { diff --git a/server_test.go b/server_test.go index 71c968bd..6e50de82 100644 --- a/server_test.go +++ b/server_test.go @@ -901,7 +901,7 @@ var _ = Describe("Server", func() { It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() { serv.config.RequireAddressValidation = func(net.Addr) bool { return true } - serv.config.MaxTokenAge = time.Millisecond + serv.maxTokenAge = time.Millisecond raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} token, err := serv.tokenGenerator.NewToken(raddr) Expect(err).ToNot(HaveOccurred()) diff --git a/transport.go b/transport.go index 0119d629..14d6b1fd 100644 --- a/transport.go +++ b/transport.go @@ -63,6 +63,13 @@ type Transport struct { // see section 8.1.3 of RFC 9000 for details. TokenGeneratorKey *TokenGeneratorKey + // MaxTokenAge is the maximum age of the resumption token presented during the handshake. + // These tokens allow skipping address resumption when resuming a QUIC connection, + // and are especially useful when using 0-RTT. + // If not set, it defaults to 24 hours. + // See section 8.1.3 of RFC 9000 for details. + MaxTokenAge time.Duration + // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. // This can be useful if version information is exchanged out-of-band. // It has no effect for clients. @@ -151,6 +158,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo t.Tracer, t.closeServer, *t.TokenGeneratorKey, + t.MaxTokenAge, t.DisableVersionNegotiationPackets, allow0RTT, ) From 55eebd49ff040846a90d818f5d513d48ee2d2c82 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 16 Sep 2023 19:37:58 +0700 Subject: [PATCH 36/43] return the cancellation cause for cancelled dials (#4078) --- client.go | 2 +- integrationtests/self/handshake_test.go | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 670e7a7c..e0ae5d3c 100644 --- a/client.go +++ b/client.go @@ -233,7 +233,7 @@ func (c *client) dial(ctx context.Context) error { select { case <-ctx.Done(): c.conn.shutdown() - return ctx.Err() + return context.Cause(ctx) case err := <-errorChan: return err case recreateErr := <-recreateChan: diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index f90ac2fd..d4f65a12 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -82,6 +82,26 @@ var _ = Describe("Handshake tests", func() { }() } + It("returns the cancellation reason when a dial is canceled", func() { + ctx, cancel := context.WithCancelCause(context.Background()) + errChan := make(chan error, 1) + go func() { + _, err := quic.DialAddr( + ctx, + "localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway + getTLSClientConfig(), + getQuicConfig(nil), + ) + errChan <- err + }() + + cancel(errors.New("application cancelled")) + var err error + Eventually(errChan).Should(Receive(&err)) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("application cancelled")) + }) + Context("using different cipher suites", func() { for n, id := range map[string]uint16{ "TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, From 22fb59ee6f2962e9987a7dede7a91821fbffd0b0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 17 Sep 2023 05:18:43 -0700 Subject: [PATCH 37/43] create FUNDING.yml --- .github/FUNDING.yml | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 .github/FUNDING.yml diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..7de30a1d --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,13 @@ +# These are supported funding model platforms + +github: [marten-seemann] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] From 9010cfd2bbc9da8c16a839a4b9494bfe3b6b9d41 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 17 Sep 2023 19:20:13 +0700 Subject: [PATCH 38/43] remove unused unknownPacketHandler interface (#4093) --- mock_unknown_packet_handler_test.go | 58 ----------------------------- mockgen.go | 3 -- packet_handler_map.go | 8 ---- transport.go | 4 +- 4 files changed, 3 insertions(+), 70 deletions(-) delete mode 100644 mock_unknown_packet_handler_test.go diff --git a/mock_unknown_packet_handler_test.go b/mock_unknown_packet_handler_test.go deleted file mode 100644 index 6f29101a..00000000 --- a/mock_unknown_packet_handler_test.go +++ /dev/null @@ -1,58 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go (interfaces: UnknownPacketHandler) - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "go.uber.org/mock/gomock" -) - -// MockUnknownPacketHandler is a mock of UnknownPacketHandler interface. -type MockUnknownPacketHandler struct { - ctrl *gomock.Controller - recorder *MockUnknownPacketHandlerMockRecorder -} - -// MockUnknownPacketHandlerMockRecorder is the mock recorder for MockUnknownPacketHandler. -type MockUnknownPacketHandlerMockRecorder struct { - mock *MockUnknownPacketHandler -} - -// NewMockUnknownPacketHandler creates a new mock instance. -func NewMockUnknownPacketHandler(ctrl *gomock.Controller) *MockUnknownPacketHandler { - mock := &MockUnknownPacketHandler{ctrl: ctrl} - mock.recorder = &MockUnknownPacketHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorder { - return m.recorder -} - -// handlePacket mocks base method. -func (m *MockUnknownPacketHandler) handlePacket(arg0 receivedPacket) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "handlePacket", arg0) -} - -// handlePacket indicates an expected call of handlePacket. -func (mr *MockUnknownPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockUnknownPacketHandler)(nil).handlePacket), arg0) -} - -// setCloseError mocks base method. -func (m *MockUnknownPacketHandler) setCloseError(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "setCloseError", arg0) -} - -// setCloseError indicates an expected call of setCloseError. -func (mr *MockUnknownPacketHandlerMockRecorder) setCloseError(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setCloseError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).setCloseError), arg0) -} diff --git a/mockgen.go b/mockgen.go index ea0aa58b..eb247386 100644 --- a/mockgen.go +++ b/mockgen.go @@ -62,9 +62,6 @@ type QUICConn = quicConn //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler" type PacketHandler = packetHandler -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_unknown_packet_handler_test.go github.com/quic-go/quic-go UnknownPacketHandler" -type UnknownPacketHandler = unknownPacketHandler - //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" type PacketHandlerManager = packetHandlerManager diff --git a/packet_handler_map.go b/packet_handler_map.go index ba30fb23..006eadf9 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -4,7 +4,6 @@ import ( "crypto/hmac" "crypto/rand" "crypto/sha256" - "errors" "hash" "io" "net" @@ -45,13 +44,6 @@ type closePacket struct { info packetInfo } -type unknownPacketHandler interface { - handlePacket(receivedPacket) - setCloseError(error) -} - -var errListenerAlreadySet = errors.New("listener already set") - type packetHandlerMap struct { mutex sync.Mutex handlers map[protocol.ConnectionID]packetHandler diff --git a/transport.go b/transport.go index 14d6b1fd..60d44a43 100644 --- a/transport.go +++ b/transport.go @@ -16,6 +16,8 @@ import ( "github.com/quic-go/quic-go/logging" ) +var errListenerAlreadySet = errors.New("listener already set") + // The Transport is the central point to manage incoming and outgoing QUIC connections. // QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple. // This means that a single UDP socket can be used for listening for incoming connections, as well as @@ -91,7 +93,7 @@ type Transport struct { // If no ConnectionIDGenerator is set, this is set to a default. connIDGenerator ConnectionIDGenerator - server unknownPacketHandler + server *baseServer conn rawConn From c12f42580329a9fa6a724f7f0f56c7abb2a7dde6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 17 Sep 2023 23:00:05 +0400 Subject: [PATCH 39/43] ackhandler: don't fail ECN validation if less than 10 testing packets are lost (#4088) * ackhandler: don't fail ECN validation less than 10 testing packets lost * ackhandler: simplify checks for mangling and loss of all testing packets --- internal/ackhandler/ecn.go | 22 +++++++++++++--------- internal/ackhandler/ecn_test.go | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/internal/ackhandler/ecn.go b/internal/ackhandler/ecn.go index 43cc3dd6..3e851c8e 100644 --- a/internal/ackhandler/ecn.go +++ b/internal/ackhandler/ecn.go @@ -125,6 +125,10 @@ func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) { return } e.numLostTesting++ + // Only proceed if we have sent all 10 testing packets. + if e.state != ecnStateUnknown { + return + } if e.numLostTesting >= e.numSentTesting { e.logger.Debugf("Disabling ECN. All testing packets were lost.") if e.tracer != nil && e.tracer.ECNStateUpdated != nil { @@ -223,16 +227,16 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 e.numAckedECT1 = ect1 e.numAckedECNCE = ecnce - if e.state == ecnStateTesting || e.state == ecnStateUnknown { - // Detect mangling (a path remarking all ECN-marked testing packets as CE). - if e.numSentECT0+e.numSentECT1 == e.numAckedECNCE && e.numAckedECNCE >= numECNTestingPackets { - if e.tracer != nil && e.tracer.ECNStateUpdated != nil { - e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) - } - e.state = ecnStateFailed - return false + // Detect mangling (a path remarking all ECN-marked testing packets as CE), + // once all 10 testing packets have been sent out. + if e.state == ecnStateUnknown && e.numSentECT0+e.numSentECT1 == e.numAckedECNCE { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) } - + e.state = ecnStateFailed + return false + } + if e.state == ecnStateTesting || e.state == ecnStateUnknown { var ackedTestingPacket bool for _, p := range packets { if e.isTestingPacket(p.PacketNumber) { diff --git a/internal/ackhandler/ecn_test.go b/internal/ackhandler/ecn_test.go index 26d54eed..78eef8c9 100644 --- a/internal/ackhandler/ecn_test.go +++ b/internal/ackhandler/ecn_test.go @@ -71,6 +71,22 @@ var _ = Describe("ECN tracker", func() { ecnTracker.LostPacket(16) }) + It("only detects ECN mangling after sending all testing packets", func() { + tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + for i := 0; i < 9; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0) + ecnTracker.LostPacket(protocol.PacketNumber(i)) + } + // Send the last testing packet, and receive a + tracer.EXPECT().ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0)) + ecnTracker.SentPacket(9, protocol.ECT0) + // Now lose the last testing packet. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets) + ecnTracker.LostPacket(9) + }) + It("passes ECN validation when a testing packet is acknowledged, while still in testing state", func() { tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) for i := 0; i < 5; i++ { From 4a046185b7d5275fd70504d2d630d10e731dbcc6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 18 Sep 2023 09:08:33 +0400 Subject: [PATCH 40/43] ackhandler: fix ECN mangling detection when packets are lost (#4089) Some of the 10 testing packets are might be lost, while others might be CE-marked. We need to detect mangling if all testing packets are either lost are CE-marked. --- internal/ackhandler/ecn.go | 24 +++++++++++++++++----- internal/ackhandler/ecn_test.go | 35 ++++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/internal/ackhandler/ecn.go b/internal/ackhandler/ecn.go index 3e851c8e..68415ac6 100644 --- a/internal/ackhandler/ecn.go +++ b/internal/ackhandler/ecn.go @@ -135,7 +135,10 @@ func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets) } e.state = ecnStateFailed + return } + // Path validation also fails if some testing packets are lost, and all other testing packets where CE-marked + e.failIfMangled() } // HandleNewlyAcked handles the ECN counts on an ACK frame. @@ -229,12 +232,11 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // Detect mangling (a path remarking all ECN-marked testing packets as CE), // once all 10 testing packets have been sent out. - if e.state == ecnStateUnknown && e.numSentECT0+e.numSentECT1 == e.numAckedECNCE { - if e.tracer != nil && e.tracer.ECNStateUpdated != nil { - e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + if e.state == ecnStateUnknown { + e.failIfMangled() + if e.state == ecnStateFailed { + return false } - e.state = ecnStateFailed - return false } if e.state == ecnStateTesting || e.state == ecnStateUnknown { var ackedTestingPacket bool @@ -259,6 +261,18 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 return e.state == ecnStateCapable && newECNCE > 0 } +// failIfMangled fails ECN validation if all testing packets are lost or CE-marked. +func (e *ecnTracker) failIfMangled() { + numAckedECNCE := e.numAckedECNCE + int64(e.numLostTesting) + if e.numSentECT0+e.numSentECT1 > numAckedECNCE { + return + } + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + } + e.state = ecnStateFailed +} + func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN { if pn < e.firstTestingPacket || e.firstTestingPacket == protocol.InvalidPacketNumber { return protocol.ECNNon diff --git a/internal/ackhandler/ecn_test.go b/internal/ackhandler/ecn_test.go index 78eef8c9..ab648d81 100644 --- a/internal/ackhandler/ecn_test.go +++ b/internal/ackhandler/ecn_test.go @@ -192,7 +192,7 @@ var _ = Describe("ECN tracker", func() { Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 15), 3, 0, 2)).To(BeFalse()) }) - It("detects ECN mangling", func() { + It("detects ECN mangling if all testing packets are marked CE", func() { sendAllTestingPackets() for i := 10; i < 20; i++ { Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) @@ -222,6 +222,39 @@ var _ = Describe("ECN tracker", func() { Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(9), 0, 0, 10)).To(BeFalse()) }) + It("detects ECN mangling, if some testing packets are marked CE, and then others are lost", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // ECN capability not confirmed yet, therefore CE marks are not regarded as congestion events + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(0, 1, 2, 3), 0, 0, 4)).To(BeFalse()) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(6, 7, 8, 9), 0, 0, 8)).To(BeFalse()) + // Lose one of the two unacknowledged packets. + ecnTracker.LostPacket(4) + // By losing the last unacknowledged testing packets, we should detect the mangling. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + ecnTracker.LostPacket(5) + }) + + It("detects ECN mangling, if some testing packets are lost, and then others are marked CE", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // Lose a few packets. + ecnTracker.LostPacket(0) + ecnTracker.LostPacket(1) + ecnTracker.LostPacket(2) + // ECN capability not confirmed yet, therefore CE marks are not regarded as congestion events + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(3, 4, 5, 6, 7, 8), 0, 0, 6)).To(BeFalse()) + // By CE-marking the last unacknowledged testing packets, we should detect the mangling. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(9), 0, 0, 7)).To(BeFalse()) + }) + It("declares congestion", func() { sendAllTestingPackets() for i := 10; i < 20; i++ { From 4bdff39ff0ce859f3f9a8f7bddc8a389add6c0f0 Mon Sep 17 00:00:00 2001 From: Toby Date: Sun, 24 Sep 2023 04:01:03 -0700 Subject: [PATCH 41/43] README: add Hysteria (#4085) * chore: add my project * fix order --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a793f984..976a07cc 100644 --- a/README.md +++ b/README.md @@ -200,12 +200,13 @@ http.Client{ ## Projects using quic-go | Project | Description | Stars | -|-----------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------| +| --------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- | | [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) | | [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) | | [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) | | [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | | [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) | +| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) | | [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) | | [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | | [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | From 9a397abc177badb035e63b4d9321fc287430ee2d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 24 Sep 2023 11:38:28 +0000 Subject: [PATCH 42/43] update gomock to v0.3.0 (#4087) --- go.mod | 4 +- go.sum | 8 +-- http3/mock_quic_early_listener_test.go | 8 ++- http3/mock_roundtripcloser_test.go | 8 ++- integrationtests/gomodvendor/go.sum | 9 ++- internal/ackhandler/mock_ecn_handler_test.go | 12 ++-- .../mock_sent_packet_tracker_test.go | 8 ++- .../ackhandler/received_packet_handler.go | 14 +++-- .../mocks/ackhandler/sent_packet_handler.go | 28 +++++---- internal/mocks/congestion.go | 22 ++++--- internal/mocks/connection_flow_controller.go | 12 ++-- internal/mocks/crypto_setup.go | 12 ++-- .../logging/internal/connection_tracer.go | 60 ++++++++++--------- internal/mocks/logging/internal/tracer.go | 12 ++-- internal/mocks/long_header_opener.go | 12 ++-- internal/mocks/quic/early_conn.go | 20 ++++--- internal/mocks/quic/stream.go | 20 ++++--- internal/mocks/short_header_opener.go | 12 ++-- internal/mocks/short_header_sealer.go | 10 +++- internal/mocks/stream_flow_controller.go | 14 +++-- internal/mocks/tls/client_session_cache.go | 10 +++- mock_ack_frame_source_test.go | 8 ++- mock_batch_conn_test.go | 8 ++- mock_conn_runner_test.go | 20 ++++--- mock_crypto_data_handler_test.go | 8 ++- mock_crypto_stream_test.go | 12 ++-- mock_frame_source_test.go | 10 +++- mock_mtu_discoverer_test.go | 10 +++- mock_packer_test.go | 22 ++++--- mock_packet_handler_manager_test.go | 28 +++++---- mock_packet_handler_test.go | 10 +++- mock_packetconn_test.go | 16 +++-- mock_quic_conn_test.go | 24 ++++---- mock_raw_conn_test.go | 10 +++- mock_receive_stream_internal_test.go | 18 +++--- mock_sealing_manager_test.go | 6 +- mock_send_conn_test.go | 8 ++- mock_send_stream_internal_test.go | 20 ++++--- mock_sender_test.go | 8 ++- mock_stream_getter_test.go | 10 +++- mock_stream_internal_test.go | 32 +++++----- mock_stream_manager_test.go | 26 ++++---- mock_stream_sender_test.go | 12 ++-- mock_token_store_test.go | 10 +++- mock_unpacker_test.go | 10 +++- server_test.go | 2 +- 46 files changed, 415 insertions(+), 248 deletions(-) diff --git a/go.mod b/go.mod index 936a3786..0a563463 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/onsi/gomega v1.27.6 github.com/quic-go/qpack v0.4.0 github.com/quic-go/qtls-go1-20 v0.3.4 - go.uber.org/mock v0.2.0 + go.uber.org/mock v0.3.0 golang.org/x/crypto v0.4.0 golang.org/x/exp v0.0.0-20221205204356-47842c84f3db golang.org/x/net v0.10.0 @@ -21,7 +21,7 @@ require ( github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect - golang.org/x/mod v0.10.0 // indirect + golang.org/x/mod v0.11.0 // indirect golang.org/x/text v0.9.0 // indirect golang.org/x/tools v0.9.1 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect diff --git a/go.sum b/go.sum index 8a3bc655..62b640db 100644 --- a/go.sum +++ b/go.sum @@ -124,8 +124,8 @@ github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cb github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= -go.uber.org/mock v0.2.0 h1:TaP3xedm7JaAgScZO7tlvlKrqT0p7I6OsdGB5YNSMDU= -go.uber.org/mock v0.2.0/go.mod h1:J0y0rp9L3xiff1+ZBfKxlC1fz2+aO16tw0tsDOixfuM= +go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= +go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -139,8 +139,8 @@ golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZ golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= -golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= +golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/http3/mock_quic_early_listener_test.go b/http3/mock_quic_early_listener_test.go index b79e9858..c6977bea 100644 --- a/http3/mock_quic_early_listener_test.go +++ b/http3/mock_quic_early_listener_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/http3 (interfaces: QUICEarlyListener) - +// +// Generated by this command: +// +// mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/quic-go/quic-go/http3 QUICEarlyListener +// // Package http3 is a generated GoMock package. package http3 @@ -46,7 +50,7 @@ func (m *MockQUICEarlyListener) Accept(arg0 context.Context) (quic.EarlyConnecti } // Accept indicates an expected call of Accept. -func (mr *MockQUICEarlyListenerMockRecorder) Accept(arg0 interface{}) *gomock.Call { +func (mr *MockQUICEarlyListenerMockRecorder) Accept(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockQUICEarlyListener)(nil).Accept), arg0) } diff --git a/http3/mock_roundtripcloser_test.go b/http3/mock_roundtripcloser_test.go index ce7f0d19..2aa3558d 100644 --- a/http3/mock_roundtripcloser_test.go +++ b/http3/mock_roundtripcloser_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/http3 (interfaces: RoundTripCloser) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser +// // Package http3 is a generated GoMock package. package http3 @@ -72,7 +76,7 @@ func (m *MockRoundTripCloser) RoundTripOpt(arg0 *http.Request, arg1 RoundTripOpt } // RoundTripOpt indicates an expected call of RoundTripOpt. -func (mr *MockRoundTripCloserMockRecorder) RoundTripOpt(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockRoundTripCloserMockRecorder) RoundTripOpt(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTripOpt", reflect.TypeOf((*MockRoundTripCloser)(nil).RoundTripOpt), arg0, arg1) } diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index 5e207944..29427a8c 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -174,8 +174,8 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= -go.uber.org/mock v0.2.0 h1:TaP3xedm7JaAgScZO7tlvlKrqT0p7I6OsdGB5YNSMDU= -go.uber.org/mock v0.2.0/go.mod h1:J0y0rp9L3xiff1+ZBfKxlC1fz2+aO16tw0tsDOixfuM= +go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= +go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -194,15 +194,15 @@ golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTk golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= +golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -305,7 +305,6 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.8/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= diff --git a/internal/ackhandler/mock_ecn_handler_test.go b/internal/ackhandler/mock_ecn_handler_test.go index 28f350c1..194ab891 100644 --- a/internal/ackhandler/mock_ecn_handler_test.go +++ b/internal/ackhandler/mock_ecn_handler_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/ackhandler (interfaces: ECNHandler) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package ackhandler -destination mock_ecn_handler_test.go github.com/quic-go/quic-go/internal/ackhandler ECNHandler +// // Package ackhandler is a generated GoMock package. package ackhandler @@ -43,7 +47,7 @@ func (m *MockECNHandler) HandleNewlyAcked(arg0 []*packet, arg1, arg2, arg3 int64 } // HandleNewlyAcked indicates an expected call of HandleNewlyAcked. -func (mr *MockECNHandlerMockRecorder) HandleNewlyAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockECNHandlerMockRecorder) HandleNewlyAcked(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleNewlyAcked", reflect.TypeOf((*MockECNHandler)(nil).HandleNewlyAcked), arg0, arg1, arg2, arg3) } @@ -55,7 +59,7 @@ func (m *MockECNHandler) LostPacket(arg0 protocol.PacketNumber) { } // LostPacket indicates an expected call of LostPacket. -func (mr *MockECNHandlerMockRecorder) LostPacket(arg0 interface{}) *gomock.Call { +func (mr *MockECNHandlerMockRecorder) LostPacket(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockECNHandler)(nil).LostPacket), arg0) } @@ -81,7 +85,7 @@ func (m *MockECNHandler) SentPacket(arg0 protocol.PacketNumber, arg1 protocol.EC } // SentPacket indicates an expected call of SentPacket. -func (mr *MockECNHandlerMockRecorder) SentPacket(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockECNHandlerMockRecorder) SentPacket(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockECNHandler)(nil).SentPacket), arg0, arg1) } diff --git a/internal/ackhandler/mock_sent_packet_tracker_test.go b/internal/ackhandler/mock_sent_packet_tracker_test.go index 63a9a163..aa8c8674 100644 --- a/internal/ackhandler/mock_sent_packet_tracker_test.go +++ b/internal/ackhandler/mock_sent_packet_tracker_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/ackhandler (interfaces: SentPacketTracker) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker +// // Package ackhandler is a generated GoMock package. package ackhandler @@ -55,7 +59,7 @@ func (m *MockSentPacketTracker) ReceivedPacket(arg0 protocol.EncryptionLevel) { } // ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketTracker)(nil).ReceivedPacket), arg0) } diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go index b068586e..6455ec1c 100644 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/ackhandler (interfaces: ReceivedPacketHandler) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler ReceivedPacketHandler +// // Package mockackhandler is a generated GoMock package. package mockackhandler @@ -43,7 +47,7 @@ func (m *MockReceivedPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { } // DropPackets indicates an expected call of DropPackets. -func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockReceivedPacketHandler)(nil).DropPackets), arg0) } @@ -57,7 +61,7 @@ func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel, a } // GetAckFrame indicates an expected call of GetAckFrame. -func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0, arg1) } @@ -85,7 +89,7 @@ func (m *MockReceivedPacketHandler) IsPotentiallyDuplicate(arg0 protocol.PacketN } // IsPotentiallyDuplicate indicates an expected call of IsPotentiallyDuplicate. -func (mr *MockReceivedPacketHandlerMockRecorder) IsPotentiallyDuplicate(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) IsPotentiallyDuplicate(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPotentiallyDuplicate", reflect.TypeOf((*MockReceivedPacketHandler)(nil).IsPotentiallyDuplicate), arg0, arg1) } @@ -99,7 +103,7 @@ func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, a } // ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3, arg4) } diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 137ac2ef..6cbde5dc 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/ackhandler (interfaces: SentPacketHandler) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler SentPacketHandler +// // Package mockackhandler is a generated GoMock package. package mockackhandler @@ -44,7 +48,7 @@ func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { } // DropPackets indicates an expected call of DropPackets. -func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) } @@ -58,7 +62,7 @@ func (m *MockSentPacketHandler) ECNMode(arg0 bool) protocol.ECN { } // ECNMode indicates an expected call of ECNMode. -func (mr *MockSentPacketHandlerMockRecorder) ECNMode(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ECNMode(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode), arg0) } @@ -101,7 +105,7 @@ func (m *MockSentPacketHandler) PeekPacketNumber(arg0 protocol.EncryptionLevel) } // PeekPacketNumber indicates an expected call of PeekPacketNumber. -func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber), arg0) } @@ -115,7 +119,7 @@ func (m *MockSentPacketHandler) PopPacketNumber(arg0 protocol.EncryptionLevel) p } // PopPacketNumber indicates an expected call of PopPacketNumber. -func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber), arg0) } @@ -129,7 +133,7 @@ func (m *MockSentPacketHandler) QueueProbePacket(arg0 protocol.EncryptionLevel) } // QueueProbePacket indicates an expected call of QueueProbePacket. -func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket), arg0) } @@ -144,7 +148,7 @@ func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.E } // ReceivedAck indicates an expected call of ReceivedAck. -func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2) } @@ -156,7 +160,7 @@ func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount) { } // ReceivedBytes indicates an expected call of ReceivedBytes. -func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0) } @@ -170,7 +174,7 @@ func (m *MockSentPacketHandler) ResetForRetry(arg0 time.Time) error { } // ResetForRetry indicates an expected call of ResetForRetry. -func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetForRetry", reflect.TypeOf((*MockSentPacketHandler)(nil).ResetForRetry), arg0) } @@ -184,7 +188,7 @@ func (m *MockSentPacketHandler) SendMode(arg0 time.Time) ackhandler.SendMode { } // SendMode indicates an expected call of SendMode. -func (mr *MockSentPacketHandlerMockRecorder) SendMode(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) SendMode(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMode", reflect.TypeOf((*MockSentPacketHandler)(nil).SendMode), arg0) } @@ -196,7 +200,7 @@ func (m *MockSentPacketHandler) SentPacket(arg0 time.Time, arg1, arg2 protocol.P } // SentPacket indicates an expected call of SentPacket. -func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) } @@ -220,7 +224,7 @@ func (m *MockSentPacketHandler) SetMaxDatagramSize(arg0 protocol.ByteCount) { } // SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. -func (mr *MockSentPacketHandlerMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) SetMaxDatagramSize(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSentPacketHandler)(nil).SetMaxDatagramSize), arg0) } diff --git a/internal/mocks/congestion.go b/internal/mocks/congestion.go index 4d38d1db..9a96239c 100644 --- a/internal/mocks/congestion.go +++ b/internal/mocks/congestion.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/congestion (interfaces: SendAlgorithmWithDebugInfos) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination congestion.go github.com/quic-go/quic-go/internal/congestion SendAlgorithmWithDebugInfos +// // Package mocks is a generated GoMock package. package mocks @@ -44,7 +48,7 @@ func (m *MockSendAlgorithmWithDebugInfos) CanSend(arg0 protocol.ByteCount) bool } // CanSend indicates an expected call of CanSend. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) CanSend(arg0 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) CanSend(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).CanSend), arg0) } @@ -72,7 +76,7 @@ func (m *MockSendAlgorithmWithDebugInfos) HasPacingBudget(arg0 time.Time) bool { } // HasPacingBudget indicates an expected call of HasPacingBudget. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) HasPacingBudget(arg0 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) HasPacingBudget(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).HasPacingBudget), arg0) } @@ -124,7 +128,7 @@ func (m *MockSendAlgorithmWithDebugInfos) OnCongestionEvent(arg0 protocol.Packet } // OnCongestionEvent indicates an expected call of OnCongestionEvent. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnCongestionEvent(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnCongestionEvent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnCongestionEvent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnCongestionEvent), arg0, arg1, arg2) } @@ -136,7 +140,7 @@ func (m *MockSendAlgorithmWithDebugInfos) OnPacketAcked(arg0 protocol.PacketNumb } // OnPacketAcked indicates an expected call of OnPacketAcked. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketAcked), arg0, arg1, arg2, arg3) } @@ -148,7 +152,7 @@ func (m *MockSendAlgorithmWithDebugInfos) OnPacketSent(arg0 time.Time, arg1 prot } // OnPacketSent indicates an expected call of OnPacketSent. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketSent(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketSent(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketSent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketSent), arg0, arg1, arg2, arg3, arg4) } @@ -160,7 +164,7 @@ func (m *MockSendAlgorithmWithDebugInfos) OnRetransmissionTimeout(arg0 bool) { } // OnRetransmissionTimeout indicates an expected call of OnRetransmissionTimeout. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnRetransmissionTimeout(arg0 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnRetransmissionTimeout(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnRetransmissionTimeout), arg0) } @@ -172,7 +176,7 @@ func (m *MockSendAlgorithmWithDebugInfos) SetMaxDatagramSize(arg0 protocol.ByteC } // SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) SetMaxDatagramSize(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).SetMaxDatagramSize), arg0) } @@ -186,7 +190,7 @@ func (m *MockSendAlgorithmWithDebugInfos) TimeUntilSend(arg0 protocol.ByteCount) } // TimeUntilSend indicates an expected call of TimeUntilSend. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) TimeUntilSend(arg0 interface{}) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) TimeUntilSend(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).TimeUntilSend), arg0) } diff --git a/internal/mocks/connection_flow_controller.go b/internal/mocks/connection_flow_controller.go index 1ec76f9e..7407054c 100644 --- a/internal/mocks/connection_flow_controller.go +++ b/internal/mocks/connection_flow_controller.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/flowcontrol (interfaces: ConnectionFlowController) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination connection_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol ConnectionFlowController +// // Package mocks is a generated GoMock package. package mocks @@ -41,7 +45,7 @@ func (m *MockConnectionFlowController) AddBytesRead(arg0 protocol.ByteCount) { } // AddBytesRead indicates an expected call of AddBytesRead. -func (mr *MockConnectionFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) AddBytesRead(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesRead), arg0) } @@ -53,7 +57,7 @@ func (m *MockConnectionFlowController) AddBytesSent(arg0 protocol.ByteCount) { } // AddBytesSent indicates an expected call of AddBytesSent. -func (mr *MockConnectionFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) AddBytesSent(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesSent), arg0) } @@ -122,7 +126,7 @@ func (m *MockConnectionFlowController) UpdateSendWindow(arg0 protocol.ByteCount) } // UpdateSendWindow indicates an expected call of UpdateSendWindow. -func (mr *MockConnectionFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) UpdateSendWindow(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockConnectionFlowController)(nil).UpdateSendWindow), arg0) } diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index bf496c22..1e4278ec 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/handshake (interfaces: CryptoSetup) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination crypto_setup_tmp.go github.com/quic-go/quic-go/internal/handshake CryptoSetup +// // Package mocks is a generated GoMock package. package mocks @@ -42,7 +46,7 @@ func (m *MockCryptoSetup) ChangeConnectionID(arg0 protocol.ConnectionID) { } // ChangeConnectionID indicates an expected call of ChangeConnectionID. -func (mr *MockCryptoSetupMockRecorder) ChangeConnectionID(arg0 interface{}) *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) ChangeConnectionID(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeConnectionID", reflect.TypeOf((*MockCryptoSetup)(nil).ChangeConnectionID), arg0) } @@ -231,7 +235,7 @@ func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLev } // HandleMessage indicates an expected call of HandleMessage. -func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) } @@ -271,7 +275,7 @@ func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error } // SetLargest1RTTAcked indicates an expected call of SetLargest1RTTAcked. -func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 interface{}) *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLargest1RTTAcked", reflect.TypeOf((*MockCryptoSetup)(nil).SetLargest1RTTAcked), arg0) } diff --git a/internal/mocks/logging/internal/connection_tracer.go b/internal/mocks/logging/internal/connection_tracer.go index 1cc0bd2e..5131453a 100644 --- a/internal/mocks/logging/internal/connection_tracer.go +++ b/internal/mocks/logging/internal/connection_tracer.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/mocks/logging (interfaces: ConnectionTracer) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package internal -destination internal/connection_tracer.go github.com/quic-go/quic-go/internal/mocks/logging ConnectionTracer +// // Package internal is a generated GoMock package. package internal @@ -46,7 +50,7 @@ func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, } // AcknowledgedPacket indicates an expected call of AcknowledgedPacket. -func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) } @@ -58,7 +62,7 @@ func (m *MockConnectionTracer) BufferedPacket(arg0 logging.PacketType, arg1 prot } // BufferedPacket indicates an expected call of BufferedPacket. -func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0, arg1) } @@ -82,7 +86,7 @@ func (m *MockConnectionTracer) ClosedConnection(arg0 error) { } // ClosedConnection indicates an expected call of ClosedConnection. -func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) } @@ -94,7 +98,7 @@ func (m *MockConnectionTracer) Debug(arg0, arg1 string) { } // Debug indicates an expected call of Debug. -func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) } @@ -106,7 +110,7 @@ func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLe } // DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. -func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) } @@ -118,7 +122,7 @@ func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { } // DroppedKey indicates an expected call of DroppedKey. -func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) } @@ -130,7 +134,7 @@ func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 proto } // DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) } @@ -142,7 +146,7 @@ func (m *MockConnectionTracer) ECNStateUpdated(arg0 logging.ECNState, arg1 loggi } // ECNStateUpdated indicates an expected call of ECNStateUpdated. -func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1) } @@ -166,7 +170,7 @@ func (m *MockConnectionTracer) LossTimerExpired(arg0 logging.TimerType, arg1 pro } // LossTimerExpired indicates an expected call of LossTimerExpired. -func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) } @@ -178,7 +182,7 @@ func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 pr } // LostPacket indicates an expected call of LostPacket. -func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) } @@ -190,7 +194,7 @@ func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, ar } // NegotiatedVersion indicates an expected call of NegotiatedVersion. -func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) } @@ -202,7 +206,7 @@ func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeade } // ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3) } @@ -214,7 +218,7 @@ func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { } // ReceivedRetry indicates an expected call of ReceivedRetry. -func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) } @@ -226,7 +230,7 @@ func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHead } // ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3) } @@ -238,7 +242,7 @@ func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportP } // ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. -func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) } @@ -250,7 +254,7 @@ func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0, arg1 proto } // ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1, arg2) } @@ -262,7 +266,7 @@ func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportP } // RestoredTransportParameters indicates an expected call of RestoredTransportParameters. -func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) } @@ -274,7 +278,7 @@ func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, a } // SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4) } @@ -286,7 +290,7 @@ func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, } // SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4) } @@ -298,7 +302,7 @@ func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParam } // SentTransportParameters indicates an expected call of SentTransportParameters. -func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) } @@ -310,7 +314,7 @@ func (m *MockConnectionTracer) SetLossTimer(arg0 logging.TimerType, arg1 protoco } // SetLossTimer indicates an expected call of SetLossTimer. -func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) } @@ -322,7 +326,7 @@ func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 } // StartedConnection indicates an expected call of StartedConnection. -func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) } @@ -334,7 +338,7 @@ func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionSta } // UpdatedCongestionState indicates an expected call of UpdatedCongestionState. -func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) } @@ -346,7 +350,7 @@ func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { } // UpdatedKey indicates an expected call of UpdatedKey. -func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) } @@ -358,7 +362,7 @@ func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, } // UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. -func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) } @@ -370,7 +374,7 @@ func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 p } // UpdatedMetrics indicates an expected call of UpdatedMetrics. -func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) } @@ -382,7 +386,7 @@ func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { } // UpdatedPTOCount indicates an expected call of UpdatedPTOCount. -func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) } diff --git a/internal/mocks/logging/internal/tracer.go b/internal/mocks/logging/internal/tracer.go index 15dbcaed..14ce172c 100644 --- a/internal/mocks/logging/internal/tracer.go +++ b/internal/mocks/logging/internal/tracer.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/mocks/logging (interfaces: Tracer) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package internal -destination internal/tracer.go github.com/quic-go/quic-go/internal/mocks/logging Tracer +// // Package internal is a generated GoMock package. package internal @@ -44,7 +48,7 @@ func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 } // DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) } @@ -56,7 +60,7 @@ func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol. } // SentPacket indicates an expected call of SentPacket. -func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) } @@ -68,7 +72,7 @@ func (m *MockTracer) SentVersionNegotiationPacket(arg0 net.Addr, arg1, arg2 prot } // SentVersionNegotiationPacket indicates an expected call of SentVersionNegotiationPacket. -func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) } diff --git a/internal/mocks/long_header_opener.go b/internal/mocks/long_header_opener.go index fc284dd8..ad5d6b69 100644 --- a/internal/mocks/long_header_opener.go +++ b/internal/mocks/long_header_opener.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/handshake (interfaces: LongHeaderOpener) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination long_header_opener.go github.com/quic-go/quic-go/internal/handshake LongHeaderOpener +// // Package mocks is a generated GoMock package. package mocks @@ -43,7 +47,7 @@ func (m *MockLongHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, ar } // DecodePacketNumber indicates an expected call of DecodePacketNumber. -func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) } @@ -55,7 +59,7 @@ func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byt } // DecryptHeader indicates an expected call of DecryptHeader. -func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) } @@ -70,7 +74,7 @@ func (m *MockLongHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumbe } // Open indicates an expected call of Open. -func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3) } diff --git a/internal/mocks/quic/early_conn.go b/internal/mocks/quic/early_conn.go index c4db88d2..205ed7f2 100644 --- a/internal/mocks/quic/early_conn.go +++ b/internal/mocks/quic/early_conn.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: EarlyConnection) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mockquic -destination quic/early_conn_tmp.go github.com/quic-go/quic-go EarlyConnection +// // Package mockquic is a generated GoMock package. package mockquic @@ -47,7 +51,7 @@ func (m *MockEarlyConnection) AcceptStream(arg0 context.Context) (quic.Stream, e } // AcceptStream indicates an expected call of AcceptStream. -func (mr *MockEarlyConnectionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) AcceptStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptStream), arg0) } @@ -62,7 +66,7 @@ func (m *MockEarlyConnection) AcceptUniStream(arg0 context.Context) (quic.Receiv } // AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockEarlyConnectionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) AcceptUniStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptUniStream), arg0) } @@ -76,7 +80,7 @@ func (m *MockEarlyConnection) CloseWithError(arg0 qerr.ApplicationErrorCode, arg } // CloseWithError indicates an expected call of CloseWithError. -func (mr *MockEarlyConnectionMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) CloseWithError(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockEarlyConnection)(nil).CloseWithError), arg0, arg1) } @@ -176,7 +180,7 @@ func (m *MockEarlyConnection) OpenStreamSync(arg0 context.Context) (quic.Stream, } // OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockEarlyConnectionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) OpenStreamSync(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenStreamSync), arg0) } @@ -206,7 +210,7 @@ func (m *MockEarlyConnection) OpenUniStreamSync(arg0 context.Context) (quic.Send } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockEarlyConnectionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenUniStreamSync), arg0) } @@ -221,7 +225,7 @@ func (m *MockEarlyConnection) ReceiveMessage(arg0 context.Context) ([]byte, erro } // ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveMessage), arg0) } @@ -249,7 +253,7 @@ func (m *MockEarlyConnection) SendMessage(arg0 []byte) error { } // SendMessage indicates an expected call of SendMessage. -func (mr *MockEarlyConnectionMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) SendMessage(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockEarlyConnection)(nil).SendMessage), arg0) } diff --git a/internal/mocks/quic/stream.go b/internal/mocks/quic/stream.go index 7fc65600..2043fdfb 100644 --- a/internal/mocks/quic/stream.go +++ b/internal/mocks/quic/stream.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: Stream) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mockquic -destination quic/stream.go github.com/quic-go/quic-go Stream +// // Package mockquic is a generated GoMock package. package mockquic @@ -44,7 +48,7 @@ func (m *MockStream) CancelRead(arg0 qerr.StreamErrorCode) { } // CancelRead indicates an expected call of CancelRead. -func (mr *MockStreamMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) CancelRead(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStream)(nil).CancelRead), arg0) } @@ -56,7 +60,7 @@ func (m *MockStream) CancelWrite(arg0 qerr.StreamErrorCode) { } // CancelWrite indicates an expected call of CancelWrite. -func (mr *MockStreamMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) CancelWrite(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStream)(nil).CancelWrite), arg0) } @@ -99,7 +103,7 @@ func (m *MockStream) Read(arg0 []byte) (int, error) { } // Read indicates an expected call of Read. -func (mr *MockStreamMockRecorder) Read(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) Read(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStream)(nil).Read), arg0) } @@ -113,7 +117,7 @@ func (m *MockStream) SetDeadline(arg0 time.Time) error { } // SetDeadline indicates an expected call of SetDeadline. -func (mr *MockStreamMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) SetDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStream)(nil).SetDeadline), arg0) } @@ -127,7 +131,7 @@ func (m *MockStream) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockStreamMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStream)(nil).SetReadDeadline), arg0) } @@ -141,7 +145,7 @@ func (m *MockStream) SetWriteDeadline(arg0 time.Time) error { } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockStreamMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) SetWriteDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStream)(nil).SetWriteDeadline), arg0) } @@ -170,7 +174,7 @@ func (m *MockStream) Write(arg0 []byte) (int, error) { } // Write indicates an expected call of Write. -func (mr *MockStreamMockRecorder) Write(arg0 interface{}) *gomock.Call { +func (mr *MockStreamMockRecorder) Write(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStream)(nil).Write), arg0) } diff --git a/internal/mocks/short_header_opener.go b/internal/mocks/short_header_opener.go index ffc8bd48..4b9f77f8 100644 --- a/internal/mocks/short_header_opener.go +++ b/internal/mocks/short_header_opener.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/handshake (interfaces: ShortHeaderOpener) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination short_header_opener.go github.com/quic-go/quic-go/internal/handshake ShortHeaderOpener +// // Package mocks is a generated GoMock package. package mocks @@ -44,7 +48,7 @@ func (m *MockShortHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, a } // DecodePacketNumber indicates an expected call of DecodePacketNumber. -func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) } @@ -56,7 +60,7 @@ func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []by } // DecryptHeader indicates an expected call of DecryptHeader. -func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) } @@ -71,7 +75,7 @@ func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 time.Time, arg3 pro } // Open indicates an expected call of Open. -func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call { +func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4, arg5) } diff --git a/internal/mocks/short_header_sealer.go b/internal/mocks/short_header_sealer.go index d0109b5e..aabe6867 100644 --- a/internal/mocks/short_header_sealer.go +++ b/internal/mocks/short_header_sealer.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/handshake (interfaces: ShortHeaderSealer) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination short_header_sealer.go github.com/quic-go/quic-go/internal/handshake ShortHeaderSealer +// // Package mocks is a generated GoMock package. package mocks @@ -41,7 +45,7 @@ func (m *MockShortHeaderSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []by } // EncryptHeader indicates an expected call of EncryptHeader. -func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockShortHeaderSealer)(nil).EncryptHeader), arg0, arg1, arg2) } @@ -83,7 +87,7 @@ func (m *MockShortHeaderSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumb } // Seal indicates an expected call of Seal. -func (mr *MockShortHeaderSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockShortHeaderSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockShortHeaderSealer)(nil).Seal), arg0, arg1, arg2, arg3) } diff --git a/internal/mocks/stream_flow_controller.go b/internal/mocks/stream_flow_controller.go index 05eaf8b9..342661c9 100644 --- a/internal/mocks/stream_flow_controller.go +++ b/internal/mocks/stream_flow_controller.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/flowcontrol (interfaces: StreamFlowController) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocks -destination stream_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol StreamFlowController +// // Package mocks is a generated GoMock package. package mocks @@ -53,7 +57,7 @@ func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) { } // AddBytesRead indicates an expected call of AddBytesRead. -func (mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesRead), arg0) } @@ -65,7 +69,7 @@ func (m *MockStreamFlowController) AddBytesSent(arg0 protocol.ByteCount) { } // AddBytesSent indicates an expected call of AddBytesSent. -func (mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesSent), arg0) } @@ -122,7 +126,7 @@ func (m *MockStreamFlowController) UpdateHighestReceived(arg0 protocol.ByteCount } // UpdateHighestReceived indicates an expected call of UpdateHighestReceived. -func (mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateHighestReceived), arg0, arg1) } @@ -134,7 +138,7 @@ func (m *MockStreamFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { } // UpdateSendWindow indicates an expected call of UpdateSendWindow. -func (mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateSendWindow), arg0) } diff --git a/internal/mocks/tls/client_session_cache.go b/internal/mocks/tls/client_session_cache.go index 6877f9bb..8837755f 100644 --- a/internal/mocks/tls/client_session_cache.go +++ b/internal/mocks/tls/client_session_cache.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: crypto/tls (interfaces: ClientSessionCache) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache +// // Package mocktls is a generated GoMock package. package mocktls @@ -44,7 +48,7 @@ func (m *MockClientSessionCache) Get(arg0 string) (*tls.ClientSessionState, bool } // Get indicates an expected call of Get. -func (mr *MockClientSessionCacheMockRecorder) Get(arg0 interface{}) *gomock.Call { +func (mr *MockClientSessionCacheMockRecorder) Get(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockClientSessionCache)(nil).Get), arg0) } @@ -56,7 +60,7 @@ func (m *MockClientSessionCache) Put(arg0 string, arg1 *tls.ClientSessionState) } // Put indicates an expected call of Put. -func (mr *MockClientSessionCacheMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockClientSessionCacheMockRecorder) Put(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockClientSessionCache)(nil).Put), arg0, arg1) } diff --git a/mock_ack_frame_source_test.go b/mock_ack_frame_source_test.go index 54105f9d..4990e472 100644 --- a/mock_ack_frame_source_test.go +++ b/mock_ack_frame_source_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: AckFrameSource) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource +// // Package quic is a generated GoMock package. package quic @@ -44,7 +48,7 @@ func (m *MockAckFrameSource) GetAckFrame(arg0 protocol.EncryptionLevel, arg1 boo } // GetAckFrame indicates an expected call of GetAckFrame. -func (mr *MockAckFrameSourceMockRecorder) GetAckFrame(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockAckFrameSourceMockRecorder) GetAckFrame(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame), arg0, arg1) } diff --git a/mock_batch_conn_test.go b/mock_batch_conn_test.go index 9621e7b4..1ecc4265 100644 --- a/mock_batch_conn_test.go +++ b/mock_batch_conn_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: sys_conn_oob.go - +// +// Generated by this command: +// +// mockgen -package quic -self_package github.com/quic-go/quic-go -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn +// // Package quic is a generated GoMock package. package quic @@ -44,7 +48,7 @@ func (m *MockBatchConn) ReadBatch(ms []ipv4.Message, flags int) (int, error) { } // ReadBatch indicates an expected call of ReadBatch. -func (mr *MockBatchConnMockRecorder) ReadBatch(ms, flags interface{}) *gomock.Call { +func (mr *MockBatchConnMockRecorder) ReadBatch(ms, flags any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadBatch", reflect.TypeOf((*MockBatchConn)(nil).ReadBatch), ms, flags) } diff --git a/mock_conn_runner_test.go b/mock_conn_runner_test.go index e404c283..37c0f6cd 100644 --- a/mock_conn_runner_test.go +++ b/mock_conn_runner_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: ConnRunner) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_conn_runner_test.go github.com/quic-go/quic-go ConnRunner +// // Package quic is a generated GoMock package. package quic @@ -43,7 +47,7 @@ func (m *MockConnRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) boo } // Add indicates an expected call of Add. -func (mr *MockConnRunnerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) Add(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockConnRunner)(nil).Add), arg0, arg1) } @@ -55,7 +59,7 @@ func (m *MockConnRunner) AddResetToken(arg0 protocol.StatelessResetToken, arg1 p } // AddResetToken indicates an expected call of AddResetToken. -func (mr *MockConnRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) AddResetToken(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockConnRunner)(nil).AddResetToken), arg0, arg1) } @@ -69,7 +73,7 @@ func (m *MockConnRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) prot } // GetStatelessResetToken indicates an expected call of GetStatelessResetToken. -func (mr *MockConnRunnerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) GetStatelessResetToken(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockConnRunner)(nil).GetStatelessResetToken), arg0) } @@ -81,7 +85,7 @@ func (m *MockConnRunner) Remove(arg0 protocol.ConnectionID) { } // Remove indicates an expected call of Remove. -func (mr *MockConnRunnerMockRecorder) Remove(arg0 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) Remove(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockConnRunner)(nil).Remove), arg0) } @@ -93,7 +97,7 @@ func (m *MockConnRunner) RemoveResetToken(arg0 protocol.StatelessResetToken) { } // RemoveResetToken indicates an expected call of RemoveResetToken. -func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockConnRunner)(nil).RemoveResetToken), arg0) } @@ -105,7 +109,7 @@ func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 pr } // ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1, arg2) } @@ -117,7 +121,7 @@ func (m *MockConnRunner) Retire(arg0 protocol.ConnectionID) { } // Retire indicates an expected call of Retire. -func (mr *MockConnRunnerMockRecorder) Retire(arg0 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) Retire(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockConnRunner)(nil).Retire), arg0) } diff --git a/mock_crypto_data_handler_test.go b/mock_crypto_data_handler_test.go index 22e6eaa2..3240b981 100644 --- a/mock_crypto_data_handler_test.go +++ b/mock_crypto_data_handler_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: CryptoDataHandler) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_data_handler_test.go github.com/quic-go/quic-go CryptoDataHandler +// // Package quic is a generated GoMock package. package quic @@ -44,7 +48,7 @@ func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.Encrypt } // HandleMessage indicates an expected call of HandleMessage. -func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1) } diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go index 253e95e8..176b51f3 100644 --- a/mock_crypto_stream_test.go +++ b/mock_crypto_stream_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: CryptoStream) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_stream_test.go github.com/quic-go/quic-go CryptoStream +// // Package quic is a generated GoMock package. package quic @@ -72,7 +76,7 @@ func (m *MockCryptoStream) HandleCryptoFrame(arg0 *wire.CryptoFrame) error { } // HandleCryptoFrame indicates an expected call of HandleCryptoFrame. -func (mr *MockCryptoStreamMockRecorder) HandleCryptoFrame(arg0 interface{}) *gomock.Call { +func (mr *MockCryptoStreamMockRecorder) HandleCryptoFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).HandleCryptoFrame), arg0) } @@ -100,7 +104,7 @@ func (m *MockCryptoStream) PopCryptoFrame(arg0 protocol.ByteCount) *wire.CryptoF } // PopCryptoFrame indicates an expected call of PopCryptoFrame. -func (mr *MockCryptoStreamMockRecorder) PopCryptoFrame(arg0 interface{}) *gomock.Call { +func (mr *MockCryptoStreamMockRecorder) PopCryptoFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).PopCryptoFrame), arg0) } @@ -115,7 +119,7 @@ func (m *MockCryptoStream) Write(arg0 []byte) (int, error) { } // Write indicates an expected call of Write. -func (mr *MockCryptoStreamMockRecorder) Write(arg0 interface{}) *gomock.Call { +func (mr *MockCryptoStreamMockRecorder) Write(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockCryptoStream)(nil).Write), arg0) } diff --git a/mock_frame_source_test.go b/mock_frame_source_test.go index 040c7581..b3b6638b 100644 --- a/mock_frame_source_test.go +++ b/mock_frame_source_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: FrameSource) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource +// // Package quic is a generated GoMock package. package quic @@ -45,7 +49,7 @@ func (m *MockFrameSource) AppendControlFrames(arg0 []ackhandler.Frame, arg1 prot } // AppendControlFrames indicates an expected call of AppendControlFrames. -func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendControlFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendControlFrames), arg0, arg1, arg2) } @@ -60,7 +64,7 @@ func (m *MockFrameSource) AppendStreamFrames(arg0 []ackhandler.StreamFrame, arg1 } // AppendStreamFrames indicates an expected call of AppendStreamFrames. -func (mr *MockFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendStreamFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendStreamFrames), arg0, arg1, arg2) } diff --git a/mock_mtu_discoverer_test.go b/mock_mtu_discoverer_test.go index 1f5432cd..a8a9be3e 100644 --- a/mock_mtu_discoverer_test.go +++ b/mock_mtu_discoverer_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: MTUDiscoverer) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_mtu_discoverer_test.go github.com/quic-go/quic-go MTUDiscoverer +// // Package quic is a generated GoMock package. package quic @@ -74,7 +78,7 @@ func (m *MockMTUDiscoverer) ShouldSendProbe(arg0 time.Time) bool { } // ShouldSendProbe indicates an expected call of ShouldSendProbe. -func (mr *MockMTUDiscovererMockRecorder) ShouldSendProbe(arg0 interface{}) *gomock.Call { +func (mr *MockMTUDiscovererMockRecorder) ShouldSendProbe(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendProbe", reflect.TypeOf((*MockMTUDiscoverer)(nil).ShouldSendProbe), arg0) } @@ -86,7 +90,7 @@ func (m *MockMTUDiscoverer) Start(arg0 protocol.ByteCount) { } // Start indicates an expected call of Start. -func (mr *MockMTUDiscovererMockRecorder) Start(arg0 interface{}) *gomock.Call { +func (mr *MockMTUDiscovererMockRecorder) Start(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockMTUDiscoverer)(nil).Start), arg0) } diff --git a/mock_packer_test.go b/mock_packer_test.go index 4eedab98..26fc63a0 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: Packer) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_packer_test.go github.com/quic-go/quic-go Packer +// // Package quic is a generated GoMock package. package quic @@ -46,7 +50,7 @@ func (m *MockPacker) AppendPacket(arg0 *packetBuffer, arg1 protocol.ByteCount, a } // AppendPacket indicates an expected call of AppendPacket. -func (mr *MockPackerMockRecorder) AppendPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) AppendPacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendPacket", reflect.TypeOf((*MockPacker)(nil).AppendPacket), arg0, arg1, arg2) } @@ -61,7 +65,7 @@ func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel, arg1 pr } // MaybePackProbePacket indicates an expected call of MaybePackProbePacket. -func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0, arg1, arg2) } @@ -77,7 +81,7 @@ func (m *MockPacker) PackAckOnlyPacket(arg0 protocol.ByteCount, arg1 protocol.Ve } // PackAckOnlyPacket indicates an expected call of PackAckOnlyPacket. -func (mr *MockPackerMockRecorder) PackAckOnlyPacket(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackAckOnlyPacket(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackAckOnlyPacket", reflect.TypeOf((*MockPacker)(nil).PackAckOnlyPacket), arg0, arg1) } @@ -92,7 +96,7 @@ func (m *MockPacker) PackApplicationClose(arg0 *qerr.ApplicationError, arg1 prot } // PackApplicationClose indicates an expected call of PackApplicationClose. -func (mr *MockPackerMockRecorder) PackApplicationClose(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackApplicationClose(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackApplicationClose", reflect.TypeOf((*MockPacker)(nil).PackApplicationClose), arg0, arg1, arg2) } @@ -107,7 +111,7 @@ func (m *MockPacker) PackCoalescedPacket(arg0 bool, arg1 protocol.ByteCount, arg } // PackCoalescedPacket indicates an expected call of PackCoalescedPacket. -func (mr *MockPackerMockRecorder) PackCoalescedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackCoalescedPacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket), arg0, arg1, arg2) } @@ -122,7 +126,7 @@ func (m *MockPacker) PackConnectionClose(arg0 *qerr.TransportError, arg1 protoco } // PackConnectionClose indicates an expected call of PackConnectionClose. -func (mr *MockPackerMockRecorder) PackConnectionClose(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackConnectionClose(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0, arg1, arg2) } @@ -138,7 +142,7 @@ func (m *MockPacker) PackMTUProbePacket(arg0 ackhandler.Frame, arg1 protocol.Byt } // PackMTUProbePacket indicates an expected call of PackMTUProbePacket. -func (mr *MockPackerMockRecorder) PackMTUProbePacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackMTUProbePacket(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackMTUProbePacket", reflect.TypeOf((*MockPacker)(nil).PackMTUProbePacket), arg0, arg1, arg2) } @@ -150,7 +154,7 @@ func (m *MockPacker) SetToken(arg0 []byte) { } // SetToken indicates an expected call of SetToken. -func (mr *MockPackerMockRecorder) SetToken(arg0 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) SetToken(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetToken", reflect.TypeOf((*MockPacker)(nil).SetToken), arg0) } diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index c948d9d5..e8c57416 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: PacketHandlerManager) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager +// // Package quic is a generated GoMock package. package quic @@ -43,7 +47,7 @@ func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHa } // Add indicates an expected call of Add. -func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) } @@ -55,7 +59,7 @@ func (m *MockPacketHandlerManager) AddResetToken(arg0 protocol.StatelessResetTok } // AddResetToken indicates an expected call of AddResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1) } @@ -69,7 +73,7 @@ func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionI } // AddWithConnID indicates an expected call of AddWithConnID. -func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2) } @@ -81,7 +85,7 @@ func (m *MockPacketHandlerManager) Close(arg0 error) { } // Close indicates an expected call of Close. -func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0) } @@ -108,7 +112,7 @@ func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandle } // Get indicates an expected call of Get. -func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0) } @@ -123,7 +127,7 @@ func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetT } // GetByResetToken indicates an expected call of GetByResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0) } @@ -137,7 +141,7 @@ func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.Connecti } // GetStatelessResetToken indicates an expected call of GetStatelessResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0) } @@ -149,7 +153,7 @@ func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { } // Remove indicates an expected call of Remove. -func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) } @@ -161,7 +165,7 @@ func (m *MockPacketHandlerManager) RemoveResetToken(arg0 protocol.StatelessReset } // RemoveResetToken indicates an expected call of RemoveResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0) } @@ -173,7 +177,7 @@ func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionI } // ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1, arg2) } @@ -185,7 +189,7 @@ func (m *MockPacketHandlerManager) Retire(arg0 protocol.ConnectionID) { } // Retire indicates an expected call of Retire. -func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0) } diff --git a/mock_packet_handler_test.go b/mock_packet_handler_test.go index e8490589..f30e8f07 100644 --- a/mock_packet_handler_test.go +++ b/mock_packet_handler_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: PacketHandler) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler +// // Package quic is a generated GoMock package. package quic @@ -41,7 +45,7 @@ func (m *MockPacketHandler) destroy(arg0 error) { } // destroy indicates an expected call of destroy. -func (mr *MockPacketHandlerMockRecorder) destroy(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerMockRecorder) destroy(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0) } @@ -67,7 +71,7 @@ func (m *MockPacketHandler) handlePacket(arg0 receivedPacket) { } // handlePacket indicates an expected call of handlePacket. -func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { +func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockPacketHandler)(nil).handlePacket), arg0) } diff --git a/mock_packetconn_test.go b/mock_packetconn_test.go index c8e20bf2..f148bb32 100644 --- a/mock_packetconn_test.go +++ b/mock_packetconn_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: net (interfaces: PacketConn) - +// +// Generated by this command: +// +// mockgen -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn +// // Package quic is a generated GoMock package. package quic @@ -74,7 +78,7 @@ func (m *MockPacketConn) ReadFrom(arg0 []byte) (int, net.Addr, error) { } // ReadFrom indicates an expected call of ReadFrom. -func (mr *MockPacketConnMockRecorder) ReadFrom(arg0 interface{}) *gomock.Call { +func (mr *MockPacketConnMockRecorder) ReadFrom(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFrom", reflect.TypeOf((*MockPacketConn)(nil).ReadFrom), arg0) } @@ -88,7 +92,7 @@ func (m *MockPacketConn) SetDeadline(arg0 time.Time) error { } // SetDeadline indicates an expected call of SetDeadline. -func (mr *MockPacketConnMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockPacketConnMockRecorder) SetDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetDeadline), arg0) } @@ -102,7 +106,7 @@ func (m *MockPacketConn) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockPacketConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockPacketConnMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetReadDeadline), arg0) } @@ -116,7 +120,7 @@ func (m *MockPacketConn) SetWriteDeadline(arg0 time.Time) error { } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockPacketConnMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockPacketConnMockRecorder) SetWriteDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetWriteDeadline), arg0) } @@ -131,7 +135,7 @@ func (m *MockPacketConn) WriteTo(arg0 []byte, arg1 net.Addr) (int, error) { } // WriteTo indicates an expected call of WriteTo. -func (mr *MockPacketConnMockRecorder) WriteTo(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPacketConnMockRecorder) WriteTo(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockPacketConn)(nil).WriteTo), arg0, arg1) } diff --git a/mock_quic_conn_test.go b/mock_quic_conn_test.go index 20015c66..d30a939d 100644 --- a/mock_quic_conn_test.go +++ b/mock_quic_conn_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: QUICConn) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_quic_conn_test.go github.com/quic-go/quic-go QUICConn +// // Package quic is a generated GoMock package. package quic @@ -47,7 +51,7 @@ func (m *MockQUICConn) AcceptStream(arg0 context.Context) (Stream, error) { } // AcceptStream indicates an expected call of AcceptStream. -func (mr *MockQUICConnMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) AcceptStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQUICConn)(nil).AcceptStream), arg0) } @@ -62,7 +66,7 @@ func (m *MockQUICConn) AcceptUniStream(arg0 context.Context) (ReceiveStream, err } // AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockQUICConnMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) AcceptUniStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQUICConn)(nil).AcceptUniStream), arg0) } @@ -76,7 +80,7 @@ func (m *MockQUICConn) CloseWithError(arg0 qerr.ApplicationErrorCode, arg1 strin } // CloseWithError indicates an expected call of CloseWithError. -func (mr *MockQUICConnMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) CloseWithError(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQUICConn)(nil).CloseWithError), arg0, arg1) } @@ -190,7 +194,7 @@ func (m *MockQUICConn) OpenStreamSync(arg0 context.Context) (Stream, error) { } // OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockQUICConnMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) OpenStreamSync(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQUICConn)(nil).OpenStreamSync), arg0) } @@ -220,7 +224,7 @@ func (m *MockQUICConn) OpenUniStreamSync(arg0 context.Context) (SendStream, erro } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockQUICConnMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQUICConn)(nil).OpenUniStreamSync), arg0) } @@ -235,7 +239,7 @@ func (m *MockQUICConn) ReceiveMessage(arg0 context.Context) ([]byte, error) { } // ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockQUICConnMockRecorder) ReceiveMessage(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) ReceiveMessage(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQUICConn)(nil).ReceiveMessage), arg0) } @@ -263,7 +267,7 @@ func (m *MockQUICConn) SendMessage(arg0 []byte) error { } // SendMessage indicates an expected call of SendMessage. -func (mr *MockQUICConnMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) SendMessage(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockQUICConn)(nil).SendMessage), arg0) } @@ -275,7 +279,7 @@ func (m *MockQUICConn) destroy(arg0 error) { } // destroy indicates an expected call of destroy. -func (mr *MockQUICConnMockRecorder) destroy(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) destroy(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQUICConn)(nil).destroy), arg0) } @@ -315,7 +319,7 @@ func (m *MockQUICConn) handlePacket(arg0 receivedPacket) { } // handlePacket indicates an expected call of handlePacket. -func (mr *MockQUICConnMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { +func (mr *MockQUICConnMockRecorder) handlePacket(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockQUICConn)(nil).handlePacket), arg0) } diff --git a/mock_raw_conn_test.go b/mock_raw_conn_test.go index d462fc87..bf4751d9 100644 --- a/mock_raw_conn_test.go +++ b/mock_raw_conn_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: RawConn) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_raw_conn_test.go github.com/quic-go/quic-go RawConn +// // Package quic is a generated GoMock package. package quic @@ -88,7 +92,7 @@ func (m *MockRawConn) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockRawConn)(nil).SetReadDeadline), arg0) } @@ -103,7 +107,7 @@ func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 } // WritePacket indicates an expected call of WritePacket. -func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3, arg4) } diff --git a/mock_receive_stream_internal_test.go b/mock_receive_stream_internal_test.go index 518e275b..74ef3fed 100644 --- a/mock_receive_stream_internal_test.go +++ b/mock_receive_stream_internal_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: ReceiveStreamI) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_receive_stream_internal_test.go github.com/quic-go/quic-go ReceiveStreamI +// // Package quic is a generated GoMock package. package quic @@ -44,7 +48,7 @@ func (m *MockReceiveStreamI) CancelRead(arg0 qerr.StreamErrorCode) { } // CancelRead indicates an expected call of CancelRead. -func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockReceiveStreamI)(nil).CancelRead), arg0) } @@ -59,7 +63,7 @@ func (m *MockReceiveStreamI) Read(arg0 []byte) (int, error) { } // Read indicates an expected call of Read. -func (mr *MockReceiveStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) Read(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReceiveStreamI)(nil).Read), arg0) } @@ -73,7 +77,7 @@ func (m *MockReceiveStreamI) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockReceiveStreamI)(nil).SetReadDeadline), arg0) } @@ -99,7 +103,7 @@ func (m *MockReceiveStreamI) closeForShutdown(arg0 error) { } // closeForShutdown indicates an expected call of closeForShutdown. -func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockReceiveStreamI)(nil).closeForShutdown), arg0) } @@ -127,7 +131,7 @@ func (m *MockReceiveStreamI) handleResetStreamFrame(arg0 *wire.ResetStreamFrame) } // handleResetStreamFrame indicates an expected call of handleResetStreamFrame. -func (mr *MockReceiveStreamIMockRecorder) handleResetStreamFrame(arg0 interface{}) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) handleResetStreamFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleResetStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleResetStreamFrame), arg0) } @@ -141,7 +145,7 @@ func (m *MockReceiveStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error { } // handleStreamFrame indicates an expected call of handleStreamFrame. -func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleStreamFrame), arg0) } diff --git a/mock_sealing_manager_test.go b/mock_sealing_manager_test.go index b77c747a..c8ecaad2 100644 --- a/mock_sealing_manager_test.go +++ b/mock_sealing_manager_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: SealingManager) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager +// // Package quic is a generated GoMock package. package quic diff --git a/mock_send_conn_test.go b/mock_send_conn_test.go index 6568f9e5..407c798b 100644 --- a/mock_send_conn_test.go +++ b/mock_send_conn_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: SendConn) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_send_conn_test.go github.com/quic-go/quic-go SendConn +// // Package quic is a generated GoMock package. package quic @@ -86,7 +90,7 @@ func (m *MockSendConn) Write(arg0 []byte, arg1 uint16, arg2 protocol.ECN) error } // Write indicates an expected call of Write. -func (mr *MockSendConnMockRecorder) Write(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockSendConnMockRecorder) Write(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1, arg2) } diff --git a/mock_send_stream_internal_test.go b/mock_send_stream_internal_test.go index 4fb89a35..ded44a85 100644 --- a/mock_send_stream_internal_test.go +++ b/mock_send_stream_internal_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: SendStreamI) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI +// // Package quic is a generated GoMock package. package quic @@ -46,7 +50,7 @@ func (m *MockSendStreamI) CancelWrite(arg0 qerr.StreamErrorCode) { } // CancelWrite indicates an expected call of CancelWrite. -func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockSendStreamI)(nil).CancelWrite), arg0) } @@ -88,7 +92,7 @@ func (m *MockSendStreamI) SetWriteDeadline(arg0 time.Time) error { } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockSendStreamI)(nil).SetWriteDeadline), arg0) } @@ -117,7 +121,7 @@ func (m *MockSendStreamI) Write(arg0 []byte) (int, error) { } // Write indicates an expected call of Write. -func (mr *MockSendStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) Write(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendStreamI)(nil).Write), arg0) } @@ -129,7 +133,7 @@ func (m *MockSendStreamI) closeForShutdown(arg0 error) { } // closeForShutdown indicates an expected call of closeForShutdown. -func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockSendStreamI)(nil).closeForShutdown), arg0) } @@ -141,7 +145,7 @@ func (m *MockSendStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) { } // handleStopSendingFrame indicates an expected call of handleStopSendingFrame. -func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0) } @@ -171,7 +175,7 @@ func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol. } // popStreamFrame indicates an expected call of popStreamFrame. -func (mr *MockSendStreamIMockRecorder) popStreamFrame(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) popStreamFrame(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockSendStreamI)(nil).popStreamFrame), arg0, arg1) } @@ -183,7 +187,7 @@ func (m *MockSendStreamI) updateSendWindow(arg0 protocol.ByteCount) { } // updateSendWindow indicates an expected call of updateSendWindow. -func (mr *MockSendStreamIMockRecorder) updateSendWindow(arg0 interface{}) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) updateSendWindow(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockSendStreamI)(nil).updateSendWindow), arg0) } diff --git a/mock_sender_test.go b/mock_sender_test.go index 67bb9d09..2565a8fa 100644 --- a/mock_sender_test.go +++ b/mock_sender_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: Sender) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_sender_test.go github.com/quic-go/quic-go Sender +// // Package quic is a generated GoMock package. package quic @@ -81,7 +85,7 @@ func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16, arg2 protocol.ECN) { } // Send indicates an expected call of Send. -func (mr *MockSenderMockRecorder) Send(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockSenderMockRecorder) Send(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1, arg2) } diff --git a/mock_stream_getter_test.go b/mock_stream_getter_test.go index 0ff7f13f..d0acde1d 100644 --- a/mock_stream_getter_test.go +++ b/mock_stream_getter_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: StreamGetter) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter +// // Package quic is a generated GoMock package. package quic @@ -44,7 +48,7 @@ func (m *MockStreamGetter) GetOrOpenReceiveStream(arg0 protocol.StreamID) (recei } // GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream. -func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0) } @@ -59,7 +63,7 @@ func (m *MockStreamGetter) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStre } // GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream. -func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0) } diff --git a/mock_stream_internal_test.go b/mock_stream_internal_test.go index 447438c1..9121a461 100644 --- a/mock_stream_internal_test.go +++ b/mock_stream_internal_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: StreamI) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_internal_test.go github.com/quic-go/quic-go StreamI +// // Package quic is a generated GoMock package. package quic @@ -46,7 +50,7 @@ func (m *MockStreamI) CancelRead(arg0 qerr.StreamErrorCode) { } // CancelRead indicates an expected call of CancelRead. -func (mr *MockStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) CancelRead(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStreamI)(nil).CancelRead), arg0) } @@ -58,7 +62,7 @@ func (m *MockStreamI) CancelWrite(arg0 qerr.StreamErrorCode) { } // CancelWrite indicates an expected call of CancelWrite. -func (mr *MockStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) CancelWrite(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStreamI)(nil).CancelWrite), arg0) } @@ -101,7 +105,7 @@ func (m *MockStreamI) Read(arg0 []byte) (int, error) { } // Read indicates an expected call of Read. -func (mr *MockStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) Read(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStreamI)(nil).Read), arg0) } @@ -115,7 +119,7 @@ func (m *MockStreamI) SetDeadline(arg0 time.Time) error { } // SetDeadline indicates an expected call of SetDeadline. -func (mr *MockStreamIMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) SetDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStreamI)(nil).SetDeadline), arg0) } @@ -129,7 +133,7 @@ func (m *MockStreamI) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStreamI)(nil).SetReadDeadline), arg0) } @@ -143,7 +147,7 @@ func (m *MockStreamI) SetWriteDeadline(arg0 time.Time) error { } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), arg0) } @@ -172,7 +176,7 @@ func (m *MockStreamI) Write(arg0 []byte) (int, error) { } // Write indicates an expected call of Write. -func (mr *MockStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) Write(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStreamI)(nil).Write), arg0) } @@ -184,7 +188,7 @@ func (m *MockStreamI) closeForShutdown(arg0 error) { } // closeForShutdown indicates an expected call of closeForShutdown. -func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0) } @@ -212,7 +216,7 @@ func (m *MockStreamI) handleResetStreamFrame(arg0 *wire.ResetStreamFrame) error } // handleResetStreamFrame indicates an expected call of handleResetStreamFrame. -func (mr *MockStreamIMockRecorder) handleResetStreamFrame(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) handleResetStreamFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleResetStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleResetStreamFrame), arg0) } @@ -224,7 +228,7 @@ func (m *MockStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) { } // handleStopSendingFrame indicates an expected call of handleStopSendingFrame. -func (mr *MockStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) handleStopSendingFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockStreamI)(nil).handleStopSendingFrame), arg0) } @@ -238,7 +242,7 @@ func (m *MockStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error { } // handleStreamFrame indicates an expected call of handleStreamFrame. -func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleStreamFrame), arg0) } @@ -268,7 +272,7 @@ func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.Vers } // popStreamFrame indicates an expected call of popStreamFrame. -func (mr *MockStreamIMockRecorder) popStreamFrame(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) popStreamFrame(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockStreamI)(nil).popStreamFrame), arg0, arg1) } @@ -280,7 +284,7 @@ func (m *MockStreamI) updateSendWindow(arg0 protocol.ByteCount) { } // updateSendWindow indicates an expected call of updateSendWindow. -func (mr *MockStreamIMockRecorder) updateSendWindow(arg0 interface{}) *gomock.Call { +func (mr *MockStreamIMockRecorder) updateSendWindow(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockStreamI)(nil).updateSendWindow), arg0) } diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go index dc4a39eb..3dacd419 100644 --- a/mock_stream_manager_test.go +++ b/mock_stream_manager_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: StreamManager) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager +// // Package quic is a generated GoMock package. package quic @@ -46,7 +50,7 @@ func (m *MockStreamManager) AcceptStream(arg0 context.Context) (Stream, error) { } // AcceptStream indicates an expected call of AcceptStream. -func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0) } @@ -61,7 +65,7 @@ func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (ReceiveStream } // AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0) } @@ -73,7 +77,7 @@ func (m *MockStreamManager) CloseWithError(arg0 error) { } // CloseWithError indicates an expected call of CloseWithError. -func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0) } @@ -87,7 +91,7 @@ func (m *MockStreamManager) DeleteStream(arg0 protocol.StreamID) error { } // DeleteStream indicates an expected call of DeleteStream. -func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0) } @@ -102,7 +106,7 @@ func (m *MockStreamManager) GetOrOpenReceiveStream(arg0 protocol.StreamID) (rece } // GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream. -func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0) } @@ -117,7 +121,7 @@ func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStr } // GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream. -func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0) } @@ -129,7 +133,7 @@ func (m *MockStreamManager) HandleMaxStreamsFrame(arg0 *wire.MaxStreamsFrame) { } // HandleMaxStreamsFrame indicates an expected call of HandleMaxStreamsFrame. -func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0) } @@ -159,7 +163,7 @@ func (m *MockStreamManager) OpenStreamSync(arg0 context.Context) (Stream, error) } // OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0) } @@ -189,7 +193,7 @@ func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (SendStream, } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0) } @@ -213,7 +217,7 @@ func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) { } // UpdateLimits indicates an expected call of UpdateLimits. -func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0) } diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go index 94606942..87a05a48 100644 --- a/mock_stream_sender_test.go +++ b/mock_stream_sender_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: StreamSender) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender +// // Package quic is a generated GoMock package. package quic @@ -42,7 +46,7 @@ func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID) { } // onHasStreamData indicates an expected call of onHasStreamData. -func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 interface{}) *gomock.Call { +func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0) } @@ -54,7 +58,7 @@ func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) { } // onStreamCompleted indicates an expected call of onStreamCompleted. -func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call { +func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0) } @@ -66,7 +70,7 @@ func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) { } // queueControlFrame indicates an expected call of queueControlFrame. -func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 interface{}) *gomock.Call { +func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "queueControlFrame", reflect.TypeOf((*MockStreamSender)(nil).queueControlFrame), arg0) } diff --git a/mock_token_store_test.go b/mock_token_store_test.go index 8576a3e2..3d251052 100644 --- a/mock_token_store_test.go +++ b/mock_token_store_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: TokenStore) - +// +// Generated by this command: +// +// mockgen -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_token_store_test.go github.com/quic-go/quic-go TokenStore +// // Package quic is a generated GoMock package. package quic @@ -42,7 +46,7 @@ func (m *MockTokenStore) Pop(arg0 string) *ClientToken { } // Pop indicates an expected call of Pop. -func (mr *MockTokenStoreMockRecorder) Pop(arg0 interface{}) *gomock.Call { +func (mr *MockTokenStoreMockRecorder) Pop(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pop", reflect.TypeOf((*MockTokenStore)(nil).Pop), arg0) } @@ -54,7 +58,7 @@ func (m *MockTokenStore) Put(arg0 string, arg1 *ClientToken) { } // Put indicates an expected call of Put. -func (mr *MockTokenStoreMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockTokenStoreMockRecorder) Put(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockTokenStore)(nil).Put), arg0, arg1) } diff --git a/mock_unpacker_test.go b/mock_unpacker_test.go index 272585c0..372e4ac3 100644 --- a/mock_unpacker_test.go +++ b/mock_unpacker_test.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: Unpacker) - +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_unpacker_test.go github.com/quic-go/quic-go Unpacker +// // Package quic is a generated GoMock package. package quic @@ -46,7 +50,7 @@ func (m *MockUnpacker) UnpackLongHeader(arg0 *wire.Header, arg1 time.Time, arg2 } // UnpackLongHeader indicates an expected call of UnpackLongHeader. -func (mr *MockUnpackerMockRecorder) UnpackLongHeader(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockUnpackerMockRecorder) UnpackLongHeader(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), arg0, arg1, arg2, arg3) } @@ -64,7 +68,7 @@ func (m *MockUnpacker) UnpackShortHeader(arg0 time.Time, arg1 []byte) (protocol. } // UnpackShortHeader indicates an expected call of UnpackShortHeader. -func (mr *MockUnpackerMockRecorder) UnpackShortHeader(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockUnpackerMockRecorder) UnpackShortHeader(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackShortHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackShortHeader), arg0, arg1) } diff --git a/server_test.go b/server_test.go index 6e50de82..6944a81e 100644 --- a/server_test.go +++ b/server_test.go @@ -1412,7 +1412,7 @@ var _ = Describe("Server", func() { _ protocol.VersionNumber, ) quicConn { conn := NewMockQUICConn(mockCtrl) - var calls []*gomock.Call + var calls []any calls = append(calls, conn.EXPECT().handlePacket(initial)) for _, p := range zeroRTTPackets { calls = append(calls, conn.EXPECT().handlePacket(p)) From 2b290742f6b91718681cc38ae1b5c989aef5ec25 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 17 Oct 2023 12:50:40 +0700 Subject: [PATCH 43/43] fix IPv4 ECN control message length on FreeBSD (#4110) --- sys_conn_helper_freebsd.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sys_conn_helper_freebsd.go b/sys_conn_helper_freebsd.go index 5e635b18..a53ca2ea 100644 --- a/sys_conn_helper_freebsd.go +++ b/sys_conn_helper_freebsd.go @@ -14,7 +14,7 @@ const ( ipv4PKTINFO = 0x7 ) -const ecnIPv4DataLen = 4 +const ecnIPv4DataLen = 1 const batchSize = 8