From ff854c45c6158e853eb966550300587d80a7265d Mon Sep 17 00:00:00 2001 From: Mingye Chen Date: Thu, 9 Jan 2025 20:29:38 -0700 Subject: [PATCH] fix: broken code during merge --- handshake_client_tls13.go | 8 +++---- u_public.go | 50 +++++++++++++++++++++++++++------------ 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index fd2d6bd..c84bd16 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -124,14 +124,14 @@ func (hs *clientHandshakeStateTLS13) handshake() error { // set echdheParams to what we received from server if ecdheKey, ok := hs.keySharesParams.GetEcdheKey(hs.serverHello.serverShare.group); ok { hs.keyShareKeys.ecdhe = ecdheKey - hs.keyShareKeys.kyber = nil // unset kyber if any + hs.keyShareKeys.curveID = hs.serverHello.serverShare.group } // set kemParams to what we received from server if kemKey, ok := hs.keySharesParams.GetKemKey(hs.serverHello.serverShare.group); ok { - if gokey, err := mlkemCirclToGo(kemKey); err == nil { - hs.keyShareKeys.kyber = gokey + if kyberKey, ecdhKey, err := mlkemCirclToGo(kemKey); err == nil { + hs.keyShareKeys.kyber = kyberKey + hs.keyShareKeys.ecdhe = ecdhKey hs.keyShareKeys.curveID = hs.serverHello.serverShare.group - hs.keyShareKeys.ecdhe = nil // unset ecdheKey if any } } // [uTLS SECTION END] diff --git a/u_public.go b/u_public.go index 168fb9f..2d4f11c 100644 --- a/u_public.go +++ b/u_public.go @@ -8,11 +8,12 @@ import ( "crypto" "crypto/ecdh" "crypto/x509" + "fmt" "hash" "time" "github.com/cloudflare/circl/kem" - "github.com/cloudflare/circl/kem/kyber/kyber768" + "github.com/cloudflare/circl/kem/hybrid" "github.com/refraction-networking/utls/internal/mlkem768" ) @@ -40,7 +41,6 @@ type PubClientHandshakeState struct { // TLS 1.3 only type TLS13OnlyState struct { - // Deprecated: Use KeyShareKeys instead. EcdheKey *ecdh.PrivateKey KeySharesParams *KeySharesParameters KEMKey *KemPrivateKey @@ -62,13 +62,30 @@ type TLS12OnlyState struct { Suite PubCipherSuite } -func mlkemCirclToGo(circlKey kem.PrivateKey) (*mlkem768.DecapsulationKey, error) { - encodedKey, err := circlKey.MarshalBinary() - if err != nil { - return nil, err +func mlkemCirclToGo(circlKey kem.PrivateKey) (*mlkem768.DecapsulationKey, *ecdh.PrivateKey, error) { + if circlKey.Scheme().Name() != "Kyber768-X25519" { + return nil, nil, fmt.Errorf("circl key is not Kyber768-X25519") } - return mlkem768.NewKeyFromExtendedEncoding(encodedKey) + encodedKey, err := circlKey.MarshalBinary() + if err != nil { + return nil, nil, err + } + + ecdhKey := encodedKey[:x25519PublicKeySize] + kyberKey := encodedKey[x25519PublicKeySize:] + + goKyberkey, err := mlkem768.NewKeyFromExtendedEncoding(kyberKey) + if err != nil { + return nil, nil, err + } + + goEcdhKey, err := ecdh.X25519().NewPrivateKey(ecdhKey) + if err != nil { + return nil, nil, err + } + + return goKyberkey, goEcdhKey, nil } func (chs *TLS13OnlyState) private13KeyShareKeys() *keySharePrivateKeys { @@ -77,9 +94,10 @@ func (chs *TLS13OnlyState) private13KeyShareKeys() *keySharePrivateKeys { } if chs.KEMKey != nil { - if key, err := mlkemCirclToGo(chs.KEMKey.SecretKey); err == nil { + if kyberKey, ecdhKey, err := mlkemCirclToGo(chs.KEMKey.SecretKey); err == nil { return &keySharePrivateKeys{ - kyber: key, + kyber: kyberKey, + ecdhe: ecdhKey, } } } @@ -94,8 +112,9 @@ func (chs *TLS13OnlyState) private13KeyShareKeys() *keySharePrivateKeys { } func (ksp *keySharePrivateKeys) publicKEMKey() *KemPrivateKey { - if ksp.kyber != nil { - if privkey, err := kyber768.Scheme().UnmarshalBinaryPrivateKey(ksp.kyber.Bytes()); err == nil { + if ksp.kyber != nil && ksp.ecdhe != nil && ksp.curveID == x25519Kyber768Draft00 { + key := append(ksp.ecdhe.Bytes(), ksp.kyber.Bytes()...) + if privkey, err := hybrid.Kyber768X25519().UnmarshalBinaryPrivateKey(key); err == nil { return &KemPrivateKey{ SecretKey: privkey, } @@ -110,10 +129,11 @@ func (chs *PubClientHandshakeState) toPrivate13() *clientHandshakeStateTLS13 { return nil } else { return &clientHandshakeStateTLS13{ - c: chs.C, - serverHello: chs.ServerHello.getPrivatePtr(), - hello: chs.Hello.getPrivatePtr(), - keyShareKeys: chs.State13.private13KeyShareKeys(), + c: chs.C, + serverHello: chs.ServerHello.getPrivatePtr(), + hello: chs.Hello.getPrivatePtr(), + keyShareKeys: chs.State13.private13KeyShareKeys(), + keySharesParams: chs.State13.KeySharesParams, session: chs.Session, earlySecret: chs.State13.EarlySecret,