From 0114defe06457ee97f8d5071d303e20eff3fd281 Mon Sep 17 00:00:00 2001 From: Mingye Chen Date: Fri, 10 Jan 2025 14:44:59 -0700 Subject: [PATCH] update: remove circl dependencies for kyber --- cfkem.go | 101 ----------------------------------- cfkem_test.go | 107 -------------------------------------- handshake_client_tls13.go | 24 +++++---- u_common.go | 12 +++++ u_conn.go | 12 +---- u_parrots.go | 40 ++++++++++---- u_public.go | 17 ++++-- 7 files changed, 69 insertions(+), 244 deletions(-) delete mode 100644 cfkem.go delete mode 100644 cfkem_test.go diff --git a/cfkem.go b/cfkem.go deleted file mode 100644 index 8d440e4..0000000 --- a/cfkem.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2022 Cloudflare, Inc. All rights reserved. Use of this source code -// is governed by a BSD-style license that can be found in the LICENSE file. -// -// Glue to add Circl's (post-quantum) hybrid KEMs. -// -// To enable set CurvePreferences with the desired scheme as the first element: -// -// import ( -// "crypto/tls" -// -// [...] -// -// config.CurvePreferences = []tls.CurveID{ -// tls.X25519Kyber768Draft00, -// tls.X25519, -// tls.P256, -// } - -package tls - -import ( - "fmt" - "io" - - "crypto/ecdh" - - "github.com/cloudflare/circl/kem" - "github.com/cloudflare/circl/kem/hybrid" -) - -// Either *ecdh.PrivateKey or *kemPrivateKey -type clientKeySharePrivate interface{} - -type kemPrivateKey struct { - secretKey kem.PrivateKey - curveID CurveID -} - -var ( - X25519Kyber512Draft00 = CurveID(0xfe30) - X25519Kyber768Draft00 = CurveID(0x6399) - X25519Kyber768Draft00Old = CurveID(0xfe31) - P256Kyber768Draft00 = CurveID(0xfe32) - invalidCurveID = CurveID(0) -) - -// Extract CurveID from clientKeySharePrivate -func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID { - switch v := ks.(type) { - case *kemPrivateKey: - return v.curveID - case *ecdh.PrivateKey: - ret, ok := curveIDForCurve(v.Curve()) - if !ok { - panic("cfkem: internal error: unknown curve") - } - return ret - default: - panic("cfkem: internal error: unknown clientKeySharePrivate") - } -} - -// Returns scheme by CurveID if supported by Circl -func curveIdToCirclScheme(id CurveID) kem.Scheme { - switch id { - case X25519Kyber512Draft00: - return hybrid.Kyber512X25519() - case X25519Kyber768Draft00, X25519Kyber768Draft00Old: - return hybrid.Kyber768X25519() - case P256Kyber768Draft00: - return hybrid.P256Kyber768Draft00() - } - return nil -} - -// Generate a new shared secret and encapsulates it for the packed -// public key in ppk using randomness from rnd. -func encapsulateForKem(scheme kem.Scheme, rnd io.Reader, ppk []byte) ( - ct, ss []byte, alert alert, err error) { - pk, err := scheme.UnmarshalBinaryPublicKey(ppk) - if err != nil { - return nil, nil, alertIllegalParameter, fmt.Errorf("unpack pk: %w", err) - } - seed := make([]byte, scheme.EncapsulationSeedSize()) - if _, err := io.ReadFull(rnd, seed); err != nil { - return nil, nil, alertInternalError, fmt.Errorf("random: %w", err) - } - ct, ss, err = scheme.EncapsulateDeterministically(pk, seed) - return ct, ss, alertIllegalParameter, err -} - -// Generate a new keypair using randomness from rnd. -func generateKemKeyPair(scheme kem.Scheme, curveID CurveID, rnd io.Reader) ( - kem.PublicKey, *kemPrivateKey, error) { - seed := make([]byte, scheme.SeedSize()) - if _, err := io.ReadFull(rnd, seed); err != nil { - return nil, nil, err - } - pk, sk := scheme.DeriveKeyPair(seed) - return pk, &kemPrivateKey{sk, curveID}, nil -} diff --git a/cfkem_test.go b/cfkem_test.go deleted file mode 100644 index 7043d5f..0000000 --- a/cfkem_test.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2022 Cloudflare, Inc. All rights reserved. Use of this source code -// is governed by a BSD-style license that can be found in the LICENSE file. - -package tls - -import ( - "context" - "fmt" - "testing" -) - -func testHybridKEX(t *testing.T, curveID CurveID, clientPQ, serverPQ, - clientTLS12, serverTLS12 bool) { - // var clientSelectedKEX *CurveID - // var retry bool - - clientConfig := testConfig.Clone() - if clientPQ { - clientConfig.CurvePreferences = []CurveID{curveID, X25519} - } - // clientCFEventHandler := func(ev CFEvent) { - // switch e := ev.(type) { - // case CFEventTLSNegotiatedNamedKEX: - // clientSelectedKEX = &e.KEX - // case CFEventTLS13HRR: - // retry = true - // } - // } - if clientTLS12 { - clientConfig.MaxVersion = VersionTLS12 - } - - serverConfig := testConfig.Clone() - if serverPQ { - serverConfig.CurvePreferences = []CurveID{curveID, X25519} - } else { - serverConfig.CurvePreferences = []CurveID{X25519} - } - if serverTLS12 { - serverConfig.MaxVersion = VersionTLS12 - } - - c, s := localPipe(t) - done := make(chan error) - defer c.Close() - - go func() { - defer s.Close() - done <- Server(s, serverConfig).Handshake() - }() - - cli := Client(c, clientConfig) - // cCtx := context.WithValue(context.Background(), CFEventHandlerContextKey{}, clientCFEventHandler) - clientErr := cli.HandshakeContext(context.Background()) - serverErr := <-done - if clientErr != nil { - t.Errorf("client error: %s", clientErr) - } - if serverErr != nil { - t.Errorf("server error: %s", serverErr) - } - - // var expectedKEX CurveID - // var expectedRetry bool - - // if clientPQ && serverPQ && !clientTLS12 && !serverTLS12 { - // expectedKEX = curveID - // } else { - // expectedKEX = X25519 - // } - // if !clientTLS12 && clientPQ && !serverPQ { - // expectedRetry = true - // } - - // if expectedRetry != retry { - // t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry) - // } - - // if clientSelectedKEX == nil { - // t.Error("No KEX happened?") - // } else if *clientSelectedKEX != expectedKEX { - // t.Errorf("failed to negotiate: expected %d, got %d", - // expectedKEX, *clientSelectedKEX) - // } -} - -func TestHybridKEX(t *testing.T) { - run := func(curveID CurveID, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) { - t.Run(fmt.Sprintf("%#04x serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", uint16(curveID), - serverPQ, clientPQ, serverTLS12, clientTLS12), func(t *testing.T) { - testHybridKEX(t, curveID, clientPQ, serverPQ, clientTLS12, serverTLS12) - }) - } - for _, curveID := range []CurveID{ - X25519Kyber512Draft00, - X25519Kyber768Draft00, - X25519Kyber768Draft00Old, - P256Kyber768Draft00, - } { - run(curveID, true, true, false, false) - run(curveID, true, false, false, false) - run(curveID, false, true, false, false) - run(curveID, true, true, true, false) - run(curveID, true, true, false, true) - run(curveID, true, true, true, true) - } -} diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index c84bd16..ffe71f1 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -61,7 +61,7 @@ func (ksp *KeySharesParameters) GetEcdhePubkey(curveID CurveID) (params *ecdh.Pu } func (ksp *KeySharesParameters) AddKemKeypair(curveID CurveID, kemKey kem.PrivateKey, kemPubKey kem.PublicKey) { - if curveIdToCirclScheme(curveID) != nil { // only store for circl schemes + if curveID == x25519Kyber768Draft00 { // only store for x25519Kyber768Draft00 ksp.kemPrivKeymap[curveID] = kemKey ksp.kemPubKeymap[curveID] = kemPubKey } @@ -121,18 +121,20 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } // [uTLS SECTION START] - // set echdheParams to what we received from server - if ecdheKey, ok := hs.keySharesParams.GetEcdheKey(hs.serverHello.serverShare.group); ok { - hs.keyShareKeys.ecdhe = ecdheKey - hs.keyShareKeys.curveID = hs.serverHello.serverShare.group - } - // set kemParams to what we received from server - if kemKey, ok := hs.keySharesParams.GetKemKey(hs.serverHello.serverShare.group); ok { - if kyberKey, ecdhKey, err := mlkemCirclToGo(kemKey); err == nil { - hs.keyShareKeys.kyber = kyberKey - hs.keyShareKeys.ecdhe = ecdhKey + if hs.keyShareKeys == nil { + // set echdheParams to what we received from server + if ecdheKey, ok := hs.keySharesParams.GetEcdheKey(hs.serverHello.serverShare.group); ok { + hs.keyShareKeys.ecdhe = ecdheKey hs.keyShareKeys.curveID = hs.serverHello.serverShare.group } + // set kemParams to what we received from server + if kemKey, ok := hs.keySharesParams.GetKemKey(hs.serverHello.serverShare.group); ok { + if kyberKey, ecdhKey, err := mlkemCirclToGo(kemKey); err == nil { + hs.keyShareKeys.kyber = kyberKey + hs.keyShareKeys.ecdhe = ecdhKey + hs.keyShareKeys.curveID = hs.serverHello.serverShare.group + } + } } // [uTLS SECTION END] diff --git a/u_common.go b/u_common.go index 59a5a2e..68ce11f 100644 --- a/u_common.go +++ b/u_common.go @@ -87,6 +87,18 @@ const ( FakeCurveFFDHE8192 CurveID = 0x0104 ) +const ( + X25519Kyber768Draft00 CurveID = x25519Kyber768Draft00 + + FakeCurveX25519Kyber512Draft00 CurveID = 0xfe30 + FakeCurveX25519Kyber768Draft00Old CurveID = 0xfe31 + FakeCurveP256Kyber768Draft00 CurveID = 0xfe32 + + X25519Kyber512Draft00 CurveID = FakeCurveX25519Kyber512Draft00 + X25519Kyber768Draft00Old CurveID = FakeCurveX25519Kyber768Draft00Old + P256Kyber768Draft00 CurveID = FakeCurveP256Kyber768Draft00 +) + // Other things const ( fakeRecordSizeLimit uint16 = 0x001c diff --git a/u_conn.go b/u_conn.go index 3372947..d39fca8 100644 --- a/u_conn.go +++ b/u_conn.go @@ -15,8 +15,6 @@ import ( "hash" "net" "strconv" - - "github.com/cloudflare/circl/kem/mlkem/mlkem768" ) type ClientHelloBuildStatus int @@ -115,15 +113,7 @@ func (uconn *UConn) buildHandshakeState(loadSession bool) error { } uconn.HandshakeState.Hello = hello.getPublicPtr() - if keySharePrivate.ecdhe != nil { - uconn.HandshakeState.State13.EcdheKey = keySharePrivate.ecdhe - } else if keySharePrivate.kyber != nil { - kemPrivKey := &mlkem768.PrivateKey{} - kemPrivKey.Unpack(keySharePrivate.kyber.EncapsulationKey()) - uconn.HandshakeState.State13.KEMKey = &KemPrivateKey{CurveID: keySharePrivate.curveID, SecretKey: kemPrivKey} - } else { - return fmt.Errorf("uTLS: unknown keySharePrivate type: %T", keySharePrivate) - } + uconn.HandshakeState.State13.KeyShareKeys = keySharePrivate.ToPublic() uconn.HandshakeState.C = uconn.Conn uconn.clientHelloBuildStatus = BuildByGoTLS } else { diff --git a/u_parrots.go b/u_parrots.go index 8f35962..5238b9e 100644 --- a/u_parrots.go +++ b/u_parrots.go @@ -18,6 +18,7 @@ import ( "strconv" "github.com/refraction-networking/utls/dicttls" + "github.com/refraction-networking/utls/internal/mlkem768" ) var ErrUnknownClientHelloID = errors.New("tls: unknown ClientHelloID") @@ -2732,35 +2733,52 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error { continue } - if scheme := curveIdToCirclScheme(curveID); scheme != nil { - pk, sk, err := generateKemKeyPair(scheme, curveID, uconn.config.rand()) + if curveID == x25519Kyber768Draft00 { + ecdheKey, err := generateECDHEKey(uconn.config.rand(), X25519) if err != nil { - return fmt.Errorf("HRR generateKemKeyPair %s: %w", - scheme.Name(), err) + return err } - packedPk, err := pk.MarshalBinary() + seed := make([]byte, mlkem768.SeedSize) + if _, err := io.ReadFull(uconn.config.rand(), seed); err != nil { + return err + } + kyberKey, err := mlkem768.NewKeyFromSeed(seed) if err != nil { - return fmt.Errorf("HRR pack circl public key %s: %w", - scheme.Name(), err) + return err } - uconn.HandshakeState.State13.KeySharesParams.AddKemKeypair(curveID, sk.secretKey, pk) - ext.KeyShares[i].Data = packedPk + + circlKyberKey, err := kyberGoToCircl(kyberKey, ecdheKey) + if err != nil { + return err + } + uconn.HandshakeState.State13.KeySharesParams.AddKemKeypair(curveID, circlKyberKey, circlKyberKey.Public()) + + ext.KeyShares[i].Data = append(ecdheKey.PublicKey().Bytes(), kyberKey.EncapsulationKey()...) if !preferredCurveIsSet { // only do this once for the first non-grease curve - uconn.HandshakeState.State13.KEMKey = sk.ToPublic() + uconn.HandshakeState.State13.KeyShareKeys.kyber = kyberKey preferredCurveIsSet = true } + + if len(ext.KeyShares) > i+1 && ext.KeyShares[i+1].Group == X25519 { + // Reuse the same X25519 ephemeral key for both keyshares, as allowed by draft-ietf-tls-hybrid-design-09, Section 3.2. + uconn.HandshakeState.State13.KeyShareKeys.Ecdhe = ecdheKey + uconn.HandshakeState.State13.KeySharesParams.AddEcdheKeypair(curveID, ecdheKey, ecdheKey.PublicKey()) + ext.KeyShares[i+1].Data = ecdheKey.PublicKey().Bytes() + } } else { ecdheKey, err := generateECDHEKey(uconn.config.rand(), curveID) if err != nil { return fmt.Errorf("unsupported Curve in KeyShareExtension: %v."+ "To mimic it, fill the Data(key) field manually", curveID) } + uconn.HandshakeState.State13.KeySharesParams.AddEcdheKeypair(curveID, ecdheKey, ecdheKey.PublicKey()) + ext.KeyShares[i].Data = ecdheKey.PublicKey().Bytes() if !preferredCurveIsSet { // only do this once for the first non-grease curve - uconn.HandshakeState.State13.EcdheKey = ecdheKey + uconn.HandshakeState.State13.KeyShareKeys.Ecdhe = ecdheKey preferredCurveIsSet = true } } diff --git a/u_public.go b/u_public.go index 2d4f11c..58fd85c 100644 --- a/u_public.go +++ b/u_public.go @@ -41,6 +41,7 @@ type PubClientHandshakeState struct { // TLS 1.3 only type TLS13OnlyState struct { + // Deprecated: Use KeyShareKeys instead. EcdheKey *ecdh.PrivateKey KeySharesParams *KeySharesParameters KEMKey *KemPrivateKey @@ -111,6 +112,10 @@ func (chs *TLS13OnlyState) private13KeyShareKeys() *keySharePrivateKeys { return nil } +func kyberGoToCircl(kyberKey *mlkem768.DecapsulationKey, ecdhKey *ecdh.PrivateKey) (kem.PrivateKey, error) { + return hybrid.Kyber768X25519().UnmarshalBinaryPrivateKey(append(ecdhKey.Bytes(), kyberKey.Bytes()...)) +} + func (ksp *keySharePrivateKeys) publicKEMKey() *KemPrivateKey { if ksp.kyber != nil && ksp.ecdhe != nil && ksp.curveID == x25519Kyber768Draft00 { key := append(ksp.ecdhe.Bytes(), ksp.kyber.Bytes()...) @@ -160,6 +165,7 @@ func (chs13 *clientHandshakeStateTLS13) toPublic13() *PubClientHandshakeState { KeySharesParams: chs13.keySharesParams, EcdheKey: chs13.keyShareKeys.ecdhe, KEMKey: chs13.keyShareKeys.publicKEMKey(), + KeyShareKeys: chs13.keyShareKeys.ToPublic(), EarlySecret: chs13.earlySecret, BinderKey: chs13.binderKey, CertReq: chs13.certReq.toPublic(), @@ -863,6 +869,11 @@ func (TKS TicketKeys) ToPrivate() []ticketKey { return tks } +type kemPrivateKey struct { + secretKey kem.PrivateKey + curveID CurveID +} + type KemPrivateKey struct { SecretKey kem.PrivateKey CurveID CurveID @@ -893,14 +904,14 @@ func (kpk *kemPrivateKey) ToPublic() *KemPrivateKey { type KeySharePrivateKeys struct { CurveID CurveID Ecdhe *ecdh.PrivateKey - Kyber *mlkem768.DecapsulationKey + kyber *mlkem768.DecapsulationKey } func (ksp *KeySharePrivateKeys) ToPrivate() *keySharePrivateKeys { return &keySharePrivateKeys{ curveID: ksp.CurveID, ecdhe: ksp.Ecdhe, - kyber: ksp.Kyber, + kyber: ksp.kyber, } } @@ -908,6 +919,6 @@ func (ksp *keySharePrivateKeys) ToPublic() *KeySharePrivateKeys { return &KeySharePrivateKeys{ CurveID: ksp.curveID, Ecdhe: ksp.ecdhe, - Kyber: ksp.kyber, + kyber: ksp.kyber, } }