crypto/tls: add Config.VerifyConnection callback

Since the ConnectionState will now be available during
verification, some code was moved around in order to
initialize and make available as much of the fields on
Conn as possible before the ConnectionState is verified.

Fixes #36736

Change-Id: I0e3efa97565ead7de5c48bb8a87e3ea54fbde140
Reviewed-on: https://go-review.googlesource.com/c/go/+/229122
Run-TryBot: Katie Hockman <katie@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Katie Hockman 2020-04-20 17:55:37 -04:00
parent 47355c49eb
commit 2ad14d1adf
8 changed files with 266 additions and 71 deletions

View file

@ -907,6 +907,9 @@ func testResumption(t *testing.T, version uint16) {
if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
}
if got, want := hs.ServerName, clientConfig.ServerName; got != want {
t.Errorf("%s: server name %s, want %s", test, got, want)
}
}
getTicket := func() []byte {
@ -1458,7 +1461,7 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
sentinelErr := errors.New("TestVerifyPeerCertificate")
verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
if l := len(rawCerts); l != 1 {
return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
}
@ -1468,6 +1471,19 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
*called = true
return nil
}
verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
if l := len(c.PeerCertificates); l != 1 {
return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
}
if len(c.VerifiedChains) == 0 {
return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
}
if isClient && len(c.OCSPResponse) == 0 {
return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
}
*called = true
return nil
}
tests := []struct {
configureServer func(*Config, *bool)
@ -1478,13 +1494,13 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
return verifyCallback(called, rawCerts, validatedChains)
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
}
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
return verifyCallback(called, rawCerts, validatedChains)
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
}
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
@ -1565,6 +1581,116 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
}
},
},
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyConnection = func(c ConnectionState) error {
return verifyConnectionCallback(called, false, c)
}
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyConnection = func(c ConnectionState) error {
return verifyConnectionCallback(called, true, c)
}
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if clientErr != nil {
t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
}
if serverErr != nil {
t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
}
if !clientCalled {
t.Errorf("test[%d]: client did not call callback", testNo)
}
if !serverCalled {
t.Errorf("test[%d]: server did not call callback", testNo)
}
},
},
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyConnection = func(c ConnectionState) error {
return sentinelErr
}
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyConnection = nil
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if serverErr != sentinelErr {
t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
}
},
},
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyConnection = nil
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyConnection = func(c ConnectionState) error {
return sentinelErr
}
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if clientErr != sentinelErr {
t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
}
},
},
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
}
config.VerifyConnection = func(c ConnectionState) error {
return sentinelErr
}
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = nil
config.VerifyConnection = nil
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if serverErr != sentinelErr {
t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
}
if !serverCalled {
t.Errorf("test[%d]: server did not call callback", testNo)
}
},
},
{
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = nil
config.VerifyConnection = nil
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
}
config.VerifyConnection = func(c ConnectionState) error {
return sentinelErr
}
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
if clientErr != sentinelErr {
t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
}
if !clientCalled {
t.Errorf("test[%d]: client did not call callback", testNo)
}
},
},
}
for i, test := range tests {
@ -1580,6 +1706,11 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
config.ClientCAs = rootCAs
config.Time = now
config.MaxVersion = version
config.Certificates = make([]Certificate, 1)
config.Certificates[0].Certificate = [][]byte{testRSACertificate}
config.Certificates[0].PrivateKey = testRSAPrivateKey
config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
test.configureServer(config, &serverCalled)
err = Server(s, config).Handshake()