diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index bdb6bb98..b3b73951 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -34,7 +34,7 @@ const ( type receivedPacketHandler struct { initialPackets *receivedPacketTracker handshakePackets *receivedPacketTracker - oneRTTPackets *receivedPacketTracker + appDataPackets *receivedPacketTracker } var _ ReceivedPacketHandler = &receivedPacketHandler{} @@ -48,7 +48,7 @@ func NewReceivedPacketHandler( return &receivedPacketHandler{ initialPackets: newReceivedPacketTracker(rttStats, logger, version), handshakePackets: newReceivedPacketTracker(rttStats, logger, version), - oneRTTPackets: newReceivedPacketTracker(rttStats, logger, version), + appDataPackets: newReceivedPacketTracker(rttStats, logger, version), } } @@ -63,8 +63,9 @@ func (h *receivedPacketHandler) ReceivedPacket( h.initialPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) case protocol.EncryptionHandshake: h.handshakePackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) - case protocol.Encryption1RTT: - h.oneRTTPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) + case protocol.Encryption0RTT, protocol.Encryption1RTT: + // TODO: implement a check that the client doesn't switch back to 0-RTT after sending a 1-RTT packet + h.appDataPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) default: panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel)) } @@ -72,7 +73,7 @@ func (h *receivedPacketHandler) ReceivedPacket( // only to be used with 1-RTT packets func (h *receivedPacketHandler) IgnoreBelow(pn protocol.PacketNumber) { - h.oneRTTPackets.IgnoreBelow(pn) + h.appDataPackets.IgnoreBelow(pn) } func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { @@ -94,7 +95,7 @@ func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { if h.handshakePackets != nil { handshakeAlarm = h.handshakePackets.GetAlarmTimeout() } - oneRTTAlarm := h.oneRTTPackets.GetAlarmTimeout() + oneRTTAlarm := h.appDataPackets.GetAlarmTimeout() return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm) } @@ -110,7 +111,8 @@ func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel) * ack = h.handshakePackets.GetAckFrame() } case protocol.Encryption1RTT: - return h.oneRTTPackets.GetAckFrame() + // 0-RTT packets can't contain ACK frames + return h.appDataPackets.GetAckFrame() default: return nil } diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index a057db2b..65edc26b 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -48,6 +48,16 @@ var _ = Describe("Received Packet Handler", func() { Expect(oneRTTAck.DelayTime).To(BeNumerically("~", time.Second, 50*time.Millisecond)) }) + It("uses the same packet number space for 0-RTT and 1-RTT packets", func() { + sendTime := time.Now().Add(-time.Second) + handler.ReceivedPacket(2, protocol.Encryption0RTT, sendTime, true) + handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true) + ack := handler.GetAckFrame(protocol.Encryption1RTT) + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(HaveLen(1)) + Expect(ack.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 2, Largest: 3})) + }) + It("drops Initial packets", func() { sendTime := time.Now().Add(-time.Second) handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)