mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-03 20:17:36 +03:00
crypto/tls: add Config.VerifyConnection callback
Since the ConnectionState will now be available during verification, some code was moved around in order to initialize and make available as much of the fields on Conn as possible before the ConnectionState is verified. Fixes #36736 Change-Id: I0e3efa97565ead7de5c48bb8a87e3ea54fbde140 Reviewed-on: https://go-review.googlesource.com/c/go/+/229122 Run-TryBot: Katie Hockman <katie@golang.org> Reviewed-by: Filippo Valsorda <filippo@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
47355c49eb
commit
2ad14d1adf
8 changed files with 266 additions and 71 deletions
|
@ -907,6 +907,9 @@ func testResumption(t *testing.T, version uint16) {
|
|||
if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
|
||||
t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
|
||||
}
|
||||
if got, want := hs.ServerName, clientConfig.ServerName; got != want {
|
||||
t.Errorf("%s: server name %s, want %s", test, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
getTicket := func() []byte {
|
||||
|
@ -1458,7 +1461,7 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
|
|||
|
||||
sentinelErr := errors.New("TestVerifyPeerCertificate")
|
||||
|
||||
verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||
verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||
if l := len(rawCerts); l != 1 {
|
||||
return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
|
||||
}
|
||||
|
@ -1468,6 +1471,19 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
|
|||
*called = true
|
||||
return nil
|
||||
}
|
||||
verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
|
||||
if l := len(c.PeerCertificates); l != 1 {
|
||||
return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
|
||||
}
|
||||
if len(c.VerifiedChains) == 0 {
|
||||
return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
|
||||
}
|
||||
if isClient && len(c.OCSPResponse) == 0 {
|
||||
return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
|
||||
}
|
||||
*called = true
|
||||
return nil
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
configureServer func(*Config, *bool)
|
||||
|
@ -1478,13 +1494,13 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
|
|||
configureServer: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||
return verifyCallback(called, rawCerts, validatedChains)
|
||||
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
|
||||
}
|
||||
},
|
||||
configureClient: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||
return verifyCallback(called, rawCerts, validatedChains)
|
||||
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
|
||||
}
|
||||
},
|
||||
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||
|
@ -1565,6 +1581,116 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
|
|||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
configureServer: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
return verifyConnectionCallback(called, false, c)
|
||||
}
|
||||
},
|
||||
configureClient: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
return verifyConnectionCallback(called, true, c)
|
||||
}
|
||||
},
|
||||
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||
if clientErr != nil {
|
||||
t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
|
||||
}
|
||||
if serverErr != nil {
|
||||
t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
|
||||
}
|
||||
if !clientCalled {
|
||||
t.Errorf("test[%d]: client did not call callback", testNo)
|
||||
}
|
||||
if !serverCalled {
|
||||
t.Errorf("test[%d]: server did not call callback", testNo)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
configureServer: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
return sentinelErr
|
||||
}
|
||||
},
|
||||
configureClient: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyConnection = nil
|
||||
},
|
||||
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||
if serverErr != sentinelErr {
|
||||
t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
configureServer: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyConnection = nil
|
||||
},
|
||||
configureClient: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
return sentinelErr
|
||||
}
|
||||
},
|
||||
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||
if clientErr != sentinelErr {
|
||||
t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
configureServer: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
|
||||
}
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
return sentinelErr
|
||||
}
|
||||
},
|
||||
configureClient: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyPeerCertificate = nil
|
||||
config.VerifyConnection = nil
|
||||
},
|
||||
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||
if serverErr != sentinelErr {
|
||||
t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
|
||||
}
|
||||
if !serverCalled {
|
||||
t.Errorf("test[%d]: server did not call callback", testNo)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
configureServer: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyPeerCertificate = nil
|
||||
config.VerifyConnection = nil
|
||||
},
|
||||
configureClient: func(config *Config, called *bool) {
|
||||
config.InsecureSkipVerify = false
|
||||
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||
return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
|
||||
}
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
return sentinelErr
|
||||
}
|
||||
},
|
||||
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||
if clientErr != sentinelErr {
|
||||
t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
|
||||
}
|
||||
if !clientCalled {
|
||||
t.Errorf("test[%d]: client did not call callback", testNo)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
|
@ -1580,6 +1706,11 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
|
|||
config.ClientCAs = rootCAs
|
||||
config.Time = now
|
||||
config.MaxVersion = version
|
||||
config.Certificates = make([]Certificate, 1)
|
||||
config.Certificates[0].Certificate = [][]byte{testRSACertificate}
|
||||
config.Certificates[0].PrivateKey = testRSAPrivateKey
|
||||
config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
|
||||
config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
|
||||
test.configureServer(config, &serverCalled)
|
||||
|
||||
err = Server(s, config).Handshake()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue