mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 04:07:35 +03:00
use separate functions per encryption level to get openers
This commit is contained in:
parent
14a31d49a0
commit
d4d3f09ee3
5 changed files with 77 additions and 35 deletions
|
@ -601,33 +601,38 @@ func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLeve
|
|||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) {
|
||||
func (h *cryptoSetup) GetInitialOpener() (Opener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
switch level {
|
||||
case protocol.EncryptionInitial:
|
||||
if h.initialOpener == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.initialOpener, nil
|
||||
case protocol.EncryptionHandshake:
|
||||
if h.handshakeOpener == nil {
|
||||
if h.initialOpener != nil {
|
||||
return nil, ErrOpenerNotYetAvailable
|
||||
}
|
||||
// if the initial opener is also not available, the keys were already dropped
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.handshakeOpener, nil
|
||||
case protocol.Encryption1RTT:
|
||||
if h.opener == nil {
|
||||
if h.initialOpener == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.initialOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetHandshakeOpener() (Opener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.handshakeOpener == nil {
|
||||
if h.initialOpener != nil {
|
||||
return nil, ErrOpenerNotYetAvailable
|
||||
}
|
||||
return h.opener, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("CryptoSetup: no opener with encryption level %s", level)
|
||||
// if the initial opener is also not available, the keys were already dropped
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.handshakeOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get1RTTOpener() (Opener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.opener == nil {
|
||||
return nil, ErrOpenerNotYetAvailable
|
||||
}
|
||||
return h.opener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) ConnectionState() tls.ConnectionState {
|
||||
|
|
|
@ -45,7 +45,10 @@ type CryptoSetup interface {
|
|||
Received1RTTAck()
|
||||
ConnectionState() tls.ConnectionState
|
||||
|
||||
GetInitialOpener() (Opener, error)
|
||||
GetHandshakeOpener() (Opener, error)
|
||||
Get1RTTOpener() (Opener, error)
|
||||
|
||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||
GetOpener(protocol.EncryptionLevel) (Opener, error)
|
||||
}
|
||||
|
|
|
@ -78,19 +78,49 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
|
||||
}
|
||||
|
||||
// GetOpener mocks base method
|
||||
func (m *MockCryptoSetup) GetOpener(arg0 protocol.EncryptionLevel) (handshake.Opener, error) {
|
||||
// Get1RTTOpener mocks base method
|
||||
func (m *MockCryptoSetup) Get1RTTOpener() (handshake.Opener, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOpener", arg0)
|
||||
ret := m.ctrl.Call(m, "Get1RTTOpener")
|
||||
ret0, _ := ret[0].(handshake.Opener)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOpener indicates an expected call of GetOpener
|
||||
func (mr *MockCryptoSetupMockRecorder) GetOpener(arg0 interface{}) *gomock.Call {
|
||||
// Get1RTTOpener indicates an expected call of Get1RTTOpener
|
||||
func (mr *MockCryptoSetupMockRecorder) Get1RTTOpener() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetOpener), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTOpener))
|
||||
}
|
||||
|
||||
// GetHandshakeOpener mocks base method
|
||||
func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.Opener, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetHandshakeOpener")
|
||||
ret0, _ := ret[0].(handshake.Opener)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetHandshakeOpener indicates an expected call of GetHandshakeOpener
|
||||
func (mr *MockCryptoSetupMockRecorder) GetHandshakeOpener() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeOpener))
|
||||
}
|
||||
|
||||
// GetInitialOpener mocks base method
|
||||
func (m *MockCryptoSetup) GetInitialOpener() (handshake.Opener, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetInitialOpener")
|
||||
ret0, _ := ret[0].(handshake.Opener)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetInitialOpener indicates an expected call of GetInitialOpener
|
||||
func (mr *MockCryptoSetupMockRecorder) GetInitialOpener() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialOpener))
|
||||
}
|
||||
|
||||
// GetSealer mocks base method
|
||||
|
|
|
@ -38,19 +38,23 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber)
|
|||
func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
var opener handshake.Opener
|
||||
var encLevel protocol.EncryptionLevel
|
||||
var err error
|
||||
switch hdr.Type {
|
||||
case protocol.PacketTypeInitial:
|
||||
encLevel = protocol.EncryptionInitial
|
||||
opener, err = u.cs.GetInitialOpener()
|
||||
case protocol.PacketTypeHandshake:
|
||||
encLevel = protocol.EncryptionHandshake
|
||||
opener, err = u.cs.GetHandshakeOpener()
|
||||
default:
|
||||
if hdr.IsLongHeader {
|
||||
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
|
||||
}
|
||||
encLevel = protocol.Encryption1RTT
|
||||
opener, err = u.cs.Get1RTTOpener()
|
||||
}
|
||||
opener, err := u.cs.GetOpener(encLevel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
hdr, hdrRaw := getHeader(extHdr)
|
||||
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
|
||||
opener := mocks.NewMockOpener(mockCtrl)
|
||||
cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(opener, nil)
|
||||
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
|
||||
_, err := unpacker.Unpack(hdr, data)
|
||||
Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19"))
|
||||
})
|
||||
|
@ -68,7 +68,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
}
|
||||
hdr, hdrRaw := getHeader(extHdr)
|
||||
opener := mocks.NewMockOpener(mockCtrl)
|
||||
cs.EXPECT().GetOpener(protocol.EncryptionInitial).Return(opener, nil)
|
||||
cs.EXPECT().GetInitialOpener().Return(opener, nil)
|
||||
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil)
|
||||
packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
|
||||
|
@ -84,7 +84,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
PacketNumberLen: 2,
|
||||
}
|
||||
hdr, hdrRaw := getHeader(extHdr)
|
||||
cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(nil, handshake.ErrOpenerNotYetAvailable)
|
||||
cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrOpenerNotYetAvailable)
|
||||
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
|
||||
Expect(err).To(MatchError(handshake.ErrOpenerNotYetAvailable))
|
||||
})
|
||||
|
@ -103,7 +103,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
}
|
||||
hdr, hdrRaw := getHeader(extHdr)
|
||||
opener := mocks.NewMockOpener(mockCtrl)
|
||||
cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil)
|
||||
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
|
||||
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
|
||||
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
|
||||
|
@ -130,7 +130,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number
|
||||
Expect(hdrRaw[0]).ToNot(Equal(firstHdrByte))
|
||||
opener := mocks.NewMockOpener(mockCtrl)
|
||||
cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil)
|
||||
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
|
||||
gomock.InOrder(
|
||||
// we're using a 2 byte packet number, so the sample starts at the 3rd payload byte
|
||||
opener.EXPECT().DecryptHeader(
|
||||
|
@ -160,7 +160,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
PacketNumberLen: 2,
|
||||
}
|
||||
opener := mocks.NewMockOpener(mockCtrl)
|
||||
cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(opener, nil).Times(2)
|
||||
cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2)
|
||||
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil)
|
||||
hdr, hdrRaw := getHeader(firstHdr)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue