refactor the packet unpacker to remove version param from constructor

This commit is contained in:
Marten Seemann 2022-12-24 13:00:21 +13:00
parent d4a0793b6a
commit 56985cbfd1
5 changed files with 49 additions and 53 deletions

View file

@ -16,8 +16,6 @@ import (
)
var _ = Describe("Packet Unpacker", func() {
const version = protocol.VersionTLS
var (
unpacker *packetUnpacker
cs *mocks.MockCryptoSetup
@ -26,7 +24,7 @@ var _ = Describe("Packet Unpacker", func() {
)
getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) {
b, err := extHdr.Append(nil, version)
b, err := extHdr.Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, err).ToNot(HaveOccurred())
hdrLen := len(b)
@ -46,7 +44,7 @@ var _ = Describe("Packet Unpacker", func() {
BeforeEach(func() {
cs = mocks.NewMockCryptoSetup(mockCtrl)
unpacker = newPacketUnpacker(cs, 4, version).(*packetUnpacker)
unpacker = newPacketUnpacker(cs, 4)
})
It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() {
@ -54,7 +52,7 @@ var _ = Describe("Packet Unpacker", func() {
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: connID,
Version: version,
Version: protocol.Version1,
},
PacketNumber: 1337,
PacketNumberLen: protocol.PacketNumberLen2,
@ -63,7 +61,7 @@ var _ = Describe("Packet Unpacker", func() {
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), data)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), data, protocol.Version1)
Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
var headerErr *headerParseError
Expect(errors.As(err, &headerErr)).To(BeTrue())
@ -87,7 +85,7 @@ var _ = Describe("Packet Unpacker", func() {
Type: protocol.PacketTypeInitial,
Length: 3 + 6, // packet number len + payload
DestConnectionID: connID,
Version: version,
Version: protocol.Version1,
},
PacketNumber: 2,
PacketNumberLen: 3,
@ -100,7 +98,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
Expect(packet.data).To(Equal([]byte("decrypted")))
@ -112,7 +110,7 @@ var _ = Describe("Packet Unpacker", func() {
Type: protocol.PacketType0RTT,
Length: 3 + 6, // packet number len + payload
DestConnectionID: connID,
Version: version,
Version: protocol.Version1,
},
PacketNumber: 20,
PacketNumberLen: 2,
@ -125,7 +123,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT))
Expect(packet.data).To(Equal([]byte("decrypted")))
@ -174,7 +172,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte(""), nil),
)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
Expect(err).To(MatchError(&qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "empty packet",
@ -204,7 +202,7 @@ var _ = Describe("Packet Unpacker", func() {
Type: protocol.PacketTypeHandshake,
Length: 3, // packet number len
DestConnectionID: connID,
Version: version,
Version: protocol.Version1,
},
PacketNumber: 2,
PacketNumberLen: 3,
@ -216,7 +214,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded}
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
Expect(err).To(MatchError(unpackErr))
})
@ -225,7 +223,7 @@ var _ = Describe("Packet Unpacker", func() {
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: connID,
Version: version,
Version: protocol.Version1,
},
PacketNumber: 0x1337,
PacketNumberLen: 2,
@ -237,7 +235,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
})
@ -258,7 +256,7 @@ var _ = Describe("Packet Unpacker", func() {
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: connID,
Version: version,
Version: protocol.Version1,
},
PacketNumber: 0x1337,
PacketNumberLen: 2,
@ -270,7 +268,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
})
@ -292,7 +290,7 @@ var _ = Describe("Packet Unpacker", func() {
Type: protocol.PacketTypeHandshake,
Length: 3, // packet number len
DestConnectionID: connID,
Version: version,
Version: protocol.Version1,
},
PacketNumber: 0x1337,
PacketNumberLen: 2,
@ -324,7 +322,7 @@ var _ = Describe("Packet Unpacker", func() {
for i := 1; i <= 100; i++ {
data = append(data, uint8(i))
}
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), data)
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), data, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(packet.hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x7331)))
})