use separate opener interfaces for long and short header packets

This commit is contained in:
Marten Seemann 2019-06-10 17:11:47 +08:00
parent b2d3ef691e
commit 598628d05b
12 changed files with 320 additions and 149 deletions

View file

@ -62,47 +62,77 @@ func (s *sealer) KeyPhase() protocol.KeyPhase {
return protocol.KeyPhaseZero
}
type opener struct {
type longHeaderOpener struct {
aead cipher.AEAD
pnDecrypter cipher.Block
// use a single slice to avoid allocations
nonceBuf []byte
hpMask []byte
// short headers protect 5 bits in the first byte, long headers only 4
is1RTT bool
}
var _ Opener = &opener{}
var _ LongHeaderOpener = &longHeaderOpener{}
func newOpener(aead cipher.AEAD, pnDecrypter cipher.Block, is1RTT bool) Opener {
return &opener{
func newLongHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) LongHeaderOpener {
return &longHeaderOpener{
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
is1RTT: is1RTT,
pnDecrypter: pnDecrypter,
hpMask: make([]byte, pnDecrypter.BlockSize()),
}
}
func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return o.aead.Open(dst, o.nonceBuf, src, ad)
}
func (o *opener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != o.pnDecrypter.BlockSize() {
panic("invalid sample size")
}
o.pnDecrypter.Encrypt(o.hpMask, sample)
if o.is1RTT {
*firstByte ^= o.hpMask[0] & 0x1f
} else {
*firstByte ^= o.hpMask[0] & 0xf
}
*firstByte ^= o.hpMask[0] & 0xf
for i := range pnBytes {
pnBytes[i] ^= o.hpMask[i+1]
}
}
type shortHeaderOpener struct {
aead cipher.AEAD
pnDecrypter cipher.Block
// use a single slice to avoid allocations
nonceBuf []byte
hpMask []byte
}
var _ ShortHeaderOpener = &shortHeaderOpener{}
func newShortHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) ShortHeaderOpener {
return &shortHeaderOpener{
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
pnDecrypter: pnDecrypter,
hpMask: make([]byte, pnDecrypter.BlockSize()),
}
}
func (o *shortHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, _ protocol.KeyPhase, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return o.aead.Open(dst, o.nonceBuf, src, ad)
}
func (o *shortHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != o.pnDecrypter.BlockSize() {
panic("invalid sample size")
}
o.pnDecrypter.Encrypt(o.hpMask, sample)
*firstByte ^= o.hpMask[0] & 0x1f
for i := range pnBytes {
pnBytes[i] ^= o.hpMask[i+1]
}

View file

@ -5,12 +5,13 @@ import (
"crypto/cipher"
"crypto/rand"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("AEAD", func() {
getSealerAndOpener := func(is1RTT bool) (ShortHeaderSealer, Opener) {
getLongHeaderSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) {
key := make([]byte, 16)
hpKey := make([]byte, 16)
rand.Read(key)
@ -24,36 +25,55 @@ var _ = Describe("AEAD", func() {
iv := make([]byte, 12)
rand.Read(iv)
return newSealer(aead, hpBlock, is1RTT), newOpener(aead, hpBlock, is1RTT)
return newSealer(aead, hpBlock, false), newLongHeaderOpener(aead, hpBlock)
}
getShortHeaderSealerAndOpener := func() (ShortHeaderSealer, ShortHeaderOpener) {
key := make([]byte, 16)
hpKey := make([]byte, 16)
rand.Read(key)
rand.Read(hpKey)
block, err := aes.NewCipher(key)
Expect(err).ToNot(HaveOccurred())
aead, err := cipher.NewGCM(block)
Expect(err).ToNot(HaveOccurred())
hpBlock, err := aes.NewCipher(hpKey)
Expect(err).ToNot(HaveOccurred())
iv := make([]byte, 12)
rand.Read(iv)
return newSealer(aead, hpBlock, true), newShortHeaderOpener(aead, hpBlock)
}
Context("message encryption", func() {
var (
sealer ShortHeaderSealer
opener Opener
)
BeforeEach(func() {
sealer, opener = getSealerAndOpener(false)
})
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
ad := []byte("Donec in velit neque.")
It("encrypts and decrypts a message", func() {
It("encrypts and decrypts a message, for long headers", func() {
sealer, opener := getLongHeaderSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
opened, err := opener.Open(nil, encrypted, 0x1337, ad)
Expect(err).ToNot(HaveOccurred())
Expect(opened).To(Equal(msg))
})
It("encrypts and decrypts a message, for short headers", func() {
sealer, opener := getShortHeaderSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
opened, err := opener.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
Expect(opened).To(Equal(msg))
})
It("fails to open a message if the associated data is not the same", func() {
sealer, opener := getLongHeaderSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
_, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad"))
Expect(err).To(MatchError("cipher: message authentication failed"))
})
It("fails to open a message if the packet number is not the same", func() {
sealer, opener := getLongHeaderSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
_, err := opener.Open(nil, encrypted, 0x42, ad)
Expect(err).To(MatchError("cipher: message authentication failed"))
@ -62,7 +82,7 @@ var _ = Describe("AEAD", func() {
Context("header encryption", func() {
It("encrypts and encrypts the header, for long headers", func() {
sealer, opener := getSealerAndOpener(false)
sealer, opener := getLongHeaderSealerAndOpener()
var lastFourBitsDifferent int
for i := 0; i < 100; i++ {
sample := make([]byte, 16)
@ -82,7 +102,7 @@ var _ = Describe("AEAD", func() {
})
It("encrypts and encrypts the header, for short headers", func() {
sealer, opener := getSealerAndOpener(true)
sealer, opener := getShortHeaderSealerAndOpener()
var lastFiveBitsDifferent int
for i := 0; i < 100; i++ {
sample := make([]byte, 16)
@ -102,7 +122,7 @@ var _ = Describe("AEAD", func() {
})
It("fails to decrypt the header when using a different sample", func() {
sealer, opener := getSealerAndOpener(true)
sealer, opener := getLongHeaderSealerAndOpener()
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
sample := make([]byte, 16)
rand.Read(sample)

View file

@ -105,15 +105,15 @@ type cryptoSetup struct {
writeEncLevel protocol.EncryptionLevel
initialStream io.Writer
initialOpener Opener
initialOpener LongHeaderOpener
initialSealer LongHeaderSealer
handshakeStream io.Writer
handshakeOpener Opener
handshakeOpener LongHeaderOpener
handshakeSealer LongHeaderSealer
oneRTTStream io.Writer
opener Opener
opener ShortHeaderOpener
sealer ShortHeaderSealer
}
@ -493,11 +493,11 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte)
switch h.readEncLevel {
case protocol.EncryptionInitial:
h.readEncLevel = protocol.EncryptionHandshake
h.handshakeOpener = newOpener(suite.AEAD(key, iv), hpDecrypter, false)
h.handshakeOpener = newLongHeaderOpener(suite.AEAD(key, iv), hpDecrypter)
h.logger.Debugf("Installed Handshake Read keys")
case protocol.EncryptionHandshake:
h.readEncLevel = protocol.Encryption1RTT
h.opener = newOpener(suite.AEAD(key, iv), hpDecrypter, true)
h.opener = newShortHeaderOpener(suite.AEAD(key, iv), hpDecrypter)
h.logger.Debugf("Installed 1-RTT Read keys")
default:
panic("unexpected read encryption level")
@ -591,7 +591,7 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
return h.sealer, nil
}
func (h *cryptoSetup) GetInitialOpener() (Opener, error) {
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
@ -601,7 +601,7 @@ func (h *cryptoSetup) GetInitialOpener() (Opener, error) {
return h.initialOpener, nil
}
func (h *cryptoSetup) GetHandshakeOpener() (Opener, error) {
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
@ -615,7 +615,7 @@ func (h *cryptoSetup) GetHandshakeOpener() (Opener, error) {
return h.handshakeOpener, nil
}
func (h *cryptoSetup) Get1RTTOpener() (Opener, error) {
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()

View file

@ -11,7 +11,7 @@ import (
var quicVersion1Salt = []byte{0xef, 0x4f, 0xb0, 0xab, 0xb4, 0x74, 0x70, 0xc4, 0x1b, 0xef, 0xcf, 0x80, 0x31, 0x33, 0x4f, 0xae, 0x48, 0x5e, 0x09, 0xa0}
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, Opener, error) {
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, LongHeaderOpener, error) {
clientSecret, serverSecret := computeSecrets(connID)
var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient {
@ -34,7 +34,7 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo
if err != nil {
return nil, nil, err
}
return newSealer(encrypter, hpEncrypter, false), newOpener(decrypter, hpDecrypter, false), nil
return newSealer(encrypter, hpEncrypter, false), newLongHeaderOpener(decrypter, hpDecrypter), nil
}
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {

View file

@ -8,12 +8,22 @@ import (
"github.com/marten-seemann/qtls"
)
// Opener opens a packet
type Opener interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
type headerDecryptor interface {
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
// LongHeaderOpener opens a long header packet
type LongHeaderOpener interface {
headerDecryptor
Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error)
}
// ShortHeaderOpener opens a short header packet
type ShortHeaderOpener interface {
headerDecryptor
Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhase, associatedData []byte) ([]byte, error)
}
// LongHeaderSealer seals a long header packet
type LongHeaderSealer interface {
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
@ -51,9 +61,9 @@ type CryptoSetup interface {
Received1RTTAck()
ConnectionState() tls.ConnectionState
GetInitialOpener() (Opener, error)
GetHandshakeOpener() (Opener, error)
Get1RTTOpener() (Opener, error)
GetInitialOpener() (LongHeaderOpener, error)
GetHandshakeOpener() (LongHeaderOpener, error)
Get1RTTOpener() (ShortHeaderOpener, error)
GetInitialSealer() (LongHeaderSealer, error)
GetHandshakeSealer() (LongHeaderSealer, error)

View file

@ -79,10 +79,10 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
}
// Get1RTTOpener mocks base method
func (m *MockCryptoSetup) Get1RTTOpener() (handshake.Opener, error) {
func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get1RTTOpener")
ret0, _ := ret[0].(handshake.Opener)
ret0, _ := ret[0].(handshake.ShortHeaderOpener)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@ -109,10 +109,10 @@ func (mr *MockCryptoSetupMockRecorder) Get1RTTSealer() *gomock.Call {
}
// GetHandshakeOpener mocks base method
func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.Opener, error) {
func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.LongHeaderOpener, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetHandshakeOpener")
ret0, _ := ret[0].(handshake.Opener)
ret0, _ := ret[0].(handshake.LongHeaderOpener)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@ -139,10 +139,10 @@ func (mr *MockCryptoSetupMockRecorder) GetHandshakeSealer() *gomock.Call {
}
// GetInitialOpener mocks base method
func (m *MockCryptoSetup) GetInitialOpener() (handshake.Opener, error) {
func (m *MockCryptoSetup) GetInitialOpener() (handshake.LongHeaderOpener, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetInitialOpener")
ret0, _ := ret[0].(handshake.Opener)
ret0, _ := ret[0].(handshake.LongHeaderOpener)
ret1, _ := ret[1].(error)
return ret0, ret1
}

View file

@ -0,0 +1,62 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: LongHeaderOpener)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockLongHeaderOpener is a mock of LongHeaderOpener interface
type MockLongHeaderOpener struct {
ctrl *gomock.Controller
recorder *MockLongHeaderOpenerMockRecorder
}
// MockLongHeaderOpenerMockRecorder is the mock recorder for MockLongHeaderOpener
type MockLongHeaderOpenerMockRecorder struct {
mock *MockLongHeaderOpener
}
// NewMockLongHeaderOpener creates a new mock instance
func NewMockLongHeaderOpener(ctrl *gomock.Controller) *MockLongHeaderOpener {
mock := &MockLongHeaderOpener{ctrl: ctrl}
mock.recorder = &MockLongHeaderOpenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockLongHeaderOpener) EXPECT() *MockLongHeaderOpenerMockRecorder {
return m.recorder
}
// DecryptHeader mocks base method
func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2)
}
// DecryptHeader indicates an expected call of DecryptHeader
func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2)
}
// Open mocks base method
func (m *MockLongHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open indicates an expected call of Open
func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3)
}

View file

@ -3,7 +3,8 @@ package mocks
//go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go"
//go:generate sh -c "mockgen -package mockquic -destination quic/session.go github.com/lucas-clemente/quic-go Session && goimports -w quic/session.go"
//go:generate sh -c "../mockgen_internal.sh mocks short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer"
//go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener"
//go:generate sh -c "../mockgen_internal.sh mocks short_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderOpener"
//go:generate sh -c "../mockgen_internal.sh mocks long_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake LongHeaderOpener"
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"

View file

@ -1,62 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Opener)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockOpener is a mock of Opener interface
type MockOpener struct {
ctrl *gomock.Controller
recorder *MockOpenerMockRecorder
}
// MockOpenerMockRecorder is the mock recorder for MockOpener
type MockOpenerMockRecorder struct {
mock *MockOpener
}
// NewMockOpener creates a new mock instance
func NewMockOpener(ctrl *gomock.Controller) *MockOpener {
mock := &MockOpener{ctrl: ctrl}
mock.recorder = &MockOpenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockOpener) EXPECT() *MockOpenerMockRecorder {
return m.recorder
}
// DecryptHeader mocks base method
func (m *MockOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2)
}
// DecryptHeader indicates an expected call of DecryptHeader
func (mr *MockOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockOpener)(nil).DecryptHeader), arg0, arg1, arg2)
}
// Open mocks base method
func (m *MockOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open indicates an expected call of Open
func (mr *MockOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockOpener)(nil).Open), arg0, arg1, arg2, arg3)
}

View file

@ -0,0 +1,62 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: ShortHeaderOpener)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockShortHeaderOpener is a mock of ShortHeaderOpener interface
type MockShortHeaderOpener struct {
ctrl *gomock.Controller
recorder *MockShortHeaderOpenerMockRecorder
}
// MockShortHeaderOpenerMockRecorder is the mock recorder for MockShortHeaderOpener
type MockShortHeaderOpenerMockRecorder struct {
mock *MockShortHeaderOpener
}
// NewMockShortHeaderOpener creates a new mock instance
func NewMockShortHeaderOpener(ctrl *gomock.Controller) *MockShortHeaderOpener {
mock := &MockShortHeaderOpener{ctrl: ctrl}
mock.recorder = &MockShortHeaderOpenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockShortHeaderOpener) EXPECT() *MockShortHeaderOpenerMockRecorder {
return m.recorder
}
// DecryptHeader mocks base method
func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2)
}
// DecryptHeader indicates an expected call of DecryptHeader
func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2)
}
// Open mocks base method
func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 protocol.KeyPhase, arg4 []byte) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open indicates an expected call of Open
func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4)
}

View file

@ -10,6 +10,10 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)
type headerDecryptor interface {
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
type unpackedPacket struct {
packetNumber protocol.PacketNumber // the decoded packet number
hdr *wire.ExtendedHeader
@ -36,28 +40,85 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber)
}
func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
r := bytes.NewReader(data)
var opener handshake.Opener
var encLevel protocol.EncryptionLevel
var err error
var extHdr *wire.ExtendedHeader
var decrypted []byte
switch hdr.Type {
case protocol.PacketTypeInitial:
encLevel = protocol.EncryptionInitial
opener, err = u.cs.GetInitialOpener()
opener, err := u.cs.GetInitialOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
case protocol.PacketTypeHandshake:
encLevel = protocol.EncryptionHandshake
opener, err = u.cs.GetHandshakeOpener()
opener, err := u.cs.GetHandshakeOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
default:
if hdr.IsLongHeader {
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
}
encLevel = protocol.Encryption1RTT
opener, err = u.cs.Get1RTTOpener()
opener, err := u.cs.Get1RTTOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
}
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, extHdr.PacketNumber)
return &unpackedPacket{
hdr: extHdr,
packetNumber: extHdr.PacketNumber,
encryptionLevel: encLevel,
data: decrypted,
}, nil
}
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, err := u.unpack(opener, hdr, data)
if err != nil {
return nil, err
return nil, nil, err
}
extHdrLen := extHdr.GetLength(u.version)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
if err != nil {
return nil, nil, err
}
return extHdr, decrypted, nil
}
func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, err := u.unpack(opener, hdr, data)
if err != nil {
return nil, nil, err
}
extHdrLen := extHdr.GetLength(u.version)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen])
if err != nil {
return nil, nil, err
}
return extHdr, decrypted, nil
}
func (u *packetUnpacker) unpack(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
r := bytes.NewReader(data)
hdrLen := int(hdr.ParsedLen())
if len(data) < hdrLen+4+16 {
return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
@ -67,7 +128,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
origPNBytes := make([]byte, 4)
copy(origPNBytes, data[hdrLen:hdrLen+4])
// 2. decrypt the header, assuming a 4 byte packet number
opener.DecryptHeader(
hd.DecryptHeader(
data[hdrLen+4:hdrLen+4+16],
&data[0],
data[hdrLen:hdrLen+4],
@ -77,30 +138,15 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
if err != nil {
return nil, fmt.Errorf("error parsing extended header: %s", err)
}
extHdrLen := hdrLen + int(extHdr.PacketNumberLen)
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
copy(data[extHdrLen:hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
copy(data[extHdr.GetLength(u.version):hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
}
pn := protocol.DecodePacketNumber(
extHdr.PacketNumber = protocol.DecodePacketNumber(
extHdr.PacketNumberLen,
u.largestRcvdPacketNumber,
extHdr.PacketNumber,
)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], pn, data[:extHdrLen])
if err != nil {
return nil, err
}
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, pn)
return &unpackedPacket{
hdr: extHdr,
packetNumber: pn,
encryptionLevel: encLevel,
data: decrypted,
}, nil
return extHdr, nil
}

View file

@ -48,7 +48,7 @@ var _ = Describe("Packet Unpacker", func() {
}
hdr, hdrRaw := getHeader(extHdr)
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockOpener(mockCtrl)
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
_, err := unpacker.Unpack(hdr, data)
Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19"))
@ -67,7 +67,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumberLen: 3,
}
hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockOpener(mockCtrl)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetInitialOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil)
@ -102,7 +102,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumberLen: 3,
}
hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockOpener(mockCtrl)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
@ -129,7 +129,7 @@ var _ = Describe("Packet Unpacker", func() {
hdrRaw[len(hdrRaw)-2] ^= 0xff // invert the packet number
hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number
Expect(hdrRaw[0]).ToNot(Equal(firstHdrByte))
opener := mocks.NewMockOpener(mockCtrl)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
gomock.InOrder(
// we're using a 2 byte packet number, so the sample starts at the 3rd payload byte
@ -158,11 +158,12 @@ var _ = Describe("Packet Unpacker", func() {
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337,
PacketNumberLen: 2,
KeyPhase: protocol.KeyPhaseOne,
}
opener := mocks.NewMockOpener(mockCtrl)
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil)
opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, protocol.KeyPhaseOne, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw := getHeader(firstHdr)
packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
@ -172,10 +173,11 @@ var _ = Describe("Packet Unpacker", func() {
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x38,
PacketNumberLen: 1,
KeyPhase: protocol.KeyPhaseZero,
}
// expect the call with the decoded packet number
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil)
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), protocol.KeyPhaseZero, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw = getHeader(secondHdr)
packet, err = unpacker.Unpack(hdr, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())