diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index 53b5eba..99fd6ea 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -629,6 +629,15 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { } ecdhePeerData = hs.serverHello.serverShare.data[mlkem.CiphertextSize768:] } + // [uTLS] SECTION BEGIN + if hs.serverHello.serverShare.group == X25519Kyber768Draft00 { + if len(ecdhePeerData) != x25519PublicKeySize+mlkem.CiphertextSize768 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server X25519Kyber768Draft00 key share") + } + ecdhePeerData = hs.serverHello.serverShare.data[:x25519PublicKeySize] + } + // [uTLS] SECTION END peerKey, err := hs.keyShareKeys.ecdhe.Curve().NewPublicKey(ecdhePeerData) if err != nil { c.sendAlert(alertIllegalParameter) @@ -651,6 +660,20 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { } sharedKey = append(mlkemShared, sharedKey...) } + // [uTLS] SECTION BEGIN + if hs.serverHello.serverShare.group == X25519Kyber768Draft00 { + if hs.keyShareKeys.mlkem == nil { + return c.sendAlert(alertInternalError) + } + ciphertext := hs.serverHello.serverShare.data[x25519PublicKeySize:] + kyberShared, err := kyberDecapsulate(hs.keyShareKeys.mlkem, ciphertext) + if err != nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid X25519Kyber768Draft00 server key share") + } + sharedKey = append(sharedKey, kyberShared...) + } + // [uTLS] SECTION END c.curveID = hs.serverHello.serverShare.group earlySecret := hs.earlySecret diff --git a/u_common.go b/u_common.go index 860b925..d43849b 100644 --- a/u_common.go +++ b/u_common.go @@ -88,7 +88,7 @@ const ( ) const ( - // X25519Kyber768Draft00 CurveID = x25519Kyber768Draft00 + X25519Kyber768Draft00 CurveID = 0x6399 FakeCurveX25519Kyber512Draft00 CurveID = 0xfe30 FakeCurveX25519Kyber768Draft00Old CurveID = 0xfe31 diff --git a/u_key_schedule.go b/u_key_schedule.go new file mode 100644 index 0000000..f52839d --- /dev/null +++ b/u_key_schedule.go @@ -0,0 +1,30 @@ +package tls + +import ( + "crypto/mlkem" + + "golang.org/x/crypto/sha3" +) + +// kyberDecapsulate implements decapsulation according to Kyber Round 3. +func kyberDecapsulate(dk *mlkem.DecapsulationKey768, c []byte) ([]byte, error) { + K, err := dk.Decapsulate(c) + if err != nil { + return nil, err + } + return kyberSharedSecret(c, K), nil +} + +func kyberSharedSecret(c, K []byte) []byte { + // Package mlkem implements ML-KEM, which compared to Kyber removed a + // final hashing step. Compute SHAKE-256(K || SHA3-256(c), 32) to match Kyber. + // See https://words.filippo.io/mlkem768/#bonus-track-using-a-ml-kem-implementation-as-kyber-v3. + h := sha3.NewShake256() + h.Write(K) + ch := sha3.New256() + ch.Write(c) + h.Write(ch.Sum(nil)) + out := make([]byte, 32) + h.Read(out) + return out +} diff --git a/u_parrots.go b/u_parrots.go index 78b7f6a..69c2636 100644 --- a/u_parrots.go +++ b/u_parrots.go @@ -618,7 +618,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient}, &SupportedCurvesExtension{[]CurveID{ GREASE_PLACEHOLDER, - X25519MLKEM768, + X25519Kyber768Draft00, X25519, CurveP256, CurveP384, @@ -642,7 +642,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SCTExtension{}, &KeyShareExtension{[]KeyShare{ {Group: CurveID(GREASE_PLACEHOLDER), Data: []byte{0}}, - {Group: X25519MLKEM768}, + {Group: X25519Kyber768Draft00}, {Group: X25519}, }}, &PSKKeyExchangeModesExtension{[]uint8{ @@ -764,7 +764,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient}, &SupportedCurvesExtension{[]CurveID{ GREASE_PLACEHOLDER, - X25519MLKEM768, + X25519Kyber768Draft00, X25519, CurveP256, CurveP384, @@ -788,7 +788,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SCTExtension{}, &KeyShareExtension{[]KeyShare{ {Group: CurveID(GREASE_PLACEHOLDER), Data: []byte{0}}, - {Group: X25519MLKEM768}, + {Group: X25519Kyber768Draft00}, {Group: X25519}, }}, &PSKKeyExchangeModesExtension{[]uint8{ @@ -2495,7 +2495,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient}, &SupportedCurvesExtension{[]CurveID{ GREASE_PLACEHOLDER, - X25519MLKEM768, + X25519Kyber768Draft00, X25519, CurveP256, CurveP384, @@ -2519,7 +2519,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SCTExtension{}, &KeyShareExtension{[]KeyShare{ {Group: CurveID(GREASE_PLACEHOLDER), Data: []byte{0}}, - {Group: X25519MLKEM768}, + {Group: X25519Kyber768Draft00}, {Group: X25519}, }}, &PSKKeyExchangeModesExtension{[]uint8{ @@ -2736,7 +2736,7 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error { continue } - if curveID == X25519MLKEM768 { + if curveID == X25519MLKEM768 || curveID == X25519Kyber768Draft00 { ecdheKey, err := generateECDHEKey(uconn.config.rand(), X25519) if err != nil { return err @@ -2756,7 +2756,11 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error { // } // uconn.HandshakeState.State13.KeySharesParams.AddKemKeypair(curveID, circlKyberKey, circlKyberKey.Public()) - ext.KeyShares[i].Data = append(mlkemKey.EncapsulationKey().Bytes(), ecdheKey.PublicKey().Bytes()...) + if curveID == X25519Kyber768Draft00 { + ext.KeyShares[i].Data = append(ecdheKey.PublicKey().Bytes(), mlkemKey.EncapsulationKey().Bytes()...) + } else { + ext.KeyShares[i].Data = append(mlkemKey.EncapsulationKey().Bytes(), ecdheKey.PublicKey().Bytes()...) + } if !preferredCurveIsSet { // only do this once for the first non-grease curve uconn.HandshakeState.State13.KeyShareKeys.mlkem = mlkemKey