mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-04 04:27:36 +03:00
sync: merge changes from go 1.23.4
This commit is contained in:
commit
cefe226467
98 changed files with 8089 additions and 4530 deletions
|
@ -7,10 +7,14 @@ package tls
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -26,6 +30,8 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/refraction-networking/utls/internal/byteorder"
|
||||
)
|
||||
|
||||
// Note: see comment in handshake_test.go for details of how the reference
|
||||
|
@ -202,7 +208,7 @@ func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd,
|
|||
var serverInfo bytes.Buffer
|
||||
for _, ext := range test.extensions {
|
||||
pem.Encode(&serverInfo, &pem.Block{
|
||||
Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
|
||||
Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", byteorder.BeUint16(ext)),
|
||||
Bytes: ext,
|
||||
})
|
||||
}
|
||||
|
@ -278,7 +284,7 @@ func (test *clientTest) loadData() (flows [][]byte, err error) {
|
|||
}
|
||||
|
||||
func (test *clientTest) run(t *testing.T, write bool) {
|
||||
var clientConn, serverConn net.Conn
|
||||
var clientConn net.Conn
|
||||
var recordingConn *recordingConn
|
||||
var childProcess *exec.Cmd
|
||||
var stdin opensslInput
|
||||
|
@ -297,178 +303,138 @@ func (test *clientTest) run(t *testing.T, write bool) {
|
|||
}
|
||||
}()
|
||||
} else {
|
||||
clientConn, serverConn = localPipe(t)
|
||||
flows, err := test.loadData()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load data from %s: %v", test.dataPath(), err)
|
||||
}
|
||||
clientConn = &replayingConn{t: t, flows: flows, reading: false}
|
||||
}
|
||||
|
||||
doneChan := make(chan bool)
|
||||
defer func() {
|
||||
clientConn.Close()
|
||||
<-doneChan
|
||||
}()
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
config := test.config
|
||||
if config == nil {
|
||||
config = testConfig
|
||||
}
|
||||
client := Client(clientConn, config)
|
||||
defer client.Close()
|
||||
|
||||
config := test.config
|
||||
if config == nil {
|
||||
config = testConfig
|
||||
if _, err := client.Write([]byte("hello\n")); err != nil {
|
||||
t.Errorf("Client.Write failed: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 1; i <= test.numRenegotiations; i++ {
|
||||
// The initial handshake will generate a
|
||||
// handshakeComplete signal which needs to be quashed.
|
||||
if i == 1 && write {
|
||||
<-stdout.handshakeComplete
|
||||
}
|
||||
client := Client(clientConn, config)
|
||||
defer client.Close()
|
||||
|
||||
if _, err := client.Write([]byte("hello\n")); err != nil {
|
||||
// OpenSSL will try to interleave application data and
|
||||
// a renegotiation if we send both concurrently.
|
||||
// Therefore: ask OpensSSL to start a renegotiation, run
|
||||
// a goroutine to call client.Read and thus process the
|
||||
// renegotiation request, watch for OpenSSL's stdout to
|
||||
// indicate that the handshake is complete and,
|
||||
// finally, have OpenSSL write something to cause
|
||||
// client.Read to complete.
|
||||
if write {
|
||||
stdin <- opensslRenegotiate
|
||||
}
|
||||
|
||||
signalChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(signalChan)
|
||||
|
||||
buf := make([]byte, 256)
|
||||
n, err := client.Read(buf)
|
||||
|
||||
if test.checkRenegotiationError != nil {
|
||||
newErr := test.checkRenegotiationError(i, err)
|
||||
if err != nil && newErr == nil {
|
||||
return
|
||||
}
|
||||
err = newErr
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
buf = buf[:n]
|
||||
if !bytes.Equal([]byte(opensslSentinel), buf) {
|
||||
t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
|
||||
}
|
||||
|
||||
if expected := i + 1; client.handshakes != expected {
|
||||
t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
|
||||
}
|
||||
}()
|
||||
|
||||
if write && test.renegotiationExpectedToFail != i {
|
||||
<-stdout.handshakeComplete
|
||||
stdin <- opensslSendSentinel
|
||||
}
|
||||
<-signalChan
|
||||
}
|
||||
|
||||
if test.sendKeyUpdate {
|
||||
if write {
|
||||
<-stdout.handshakeComplete
|
||||
stdin <- opensslKeyUpdate
|
||||
}
|
||||
|
||||
doneRead := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(doneRead)
|
||||
|
||||
buf := make([]byte, 256)
|
||||
n, err := client.Read(buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Client.Read failed after KeyUpdate: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
buf = buf[:n]
|
||||
if !bytes.Equal([]byte(opensslSentinel), buf) {
|
||||
t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
|
||||
}
|
||||
}()
|
||||
|
||||
if write {
|
||||
// There's no real reason to wait for the client KeyUpdate to
|
||||
// send data with the new server keys, except that s_server
|
||||
// drops writes if they are sent at the wrong time.
|
||||
<-stdout.readKeyUpdate
|
||||
stdin <- opensslSendSentinel
|
||||
}
|
||||
<-doneRead
|
||||
|
||||
if _, err := client.Write([]byte("hello again\n")); err != nil {
|
||||
t.Errorf("Client.Write failed: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for i := 1; i <= test.numRenegotiations; i++ {
|
||||
// The initial handshake will generate a
|
||||
// handshakeComplete signal which needs to be quashed.
|
||||
if i == 1 && write {
|
||||
<-stdout.handshakeComplete
|
||||
}
|
||||
|
||||
// OpenSSL will try to interleave application data and
|
||||
// a renegotiation if we send both concurrently.
|
||||
// Therefore: ask OpensSSL to start a renegotiation, run
|
||||
// a goroutine to call client.Read and thus process the
|
||||
// renegotiation request, watch for OpenSSL's stdout to
|
||||
// indicate that the handshake is complete and,
|
||||
// finally, have OpenSSL write something to cause
|
||||
// client.Read to complete.
|
||||
if write {
|
||||
stdin <- opensslRenegotiate
|
||||
}
|
||||
|
||||
signalChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(signalChan)
|
||||
|
||||
buf := make([]byte, 256)
|
||||
n, err := client.Read(buf)
|
||||
|
||||
if test.checkRenegotiationError != nil {
|
||||
newErr := test.checkRenegotiationError(i, err)
|
||||
if err != nil && newErr == nil {
|
||||
return
|
||||
}
|
||||
err = newErr
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
buf = buf[:n]
|
||||
if !bytes.Equal([]byte(opensslSentinel), buf) {
|
||||
t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
|
||||
}
|
||||
|
||||
if expected := i + 1; client.handshakes != expected {
|
||||
t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
|
||||
}
|
||||
}()
|
||||
|
||||
if write && test.renegotiationExpectedToFail != i {
|
||||
<-stdout.handshakeComplete
|
||||
stdin <- opensslSendSentinel
|
||||
}
|
||||
<-signalChan
|
||||
}
|
||||
|
||||
if test.sendKeyUpdate {
|
||||
if write {
|
||||
<-stdout.handshakeComplete
|
||||
stdin <- opensslKeyUpdate
|
||||
}
|
||||
|
||||
doneRead := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(doneRead)
|
||||
|
||||
buf := make([]byte, 256)
|
||||
n, err := client.Read(buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Client.Read failed after KeyUpdate: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
buf = buf[:n]
|
||||
if !bytes.Equal([]byte(opensslSentinel), buf) {
|
||||
t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
|
||||
}
|
||||
}()
|
||||
|
||||
if write {
|
||||
// There's no real reason to wait for the client KeyUpdate to
|
||||
// send data with the new server keys, except that s_server
|
||||
// drops writes if they are sent at the wrong time.
|
||||
<-stdout.readKeyUpdate
|
||||
stdin <- opensslSendSentinel
|
||||
}
|
||||
<-doneRead
|
||||
|
||||
if _, err := client.Write([]byte("hello again\n")); err != nil {
|
||||
t.Errorf("Client.Write failed: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if test.validate != nil {
|
||||
if err := test.validate(client.ConnectionState()); err != nil {
|
||||
t.Errorf("validate callback returned error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// If the server sent us an alert after our last flight, give it a
|
||||
// chance to arrive.
|
||||
if write && test.renegotiationExpectedToFail == 0 {
|
||||
if err := peekError(client); err != nil {
|
||||
t.Errorf("final Read returned an error: %s", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if !write {
|
||||
flows, err := test.loadData()
|
||||
if err != nil {
|
||||
t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
|
||||
}
|
||||
for i, b := range flows {
|
||||
if i%2 == 1 {
|
||||
if *fast {
|
||||
serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
|
||||
} else {
|
||||
serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
|
||||
}
|
||||
serverConn.Write(b)
|
||||
continue
|
||||
}
|
||||
bb := make([]byte, len(b))
|
||||
if *fast {
|
||||
serverConn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
} else {
|
||||
serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
|
||||
}
|
||||
_, err := io.ReadFull(serverConn, bb)
|
||||
if err != nil {
|
||||
t.Fatalf("%s, flow %d: %s", test.name, i+1, err)
|
||||
}
|
||||
if !bytes.Equal(b, bb) {
|
||||
t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
|
||||
}
|
||||
if test.validate != nil {
|
||||
if err := test.validate(client.ConnectionState()); err != nil {
|
||||
t.Errorf("validate callback returned error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
<-doneChan
|
||||
if !write {
|
||||
serverConn.Close()
|
||||
// If the server sent us an alert after our last flight, give it a
|
||||
// chance to arrive.
|
||||
if write && test.renegotiationExpectedToFail == 0 {
|
||||
if err := peekError(client); err != nil {
|
||||
t.Errorf("final Read returned an error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if write {
|
||||
clientConn.Close()
|
||||
path := test.dataPath()
|
||||
out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
|
@ -659,6 +625,12 @@ func TestHandshakeClientHelloRetryRequest(t *testing.T) {
|
|||
name: "HelloRetryRequest",
|
||||
args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
|
||||
config: config,
|
||||
validate: func(cs ConnectionState) error {
|
||||
if !cs.testingOnlyDidHRR {
|
||||
return errors.New("expected HelloRetryRequest")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
runClientTestTLS13(t, test)
|
||||
|
@ -881,6 +853,7 @@ func testResumption(t *testing.T, version uint16) {
|
|||
MaxVersion: version,
|
||||
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
|
||||
Certificates: testConfig.Certificates,
|
||||
Time: testTime, // [uTLS]
|
||||
}
|
||||
|
||||
issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
|
||||
|
@ -897,6 +870,7 @@ func testResumption(t *testing.T, version uint16) {
|
|||
ClientSessionCache: NewLRUClientSessionCache(32),
|
||||
RootCAs: rootCAs,
|
||||
ServerName: "example.golang",
|
||||
Time: testTime, // [uTLS]
|
||||
}
|
||||
|
||||
testResumeState := func(test string, didResume bool) {
|
||||
|
@ -917,7 +891,7 @@ func testResumption(t *testing.T, version uint16) {
|
|||
}
|
||||
|
||||
getTicket := func() []byte {
|
||||
return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.ticket
|
||||
return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.session.ticket
|
||||
}
|
||||
deleteTicket := func() {
|
||||
ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey
|
||||
|
@ -943,7 +917,7 @@ func testResumption(t *testing.T, version uint16) {
|
|||
|
||||
// An old session ticket is replaced with a ticket encrypted with a fresh key.
|
||||
ticket = getTicket()
|
||||
serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
|
||||
serverConfig.Time = func() time.Time { return testTime().Add(24*time.Hour + time.Minute) } // [uTLS]
|
||||
testResumeState("ResumeWithOldTicket", true)
|
||||
if bytes.Equal(ticket, getTicket()) {
|
||||
t.Fatal("old first ticket matches the fresh one")
|
||||
|
@ -951,13 +925,13 @@ func testResumption(t *testing.T, version uint16) {
|
|||
|
||||
// Once the session master secret is expired, a full handshake should occur.
|
||||
ticket = getTicket()
|
||||
serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
|
||||
serverConfig.Time = func() time.Time { return testTime().Add(24*8*time.Hour + time.Minute) } // [uTLS]
|
||||
testResumeState("ResumeWithExpiredTicket", false)
|
||||
if bytes.Equal(ticket, getTicket()) {
|
||||
t.Fatal("expired first ticket matches the fresh one")
|
||||
}
|
||||
|
||||
serverConfig.Time = func() time.Time { return time.Now() } // reset the time back
|
||||
serverConfig.Time = testTime // reset the time back // [uTLS]
|
||||
key1 := randomKey()
|
||||
serverConfig.SetSessionTicketKeys([][32]byte{key1})
|
||||
|
||||
|
@ -974,11 +948,11 @@ func testResumption(t *testing.T, version uint16) {
|
|||
testResumeState("KeyChangeFinish", true)
|
||||
|
||||
// Age the session ticket a bit, but not yet expired.
|
||||
serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
|
||||
serverConfig.Time = func() time.Time { return testTime().Add(24*time.Hour + time.Minute) } // [uTLS]
|
||||
testResumeState("OldSessionTicket", true)
|
||||
ticket = getTicket()
|
||||
// Expire the session ticket, which would force a full handshake.
|
||||
serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
|
||||
serverConfig.Time = func() time.Time { return testTime().Add(24*8*time.Hour + 2*time.Minute) } // [uTLS]
|
||||
testResumeState("ExpiredSessionTicket", false)
|
||||
if bytes.Equal(ticket, getTicket()) {
|
||||
t.Fatal("new ticket wasn't provided after old ticket expired")
|
||||
|
@ -986,7 +960,7 @@ func testResumption(t *testing.T, version uint16) {
|
|||
|
||||
// Age the session ticket a bit at a time, but don't expire it.
|
||||
d := 0 * time.Hour
|
||||
serverConfig.Time = func() time.Time { return time.Now().Add(d) }
|
||||
serverConfig.Time = func() time.Time { return testTime().Add(d) } // [uTLS]
|
||||
deleteTicket()
|
||||
testResumeState("GetFreshSessionTicket", false)
|
||||
for i := 0; i < 13; i++ {
|
||||
|
@ -997,7 +971,7 @@ func testResumption(t *testing.T, version uint16) {
|
|||
// handshake occurs for TLS 1.2. Resumption should still occur for
|
||||
// TLS 1.3 since the client should be using a fresh ticket sent over
|
||||
// by the server.
|
||||
d += 12 * time.Hour
|
||||
d += 12*time.Hour + time.Minute // [uTLS]
|
||||
if version == VersionTLS13 {
|
||||
testResumeState("ExpiredSessionTicket", true)
|
||||
} else {
|
||||
|
@ -1013,6 +987,7 @@ func testResumption(t *testing.T, version uint16) {
|
|||
MaxVersion: version,
|
||||
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
|
||||
Certificates: testConfig.Certificates,
|
||||
Time: testTime, // [uTLS]
|
||||
}
|
||||
serverConfig.SetSessionTicketKeys([][32]byte{key2})
|
||||
|
||||
|
@ -1038,6 +1013,7 @@ func testResumption(t *testing.T, version uint16) {
|
|||
CurvePreferences: []CurveID{CurveP521, CurveP384, CurveP256},
|
||||
MaxVersion: version,
|
||||
Certificates: testConfig.Certificates,
|
||||
Time: testTime, // [uTLS]
|
||||
}
|
||||
testResumeState("InitialHandshake", false)
|
||||
testResumeState("WithHelloRetryRequest", true)
|
||||
|
@ -1047,6 +1023,7 @@ func testResumption(t *testing.T, version uint16) {
|
|||
MaxVersion: version,
|
||||
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
|
||||
Certificates: testConfig.Certificates,
|
||||
Time: testTime, // [uTLS]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1101,6 +1078,10 @@ func (c *serializingClientCache) Get(sessionKey string) (session *ClientSessionS
|
|||
}
|
||||
|
||||
func (c *serializingClientCache) Put(sessionKey string, cs *ClientSessionState) {
|
||||
if cs == nil {
|
||||
c.ticket, c.state = nil, nil
|
||||
return
|
||||
}
|
||||
ticket, state, err := cs.ResumptionState()
|
||||
if err != nil {
|
||||
c.t.Error(err)
|
||||
|
@ -1761,6 +1742,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
|
|||
serverConfig := &Config{
|
||||
MaxVersion: version,
|
||||
Certificates: []Certificate{testConfig.Certificates[0]},
|
||||
Time: testTime, // [uTLS]
|
||||
ClientCAs: rootCAs,
|
||||
NextProtos: []string{"protocol1"},
|
||||
}
|
||||
|
@ -1774,6 +1756,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
|
|||
RootCAs: rootCAs,
|
||||
ServerName: "example.golang",
|
||||
Certificates: []Certificate{testConfig.Certificates[0]},
|
||||
Time: testTime, // [uTLS]
|
||||
NextProtos: []string{"protocol1"},
|
||||
}
|
||||
test.configureClient(clientConfig, &clientCalled)
|
||||
|
@ -2564,6 +2547,7 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
|
|||
ClientSessionCache: NewLRUClientSessionCache(32),
|
||||
ServerName: "example.golang",
|
||||
RootCAs: roots,
|
||||
Time: testTime, // [uTLS]
|
||||
}
|
||||
serverConfig := testConfig.Clone()
|
||||
serverConfig.MaxVersion = ver
|
||||
|
@ -2799,3 +2783,123 @@ func TestHandshakeRSATooBig(t *testing.T) {
|
|||
t.Errorf("Conn.processCertsFromClient unexpected error: want %q, got %q", expectedErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLS13ECHRejectionCallbacks(t *testing.T) {
|
||||
k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "test"},
|
||||
DNSNames: []string{"example.golang"},
|
||||
NotBefore: testConfig.Time().Add(-time.Hour),
|
||||
NotAfter: testConfig.Time().Add(time.Hour),
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, k.Public(), k)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientConfig, serverConfig := testConfig.Clone(), testConfig.Clone()
|
||||
serverConfig.Certificates = []Certificate{
|
||||
{
|
||||
Certificate: [][]byte{certDER},
|
||||
PrivateKey: k,
|
||||
},
|
||||
}
|
||||
serverConfig.MinVersion = VersionTLS13
|
||||
clientConfig.RootCAs = x509.NewCertPool()
|
||||
clientConfig.RootCAs.AddCert(cert)
|
||||
clientConfig.MinVersion = VersionTLS13
|
||||
clientConfig.EncryptedClientHelloConfigList, _ = hex.DecodeString("0041fe0d003d0100200020204bed0a11fc0dde595a9b78d966b0011128eb83f65d3c91c1cc5ac786cd246f000400010001ff0e6578616d706c652e676f6c616e670000")
|
||||
clientConfig.ServerName = "example.golang"
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
expectedErr string
|
||||
|
||||
verifyConnection func(ConnectionState) error
|
||||
verifyPeerCertificate func([][]byte, [][]*x509.Certificate) error
|
||||
encryptedClientHelloRejectionVerify func(ConnectionState) error
|
||||
}{
|
||||
{
|
||||
name: "no callbacks",
|
||||
expectedErr: "tls: server rejected ECH",
|
||||
},
|
||||
{
|
||||
name: "EncryptedClientHelloRejectionVerify, no err",
|
||||
encryptedClientHelloRejectionVerify: func(ConnectionState) error {
|
||||
return nil
|
||||
},
|
||||
expectedErr: "tls: server rejected ECH",
|
||||
},
|
||||
{
|
||||
name: "EncryptedClientHelloRejectionVerify, err",
|
||||
encryptedClientHelloRejectionVerify: func(ConnectionState) error {
|
||||
return errors.New("callback err")
|
||||
},
|
||||
// testHandshake returns the server side error, so we just need to
|
||||
// check alertBadCertificate was sent
|
||||
expectedErr: "callback err",
|
||||
},
|
||||
{
|
||||
name: "VerifyConnection, err",
|
||||
verifyConnection: func(ConnectionState) error {
|
||||
return errors.New("callback err")
|
||||
},
|
||||
expectedErr: "tls: server rejected ECH",
|
||||
},
|
||||
{
|
||||
name: "VerifyPeerCertificate, err",
|
||||
verifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error {
|
||||
return errors.New("callback err")
|
||||
},
|
||||
expectedErr: "tls: server rejected ECH",
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, s := localPipe(t)
|
||||
done := make(chan error)
|
||||
|
||||
go func() {
|
||||
serverErr := Server(s, serverConfig).Handshake()
|
||||
s.Close()
|
||||
done <- serverErr
|
||||
}()
|
||||
|
||||
cConfig := clientConfig.Clone()
|
||||
cConfig.VerifyConnection = tc.verifyConnection
|
||||
cConfig.VerifyPeerCertificate = tc.verifyPeerCertificate
|
||||
cConfig.EncryptedClientHelloRejectionVerify = tc.encryptedClientHelloRejectionVerify
|
||||
|
||||
clientErr := Client(c, cConfig).Handshake()
|
||||
c.Close()
|
||||
|
||||
if tc.expectedErr == "" && clientErr != nil {
|
||||
t.Fatalf("unexpected err: %s", clientErr)
|
||||
} else if clientErr != nil && tc.expectedErr != clientErr.Error() {
|
||||
t.Fatalf("unexpected err: got %q, want %q", clientErr, tc.expectedErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestECHTLS12Server(t *testing.T) {
|
||||
clientConfig, serverConfig := testConfig.Clone(), testConfig.Clone()
|
||||
|
||||
serverConfig.MaxVersion = VersionTLS12
|
||||
clientConfig.MinVersion = 0
|
||||
|
||||
clientConfig.EncryptedClientHelloConfigList, _ = hex.DecodeString("0041fe0d003d0100200020204bed0a11fc0dde595a9b78d966b0011128eb83f65d3c91c1cc5ac786cd246f000400010001ff0e6578616d706c652e676f6c616e670000")
|
||||
|
||||
expectedErr := "server: tls: client offered only unsupported versions: [304]\nclient: remote error: tls: protocol version not supported"
|
||||
_, _, err := testHandshake(t, clientConfig, serverConfig)
|
||||
if err == nil || err.Error() != expectedErr {
|
||||
t.Fatalf("unexpected handshake error: got %q, want %q", err, expectedErr)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue