diff --git a/handshake_client.go b/handshake_client.go index 1c14476..ea9c4c5 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -10,9 +10,9 @@ import ( "crypto" "crypto/ecdsa" "crypto/ed25519" + "crypto/internal/fips/mlkem" "crypto/internal/fips/tls13" "crypto/internal/hpke" - "crypto/internal/mlkem768" "crypto/rsa" "crypto/subtle" "crypto/x509" @@ -160,11 +160,11 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCon if err != nil { return nil, nil, nil, err } - seed := make([]byte, mlkem768.SeedSize) + seed := make([]byte, mlkem.SeedSize) if _, err := io.ReadFull(config.rand(), seed); err != nil { return nil, nil, nil, err } - keyShareKeys.kyber, err = mlkem768.NewDecapsulationKey768(seed) + keyShareKeys.kyber, err = mlkem.NewDecapsulationKey768(seed) if err != nil { return nil, nil, nil, err } diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index fbec743..6ce83b9 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -10,8 +10,8 @@ import ( "crypto" "crypto/hmac" "crypto/internal/fips/hkdf" + "crypto/internal/fips/mlkem" "crypto/internal/fips/tls13" - "crypto/internal/mlkem768" "crypto/rsa" "crypto/subtle" "errors" @@ -481,7 +481,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { ecdhePeerData := hs.serverHello.serverShare.data if hs.serverHello.serverShare.group == x25519Kyber768Draft00 { - if len(ecdhePeerData) != x25519PublicKeySize+mlkem768.CiphertextSize768 { + if len(ecdhePeerData) != x25519PublicKeySize+mlkem.CiphertextSize768 { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid server key share") } diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 3591aa1..aa1ffd9 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -9,8 +9,8 @@ import ( "context" "crypto" "crypto/hmac" + "crypto/internal/fips/mlkem" "crypto/internal/fips/tls13" - "crypto/internal/mlkem768" "crypto/rsa" "errors" "hash" @@ -223,7 +223,7 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error { ecdhData := clientKeyShare.data if selectedGroup == x25519Kyber768Draft00 { ecdhGroup = X25519 - if len(ecdhData) != x25519PublicKeySize+mlkem768.EncapsulationKeySize768 { + if len(ecdhData) != x25519PublicKeySize+mlkem.EncapsulationKeySize768 { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid Kyber client key share") } diff --git a/key_schedule.go b/key_schedule.go index 8377807..99229ea 100644 --- a/key_schedule.go +++ b/key_schedule.go @@ -7,13 +7,12 @@ package tls import ( "crypto/ecdh" "crypto/hmac" + "crypto/internal/fips/mlkem" + "crypto/internal/fips/sha3" "crypto/internal/fips/tls13" - "crypto/internal/mlkem768" "errors" "hash" "io" - - "golang.org/x/crypto/sha3" ) // This file contains the functions necessary to compute the TLS 1.3 key @@ -54,11 +53,11 @@ func (c *cipherSuiteTLS13) exportKeyingMaterial(s *tls13.MasterSecret, transcrip type keySharePrivateKeys struct { curveID CurveID ecdhe *ecdh.PrivateKey - kyber *mlkem768.DecapsulationKey768 + kyber *mlkem.DecapsulationKey768 } // kyberDecapsulate implements decapsulation according to Kyber Round 3. -func kyberDecapsulate(dk *mlkem768.DecapsulationKey768, c []byte) ([]byte, error) { +func kyberDecapsulate(dk *mlkem.DecapsulationKey768, c []byte) ([]byte, error) { K, err := dk.Decapsulate(c) if err != nil { return nil, err @@ -68,7 +67,7 @@ func kyberDecapsulate(dk *mlkem768.DecapsulationKey768, c []byte) ([]byte, error // kyberEncapsulate implements encapsulation according to Kyber Round 3. func kyberEncapsulate(ek []byte) (c, ss []byte, err error) { - k, err := mlkem768.NewEncapsulationKey768(ek) + k, err := mlkem.NewEncapsulationKey768(ek) if err != nil { return nil, nil, err } @@ -77,13 +76,14 @@ func kyberEncapsulate(ek []byte) (c, ss []byte, err error) { } func kyberSharedSecret(c, K []byte) []byte { - // Package mlkem768 implements ML-KEM, which compared to Kyber removed a + // Package mlkem implements ML-KEM, which compared to Kyber removed a // final hashing step. Compute SHAKE-256(K || SHA3-256(c), 32) to match Kyber. // See https://words.filippo.io/mlkem768/#bonus-track-using-a-ml-kem-implementation-as-kyber-v3. h := sha3.NewShake256() h.Write(K) - ch := sha3.Sum256(c) - h.Write(ch[:]) + ch := sha3.New256() + ch.Write(c) + h.Write(ch.Sum(nil)) out := make([]byte, 32) h.Read(out) return out diff --git a/key_schedule_test.go b/key_schedule_test.go index 766370f..0dc3601 100644 --- a/key_schedule_test.go +++ b/key_schedule_test.go @@ -6,8 +6,8 @@ package tls import ( "bytes" + "crypto/internal/fips/mlkem" "crypto/internal/fips/tls13" - "crypto/internal/mlkem768" "crypto/sha256" "encoding/hex" "strings" @@ -120,7 +120,7 @@ func TestTrafficKey(t *testing.T) { } func TestKyberEncapsulate(t *testing.T) { - dk, err := mlkem768.GenerateKey768() + dk, err := mlkem.GenerateKey768() if err != nil { t.Fatal(err) }