diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index 53cf0916..c5702d35 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -62,47 +62,77 @@ func (s *sealer) KeyPhase() protocol.KeyPhase { return protocol.KeyPhaseZero } -type opener struct { +type longHeaderOpener struct { aead cipher.AEAD pnDecrypter cipher.Block // use a single slice to avoid allocations nonceBuf []byte hpMask []byte - - // short headers protect 5 bits in the first byte, long headers only 4 - is1RTT bool } -var _ Opener = &opener{} +var _ LongHeaderOpener = &longHeaderOpener{} -func newOpener(aead cipher.AEAD, pnDecrypter cipher.Block, is1RTT bool) Opener { - return &opener{ +func newLongHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) LongHeaderOpener { + return &longHeaderOpener{ aead: aead, nonceBuf: make([]byte, aead.NonceSize()), - is1RTT: is1RTT, pnDecrypter: pnDecrypter, hpMask: make([]byte, pnDecrypter.BlockSize()), } } -func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { +func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) // The AEAD we're using here will be the qtls.aeadAESGCM13. // It uses the nonce provided here and XOR it with the IV. return o.aead.Open(dst, o.nonceBuf, src, ad) } -func (o *opener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { +func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { if len(sample) != o.pnDecrypter.BlockSize() { panic("invalid sample size") } o.pnDecrypter.Encrypt(o.hpMask, sample) - if o.is1RTT { - *firstByte ^= o.hpMask[0] & 0x1f - } else { - *firstByte ^= o.hpMask[0] & 0xf - } + *firstByte ^= o.hpMask[0] & 0xf + for i := range pnBytes { + pnBytes[i] ^= o.hpMask[i+1] + } +} + +type shortHeaderOpener struct { + aead cipher.AEAD + pnDecrypter cipher.Block + + // use a single slice to avoid allocations + nonceBuf []byte + hpMask []byte +} + +var _ ShortHeaderOpener = &shortHeaderOpener{} + +func newShortHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) ShortHeaderOpener { + return &shortHeaderOpener{ + aead: aead, + nonceBuf: make([]byte, aead.NonceSize()), + pnDecrypter: pnDecrypter, + hpMask: make([]byte, pnDecrypter.BlockSize()), + } +} + +func (o *shortHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, _ protocol.KeyPhase, ad []byte) ([]byte, error) { + binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + return o.aead.Open(dst, o.nonceBuf, src, ad) +} + +func (o *shortHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + if len(sample) != o.pnDecrypter.BlockSize() { + panic("invalid sample size") + } + o.pnDecrypter.Encrypt(o.hpMask, sample) + *firstByte ^= o.hpMask[0] & 0x1f for i := range pnBytes { pnBytes[i] ^= o.hpMask[i+1] } diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index 8559b3c4..e1f105cc 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -5,12 +5,13 @@ import ( "crypto/cipher" "crypto/rand" + "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("AEAD", func() { - getSealerAndOpener := func(is1RTT bool) (ShortHeaderSealer, Opener) { + getLongHeaderSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { key := make([]byte, 16) hpKey := make([]byte, 16) rand.Read(key) @@ -24,36 +25,55 @@ var _ = Describe("AEAD", func() { iv := make([]byte, 12) rand.Read(iv) - return newSealer(aead, hpBlock, is1RTT), newOpener(aead, hpBlock, is1RTT) + return newSealer(aead, hpBlock, false), newLongHeaderOpener(aead, hpBlock) + } + + getShortHeaderSealerAndOpener := func() (ShortHeaderSealer, ShortHeaderOpener) { + key := make([]byte, 16) + hpKey := make([]byte, 16) + rand.Read(key) + rand.Read(hpKey) + block, err := aes.NewCipher(key) + Expect(err).ToNot(HaveOccurred()) + aead, err := cipher.NewGCM(block) + Expect(err).ToNot(HaveOccurred()) + hpBlock, err := aes.NewCipher(hpKey) + Expect(err).ToNot(HaveOccurred()) + + iv := make([]byte, 12) + rand.Read(iv) + return newSealer(aead, hpBlock, true), newShortHeaderOpener(aead, hpBlock) } Context("message encryption", func() { - var ( - sealer ShortHeaderSealer - opener Opener - ) - - BeforeEach(func() { - sealer, opener = getSealerAndOpener(false) - }) - msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") ad := []byte("Donec in velit neque.") - It("encrypts and decrypts a message", func() { + It("encrypts and decrypts a message, for long headers", func() { + sealer, opener := getLongHeaderSealerAndOpener() encrypted := sealer.Seal(nil, msg, 0x1337, ad) opened, err := opener.Open(nil, encrypted, 0x1337, ad) Expect(err).ToNot(HaveOccurred()) Expect(opened).To(Equal(msg)) }) + It("encrypts and decrypts a message, for short headers", func() { + sealer, opener := getShortHeaderSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + opened, err := opener.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) + }) + It("fails to open a message if the associated data is not the same", func() { + sealer, opener := getLongHeaderSealerAndOpener() encrypted := sealer.Seal(nil, msg, 0x1337, ad) _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) Expect(err).To(MatchError("cipher: message authentication failed")) }) It("fails to open a message if the packet number is not the same", func() { + sealer, opener := getLongHeaderSealerAndOpener() encrypted := sealer.Seal(nil, msg, 0x1337, ad) _, err := opener.Open(nil, encrypted, 0x42, ad) Expect(err).To(MatchError("cipher: message authentication failed")) @@ -62,7 +82,7 @@ var _ = Describe("AEAD", func() { Context("header encryption", func() { It("encrypts and encrypts the header, for long headers", func() { - sealer, opener := getSealerAndOpener(false) + sealer, opener := getLongHeaderSealerAndOpener() var lastFourBitsDifferent int for i := 0; i < 100; i++ { sample := make([]byte, 16) @@ -82,7 +102,7 @@ var _ = Describe("AEAD", func() { }) It("encrypts and encrypts the header, for short headers", func() { - sealer, opener := getSealerAndOpener(true) + sealer, opener := getShortHeaderSealerAndOpener() var lastFiveBitsDifferent int for i := 0; i < 100; i++ { sample := make([]byte, 16) @@ -102,7 +122,7 @@ var _ = Describe("AEAD", func() { }) It("fails to decrypt the header when using a different sample", func() { - sealer, opener := getSealerAndOpener(true) + sealer, opener := getLongHeaderSealerAndOpener() header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} sample := make([]byte, 16) rand.Read(sample) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 81f5baa2..3606b810 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -105,15 +105,15 @@ type cryptoSetup struct { writeEncLevel protocol.EncryptionLevel initialStream io.Writer - initialOpener Opener + initialOpener LongHeaderOpener initialSealer LongHeaderSealer handshakeStream io.Writer - handshakeOpener Opener + handshakeOpener LongHeaderOpener handshakeSealer LongHeaderSealer oneRTTStream io.Writer - opener Opener + opener ShortHeaderOpener sealer ShortHeaderSealer } @@ -493,11 +493,11 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) switch h.readEncLevel { case protocol.EncryptionInitial: h.readEncLevel = protocol.EncryptionHandshake - h.handshakeOpener = newOpener(suite.AEAD(key, iv), hpDecrypter, false) + h.handshakeOpener = newLongHeaderOpener(suite.AEAD(key, iv), hpDecrypter) h.logger.Debugf("Installed Handshake Read keys") case protocol.EncryptionHandshake: h.readEncLevel = protocol.Encryption1RTT - h.opener = newOpener(suite.AEAD(key, iv), hpDecrypter, true) + h.opener = newShortHeaderOpener(suite.AEAD(key, iv), hpDecrypter) h.logger.Debugf("Installed 1-RTT Read keys") default: panic("unexpected read encryption level") @@ -591,7 +591,7 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { return h.sealer, nil } -func (h *cryptoSetup) GetInitialOpener() (Opener, error) { +func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { h.mutex.Lock() defer h.mutex.Unlock() @@ -601,7 +601,7 @@ func (h *cryptoSetup) GetInitialOpener() (Opener, error) { return h.initialOpener, nil } -func (h *cryptoSetup) GetHandshakeOpener() (Opener, error) { +func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { h.mutex.Lock() defer h.mutex.Unlock() @@ -615,7 +615,7 @@ func (h *cryptoSetup) GetHandshakeOpener() (Opener, error) { return h.handshakeOpener, nil } -func (h *cryptoSetup) Get1RTTOpener() (Opener, error) { +func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { h.mutex.Lock() defer h.mutex.Unlock() diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index c561ae61..fcb08e15 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -11,7 +11,7 @@ import ( var quicVersion1Salt = []byte{0xef, 0x4f, 0xb0, 0xab, 0xb4, 0x74, 0x70, 0xc4, 0x1b, 0xef, 0xcf, 0x80, 0x31, 0x33, 0x4f, 0xae, 0x48, 0x5e, 0x09, 0xa0} // NewInitialAEAD creates a new AEAD for Initial encryption / decryption. -func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, Opener, error) { +func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, LongHeaderOpener, error) { clientSecret, serverSecret := computeSecrets(connID) var mySecret, otherSecret []byte if pers == protocol.PerspectiveClient { @@ -34,7 +34,7 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo if err != nil { return nil, nil, err } - return newSealer(encrypter, hpEncrypter, false), newOpener(decrypter, hpDecrypter, false), nil + return newSealer(encrypter, hpEncrypter, false), newLongHeaderOpener(decrypter, hpDecrypter), nil } func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index c1653935..03f15fdf 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -8,12 +8,22 @@ import ( "github.com/marten-seemann/qtls" ) -// Opener opens a packet -type Opener interface { - Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) +type headerDecryptor interface { DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) } +// LongHeaderOpener opens a long header packet +type LongHeaderOpener interface { + headerDecryptor + Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error) +} + +// ShortHeaderOpener opens a short header packet +type ShortHeaderOpener interface { + headerDecryptor + Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhase, associatedData []byte) ([]byte, error) +} + // LongHeaderSealer seals a long header packet type LongHeaderSealer interface { Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte @@ -51,9 +61,9 @@ type CryptoSetup interface { Received1RTTAck() ConnectionState() tls.ConnectionState - GetInitialOpener() (Opener, error) - GetHandshakeOpener() (Opener, error) - Get1RTTOpener() (Opener, error) + GetInitialOpener() (LongHeaderOpener, error) + GetHandshakeOpener() (LongHeaderOpener, error) + Get1RTTOpener() (ShortHeaderOpener, error) GetInitialSealer() (LongHeaderSealer, error) GetHandshakeSealer() (LongHeaderSealer, error) diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index 13f0a5bf..77ed0bed 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -79,10 +79,10 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call { } // Get1RTTOpener mocks base method -func (m *MockCryptoSetup) Get1RTTOpener() (handshake.Opener, error) { +func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get1RTTOpener") - ret0, _ := ret[0].(handshake.Opener) + ret0, _ := ret[0].(handshake.ShortHeaderOpener) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -109,10 +109,10 @@ func (mr *MockCryptoSetupMockRecorder) Get1RTTSealer() *gomock.Call { } // GetHandshakeOpener mocks base method -func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.Opener, error) { +func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.LongHeaderOpener, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetHandshakeOpener") - ret0, _ := ret[0].(handshake.Opener) + ret0, _ := ret[0].(handshake.LongHeaderOpener) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -139,10 +139,10 @@ func (mr *MockCryptoSetupMockRecorder) GetHandshakeSealer() *gomock.Call { } // GetInitialOpener mocks base method -func (m *MockCryptoSetup) GetInitialOpener() (handshake.Opener, error) { +func (m *MockCryptoSetup) GetInitialOpener() (handshake.LongHeaderOpener, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetInitialOpener") - ret0, _ := ret[0].(handshake.Opener) + ret0, _ := ret[0].(handshake.LongHeaderOpener) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/internal/mocks/long_header_opener.go b/internal/mocks/long_header_opener.go new file mode 100644 index 00000000..d99bd827 --- /dev/null +++ b/internal/mocks/long_header_opener.go @@ -0,0 +1,62 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: LongHeaderOpener) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockLongHeaderOpener is a mock of LongHeaderOpener interface +type MockLongHeaderOpener struct { + ctrl *gomock.Controller + recorder *MockLongHeaderOpenerMockRecorder +} + +// MockLongHeaderOpenerMockRecorder is the mock recorder for MockLongHeaderOpener +type MockLongHeaderOpenerMockRecorder struct { + mock *MockLongHeaderOpener +} + +// NewMockLongHeaderOpener creates a new mock instance +func NewMockLongHeaderOpener(ctrl *gomock.Controller) *MockLongHeaderOpener { + mock := &MockLongHeaderOpener{ctrl: ctrl} + mock.recorder = &MockLongHeaderOpenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockLongHeaderOpener) EXPECT() *MockLongHeaderOpenerMockRecorder { + return m.recorder +} + +// DecryptHeader mocks base method +func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2) +} + +// DecryptHeader indicates an expected call of DecryptHeader +func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) +} + +// Open mocks base method +func (m *MockLongHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Open indicates an expected call of Open +func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3) +} diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 91ec4cb8..8b443f29 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -3,7 +3,8 @@ package mocks //go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go" //go:generate sh -c "mockgen -package mockquic -destination quic/session.go github.com/lucas-clemente/quic-go Session && goimports -w quic/session.go" //go:generate sh -c "../mockgen_internal.sh mocks short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer" -//go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener" +//go:generate sh -c "../mockgen_internal.sh mocks short_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderOpener" +//go:generate sh -c "../mockgen_internal.sh mocks long_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake LongHeaderOpener" //go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup" //go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" //go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler" diff --git a/internal/mocks/opener.go b/internal/mocks/opener.go deleted file mode 100644 index 2fc399e4..00000000 --- a/internal/mocks/opener.go +++ /dev/null @@ -1,62 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Opener) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// MockOpener is a mock of Opener interface -type MockOpener struct { - ctrl *gomock.Controller - recorder *MockOpenerMockRecorder -} - -// MockOpenerMockRecorder is the mock recorder for MockOpener -type MockOpenerMockRecorder struct { - mock *MockOpener -} - -// NewMockOpener creates a new mock instance -func NewMockOpener(ctrl *gomock.Controller) *MockOpener { - mock := &MockOpener{ctrl: ctrl} - mock.recorder = &MockOpenerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockOpener) EXPECT() *MockOpenerMockRecorder { - return m.recorder -} - -// DecryptHeader mocks base method -func (m *MockOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2) -} - -// DecryptHeader indicates an expected call of DecryptHeader -func (mr *MockOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockOpener)(nil).DecryptHeader), arg0, arg1, arg2) -} - -// Open mocks base method -func (m *MockOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Open indicates an expected call of Open -func (mr *MockOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockOpener)(nil).Open), arg0, arg1, arg2, arg3) -} diff --git a/internal/mocks/short_header_opener.go b/internal/mocks/short_header_opener.go new file mode 100644 index 00000000..a51f9eb5 --- /dev/null +++ b/internal/mocks/short_header_opener.go @@ -0,0 +1,62 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: ShortHeaderOpener) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockShortHeaderOpener is a mock of ShortHeaderOpener interface +type MockShortHeaderOpener struct { + ctrl *gomock.Controller + recorder *MockShortHeaderOpenerMockRecorder +} + +// MockShortHeaderOpenerMockRecorder is the mock recorder for MockShortHeaderOpener +type MockShortHeaderOpenerMockRecorder struct { + mock *MockShortHeaderOpener +} + +// NewMockShortHeaderOpener creates a new mock instance +func NewMockShortHeaderOpener(ctrl *gomock.Controller) *MockShortHeaderOpener { + mock := &MockShortHeaderOpener{ctrl: ctrl} + mock.recorder = &MockShortHeaderOpenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockShortHeaderOpener) EXPECT() *MockShortHeaderOpenerMockRecorder { + return m.recorder +} + +// DecryptHeader mocks base method +func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2) +} + +// DecryptHeader indicates an expected call of DecryptHeader +func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) +} + +// Open mocks base method +func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 protocol.KeyPhase, arg4 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Open indicates an expected call of Open +func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4) +} diff --git a/packet_unpacker.go b/packet_unpacker.go index 6e246552..fe302d09 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -10,6 +10,10 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) +type headerDecryptor interface { + DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) +} + type unpackedPacket struct { packetNumber protocol.PacketNumber // the decoded packet number hdr *wire.ExtendedHeader @@ -36,28 +40,85 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) } func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) { - r := bytes.NewReader(data) - - var opener handshake.Opener var encLevel protocol.EncryptionLevel - var err error + var extHdr *wire.ExtendedHeader + var decrypted []byte switch hdr.Type { case protocol.PacketTypeInitial: encLevel = protocol.EncryptionInitial - opener, err = u.cs.GetInitialOpener() + opener, err := u.cs.GetInitialOpener() + if err != nil { + return nil, err + } + extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) + if err != nil { + return nil, err + } case protocol.PacketTypeHandshake: encLevel = protocol.EncryptionHandshake - opener, err = u.cs.GetHandshakeOpener() + opener, err := u.cs.GetHandshakeOpener() + if err != nil { + return nil, err + } + extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) + if err != nil { + return nil, err + } default: if hdr.IsLongHeader { return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) } encLevel = protocol.Encryption1RTT - opener, err = u.cs.Get1RTTOpener() + opener, err := u.cs.Get1RTTOpener() + if err != nil { + return nil, err + } + extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, data) + if err != nil { + return nil, err + } } + + // Only do this after decrypting, so we are sure the packet is not attacker-controlled + u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, extHdr.PacketNumber) + + return &unpackedPacket{ + hdr: extHdr, + packetNumber: extHdr.PacketNumber, + encryptionLevel: encLevel, + data: decrypted, + }, nil +} + +func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { + extHdr, err := u.unpack(opener, hdr, data) if err != nil { - return nil, err + return nil, nil, err } + extHdrLen := extHdr.GetLength(u.version) + decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen]) + if err != nil { + return nil, nil, err + } + return extHdr, decrypted, nil +} + +func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { + extHdr, err := u.unpack(opener, hdr, data) + if err != nil { + return nil, nil, err + } + extHdrLen := extHdr.GetLength(u.version) + decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen]) + if err != nil { + return nil, nil, err + } + return extHdr, decrypted, nil +} + +func (u *packetUnpacker) unpack(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { + r := bytes.NewReader(data) + hdrLen := int(hdr.ParsedLen()) if len(data) < hdrLen+4+16 { return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", len(data)-hdrLen) @@ -67,7 +128,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, origPNBytes := make([]byte, 4) copy(origPNBytes, data[hdrLen:hdrLen+4]) // 2. decrypt the header, assuming a 4 byte packet number - opener.DecryptHeader( + hd.DecryptHeader( data[hdrLen+4:hdrLen+4+16], &data[0], data[hdrLen:hdrLen+4], @@ -77,30 +138,15 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, if err != nil { return nil, fmt.Errorf("error parsing extended header: %s", err) } - extHdrLen := hdrLen + int(extHdr.PacketNumberLen) // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier if extHdr.PacketNumberLen != protocol.PacketNumberLen4 { - copy(data[extHdrLen:hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):]) + copy(data[extHdr.GetLength(u.version):hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):]) } - pn := protocol.DecodePacketNumber( + extHdr.PacketNumber = protocol.DecodePacketNumber( extHdr.PacketNumberLen, u.largestRcvdPacketNumber, extHdr.PacketNumber, ) - - decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], pn, data[:extHdrLen]) - if err != nil { - return nil, err - } - - // Only do this after decrypting, so we are sure the packet is not attacker-controlled - u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, pn) - - return &unpackedPacket{ - hdr: extHdr, - packetNumber: pn, - encryptionLevel: encLevel, - data: decrypted, - }, nil + return extHdr, nil } diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 7e274704..fdaba9b1 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -48,7 +48,7 @@ var _ = Describe("Packet Unpacker", func() { } hdr, hdrRaw := getHeader(extHdr) data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) - opener := mocks.NewMockOpener(mockCtrl) + opener := mocks.NewMockShortHeaderOpener(mockCtrl) cs.EXPECT().Get1RTTOpener().Return(opener, nil) _, err := unpacker.Unpack(hdr, data) Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) @@ -67,7 +67,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumberLen: 3, } hdr, hdrRaw := getHeader(extHdr) - opener := mocks.NewMockOpener(mockCtrl) + opener := mocks.NewMockLongHeaderOpener(mockCtrl) cs.EXPECT().GetInitialOpener().Return(opener, nil) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil) @@ -102,7 +102,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumberLen: 3, } hdr, hdrRaw := getHeader(extHdr) - opener := mocks.NewMockOpener(mockCtrl) + opener := mocks.NewMockLongHeaderOpener(mockCtrl) cs.EXPECT().GetHandshakeOpener().Return(opener, nil) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err")) @@ -129,7 +129,7 @@ var _ = Describe("Packet Unpacker", func() { hdrRaw[len(hdrRaw)-2] ^= 0xff // invert the packet number hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number Expect(hdrRaw[0]).ToNot(Equal(firstHdrByte)) - opener := mocks.NewMockOpener(mockCtrl) + opener := mocks.NewMockLongHeaderOpener(mockCtrl) cs.EXPECT().GetHandshakeOpener().Return(opener, nil) gomock.InOrder( // we're using a 2 byte packet number, so the sample starts at the 3rd payload byte @@ -158,11 +158,12 @@ var _ = Describe("Packet Unpacker", func() { Header: wire.Header{DestConnectionID: connID}, PacketNumber: 0x1337, PacketNumberLen: 2, + KeyPhase: protocol.KeyPhaseOne, } - opener := mocks.NewMockOpener(mockCtrl) + opener := mocks.NewMockShortHeaderOpener(mockCtrl) cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, protocol.KeyPhaseOne, gomock.Any()).Return([]byte{0}, nil) hdr, hdrRaw := getHeader(firstHdr) packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) Expect(err).ToNot(HaveOccurred()) @@ -172,10 +173,11 @@ var _ = Describe("Packet Unpacker", func() { Header: wire.Header{DestConnectionID: connID}, PacketNumber: 0x38, PacketNumberLen: 1, + KeyPhase: protocol.KeyPhaseZero, } // expect the call with the decoded packet number opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), protocol.KeyPhaseZero, gomock.Any()).Return([]byte{0}, nil) hdr, hdrRaw = getHeader(secondHdr) packet, err = unpacker.Unpack(hdr, append(hdrRaw, payload...)) Expect(err).ToNot(HaveOccurred())