utls/handshake_messages_test.go
Gaukas Wang 86e9b69fdd
sync: Go 1.21 with QUIC support (#208)
* sync: Go 1.21rc3, QUIC support added (#207)

* sync: merge with upstream tag/go-1.21rc3 (#11)

* fix: all tests pass

* impl: UQUIC Transport

* deps: bump up min Go version

* new: uquic

* fix: add QUICTransportParameter

* deprecated: Go 1.19 no longer supported

Go 1.19 will fail to build or pass the test once we bump up to the new version.

* sync: crypto/tls: restrict RSA keys in certificates to <= 8192 bits (#209)

* [release-branch.go1.21] crypto/tls: restrict RSA keys in certificates to <= 8192 bits

Extremely large RSA keys in certificate chains can cause a client/server
to expend significant CPU time verifying signatures. Limit this by
restricting the size of RSA keys transmitted during handshakes to <=
8192 bits.

Based on a survey of publicly trusted RSA keys, there are currently only
three certificates in circulation with keys larger than this, and all
three appear to be test certificates that are not actively deployed. It
is possible there are larger keys in use in private PKIs, but we target
the web PKI, so causing breakage here in the interests of increasing the
default safety of users of crypto/tls seems reasonable.

Thanks to Mateusz Poliwczak for reporting this issue.

Fixes CVE-2023-29409

* build: [ci skip] boring not included

* fix: typo [ci skip]

* docs: replenish readme [ci skip]

replace old build status badge with new ones, bump up required version noted in docs, update developer contact to reflect current status.
2023-08-03 23:22:53 -06:00

567 lines
15 KiB
Go

// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"crypto/x509"
"encoding/hex"
"math"
"math/rand"
"reflect"
"strings"
"testing"
"testing/quick"
"time"
)
var tests = []handshakeMessage{
&clientHelloMsg{},
&serverHelloMsg{},
&finishedMsg{},
&certificateMsg{},
&certificateRequestMsg{},
&certificateVerifyMsg{
hasSignatureAlgorithm: true,
},
&certificateStatusMsg{},
&clientKeyExchangeMsg{},
&newSessionTicketMsg{},
&encryptedExtensionsMsg{},
&endOfEarlyDataMsg{},
&keyUpdateMsg{},
&newSessionTicketMsgTLS13{},
&certificateRequestMsgTLS13{},
&certificateMsgTLS13{},
&SessionState{},
}
func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
t.Helper()
b, err := msg.marshal()
if err != nil {
t.Fatal(err)
}
return b
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(time.Now().UnixNano()))
for i, m := range tests {
ty := reflect.ValueOf(m).Type()
n := 100
if testing.Short() {
n = 5
}
for j := 0; j < n; j++ {
v, ok := quick.Value(ty, rand)
if !ok {
t.Errorf("#%d: failed to create value", i)
break
}
m1 := v.Interface().(handshakeMessage)
marshaled := mustMarshal(t, m1)
if !m.unmarshal(marshaled) {
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
break
}
m.marshal() // to fill any marshal cache in the message
if m, ok := m.(*SessionState); ok {
m.activeCertHandles = nil
}
if !reflect.DeepEqual(m1, m) {
t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
break
}
if i >= 3 {
// The first three message types (ClientHello,
// ServerHello and Finished) are allowed to
// have parsable prefixes because the extension
// data is optional and the length of the
// Finished varies across versions.
for j := 0; j < len(marshaled); j++ {
if m.unmarshal(marshaled[0:j]) {
t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
break
}
}
}
}
}
}
func TestFuzz(t *testing.T) {
rand := rand.New(rand.NewSource(0))
for _, m := range tests {
for j := 0; j < 1000; j++ {
len := rand.Intn(1000)
bytes := randomBytes(len, rand)
// This just looks for crashes due to bounds errors etc.
m.unmarshal(bytes)
}
}
}
func randomBytes(n int, rand *rand.Rand) []byte {
r := make([]byte, n)
if _, err := rand.Read(r); err != nil {
panic("rand.Read failed: " + err.Error())
}
return r
}
func randomString(n int, rand *rand.Rand) string {
b := randomBytes(n, rand)
return string(b)
}
func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientHelloMsg{}
m.vers = uint16(rand.Intn(65536))
m.random = randomBytes(32, rand)
m.sessionId = randomBytes(rand.Intn(32), rand)
m.cipherSuites = make([]uint16, rand.Intn(63)+1)
for i := 0; i < len(m.cipherSuites); i++ {
cs := uint16(rand.Int31())
if cs == scsvRenegotiation {
cs += 1
}
m.cipherSuites[i] = cs
}
m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
if rand.Intn(10) > 5 {
m.serverName = randomString(rand.Intn(255), rand)
for strings.HasSuffix(m.serverName, ".") {
m.serverName = m.serverName[:len(m.serverName)-1]
}
}
m.ocspStapling = rand.Intn(10) > 5
m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
for i := range m.supportedCurves {
m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
}
if rand.Intn(10) > 5 {
m.ticketSupported = true
if rand.Intn(10) > 5 {
m.sessionTicket = randomBytes(rand.Intn(300), rand)
} else {
m.sessionTicket = make([]byte, 0)
}
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
}
for i := 0; i < rand.Intn(5); i++ {
m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
}
if rand.Intn(10) > 5 {
m.scts = true
}
if rand.Intn(10) > 5 {
m.secureRenegotiationSupported = true
m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
}
if rand.Intn(10) > 5 {
m.extendedMasterSecret = true
}
for i := 0; i < rand.Intn(5); i++ {
m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
}
if rand.Intn(10) > 5 {
m.cookie = randomBytes(rand.Intn(500)+1, rand)
}
for i := 0; i < rand.Intn(5); i++ {
var ks keyShare
ks.group = CurveID(rand.Intn(30000) + 1)
ks.data = randomBytes(rand.Intn(200)+1, rand)
m.keyShares = append(m.keyShares, ks)
}
switch rand.Intn(3) {
case 1:
m.pskModes = []uint8{pskModeDHE}
case 2:
m.pskModes = []uint8{pskModeDHE, pskModePlain}
}
for i := 0; i < rand.Intn(5); i++ {
var psk pskIdentity
psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
psk.label = randomBytes(rand.Intn(500)+1, rand)
m.pskIdentities = append(m.pskIdentities, psk)
m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
}
if rand.Intn(10) > 5 {
m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
}
if rand.Intn(10) > 5 {
m.earlyData = true
}
return reflect.ValueOf(m)
}
func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &serverHelloMsg{}
m.vers = uint16(rand.Intn(65536))
m.random = randomBytes(32, rand)
m.sessionId = randomBytes(rand.Intn(32), rand)
m.cipherSuite = uint16(rand.Int31())
m.compressionMethod = uint8(rand.Intn(256))
m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
if rand.Intn(10) > 5 {
m.ocspStapling = true
}
if rand.Intn(10) > 5 {
m.ticketSupported = true
}
if rand.Intn(10) > 5 {
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
}
for i := 0; i < rand.Intn(4); i++ {
m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
}
if rand.Intn(10) > 5 {
m.secureRenegotiationSupported = true
m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
}
if rand.Intn(10) > 5 {
m.extendedMasterSecret = true
}
if rand.Intn(10) > 5 {
m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
}
if rand.Intn(10) > 5 {
m.cookie = randomBytes(rand.Intn(500)+1, rand)
}
if rand.Intn(10) > 5 {
for i := 0; i < rand.Intn(5); i++ {
m.serverShare.group = CurveID(rand.Intn(30000) + 1)
m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
}
} else if rand.Intn(10) > 5 {
m.selectedGroup = CurveID(rand.Intn(30000) + 1)
}
if rand.Intn(10) > 5 {
m.selectedIdentityPresent = true
m.selectedIdentity = uint16(rand.Intn(0xffff))
}
return reflect.ValueOf(m)
}
func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &encryptedExtensionsMsg{}
if rand.Intn(10) > 5 {
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
}
if rand.Intn(10) > 5 {
m.earlyData = true
}
return reflect.ValueOf(m)
}
func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateMsg{}
numCerts := rand.Intn(20)
m.certificates = make([][]byte, numCerts)
for i := 0; i < numCerts; i++ {
m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
}
return reflect.ValueOf(m)
}
func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateRequestMsg{}
m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
for i := 0; i < rand.Intn(100); i++ {
m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
}
return reflect.ValueOf(m)
}
func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateVerifyMsg{}
m.hasSignatureAlgorithm = true
m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
m.signature = randomBytes(rand.Intn(15)+1, rand)
return reflect.ValueOf(m)
}
func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateStatusMsg{}
m.response = randomBytes(rand.Intn(10)+1, rand)
return reflect.ValueOf(m)
}
func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientKeyExchangeMsg{}
m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
return reflect.ValueOf(m)
}
func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &finishedMsg{}
m.verifyData = randomBytes(12, rand)
return reflect.ValueOf(m)
}
func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &newSessionTicketMsg{}
m.ticket = randomBytes(rand.Intn(4), rand)
return reflect.ValueOf(m)
}
var sessionTestCerts []*x509.Certificate
func init() {
cert, err := x509.ParseCertificate(testRSACertificate)
if err != nil {
panic(err)
}
sessionTestCerts = append(sessionTestCerts, cert)
cert, err = x509.ParseCertificate(testRSACertificateIssuer)
if err != nil {
panic(err)
}
sessionTestCerts = append(sessionTestCerts, cert)
}
func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
s := &SessionState{}
isTLS13 := rand.Intn(10) > 5
if isTLS13 {
s.version = VersionTLS13
} else {
s.version = uint16(rand.Intn(VersionTLS13))
}
s.isClient = rand.Intn(10) > 5
s.cipherSuite = uint16(rand.Intn(math.MaxUint16))
s.createdAt = uint64(rand.Int63())
s.secret = randomBytes(rand.Intn(100)+1, rand)
for n, i := rand.Intn(3), 0; i < n; i++ {
s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand))
}
if rand.Intn(10) > 5 {
s.EarlyData = true
}
if rand.Intn(10) > 5 {
s.extMasterSecret = true
}
if s.isClient || rand.Intn(10) > 5 {
if rand.Intn(10) > 5 {
s.peerCertificates = sessionTestCerts
} else {
s.peerCertificates = sessionTestCerts[:1]
}
}
if rand.Intn(10) > 5 && s.peerCertificates != nil {
s.ocspResponse = randomBytes(rand.Intn(100)+1, rand)
}
if rand.Intn(10) > 5 && s.peerCertificates != nil {
for i := 0; i < rand.Intn(2)+1; i++ {
s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand))
}
}
if len(s.peerCertificates) > 0 {
for i := 0; i < rand.Intn(3); i++ {
if rand.Intn(10) > 5 {
s.verifiedChains = append(s.verifiedChains, s.peerCertificates)
} else {
s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1])
}
}
}
if rand.Intn(10) > 5 && s.EarlyData {
s.alpnProtocol = string(randomBytes(rand.Intn(10), rand))
}
if s.isClient {
if isTLS13 {
s.useBy = uint64(rand.Int63())
s.ageAdd = uint32(rand.Int63() & math.MaxUint32)
}
}
return reflect.ValueOf(s)
}
func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() }
func (s *SessionState) unmarshal(b []byte) bool {
ss, err := ParseSessionState(b)
if err != nil {
return false
}
*s = *ss
return true
}
func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &endOfEarlyDataMsg{}
return reflect.ValueOf(m)
}
func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &keyUpdateMsg{}
m.updateRequested = rand.Intn(10) > 5
return reflect.ValueOf(m)
}
func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &newSessionTicketMsgTLS13{}
m.lifetime = uint32(rand.Intn(500000))
m.ageAdd = uint32(rand.Intn(500000))
m.nonce = randomBytes(rand.Intn(100), rand)
m.label = randomBytes(rand.Intn(1000), rand)
if rand.Intn(10) > 5 {
m.maxEarlyData = uint32(rand.Intn(500000))
}
return reflect.ValueOf(m)
}
func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateRequestMsgTLS13{}
if rand.Intn(10) > 5 {
m.ocspStapling = true
}
if rand.Intn(10) > 5 {
m.scts = true
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
}
if rand.Intn(10) > 5 {
m.certificateAuthorities = make([][]byte, 3)
for i := 0; i < 3; i++ {
m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
}
}
return reflect.ValueOf(m)
}
func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateMsgTLS13{}
for i := 0; i < rand.Intn(2)+1; i++ {
m.certificate.Certificate = append(
m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
}
if rand.Intn(10) > 5 {
m.ocspStapling = true
m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
}
if rand.Intn(10) > 5 {
m.scts = true
for i := 0; i < rand.Intn(2)+1; i++ {
m.certificate.SignedCertificateTimestamps = append(
m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
}
}
return reflect.ValueOf(m)
}
// [UTLS]
func (*utlsCompressedCertificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &utlsCompressedCertificateMsg{}
m.algorithm = uint16(rand.Intn(2 << 15))
m.uncompressedLength = uint32(rand.Intn(2 << 23))
m.compressedCertificateMessage = randomBytes(rand.Intn(500)+1, rand)
return reflect.ValueOf(m)
}
func TestRejectEmptySCTList(t *testing.T) {
// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
var random [32]byte
sct := []byte{0x42, 0x42, 0x42, 0x42}
serverHello := &serverHelloMsg{
vers: VersionTLS12,
random: random[:],
scts: [][]byte{sct},
}
serverHelloBytes := mustMarshal(t, serverHello)
var serverHelloCopy serverHelloMsg
if !serverHelloCopy.unmarshal(serverHelloBytes) {
t.Fatal("Failed to unmarshal initial message")
}
// Change serverHelloBytes so that the SCT list is empty
i := bytes.Index(serverHelloBytes, sct)
if i < 0 {
t.Fatal("Cannot find SCT in ServerHello")
}
var serverHelloEmptySCT []byte
serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
// Append the extension length and SCT list length for an empty list.
serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
// Update the handshake message length.
serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
// Update the extensions length
serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
t.Fatal("Unmarshaled ServerHello with empty SCT list")
}
}
func TestRejectEmptySCT(t *testing.T) {
// Not only must the SCT list be non-empty, but the SCT elements must
// not be zero length.
var random [32]byte
serverHello := &serverHelloMsg{
vers: VersionTLS12,
random: random[:],
scts: [][]byte{nil},
}
serverHelloBytes := mustMarshal(t, serverHello)
var serverHelloCopy serverHelloMsg
if serverHelloCopy.unmarshal(serverHelloBytes) {
t.Fatal("Unmarshaled ServerHello with zero-length SCT")
}
}
func TestRejectDuplicateExtensions(t *testing.T) {
clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
if err != nil {
t.Fatalf("failed to decode test ClientHello: %s", err)
}
var clientHelloCopy clientHelloMsg
if clientHelloCopy.unmarshal(clientHelloBytes) {
t.Error("Unmarshaled ClientHello with duplicate extensions")
}
serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
if err != nil {
t.Fatalf("failed to decode test ServerHello: %s", err)
}
var serverHelloCopy serverHelloMsg
if serverHelloCopy.unmarshal(serverHelloBytes) {
t.Fatal("Unmarshaled ServerHello with duplicate extensions")
}
}