utls/handshake_messages_test.go
Roland Shoemaker ce1cbd081a crypto/tls: add ech client support
This CL adds a (very opinionated) client-side ECH implementation.

In particular, if a user configures a ECHConfigList, by setting the
Config.EncryptedClientHelloConfigList, but we determine that none of
the configs are appropriate, we will not fallback to plaintext SNI, and
will instead return an error. It is then up to the user to decide if
they wish to fallback to plaintext themselves (by removing the config
list).

Additionally if Config.EncryptedClientHelloConfigList is provided, we
will not offer TLS support lower than 1.3, since negotiating any other
version, while offering ECH, is a hard error anyway. Similarly, if a
user wishes to fallback to plaintext SNI by using 1.2, they may do so
by removing the config list.

With regard to PSK GREASE, we match the boringssl  behavior, which does
not include PSK identities/binders in the outer hello when doing ECH.

If the server rejects ECH, we will return a ECHRejectionError error,
which, if provided by the server, will contain a ECHConfigList in the
RetryConfigList field containing configs that should be used if the user
wishes to retry. It is up to the user to replace their existing
Config.EncryptedClientHelloConfigList with the retry config list.

Fixes #63369

Cq-Include-Trybots: luci.golang.try:gotip-linux-amd64-longtest
Change-Id: I9bc373c044064221a647a388ac61624efd6bbdbf
Reviewed-on: https://go-review.googlesource.com/c/go/+/578575
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Than McIntosh <thanm@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
Auto-Submit: Roland Shoemaker <roland@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
2024-05-23 03:10:12 +00:00

576 lines
16 KiB
Go

// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"crypto/x509"
"encoding/hex"
"math"
"math/rand"
"reflect"
"strings"
"testing"
"testing/quick"
"time"
)
var tests = []handshakeMessage{
&clientHelloMsg{},
&serverHelloMsg{},
&finishedMsg{},
&certificateMsg{},
&certificateRequestMsg{},
&certificateVerifyMsg{
hasSignatureAlgorithm: true,
},
&certificateStatusMsg{},
&clientKeyExchangeMsg{},
&newSessionTicketMsg{},
&encryptedExtensionsMsg{},
&endOfEarlyDataMsg{},
&keyUpdateMsg{},
&newSessionTicketMsgTLS13{},
&certificateRequestMsgTLS13{},
&certificateMsgTLS13{},
&SessionState{},
}
func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
t.Helper()
b, err := msg.marshal()
if err != nil {
t.Fatal(err)
}
return b
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(time.Now().UnixNano()))
for i, m := range tests {
ty := reflect.ValueOf(m).Type()
t.Run(ty.String(), func(t *testing.T) {
n := 100
if testing.Short() {
n = 5
}
for j := 0; j < n; j++ {
v, ok := quick.Value(ty, rand)
if !ok {
t.Errorf("#%d: failed to create value", i)
break
}
m1 := v.Interface().(handshakeMessage)
marshaled := mustMarshal(t, m1)
if !m.unmarshal(marshaled) {
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
break
}
if m, ok := m.(*SessionState); ok {
m.activeCertHandles = nil
}
// clientHelloMsg and serverHelloMsg, when unmarshalled, store
// their original representation, for later use in the handshake
// transcript. In order to prevent DeepEqual from failing since
// we didn't create the original message via unmarshalling, nil
// the field.
switch t := m.(type) {
case *clientHelloMsg:
t.original = nil
case *serverHelloMsg:
t.original = nil
}
if !reflect.DeepEqual(m1, m) {
t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
break
}
if i >= 3 {
// The first three message types (ClientHello,
// ServerHello and Finished) are allowed to
// have parsable prefixes because the extension
// data is optional and the length of the
// Finished varies across versions.
for j := 0; j < len(marshaled); j++ {
if m.unmarshal(marshaled[0:j]) {
t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
break
}
}
}
}
})
}
}
func TestFuzz(t *testing.T) {
rand := rand.New(rand.NewSource(0))
for _, m := range tests {
for j := 0; j < 1000; j++ {
len := rand.Intn(1000)
bytes := randomBytes(len, rand)
// This just looks for crashes due to bounds errors etc.
m.unmarshal(bytes)
}
}
}
func randomBytes(n int, rand *rand.Rand) []byte {
r := make([]byte, n)
if _, err := rand.Read(r); err != nil {
panic("rand.Read failed: " + err.Error())
}
return r
}
func randomString(n int, rand *rand.Rand) string {
b := randomBytes(n, rand)
return string(b)
}
func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientHelloMsg{}
m.vers = uint16(rand.Intn(65536))
m.random = randomBytes(32, rand)
m.sessionId = randomBytes(rand.Intn(32), rand)
m.cipherSuites = make([]uint16, rand.Intn(63)+1)
for i := 0; i < len(m.cipherSuites); i++ {
cs := uint16(rand.Int31())
if cs == scsvRenegotiation {
cs += 1
}
m.cipherSuites[i] = cs
}
m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
if rand.Intn(10) > 5 {
m.serverName = randomString(rand.Intn(255), rand)
for strings.HasSuffix(m.serverName, ".") {
m.serverName = m.serverName[:len(m.serverName)-1]
}
}
m.ocspStapling = rand.Intn(10) > 5
m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
for i := range m.supportedCurves {
m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
}
if rand.Intn(10) > 5 {
m.ticketSupported = true
if rand.Intn(10) > 5 {
m.sessionTicket = randomBytes(rand.Intn(300), rand)
} else {
m.sessionTicket = make([]byte, 0)
}
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
}
for i := 0; i < rand.Intn(5); i++ {
m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
}
if rand.Intn(10) > 5 {
m.scts = true
}
if rand.Intn(10) > 5 {
m.secureRenegotiationSupported = true
m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
}
if rand.Intn(10) > 5 {
m.extendedMasterSecret = true
}
for i := 0; i < rand.Intn(5); i++ {
m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
}
if rand.Intn(10) > 5 {
m.cookie = randomBytes(rand.Intn(500)+1, rand)
}
for i := 0; i < rand.Intn(5); i++ {
var ks keyShare
ks.group = CurveID(rand.Intn(30000) + 1)
ks.data = randomBytes(rand.Intn(200)+1, rand)
m.keyShares = append(m.keyShares, ks)
}
switch rand.Intn(3) {
case 1:
m.pskModes = []uint8{pskModeDHE}
case 2:
m.pskModes = []uint8{pskModeDHE, pskModePlain}
}
for i := 0; i < rand.Intn(5); i++ {
var psk pskIdentity
psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
psk.label = randomBytes(rand.Intn(500)+1, rand)
m.pskIdentities = append(m.pskIdentities, psk)
m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
}
if rand.Intn(10) > 5 {
m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
}
if rand.Intn(10) > 5 {
m.earlyData = true
}
return reflect.ValueOf(m)
}
func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &serverHelloMsg{}
m.vers = uint16(rand.Intn(65536))
m.random = randomBytes(32, rand)
m.sessionId = randomBytes(rand.Intn(32), rand)
m.cipherSuite = uint16(rand.Int31())
m.compressionMethod = uint8(rand.Intn(256))
m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
if rand.Intn(10) > 5 {
m.ocspStapling = true
}
if rand.Intn(10) > 5 {
m.ticketSupported = true
}
if rand.Intn(10) > 5 {
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
}
for i := 0; i < rand.Intn(4); i++ {
m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
}
if rand.Intn(10) > 5 {
m.secureRenegotiationSupported = true
m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
}
if rand.Intn(10) > 5 {
m.extendedMasterSecret = true
}
if rand.Intn(10) > 5 {
m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
}
if rand.Intn(10) > 5 {
m.cookie = randomBytes(rand.Intn(500)+1, rand)
}
if rand.Intn(10) > 5 {
for i := 0; i < rand.Intn(5); i++ {
m.serverShare.group = CurveID(rand.Intn(30000) + 1)
m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
}
} else if rand.Intn(10) > 5 {
m.selectedGroup = CurveID(rand.Intn(30000) + 1)
}
if rand.Intn(10) > 5 {
m.selectedIdentityPresent = true
m.selectedIdentity = uint16(rand.Intn(0xffff))
}
if rand.Intn(10) > 5 {
m.encryptedClientHello = randomBytes(rand.Intn(50)+1, rand)
}
if rand.Intn(10) > 5 {
m.serverNameAck = rand.Intn(2) == 1
}
return reflect.ValueOf(m)
}
func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &encryptedExtensionsMsg{}
if rand.Intn(10) > 5 {
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
}
if rand.Intn(10) > 5 {
m.earlyData = true
}
return reflect.ValueOf(m)
}
func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateMsg{}
numCerts := rand.Intn(20)
m.certificates = make([][]byte, numCerts)
for i := 0; i < numCerts; i++ {
m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
}
return reflect.ValueOf(m)
}
func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateRequestMsg{}
m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
for i := 0; i < rand.Intn(100); i++ {
m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
}
return reflect.ValueOf(m)
}
func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateVerifyMsg{}
m.hasSignatureAlgorithm = true
m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
m.signature = randomBytes(rand.Intn(15)+1, rand)
return reflect.ValueOf(m)
}
func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateStatusMsg{}
m.response = randomBytes(rand.Intn(10)+1, rand)
return reflect.ValueOf(m)
}
func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientKeyExchangeMsg{}
m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
return reflect.ValueOf(m)
}
func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &finishedMsg{}
m.verifyData = randomBytes(12, rand)
return reflect.ValueOf(m)
}
func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &newSessionTicketMsg{}
m.ticket = randomBytes(rand.Intn(4), rand)
return reflect.ValueOf(m)
}
var sessionTestCerts []*x509.Certificate
func init() {
cert, err := x509.ParseCertificate(testRSACertificate)
if err != nil {
panic(err)
}
sessionTestCerts = append(sessionTestCerts, cert)
cert, err = x509.ParseCertificate(testRSACertificateIssuer)
if err != nil {
panic(err)
}
sessionTestCerts = append(sessionTestCerts, cert)
}
func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
s := &SessionState{}
isTLS13 := rand.Intn(10) > 5
if isTLS13 {
s.version = VersionTLS13
} else {
s.version = uint16(rand.Intn(VersionTLS13))
}
s.isClient = rand.Intn(10) > 5
s.cipherSuite = uint16(rand.Intn(math.MaxUint16))
s.createdAt = uint64(rand.Int63())
s.secret = randomBytes(rand.Intn(100)+1, rand)
for n, i := rand.Intn(3), 0; i < n; i++ {
s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand))
}
if rand.Intn(10) > 5 {
s.EarlyData = true
}
if rand.Intn(10) > 5 {
s.extMasterSecret = true
}
if s.isClient || rand.Intn(10) > 5 {
if rand.Intn(10) > 5 {
s.peerCertificates = sessionTestCerts
} else {
s.peerCertificates = sessionTestCerts[:1]
}
}
if rand.Intn(10) > 5 && s.peerCertificates != nil {
s.ocspResponse = randomBytes(rand.Intn(100)+1, rand)
}
if rand.Intn(10) > 5 && s.peerCertificates != nil {
for i := 0; i < rand.Intn(2)+1; i++ {
s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand))
}
}
if len(s.peerCertificates) > 0 {
for i := 0; i < rand.Intn(3); i++ {
if rand.Intn(10) > 5 {
s.verifiedChains = append(s.verifiedChains, s.peerCertificates)
} else {
s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1])
}
}
}
if rand.Intn(10) > 5 && s.EarlyData {
s.alpnProtocol = string(randomBytes(rand.Intn(10), rand))
}
if s.isClient {
if isTLS13 {
s.useBy = uint64(rand.Int63())
s.ageAdd = uint32(rand.Int63() & math.MaxUint32)
}
}
return reflect.ValueOf(s)
}
func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() }
func (s *SessionState) unmarshal(b []byte) bool {
ss, err := ParseSessionState(b)
if err != nil {
return false
}
*s = *ss
return true
}
func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &endOfEarlyDataMsg{}
return reflect.ValueOf(m)
}
func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &keyUpdateMsg{}
m.updateRequested = rand.Intn(10) > 5
return reflect.ValueOf(m)
}
func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &newSessionTicketMsgTLS13{}
m.lifetime = uint32(rand.Intn(500000))
m.ageAdd = uint32(rand.Intn(500000))
m.nonce = randomBytes(rand.Intn(100), rand)
m.label = randomBytes(rand.Intn(1000), rand)
if rand.Intn(10) > 5 {
m.maxEarlyData = uint32(rand.Intn(500000))
}
return reflect.ValueOf(m)
}
func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateRequestMsgTLS13{}
if rand.Intn(10) > 5 {
m.ocspStapling = true
}
if rand.Intn(10) > 5 {
m.scts = true
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
}
if rand.Intn(10) > 5 {
m.certificateAuthorities = make([][]byte, 3)
for i := 0; i < 3; i++ {
m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
}
}
return reflect.ValueOf(m)
}
func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateMsgTLS13{}
for i := 0; i < rand.Intn(2)+1; i++ {
m.certificate.Certificate = append(
m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
}
if rand.Intn(10) > 5 {
m.ocspStapling = true
m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
}
if rand.Intn(10) > 5 {
m.scts = true
for i := 0; i < rand.Intn(2)+1; i++ {
m.certificate.SignedCertificateTimestamps = append(
m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
}
}
return reflect.ValueOf(m)
}
func TestRejectEmptySCTList(t *testing.T) {
// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
var random [32]byte
sct := []byte{0x42, 0x42, 0x42, 0x42}
serverHello := &serverHelloMsg{
vers: VersionTLS12,
random: random[:],
scts: [][]byte{sct},
}
serverHelloBytes := mustMarshal(t, serverHello)
var serverHelloCopy serverHelloMsg
if !serverHelloCopy.unmarshal(serverHelloBytes) {
t.Fatal("Failed to unmarshal initial message")
}
// Change serverHelloBytes so that the SCT list is empty
i := bytes.Index(serverHelloBytes, sct)
if i < 0 {
t.Fatal("Cannot find SCT in ServerHello")
}
var serverHelloEmptySCT []byte
serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
// Append the extension length and SCT list length for an empty list.
serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
// Update the handshake message length.
serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
// Update the extensions length
serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
t.Fatal("Unmarshaled ServerHello with empty SCT list")
}
}
func TestRejectEmptySCT(t *testing.T) {
// Not only must the SCT list be non-empty, but the SCT elements must
// not be zero length.
var random [32]byte
serverHello := &serverHelloMsg{
vers: VersionTLS12,
random: random[:],
scts: [][]byte{nil},
}
serverHelloBytes := mustMarshal(t, serverHello)
var serverHelloCopy serverHelloMsg
if serverHelloCopy.unmarshal(serverHelloBytes) {
t.Fatal("Unmarshaled ServerHello with zero-length SCT")
}
}
func TestRejectDuplicateExtensions(t *testing.T) {
clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
if err != nil {
t.Fatalf("failed to decode test ClientHello: %s", err)
}
var clientHelloCopy clientHelloMsg
if clientHelloCopy.unmarshal(clientHelloBytes) {
t.Error("Unmarshaled ClientHello with duplicate extensions")
}
serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
if err != nil {
t.Fatalf("failed to decode test ServerHello: %s", err)
}
var serverHelloCopy serverHelloMsg
if serverHelloCopy.unmarshal(serverHelloBytes) {
t.Fatal("Unmarshaled ServerHello with duplicate extensions")
}
}