mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-03 20:17:36 +03:00
crypto/tls: add Config.VerifyConnection callback
Since the ConnectionState will now be available during verification, some code was moved around in order to initialize and make available as much of the fields on Conn as possible before the ConnectionState is verified. Fixes #36736 Change-Id: I0e3efa97565ead7de5c48bb8a87e3ea54fbde140 Reviewed-on: https://go-review.googlesource.com/c/go/+/229122 Run-TryBot: Katie Hockman <katie@golang.org> Reviewed-by: Filippo Valsorda <filippo@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
47355c49eb
commit
2ad14d1adf
8 changed files with 266 additions and 71 deletions
13
common.go
13
common.go
|
@ -219,7 +219,7 @@ type ConnectionState struct {
|
|||
CipherSuite uint16 // cipher suite in use (TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, ...)
|
||||
NegotiatedProtocol string // negotiated next protocol (not guaranteed to be from Config.NextProtos)
|
||||
NegotiatedProtocolIsMutual bool // negotiated protocol was advertised by server (client side only)
|
||||
ServerName string // server name requested by client, if any (server side only)
|
||||
ServerName string // server name requested by client, if any
|
||||
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
|
||||
VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates
|
||||
SignedCertificateTimestamps [][]byte // SCTs from the peer, if any
|
||||
|
@ -520,6 +520,16 @@ type Config struct {
|
|||
// be considered but the verifiedChains argument will always be nil.
|
||||
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
|
||||
|
||||
// VerifyConnection, if not nil, is called after normal certificate
|
||||
// verification and after VerifyPeerCertificate by either a TLS client
|
||||
// or server. If it returns a non-nil error, the handshake is aborted
|
||||
// and that error results.
|
||||
//
|
||||
// If normal verification fails then the handshake will abort before
|
||||
// considering this callback. This callback will run for all connections
|
||||
// regardless of InsecureSkipVerify or ClientAuth settings.
|
||||
VerifyConnection func(ConnectionState) error
|
||||
|
||||
// RootCAs defines the set of root certificate authorities
|
||||
// that clients use when verifying server certificates.
|
||||
// If RootCAs is nil, TLS uses the host's root CA set.
|
||||
|
@ -685,6 +695,7 @@ func (c *Config) Clone() *Config {
|
|||
GetClientCertificate: c.GetClientCertificate,
|
||||
GetConfigForClient: c.GetConfigForClient,
|
||||
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||
VerifyConnection: c.VerifyConnection,
|
||||
RootCAs: c.RootCAs,
|
||||
NextProtos: c.NextProtos,
|
||||
ServerName: c.ServerName,
|
||||
|
|
43
conn.go
43
conn.go
|
@ -1379,35 +1379,34 @@ func (c *Conn) Handshake() error {
|
|||
func (c *Conn) ConnectionState() ConnectionState {
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
return c.connectionStateLocked()
|
||||
}
|
||||
|
||||
func (c *Conn) connectionStateLocked() ConnectionState {
|
||||
var state ConnectionState
|
||||
state.HandshakeComplete = c.handshakeComplete()
|
||||
state.Version = c.vers
|
||||
state.NegotiatedProtocol = c.clientProtocol
|
||||
state.DidResume = c.didResume
|
||||
state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
|
||||
state.ServerName = c.serverName
|
||||
|
||||
if state.HandshakeComplete {
|
||||
state.Version = c.vers
|
||||
state.NegotiatedProtocol = c.clientProtocol
|
||||
state.DidResume = c.didResume
|
||||
state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
|
||||
state.CipherSuite = c.cipherSuite
|
||||
state.PeerCertificates = c.peerCertificates
|
||||
state.VerifiedChains = c.verifiedChains
|
||||
state.SignedCertificateTimestamps = c.scts
|
||||
state.OCSPResponse = c.ocspResponse
|
||||
if !c.didResume && c.vers != VersionTLS13 {
|
||||
if c.clientFinishedIsFirst {
|
||||
state.TLSUnique = c.clientFinished[:]
|
||||
} else {
|
||||
state.TLSUnique = c.serverFinished[:]
|
||||
}
|
||||
}
|
||||
if c.config.Renegotiation != RenegotiateNever {
|
||||
state.ekm = noExportedKeyingMaterial
|
||||
state.CipherSuite = c.cipherSuite
|
||||
state.PeerCertificates = c.peerCertificates
|
||||
state.VerifiedChains = c.verifiedChains
|
||||
state.SignedCertificateTimestamps = c.scts
|
||||
state.OCSPResponse = c.ocspResponse
|
||||
if !c.didResume && c.vers != VersionTLS13 {
|
||||
if c.clientFinishedIsFirst {
|
||||
state.TLSUnique = c.clientFinished[:]
|
||||
} else {
|
||||
state.ekm = c.ekm
|
||||
state.TLSUnique = c.serverFinished[:]
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.Renegotiation != RenegotiateNever {
|
||||
state.ekm = noExportedKeyingMaterial
|
||||
} else {
|
||||
state.ekm = c.ekm
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
|
|
|
@ -146,6 +146,7 @@ func (c *Conn) clientHandshake() (err error) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.serverName = hello.serverName
|
||||
|
||||
cacheKey, session, earlySecret, binderKey := c.loadSession(hello)
|
||||
if cacheKey != "" && session != nil {
|
||||
|
@ -388,6 +389,7 @@ func (hs *clientHandshakeState) handshake() error {
|
|||
hs.finishedHash.Write(hs.serverHello.marshal())
|
||||
|
||||
c.buffering = true
|
||||
c.didResume = isResume
|
||||
if isResume {
|
||||
if err := hs.establishKeys(); err != nil {
|
||||
return err
|
||||
|
@ -399,6 +401,15 @@ func (hs *clientHandshakeState) handshake() error {
|
|||
return err
|
||||
}
|
||||
c.clientFinishedIsFirst = false
|
||||
// Make sure the connection is still being verified whether or not this
|
||||
// is a resumption. Resumptions currently don't reverify certificates so
|
||||
// they don't call verifyServerCertificate. See Issue 31641.
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := hs.sendFinished(c.clientFinished[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -428,7 +439,6 @@ func (hs *clientHandshakeState) handshake() error {
|
|||
}
|
||||
|
||||
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
|
||||
c.didResume = isResume
|
||||
atomic.StoreUint32(&c.handshakeStatus, 1)
|
||||
|
||||
return nil
|
||||
|
@ -458,25 +468,6 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
|||
}
|
||||
hs.finishedHash.Write(certMsg.marshal())
|
||||
|
||||
if c.handshakes == 0 {
|
||||
// If this is the first handshake on a connection, process and
|
||||
// (optionally) verify the server's certificates.
|
||||
if err := c.verifyServerCertificate(certMsg.certificates); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// This is a renegotiation handshake. We require that the
|
||||
// server's identity (i.e. leaf certificate) is unchanged and
|
||||
// thus any previous trust decision is still valid.
|
||||
//
|
||||
// See https://mitls.org/pages/attacks/3SHAKE for the
|
||||
// motivation behind this requirement.
|
||||
if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return errors.New("tls: server's identity changed during renegotiation")
|
||||
}
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -505,6 +496,25 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
|||
}
|
||||
}
|
||||
|
||||
if c.handshakes == 0 {
|
||||
// If this is the first handshake on a connection, process and
|
||||
// (optionally) verify the server's certificates.
|
||||
if err := c.verifyServerCertificate(certMsg.certificates); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// This is a renegotiation handshake. We require that the
|
||||
// server's identity (i.e. leaf certificate) is unchanged and
|
||||
// thus any previous trust decision is still valid.
|
||||
//
|
||||
// See https://mitls.org/pages/attacks/3SHAKE for the
|
||||
// motivation behind this requirement.
|
||||
if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return errors.New("tls: server's identity changed during renegotiation")
|
||||
}
|
||||
}
|
||||
|
||||
keyAgreement := hs.suite.ka(c.vers)
|
||||
|
||||
skx, ok := msg.(*serverKeyExchangeMsg)
|
||||
|
@ -831,13 +841,6 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
|
|||
}
|
||||
}
|
||||
|
||||
if c.config.VerifyPeerCertificate != nil {
|
||||
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
switch certs[0].PublicKey.(type) {
|
||||
case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey:
|
||||
break
|
||||
|
@ -848,6 +851,20 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
|
|||
|
||||
c.peerCertificates = certs
|
||||
|
||||
if c.config.VerifyPeerCertificate != nil {
|
||||
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -907,6 +907,9 @@ func testResumption(t *testing.T, version uint16) {
|
|||
if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
|
||||
t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
|
||||
}
|
||||
if got, want := hs.ServerName, clientConfig.ServerName; got != want {
|
||||
t.Errorf("%s: server name %s, want %s", test, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
getTicket := func() []byte {
|
||||
|
@ -1458,7 +1461,7 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
|
|||
|
||||
sentinelErr := errors.New("TestVerifyPeerCertificate")
|
||||
|
||||
verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||
verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||
if l := len(rawCerts); l != 1 {
|
||||
return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
|
||||
}
|
||||
|
@ -1468,6 +1471,19 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
|
|||
*called = true
|
||||
return nil
|
||||
}
|
||||
verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
|
||||
if l := len(c.PeerCertificates); l != 1 {
|
||||
return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
|
||||
}
|
||||
if len(c.VerifiedChains) == 0 {
|
||||
return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
|
||||
}
|
||||
if isClient && len(c.OCSPResponse) == 0 {
|
||||
return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
|
||||
}
|
||||
*called = true
|
||||
return nil
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
configureServer func(*Config, *bool)
|
||||
|
@ -1478,13 +1494,13 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
|
|||
configureServer: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||
return verifyCallback(called, rawCerts, validatedChains)
|
||||
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
|
||||
}
|
||||
},
|
||||
configureClient: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||
return verifyCallback(called, rawCerts, validatedChains)
|
||||
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
|
||||
}
|
||||
},
|
||||
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||
|
@ -1565,6 +1581,116 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
|
|||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
configureServer: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
return verifyConnectionCallback(called, false, c)
|
||||
}
|
||||
},
|
||||
configureClient: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
return verifyConnectionCallback(called, true, c)
|
||||
}
|
||||
},
|
||||
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||
if clientErr != nil {
|
||||
t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
|
||||
}
|
||||
if serverErr != nil {
|
||||
t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
|
||||
}
|
||||
if !clientCalled {
|
||||
t.Errorf("test[%d]: client did not call callback", testNo)
|
||||
}
|
||||
if !serverCalled {
|
||||
t.Errorf("test[%d]: server did not call callback", testNo)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
configureServer: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
return sentinelErr
|
||||
}
|
||||
},
|
||||
configureClient: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyConnection = nil
|
||||
},
|
||||
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||
if serverErr != sentinelErr {
|
||||
t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
configureServer: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyConnection = nil
|
||||
},
|
||||
configureClient: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
return sentinelErr
|
||||
}
|
||||
},
|
||||
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||
if clientErr != sentinelErr {
|
||||
t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
configureServer: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
|
||||
}
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
return sentinelErr
|
||||
}
|
||||
},
|
||||
configureClient: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyPeerCertificate = nil
|
||||
config.VerifyConnection = nil
|
||||
},
|
||||
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||
if serverErr != sentinelErr {
|
||||
t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
|
||||
}
|
||||
if !serverCalled {
|
||||
t.Errorf("test[%d]: server did not call callback", testNo)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
configureServer: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyPeerCertificate = nil
|
||||
config.VerifyConnection = nil
|
||||
},
|
||||
configureClient: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
|
||||
}
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
return sentinelErr
|
||||
}
|
||||
},
|
||||
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||
if clientErr != sentinelErr {
|
||||
t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
|
||||
}
|
||||
if !clientCalled {
|
||||
t.Errorf("test[%d]: client did not call callback", testNo)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
|
@ -1580,6 +1706,11 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
|
|||
config.ClientCAs = rootCAs
|
||||
config.Time = now
|
||||
config.MaxVersion = version
|
||||
config.Certificates = make([]Certificate, 1)
|
||||
config.Certificates[0].Certificate = [][]byte{testRSACertificate}
|
||||
config.Certificates[0].PrivateKey = testRSAPrivateKey
|
||||
config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
|
||||
config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
|
||||
test.configureServer(config, &serverCalled)
|
||||
|
||||
err = Server(s, config).Handshake()
|
||||
|
|
|
@ -407,6 +407,15 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
|
|||
// Either a PSK or a certificate is always used, but not both.
|
||||
// See RFC 8446, Section 4.1.1.
|
||||
if hs.usingPSK {
|
||||
// Make sure the connection is still being verified whether or not this
|
||||
// is a resumption. Resumptions currently don't reverify certificates so
|
||||
// they don't call verifyServerCertificate. See Issue 31641.
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -68,6 +68,7 @@ func (hs *serverHandshakeState) handshake() error {
|
|||
c.buffering = true
|
||||
if hs.checkForResumption() {
|
||||
// The client has included a session ticket and so we do an abbreviated handshake.
|
||||
c.didResume = true
|
||||
if err := hs.doResumeHandshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -92,7 +93,6 @@ func (hs *serverHandshakeState) handshake() error {
|
|||
if err := hs.readFinished(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
c.didResume = true
|
||||
} else {
|
||||
// The client didn't include a session ticket, or it wasn't
|
||||
// valid so we do a full handshake.
|
||||
|
@ -553,6 +553,15 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Make sure the connection is still being verified whether or not
|
||||
// the server requested a client certificate.
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get client key exchange
|
||||
|
@ -771,6 +780,19 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
|
|||
c.verifiedChains = chains
|
||||
}
|
||||
|
||||
c.peerCertificates = certs
|
||||
c.ocspResponse = certificate.OCSPStaple
|
||||
c.scts = certificate.SignedCertificateTimestamps
|
||||
|
||||
if len(certs) > 0 {
|
||||
switch certs[0].PublicKey.(type) {
|
||||
case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
|
||||
default:
|
||||
c.sendAlert(alertUnsupportedCertificate)
|
||||
return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.VerifyPeerCertificate != nil {
|
||||
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
|
@ -778,20 +800,13 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
|
|||
}
|
||||
}
|
||||
|
||||
if len(certs) == 0 {
|
||||
return nil
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
switch certs[0].PublicKey.(type) {
|
||||
case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
|
||||
default:
|
||||
c.sendAlert(alertUnsupportedCertificate)
|
||||
return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
|
||||
}
|
||||
|
||||
c.peerCertificates = certs
|
||||
c.ocspResponse = certificate.OCSPStaple
|
||||
c.scts = certificate.SignedCertificateTimestamps
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -306,6 +306,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
|
|||
return errors.New("tls: invalid PSK binder")
|
||||
}
|
||||
|
||||
c.didResume = true
|
||||
if err := c.processCertsFromClient(sessionState.certificate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -313,7 +314,6 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
|
|||
hs.hello.selectedIdentityPresent = true
|
||||
hs.hello.selectedIdentity = uint16(i)
|
||||
hs.usingPSK = true
|
||||
c.didResume = true
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -753,6 +753,14 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
|
|||
c := hs.c
|
||||
|
||||
if !hs.requestClientCert() {
|
||||
// Make sure the connection is still being verified whether or not
|
||||
// the server requested a client certificate.
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
13
tls_test.go
13
tls_test.go
|
@ -734,7 +734,7 @@ func TestWarningAlertFlood(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCloneFuncFields(t *testing.T) {
|
||||
const expectedCount = 5
|
||||
const expectedCount = 6
|
||||
called := 0
|
||||
|
||||
c1 := Config{
|
||||
|
@ -758,6 +758,10 @@ func TestCloneFuncFields(t *testing.T) {
|
|||
called |= 1 << 4
|
||||
return nil
|
||||
},
|
||||
VerifyConnection: func(ConnectionState) error {
|
||||
called |= 1 << 5
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
c2 := c1.Clone()
|
||||
|
@ -767,6 +771,7 @@ func TestCloneFuncFields(t *testing.T) {
|
|||
c2.GetClientCertificate(nil)
|
||||
c2.GetConfigForClient(nil)
|
||||
c2.VerifyPeerCertificate(nil, nil)
|
||||
c2.VerifyConnection(ConnectionState{})
|
||||
|
||||
if called != (1<<expectedCount)-1 {
|
||||
t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
|
||||
|
@ -790,7 +795,7 @@ func TestCloneNonFuncFields(t *testing.T) {
|
|||
switch fn := typ.Field(i).Name; fn {
|
||||
case "Rand":
|
||||
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
|
||||
case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
|
||||
case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate":
|
||||
// DeepEqual can't compare functions. If you add a
|
||||
// function field to this list, you must also change
|
||||
// TestCloneFuncFields to ensure that the func field is
|
||||
|
@ -1116,8 +1121,8 @@ func TestConnectionState(t *testing.T) {
|
|||
if ss.ServerName != serverName {
|
||||
t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
|
||||
}
|
||||
if cs.ServerName != "" {
|
||||
t.Errorf("Got unexpected server name on the client side")
|
||||
if cs.ServerName != serverName {
|
||||
t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
|
||||
}
|
||||
|
||||
if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue