new: enable PQ parrots

* Redesign KeySharesEcdheParameters into KeySharesParameters which supports multiple types of keys.

* Optimize program logic to prevent using unwanted keys
This commit is contained in:
Gaukas Wang 2023-08-12 20:20:29 -06:00
parent da99cf7fe2
commit db12e32ba7
No known key found for this signature in database
GPG key ID: 9E2F8986D76F8B5D
5 changed files with 192 additions and 97 deletions

View file

@ -138,7 +138,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms
} }
var secret clientKeySharePrivate var secret clientKeySharePrivate // [UTLS]
if hello.supportedVersions[0] == VersionTLS13 { if hello.supportedVersions[0] == VersionTLS13 {
// Reset the list of ciphers when the client only supports TLS 1.3. // Reset the list of ciphers when the client only supports TLS 1.3.
if len(hello.supportedVersions) == 1 { if len(hello.supportedVersions) == 1 {
@ -280,7 +280,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
earlySecret: earlySecret, earlySecret: earlySecret,
binderKey: binderKey, binderKey: binderKey,
keySharesEcdheParams: make(KeySharesEcdheParameters, 2), // [uTLS] keySharesParams: NewKeySharesParameters(), // [uTLS]
} }
if ecdheKey, ok := keySharePrivate.(*ecdh.PrivateKey); ok { if ecdheKey, ok := keySharePrivate.(*ecdh.PrivateKey); ok {

View file

@ -15,34 +15,74 @@ import (
"fmt" "fmt"
"hash" "hash"
"time" "time"
"github.com/cloudflare/circl/kem"
) )
// [uTLS SECTION START] // [uTLS SECTION START]
type KeySharesEcdheParameters map[CurveID]*ecdh.PrivateKey // KeySharesParameters serves as a in-memory storage for generated keypairs by UTLS when generating
// ClientHello. It is used to store both ecdhe and kem keypairs.
type KeySharesParameters struct {
ecdhePrivKeymap map[CurveID]*ecdh.PrivateKey
ecdhePubKeymap map[CurveID]*ecdh.PublicKey
func (keymap KeySharesEcdheParameters) AddEcdheParams(curveID CurveID, ecdheKey *ecdh.PrivateKey) { // based on cloudflare/go
keymap[curveID] = ecdheKey kemPrivKeymap map[CurveID]kem.PrivateKey
kemPubKeymap map[CurveID]kem.PublicKey
} }
func (keymap KeySharesEcdheParameters) GetEcdheParams(curveID CurveID) (ecdheKey *ecdh.PrivateKey, ok bool) {
ecdheKey, ok = keymap[curveID] func NewKeySharesParameters() *KeySharesParameters {
return &KeySharesParameters{
ecdhePrivKeymap: make(map[CurveID]*ecdh.PrivateKey),
ecdhePubKeymap: make(map[CurveID]*ecdh.PublicKey),
kemPrivKeymap: make(map[CurveID]kem.PrivateKey),
kemPubKeymap: make(map[CurveID]kem.PublicKey),
}
}
func (ksp *KeySharesParameters) AddEcdheKeypair(curveID CurveID, ecdheKey *ecdh.PrivateKey, ecdhePubKey *ecdh.PublicKey) {
ksp.ecdhePrivKeymap[curveID] = ecdheKey
ksp.ecdhePubKeymap[curveID] = ecdhePubKey
}
func (ksp *KeySharesParameters) GetEcdheKey(curveID CurveID) (ecdheKey *ecdh.PrivateKey, ok bool) {
ecdheKey, ok = ksp.ecdhePrivKeymap[curveID]
return return
} }
func (keymap KeySharesEcdheParameters) GetPublicEcdheParams(curveID CurveID) (params *ecdh.PrivateKey, ok bool) {
params, ok = keymap[curveID] func (ksp *KeySharesParameters) GetEcdhePubkey(curveID CurveID) (params *ecdh.PublicKey, ok bool) {
params, ok = ksp.ecdhePubKeymap[curveID]
return
}
func (ksp *KeySharesParameters) AddKemKeypair(curveID CurveID, kemKey kem.PrivateKey, kemPubKey kem.PublicKey) {
if curveIdToCirclScheme(curveID) != nil { // only store for circl schemes
ksp.kemPrivKeymap[curveID] = kemKey
ksp.kemPubKeymap[curveID] = kemPubKey
}
}
func (ksp *KeySharesParameters) GetKemKey(curveID CurveID) (kemKey kem.PrivateKey, ok bool) {
kemKey, ok = ksp.kemPrivKeymap[curveID]
return
}
func (ksp *KeySharesParameters) GetKemPubkey(curveID CurveID) (params kem.PublicKey, ok bool) {
params, ok = ksp.kemPubKeymap[curveID]
return return
} }
// [uTLS SECTION END] // [uTLS SECTION END]
type clientHandshakeStateTLS13 struct { type clientHandshakeStateTLS13 struct {
c *Conn c *Conn
ctx context.Context ctx context.Context
serverHello *serverHelloMsg serverHello *serverHelloMsg
hello *clientHelloMsg hello *clientHelloMsg
ecdheKey *ecdh.PrivateKey ecdheKey *ecdh.PrivateKey
keySharesEcdheParams KeySharesEcdheParameters // [uTLS] kemKey *kemPrivateKey // [uTLS] ported from cloudflare/go
kemKey *kemPrivateKey // [uTLS] keySharesParams *KeySharesParameters // [uTLS] support both ecdhe and kem
// keySharesCirclParams KeySharesCirclParameters // [uTLS] TODO: perhaps implement?
session *SessionState session *SessionState
earlySecret []byte earlySecret []byte
@ -77,10 +117,18 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
} }
// [uTLS SECTION START] // [uTLS SECTION START]
// set echdheParams to what we received from server // set echdheParams to what we received from server
if ecdheKey, ok := hs.keySharesEcdheParams.GetEcdheParams(hs.serverHello.serverShare.group); ok { if ecdheKey, ok := hs.keySharesParams.GetEcdheKey(hs.serverHello.serverShare.group); ok {
hs.ecdheKey = ecdheKey hs.ecdheKey = ecdheKey
hs.kemKey = nil // unset kemKey if any
}
// set kemParams to what we received from server
if kemKey, ok := hs.keySharesParams.GetKemKey(hs.serverHello.serverShare.group); ok {
hs.kemKey = &kemPrivateKey{
secretKey: kemKey,
curveID: hs.serverHello.serverShare.group,
}
hs.ecdheKey = nil // unset ecdheKey if any
} }
// [uTLS SECTION END] // [uTLS SECTION END]
@ -466,20 +514,22 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
c.sendAlert(alertIllegalParameter) c.sendAlert(alertIllegalParameter)
return errors.New("tls: server did not send a key share") return errors.New("tls: server did not send a key share")
} }
if hs.ecdheKey != nil {
if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group != sentID { // [UTLS SECTION BEGINS]
c.sendAlert(alertIllegalParameter) var supportedGroupCompatible bool
return errors.New("tls: server selected unsupported group") if hs.ecdheKey != nil { // if we did send ECDHE KeyShare
if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group == sentID { // and server selected ECDHE KeyShare
supportedGroupCompatible = true
} }
} else if hs.kemKey != nil {
if clientKeySharePrivateCurveID(hs.kemKey) != hs.serverHello.serverShare.group {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
} else {
c.sendAlert(alertInternalError)
return errors.New("tls: ecdheKey and kemKey are both nil")
} }
if hs.kemKey != nil && clientKeySharePrivateCurveID(hs.kemKey) == hs.serverHello.serverShare.group { // we did send KEM KeyShare and server selected KEM KeyShare
supportedGroupCompatible = true
}
if !supportedGroupCompatible { // none matched
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
// [UTLS SECTION ENDS]
if !hs.serverHello.selectedIdentityPresent { if !hs.serverHello.selectedIdentityPresent {
return nil return nil
@ -521,24 +571,28 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
var err error var err error
if hs.ecdheKey != nil { if hs.ecdheKey != nil {
peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data) if ecdheCurveID, _ := curveIDForCurve(hs.ecdheKey.Curve()); ecdheCurveID == hs.serverHello.serverShare.group {
if err != nil { peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data)
c.sendAlert(alertIllegalParameter) if err != nil {
return errors.New("tls: invalid server key share") c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server key share")
}
sharedKey, err = hs.ecdheKey.ECDH(peerKey)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server key share")
}
} }
sharedKey, err = hs.ecdheKey.ECDH(peerKey) }
if err != nil { if sharedKey == nil && hs.kemKey != nil && clientKeySharePrivateCurveID(hs.kemKey) == hs.serverHello.serverShare.group {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server key share")
}
} else if hs.kemKey != nil {
sk := hs.kemKey.secretKey sk := hs.kemKey.secretKey
sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data) sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data)
if err != nil { if err != nil {
c.sendAlert(alertIllegalParameter) c.sendAlert(alertIllegalParameter)
return fmt.Errorf("%s decaps: %w", sk.Scheme().Name(), err) return fmt.Errorf("%s decaps: %w", sk.Scheme().Name(), err)
} }
} else { }
if sharedKey == nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return errors.New("tls: ecdheKey and circlKey are both nil") return errors.New("tls: ecdheKey and circlKey are both nil")
} }

