fix decoding of packet numbers in different packet number spaces

When QUIC abandoned the "stream 0" design, it introduced separate
packet number spaces for packets with different encryption levels.
Packet number compression now also works per packet number space.
The current code doesn't lead to any problems if the peer starts sending
with packet number 0, as we only exchange a few packets in the Initial
and the Handshake packet number space and there's nothing to compress.
It might lead to problems if the peer starts with a large packet number
in one space (which is allowed by the spec), and then starts with a
small packet number in another packet number space.
This commit is contained in:
Marten Seemann 2020-11-23 20:46:28 +07:00
parent ed1956f5a9
commit 95334202c5
9 changed files with 125 additions and 57 deletions

View file

@ -6,6 +6,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls" "github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD { func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD {
@ -50,6 +51,7 @@ func (s *longHeaderSealer) Overhead() int {
type longHeaderOpener struct { type longHeaderOpener struct {
aead cipher.AEAD aead cipher.AEAD
headerProtector headerProtector headerProtector headerProtector
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
// use a single slice to avoid allocations // use a single slice to avoid allocations
nonceBuf []byte nonceBuf []byte
@ -65,12 +67,18 @@ func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) Long
} }
} }
func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN)
}
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13. // The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV. // It uses the nonce provided here and XOR it with the IV.
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad) dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
if err != nil { if err == nil {
o.highestRcvdPN = utils.MaxPacketNumber(o.highestRcvdPN, pn)
} else {
err = ErrDecryptionFailed err = ErrDecryptionFailed
} }
return dec, err return dec, err

View file

@ -58,6 +58,22 @@ var _ = Describe("Long Header AEAD", func() {
_, err := opener.Open(nil, encrypted, 0x42, ad) _, err := opener.Open(nil, encrypted, 0x42, ad)
Expect(err).To(MatchError(ErrDecryptionFailed)) Expect(err).To(MatchError(ErrDecryptionFailed))
}) })
It("decodes the packet number", func() {
sealer, opener := getSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
_, err := opener.Open(nil, encrypted, 0x1337, ad)
Expect(err).ToNot(HaveOccurred())
Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338))
})
It("ignores packets it can't decrypt for packet number derivation", func() {
sealer, opener := getSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
_, err := opener.Open(nil, encrypted[:len(encrypted)-1], 0x1337, ad)
Expect(err).To(HaveOccurred())
Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38))
})
}) })
Context("header encryption", func() { Context("header encryption", func() {

View file

@ -33,12 +33,14 @@ type headerDecryptor interface {
// LongHeaderOpener opens a long header packet // LongHeaderOpener opens a long header packet
type LongHeaderOpener interface { type LongHeaderOpener interface {
headerDecryptor headerDecryptor
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error) Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error)
} }
// ShortHeaderOpener opens a short header packet // ShortHeaderOpener opens a short header packet
type ShortHeaderOpener interface { type ShortHeaderOpener interface {
headerDecryptor headerDecryptor
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error)
} }

View file

@ -37,6 +37,7 @@ type updatableAEAD struct {
firstRcvdWithCurrentKey protocol.PacketNumber firstRcvdWithCurrentKey protocol.PacketNumber
firstSentWithCurrentKey protocol.PacketNumber firstSentWithCurrentKey protocol.PacketNumber
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
numRcvdWithCurrentKey uint64 numRcvdWithCurrentKey uint64
numSentWithCurrentKey uint64 numSentWithCurrentKey uint64
rcvAEAD cipher.AEAD rcvAEAD cipher.AEAD
@ -153,6 +154,10 @@ func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSu
} }
} }
func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN)
}
func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
dec, err := a.open(dst, src, rcvTime, pn, kp, ad) dec, err := a.open(dst, src, rcvTime, pn, kp, ad)
if err == ErrDecryptionFailed { if err == ErrDecryptionFailed {
@ -161,6 +166,9 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac
return nil, qerr.AEADLimitReached return nil, qerr.AEADLimitReached
} }
} }
if err == nil {
a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn)
}
return dec, err return dec, err
} }

View file

@ -117,6 +117,20 @@ var _ = Describe("Updatable AEAD", func() {
Expect(err).To(MatchError(ErrDecryptionFailed)) Expect(err).To(MatchError(ErrDecryptionFailed))
}) })
It("decodes the packet number", func() {
encrypted := server.Seal(nil, msg, 0x1337, ad)
_, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338))
})
It("ignores packets it can't decrypt for packet number derivation", func() {
encrypted := server.Seal(nil, msg, 0x1337, ad)
_, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).To(HaveOccurred())
Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38))
})
It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() { It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() {
client.invalidPacketLimit = 10 client.invalidPacketLimit = 10
for i := 0; i < 9; i++ { for i := 0; i < 9; i++ {

View file

@ -34,6 +34,20 @@ func (m *MockLongHeaderOpener) EXPECT() *MockLongHeaderOpenerMockRecorder {
return m.recorder return m.recorder
} }
// DecodePacketNumber mocks base method
func (m *MockLongHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) protocol.PacketNumber {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DecodePacketNumber", arg0, arg1)
ret0, _ := ret[0].(protocol.PacketNumber)
return ret0
}
// DecodePacketNumber indicates an expected call of DecodePacketNumber
func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecodePacketNumber), arg0, arg1)
}
// DecryptHeader mocks base method // DecryptHeader mocks base method
func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -35,6 +35,20 @@ func (m *MockShortHeaderOpener) EXPECT() *MockShortHeaderOpenerMockRecorder {
return m.recorder return m.recorder
} }
// DecodePacketNumber mocks base method
func (m *MockShortHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) protocol.PacketNumber {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DecodePacketNumber", arg0, arg1)
ret0, _ := ret[0].(protocol.PacketNumber)
return ret0
}
// DecodePacketNumber indicates an expected call of DecodePacketNumber
func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecodePacketNumber), arg0, arg1)
}
// DecryptHeader mocks base method // DecryptHeader mocks base method
func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -7,7 +7,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
) )
@ -44,8 +43,6 @@ type unpackedPacket struct {
type packetUnpacker struct { type packetUnpacker struct {
cs handshake.CryptoSetup cs handshake.CryptoSetup
largestRcvdPacketNumber protocol.PacketNumber
version protocol.VersionNumber version protocol.VersionNumber
} }
@ -112,9 +109,6 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte
} }
} }
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, extHdr.PacketNumber)
return &unpackedPacket{ return &unpackedPacket{
hdr: extHdr, hdr: extHdr,
packetNumber: extHdr.PacketNumber, packetNumber: extHdr.PacketNumber,
@ -132,6 +126,7 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene
return nil, nil, parseErr return nil, nil, parseErr
} }
extHdrLen := extHdr.ParsedLen() extHdrLen := extHdr.ParsedLen()
extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen]) decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -155,6 +150,7 @@ func (u *packetUnpacker) unpackShortHeaderPacket(
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, nil, parseErr return nil, nil, parseErr
} }
extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen)
extHdrLen := extHdr.ParsedLen() extHdrLen := extHdr.ParsedLen()
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen]) decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen])
if err != nil { if err != nil {
@ -172,11 +168,6 @@ func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data
if err != nil && err != wire.ErrInvalidReservedBits { if err != nil && err != wire.ErrInvalidReservedBits {
return nil, &headerParseError{err: err} return nil, &headerParseError{err: err}
} }
extHdr.PacketNumber = protocol.DecodePacketNumber(
extHdr.PacketNumberLen,
u.largestRcvdPacketNumber,
extHdr.PacketNumber,
)
return extHdr, err return extHdr, err
} }

View file

@ -5,14 +5,14 @@ import (
"errors" "errors"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -97,9 +97,12 @@ var _ = Describe("Packet Unpacker", func() {
} }
hdr, hdrRaw := getHeader(extHdr) hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockLongHeaderOpener(mockCtrl) opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetInitialOpener().Return(opener, nil) gomock.InOrder(
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) cs.EXPECT().GetInitialOpener().Return(opener, nil),
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
@ -115,20 +118,45 @@ var _ = Describe("Packet Unpacker", func() {
DestConnectionID: connID, DestConnectionID: connID,
Version: version, Version: version,
}, },
PacketNumber: 2, PacketNumber: 20,
PacketNumberLen: 3, PacketNumberLen: 2,
} }
hdr, hdrRaw := getHeader(extHdr) hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockLongHeaderOpener(mockCtrl) opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().Get0RTTOpener().Return(opener, nil) gomock.InOrder(
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) cs.EXPECT().Get0RTTOpener().Return(opener, nil),
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT)) Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT))
Expect(packet.data).To(Equal([]byte("decrypted"))) Expect(packet.data).To(Equal([]byte("decrypted")))
}) })
It("opens short header packets", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
KeyPhase: protocol.KeyPhaseOne,
PacketNumber: 99,
PacketNumberLen: protocol.PacketNumberLen4,
}
hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
now := time.Now()
gomock.InOrder(
cs.EXPECT().Get1RTTOpener().Return(opener, nil),
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.Unpack(hdr, now, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption1RTT))
Expect(packet.data).To(Equal([]byte("decrypted")))
})
It("returns the error when getting the sealer fails", func() { It("returns the error when getting the sealer fails", func() {
extHdr := &wire.ExtendedHeader{ extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID}, Header: wire.Header{DestConnectionID: connID},
@ -157,6 +185,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockLongHeaderOpener(mockCtrl) opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, qerr.CryptoBufferExceeded) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, qerr.CryptoBufferExceeded)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(qerr.CryptoBufferExceeded)) Expect(err).To(MatchError(qerr.CryptoBufferExceeded))
@ -178,6 +207,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockLongHeaderOpener(mockCtrl) opener := mocks.NewMockLongHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().GetHandshakeOpener().Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
@ -194,6 +224,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
@ -210,6 +241,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrDecryptionFailed)) Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
@ -247,7 +279,8 @@ var _ = Describe("Packet Unpacker", func() {
pnBytes[i] ^= 0xff // invert the packet number bytes pnBytes[i] ^= 0xff // invert the packet number bytes
} }
}), }),
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1337), origHdrRaw).Return([]byte{0}, nil), opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2).Return(protocol.PacketNumber(0x7331)),
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x7331), origHdrRaw).Return([]byte{0}, nil),
) )
data := hdrRaw data := hdrRaw
for i := 1; i <= 100; i++ { for i := 1; i <= 100; i++ {
@ -255,38 +288,6 @@ var _ = Describe("Packet Unpacker", func() {
} }
packet, err := unpacker.Unpack(hdr, time.Now(), data) packet, err := unpacker.Unpack(hdr, time.Now(), data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337))) Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x7331)))
})
It("decodes the packet number", func() {
rcvTime := time.Now().Add(-time.Hour)
firstHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337,
PacketNumberLen: 2,
KeyPhase: protocol.KeyPhaseOne,
}
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), rcvTime, firstHdr.PacketNumber, protocol.KeyPhaseOne, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw := getHeader(firstHdr)
packet, err := unpacker.Unpack(hdr, rcvTime, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
// the real packet number is 0x1338, but only the last byte is sent
secondHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x38,
PacketNumberLen: 1,
KeyPhase: protocol.KeyPhaseZero,
}
// expect the call with the decoded packet number
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), rcvTime, protocol.PacketNumber(0x1338), protocol.KeyPhaseZero, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw = getHeader(secondHdr)
packet, err = unpacker.Unpack(hdr, rcvTime, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338)))
}) })
}) })