crypto/tls: add VerifyPeerCertificate to tls.Config

VerifyPeerCertificate returns an error if the peer should not be
trusted. It will be called after the initial handshake and before
any other verification checks on the cert or chain are performed.
This provides the callee an opportunity to augment the certificate
verification.

If VerifyPeerCertificate is not nil and returns an error,
then the handshake will fail.

Fixes #16363

Change-Id: I6a22f199f0e81b6f5d5f37c54d85ab878216bb22
Reviewed-on: https://go-review.googlesource.com/26654
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Joshua Boelter 2016-07-13 16:22:28 -06:00 committed by Adam Langley
parent 50401d1e36
commit 8f3c635434
5 changed files with 182 additions and 1 deletions

View file

@ -1067,6 +1067,160 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
}
}
func TestVerifyPeerCertificate(t *testing.T) {
issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
if err != nil {
panic(err)
}
rootCAs := x509.NewCertPool()
rootCAs.AddCert(issuer)
now := func() time.Time { return time.Unix(1476984729, 0) }
sentinelErr := errors.New("TestVerifyPeerCertificate")
verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
if l := len(rawCerts); l != 1 {
return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
}
if len(validatedChains) == 0 {
return errors.New("got len(validatedChains) = 0, wanted non-zero")
}
*called = true
return nil
}
tests := []struct {
configureServer func(*Config, *bool)
configureClient func(*Config, *bool)
validate func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
}{
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
return verifyCallback(called, rawCerts, validatedChains)
}
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
return verifyCallback(called, rawCerts, validatedChains)
}
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if clientErr != nil {
t.Errorf("#%d: client handshake failed: %v", testNo, clientErr)
}
if serverErr != nil {
t.Errorf("#%d: server handshake failed: %v", testNo, serverErr)
}
if !clientCalled {
t.Error("#%d: client did not call callback", testNo)
}
if !serverCalled {
t.Error("#%d: server did not call callback", testNo)
}
},
},
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
return sentinelErr
}
},
configureClient: func(config *Config, called *bool) {
config.VerifyPeerCertificate = nil
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if serverErr != sentinelErr {
t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
}
},
},
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
},
configureClient: func(config *Config, called *bool) {
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
return sentinelErr
}
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if clientErr != sentinelErr {
t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
}
},
},
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = true
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
if l := len(rawCerts); l != 1 {
return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
}
// With InsecureSkipVerify set, this
// callback should still be called but
// validatedChains must be empty.
if l := len(validatedChains); l != 0 {
return errors.New("got len(validatedChains) = 0, wanted zero")
}
*called = true
return nil
}
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if clientErr != nil {
t.Errorf("#%d: client handshake failed: %v", testNo, clientErr)
}
if serverErr != nil {
t.Errorf("#%d: server handshake failed: %v", testNo, serverErr)
}
if !clientCalled {
t.Error("#%d: client did not call callback", testNo)
}
},
},
}
for i, test := range tests {
c, s := net.Pipe()
done := make(chan error)
var clientCalled, serverCalled bool
go func() {
config := testConfig.Clone()
config.ServerName = "example.golang"
config.ClientAuth = RequireAndVerifyClientCert
config.ClientCAs = rootCAs
config.Time = now
test.configureServer(config, &serverCalled)
err = Server(s, config).Handshake()
s.Close()
done <- err
}()
config := testConfig.Clone()
config.ServerName = "example.golang"
config.RootCAs = rootCAs
config.Time = now
test.configureClient(config, &clientCalled)
clientErr := Client(c, config).Handshake()
c.Close()
serverErr := <-done
test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
}
}
// brokenConn wraps a net.Conn and causes all Writes after a certain number to
// fail with brokenConnErr.
type brokenConn struct {