View file

@ -7,7 +7,6 @@ package tls
import ( import (
"bytes" "bytes"
"compress/zlib" "compress/zlib"
"crypto/ecdh"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -164,7 +163,7 @@ func (hs *clientHandshakeStateTLS13) utlsReadServerParameters(encryptedExtension
return nil return nil
} }
func (c *Conn) makeClientHelloForApplyPreset() (*clientHelloMsg, *ecdh.PrivateKey, error) { func (c *Conn) makeClientHelloForApplyPreset() (*clientHelloMsg, clientKeySharePrivate, error) {
config := c.config config := c.config
// [UTLS SECTION START] // [UTLS SECTION START]
@ -261,7 +260,7 @@ func (c *Conn) makeClientHelloForApplyPreset() (*clientHelloMsg, *ecdh.PrivateKe
hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms
} }
var key *ecdh.PrivateKey var secret clientKeySharePrivate // [UTLS]
if hello.supportedVersions[0] == VersionTLS13 { if hello.supportedVersions[0] == VersionTLS13 {
// Reset the list of ciphers when the client only supports TLS 1.3. // Reset the list of ciphers when the client only supports TLS 1.3.
if len(hello.supportedVersions) == 1 { if len(hello.supportedVersions) == 1 {
@ -273,15 +272,32 @@ func (c *Conn) makeClientHelloForApplyPreset() (*clientHelloMsg, *ecdh.PrivateKe
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...)
} }
curveID := config.curvePreferences()[0] // curveID := config.curvePreferences()[0]
if _, ok := curveForCurveID(curveID); !ok { // // [UTLS SECTION BEGINS]
return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") // // Ported from cloudflare/go with modifications to preserve crypto/tls compatibility
} // if scheme := curveIdToCirclScheme(curveID); scheme != nil {
key, err = generateECDHEKey(config.rand(), curveID) // pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand())
if err != nil { // if err != nil {
return nil, nil, err // return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w", scheme.Name(), err)
} // }
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} // packedPk, err := pk.MarshalBinary()
// if err != nil {
// return nil, nil, fmt.Errorf("pack circl public key %s: %w", scheme.Name(), err)
// }
// hello.keyShares = []keyShare{{group: curveID, data: packedPk}}
// secret = sk
// } else {
// if _, ok := curveForCurveID(curveID); !ok {
// return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
// }
// key, err := generateECDHEKey(config.rand(), curveID)
// if err != nil {
// return nil, nil, err
// }
// hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
// secret = key
// }
// // [UTLS SECTION ENDS]
} }
// [UTLS] We don't need this, since it is not ready yet // [UTLS] We don't need this, since it is not ready yet
@ -296,5 +312,5 @@ func (c *Conn) makeClientHelloForApplyPreset() (*clientHelloMsg, *ecdh.PrivateKe
// hello.quicTransportParameters = p // hello.quicTransportParameters = p
// } // }
return hello, key, nil return hello, secret, nil
} }

