From d4d3f09ee37643a179f6aad419befa286a58d5ae Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 10 Jun 2019 14:34:42 +0800 Subject: [PATCH] use separate functions per encryption level to get openers --- internal/handshake/crypto_setup.go | 47 +++++++++++++++++------------- internal/handshake/interface.go | 5 +++- internal/mocks/crypto_setup.go | 42 ++++++++++++++++++++++---- packet_unpacker.go | 6 +++- packet_unpacker_test.go | 12 ++++---- 5 files changed, 77 insertions(+), 35 deletions(-) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 1d343f45..7b5532a8 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -601,33 +601,38 @@ func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLeve } } -func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) { +func (h *cryptoSetup) GetInitialOpener() (Opener, error) { h.mutex.Lock() defer h.mutex.Unlock() - switch level { - case protocol.EncryptionInitial: - if h.initialOpener == nil { - return nil, ErrKeysDropped - } - return h.initialOpener, nil - case protocol.EncryptionHandshake: - if h.handshakeOpener == nil { - if h.initialOpener != nil { - return nil, ErrOpenerNotYetAvailable - } - // if the initial opener is also not available, the keys were already dropped - return nil, ErrKeysDropped - } - return h.handshakeOpener, nil - case protocol.Encryption1RTT: - if h.opener == nil { + if h.initialOpener == nil { + return nil, ErrKeysDropped + } + return h.initialOpener, nil +} + +func (h *cryptoSetup) GetHandshakeOpener() (Opener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.handshakeOpener == nil { + if h.initialOpener != nil { return nil, ErrOpenerNotYetAvailable } - return h.opener, nil - default: - return nil, fmt.Errorf("CryptoSetup: no opener with encryption level %s", level) + // if the initial opener is also not available, the keys were already dropped + return nil, ErrKeysDropped } + return h.handshakeOpener, nil +} + +func (h *cryptoSetup) Get1RTTOpener() (Opener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.opener == nil { + return nil, ErrOpenerNotYetAvailable + } + return h.opener, nil } func (h *cryptoSetup) ConnectionState() tls.ConnectionState { diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 09fdcf1a..e7411c71 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -45,7 +45,10 @@ type CryptoSetup interface { Received1RTTAck() ConnectionState() tls.ConnectionState + GetInitialOpener() (Opener, error) + GetHandshakeOpener() (Opener, error) + Get1RTTOpener() (Opener, error) + GetSealer() (protocol.EncryptionLevel, Sealer) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) - GetOpener(protocol.EncryptionLevel) (Opener, error) } diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index 81585647..20acb112 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -78,19 +78,49 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) } -// GetOpener mocks base method -func (m *MockCryptoSetup) GetOpener(arg0 protocol.EncryptionLevel) (handshake.Opener, error) { +// Get1RTTOpener mocks base method +func (m *MockCryptoSetup) Get1RTTOpener() (handshake.Opener, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOpener", arg0) + ret := m.ctrl.Call(m, "Get1RTTOpener") ret0, _ := ret[0].(handshake.Opener) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetOpener indicates an expected call of GetOpener -func (mr *MockCryptoSetupMockRecorder) GetOpener(arg0 interface{}) *gomock.Call { +// Get1RTTOpener indicates an expected call of Get1RTTOpener +func (mr *MockCryptoSetupMockRecorder) Get1RTTOpener() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetOpener), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTOpener)) +} + +// GetHandshakeOpener mocks base method +func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.Opener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHandshakeOpener") + ret0, _ := ret[0].(handshake.Opener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetHandshakeOpener indicates an expected call of GetHandshakeOpener +func (mr *MockCryptoSetupMockRecorder) GetHandshakeOpener() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeOpener)) +} + +// GetInitialOpener mocks base method +func (m *MockCryptoSetup) GetInitialOpener() (handshake.Opener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetInitialOpener") + ret0, _ := ret[0].(handshake.Opener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInitialOpener indicates an expected call of GetInitialOpener +func (mr *MockCryptoSetupMockRecorder) GetInitialOpener() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialOpener)) } // GetSealer mocks base method diff --git a/packet_unpacker.go b/packet_unpacker.go index f3b5a2f7..6e246552 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -38,19 +38,23 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) { r := bytes.NewReader(data) + var opener handshake.Opener var encLevel protocol.EncryptionLevel + var err error switch hdr.Type { case protocol.PacketTypeInitial: encLevel = protocol.EncryptionInitial + opener, err = u.cs.GetInitialOpener() case protocol.PacketTypeHandshake: encLevel = protocol.EncryptionHandshake + opener, err = u.cs.GetHandshakeOpener() default: if hdr.IsLongHeader { return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) } encLevel = protocol.Encryption1RTT + opener, err = u.cs.Get1RTTOpener() } - opener, err := u.cs.GetOpener(encLevel) if err != nil { return nil, err } diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 579ff7f0..7e274704 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -49,7 +49,7 @@ var _ = Describe("Packet Unpacker", func() { hdr, hdrRaw := getHeader(extHdr) data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) opener := mocks.NewMockOpener(mockCtrl) - cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(opener, nil) + cs.EXPECT().Get1RTTOpener().Return(opener, nil) _, err := unpacker.Unpack(hdr, data) Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) }) @@ -68,7 +68,7 @@ var _ = Describe("Packet Unpacker", func() { } hdr, hdrRaw := getHeader(extHdr) opener := mocks.NewMockOpener(mockCtrl) - cs.EXPECT().GetOpener(protocol.EncryptionInitial).Return(opener, nil) + cs.EXPECT().GetInitialOpener().Return(opener, nil) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil) packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) @@ -84,7 +84,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumberLen: 2, } hdr, hdrRaw := getHeader(extHdr) - cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(nil, handshake.ErrOpenerNotYetAvailable) + cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrOpenerNotYetAvailable) _, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) Expect(err).To(MatchError(handshake.ErrOpenerNotYetAvailable)) }) @@ -103,7 +103,7 @@ var _ = Describe("Packet Unpacker", func() { } hdr, hdrRaw := getHeader(extHdr) opener := mocks.NewMockOpener(mockCtrl) - cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil) + cs.EXPECT().GetHandshakeOpener().Return(opener, nil) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err")) _, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) @@ -130,7 +130,7 @@ var _ = Describe("Packet Unpacker", func() { hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number Expect(hdrRaw[0]).ToNot(Equal(firstHdrByte)) opener := mocks.NewMockOpener(mockCtrl) - cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil) + cs.EXPECT().GetHandshakeOpener().Return(opener, nil) gomock.InOrder( // we're using a 2 byte packet number, so the sample starts at the 3rd payload byte opener.EXPECT().DecryptHeader( @@ -160,7 +160,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumberLen: 2, } opener := mocks.NewMockOpener(mockCtrl) - cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(opener, nil).Times(2) + cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil) hdr, hdrRaw := getHeader(firstHdr)