From 40a993e31ceb01187cce8e192c98280c0a77fc6f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 10 Aug 2019 21:40:10 +0700 Subject: [PATCH] check that the client doesn't switch back to 0-RTT after sending 1-RTT --- internal/ackhandler/interfaces.go | 2 +- .../ackhandler/received_packet_handler.go | 17 +++++++-- .../received_packet_handler_test.go | 38 +++++++++++++------ .../ackhandler/received_packet_handler.go | 6 ++- session.go | 3 +- 5 files changed, 46 insertions(+), 20 deletions(-) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 02979faa..c57099bc 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -57,7 +57,7 @@ type SentPacketHandler interface { // ReceivedPacketHandler handles ACKs needed to send for incoming packets type ReceivedPacketHandler interface { - ReceivedPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) + ReceivedPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error IgnoreBelow(protocol.PacketNumber) DropPackets(protocol.EncryptionLevel) diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index b3b73951..bc5680de 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -35,6 +35,8 @@ type receivedPacketHandler struct { initialPackets *receivedPacketTracker handshakePackets *receivedPacketTracker appDataPackets *receivedPacketTracker + + lowest1RTTPacket protocol.PacketNumber } var _ ReceivedPacketHandler = &receivedPacketHandler{} @@ -49,6 +51,7 @@ func NewReceivedPacketHandler( initialPackets: newReceivedPacketTracker(rttStats, logger, version), handshakePackets: newReceivedPacketTracker(rttStats, logger, version), appDataPackets: newReceivedPacketTracker(rttStats, logger, version), + lowest1RTTPacket: protocol.InvalidPacketNumber, } } @@ -57,18 +60,26 @@ func (h *receivedPacketHandler) ReceivedPacket( encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool, -) { +) error { switch encLevel { case protocol.EncryptionInitial: h.initialPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) case protocol.EncryptionHandshake: h.handshakePackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) - case protocol.Encryption0RTT, protocol.Encryption1RTT: - // TODO: implement a check that the client doesn't switch back to 0-RTT after sending a 1-RTT packet + case protocol.Encryption0RTT: + if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket { + return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket) + } + h.appDataPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) + case protocol.Encryption1RTT: + if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket { + h.lowest1RTTPacket = pn + } h.appDataPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) default: panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel)) } + return nil } // only to be used with 1-RTT packets diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index 65edc26b..7cd8b293 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -25,12 +25,12 @@ var _ = Describe("Received Packet Handler", func() { It("generates ACKs for different packet number spaces", func() { sendTime := time.Now().Add(-time.Second) - handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true) - handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true) - handler.ReceivedPacket(5, protocol.Encryption1RTT, sendTime, true) - handler.ReceivedPacket(3, protocol.EncryptionInitial, sendTime, true) - handler.ReceivedPacket(2, protocol.EncryptionHandshake, sendTime, true) - handler.ReceivedPacket(4, protocol.Encryption1RTT, sendTime, true) + Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(5, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(3, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(4, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) initialAck := handler.GetAckFrame(protocol.EncryptionInitial) Expect(initialAck).ToNot(BeNil()) Expect(initialAck.AckRanges).To(HaveLen(1)) @@ -50,18 +50,32 @@ var _ = Describe("Received Packet Handler", func() { It("uses the same packet number space for 0-RTT and 1-RTT packets", func() { sendTime := time.Now().Add(-time.Second) - handler.ReceivedPacket(2, protocol.Encryption0RTT, sendTime, true) - handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true) + Expect(handler.ReceivedPacket(2, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) ack := handler.GetAckFrame(protocol.Encryption1RTT) Expect(ack).ToNot(BeNil()) Expect(ack.AckRanges).To(HaveLen(1)) Expect(ack.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 2, Largest: 3})) }) + It("rejects 0-RTT packets with higher packet numbers than 1-RTT packets", func() { + sendTime := time.Now() + Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(11, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(12, protocol.Encryption0RTT, sendTime, true)).To(MatchError("received packet number 12 on a 0-RTT packet after receiving 11 on a 1-RTT packet")) + }) + + It("allows reordered 0-RTT packets", func() { + sendTime := time.Now() + Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(12, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(11, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + }) + It("drops Initial packets", func() { sendTime := time.Now().Add(-time.Second) - handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true) - handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true) + Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) Expect(handler.GetAckFrame(protocol.EncryptionInitial)).ToNot(BeNil()) handler.DropPackets(protocol.EncryptionInitial) Expect(handler.GetAckFrame(protocol.EncryptionInitial)).To(BeNil()) @@ -70,8 +84,8 @@ var _ = Describe("Received Packet Handler", func() { It("drops Handshake packets", func() { sendTime := time.Now().Add(-time.Second) - handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true) - handler.ReceivedPacket(2, protocol.Encryption1RTT, sendTime, true) + Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).ToNot(BeNil()) handler.DropPackets(protocol.EncryptionInitial) Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).To(BeNil()) diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go index 5d8f6169..640b725a 100644 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -89,9 +89,11 @@ func (mr *MockReceivedPacketHandlerMockRecorder) IgnoreBelow(arg0 interface{}) * } // ReceivedPacket mocks base method -func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 protocol.EncryptionLevel, arg2 time.Time, arg3 bool) { +func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 protocol.EncryptionLevel, arg2 time.Time, arg3 bool) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 } // ReceivedPacket indicates an expected call of ReceivedPacket diff --git a/session.go b/session.go index 0c40f1cf..0765424f 100644 --- a/session.go +++ b/session.go @@ -821,8 +821,7 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time }) } - s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, packet.encryptionLevel, rcvTime, isAckEliciting) - return nil + return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, packet.encryptionLevel, rcvTime, isAckEliciting) } func (s *session) handleFrame(f wire.Frame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {