don't explicitely pass the header binary to the unpacker

This commit is contained in:
Marten Seemann 2018-11-29 16:20:13 +07:00
parent ce6b9f2f1d
commit c912860a94
5 changed files with 24 additions and 24 deletions

View file

@ -35,14 +35,14 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder {
}
// Unpack mocks base method
func (m *MockUnpacker) Unpack(arg0 []byte, arg1 *wire.ExtendedHeader, arg2 []byte) (*unpackedPacket, error) {
ret := m.ctrl.Call(m, "Unpack", arg0, arg1, arg2)
func (m *MockUnpacker) Unpack(arg0 *wire.ExtendedHeader, arg1 []byte) (*unpackedPacket, error) {
ret := m.ctrl.Call(m, "Unpack", arg0, arg1)
ret0, _ := ret[0].(*unpackedPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Unpack indicates an expected call of Unpack
func (mr *MockUnpackerMockRecorder) Unpack(arg0, arg1, arg2 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockUnpacker)(nil).Unpack), arg0, arg1, arg2)
func (mr *MockUnpackerMockRecorder) Unpack(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockUnpacker)(nil).Unpack), arg0, arg1)
}

View file

@ -40,7 +40,7 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
}
}
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
func (u *packetUnpacker) Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
pn := protocol.DecodePacketNumber(
hdr.PacketNumberLen,
u.largestRcvdPacketNumber,
@ -56,16 +56,16 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, d
var err error
switch hdr.Type {
case protocol.PacketTypeInitial:
decrypted, err = u.aead.OpenInitial(buf, data, pn, headerBinary)
decrypted, err = u.aead.OpenInitial(buf, data, pn, hdr.Raw)
encryptionLevel = protocol.EncryptionInitial
case protocol.PacketTypeHandshake:
decrypted, err = u.aead.OpenHandshake(buf, data, pn, headerBinary)
decrypted, err = u.aead.OpenHandshake(buf, data, pn, hdr.Raw)
encryptionLevel = protocol.EncryptionHandshake
default:
if hdr.IsLongHeader {
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
}
decrypted, err = u.aead.Open1RTT(buf, data, pn, headerBinary)
decrypted, err = u.aead.Open1RTT(buf, data, pn, hdr.Raw)
encryptionLevel = protocol.Encryption1RTT
}
if err != nil {

View file

@ -33,7 +33,7 @@ var _ = Describe("Packet Unpacker", func() {
It("errors if the packet doesn't contain any payload", func() {
data := []byte("foobar")
aead.EXPECT().Open1RTT(gomock.Any(), []byte("foobar"), hdr.PacketNumber, hdr.Raw).Return([]byte{}, nil)
_, err := unpacker.Unpack(hdr.Raw, hdr, data)
_, err := unpacker.Unpack(hdr, data)
Expect(err).To(MatchError(qerr.MissingPayload))
})
@ -41,7 +41,7 @@ var _ = Describe("Packet Unpacker", func() {
hdr.IsLongHeader = true
hdr.Type = protocol.PacketTypeInitial
aead.EXPECT().OpenInitial(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil)
packet, err := unpacker.Unpack(hdr.Raw, hdr, nil)
packet, err := unpacker.Unpack(hdr, nil)
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
})
@ -50,7 +50,7 @@ var _ = Describe("Packet Unpacker", func() {
hdr.IsLongHeader = true
hdr.Type = protocol.PacketTypeHandshake
aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil)
packet, err := unpacker.Unpack(hdr.Raw, hdr, nil)
packet, err := unpacker.Unpack(hdr, nil)
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionHandshake))
})
@ -59,7 +59,7 @@ var _ = Describe("Packet Unpacker", func() {
hdr.IsLongHeader = true
hdr.Type = protocol.PacketTypeHandshake
aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
_, err := unpacker.Unpack(hdr.Raw, hdr, nil)
_, err := unpacker.Unpack(hdr, nil)
Expect(err).To(MatchError(qerr.Error(qerr.DecryptionFailure, "test err")))
})
@ -69,7 +69,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumberLen: 2,
}
aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil)
packet, err := unpacker.Unpack(firstHdr.Raw, firstHdr, nil)
packet, err := unpacker.Unpack(firstHdr, nil)
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
// the real packet number is 0x1338, but only the last byte is sent
@ -79,7 +79,7 @@ var _ = Describe("Packet Unpacker", func() {
}
// expect the call with the decoded packet number
aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil)
packet, err = unpacker.Unpack(secondHdr.Raw, secondHdr, nil)
packet, err = unpacker.Unpack(secondHdr, nil)
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338)))
})
@ -89,7 +89,7 @@ var _ = Describe("Packet Unpacker", func() {
(&wire.PingFrame{}).Write(buf, protocol.VersionWhatever)
(&wire.DataBlockedFrame{}).Write(buf, protocol.VersionWhatever)
aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return(buf.Bytes(), nil)
packet, err := unpacker.Unpack(hdr.Raw, hdr, nil)
packet, err := unpacker.Unpack(hdr, nil)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{&wire.PingFrame{}, &wire.DataBlockedFrame{}}))
})

View file

@ -22,7 +22,7 @@ import (
)
type unpacker interface {
Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error)
Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error)
}
type streamGetter interface {
@ -503,7 +503,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
// TODO(#1312): implement parsing of compound packets
}
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
packet, err := s.unpacker.Unpack(hdr, data)
// if the decryption failed, this might be a packet sent by an attacker
if err != nil {
return err

View file

@ -468,7 +468,7 @@ var _ = Describe("Session", func() {
PacketNumberLen: protocol.PacketNumberLen1,
}
rcvTime := time.Now().Add(-10 * time.Second)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{packetNumber: 0x1337}, nil)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{packetNumber: 0x1337}, nil)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false)
sess.receivedPacketHandler = rph
@ -481,7 +481,7 @@ var _ = Describe("Session", func() {
It("closes when handling a packet fails", func() {
testErr := errors.New("unpack error")
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, testErr)
streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
@ -499,7 +499,7 @@ var _ = Describe("Session", func() {
})
It("handles duplicate packets", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil).Times(2)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil).Times(2)
hdr := &wire.ExtendedHeader{
PacketNumber: 5,
PacketNumberLen: protocol.PacketNumberLen1,
@ -511,7 +511,7 @@ var _ = Describe("Session", func() {
It("ignores packets with a different source connection ID", func() {
// Send one packet, which might change the connection ID.
// only EXPECT one call to the unpacker
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
Expect(sess.handlePacketImpl(&receivedPacket{
hdr: &wire.Header{
IsLongHeader: true,
@ -582,7 +582,7 @@ var _ = Describe("Session", func() {
payloadLen := 456 - int(pnLen)
data := getData(hdr)
data = append(data, make([]byte, payloadLen)...)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ []byte, _ *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(payloadLen))
return &unpackedPacket{}, nil
})
@ -591,7 +591,7 @@ var _ = Describe("Session", func() {
Context("updating the remote address", func() {
It("doesn't support connection migration", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
origAddr := sess.conn.(*mockConnection).remoteAddr
remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
Expect(origAddr).ToNot(Equal(remoteIP))
@ -1362,7 +1362,7 @@ var _ = Describe("Client Session", func() {
It("changes the connection ID when receiving the first packet from the server", func() {
unpacker := NewMockUnpacker(mockCtrl)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
sess.unpacker = unpacker
go func() {
defer GinkgoRecover()