use separate opener interfaces for long and short header packets

This commit is contained in:
Marten Seemann 2019-06-10 17:11:47 +08:00
parent b2d3ef691e
commit 598628d05b
12 changed files with 320 additions and 149 deletions

View file

@ -62,47 +62,77 @@ func (s *sealer) KeyPhase() protocol.KeyPhase {
return protocol.KeyPhaseZero return protocol.KeyPhaseZero
} }
type opener struct { type longHeaderOpener struct {
aead cipher.AEAD aead cipher.AEAD
pnDecrypter cipher.Block pnDecrypter cipher.Block
// use a single slice to avoid allocations // use a single slice to avoid allocations
nonceBuf []byte nonceBuf []byte
hpMask []byte hpMask []byte
// short headers protect 5 bits in the first byte, long headers only 4
is1RTT bool
} }
var _ Opener = &opener{} var _ LongHeaderOpener = &longHeaderOpener{}
func newOpener(aead cipher.AEAD, pnDecrypter cipher.Block, is1RTT bool) Opener { func newLongHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) LongHeaderOpener {
return &opener{ return &longHeaderOpener{
aead: aead, aead: aead,
nonceBuf: make([]byte, aead.NonceSize()), nonceBuf: make([]byte, aead.NonceSize()),
is1RTT: is1RTT,
pnDecrypter: pnDecrypter, pnDecrypter: pnDecrypter,
hpMask: make([]byte, pnDecrypter.BlockSize()), hpMask: make([]byte, pnDecrypter.BlockSize()),
} }
} }
func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13. // The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV. // It uses the nonce provided here and XOR it with the IV.
return o.aead.Open(dst, o.nonceBuf, src, ad) return o.aead.Open(dst, o.nonceBuf, src, ad)
} }
func (o *opener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != o.pnDecrypter.BlockSize() { if len(sample) != o.pnDecrypter.BlockSize() {
panic("invalid sample size") panic("invalid sample size")
} }
o.pnDecrypter.Encrypt(o.hpMask, sample) o.pnDecrypter.Encrypt(o.hpMask, sample)
if o.is1RTT { *firstByte ^= o.hpMask[0] & 0xf
*firstByte ^= o.hpMask[0] & 0x1f for i := range pnBytes {
} else { pnBytes[i] ^= o.hpMask[i+1]
*firstByte ^= o.hpMask[0] & 0xf }
} }
type shortHeaderOpener struct {
aead cipher.AEAD
pnDecrypter cipher.Block
// use a single slice to avoid allocations
nonceBuf []byte
hpMask []byte
}
var _ ShortHeaderOpener = &shortHeaderOpener{}
func newShortHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) ShortHeaderOpener {
return &shortHeaderOpener{
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
pnDecrypter: pnDecrypter,
hpMask: make([]byte, pnDecrypter.BlockSize()),
}
}
func (o *shortHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, _ protocol.KeyPhase, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return o.aead.Open(dst, o.nonceBuf, src, ad)
}
func (o *shortHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != o.pnDecrypter.BlockSize() {
panic("invalid sample size")
}
o.pnDecrypter.Encrypt(o.hpMask, sample)
*firstByte ^= o.hpMask[0] & 0x1f
for i := range pnBytes { for i := range pnBytes {
pnBytes[i] ^= o.hpMask[i+1] pnBytes[i] ^= o.hpMask[i+1]
} }

View file

@ -5,12 +5,13 @@ import (
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
var _ = Describe("AEAD", func() { var _ = Describe("AEAD", func() {
getSealerAndOpener := func(is1RTT bool) (ShortHeaderSealer, Opener) { getLongHeaderSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) {
key := make([]byte, 16) key := make([]byte, 16)
hpKey := make([]byte, 16) hpKey := make([]byte, 16)
rand.Read(key) rand.Read(key)
@ -24,36 +25,55 @@ var _ = Describe("AEAD", func() {
iv := make([]byte, 12) iv := make([]byte, 12)
rand.Read(iv) rand.Read(iv)
return newSealer(aead, hpBlock, is1RTT), newOpener(aead, hpBlock, is1RTT) return newSealer(aead, hpBlock, false), newLongHeaderOpener(aead, hpBlock)
}
getShortHeaderSealerAndOpener := func() (ShortHeaderSealer, ShortHeaderOpener) {
key := make([]byte, 16)
hpKey := make([]byte, 16)
rand.Read(key)
rand.Read(hpKey)
block, err := aes.NewCipher(key)
Expect(err).ToNot(HaveOccurred())
aead, err := cipher.NewGCM(block)
Expect(err).ToNot(HaveOccurred())
hpBlock, err := aes.NewCipher(hpKey)
Expect(err).ToNot(HaveOccurred())
iv := make([]byte, 12)
rand.Read(iv)
return newSealer(aead, hpBlock, true), newShortHeaderOpener(aead, hpBlock)
} }
Context("message encryption", func() { Context("message encryption", func() {
var (
sealer ShortHeaderSealer
opener Opener
)
BeforeEach(func() {
sealer, opener = getSealerAndOpener(false)
})
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
ad := []byte("Donec in velit neque.") ad := []byte("Donec in velit neque.")
It("encrypts and decrypts a message", func() { It("encrypts and decrypts a message, for long headers", func() {
sealer, opener := getLongHeaderSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad) encrypted := sealer.Seal(nil, msg, 0x1337, ad)
opened, err := opener.Open(nil, encrypted, 0x1337, ad) opened, err := opener.Open(nil, encrypted, 0x1337, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(opened).To(Equal(msg)) Expect(opened).To(Equal(msg))
}) })
It("encrypts and decrypts a message, for short headers", func() {
sealer, opener := getShortHeaderSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
opened, err := opener.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
Expect(opened).To(Equal(msg))
})
It("fails to open a message if the associated data is not the same", func() { It("fails to open a message if the associated data is not the same", func() {
sealer, opener := getLongHeaderSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad) encrypted := sealer.Seal(nil, msg, 0x1337, ad)
_, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad"))
Expect(err).To(MatchError("cipher: message authentication failed")) Expect(err).To(MatchError("cipher: message authentication failed"))
}) })
It("fails to open a message if the packet number is not the same", func() { It("fails to open a message if the packet number is not the same", func() {
sealer, opener := getLongHeaderSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad) encrypted := sealer.Seal(nil, msg, 0x1337, ad)
_, err := opener.Open(nil, encrypted, 0x42, ad) _, err := opener.Open(nil, encrypted, 0x42, ad)
Expect(err).To(MatchError("cipher: message authentication failed")) Expect(err).To(MatchError("cipher: message authentication failed"))
@ -62,7 +82,7 @@ var _ = Describe("AEAD", func() {
Context("header encryption", func() { Context("header encryption", func() {
It("encrypts and encrypts the header, for long headers", func() { It("encrypts and encrypts the header, for long headers", func() {
sealer, opener := getSealerAndOpener(false) sealer, opener := getLongHeaderSealerAndOpener()
var lastFourBitsDifferent int var lastFourBitsDifferent int
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
sample := make([]byte, 16) sample := make([]byte, 16)
@ -82,7 +102,7 @@ var _ = Describe("AEAD", func() {
}) })
It("encrypts and encrypts the header, for short headers", func() { It("encrypts and encrypts the header, for short headers", func() {
sealer, opener := getSealerAndOpener(true) sealer, opener := getShortHeaderSealerAndOpener()
var lastFiveBitsDifferent int var lastFiveBitsDifferent int
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
sample := make([]byte, 16) sample := make([]byte, 16)
@ -102,7 +122,7 @@ var _ = Describe("AEAD", func() {
}) })
It("fails to decrypt the header when using a different sample", func() { It("fails to decrypt the header when using a different sample", func() {
sealer, opener := getSealerAndOpener(true) sealer, opener := getLongHeaderSealerAndOpener()
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
sample := make([]byte, 16) sample := make([]byte, 16)
rand.Read(sample) rand.Read(sample)

View file

@ -105,15 +105,15 @@ type cryptoSetup struct {
writeEncLevel protocol.EncryptionLevel writeEncLevel protocol.EncryptionLevel
initialStream io.Writer initialStream io.Writer
initialOpener Opener initialOpener LongHeaderOpener
initialSealer LongHeaderSealer initialSealer LongHeaderSealer
handshakeStream io.Writer handshakeStream io.Writer
handshakeOpener Opener handshakeOpener LongHeaderOpener
handshakeSealer LongHeaderSealer handshakeSealer LongHeaderSealer
oneRTTStream io.Writer oneRTTStream io.Writer
opener Opener opener ShortHeaderOpener
sealer ShortHeaderSealer sealer ShortHeaderSealer
} }
@ -493,11 +493,11 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte)
switch h.readEncLevel { switch h.readEncLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
h.readEncLevel = protocol.EncryptionHandshake h.readEncLevel = protocol.EncryptionHandshake
h.handshakeOpener = newOpener(suite.AEAD(key, iv), hpDecrypter, false) h.handshakeOpener = newLongHeaderOpener(suite.AEAD(key, iv), hpDecrypter)
h.logger.Debugf("Installed Handshake Read keys") h.logger.Debugf("Installed Handshake Read keys")
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
h.readEncLevel = protocol.Encryption1RTT h.readEncLevel = protocol.Encryption1RTT
h.opener = newOpener(suite.AEAD(key, iv), hpDecrypter, true) h.opener = newShortHeaderOpener(suite.AEAD(key, iv), hpDecrypter)
h.logger.Debugf("Installed 1-RTT Read keys") h.logger.Debugf("Installed 1-RTT Read keys")
default: default:
panic("unexpected read encryption level") panic("unexpected read encryption level")
@ -591,7 +591,7 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
return h.sealer, nil return h.sealer, nil
} }
func (h *cryptoSetup) GetInitialOpener() (Opener, error) { func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
@ -601,7 +601,7 @@ func (h *cryptoSetup) GetInitialOpener() (Opener, error) {
return h.initialOpener, nil return h.initialOpener, nil
} }
func (h *cryptoSetup) GetHandshakeOpener() (Opener, error) { func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
@ -615,7 +615,7 @@ func (h *cryptoSetup) GetHandshakeOpener() (Opener, error) {
return h.handshakeOpener, nil return h.handshakeOpener, nil
} }
func (h *cryptoSetup) Get1RTTOpener() (Opener, error) { func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()

View file

@ -11,7 +11,7 @@ import (
var quicVersion1Salt = []byte{0xef, 0x4f, 0xb0, 0xab, 0xb4, 0x74, 0x70, 0xc4, 0x1b, 0xef, 0xcf, 0x80, 0x31, 0x33, 0x4f, 0xae, 0x48, 0x5e, 0x09, 0xa0} var quicVersion1Salt = []byte{0xef, 0x4f, 0xb0, 0xab, 0xb4, 0x74, 0x70, 0xc4, 0x1b, 0xef, 0xcf, 0x80, 0x31, 0x33, 0x4f, 0xae, 0x48, 0x5e, 0x09, 0xa0}
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption. // NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, Opener, error) { func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, LongHeaderOpener, error) {
clientSecret, serverSecret := computeSecrets(connID) clientSecret, serverSecret := computeSecrets(connID)
var mySecret, otherSecret []byte var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient { if pers == protocol.PerspectiveClient {
@ -34,7 +34,7 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return newSealer(encrypter, hpEncrypter, false), newOpener(decrypter, hpDecrypter, false), nil return newSealer(encrypter, hpEncrypter, false), newLongHeaderOpener(decrypter, hpDecrypter), nil
} }
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {

View file

@ -8,12 +8,22 @@ import (
"github.com/marten-seemann/qtls" "github.com/marten-seemann/qtls"
) )
// Opener opens a packet type headerDecryptor interface {
type Opener interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
} }
// LongHeaderOpener opens a long header packet
type LongHeaderOpener interface {
headerDecryptor
Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error)
}
// ShortHeaderOpener opens a short header packet
type ShortHeaderOpener interface {
headerDecryptor
Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhase, associatedData []byte) ([]byte, error)
}
// LongHeaderSealer seals a long header packet // LongHeaderSealer seals a long header packet
type LongHeaderSealer interface { type LongHeaderSealer interface {
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
@ -51,9 +61,9 @@ type CryptoSetup interface {
Received1RTTAck() Received1RTTAck()
ConnectionState() tls.ConnectionState ConnectionState() tls.ConnectionState
GetInitialOpener() (Opener, error) GetInitialOpener() (LongHeaderOpener, error)
GetHandshakeOpener() (Opener, error) GetHandshakeOpener() (LongHeaderOpener, error)
Get1RTTOpener() (Opener, error) Get1RTTOpener() (ShortHeaderOpener, error)
GetInitialSealer() (LongHeaderSealer, error) GetInitialSealer() (LongHeaderSealer, error)
GetHandshakeSealer() (LongHeaderSealer, error) GetHandshakeSealer() (LongHeaderSealer, error)

View file

@ -79,10 +79,10 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
} }
// Get1RTTOpener mocks base method // Get1RTTOpener mocks base method
func (m *MockCryptoSetup) Get1RTTOpener() (handshake.Opener, error) { func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get1RTTOpener") ret := m.ctrl.Call(m, "Get1RTTOpener")
ret0, _ := ret[0].(handshake.Opener) ret0, _ := ret[0].(handshake.ShortHeaderOpener)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@ -109,10 +109,10 @@ func (mr *MockCryptoSetupMockRecorder) Get1RTTSealer() *gomock.Call {
} }
// GetHandshakeOpener mocks base method // GetHandshakeOpener mocks base method
func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.Opener, error) { func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.LongHeaderOpener, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetHandshakeOpener") ret := m.ctrl.Call(m, "GetHandshakeOpener")
ret0, _ := ret[0].(handshake.Opener) ret0, _ := ret[0].(handshake.LongHeaderOpener)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@ -139,10 +139,10 @@ func (mr *MockCryptoSetupMockRecorder) GetHandshakeSealer() *gomock.Call {
} }
// GetInitialOpener mocks base method // GetInitialOpener mocks base method
func (m *MockCryptoSetup) GetInitialOpener() (handshake.Opener, error) { func (m *MockCryptoSetup) GetInitialOpener() (handshake.LongHeaderOpener, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetInitialOpener") ret := m.ctrl.Call(m, "GetInitialOpener")
ret0, _ := ret[0].(handshake.Opener) ret0, _ := ret[0].(handshake.LongHeaderOpener)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }

View file

@ -0,0 +1,62 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: LongHeaderOpener)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockLongHeaderOpener is a mock of LongHeaderOpener interface
type MockLongHeaderOpener struct {
ctrl *gomock.Controller
recorder *MockLongHeaderOpenerMockRecorder
}
// MockLongHeaderOpenerMockRecorder is the mock recorder for MockLongHeaderOpener
type MockLongHeaderOpenerMockRecorder struct {
mock *MockLongHeaderOpener
}
// NewMockLongHeaderOpener creates a new mock instance
func NewMockLongHeaderOpener(ctrl *gomock.Controller) *MockLongHeaderOpener {
mock := &MockLongHeaderOpener{ctrl: ctrl}
mock.recorder = &MockLongHeaderOpenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockLongHeaderOpener) EXPECT() *MockLongHeaderOpenerMockRecorder {
return m.recorder
}
// DecryptHeader mocks base method
func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2)
}
// DecryptHeader indicates an expected call of DecryptHeader
func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2)
}
// Open mocks base method
func (m *MockLongHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open indicates an expected call of Open
func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3)
}

View file

@ -3,7 +3,8 @@ package mocks
//go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go" //go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go"
//go:generate sh -c "mockgen -package mockquic -destination quic/session.go github.com/lucas-clemente/quic-go Session && goimports -w quic/session.go" //go:generate sh -c "mockgen -package mockquic -destination quic/session.go github.com/lucas-clemente/quic-go Session && goimports -w quic/session.go"
//go:generate sh -c "../mockgen_internal.sh mocks short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer" //go:generate sh -c "../mockgen_internal.sh mocks short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer"
//go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener" //go:generate sh -c "../mockgen_internal.sh mocks short_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderOpener"
//go:generate sh -c "../mockgen_internal.sh mocks long_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake LongHeaderOpener"
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup" //go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" //go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler" //go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"

View file

@ -1,62 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Opener)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockOpener is a mock of Opener interface
type MockOpener struct {
ctrl *gomock.Controller
recorder *MockOpenerMockRecorder
}
// MockOpenerMockRecorder is the mock recorder for MockOpener
type MockOpenerMockRecorder struct {
mock *MockOpener
}
// NewMockOpener creates a new mock instance
func NewMockOpener(ctrl *gomock.Controller) *MockOpener {
mock := &MockOpener{ctrl: ctrl}
mock.recorder = &MockOpenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockOpener) EXPECT() *MockOpenerMockRecorder {
return m.recorder
}
// DecryptHeader mocks base method
func (m *MockOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2)
}
// DecryptHeader indicates an expected call of DecryptHeader
func (mr *MockOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockOpener)(nil).DecryptHeader), arg0, arg1, arg2)
}
// Open mocks base method
func (m *MockOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open indicates an expected call of Open
func (mr *MockOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockOpener)(nil).Open), arg0, arg1, arg2, arg3)
}

View file

@ -0,0 +1,62 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: ShortHeaderOpener)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockShortHeaderOpener is a mock of ShortHeaderOpener interface
type MockShortHeaderOpener struct {
ctrl *gomock.Controller
recorder *MockShortHeaderOpenerMockRecorder
}
// MockShortHeaderOpenerMockRecorder is the mock recorder for MockShortHeaderOpener
type MockShortHeaderOpenerMockRecorder struct {
mock *MockShortHeaderOpener
}
// NewMockShortHeaderOpener creates a new mock instance
func NewMockShortHeaderOpener(ctrl *gomock.Controller) *MockShortHeaderOpener {
mock := &MockShortHeaderOpener{ctrl: ctrl}
mock.recorder = &MockShortHeaderOpenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockShortHeaderOpener) EXPECT() *MockShortHeaderOpenerMockRecorder {
return m.recorder
}
// DecryptHeader mocks base method
func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2)
}
// DecryptHeader indicates an expected call of DecryptHeader
func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2)
}
// Open mocks base method
func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 protocol.KeyPhase, arg4 []byte) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open indicates an expected call of Open
func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4)
}

View file

@ -10,6 +10,10 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
) )
type headerDecryptor interface {
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
type unpackedPacket struct { type unpackedPacket struct {
packetNumber protocol.PacketNumber // the decoded packet number packetNumber protocol.PacketNumber // the decoded packet number
hdr *wire.ExtendedHeader hdr *wire.ExtendedHeader
@ -36,28 +40,85 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber)
} }
func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) { func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
r := bytes.NewReader(data)
var opener handshake.Opener
var encLevel protocol.EncryptionLevel var encLevel protocol.EncryptionLevel
var err error var extHdr *wire.ExtendedHeader
var decrypted []byte
switch hdr.Type { switch hdr.Type {
case protocol.PacketTypeInitial: case protocol.PacketTypeInitial:
encLevel = protocol.EncryptionInitial encLevel = protocol.EncryptionInitial
opener, err = u.cs.GetInitialOpener() opener, err := u.cs.GetInitialOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
case protocol.PacketTypeHandshake: case protocol.PacketTypeHandshake:
encLevel = protocol.EncryptionHandshake encLevel = protocol.EncryptionHandshake
opener, err = u.cs.GetHandshakeOpener() opener, err := u.cs.GetHandshakeOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
default: default:
if hdr.IsLongHeader { if hdr.IsLongHeader {
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
} }
encLevel = protocol.Encryption1RTT encLevel = protocol.Encryption1RTT
opener, err = u.cs.Get1RTTOpener() opener, err := u.cs.Get1RTTOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
} }
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, extHdr.PacketNumber)
return &unpackedPacket{
hdr: extHdr,
packetNumber: extHdr.PacketNumber,
encryptionLevel: encLevel,
data: decrypted,
}, nil
}
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, err := u.unpack(opener, hdr, data)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
extHdrLen := extHdr.GetLength(u.version)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
if err != nil {
return nil, nil, err
}
return extHdr, decrypted, nil
}
func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, err := u.unpack(opener, hdr, data)
if err != nil {
return nil, nil, err
}
extHdrLen := extHdr.GetLength(u.version)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen])
if err != nil {
return nil, nil, err
}
return extHdr, decrypted, nil
}
func (u *packetUnpacker) unpack(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
r := bytes.NewReader(data)
hdrLen := int(hdr.ParsedLen()) hdrLen := int(hdr.ParsedLen())
if len(data) < hdrLen+4+16 { if len(data) < hdrLen+4+16 {
return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", len(data)-hdrLen) return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
@ -67,7 +128,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
origPNBytes := make([]byte, 4) origPNBytes := make([]byte, 4)
copy(origPNBytes, data[hdrLen:hdrLen+4]) copy(origPNBytes, data[hdrLen:hdrLen+4])
// 2. decrypt the header, assuming a 4 byte packet number // 2. decrypt the header, assuming a 4 byte packet number
opener.DecryptHeader( hd.DecryptHeader(
data[hdrLen+4:hdrLen+4+16], data[hdrLen+4:hdrLen+4+16],
&data[0], &data[0],
data[hdrLen:hdrLen+4], data[hdrLen:hdrLen+4],
@ -77,30 +138,15 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing extended header: %s", err) return nil, fmt.Errorf("error parsing extended header: %s", err)
} }
extHdrLen := hdrLen + int(extHdr.PacketNumberLen)
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
if extHdr.PacketNumberLen != protocol.PacketNumberLen4 { if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
copy(data[extHdrLen:hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):]) copy(data[extHdr.GetLength(u.version):hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
} }
pn := protocol.DecodePacketNumber( extHdr.PacketNumber = protocol.DecodePacketNumber(
extHdr.PacketNumberLen, extHdr.PacketNumberLen,
u.largestRcvdPacketNumber, u.largestRcvdPacketNumber,
extHdr.PacketNumber, extHdr.PacketNumber,
) )
return extHdr, nil
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], pn, data[:extHdrLen])
if err != nil {
return nil, err
}
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, pn)
return &unpackedPacket{
hdr: extHdr,
packetNumber: pn,
encryptionLevel: encLevel,
data: decrypted,
}, nil
} }