View file

@ -5,6 +5,7 @@
package tls package tls
import ( import (
"crypto/ecdh"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"errors" "errors"
@ -2013,13 +2014,17 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
return err return err
} }
privateHello, ecdheKey, err := uconn.makeClientHelloForApplyPreset() privateHello, clientKeySharePrivate, err := uconn.makeClientHelloForApplyPreset()
if err != nil { if err != nil {
return err return err
} }
uconn.HandshakeState.Hello = privateHello.getPublicPtr() uconn.HandshakeState.Hello = privateHello.getPublicPtr()
uconn.HandshakeState.State13.EcdheKey = ecdheKey if ecdheKey, ok := clientKeySharePrivate.(*ecdh.PrivateKey); ok {
uconn.HandshakeState.State13.KeySharesEcdheParams = make(KeySharesEcdheParameters, 2) uconn.HandshakeState.State13.EcdheKey = ecdheKey
} else if kemKey, ok := clientKeySharePrivate.(*kemPrivateKey); ok {
uconn.HandshakeState.State13.KEMKey = kemKey.ToPublic()
}
uconn.HandshakeState.State13.KeySharesParams = NewKeySharesParameters()
hello := uconn.HandshakeState.Hello hello := uconn.HandshakeState.Hello
session := uconn.HandshakeState.Session session := uconn.HandshakeState.Session
@ -2119,17 +2124,37 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
continue continue
} }
ecdheKey, err := generateECDHEKey(uconn.config.rand(), curveID) if scheme := curveIdToCirclScheme(curveID); scheme != nil {
if err != nil { pk, sk, err := generateKemKeyPair(scheme, curveID, uconn.config.rand())
return fmt.Errorf("unsupported Curve in KeyShareExtension: %v."+ if err != nil {
"To mimic it, fill the Data(key) field manually", curveID) return fmt.Errorf("HRR generateKemKeyPair %s: %w",
} scheme.Name(), err)
uconn.HandshakeState.State13.KeySharesEcdheParams.AddEcdheParams(curveID, ecdheKey) }
ext.KeyShares[i].Data = ecdheKey.PublicKey().Bytes() packedPk, err := pk.MarshalBinary()
if !preferredCurveIsSet { if err != nil {
// only do this once for the first non-grease curve return fmt.Errorf("HRR pack circl public key %s: %w",
uconn.HandshakeState.State13.EcdheKey = ecdheKey scheme.Name(), err)
preferredCurveIsSet = true }
uconn.HandshakeState.State13.KeySharesParams.AddKemKeypair(curveID, sk.secretKey, pk)
ext.KeyShares[i].Data = packedPk
if !preferredCurveIsSet {
// only do this once for the first non-grease curve
uconn.HandshakeState.State13.KEMKey = sk.ToPublic()
preferredCurveIsSet = true
}
} else {
ecdheKey, err := generateECDHEKey(uconn.config.rand(), curveID)
if err != nil {
return fmt.Errorf("unsupported Curve in KeyShareExtension: %v."+
"To mimic it, fill the Data(key) field manually", curveID)
}
uconn.HandshakeState.State13.KeySharesParams.AddEcdheKeypair(curveID, ecdheKey, ecdheKey.PublicKey())
ext.KeyShares[i].Data = ecdheKey.PublicKey().Bytes()
if !preferredCurveIsSet {
// only do this once for the first non-grease curve
uconn.HandshakeState.State13.EcdheKey = ecdheKey
preferredCurveIsSet = true
}
} }
} }
case *SupportedVersionsExtension: case *SupportedVersionsExtension:

