crypto/tls: add Config.VerifyConnection callback

Since the ConnectionState will now be available during
verification, some code was moved around in order to
initialize and make available as much of the fields on
Conn as possible before the ConnectionState is verified.

Fixes #36736

Change-Id: I0e3efa97565ead7de5c48bb8a87e3ea54fbde140
Reviewed-on: https://go-review.googlesource.com/c/go/+/229122
Run-TryBot: Katie Hockman <katie@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Katie Hockman 2020-04-20 17:55:37 -04:00
parent 47355c49eb
commit 2ad14d1adf
8 changed files with 266 additions and 71 deletions

View file

@ -219,7 +219,7 @@ type ConnectionState struct {
CipherSuite uint16 // cipher suite in use (TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, ...) CipherSuite uint16 // cipher suite in use (TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, ...)
NegotiatedProtocol string // negotiated next protocol (not guaranteed to be from Config.NextProtos) NegotiatedProtocol string // negotiated next protocol (not guaranteed to be from Config.NextProtos)
NegotiatedProtocolIsMutual bool // negotiated protocol was advertised by server (client side only) NegotiatedProtocolIsMutual bool // negotiated protocol was advertised by server (client side only)
ServerName string // server name requested by client, if any (server side only) ServerName string // server name requested by client, if any
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates
SignedCertificateTimestamps [][]byte // SCTs from the peer, if any SignedCertificateTimestamps [][]byte // SCTs from the peer, if any
@ -520,6 +520,16 @@ type Config struct {
// be considered but the verifiedChains argument will always be nil. // be considered but the verifiedChains argument will always be nil.
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
// VerifyConnection, if not nil, is called after normal certificate
// verification and after VerifyPeerCertificate by either a TLS client
// or server. If it returns a non-nil error, the handshake is aborted
// and that error results.
//
// If normal verification fails then the handshake will abort before
// considering this callback. This callback will run for all connections
// regardless of InsecureSkipVerify or ClientAuth settings.
VerifyConnection func(ConnectionState) error
// RootCAs defines the set of root certificate authorities // RootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates. // that clients use when verifying server certificates.
// If RootCAs is nil, TLS uses the host's root CA set. // If RootCAs is nil, TLS uses the host's root CA set.
@ -685,6 +695,7 @@ func (c *Config) Clone() *Config {
GetClientCertificate: c.GetClientCertificate, GetClientCertificate: c.GetClientCertificate,
GetConfigForClient: c.GetConfigForClient, GetConfigForClient: c.GetConfigForClient,
VerifyPeerCertificate: c.VerifyPeerCertificate, VerifyPeerCertificate: c.VerifyPeerCertificate,
VerifyConnection: c.VerifyConnection,
RootCAs: c.RootCAs, RootCAs: c.RootCAs,
NextProtos: c.NextProtos, NextProtos: c.NextProtos,
ServerName: c.ServerName, ServerName: c.ServerName,

43
conn.go
View file

@ -1379,35 +1379,34 @@ func (c *Conn) Handshake() error {
func (c *Conn) ConnectionState() ConnectionState { func (c *Conn) ConnectionState() ConnectionState {
c.handshakeMutex.Lock() c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
return c.connectionStateLocked()
}
func (c *Conn) connectionStateLocked() ConnectionState {
var state ConnectionState var state ConnectionState
state.HandshakeComplete = c.handshakeComplete() state.HandshakeComplete = c.handshakeComplete()
state.Version = c.vers
state.NegotiatedProtocol = c.clientProtocol
state.DidResume = c.didResume
state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
state.ServerName = c.serverName state.ServerName = c.serverName
state.CipherSuite = c.cipherSuite
if state.HandshakeComplete { state.PeerCertificates = c.peerCertificates
state.Version = c.vers state.VerifiedChains = c.verifiedChains
state.NegotiatedProtocol = c.clientProtocol state.SignedCertificateTimestamps = c.scts
state.DidResume = c.didResume state.OCSPResponse = c.ocspResponse
state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback if !c.didResume && c.vers != VersionTLS13 {
state.CipherSuite = c.cipherSuite if c.clientFinishedIsFirst {
state.PeerCertificates = c.peerCertificates state.TLSUnique = c.clientFinished[:]
state.VerifiedChains = c.verifiedChains
state.SignedCertificateTimestamps = c.scts
state.OCSPResponse = c.ocspResponse
if !c.didResume && c.vers != VersionTLS13 {
if c.clientFinishedIsFirst {
state.TLSUnique = c.clientFinished[:]
} else {
state.TLSUnique = c.serverFinished[:]
}
}
if c.config.Renegotiation != RenegotiateNever {
state.ekm = noExportedKeyingMaterial
} else { } else {
state.ekm = c.ekm state.TLSUnique = c.serverFinished[:]
} }
} }
if c.config.Renegotiation != RenegotiateNever {
state.ekm = noExportedKeyingMaterial
} else {
state.ekm = c.ekm
}
return state return state
} }

View file

@ -146,6 +146,7 @@ func (c *Conn) clientHandshake() (err error) {
if err != nil { if err != nil {
return err return err
} }
c.serverName = hello.serverName
cacheKey, session, earlySecret, binderKey := c.loadSession(hello) cacheKey, session, earlySecret, binderKey := c.loadSession(hello)
if cacheKey != "" && session != nil { if cacheKey != "" && session != nil {
@ -388,6 +389,7 @@ func (hs *clientHandshakeState) handshake() error {
hs.finishedHash.Write(hs.serverHello.marshal()) hs.finishedHash.Write(hs.serverHello.marshal())
c.buffering = true c.buffering = true
c.didResume = isResume
if isResume { if isResume {
if err := hs.establishKeys(); err != nil { if err := hs.establishKeys(); err != nil {
return err return err
@ -399,6 +401,15 @@ func (hs *clientHandshakeState) handshake() error {
return err return err
} }
c.clientFinishedIsFirst = false c.clientFinishedIsFirst = false
// Make sure the connection is still being verified whether or not this
// is a resumption. Resumptions currently don't reverify certificates so
// they don't call verifyServerCertificate. See Issue 31641.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
if err := hs.sendFinished(c.clientFinished[:]); err != nil { if err := hs.sendFinished(c.clientFinished[:]); err != nil {
return err return err
} }
@ -428,7 +439,6 @@ func (hs *clientHandshakeState) handshake() error {
} }
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
c.didResume = isResume
atomic.StoreUint32(&c.handshakeStatus, 1) atomic.StoreUint32(&c.handshakeStatus, 1)
return nil return nil
@ -458,25 +468,6 @@ func (hs *clientHandshakeState) doFullHandshake() error {
} }
hs.finishedHash.Write(certMsg.marshal()) hs.finishedHash.Write(certMsg.marshal())
if c.handshakes == 0 {
// If this is the first handshake on a connection, process and
// (optionally) verify the server's certificates.
if err := c.verifyServerCertificate(certMsg.certificates); err != nil {
return err
}
} else {
// This is a renegotiation handshake. We require that the
// server's identity (i.e. leaf certificate) is unchanged and
// thus any previous trust decision is still valid.
//
// See https://mitls.org/pages/attacks/3SHAKE for the
// motivation behind this requirement.
if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) {
c.sendAlert(alertBadCertificate)
return errors.New("tls: server's identity changed during renegotiation")
}
}
msg, err = c.readHandshake() msg, err = c.readHandshake()
if err != nil { if err != nil {
return err return err
@ -505,6 +496,25 @@ func (hs *clientHandshakeState) doFullHandshake() error {
} }
} }
if c.handshakes == 0 {
// If this is the first handshake on a connection, process and
// (optionally) verify the server's certificates.
if err := c.verifyServerCertificate(certMsg.certificates); err != nil {
return err
}
} else {
// This is a renegotiation handshake. We require that the
// server's identity (i.e. leaf certificate) is unchanged and
// thus any previous trust decision is still valid.
//
// See https://mitls.org/pages/attacks/3SHAKE for the
// motivation behind this requirement.
if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) {
c.sendAlert(alertBadCertificate)
return errors.New("tls: server's identity changed during renegotiation")
}
}
keyAgreement := hs.suite.ka(c.vers) keyAgreement := hs.suite.ka(c.vers)
skx, ok := msg.(*serverKeyExchangeMsg) skx, ok := msg.(*serverKeyExchangeMsg)
@ -831,13 +841,6 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
} }
} }
if c.config.VerifyPeerCertificate != nil {
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
switch certs[0].PublicKey.(type) { switch certs[0].PublicKey.(type) {
case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey:
break break
@ -848,6 +851,20 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
c.peerCertificates = certs c.peerCertificates = certs
if c.config.VerifyPeerCertificate != nil {
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil return nil
} }

View file

@ -907,6 +907,9 @@ func testResumption(t *testing.T, version uint16) {
if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) { if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains) t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
} }
if got, want := hs.ServerName, clientConfig.ServerName; got != want {
t.Errorf("%s: server name %s, want %s", test, got, want)
}
} }
getTicket := func() []byte { getTicket := func() []byte {
@ -1458,7 +1461,7 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
sentinelErr := errors.New("TestVerifyPeerCertificate") sentinelErr := errors.New("TestVerifyPeerCertificate")
verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
if l := len(rawCerts); l != 1 { if l := len(rawCerts); l != 1 {
return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
} }
@ -1468,6 +1471,19 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
*called = true *called = true
return nil return nil
} }
verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
if l := len(c.PeerCertificates); l != 1 {
return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
}
if len(c.VerifiedChains) == 0 {
return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
}
if isClient && len(c.OCSPResponse) == 0 {
return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
}
*called = true
return nil
}
tests := []struct { tests := []struct {
configureServer func(*Config, *bool) configureServer func(*Config, *bool)
@ -1478,13 +1494,13 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
configureServer: func(config *Config, called *bool) { configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
return verifyCallback(called, rawCerts, validatedChains) return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
} }
}, },
configureClient: func(config *Config, called *bool) { configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
return verifyCallback(called, rawCerts, validatedChains) return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
} }
}, },
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
@ -1565,6 +1581,116 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
} }
}, },
}, },
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyConnection = func(c ConnectionState) error {
return verifyConnectionCallback(called, false, c)
}
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyConnection = func(c ConnectionState) error {
return verifyConnectionCallback(called, true, c)
}
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if clientErr != nil {
t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
}
if serverErr != nil {
t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
}
if !clientCalled {
t.Errorf("test[%d]: client did not call callback", testNo)
}
if !serverCalled {
t.Errorf("test[%d]: server did not call callback", testNo)
}
},
},
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyConnection = func(c ConnectionState) error {
return sentinelErr
}
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyConnection = nil
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if serverErr != sentinelErr {
t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
}
},
},
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyConnection = nil
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyConnection = func(c ConnectionState) error {
return sentinelErr
}
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if clientErr != sentinelErr {
t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
}
},
},
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
}
config.VerifyConnection = func(c ConnectionState) error {
return sentinelErr
}
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = nil
config.VerifyConnection = nil
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if serverErr != sentinelErr {
t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
}
if !serverCalled {
t.Errorf("test[%d]: server did not call callback", testNo)
}
},
},
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = nil
config.VerifyConnection = nil
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
}
config.VerifyConnection = func(c ConnectionState) error {
return sentinelErr
}
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if clientErr != sentinelErr {
t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
}
if !clientCalled {
t.Errorf("test[%d]: client did not call callback", testNo)
}
},
},
} }
for i, test := range tests { for i, test := range tests {
@ -1580,6 +1706,11 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
config.ClientCAs = rootCAs config.ClientCAs = rootCAs
config.Time = now config.Time = now
config.MaxVersion = version config.MaxVersion = version
config.Certificates = make([]Certificate, 1)
config.Certificates[0].Certificate = [][]byte{testRSACertificate}
config.Certificates[0].PrivateKey = testRSAPrivateKey
config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
test.configureServer(config, &serverCalled) test.configureServer(config, &serverCalled)
err = Server(s, config).Handshake() err = Server(s, config).Handshake()

View file

@ -407,6 +407,15 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
// Either a PSK or a certificate is always used, but not both. // Either a PSK or a certificate is always used, but not both.
// See RFC 8446, Section 4.1.1. // See RFC 8446, Section 4.1.1.
if hs.usingPSK { if hs.usingPSK {
// Make sure the connection is still being verified whether or not this
// is a resumption. Resumptions currently don't reverify certificates so
// they don't call verifyServerCertificate. See Issue 31641.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil return nil
} }

View file

@ -68,6 +68,7 @@ func (hs *serverHandshakeState) handshake() error {
c.buffering = true c.buffering = true
if hs.checkForResumption() { if hs.checkForResumption() {
// The client has included a session ticket and so we do an abbreviated handshake. // The client has included a session ticket and so we do an abbreviated handshake.
c.didResume = true
if err := hs.doResumeHandshake(); err != nil { if err := hs.doResumeHandshake(); err != nil {
return err return err
} }
@ -92,7 +93,6 @@ func (hs *serverHandshakeState) handshake() error {
if err := hs.readFinished(nil); err != nil { if err := hs.readFinished(nil); err != nil {
return err return err
} }
c.didResume = true
} else { } else {
// The client didn't include a session ticket, or it wasn't // The client didn't include a session ticket, or it wasn't
// valid so we do a full handshake. // valid so we do a full handshake.
@ -553,6 +553,15 @@ func (hs *serverHandshakeState) doFullHandshake() error {
if err != nil { if err != nil {
return err return err
} }
} else {
// Make sure the connection is still being verified whether or not
// the server requested a client certificate.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
} }
// Get client key exchange // Get client key exchange
@ -771,6 +780,19 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
c.verifiedChains = chains c.verifiedChains = chains
} }
c.peerCertificates = certs
c.ocspResponse = certificate.OCSPStaple
c.scts = certificate.SignedCertificateTimestamps
if len(certs) > 0 {
switch certs[0].PublicKey.(type) {
case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
default:
c.sendAlert(alertUnsupportedCertificate)
return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
}
}
if c.config.VerifyPeerCertificate != nil { if c.config.VerifyPeerCertificate != nil {
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
c.sendAlert(alertBadCertificate) c.sendAlert(alertBadCertificate)
@ -778,20 +800,13 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
} }
} }
if len(certs) == 0 { if c.config.VerifyConnection != nil {
return nil if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
} }
switch certs[0].PublicKey.(type) {
case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
default:
c.sendAlert(alertUnsupportedCertificate)
return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
}
c.peerCertificates = certs
c.ocspResponse = certificate.OCSPStaple
c.scts = certificate.SignedCertificateTimestamps
return nil return nil
} }

View file

@ -306,6 +306,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
return errors.New("tls: invalid PSK binder") return errors.New("tls: invalid PSK binder")
} }
c.didResume = true
if err := c.processCertsFromClient(sessionState.certificate); err != nil { if err := c.processCertsFromClient(sessionState.certificate); err != nil {
return err return err
} }
@ -313,7 +314,6 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
hs.hello.selectedIdentityPresent = true hs.hello.selectedIdentityPresent = true
hs.hello.selectedIdentity = uint16(i) hs.hello.selectedIdentity = uint16(i)
hs.usingPSK = true hs.usingPSK = true
c.didResume = true
return nil return nil
} }
@ -753,6 +753,14 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
c := hs.c c := hs.c
if !hs.requestClientCert() { if !hs.requestClientCert() {
// Make sure the connection is still being verified whether or not
// the server requested a client certificate.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil return nil
} }

View file

@ -734,7 +734,7 @@ func TestWarningAlertFlood(t *testing.T) {
} }
func TestCloneFuncFields(t *testing.T) { func TestCloneFuncFields(t *testing.T) {
const expectedCount = 5 const expectedCount = 6
called := 0 called := 0
c1 := Config{ c1 := Config{
@ -758,6 +758,10 @@ func TestCloneFuncFields(t *testing.T) {
called |= 1 << 4 called |= 1 << 4
return nil return nil
}, },
VerifyConnection: func(ConnectionState) error {
called |= 1 << 5
return nil
},
} }
c2 := c1.Clone() c2 := c1.Clone()
@ -767,6 +771,7 @@ func TestCloneFuncFields(t *testing.T) {
c2.GetClientCertificate(nil) c2.GetClientCertificate(nil)
c2.GetConfigForClient(nil) c2.GetConfigForClient(nil)
c2.VerifyPeerCertificate(nil, nil) c2.VerifyPeerCertificate(nil, nil)
c2.VerifyConnection(ConnectionState{})
if called != (1<<expectedCount)-1 { if called != (1<<expectedCount)-1 {
t.Fatalf("expected %d calls but saw calls %b", expectedCount, called) t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
@ -790,7 +795,7 @@ func TestCloneNonFuncFields(t *testing.T) {
switch fn := typ.Field(i).Name; fn { switch fn := typ.Field(i).Name; fn {
case "Rand": case "Rand":
f.Set(reflect.ValueOf(io.Reader(os.Stdin))) f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate": case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate":
// DeepEqual can't compare functions. If you add a // DeepEqual can't compare functions. If you add a
// function field to this list, you must also change // function field to this list, you must also change
// TestCloneFuncFields to ensure that the func field is // TestCloneFuncFields to ensure that the func field is
@ -1116,8 +1121,8 @@ func TestConnectionState(t *testing.T) {
if ss.ServerName != serverName { if ss.ServerName != serverName {
t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName) t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
} }
if cs.ServerName != "" { if cs.ServerName != serverName {
t.Errorf("Got unexpected server name on the client side") t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
} }
if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 { if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {