refactor the packet unpacker to remove version param from constructor

This commit is contained in:
Marten Seemann 2022-12-24 13:00:21 +13:00
parent d4a0793b6a
commit 56985cbfd1
5 changed files with 49 additions and 53 deletions

View file

@ -25,7 +25,7 @@ import (
) )
type unpacker interface { type unpacker interface {
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error)
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
} }
@ -366,7 +366,7 @@ var newConnection = func(
s.perspective, s.perspective,
s.version, s.version,
) )
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen, s.version) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
return s return s
} }
@ -476,7 +476,7 @@ var newClientConnection = func(
s.clientHelloWritten = clientHelloWritten s.clientHelloWritten = clientHelloWritten
s.cryptoStreamHandler = cs s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen, s.version) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.packer = newPacketPacker( s.packer = newPacketPacker(
srcConnID, srcConnID,
s.connIDManager.Get, s.connIDManager.Get,
@ -1009,7 +1009,7 @@ func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header)
return false return false
} }
packet, err := s.unpacker.UnpackLongHeader(hdr, p.rcvTime, p.data) packet, err := s.unpacker.UnpackLongHeader(hdr, p.rcvTime, p.data, s.version)
if err != nil { if err != nil {
wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr)) wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr))
return false return false

View file

@ -755,7 +755,7 @@ var _ = Describe("Connection", func() {
packet := getLongHeaderPacket(hdr, nil) packet := getLongHeaderPacket(hdr, nil)
packet.ecn = protocol.ECNCE packet.ecn = protocol.ECNCE
rcvTime := time.Now().Add(-10 * time.Second) rcvTime := time.Now().Add(-10 * time.Second)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().UnpackLongHeader(gomock.Any(), rcvTime, gomock.Any(), conn.version).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial, encryptionLevel: protocol.EncryptionInitial,
hdr: &unpackedHdr, hdr: &unpackedHdr,
data: []byte{0}, // one PADDING frame data: []byte{0}, // one PADDING frame
@ -801,7 +801,7 @@ var _ = Describe("Connection", func() {
}) })
It("drops a packet when unpacking fails", func() { It("drops a packet when unpacking fails", func() {
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(nil, handshake.ErrDecryptionFailed)
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
@ -998,7 +998,7 @@ var _ = Describe("Connection", func() {
Expect(srcConnID).ToNot(Equal(hdr2.SrcConnectionID)) Expect(srcConnID).ToNot(Equal(hdr2.SrcConnectionID))
// Send one packet, which might change the connection ID. // Send one packet, which might change the connection ID.
// only EXPECT one call to the unpacker // only EXPECT one call to the unpacker
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(&unpackedPacket{
encryptionLevel: protocol.Encryption1RTT, encryptionLevel: protocol.Encryption1RTT,
hdr: hdr1, hdr: hdr1,
data: []byte{0}, // one PADDING frame data: []byte{0}, // one PADDING frame
@ -1026,7 +1026,7 @@ var _ = Describe("Connection", func() {
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 1, PacketNumber: 1,
} }
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable) unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(nil, handshake.ErrKeysNotYetAvailable)
packet := getLongHeaderPacket(hdr, nil) packet := getLongHeaderPacket(hdr, nil)
tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake, packet.Size()) tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake, packet.Size())
Expect(conn.handlePacketImpl(packet)).To(BeFalse()) Expect(conn.handlePacketImpl(packet)).To(BeFalse())
@ -1067,7 +1067,7 @@ var _ = Describe("Connection", func() {
It("cuts packets to the right length", func() { It("cuts packets to the right length", func() {
hdrLen, packet := getPacketWithLength(srcConnID, 456) hdrLen, packet := getPacketWithLength(srcConnID, 456)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen + 456 - 3)) Expect(data).To(HaveLen(hdrLen + 456 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -1081,7 +1081,7 @@ var _ = Describe("Connection", func() {
It("handles coalesced packets", func() { It("handles coalesced packets", func() {
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -1093,7 +1093,7 @@ var _ = Describe("Connection", func() {
}, nil }, nil
}) })
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -1117,8 +1117,8 @@ var _ = Describe("Connection", func() {
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
gomock.InOrder( gomock.InOrder(
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable), unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(nil, handshake.ErrKeysNotYetAvailable),
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -1142,7 +1142,7 @@ var _ = Describe("Connection", func() {
wrongConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) wrongConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
Expect(srcConnID).ToNot(Equal(wrongConnID)) Expect(srcConnID).ToNot(Equal(wrongConnID))
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -2421,7 +2421,7 @@ var _ = Describe("Client Connection", func() {
It("changes the connection ID when receiving the first packet from the server", func() { It("changes the connection ID when receiving the first packet from the server", func() {
unpacker := NewMockUnpacker(mockCtrl) unpacker := NewMockUnpacker(mockCtrl)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.Encryption1RTT, encryptionLevel: protocol.Encryption1RTT,
hdr: &wire.ExtendedHeader{Header: *hdr}, hdr: &wire.ExtendedHeader{Header: *hdr},
@ -2469,7 +2469,7 @@ var _ = Describe("Client Connection", func() {
}) })
Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}))) Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})))
// now receive a packet with the original source connection ID // now receive a packet with the original source connection ID
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte) (*unpackedPacket, error) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
return &unpackedPacket{ return &unpackedPacket{
hdr: &wire.ExtendedHeader{Header: *hdr}, hdr: &wire.ExtendedHeader{Header: *hdr},
data: []byte{0}, data: []byte{0},
@ -2906,7 +2906,7 @@ var _ = Describe("Client Connection", func() {
Expect(hdr2.SrcConnectionID).ToNot(Equal(srcConnID)) Expect(hdr2.SrcConnectionID).ToNot(Equal(srcConnID))
// Send one packet, which might change the connection ID. // Send one packet, which might change the connection ID.
// only EXPECT one call to the unpacker // only EXPECT one call to the unpacker
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial, encryptionLevel: protocol.EncryptionInitial,
hdr: hdr1, hdr: hdr1,
data: []byte{0}, // one PADDING frame data: []byte{0}, // one PADDING frame

View file

@ -37,18 +37,18 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder {
} }
// UnpackLongHeader mocks base method. // UnpackLongHeader mocks base method.
func (m *MockUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { func (m *MockUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnpackLongHeader", hdr, rcvTime, data) ret := m.ctrl.Call(m, "UnpackLongHeader", hdr, rcvTime, data, v)
ret0, _ := ret[0].(*unpackedPacket) ret0, _ := ret[0].(*unpackedPacket)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// UnpackLongHeader indicates an expected call of UnpackLongHeader. // UnpackLongHeader indicates an expected call of UnpackLongHeader.
func (mr *MockUnpackerMockRecorder) UnpackLongHeader(hdr, rcvTime, data interface{}) *gomock.Call { func (mr *MockUnpackerMockRecorder) UnpackLongHeader(hdr, rcvTime, data, v interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), hdr, rcvTime, data) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), hdr, rcvTime, data, v)
} }
// UnpackShortHeader mocks base method. // UnpackShortHeader mocks base method.

View file

@ -38,16 +38,14 @@ type packetUnpacker struct {
cs handshake.CryptoSetup cs handshake.CryptoSetup
shortHdrConnIDLen int shortHdrConnIDLen int
version protocol.VersionNumber
} }
var _ unpacker = &packetUnpacker{} var _ unpacker = &packetUnpacker{}
func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int, version protocol.VersionNumber) unpacker { func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetUnpacker {
return &packetUnpacker{ return &packetUnpacker{
cs: cs, cs: cs,
shortHdrConnIDLen: shortHdrConnIDLen, shortHdrConnIDLen: shortHdrConnIDLen,
version: version,
} }
} }
@ -55,7 +53,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int, version
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
// If any other error occurred when parsing the header, the error is of type headerParseError. // If any other error occurred when parsing the header, the error is of type headerParseError.
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD. // If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) {
var encLevel protocol.EncryptionLevel var encLevel protocol.EncryptionLevel
var extHdr *wire.ExtendedHeader var extHdr *wire.ExtendedHeader
var decrypted []byte var decrypted []byte
@ -67,7 +65,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
if err != nil { if err != nil {
return nil, err return nil, err
} }
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -77,7 +75,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
if err != nil { if err != nil {
return nil, err return nil, err
} }
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -87,7 +85,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
if err != nil { if err != nil {
return nil, err return nil, err
} }
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -127,8 +125,8 @@ func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (prot
return pn, pnLen, kp, decrypted, nil return pn, pnLen, kp, decrypted, nil
} }
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data) extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
// If the reserved bits are set incorrectly, we still need to continue unpacking. // If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker // This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption. // to gain information about the header encryption.
@ -189,15 +187,15 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int
} }
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) {
extHdr, err := unpackLongHeader(hd, hdr, data, u.version) extHdr, err := unpackLongHeader(hd, hdr, data, v)
if err != nil && err != wire.ErrInvalidReservedBits { if err != nil && err != wire.ErrInvalidReservedBits {
return nil, &headerParseError{err: err} return nil, &headerParseError{err: err}
} }
return extHdr, err return extHdr, err
} }
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) { func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) {
r := bytes.NewReader(data) r := bytes.NewReader(data)
hdrLen := hdr.ParsedLen() hdrLen := hdr.ParsedLen()
@ -216,7 +214,7 @@ func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version
data[hdrLen:hdrLen+4], data[hdrLen:hdrLen+4],
) )
// 3. parse the header (and learn the actual length of the packet number) // 3. parse the header (and learn the actual length of the packet number)
extHdr, parseErr := hdr.ParseExtended(r, version) extHdr, parseErr := hdr.ParseExtended(r, v)
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, parseErr return nil, parseErr
} }

View file

@ -16,8 +16,6 @@ import (
) )
var _ = Describe("Packet Unpacker", func() { var _ = Describe("Packet Unpacker", func() {
const version = protocol.VersionTLS
var ( var (
unpacker *packetUnpacker unpacker *packetUnpacker
cs *mocks.MockCryptoSetup cs *mocks.MockCryptoSetup
@ -26,7 +24,7 @@ var _ = Describe("Packet Unpacker", func() {
) )
getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) { getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) {
b, err := extHdr.Append(nil, version) b, err := extHdr.Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
hdrLen := len(b) hdrLen := len(b)
@ -46,7 +44,7 @@ var _ = Describe("Packet Unpacker", func() {
BeforeEach(func() { BeforeEach(func() {
cs = mocks.NewMockCryptoSetup(mockCtrl) cs = mocks.NewMockCryptoSetup(mockCtrl)
unpacker = newPacketUnpacker(cs, 4, version).(*packetUnpacker) unpacker = newPacketUnpacker(cs, 4)
}) })
It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() { It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() {
@ -54,7 +52,7 @@ var _ = Describe("Packet Unpacker", func() {
Header: wire.Header{ Header: wire.Header{
Type: protocol.PacketTypeHandshake, Type: protocol.PacketTypeHandshake,
DestConnectionID: connID, DestConnectionID: connID,
Version: version, Version: protocol.Version1,
}, },
PacketNumber: 1337, PacketNumber: 1337,
PacketNumberLen: protocol.PacketNumberLen2, PacketNumberLen: protocol.PacketNumberLen2,
@ -63,7 +61,7 @@ var _ = Describe("Packet Unpacker", func() {
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockLongHeaderOpener(mockCtrl) opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), data) _, err := unpacker.UnpackLongHeader(hdr, time.Now(), data, protocol.Version1)
Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
var headerErr *headerParseError var headerErr *headerParseError
Expect(errors.As(err, &headerErr)).To(BeTrue()) Expect(errors.As(err, &headerErr)).To(BeTrue())
@ -87,7 +85,7 @@ var _ = Describe("Packet Unpacker", func() {
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
Length: 3 + 6, // packet number len + payload Length: 3 + 6, // packet number len + payload
DestConnectionID: connID, DestConnectionID: connID,
Version: version, Version: protocol.Version1,
}, },
PacketNumber: 2, PacketNumber: 2,
PacketNumberLen: 3, PacketNumberLen: 3,
@ -100,7 +98,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)), opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil), opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil),
) )
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...)) packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
Expect(packet.data).To(Equal([]byte("decrypted"))) Expect(packet.data).To(Equal([]byte("decrypted")))
@ -112,7 +110,7 @@ var _ = Describe("Packet Unpacker", func() {
Type: protocol.PacketType0RTT, Type: protocol.PacketType0RTT,
Length: 3 + 6, // packet number len + payload Length: 3 + 6, // packet number len + payload
DestConnectionID: connID, DestConnectionID: connID,
Version: version, Version: protocol.Version1,
}, },
PacketNumber: 20, PacketNumber: 20,
PacketNumberLen: 2, PacketNumberLen: 2,
@ -125,7 +123,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)), opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil), opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil),
) )
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...)) packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT)) Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT))
Expect(packet.data).To(Equal([]byte("decrypted"))) Expect(packet.data).To(Equal([]byte("decrypted")))
@ -174,7 +172,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)), opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte(""), nil), opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte(""), nil),
) )
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...)) _, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
Expect(err).To(MatchError(&qerr.TransportError{ Expect(err).To(MatchError(&qerr.TransportError{
ErrorCode: qerr.ProtocolViolation, ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "empty packet", ErrorMessage: "empty packet",
@ -204,7 +202,7 @@ var _ = Describe("Packet Unpacker", func() {
Type: protocol.PacketTypeHandshake, Type: protocol.PacketTypeHandshake,
Length: 3, // packet number len Length: 3, // packet number len
DestConnectionID: connID, DestConnectionID: connID,
Version: version, Version: protocol.Version1,
}, },
PacketNumber: 2, PacketNumber: 2,
PacketNumberLen: 3, PacketNumberLen: 3,
@ -216,7 +214,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded} unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded}
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...)) _, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
Expect(err).To(MatchError(unpackErr)) Expect(err).To(MatchError(unpackErr))
}) })
@ -225,7 +223,7 @@ var _ = Describe("Packet Unpacker", func() {
Header: wire.Header{ Header: wire.Header{
Type: protocol.PacketTypeHandshake, Type: protocol.PacketTypeHandshake,
DestConnectionID: connID, DestConnectionID: connID,
Version: version, Version: protocol.Version1,
}, },
PacketNumber: 0x1337, PacketNumber: 0x1337,
PacketNumberLen: 2, PacketNumberLen: 2,
@ -237,7 +235,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().GetHandshakeOpener().Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...)) _, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
}) })
@ -258,7 +256,7 @@ var _ = Describe("Packet Unpacker", func() {
Header: wire.Header{ Header: wire.Header{
Type: protocol.PacketTypeHandshake, Type: protocol.PacketTypeHandshake,
DestConnectionID: connID, DestConnectionID: connID,
Version: version, Version: protocol.Version1,
}, },
PacketNumber: 0x1337, PacketNumber: 0x1337,
PacketNumberLen: 2, PacketNumberLen: 2,
@ -270,7 +268,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().GetHandshakeOpener().Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...)) _, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
Expect(err).To(MatchError(handshake.ErrDecryptionFailed)) Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
}) })
@ -292,7 +290,7 @@ var _ = Describe("Packet Unpacker", func() {
Type: protocol.PacketTypeHandshake, Type: protocol.PacketTypeHandshake,
Length: 3, // packet number len Length: 3, // packet number len
DestConnectionID: connID, DestConnectionID: connID,
Version: version, Version: protocol.Version1,
}, },
PacketNumber: 0x1337, PacketNumber: 0x1337,
PacketNumberLen: 2, PacketNumberLen: 2,
@ -324,7 +322,7 @@ var _ = Describe("Packet Unpacker", func() {
for i := 1; i <= 100; i++ { for i := 1; i <= 100; i++ {
data = append(data, uint8(i)) data = append(data, uint8(i))
} }
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), data) packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), data, protocol.Version1)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x7331))) Expect(packet.hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x7331)))
}) })