diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 6d584c7d..ef14753a 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -9,6 +9,7 @@ import ( "sync" "unsafe" + "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/utils" @@ -120,6 +121,7 @@ func NewCryptoSetupClient( tp *TransportParameters, runner handshakeRunner, tlsConf *tls.Config, + rttStats *congestion.RTTStats, logger utils.Logger, ) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) { cs, clientHelloWritten, err := newCryptoSetup( @@ -130,6 +132,7 @@ func NewCryptoSetupClient( tp, runner, tlsConf, + rttStats, logger, protocol.PerspectiveClient, ) @@ -150,6 +153,7 @@ func NewCryptoSetupServer( tp *TransportParameters, runner handshakeRunner, tlsConf *tls.Config, + rttStats *congestion.RTTStats, logger utils.Logger, ) (CryptoSetup, error) { cs, _, err := newCryptoSetup( @@ -160,6 +164,7 @@ func NewCryptoSetupServer( tp, runner, tlsConf, + rttStats, logger, protocol.PerspectiveServer, ) @@ -178,6 +183,7 @@ func newCryptoSetup( tp *TransportParameters, runner handshakeRunner, tlsConf *tls.Config, + rttStats *congestion.RTTStats, logger utils.Logger, perspective protocol.Perspective, ) (*cryptoSetup, <-chan struct{} /* ClientHello written */, error) { @@ -192,7 +198,7 @@ func newCryptoSetup( initialOpener: initialOpener, handshakeStream: handshakeStream, oneRTTStream: oneRTTStream, - aead: newUpdatableAEAD(logger), + aead: newUpdatableAEAD(rttStats, logger), readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, runner: runner, diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 8778ab5b..366fcb23 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -13,6 +13,7 @@ import ( "time" gomock "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/testdata" @@ -93,6 +94,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{}, NewMockHandshakeRunner(mockCtrl), tlsConf, + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("server"), ) Expect(err).ToNot(HaveOccurred()) @@ -124,6 +126,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{}, runner, testdata.GetTLSConfig(), + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("server"), ) Expect(err).ToNot(HaveOccurred()) @@ -161,6 +164,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{}, runner, testdata.GetTLSConfig(), + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("server"), ) Expect(err).ToNot(HaveOccurred()) @@ -200,6 +204,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{}, runner, serverConf, + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("server"), ) Expect(err).ToNot(HaveOccurred()) @@ -233,6 +238,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{}, NewMockHandshakeRunner(mockCtrl), serverConf, + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("server"), ) Expect(err).ToNot(HaveOccurred()) @@ -311,6 +317,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{}, cRunner, clientConf, + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("client"), ) Expect(err).ToNot(HaveOccurred()) @@ -332,6 +339,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{StatelessResetToken: &token}, sRunner, serverConf, + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("server"), ) Expect(err).ToNot(HaveOccurred()) @@ -384,6 +392,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{}, runner, &tls.Config{InsecureSkipVerify: true}, + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("client"), ) Expect(err).ToNot(HaveOccurred()) @@ -424,6 +433,7 @@ var _ = Describe("Crypto Setup TLS", func() { cTransportParameters, cRunner, clientConf, + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("client"), ) Expect(err).ToNot(HaveOccurred()) @@ -446,6 +456,7 @@ var _ = Describe("Crypto Setup TLS", func() { sTransportParameters, sRunner, serverConf, + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("server"), ) Expect(err).ToNot(HaveOccurred()) @@ -481,6 +492,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{}, cRunner, clientConf, + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("client"), ) Expect(err).ToNot(HaveOccurred()) @@ -498,6 +510,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{}, sRunner, serverConf, + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("server"), ) Expect(err).ToNot(HaveOccurred()) @@ -536,6 +549,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{}, cRunner, clientConf, + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("client"), ) Expect(err).ToNot(HaveOccurred()) @@ -553,6 +567,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{}, sRunner, serverConf, + &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("server"), ) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index a08824de..65ea365a 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -7,7 +7,9 @@ import ( "fmt" "os" "strconv" + "time" + "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/utils" @@ -47,7 +49,9 @@ type updatableAEAD struct { largestAcked protocol.PacketNumber keyUpdateInterval uint64 - prevRcvAEAD cipher.AEAD + // Time when the keys should be dropped. Keys are dropped on the next call to Open(). + prevRcvAEADExpiry time.Time + prevRcvAEAD cipher.AEAD firstRcvdWithCurrentKey protocol.PacketNumber firstSentWithCurrentKey protocol.PacketNumber @@ -64,6 +68,8 @@ type updatableAEAD struct { hpDecrypter cipher.Block hpEncrypter cipher.Block + rttStats *congestion.RTTStats + logger utils.Logger // use a single slice to avoid allocations @@ -74,12 +80,13 @@ type updatableAEAD struct { var _ ShortHeaderOpener = &updatableAEAD{} var _ ShortHeaderSealer = &updatableAEAD{} -func newUpdatableAEAD(logger utils.Logger) *updatableAEAD { +func newUpdatableAEAD(rttStats *congestion.RTTStats, logger utils.Logger) *updatableAEAD { return &updatableAEAD{ largestAcked: protocol.InvalidPacketNumber, firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, firstSentWithCurrentKey: protocol.InvalidPacketNumber, keyUpdateInterval: keyUpdateInterval, + rttStats: rttStats, logger: logger, } } @@ -91,6 +98,7 @@ func (a *updatableAEAD) rollKeys() { a.numRcvdWithCurrentKey = 0 a.numSentWithCurrentKey = 0 a.prevRcvAEAD = a.rcvAEAD + a.prevRcvAEADExpiry = time.Now().Add(3 * a.rttStats.PTO()) a.rcvAEAD = a.nextRcvAEAD a.sendAEAD = a.nextSendAEAD @@ -135,15 +143,22 @@ func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { } func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { + if a.prevRcvAEAD != nil && time.Now().After(a.prevRcvAEADExpiry) { + a.prevRcvAEAD = nil + a.prevRcvAEADExpiry = time.Time{} + } binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) if kp != a.keyPhase.Bit() { if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { - if a.prevRcvAEAD == nil { + if a.keyPhase == 0 { // This can only occur when the first packet received has key phase 1. // This is an error, since the key phase starts at 0, // and peers are only allowed to update keys after the handshake is confirmed. return nil, qerr.Error(qerr.ProtocolViolation, "wrong initial keyphase") } + if a.prevRcvAEAD == nil { + return nil, ErrKeysDropped + } // we updated the key, but the peer hasn't updated yet dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) if err != nil { diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 314d048d..3ec8e920 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -6,7 +6,9 @@ import ( "crypto/cipher" "crypto/rand" "os" + "time" + "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" @@ -29,14 +31,14 @@ func (c *mockCipherSuite) AEAD(key, _ []byte) cipher.AEAD { } var _ = Describe("Updatable AEAD", func() { - getPeers := func() (client, server *updatableAEAD) { + getPeers := func(rttStats *congestion.RTTStats) (client, server *updatableAEAD) { trafficSecret1 := make([]byte, 16) trafficSecret2 := make([]byte, 16) rand.Read(trafficSecret1) rand.Read(trafficSecret2) - client = newUpdatableAEAD(utils.DefaultLogger) - server = newUpdatableAEAD(utils.DefaultLogger) + client = newUpdatableAEAD(rttStats, utils.DefaultLogger) + server = newUpdatableAEAD(rttStats, utils.DefaultLogger) client.SetReadKey(&mockCipherSuite{}, trafficSecret2) client.SetWriteKey(&mockCipherSuite{}, trafficSecret1) server.SetReadKey(&mockCipherSuite{}, trafficSecret1) @@ -46,7 +48,7 @@ var _ = Describe("Updatable AEAD", func() { Context("header protection", func() { It("encrypts and decrypts the header", func() { - server, client := getPeers() + server, client := getPeers(&congestion.RTTStats{}) var lastFiveBitsDifferent int for i := 0; i < 100; i++ { sample := make([]byte, 16) @@ -69,9 +71,11 @@ var _ = Describe("Updatable AEAD", func() { Context("message encryption", func() { var msg, ad []byte var server, client *updatableAEAD + var rttStats *congestion.RTTStats BeforeEach(func() { - server, client = getPeers() + rttStats = &congestion.RTTStats{} + server, client = getPeers(rttStats) msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") ad = []byte("Donec in velit neque.") }) @@ -132,6 +136,7 @@ var _ = Describe("Updatable AEAD", func() { }) It("opens a reordered packet with the old keys after an update", func() { + rttStats.UpdateRTT(time.Hour, 0, time.Time{}) // make sure the keys don't get dropped yet encrypted01 := client.Seal(nil, msg, 0x42, ad) encrypted02 := client.Seal(nil, msg, 0x43, ad) // receive the first packet with key phase 0 @@ -153,6 +158,30 @@ var _ = Describe("Updatable AEAD", func() { Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) }) + It("drops keys 3 PTOs after a key update", func() { + rttStats.UpdateRTT(10*time.Millisecond, 0, time.Now()) + pto := rttStats.PTO() + Expect(pto).To(BeNumerically("<", 50*time.Millisecond)) + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // send one packet at key phase zero + _ = server.Seal(nil, msg, 0x1, ad) + // now receive a packet with key phase 1 + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + time.Sleep(3 * pto) + _, err = server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrKeysDropped)) + }) + It("errors when the peer starts with key phase 1", func() { client.rollKeys() encrypted := client.Seal(nil, msg, 0x1337, ad) diff --git a/session.go b/session.go index b71e96c8..2ea582d3 100644 --- a/session.go +++ b/session.go @@ -223,6 +223,7 @@ var newSession = func( onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, }, tlsConf, + s.rttStats, logger, ) if err != nil { @@ -296,6 +297,7 @@ var newClientSession = func( onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, }, tlsConf, + s.rttStats, logger, ) if err != nil {