don't explicitely pass the header binary to the unpacker

This commit is contained in:
Marten Seemann 2018-11-29 16:20:13 +07:00
parent ce6b9f2f1d
commit c912860a94
5 changed files with 24 additions and 24 deletions

View file

@ -35,14 +35,14 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder {
} }
// Unpack mocks base method // Unpack mocks base method
func (m *MockUnpacker) Unpack(arg0 []byte, arg1 *wire.ExtendedHeader, arg2 []byte) (*unpackedPacket, error) { func (m *MockUnpacker) Unpack(arg0 *wire.ExtendedHeader, arg1 []byte) (*unpackedPacket, error) {
ret := m.ctrl.Call(m, "Unpack", arg0, arg1, arg2) ret := m.ctrl.Call(m, "Unpack", arg0, arg1)
ret0, _ := ret[0].(*unpackedPacket) ret0, _ := ret[0].(*unpackedPacket)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// Unpack indicates an expected call of Unpack // Unpack indicates an expected call of Unpack
func (mr *MockUnpackerMockRecorder) Unpack(arg0, arg1, arg2 interface{}) *gomock.Call { func (mr *MockUnpackerMockRecorder) Unpack(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockUnpacker)(nil).Unpack), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockUnpacker)(nil).Unpack), arg0, arg1)
} }

View file

@ -40,7 +40,7 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
} }
} }
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) { func (u *packetUnpacker) Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
pn := protocol.DecodePacketNumber( pn := protocol.DecodePacketNumber(
hdr.PacketNumberLen, hdr.PacketNumberLen,
u.largestRcvdPacketNumber, u.largestRcvdPacketNumber,
@ -56,16 +56,16 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, d
var err error var err error
switch hdr.Type { switch hdr.Type {
case protocol.PacketTypeInitial: case protocol.PacketTypeInitial:
decrypted, err = u.aead.OpenInitial(buf, data, pn, headerBinary) decrypted, err = u.aead.OpenInitial(buf, data, pn, hdr.Raw)
encryptionLevel = protocol.EncryptionInitial encryptionLevel = protocol.EncryptionInitial
case protocol.PacketTypeHandshake: case protocol.PacketTypeHandshake:
decrypted, err = u.aead.OpenHandshake(buf, data, pn, headerBinary) decrypted, err = u.aead.OpenHandshake(buf, data, pn, hdr.Raw)
encryptionLevel = protocol.EncryptionHandshake encryptionLevel = protocol.EncryptionHandshake
default: default:
if hdr.IsLongHeader { if hdr.IsLongHeader {
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
} }
decrypted, err = u.aead.Open1RTT(buf, data, pn, headerBinary) decrypted, err = u.aead.Open1RTT(buf, data, pn, hdr.Raw)
encryptionLevel = protocol.Encryption1RTT encryptionLevel = protocol.Encryption1RTT
} }
if err != nil { if err != nil {

View file

@ -33,7 +33,7 @@ var _ = Describe("Packet Unpacker", func() {
It("errors if the packet doesn't contain any payload", func() { It("errors if the packet doesn't contain any payload", func() {
data := []byte("foobar") data := []byte("foobar")
aead.EXPECT().Open1RTT(gomock.Any(), []byte("foobar"), hdr.PacketNumber, hdr.Raw).Return([]byte{}, nil) aead.EXPECT().Open1RTT(gomock.Any(), []byte("foobar"), hdr.PacketNumber, hdr.Raw).Return([]byte{}, nil)
_, err := unpacker.Unpack(hdr.Raw, hdr, data) _, err := unpacker.Unpack(hdr, data)
Expect(err).To(MatchError(qerr.MissingPayload)) Expect(err).To(MatchError(qerr.MissingPayload))
}) })
@ -41,7 +41,7 @@ var _ = Describe("Packet Unpacker", func() {
hdr.IsLongHeader = true hdr.IsLongHeader = true
hdr.Type = protocol.PacketTypeInitial hdr.Type = protocol.PacketTypeInitial
aead.EXPECT().OpenInitial(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil) aead.EXPECT().OpenInitial(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil)
packet, err := unpacker.Unpack(hdr.Raw, hdr, nil) packet, err := unpacker.Unpack(hdr, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
}) })
@ -50,7 +50,7 @@ var _ = Describe("Packet Unpacker", func() {
hdr.IsLongHeader = true hdr.IsLongHeader = true
hdr.Type = protocol.PacketTypeHandshake hdr.Type = protocol.PacketTypeHandshake
aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil) aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil)
packet, err := unpacker.Unpack(hdr.Raw, hdr, nil) packet, err := unpacker.Unpack(hdr, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionHandshake)) Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionHandshake))
}) })
@ -59,7 +59,7 @@ var _ = Describe("Packet Unpacker", func() {
hdr.IsLongHeader = true hdr.IsLongHeader = true
hdr.Type = protocol.PacketTypeHandshake hdr.Type = protocol.PacketTypeHandshake
aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err")) aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
_, err := unpacker.Unpack(hdr.Raw, hdr, nil) _, err := unpacker.Unpack(hdr, nil)
Expect(err).To(MatchError(qerr.Error(qerr.DecryptionFailure, "test err"))) Expect(err).To(MatchError(qerr.Error(qerr.DecryptionFailure, "test err")))
}) })
@ -69,7 +69,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumberLen: 2, PacketNumberLen: 2,
} }
aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil) aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil)
packet, err := unpacker.Unpack(firstHdr.Raw, firstHdr, nil) packet, err := unpacker.Unpack(firstHdr, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337))) Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
// the real packet number is 0x1338, but only the last byte is sent // the real packet number is 0x1338, but only the last byte is sent
@ -79,7 +79,7 @@ var _ = Describe("Packet Unpacker", func() {
} }
// expect the call with the decoded packet number // expect the call with the decoded packet number
aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil) aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil)
packet, err = unpacker.Unpack(secondHdr.Raw, secondHdr, nil) packet, err = unpacker.Unpack(secondHdr, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338))) Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338)))
}) })
@ -89,7 +89,7 @@ var _ = Describe("Packet Unpacker", func() {
(&wire.PingFrame{}).Write(buf, protocol.VersionWhatever) (&wire.PingFrame{}).Write(buf, protocol.VersionWhatever)
(&wire.DataBlockedFrame{}).Write(buf, protocol.VersionWhatever) (&wire.DataBlockedFrame{}).Write(buf, protocol.VersionWhatever)
aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return(buf.Bytes(), nil) aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return(buf.Bytes(), nil)
packet, err := unpacker.Unpack(hdr.Raw, hdr, nil) packet, err := unpacker.Unpack(hdr, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{&wire.PingFrame{}, &wire.DataBlockedFrame{}})) Expect(packet.frames).To(Equal([]wire.Frame{&wire.PingFrame{}, &wire.DataBlockedFrame{}}))
}) })

View file

@ -22,7 +22,7 @@ import (
) )
type unpacker interface { type unpacker interface {
Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error)
} }
type streamGetter interface { type streamGetter interface {
@ -503,7 +503,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
// TODO(#1312): implement parsing of compound packets // TODO(#1312): implement parsing of compound packets
} }
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) packet, err := s.unpacker.Unpack(hdr, data)
// if the decryption failed, this might be a packet sent by an attacker // if the decryption failed, this might be a packet sent by an attacker
if err != nil { if err != nil {
return err return err

View file

@ -468,7 +468,7 @@ var _ = Describe("Session", func() {
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
} }
rcvTime := time.Now().Add(-10 * time.Second) rcvTime := time.Now().Add(-10 * time.Second)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{packetNumber: 0x1337}, nil) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{packetNumber: 0x1337}, nil)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false) rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false)
sess.receivedPacketHandler = rph sess.receivedPacketHandler = rph
@ -481,7 +481,7 @@ var _ = Describe("Session", func() {
It("closes when handling a packet fails", func() { It("closes when handling a packet fails", func() {
testErr := errors.New("unpack error") testErr := errors.New("unpack error")
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, testErr)
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
@ -499,7 +499,7 @@ var _ = Describe("Session", func() {
}) })
It("handles duplicate packets", func() { It("handles duplicate packets", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil).Times(2) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil).Times(2)
hdr := &wire.ExtendedHeader{ hdr := &wire.ExtendedHeader{
PacketNumber: 5, PacketNumber: 5,
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
@ -511,7 +511,7 @@ var _ = Describe("Session", func() {
It("ignores packets with a different source connection ID", func() { It("ignores packets with a different source connection ID", func() {
// Send one packet, which might change the connection ID. // Send one packet, which might change the connection ID.
// only EXPECT one call to the unpacker // only EXPECT one call to the unpacker
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
Expect(sess.handlePacketImpl(&receivedPacket{ Expect(sess.handlePacketImpl(&receivedPacket{
hdr: &wire.Header{ hdr: &wire.Header{
IsLongHeader: true, IsLongHeader: true,
@ -582,7 +582,7 @@ var _ = Describe("Session", func() {
payloadLen := 456 - int(pnLen) payloadLen := 456 - int(pnLen)
data := getData(hdr) data := getData(hdr)
data = append(data, make([]byte, payloadLen)...) data = append(data, make([]byte, payloadLen)...)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ []byte, _ *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(payloadLen)) Expect(data).To(HaveLen(payloadLen))
return &unpackedPacket{}, nil return &unpackedPacket{}, nil
}) })
@ -591,7 +591,7 @@ var _ = Describe("Session", func() {
Context("updating the remote address", func() { Context("updating the remote address", func() {
It("doesn't support connection migration", func() { It("doesn't support connection migration", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
origAddr := sess.conn.(*mockConnection).remoteAddr origAddr := sess.conn.(*mockConnection).remoteAddr
remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
Expect(origAddr).ToNot(Equal(remoteIP)) Expect(origAddr).ToNot(Equal(remoteIP))
@ -1362,7 +1362,7 @@ var _ = Describe("Client Session", func() {
It("changes the connection ID when receiving the first packet from the server", func() { It("changes the connection ID when receiving the first packet from the server", func() {
unpacker := NewMockUnpacker(mockCtrl) unpacker := NewMockUnpacker(mockCtrl)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
sess.unpacker = unpacker sess.unpacker = unpacker
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()