crypto/tls: add new X25519Kyber768Draft00 code point

Ported from cloudflare/go to support the upcoming new post-quantum keyshare.

----

* Point tls.X25519Kyber768Draft00 to the new 0x6399 identifier while the
  old 0xfe31 identifier is available as tls.X25519Kyber768Draft00Old.
* Make sure that the kem.PrivateKey can always be mapped to the CurveID
  that was linked to it. This is needed since we now have two ID
  aliasing to the same scheme, and clients need to be able to detect
  whether the key share presented by the server actually matches the key
  share that the client originally sent.
* Update tests, add the new identifier and remove unnecessary code.

Link: https://mailarchive.ietf.org/arch/msg/tls/HAWpNpgptl--UZNSYuvsjB-Pc2k/
Link: https://datatracker.ietf.org/doc/draft-tls-westerbaan-xyber768d00/02/
Co-Authored-By: Peter Wu <peter@lekensteyn.nl>
Co-Authored-By: Bas Westerbaan <bas@westerbaan.name>
This commit is contained in:
Gaukas Wang 2023-08-12 16:19:18 -06:00
parent c4c1fcb925
commit bc2a3c2388
No known key found for this signature in database
GPG key ID: 9E2F8986D76F8B5D
6 changed files with 91 additions and 91 deletions

View file