View file

@ -38,17 +38,17 @@ type PubClientHandshakeState struct {
// TLS 1.3 only // TLS 1.3 only
type TLS13OnlyState struct { type TLS13OnlyState struct {
Suite *PubCipherSuiteTLS13 Suite *PubCipherSuiteTLS13
EcdheKey *ecdh.PrivateKey EcdheKey *ecdh.PrivateKey
KeySharesEcdheParams KeySharesEcdheParameters KeySharesParams *KeySharesParameters
KEMKey *KemPrivateKey KEMKey *KemPrivateKey
EarlySecret []byte EarlySecret []byte
BinderKey []byte BinderKey []byte
CertReq *CertificateRequestMsgTLS13 CertReq *CertificateRequestMsgTLS13
UsingPSK bool UsingPSK bool
SentDummyCCS bool SentDummyCCS bool
Transcript hash.Hash Transcript hash.Hash
TrafficSecret []byte // client_application_traffic_secret_0 TrafficSecret []byte // client_application_traffic_secret_0
} }
// TLS 1.2 and before only // TLS 1.2 and before only
@ -62,12 +62,12 @@ func (chs *PubClientHandshakeState) toPrivate13() *clientHandshakeStateTLS13 {
return nil return nil
} else { } else {
return &clientHandshakeStateTLS13{ return &clientHandshakeStateTLS13{
c: chs.C, c: chs.C,
serverHello: chs.ServerHello.getPrivatePtr(), serverHello: chs.ServerHello.getPrivatePtr(),
hello: chs.Hello.getPrivatePtr(), hello: chs.Hello.getPrivatePtr(),
ecdheKey: chs.State13.EcdheKey, ecdheKey: chs.State13.EcdheKey,
keySharesEcdheParams: chs.State13.KeySharesEcdheParams, keySharesParams: chs.State13.KeySharesParams,
kemKey: chs.State13.KEMKey.ToPrivate(), kemKey: chs.State13.KEMKey.ToPrivate(),
session: chs.Session, session: chs.Session,
earlySecret: chs.State13.EarlySecret, earlySecret: chs.State13.EarlySecret,
@ -91,17 +91,17 @@ func (chs13 *clientHandshakeStateTLS13) toPublic13() *PubClientHandshakeState {
return nil return nil
} else { } else {
tls13State := TLS13OnlyState{ tls13State := TLS13OnlyState{
KeySharesEcdheParams: chs13.keySharesEcdheParams, KeySharesParams: chs13.keySharesParams,
EcdheKey: chs13.ecdheKey, EcdheKey: chs13.ecdheKey,
KEMKey: chs13.kemKey.ToPublic(), KEMKey: chs13.kemKey.ToPublic(),
EarlySecret: chs13.earlySecret, EarlySecret: chs13.earlySecret,
BinderKey: chs13.binderKey, BinderKey: chs13.binderKey,
CertReq: chs13.certReq.toPublic(), CertReq: chs13.certReq.toPublic(),
UsingPSK: chs13.usingPSK, UsingPSK: chs13.usingPSK,
SentDummyCCS: chs13.sentDummyCCS, SentDummyCCS: chs13.sentDummyCCS,
Suite: chs13.suite.toPublic(), Suite: chs13.suite.toPublic(),
TrafficSecret: chs13.trafficSecret, TrafficSecret: chs13.trafficSecret,
Transcript: chs13.transcript, Transcript: chs13.transcript,
} }
return &PubClientHandshakeState{ return &PubClientHandshakeState{
C: chs13.c, C: chs13.c,