mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
use separate sealer interfaces for long and short header packets
This commit is contained in:
parent
c503769bcd
commit
b2d3ef691e
12 changed files with 233 additions and 172 deletions
|
@ -19,9 +19,10 @@ type sealer struct {
|
|||
is1RTT bool
|
||||
}
|
||||
|
||||
var _ Sealer = &sealer{}
|
||||
var _ LongHeaderSealer = &sealer{}
|
||||
var _ ShortHeaderSealer = &sealer{}
|
||||
|
||||
func newSealer(aead cipher.AEAD, hpEncrypter cipher.Block, is1RTT bool) Sealer {
|
||||
func newSealer(aead cipher.AEAD, hpEncrypter cipher.Block, is1RTT bool) ShortHeaderSealer {
|
||||
return &sealer{
|
||||
aead: aead,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
|
@ -57,6 +58,10 @@ func (s *sealer) Overhead() int {
|
|||
return s.aead.Overhead()
|
||||
}
|
||||
|
||||
func (s *sealer) KeyPhase() protocol.KeyPhase {
|
||||
return protocol.KeyPhaseZero
|
||||
}
|
||||
|
||||
type opener struct {
|
||||
aead cipher.AEAD
|
||||
pnDecrypter cipher.Block
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("AEAD", func() {
|
||||
getSealerAndOpener := func(is1RTT bool) (Sealer, Opener) {
|
||||
getSealerAndOpener := func(is1RTT bool) (ShortHeaderSealer, Opener) {
|
||||
key := make([]byte, 16)
|
||||
hpKey := make([]byte, 16)
|
||||
rand.Read(key)
|
||||
|
@ -29,7 +29,7 @@ var _ = Describe("AEAD", func() {
|
|||
|
||||
Context("message encryption", func() {
|
||||
var (
|
||||
sealer Sealer
|
||||
sealer ShortHeaderSealer
|
||||
opener Opener
|
||||
)
|
||||
|
||||
|
|
|
@ -106,15 +106,15 @@ type cryptoSetup struct {
|
|||
|
||||
initialStream io.Writer
|
||||
initialOpener Opener
|
||||
initialSealer Sealer
|
||||
initialSealer LongHeaderSealer
|
||||
|
||||
handshakeStream io.Writer
|
||||
handshakeOpener Opener
|
||||
handshakeSealer Sealer
|
||||
handshakeSealer LongHeaderSealer
|
||||
|
||||
oneRTTStream io.Writer
|
||||
opener Opener
|
||||
sealer Sealer
|
||||
sealer ShortHeaderSealer
|
||||
}
|
||||
|
||||
var _ qtls.RecordLayer = &cryptoSetup{}
|
||||
|
@ -564,14 +564,14 @@ func (h *cryptoSetup) SendAlert(alert uint8) {
|
|||
h.alertChan <- alert
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetInitialSealer() (Sealer, error) {
|
||||
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
return h.initialSealer, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetHandshakeSealer() (Sealer, error) {
|
||||
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
|
@ -581,7 +581,7 @@ func (h *cryptoSetup) GetHandshakeSealer() (Sealer, error) {
|
|||
return h.handshakeSealer, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get1RTTSealer() (Sealer, error) {
|
||||
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
var quicVersion1Salt = []byte{0xef, 0x4f, 0xb0, 0xab, 0xb4, 0x74, 0x70, 0xc4, 0x1b, 0xef, 0xcf, 0x80, 0x31, 0x33, 0x4f, 0xae, 0x48, 0x5e, 0x09, 0xa0}
|
||||
|
||||
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
|
||||
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Sealer, Opener, error) {
|
||||
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, Opener, error) {
|
||||
clientSecret, serverSecret := computeSecrets(connID)
|
||||
var mySecret, otherSecret []byte
|
||||
if pers == protocol.PerspectiveClient {
|
||||
|
|
|
@ -14,13 +14,19 @@ type Opener interface {
|
|||
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
|
||||
}
|
||||
|
||||
// Sealer seals a packet
|
||||
type Sealer interface {
|
||||
// LongHeaderSealer seals a long header packet
|
||||
type LongHeaderSealer interface {
|
||||
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
||||
EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
|
||||
Overhead() int
|
||||
}
|
||||
|
||||
// ShortHeaderSealer seals a short header packet
|
||||
type ShortHeaderSealer interface {
|
||||
LongHeaderSealer
|
||||
KeyPhase() protocol.KeyPhase
|
||||
}
|
||||
|
||||
// A tlsExtensionHandler sends and received the QUIC TLS extension.
|
||||
type tlsExtensionHandler interface {
|
||||
GetExtensions(msgType uint8) []qtls.Extension
|
||||
|
@ -49,7 +55,7 @@ type CryptoSetup interface {
|
|||
GetHandshakeOpener() (Opener, error)
|
||||
Get1RTTOpener() (Opener, error)
|
||||
|
||||
GetInitialSealer() (Sealer, error)
|
||||
GetHandshakeSealer() (Sealer, error)
|
||||
Get1RTTSealer() (Sealer, error)
|
||||
GetInitialSealer() (LongHeaderSealer, error)
|
||||
GetHandshakeSealer() (LongHeaderSealer, error)
|
||||
Get1RTTSealer() (ShortHeaderSealer, error)
|
||||
}
|
||||
|
|
|
@ -94,10 +94,10 @@ func (mr *MockCryptoSetupMockRecorder) Get1RTTOpener() *gomock.Call {
|
|||
}
|
||||
|
||||
// Get1RTTSealer mocks base method
|
||||
func (m *MockCryptoSetup) Get1RTTSealer() (handshake.Sealer, error) {
|
||||
func (m *MockCryptoSetup) Get1RTTSealer() (handshake.ShortHeaderSealer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get1RTTSealer")
|
||||
ret0, _ := ret[0].(handshake.Sealer)
|
||||
ret0, _ := ret[0].(handshake.ShortHeaderSealer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
@ -124,10 +124,10 @@ func (mr *MockCryptoSetupMockRecorder) GetHandshakeOpener() *gomock.Call {
|
|||
}
|
||||
|
||||
// GetHandshakeSealer mocks base method
|
||||
func (m *MockCryptoSetup) GetHandshakeSealer() (handshake.Sealer, error) {
|
||||
func (m *MockCryptoSetup) GetHandshakeSealer() (handshake.LongHeaderSealer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetHandshakeSealer")
|
||||
ret0, _ := ret[0].(handshake.Sealer)
|
||||
ret0, _ := ret[0].(handshake.LongHeaderSealer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
@ -154,10 +154,10 @@ func (mr *MockCryptoSetupMockRecorder) GetInitialOpener() *gomock.Call {
|
|||
}
|
||||
|
||||
// GetInitialSealer mocks base method
|
||||
func (m *MockCryptoSetup) GetInitialSealer() (handshake.Sealer, error) {
|
||||
func (m *MockCryptoSetup) GetInitialSealer() (handshake.LongHeaderSealer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetInitialSealer")
|
||||
ret0, _ := ret[0].(handshake.Sealer)
|
||||
ret0, _ := ret[0].(handshake.LongHeaderSealer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ package mocks
|
|||
|
||||
//go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go"
|
||||
//go:generate sh -c "mockgen -package mockquic -destination quic/session.go github.com/lucas-clemente/quic-go Session && goimports -w quic/session.go"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Sealer)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockSealer is a mock of Sealer interface
|
||||
type MockSealer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSealerMockRecorder
|
||||
}
|
||||
|
||||
// MockSealerMockRecorder is the mock recorder for MockSealer
|
||||
type MockSealerMockRecorder struct {
|
||||
mock *MockSealer
|
||||
}
|
||||
|
||||
// NewMockSealer creates a new mock instance
|
||||
func NewMockSealer(ctrl *gomock.Controller) *MockSealer {
|
||||
mock := &MockSealer{ctrl: ctrl}
|
||||
mock.recorder = &MockSealerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockSealer) EXPECT() *MockSealerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// EncryptHeader mocks base method
|
||||
func (m *MockSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "EncryptHeader", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// EncryptHeader indicates an expected call of EncryptHeader
|
||||
func (mr *MockSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockSealer)(nil).EncryptHeader), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// Overhead mocks base method
|
||||
func (m *MockSealer) Overhead() int {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Overhead")
|
||||
ret0, _ := ret[0].(int)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Overhead indicates an expected call of Overhead
|
||||
func (mr *MockSealerMockRecorder) Overhead() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockSealer)(nil).Overhead))
|
||||
}
|
||||
|
||||
// Seal mocks base method
|
||||
func (m *MockSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) []byte {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Seal", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Seal indicates an expected call of Seal
|
||||
func (mr *MockSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockSealer)(nil).Seal), arg0, arg1, arg2, arg3)
|
||||
}
|
89
internal/mocks/short_header_sealer.go
Normal file
89
internal/mocks/short_header_sealer.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: ShortHeaderSealer)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockShortHeaderSealer is a mock of ShortHeaderSealer interface
|
||||
type MockShortHeaderSealer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockShortHeaderSealerMockRecorder
|
||||
}
|
||||
|
||||
// MockShortHeaderSealerMockRecorder is the mock recorder for MockShortHeaderSealer
|
||||
type MockShortHeaderSealerMockRecorder struct {
|
||||
mock *MockShortHeaderSealer
|
||||
}
|
||||
|
||||
// NewMockShortHeaderSealer creates a new mock instance
|
||||
func NewMockShortHeaderSealer(ctrl *gomock.Controller) *MockShortHeaderSealer {
|
||||
mock := &MockShortHeaderSealer{ctrl: ctrl}
|
||||
mock.recorder = &MockShortHeaderSealerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockShortHeaderSealer) EXPECT() *MockShortHeaderSealerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// EncryptHeader mocks base method
|
||||
func (m *MockShortHeaderSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "EncryptHeader", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// EncryptHeader indicates an expected call of EncryptHeader
|
||||
func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockShortHeaderSealer)(nil).EncryptHeader), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// KeyPhase mocks base method
|
||||
func (m *MockShortHeaderSealer) KeyPhase() protocol.KeyPhase {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "KeyPhase")
|
||||
ret0, _ := ret[0].(protocol.KeyPhase)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// KeyPhase indicates an expected call of KeyPhase
|
||||
func (mr *MockShortHeaderSealerMockRecorder) KeyPhase() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyPhase", reflect.TypeOf((*MockShortHeaderSealer)(nil).KeyPhase))
|
||||
}
|
||||
|
||||
// Overhead mocks base method
|
||||
func (m *MockShortHeaderSealer) Overhead() int {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Overhead")
|
||||
ret0, _ := ret[0].(int)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Overhead indicates an expected call of Overhead
|
||||
func (mr *MockShortHeaderSealerMockRecorder) Overhead() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockShortHeaderSealer)(nil).Overhead))
|
||||
}
|
||||
|
||||
// Seal mocks base method
|
||||
func (m *MockShortHeaderSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) []byte {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Seal", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Seal indicates an expected call of Seal
|
||||
func (mr *MockShortHeaderSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockShortHeaderSealer)(nil).Seal), arg0, arg1, arg2, arg3)
|
||||
}
|
|
@ -35,10 +35,10 @@ func (m *MockSealingManager) EXPECT() *MockSealingManagerMockRecorder {
|
|||
}
|
||||
|
||||
// Get1RTTSealer mocks base method
|
||||
func (m *MockSealingManager) Get1RTTSealer() (handshake.Sealer, error) {
|
||||
func (m *MockSealingManager) Get1RTTSealer() (handshake.ShortHeaderSealer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get1RTTSealer")
|
||||
ret0, _ := ret[0].(handshake.Sealer)
|
||||
ret0, _ := ret[0].(handshake.ShortHeaderSealer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
@ -50,10 +50,10 @@ func (mr *MockSealingManagerMockRecorder) Get1RTTSealer() *gomock.Call {
|
|||
}
|
||||
|
||||
// GetHandshakeSealer mocks base method
|
||||
func (m *MockSealingManager) GetHandshakeSealer() (handshake.Sealer, error) {
|
||||
func (m *MockSealingManager) GetHandshakeSealer() (handshake.LongHeaderSealer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetHandshakeSealer")
|
||||
ret0, _ := ret[0].(handshake.Sealer)
|
||||
ret0, _ := ret[0].(handshake.LongHeaderSealer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
@ -65,10 +65,10 @@ func (mr *MockSealingManagerMockRecorder) GetHandshakeSealer() *gomock.Call {
|
|||
}
|
||||
|
||||
// GetInitialSealer mocks base method
|
||||
func (m *MockSealingManager) GetInitialSealer() (handshake.Sealer, error) {
|
||||
func (m *MockSealingManager) GetInitialSealer() (handshake.LongHeaderSealer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetInitialSealer")
|
||||
ret0, _ := ret[0].(handshake.Sealer)
|
||||
ret0, _ := ret[0].(handshake.LongHeaderSealer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
|
145
packet_packer.go
145
packet_packer.go
|
@ -25,6 +25,10 @@ type packer interface {
|
|||
ChangeDestConnectionID(protocol.ConnectionID)
|
||||
}
|
||||
|
||||
type sealer interface {
|
||||
handshake.LongHeaderSealer
|
||||
}
|
||||
|
||||
type payload struct {
|
||||
frames []wire.Frame
|
||||
ack *wire.AckFrame
|
||||
|
@ -93,9 +97,9 @@ type packetNumberManager interface {
|
|||
}
|
||||
|
||||
type sealingManager interface {
|
||||
GetInitialSealer() (handshake.Sealer, error)
|
||||
GetHandshakeSealer() (handshake.Sealer, error)
|
||||
Get1RTTSealer() (handshake.Sealer, error)
|
||||
GetInitialSealer() (handshake.LongHeaderSealer, error)
|
||||
GetHandshakeSealer() (handshake.LongHeaderSealer, error)
|
||||
Get1RTTSealer() (handshake.ShortHeaderSealer, error)
|
||||
}
|
||||
|
||||
type frameSource interface {
|
||||
|
@ -165,10 +169,11 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
|
|||
length: ccf.Length(p.version),
|
||||
}
|
||||
// send the CONNECTION_CLOSE frame with the highest available encryption level
|
||||
var sealer handshake.Sealer
|
||||
var err error
|
||||
var hdr *wire.ExtendedHeader
|
||||
var sealer sealer
|
||||
encLevel := protocol.Encryption1RTT
|
||||
sealer, err = p.cryptoSetup.Get1RTTSealer()
|
||||
s, err := p.cryptoSetup.Get1RTTSealer()
|
||||
if err != nil {
|
||||
encLevel = protocol.EncryptionHandshake
|
||||
sealer, err = p.cryptoSetup.GetHandshakeSealer()
|
||||
|
@ -178,11 +183,16 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdr = p.getLongHeader(protocol.EncryptionInitial)
|
||||
} else {
|
||||
hdr = p.getLongHeader(protocol.EncryptionHandshake)
|
||||
}
|
||||
} else {
|
||||
sealer = s
|
||||
hdr = p.getShortHeader(s.KeyPhase())
|
||||
}
|
||||
|
||||
header := p.getHeader(encLevel)
|
||||
return p.writeAndSealPacket(header, payload, encLevel, sealer)
|
||||
return p.writeAndSealPacket(hdr, payload, encLevel, sealer)
|
||||
}
|
||||
|
||||
func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
|
||||
|
@ -199,7 +209,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
header := p.getHeader(protocol.Encryption1RTT)
|
||||
header := p.getShortHeader(sealer.KeyPhase())
|
||||
return p.writeAndSealPacket(header, payload, protocol.Encryption1RTT, sealer)
|
||||
}
|
||||
|
||||
|
@ -222,28 +232,42 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP
|
|||
}
|
||||
|
||||
var packets []*packedPacket
|
||||
var err error
|
||||
var sealer handshake.Sealer
|
||||
switch packet.EncryptionLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
sealer, err = p.cryptoSetup.GetInitialSealer()
|
||||
case protocol.EncryptionHandshake:
|
||||
sealer, err = p.cryptoSetup.GetHandshakeSealer()
|
||||
case protocol.Encryption1RTT:
|
||||
sealer, err = p.cryptoSetup.Get1RTTSealer()
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected encryption level: %s", packet.EncryptionLevel)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for len(controlFrames) > 0 || len(streamFrames) > 0 {
|
||||
var frames []wire.Frame
|
||||
var length protocol.ByteCount
|
||||
|
||||
header := p.getHeader(packet.EncryptionLevel)
|
||||
headerLen := header.GetLength(p.version)
|
||||
maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen
|
||||
var sealer sealer
|
||||
var hdr *wire.ExtendedHeader
|
||||
switch packet.EncryptionLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
var err error
|
||||
sealer, err = p.cryptoSetup.GetInitialSealer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdr = p.getLongHeader(protocol.EncryptionInitial)
|
||||
case protocol.EncryptionHandshake:
|
||||
var err error
|
||||
sealer, err = p.cryptoSetup.GetHandshakeSealer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdr = p.getLongHeader(protocol.EncryptionHandshake)
|
||||
case protocol.Encryption1RTT:
|
||||
var s handshake.ShortHeaderSealer
|
||||
var err error
|
||||
s, err = p.cryptoSetup.Get1RTTSealer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sealer = s
|
||||
hdr = p.getShortHeader(s.KeyPhase())
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected encryption level: %s", packet.EncryptionLevel)
|
||||
}
|
||||
|
||||
hdrLen := hdr.GetLength(p.version)
|
||||
maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - hdrLen
|
||||
|
||||
for len(controlFrames) > 0 {
|
||||
frame := controlFrames[0]
|
||||
|
@ -277,7 +301,7 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP
|
|||
if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
|
||||
sf.DataLenPresent = false
|
||||
}
|
||||
p, err := p.writeAndSealPacket(header, payload{frames: frames, length: length}, packet.EncryptionLevel, sealer)
|
||||
p, err := p.writeAndSealPacket(hdr, payload{frames: frames, length: length}, packet.EncryptionLevel, sealer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -302,7 +326,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
|
|||
// sealer not yet available
|
||||
return nil, nil
|
||||
}
|
||||
header := p.getHeader(protocol.Encryption1RTT)
|
||||
header := p.getShortHeader(sealer.KeyPhase())
|
||||
headerLen := header.GetLength(p.version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -340,7 +364,7 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
|
|||
|
||||
hasData := p.initialStream.HasData()
|
||||
ack := p.acks.GetAckFrame(protocol.EncryptionInitial)
|
||||
var sealer handshake.Sealer
|
||||
var sealer handshake.LongHeaderSealer
|
||||
var err error
|
||||
if hasData || ack != nil {
|
||||
s = p.initialStream
|
||||
|
@ -367,7 +391,7 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
|
|||
payload.ack = ack
|
||||
payload.length = ack.Length(p.version)
|
||||
}
|
||||
hdr := p.getHeader(encLevel)
|
||||
hdr := p.getLongHeader(encLevel)
|
||||
hdrLen := hdr.GetLength(p.version)
|
||||
if hasData {
|
||||
cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length)
|
||||
|
@ -396,40 +420,49 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) (paylo
|
|||
return payload, nil
|
||||
}
|
||||
|
||||
func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
|
||||
pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
|
||||
header := &wire.ExtendedHeader{}
|
||||
header.PacketNumber = pn
|
||||
header.PacketNumberLen = pnLen
|
||||
header.Version = p.version
|
||||
header.DestConnectionID = p.destConnID
|
||||
func (p *packetPacker) getShortHeader(kp protocol.KeyPhase) *wire.ExtendedHeader {
|
||||
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
|
||||
hdr := &wire.ExtendedHeader{}
|
||||
hdr.PacketNumber = pn
|
||||
hdr.PacketNumberLen = pnLen
|
||||
hdr.DestConnectionID = p.destConnID
|
||||
hdr.KeyPhase = kp
|
||||
return hdr
|
||||
}
|
||||
|
||||
if encLevel != protocol.Encryption1RTT {
|
||||
header.IsLongHeader = true
|
||||
// Always send Initial and Handshake packets with the maximum packet number length.
|
||||
// This simplifies retransmissions: Since the header can't get any larger,
|
||||
// we don't need to split CRYPTO frames.
|
||||
header.PacketNumberLen = protocol.PacketNumberLen4
|
||||
header.SrcConnectionID = p.srcConnID
|
||||
// Set the length to the maximum packet size.
|
||||
// Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns.
|
||||
header.Length = p.maxPacketSize
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
header.Type = protocol.PacketTypeInitial
|
||||
case protocol.EncryptionHandshake:
|
||||
header.Type = protocol.PacketTypeHandshake
|
||||
}
|
||||
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
|
||||
pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
|
||||
hdr := &wire.ExtendedHeader{}
|
||||
hdr.PacketNumber = pn
|
||||
hdr.PacketNumberLen = pnLen
|
||||
hdr.DestConnectionID = p.destConnID
|
||||
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
hdr.Type = protocol.PacketTypeInitial
|
||||
case protocol.EncryptionHandshake:
|
||||
hdr.Type = protocol.PacketTypeHandshake
|
||||
}
|
||||
|
||||
return header
|
||||
hdr.Version = p.version
|
||||
hdr.IsLongHeader = true
|
||||
// Always send Initial and Handshake packets with the maximum packet number length.
|
||||
// This simplifies retransmissions: Since the header can't get any larger,
|
||||
// we don't need to split CRYPTO frames.
|
||||
hdr.PacketNumberLen = protocol.PacketNumberLen4
|
||||
hdr.SrcConnectionID = p.srcConnID
|
||||
// Set the length to the maximum packet size.
|
||||
// Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns.
|
||||
hdr.Length = p.maxPacketSize
|
||||
|
||||
return hdr
|
||||
}
|
||||
|
||||
func (p *packetPacker) writeAndSealPacket(
|
||||
header *wire.ExtendedHeader,
|
||||
payload payload,
|
||||
encLevel protocol.EncryptionLevel,
|
||||
sealer handshake.Sealer,
|
||||
sealer sealer,
|
||||
) (*packedPacket, error) {
|
||||
var paddingLen protocol.ByteCount
|
||||
pnLen := protocol.ByteCount(header.PacketNumberLen)
|
||||
|
@ -454,7 +487,7 @@ func (p *packetPacker) writeAndSealPacketWithPadding(
|
|||
payload payload,
|
||||
paddingLen protocol.ByteCount,
|
||||
encLevel protocol.EncryptionLevel,
|
||||
sealer handshake.Sealer,
|
||||
sealer sealer,
|
||||
) (*packedPacket, error) {
|
||||
packetBuffer := getPacketBuffer()
|
||||
buffer := bytes.NewBuffer(packetBuffer.Slice[:0])
|
||||
|
|
|
@ -106,7 +106,7 @@ var _ = Describe("Packet packer", func() {
|
|||
Context("generating a packet header", func() {
|
||||
It("uses the Long Header format", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
h := packer.getHeader(protocol.EncryptionHandshake)
|
||||
h := packer.getLongHeader(protocol.EncryptionHandshake)
|
||||
Expect(h.IsLongHeader).To(BeTrue())
|
||||
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x42)))
|
||||
// long headers always use 4 byte packet numbers, no matter what the packet number generator says
|
||||
|
@ -120,7 +120,7 @@ var _ = Describe("Packet packer", func() {
|
|||
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
packer.srcConnID = srcConnID
|
||||
packer.destConnID = destConnID
|
||||
h := packer.getHeader(protocol.EncryptionHandshake)
|
||||
h := packer.getLongHeader(protocol.EncryptionHandshake)
|
||||
Expect(h.SrcConnectionID).To(Equal(srcConnID))
|
||||
Expect(h.DestConnectionID).To(Equal(destConnID))
|
||||
})
|
||||
|
@ -132,21 +132,22 @@ var _ = Describe("Packet packer", func() {
|
|||
dest1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
dest2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
packer.ChangeDestConnectionID(dest1)
|
||||
h := packer.getHeader(protocol.EncryptionInitial)
|
||||
h := packer.getLongHeader(protocol.EncryptionInitial)
|
||||
Expect(h.SrcConnectionID).To(Equal(srcConnID))
|
||||
Expect(h.DestConnectionID).To(Equal(dest1))
|
||||
packer.ChangeDestConnectionID(dest2)
|
||||
h = packer.getHeader(protocol.EncryptionInitial)
|
||||
h = packer.getLongHeader(protocol.EncryptionInitial)
|
||||
Expect(h.SrcConnectionID).To(Equal(srcConnID))
|
||||
Expect(h.DestConnectionID).To(Equal(dest2))
|
||||
})
|
||||
|
||||
It("uses the Short Header format for 1-RTT packets", func() {
|
||||
It("gets a short header", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen4)
|
||||
h := packer.getHeader(protocol.Encryption1RTT)
|
||||
h := packer.getShortHeader(protocol.KeyPhaseOne)
|
||||
Expect(h.IsLongHeader).To(BeFalse())
|
||||
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x1337)))
|
||||
Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
|
||||
Expect(h.KeyPhase).To(Equal(protocol.KeyPhaseOne))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -156,10 +157,11 @@ var _ = Describe("Packet packer", func() {
|
|||
handshakeStream.EXPECT().HasData()
|
||||
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337))
|
||||
sealer := mocks.NewMockSealer(mockCtrl)
|
||||
sealer := mocks.NewMockShortHeaderSealer(mockCtrl)
|
||||
sealer.EXPECT().Overhead().Return(4).AnyTimes()
|
||||
var hdrRaw []byte
|
||||
gomock.InOrder(
|
||||
sealer.EXPECT().KeyPhase().Return(protocol.KeyPhaseOne),
|
||||
sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1337), gomock.Any()).DoAndReturn(func(_, src []byte, _ protocol.PacketNumber, aad []byte) []byte {
|
||||
hdrRaw = append([]byte{}, aad...)
|
||||
return append(src, []byte{0xde, 0xca, 0xfb, 0xad}...)
|
||||
|
@ -195,10 +197,11 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
|
||||
Context("packing packets", func() {
|
||||
var sealer *mocks.MockSealer
|
||||
var sealer *mocks.MockShortHeaderSealer
|
||||
|
||||
BeforeEach(func() {
|
||||
sealer = mocks.NewMockSealer(mockCtrl)
|
||||
sealer = mocks.NewMockShortHeaderSealer(mockCtrl)
|
||||
sealer.EXPECT().KeyPhase().Return(protocol.KeyPhaseOne).AnyTimes()
|
||||
sealer.EXPECT().Overhead().Return(7).AnyTimes()
|
||||
sealer.EXPECT().EncryptHeader(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) []byte {
|
||||
|
@ -505,7 +508,7 @@ var _ = Describe("Packet packer", func() {
|
|||
It("packs two packets for retransmission if the original packet contained many control frames", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
|
||||
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).Times(2)
|
||||
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
|
||||
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2)
|
||||
var frames []wire.Frame
|
||||
var totalLen protocol.ByteCount
|
||||
// pack a bunch of control frames, such that the packet is way bigger than a single packet
|
||||
|
@ -533,7 +536,7 @@ var _ = Describe("Packet packer", func() {
|
|||
It("splits a STREAM frame that doesn't fit", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
|
||||
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).Times(2)
|
||||
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
|
||||
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2)
|
||||
packets, err := packer.PackRetransmission(&ackhandler.Packet{
|
||||
EncryptionLevel: protocol.Encryption1RTT,
|
||||
Frames: []wire.Frame{&wire.StreamFrame{
|
||||
|
@ -598,7 +601,7 @@ var _ = Describe("Packet packer", func() {
|
|||
It("packs two packets for retransmission if the original packet contained many STREAM frames", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
|
||||
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).Times(2)
|
||||
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
|
||||
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2)
|
||||
var frames []wire.Frame
|
||||
var totalLen protocol.ByteCount
|
||||
// pack a bunch of control frames, such that the packet is way bigger than a single packet
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue