From 67f923c7369043826b3db82e152f07aa724b6ae0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 13 Dec 2018 14:10:24 +0630 Subject: [PATCH] return the Opener from the crypto setup --- internal/handshake/crypto_setup.go | 31 +++++++------ internal/handshake/interface.go | 5 +- internal/mocks/crypto_setup.go | 52 ++++++--------------- internal/mocks/mockgen.go | 1 + internal/mocks/opener.go | 48 +++++++++++++++++++ mock_quic_aead_test.go | 74 ------------------------------ mockgen.go | 1 - packet_unpacker.go | 32 ++++++------- packet_unpacker_test.go | 64 ++++++++++++++------------ 9 files changed, 128 insertions(+), 180 deletions(-) create mode 100644 internal/mocks/opener.go delete mode 100644 mock_quic_aead_test.go diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 70161e67..f36f3e93 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -493,22 +493,23 @@ func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLeve } } -func (h *cryptoSetup) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { - return h.initialOpener.Open(dst, src, pn, ad) -} - -func (h *cryptoSetup) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { - if h.handshakeOpener == nil { - return nil, errors.New("no handshake opener") +func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) { + switch level { + case protocol.EncryptionInitial: + return h.initialOpener, nil + case protocol.EncryptionHandshake: + if h.handshakeOpener == nil { + return nil, errors.New("CryptoSetup: no opener with encryption level Handshake") + } + return h.handshakeOpener, nil + case protocol.Encryption1RTT: + if h.opener == nil { + return nil, errors.New("CryptoSetup: no opener with encryption level 1-RTT") + } + return h.opener, nil + default: + return nil, fmt.Errorf("CryptoSetup: no opener with encryption level %s", level) } - return h.handshakeOpener.Open(dst, src, pn, ad) -} - -func (h *cryptoSetup) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { - if h.opener == nil { - return nil, errors.New("no 1-RTT opener") - } - return h.opener.Open(dst, src, pn, ad) } func (h *cryptoSetup) ConnectionState() ConnectionState { diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index a656df12..521af784 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -35,10 +35,7 @@ type CryptoSetup interface { GetSealer() (protocol.EncryptionLevel, Sealer) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) - - OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) - OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) - Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) + GetOpener(protocol.EncryptionLevel) (Opener, error) } // ConnectionState records basic details about the QUIC connection. diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index 76b9e755..ebded366 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -59,6 +59,19 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) } +// GetOpener mocks base method +func (m *MockCryptoSetup) GetOpener(arg0 protocol.EncryptionLevel) (handshake.Opener, error) { + ret := m.ctrl.Call(m, "GetOpener", arg0) + ret0, _ := ret[0].(handshake.Opener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOpener indicates an expected call of GetOpener +func (mr *MockCryptoSetupMockRecorder) GetOpener(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetOpener), arg0) +} + // GetSealer mocks base method func (m *MockCryptoSetup) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) { ret := m.ctrl.Call(m, "GetSealer") @@ -97,45 +110,6 @@ func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) } -// Open1RTT mocks base method -func (m *MockCryptoSetup) Open1RTT(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { - ret := m.ctrl.Call(m, "Open1RTT", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Open1RTT indicates an expected call of Open1RTT -func (mr *MockCryptoSetupMockRecorder) Open1RTT(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open1RTT", reflect.TypeOf((*MockCryptoSetup)(nil).Open1RTT), arg0, arg1, arg2, arg3) -} - -// OpenHandshake mocks base method -func (m *MockCryptoSetup) OpenHandshake(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { - ret := m.ctrl.Call(m, "OpenHandshake", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenHandshake indicates an expected call of OpenHandshake -func (mr *MockCryptoSetupMockRecorder) OpenHandshake(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).OpenHandshake), arg0, arg1, arg2, arg3) -} - -// OpenInitial mocks base method -func (m *MockCryptoSetup) OpenInitial(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { - ret := m.ctrl.Call(m, "OpenInitial", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenInitial indicates an expected call of OpenInitial -func (mr *MockCryptoSetupMockRecorder) OpenInitial(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenInitial", reflect.TypeOf((*MockCryptoSetup)(nil).OpenInitial), arg0, arg1, arg2, arg3) -} - // RunHandshake mocks base method func (m *MockCryptoSetup) RunHandshake() error { ret := m.ctrl.Call(m, "RunHandshake") diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index e49365c5..c46ce502 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -1,6 +1,7 @@ package mocks //go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer" +//go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener" //go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup" //go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" //go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler" diff --git a/internal/mocks/opener.go b/internal/mocks/opener.go new file mode 100644 index 00000000..0268a4a4 --- /dev/null +++ b/internal/mocks/opener.go @@ -0,0 +1,48 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Opener) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockOpener is a mock of Opener interface +type MockOpener struct { + ctrl *gomock.Controller + recorder *MockOpenerMockRecorder +} + +// MockOpenerMockRecorder is the mock recorder for MockOpener +type MockOpenerMockRecorder struct { + mock *MockOpener +} + +// NewMockOpener creates a new mock instance +func NewMockOpener(ctrl *gomock.Controller) *MockOpener { + mock := &MockOpener{ctrl: ctrl} + mock.recorder = &MockOpenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockOpener) EXPECT() *MockOpenerMockRecorder { + return m.recorder +} + +// Open mocks base method +func (m *MockOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { + ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Open indicates an expected call of Open +func (mr *MockOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockOpener)(nil).Open), arg0, arg1, arg2, arg3) +} diff --git a/mock_quic_aead_test.go b/mock_quic_aead_test.go deleted file mode 100644 index 80bb2224..00000000 --- a/mock_quic_aead_test.go +++ /dev/null @@ -1,74 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go (interfaces: QuicAEAD) - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// MockQuicAEAD is a mock of QuicAEAD interface -type MockQuicAEAD struct { - ctrl *gomock.Controller - recorder *MockQuicAEADMockRecorder -} - -// MockQuicAEADMockRecorder is the mock recorder for MockQuicAEAD -type MockQuicAEADMockRecorder struct { - mock *MockQuicAEAD -} - -// NewMockQuicAEAD creates a new mock instance -func NewMockQuicAEAD(ctrl *gomock.Controller) *MockQuicAEAD { - mock := &MockQuicAEAD{ctrl: ctrl} - mock.recorder = &MockQuicAEADMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockQuicAEAD) EXPECT() *MockQuicAEADMockRecorder { - return m.recorder -} - -// Open1RTT mocks base method -func (m *MockQuicAEAD) Open1RTT(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { - ret := m.ctrl.Call(m, "Open1RTT", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Open1RTT indicates an expected call of Open1RTT -func (mr *MockQuicAEADMockRecorder) Open1RTT(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open1RTT", reflect.TypeOf((*MockQuicAEAD)(nil).Open1RTT), arg0, arg1, arg2, arg3) -} - -// OpenHandshake mocks base method -func (m *MockQuicAEAD) OpenHandshake(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { - ret := m.ctrl.Call(m, "OpenHandshake", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenHandshake indicates an expected call of OpenHandshake -func (mr *MockQuicAEADMockRecorder) OpenHandshake(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenHandshake", reflect.TypeOf((*MockQuicAEAD)(nil).OpenHandshake), arg0, arg1, arg2, arg3) -} - -// OpenInitial mocks base method -func (m *MockQuicAEAD) OpenInitial(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { - ret := m.ctrl.Call(m, "OpenInitial", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenInitial indicates an expected call of OpenInitial -func (mr *MockQuicAEADMockRecorder) OpenInitial(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenInitial", reflect.TypeOf((*MockQuicAEAD)(nil).OpenInitial), arg0, arg1, arg2, arg3) -} diff --git a/mockgen.go b/mockgen.go index eb0d2e48..1a882c35 100644 --- a/mockgen.go +++ b/mockgen.go @@ -13,7 +13,6 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_sealing_manager_test.go github.com/lucas-clemente/quic-go sealingManager" //go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker" //go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/lucas-clemente/quic-go packer" -//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner" //go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession" //go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler" diff --git a/packet_unpacker.go b/packet_unpacker.go index cb7f2c0f..cc998b65 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" + "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/utils" @@ -17,15 +18,9 @@ type unpackedPacket struct { frames []wire.Frame } -type quicAEAD interface { - OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) - OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) - Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) -} - // The packetUnpacker unpacks QUIC packets. type packetUnpacker struct { - aead quicAEAD + cs handshake.CryptoSetup largestRcvdPacketNumber protocol.PacketNumber @@ -34,9 +29,9 @@ type packetUnpacker struct { var _ unpacker = &packetUnpacker{} -func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker { +func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) unpacker { return &packetUnpacker{ - aead: aead, + cs: cs, version: version, } } @@ -69,22 +64,23 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, buf = buf[:0] defer putPacketBuffer(&buf) - var decrypted []byte - var encryptionLevel protocol.EncryptionLevel + var encLevel protocol.EncryptionLevel switch hdr.Type { case protocol.PacketTypeInitial: - decrypted, err = u.aead.OpenInitial(buf, data, pn, extHdr.Raw) - encryptionLevel = protocol.EncryptionInitial + encLevel = protocol.EncryptionInitial case protocol.PacketTypeHandshake: - decrypted, err = u.aead.OpenHandshake(buf, data, pn, extHdr.Raw) - encryptionLevel = protocol.EncryptionHandshake + encLevel = protocol.EncryptionHandshake default: if hdr.IsLongHeader { return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) } - decrypted, err = u.aead.Open1RTT(buf, data, pn, extHdr.Raw) - encryptionLevel = protocol.Encryption1RTT + encLevel = protocol.Encryption1RTT } + opener, err := u.cs.GetOpener(encLevel) + if err != nil { + return nil, qerr.Error(qerr.DecryptionFailure, err.Error()) + } + decrypted, err := opener.Open(buf, data, pn, extHdr.Raw) if err != nil { return nil, qerr.Error(qerr.DecryptionFailure, err.Error()) } @@ -100,7 +96,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, return &unpackedPacket{ hdr: extHdr, packetNumber: pn, - encryptionLevel: encryptionLevel, + encryptionLevel: encLevel, frames: fs, }, nil } diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 331289ba..110c518b 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" @@ -17,7 +18,7 @@ var _ = Describe("Packet Unpacker", func() { const version = protocol.VersionTLS var ( unpacker *packetUnpacker - aead *MockQuicAEAD + cs *mocks.MockCryptoSetup connID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} ) @@ -30,8 +31,8 @@ var _ = Describe("Packet Unpacker", func() { } BeforeEach(func() { - aead = NewMockQuicAEAD(mockCtrl) - unpacker = newPacketUnpacker(aead, version).(*packetUnpacker) + cs = mocks.NewMockCryptoSetup(mockCtrl) + unpacker = newPacketUnpacker(cs, version).(*packetUnpacker) }) It("errors if the packet doesn't contain any payload", func() { @@ -43,7 +44,9 @@ var _ = Describe("Packet Unpacker", func() { hdr, hdrRaw := getHeader(extHdr) data := append(hdrRaw, []byte("foobar")...) // add some payload // return an empty (unencrypted) payload - aead.EXPECT().Open1RTT(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{}, nil) + opener := mocks.NewMockOpener(mockCtrl) + cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(opener, nil) + opener.EXPECT().Open(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{}, nil) _, err := unpacker.Unpack(hdr, data) Expect(err).To(MatchError(qerr.MissingPayload)) }) @@ -61,31 +64,14 @@ var _ = Describe("Packet Unpacker", func() { PacketNumberLen: 3, } hdr, hdrRaw := getHeader(extHdr) - aead.EXPECT().OpenInitial(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{0}, nil) + opener := mocks.NewMockOpener(mockCtrl) + cs.EXPECT().GetOpener(protocol.EncryptionInitial).Return(opener, nil) + opener.EXPECT().Open(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{0}, nil) packet, err := unpacker.Unpack(hdr, append(hdrRaw, []byte("foobar")...)) Expect(err).ToNot(HaveOccurred()) Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) }) - It("opens Handshake packets", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Length: 3 + 6, // packet number len + payload - DestConnectionID: connID, - Version: version, - }, - PacketNumber: 2, - PacketNumberLen: 3, - } - hdr, hdrRaw := getHeader(extHdr) - aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).Return([]byte{0}, nil) - packet, err := unpacker.Unpack(hdr, append(hdrRaw, []byte("foobar")...)) - Expect(err).ToNot(HaveOccurred()) - Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionHandshake)) - }) - It("errors on packets that are smaller than the length in the packet header", func() { extHdr := &wire.ExtendedHeader{ Header: wire.Header{ @@ -118,7 +104,9 @@ var _ = Describe("Packet Unpacker", func() { payloadLen := 456 - int(pnLen) hdr, hdrRaw := getHeader(extHdr) data := append(hdrRaw, make([]byte, payloadLen)...) - aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).DoAndReturn(func(_, payload []byte, _ protocol.PacketNumber, _ []byte) ([]byte, error) { + opener := mocks.NewMockOpener(mockCtrl) + cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).DoAndReturn(func(_, payload []byte, _ protocol.PacketNumber, _ []byte) ([]byte, error) { Expect(payload).To(HaveLen(payloadLen)) return []byte{0}, nil }) @@ -126,6 +114,18 @@ var _ = Describe("Packet Unpacker", func() { Expect(err).ToNot(HaveOccurred()) }) + It("returns the error when getting the sealer fails", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: connID}, + PacketNumber: 0x1337, + PacketNumberLen: 2, + } + hdr, hdrRaw := getHeader(extHdr) + cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(nil, errors.New("test err")) + _, err := unpacker.Unpack(hdr, hdrRaw) + Expect(err).To(MatchError(qerr.Error(qerr.DecryptionFailure, "test err"))) + }) + It("returns the error when unpacking fails", func() { extHdr := &wire.ExtendedHeader{ Header: wire.Header{ @@ -139,7 +139,9 @@ var _ = Describe("Packet Unpacker", func() { PacketNumberLen: 3, } hdr, hdrRaw := getHeader(extHdr) - aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err")) + opener := mocks.NewMockOpener(mockCtrl) + cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err")) _, err := unpacker.Unpack(hdr, hdrRaw) Expect(err).To(MatchError(qerr.Error(qerr.DecryptionFailure, "test err"))) }) @@ -150,7 +152,9 @@ var _ = Describe("Packet Unpacker", func() { PacketNumber: 0x1337, PacketNumberLen: 2, } - aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil) + opener := mocks.NewMockOpener(mockCtrl) + cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(opener, nil).Times(2) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil) packet, err := unpacker.Unpack(getHeader(firstHdr)) Expect(err).ToNot(HaveOccurred()) Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337))) @@ -161,7 +165,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumberLen: 1, } // expect the call with the decoded packet number - aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil) packet, err = unpacker.Unpack(getHeader(secondHdr)) Expect(err).ToNot(HaveOccurred()) Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338))) @@ -177,7 +181,9 @@ var _ = Describe("Packet Unpacker", func() { (&wire.PingFrame{}).Write(buf, protocol.VersionWhatever) (&wire.DataBlockedFrame{}).Write(buf, protocol.VersionWhatever) hdr, hdrRaw := getHeader(extHdr) - aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).Return(buf.Bytes(), nil) + opener := mocks.NewMockOpener(mockCtrl) + cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(opener, nil) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).Return(buf.Bytes(), nil) packet, err := unpacker.Unpack(hdr, append(hdrRaw, buf.Bytes()...)) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(Equal([]wire.Frame{&wire.PingFrame{}, &wire.DataBlockedFrame{}}))