From 5dd6d91c11eb6ecf871c1a134ececb71c2e73a9e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 12 Aug 2023 10:08:40 +0800 Subject: [PATCH] send and track packets with ECN markings --- connection.go | 65 +++--- connection_test.go | 200 +++++++++--------- internal/ackhandler/interfaces.go | 4 +- internal/ackhandler/sent_packet_handler.go | 6 + .../ackhandler/sent_packet_handler_test.go | 2 +- .../mocks/ackhandler/sent_packet_handler.go | 22 +- mock_raw_conn_test.go | 9 +- mock_send_conn_test.go | 9 +- mock_sender_test.go | 9 +- packet_handler_map.go | 2 +- send_conn.go | 15 +- send_conn_test.go | 17 +- send_queue.go | 12 +- send_queue_test.go | 36 ++-- server.go | 6 +- sys_conn.go | 2 +- sys_conn_df_windows.go | 4 - sys_conn_no_oob.go | 5 - sys_conn_oob.go | 9 +- sys_conn_oob_test.go | 30 ++- transport.go | 6 +- 21 files changed, 264 insertions(+), 206 deletions(-) diff --git a/connection.go b/connection.go index 1f3dadfd..f234ea3f 100644 --- a/connection.go +++ b/connection.go @@ -1830,9 +1830,10 @@ func (s *connection) sendPackets(now time.Time) error { if err != nil { return err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) - s.registerPackedShortHeaderPacket(p, now) - s.sendQueue.Send(buf, 0) + ecn := s.sentPacketHandler.ECNMode() + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false) + s.registerPackedShortHeaderPacket(p, ecn, now) + s.sendQueue.Send(buf, 0, ecn) // This is kind of a hack. We need to trigger sending again somehow. s.pacingDeadline = deadlineSendImmediately return nil @@ -1852,7 +1853,7 @@ func (s *connection) sendPackets(now time.Time) error { return err } s.sentFirstPacket = true - if err := s.sendPackedCoalescedPacket(packet, now); err != nil { + if err := s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(), now); err != nil { return err } sendMode := s.sentPacketHandler.SendMode(now) @@ -1873,7 +1874,8 @@ func (s *connection) sendPackets(now time.Time) error { func (s *connection) sendPacketsWithoutGSO(now time.Time) error { for { buf := getPacketBuffer() - if _, err := s.appendPacket(buf, s.mtuDiscoverer.CurrentSize(), now); err != nil { + ecn := s.sentPacketHandler.ECNMode() + if _, err := s.appendOnePacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil { if err == errNothingToPack { buf.Release() return nil @@ -1881,7 +1883,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error { return err } - s.sendQueue.Send(buf, 0) + s.sendQueue.Send(buf, 0, ecn) if s.sendQueue.WouldBlock() { return nil @@ -1908,7 +1910,8 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { for { var dontSendMore bool - size, err := s.appendPacket(buf, maxSize, now) + ecn := s.sentPacketHandler.ECNMode() + size, err := s.appendOnePacket(buf, maxSize, ecn, now) if err != nil { if err != errNothingToPack { return err @@ -1938,7 +1941,7 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { continue } - s.sendQueue.Send(buf, uint16(maxSize)) + s.sendQueue.Send(buf, uint16(maxSize), ecn) if dontSendMore { return nil @@ -1966,6 +1969,7 @@ func (s *connection) resetPacingDeadline() { } func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { + ecn := s.sentPacketHandler.ECNMode() if !s.handshakeConfirmed { packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { @@ -1974,7 +1978,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { if packet == nil { return nil } - return s.sendPackedCoalescedPacket(packet, time.Now()) + return s.sendPackedCoalescedPacket(packet, ecn, time.Now()) } p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version) @@ -1984,9 +1988,9 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { } return err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) - s.registerPackedShortHeaderPacket(p, now) - s.sendQueue.Send(buf, 0) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false) + s.registerPackedShortHeaderPacket(p, ecn, now) + s.sendQueue.Send(buf, 0, ecn) return nil } @@ -2018,24 +2022,24 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) } - return s.sendPackedCoalescedPacket(packet, now) + return s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(), now) } -// appendPacket appends a new packet to the given packetBuffer. +// appendOnePacket appends a new packet to the given packetBuffer. // If there was nothing to pack, the returned size is 0. -func (s *connection) appendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time) (protocol.ByteCount, error) { +func (s *connection) appendOnePacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) { startLen := buf.Len() p, err := s.packer.AppendPacket(buf, maxSize, s.version) if err != nil { return 0, err } size := buf.Len() - startLen - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, size, false) - s.registerPackedShortHeaderPacket(p, now) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, size, false) + s.registerPackedShortHeaderPacket(p, ecn, now) return size, nil } -func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now time.Time) { +func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, now time.Time) { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) { s.firstAckElicitingPacketAfterIdleSentTime = now } @@ -2044,12 +2048,12 @@ func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now ti if p.Ack != nil { largestAcked = p.Ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) + s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket) s.connIDManager.SentPacket() } -func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) error { - s.logCoalescedPacket(packet) +func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN, now time.Time) error { + s.logCoalescedPacket(packet, ecn) for _, p := range packet.longHdrPackets { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { s.firstAckElicitingPacketAfterIdleSentTime = now @@ -2058,7 +2062,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time if p.ack != nil { largestAcked = p.ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false) + s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false) if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake { // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. // See Section 4.9.1 of RFC 9001. @@ -2075,10 +2079,10 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time if p.Ack != nil { largestAcked = p.Ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) + s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket) } s.connIDManager.SentPacket() - s.sendQueue.Send(packet.buffer, 0) + s.sendQueue.Send(packet.buffer, 0, ecn) return nil } @@ -2100,8 +2104,9 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { if err != nil { return nil, err } - s.logCoalescedPacket(packet) - return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0) + ecn := s.sentPacketHandler.ECNMode() + s.logCoalescedPacket(packet, ecn) + return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn) } func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { @@ -2144,11 +2149,12 @@ func (s *connection) logShortHeaderPacket( pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit, + ecn protocol.ECN, size protocol.ByteCount, isCoalesced bool, ) { if s.logger.Debug() && !isCoalesced { - s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT", pn, size, s.logID) + s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, s.logID, ecn) } // quic-go logging if s.logger.Debug() { @@ -2191,7 +2197,7 @@ func (s *connection) logShortHeaderPacket( } } -func (s *connection) logCoalescedPacket(packet *coalescedPacket) { +func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) { if s.logger.Debug() { // There's a short period between dropping both Initial and Handshake keys and completion of the handshake, // during which we might call PackCoalescedPacket but just pack a short header packet. @@ -2204,6 +2210,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) { packet.shortHdrPacket.PacketNumber, packet.shortHdrPacket.PacketNumberLen, packet.shortHdrPacket.KeyPhase, + ecn, packet.shortHdrPacket.Length, false, ) @@ -2219,7 +2226,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) { s.logLongHeaderPacket(p) } if p := packet.shortHdrPacket; p != nil { - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true) } } diff --git a/connection_test.go b/connection_test.go index 4c02c492..828f4ced 100644 --- a/connection_test.go +++ b/connection_test.go @@ -454,7 +454,7 @@ var _ = Describe("Connection", func() { Expect(e.ErrorMessage).To(BeEmpty()) return &coalescedPacket{buffer: buffer}, nil }) - mconn.EXPECT().Write([]byte("connection close"), gomock.Any()) + mconn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { var appErr *ApplicationError @@ -475,7 +475,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -494,7 +494,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(expectedErr), tracer.EXPECT().Close(), @@ -516,7 +516,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(expectedErr), tracer.EXPECT().Close(), @@ -565,7 +565,7 @@ var _ = Describe("Connection", func() { close(returned) }() Consistently(returned).ShouldNot(BeClosed()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -609,13 +609,14 @@ var _ = Describe("Connection", func() { conn.handshakeConfirmed = true sconn := NewMockSendConn(mockCtrl) sconn.EXPECT().capabilities().AnyTimes() - sconn.EXPECT().Write(gomock.Any(), gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() + sconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() conn.sendQueue = newSendQueue(sconn) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes() + sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() // only expect a single SentPacket() call - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -837,7 +838,7 @@ var _ = Describe("Connection", func() { // make the go routine return tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("close")) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -872,7 +873,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("close")) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -908,7 +909,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("close")) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -930,7 +931,7 @@ var _ = Describe("Connection", func() { close(done) }() expectReplaceWithClosed() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) packet := getShortHeaderPacket(srcConnID, 0x42, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -958,7 +959,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.shutdown() Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -980,7 +981,7 @@ var _ = Describe("Connection", func() { close(done) }() expectReplaceWithClosed() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil)) @@ -1190,6 +1191,7 @@ var _ = Describe("Connection", func() { var ( connDone chan struct{} sender *MockSender + sph *mockackhandler.MockSentPacketHandler ) BeforeEach(func() { @@ -1198,14 +1200,17 @@ var _ = Describe("Connection", func() { sender.EXPECT().WouldBlock().AnyTimes() conn.sendQueue = sender connDone = make(chan struct{}) + sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) + conn.sentPacketHandler = sph }) AfterEach(func() { streamManager.EXPECT().CloseWithError(gomock.Any()) + sph.EXPECT().ECNMode().Return(protocol.ECNCE).MaxTimes(1) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() @@ -1226,12 +1231,11 @@ var _ = Describe("Connection", func() { It("sends packets", func() { conn.handshakeConfirmed = true - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - conn.sentPacketHandler = sph + sph.EXPECT().ECNMode().Return(protocol.ECNNon).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) runConn() p := shortHeaderPacket{ DestConnID: protocol.ParseConnectionID([]byte{1, 2, 3}), @@ -1243,7 +1247,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() sent := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) tracer.EXPECT().SentShortHeaderPacket(&logging.ShortHeader{ DestConnectionID: p.DestConnID, PacketNumber: p.PacketNumber, @@ -1256,6 +1260,9 @@ var _ = Describe("Connection", func() { It("doesn't send packets if there's nothing to send", func() { conn.handshakeConfirmed = true + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() runConn() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) @@ -1264,13 +1271,12 @@ var _ = Describe("Connection", func() { }) It("sends ACK only packets", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) + sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes() done := make(chan struct{}) packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) { close(done) }) - conn.sentPacketHandler = sph runConn() conn.scheduleSending() Eventually(done).Should(BeClosed()) @@ -1278,12 +1284,11 @@ var _ = Describe("Connection", func() { It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { conn.handshakeConfirmed = true - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - conn.sentPacketHandler = sph + sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 13}, []byte("foobar")) @@ -1291,7 +1296,7 @@ var _ = Describe("Connection", func() { conn.connFlowController = fc runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) @@ -1300,11 +1305,9 @@ var _ = Describe("Connection", func() { }) It("doesn't send when the SentPacketHandler doesn't allow it", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone).AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() - conn.sentPacketHandler = sph runConn() conn.scheduleSending() time.Sleep(50 * time.Millisecond) @@ -1333,21 +1336,19 @@ var _ = Describe("Connection", func() { }) It("sends a probe packet", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) sph.EXPECT().QueueProbePacket(encLevel) + sph.EXPECT().ECNMode() p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) { - Expect(pn).To(Equal(protocol.PacketNumber(123))) - }) + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) if enc == protocol.Encryption1RTT { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) } else { @@ -1358,21 +1359,18 @@ var _ = Describe("Connection", func() { }) It("sends a PING as a probe packet", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) + sph.EXPECT().ECNMode() sph.EXPECT().QueueProbePacket(encLevel).Return(false) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) { - Expect(pn).To(Equal(protocol.PacketNumber(123))) - }) - conn.sentPacketHandler = sph + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) if enc == protocol.Encryption1RTT { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) } else { @@ -1409,10 +1407,11 @@ var _ = Describe("Connection", func() { AfterEach(func() { // make the go routine return + sph.EXPECT().ECNMode().MaxTimes(1) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() @@ -1421,17 +1420,18 @@ var _ = Describe("Connection", func() { }) It("sends multiple packets one by one immediately", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) + sph.EXPECT().ECNMode().Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, []byte("packet11")) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16) { + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal([]byte("packet10"))) }) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16) { + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal([]byte("packet11"))) }) go func() { @@ -1446,7 +1446,8 @@ var _ = Describe("Connection", func() { It("sends multiple packets one by one immediately, with GSO", func() { enableGSO() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + sph.EXPECT().ECNMode().Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) rand.Read(payload1) @@ -1456,7 +1457,7 @@ var _ = Describe("Connection", func() { expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize())).Do(func(b *packetBuffer, l uint16) { + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal(append(payload1, payload2...))) }) go func() { @@ -1471,8 +1472,9 @@ var _ = Describe("Connection", func() { It("stops appending packets when a smaller packet is packed, with GSO", func() { enableGSO() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) + sph.EXPECT().ECNMode().Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) rand.Read(payload1) @@ -1481,7 +1483,7 @@ var _ = Describe("Connection", func() { expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize())).Do(func(b *packetBuffer, l uint16) { + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { Expect(b.Data).To(Equal(append(payload1, payload2...))) }) go func() { @@ -1495,12 +1497,13 @@ var _ = Describe("Connection", func() { }) It("sends multiple packets, when the pacer allows immediate sending", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) + sph.EXPECT().ECNMode().Times(2) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1512,13 +1515,14 @@ var _ = Describe("Connection", func() { }) It("allows an ACK to be sent when pacing limited", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) + sph.EXPECT().ECNMode() packer.EXPECT().PackAckOnlyPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{PacketNumber: 123}, getPacketBuffer(), nil) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1532,12 +1536,13 @@ var _ = Describe("Connection", func() { // when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck // we shouldn't send the ACK in the same run It("doesn't send an ACK right after becoming congestion limited", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) + sph.EXPECT().ECNMode().Times(2) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1552,19 +1557,21 @@ var _ = Describe("Connection", func() { pacingDelay := scaleDuration(100 * time.Millisecond) gomock.InOrder( sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), + sph.EXPECT().ECNMode(), expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), + sph.EXPECT().ECNMode(), expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 101}, []byte("packet101")), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)), ) written := make(chan struct{}, 2) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }).Times(2) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }).Times(2) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1578,8 +1585,9 @@ var _ = Describe("Connection", func() { }) It("sends multiple packets at once", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) + sph.EXPECT().ECNMode().Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) for pn := protocol.PacketNumber(1000); pn < 1003; pn++ { @@ -1587,7 +1595,7 @@ var _ = Describe("Connection", func() { } written := make(chan struct{}, 3) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }).Times(3) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }).Times(3) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1618,11 +1626,12 @@ var _ = Describe("Connection", func() { written := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) available <- struct{}{} Eventually(written).Should(BeClosed()) }) @@ -1639,14 +1648,15 @@ var _ = Describe("Connection", func() { written := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(time.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ByteCount, bool) { + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(time.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ECN, protocol.ByteCount, bool) { sph.EXPECT().ReceivedBytes(gomock.Any()) conn.handlePacket(receivedPacket{buffer: getPacketBuffer()}) }) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) conn.scheduleSending() time.Sleep(scaleDuration(50 * time.Millisecond)) @@ -1655,13 +1665,14 @@ var _ = Describe("Connection", func() { }) It("stops sending when the send queue is full", func() { - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) + sph.EXPECT().ECNMode() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock() sender.EXPECT().WouldBlock().Return(true).Times(2) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1675,12 +1686,13 @@ var _ = Describe("Connection", func() { time.Sleep(scaleDuration(50 * time.Millisecond)) // now make room in the send queue - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1001}, []byte("packet1001")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }) available <- struct{}{} Eventually(written).Should(Receive()) @@ -1691,6 +1703,7 @@ var _ = Describe("Connection", func() { It("doesn't set a pacing timer when there is no data to send", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) // don't EXPECT any calls to mconn.Write() @@ -1708,12 +1721,13 @@ var _ = Describe("Connection", func() { mtuDiscoverer := NewMockMTUDiscoverer(mockCtrl) conn.mtuDiscoverer = mtuDiscoverer conn.config.DisablePathMTUDiscovery = false - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) + sph.EXPECT().ECNMode() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }) mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true) ping := ackhandler.Frame{Frame: &wire.PingFrame{}} mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) @@ -1747,7 +1761,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) sender.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -1760,8 +1774,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1}, []byte("packet1")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) @@ -1776,7 +1791,7 @@ var _ = Describe("Connection", func() { time.Sleep(50 * time.Millisecond) // only EXPECT calls after scheduleSending is called written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() conn.scheduleSending() Eventually(written).Should(BeClosed()) @@ -1788,9 +1803,8 @@ var _ = Describe("Connection", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) { - Expect(pn).To(Equal(protocol.PacketNumber(1234))) - }) + sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(1234), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) @@ -1799,7 +1813,7 @@ var _ = Describe("Connection", func() { conn.receivedPacketHandler = rph written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() go func() { defer GinkgoRecover() @@ -1841,18 +1855,11 @@ var _ = Describe("Connection", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes() sph.EXPECT().TimeUntilSend().Return(time.Now()).AnyTimes() gomock.InOrder( - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, _ bool) { - Expect(encLevel).To(Equal(protocol.EncryptionInitial)) - Expect(pn).To(Equal(protocol.PacketNumber(13))) - Expect(size).To(BeEquivalentTo(123)) - }), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, _ bool) { - Expect(encLevel).To(Equal(protocol.EncryptionHandshake)) - Expect(pn).To(Equal(protocol.PacketNumber(37))) - Expect(size).To(BeEquivalentTo(1234)) - }), + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(13), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionInitial, protocol.ECT1, protocol.ByteCount(123), gomock.Any()), + sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(37), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionHandshake, protocol.ECT1, protocol.ByteCount(1234), gomock.Any()), ) gomock.InOrder( tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) { @@ -1864,7 +1871,7 @@ var _ = Describe("Connection", func() { ) sent := make(chan struct{}) - mconn.EXPECT().Write([]byte("foobar"), uint16(0)).Do(func([]byte, uint16) { close(sent) }) + mconn.EXPECT().Write([]byte("foobar"), uint16(0), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) { close(sent) }) go func() { defer GinkgoRecover() @@ -1881,7 +1888,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -1952,7 +1959,7 @@ var _ = Describe("Connection", func() { }() handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx).ShouldNot(BeClosed()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("handshake error")) Consistently(handshakeCtx).ShouldNot(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed()) @@ -1961,11 +1968,12 @@ var _ = Describe("Connection", func() { It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SetHandshakeConfirmed() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph done := make(chan struct{}) @@ -1987,7 +1995,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handleHandshakeComplete()).To(Succeed()) conn.run() }() @@ -2016,7 +2024,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2043,7 +2051,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() Expect(conn.CloseWithError(0x1337, testErr.Error())).To(Succeed()) @@ -2102,7 +2110,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2255,7 +2263,7 @@ var _ = Describe("Connection", func() { // make the go routine return expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) conn.shutdown() Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -2338,7 +2346,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2554,7 +2562,7 @@ var _ = Describe("Client Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any()).MaxTimes(1) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2851,7 +2859,7 @@ var _ = Describe("Client Connection", func() { packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) } cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()), tracer.EXPECT().Close(), diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 7b4eaa31..ba224228 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -10,7 +10,7 @@ import ( // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet - SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, isPathMTUProbePacket bool) + SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool) // ReceivedAck processes an ACK frame. // It does not store a copy of the frame. ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error) @@ -29,6 +29,8 @@ type SentPacketHandler interface { // only to be called once the handshake is complete QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ + ECNMode() protocol.ECN + PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index a03d9f53..3b972b66 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -228,6 +228,7 @@ func (h *sentPacketHandler) SentPacket( streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, + _ protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool, ) { @@ -712,6 +713,11 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { return h.alarm } +func (h *sentPacketHandler) ECNMode() protocol.ECN { + // TODO: implement ECN logic + return protocol.ECNNon +} + func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { pnSpace := h.getPacketNumberSpace(encLevel) pn := pnSpace.pns.Peek() diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 396014d7..614b7c85 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -106,7 +106,7 @@ var _ = Describe("SentPacketHandler", func() { } sentPacket := func(p *packet) { - handler.SentPacket(p.SendTime, p.PacketNumber, p.LargestAcked, p.StreamFrames, p.Frames, p.EncryptionLevel, p.Length, p.IsPathMTUProbePacket) + handler.SentPacket(p.SendTime, p.PacketNumber, p.LargestAcked, p.StreamFrames, p.Frames, p.EncryptionLevel, protocol.ECNNon, p.Length, p.IsPathMTUProbePacket) } expectInPacketHistory := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) { diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 24f5a157..3479d9e6 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -49,6 +49,20 @@ func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) } +// ECNMode mocks base method. +func (m *MockSentPacketHandler) ECNMode() protocol.ECN { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ECNMode") + ret0, _ := ret[0].(protocol.ECN) + return ret0 +} + +// ECNMode indicates an expected call of ECNMode. +func (mr *MockSentPacketHandlerMockRecorder) ECNMode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode)) +} + // GetLossDetectionTimeout mocks base method. func (m *MockSentPacketHandler) GetLossDetectionTimeout() time.Time { m.ctrl.T.Helper() @@ -176,15 +190,15 @@ func (mr *MockSentPacketHandlerMockRecorder) SendMode(arg0 interface{}) *gomock. } // SentPacket mocks base method. -func (m *MockSentPacketHandler) SentPacket(arg0 time.Time, arg1, arg2 protocol.PacketNumber, arg3 []ackhandler.StreamFrame, arg4 []ackhandler.Frame, arg5 protocol.EncryptionLevel, arg6 protocol.ByteCount, arg7 bool) { +func (m *MockSentPacketHandler) SentPacket(arg0 time.Time, arg1, arg2 protocol.PacketNumber, arg3 []ackhandler.StreamFrame, arg4 []ackhandler.Frame, arg5 protocol.EncryptionLevel, arg6 protocol.ECN, arg7 protocol.ByteCount, arg8 bool) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) } // SentPacket indicates an expected call of SentPacket. -func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7 interface{}) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) } // SetHandshakeConfirmed mocks base method. diff --git a/mock_raw_conn_test.go b/mock_raw_conn_test.go index 84d8f276..d462fc87 100644 --- a/mock_raw_conn_test.go +++ b/mock_raw_conn_test.go @@ -9,6 +9,7 @@ import ( reflect "reflect" time "time" + protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) @@ -93,18 +94,18 @@ func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Cal } // WritePacket mocks base method. -func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 uint16) (int, error) { +func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 uint16, arg4 protocol.ECN) (int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // WritePacket indicates an expected call of WritePacket. -func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3, arg4) } // capabilities mocks base method. diff --git a/mock_send_conn_test.go b/mock_send_conn_test.go index 529b9c58..6568f9e5 100644 --- a/mock_send_conn_test.go +++ b/mock_send_conn_test.go @@ -8,6 +8,7 @@ import ( net "net" reflect "reflect" + protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) @@ -77,17 +78,17 @@ func (mr *MockSendConnMockRecorder) RemoteAddr() *gomock.Call { } // Write mocks base method. -func (m *MockSendConn) Write(arg0 []byte, arg1 uint16) error { +func (m *MockSendConn) Write(arg0 []byte, arg1 uint16, arg2 protocol.ECN) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", arg0, arg1) + ret := m.ctrl.Call(m, "Write", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // Write indicates an expected call of Write. -func (mr *MockSendConnMockRecorder) Write(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockSendConnMockRecorder) Write(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1, arg2) } // capabilities mocks base method. diff --git a/mock_sender_test.go b/mock_sender_test.go index 40671562..67bb9d09 100644 --- a/mock_sender_test.go +++ b/mock_sender_test.go @@ -7,6 +7,7 @@ package quic import ( reflect "reflect" + protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) @@ -74,15 +75,15 @@ func (mr *MockSenderMockRecorder) Run() *gomock.Call { } // Send mocks base method. -func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16) { +func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16, arg2 protocol.ECN) { m.ctrl.T.Helper() - m.ctrl.Call(m, "Send", arg0, arg1) + m.ctrl.Call(m, "Send", arg0, arg1, arg2) } // Send indicates an expected call of Send. -func (mr *MockSenderMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockSenderMockRecorder) Send(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1, arg2) } // WouldBlock mocks base method. diff --git a/packet_handler_map.go b/packet_handler_map.go index 60b7cef9..d7d92377 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -29,7 +29,7 @@ type rawConn interface { // WritePacket writes a packet on the wire. // gsoSize is the size of a single packet, or 0 to disable GSO. // It is invalid to set gsoSize if capabilities.GSO is not set. - WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16) (int, error) + WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) LocalAddr() net.Addr SetReadDeadline(time.Time) error io.Closer diff --git a/send_conn.go b/send_conn.go index 030e0fba..4fda1469 100644 --- a/send_conn.go +++ b/send_conn.go @@ -9,7 +9,7 @@ import ( // A sendConn allows sending using a simple Write() on a non-connected packet conn. type sendConn interface { - Write(b []byte, gsoSize uint16) error + Write(b []byte, gsoSize uint16, ecn protocol.ECN) error Close() error LocalAddr() net.Addr RemoteAddr() net.Addr @@ -43,13 +43,6 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge } oob := info.OOB() - if remoteUDPAddr, ok := remote.(*net.UDPAddr); ok { - if remoteUDPAddr.IP.To4() != nil { - oob = appendIPv4ECNMsg(oob, protocol.ECT1) - } else { - oob = appendIPv6ECNMsg(oob, protocol.ECT1) - } - } // increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating l := len(oob) oob = append(oob, make([]byte, 64)...)[:l] @@ -62,8 +55,8 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge } } -func (c *sconn) Write(p []byte, gsoSize uint16) error { - _, err := c.WritePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize) +func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error { + _, err := c.WritePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize, ecn) if err != nil && isGSOError(err) { // disable GSO for future calls c.gotGSOError = true @@ -76,7 +69,7 @@ func (c *sconn) Write(p []byte, gsoSize uint16) error { if l > int(gsoSize) { l = int(gsoSize) } - if _, err := c.WritePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0); err != nil { + if _, err := c.WritePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0, ecn); err != nil { return err } p = p[l:] diff --git a/send_conn_test.go b/send_conn_test.go index 963f2482..bbac8fe7 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -5,6 +5,7 @@ import ( "net/netip" "runtime" + "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" . "github.com/onsi/ginkgo/v2" @@ -45,8 +46,8 @@ var _ = Describe("Connection (for sending packets)", func() { pi := packetInfo{addr: netip.IPv6Loopback()} Expect(pi.OOB()).ToNot(BeEmpty()) c := newSendConn(rawConn, remoteAddr, pi, utils.DefaultLogger) - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, pi.OOB(), uint16(0)) - Expect(c.Write([]byte("foobar"), 0)).To(Succeed()) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, pi.OOB(), uint16(0), protocol.ECT1) + Expect(c.Write([]byte("foobar"), 0, protocol.ECT1)).To(Succeed()) }) } @@ -55,8 +56,8 @@ var _ = Describe("Connection (for sending packets)", func() { rawConn.EXPECT().LocalAddr() rawConn.EXPECT().capabilities().AnyTimes() c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(3)) - Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(3), protocol.ECNCE) + Expect(c.Write([]byte("foobar"), 3, protocol.ECNCE)).To(Succeed()) }) if platformSupportsGSO { @@ -67,11 +68,11 @@ var _ = Describe("Connection (for sending packets)", func() { c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) Expect(c.capabilities().GSO).To(BeTrue()) gomock.InOrder( - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(4)).Return(0, errGSO), - rawConn.EXPECT().WritePacket([]byte("foob"), remoteAddr, gomock.Any(), uint16(0)).Return(4, nil), - rawConn.EXPECT().WritePacket([]byte("ar"), remoteAddr, gomock.Any(), uint16(0)).Return(2, nil), + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(4), protocol.ECNCE).Return(0, errGSO), + rawConn.EXPECT().WritePacket([]byte("foob"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(4, nil), + rawConn.EXPECT().WritePacket([]byte("ar"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(2, nil), ) - Expect(c.Write([]byte("foobar"), 4)).To(Succeed()) + Expect(c.Write([]byte("foobar"), 4, protocol.ECNCE)).To(Succeed()) Expect(c.capabilities().GSO).To(BeFalse()) }) } diff --git a/send_queue.go b/send_queue.go index 2da546e5..bde02334 100644 --- a/send_queue.go +++ b/send_queue.go @@ -1,8 +1,9 @@ package quic +import "github.com/quic-go/quic-go/internal/protocol" + type sender interface { - // Send sends a packet. GSO is only used if gsoSize > 0. - Send(p *packetBuffer, gsoSize uint16) + Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) Run() error WouldBlock() bool Available() <-chan struct{} @@ -12,6 +13,7 @@ type sender interface { type queueEntry struct { buf *packetBuffer gsoSize uint16 + ecn protocol.ECN } type sendQueue struct { @@ -39,9 +41,9 @@ func newSendQueue(conn sendConn) sender { // Send sends out a packet. It's guaranteed to not block. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Otherwise Send will panic. -func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16) { +func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) { select { - case h.queue <- queueEntry{buf: p, gsoSize: gsoSize}: + case h.queue <- queueEntry{buf: p, gsoSize: gsoSize, ecn: ecn}: // clear available channel if we've reached capacity if len(h.queue) == sendQueueCapacity { select { @@ -76,7 +78,7 @@ func (h *sendQueue) Run() error { // make sure that all queued packets are actually sent out shouldClose = true case e := <-h.queue: - if err := h.conn.Write(e.buf.Data, e.gsoSize); err != nil { + if err := h.conn.Write(e.buf.Data, e.gsoSize, e.ecn); err != nil { // This additional check enables: // 1. Checking for "datagram too large" message from the kernel, as such, // 2. Path MTU discovery,and diff --git a/send_queue_test.go b/send_queue_test.go index 0ed7bea5..e8cb8bdc 100644 --- a/send_queue_test.go +++ b/send_queue_test.go @@ -3,6 +3,8 @@ package quic import ( "errors" + "github.com/quic-go/quic-go/internal/protocol" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "go.uber.org/mock/gomock" @@ -26,10 +28,10 @@ var _ = Describe("Send Queue", func() { It("sends a packet", func() { p := getPacket([]byte("foobar")) - q.Send(p, 10) // make sure the packet size is passed through to the conn + q.Send(p, 10, protocol.ECT1) // make sure the packet size is passed through to the conn written := make(chan struct{}) - c.EXPECT().Write([]byte("foobar"), uint16(10)).Do(func([]byte, uint16) { close(written) }) + c.EXPECT().Write([]byte("foobar"), uint16(10), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) { close(written) }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -45,19 +47,19 @@ var _ = Describe("Send Queue", func() { It("panics when Send() is called although there's no space in the queue", func() { for i := 0; i < sendQueueCapacity; i++ { Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) } Expect(q.WouldBlock()).To(BeTrue()) - Expect(func() { q.Send(getPacket([]byte("raboof")), 6) }).To(Panic()) + Expect(func() { q.Send(getPacket([]byte("raboof")), 6, protocol.ECNNon) }).To(Panic()) }) It("signals when sending is possible again", func() { Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar1")), 6) + q.Send(getPacket([]byte("foobar1")), 6, protocol.ECNNon) Consistently(q.Available()).ShouldNot(Receive()) // now start sending out packets. This should free up queue space. - c.EXPECT().Write(gomock.Any(), gomock.Any()).MinTimes(1).MaxTimes(2) + c.EXPECT().Write(gomock.Any(), gomock.Any(), protocol.ECNNon).MinTimes(1).MaxTimes(2) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -67,7 +69,7 @@ var _ = Describe("Send Queue", func() { Eventually(q.Available()).Should(Receive()) Expect(q.WouldBlock()).To(BeFalse()) - Expect(func() { q.Send(getPacket([]byte("foobar2")), 7) }).ToNot(Panic()) + Expect(func() { q.Send(getPacket([]byte("foobar2")), 7, protocol.ECNNon) }).ToNot(Panic()) q.Close() Eventually(done).Should(BeClosed()) @@ -77,7 +79,7 @@ var _ = Describe("Send Queue", func() { write := make(chan struct{}, 1) written := make(chan struct{}, 100) // now start sending out packets. This should free up queue space. - c.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, uint16) error { + c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, uint16, protocol.ECN) error { written <- struct{}{} <-write return nil @@ -92,19 +94,19 @@ var _ = Describe("Send Queue", func() { close(done) }() - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) <-written // now fill up the send queue for i := 0; i < sendQueueCapacity; i++ { Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) } // One more packet is queued when it's picked up by Run and written to the connection. // In this test, it's blocked on write channel in the mocked Write call. <-written Eventually(q.WouldBlock()).Should(BeFalse()) - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) Expect(q.WouldBlock()).To(BeTrue()) Consistently(q.Available()).ShouldNot(Receive()) @@ -130,15 +132,15 @@ var _ = Describe("Send Queue", func() { // the run loop exits if there is a write error testErr := errors.New("test error") - c.EXPECT().Write(gomock.Any(), gomock.Any()).Return(testErr) - q.Send(getPacket([]byte("foobar")), 6) + c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Return(testErr) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) Eventually(done).Should(BeClosed()) sent := make(chan struct{}) go func() { defer GinkgoRecover() - q.Send(getPacket([]byte("raboof")), 6) - q.Send(getPacket([]byte("quux")), 4) + q.Send(getPacket([]byte("raboof")), 6, protocol.ECNNon) + q.Send(getPacket([]byte("quux")), 4, protocol.ECNNon) close(sent) }() @@ -147,7 +149,7 @@ var _ = Describe("Send Queue", func() { It("blocks Close() until the packet has been sent out", func() { written := make(chan []byte) - c.EXPECT().Write(gomock.Any(), gomock.Any()).Do(func(p []byte, _ uint16) { written <- p }) + c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(p []byte, _ uint16, _ protocol.ECN) { written <- p }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -155,7 +157,7 @@ var _ = Describe("Send Queue", func() { close(done) }() - q.Send(getPacket([]byte("foobar")), 6) + q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon) closed := make(chan struct{}) go func() { diff --git a/server.go b/server.go index 92b5e91a..681e8956 100644 --- a/server.go +++ b/server.go @@ -745,7 +745,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0) + _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNNon) return err } @@ -844,7 +844,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0) + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNNon) return err } @@ -882,7 +882,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { if s.tracer != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNNon); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/sys_conn.go b/sys_conn.go index a72aead5..4d7a6577 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -104,7 +104,7 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16) (n int, err error) { +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, _ protocol.ECN) (n int, err error) { if gsoSize != 0 { panic("cannot use GSO with a basicConn") } diff --git a/sys_conn_df_windows.go b/sys_conn_df_windows.go index e56b7460..e27635ec 100644 --- a/sys_conn_df_windows.go +++ b/sys_conn_df_windows.go @@ -8,7 +8,6 @@ import ( "golang.org/x/sys/windows" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) @@ -53,6 +52,3 @@ func isRecvMsgSizeErr(err error) bool { // https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2 return errors.Is(err, windows.WSAEMSGSIZE) } - -func appendIPv4ECNMsg([]byte, protocol.ECN) []byte { return nil } -func appendIPv6ECNMsg([]byte, protocol.ECN) []byte { return nil } diff --git a/sys_conn_no_oob.go b/sys_conn_no_oob.go index c1fb6fb1..2a1f807e 100644 --- a/sys_conn_no_oob.go +++ b/sys_conn_no_oob.go @@ -5,8 +5,6 @@ package quic import ( "net" "net/netip" - - "github.com/quic-go/quic-go/internal/protocol" ) func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) { @@ -16,9 +14,6 @@ func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) { func inspectReadBuffer(any) (int, error) { return 0, nil } func inspectWriteBuffer(any) (int, error) { return 0, nil } -func appendIPv4ECNMsg([]byte, protocol.ECN) []byte { return nil } -func appendIPv6ECNMsg([]byte, protocol.ECN) []byte { return nil } - type packetInfo struct { addr netip.Addr } diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 6446d893..7b6052ba 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -229,7 +229,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { } // WritePacket writes a new packet. -func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16) (int, error) { +func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) { oob := packetInfoOOB if gsoSize > 0 { if !c.capabilities().GSO { @@ -237,6 +237,13 @@ func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gso } oob = appendUDPSegmentSizeMsg(oob, gsoSize) } + if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok { + if remoteUDPAddr.IP.To4() != nil { + oob = appendIPv4ECNMsg(oob, ecn) + } else { + oob = appendIPv6ECNMsg(oob, ecn) + } + } n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err } diff --git a/sys_conn_oob_test.go b/sys_conn_oob_test.go index 9d0f5e9d..40596421 100644 --- a/sys_conn_oob_test.go +++ b/sys_conn_oob_test.go @@ -53,7 +53,7 @@ var _ = Describe("OOB Conn Test", func() { return udpConn, packetChan } - Context("ECN conn", func() { + Context("reading ECN-marked packets", func() { sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr { conn, err := net.DialUDP(network, nil, addr) ExpectWithOffset(1, err).ToNot(HaveOccurred()) @@ -289,6 +289,27 @@ var _ = Describe("OOB Conn Test", func() { }) }) + Context("sending ECN-marked packets", func() { + It("sets the ECN control message", func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + c := &oobRecordingConn{UDPConn: udpConn} + oobConn, err := newConn(c, true) + Expect(err).ToNot(HaveOccurred()) + + oob := make([]byte, 0, 123) + oobConn.WritePacket([]byte("foobar"), addr, oob, 0, protocol.ECNCE) + Expect(c.oobs).To(HaveLen(1)) + oobMsg := c.oobs[0] + Expect(oobMsg).ToNot(BeEmpty()) + Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob + expected := appendIPv4ECNMsg([]byte{}, protocol.ECNCE) + Expect(oobMsg).To(Equal(expected)) + }) + }) + if platformSupportsGSO { Context("GSO", func() { It("appends the GSO control message", func() { @@ -301,14 +322,15 @@ var _ = Describe("OOB Conn Test", func() { Expect(err).ToNot(HaveOccurred()) Expect(oobConn.capabilities().GSO).To(BeTrue()) - oob := make([]byte, 0, 42) - oobConn.WritePacket([]byte("foobar"), addr, oob, 3) + oob := make([]byte, 0, 123) + oobConn.WritePacket([]byte("foobar"), addr, oob, 3, protocol.ECNCE) Expect(c.oobs).To(HaveLen(1)) oobMsg := c.oobs[0] Expect(oobMsg).ToNot(BeEmpty()) Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob expected := appendUDPSegmentSizeMsg([]byte{}, 3) - Expect(oobMsg).To(Equal(expected)) + // Check that the first control message is the OOB control message. + Expect(oobMsg[:len(expected)]).To(Equal(expected)) }) }) } diff --git a/transport.go b/transport.go index 41c97347..b841222e 100644 --- a/transport.go +++ b/transport.go @@ -228,7 +228,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { return 0, err } - return t.conn.WritePacket(b, addr, nil, 0) + return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNNon) } func (t *Transport) enqueueClosePacket(p closePacket) { @@ -246,7 +246,7 @@ func (t *Transport) runSendQueue() { case <-t.listening: return case p := <-t.closeQueue: - t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0) + t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNNon) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } @@ -414,7 +414,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0); err != nil { + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNNon); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } }