crypto/tls: add server-side ECH

Adds support for server-side ECH.

We make a couple of implementation decisions that are not completely
in-line with the spec. In particular, we don't enforce that the SNI
matches the ECHConfig public_name, and we implement a hybrid
shared/backend mode (rather than shared or split mode, as described in
Section 7). Both of these match the behavior of BoringSSL.

The hybrid server mode will either act as a shared mode server, where-in
the server accepts "outer" client hellos and unwraps them before
processing the "inner" hello, or accepts bare "inner" hellos initially.
This lets the server operate either transparently as a shared mode
server, or a backend server, in Section 7 terminology. This seems like
the best implementation choice for a TLS library.

Fixes #68500

Change-Id: Ife69db7c1886610742e95e76b0ca92587e6d7ed4
Reviewed-on: https://go-review.googlesource.com/c/go/+/623576
Reviewed-by: Filippo Valsorda <filippo@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
Auto-Submit: Roland Shoemaker <roland@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
This commit is contained in:
Roland Shoemaker 2024-10-29 20:22:27 -07:00 committed by Gopher Robot
parent 83cefcdeed
commit 212bbb2c77
12 changed files with 770 additions and 95 deletions

472
ech.go
View file

@ -5,13 +5,28 @@
package tls
import (
"bytes"
"crypto/internal/hpke"
"errors"
"fmt"
"slices"
"strings"
"golang.org/x/crypto/cryptobyte"
)
// sortedSupportedAEADs is just a sorted version of hpke.SupportedAEADS.
// We need this so that when we insert them into ECHConfigs the ordering
// is stable.
var sortedSupportedAEADs []uint16
func init() {
for aeadID := range hpke.SupportedAEADs {
sortedSupportedAEADs = append(sortedSupportedAEADs, aeadID)
}
slices.Sort(sortedSupportedAEADs)
}
type echCipher struct {
KDFID uint16
AEADID uint16
@ -40,12 +55,77 @@ type echConfig struct {
var errMalformedECHConfig = errors.New("tls: malformed ECHConfigList")
func parseECHConfig(enc []byte) (skip bool, ec echConfig, err error) {
s := cryptobyte.String(enc)
ec.raw = []byte(enc)
if !s.ReadUint16(&ec.Version) {
return false, echConfig{}, errMalformedECHConfig
}
if !s.ReadUint16(&ec.Length) {
return false, echConfig{}, errMalformedECHConfig
}
if len(ec.raw) < int(ec.Length)+4 {
return false, echConfig{}, errMalformedECHConfig
}
ec.raw = ec.raw[:ec.Length+4]
if ec.Version != extensionEncryptedClientHello {
s.Skip(int(ec.Length))
return true, echConfig{}, nil
}
if !s.ReadUint8(&ec.ConfigID) {
return false, echConfig{}, errMalformedECHConfig
}
if !s.ReadUint16(&ec.KemID) {
return false, echConfig{}, errMalformedECHConfig
}
if !readUint16LengthPrefixed(&s, &ec.PublicKey) {
return false, echConfig{}, errMalformedECHConfig
}
var cipherSuites cryptobyte.String
if !s.ReadUint16LengthPrefixed(&cipherSuites) {
return false, echConfig{}, errMalformedECHConfig
}
for !cipherSuites.Empty() {
var c echCipher
if !cipherSuites.ReadUint16(&c.KDFID) {
return false, echConfig{}, errMalformedECHConfig
}
if !cipherSuites.ReadUint16(&c.AEADID) {
return false, echConfig{}, errMalformedECHConfig
}
ec.SymmetricCipherSuite = append(ec.SymmetricCipherSuite, c)
}
if !s.ReadUint8(&ec.MaxNameLength) {
return false, echConfig{}, errMalformedECHConfig
}
var publicName cryptobyte.String
if !s.ReadUint8LengthPrefixed(&publicName) {
return false, echConfig{}, errMalformedECHConfig
}
ec.PublicName = publicName
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) {
return false, echConfig{}, errMalformedECHConfig
}
for !extensions.Empty() {
var e echExtension
if !extensions.ReadUint16(&e.Type) {
return false, echConfig{}, errMalformedECHConfig
}
if !extensions.ReadUint16LengthPrefixed((*cryptobyte.String)(&e.Data)) {
return false, echConfig{}, errMalformedECHConfig
}
ec.Extensions = append(ec.Extensions, e)
}
return false, ec, nil
}
// parseECHConfigList parses a draft-ietf-tls-esni-18 ECHConfigList, returning a
// slice of parsed ECHConfigs, in the same order they were parsed, or an error
// if the list is malformed.
func parseECHConfigList(data []byte) ([]echConfig, error) {
s := cryptobyte.String(data)
// Skip the length prefix
var length uint16
if !s.ReadUint16(&length) {
return nil, errMalformedECHConfig
@ -55,69 +135,18 @@ func parseECHConfigList(data []byte) ([]echConfig, error) {
}
var configs []echConfig
for len(s) > 0 {
var ec echConfig
ec.raw = []byte(s)
if !s.ReadUint16(&ec.Version) {
return nil, errMalformedECHConfig
if len(s) < 4 {
return nil, errors.New("tls: malformed ECHConfig")
}
if !s.ReadUint16(&ec.Length) {
return nil, errMalformedECHConfig
configLen := uint16(s[2])<<8 | uint16(s[3])
skip, ec, err := parseECHConfig(s)
if err != nil {
return nil, err
}
if len(ec.raw) < int(ec.Length)+4 {
return nil, errMalformedECHConfig
s = s[configLen+4:]
if !skip {
configs = append(configs, ec)
}
ec.raw = ec.raw[:ec.Length+4]
if ec.Version != extensionEncryptedClientHello {
s.Skip(int(ec.Length))
continue
}
if !s.ReadUint8(&ec.ConfigID) {
return nil, errMalformedECHConfig
}
if !s.ReadUint16(&ec.KemID) {
return nil, errMalformedECHConfig
}
if !s.ReadUint16LengthPrefixed((*cryptobyte.String)(&ec.PublicKey)) {
return nil, errMalformedECHConfig
}
var cipherSuites cryptobyte.String
if !s.ReadUint16LengthPrefixed(&cipherSuites) {
return nil, errMalformedECHConfig
}
for !cipherSuites.Empty() {
var c echCipher
if !cipherSuites.ReadUint16(&c.KDFID) {
return nil, errMalformedECHConfig
}
if !cipherSuites.ReadUint16(&c.AEADID) {
return nil, errMalformedECHConfig
}
ec.SymmetricCipherSuite = append(ec.SymmetricCipherSuite, c)
}
if !s.ReadUint8(&ec.MaxNameLength) {
return nil, errMalformedECHConfig
}
var publicName cryptobyte.String
if !s.ReadUint8LengthPrefixed(&publicName) {
return nil, errMalformedECHConfig
}
ec.PublicName = publicName
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) {
return nil, errMalformedECHConfig
}
for !extensions.Empty() {
var e echExtension
if !extensions.ReadUint16(&e.Type) {
return nil, errMalformedECHConfig
}
if !extensions.ReadUint16LengthPrefixed((*cryptobyte.String)(&e.Data)) {
return nil, errMalformedECHConfig
}
ec.Extensions = append(ec.Extensions, e)
}
configs = append(configs, ec)
}
return configs, nil
}
@ -195,6 +224,175 @@ func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, e
return append(h, make([]byte, paddingLen)...), nil
}
func skipUint8LengthPrefixed(s *cryptobyte.String) bool {
var skip uint8
if !s.ReadUint8(&skip) {
return false
}
return s.Skip(int(skip))
}
func skipUint16LengthPrefixed(s *cryptobyte.String) bool {
var skip uint16
if !s.ReadUint16(&skip) {
return false
}
return s.Skip(int(skip))
}
type rawExtension struct {
extType uint16
data []byte
}
func extractRawExtensions(hello *clientHelloMsg) ([]rawExtension, error) {
s := cryptobyte.String(hello.original)
if !s.Skip(4+2+32) || // header, version, random
!skipUint8LengthPrefixed(&s) || // session ID
!skipUint16LengthPrefixed(&s) || // cipher suites
!skipUint8LengthPrefixed(&s) { // compression methods
return nil, errors.New("tls: malformed outer client hello")
}
var rawExtensions []rawExtension
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) {
return nil, errors.New("tls: malformed outer client hello")
}
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return nil, errors.New("tls: invalid inner client hello")
}
rawExtensions = append(rawExtensions, rawExtension{extension, extData})
}
return rawExtensions, nil
}
func decodeInnerClientHello(outer *clientHelloMsg, encoded []byte) (*clientHelloMsg, error) {
// Reconstructing the inner client hello from its encoded form is somewhat
// complicated. It is missing its header (message type and length), session
// ID, and the extensions may be compressed. Since we need to put the
// extensions back in the same order as they were in the raw outer hello,
// and since we don't store the raw extensions, or the order we parsed them
// in, we need to reparse the raw extensions from the outer hello in order
// to properly insert them into the inner hello. This _should_ result in raw
// bytes which match the hello as it was generated by the client.
innerReader := cryptobyte.String(encoded)
var versionAndRandom, sessionID, cipherSuites, compressionMethods []byte
var extensions cryptobyte.String
if !innerReader.ReadBytes(&versionAndRandom, 2+32) ||
!readUint8LengthPrefixed(&innerReader, &sessionID) ||
len(sessionID) != 0 ||
!readUint16LengthPrefixed(&innerReader, &cipherSuites) ||
!readUint8LengthPrefixed(&innerReader, &compressionMethods) ||
!innerReader.ReadUint16LengthPrefixed(&extensions) {
return nil, errors.New("tls: invalid inner client hello")
}
// The specification says we must verify that the trailing padding is all
// zeros. This is kind of weird for TLS messages, where we generally just
// throw away any trailing garbage.
for _, p := range innerReader {
if p != 0 {
return nil, errors.New("tls: invalid inner client hello")
}
}
rawOuterExts, err := extractRawExtensions(outer)
if err != nil {
return nil, err
}
recon := cryptobyte.NewBuilder(nil)
recon.AddUint8(typeClientHello)
recon.AddUint24LengthPrefixed(func(recon *cryptobyte.Builder) {
recon.AddBytes(versionAndRandom)
recon.AddUint8LengthPrefixed(func(recon *cryptobyte.Builder) {
recon.AddBytes(outer.sessionId)
})
recon.AddUint16LengthPrefixed(func(recon *cryptobyte.Builder) {
recon.AddBytes(cipherSuites)
})
recon.AddUint8LengthPrefixed(func(recon *cryptobyte.Builder) {
recon.AddBytes(compressionMethods)
})
recon.AddUint16LengthPrefixed(func(recon *cryptobyte.Builder) {
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
recon.SetError(errors.New("tls: invalid inner client hello"))
return
}
if extension == extensionECHOuterExtensions {
if !extData.ReadUint8LengthPrefixed(&extData) {
recon.SetError(errors.New("tls: invalid inner client hello"))
return
}
var i int
for !extData.Empty() {
var extType uint16
if !extData.ReadUint16(&extType) {
recon.SetError(errors.New("tls: invalid inner client hello"))
return
}
if extType == extensionEncryptedClientHello {
recon.SetError(errors.New("tls: invalid outer extensions"))
return
}
for ; i <= len(rawOuterExts); i++ {
if i == len(rawOuterExts) {
recon.SetError(errors.New("tls: invalid outer extensions"))
return
}
if rawOuterExts[i].extType == extType {
break
}
}
recon.AddUint16(rawOuterExts[i].extType)
recon.AddUint16LengthPrefixed(func(recon *cryptobyte.Builder) {
recon.AddBytes(rawOuterExts[i].data)
})
}
} else {
recon.AddUint16(extension)
recon.AddUint16LengthPrefixed(func(recon *cryptobyte.Builder) {
recon.AddBytes(extData)
})
}
}
})
})
reconBytes, err := recon.Bytes()
if err != nil {
return nil, err
}
inner := &clientHelloMsg{}
if !inner.unmarshal(reconBytes) {
return nil, errors.New("tls: invalid reconstructed inner client hello")
}
if !bytes.Equal(inner.encryptedClientHello, []byte{uint8(innerECHExt)}) {
return nil, errors.New("tls: client sent invalid encrypted_client_hello extension")
}
if len(inner.supportedVersions) != 1 || (len(inner.supportedVersions) >= 1 && inner.supportedVersions[0] != VersionTLS13) {
return nil, errors.New("tls: client sent encrypted_client_hello extension and offered incompatible versions")
}
return inner, nil
}
func decryptECHPayload(context *hpke.Receipient, hello, payload []byte) ([]byte, error) {
outerAAD := bytes.Replace(hello[4:], payload, make([]byte, len(payload)), 1)
return context.Open(outerAAD, payload)
}
func generateOuterECHExt(id uint8, kdfID, aeadID uint16, encodedKey []byte, payload []byte) ([]byte, error) {
var b cryptobyte.Builder
b.AddUint8(0) // outer
@ -206,7 +404,7 @@ func generateOuterECHExt(id uint8, kdfID, aeadID uint16, encodedKey []byte, payl
return b.Bytes()
}
func computeAndUpdateOuterECHExtension(outer, inner *clientHelloMsg, ech *echContext, useKey bool) error {
func computeAndUpdateOuterECHExtension(outer, inner *clientHelloMsg, ech *echClientContext, useKey bool) error {
var encapKey []byte
if useKey {
encapKey = ech.encapsulatedKey
@ -281,3 +479,153 @@ type ECHRejectionError struct {
func (e *ECHRejectionError) Error() string {
return "tls: server rejected ECH"
}
var errMalformedECHExt = errors.New("tls: malformed encrypted_client_hello extension")
type echExtType uint8
const (
innerECHExt echExtType = 1
outerECHExt echExtType = 0
)
func parseECHExt(ext []byte) (echType echExtType, cs echCipher, configID uint8, encap []byte, payload []byte, err error) {
data := make([]byte, len(ext))
copy(data, ext)
s := cryptobyte.String(data)
var echInt uint8
if !s.ReadUint8(&echInt) {
err = errMalformedECHExt
return
}
echType = echExtType(echInt)
if echType == innerECHExt {
if !s.Empty() {
err = errMalformedECHExt
return
}
return echType, cs, 0, nil, nil, nil
}
if echType != outerECHExt {
err = errMalformedECHExt
return
}
if !s.ReadUint16(&cs.KDFID) {
err = errMalformedECHExt
return
}
if !s.ReadUint16(&cs.AEADID) {
err = errMalformedECHExt
return
}
if !s.ReadUint8(&configID) {
err = errMalformedECHExt
return
}
if !readUint16LengthPrefixed(&s, &encap) {
err = errMalformedECHExt
return
}
if !readUint16LengthPrefixed(&s, &payload) {
err = errMalformedECHExt
return
}
// NOTE: clone encap and payload so that mutating them does not mutate the
// raw extension bytes.
return echType, cs, configID, bytes.Clone(encap), bytes.Clone(payload), nil
}
func marshalEncryptedClientHelloConfigList(configs []EncryptedClientHelloKey) ([]byte, error) {
builder := cryptobyte.NewBuilder(nil)
builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) {
for _, c := range configs {
builder.AddBytes(c.Config)
}
})
return builder.Bytes()
}
func (c *Conn) processECHClientHello(outer *clientHelloMsg) (*clientHelloMsg, *echServerContext, error) {
echType, echCiphersuite, configID, encap, payload, err := parseECHExt(outer.encryptedClientHello)
if err != nil {
c.sendAlert(alertDecodeError)
return nil, nil, errors.New("tls: client sent invalid encrypted_client_hello extension")
}
if echType == innerECHExt {
return outer, &echServerContext{inner: true}, nil
}
if len(c.config.EncryptedClientHelloKeys) == 0 {
return outer, nil, nil
}
for _, echKey := range c.config.EncryptedClientHelloKeys {
skip, config, err := parseECHConfig(echKey.Config)
if err != nil || skip {
c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKeys Config: %s", err)
}
if skip {
continue
}
echPriv, err := hpke.ParseHPKEPrivateKey(config.KemID, echKey.PrivateKey)
if err != nil {
c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKeys PrivateKey: %s", err)
}
info := append([]byte("tls ech\x00"), echKey.Config...)
hpkeContext, err := hpke.SetupReceipient(hpke.DHKEM_X25519_HKDF_SHA256, echCiphersuite.KDFID, echCiphersuite.AEADID, echPriv, info, encap)
if err != nil {
// attempt next trial decryption
continue
}
encodedInner, err := decryptECHPayload(hpkeContext, outer.original, payload)
if err != nil {
// attempt next trial decryption
continue
}
// NOTE: we do not enforce that the sent server_name matches the ECH
// configs PublicName, since this is not particularly important, and
// the client already had to know what it was in order to properly
// encrypt the payload. This is only a MAY in the spec, so we're not
// doing anything revolutionary.
echInner, err := decodeInnerClientHello(outer, encodedInner)
if err != nil {
c.sendAlert(alertIllegalParameter)
return nil, nil, errors.New("tls: client sent invalid encrypted_client_hello extension")
}
c.echAccepted = true
return echInner, &echServerContext{
hpkeContext: hpkeContext,
configID: configID,
ciphersuite: echCiphersuite,
}, nil
}
return outer, nil, nil
}
func buildRetryConfigList(keys []EncryptedClientHelloKey) ([]byte, error) {
var atLeastOneRetryConfig bool
var retryBuilder cryptobyte.Builder
retryBuilder.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, c := range keys {
if !c.SendAsRetry {
continue
}
atLeastOneRetryConfig = true
b.AddBytes(c.Config)
}
})
if !atLeastOneRetryConfig {
return nil, nil
}
return retryBuilder.Bytes()
}