update: remove circl dependencies for kyber

This commit is contained in:
Mingye Chen 2025-01-10 14:44:59 -07:00
parent ff854c45c6
commit 0114defe06
7 changed files with 69 additions and 244 deletions

101
cfkem.go
View file

@ -1,101 +0,0 @@
// Copyright 2022 Cloudflare, Inc. All rights reserved. Use of this source code
// is governed by a BSD-style license that can be found in the LICENSE file.
//
// Glue to add Circl's (post-quantum) hybrid KEMs.
//
// To enable set CurvePreferences with the desired scheme as the first element:
//
// import (
// "crypto/tls"
//
// [...]
//
// config.CurvePreferences = []tls.CurveID{
// tls.X25519Kyber768Draft00,
// tls.X25519,
// tls.P256,
// }
package tls
import (
"fmt"
"io"
"crypto/ecdh"
"github.com/cloudflare/circl/kem"
"github.com/cloudflare/circl/kem/hybrid"
)
// Either *ecdh.PrivateKey or *kemPrivateKey
type clientKeySharePrivate interface{}
type kemPrivateKey struct {
secretKey kem.PrivateKey
curveID CurveID
}
var (
X25519Kyber512Draft00 = CurveID(0xfe30)
X25519Kyber768Draft00 = CurveID(0x6399)
X25519Kyber768Draft00Old = CurveID(0xfe31)
P256Kyber768Draft00 = CurveID(0xfe32)
invalidCurveID = CurveID(0)
)
// Extract CurveID from clientKeySharePrivate
func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID {
switch v := ks.(type) {
case *kemPrivateKey:
return v.curveID
case *ecdh.PrivateKey:
ret, ok := curveIDForCurve(v.Curve())
if !ok {
panic("cfkem: internal error: unknown curve")
}
return ret
default:
panic("cfkem: internal error: unknown clientKeySharePrivate")
}
}
// Returns scheme by CurveID if supported by Circl
func curveIdToCirclScheme(id CurveID) kem.Scheme {
switch id {
case X25519Kyber512Draft00:
return hybrid.Kyber512X25519()
case X25519Kyber768Draft00, X25519Kyber768Draft00Old:
return hybrid.Kyber768X25519()
case P256Kyber768Draft00:
return hybrid.P256Kyber768Draft00()
}
return nil
}
// Generate a new shared secret and encapsulates it for the packed
// public key in ppk using randomness from rnd.
func encapsulateForKem(scheme kem.Scheme, rnd io.Reader, ppk []byte) (
ct, ss []byte, alert alert, err error) {
pk, err := scheme.UnmarshalBinaryPublicKey(ppk)
if err != nil {
return nil, nil, alertIllegalParameter, fmt.Errorf("unpack pk: %w", err)
}
seed := make([]byte, scheme.EncapsulationSeedSize())
if _, err := io.ReadFull(rnd, seed); err != nil {
return nil, nil, alertInternalError, fmt.Errorf("random: %w", err)
}
ct, ss, err = scheme.EncapsulateDeterministically(pk, seed)
return ct, ss, alertIllegalParameter, err
}
// Generate a new keypair using randomness from rnd.
func generateKemKeyPair(scheme kem.Scheme, curveID CurveID, rnd io.Reader) (
kem.PublicKey, *kemPrivateKey, error) {
seed := make([]byte, scheme.SeedSize())
if _, err := io.ReadFull(rnd, seed); err != nil {
return nil, nil, err
}
pk, sk := scheme.DeriveKeyPair(seed)
return pk, &kemPrivateKey{sk, curveID}, nil
}

View file

@ -1,107 +0,0 @@
// Copyright 2022 Cloudflare, Inc. All rights reserved. Use of this source code
// is governed by a BSD-style license that can be found in the LICENSE file.
package tls
import (
"context"
"fmt"
"testing"
)
func testHybridKEX(t *testing.T, curveID CurveID, clientPQ, serverPQ,
clientTLS12, serverTLS12 bool) {
// var clientSelectedKEX *CurveID
// var retry bool
clientConfig := testConfig.Clone()
if clientPQ {
clientConfig.CurvePreferences = []CurveID{curveID, X25519}
}
// clientCFEventHandler := func(ev CFEvent) {
// switch e := ev.(type) {
// case CFEventTLSNegotiatedNamedKEX:
// clientSelectedKEX = &e.KEX
// case CFEventTLS13HRR:
// retry = true
// }
// }
if clientTLS12 {
clientConfig.MaxVersion = VersionTLS12
}
serverConfig := testConfig.Clone()
if serverPQ {
serverConfig.CurvePreferences = []CurveID{curveID, X25519}
} else {
serverConfig.CurvePreferences = []CurveID{X25519}
}
if serverTLS12 {
serverConfig.MaxVersion = VersionTLS12
}
c, s := localPipe(t)
done := make(chan error)
defer c.Close()
go func() {
defer s.Close()
done <- Server(s, serverConfig).Handshake()
}()
cli := Client(c, clientConfig)
// cCtx := context.WithValue(context.Background(), CFEventHandlerContextKey{}, clientCFEventHandler)
clientErr := cli.HandshakeContext(context.Background())
serverErr := <-done
if clientErr != nil {
t.Errorf("client error: %s", clientErr)
}
if serverErr != nil {
t.Errorf("server error: %s", serverErr)
}
// var expectedKEX CurveID
// var expectedRetry bool
// if clientPQ && serverPQ && !clientTLS12 && !serverTLS12 {
// expectedKEX = curveID
// } else {
// expectedKEX = X25519
// }
// if !clientTLS12 && clientPQ && !serverPQ {
// expectedRetry = true
// }
// if expectedRetry != retry {
// t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry)
// }
// if clientSelectedKEX == nil {
// t.Error("No KEX happened?")
// } else if *clientSelectedKEX != expectedKEX {
// t.Errorf("failed to negotiate: expected %d, got %d",
// expectedKEX, *clientSelectedKEX)
// }
}
func TestHybridKEX(t *testing.T) {
run := func(curveID CurveID, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) {
t.Run(fmt.Sprintf("%#04x serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", uint16(curveID),
serverPQ, clientPQ, serverTLS12, clientTLS12), func(t *testing.T) {
testHybridKEX(t, curveID, clientPQ, serverPQ, clientTLS12, serverTLS12)
})
}
for _, curveID := range []CurveID{
X25519Kyber512Draft00,
X25519Kyber768Draft00,
X25519Kyber768Draft00Old,
P256Kyber768Draft00,
} {
run(curveID, true, true, false, false)
run(curveID, true, false, false, false)
run(curveID, false, true, false, false)
run(curveID, true, true, true, false)
run(curveID, true, true, false, true)
run(curveID, true, true, true, true)
}
}

View file

@ -61,7 +61,7 @@ func (ksp *KeySharesParameters) GetEcdhePubkey(curveID CurveID) (params *ecdh.Pu
}
func (ksp *KeySharesParameters) AddKemKeypair(curveID CurveID, kemKey kem.PrivateKey, kemPubKey kem.PublicKey) {
if curveIdToCirclScheme(curveID) != nil { // only store for circl schemes
if curveID == x25519Kyber768Draft00 { // only store for x25519Kyber768Draft00
ksp.kemPrivKeymap[curveID] = kemKey
ksp.kemPubKeymap[curveID] = kemPubKey
}
@ -121,18 +121,20 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
}
// [uTLS SECTION START]
// set echdheParams to what we received from server
if ecdheKey, ok := hs.keySharesParams.GetEcdheKey(hs.serverHello.serverShare.group); ok {
hs.keyShareKeys.ecdhe = ecdheKey
hs.keyShareKeys.curveID = hs.serverHello.serverShare.group
}
// set kemParams to what we received from server
if kemKey, ok := hs.keySharesParams.GetKemKey(hs.serverHello.serverShare.group); ok {
if kyberKey, ecdhKey, err := mlkemCirclToGo(kemKey); err == nil {
hs.keyShareKeys.kyber = kyberKey
hs.keyShareKeys.ecdhe = ecdhKey
if hs.keyShareKeys == nil {
// set echdheParams to what we received from server
if ecdheKey, ok := hs.keySharesParams.GetEcdheKey(hs.serverHello.serverShare.group); ok {
hs.keyShareKeys.ecdhe = ecdheKey
hs.keyShareKeys.curveID = hs.serverHello.serverShare.group
}
// set kemParams to what we received from server
if kemKey, ok := hs.keySharesParams.GetKemKey(hs.serverHello.serverShare.group); ok {
if kyberKey, ecdhKey, err := mlkemCirclToGo(kemKey); err == nil {
hs.keyShareKeys.kyber = kyberKey
hs.keyShareKeys.ecdhe = ecdhKey
hs.keyShareKeys.curveID = hs.serverHello.serverShare.group
}
}
}
// [uTLS SECTION END]

View file

@ -87,6 +87,18 @@ const (
FakeCurveFFDHE8192 CurveID = 0x0104
)
const (
X25519Kyber768Draft00 CurveID = x25519Kyber768Draft00
FakeCurveX25519Kyber512Draft00 CurveID = 0xfe30
FakeCurveX25519Kyber768Draft00Old CurveID = 0xfe31
FakeCurveP256Kyber768Draft00 CurveID = 0xfe32
X25519Kyber512Draft00 CurveID = FakeCurveX25519Kyber512Draft00
X25519Kyber768Draft00Old CurveID = FakeCurveX25519Kyber768Draft00Old
P256Kyber768Draft00 CurveID = FakeCurveP256Kyber768Draft00
)
// Other things
const (
fakeRecordSizeLimit uint16 = 0x001c

View file

@ -15,8 +15,6 @@ import (
"hash"
"net"
"strconv"
"github.com/cloudflare/circl/kem/mlkem/mlkem768"
)
type ClientHelloBuildStatus int
@ -115,15 +113,7 @@ func (uconn *UConn) buildHandshakeState(loadSession bool) error {
}
uconn.HandshakeState.Hello = hello.getPublicPtr()
if keySharePrivate.ecdhe != nil {
uconn.HandshakeState.State13.EcdheKey = keySharePrivate.ecdhe
} else if keySharePrivate.kyber != nil {
kemPrivKey := &mlkem768.PrivateKey{}
kemPrivKey.Unpack(keySharePrivate.kyber.EncapsulationKey())
uconn.HandshakeState.State13.KEMKey = &KemPrivateKey{CurveID: keySharePrivate.curveID, SecretKey: kemPrivKey}
} else {
return fmt.Errorf("uTLS: unknown keySharePrivate type: %T", keySharePrivate)
}
uconn.HandshakeState.State13.KeyShareKeys = keySharePrivate.ToPublic()
uconn.HandshakeState.C = uconn.Conn
uconn.clientHelloBuildStatus = BuildByGoTLS
} else {

View file

@ -18,6 +18,7 @@ import (
"strconv"
"github.com/refraction-networking/utls/dicttls"
"github.com/refraction-networking/utls/internal/mlkem768"
)
var ErrUnknownClientHelloID = errors.New("tls: unknown ClientHelloID")
@ -2732,35 +2733,52 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
continue
}
if scheme := curveIdToCirclScheme(curveID); scheme != nil {
pk, sk, err := generateKemKeyPair(scheme, curveID, uconn.config.rand())
if curveID == x25519Kyber768Draft00 {
ecdheKey, err := generateECDHEKey(uconn.config.rand(), X25519)
if err != nil {
return fmt.Errorf("HRR generateKemKeyPair %s: %w",
scheme.Name(), err)
return err
}
packedPk, err := pk.MarshalBinary()
seed := make([]byte, mlkem768.SeedSize)
if _, err := io.ReadFull(uconn.config.rand(), seed); err != nil {
return err
}
kyberKey, err := mlkem768.NewKeyFromSeed(seed)
if err != nil {
return fmt.Errorf("HRR pack circl public key %s: %w",
scheme.Name(), err)
return err
}
uconn.HandshakeState.State13.KeySharesParams.AddKemKeypair(curveID, sk.secretKey, pk)
ext.KeyShares[i].Data = packedPk
circlKyberKey, err := kyberGoToCircl(kyberKey, ecdheKey)
if err != nil {
return err
}
uconn.HandshakeState.State13.KeySharesParams.AddKemKeypair(curveID, circlKyberKey, circlKyberKey.Public())
ext.KeyShares[i].Data = append(ecdheKey.PublicKey().Bytes(), kyberKey.EncapsulationKey()...)
if !preferredCurveIsSet {
// only do this once for the first non-grease curve
uconn.HandshakeState.State13.KEMKey = sk.ToPublic()
uconn.HandshakeState.State13.KeyShareKeys.kyber = kyberKey
preferredCurveIsSet = true
}
if len(ext.KeyShares) > i+1 && ext.KeyShares[i+1].Group == X25519 {
// Reuse the same X25519 ephemeral key for both keyshares, as allowed by draft-ietf-tls-hybrid-design-09, Section 3.2.
uconn.HandshakeState.State13.KeyShareKeys.Ecdhe = ecdheKey
uconn.HandshakeState.State13.KeySharesParams.AddEcdheKeypair(curveID, ecdheKey, ecdheKey.PublicKey())
ext.KeyShares[i+1].Data = ecdheKey.PublicKey().Bytes()
}
} else {
ecdheKey, err := generateECDHEKey(uconn.config.rand(), curveID)
if err != nil {
return fmt.Errorf("unsupported Curve in KeyShareExtension: %v."+
"To mimic it, fill the Data(key) field manually", curveID)
}
uconn.HandshakeState.State13.KeySharesParams.AddEcdheKeypair(curveID, ecdheKey, ecdheKey.PublicKey())
ext.KeyShares[i].Data = ecdheKey.PublicKey().Bytes()
if !preferredCurveIsSet {
// only do this once for the first non-grease curve
uconn.HandshakeState.State13.EcdheKey = ecdheKey
uconn.HandshakeState.State13.KeyShareKeys.Ecdhe = ecdheKey
preferredCurveIsSet = true
}
}

View file

@ -41,6 +41,7 @@ type PubClientHandshakeState struct {
// TLS 1.3 only
type TLS13OnlyState struct {
// Deprecated: Use KeyShareKeys instead.
EcdheKey *ecdh.PrivateKey
KeySharesParams *KeySharesParameters
KEMKey *KemPrivateKey
@ -111,6 +112,10 @@ func (chs *TLS13OnlyState) private13KeyShareKeys() *keySharePrivateKeys {
return nil
}
func kyberGoToCircl(kyberKey *mlkem768.DecapsulationKey, ecdhKey *ecdh.PrivateKey) (kem.PrivateKey, error) {
return hybrid.Kyber768X25519().UnmarshalBinaryPrivateKey(append(ecdhKey.Bytes(), kyberKey.Bytes()...))
}
func (ksp *keySharePrivateKeys) publicKEMKey() *KemPrivateKey {
if ksp.kyber != nil && ksp.ecdhe != nil && ksp.curveID == x25519Kyber768Draft00 {
key := append(ksp.ecdhe.Bytes(), ksp.kyber.Bytes()...)
@ -160,6 +165,7 @@ func (chs13 *clientHandshakeStateTLS13) toPublic13() *PubClientHandshakeState {
KeySharesParams: chs13.keySharesParams,
EcdheKey: chs13.keyShareKeys.ecdhe,
KEMKey: chs13.keyShareKeys.publicKEMKey(),
KeyShareKeys: chs13.keyShareKeys.ToPublic(),
EarlySecret: chs13.earlySecret,
BinderKey: chs13.binderKey,
CertReq: chs13.certReq.toPublic(),
@ -863,6 +869,11 @@ func (TKS TicketKeys) ToPrivate() []ticketKey {
return tks
}
type kemPrivateKey struct {
secretKey kem.PrivateKey
curveID CurveID
}
type KemPrivateKey struct {
SecretKey kem.PrivateKey
CurveID CurveID
@ -893,14 +904,14 @@ func (kpk *kemPrivateKey) ToPublic() *KemPrivateKey {
type KeySharePrivateKeys struct {
CurveID CurveID
Ecdhe *ecdh.PrivateKey
Kyber *mlkem768.DecapsulationKey
kyber *mlkem768.DecapsulationKey
}
func (ksp *KeySharePrivateKeys) ToPrivate() *keySharePrivateKeys {
return &keySharePrivateKeys{
curveID: ksp.CurveID,
ecdhe: ksp.Ecdhe,
kyber: ksp.Kyber,
kyber: ksp.kyber,
}
}
@ -908,6 +919,6 @@ func (ksp *keySharePrivateKeys) ToPublic() *KeySharePrivateKeys {
return &KeySharePrivateKeys{
CurveID: ksp.curveID,
Ecdhe: ksp.ecdhe,
Kyber: ksp.kyber,
kyber: ksp.kyber,
}
}