diff --git a/connection.go b/connection.go index 1920bb2a..d4e44be6 100644 --- a/connection.go +++ b/connection.go @@ -1832,7 +1832,7 @@ func (s *connection) sendPackets(now time.Time) error { if err != nil { return err } - ecn := s.sentPacketHandler.ECNMode() + ecn := s.sentPacketHandler.ECNMode(true) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false) s.registerPackedShortHeaderPacket(p, ecn, now) s.sendQueue.Send(buf, 0, ecn) @@ -1855,7 +1855,7 @@ func (s *connection) sendPackets(now time.Time) error { return err } s.sentFirstPacket = true - if err := s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(), now); err != nil { + if err := s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now); err != nil { return err } sendMode := s.sentPacketHandler.SendMode(now) @@ -1876,8 +1876,8 @@ func (s *connection) sendPackets(now time.Time) error { func (s *connection) sendPacketsWithoutGSO(now time.Time) error { for { buf := getPacketBuffer() - ecn := s.sentPacketHandler.ECNMode() - if _, err := s.appendOnePacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil { + ecn := s.sentPacketHandler.ECNMode(true) + if _, err := s.appendOneShortHeaderPacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil { if err == errNothingToPack { buf.Release() return nil @@ -1912,8 +1912,8 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { for { var dontSendMore bool - ecn := s.sentPacketHandler.ECNMode() - size, err := s.appendOnePacket(buf, maxSize, ecn, now) + ecn := s.sentPacketHandler.ECNMode(true) + size, err := s.appendOneShortHeaderPacket(buf, maxSize, ecn, now) if err != nil { if err != errNothingToPack { return err @@ -1971,8 +1971,8 @@ func (s *connection) resetPacingDeadline() { } func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { - ecn := s.sentPacketHandler.ECNMode() if !s.handshakeConfirmed { + ecn := s.sentPacketHandler.ECNMode(false) packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { return err @@ -1983,6 +1983,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { return s.sendPackedCoalescedPacket(packet, ecn, time.Now()) } + ecn := s.sentPacketHandler.ECNMode(true) p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { if err == errNothingToPack { @@ -2024,12 +2025,12 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) } - return s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(), now) + return s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now) } -// appendOnePacket appends a new packet to the given packetBuffer. +// appendOneShortHeaderPacket appends a new packet to the given packetBuffer. // If there was nothing to pack, the returned size is 0. -func (s *connection) appendOnePacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) { +func (s *connection) appendOneShortHeaderPacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) { startLen := buf.Len() p, err := s.packer.AppendPacket(buf, maxSize, s.version) if err != nil { @@ -2106,7 +2107,7 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { if err != nil { return nil, err } - ecn := s.sentPacketHandler.ECNMode() + ecn := s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()) s.logCoalescedPacket(packet, ecn) return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn) } diff --git a/connection_test.go b/connection_test.go index 828f4ced..c4e067af 100644 --- a/connection_test.go +++ b/connection_test.go @@ -613,7 +613,7 @@ var _ = Describe("Connection", func() { conn.sendQueue = newSendQueue(sconn) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes() - sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes() + sph.EXPECT().ECNMode(true).Return(protocol.ECT1).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() // only expect a single SentPacket() call sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) @@ -1206,7 +1206,7 @@ var _ = Describe("Connection", func() { AfterEach(func() { streamManager.EXPECT().CloseWithError(gomock.Any()) - sph.EXPECT().ECNMode().Return(protocol.ECNCE).MaxTimes(1) + sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECNCE).MaxTimes(1) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() @@ -1234,7 +1234,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().Return(protocol.ECNNon).AnyTimes() + sph.EXPECT().ECNMode(true).Return(protocol.ECNNon).AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) runConn() p := shortHeaderPacket{ @@ -1262,7 +1262,7 @@ var _ = Describe("Connection", func() { conn.handshakeConfirmed = true sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(true).AnyTimes() runConn() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) @@ -1274,7 +1274,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) - sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes() done := make(chan struct{}) packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) { close(done) }) runConn() @@ -1287,7 +1287,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) @@ -1341,7 +1341,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) sph.EXPECT().QueueProbePacket(encLevel) - sph.EXPECT().ECNMode() + sph.EXPECT().ECNMode(gomock.Any()) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) @@ -1363,7 +1363,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) - sph.EXPECT().ECNMode() + sph.EXPECT().ECNMode(gomock.Any()) sph.EXPECT().QueueProbePacket(encLevel).Return(false) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) @@ -1407,7 +1407,7 @@ var _ = Describe("Connection", func() { AfterEach(func() { // make the go routine return - sph.EXPECT().ECNMode().MaxTimes(1) + sph.EXPECT().ECNMode(gomock.Any()).MaxTimes(1) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() @@ -1422,7 +1422,7 @@ var _ = Describe("Connection", func() { It("sends multiple packets one by one immediately", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) - sph.EXPECT().ECNMode().Times(2) + sph.EXPECT().ECNMode(gomock.Any()).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) @@ -1447,7 +1447,7 @@ var _ = Describe("Connection", func() { It("sends multiple packets one by one immediately, with GSO", func() { enableGSO() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) - sph.EXPECT().ECNMode().Times(3) + sph.EXPECT().ECNMode(true).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) rand.Read(payload1) @@ -1474,7 +1474,7 @@ var _ = Describe("Connection", func() { enableGSO() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) - sph.EXPECT().ECNMode().Times(2) + sph.EXPECT().ECNMode(true).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) rand.Read(payload1) @@ -1499,7 +1499,7 @@ var _ = Describe("Connection", func() { It("sends multiple packets, when the pacer allows immediate sending", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) - sph.EXPECT().ECNMode().Times(2) + sph.EXPECT().ECNMode(gomock.Any()).Times(2) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() @@ -1518,7 +1518,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) - sph.EXPECT().ECNMode() + sph.EXPECT().ECNMode(gomock.Any()) packer.EXPECT().PackAckOnlyPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{PacketNumber: 123}, getPacketBuffer(), nil) sender.EXPECT().WouldBlock().AnyTimes() @@ -1539,7 +1539,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) - sph.EXPECT().ECNMode().Times(2) + sph.EXPECT().ECNMode(gomock.Any()).Times(2) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) @@ -1557,13 +1557,13 @@ var _ = Describe("Connection", func() { pacingDelay := scaleDuration(100 * time.Millisecond) gomock.InOrder( sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), - sph.EXPECT().ECNMode(), + sph.EXPECT().ECNMode(gomock.Any()), expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")), sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), - sph.EXPECT().ECNMode(), + sph.EXPECT().ECNMode(gomock.Any()), expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 101}, []byte("packet101")), sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), @@ -1587,7 +1587,7 @@ var _ = Describe("Connection", func() { It("sends multiple packets at once", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) - sph.EXPECT().ECNMode().Times(3) + sph.EXPECT().ECNMode(gomock.Any()).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) for pn := protocol.PacketNumber(1000); pn < 1003; pn++ { @@ -1628,7 +1628,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().WouldBlock().AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) @@ -1653,7 +1653,7 @@ var _ = Describe("Connection", func() { conn.handlePacket(receivedPacket{buffer: getPacketBuffer()}) }) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) }) @@ -1667,7 +1667,7 @@ var _ = Describe("Connection", func() { It("stops sending when the send queue is full", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) - sph.EXPECT().ECNMode() + sph.EXPECT().ECNMode(gomock.Any()) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock() @@ -1688,7 +1688,7 @@ var _ = Describe("Connection", func() { // now make room in the send queue sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1001}, []byte("packet1001")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) @@ -1703,7 +1703,7 @@ var _ = Describe("Connection", func() { It("doesn't set a pacing timer when there is no data to send", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) // don't EXPECT any calls to mconn.Write() @@ -1723,7 +1723,7 @@ var _ = Describe("Connection", func() { conn.config.DisablePathMTUDiscovery = false sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) - sph.EXPECT().ECNMode() + sph.EXPECT().ECNMode(true) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock().AnyTimes() @@ -1774,7 +1774,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph @@ -1803,7 +1803,7 @@ var _ = Describe("Connection", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(1234), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) @@ -1855,7 +1855,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes() + sph.EXPECT().ECNMode(false).Return(protocol.ECT1).AnyTimes() sph.EXPECT().TimeUntilSend().Return(time.Now()).AnyTimes() gomock.InOrder( sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(13), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionInitial, protocol.ECT1, protocol.ByteCount(123), gomock.Any()), @@ -1968,7 +1968,7 @@ var _ = Describe("Connection", func() { It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SetHandshakeConfirmed() diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index ba224228..ba8cbbda 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -29,8 +29,7 @@ type SentPacketHandler interface { // only to be called once the handshake is complete QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ - ECNMode() protocol.ECN - + ECNMode(isShortHeaderPacket bool) protocol.ECN // isShortHeaderPacket should only be true for non-coalesced 1-RTT packets PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index da5d0718..53aba0dd 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -717,10 +717,13 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { return h.alarm } -func (h *sentPacketHandler) ECNMode() protocol.ECN { +func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN { if !h.enableECN { return protocol.ECNUnsupported } + if !isShortHeaderPacket { + return protocol.ECNNon + } // TODO: implement ECN logic return protocol.ECNNon } diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 3479d9e6..137ac2ef 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -50,17 +50,17 @@ func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomo } // ECNMode mocks base method. -func (m *MockSentPacketHandler) ECNMode() protocol.ECN { +func (m *MockSentPacketHandler) ECNMode(arg0 bool) protocol.ECN { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ECNMode") + ret := m.ctrl.Call(m, "ECNMode", arg0) ret0, _ := ret[0].(protocol.ECN) return ret0 } // ECNMode indicates an expected call of ECNMode. -func (mr *MockSentPacketHandlerMockRecorder) ECNMode() *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ECNMode(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode), arg0) } // GetLossDetectionTimeout mocks base method. diff --git a/packet_packer.go b/packet_packer.go index 0483f35b..64081c68 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -71,6 +71,11 @@ type coalescedPacket struct { shortHdrPacket *shortHeaderPacket } +// IsOnlyShortHeaderPacket says if this packet only contains a short header packet (and no long header packets). +func (p *coalescedPacket) IsOnlyShortHeaderPacket() bool { + return len(p.longHdrPackets) == 0 && p.shortHdrPacket != nil +} + func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel { //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data). switch p.header.Type { diff --git a/packet_packer_test.go b/packet_packer_test.go index 6dba31f9..e68906f3 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -655,6 +655,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(packet).ToNot(BeNil()) Expect(packet.longHdrPackets).To(HaveLen(1)) + Expect(packet.IsOnlyShortHeaderPacket()).To(BeFalse()) // cut off the tag that the mock sealer added // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] hdr, _, _, err := wire.ParsePacket(packet.buffer.Data) @@ -874,6 +875,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) + Expect(p.IsOnlyShortHeaderPacket()).To(BeFalse()) parsePacket(p.buffer.Data) }) @@ -1047,6 +1049,7 @@ var _ = Describe("Packet packer", func() { packer.retransmissionQueue.addAppData(&wire.PingFrame{}) p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) + Expect(p.IsOnlyShortHeaderPacket()).To(BeFalse()) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -1422,6 +1425,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) + Expect(p.IsOnlyShortHeaderPacket()).To(BeTrue()) Expect(p.longHdrPackets).To(BeEmpty()) Expect(p.shortHdrPacket).ToNot(BeNil()) packet := p.shortHdrPacket @@ -1448,6 +1452,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) + Expect(p.IsOnlyShortHeaderPacket()).To(BeTrue()) Expect(p.longHdrPackets).To(BeEmpty()) Expect(p.shortHdrPacket).ToNot(BeNil()) packet := p.shortHdrPacket