diff --git a/auth.go b/auth.go index 7c5675c..7726630 100644 --- a/auth.go +++ b/auth.go @@ -15,6 +15,9 @@ import ( "fmt" "hash" "io" + + circlPki "github.com/cloudflare/circl/pki" + circlSign "github.com/cloudflare/circl/sign" ) // verifyHandshakeSignature verifies a signature against pre-hashed @@ -55,7 +58,20 @@ func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc c return err } default: - return errors.New("internal error: unknown signature type") + // [UTLS SECTION BEGINS] + // Ported from cloudflare/go + scheme := circlSchemeBySigType(sigType) + if scheme == nil { + return errors.New("internal error: unknown signature type") + } + pubKey, ok := pubkey.(circlSign.PublicKey) + if !ok { + return fmt.Errorf("expected a %s public key, got %T", scheme.Name(), pubkey) + } + if !scheme.Verify(pubKey, signed, sig, nil) { + return fmt.Errorf("%s verification failure", scheme.Name()) + } + // [UTLS SECTION ENDS] } return nil } @@ -106,7 +122,18 @@ func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType case Ed25519: sigType = signatureEd25519 default: - return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + // [UTLS SECTION BEGINS] + // Ported from cloudflare/go + scheme := circlPki.SchemeByTLSID(uint(signatureAlgorithm)) + if scheme == nil { + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + sigType = sigTypeByCirclScheme(scheme) + if sigType == 0 { + return 0, 0, fmt.Errorf("circl scheme %s not supported", + scheme.Name()) + } + // [UTLS SECTION ENDS] } switch signatureAlgorithm { case PKCS1WithSHA1, ECDSAWithSHA1: @@ -120,7 +147,14 @@ func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType case Ed25519: hash = directSigning default: - return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + // [UTLS SECTION BEGINS] + // Ported from cloudflare/go + scheme := circlPki.SchemeByTLSID(uint(signatureAlgorithm)) + if scheme == nil { + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + hash = directSigning + // [UTLS SECTION ENDS] } return sigType, hash, nil } @@ -140,6 +174,11 @@ func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash c // full signature, and not even OpenSSL bothers with the // complexity, so we can't even test it properly. return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2") + // [UTLS SECTION BEGINS] + // Ported from cloudflare/go + case circlSign.PublicKey: + return 0, 0, fmt.Errorf("tls: circl public keys are not supported before TLS 1.2") + // [UTLS SECTION ENDS] default: return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub) } @@ -210,6 +249,16 @@ func signatureSchemesForCertificate(version uint16, cert *Certificate) []Signatu } case ed25519.PublicKey: sigAlgs = []SignatureScheme{Ed25519} + // [UTLS SECTION BEGINS] + // Ported from cloudflare/go + case circlSign.PublicKey: + scheme := pub.Scheme() + tlsScheme, ok := scheme.(circlPki.TLSScheme) + if !ok { + return nil + } + sigAlgs = []SignatureScheme{SignatureScheme(tlsScheme.TLSIdentifier())} + // [UTLS SECTION ENDS] default: return nil } diff --git a/auth_test.go b/auth_test.go index c23d93f..54cf15d 100644 --- a/auth_test.go +++ b/auth_test.go @@ -7,6 +7,8 @@ package tls import ( "crypto" "testing" + + circlPki "github.com/cloudflare/circl/pki" ) func TestSignatureSelection(t *testing.T) { @@ -161,7 +163,7 @@ func TestSupportedSignatureAlgorithms(t *testing.T) { if sigType == 0 { t.Errorf("%v: missing signature type", sigAlg) } - if hash == 0 && sigAlg != Ed25519 { + if hash == 0 && sigAlg != Ed25519 && circlPki.SchemeByTLSID(uint(sigAlg)) == nil { // [UTLS] ported from cloudflare/go t.Errorf("%v: missing hash", sigAlg) } } diff --git a/cfkem.go b/cfkem.go new file mode 100644 index 0000000..8d440e4 --- /dev/null +++ b/cfkem.go @@ -0,0 +1,101 @@ +// Copyright 2022 Cloudflare, Inc. All rights reserved. Use of this source code +// is governed by a BSD-style license that can be found in the LICENSE file. +// +// Glue to add Circl's (post-quantum) hybrid KEMs. +// +// To enable set CurvePreferences with the desired scheme as the first element: +// +// import ( +// "crypto/tls" +// +// [...] +// +// config.CurvePreferences = []tls.CurveID{ +// tls.X25519Kyber768Draft00, +// tls.X25519, +// tls.P256, +// } + +package tls + +import ( + "fmt" + "io" + + "crypto/ecdh" + + "github.com/cloudflare/circl/kem" + "github.com/cloudflare/circl/kem/hybrid" +) + +// Either *ecdh.PrivateKey or *kemPrivateKey +type clientKeySharePrivate interface{} + +type kemPrivateKey struct { + secretKey kem.PrivateKey + curveID CurveID +} + +var ( + X25519Kyber512Draft00 = CurveID(0xfe30) + X25519Kyber768Draft00 = CurveID(0x6399) + X25519Kyber768Draft00Old = CurveID(0xfe31) + P256Kyber768Draft00 = CurveID(0xfe32) + invalidCurveID = CurveID(0) +) + +// Extract CurveID from clientKeySharePrivate +func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID { + switch v := ks.(type) { + case *kemPrivateKey: + return v.curveID + case *ecdh.PrivateKey: + ret, ok := curveIDForCurve(v.Curve()) + if !ok { + panic("cfkem: internal error: unknown curve") + } + return ret + default: + panic("cfkem: internal error: unknown clientKeySharePrivate") + } +} + +// Returns scheme by CurveID if supported by Circl +func curveIdToCirclScheme(id CurveID) kem.Scheme { + switch id { + case X25519Kyber512Draft00: + return hybrid.Kyber512X25519() + case X25519Kyber768Draft00, X25519Kyber768Draft00Old: + return hybrid.Kyber768X25519() + case P256Kyber768Draft00: + return hybrid.P256Kyber768Draft00() + } + return nil +} + +// Generate a new shared secret and encapsulates it for the packed +// public key in ppk using randomness from rnd. +func encapsulateForKem(scheme kem.Scheme, rnd io.Reader, ppk []byte) ( + ct, ss []byte, alert alert, err error) { + pk, err := scheme.UnmarshalBinaryPublicKey(ppk) + if err != nil { + return nil, nil, alertIllegalParameter, fmt.Errorf("unpack pk: %w", err) + } + seed := make([]byte, scheme.EncapsulationSeedSize()) + if _, err := io.ReadFull(rnd, seed); err != nil { + return nil, nil, alertInternalError, fmt.Errorf("random: %w", err) + } + ct, ss, err = scheme.EncapsulateDeterministically(pk, seed) + return ct, ss, alertIllegalParameter, err +} + +// Generate a new keypair using randomness from rnd. +func generateKemKeyPair(scheme kem.Scheme, curveID CurveID, rnd io.Reader) ( + kem.PublicKey, *kemPrivateKey, error) { + seed := make([]byte, scheme.SeedSize()) + if _, err := io.ReadFull(rnd, seed); err != nil { + return nil, nil, err + } + pk, sk := scheme.DeriveKeyPair(seed) + return pk, &kemPrivateKey{sk, curveID}, nil +} diff --git a/cfkem_test.go b/cfkem_test.go new file mode 100644 index 0000000..7043d5f --- /dev/null +++ b/cfkem_test.go @@ -0,0 +1,107 @@ +// Copyright 2022 Cloudflare, Inc. All rights reserved. Use of this source code +// is governed by a BSD-style license that can be found in the LICENSE file. + +package tls + +import ( + "context" + "fmt" + "testing" +) + +func testHybridKEX(t *testing.T, curveID CurveID, clientPQ, serverPQ, + clientTLS12, serverTLS12 bool) { + // var clientSelectedKEX *CurveID + // var retry bool + + clientConfig := testConfig.Clone() + if clientPQ { + clientConfig.CurvePreferences = []CurveID{curveID, X25519} + } + // clientCFEventHandler := func(ev CFEvent) { + // switch e := ev.(type) { + // case CFEventTLSNegotiatedNamedKEX: + // clientSelectedKEX = &e.KEX + // case CFEventTLS13HRR: + // retry = true + // } + // } + if clientTLS12 { + clientConfig.MaxVersion = VersionTLS12 + } + + serverConfig := testConfig.Clone() + if serverPQ { + serverConfig.CurvePreferences = []CurveID{curveID, X25519} + } else { + serverConfig.CurvePreferences = []CurveID{X25519} + } + if serverTLS12 { + serverConfig.MaxVersion = VersionTLS12 + } + + c, s := localPipe(t) + done := make(chan error) + defer c.Close() + + go func() { + defer s.Close() + done <- Server(s, serverConfig).Handshake() + }() + + cli := Client(c, clientConfig) + // cCtx := context.WithValue(context.Background(), CFEventHandlerContextKey{}, clientCFEventHandler) + clientErr := cli.HandshakeContext(context.Background()) + serverErr := <-done + if clientErr != nil { + t.Errorf("client error: %s", clientErr) + } + if serverErr != nil { + t.Errorf("server error: %s", serverErr) + } + + // var expectedKEX CurveID + // var expectedRetry bool + + // if clientPQ && serverPQ && !clientTLS12 && !serverTLS12 { + // expectedKEX = curveID + // } else { + // expectedKEX = X25519 + // } + // if !clientTLS12 && clientPQ && !serverPQ { + // expectedRetry = true + // } + + // if expectedRetry != retry { + // t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry) + // } + + // if clientSelectedKEX == nil { + // t.Error("No KEX happened?") + // } else if *clientSelectedKEX != expectedKEX { + // t.Errorf("failed to negotiate: expected %d, got %d", + // expectedKEX, *clientSelectedKEX) + // } +} + +func TestHybridKEX(t *testing.T) { + run := func(curveID CurveID, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) { + t.Run(fmt.Sprintf("%#04x serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", uint16(curveID), + serverPQ, clientPQ, serverTLS12, clientTLS12), func(t *testing.T) { + testHybridKEX(t, curveID, clientPQ, serverPQ, clientTLS12, serverTLS12) + }) + } + for _, curveID := range []CurveID{ + X25519Kyber512Draft00, + X25519Kyber768Draft00, + X25519Kyber768Draft00Old, + P256Kyber768Draft00, + } { + run(curveID, true, true, false, false) + run(curveID, true, false, false, false) + run(curveID, false, true, false, false) + run(curveID, true, true, true, false) + run(curveID, true, true, false, true) + run(curveID, true, true, true, true) + } +} diff --git a/common.go b/common.go index 7c6eaf0..a636dde 100644 --- a/common.go +++ b/common.go @@ -189,6 +189,7 @@ const ( signatureRSAPSS signatureECDSA signatureEd25519 + signatureEdDilithium3 ) // directSigning is a standard Hash value that signals that no pre-hashing @@ -780,6 +781,11 @@ type Config struct { // its key share in TLS 1.3. This may change in the future. CurvePreferences []CurveID + // PQSignatureSchemesEnabled controls whether additional post-quantum + // signature schemes are supported for peer certificates. For available + // signature schemes, see tls_cf.go. + PQSignatureSchemesEnabled bool // [UTLS] ported from cloudflare/go + // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. // When true, the largest possible TLS record size is always used. When // false, the size of TLS records may be adjusted in an attempt to @@ -885,6 +891,7 @@ func (c *Config) Clone() *Config { MinVersion: c.MinVersion, MaxVersion: c.MaxVersion, CurvePreferences: c.CurvePreferences, + PQSignatureSchemesEnabled: c.PQSignatureSchemesEnabled, // [UTLS] DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, Renegotiation: c.Renegotiation, KeyLogWriter: c.KeyLogWriter, diff --git a/generate_cert.go b/generate_cert.go index cd4bfc5..dce68f7 100644 --- a/generate_cert.go +++ b/generate_cert.go @@ -25,6 +25,9 @@ import ( "os" "strings" "time" + + circlSign "github.com/cloudflare/circl/sign" + circlSchemes "github.com/cloudflare/circl/sign/schemes" ) var ( @@ -35,6 +38,7 @@ var ( rsaBits = flag.Int("rsa-bits", 2048, "Size of RSA key to generate. Ignored if --ecdsa-curve is set") ecdsaCurve = flag.String("ecdsa-curve", "", "ECDSA curve to use to generate a key. Valid values are P224, P256 (recommended), P384, P521") ed25519Key = flag.Bool("ed25519", false, "Generate an Ed25519 key") + circlKey = flag.String("circl", "", "Generate a key supported by Circl") // [UTLS] ported from cloudflare/go ) func publicKey(priv any) any { @@ -45,6 +49,11 @@ func publicKey(priv any) any { return &k.PublicKey case ed25519.PrivateKey: return k.Public().(ed25519.PublicKey) + // [UTLS SECTION BEGINS] + // Ported from cloudflare/go + case circlSign.PrivateKey: + return k.Public() + // [UTLS SECTION ENDS] default: return nil } @@ -63,6 +72,15 @@ func main() { case "": if *ed25519Key { _, priv, err = ed25519.GenerateKey(rand.Reader) + // [UTLS SECTION BEGINS] + // Ported from cloudflare/go + } else if *circlKey != "" { + scheme := circlSchemes.ByName(*circlKey) + if scheme == nil { + log.Fatalf("No such Circl scheme: %s", *circlKey) + } + _, priv, err = scheme.GenerateKey() + // [UTLS SECTION ENDS] } else { priv, err = rsa.GenerateKey(rand.Reader, *rsaBits) } diff --git a/go.mod b/go.mod index 7be48a1..cef850a 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ retract ( require ( github.com/andybalholm/brotli v1.0.5 + github.com/cloudflare/circl v1.3.3 github.com/gaukas/godicttls v0.0.4 github.com/klauspost/compress v1.16.7 github.com/quic-go/quic-go v0.37.4 diff --git a/go.sum b/go.sum index b807fd5..00b89e1 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/cloudflare/circl v1.3.3 h1:fE/Qz0QdIGqeWfnwq0RE0R7MI51s0M2E4Ga9kq5AEMs= +github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= github.com/gaukas/godicttls v0.0.4 h1:NlRaXb3J6hAnTmWdsEKb9bcSBD6BvcIjdGdeb0zfXbk= github.com/gaukas/godicttls v0.0.4/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= diff --git a/handshake_client.go b/handshake_client.go index 088baa5..f0df1ef 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -21,6 +21,8 @@ import ( "net" "strings" "time" + + circlSign "github.com/cloudflare/circl/sign" ) type clientHandshakeState struct { @@ -39,7 +41,7 @@ type clientHandshakeState struct { var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme -func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { +func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error) { config := c.config // [UTLS SECTION START] @@ -130,13 +132,13 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { } if hello.vers >= VersionTLS12 { - hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + hello.supportedSignatureAlgorithms = config.supportedSignatureAlgorithms() // [UTLS] ported from cloudflare/go } if testingOnlyForceClientHelloSignatureAlgorithms != nil { hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms } - var key *ecdh.PrivateKey + var secret clientKeySharePrivate if hello.supportedVersions[0] == VersionTLS13 { // Reset the list of ciphers when the client only supports TLS 1.3. if len(hello.supportedVersions) == 1 { @@ -149,14 +151,31 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { } curveID := config.curvePreferences()[0] - if _, ok := curveForCurveID(curveID); !ok { - return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + // [UTLS SECTION BEGINS] + // Ported from cloudflare/go with modifications to preserve crypto/tls compatibility + if scheme := curveIdToCirclScheme(curveID); scheme != nil { + pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand()) + if err != nil { + return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w", scheme.Name(), err) + } + packedPk, err := pk.MarshalBinary() + if err != nil { + return nil, nil, fmt.Errorf("pack circl public key %s: %w", scheme.Name(), err) + } + hello.keyShares = []keyShare{{group: curveID, data: packedPk}} + secret = sk + } else { + if _, ok := curveForCurveID(curveID); !ok { + return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + key, err := generateECDHEKey(config.rand(), curveID) + if err != nil { + return nil, nil, err + } + hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} + secret = key } - key, err = generateECDHEKey(config.rand(), curveID) - if err != nil { - return nil, nil, err - } - hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} + // [UTLS SECTION ENDS] } if c.quic != nil { @@ -170,7 +189,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { hello.quicTransportParameters = p } - return hello, key, nil + return hello, secret, nil } func (c *Conn) clientHandshake(ctx context.Context) (err error) { @@ -182,7 +201,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { // need to be reset. c.didResume = false - hello, ecdheKey, err := c.makeClientHello() + hello, keySharePrivate, err := c.makeClientHello() if err != nil { return err } @@ -256,7 +275,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { ctx: ctx, serverHello: serverHello, hello: hello, - ecdheKey: ecdheKey, + // ecdheKey: ecdheKey, session: session, earlySecret: earlySecret, binderKey: binderKey, @@ -264,6 +283,14 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { keySharesEcdheParams: make(KeySharesEcdheParameters, 2), // [uTLS] } + if ecdheKey, ok := keySharePrivate.(*ecdh.PrivateKey); ok { + hs.ecdheKey = ecdheKey + } else if kemKey, ok := keySharePrivate.(*kemPrivateKey); ok { + hs.kemKey = kemKey + } else { + return fmt.Errorf("tls: unknown key share type %T", keySharePrivate) + } + // In TLS 1.3, session tickets are delivered after the handshake. return hs.handshake() } @@ -1019,7 +1046,7 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { } switch certs[0].PublicKey.(type) { - case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: + case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey, circlSign.PublicKey: // [UTLS] ported from cloudflare/go break default: c.sendAlert(alertUnsupportedCertificate) diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index 7e307c1..858f0b8 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -35,17 +35,19 @@ func (keymap KeySharesEcdheParameters) GetPublicEcdheParams(curveID CurveID) (pa // [uTLS SECTION END] type clientHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - ecdheKey *ecdh.PrivateKey - + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg + ecdheKey *ecdh.PrivateKey keySharesEcdheParams KeySharesEcdheParameters // [uTLS] + kemKey *kemPrivateKey // [uTLS] + // keySharesCirclParams KeySharesCirclParameters // [uTLS] TODO: perhaps implement? - session *SessionState - earlySecret []byte - binderKey []byte + session *SessionState + earlySecret []byte + binderKey []byte + selectedGroup CurveID certReq *certificateRequestMsgTLS13 usingPSK bool @@ -83,7 +85,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { // [uTLS SECTION END] // Consistency check on the presence of a keyShare and its parameters. - if hs.ecdheKey == nil || len(hs.hello.keyShares) < 1 { // [uTLS] + if (hs.ecdheKey == nil && hs.kemKey == nil) || len(hs.hello.keyShares) < 1 { // [uTLS] // keyshares "< 1" instead of "!= 1", as uTLS may send multiple return c.sendAlert(alertInternalError) } @@ -268,21 +270,55 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { c.sendAlert(alertIllegalParameter) return errors.New("tls: server selected unsupported group") } - if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); sentID == curveID { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") - } - if _, ok := curveForCurveID(curveID); !ok { + + // [UTLS SECTION BEGINS] + // ported from cloudflare/go, slightly modified to maintain compatibility with crypto/tls upstream + if hs.ecdheKey != nil { + if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); sentID == curveID { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") + } + } else if hs.kemKey != nil { + if clientKeySharePrivateCurveID(hs.kemKey) == curveID { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") + } + } else { c.sendAlert(alertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") + return errors.New("tls: ecdheKey and kemKey are both nil") } - key, err := generateECDHEKey(c.config.rand(), curveID) - if err != nil { - c.sendAlert(alertInternalError) - return err + + if scheme := curveIdToCirclScheme(curveID); scheme != nil { + pk, sk, err := generateKemKeyPair(scheme, curveID, c.config.rand()) + if err != nil { + c.sendAlert(alertInternalError) + return fmt.Errorf("HRR generateKemKeyPair %s: %w", + scheme.Name(), err) + } + packedPk, err := pk.MarshalBinary() + if err != nil { + c.sendAlert(alertInternalError) + return fmt.Errorf("HRR pack circl public key %s: %w", + scheme.Name(), err) + } + hs.kemKey = sk + hs.ecdheKey = nil // unset ecdheKey if any + hs.hello.keyShares = []keyShare{{group: curveID, data: packedPk}} + } else { + if _, ok := curveForCurveID(curveID); !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + key, err := generateECDHEKey(c.config.rand(), curveID) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.ecdheKey = key + hs.kemKey = nil // unset kemKey if any + hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} } - hs.ecdheKey = key - hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} + // [UTLS SECTION ENDS] } hs.hello.raw = nil @@ -430,9 +466,19 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { c.sendAlert(alertIllegalParameter) return errors.New("tls: server did not send a key share") } - if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group != sentID { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported group") + if hs.ecdheKey != nil { + if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group != sentID { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + } else if hs.kemKey != nil { + if clientKeySharePrivateCurveID(hs.kemKey) != hs.serverHello.serverShare.group { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + } else { + c.sendAlert(alertInternalError) + return errors.New("tls: ecdheKey and kemKey are both nil") } if !hs.serverHello.selectedIdentityPresent { @@ -469,16 +515,34 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { c := hs.c - peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server key share") - } - sharedKey, err := hs.ecdheKey.ECDH(peerKey) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server key share") + // [UTLS SECTION BEGINS] + // ported from cloudflare/go, slightly modified to maintain compatibility with crypto/tls upstream + var sharedKey []byte + var err error + + if hs.ecdheKey != nil { + peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data) + if err != nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server key share") + } + sharedKey, err = hs.ecdheKey.ECDH(peerKey) + if err != nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server key share") + } + } else if hs.kemKey != nil { + sk := hs.kemKey.secretKey + sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data) + if err != nil { + c.sendAlert(alertIllegalParameter) + return fmt.Errorf("%s decaps: %w", sk.Scheme().Name(), err) + } + } else { + c.sendAlert(alertInternalError) + return errors.New("tls: ecdheKey and circlKey are both nil") } + // [UTLS SECTION ENDS] earlySecret := hs.earlySecret if !hs.usingPSK { @@ -680,7 +744,7 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { } // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, c.config.supportedSignatureAlgorithms()) { // [UTLS] ported from cloudflare/go c.sendAlert(alertIllegalParameter) return errors.New("tls: certificate used with invalid signature algorithm") } diff --git a/handshake_server.go b/handshake_server.go index dcbda21..c29e9a3 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -17,6 +17,8 @@ import ( "hash" "io" "time" + + circlSign "github.com/cloudflare/circl/sign" ) // serverHandshakeState contains details of a server handshake in progress. @@ -593,7 +595,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { } if c.vers >= VersionTLS12 { certReq.hasSignatureAlgorithm = true - certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + certReq.supportedSignatureAlgorithms = c.config.supportedSignatureAlgorithms() // [UTLS] ported from cloudflare/go } // An empty list of certificateAuthorities signals to @@ -917,7 +919,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { if len(certs) > 0 { switch certs[0].PublicKey.(type) { - case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: + case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey, circlSign.PublicKey: // [UTLS] ported from cloudflare/go default: c.sendAlert(alertUnsupportedCertificate) return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 07b1a38..43192b3 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -12,6 +12,7 @@ import ( "crypto/rsa" "encoding/binary" "errors" + "fmt" "hash" "io" "time" @@ -33,6 +34,7 @@ type serverHandshakeStateTLS13 struct { suite *cipherSuiteTLS13 cert *Certificate sigAlg SignatureScheme + selectedGroup CurveID earlySecret []byte sharedKey []byte handshakeSecret []byte @@ -211,23 +213,31 @@ GroupSelection: clientKeyShare = &hs.clientHello.keyShares[0] } - if _, ok := curveForCurveID(selectedGroup); !ok { + if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && curveIdToCirclScheme(selectedGroup) == nil && !ok { c.sendAlert(alertInternalError) return errors.New("tls: CurvePreferences includes unsupported curve") } - key, err := generateECDHEKey(c.config.rand(), selectedGroup) - if err != nil { - c.sendAlert(alertInternalError) - return err + if kem := curveIdToCirclScheme(selectedGroup); kem != nil { + ct, ss, alert, err := encapsulateForKem(kem, c.config.rand(), clientKeyShare.data) + if err != nil { + c.sendAlert(alert) + return fmt.Errorf("%s encap: %w", kem.Name(), err) + } + hs.hello.serverShare = keyShare{group: selectedGroup, data: ct} + hs.sharedKey = ss + } else { + key, err := generateECDHEKey(c.config.rand(), selectedGroup) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()} + peerKey, err := key.Curve().NewPublicKey(clientKeyShare.data) + if err == nil { + hs.sharedKey, _ = key.ECDH(peerKey) + } } - hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()} - peerKey, err := key.Curve().NewPublicKey(clientKeyShare.data) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid client key share") - } - hs.sharedKey, err = key.ECDH(peerKey) - if err != nil { + if hs.sharedKey == nil { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid client key share") } @@ -670,7 +680,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { certReq := new(certificateRequestMsgTLS13) certReq.ocspStapling = true certReq.scts = true - certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + certReq.supportedSignatureAlgorithms = c.config.supportedSignatureAlgorithms() // [UTLS] ported from cloudflare/go if c.config.ClientCAs != nil { certReq.certificateAuthorities = c.config.ClientCAs.Subjects() } @@ -932,7 +942,7 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { } // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, c.config.supportedSignatureAlgorithms()) { // [UTLS] ported from cloudflare/go c.sendAlert(alertIllegalParameter) return errors.New("tls: client certificate used with invalid signature algorithm") } diff --git a/key_agreement.go b/key_agreement.go index 2c8c5b8..3c73345 100644 --- a/key_agreement.go +++ b/key_agreement.go @@ -130,7 +130,7 @@ func md5SHA1Hash(slices [][]byte) []byte { // the sigType (for earlier TLS versions). For Ed25519 signatures, which don't // do pre-hashing, it returns the concatenation of the slices. func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { - if sigType == signatureEd25519 { + if sigType == signatureEd25519 || circlSchemeBySigType(sigType) != nil { // [UTLS] ported from cloudflare/go var signed []byte for _, slice := range slices { signed = append(signed, slice...) @@ -169,7 +169,7 @@ type ecdheKeyAgreement struct { func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { var curveID CurveID for _, c := range clientHello.supportedCurves { - if config.supportsCurve(c) { + if config.supportsCurve(c) && curveIdToCirclScheme(c) == nil { curveID = c break } diff --git a/prf.go b/prf.go index 20bac96..c2a024f 100644 --- a/prf.go +++ b/prf.go @@ -225,11 +225,11 @@ func (h finishedHash) serverSum(masterSecret []byte) []byte { // hashForClientCertificate returns the handshake messages so far, pre-hashed if // necessary, suitable for signing by a TLS client certificate. func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash) []byte { - if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { + if (h.version >= VersionTLS12 || sigType == signatureEd25519 || circlSchemeBySigType(sigType) != nil) && h.buffer == nil { // [UTLS] ported from cloudflare/go panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") } - if sigType == signatureEd25519 { + if sigType == signatureEd25519 || circlSchemeBySigType(sigType) != nil { // [UTLS] ported from cloudflare/go return h.buffer } diff --git a/tls.go b/tls.go index b529c70..c4f6f39 100644 --- a/tls.go +++ b/tls.go @@ -25,6 +25,8 @@ import ( "net" "os" "strings" + + circlSign "github.com/cloudflare/circl/sign" ) // Server returns a new TLS server side connection @@ -326,6 +328,20 @@ func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { return fail(errors.New("tls: private key does not match public key")) } + // [UTLS SECTION BEGINS] + // Ported from cloudflare/go + case circlSign.PublicKey: + priv, ok := cert.PrivateKey.(circlSign.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + pkBytes, err := priv.Public().(circlSign.PublicKey).MarshalBinary() + pkBytes2, err2 := pub.MarshalBinary() + + if err != nil || err2 != nil || !bytes.Equal(pkBytes, pkBytes2) { + return fail(errors.New("tls: private key does not match public key")) + } + // [UTLS SECTION ENDS] default: return fail(errors.New("tls: unknown public key algorithm")) } @@ -342,7 +358,7 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { } if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { switch key := key.(type) { - case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: + case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey, circlSign.PrivateKey: return key, nil default: return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") diff --git a/tls_cf.go b/tls_cf.go new file mode 100644 index 0000000..8160be0 --- /dev/null +++ b/tls_cf.go @@ -0,0 +1,66 @@ +// Copyright 2021 Cloudflare, Inc. All rights reserved. Use of this source code +// is governed by a BSD-style license that can be found in the LICENSE file. + +package tls + +import ( + circlPki "github.com/cloudflare/circl/pki" + circlSign "github.com/cloudflare/circl/sign" + "github.com/cloudflare/circl/sign/eddilithium3" +) + +// To add a signature scheme from Circl +// +// 1. make sure it implements TLSScheme and CertificateScheme, +// 2. follow the instructions in crypto/x509/x509_cf.go +// 3. add a signature to the iota in common.go +// 4. add row in the circlSchemes lists below + +var circlSchemes = [...]struct { + sigType uint8 + scheme circlSign.Scheme +}{ + {signatureEdDilithium3, eddilithium3.Scheme()}, +} + +func circlSchemeBySigType(sigType uint8) circlSign.Scheme { + for _, cs := range circlSchemes { + if cs.sigType == sigType { + return cs.scheme + } + } + return nil +} + +func sigTypeByCirclScheme(scheme circlSign.Scheme) uint8 { + for _, cs := range circlSchemes { + if cs.scheme == scheme { + return cs.sigType + } + } + return 0 +} + +var supportedSignatureAlgorithmsWithCircl []SignatureScheme + +// supportedSignatureAlgorithms returns enabled signature schemes. PQ signature +// schemes are only included when tls.Config#PQSignatureSchemesEnabled is set +// and FIPS-only mode is not enabled. +func (c *Config) supportedSignatureAlgorithms() []SignatureScheme { + // If FIPS-only mode is requested, do not add other algos. + if needFIPS() { + return supportedSignatureAlgorithms() + } + if c != nil && c.PQSignatureSchemesEnabled { + return supportedSignatureAlgorithmsWithCircl + } + return defaultSupportedSignatureAlgorithms +} + +func init() { + supportedSignatureAlgorithmsWithCircl = append([]SignatureScheme{}, defaultSupportedSignatureAlgorithms...) + for _, cs := range circlSchemes { + supportedSignatureAlgorithmsWithCircl = append(supportedSignatureAlgorithmsWithCircl, + SignatureScheme(cs.scheme.(circlPki.TLSScheme).TLSIdentifier())) + } +} diff --git a/tls_test.go b/tls_test.go index f4e282f..923fb29 100644 --- a/tls_test.go +++ b/tls_test.go @@ -866,6 +866,8 @@ func TestCloneNonFuncFields(t *testing.T) { f.Set(reflect.ValueOf([]uint16{1, 2})) case "CurvePreferences": f.Set(reflect.ValueOf([]CurveID{CurveP256})) + case "PQSignatureSchemesEnabled": // [UTLS] ported from cloudflare/go + f.Set(reflect.ValueOf(true)) case "Renegotiation": f.Set(reflect.ValueOf(RenegotiateOnceAsClient)) case "mutex", "autoSessionTicketKeys", "sessionTicketKeys": diff --git a/u_conn.go b/u_conn.go index ef79b58..d3348a1 100644 --- a/u_conn.go +++ b/u_conn.go @@ -9,6 +9,7 @@ import ( "bytes" "context" "crypto/cipher" + "crypto/ecdh" "encoding/binary" "errors" "fmt" @@ -76,13 +77,19 @@ func (uconn *UConn) BuildHandshakeState() error { } // use default Golang ClientHello. - hello, ecdheKey, err := uconn.makeClientHello() + hello, keySharePrivate, err := uconn.makeClientHello() if err != nil { return err } uconn.HandshakeState.Hello = hello.getPublicPtr() - uconn.HandshakeState.State13.EcdheKey = ecdheKey + if ecdheKey, ok := keySharePrivate.(*ecdh.PrivateKey); ok { + uconn.HandshakeState.State13.EcdheKey = ecdheKey + } else if kemKey, ok := keySharePrivate.(*kemPrivateKey); ok { + uconn.HandshakeState.State13.KEMKey = kemKey.ToPublic() + } else { + return fmt.Errorf("uTLS: unknown keySharePrivate type: %T", keySharePrivate) + } uconn.HandshakeState.C = uconn.Conn } else { if !uconn.ClientHelloBuilt { diff --git a/u_public.go b/u_public.go index 3699b80..ba83cea 100644 --- a/u_public.go +++ b/u_public.go @@ -10,6 +10,8 @@ import ( "crypto/x509" "hash" "time" + + "github.com/cloudflare/circl/kem" ) // ClientHandshakeState includes both TLS 1.3-only and TLS 1.2-only states, @@ -39,6 +41,7 @@ type TLS13OnlyState struct { Suite *PubCipherSuiteTLS13 EcdheKey *ecdh.PrivateKey KeySharesEcdheParams KeySharesEcdheParameters + KEMKey *KemPrivateKey EarlySecret []byte BinderKey []byte CertReq *CertificateRequestMsgTLS13 @@ -64,6 +67,7 @@ func (chs *PubClientHandshakeState) toPrivate13() *clientHandshakeStateTLS13 { hello: chs.Hello.getPrivatePtr(), ecdheKey: chs.State13.EcdheKey, keySharesEcdheParams: chs.State13.KeySharesEcdheParams, + kemKey: chs.State13.KEMKey.ToPrivate(), session: chs.Session, earlySecret: chs.State13.EarlySecret, @@ -89,6 +93,7 @@ func (chs13 *clientHandshakeStateTLS13) toPublic13() *PubClientHandshakeState { tls13State := TLS13OnlyState{ KeySharesEcdheParams: chs13.keySharesEcdheParams, EcdheKey: chs13.ecdheKey, + KEMKey: chs13.kemKey.ToPublic(), EarlySecret: chs13.earlySecret, BinderKey: chs13.binderKey, CertReq: chs13.certReq.toPublic(), @@ -734,3 +739,30 @@ func (TKS TicketKeys) ToPrivate() []ticketKey { } return tks } + +type KemPrivateKey struct { + SecretKey kem.PrivateKey + CurveID CurveID +} + +func (kpk *KemPrivateKey) ToPrivate() *kemPrivateKey { + if kpk == nil { + return nil + } else { + return &kemPrivateKey{ + secretKey: kpk.SecretKey, + curveID: kpk.CurveID, + } + } +} + +func (kpk *kemPrivateKey) ToPublic() *KemPrivateKey { + if kpk == nil { + return nil + } else { + return &KemPrivateKey{ + SecretKey: kpk.secretKey, + CurveID: kpk.curveID, + } + } +}