crypto/tls: add new X25519Kyber768Draft00 code point

Ported from cloudflare/go to support the upcoming new post-quantum keyshare.

----

* Point tls.X25519Kyber768Draft00 to the new 0x6399 identifier while the
  old 0xfe31 identifier is available as tls.X25519Kyber768Draft00Old.
* Make sure that the kem.PrivateKey can always be mapped to the CurveID
  that was linked to it. This is needed since we now have two ID
  aliasing to the same scheme, and clients need to be able to detect
  whether the key share presented by the server actually matches the key
  share that the client originally sent.
* Update tests, add the new identifier and remove unnecessary code.

Link: https://mailarchive.ietf.org/arch/msg/tls/HAWpNpgptl--UZNSYuvsjB-Pc2k/
Link: https://datatracker.ietf.org/doc/draft-tls-westerbaan-xyber768d00/02/
Co-Authored-By: Peter Wu <peter@lekensteyn.nl>
Co-Authored-By: Bas Westerbaan <bas@westerbaan.name>
This commit is contained in:
Gaukas Wang 2023-08-12 16:19:18 -06:00
parent c4c1fcb925
commit bc2a3c2388
No known key found for this signature in database
GPG key ID: 9E2F8986D76F8B5D
6 changed files with 91 additions and 91 deletions

View file

@ -6,13 +6,12 @@
// To enable set CurvePreferences with the desired scheme as the first element:
//
// import (
// "github.com/cloudflare/circl/kem/tls"
// "github.com/cloudflare/circl/kem/hybrid"
// "crypto/tls"
//
// [...]
//
// config.CurvePreferences = []tls.CurveID{
// hybrid.X25519Kyber512Draft00().(tls.TLSScheme).TLSCurveID(),
// tls.X25519Kyber768Draft00,
// tls.X25519,
// tls.P256,
// }
@ -29,38 +28,27 @@ import (
"github.com/cloudflare/circl/kem/hybrid"
)
// Either ecdheParameters or kem.PrivateKey
// Either *ecdh.PrivateKey or *kemPrivateKey
type clientKeySharePrivate interface{}
var (
X25519Kyber512Draft00 = CurveID(0xfe30)
X25519Kyber768Draft00 = CurveID(0xfe31)
P256Kyber768Draft00 = CurveID(0xfe32)
invalidCurveID = CurveID(0)
)
func kemSchemeKeyToCurveID(s kem.Scheme) CurveID {
switch s.Name() {
case "Kyber512-X25519":
return X25519Kyber512Draft00
case "Kyber768-X25519":
return X25519Kyber768Draft00
case "P256Kyber768Draft00":
return P256Kyber768Draft00
default:
return invalidCurveID
}
type kemPrivateKey struct {
secretKey kem.PrivateKey
curveID CurveID
}
var (
X25519Kyber512Draft00 = CurveID(0xfe30)
X25519Kyber768Draft00 = CurveID(0x6399)
X25519Kyber768Draft00Old = CurveID(0xfe31)
P256Kyber768Draft00 = CurveID(0xfe32)
invalidCurveID = CurveID(0)
)
// Extract CurveID from clientKeySharePrivate
func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID {
switch v := ks.(type) {
case kem.PrivateKey:
ret := kemSchemeKeyToCurveID(v.Scheme())
if ret == invalidCurveID {
panic("cfkem: internal error: don't know CurveID for this KEM")
}
return ret
case *kemPrivateKey:
return v.curveID
case *ecdh.PrivateKey:
ret, ok := curveIDForCurve(v.Curve())
if !ok {
@ -77,7 +65,7 @@ func curveIdToCirclScheme(id CurveID) kem.Scheme {
switch id {
case X25519Kyber512Draft00:
return hybrid.Kyber512X25519()
case X25519Kyber768Draft00:
case X25519Kyber768Draft00, X25519Kyber768Draft00Old:
return hybrid.Kyber768X25519()
case P256Kyber768Draft00:
return hybrid.P256Kyber768Draft00()
@ -102,12 +90,12 @@ func encapsulateForKem(scheme kem.Scheme, rnd io.Reader, ppk []byte) (
}
// Generate a new keypair using randomness from rnd.
func generateKemKeyPair(scheme kem.Scheme, rnd io.Reader) (
kem.PublicKey, kem.PrivateKey, error) {
func generateKemKeyPair(scheme kem.Scheme, curveID CurveID, rnd io.Reader) (
kem.PublicKey, *kemPrivateKey, error) {
seed := make([]byte, scheme.SeedSize())
if _, err := io.ReadFull(rnd, seed); err != nil {
return nil, nil, err
}
pk, sk := scheme.DeriveKeyPair(seed)
return pk, sk, nil
return pk, &kemPrivateKey{sk, curveID}, nil
}

View file

@ -7,28 +7,16 @@ import (
"context"
"fmt"
"testing"
"github.com/cloudflare/circl/kem"
"github.com/cloudflare/circl/kem/hybrid"
)
func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ,
func testHybridKEX(t *testing.T, curveID CurveID, clientPQ, serverPQ,
clientTLS12, serverTLS12 bool) {
// var clientSelectedKEX *CurveID
// var retry bool
rsaCert := Certificate{
Certificate: [][]byte{testRSACertificate},
PrivateKey: testRSAPrivateKey,
}
serverCerts := []Certificate{rsaCert}
clientConfig := testConfig.Clone()
if clientPQ {
clientConfig.CurvePreferences = []CurveID{
kemSchemeKeyToCurveID(scheme),
X25519,
}
clientConfig.CurvePreferences = []CurveID{curveID, X25519}
}
// clientCFEventHandler := func(ev CFEvent) {
// switch e := ev.(type) {
@ -44,15 +32,13 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ,
serverConfig := testConfig.Clone()
if serverPQ {
serverConfig.CurvePreferences = []CurveID{
kemSchemeKeyToCurveID(scheme),
X25519,
}
serverConfig.CurvePreferences = []CurveID{curveID, X25519}
} else {
serverConfig.CurvePreferences = []CurveID{X25519}
}
if serverTLS12 {
serverConfig.MaxVersion = VersionTLS12
}
serverConfig.Certificates = serverCerts
c, s := localPipe(t)
done := make(chan error)
@ -78,7 +64,7 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ,
// var expectedRetry bool
// if clientPQ && serverPQ && !clientTLS12 && !serverTLS12 {
// expectedKEX = kemSchemeKeyToCurveID(scheme)
// expectedKEX = curveID
// } else {
// expectedKEX = X25519
// }
@ -86,36 +72,36 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ,
// expectedRetry = true
// }
// if clientSelectedKEX == nil {
// t.Error("No KEX happened?")
// }
// if *clientSelectedKEX != expectedKEX {
// t.Errorf("failed to negotiate: expected %d, got %d",
// expectedKEX, *clientSelectedKEX)
// }
// if expectedRetry != retry {
// t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry)
// }
// if clientSelectedKEX == nil {
// t.Error("No KEX happened?")
// } else if *clientSelectedKEX != expectedKEX {
// t.Errorf("failed to negotiate: expected %d, got %d",
// expectedKEX, *clientSelectedKEX)
// }
}
func TestHybridKEX(t *testing.T) {
run := func(scheme kem.Scheme, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) {
t.Run(fmt.Sprintf("%s serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", scheme.Name(),
run := func(curveID CurveID, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) {
t.Run(fmt.Sprintf("%#04x serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", uint16(curveID),
serverPQ, clientPQ, serverTLS12, clientTLS12), func(t *testing.T) {
testHybridKEX(t, scheme, clientPQ, serverPQ, clientTLS12, serverTLS12)
testHybridKEX(t, curveID, clientPQ, serverPQ, clientTLS12, serverTLS12)
})
}
for _, scheme := range []kem.Scheme{
hybrid.Kyber512X25519(),
hybrid.Kyber768X25519(),
hybrid.P256Kyber768Draft00(),
for _, curveID := range []CurveID{
X25519Kyber512Draft00,
X25519Kyber768Draft00,
X25519Kyber768Draft00Old,
P256Kyber768Draft00,
} {
run(scheme, true, true, false, false)
run(scheme, true, false, false, false)
run(scheme, false, true, false, false)
run(scheme, true, true, true, false)
run(scheme, true, true, false, true)
run(scheme, true, true, true, true)
run(curveID, true, true, false, false)
run(curveID, true, false, false, false)
run(curveID, false, true, false, false)
run(curveID, true, true, true, false)
run(curveID, true, true, false, true)
run(curveID, true, true, true, true)
}
}

View file

@ -22,7 +22,6 @@ import (
"strings"
"time"
circlKem "github.com/cloudflare/circl/kem"
circlSign "github.com/cloudflare/circl/sign"
)
@ -152,8 +151,10 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
}
curveID := config.curvePreferences()[0]
// [UTLS SECTION BEGINS]
// Ported from cloudflare/go with modifications to preserve crypto/tls compatibility
if scheme := curveIdToCirclScheme(curveID); scheme != nil {
pk, sk, err := generateKemKeyPair(scheme, config.rand())
pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand())
if err != nil {
return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w", scheme.Name(), err)
}
@ -174,6 +175,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
secret = key
}
// [UTLS SECTION ENDS]
}
if c.quic != nil {
@ -283,7 +285,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
if ecdheKey, ok := keySharePrivate.(*ecdh.PrivateKey); ok {
hs.ecdheKey = ecdheKey
} else if kemKey, ok := keySharePrivate.(circlKem.PrivateKey); ok {
} else if kemKey, ok := keySharePrivate.(*kemPrivateKey); ok {
hs.kemKey = kemKey
} else {
return fmt.Errorf("tls: unknown key share type %T", keySharePrivate)

View file

@ -15,8 +15,6 @@ import (
"fmt"
"hash"
"time"
circlKem "github.com/cloudflare/circl/kem"
)
// [uTLS SECTION START]
@ -43,7 +41,7 @@ type clientHandshakeStateTLS13 struct {
hello *clientHelloMsg
ecdheKey *ecdh.PrivateKey
keySharesEcdheParams KeySharesEcdheParameters // [uTLS]
kemKey circlKem.PrivateKey // [uTLS]
kemKey *kemPrivateKey // [uTLS]
// keySharesCirclParams KeySharesCirclParameters // [uTLS] TODO: perhaps implement?
session *SessionState
@ -291,7 +289,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
}
if scheme := curveIdToCirclScheme(curveID); scheme != nil {
pk, sk, err := generateKemKeyPair(scheme, c.config.rand())
pk, sk, err := generateKemKeyPair(scheme, curveID, c.config.rand())
if err != nil {
c.sendAlert(alertInternalError)
return fmt.Errorf("HRR generateKemKeyPair %s: %w",
@ -534,10 +532,11 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
return errors.New("tls: invalid server key share")
}
} else if hs.kemKey != nil {
sharedKey, err = hs.kemKey.Scheme().Decapsulate(hs.kemKey, hs.serverHello.serverShare.data)
sk := hs.kemKey.secretKey
sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data)
if err != nil {
c.sendAlert(alertIllegalParameter)
return fmt.Errorf("%s decaps: %w", hs.kemKey.Scheme().Name(), err)
return fmt.Errorf("%s decaps: %w", sk.Scheme().Name(), err)
}
} else {
c.sendAlert(alertInternalError)

View file

@ -17,8 +17,6 @@ import (
"io"
"net"
"strconv"
circlKem "github.com/cloudflare/circl/kem"
)
type UConn struct {
@ -87,10 +85,10 @@ func (uconn *UConn) BuildHandshakeState() error {
uconn.HandshakeState.Hello = hello.getPublicPtr()
if ecdheKey, ok := keySharePrivate.(*ecdh.PrivateKey); ok {
uconn.HandshakeState.State13.EcdheKey = ecdheKey
} else if kemKey, ok := keySharePrivate.(circlKem.PrivateKey); ok {
uconn.HandshakeState.State13.KEMKey = kemKey
} else if kemKey, ok := keySharePrivate.(*kemPrivateKey); ok {
uconn.HandshakeState.State13.KEMKey = kemKey.ToPublic()
} else {
return fmt.Errorf("uTLS: unknown keySharePrivate type: %T", keySharePrivate)
}
uconn.HandshakeState.C = uconn.Conn
} else {

View file

@ -11,7 +11,7 @@ import (
"hash"
"time"
circlKem "github.com/cloudflare/circl/kem"
"github.com/cloudflare/circl/kem"
)
// ClientHandshakeState includes both TLS 1.3-only and TLS 1.2-only states,
@ -41,7 +41,7 @@ type TLS13OnlyState struct {
Suite *PubCipherSuiteTLS13
EcdheKey *ecdh.PrivateKey
KeySharesEcdheParams KeySharesEcdheParameters
KEMKey circlKem.PrivateKey
KEMKey *KemPrivateKey
EarlySecret []byte
BinderKey []byte
CertReq *CertificateRequestMsgTLS13
@ -67,7 +67,7 @@ func (chs *PubClientHandshakeState) toPrivate13() *clientHandshakeStateTLS13 {
hello: chs.Hello.getPrivatePtr(),
ecdheKey: chs.State13.EcdheKey,
keySharesEcdheParams: chs.State13.KeySharesEcdheParams,
kemKey: chs.State13.KEMKey,
kemKey: chs.State13.KEMKey.ToPrivate(),
session: chs.Session,
earlySecret: chs.State13.EarlySecret,
@ -93,7 +93,7 @@ func (chs13 *clientHandshakeStateTLS13) toPublic13() *PubClientHandshakeState {
tls13State := TLS13OnlyState{
KeySharesEcdheParams: chs13.keySharesEcdheParams,
EcdheKey: chs13.ecdheKey,
KEMKey: chs13.kemKey,
KEMKey: chs13.kemKey.ToPublic(),
EarlySecret: chs13.earlySecret,
BinderKey: chs13.binderKey,
CertReq: chs13.certReq.toPublic(),
@ -739,3 +739,30 @@ func (TKS TicketKeys) ToPrivate() []ticketKey {
}
return tks
}
type KemPrivateKey struct {
SecretKey kem.PrivateKey
CurveID CurveID
}
func (kpk *KemPrivateKey) ToPrivate() *kemPrivateKey {
if kpk == nil {
return nil
} else {
return &kemPrivateKey{
secretKey: kpk.SecretKey,
curveID: kpk.CurveID,
}
}
}
func (kpk *kemPrivateKey) ToPublic() *KemPrivateKey {
if kpk == nil {
return nil
} else {
return &KemPrivateKey{
SecretKey: kpk.secretKey,
CurveID: kpk.curveID,
}
}
}