mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 04:07:35 +03:00
use separate opener interfaces for long and short header packets
This commit is contained in:
parent
b2d3ef691e
commit
598628d05b
12 changed files with 320 additions and 149 deletions
|
@ -62,47 +62,77 @@ func (s *sealer) KeyPhase() protocol.KeyPhase {
|
|||
return protocol.KeyPhaseZero
|
||||
}
|
||||
|
||||
type opener struct {
|
||||
type longHeaderOpener struct {
|
||||
aead cipher.AEAD
|
||||
pnDecrypter cipher.Block
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
hpMask []byte
|
||||
|
||||
// short headers protect 5 bits in the first byte, long headers only 4
|
||||
is1RTT bool
|
||||
}
|
||||
|
||||
var _ Opener = &opener{}
|
||||
var _ LongHeaderOpener = &longHeaderOpener{}
|
||||
|
||||
func newOpener(aead cipher.AEAD, pnDecrypter cipher.Block, is1RTT bool) Opener {
|
||||
return &opener{
|
||||
func newLongHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) LongHeaderOpener {
|
||||
return &longHeaderOpener{
|
||||
aead: aead,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
is1RTT: is1RTT,
|
||||
pnDecrypter: pnDecrypter,
|
||||
hpMask: make([]byte, pnDecrypter.BlockSize()),
|
||||
}
|
||||
}
|
||||
|
||||
func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
return o.aead.Open(dst, o.nonceBuf, src, ad)
|
||||
}
|
||||
|
||||
func (o *opener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
if len(sample) != o.pnDecrypter.BlockSize() {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
o.pnDecrypter.Encrypt(o.hpMask, sample)
|
||||
if o.is1RTT {
|
||||
*firstByte ^= o.hpMask[0] & 0x1f
|
||||
} else {
|
||||
*firstByte ^= o.hpMask[0] & 0xf
|
||||
}
|
||||
*firstByte ^= o.hpMask[0] & 0xf
|
||||
for i := range pnBytes {
|
||||
pnBytes[i] ^= o.hpMask[i+1]
|
||||
}
|
||||
}
|
||||
|
||||
type shortHeaderOpener struct {
|
||||
aead cipher.AEAD
|
||||
pnDecrypter cipher.Block
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
hpMask []byte
|
||||
}
|
||||
|
||||
var _ ShortHeaderOpener = &shortHeaderOpener{}
|
||||
|
||||
func newShortHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) ShortHeaderOpener {
|
||||
return &shortHeaderOpener{
|
||||
aead: aead,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
pnDecrypter: pnDecrypter,
|
||||
hpMask: make([]byte, pnDecrypter.BlockSize()),
|
||||
}
|
||||
}
|
||||
|
||||
func (o *shortHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, _ protocol.KeyPhase, ad []byte) ([]byte, error) {
|
||||
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
return o.aead.Open(dst, o.nonceBuf, src, ad)
|
||||
}
|
||||
|
||||
func (o *shortHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
if len(sample) != o.pnDecrypter.BlockSize() {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
o.pnDecrypter.Encrypt(o.hpMask, sample)
|
||||
*firstByte ^= o.hpMask[0] & 0x1f
|
||||
for i := range pnBytes {
|
||||
pnBytes[i] ^= o.hpMask[i+1]
|
||||
}
|
||||
|
|
|
@ -5,12 +5,13 @@ import (
|
|||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("AEAD", func() {
|
||||
getSealerAndOpener := func(is1RTT bool) (ShortHeaderSealer, Opener) {
|
||||
getLongHeaderSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) {
|
||||
key := make([]byte, 16)
|
||||
hpKey := make([]byte, 16)
|
||||
rand.Read(key)
|
||||
|
@ -24,36 +25,55 @@ var _ = Describe("AEAD", func() {
|
|||
|
||||
iv := make([]byte, 12)
|
||||
rand.Read(iv)
|
||||
return newSealer(aead, hpBlock, is1RTT), newOpener(aead, hpBlock, is1RTT)
|
||||
return newSealer(aead, hpBlock, false), newLongHeaderOpener(aead, hpBlock)
|
||||
}
|
||||
|
||||
getShortHeaderSealerAndOpener := func() (ShortHeaderSealer, ShortHeaderOpener) {
|
||||
key := make([]byte, 16)
|
||||
hpKey := make([]byte, 16)
|
||||
rand.Read(key)
|
||||
rand.Read(hpKey)
|
||||
block, err := aes.NewCipher(key)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
aead, err := cipher.NewGCM(block)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hpBlock, err := aes.NewCipher(hpKey)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
iv := make([]byte, 12)
|
||||
rand.Read(iv)
|
||||
return newSealer(aead, hpBlock, true), newShortHeaderOpener(aead, hpBlock)
|
||||
}
|
||||
|
||||
Context("message encryption", func() {
|
||||
var (
|
||||
sealer ShortHeaderSealer
|
||||
opener Opener
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
sealer, opener = getSealerAndOpener(false)
|
||||
})
|
||||
|
||||
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
ad := []byte("Donec in velit neque.")
|
||||
|
||||
It("encrypts and decrypts a message", func() {
|
||||
It("encrypts and decrypts a message, for long headers", func() {
|
||||
sealer, opener := getLongHeaderSealerAndOpener()
|
||||
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
|
||||
opened, err := opener.Open(nil, encrypted, 0x1337, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(opened).To(Equal(msg))
|
||||
})
|
||||
|
||||
It("encrypts and decrypts a message, for short headers", func() {
|
||||
sealer, opener := getShortHeaderSealerAndOpener()
|
||||
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
|
||||
opened, err := opener.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(opened).To(Equal(msg))
|
||||
})
|
||||
|
||||
It("fails to open a message if the associated data is not the same", func() {
|
||||
sealer, opener := getLongHeaderSealerAndOpener()
|
||||
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad"))
|
||||
Expect(err).To(MatchError("cipher: message authentication failed"))
|
||||
})
|
||||
|
||||
It("fails to open a message if the packet number is not the same", func() {
|
||||
sealer, opener := getLongHeaderSealerAndOpener()
|
||||
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := opener.Open(nil, encrypted, 0x42, ad)
|
||||
Expect(err).To(MatchError("cipher: message authentication failed"))
|
||||
|
@ -62,7 +82,7 @@ var _ = Describe("AEAD", func() {
|
|||
|
||||
Context("header encryption", func() {
|
||||
It("encrypts and encrypts the header, for long headers", func() {
|
||||
sealer, opener := getSealerAndOpener(false)
|
||||
sealer, opener := getLongHeaderSealerAndOpener()
|
||||
var lastFourBitsDifferent int
|
||||
for i := 0; i < 100; i++ {
|
||||
sample := make([]byte, 16)
|
||||
|
@ -82,7 +102,7 @@ var _ = Describe("AEAD", func() {
|
|||
})
|
||||
|
||||
It("encrypts and encrypts the header, for short headers", func() {
|
||||
sealer, opener := getSealerAndOpener(true)
|
||||
sealer, opener := getShortHeaderSealerAndOpener()
|
||||
var lastFiveBitsDifferent int
|
||||
for i := 0; i < 100; i++ {
|
||||
sample := make([]byte, 16)
|
||||
|
@ -102,7 +122,7 @@ var _ = Describe("AEAD", func() {
|
|||
})
|
||||
|
||||
It("fails to decrypt the header when using a different sample", func() {
|
||||
sealer, opener := getSealerAndOpener(true)
|
||||
sealer, opener := getLongHeaderSealerAndOpener()
|
||||
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
|
||||
sample := make([]byte, 16)
|
||||
rand.Read(sample)
|
||||
|
|
|
@ -105,15 +105,15 @@ type cryptoSetup struct {
|
|||
writeEncLevel protocol.EncryptionLevel
|
||||
|
||||
initialStream io.Writer
|
||||
initialOpener Opener
|
||||
initialOpener LongHeaderOpener
|
||||
initialSealer LongHeaderSealer
|
||||
|
||||
handshakeStream io.Writer
|
||||
handshakeOpener Opener
|
||||
handshakeOpener LongHeaderOpener
|
||||
handshakeSealer LongHeaderSealer
|
||||
|
||||
oneRTTStream io.Writer
|
||||
opener Opener
|
||||
opener ShortHeaderOpener
|
||||
sealer ShortHeaderSealer
|
||||
}
|
||||
|
||||
|
@ -493,11 +493,11 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte)
|
|||
switch h.readEncLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
h.readEncLevel = protocol.EncryptionHandshake
|
||||
h.handshakeOpener = newOpener(suite.AEAD(key, iv), hpDecrypter, false)
|
||||
h.handshakeOpener = newLongHeaderOpener(suite.AEAD(key, iv), hpDecrypter)
|
||||
h.logger.Debugf("Installed Handshake Read keys")
|
||||
case protocol.EncryptionHandshake:
|
||||
h.readEncLevel = protocol.Encryption1RTT
|
||||
h.opener = newOpener(suite.AEAD(key, iv), hpDecrypter, true)
|
||||
h.opener = newShortHeaderOpener(suite.AEAD(key, iv), hpDecrypter)
|
||||
h.logger.Debugf("Installed 1-RTT Read keys")
|
||||
default:
|
||||
panic("unexpected read encryption level")
|
||||
|
@ -591,7 +591,7 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
|
|||
return h.sealer, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetInitialOpener() (Opener, error) {
|
||||
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
|
@ -601,7 +601,7 @@ func (h *cryptoSetup) GetInitialOpener() (Opener, error) {
|
|||
return h.initialOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetHandshakeOpener() (Opener, error) {
|
||||
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
|
@ -615,7 +615,7 @@ func (h *cryptoSetup) GetHandshakeOpener() (Opener, error) {
|
|||
return h.handshakeOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get1RTTOpener() (Opener, error) {
|
||||
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
var quicVersion1Salt = []byte{0xef, 0x4f, 0xb0, 0xab, 0xb4, 0x74, 0x70, 0xc4, 0x1b, 0xef, 0xcf, 0x80, 0x31, 0x33, 0x4f, 0xae, 0x48, 0x5e, 0x09, 0xa0}
|
||||
|
||||
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
|
||||
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, Opener, error) {
|
||||
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, LongHeaderOpener, error) {
|
||||
clientSecret, serverSecret := computeSecrets(connID)
|
||||
var mySecret, otherSecret []byte
|
||||
if pers == protocol.PerspectiveClient {
|
||||
|
@ -34,7 +34,7 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return newSealer(encrypter, hpEncrypter, false), newOpener(decrypter, hpDecrypter, false), nil
|
||||
return newSealer(encrypter, hpEncrypter, false), newLongHeaderOpener(decrypter, hpDecrypter), nil
|
||||
}
|
||||
|
||||
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||
|
|
|
@ -8,12 +8,22 @@ import (
|
|||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
// Opener opens a packet
|
||||
type Opener interface {
|
||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
type headerDecryptor interface {
|
||||
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
|
||||
}
|
||||
|
||||
// LongHeaderOpener opens a long header packet
|
||||
type LongHeaderOpener interface {
|
||||
headerDecryptor
|
||||
Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// ShortHeaderOpener opens a short header packet
|
||||
type ShortHeaderOpener interface {
|
||||
headerDecryptor
|
||||
Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhase, associatedData []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// LongHeaderSealer seals a long header packet
|
||||
type LongHeaderSealer interface {
|
||||
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
||||
|
@ -51,9 +61,9 @@ type CryptoSetup interface {
|
|||
Received1RTTAck()
|
||||
ConnectionState() tls.ConnectionState
|
||||
|
||||
GetInitialOpener() (Opener, error)
|
||||
GetHandshakeOpener() (Opener, error)
|
||||
Get1RTTOpener() (Opener, error)
|
||||
GetInitialOpener() (LongHeaderOpener, error)
|
||||
GetHandshakeOpener() (LongHeaderOpener, error)
|
||||
Get1RTTOpener() (ShortHeaderOpener, error)
|
||||
|
||||
GetInitialSealer() (LongHeaderSealer, error)
|
||||
GetHandshakeSealer() (LongHeaderSealer, error)
|
||||
|
|
|
@ -79,10 +79,10 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
|
|||
}
|
||||
|
||||
// Get1RTTOpener mocks base method
|
||||
func (m *MockCryptoSetup) Get1RTTOpener() (handshake.Opener, error) {
|
||||
func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get1RTTOpener")
|
||||
ret0, _ := ret[0].(handshake.Opener)
|
||||
ret0, _ := ret[0].(handshake.ShortHeaderOpener)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
@ -109,10 +109,10 @@ func (mr *MockCryptoSetupMockRecorder) Get1RTTSealer() *gomock.Call {
|
|||
}
|
||||
|
||||
// GetHandshakeOpener mocks base method
|
||||
func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.Opener, error) {
|
||||
func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.LongHeaderOpener, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetHandshakeOpener")
|
||||
ret0, _ := ret[0].(handshake.Opener)
|
||||
ret0, _ := ret[0].(handshake.LongHeaderOpener)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
@ -139,10 +139,10 @@ func (mr *MockCryptoSetupMockRecorder) GetHandshakeSealer() *gomock.Call {
|
|||
}
|
||||
|
||||
// GetInitialOpener mocks base method
|
||||
func (m *MockCryptoSetup) GetInitialOpener() (handshake.Opener, error) {
|
||||
func (m *MockCryptoSetup) GetInitialOpener() (handshake.LongHeaderOpener, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetInitialOpener")
|
||||
ret0, _ := ret[0].(handshake.Opener)
|
||||
ret0, _ := ret[0].(handshake.LongHeaderOpener)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
|
62
internal/mocks/long_header_opener.go
Normal file
62
internal/mocks/long_header_opener.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: LongHeaderOpener)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockLongHeaderOpener is a mock of LongHeaderOpener interface
|
||||
type MockLongHeaderOpener struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLongHeaderOpenerMockRecorder
|
||||
}
|
||||
|
||||
// MockLongHeaderOpenerMockRecorder is the mock recorder for MockLongHeaderOpener
|
||||
type MockLongHeaderOpenerMockRecorder struct {
|
||||
mock *MockLongHeaderOpener
|
||||
}
|
||||
|
||||
// NewMockLongHeaderOpener creates a new mock instance
|
||||
func NewMockLongHeaderOpener(ctrl *gomock.Controller) *MockLongHeaderOpener {
|
||||
mock := &MockLongHeaderOpener{ctrl: ctrl}
|
||||
mock.recorder = &MockLongHeaderOpenerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockLongHeaderOpener) EXPECT() *MockLongHeaderOpenerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// DecryptHeader mocks base method
|
||||
func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// DecryptHeader indicates an expected call of DecryptHeader
|
||||
func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// Open mocks base method
|
||||
func (m *MockLongHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Open indicates an expected call of Open
|
||||
func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3)
|
||||
}
|
|
@ -3,7 +3,8 @@ package mocks
|
|||
//go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go"
|
||||
//go:generate sh -c "mockgen -package mockquic -destination quic/session.go github.com/lucas-clemente/quic-go Session && goimports -w quic/session.go"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks short_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderOpener"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks long_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake LongHeaderOpener"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
|
||||
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
|
||||
|
|
|
@ -1,62 +0,0 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Opener)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockOpener is a mock of Opener interface
|
||||
type MockOpener struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockOpenerMockRecorder
|
||||
}
|
||||
|
||||
// MockOpenerMockRecorder is the mock recorder for MockOpener
|
||||
type MockOpenerMockRecorder struct {
|
||||
mock *MockOpener
|
||||
}
|
||||
|
||||
// NewMockOpener creates a new mock instance
|
||||
func NewMockOpener(ctrl *gomock.Controller) *MockOpener {
|
||||
mock := &MockOpener{ctrl: ctrl}
|
||||
mock.recorder = &MockOpenerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockOpener) EXPECT() *MockOpenerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// DecryptHeader mocks base method
|
||||
func (m *MockOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// DecryptHeader indicates an expected call of DecryptHeader
|
||||
func (mr *MockOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockOpener)(nil).DecryptHeader), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// Open mocks base method
|
||||
func (m *MockOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Open indicates an expected call of Open
|
||||
func (mr *MockOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockOpener)(nil).Open), arg0, arg1, arg2, arg3)
|
||||
}
|
62
internal/mocks/short_header_opener.go
Normal file
62
internal/mocks/short_header_opener.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: ShortHeaderOpener)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockShortHeaderOpener is a mock of ShortHeaderOpener interface
|
||||
type MockShortHeaderOpener struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockShortHeaderOpenerMockRecorder
|
||||
}
|
||||
|
||||
// MockShortHeaderOpenerMockRecorder is the mock recorder for MockShortHeaderOpener
|
||||
type MockShortHeaderOpenerMockRecorder struct {
|
||||
mock *MockShortHeaderOpener
|
||||
}
|
||||
|
||||
// NewMockShortHeaderOpener creates a new mock instance
|
||||
func NewMockShortHeaderOpener(ctrl *gomock.Controller) *MockShortHeaderOpener {
|
||||
mock := &MockShortHeaderOpener{ctrl: ctrl}
|
||||
mock.recorder = &MockShortHeaderOpenerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockShortHeaderOpener) EXPECT() *MockShortHeaderOpenerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// DecryptHeader mocks base method
|
||||
func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// DecryptHeader indicates an expected call of DecryptHeader
|
||||
func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// Open mocks base method
|
||||
func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 protocol.KeyPhase, arg4 []byte) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Open indicates an expected call of Open
|
||||
func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4)
|
||||
}
|
|
@ -10,6 +10,10 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type headerDecryptor interface {
|
||||
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
|
||||
}
|
||||
|
||||
type unpackedPacket struct {
|
||||
packetNumber protocol.PacketNumber // the decoded packet number
|
||||
hdr *wire.ExtendedHeader
|
||||
|
@ -36,28 +40,85 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber)
|
|||
}
|
||||
|
||||
func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
var opener handshake.Opener
|
||||
var encLevel protocol.EncryptionLevel
|
||||
var err error
|
||||
var extHdr *wire.ExtendedHeader
|
||||
var decrypted []byte
|
||||
switch hdr.Type {
|
||||
case protocol.PacketTypeInitial:
|
||||
encLevel = protocol.EncryptionInitial
|
||||
opener, err = u.cs.GetInitialOpener()
|
||||
opener, err := u.cs.GetInitialOpener()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case protocol.PacketTypeHandshake:
|
||||
encLevel = protocol.EncryptionHandshake
|
||||
opener, err = u.cs.GetHandshakeOpener()
|
||||
opener, err := u.cs.GetHandshakeOpener()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
if hdr.IsLongHeader {
|
||||
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
|
||||
}
|
||||
encLevel = protocol.Encryption1RTT
|
||||
opener, err = u.cs.Get1RTTOpener()
|
||||
opener, err := u.cs.Get1RTTOpener()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
|
||||
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, extHdr.PacketNumber)
|
||||
|
||||
return &unpackedPacket{
|
||||
hdr: extHdr,
|
||||
packetNumber: extHdr.PacketNumber,
|
||||
encryptionLevel: encLevel,
|
||||
data: decrypted,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
|
||||
extHdr, err := u.unpack(opener, hdr, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
extHdrLen := extHdr.GetLength(u.version)
|
||||
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return extHdr, decrypted, nil
|
||||
}
|
||||
|
||||
func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
|
||||
extHdr, err := u.unpack(opener, hdr, data)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
extHdrLen := extHdr.GetLength(u.version)
|
||||
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return extHdr, decrypted, nil
|
||||
}
|
||||
|
||||
func (u *packetUnpacker) unpack(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
hdrLen := int(hdr.ParsedLen())
|
||||
if len(data) < hdrLen+4+16 {
|
||||
return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
|
||||
|
@ -67,7 +128,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
|
|||
origPNBytes := make([]byte, 4)
|
||||
copy(origPNBytes, data[hdrLen:hdrLen+4])
|
||||
// 2. decrypt the header, assuming a 4 byte packet number
|
||||
opener.DecryptHeader(
|
||||
hd.DecryptHeader(
|
||||
data[hdrLen+4:hdrLen+4+16],
|
||||
&data[0],
|
||||
data[hdrLen:hdrLen+4],
|
||||
|
@ -77,30 +138,15 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing extended header: %s", err)
|
||||
}
|
||||
extHdrLen := hdrLen + int(extHdr.PacketNumberLen)
|
||||
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
|
||||
if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
|
||||
copy(data[extHdrLen:hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
|
||||
copy(data[extHdr.GetLength(u.version):hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
|
||||
}
|
||||
|
||||
pn := protocol.DecodePacketNumber(
|
||||
extHdr.PacketNumber = protocol.DecodePacketNumber(
|
||||
extHdr.PacketNumberLen,
|
||||
u.largestRcvdPacketNumber,
|
||||
extHdr.PacketNumber,
|
||||
)
|
||||
|
||||
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], pn, data[:extHdrLen])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
|
||||
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, pn)
|
||||
|
||||
return &unpackedPacket{
|
||||
hdr: extHdr,
|
||||
packetNumber: pn,
|
||||
encryptionLevel: encLevel,
|
||||
data: decrypted,
|
||||
}, nil
|
||||
return extHdr, nil
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
}
|
||||
hdr, hdrRaw := getHeader(extHdr)
|
||||
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
|
||||
opener := mocks.NewMockOpener(mockCtrl)
|
||||
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
|
||||
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
|
||||
_, err := unpacker.Unpack(hdr, data)
|
||||
Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19"))
|
||||
|
@ -67,7 +67,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
PacketNumberLen: 3,
|
||||
}
|
||||
hdr, hdrRaw := getHeader(extHdr)
|
||||
opener := mocks.NewMockOpener(mockCtrl)
|
||||
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
|
||||
cs.EXPECT().GetInitialOpener().Return(opener, nil)
|
||||
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil)
|
||||
|
@ -102,7 +102,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
PacketNumberLen: 3,
|
||||
}
|
||||
hdr, hdrRaw := getHeader(extHdr)
|
||||
opener := mocks.NewMockOpener(mockCtrl)
|
||||
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
|
||||
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
|
||||
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
|
||||
|
@ -129,7 +129,7 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
hdrRaw[len(hdrRaw)-2] ^= 0xff // invert the packet number
|
||||
hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number
|
||||
Expect(hdrRaw[0]).ToNot(Equal(firstHdrByte))
|
||||
opener := mocks.NewMockOpener(mockCtrl)
|
||||
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
|
||||
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
|
||||
gomock.InOrder(
|
||||
// we're using a 2 byte packet number, so the sample starts at the 3rd payload byte
|
||||
|
@ -158,11 +158,12 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
Header: wire.Header{DestConnectionID: connID},
|
||||
PacketNumber: 0x1337,
|
||||
PacketNumberLen: 2,
|
||||
KeyPhase: protocol.KeyPhaseOne,
|
||||
}
|
||||
opener := mocks.NewMockOpener(mockCtrl)
|
||||
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
|
||||
cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2)
|
||||
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil)
|
||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, protocol.KeyPhaseOne, gomock.Any()).Return([]byte{0}, nil)
|
||||
hdr, hdrRaw := getHeader(firstHdr)
|
||||
packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -172,10 +173,11 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
Header: wire.Header{DestConnectionID: connID},
|
||||
PacketNumber: 0x38,
|
||||
PacketNumberLen: 1,
|
||||
KeyPhase: protocol.KeyPhaseZero,
|
||||
}
|
||||
// expect the call with the decoded packet number
|
||||
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil)
|
||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), protocol.KeyPhaseZero, gomock.Any()).Return([]byte{0}, nil)
|
||||
hdr, hdrRaw = getHeader(secondHdr)
|
||||
packet, err = unpacker.Unpack(hdr, append(hdrRaw, payload...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue