diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index 13e18d26..54eda9b7 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -6,6 +6,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qtls" + "github.com/lucas-clemente/quic-go/internal/utils" ) func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD { @@ -50,6 +51,7 @@ func (s *longHeaderSealer) Overhead() int { type longHeaderOpener struct { aead cipher.AEAD headerProtector headerProtector + highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) // use a single slice to avoid allocations nonceBuf []byte @@ -65,12 +67,18 @@ func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) Long } } +func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { + return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN) +} + func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) // The AEAD we're using here will be the qtls.aeadAESGCM13. // It uses the nonce provided here and XOR it with the IV. dec, err := o.aead.Open(dst, o.nonceBuf, src, ad) - if err != nil { + if err == nil { + o.highestRcvdPN = utils.MaxPacketNumber(o.highestRcvdPN, pn) + } else { err = ErrDecryptionFailed } return dec, err diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index aca94d6a..20110737 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -58,6 +58,22 @@ var _ = Describe("Long Header AEAD", func() { _, err := opener.Open(nil, encrypted, 0x42, ad) Expect(err).To(MatchError(ErrDecryptionFailed)) }) + + It("decodes the packet number", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x1337, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) + }) + + It("ignores packets it can't decrypt for packet number derivation", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted[:len(encrypted)-1], 0x1337, ad) + Expect(err).To(HaveOccurred()) + Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) + }) }) Context("header encryption", func() { diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 90b7238a..112f6c25 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -33,12 +33,14 @@ type headerDecryptor interface { // LongHeaderOpener opens a long header packet type LongHeaderOpener interface { headerDecryptor + DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error) } // ShortHeaderOpener opens a short header packet type ShortHeaderOpener interface { headerDecryptor + DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error) } diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index fb3ae9a3..f28f3e03 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -37,6 +37,7 @@ type updatableAEAD struct { firstRcvdWithCurrentKey protocol.PacketNumber firstSentWithCurrentKey protocol.PacketNumber + highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) numRcvdWithCurrentKey uint64 numSentWithCurrentKey uint64 rcvAEAD cipher.AEAD @@ -153,6 +154,10 @@ func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSu } } +func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { + return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN) +} + func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { dec, err := a.open(dst, src, rcvTime, pn, kp, ad) if err == ErrDecryptionFailed { @@ -161,6 +166,9 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac return nil, qerr.AEADLimitReached } } + if err == nil { + a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn) + } return dec, err } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 0c115cdc..8726c3b0 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -117,6 +117,20 @@ var _ = Describe("Updatable AEAD", func() { Expect(err).To(MatchError(ErrDecryptionFailed)) }) + It("decodes the packet number", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) + }) + + It("ignores packets it can't decrypt for packet number derivation", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).To(HaveOccurred()) + Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) + }) + It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() { client.invalidPacketLimit = 10 for i := 0; i < 9; i++ { diff --git a/internal/mocks/long_header_opener.go b/internal/mocks/long_header_opener.go index d99bd827..c5610adf 100644 --- a/internal/mocks/long_header_opener.go +++ b/internal/mocks/long_header_opener.go @@ -34,6 +34,20 @@ func (m *MockLongHeaderOpener) EXPECT() *MockLongHeaderOpenerMockRecorder { return m.recorder } +// DecodePacketNumber mocks base method +func (m *MockLongHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) protocol.PacketNumber { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DecodePacketNumber", arg0, arg1) + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// DecodePacketNumber indicates an expected call of DecodePacketNumber +func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) +} + // DecryptHeader mocks base method func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { m.ctrl.T.Helper() diff --git a/internal/mocks/short_header_opener.go b/internal/mocks/short_header_opener.go index ab6ad374..fe81572e 100644 --- a/internal/mocks/short_header_opener.go +++ b/internal/mocks/short_header_opener.go @@ -35,6 +35,20 @@ func (m *MockShortHeaderOpener) EXPECT() *MockShortHeaderOpenerMockRecorder { return m.recorder } +// DecodePacketNumber mocks base method +func (m *MockShortHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) protocol.PacketNumber { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DecodePacketNumber", arg0, arg1) + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// DecodePacketNumber indicates an expected call of DecodePacketNumber +func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) +} + // DecryptHeader mocks base method func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { m.ctrl.T.Helper() diff --git a/packet_unpacker.go b/packet_unpacker.go index 6e98f961..f78c6a63 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -7,7 +7,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -43,8 +42,6 @@ type unpackedPacket struct { type packetUnpacker struct { cs handshake.CryptoSetup - largestRcvdPacketNumber protocol.PacketNumber - version protocol.VersionNumber } @@ -111,9 +108,6 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte } } - // Only do this after decrypting, so we are sure the packet is not attacker-controlled - u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, extHdr.PacketNumber) - return &unpackedPacket{ hdr: extHdr, packetNumber: extHdr.PacketNumber, @@ -131,6 +125,7 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene return nil, nil, parseErr } extHdrLen := extHdr.ParsedLen() + extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen]) if err != nil { return nil, nil, err @@ -154,6 +149,7 @@ func (u *packetUnpacker) unpackShortHeaderPacket( if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { return nil, nil, parseErr } + extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) extHdrLen := extHdr.ParsedLen() decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen]) if err != nil { @@ -171,11 +167,6 @@ func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data if err != nil && err != wire.ErrInvalidReservedBits { return nil, &headerParseError{err: err} } - extHdr.PacketNumber = protocol.DecodePacketNumber( - extHdr.PacketNumberLen, - u.largestRcvdPacketNumber, - extHdr.PacketNumber, - ) return extHdr, err } diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index d68b0da2..51eabf85 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -5,14 +5,14 @@ import ( "errors" "time" - "github.com/lucas-clemente/quic-go/internal/qerr" - - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -97,9 +97,12 @@ var _ = Describe("Packet Unpacker", func() { } hdr, hdrRaw := getHeader(extHdr) opener := mocks.NewMockLongHeaderOpener(mockCtrl) - cs.EXPECT().GetInitialOpener().Return(opener, nil) - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil) + gomock.InOrder( + cs.EXPECT().GetInitialOpener().Return(opener, nil), + opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()), + opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)), + opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil), + ) packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) Expect(err).ToNot(HaveOccurred()) Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) @@ -115,20 +118,45 @@ var _ = Describe("Packet Unpacker", func() { DestConnectionID: connID, Version: version, }, - PacketNumber: 2, - PacketNumberLen: 3, + PacketNumber: 20, + PacketNumberLen: 2, } hdr, hdrRaw := getHeader(extHdr) opener := mocks.NewMockLongHeaderOpener(mockCtrl) - cs.EXPECT().Get0RTTOpener().Return(opener, nil) - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil) + gomock.InOrder( + cs.EXPECT().Get0RTTOpener().Return(opener, nil), + opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()), + opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)), + opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil), + ) packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) Expect(err).ToNot(HaveOccurred()) Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT)) Expect(packet.data).To(Equal([]byte("decrypted"))) }) + It("opens short header packets", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: connID}, + KeyPhase: protocol.KeyPhaseOne, + PacketNumber: 99, + PacketNumberLen: protocol.PacketNumberLen4, + } + hdr, hdrRaw := getHeader(extHdr) + opener := mocks.NewMockShortHeaderOpener(mockCtrl) + now := time.Now() + gomock.InOrder( + cs.EXPECT().Get1RTTOpener().Return(opener, nil), + opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()), + opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)), + opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil), + ) + packet, err := unpacker.Unpack(hdr, now, append(hdrRaw, payload...)) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.encryptionLevel).To(Equal(protocol.Encryption1RTT)) + Expect(packet.data).To(Equal([]byte("decrypted"))) + }) + It("returns the error when getting the sealer fails", func() { extHdr := &wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: connID}, @@ -157,6 +185,7 @@ var _ = Describe("Packet Unpacker", func() { opener := mocks.NewMockLongHeaderOpener(mockCtrl) cs.EXPECT().GetHandshakeOpener().Return(opener, nil) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) + opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, qerr.CryptoBufferExceeded) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(qerr.CryptoBufferExceeded)) @@ -178,6 +207,7 @@ var _ = Describe("Packet Unpacker", func() { opener := mocks.NewMockLongHeaderOpener(mockCtrl) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) cs.EXPECT().GetHandshakeOpener().Return(opener, nil) + opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) @@ -194,6 +224,7 @@ var _ = Describe("Packet Unpacker", func() { opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) cs.EXPECT().Get1RTTOpener().Return(opener, nil) + opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) @@ -210,6 +241,7 @@ var _ = Describe("Packet Unpacker", func() { opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) cs.EXPECT().Get1RTTOpener().Return(opener, nil) + opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(handshake.ErrDecryptionFailed)) @@ -247,7 +279,8 @@ var _ = Describe("Packet Unpacker", func() { pnBytes[i] ^= 0xff // invert the packet number bytes } }), - opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1337), origHdrRaw).Return([]byte{0}, nil), + opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2).Return(protocol.PacketNumber(0x7331)), + opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x7331), origHdrRaw).Return([]byte{0}, nil), ) data := hdrRaw for i := 1; i <= 100; i++ { @@ -255,38 +288,6 @@ var _ = Describe("Packet Unpacker", func() { } packet, err := unpacker.Unpack(hdr, time.Now(), data) Expect(err).ToNot(HaveOccurred()) - Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337))) - }) - - It("decodes the packet number", func() { - rcvTime := time.Now().Add(-time.Hour) - firstHdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - PacketNumber: 0x1337, - PacketNumberLen: 2, - KeyPhase: protocol.KeyPhaseOne, - } - opener := mocks.NewMockShortHeaderOpener(mockCtrl) - cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2) - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), gomock.Any(), rcvTime, firstHdr.PacketNumber, protocol.KeyPhaseOne, gomock.Any()).Return([]byte{0}, nil) - hdr, hdrRaw := getHeader(firstHdr) - packet, err := unpacker.Unpack(hdr, rcvTime, append(hdrRaw, payload...)) - Expect(err).ToNot(HaveOccurred()) - Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337))) - // the real packet number is 0x1338, but only the last byte is sent - secondHdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - PacketNumber: 0x38, - PacketNumberLen: 1, - KeyPhase: protocol.KeyPhaseZero, - } - // expect the call with the decoded packet number - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), gomock.Any(), rcvTime, protocol.PacketNumber(0x1338), protocol.KeyPhaseZero, gomock.Any()).Return([]byte{0}, nil) - hdr, hdrRaw = getHeader(secondHdr) - packet, err = unpacker.Unpack(hdr, rcvTime, append(hdrRaw, payload...)) - Expect(err).ToNot(HaveOccurred()) - Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338))) + Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x7331))) }) })