sync: merge changes from go 1.23.4

This commit is contained in:
Mingye Chen 2025-01-07 15:55:09 -07:00
commit cefe226467
98 changed files with 8089 additions and 4530 deletions

View file

@ -7,10 +7,14 @@ package tls
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
@ -26,6 +30,8 @@ import (
"strings"
"testing"
"time"
"github.com/refraction-networking/utls/internal/byteorder"
)
// Note: see comment in handshake_test.go for details of how the reference
@ -202,7 +208,7 @@ func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd,
var serverInfo bytes.Buffer
for _, ext := range test.extensions {
pem.Encode(&serverInfo, &pem.Block{
Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", byteorder.BeUint16(ext)),
Bytes: ext,
})
}
@ -278,7 +284,7 @@ func (test *clientTest) loadData() (flows [][]byte, err error) {
}
func (test *clientTest) run(t *testing.T, write bool) {
var clientConn, serverConn net.Conn
var clientConn net.Conn
var recordingConn *recordingConn
var childProcess *exec.Cmd
var stdin opensslInput
@ -297,178 +303,138 @@ func (test *clientTest) run(t *testing.T, write bool) {
}
}()
} else {
clientConn, serverConn = localPipe(t)
flows, err := test.loadData()
if err != nil {
t.Fatalf("failed to load data from %s: %v", test.dataPath(), err)
}
clientConn = &replayingConn{t: t, flows: flows, reading: false}
}
doneChan := make(chan bool)
defer func() {
clientConn.Close()
<-doneChan
}()
go func() {
defer close(doneChan)
config := test.config
if config == nil {
config = testConfig
}
client := Client(clientConn, config)
defer client.Close()
config := test.config
if config == nil {
config = testConfig
if _, err := client.Write([]byte("hello\n")); err != nil {
t.Errorf("Client.Write failed: %s", err)
return
}
for i := 1; i <= test.numRenegotiations; i++ {
// The initial handshake will generate a
// handshakeComplete signal which needs to be quashed.
if i == 1 && write {
<-stdout.handshakeComplete
}
client := Client(clientConn, config)
defer client.Close()
if _, err := client.Write([]byte("hello\n")); err != nil {
// OpenSSL will try to interleave application data and
// a renegotiation if we send both concurrently.
// Therefore: ask OpensSSL to start a renegotiation, run
// a goroutine to call client.Read and thus process the
// renegotiation request, watch for OpenSSL's stdout to
// indicate that the handshake is complete and,
// finally, have OpenSSL write something to cause
// client.Read to complete.
if write {
stdin <- opensslRenegotiate
}
signalChan := make(chan struct{})
go func() {
defer close(signalChan)
buf := make([]byte, 256)
n, err := client.Read(buf)
if test.checkRenegotiationError != nil {
newErr := test.checkRenegotiationError(i, err)
if err != nil && newErr == nil {
return
}
err = newErr
}
if err != nil {
t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
return
}
buf = buf[:n]
if !bytes.Equal([]byte(opensslSentinel), buf) {
t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
}
if expected := i + 1; client.handshakes != expected {
t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
}
}()
if write && test.renegotiationExpectedToFail != i {
<-stdout.handshakeComplete
stdin <- opensslSendSentinel
}
<-signalChan
}
if test.sendKeyUpdate {
if write {
<-stdout.handshakeComplete
stdin <- opensslKeyUpdate
}
doneRead := make(chan struct{})
go func() {
defer close(doneRead)
buf := make([]byte, 256)
n, err := client.Read(buf)
if err != nil {
t.Errorf("Client.Read failed after KeyUpdate: %s", err)
return
}
buf = buf[:n]
if !bytes.Equal([]byte(opensslSentinel), buf) {
t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
}
}()
if write {
// There's no real reason to wait for the client KeyUpdate to
// send data with the new server keys, except that s_server
// drops writes if they are sent at the wrong time.
<-stdout.readKeyUpdate
stdin <- opensslSendSentinel
}
<-doneRead
if _, err := client.Write([]byte("hello again\n")); err != nil {
t.Errorf("Client.Write failed: %s", err)
return
}
}
for i := 1; i <= test.numRenegotiations; i++ {
// The initial handshake will generate a
// handshakeComplete signal which needs to be quashed.
if i == 1 && write {
<-stdout.handshakeComplete
}
// OpenSSL will try to interleave application data and
// a renegotiation if we send both concurrently.
// Therefore: ask OpensSSL to start a renegotiation, run
// a goroutine to call client.Read and thus process the
// renegotiation request, watch for OpenSSL's stdout to
// indicate that the handshake is complete and,
// finally, have OpenSSL write something to cause
// client.Read to complete.
if write {
stdin <- opensslRenegotiate
}
signalChan := make(chan struct{})
go func() {
defer close(signalChan)
buf := make([]byte, 256)
n, err := client.Read(buf)
if test.checkRenegotiationError != nil {
newErr := test.checkRenegotiationError(i, err)
if err != nil && newErr == nil {
return
}
err = newErr
}
if err != nil {
t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
return
}
buf = buf[:n]
if !bytes.Equal([]byte(opensslSentinel), buf) {
t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
}
if expected := i + 1; client.handshakes != expected {
t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
}
}()
if write && test.renegotiationExpectedToFail != i {
<-stdout.handshakeComplete
stdin <- opensslSendSentinel
}
<-signalChan
}
if test.sendKeyUpdate {
if write {
<-stdout.handshakeComplete
stdin <- opensslKeyUpdate
}
doneRead := make(chan struct{})
go func() {
defer close(doneRead)
buf := make([]byte, 256)
n, err := client.Read(buf)
if err != nil {
t.Errorf("Client.Read failed after KeyUpdate: %s", err)
return
}
buf = buf[:n]
if !bytes.Equal([]byte(opensslSentinel), buf) {
t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
}
}()
if write {
// There's no real reason to wait for the client KeyUpdate to
// send data with the new server keys, except that s_server
// drops writes if they are sent at the wrong time.
<-stdout.readKeyUpdate
stdin <- opensslSendSentinel
}
<-doneRead
if _, err := client.Write([]byte("hello again\n")); err != nil {
t.Errorf("Client.Write failed: %s", err)
return
}
}
if test.validate != nil {
if err := test.validate(client.ConnectionState()); err != nil {
t.Errorf("validate callback returned error: %s", err)
}
}
// If the server sent us an alert after our last flight, give it a
// chance to arrive.
if write && test.renegotiationExpectedToFail == 0 {
if err := peekError(client); err != nil {
t.Errorf("final Read returned an error: %s", err)
}
}
}()
if !write {
flows, err := test.loadData()
if err != nil {
t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
}
for i, b := range flows {
if i%2 == 1 {
if *fast {
serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
} else {
serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
}
serverConn.Write(b)
continue
}
bb := make([]byte, len(b))
if *fast {
serverConn.SetReadDeadline(time.Now().Add(1 * time.Second))
} else {
serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
}
_, err := io.ReadFull(serverConn, bb)
if err != nil {
t.Fatalf("%s, flow %d: %s", test.name, i+1, err)
}
if !bytes.Equal(b, bb) {
t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
}
if test.validate != nil {
if err := test.validate(client.ConnectionState()); err != nil {
t.Errorf("validate callback returned error: %s", err)
}
}
<-doneChan
if !write {
serverConn.Close()
// If the server sent us an alert after our last flight, give it a
// chance to arrive.
if write && test.renegotiationExpectedToFail == 0 {
if err := peekError(client); err != nil {
t.Errorf("final Read returned an error: %s", err)
}
}
if write {
clientConn.Close()
path := test.dataPath()
out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
@ -659,6 +625,12 @@ func TestHandshakeClientHelloRetryRequest(t *testing.T) {
name: "HelloRetryRequest",
args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
config: config,
validate: func(cs ConnectionState) error {
if !cs.testingOnlyDidHRR {
return errors.New("expected HelloRetryRequest")
}
return nil
},
}
runClientTestTLS13(t, test)
@ -881,6 +853,7 @@ func testResumption(t *testing.T, version uint16) {
MaxVersion: version,
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
Certificates: testConfig.Certificates,
Time: testTime, // [uTLS]
}
issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
@ -897,6 +870,7 @@ func testResumption(t *testing.T, version uint16) {
ClientSessionCache: NewLRUClientSessionCache(32),
RootCAs: rootCAs,
ServerName: "example.golang",
Time: testTime, // [uTLS]
}
testResumeState := func(test string, didResume bool) {
@ -917,7 +891,7 @@ func testResumption(t *testing.T, version uint16) {
}
getTicket := func() []byte {
return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.ticket
return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.session.ticket
}
deleteTicket := func() {
ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey
@ -943,7 +917,7 @@ func testResumption(t *testing.T, version uint16) {
// An old session ticket is replaced with a ticket encrypted with a fresh key.
ticket = getTicket()
serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
serverConfig.Time = func() time.Time { return testTime().Add(24*time.Hour + time.Minute) } // [uTLS]
testResumeState("ResumeWithOldTicket", true)
if bytes.Equal(ticket, getTicket()) {
t.Fatal("old first ticket matches the fresh one")
@ -951,13 +925,13 @@ func testResumption(t *testing.T, version uint16) {
// Once the session master secret is expired, a full handshake should occur.
ticket = getTicket()
serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
serverConfig.Time = func() time.Time { return testTime().Add(24*8*time.Hour + time.Minute) } // [uTLS]
testResumeState("ResumeWithExpiredTicket", false)
if bytes.Equal(ticket, getTicket()) {
t.Fatal("expired first ticket matches the fresh one")
}
serverConfig.Time = func() time.Time { return time.Now() } // reset the time back
serverConfig.Time = testTime // reset the time back // [uTLS]
key1 := randomKey()
serverConfig.SetSessionTicketKeys([][32]byte{key1})
@ -974,11 +948,11 @@ func testResumption(t *testing.T, version uint16) {
testResumeState("KeyChangeFinish", true)
// Age the session ticket a bit, but not yet expired.
serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
serverConfig.Time = func() time.Time { return testTime().Add(24*time.Hour + time.Minute) } // [uTLS]
testResumeState("OldSessionTicket", true)
ticket = getTicket()
// Expire the session ticket, which would force a full handshake.
serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
serverConfig.Time = func() time.Time { return testTime().Add(24*8*time.Hour + 2*time.Minute) } // [uTLS]
testResumeState("ExpiredSessionTicket", false)
if bytes.Equal(ticket, getTicket()) {
t.Fatal("new ticket wasn't provided after old ticket expired")
@ -986,7 +960,7 @@ func testResumption(t *testing.T, version uint16) {
// Age the session ticket a bit at a time, but don't expire it.
d := 0 * time.Hour
serverConfig.Time = func() time.Time { return time.Now().Add(d) }
serverConfig.Time = func() time.Time { return testTime().Add(d) } // [uTLS]
deleteTicket()
testResumeState("GetFreshSessionTicket", false)
for i := 0; i < 13; i++ {
@ -997,7 +971,7 @@ func testResumption(t *testing.T, version uint16) {
// handshake occurs for TLS 1.2. Resumption should still occur for
// TLS 1.3 since the client should be using a fresh ticket sent over
// by the server.
d += 12 * time.Hour
d += 12*time.Hour + time.Minute // [uTLS]
if version == VersionTLS13 {
testResumeState("ExpiredSessionTicket", true)
} else {
@ -1013,6 +987,7 @@ func testResumption(t *testing.T, version uint16) {
MaxVersion: version,
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
Certificates: testConfig.Certificates,
Time: testTime, // [uTLS]
}
serverConfig.SetSessionTicketKeys([][32]byte{key2})
@ -1038,6 +1013,7 @@ func testResumption(t *testing.T, version uint16) {
CurvePreferences: []CurveID{CurveP521, CurveP384, CurveP256},
MaxVersion: version,
Certificates: testConfig.Certificates,
Time: testTime, // [uTLS]
}
testResumeState("InitialHandshake", false)
testResumeState("WithHelloRetryRequest", true)
@ -1047,6 +1023,7 @@ func testResumption(t *testing.T, version uint16) {
MaxVersion: version,
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
Certificates: testConfig.Certificates,
Time: testTime, // [uTLS]
}
}
@ -1101,6 +1078,10 @@ func (c *serializingClientCache) Get(sessionKey string) (session *ClientSessionS
}
func (c *serializingClientCache) Put(sessionKey string, cs *ClientSessionState) {
if cs == nil {
c.ticket, c.state = nil, nil
return
}
ticket, state, err := cs.ResumptionState()
if err != nil {
c.t.Error(err)
@ -1761,6 +1742,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
serverConfig := &Config{
MaxVersion: version,
Certificates: []Certificate{testConfig.Certificates[0]},
Time: testTime, // [uTLS]
ClientCAs: rootCAs,
NextProtos: []string{"protocol1"},
}
@ -1774,6 +1756,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
RootCAs: rootCAs,
ServerName: "example.golang",
Certificates: []Certificate{testConfig.Certificates[0]},
Time: testTime, // [uTLS]
NextProtos: []string{"protocol1"},
}
test.configureClient(clientConfig, &clientCalled)
@ -2564,6 +2547,7 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
ClientSessionCache: NewLRUClientSessionCache(32),
ServerName: "example.golang",
RootCAs: roots,
Time: testTime, // [uTLS]
}
serverConfig := testConfig.Clone()
serverConfig.MaxVersion = ver
@ -2799,3 +2783,123 @@ func TestHandshakeRSATooBig(t *testing.T) {
t.Errorf("Conn.processCertsFromClient unexpected error: want %q, got %q", expectedErr, err)
}
}
func TestTLS13ECHRejectionCallbacks(t *testing.T) {
k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "test"},
DNSNames: []string{"example.golang"},
NotBefore: testConfig.Time().Add(-time.Hour),
NotAfter: testConfig.Time().Add(time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, k.Public(), k)
if err != nil {
t.Fatal(err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
t.Fatal(err)
}
clientConfig, serverConfig := testConfig.Clone(), testConfig.Clone()
serverConfig.Certificates = []Certificate{
{
Certificate: [][]byte{certDER},
PrivateKey: k,
},
}
serverConfig.MinVersion = VersionTLS13
clientConfig.RootCAs = x509.NewCertPool()
clientConfig.RootCAs.AddCert(cert)
clientConfig.MinVersion = VersionTLS13
clientConfig.EncryptedClientHelloConfigList, _ = hex.DecodeString("0041fe0d003d0100200020204bed0a11fc0dde595a9b78d966b0011128eb83f65d3c91c1cc5ac786cd246f000400010001ff0e6578616d706c652e676f6c616e670000")
clientConfig.ServerName = "example.golang"
for _, tc := range []struct {
name string
expectedErr string
verifyConnection func(ConnectionState) error
verifyPeerCertificate func([][]byte, [][]*x509.Certificate) error
encryptedClientHelloRejectionVerify func(ConnectionState) error
}{
{
name: "no callbacks",
expectedErr: "tls: server rejected ECH",
},
{
name: "EncryptedClientHelloRejectionVerify, no err",
encryptedClientHelloRejectionVerify: func(ConnectionState) error {
return nil
},
expectedErr: "tls: server rejected ECH",
},
{
name: "EncryptedClientHelloRejectionVerify, err",
encryptedClientHelloRejectionVerify: func(ConnectionState) error {
return errors.New("callback err")
},
// testHandshake returns the server side error, so we just need to
// check alertBadCertificate was sent
expectedErr: "callback err",
},
{
name: "VerifyConnection, err",
verifyConnection: func(ConnectionState) error {
return errors.New("callback err")
},
expectedErr: "tls: server rejected ECH",
},
{
name: "VerifyPeerCertificate, err",
verifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error {
return errors.New("callback err")
},
expectedErr: "tls: server rejected ECH",
},
} {
t.Run(tc.name, func(t *testing.T) {
c, s := localPipe(t)
done := make(chan error)
go func() {
serverErr := Server(s, serverConfig).Handshake()
s.Close()
done <- serverErr
}()
cConfig := clientConfig.Clone()
cConfig.VerifyConnection = tc.verifyConnection
cConfig.VerifyPeerCertificate = tc.verifyPeerCertificate
cConfig.EncryptedClientHelloRejectionVerify = tc.encryptedClientHelloRejectionVerify
clientErr := Client(c, cConfig).Handshake()
c.Close()
if tc.expectedErr == "" && clientErr != nil {
t.Fatalf("unexpected err: %s", clientErr)
} else if clientErr != nil && tc.expectedErr != clientErr.Error() {
t.Fatalf("unexpected err: got %q, want %q", clientErr, tc.expectedErr)
}
})
}
}
func TestECHTLS12Server(t *testing.T) {
clientConfig, serverConfig := testConfig.Clone(), testConfig.Clone()
serverConfig.MaxVersion = VersionTLS12
clientConfig.MinVersion = 0
clientConfig.EncryptedClientHelloConfigList, _ = hex.DecodeString("0041fe0d003d0100200020204bed0a11fc0dde595a9b78d966b0011128eb83f65d3c91c1cc5ac786cd246f000400010001ff0e6578616d706c652e676f6c616e670000")
expectedErr := "server: tls: client offered only unsupported versions: [304]\nclient: remote error: tls: protocol version not supported"
_, _, err := testHandshake(t, clientConfig, serverConfig)
if err == nil || err.Error() != expectedErr {
t.Fatalf("unexpected handshake error: got %q, want %q", err, expectedErr)
}
}