@ -6,13 +6,12 @@
// To enable set CurvePreferences with the desired scheme as the first element: // To enable set CurvePreferences with the desired scheme as the first element:
// //
// import ( // import (
// "github.com/cloudflare/circl/kem/tls" // "crypto/tls"
// "github.com/cloudflare/circl/kem/hybrid"
// //
// [...] // [...]
// //
// config.CurvePreferences = []tls.CurveID{ // config.CurvePreferences = []tls.CurveID{
// hybrid.X25519Kyber512Draft00().(tls.TLSScheme).TLSCurveID(), // tls.X25519Kyber768Draft00,
// tls.X25519, // tls.X25519,
// tls.P256, // tls.P256,
// } // }
@ -29,38 +28,27 @@ import (
"github.com/cloudflare/circl/kem/hybrid" "github.com/cloudflare/circl/kem/hybrid"
) )
// Either ecdheParameters or kem.PrivateKey // Either *ecdh.PrivateKey or *kemPrivateKey
type clientKeySharePrivate interface{} type clientKeySharePrivate interface{}
var ( type kemPrivateKey struct {
X25519Kyber512Draft00 = CurveID(0xfe30) secretKey kem.PrivateKey
X25519Kyber768Draft00 = CurveID(0xfe31) curveID CurveID
P256Kyber768Draft00 = CurveID(0xfe32)
invalidCurveID = CurveID(0)
)
func kemSchemeKeyToCurveID(s kem.Scheme) CurveID {
switch s.Name() {
case "Kyber512-X25519":
return X25519Kyber512Draft00
case "Kyber768-X25519":
return X25519Kyber768Draft00
case "P256Kyber768Draft00":
return P256Kyber768Draft00
default:
return invalidCurveID
}
} }
var (
X25519Kyber512Draft00 = CurveID(0xfe30)
X25519Kyber768Draft00 = CurveID(0x6399)
X25519Kyber768Draft00Old = CurveID(0xfe31)
P256Kyber768Draft00 = CurveID(0xfe32)
invalidCurveID = CurveID(0)
)
// Extract CurveID from clientKeySharePrivate // Extract CurveID from clientKeySharePrivate
func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID { func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID {
switch v := ks.(type) { switch v := ks.(type) {
case kem.PrivateKey: case *kemPrivateKey:
ret := kemSchemeKeyToCurveID(v.Scheme()) return v.curveID
if ret == invalidCurveID {
panic("cfkem: internal error: don't know CurveID for this KEM")
}
return ret
case *ecdh.PrivateKey: case *ecdh.PrivateKey:
ret, ok := curveIDForCurve(v.Curve()) ret, ok := curveIDForCurve(v.Curve())
if !ok { if !ok {
@ -77,7 +65,7 @@ func curveIdToCirclScheme(id CurveID) kem.Scheme {
switch id { switch id {
case X25519Kyber512Draft00: case X25519Kyber512Draft00:
return hybrid.Kyber512X25519() return hybrid.Kyber512X25519()
case X25519Kyber768Draft00: case X25519Kyber768Draft00, X25519Kyber768Draft00Old:
return hybrid.Kyber768X25519() return hybrid.Kyber768X25519()
case P256Kyber768Draft00: case P256Kyber768Draft00:
return hybrid.P256Kyber768Draft00() return hybrid.P256Kyber768Draft00()
@ -102,12 +90,12 @@ func encapsulateForKem(scheme kem.Scheme, rnd io.Reader, ppk []byte) (
} }
// Generate a new keypair using randomness from rnd. // Generate a new keypair using randomness from rnd.
func generateKemKeyPair(scheme kem.Scheme, rnd io.Reader) ( func generateKemKeyPair(scheme kem.Scheme, curveID CurveID, rnd io.Reader) (
kem.PublicKey, kem.PrivateKey, error) { kem.PublicKey, *kemPrivateKey, error) {
seed := make([]byte, scheme.SeedSize()) seed := make([]byte, scheme.SeedSize())
if _, err := io.ReadFull(rnd, seed); err != nil { if _, err := io.ReadFull(rnd, seed); err != nil {
return nil, nil, err return nil, nil, err
} }
pk, sk := scheme.DeriveKeyPair(seed) pk, sk := scheme.DeriveKeyPair(seed)
return pk, sk, nil return pk, &kemPrivateKey{sk, curveID}, nil
} }

View file

@ -7,28 +7,16 @@ import (
"context" "context"
"fmt" "fmt"
"testing" "testing"
"github.com/cloudflare/circl/kem"
"github.com/cloudflare/circl/kem/hybrid"
) )
func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ, func testHybridKEX(t *testing.T, curveID CurveID, clientPQ, serverPQ,
clientTLS12, serverTLS12 bool) { clientTLS12, serverTLS12 bool) {
// var clientSelectedKEX *CurveID // var clientSelectedKEX *CurveID
// var retry bool // var retry bool
rsaCert := Certificate{
Certificate: [][]byte{testRSACertificate},
PrivateKey: testRSAPrivateKey,
}
serverCerts := []Certificate{rsaCert}
clientConfig := testConfig.Clone() clientConfig := testConfig.Clone()
if clientPQ { if clientPQ {
clientConfig.CurvePreferences = []CurveID{ clientConfig.CurvePreferences = []CurveID{curveID, X25519}
kemSchemeKeyToCurveID(scheme),
X25519,
}
} }
// clientCFEventHandler := func(ev CFEvent) { // clientCFEventHandler := func(ev CFEvent) {
// switch e := ev.(type) { // switch e := ev.(type) {
@ -44,15 +32,13 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ,
serverConfig := testConfig.Clone() serverConfig := testConfig.Clone()
if serverPQ { if serverPQ {
serverConfig.CurvePreferences = []CurveID{ serverConfig.CurvePreferences = []CurveID{curveID, X25519}
kemSchemeKeyToCurveID(scheme), } else {
X25519, serverConfig.CurvePreferences = []CurveID{X25519}
}
} }
if serverTLS12 { if serverTLS12 {
serverConfig.MaxVersion = VersionTLS12 serverConfig.MaxVersion = VersionTLS12
} }
serverConfig.Certificates = serverCerts
c, s := localPipe(t) c, s := localPipe(t)
done := make(chan error) done := make(chan error)
@ -78,7 +64,7 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ,
// var expectedRetry bool // var expectedRetry bool
// if clientPQ && serverPQ && !clientTLS12 && !serverTLS12 { // if clientPQ && serverPQ && !clientTLS12 && !serverTLS12 {
// expectedKEX = kemSchemeKeyToCurveID(scheme) // expectedKEX = curveID
// } else { // } else {
// expectedKEX = X25519 // expectedKEX = X25519
// } // }
@ -86,36 +72,36 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ,
// expectedRetry = true // expectedRetry = true
// } // }
// if clientSelectedKEX == nil {
// t.Error("No KEX happened?")
// }
// if *clientSelectedKEX != expectedKEX {
// t.Errorf("failed to negotiate: expected %d, got %d",
// expectedKEX, *clientSelectedKEX)
// }
// if expectedRetry != retry { // if expectedRetry != retry {
// t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry) // t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry)
// } // }
// if clientSelectedKEX == nil {
// t.Error("No KEX happened?")
// } else if *clientSelectedKEX != expectedKEX {
// t.Errorf("failed to negotiate: expected %d, got %d",
// expectedKEX, *clientSelectedKEX)
// }
} }
func TestHybridKEX(t *testing.T) { func TestHybridKEX(t *testing.T) {
run := func(scheme kem.Scheme, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) { run := func(curveID CurveID, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) {
t.Run(fmt.Sprintf("%s serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", scheme.Name(), t.Run(fmt.Sprintf("%#04x serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", uint16(curveID),
serverPQ, clientPQ, serverTLS12, clientTLS12), func(t *testing.T) { serverPQ, clientPQ, serverTLS12, clientTLS12), func(t *testing.T) {
testHybridKEX(t, scheme, clientPQ, serverPQ, clientTLS12, serverTLS12) testHybridKEX(t, curveID, clientPQ, serverPQ, clientTLS12, serverTLS12)
}) })
} }
for _, scheme := range []kem.Scheme{ for _, curveID := range []CurveID{
hybrid.Kyber512X25519(), X25519Kyber512Draft00,
hybrid.Kyber768X25519(), X25519Kyber768Draft00,
hybrid.P256Kyber768Draft00(), X25519Kyber768Draft00Old,
P256Kyber768Draft00,
} { } {
run(scheme, true, true, false, false) run(curveID, true, true, false, false)
run(scheme, true, false, false, false) run(curveID, true, false, false, false)
run(scheme, false, true, false, false) run(curveID, false, true, false, false)
run(scheme, true, true, true, false) run(curveID, true, true, true, false)
run(scheme, true, true, false, true) run(curveID, true, true, false, true)
run(scheme, true, true, true, true) run(curveID, true, true, true, true)
} }
} }

View file

@ -22,7 +22,6 @@ import (
"strings" "strings"
"time" "time"
circlKem "github.com/cloudflare/circl/kem"
circlSign "github.com/cloudflare/circl/sign" circlSign "github.com/cloudflare/circl/sign"
) )
@ -152,8 +151,10 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
} }
curveID := config.curvePreferences()[0] curveID := config.curvePreferences()[0]
// [UTLS SECTION BEGINS]
// Ported from cloudflare/go with modifications to preserve crypto/tls compatibility
if scheme := curveIdToCirclScheme(curveID); scheme != nil { if scheme := curveIdToCirclScheme(curveID); scheme != nil {
pk, sk, err := generateKemKeyPair(scheme, config.rand()) pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand())
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w", scheme.Name(), err) return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w", scheme.Name(), err)
} }
@ -174,6 +175,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
secret = key secret = key
} }
// [UTLS SECTION ENDS]
} }
if c.quic != nil { if c.quic != nil {
@ -283,7 +285,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
if ecdheKey, ok := keySharePrivate.(*ecdh.PrivateKey); ok { if ecdheKey, ok := keySharePrivate.(*ecdh.PrivateKey); ok {
hs.ecdheKey = ecdheKey hs.ecdheKey = ecdheKey
} else if kemKey, ok := keySharePrivate.(circlKem.PrivateKey); ok { } else if kemKey, ok := keySharePrivate.(*kemPrivateKey); ok {
hs.kemKey = kemKey hs.kemKey = kemKey
} else { } else {
return fmt.Errorf("tls: unknown key share type %T", keySharePrivate) return fmt.Errorf("tls: unknown key share type %T", keySharePrivate)

View file

@ -15,8 +15,6 @@ import (
"fmt" "fmt"
"hash" "hash"
"time" "time"
circlKem "github.com/cloudflare/circl/kem"
) )
// [uTLS SECTION START] // [uTLS SECTION START]
@ -43,7 +41,7 @@ type clientHandshakeStateTLS13 struct {
hello *clientHelloMsg hello *clientHelloMsg
ecdheKey *ecdh.PrivateKey ecdheKey *ecdh.PrivateKey
keySharesEcdheParams KeySharesEcdheParameters // [uTLS] keySharesEcdheParams KeySharesEcdheParameters // [uTLS]
kemKey circlKem.PrivateKey // [uTLS] kemKey *kemPrivateKey // [uTLS]
// keySharesCirclParams KeySharesCirclParameters // [uTLS] TODO: perhaps implement? // keySharesCirclParams KeySharesCirclParameters // [uTLS] TODO: perhaps implement?
session *SessionState session *SessionState
@ -291,7 +289,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
} }
if scheme := curveIdToCirclScheme(curveID); scheme != nil { if scheme := curveIdToCirclScheme(curveID); scheme != nil {
pk, sk, err := generateKemKeyPair(scheme, c.config.rand()) pk, sk, err := generateKemKeyPair(scheme, curveID, c.config.rand())
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return fmt.Errorf("HRR generateKemKeyPair %s: %w", return fmt.Errorf("HRR generateKemKeyPair %s: %w",
@ -534,10 +532,11 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
return errors.New("tls: invalid server key share") return errors.New("tls: invalid server key share")
} }
} else if hs.kemKey != nil { } else if hs.kemKey != nil {
sharedKey, err = hs.kemKey.Scheme().Decapsulate(hs.kemKey, hs.serverHello.serverShare.data) sk := hs.kemKey.secretKey
sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data)
if err != nil { if err != nil {
c.sendAlert(alertIllegalParameter) c.sendAlert(alertIllegalParameter)
return fmt.Errorf("%s decaps: %w", hs.kemKey.Scheme().Name(), err) return fmt.Errorf("%s decaps: %w", sk.Scheme().Name(), err)
} }
} else { } else {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)

View file

@ -17,8 +17,6 @@ import (
"io" "io"
"net" "net"
"strconv" "strconv"
circlKem "github.com/cloudflare/circl/kem"
) )
type UConn struct { type UConn struct {
@ -87,10 +85,10 @@ func (uconn *UConn) BuildHandshakeState() error {
uconn.HandshakeState.Hello = hello.getPublicPtr() uconn.HandshakeState.Hello = hello.getPublicPtr()
if ecdheKey, ok := keySharePrivate.(*ecdh.PrivateKey); ok { if ecdheKey, ok := keySharePrivate.(*ecdh.PrivateKey); ok {
uconn.HandshakeState.State13.EcdheKey = ecdheKey uconn.HandshakeState.State13.EcdheKey = ecdheKey
} else if kemKey, ok := keySharePrivate.(circlKem.PrivateKey); ok { } else if kemKey, ok := keySharePrivate.(*kemPrivateKey); ok {
uconn.HandshakeState.State13.KEMKey = kemKey uconn.HandshakeState.State13.KEMKey = kemKey.ToPublic()
} else { } else {
return fmt.Errorf("uTLS: unknown keySharePrivate type: %T", keySharePrivate)
} }
uconn.HandshakeState.C = uconn.Conn uconn.HandshakeState.C = uconn.Conn
} else { } else {

View file

@ -11,7 +11,7 @@ import (
"hash" "hash"
"time" "time"
circlKem "github.com/cloudflare/circl/kem" "github.com/cloudflare/circl/kem"
) )
// ClientHandshakeState includes both TLS 1.3-only and TLS 1.2-only states, // ClientHandshakeState includes both TLS 1.3-only and TLS 1.2-only states,
@ -41,7 +41,7 @@ type TLS13OnlyState struct {
Suite *PubCipherSuiteTLS13 Suite *PubCipherSuiteTLS13
EcdheKey *ecdh.PrivateKey EcdheKey *ecdh.PrivateKey
KeySharesEcdheParams KeySharesEcdheParameters KeySharesEcdheParams KeySharesEcdheParameters
KEMKey circlKem.PrivateKey KEMKey *KemPrivateKey
EarlySecret []byte EarlySecret []byte
BinderKey []byte BinderKey []byte
CertReq *CertificateRequestMsgTLS13 CertReq *CertificateRequestMsgTLS13
@ -67,7 +67,7 @@ func (chs *PubClientHandshakeState) toPrivate13() *clientHandshakeStateTLS13 {
hello: chs.Hello.getPrivatePtr(), hello: chs.Hello.getPrivatePtr(),
ecdheKey: chs.State13.EcdheKey, ecdheKey: chs.State13.EcdheKey,
keySharesEcdheParams: chs.State13.KeySharesEcdheParams, keySharesEcdheParams: chs.State13.KeySharesEcdheParams,
kemKey: chs.State13.KEMKey, kemKey: chs.State13.KEMKey.ToPrivate(),
session: chs.Session, session: chs.Session,
earlySecret: chs.State13.EarlySecret, earlySecret: chs.State13.EarlySecret,
@ -93,7 +93,7 @@ func (chs13 *clientHandshakeStateTLS13) toPublic13() *PubClientHandshakeState {
tls13State := TLS13OnlyState{ tls13State := TLS13OnlyState{
KeySharesEcdheParams: chs13.keySharesEcdheParams, KeySharesEcdheParams: chs13.keySharesEcdheParams,
EcdheKey: chs13.ecdheKey, EcdheKey: chs13.ecdheKey,
KEMKey: chs13.kemKey, KEMKey: chs13.kemKey.ToPublic(),
EarlySecret: chs13.earlySecret, EarlySecret: chs13.earlySecret,
BinderKey: chs13.binderKey, BinderKey: chs13.binderKey,
CertReq: chs13.certReq.toPublic(), CertReq: chs13.certReq.toPublic(),
@ -739,3 +739,30 @@ func (TKS TicketKeys) ToPrivate() []ticketKey {
} }
return tks return tks
} }
type KemPrivateKey struct {
SecretKey kem.PrivateKey
CurveID CurveID
}
func (kpk *KemPrivateKey) ToPrivate() *kemPrivateKey {
if kpk == nil {
return nil
} else {
return &kemPrivateKey{
secretKey: kpk.SecretKey,
curveID: kpk.CurveID,
}
}
}
func (kpk *kemPrivateKey) ToPublic() *KemPrivateKey {
if kpk == nil {
return nil
} else {
return &KemPrivateKey{
SecretKey: kpk.secretKey,
CurveID: kpk.curveID,
}
}
}