drop old keys 3 PTOs after a key update

This commit is contained in:
Marten Seemann 2019-06-29 16:40:45 +07:00
parent a2a4a216de
commit a3878d99c9
5 changed files with 76 additions and 9 deletions

View file

@ -9,6 +9,7 @@ import (
"sync"
"unsafe"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
@ -120,6 +121,7 @@ func NewCryptoSetupClient(
tp *TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
rttStats *congestion.RTTStats,
logger utils.Logger,
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
cs, clientHelloWritten, err := newCryptoSetup(
@ -130,6 +132,7 @@ func NewCryptoSetupClient(
tp,
runner,
tlsConf,
rttStats,
logger,
protocol.PerspectiveClient,
)
@ -150,6 +153,7 @@ func NewCryptoSetupServer(
tp *TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
rttStats *congestion.RTTStats,
logger utils.Logger,
) (CryptoSetup, error) {
cs, _, err := newCryptoSetup(
@ -160,6 +164,7 @@ func NewCryptoSetupServer(
tp,
runner,
tlsConf,
rttStats,
logger,
protocol.PerspectiveServer,
)
@ -178,6 +183,7 @@ func newCryptoSetup(
tp *TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
rttStats *congestion.RTTStats,
logger utils.Logger,
perspective protocol.Perspective,
) (*cryptoSetup, <-chan struct{} /* ClientHello written */, error) {
@ -192,7 +198,7 @@ func newCryptoSetup(
initialOpener: initialOpener,
handshakeStream: handshakeStream,
oneRTTStream: oneRTTStream,
aead: newUpdatableAEAD(logger),
aead: newUpdatableAEAD(rttStats, logger),
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
runner: runner,

View file

@ -13,6 +13,7 @@ import (
"time"
gomock "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/testdata"
@ -93,6 +94,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
NewMockHandshakeRunner(mockCtrl),
tlsConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
@ -124,6 +126,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
runner,
testdata.GetTLSConfig(),
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
@ -161,6 +164,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
runner,
testdata.GetTLSConfig(),
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
@ -200,6 +204,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
runner,
serverConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
@ -233,6 +238,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
NewMockHandshakeRunner(mockCtrl),
serverConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
@ -311,6 +317,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
cRunner,
clientConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
@ -332,6 +339,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
@ -384,6 +392,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
runner,
&tls.Config{InsecureSkipVerify: true},
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
@ -424,6 +433,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cTransportParameters,
cRunner,
clientConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
@ -446,6 +456,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sTransportParameters,
sRunner,
serverConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
@ -481,6 +492,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
cRunner,
clientConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
@ -498,6 +510,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
sRunner,
serverConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
@ -536,6 +549,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
cRunner,
clientConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
@ -553,6 +567,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
sRunner,
serverConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())

View file

@ -7,7 +7,9 @@ import (
"fmt"
"os"
"strconv"
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
@ -47,7 +49,9 @@ type updatableAEAD struct {
largestAcked protocol.PacketNumber
keyUpdateInterval uint64
prevRcvAEAD cipher.AEAD
// Time when the keys should be dropped. Keys are dropped on the next call to Open().
prevRcvAEADExpiry time.Time
prevRcvAEAD cipher.AEAD
firstRcvdWithCurrentKey protocol.PacketNumber
firstSentWithCurrentKey protocol.PacketNumber
@ -64,6 +68,8 @@ type updatableAEAD struct {
hpDecrypter cipher.Block
hpEncrypter cipher.Block
rttStats *congestion.RTTStats
logger utils.Logger
// use a single slice to avoid allocations
@ -74,12 +80,13 @@ type updatableAEAD struct {
var _ ShortHeaderOpener = &updatableAEAD{}
var _ ShortHeaderSealer = &updatableAEAD{}
func newUpdatableAEAD(logger utils.Logger) *updatableAEAD {
func newUpdatableAEAD(rttStats *congestion.RTTStats, logger utils.Logger) *updatableAEAD {
return &updatableAEAD{
largestAcked: protocol.InvalidPacketNumber,
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
keyUpdateInterval: keyUpdateInterval,
rttStats: rttStats,
logger: logger,
}
}
@ -91,6 +98,7 @@ func (a *updatableAEAD) rollKeys() {
a.numRcvdWithCurrentKey = 0
a.numSentWithCurrentKey = 0
a.prevRcvAEAD = a.rcvAEAD
a.prevRcvAEADExpiry = time.Now().Add(3 * a.rttStats.PTO())
a.rcvAEAD = a.nextRcvAEAD
a.sendAEAD = a.nextSendAEAD
@ -135,15 +143,22 @@ func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
}
func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
if a.prevRcvAEAD != nil && time.Now().After(a.prevRcvAEADExpiry) {
a.prevRcvAEAD = nil
a.prevRcvAEADExpiry = time.Time{}
}
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
if kp != a.keyPhase.Bit() {
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
if a.prevRcvAEAD == nil {
if a.keyPhase == 0 {
// This can only occur when the first packet received has key phase 1.
// This is an error, since the key phase starts at 0,
// and peers are only allowed to update keys after the handshake is confirmed.
return nil, qerr.Error(qerr.ProtocolViolation, "wrong initial keyphase")
}
if a.prevRcvAEAD == nil {
return nil, ErrKeysDropped
}
// we updated the key, but the peer hasn't updated yet
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {

View file

@ -6,7 +6,9 @@ import (
"crypto/cipher"
"crypto/rand"
"os"
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
@ -29,14 +31,14 @@ func (c *mockCipherSuite) AEAD(key, _ []byte) cipher.AEAD {
}
var _ = Describe("Updatable AEAD", func() {
getPeers := func() (client, server *updatableAEAD) {
getPeers := func(rttStats *congestion.RTTStats) (client, server *updatableAEAD) {
trafficSecret1 := make([]byte, 16)
trafficSecret2 := make([]byte, 16)
rand.Read(trafficSecret1)
rand.Read(trafficSecret2)
client = newUpdatableAEAD(utils.DefaultLogger)
server = newUpdatableAEAD(utils.DefaultLogger)
client = newUpdatableAEAD(rttStats, utils.DefaultLogger)
server = newUpdatableAEAD(rttStats, utils.DefaultLogger)
client.SetReadKey(&mockCipherSuite{}, trafficSecret2)
client.SetWriteKey(&mockCipherSuite{}, trafficSecret1)
server.SetReadKey(&mockCipherSuite{}, trafficSecret1)
@ -46,7 +48,7 @@ var _ = Describe("Updatable AEAD", func() {
Context("header protection", func() {
It("encrypts and decrypts the header", func() {
server, client := getPeers()
server, client := getPeers(&congestion.RTTStats{})
var lastFiveBitsDifferent int
for i := 0; i < 100; i++ {
sample := make([]byte, 16)
@ -69,9 +71,11 @@ var _ = Describe("Updatable AEAD", func() {
Context("message encryption", func() {
var msg, ad []byte
var server, client *updatableAEAD
var rttStats *congestion.RTTStats
BeforeEach(func() {
server, client = getPeers()
rttStats = &congestion.RTTStats{}
server, client = getPeers(rttStats)
msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
ad = []byte("Donec in velit neque.")
})
@ -132,6 +136,7 @@ var _ = Describe("Updatable AEAD", func() {
})
It("opens a reordered packet with the old keys after an update", func() {
rttStats.UpdateRTT(time.Hour, 0, time.Time{}) // make sure the keys don't get dropped yet
encrypted01 := client.Seal(nil, msg, 0x42, ad)
encrypted02 := client.Seal(nil, msg, 0x43, ad)
// receive the first packet with key phase 0
@ -153,6 +158,30 @@ var _ = Describe("Updatable AEAD", func() {
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
})
It("drops keys 3 PTOs after a key update", func() {
rttStats.UpdateRTT(10*time.Millisecond, 0, time.Now())
pto := rttStats.PTO()
Expect(pto).To(BeNumerically("<", 50*time.Millisecond))
encrypted01 := client.Seal(nil, msg, 0x42, ad)
encrypted02 := client.Seal(nil, msg, 0x43, ad)
// receive the first packet with key phase 0
_, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
// send one packet at key phase zero
_ = server.Seal(nil, msg, 0x1, ad)
// now receive a packet with key phase 1
client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x44, ad)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// now receive a reordered packet with key phase 0
time.Sleep(3 * pto)
_, err = server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError(ErrKeysDropped))
})
It("errors when the peer starts with key phase 1", func() {
client.rollKeys()
encrypted := client.Seal(nil, msg, 0x1337, ad)

View file

@ -223,6 +223,7 @@ var newSession = func(
onHandshakeComplete: func() { close(s.handshakeCompleteChan) },
},
tlsConf,
s.rttStats,
logger,
)
if err != nil {
@ -296,6 +297,7 @@ var newClientSession = func(
onHandshakeComplete: func() { close(s.handshakeCompleteChan) },
},
tlsConf,
s.rttStats,
logger,
)
if err != nil {