use separate functions per encryption level to get openers

This commit is contained in:
Marten Seemann 2019-06-10 14:34:42 +08:00
parent 14a31d49a0
commit d4d3f09ee3
5 changed files with 77 additions and 35 deletions

View file

@ -601,33 +601,38 @@ func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLeve
} }
} }
func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) { func (h *cryptoSetup) GetInitialOpener() (Opener, error) {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
switch level { if h.initialOpener == nil {
case protocol.EncryptionInitial: return nil, ErrKeysDropped
if h.initialOpener == nil { }
return nil, ErrKeysDropped return h.initialOpener, nil
} }
return h.initialOpener, nil
case protocol.EncryptionHandshake: func (h *cryptoSetup) GetHandshakeOpener() (Opener, error) {
if h.handshakeOpener == nil { h.mutex.Lock()
if h.initialOpener != nil { defer h.mutex.Unlock()
return nil, ErrOpenerNotYetAvailable
} if h.handshakeOpener == nil {
// if the initial opener is also not available, the keys were already dropped if h.initialOpener != nil {
return nil, ErrKeysDropped
}
return h.handshakeOpener, nil
case protocol.Encryption1RTT:
if h.opener == nil {
return nil, ErrOpenerNotYetAvailable return nil, ErrOpenerNotYetAvailable
} }
return h.opener, nil // if the initial opener is also not available, the keys were already dropped
default: return nil, ErrKeysDropped
return nil, fmt.Errorf("CryptoSetup: no opener with encryption level %s", level)
} }
return h.handshakeOpener, nil
}
func (h *cryptoSetup) Get1RTTOpener() (Opener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.opener == nil {
return nil, ErrOpenerNotYetAvailable
}
return h.opener, nil
} }
func (h *cryptoSetup) ConnectionState() tls.ConnectionState { func (h *cryptoSetup) ConnectionState() tls.ConnectionState {

View file

@ -45,7 +45,10 @@ type CryptoSetup interface {
Received1RTTAck() Received1RTTAck()
ConnectionState() tls.ConnectionState ConnectionState() tls.ConnectionState
GetInitialOpener() (Opener, error)
GetHandshakeOpener() (Opener, error)
Get1RTTOpener() (Opener, error)
GetSealer() (protocol.EncryptionLevel, Sealer) GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetOpener(protocol.EncryptionLevel) (Opener, error)
} }

View file

@ -78,19 +78,49 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
} }
// GetOpener mocks base method // Get1RTTOpener mocks base method
func (m *MockCryptoSetup) GetOpener(arg0 protocol.EncryptionLevel) (handshake.Opener, error) { func (m *MockCryptoSetup) Get1RTTOpener() (handshake.Opener, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOpener", arg0) ret := m.ctrl.Call(m, "Get1RTTOpener")
ret0, _ := ret[0].(handshake.Opener) ret0, _ := ret[0].(handshake.Opener)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetOpener indicates an expected call of GetOpener // Get1RTTOpener indicates an expected call of Get1RTTOpener
func (mr *MockCryptoSetupMockRecorder) GetOpener(arg0 interface{}) *gomock.Call { func (mr *MockCryptoSetupMockRecorder) Get1RTTOpener() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetOpener), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTOpener))
}
// GetHandshakeOpener mocks base method
func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.Opener, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetHandshakeOpener")
ret0, _ := ret[0].(handshake.Opener)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetHandshakeOpener indicates an expected call of GetHandshakeOpener
func (mr *MockCryptoSetupMockRecorder) GetHandshakeOpener() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeOpener))
}
// GetInitialOpener mocks base method
func (m *MockCryptoSetup) GetInitialOpener() (handshake.Opener, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetInitialOpener")
ret0, _ := ret[0].(handshake.Opener)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetInitialOpener indicates an expected call of GetInitialOpener
func (mr *MockCryptoSetupMockRecorder) GetInitialOpener() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialOpener))
} }
// GetSealer mocks base method // GetSealer mocks base method

View file

@ -38,19 +38,23 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber)
func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) { func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
r := bytes.NewReader(data) r := bytes.NewReader(data)
var opener handshake.Opener
var encLevel protocol.EncryptionLevel var encLevel protocol.EncryptionLevel
var err error
switch hdr.Type { switch hdr.Type {
case protocol.PacketTypeInitial: case protocol.PacketTypeInitial:
encLevel = protocol.EncryptionInitial encLevel = protocol.EncryptionInitial
opener, err = u.cs.GetInitialOpener()
case protocol.PacketTypeHandshake: case protocol.PacketTypeHandshake:
encLevel = protocol.EncryptionHandshake encLevel = protocol.EncryptionHandshake
opener, err = u.cs.GetHandshakeOpener()
default: default:
if hdr.IsLongHeader { if hdr.IsLongHeader {
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
} }
encLevel = protocol.Encryption1RTT encLevel = protocol.Encryption1RTT
opener, err = u.cs.Get1RTTOpener()
} }
opener, err := u.cs.GetOpener(encLevel)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -49,7 +49,7 @@ var _ = Describe("Packet Unpacker", func() {
hdr, hdrRaw := getHeader(extHdr) hdr, hdrRaw := getHeader(extHdr)
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockOpener(mockCtrl) opener := mocks.NewMockOpener(mockCtrl)
cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
_, err := unpacker.Unpack(hdr, data) _, err := unpacker.Unpack(hdr, data)
Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19"))
}) })
@ -68,7 +68,7 @@ var _ = Describe("Packet Unpacker", func() {
} }
hdr, hdrRaw := getHeader(extHdr) hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockOpener(mockCtrl) opener := mocks.NewMockOpener(mockCtrl)
cs.EXPECT().GetOpener(protocol.EncryptionInitial).Return(opener, nil) cs.EXPECT().GetInitialOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil) opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil)
packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
@ -84,7 +84,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumberLen: 2, PacketNumberLen: 2,
} }
hdr, hdrRaw := getHeader(extHdr) hdr, hdrRaw := getHeader(extHdr)
cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(nil, handshake.ErrOpenerNotYetAvailable) cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrOpenerNotYetAvailable)
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) _, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrOpenerNotYetAvailable)) Expect(err).To(MatchError(handshake.ErrOpenerNotYetAvailable))
}) })
@ -103,7 +103,7 @@ var _ = Describe("Packet Unpacker", func() {
} }
hdr, hdrRaw := getHeader(extHdr) hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockOpener(mockCtrl) opener := mocks.NewMockOpener(mockCtrl)
cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err")) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) _, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
@ -130,7 +130,7 @@ var _ = Describe("Packet Unpacker", func() {
hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number
Expect(hdrRaw[0]).ToNot(Equal(firstHdrByte)) Expect(hdrRaw[0]).ToNot(Equal(firstHdrByte))
opener := mocks.NewMockOpener(mockCtrl) opener := mocks.NewMockOpener(mockCtrl)
cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
gomock.InOrder( gomock.InOrder(
// we're using a 2 byte packet number, so the sample starts at the 3rd payload byte // we're using a 2 byte packet number, so the sample starts at the 3rd payload byte
opener.EXPECT().DecryptHeader( opener.EXPECT().DecryptHeader(
@ -160,7 +160,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumberLen: 2, PacketNumberLen: 2,
} }
opener := mocks.NewMockOpener(mockCtrl) opener := mocks.NewMockOpener(mockCtrl)
cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(opener, nil).Times(2) cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw := getHeader(firstHdr) hdr, hdrRaw := getHeader(firstHdr)