diff --git a/handshake_client.go b/handshake_client.go index f0df1ef..a0671ed 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -138,7 +138,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error) hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms } - var secret clientKeySharePrivate + var secret clientKeySharePrivate // [UTLS] if hello.supportedVersions[0] == VersionTLS13 { // Reset the list of ciphers when the client only supports TLS 1.3. if len(hello.supportedVersions) == 1 { @@ -280,7 +280,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { earlySecret: earlySecret, binderKey: binderKey, - keySharesEcdheParams: make(KeySharesEcdheParameters, 2), // [uTLS] + keySharesParams: NewKeySharesParameters(), // [uTLS] } if ecdheKey, ok := keySharePrivate.(*ecdh.PrivateKey); ok { diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index 858f0b8..52c7139 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -15,34 +15,74 @@ import ( "fmt" "hash" "time" + + "github.com/cloudflare/circl/kem" ) // [uTLS SECTION START] -type KeySharesEcdheParameters map[CurveID]*ecdh.PrivateKey +// KeySharesParameters serves as a in-memory storage for generated keypairs by UTLS when generating +// ClientHello. It is used to store both ecdhe and kem keypairs. +type KeySharesParameters struct { + ecdhePrivKeymap map[CurveID]*ecdh.PrivateKey + ecdhePubKeymap map[CurveID]*ecdh.PublicKey -func (keymap KeySharesEcdheParameters) AddEcdheParams(curveID CurveID, ecdheKey *ecdh.PrivateKey) { - keymap[curveID] = ecdheKey + // based on cloudflare/go + kemPrivKeymap map[CurveID]kem.PrivateKey + kemPubKeymap map[CurveID]kem.PublicKey } -func (keymap KeySharesEcdheParameters) GetEcdheParams(curveID CurveID) (ecdheKey *ecdh.PrivateKey, ok bool) { - ecdheKey, ok = keymap[curveID] + +func NewKeySharesParameters() *KeySharesParameters { + return &KeySharesParameters{ + ecdhePrivKeymap: make(map[CurveID]*ecdh.PrivateKey), + ecdhePubKeymap: make(map[CurveID]*ecdh.PublicKey), + + kemPrivKeymap: make(map[CurveID]kem.PrivateKey), + kemPubKeymap: make(map[CurveID]kem.PublicKey), + } +} + +func (ksp *KeySharesParameters) AddEcdheKeypair(curveID CurveID, ecdheKey *ecdh.PrivateKey, ecdhePubKey *ecdh.PublicKey) { + ksp.ecdhePrivKeymap[curveID] = ecdheKey + ksp.ecdhePubKeymap[curveID] = ecdhePubKey +} + +func (ksp *KeySharesParameters) GetEcdheKey(curveID CurveID) (ecdheKey *ecdh.PrivateKey, ok bool) { + ecdheKey, ok = ksp.ecdhePrivKeymap[curveID] return } -func (keymap KeySharesEcdheParameters) GetPublicEcdheParams(curveID CurveID) (params *ecdh.PrivateKey, ok bool) { - params, ok = keymap[curveID] + +func (ksp *KeySharesParameters) GetEcdhePubkey(curveID CurveID) (params *ecdh.PublicKey, ok bool) { + params, ok = ksp.ecdhePubKeymap[curveID] + return +} + +func (ksp *KeySharesParameters) AddKemKeypair(curveID CurveID, kemKey kem.PrivateKey, kemPubKey kem.PublicKey) { + if curveIdToCirclScheme(curveID) != nil { // only store for circl schemes + ksp.kemPrivKeymap[curveID] = kemKey + ksp.kemPubKeymap[curveID] = kemPubKey + } +} + +func (ksp *KeySharesParameters) GetKemKey(curveID CurveID) (kemKey kem.PrivateKey, ok bool) { + kemKey, ok = ksp.kemPrivKeymap[curveID] + return +} + +func (ksp *KeySharesParameters) GetKemPubkey(curveID CurveID) (params kem.PublicKey, ok bool) { + params, ok = ksp.kemPubKeymap[curveID] return } // [uTLS SECTION END] type clientHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - ecdheKey *ecdh.PrivateKey - keySharesEcdheParams KeySharesEcdheParameters // [uTLS] - kemKey *kemPrivateKey // [uTLS] - // keySharesCirclParams KeySharesCirclParameters // [uTLS] TODO: perhaps implement? + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg + ecdheKey *ecdh.PrivateKey + kemKey *kemPrivateKey // [uTLS] ported from cloudflare/go + keySharesParams *KeySharesParameters // [uTLS] support both ecdhe and kem session *SessionState earlySecret []byte @@ -77,10 +117,18 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } // [uTLS SECTION START] - // set echdheParams to what we received from server - if ecdheKey, ok := hs.keySharesEcdheParams.GetEcdheParams(hs.serverHello.serverShare.group); ok { + if ecdheKey, ok := hs.keySharesParams.GetEcdheKey(hs.serverHello.serverShare.group); ok { hs.ecdheKey = ecdheKey + hs.kemKey = nil // unset kemKey if any + } + // set kemParams to what we received from server + if kemKey, ok := hs.keySharesParams.GetKemKey(hs.serverHello.serverShare.group); ok { + hs.kemKey = &kemPrivateKey{ + secretKey: kemKey, + curveID: hs.serverHello.serverShare.group, + } + hs.ecdheKey = nil // unset ecdheKey if any } // [uTLS SECTION END] @@ -466,20 +514,22 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { c.sendAlert(alertIllegalParameter) return errors.New("tls: server did not send a key share") } - if hs.ecdheKey != nil { - if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group != sentID { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported group") + + // [UTLS SECTION BEGINS] + var supportedGroupCompatible bool + if hs.ecdheKey != nil { // if we did send ECDHE KeyShare + if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group == sentID { // and server selected ECDHE KeyShare + supportedGroupCompatible = true } - } else if hs.kemKey != nil { - if clientKeySharePrivateCurveID(hs.kemKey) != hs.serverHello.serverShare.group { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported group") - } - } else { - c.sendAlert(alertInternalError) - return errors.New("tls: ecdheKey and kemKey are both nil") } + if hs.kemKey != nil && clientKeySharePrivateCurveID(hs.kemKey) == hs.serverHello.serverShare.group { // we did send KEM KeyShare and server selected KEM KeyShare + supportedGroupCompatible = true + } + if !supportedGroupCompatible { // none matched + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + // [UTLS SECTION ENDS] if !hs.serverHello.selectedIdentityPresent { return nil @@ -521,24 +571,28 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { var err error if hs.ecdheKey != nil { - peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server key share") + if ecdheCurveID, _ := curveIDForCurve(hs.ecdheKey.Curve()); ecdheCurveID == hs.serverHello.serverShare.group { + peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data) + if err != nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server key share") + } + sharedKey, err = hs.ecdheKey.ECDH(peerKey) + if err != nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server key share") + } } - sharedKey, err = hs.ecdheKey.ECDH(peerKey) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server key share") - } - } else if hs.kemKey != nil { + } + if sharedKey == nil && hs.kemKey != nil && clientKeySharePrivateCurveID(hs.kemKey) == hs.serverHello.serverShare.group { sk := hs.kemKey.secretKey sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data) if err != nil { c.sendAlert(alertIllegalParameter) return fmt.Errorf("%s decaps: %w", sk.Scheme().Name(), err) } - } else { + } + if sharedKey == nil { c.sendAlert(alertInternalError) return errors.New("tls: ecdheKey and circlKey are both nil") } diff --git a/u_handshake_client.go b/u_handshake_client.go index 7ff80d8..7b9a98f 100644 --- a/u_handshake_client.go +++ b/u_handshake_client.go @@ -7,7 +7,6 @@ package tls import ( "bytes" "compress/zlib" - "crypto/ecdh" "errors" "fmt" "io" @@ -164,7 +163,7 @@ func (hs *clientHandshakeStateTLS13) utlsReadServerParameters(encryptedExtension return nil } -func (c *Conn) makeClientHelloForApplyPreset() (*clientHelloMsg, *ecdh.PrivateKey, error) { +func (c *Conn) makeClientHelloForApplyPreset() (*clientHelloMsg, clientKeySharePrivate, error) { config := c.config // [UTLS SECTION START] @@ -261,7 +260,7 @@ func (c *Conn) makeClientHelloForApplyPreset() (*clientHelloMsg, *ecdh.PrivateKe hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms } - var key *ecdh.PrivateKey + var secret clientKeySharePrivate // [UTLS] if hello.supportedVersions[0] == VersionTLS13 { // Reset the list of ciphers when the client only supports TLS 1.3. if len(hello.supportedVersions) == 1 { @@ -273,15 +272,32 @@ func (c *Conn) makeClientHelloForApplyPreset() (*clientHelloMsg, *ecdh.PrivateKe hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) } - curveID := config.curvePreferences()[0] - if _, ok := curveForCurveID(curveID); !ok { - return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") - } - key, err = generateECDHEKey(config.rand(), curveID) - if err != nil { - return nil, nil, err - } - hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} + // curveID := config.curvePreferences()[0] + // // [UTLS SECTION BEGINS] + // // Ported from cloudflare/go with modifications to preserve crypto/tls compatibility + // if scheme := curveIdToCirclScheme(curveID); scheme != nil { + // pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand()) + // if err != nil { + // return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w", scheme.Name(), err) + // } + // packedPk, err := pk.MarshalBinary() + // if err != nil { + // return nil, nil, fmt.Errorf("pack circl public key %s: %w", scheme.Name(), err) + // } + // hello.keyShares = []keyShare{{group: curveID, data: packedPk}} + // secret = sk + // } else { + // if _, ok := curveForCurveID(curveID); !ok { + // return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + // } + // key, err := generateECDHEKey(config.rand(), curveID) + // if err != nil { + // return nil, nil, err + // } + // hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} + // secret = key + // } + // // [UTLS SECTION ENDS] } // [UTLS] We don't need this, since it is not ready yet @@ -296,5 +312,5 @@ func (c *Conn) makeClientHelloForApplyPreset() (*clientHelloMsg, *ecdh.PrivateKe // hello.quicTransportParameters = p // } - return hello, key, nil + return hello, secret, nil } diff --git a/u_parrots.go b/u_parrots.go index 7e2660f..a696259 100644 --- a/u_parrots.go +++ b/u_parrots.go @@ -5,6 +5,7 @@ package tls import ( + "crypto/ecdh" "crypto/sha256" "encoding/binary" "errors" @@ -2013,13 +2014,17 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error { return err } - privateHello, ecdheKey, err := uconn.makeClientHelloForApplyPreset() + privateHello, clientKeySharePrivate, err := uconn.makeClientHelloForApplyPreset() if err != nil { return err } uconn.HandshakeState.Hello = privateHello.getPublicPtr() - uconn.HandshakeState.State13.EcdheKey = ecdheKey - uconn.HandshakeState.State13.KeySharesEcdheParams = make(KeySharesEcdheParameters, 2) + if ecdheKey, ok := clientKeySharePrivate.(*ecdh.PrivateKey); ok { + uconn.HandshakeState.State13.EcdheKey = ecdheKey + } else if kemKey, ok := clientKeySharePrivate.(*kemPrivateKey); ok { + uconn.HandshakeState.State13.KEMKey = kemKey.ToPublic() + } + uconn.HandshakeState.State13.KeySharesParams = NewKeySharesParameters() hello := uconn.HandshakeState.Hello session := uconn.HandshakeState.Session @@ -2119,17 +2124,37 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error { continue } - ecdheKey, err := generateECDHEKey(uconn.config.rand(), curveID) - if err != nil { - return fmt.Errorf("unsupported Curve in KeyShareExtension: %v."+ - "To mimic it, fill the Data(key) field manually", curveID) - } - uconn.HandshakeState.State13.KeySharesEcdheParams.AddEcdheParams(curveID, ecdheKey) - ext.KeyShares[i].Data = ecdheKey.PublicKey().Bytes() - if !preferredCurveIsSet { - // only do this once for the first non-grease curve - uconn.HandshakeState.State13.EcdheKey = ecdheKey - preferredCurveIsSet = true + if scheme := curveIdToCirclScheme(curveID); scheme != nil { + pk, sk, err := generateKemKeyPair(scheme, curveID, uconn.config.rand()) + if err != nil { + return fmt.Errorf("HRR generateKemKeyPair %s: %w", + scheme.Name(), err) + } + packedPk, err := pk.MarshalBinary() + if err != nil { + return fmt.Errorf("HRR pack circl public key %s: %w", + scheme.Name(), err) + } + uconn.HandshakeState.State13.KeySharesParams.AddKemKeypair(curveID, sk.secretKey, pk) + ext.KeyShares[i].Data = packedPk + if !preferredCurveIsSet { + // only do this once for the first non-grease curve + uconn.HandshakeState.State13.KEMKey = sk.ToPublic() + preferredCurveIsSet = true + } + } else { + ecdheKey, err := generateECDHEKey(uconn.config.rand(), curveID) + if err != nil { + return fmt.Errorf("unsupported Curve in KeyShareExtension: %v."+ + "To mimic it, fill the Data(key) field manually", curveID) + } + uconn.HandshakeState.State13.KeySharesParams.AddEcdheKeypair(curveID, ecdheKey, ecdheKey.PublicKey()) + ext.KeyShares[i].Data = ecdheKey.PublicKey().Bytes() + if !preferredCurveIsSet { + // only do this once for the first non-grease curve + uconn.HandshakeState.State13.EcdheKey = ecdheKey + preferredCurveIsSet = true + } } } case *SupportedVersionsExtension: diff --git a/u_public.go b/u_public.go index ba83cea..55caa43 100644 --- a/u_public.go +++ b/u_public.go @@ -38,17 +38,17 @@ type PubClientHandshakeState struct { // TLS 1.3 only type TLS13OnlyState struct { - Suite *PubCipherSuiteTLS13 - EcdheKey *ecdh.PrivateKey - KeySharesEcdheParams KeySharesEcdheParameters - KEMKey *KemPrivateKey - EarlySecret []byte - BinderKey []byte - CertReq *CertificateRequestMsgTLS13 - UsingPSK bool - SentDummyCCS bool - Transcript hash.Hash - TrafficSecret []byte // client_application_traffic_secret_0 + Suite *PubCipherSuiteTLS13 + EcdheKey *ecdh.PrivateKey + KeySharesParams *KeySharesParameters + KEMKey *KemPrivateKey + EarlySecret []byte + BinderKey []byte + CertReq *CertificateRequestMsgTLS13 + UsingPSK bool + SentDummyCCS bool + Transcript hash.Hash + TrafficSecret []byte // client_application_traffic_secret_0 } // TLS 1.2 and before only @@ -62,12 +62,12 @@ func (chs *PubClientHandshakeState) toPrivate13() *clientHandshakeStateTLS13 { return nil } else { return &clientHandshakeStateTLS13{ - c: chs.C, - serverHello: chs.ServerHello.getPrivatePtr(), - hello: chs.Hello.getPrivatePtr(), - ecdheKey: chs.State13.EcdheKey, - keySharesEcdheParams: chs.State13.KeySharesEcdheParams, - kemKey: chs.State13.KEMKey.ToPrivate(), + c: chs.C, + serverHello: chs.ServerHello.getPrivatePtr(), + hello: chs.Hello.getPrivatePtr(), + ecdheKey: chs.State13.EcdheKey, + keySharesParams: chs.State13.KeySharesParams, + kemKey: chs.State13.KEMKey.ToPrivate(), session: chs.Session, earlySecret: chs.State13.EarlySecret, @@ -91,17 +91,17 @@ func (chs13 *clientHandshakeStateTLS13) toPublic13() *PubClientHandshakeState { return nil } else { tls13State := TLS13OnlyState{ - KeySharesEcdheParams: chs13.keySharesEcdheParams, - EcdheKey: chs13.ecdheKey, - KEMKey: chs13.kemKey.ToPublic(), - EarlySecret: chs13.earlySecret, - BinderKey: chs13.binderKey, - CertReq: chs13.certReq.toPublic(), - UsingPSK: chs13.usingPSK, - SentDummyCCS: chs13.sentDummyCCS, - Suite: chs13.suite.toPublic(), - TrafficSecret: chs13.trafficSecret, - Transcript: chs13.transcript, + KeySharesParams: chs13.keySharesParams, + EcdheKey: chs13.ecdheKey, + KEMKey: chs13.kemKey.ToPublic(), + EarlySecret: chs13.earlySecret, + BinderKey: chs13.binderKey, + CertReq: chs13.certReq.toPublic(), + UsingPSK: chs13.usingPSK, + SentDummyCCS: chs13.sentDummyCCS, + Suite: chs13.suite.toPublic(), + TrafficSecret: chs13.trafficSecret, + Transcript: chs13.transcript, } return &PubClientHandshakeState{ C: chs13.c,