feat: add support for ECH when using custom clienthello specs

Use the standard library's client hello marshaller to generate the ClientHelloInner and replace any extension that implements the EncryptedClientHelloExtension interface in the extensions list with it.
This commit is contained in:
Mingye Chen 2025-03-08 16:42:47 -07:00
parent 227c44e79f
commit 9b2ccad920
5 changed files with 150 additions and 18 deletions

10
ech.go
View file

@ -207,8 +207,16 @@ func pickECHCipherSuite(suites []echCipher) (echCipher, error) {
return echCipher{}, errors.New("tls: no supported symmetric ciphersuites for ECH")
}
// [uTLS SECTION BEGIN]
func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, error) {
h, err := inner.marshalMsg(true)
return encodeInnerClientHelloReorderOuterExts(inner, maxNameLength, nil)
}
// [uTLS SECTION END]
// func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, error) {
func encodeInnerClientHelloReorderOuterExts(inner *clientHelloMsg, maxNameLength int, outerExts []uint16) ([]byte, error) { // uTLS
h, err := inner.marshalMsgReorderOuterExts(true, outerExts)
if err != nil {
return nil, err
}

View file

@ -75,9 +75,21 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
if hs.echContext != nil {
hs.echContext.innerTranscript = hs.suite.hash.New()
if err := transcriptMsg(hs.echContext.innerHello, hs.echContext.innerTranscript); err != nil {
// [uTLS SECTION BEGIN]
encodedInner, err := encodeInnerClientHelloReorderOuterExts(hs.echContext.innerHello, int(hs.echContext.config.MaxNameLength), hs.uconn.extensionsList())
if err != nil {
return err
}
decodedInner, err := decodeInnerClientHello(hs.hello, encodedInner)
if err != nil {
return err
}
if err := transcriptMsg(decodedInner, hs.echContext.innerTranscript); err != nil {
return err
}
// [uTLS SECTION END]
}
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {

View file

@ -105,7 +105,15 @@ type clientHelloMsg struct {
nextProtoNeg bool
}
// [uTLS SECTION BEGIN]
func (m *clientHelloMsg) marshalMsg(echInner bool) ([]byte, error) {
return m.marshalMsgReorderOuterExts(echInner, nil)
}
// [uTLS SECTION END]
// func (m *clientHelloMsg) marshalMsg(echInner bool) ([]byte, error) {
func (m *clientHelloMsg) marshalMsgReorderOuterExts(echInner bool, outerExts []uint16) ([]byte, error) { // uTLS
var exts cryptobyte.Builder
if len(m.serverName) > 0 {
// RFC 6066, Section 3
@ -254,18 +262,14 @@ func (m *clientHelloMsg) marshalMsg(echInner bool) ([]byte, error) {
}
if len(m.supportedVersions) > 0 {
// RFC 8446, Section 4.2.1
if echInner {
echOuterExts = append(echOuterExts, extensionSupportedVersions)
} else {
exts.AddUint16(extensionSupportedVersions)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, vers := range m.supportedVersions {
exts.AddUint16(vers)
}
})
exts.AddUint16(extensionSupportedVersions)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, vers := range m.supportedVersions {
exts.AddUint16(vers)
}
})
}
})
}
if len(m.cookie) > 0 {
// RFC 8446, Section 4.2.2
@ -311,6 +315,20 @@ func (m *clientHelloMsg) marshalMsg(echInner bool) ([]byte, error) {
})
}
}
// [uTLS SECTION BEGIN]
if echInner && outerExts != nil {
echOuterExtsReordered := slices.Collect(func(yield func(uint16) bool) {
for _, ext := range outerExts {
if slices.Contains(echOuterExts, ext) {
if !yield(ext) {
return
}
}
}
})
echOuterExts = echOuterExtsReordered
}
// [uTLS SECTION END]
if len(echOuterExts) > 0 && echInner {
exts.AddUint16(extensionECHOuterExtensions)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {

102
u_conn.go
View file

@ -14,7 +14,10 @@ import (
"fmt"
"hash"
"net"
"slices"
"strconv"
"golang.org/x/crypto/cryptobyte"
)
type ClientHelloBuildStatus int
@ -487,15 +490,106 @@ func (uconn *UConn) ApplyConfig() error {
return nil
}
func (uconn *UConn) extensionsList() []uint16 {
outerExts := []uint16{}
for _, ext := range uconn.Extensions {
buffer := cryptobyte.String(make([]byte, 2000))
ext.Read(buffer)
var extension uint16
buffer.ReadUint16(&extension)
outerExts = append(outerExts, extension)
}
return outerExts
}
func (uconn *UConn) MarshalClientHello() error {
if len(uconn.config.ECHConfigs) > 0 && uconn.ech != nil {
if err := uconn.ech.Configure(uconn.config.ECHConfigs); err != nil {
if len(uconn.config.EncryptedClientHelloConfigList) > 0 {
inner, _, ech, err := uconn.makeClientHello()
if err != nil {
return err
}
return uconn.ech.MarshalClientHello(uconn)
inner.keyShares = KeyShares(uconn.HandshakeState.Hello.KeyShares).ToPrivate()
inner.supportedSignatureAlgorithms = uconn.HandshakeState.Hello.SupportedSignatureAlgorithms
inner.sessionId = uconn.HandshakeState.Hello.SessionId
inner.supportedCurves = uconn.HandshakeState.Hello.SupportedCurves
inner.supportedVersions = []uint16{VersionTLS13}
ech.innerHello = inner
encapKey := ech.encapsulatedKey
encodedInner, err := encodeInnerClientHelloReorderOuterExts(inner, int(ech.config.MaxNameLength), uconn.extensionsList())
if err != nil {
return err
}
// NOTE: the tag lengths for all of the supported AEADs are the same (16
// bytes), so we have hardcoded it here. If we add support for another AEAD
// with a different tag length, we will need to change this.
encryptedLen := len(encodedInner) + 16 // AEAD tag length
outerECHExt, err := generateOuterECHExt(ech.config.ConfigID, ech.kdfID, ech.aeadID, encapKey, make([]byte, encryptedLen))
if err != nil {
return err
}
sniExtIdex := slices.IndexFunc(uconn.Extensions, func(ext TLSExtension) bool {
_, ok := ext.(*SNIExtension)
return ok
})
uconn.Extensions[sniExtIdex] = &SNIExtension{
ServerName: string(ech.config.PublicName),
}
echExtIdx := slices.IndexFunc(uconn.Extensions, func(ext TLSExtension) bool {
_, ok := ext.(EncryptedClientHelloExtension)
return ok
})
uconn.Extensions[echExtIdx] = &GenericExtension{
Id: extensionEncryptedClientHello,
Data: outerECHExt,
}
// uconn.HandshakeState.Hello.Random = make([]byte, 32)
// _, err = io.ReadFull(uconn.config.rand(), uconn.HandshakeState.Hello.Random)
// if err != nil {
// return errors.New("tls: short read from Rand: " + err.Error())
// }
if err := uconn.MarshalClientHelloNoECH(); err != nil {
return err
}
serializedOuter := uconn.HandshakeState.Hello.Raw
serializedOuter = serializedOuter[4:] // strip the four byte prefix
encryptedInner, err := ech.hpkeContext.Seal(serializedOuter, encodedInner)
if err != nil {
return err
}
outerECHExt, err = generateOuterECHExt(ech.config.ConfigID, ech.kdfID, ech.aeadID, encapKey, encryptedInner)
if err != nil {
return err
}
uconn.Extensions[echExtIdx] = &GenericExtension{
Id: extensionEncryptedClientHello,
Data: outerECHExt,
}
if err := uconn.MarshalClientHelloNoECH(); err != nil {
return err
}
uconn.echCtx = ech
return nil
}
return uconn.MarshalClientHelloNoECH() // if no ECH pointer, just marshal normally
if err := uconn.MarshalClientHelloNoECH(); err != nil {
return err
}
return nil
}
// MarshalClientHelloNoECH marshals ClientHello as if there was no

View file

@ -483,7 +483,7 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
}()
}
if ech != nil {
if ech != nil && c.clientHelloBuildStatus == BuildByGoTLS {
// Split hello into inner and outer
ech.innerHello = hello.clone()