mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
uTLS is not yet bumped to the new version, so this commit breaks the dependencies relationship by getting rid of the local replace.
594 lines
27 KiB
Go
594 lines
27 KiB
Go
package handshake
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
tls "github.com/refraction-networking/utls"
|
|
|
|
mocklogging "github.com/refraction-networking/uquic/internal/mocks/logging"
|
|
"github.com/refraction-networking/uquic/internal/protocol"
|
|
"github.com/refraction-networking/uquic/internal/qerr"
|
|
"github.com/refraction-networking/uquic/internal/utils"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
var _ = Describe("Updatable AEAD", func() {
|
|
DescribeTable("ChaCha test vector",
|
|
func(v protocol.VersionNumber, expectedPayload, expectedPacket []byte) {
|
|
secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b")
|
|
aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil, v)
|
|
chacha := cipherSuites[2]
|
|
Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256))
|
|
aead.SetWriteKey(chacha, secret)
|
|
const pnOffset = 1
|
|
header := splitHexString("4200bff4")
|
|
payloadOffset := len(header)
|
|
plaintext := splitHexString("01")
|
|
payload := aead.Seal(nil, plaintext, 654360564, header)
|
|
Expect(payload).To(Equal(expectedPayload))
|
|
packet := append(header, payload...)
|
|
aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset])
|
|
Expect(packet).To(Equal(expectedPacket))
|
|
},
|
|
Entry("QUIC v1",
|
|
protocol.Version1,
|
|
splitHexString("655e5cd55c41f69080575d7999c25a5bfb"),
|
|
splitHexString("4cfe4189655e5cd55c41f69080575d7999c25a5bfb"),
|
|
),
|
|
Entry("QUIC v2",
|
|
protocol.Version2,
|
|
splitHexString("0ae7b6b932bc27d786f4bc2bb20f2162ba"),
|
|
splitHexString("5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"),
|
|
),
|
|
)
|
|
|
|
for _, ver := range []protocol.VersionNumber{protocol.Version1, protocol.Version2} {
|
|
v := ver
|
|
|
|
Context(fmt.Sprintf("using version %s", v), func() {
|
|
for i := range cipherSuites {
|
|
cs := cipherSuites[i]
|
|
|
|
Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() {
|
|
var (
|
|
client, server *updatableAEAD
|
|
serverTracer *mocklogging.MockConnectionTracer
|
|
rttStats *utils.RTTStats
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
trafficSecret1 := make([]byte, 16)
|
|
trafficSecret2 := make([]byte, 16)
|
|
rand.Read(trafficSecret1)
|
|
rand.Read(trafficSecret2)
|
|
|
|
rttStats = utils.NewRTTStats()
|
|
client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v)
|
|
server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger, v)
|
|
client.SetReadKey(cs, trafficSecret2)
|
|
client.SetWriteKey(cs, trafficSecret1)
|
|
server.SetReadKey(cs, trafficSecret1)
|
|
server.SetWriteKey(cs, trafficSecret2)
|
|
})
|
|
|
|
Context("header protection", func() {
|
|
It("encrypts and decrypts the header", func() {
|
|
var lastFiveBitsDifferent int
|
|
for i := 0; i < 100; i++ {
|
|
sample := make([]byte, 16)
|
|
rand.Read(sample)
|
|
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
|
|
client.EncryptHeader(sample, &header[0], header[9:13])
|
|
if header[0]&0x1f != 0xb5&0x1f {
|
|
lastFiveBitsDifferent++
|
|
}
|
|
Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0)))
|
|
Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
|
|
Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
|
|
server.DecryptHeader(sample, &header[0], header[9:13])
|
|
Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
|
|
}
|
|
Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75))
|
|
})
|
|
})
|
|
|
|
Context("message encryption", func() {
|
|
var msg, ad []byte
|
|
|
|
BeforeEach(func() {
|
|
msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
|
ad = []byte("Donec in velit neque.")
|
|
})
|
|
|
|
It("encrypts and decrypts a message", func() {
|
|
encrypted := server.Seal(nil, msg, 0x1337, ad)
|
|
opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(opened).To(Equal(msg))
|
|
})
|
|
|
|
It("saves the first packet number", func() {
|
|
client.Seal(nil, msg, 0x1337, ad)
|
|
Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337)))
|
|
client.Seal(nil, msg, 0x1338, ad)
|
|
Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337)))
|
|
})
|
|
|
|
It("fails to open a message if the associated data is not the same", func() {
|
|
encrypted := client.Seal(nil, msg, 0x1337, ad)
|
|
_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad"))
|
|
Expect(err).To(MatchError(ErrDecryptionFailed))
|
|
})
|
|
|
|
It("fails to open a message if the packet number is not the same", func() {
|
|
encrypted := server.Seal(nil, msg, 0x1337, ad)
|
|
_, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
|
|
Expect(err).To(MatchError(ErrDecryptionFailed))
|
|
})
|
|
|
|
It("decodes the packet number", func() {
|
|
encrypted := server.Seal(nil, msg, 0x1337, ad)
|
|
_, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338))
|
|
})
|
|
|
|
It("ignores packets it can't decrypt for packet number derivation", func() {
|
|
encrypted := server.Seal(nil, msg, 0x1337, ad)
|
|
_, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38))
|
|
})
|
|
|
|
It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() {
|
|
client.invalidPacketLimit = 10
|
|
for i := 0; i < 9; i++ {
|
|
_, err := client.Open(nil, []byte("foobar"), time.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad"))
|
|
Expect(err).To(MatchError(ErrDecryptionFailed))
|
|
}
|
|
_, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad"))
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
|
|
Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.AEADLimitReached))
|
|
})
|
|
|
|
Context("key updates", func() {
|
|
Context("receiving key updates", func() {
|
|
It("updates keys", func() {
|
|
now := time.Now()
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
encrypted0 := server.Seal(nil, msg, 0x1337, ad)
|
|
server.rollKeys()
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
encrypted1 := server.Seal(nil, msg, 0x1337, ad)
|
|
Expect(encrypted0).ToNot(Equal(encrypted1))
|
|
// expect opening to fail. The client didn't roll keys yet
|
|
_, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad)
|
|
Expect(err).To(MatchError(ErrDecryptionFailed))
|
|
client.rollKeys()
|
|
decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(decrypted).To(Equal(msg))
|
|
})
|
|
|
|
It("updates the keys when receiving a packet with the next key phase", func() {
|
|
now := time.Now()
|
|
// receive the first packet at key phase zero
|
|
encrypted0 := client.Seal(nil, msg, 0x42, ad)
|
|
decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(decrypted).To(Equal(msg))
|
|
// send one packet at key phase zero
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
_ = server.Seal(nil, msg, 0x1, ad)
|
|
// now received a message at key phase one
|
|
client.rollKeys()
|
|
encrypted1 := client.Seal(nil, msg, 0x43, ad)
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
|
|
decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(decrypted).To(Equal(msg))
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
})
|
|
|
|
It("opens a reordered packet with the old keys after an update", func() {
|
|
now := time.Now()
|
|
encrypted01 := client.Seal(nil, msg, 0x42, ad)
|
|
encrypted02 := client.Seal(nil, msg, 0x43, ad)
|
|
// receive the first packet with key phase 0
|
|
_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
// send one packet at key phase zero
|
|
_ = server.Seal(nil, msg, 0x1, ad)
|
|
// now receive a packet with key phase 1
|
|
client.rollKeys()
|
|
encrypted1 := client.Seal(nil, msg, 0x44, ad)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
|
|
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
// now receive a reordered packet with key phase 0
|
|
decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(decrypted).To(Equal(msg))
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
})
|
|
|
|
It("drops keys 3 PTOs after a key update", func() {
|
|
now := time.Now()
|
|
rttStats.UpdateRTT(10*time.Millisecond, 0, now)
|
|
pto := rttStats.PTO(true)
|
|
encrypted01 := client.Seal(nil, msg, 0x42, ad)
|
|
encrypted02 := client.Seal(nil, msg, 0x43, ad)
|
|
// receive the first packet with key phase 0
|
|
_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
// send one packet at key phase zero
|
|
_ = server.Seal(nil, msg, 0x1, ad)
|
|
// now receive a packet with key phase 1
|
|
client.rollKeys()
|
|
encrypted1 := client.Seal(nil, msg, 0x44, ad)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
|
|
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
|
|
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
// now receive a reordered packet with key phase 0
|
|
_, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad)
|
|
Expect(err).To(MatchError(ErrKeysDropped))
|
|
})
|
|
|
|
It("allows the first key update immediately", func() {
|
|
// receive a packet at key phase one, before having sent or received any packets at key phase 0
|
|
client.rollKeys()
|
|
encrypted1 := client.Seal(nil, msg, 0x1337, ad)
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
|
|
_, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() {
|
|
client.rollKeys()
|
|
encrypted := client.Seal(nil, msg, 0x1337, ad)
|
|
encrypted = encrypted[:len(encrypted)-1]
|
|
_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
|
|
Expect(err).To(MatchError(ErrDecryptionFailed))
|
|
})
|
|
|
|
It("errors when the peer updates keys too frequently", func() {
|
|
server.rollKeys()
|
|
client.rollKeys()
|
|
// receive the first packet at key phase one
|
|
encrypted0 := client.Seal(nil, msg, 0x42, ad)
|
|
_, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
// now receive a packet at key phase two, before having sent any packets
|
|
client.rollKeys()
|
|
encrypted1 := client.Seal(nil, msg, 0x42, ad)
|
|
_, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
|
|
Expect(err).To(MatchError(&qerr.TransportError{
|
|
ErrorCode: qerr.KeyUpdateError,
|
|
ErrorMessage: "keys updated too quickly",
|
|
}))
|
|
})
|
|
})
|
|
|
|
Context("initiating key updates", func() {
|
|
const firstKeyUpdateInterval = 5
|
|
const keyUpdateInterval = 20
|
|
var origKeyUpdateInterval, origFirstKeyUpdateInterval uint64
|
|
|
|
BeforeEach(func() {
|
|
origKeyUpdateInterval = KeyUpdateInterval
|
|
origFirstKeyUpdateInterval = FirstKeyUpdateInterval
|
|
KeyUpdateInterval = keyUpdateInterval
|
|
FirstKeyUpdateInterval = firstKeyUpdateInterval
|
|
server.SetHandshakeConfirmed()
|
|
})
|
|
|
|
AfterEach(func() {
|
|
KeyUpdateInterval = origKeyUpdateInterval
|
|
FirstKeyUpdateInterval = origFirstKeyUpdateInterval
|
|
})
|
|
|
|
It("initiates a key update after sealing the maximum number of packets, for the first update", func() {
|
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
|
pn := protocol.PacketNumber(i)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
server.Seal(nil, msg, pn, ad)
|
|
}
|
|
// the first update is allowed without receiving an acknowledgement
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
})
|
|
|
|
It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() {
|
|
server.rollKeys()
|
|
client.rollKeys()
|
|
for i := 0; i < keyUpdateInterval; i++ {
|
|
pn := protocol.PacketNumber(i)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
server.Seal(nil, msg, pn, ad)
|
|
}
|
|
// no update allowed before receiving an acknowledgement for the current key phase
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
// receive an ACK for a packet sent in key phase 0
|
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
|
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad"))
|
|
Expect(err).ToNot(HaveOccurred())
|
|
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
|
|
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
})
|
|
|
|
It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
|
|
// First make sure that we update our keys.
|
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
|
pn := protocol.PacketNumber(i)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
server.Seal(nil, msg, pn, ad)
|
|
}
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
// Now that our keys are updated, send a packet using the new keys.
|
|
const nextPN = firstKeyUpdateInterval + 1
|
|
server.Seal(nil, msg, nextPN, ad)
|
|
// We haven't decrypted any packet in the new key phase yet.
|
|
// This means that the ACK must have been sent in the old key phase.
|
|
Expect(server.SetLargestAcked(nextPN)).To(MatchError(&qerr.TransportError{
|
|
ErrorCode: qerr.KeyUpdateError,
|
|
ErrorMessage: "received ACK for key phase 1, but peer didn't update keys",
|
|
}))
|
|
})
|
|
|
|
It("doesn't error before actually sending a packet in the new key phase", func() {
|
|
// First make sure that we update our keys.
|
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
|
pn := protocol.PacketNumber(i)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
server.Seal(nil, msg, pn, ad)
|
|
}
|
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
|
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
|
|
Expect(err).ToNot(HaveOccurred())
|
|
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
|
// Now that our keys are updated, send a packet using the new keys.
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
// We haven't decrypted any packet in the new key phase yet.
|
|
// This means that the ACK must have been sent in the old key phase.
|
|
Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("initiates a key update after opening the maximum number of packets, for the first update", func() {
|
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
|
pn := protocol.PacketNumber(i)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
encrypted := client.Seal(nil, msg, pn, ad)
|
|
_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
}
|
|
// the first update is allowed without receiving an acknowledgement
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
})
|
|
|
|
It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() {
|
|
server.rollKeys()
|
|
client.rollKeys()
|
|
for i := 0; i < keyUpdateInterval; i++ {
|
|
pn := protocol.PacketNumber(i)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
encrypted := client.Seal(nil, msg, pn, ad)
|
|
_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
}
|
|
// no update allowed before receiving an acknowledgement for the current key phase
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
server.Seal(nil, msg, 1, ad)
|
|
Expect(server.SetLargestAcked(1)).To(Succeed())
|
|
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
})
|
|
|
|
It("drops keys 3 PTOs after a key update", func() {
|
|
now := time.Now()
|
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
|
pn := protocol.PacketNumber(i)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
server.Seal(nil, msg, pn, ad)
|
|
}
|
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
|
_, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad"))
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(server.SetLargestAcked(0)).To(Succeed())
|
|
// Now we've initiated the first key update.
|
|
// Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there
|
|
threePTO := 3 * rttStats.PTO(false)
|
|
dataKeyPhaseZero := client.Seal(nil, msg, 1, ad)
|
|
_, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
// Now receive a packet with key phase 1.
|
|
// This should start the timer to drop the keys after 3 PTOs.
|
|
client.rollKeys()
|
|
dataKeyPhaseOne := client.Seal(nil, msg, 10, ad)
|
|
t := now.Add(threePTO).Add(time.Second)
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
|
|
_, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
// Make sure the keys are still here.
|
|
_, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
|
|
_, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad)
|
|
Expect(err).To(MatchError(ErrKeysDropped))
|
|
})
|
|
|
|
It("doesn't drop the first key generation too early", func() {
|
|
now := time.Now()
|
|
data1 := client.Seal(nil, msg, 1, ad)
|
|
_, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
|
pn := protocol.PacketNumber(i)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
server.Seal(nil, msg, pn, ad)
|
|
Expect(server.SetLargestAcked(pn)).To(Succeed())
|
|
}
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
// The server never received a packet at key phase 1.
|
|
// Make sure the key phase 0 is still there at a much later point.
|
|
data2 := client.Seal(nil, msg, 1, ad)
|
|
_, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() {
|
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
|
pn := protocol.PacketNumber(i)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
server.Seal(nil, msg, pn, ad)
|
|
}
|
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
|
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
|
|
Expect(err).ToNot(HaveOccurred())
|
|
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
const nextPN = keyUpdateInterval + 1
|
|
// Send and receive an acknowledgement for a packet in key phase 1.
|
|
// We are now running a timer to drop the keys with 3 PTO.
|
|
server.Seal(nil, msg, nextPN, ad)
|
|
client.rollKeys()
|
|
dataKeyPhaseOne := client.Seal(nil, msg, 2, ad)
|
|
now := time.Now()
|
|
_, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(server.SetLargestAcked(nextPN))
|
|
// Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over.
|
|
// This mean that we need to drop the keys for key phase 0 immediately.
|
|
client.rollKeys()
|
|
dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad)
|
|
gomock.InOrder(
|
|
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)),
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true),
|
|
)
|
|
_, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
})
|
|
|
|
It("drops keys early when we initiate another key update within the 3 PTO period", func() {
|
|
server.SetHandshakeConfirmed()
|
|
// send so many packets that we initiate the first key update
|
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
|
pn := protocol.PacketNumber(i)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
server.Seal(nil, msg, pn, ad)
|
|
}
|
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
|
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
|
|
Expect(err).ToNot(HaveOccurred())
|
|
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
// send so many packets that we initiate the next key update
|
|
for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ {
|
|
pn := protocol.PacketNumber(i)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
|
server.Seal(nil, msg, pn, ad)
|
|
}
|
|
client.rollKeys()
|
|
b = client.Seal(nil, []byte("foobar"), 2, []byte("ad"))
|
|
now := time.Now()
|
|
_, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad"))
|
|
Expect(err).ToNot(HaveOccurred())
|
|
ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed())
|
|
gomock.InOrder(
|
|
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)),
|
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false),
|
|
)
|
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
// We haven't received an ACK for a packet sent in key phase 2 yet.
|
|
// Make sure we canceled the timer to drop the previous key phase.
|
|
b = client.Seal(nil, []byte("foobar"), 3, []byte("ad"))
|
|
_, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad"))
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
})
|
|
})
|
|
})
|
|
})
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
func getClientAndServer() (client, server *updatableAEAD) {
|
|
trafficSecret1 := make([]byte, 16)
|
|
trafficSecret2 := make([]byte, 16)
|
|
rand.Read(trafficSecret1)
|
|
rand.Read(trafficSecret2)
|
|
|
|
cs := cipherSuites[0]
|
|
rttStats := utils.NewRTTStats()
|
|
client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1)
|
|
server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1)
|
|
client.SetReadKey(cs, trafficSecret2)
|
|
client.SetWriteKey(cs, trafficSecret1)
|
|
server.SetReadKey(cs, trafficSecret1)
|
|
server.SetWriteKey(cs, trafficSecret2)
|
|
return
|
|
}
|
|
|
|
func BenchmarkPacketEncryption(b *testing.B) {
|
|
client, _ := getClientAndServer()
|
|
const l = 1200
|
|
src := make([]byte, l)
|
|
rand.Read(src)
|
|
ad := make([]byte, 32)
|
|
rand.Read(ad)
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
src = client.Seal(src[:0], src[:l], protocol.PacketNumber(i), ad)
|
|
}
|
|
}
|
|
|
|
func BenchmarkPacketDecryption(b *testing.B) {
|
|
client, server := getClientAndServer()
|
|
const l = 1200
|
|
src := make([]byte, l)
|
|
dst := make([]byte, l)
|
|
rand.Read(src)
|
|
ad := make([]byte, 32)
|
|
rand.Read(ad)
|
|
src = client.Seal(src[:0], src[:l], 1337, ad)
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
if _, err := server.Open(dst[:0], src, time.Time{}, 1337, protocol.KeyPhaseZero, ad); err != nil {
|
|
b.Fatalf("opening failed: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkRollKeys(b *testing.B) {
|
|
client, _ := getClientAndServer()
|
|
for i := 0; i < b.N; i++ {
|
|
client.rollKeys()
|
|
}
|
|
if int(client.keyPhase) != b.N {
|
|
b.Fatal("didn't roll keys often enough")
|
|
}
|
|
}
|