diff --git a/ech.go b/ech.go index 4f59e07..184c468 100644 --- a/ech.go +++ b/ech.go @@ -207,8 +207,16 @@ func pickECHCipherSuite(suites []echCipher) (echCipher, error) { return echCipher{}, errors.New("tls: no supported symmetric ciphersuites for ECH") } +// [uTLS SECTION BEGIN] func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, error) { - h, err := inner.marshalMsg(true) + return encodeInnerClientHelloReorderOuterExts(inner, maxNameLength, nil) +} + +// [uTLS SECTION END] + +// func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, error) { +func encodeInnerClientHelloReorderOuterExts(inner *clientHelloMsg, maxNameLength int, outerExts []uint16) ([]byte, error) { // uTLS + h, err := inner.marshalMsgReorderOuterExts(true, outerExts) if err != nil { return nil, err } diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index 5b1fb18..07a1e21 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -75,9 +75,17 @@ func (hs *clientHandshakeStateTLS13) handshake() error { if hs.echContext != nil { hs.echContext.innerTranscript = hs.suite.hash.New() - if err := transcriptMsg(hs.echContext.innerHello, hs.echContext.innerTranscript); err != nil { - return err + // [uTLS SECTION BEGIN] + if hs.uconn != nil && hs.uconn.clientHelloBuildStatus == BuildByUtls { + if err := hs.uconn.echTranscriptMsg(hs.hello, hs.echContext); err != nil { + return err + } + } else { + if err := transcriptMsg(hs.echContext.innerHello, hs.echContext.innerTranscript); err != nil { + return err + } } + // [uTLS SECTION END] } if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { @@ -426,7 +434,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { hs.uconn.Extensions[cookieIndex:]...)...) } } - if err := hs.uconn.MarshalClientHello(); err != nil { + if err := hs.uconn.MarshalClientHelloNoECH(); err != nil { return err } hs.hello.original = hs.uconn.HandshakeState.Hello.Raw @@ -445,12 +453,25 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { // extension which may have changed is keyShares. hs.hello.keyShares = hello.keyShares hs.echContext.innerHello = hello - if err := transcriptMsg(hs.echContext.innerHello, hs.echContext.innerTranscript); err != nil { - return err - } + if hs.uconn != nil && hs.uconn.clientHelloBuildStatus == BuildByUtls { + if err := hs.uconn.computeAndUpdateOuterECHExtension(hs.echContext.innerHello, hs.echContext, false); err != nil { + return err + } - if err := computeAndUpdateOuterECHExtension(hs.hello, hs.echContext.innerHello, hs.echContext, false); err != nil { - return err + hs.hello.original = hs.uconn.HandshakeState.Hello.Raw + + if err := hs.uconn.echTranscriptMsg(hs.hello, hs.echContext); err != nil { + return err + } + + } else { + if err := transcriptMsg(hs.echContext.innerHello, hs.echContext.innerTranscript); err != nil { + return err + } + + if err := computeAndUpdateOuterECHExtension(hs.hello, hs.echContext.innerHello, hs.echContext, false); err != nil { + return err + } } } else { hs.hello = hello diff --git a/handshake_messages.go b/handshake_messages.go index 34f70d6..3b1d341 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -106,6 +106,12 @@ type clientHelloMsg struct { } func (m *clientHelloMsg) marshalMsg(echInner bool) ([]byte, error) { + // [uTLS SECTION BEGIN] + return m.marshalMsgReorderOuterExts(echInner, nil) +} + +func (m *clientHelloMsg) marshalMsgReorderOuterExts(echInner bool, outerExts []uint16) ([]byte, error) { + // [uTLS SECTION END] var exts cryptobyte.Builder if len(m.serverName) > 0 { // RFC 6066, Section 3 @@ -254,7 +260,7 @@ func (m *clientHelloMsg) marshalMsg(echInner bool) ([]byte, error) { } if len(m.supportedVersions) > 0 { // RFC 8446, Section 4.2.1 - if echInner { + if echInner && outerExts == nil { // uTLS echOuterExts = append(echOuterExts, extensionSupportedVersions) } else { exts.AddUint16(extensionSupportedVersions) @@ -311,6 +317,21 @@ func (m *clientHelloMsg) marshalMsg(echInner bool) ([]byte, error) { }) } } + // [uTLS SECTION BEGIN] + // reorder OuterExtensions according to their order in the spec + if echInner && outerExts != nil { + echOuterExtsReordered := slices.Collect(func(yield func(uint16) bool) { + for _, ext := range outerExts { + if slices.Contains(echOuterExts, ext) { + if !yield(ext) { + return + } + } + } + }) + echOuterExts = echOuterExtsReordered + } + // [uTLS SECTION END] if len(echOuterExts) > 0 && echInner { exts.AddUint16(extensionECHOuterExtensions) exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { diff --git a/handshake_test.go b/handshake_test.go index 8f62403..56126fa 100644 --- a/handshake_test.go +++ b/handshake_test.go @@ -476,11 +476,32 @@ func runMain(m *testing.M) int { } func testHandshake(t *testing.T, clientConfig, serverConfig *Config) (serverState, clientState ConnectionState, err error) { + // [uTLS SECTION BEGIN] + return testUtlsHandshake(t, clientConfig, serverConfig, nil) +} +func testUtlsHandshake(t *testing.T, clientConfig, serverConfig *Config, spec *ClientHelloSpec) (serverState, clientState ConnectionState, err error) { + // [uTLS SECTION END] const sentinel = "SENTINEL\n" c, s := localPipe(t) errChan := make(chan error, 1) go func() { - cli := Client(c, clientConfig) + // [uTLS SECTION BEGIN] + var cli interface { + Handshake() error + ConnectionState() ConnectionState + Close() error + io.Reader + } + if spec != nil { + ucli := UClient(c, clientConfig, HelloCustom) + if err = ucli.ApplyPreset(spec); err != nil { + return + } + cli = ucli + } else { + cli = Client(c, clientConfig) + } + // [uTLS SECTION END] err := cli.Handshake() if err != nil { errChan <- fmt.Errorf("client: %v", err) diff --git a/tls_test.go b/tls_test.go index 4e37c36..45da4bb 100644 --- a/tls_test.go +++ b/tls_test.go @@ -2068,6 +2068,10 @@ func TestLargeCertMsg(t *testing.T) { } func TestECH(t *testing.T) { + testECHSpec(t, nil, true) +} + +func testECHSpec(t *testing.T, spec *ClientHelloSpec, expectSuccess bool) { k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) @@ -2157,26 +2161,34 @@ func TestECH(t *testing.T) { {Config: echConfig, PrivateKey: echKey.Bytes(), SendAsRetry: true}, } - ss, cs, err := testHandshake(t, clientConfig, serverConfig) - if err != nil { - t.Fatalf("unexpected failure: %s", err) - } - if !ss.ECHAccepted { - t.Fatal("server ConnectionState shows ECH not accepted") - } - if !cs.ECHAccepted { - t.Fatal("client ConnectionState shows ECH not accepted") - } - if cs.ServerName != "secret.example" || ss.ServerName != "secret.example" { - t.Fatalf("unexpected ConnectionState.ServerName, want %q, got server:%q, client: %q", "secret.example", ss.ServerName, cs.ServerName) - } - if len(cs.VerifiedChains) != 1 { - t.Fatal("unexpect number of certificate chains") - } - if len(cs.VerifiedChains[0]) != 1 { - t.Fatal("unexpect number of certificates") - } - if !cs.VerifiedChains[0][0].Equal(secretCert) { - t.Fatal("unexpected certificate") + // [uTLS SECTION BEGIN] + ss, cs, err := testUtlsHandshake(t, clientConfig, serverConfig, spec) + if expectSuccess { + if err != nil { + t.Fatalf("unexpected failure: %s", err) + } + if !ss.ECHAccepted { + t.Fatal("server ConnectionState shows ECH not accepted") + } + if !cs.ECHAccepted { + t.Fatal("client ConnectionState shows ECH not accepted") + } + if cs.ServerName != "secret.example" || ss.ServerName != "secret.example" { + t.Fatalf("unexpected ConnectionState.ServerName, want %q, got server:%q, client: %q", "secret.example", ss.ServerName, cs.ServerName) + } + if len(cs.VerifiedChains) != 1 { + t.Fatal("unexpect number of certificate chains") + } + if len(cs.VerifiedChains[0]) != 1 { + t.Fatal("unexpect number of certificates") + } + if !cs.VerifiedChains[0][0].Equal(secretCert) { + t.Fatal("unexpected certificate") + } + } else { + if err == nil { + t.Fatalf("unexpected handshake success, expected failure") + } } + // [uTLS SECTION END] } diff --git a/u_conn.go b/u_conn.go index 6871a73..cc1f12d 100644 --- a/u_conn.go +++ b/u_conn.go @@ -14,7 +14,10 @@ import ( "fmt" "hash" "net" + "slices" "strconv" + + "golang.org/x/crypto/cryptobyte" ) type ClientHelloBuildStatus int @@ -487,15 +490,119 @@ func (uconn *UConn) ApplyConfig() error { return nil } -func (uconn *UConn) MarshalClientHello() error { - if len(uconn.config.ECHConfigs) > 0 && uconn.ech != nil { - if err := uconn.ech.Configure(uconn.config.ECHConfigs); err != nil { - return err - } - return uconn.ech.MarshalClientHello(uconn) +func (uconn *UConn) extensionsList() []uint16 { + + outerExts := []uint16{} + for _, ext := range uconn.Extensions { + buffer := cryptobyte.String(make([]byte, 2000)) + ext.Read(buffer) + var extension uint16 + buffer.ReadUint16(&extension) + outerExts = append(outerExts, extension) + } + return outerExts +} + +func (uconn *UConn) computeAndUpdateOuterECHExtension(inner *clientHelloMsg, ech *echClientContext, useKey bool) error { + // This function is mostly copied from + // https://github.com/refraction-networking/utls/blob/e430876b1d82fdf582efc57f3992d448e7ab3d8a/ech.go#L408 + var encapKey []byte + if useKey { + encapKey = ech.encapsulatedKey } - return uconn.MarshalClientHelloNoECH() // if no ECH pointer, just marshal normally + encodedInner, err := encodeInnerClientHelloReorderOuterExts(inner, int(ech.config.MaxNameLength), uconn.extensionsList()) + if err != nil { + return err + } + + encryptedLen := len(encodedInner) + 16 + outerECHExt, err := generateOuterECHExt(ech.config.ConfigID, ech.kdfID, ech.aeadID, encapKey, make([]byte, encryptedLen)) + if err != nil { + return err + } + + echExtIdx := slices.IndexFunc(uconn.Extensions, func(ext TLSExtension) bool { + _, ok := ext.(EncryptedClientHelloExtension) + return ok + }) + if echExtIdx < 0 { + return fmt.Errorf("extension satisfying EncryptedClientHelloExtension not present") + } + oldExt := uconn.Extensions[echExtIdx] + + uconn.Extensions[echExtIdx] = &GenericExtension{ + Id: extensionEncryptedClientHello, + Data: outerECHExt, + } + + if err := uconn.MarshalClientHelloNoECH(); err != nil { + return err + } + + serializedOuter := uconn.HandshakeState.Hello.Raw + serializedOuter = serializedOuter[4:] + encryptedInner, err := ech.hpkeContext.Seal(serializedOuter, encodedInner) + if err != nil { + return err + } + outerECHExt, err = generateOuterECHExt(ech.config.ConfigID, ech.kdfID, ech.aeadID, encapKey, encryptedInner) + if err != nil { + return err + } + uconn.Extensions[echExtIdx] = &GenericExtension{ + Id: extensionEncryptedClientHello, + Data: outerECHExt, + } + + if err := uconn.MarshalClientHelloNoECH(); err != nil { + return err + } + + uconn.Extensions[echExtIdx] = oldExt + return nil + +} + +func (uconn *UConn) MarshalClientHello() error { + if len(uconn.config.EncryptedClientHelloConfigList) > 0 { + inner, _, ech, err := uconn.makeClientHello() + if err != nil { + return err + } + + // copy compressed extensions to the ClientHelloInner + inner.keyShares = KeyShares(uconn.HandshakeState.Hello.KeyShares).ToPrivate() + inner.supportedSignatureAlgorithms = uconn.HandshakeState.Hello.SupportedSignatureAlgorithms + inner.sessionId = uconn.HandshakeState.Hello.SessionId + inner.supportedCurves = uconn.HandshakeState.Hello.SupportedCurves + + ech.innerHello = inner + + sniExtIdex := slices.IndexFunc(uconn.Extensions, func(ext TLSExtension) bool { + _, ok := ext.(*SNIExtension) + return ok + }) + if sniExtIdex < 0 { + return fmt.Errorf("sni extension missing while attempting ECH") + } + + uconn.Extensions[sniExtIdex] = &SNIExtension{ + ServerName: string(ech.config.PublicName), + } + + uconn.computeAndUpdateOuterECHExtension(inner, ech, true) + + uconn.echCtx = ech + return nil + } + + if err := uconn.MarshalClientHelloNoECH(); err != nil { + return err + } + + return nil + } // MarshalClientHelloNoECH marshals ClientHello as if there was no @@ -651,8 +758,10 @@ func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []T } uconn.HandshakeState.Hello.SupportedVersions = makeSupportedVersions(minTLSVers, maxTLSVers) - uconn.config.MinVersion = minTLSVers - uconn.config.MaxVersion = maxTLSVers + if uconn.config.EncryptedClientHelloConfigList == nil { + uconn.config.MinVersion = minTLSVers + uconn.config.MaxVersion = maxTLSVers + } return nil } diff --git a/u_conn_test.go b/u_conn_test.go index 4931b0d..be162c7 100644 --- a/u_conn_test.go +++ b/u_conn_test.go @@ -743,3 +743,111 @@ func TestUTLSMakeConnWithCompleteHandshake(t *testing.T) { serverTls.Write(serverMsg) } + +func TestUTLSECH(t *testing.T) { + chromeLatest, err := utlsIdToSpec(HelloChrome_Auto) + if err != nil { + t.Fatal(err) + } + + firefoxLatest, err := utlsIdToSpec(HelloFirefox_Auto) + if err != nil { + t.Fatal(err) + } + + for _, test := range []struct { + name string + spec *ClientHelloSpec + expectSuccess bool + }{ + { + name: "latest chrome", + spec: &chromeLatest, + expectSuccess: true, + }, + { + name: "latest firefox", + spec: &firefoxLatest, + expectSuccess: true, + }, + { + name: "ech extension missing", + spec: &ClientHelloSpec{ + CipherSuites: []uint16{ + GREASE_PLACEHOLDER, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + TLS_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_256_GCM_SHA384, + TLS_RSA_WITH_AES_128_CBC_SHA, + TLS_RSA_WITH_AES_256_CBC_SHA, + }, + CompressionMethods: []byte{ + 0x00, // compressionNone + }, + Extensions: ShuffleChromeTLSExtensions([]TLSExtension{ + &UtlsGREASEExtension{}, + &SNIExtension{}, + &ExtendedMasterSecretExtension{}, + &RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient}, + &SupportedCurvesExtension{[]CurveID{ + GREASE_PLACEHOLDER, + X25519Kyber768Draft00, + X25519, + CurveP256, + CurveP384, + }}, + &SupportedPointsExtension{SupportedPoints: []byte{ + 0x00, // pointFormatUncompressed + }}, + &SessionTicketExtension{}, + &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, + &StatusRequestExtension{}, + &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ + ECDSAWithP256AndSHA256, + PSSWithSHA256, + PKCS1WithSHA256, + ECDSAWithP384AndSHA384, + PSSWithSHA384, + PKCS1WithSHA384, + PSSWithSHA512, + PKCS1WithSHA512, + }}, + &SCTExtension{}, + &KeyShareExtension{[]KeyShare{ + {Group: CurveID(GREASE_PLACEHOLDER), Data: []byte{0}}, + {Group: X25519Kyber768Draft00}, + {Group: X25519}, + }}, + &PSKKeyExchangeModesExtension{[]uint8{ + PskModeDHE, + }}, + &SupportedVersionsExtension{[]uint16{ + GREASE_PLACEHOLDER, + VersionTLS13, + VersionTLS12, + }}, + &UtlsCompressCertExtension{[]CertCompressionAlgo{ + CertCompressionBrotli, + }}, + &ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}}, + &UtlsGREASEExtension{}, + }), + }, + expectSuccess: false, + }, + } { + t.Run(test.name, func(t *testing.T) { + testECHSpec(t, test.spec, test.expectSuccess) + }) + } +} diff --git a/u_handshake_client.go b/u_handshake_client.go index 8674caa..2cb6760 100644 --- a/u_handshake_client.go +++ b/u_handshake_client.go @@ -483,7 +483,7 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) { }() } - if ech != nil { + if ech != nil && c.clientHelloBuildStatus != BuildByUtls { // Split hello into inner and outer ech.innerHello = hello.clone() @@ -578,3 +578,23 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) { } return nil } + +func (c *UConn) echTranscriptMsg(outer *clientHelloMsg, echCtx *echClientContext) (err error) { + // Recreate the inner ClientHello from its compressed form using server's decodeInnerClientHello function. + // See https://github.com/refraction-networking/utls/blob/e430876b1d82fdf582efc57f3992d448e7ab3d8a/ech.go#L276-L283 + encodedInner, err := encodeInnerClientHelloReorderOuterExts(echCtx.innerHello, int(echCtx.config.MaxNameLength), c.extensionsList()) + if err != nil { + return err + } + + decodedInner, err := decodeInnerClientHello(outer, encodedInner) + if err != nil { + return err + } + + if err := transcriptMsg(decodedInner, echCtx.innerTranscript); err != nil { + return err + } + + return nil +}