diff --git a/handshake_client.go b/handshake_client.go index f6bccc4..8965ad6 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -164,7 +164,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCon if _, err := io.ReadFull(config.rand(), seed); err != nil { return nil, nil, nil, err } - keyShareKeys.kyber, err = mlkem768.NewKeyFromSeed(seed) + keyShareKeys.kyber, err = mlkem768.NewDecapsulationKey(seed) if err != nil { return nil, nil, nil, err } @@ -174,7 +174,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCon // both, as allowed by draft-ietf-tls-hybrid-design-09, Section 3.2. hello.keyShares = []keyShare{ {group: x25519Kyber768Draft00, data: append(keyShareKeys.ecdhe.PublicKey().Bytes(), - keyShareKeys.kyber.EncapsulationKey()...)}, + keyShareKeys.kyber.EncapsulationKey().Bytes()...)}, {group: X25519, data: keyShareKeys.ecdhe.PublicKey().Bytes()}, } } else { diff --git a/key_schedule.go b/key_schedule.go index e8ee9ce..3bbfc1b 100644 --- a/key_schedule.go +++ b/key_schedule.go @@ -63,19 +63,20 @@ func kyberDecapsulate(dk *mlkem768.DecapsulationKey, c []byte) ([]byte, error) { if err != nil { return nil, err } - return kyberSharedSecret(K, c), nil + return kyberSharedSecret(c, K), nil } // kyberEncapsulate implements encapsulation according to Kyber Round 3. func kyberEncapsulate(ek []byte) (c, ss []byte, err error) { - c, ss, err = mlkem768.Encapsulate(ek) + k, err := mlkem768.NewEncapsulationKey(ek) if err != nil { return nil, nil, err } - return c, kyberSharedSecret(ss, c), nil + c, ss = k.Encapsulate() + return c, kyberSharedSecret(c, ss), nil } -func kyberSharedSecret(K, c []byte) []byte { +func kyberSharedSecret(c, K []byte) []byte { // Package mlkem768 implements ML-KEM, which compared to Kyber removed a // final hashing step. Compute SHAKE-256(K || SHA3-256(c), 32) to match Kyber. // See https://words.filippo.io/mlkem768/#bonus-track-using-a-ml-kem-implementation-as-kyber-v3. diff --git a/key_schedule_test.go b/key_schedule_test.go index 095113c..3253277 100644 --- a/key_schedule_test.go +++ b/key_schedule_test.go @@ -124,7 +124,7 @@ func TestKyberEncapsulate(t *testing.T) { if err != nil { t.Fatal(err) } - ct, ss, err := kyberEncapsulate(dk.EncapsulationKey()) + ct, ss, err := kyberEncapsulate(dk.EncapsulationKey().Bytes()) if err != nil { t.Fatal(err) }