View file

@ -48,7 +48,7 @@ var _ = Describe("Packet Unpacker", func() {
} }
hdr, hdrRaw := getHeader(extHdr) hdr, hdrRaw := getHeader(extHdr)
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
_, err := unpacker.Unpack(hdr, data) _, err := unpacker.Unpack(hdr, data)
Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19"))
@ -67,7 +67,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumberLen: 3, PacketNumberLen: 3,
} }
hdr, hdrRaw := getHeader(extHdr) hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockOpener(mockCtrl) opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetInitialOpener().Return(opener, nil) cs.EXPECT().GetInitialOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil) opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil)
@ -102,7 +102,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumberLen: 3, PacketNumberLen: 3,
} }
hdr, hdrRaw := getHeader(extHdr) hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockOpener(mockCtrl) opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err")) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
@ -129,7 +129,7 @@ var _ = Describe("Packet Unpacker", func() {
hdrRaw[len(hdrRaw)-2] ^= 0xff // invert the packet number hdrRaw[len(hdrRaw)-2] ^= 0xff // invert the packet number
hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number
Expect(hdrRaw[0]).ToNot(Equal(firstHdrByte)) Expect(hdrRaw[0]).ToNot(Equal(firstHdrByte))
opener := mocks.NewMockOpener(mockCtrl) opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
gomock.InOrder( gomock.InOrder(
// we're using a 2 byte packet number, so the sample starts at the 3rd payload byte // we're using a 2 byte packet number, so the sample starts at the 3rd payload byte
@ -158,11 +158,12 @@ var _ = Describe("Packet Unpacker", func() {
Header: wire.Header{DestConnectionID: connID}, Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337, PacketNumber: 0x1337,
PacketNumberLen: 2, PacketNumberLen: 2,
KeyPhase: protocol.KeyPhaseOne,
} }
opener := mocks.NewMockOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2) cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, protocol.KeyPhaseOne, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw := getHeader(firstHdr) hdr, hdrRaw := getHeader(firstHdr)
packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -172,10 +173,11 @@ var _ = Describe("Packet Unpacker", func() {
Header: wire.Header{DestConnectionID: connID}, Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x38, PacketNumber: 0x38,
PacketNumberLen: 1, PacketNumberLen: 1,
KeyPhase: protocol.KeyPhaseZero,
} }
// expect the call with the decoded packet number // expect the call with the decoded packet number
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), protocol.KeyPhaseZero, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw = getHeader(secondHdr) hdr, hdrRaw = getHeader(secondHdr)
packet, err = unpacker.Unpack(hdr, append(hdrRaw, payload...)) packet, err = unpacker.Unpack(hdr, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())