mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
drop old keys 3 PTOs after a key update
This commit is contained in:
parent
a2a4a216de
commit
a3878d99c9
5 changed files with 76 additions and 9 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
@ -120,6 +121,7 @@ func NewCryptoSetupClient(
|
|||
tp *TransportParameters,
|
||||
runner handshakeRunner,
|
||||
tlsConf *tls.Config,
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
|
||||
cs, clientHelloWritten, err := newCryptoSetup(
|
||||
|
@ -130,6 +132,7 @@ func NewCryptoSetupClient(
|
|||
tp,
|
||||
runner,
|
||||
tlsConf,
|
||||
rttStats,
|
||||
logger,
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
|
@ -150,6 +153,7 @@ func NewCryptoSetupServer(
|
|||
tp *TransportParameters,
|
||||
runner handshakeRunner,
|
||||
tlsConf *tls.Config,
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, error) {
|
||||
cs, _, err := newCryptoSetup(
|
||||
|
@ -160,6 +164,7 @@ func NewCryptoSetupServer(
|
|||
tp,
|
||||
runner,
|
||||
tlsConf,
|
||||
rttStats,
|
||||
logger,
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
|
@ -178,6 +183,7 @@ func newCryptoSetup(
|
|||
tp *TransportParameters,
|
||||
runner handshakeRunner,
|
||||
tlsConf *tls.Config,
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
) (*cryptoSetup, <-chan struct{} /* ClientHello written */, error) {
|
||||
|
@ -192,7 +198,7 @@ func newCryptoSetup(
|
|||
initialOpener: initialOpener,
|
||||
handshakeStream: handshakeStream,
|
||||
oneRTTStream: oneRTTStream,
|
||||
aead: newUpdatableAEAD(logger),
|
||||
aead: newUpdatableAEAD(rttStats, logger),
|
||||
readEncLevel: protocol.EncryptionInitial,
|
||||
writeEncLevel: protocol.EncryptionInitial,
|
||||
runner: runner,
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
@ -93,6 +94,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
NewMockHandshakeRunner(mockCtrl),
|
||||
tlsConf,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -124,6 +126,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
runner,
|
||||
testdata.GetTLSConfig(),
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -161,6 +164,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
runner,
|
||||
testdata.GetTLSConfig(),
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -200,6 +204,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
runner,
|
||||
serverConf,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -233,6 +238,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
NewMockHandshakeRunner(mockCtrl),
|
||||
serverConf,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -311,6 +317,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
cRunner,
|
||||
clientConf,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -332,6 +339,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{StatelessResetToken: &token},
|
||||
sRunner,
|
||||
serverConf,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -384,6 +392,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
runner,
|
||||
&tls.Config{InsecureSkipVerify: true},
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -424,6 +433,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
cTransportParameters,
|
||||
cRunner,
|
||||
clientConf,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -446,6 +456,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
sTransportParameters,
|
||||
sRunner,
|
||||
serverConf,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -481,6 +492,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
cRunner,
|
||||
clientConf,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -498,6 +510,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
sRunner,
|
||||
serverConf,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -536,6 +549,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
cRunner,
|
||||
clientConf,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -553,6 +567,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
sRunner,
|
||||
serverConf,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -7,7 +7,9 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
|
@ -47,7 +49,9 @@ type updatableAEAD struct {
|
|||
largestAcked protocol.PacketNumber
|
||||
keyUpdateInterval uint64
|
||||
|
||||
prevRcvAEAD cipher.AEAD
|
||||
// Time when the keys should be dropped. Keys are dropped on the next call to Open().
|
||||
prevRcvAEADExpiry time.Time
|
||||
prevRcvAEAD cipher.AEAD
|
||||
|
||||
firstRcvdWithCurrentKey protocol.PacketNumber
|
||||
firstSentWithCurrentKey protocol.PacketNumber
|
||||
|
@ -64,6 +68,8 @@ type updatableAEAD struct {
|
|||
hpDecrypter cipher.Block
|
||||
hpEncrypter cipher.Block
|
||||
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
logger utils.Logger
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
|
@ -74,12 +80,13 @@ type updatableAEAD struct {
|
|||
var _ ShortHeaderOpener = &updatableAEAD{}
|
||||
var _ ShortHeaderSealer = &updatableAEAD{}
|
||||
|
||||
func newUpdatableAEAD(logger utils.Logger) *updatableAEAD {
|
||||
func newUpdatableAEAD(rttStats *congestion.RTTStats, logger utils.Logger) *updatableAEAD {
|
||||
return &updatableAEAD{
|
||||
largestAcked: protocol.InvalidPacketNumber,
|
||||
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
|
||||
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
|
||||
keyUpdateInterval: keyUpdateInterval,
|
||||
rttStats: rttStats,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
@ -91,6 +98,7 @@ func (a *updatableAEAD) rollKeys() {
|
|||
a.numRcvdWithCurrentKey = 0
|
||||
a.numSentWithCurrentKey = 0
|
||||
a.prevRcvAEAD = a.rcvAEAD
|
||||
a.prevRcvAEADExpiry = time.Now().Add(3 * a.rttStats.PTO())
|
||||
a.rcvAEAD = a.nextRcvAEAD
|
||||
a.sendAEAD = a.nextSendAEAD
|
||||
|
||||
|
@ -135,15 +143,22 @@ func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
|
|||
}
|
||||
|
||||
func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
|
||||
if a.prevRcvAEAD != nil && time.Now().After(a.prevRcvAEADExpiry) {
|
||||
a.prevRcvAEAD = nil
|
||||
a.prevRcvAEADExpiry = time.Time{}
|
||||
}
|
||||
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
||||
if kp != a.keyPhase.Bit() {
|
||||
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
|
||||
if a.prevRcvAEAD == nil {
|
||||
if a.keyPhase == 0 {
|
||||
// This can only occur when the first packet received has key phase 1.
|
||||
// This is an error, since the key phase starts at 0,
|
||||
// and peers are only allowed to update keys after the handshake is confirmed.
|
||||
return nil, qerr.Error(qerr.ProtocolViolation, "wrong initial keyphase")
|
||||
}
|
||||
if a.prevRcvAEAD == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
// we updated the key, but the peer hasn't updated yet
|
||||
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
if err != nil {
|
||||
|
|
|
@ -6,7 +6,9 @@ import (
|
|||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
. "github.com/onsi/ginkgo"
|
||||
|
@ -29,14 +31,14 @@ func (c *mockCipherSuite) AEAD(key, _ []byte) cipher.AEAD {
|
|||
}
|
||||
|
||||
var _ = Describe("Updatable AEAD", func() {
|
||||
getPeers := func() (client, server *updatableAEAD) {
|
||||
getPeers := func(rttStats *congestion.RTTStats) (client, server *updatableAEAD) {
|
||||
trafficSecret1 := make([]byte, 16)
|
||||
trafficSecret2 := make([]byte, 16)
|
||||
rand.Read(trafficSecret1)
|
||||
rand.Read(trafficSecret2)
|
||||
|
||||
client = newUpdatableAEAD(utils.DefaultLogger)
|
||||
server = newUpdatableAEAD(utils.DefaultLogger)
|
||||
client = newUpdatableAEAD(rttStats, utils.DefaultLogger)
|
||||
server = newUpdatableAEAD(rttStats, utils.DefaultLogger)
|
||||
client.SetReadKey(&mockCipherSuite{}, trafficSecret2)
|
||||
client.SetWriteKey(&mockCipherSuite{}, trafficSecret1)
|
||||
server.SetReadKey(&mockCipherSuite{}, trafficSecret1)
|
||||
|
@ -46,7 +48,7 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
|
||||
Context("header protection", func() {
|
||||
It("encrypts and decrypts the header", func() {
|
||||
server, client := getPeers()
|
||||
server, client := getPeers(&congestion.RTTStats{})
|
||||
var lastFiveBitsDifferent int
|
||||
for i := 0; i < 100; i++ {
|
||||
sample := make([]byte, 16)
|
||||
|
@ -69,9 +71,11 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
Context("message encryption", func() {
|
||||
var msg, ad []byte
|
||||
var server, client *updatableAEAD
|
||||
var rttStats *congestion.RTTStats
|
||||
|
||||
BeforeEach(func() {
|
||||
server, client = getPeers()
|
||||
rttStats = &congestion.RTTStats{}
|
||||
server, client = getPeers(rttStats)
|
||||
msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
ad = []byte("Donec in velit neque.")
|
||||
})
|
||||
|
@ -132,6 +136,7 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
})
|
||||
|
||||
It("opens a reordered packet with the old keys after an update", func() {
|
||||
rttStats.UpdateRTT(time.Hour, 0, time.Time{}) // make sure the keys don't get dropped yet
|
||||
encrypted01 := client.Seal(nil, msg, 0x42, ad)
|
||||
encrypted02 := client.Seal(nil, msg, 0x43, ad)
|
||||
// receive the first packet with key phase 0
|
||||
|
@ -153,6 +158,30 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
})
|
||||
|
||||
It("drops keys 3 PTOs after a key update", func() {
|
||||
rttStats.UpdateRTT(10*time.Millisecond, 0, time.Now())
|
||||
pto := rttStats.PTO()
|
||||
Expect(pto).To(BeNumerically("<", 50*time.Millisecond))
|
||||
encrypted01 := client.Seal(nil, msg, 0x42, ad)
|
||||
encrypted02 := client.Seal(nil, msg, 0x43, ad)
|
||||
// receive the first packet with key phase 0
|
||||
_, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// send one packet at key phase zero
|
||||
_ = server.Seal(nil, msg, 0x1, ad)
|
||||
// now receive a packet with key phase 1
|
||||
client.rollKeys()
|
||||
encrypted1 := client.Seal(nil, msg, 0x44, ad)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
_, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
// now receive a reordered packet with key phase 0
|
||||
time.Sleep(3 * pto)
|
||||
_, err = server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).To(MatchError(ErrKeysDropped))
|
||||
})
|
||||
|
||||
It("errors when the peer starts with key phase 1", func() {
|
||||
client.rollKeys()
|
||||
encrypted := client.Seal(nil, msg, 0x1337, ad)
|
||||
|
|
|
@ -223,6 +223,7 @@ var newSession = func(
|
|||
onHandshakeComplete: func() { close(s.handshakeCompleteChan) },
|
||||
},
|
||||
tlsConf,
|
||||
s.rttStats,
|
||||
logger,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -296,6 +297,7 @@ var newClientSession = func(
|
|||
onHandshakeComplete: func() { close(s.handshakeCompleteChan) },
|
||||
},
|
||||
tlsConf,
|
||||
s.rttStats,
|
||||
logger,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue