crypto/tls: add Config.VerifyConnection callback

Since the ConnectionState will now be available during
verification, some code was moved around in order to
initialize and make available as much of the fields on
Conn as possible before the ConnectionState is verified.

Fixes #36736

Change-Id: I0e3efa97565ead7de5c48bb8a87e3ea54fbde140
Reviewed-on: https://go-review.googlesource.com/c/go/+/229122
Run-TryBot: Katie Hockman <katie@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Katie Hockman 2020-04-20 17:55:37 -04:00
parent 47355c49eb
commit 2ad14d1adf
8 changed files with 266 additions and 71 deletions

View file

@ -734,7 +734,7 @@ func TestWarningAlertFlood(t *testing.T) {
}
func TestCloneFuncFields(t *testing.T) {
const expectedCount = 5
const expectedCount = 6
called := 0
c1 := Config{
@ -758,6 +758,10 @@ func TestCloneFuncFields(t *testing.T) {
called |= 1 << 4
return nil
},
VerifyConnection: func(ConnectionState) error {
called |= 1 << 5
return nil
},
}
c2 := c1.Clone()
@ -767,6 +771,7 @@ func TestCloneFuncFields(t *testing.T) {
c2.GetClientCertificate(nil)
c2.GetConfigForClient(nil)
c2.VerifyPeerCertificate(nil, nil)
c2.VerifyConnection(ConnectionState{})
if called != (1<<expectedCount)-1 {
t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
@ -790,7 +795,7 @@ func TestCloneNonFuncFields(t *testing.T) {
switch fn := typ.Field(i).Name; fn {
case "Rand":
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate":
// DeepEqual can't compare functions. If you add a
// function field to this list, you must also change
// TestCloneFuncFields to ensure that the func field is
@ -1116,8 +1121,8 @@ func TestConnectionState(t *testing.T) {
if ss.ServerName != serverName {
t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
}
if cs.ServerName != "" {
t.Errorf("Got unexpected server name on the client side")
if cs.ServerName != serverName {
t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
}
if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {