utls/bogo_shim_test.go
Roland Shoemaker ce1cbd081a crypto/tls: add ech client support
This CL adds a (very opinionated) client-side ECH implementation.

In particular, if a user configures a ECHConfigList, by setting the
Config.EncryptedClientHelloConfigList, but we determine that none of
the configs are appropriate, we will not fallback to plaintext SNI, and
will instead return an error. It is then up to the user to decide if
they wish to fallback to plaintext themselves (by removing the config
list).

Additionally if Config.EncryptedClientHelloConfigList is provided, we
will not offer TLS support lower than 1.3, since negotiating any other
version, while offering ECH, is a hard error anyway. Similarly, if a
user wishes to fallback to plaintext SNI by using 1.2, they may do so
by removing the config list.

With regard to PSK GREASE, we match the boringssl  behavior, which does
not include PSK identities/binders in the outer hello when doing ECH.

If the server rejects ECH, we will return a ECHRejectionError error,
which, if provided by the server, will contain a ECHConfigList in the
RetryConfigList field containing configs that should be used if the user
wishes to retry. It is up to the user to replace their existing
Config.EncryptedClientHelloConfigList with the retry config list.

Fixes #63369

Cq-Include-Trybots: luci.golang.try:gotip-linux-amd64-longtest
Change-Id: I9bc373c044064221a647a388ac61624efd6bbdbf
Reviewed-on: https://go-review.googlesource.com/c/go/+/578575
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Than McIntosh <thanm@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
Auto-Submit: Roland Shoemaker <roland@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
2024-05-23 03:10:12 +00:00

365 lines
10 KiB
Go

package tls
import (
"bytes"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"flag"
"fmt"
"internal/byteorder"
"internal/testenv"
"io"
"log"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"
)
var (
port = flag.String("port", "", "")
server = flag.Bool("server", false, "")
isHandshakerSupported = flag.Bool("is-handshaker-supported", false, "")
keyfile = flag.String("key-file", "", "")
certfile = flag.String("cert-file", "", "")
trustCert = flag.String("trust-cert", "", "")
minVersion = flag.Int("min-version", VersionSSL30, "")
maxVersion = flag.Int("max-version", VersionTLS13, "")
noTLS13 = flag.Bool("no-tls13", false, "")
requireAnyClientCertificate = flag.Bool("require-any-client-certificate", false, "")
shimWritesFirst = flag.Bool("shim-writes-first", false, "")
resumeCount = flag.Int("resume-count", 0, "")
curves = flagStringSlice("curves", "")
expectedCurve = flag.String("expect-curve-id", "", "")
shimID = flag.Uint64("shim-id", 0, "")
_ = flag.Bool("ipv6", false, "")
echConfigListB64 = flag.String("ech-config-list", "", "")
expectECHAccepted = flag.Bool("expect-ech-accept", false, "")
expectHRR = flag.Bool("expect-hrr", false, "")
expectedECHRetryConfigs = flag.String("expect-ech-retry-configs", "", "")
expectNoECHRetryConfigs = flag.Bool("expect-no-ech-retry-configs", false, "")
onInitialExpectECHAccepted = flag.Bool("on-initial-expect-ech-accept", false, "")
_ = flag.Bool("expect-no-ech-name-override", false, "")
_ = flag.String("expect-ech-name-override", "", "")
_ = flag.Bool("reverify-on-resume", false, "")
onResumeECHConfigListB64 = flag.String("on-resume-ech-config-list", "", "")
_ = flag.Bool("on-resume-expect-reject-early-data", false, "")
onResumeExpectECHAccepted = flag.Bool("on-resume-expect-ech-accept", false, "")
_ = flag.Bool("on-resume-expect-no-ech-name-override", false, "")
expectedServerName = flag.String("expect-server-name", "", "")
expectSessionMiss = flag.Bool("expect-session-miss", false, "")
_ = flag.Bool("enable-early-data", false, "")
_ = flag.Bool("on-resume-expect-accept-early-data", false, "")
_ = flag.Bool("expect-ticket-supports-early-data", false, "")
onResumeShimWritesFirst = flag.Bool("on-resume-shim-writes-first", false, "")
advertiseALPN = flag.String("advertise-alpn", "", "")
expectALPN = flag.String("expect-alpn", "", "")
hostName = flag.String("host-name", "", "")
verifyPeer = flag.Bool("verify-peer", false, "")
_ = flag.Bool("use-custom-verify-callback", false, "")
)
type stringSlice []string
func flagStringSlice(name, usage string) *stringSlice {
f := &stringSlice{}
flag.Var(f, name, usage)
return f
}
func (saf stringSlice) String() string {
return strings.Join(saf, ",")
}
func (saf stringSlice) Set(s string) error {
saf = append(saf, s)
return nil
}
func bogoShim() {
if *isHandshakerSupported {
fmt.Println("No")
return
}
cfg := &Config{
ServerName: "test",
MinVersion: uint16(*minVersion),
MaxVersion: uint16(*maxVersion),
ClientSessionCache: NewLRUClientSessionCache(0),
}
if *noTLS13 && cfg.MaxVersion == VersionTLS13 {
cfg.MaxVersion = VersionTLS12
}
if *advertiseALPN != "" {
alpns := *advertiseALPN
for len(alpns) > 0 {
alpnLen := int(alpns[0])
cfg.NextProtos = append(cfg.NextProtos, alpns[1:1+alpnLen])
alpns = alpns[alpnLen+1:]
}
}
if *hostName != "" {
cfg.ServerName = *hostName
}
if *keyfile != "" || *certfile != "" {
pair, err := LoadX509KeyPair(*certfile, *keyfile)
if err != nil {
log.Fatalf("load key-file err: %s", err)
}
cfg.Certificates = []Certificate{pair}
}
if *trustCert != "" {
pool := x509.NewCertPool()
certFile, err := os.ReadFile(*trustCert)
if err != nil {
log.Fatalf("load trust-cert err: %s", err)
}
block, _ := pem.Decode(certFile)
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
log.Fatalf("parse trust-cert err: %s", err)
}
pool.AddCert(cert)
cfg.RootCAs = pool
}
if *requireAnyClientCertificate {
cfg.ClientAuth = RequireAnyClientCert
}
if *verifyPeer {
cfg.ClientAuth = VerifyClientCertIfGiven
}
if *echConfigListB64 != "" {
echConfigList, err := base64.StdEncoding.DecodeString(*echConfigListB64)
if err != nil {
log.Fatalf("parse ech-config-list err: %s", err)
}
cfg.EncryptedClientHelloConfigList = echConfigList
cfg.MinVersion = VersionTLS13
}
if len(*curves) != 0 {
for _, curveStr := range *curves {
id, err := strconv.Atoi(curveStr)
if err != nil {
log.Fatalf("failed to parse curve id %q: %s", curveStr, err)
}
cfg.CurvePreferences = append(cfg.CurvePreferences, CurveID(id))
}
}
for i := 0; i < *resumeCount+1; i++ {
if i > 0 && (*onResumeECHConfigListB64 != "") {
echConfigList, err := base64.StdEncoding.DecodeString(*onResumeECHConfigListB64)
if err != nil {
log.Fatalf("parse ech-config-list err: %s", err)
}
cfg.EncryptedClientHelloConfigList = echConfigList
}
conn, err := net.Dial("tcp", net.JoinHostPort("localhost", *port))
if err != nil {
log.Fatalf("dial err: %s", err)
}
defer conn.Close()
// Write the shim ID we were passed as a little endian uint64
shimIDBytes := make([]byte, 8)
byteorder.LePutUint64(shimIDBytes, *shimID)
if _, err := conn.Write(shimIDBytes); err != nil {
log.Fatalf("failed to write shim id: %s", err)
}
var tlsConn *Conn
if *server {
tlsConn = Server(conn, cfg)
} else {
tlsConn = Client(conn, cfg)
}
if i == 0 && *shimWritesFirst {
if _, err := tlsConn.Write([]byte("hello")); err != nil {
log.Fatalf("write err: %s", err)
}
}
for {
buf := make([]byte, 500)
var n int
n, err = tlsConn.Read(buf)
if err != nil {
break
}
buf = buf[:n]
for i := range buf {
buf[i] ^= 0xff
}
if _, err = tlsConn.Write(buf); err != nil {
break
}
}
if err != nil && err != io.EOF {
retryErr, ok := err.(*ECHRejectionError)
if !ok {
log.Fatalf("unexpected error type returned: %v", err)
}
if *expectNoECHRetryConfigs && len(retryErr.RetryConfigList) > 0 {
log.Fatalf("expected no ECH retry configs, got some")
}
if *expectedECHRetryConfigs != "" {
expectedRetryConfigs, err := base64.StdEncoding.DecodeString(*expectedECHRetryConfigs)
if err != nil {
log.Fatalf("failed to decode expected retry configs: %s", err)
}
if !bytes.Equal(retryErr.RetryConfigList, expectedRetryConfigs) {
log.Fatalf("unexpected retry list returned: got %x, want %x", retryErr.RetryConfigList, expectedRetryConfigs)
}
}
log.Fatalf("conn error: %s", err)
}
cs := tlsConn.ConnectionState()
if cs.HandshakeComplete {
if *expectALPN != "" && cs.NegotiatedProtocol != *expectALPN {
log.Fatalf("unexpected protocol negotiated: want %q, got %q", *expectALPN, cs.NegotiatedProtocol)
}
if *expectECHAccepted && !cs.ECHAccepted {
log.Fatal("expected ECH to be accepted, but connection state shows it was not")
} else if i == 0 && *onInitialExpectECHAccepted && !cs.ECHAccepted {
log.Fatal("expected ECH to be accepted, but connection state shows it was not")
} else if i > 0 && *onResumeExpectECHAccepted && !cs.ECHAccepted {
log.Fatal("expected ECH to be accepted on resumption, but connection state shows it was not")
} else if i == 0 && !*expectECHAccepted && cs.ECHAccepted {
log.Fatal("did not expect ECH, but it was accepted")
}
if *expectHRR && !cs.testingOnlyDidHRR {
log.Fatal("expected HRR but did not do it")
}
if *expectSessionMiss && cs.DidResume {
log.Fatal("unexpected session resumption")
}
if *expectedServerName != "" && cs.ServerName != *expectedServerName {
log.Fatalf("unexpected server name: got %q, want %q", cs.ServerName, *expectedServerName)
}
}
if *expectedCurve != "" {
expectedCurveID, err := strconv.Atoi(*expectedCurve)
if err != nil {
log.Fatalf("failed to parse -expect-curve-id: %s", err)
}
if tlsConn.curveID != CurveID(expectedCurveID) {
log.Fatalf("unexpected curve id: want %d, got %d", expectedCurveID, tlsConn.curveID)
}
}
}
}
func TestBogoSuite(t *testing.T) {
testenv.SkipIfShortAndSlow(t)
testenv.MustHaveExternalNetwork(t)
testenv.MustHaveGoRun(t)
testenv.MustHaveExec(t)
if testing.Short() {
t.Skip("skipping in short mode")
}
if testenv.Builder() != "" && runtime.GOOS == "windows" {
t.Skip("#66913: windows network connections are flakey on builders")
}
var bogoDir string
if *bogoLocalDir != "" {
bogoDir = *bogoLocalDir
} else {
const boringsslModVer = "v0.0.0-20240517213134-ba62c812f01f"
output, err := exec.Command("go", "mod", "download", "-json", "github.com/google/boringssl@"+boringsslModVer).CombinedOutput()
if err != nil {
t.Fatalf("failed to download boringssl: %s", err)
}
var j struct {
Dir string
}
if err := json.Unmarshal(output, &j); err != nil {
t.Fatalf("failed to parse 'go mod download' output: %s", err)
}
bogoDir = j.Dir
}
cwd, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
args := []string{
"test",
".",
fmt.Sprintf("-shim-config=%s", filepath.Join(cwd, "bogo_config.json")),
fmt.Sprintf("-shim-path=%s", os.Args[0]),
"-shim-extra-flags=-bogo-mode",
"-allow-unimplemented",
"-loose-errors", // TODO(roland): this should be removed eventually
"-pipe",
"-v",
}
if *bogoFilter != "" {
args = append(args, fmt.Sprintf("-test=%s", *bogoFilter))
}
goCmd, err := testenv.GoTool()
if err != nil {
t.Fatal(err)
}
cmd := exec.Command(goCmd, args...)
out := &strings.Builder{}
cmd.Stdout, cmd.Stderr = io.MultiWriter(os.Stdout, out), os.Stderr
cmd.Dir = filepath.Join(bogoDir, "ssl/test/runner")
err = cmd.Run()
if err != nil {
t.Fatalf("bogo failed: %s", err)
}
if *bogoFilter == "" {
assertPass := func(t *testing.T, name string) {
t.Helper()
if !strings.Contains(out.String(), "PASSED ("+name+")\n") {
t.Errorf("Expected test %s did not run", name)
}
}
assertPass(t, "CurveTest-Client-Kyber-TLS13")
assertPass(t, "CurveTest-Server-Kyber-TLS13")
}
}