mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 04:37:36 +03:00
refactor the packet unpacker to remove version param from constructor
This commit is contained in:
parent
d4a0793b6a
commit
56985cbfd1
5 changed files with 49 additions and 53 deletions
|
@ -25,7 +25,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type unpacker interface {
|
type unpacker interface {
|
||||||
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error)
|
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error)
|
||||||
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
|
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -366,7 +366,7 @@ var newConnection = func(
|
||||||
s.perspective,
|
s.perspective,
|
||||||
s.version,
|
s.version,
|
||||||
)
|
)
|
||||||
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen, s.version)
|
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
|
||||||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
|
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
@ -476,7 +476,7 @@ var newClientConnection = func(
|
||||||
s.clientHelloWritten = clientHelloWritten
|
s.clientHelloWritten = clientHelloWritten
|
||||||
s.cryptoStreamHandler = cs
|
s.cryptoStreamHandler = cs
|
||||||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
|
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
|
||||||
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen, s.version)
|
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
|
||||||
s.packer = newPacketPacker(
|
s.packer = newPacketPacker(
|
||||||
srcConnID,
|
srcConnID,
|
||||||
s.connIDManager.Get,
|
s.connIDManager.Get,
|
||||||
|
@ -1009,7 +1009,7 @@ func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
packet, err := s.unpacker.UnpackLongHeader(hdr, p.rcvTime, p.data)
|
packet, err := s.unpacker.UnpackLongHeader(hdr, p.rcvTime, p.data, s.version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr))
|
wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr))
|
||||||
return false
|
return false
|
||||||
|
|
|
@ -755,7 +755,7 @@ var _ = Describe("Connection", func() {
|
||||||
packet := getLongHeaderPacket(hdr, nil)
|
packet := getLongHeaderPacket(hdr, nil)
|
||||||
packet.ecn = protocol.ECNCE
|
packet.ecn = protocol.ECNCE
|
||||||
rcvTime := time.Now().Add(-10 * time.Second)
|
rcvTime := time.Now().Add(-10 * time.Second)
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), rcvTime, gomock.Any(), conn.version).Return(&unpackedPacket{
|
||||||
encryptionLevel: protocol.EncryptionInitial,
|
encryptionLevel: protocol.EncryptionInitial,
|
||||||
hdr: &unpackedHdr,
|
hdr: &unpackedHdr,
|
||||||
data: []byte{0}, // one PADDING frame
|
data: []byte{0}, // one PADDING frame
|
||||||
|
@ -801,7 +801,7 @@ var _ = Describe("Connection", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops a packet when unpacking fails", func() {
|
It("drops a packet when unpacking fails", func() {
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(nil, handshake.ErrDecryptionFailed)
|
||||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||||
cryptoSetup.EXPECT().Close()
|
cryptoSetup.EXPECT().Close()
|
||||||
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
|
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
|
||||||
|
@ -998,7 +998,7 @@ var _ = Describe("Connection", func() {
|
||||||
Expect(srcConnID).ToNot(Equal(hdr2.SrcConnectionID))
|
Expect(srcConnID).ToNot(Equal(hdr2.SrcConnectionID))
|
||||||
// Send one packet, which might change the connection ID.
|
// Send one packet, which might change the connection ID.
|
||||||
// only EXPECT one call to the unpacker
|
// only EXPECT one call to the unpacker
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(&unpackedPacket{
|
||||||
encryptionLevel: protocol.Encryption1RTT,
|
encryptionLevel: protocol.Encryption1RTT,
|
||||||
hdr: hdr1,
|
hdr: hdr1,
|
||||||
data: []byte{0}, // one PADDING frame
|
data: []byte{0}, // one PADDING frame
|
||||||
|
@ -1026,7 +1026,7 @@ var _ = Describe("Connection", func() {
|
||||||
PacketNumberLen: protocol.PacketNumberLen1,
|
PacketNumberLen: protocol.PacketNumberLen1,
|
||||||
PacketNumber: 1,
|
PacketNumber: 1,
|
||||||
}
|
}
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable)
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(nil, handshake.ErrKeysNotYetAvailable)
|
||||||
packet := getLongHeaderPacket(hdr, nil)
|
packet := getLongHeaderPacket(hdr, nil)
|
||||||
tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake, packet.Size())
|
tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake, packet.Size())
|
||||||
Expect(conn.handlePacketImpl(packet)).To(BeFalse())
|
Expect(conn.handlePacketImpl(packet)).To(BeFalse())
|
||||||
|
@ -1067,7 +1067,7 @@ var _ = Describe("Connection", func() {
|
||||||
|
|
||||||
It("cuts packets to the right length", func() {
|
It("cuts packets to the right length", func() {
|
||||||
hdrLen, packet := getPacketWithLength(srcConnID, 456)
|
hdrLen, packet := getPacketWithLength(srcConnID, 456)
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
|
||||||
Expect(data).To(HaveLen(hdrLen + 456 - 3))
|
Expect(data).To(HaveLen(hdrLen + 456 - 3))
|
||||||
return &unpackedPacket{
|
return &unpackedPacket{
|
||||||
encryptionLevel: protocol.EncryptionHandshake,
|
encryptionLevel: protocol.EncryptionHandshake,
|
||||||
|
@ -1081,7 +1081,7 @@ var _ = Describe("Connection", func() {
|
||||||
|
|
||||||
It("handles coalesced packets", func() {
|
It("handles coalesced packets", func() {
|
||||||
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
|
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
|
||||||
Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
|
Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
|
||||||
return &unpackedPacket{
|
return &unpackedPacket{
|
||||||
encryptionLevel: protocol.EncryptionHandshake,
|
encryptionLevel: protocol.EncryptionHandshake,
|
||||||
|
@ -1093,7 +1093,7 @@ var _ = Describe("Connection", func() {
|
||||||
}, nil
|
}, nil
|
||||||
})
|
})
|
||||||
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
|
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
|
||||||
Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
|
Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
|
||||||
return &unpackedPacket{
|
return &unpackedPacket{
|
||||||
encryptionLevel: protocol.EncryptionHandshake,
|
encryptionLevel: protocol.EncryptionHandshake,
|
||||||
|
@ -1117,8 +1117,8 @@ var _ = Describe("Connection", func() {
|
||||||
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
|
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
|
||||||
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
|
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
|
||||||
gomock.InOrder(
|
gomock.InOrder(
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable),
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(nil, handshake.ErrKeysNotYetAvailable),
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
|
||||||
Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
|
Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
|
||||||
return &unpackedPacket{
|
return &unpackedPacket{
|
||||||
encryptionLevel: protocol.EncryptionHandshake,
|
encryptionLevel: protocol.EncryptionHandshake,
|
||||||
|
@ -1142,7 +1142,7 @@ var _ = Describe("Connection", func() {
|
||||||
wrongConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
|
wrongConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
|
||||||
Expect(srcConnID).ToNot(Equal(wrongConnID))
|
Expect(srcConnID).ToNot(Equal(wrongConnID))
|
||||||
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
|
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
|
||||||
Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
|
Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
|
||||||
return &unpackedPacket{
|
return &unpackedPacket{
|
||||||
encryptionLevel: protocol.EncryptionHandshake,
|
encryptionLevel: protocol.EncryptionHandshake,
|
||||||
|
@ -2421,7 +2421,7 @@ var _ = Describe("Client Connection", func() {
|
||||||
|
|
||||||
It("changes the connection ID when receiving the first packet from the server", func() {
|
It("changes the connection ID when receiving the first packet from the server", func() {
|
||||||
unpacker := NewMockUnpacker(mockCtrl)
|
unpacker := NewMockUnpacker(mockCtrl)
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
|
||||||
return &unpackedPacket{
|
return &unpackedPacket{
|
||||||
encryptionLevel: protocol.Encryption1RTT,
|
encryptionLevel: protocol.Encryption1RTT,
|
||||||
hdr: &wire.ExtendedHeader{Header: *hdr},
|
hdr: &wire.ExtendedHeader{Header: *hdr},
|
||||||
|
@ -2469,7 +2469,7 @@ var _ = Describe("Client Connection", func() {
|
||||||
})
|
})
|
||||||
Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})))
|
Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})))
|
||||||
// now receive a packet with the original source connection ID
|
// now receive a packet with the original source connection ID
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte) (*unpackedPacket, error) {
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
|
||||||
return &unpackedPacket{
|
return &unpackedPacket{
|
||||||
hdr: &wire.ExtendedHeader{Header: *hdr},
|
hdr: &wire.ExtendedHeader{Header: *hdr},
|
||||||
data: []byte{0},
|
data: []byte{0},
|
||||||
|
@ -2906,7 +2906,7 @@ var _ = Describe("Client Connection", func() {
|
||||||
Expect(hdr2.SrcConnectionID).ToNot(Equal(srcConnID))
|
Expect(hdr2.SrcConnectionID).ToNot(Equal(srcConnID))
|
||||||
// Send one packet, which might change the connection ID.
|
// Send one packet, which might change the connection ID.
|
||||||
// only EXPECT one call to the unpacker
|
// only EXPECT one call to the unpacker
|
||||||
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(&unpackedPacket{
|
||||||
encryptionLevel: protocol.EncryptionInitial,
|
encryptionLevel: protocol.EncryptionInitial,
|
||||||
hdr: hdr1,
|
hdr: hdr1,
|
||||||
data: []byte{0}, // one PADDING frame
|
data: []byte{0}, // one PADDING frame
|
||||||
|
|
|
@ -37,18 +37,18 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder {
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnpackLongHeader mocks base method.
|
// UnpackLongHeader mocks base method.
|
||||||
func (m *MockUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) {
|
func (m *MockUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "UnpackLongHeader", hdr, rcvTime, data)
|
ret := m.ctrl.Call(m, "UnpackLongHeader", hdr, rcvTime, data, v)
|
||||||
ret0, _ := ret[0].(*unpackedPacket)
|
ret0, _ := ret[0].(*unpackedPacket)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnpackLongHeader indicates an expected call of UnpackLongHeader.
|
// UnpackLongHeader indicates an expected call of UnpackLongHeader.
|
||||||
func (mr *MockUnpackerMockRecorder) UnpackLongHeader(hdr, rcvTime, data interface{}) *gomock.Call {
|
func (mr *MockUnpackerMockRecorder) UnpackLongHeader(hdr, rcvTime, data, v interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), hdr, rcvTime, data)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), hdr, rcvTime, data, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnpackShortHeader mocks base method.
|
// UnpackShortHeader mocks base method.
|
||||||
|
|
|
@ -38,16 +38,14 @@ type packetUnpacker struct {
|
||||||
cs handshake.CryptoSetup
|
cs handshake.CryptoSetup
|
||||||
|
|
||||||
shortHdrConnIDLen int
|
shortHdrConnIDLen int
|
||||||
version protocol.VersionNumber
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ unpacker = &packetUnpacker{}
|
var _ unpacker = &packetUnpacker{}
|
||||||
|
|
||||||
func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int, version protocol.VersionNumber) unpacker {
|
func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetUnpacker {
|
||||||
return &packetUnpacker{
|
return &packetUnpacker{
|
||||||
cs: cs,
|
cs: cs,
|
||||||
shortHdrConnIDLen: shortHdrConnIDLen,
|
shortHdrConnIDLen: shortHdrConnIDLen,
|
||||||
version: version,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,7 +53,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int, version
|
||||||
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
|
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
|
||||||
// If any other error occurred when parsing the header, the error is of type headerParseError.
|
// If any other error occurred when parsing the header, the error is of type headerParseError.
|
||||||
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
|
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
|
||||||
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) {
|
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) {
|
||||||
var encLevel protocol.EncryptionLevel
|
var encLevel protocol.EncryptionLevel
|
||||||
var extHdr *wire.ExtendedHeader
|
var extHdr *wire.ExtendedHeader
|
||||||
var decrypted []byte
|
var decrypted []byte
|
||||||
|
@ -67,7 +65,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
|
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -77,7 +75,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
|
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -87,7 +85,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
|
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -127,8 +125,8 @@ func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (prot
|
||||||
return pn, pnLen, kp, decrypted, nil
|
return pn, pnLen, kp, decrypted, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
|
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, []byte, error) {
|
||||||
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data)
|
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
|
||||||
// If the reserved bits are set incorrectly, we still need to continue unpacking.
|
// If the reserved bits are set incorrectly, we still need to continue unpacking.
|
||||||
// This avoids a timing side-channel, which otherwise might allow an attacker
|
// This avoids a timing side-channel, which otherwise might allow an attacker
|
||||||
// to gain information about the header encryption.
|
// to gain information about the header encryption.
|
||||||
|
@ -189,15 +187,15 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int
|
||||||
}
|
}
|
||||||
|
|
||||||
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
|
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
|
||||||
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
|
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) {
|
||||||
extHdr, err := unpackLongHeader(hd, hdr, data, u.version)
|
extHdr, err := unpackLongHeader(hd, hdr, data, v)
|
||||||
if err != nil && err != wire.ErrInvalidReservedBits {
|
if err != nil && err != wire.ErrInvalidReservedBits {
|
||||||
return nil, &headerParseError{err: err}
|
return nil, &headerParseError{err: err}
|
||||||
}
|
}
|
||||||
return extHdr, err
|
return extHdr, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) {
|
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) {
|
||||||
r := bytes.NewReader(data)
|
r := bytes.NewReader(data)
|
||||||
|
|
||||||
hdrLen := hdr.ParsedLen()
|
hdrLen := hdr.ParsedLen()
|
||||||
|
@ -216,7 +214,7 @@ func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version
|
||||||
data[hdrLen:hdrLen+4],
|
data[hdrLen:hdrLen+4],
|
||||||
)
|
)
|
||||||
// 3. parse the header (and learn the actual length of the packet number)
|
// 3. parse the header (and learn the actual length of the packet number)
|
||||||
extHdr, parseErr := hdr.ParseExtended(r, version)
|
extHdr, parseErr := hdr.ParseExtended(r, v)
|
||||||
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
|
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
|
||||||
return nil, parseErr
|
return nil, parseErr
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,8 +16,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Packet Unpacker", func() {
|
var _ = Describe("Packet Unpacker", func() {
|
||||||
const version = protocol.VersionTLS
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
unpacker *packetUnpacker
|
unpacker *packetUnpacker
|
||||||
cs *mocks.MockCryptoSetup
|
cs *mocks.MockCryptoSetup
|
||||||
|
@ -26,7 +24,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
)
|
)
|
||||||
|
|
||||||
getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) {
|
getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) {
|
||||||
b, err := extHdr.Append(nil, version)
|
b, err := extHdr.Append(nil, protocol.Version1)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||||
hdrLen := len(b)
|
hdrLen := len(b)
|
||||||
|
@ -46,7 +44,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
cs = mocks.NewMockCryptoSetup(mockCtrl)
|
cs = mocks.NewMockCryptoSetup(mockCtrl)
|
||||||
unpacker = newPacketUnpacker(cs, 4, version).(*packetUnpacker)
|
unpacker = newPacketUnpacker(cs, 4)
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() {
|
It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() {
|
||||||
|
@ -54,7 +52,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
Header: wire.Header{
|
Header: wire.Header{
|
||||||
Type: protocol.PacketTypeHandshake,
|
Type: protocol.PacketTypeHandshake,
|
||||||
DestConnectionID: connID,
|
DestConnectionID: connID,
|
||||||
Version: version,
|
Version: protocol.Version1,
|
||||||
},
|
},
|
||||||
PacketNumber: 1337,
|
PacketNumber: 1337,
|
||||||
PacketNumberLen: protocol.PacketNumberLen2,
|
PacketNumberLen: protocol.PacketNumberLen2,
|
||||||
|
@ -63,7 +61,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
|
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
|
||||||
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
|
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
|
||||||
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
|
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
|
||||||
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), data)
|
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), data, protocol.Version1)
|
||||||
Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
|
Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
|
||||||
var headerErr *headerParseError
|
var headerErr *headerParseError
|
||||||
Expect(errors.As(err, &headerErr)).To(BeTrue())
|
Expect(errors.As(err, &headerErr)).To(BeTrue())
|
||||||
|
@ -87,7 +85,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
Type: protocol.PacketTypeInitial,
|
Type: protocol.PacketTypeInitial,
|
||||||
Length: 3 + 6, // packet number len + payload
|
Length: 3 + 6, // packet number len + payload
|
||||||
DestConnectionID: connID,
|
DestConnectionID: connID,
|
||||||
Version: version,
|
Version: protocol.Version1,
|
||||||
},
|
},
|
||||||
PacketNumber: 2,
|
PacketNumber: 2,
|
||||||
PacketNumberLen: 3,
|
PacketNumberLen: 3,
|
||||||
|
@ -100,7 +98,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)),
|
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)),
|
||||||
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil),
|
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil),
|
||||||
)
|
)
|
||||||
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
|
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
|
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
|
||||||
Expect(packet.data).To(Equal([]byte("decrypted")))
|
Expect(packet.data).To(Equal([]byte("decrypted")))
|
||||||
|
@ -112,7 +110,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
Type: protocol.PacketType0RTT,
|
Type: protocol.PacketType0RTT,
|
||||||
Length: 3 + 6, // packet number len + payload
|
Length: 3 + 6, // packet number len + payload
|
||||||
DestConnectionID: connID,
|
DestConnectionID: connID,
|
||||||
Version: version,
|
Version: protocol.Version1,
|
||||||
},
|
},
|
||||||
PacketNumber: 20,
|
PacketNumber: 20,
|
||||||
PacketNumberLen: 2,
|
PacketNumberLen: 2,
|
||||||
|
@ -125,7 +123,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)),
|
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)),
|
||||||
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil),
|
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil),
|
||||||
)
|
)
|
||||||
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
|
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT))
|
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT))
|
||||||
Expect(packet.data).To(Equal([]byte("decrypted")))
|
Expect(packet.data).To(Equal([]byte("decrypted")))
|
||||||
|
@ -174,7 +172,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)),
|
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)),
|
||||||
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte(""), nil),
|
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte(""), nil),
|
||||||
)
|
)
|
||||||
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
|
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
|
||||||
Expect(err).To(MatchError(&qerr.TransportError{
|
Expect(err).To(MatchError(&qerr.TransportError{
|
||||||
ErrorCode: qerr.ProtocolViolation,
|
ErrorCode: qerr.ProtocolViolation,
|
||||||
ErrorMessage: "empty packet",
|
ErrorMessage: "empty packet",
|
||||||
|
@ -204,7 +202,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
Type: protocol.PacketTypeHandshake,
|
Type: protocol.PacketTypeHandshake,
|
||||||
Length: 3, // packet number len
|
Length: 3, // packet number len
|
||||||
DestConnectionID: connID,
|
DestConnectionID: connID,
|
||||||
Version: version,
|
Version: protocol.Version1,
|
||||||
},
|
},
|
||||||
PacketNumber: 2,
|
PacketNumber: 2,
|
||||||
PacketNumberLen: 3,
|
PacketNumberLen: 3,
|
||||||
|
@ -216,7 +214,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
|
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
|
||||||
unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded}
|
unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded}
|
||||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr)
|
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr)
|
||||||
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
|
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
|
||||||
Expect(err).To(MatchError(unpackErr))
|
Expect(err).To(MatchError(unpackErr))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -225,7 +223,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
Header: wire.Header{
|
Header: wire.Header{
|
||||||
Type: protocol.PacketTypeHandshake,
|
Type: protocol.PacketTypeHandshake,
|
||||||
DestConnectionID: connID,
|
DestConnectionID: connID,
|
||||||
Version: version,
|
Version: protocol.Version1,
|
||||||
},
|
},
|
||||||
PacketNumber: 0x1337,
|
PacketNumber: 0x1337,
|
||||||
PacketNumberLen: 2,
|
PacketNumberLen: 2,
|
||||||
|
@ -237,7 +235,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
|
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
|
||||||
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
|
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
|
||||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
|
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
|
||||||
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
|
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
|
||||||
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
|
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -258,7 +256,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
Header: wire.Header{
|
Header: wire.Header{
|
||||||
Type: protocol.PacketTypeHandshake,
|
Type: protocol.PacketTypeHandshake,
|
||||||
DestConnectionID: connID,
|
DestConnectionID: connID,
|
||||||
Version: version,
|
Version: protocol.Version1,
|
||||||
},
|
},
|
||||||
PacketNumber: 0x1337,
|
PacketNumber: 0x1337,
|
||||||
PacketNumberLen: 2,
|
PacketNumberLen: 2,
|
||||||
|
@ -270,7 +268,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
|
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
|
||||||
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
|
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
|
||||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
|
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
|
||||||
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...))
|
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
|
||||||
Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
|
Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -292,7 +290,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
Type: protocol.PacketTypeHandshake,
|
Type: protocol.PacketTypeHandshake,
|
||||||
Length: 3, // packet number len
|
Length: 3, // packet number len
|
||||||
DestConnectionID: connID,
|
DestConnectionID: connID,
|
||||||
Version: version,
|
Version: protocol.Version1,
|
||||||
},
|
},
|
||||||
PacketNumber: 0x1337,
|
PacketNumber: 0x1337,
|
||||||
PacketNumberLen: 2,
|
PacketNumberLen: 2,
|
||||||
|
@ -324,7 +322,7 @@ var _ = Describe("Packet Unpacker", func() {
|
||||||
for i := 1; i <= 100; i++ {
|
for i := 1; i <= 100; i++ {
|
||||||
data = append(data, uint8(i))
|
data = append(data, uint8(i))
|
||||||
}
|
}
|
||||||
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), data)
|
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), data, protocol.Version1)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(packet.hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x7331)))
|
Expect(packet.hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x7331)))
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue