diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 1610e9a2..71d42fb4 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -3,6 +3,7 @@ package handshake import ( "crypto" "crypto/cipher" + "crypto/tls" "encoding/binary" "fmt" "os" @@ -18,7 +19,7 @@ import ( // By setting this environment variable, the key update interval can be adjusted. // This is not needed in production, but useful for integration and interop testing. -// Note that no mattter what value is set, a key update is only initiated once it is +// Note that no matter what value is set, a key update is only initiated once it is // permitted (i.e. once an ACK for a packet sent at the current key phase has been received). const keyUpdateEnv = "QUIC_GO_KEY_UPDATE_INTERVAL" @@ -47,7 +48,10 @@ type updatableAEAD struct { keyPhase protocol.KeyPhase largestAcked protocol.PacketNumber firstPacketNumber protocol.PacketNumber - keyUpdateInterval uint64 + + keyUpdateInterval uint64 + invalidPacketLimit uint64 + invalidPacketCount uint64 // Time when the keys should be dropped. Keys are dropped on the next call to Open(). prevRcvAEADExpiry time.Time @@ -127,9 +131,7 @@ func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret [ a.rcvAEAD = createAEAD(suite, trafficSecret) a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false) if a.suite == nil { - a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize()) - a.aeadOverhead = a.rcvAEAD.Overhead() - a.suite = suite + a.setAEADParameters(a.rcvAEAD, suite) } a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) @@ -142,16 +144,39 @@ func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret a.sendAEAD = createAEAD(suite, trafficSecret) a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false) if a.suite == nil { - a.nonceBuf = make([]byte, a.sendAEAD.NonceSize()) - a.aeadOverhead = a.sendAEAD.Overhead() - a.suite = suite + a.setAEADParameters(a.sendAEAD, suite) } a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret) } +func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) { + a.nonceBuf = make([]byte, aead.NonceSize()) + a.aeadOverhead = aead.Overhead() + a.suite = suite + switch suite.ID { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + a.invalidPacketLimit = protocol.InvalidPacketLimitAES + case tls.TLS_CHACHA20_POLY1305_SHA256: + a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha + default: + panic(fmt.Sprintf("unknown cipher suite %d", suite.ID)) + } +} + func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { + dec, err := a.open(dst, src, rcvTime, pn, kp, ad) + if err == ErrDecryptionFailed { + a.invalidPacketCount++ + if a.invalidPacketCount >= a.invalidPacketLimit { + return nil, qerr.AEADLimitReached + } + } + return dec, err +} + +func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) { a.prevRcvAEAD = nil a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 6920f2f6..46ab5bcb 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -9,6 +9,7 @@ import ( mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/qtls" "github.com/lucas-clemente/quic-go/internal/utils" @@ -116,6 +117,16 @@ var _ = Describe("Updatable AEAD", func() { Expect(err).To(MatchError(ErrDecryptionFailed)) }) + It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() { + client.invalidPacketLimit = 10 + for i := 0; i < 9; i++ { + _, err := client.Open(nil, []byte("foobar"), time.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) + } + _, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).To(MatchError(qerr.AEADLimitReached)) + }) + Context("key updates", func() { Context("receiving key updates", func() { It("updates keys", func() { diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 484bf4c2..bf5333ab 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -75,3 +75,10 @@ const MaxMaxAckDelay = (1<<14 - 1) * time.Millisecond // MaxConnIDLen is the maximum length of the connection ID const MaxConnIDLen = 20 + +// InvalidPacketLimitAES is the maximum number of packets that we can fail to decrypt when using +// AEAD_AES_128_GCM or AEAD_AES_265_GCM. +const InvalidPacketLimitAES = 1 << 54 + +// InvalidPacketLimitChaCha is the maximum number of packets that we can fail to decrypt when using AEAD_CHACHA20_POLY1305. +const InvalidPacketLimitChaCha = 1 << 36 diff --git a/internal/qerr/error_codes.go b/internal/qerr/error_codes.go index 4bfb7220..87d4d6f5 100644 --- a/internal/qerr/error_codes.go +++ b/internal/qerr/error_codes.go @@ -26,6 +26,7 @@ const ( ApplicationError ErrorCode = 0xc CryptoBufferExceeded ErrorCode = 0xd KeyUpdateError ErrorCode = 0xe + AEADLimitReached ErrorCode = 0xf ) func (e ErrorCode) isCryptoError() bool { @@ -80,6 +81,8 @@ func (e ErrorCode) String() string { return "CRYPTO_BUFFER_EXCEEDED" case KeyUpdateError: return "KEY_UPDATE_ERROR" + case AEADLimitReached: + return "AEAD_LIMIT_REACHED" default: if e.isCryptoError() { return "CRYPTO_ERROR" diff --git a/qlog/types.go b/qlog/types.go index 1dca4d5b..a5c9a492 100644 --- a/qlog/types.go +++ b/qlog/types.go @@ -211,6 +211,8 @@ func (e transportError) String() string { return "crypto_buffer_exceeded" case qerr.KeyUpdateError: return "key_update_error" + case qerr.AEADLimitReached: + return "aead_limit_reached" default: return "" }