diff --git a/packet_unpacker.go b/packet_unpacker.go index 93fb206a..6a9d5fbf 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -11,6 +11,7 @@ import ( ) type unpackedPacket struct { + packetNumber protocol.PacketNumber encryptionLevel protocol.EncryptionLevel frames []wire.Frame } @@ -40,7 +41,7 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker { } func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) { - hdr.PacketNumber = protocol.DecodePacketNumber( + pn := protocol.DecodePacketNumber( hdr.PacketNumberLen, u.largestRcvdPacketNumber, hdr.PacketNumber, @@ -55,16 +56,16 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, d var err error switch hdr.Type { case protocol.PacketTypeInitial: - decrypted, err = u.aead.OpenInitial(buf, data, hdr.PacketNumber, headerBinary) + decrypted, err = u.aead.OpenInitial(buf, data, pn, headerBinary) encryptionLevel = protocol.EncryptionInitial case protocol.PacketTypeHandshake: - decrypted, err = u.aead.OpenHandshake(buf, data, hdr.PacketNumber, headerBinary) + decrypted, err = u.aead.OpenHandshake(buf, data, pn, headerBinary) encryptionLevel = protocol.EncryptionHandshake default: if hdr.IsLongHeader { return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) } - decrypted, err = u.aead.Open1RTT(buf, data, hdr.PacketNumber, headerBinary) + decrypted, err = u.aead.Open1RTT(buf, data, pn, headerBinary) encryptionLevel = protocol.Encryption1RTT } if err != nil { @@ -72,7 +73,7 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, d } // Only do this after decrypting, so we are sure the packet is not attacker-controlled - u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, hdr.PacketNumber) + u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, pn) fs, err := u.parseFrames(decrypted) if err != nil { @@ -80,6 +81,7 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, d } return &unpackedPacket{ + packetNumber: pn, encryptionLevel: encryptionLevel, frames: fs, }, nil diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 69a4f9d5..a69f6ecd 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -69,8 +69,9 @@ var _ = Describe("Packet Unpacker", func() { PacketNumberLen: 2, } aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil) - _, err := unpacker.Unpack(firstHdr.Raw, firstHdr, nil) + packet, err := unpacker.Unpack(firstHdr.Raw, firstHdr, nil) Expect(err).ToNot(HaveOccurred()) + Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337))) // the real packet number is 0x1338, but only the last byte is sent secondHdr := &wire.ExtendedHeader{ PacketNumber: 0x38, @@ -78,8 +79,9 @@ var _ = Describe("Packet Unpacker", func() { } // expect the call with the decoded packet number aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil) - _, err = unpacker.Unpack(secondHdr.Raw, secondHdr, nil) + packet, err = unpacker.Unpack(secondHdr.Raw, secondHdr, nil) Expect(err).ToNot(HaveOccurred()) + Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338))) }) It("unpacks the frames", func() { diff --git a/session.go b/session.go index 7d15b4e5..76ffe799 100644 --- a/session.go +++ b/session.go @@ -504,19 +504,16 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { } packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) - if s.logger.Debug() { - if err != nil { - s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data), hdr.DestConnectionID) - } else { - s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s, %s", hdr.PacketNumber, len(p.data), hdr.DestConnectionID, packet.encryptionLevel) - } - hdr.Log(s.logger) - } // if the decryption failed, this might be a packet sent by an attacker if err != nil { return err } + if s.logger.Debug() { + s.logger.Debugf("<- Reading packet %#x (%d bytes) for connection %s, %s", packet.packetNumber, len(p.data), p.hdr.DestConnectionID, packet.encryptionLevel) + hdr.Log(s.logger) + } + // The server can change the source connection ID with the first Handshake packet. if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) { s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", hdr.SrcConnectionID) @@ -541,12 +538,12 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { // The session will be closed and recreated as soon as the crypto setup processed the HRR. if hdr.Type != protocol.PacketTypeRetry { isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames) - if err := s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, p.rcvTime, isRetransmittable); err != nil { + if err := s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, p.rcvTime, isRetransmittable); err != nil { return err } } - return s.handleFrames(packet.frames, hdr.PacketNumber, packet.encryptionLevel) + return s.handleFrames(packet.frames, packet.packetNumber, packet.encryptionLevel) } func (s *session) handleFrames(fs []wire.Frame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error { diff --git a/session_test.go b/session_test.go index d188b7f5..239afe31 100644 --- a/session_test.go +++ b/session_test.go @@ -464,13 +464,13 @@ var _ = Describe("Session", func() { It("informs the ReceivedPacketHandler", func() { hdr := &wire.ExtendedHeader{ Raw: []byte("raw header"), - PacketNumber: 5, - PacketNumberLen: protocol.PacketNumberLen4, + PacketNumber: 0x37, + PacketNumberLen: protocol.PacketNumberLen1, } rcvTime := time.Now().Add(-10 * time.Second) - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{packetNumber: 0x1337}, nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - rph.EXPECT().ReceivedPacket(protocol.PacketNumber(5), rcvTime, false) + rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false) sess.receivedPacketHandler = rph Expect(sess.handlePacketImpl(&receivedPacket{ rcvTime: rcvTime,