mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
271 lines
8.9 KiB
Go
271 lines
8.9 KiB
Go
package handshake
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/cipher"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
|
"github.com/lucas-clemente/quic-go/internal/qtls"
|
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
"github.com/lucas-clemente/quic-go/logging"
|
|
)
|
|
|
|
// By setting this environment variable, the key update interval can be adjusted.
|
|
// This is not needed in production, but useful for integration and interop testing.
|
|
// Note that no mattter what value is set, a key update is only initiated once it is
|
|
// permitted (i.e. once an ACK for a packet sent at the current key phase has been received).
|
|
const keyUpdateEnv = "QUIC_GO_KEY_UPDATE_INTERVAL"
|
|
|
|
var keyUpdateInterval uint64
|
|
|
|
func init() {
|
|
setKeyUpdateInterval()
|
|
}
|
|
|
|
func setKeyUpdateInterval() {
|
|
env := os.Getenv(keyUpdateEnv)
|
|
if env == "" {
|
|
keyUpdateInterval = protocol.KeyUpdateInterval
|
|
return
|
|
}
|
|
interval, err := strconv.ParseUint(env, 10, 64)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Cannot parse %s: %s", keyUpdateEnv, err))
|
|
}
|
|
keyUpdateInterval = interval
|
|
}
|
|
|
|
type updatableAEAD struct {
|
|
suite *qtls.CipherSuiteTLS13
|
|
|
|
keyPhase protocol.KeyPhase
|
|
largestAcked protocol.PacketNumber
|
|
firstPacketNumber protocol.PacketNumber
|
|
keyUpdateInterval uint64
|
|
|
|
// Time when the keys should be dropped. Keys are dropped on the next call to Open().
|
|
prevRcvAEADExpiry time.Time
|
|
prevRcvAEAD cipher.AEAD
|
|
|
|
firstRcvdWithCurrentKey protocol.PacketNumber
|
|
firstSentWithCurrentKey protocol.PacketNumber
|
|
numRcvdWithCurrentKey uint64
|
|
numSentWithCurrentKey uint64
|
|
rcvAEAD cipher.AEAD
|
|
sendAEAD cipher.AEAD
|
|
// caches cipher.AEAD.Overhead(). This speeds up calls to Overhead().
|
|
aeadOverhead int
|
|
|
|
nextRcvAEAD cipher.AEAD
|
|
nextSendAEAD cipher.AEAD
|
|
nextRcvTrafficSecret []byte
|
|
nextSendTrafficSecret []byte
|
|
|
|
headerDecrypter headerProtector
|
|
headerEncrypter headerProtector
|
|
|
|
rttStats *utils.RTTStats
|
|
|
|
tracer logging.ConnectionTracer
|
|
logger utils.Logger
|
|
|
|
// use a single slice to avoid allocations
|
|
nonceBuf []byte
|
|
}
|
|
|
|
var _ ShortHeaderOpener = &updatableAEAD{}
|
|
var _ ShortHeaderSealer = &updatableAEAD{}
|
|
|
|
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger) *updatableAEAD {
|
|
return &updatableAEAD{
|
|
firstPacketNumber: protocol.InvalidPacketNumber,
|
|
largestAcked: protocol.InvalidPacketNumber,
|
|
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
|
|
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
|
|
keyUpdateInterval: keyUpdateInterval,
|
|
rttStats: rttStats,
|
|
tracer: tracer,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (a *updatableAEAD) rollKeys(now time.Time) {
|
|
a.keyPhase++
|
|
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
|
|
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
|
|
a.numRcvdWithCurrentKey = 0
|
|
a.numSentWithCurrentKey = 0
|
|
a.prevRcvAEAD = a.rcvAEAD
|
|
a.prevRcvAEADExpiry = now.Add(3 * a.rttStats.PTO(true))
|
|
a.rcvAEAD = a.nextRcvAEAD
|
|
a.sendAEAD = a.nextSendAEAD
|
|
|
|
a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret)
|
|
a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret)
|
|
a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret)
|
|
a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret)
|
|
}
|
|
|
|
func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
|
|
return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
|
|
}
|
|
|
|
// For the client, this function is called before SetWriteKey.
|
|
// For the server, this function is called after SetWriteKey.
|
|
func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
|
a.rcvAEAD = createAEAD(suite, trafficSecret)
|
|
a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false)
|
|
if a.suite == nil {
|
|
a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize())
|
|
a.aeadOverhead = a.rcvAEAD.Overhead()
|
|
a.suite = suite
|
|
}
|
|
|
|
a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
|
|
a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret)
|
|
}
|
|
|
|
// For the client, this function is called after SetReadKey.
|
|
// For the server, this function is called before SetWriteKey.
|
|
func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
|
a.sendAEAD = createAEAD(suite, trafficSecret)
|
|
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false)
|
|
if a.suite == nil {
|
|
a.nonceBuf = make([]byte, a.sendAEAD.NonceSize())
|
|
a.aeadOverhead = a.sendAEAD.Overhead()
|
|
a.suite = suite
|
|
}
|
|
|
|
a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
|
|
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret)
|
|
}
|
|
|
|
func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
|
|
if a.prevRcvAEAD != nil && rcvTime.After(a.prevRcvAEADExpiry) {
|
|
a.prevRcvAEAD = nil
|
|
a.prevRcvAEADExpiry = time.Time{}
|
|
if a.tracer != nil {
|
|
a.tracer.DroppedKey(a.keyPhase - 1)
|
|
}
|
|
}
|
|
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
|
if kp != a.keyPhase.Bit() {
|
|
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
|
|
if a.keyPhase == 0 {
|
|
// This can only occur when the first packet received has key phase 1.
|
|
// This is an error, since the key phase starts at 0,
|
|
// and peers are only allowed to update keys after the handshake is confirmed.
|
|
return nil, qerr.NewError(qerr.ProtocolViolation, "wrong initial keyphase")
|
|
}
|
|
if a.prevRcvAEAD == nil {
|
|
return nil, ErrKeysDropped
|
|
}
|
|
// we updated the key, but the peer hasn't updated yet
|
|
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
|
if err != nil {
|
|
err = ErrDecryptionFailed
|
|
}
|
|
return dec, err
|
|
}
|
|
// try opening the packet with the next key phase
|
|
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
|
if err != nil {
|
|
return nil, ErrDecryptionFailed
|
|
}
|
|
// Opening succeeded. Check if the peer was allowed to update.
|
|
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
|
|
return nil, qerr.NewError(qerr.ProtocolViolation, "keys updated too quickly")
|
|
}
|
|
a.rollKeys(rcvTime)
|
|
a.logger.Debugf("Peer updated keys to %s", a.keyPhase)
|
|
if a.tracer != nil {
|
|
a.tracer.UpdatedKey(a.keyPhase, true)
|
|
}
|
|
a.firstRcvdWithCurrentKey = pn
|
|
return dec, err
|
|
}
|
|
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
|
// It uses the nonce provided here and XOR it with the IV.
|
|
dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
|
if err != nil {
|
|
err = ErrDecryptionFailed
|
|
} else {
|
|
a.numRcvdWithCurrentKey++
|
|
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
|
|
a.firstRcvdWithCurrentKey = pn
|
|
}
|
|
}
|
|
return dec, err
|
|
}
|
|
|
|
func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
|
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
|
|
a.firstSentWithCurrentKey = pn
|
|
}
|
|
if a.firstPacketNumber == protocol.InvalidPacketNumber {
|
|
a.firstPacketNumber = pn
|
|
}
|
|
a.numSentWithCurrentKey++
|
|
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
|
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
|
// It uses the nonce provided here and XOR it with the IV.
|
|
return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
|
|
}
|
|
|
|
func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) {
|
|
a.largestAcked = pn
|
|
}
|
|
|
|
func (a *updatableAEAD) updateAllowed() bool {
|
|
return a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
|
|
a.largestAcked != protocol.InvalidPacketNumber &&
|
|
a.largestAcked >= a.firstSentWithCurrentKey
|
|
}
|
|
|
|
func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
|
|
if !a.updateAllowed() {
|
|
return false
|
|
}
|
|
if a.numRcvdWithCurrentKey >= a.keyUpdateInterval {
|
|
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %s", a.numRcvdWithCurrentKey, a.keyPhase+1)
|
|
return true
|
|
}
|
|
if a.numSentWithCurrentKey >= a.keyUpdateInterval {
|
|
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %s", a.numSentWithCurrentKey, a.keyPhase+1)
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
|
|
if a.shouldInitiateKeyUpdate() {
|
|
a.rollKeys(time.Now())
|
|
a.logger.Debugf("Initiating key update to key phase %s", a.keyPhase)
|
|
if a.tracer != nil {
|
|
a.tracer.UpdatedKey(a.keyPhase, false)
|
|
}
|
|
}
|
|
return a.keyPhase.Bit()
|
|
}
|
|
|
|
func (a *updatableAEAD) Overhead() int {
|
|
return a.aeadOverhead
|
|
}
|
|
|
|
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
|
a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
|
|
}
|
|
|
|
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
|
a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
|
|
}
|
|
|
|
func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber {
|
|
return a.firstPacketNumber
|
|
}
|