crypto/tls: fix duplicate calls to VerifyConnection

Also add a test that could reproduce this error and
ensure it doesn't occur in other configurations.

Fixes #39012

Change-Id: If792b5131f312c269fd2c5f08c9ed5c00188d1af
Reviewed-on: https://go-review.googlesource.com/c/go/+/233957
Run-TryBot: Katie Hockman <katie@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Katie Hockman 2020-05-13 17:44:20 -04:00
parent 5e9ce971e5
commit f505878117
3 changed files with 238 additions and 15 deletions

View file

@ -1464,6 +1464,225 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
}
}
func TestVerifyConnection(t *testing.T) {
t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
}
func testVerifyConnection(t *testing.T, version uint16) {
checkFields := func(c ConnectionState, called *int) error {
if c.Version != version {
return fmt.Errorf("got Version %v, want %v", c.Version, version)
}
if c.HandshakeComplete {
return fmt.Errorf("got HandshakeComplete, want false")
}
if c.ServerName != "example.golang" {
return fmt.Errorf("got ServerName %s, want %s", c.ServerName, "example.golang")
}
if c.NegotiatedProtocol != "protocol1" {
return fmt.Errorf("got NegotiatedProtocol %s, want %s", c.NegotiatedProtocol, "protocol1")
}
wantDidResume := false
if *called == 2 { // if this is the second time, then it should be a resumption
wantDidResume = true
}
if c.DidResume != wantDidResume {
return fmt.Errorf("got DidResume %t, want %t", c.DidResume, wantDidResume)
}
return nil
}
tests := []struct {
name string
configureServer func(*Config, *int)
configureClient func(*Config, *int)
}{
{
name: "RequireAndVerifyClientCert",
configureServer: func(config *Config, called *int) {
config.ClientAuth = RequireAndVerifyClientCert
config.VerifyConnection = func(c ConnectionState) error {
*called++
if l := len(c.PeerCertificates); l != 1 {
return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
}
if len(c.VerifiedChains) == 0 {
return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
}
return checkFields(c, called)
}
},
configureClient: func(config *Config, called *int) {
config.VerifyConnection = func(c ConnectionState) error {
*called++
if l := len(c.PeerCertificates); l != 1 {
return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
}
if len(c.VerifiedChains) == 0 {
return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
}
if c.DidResume {
return nil
// The SCTs and OCSP Responce are dropped on resumption.
// See http://golang.org/issue/39075.
}
if len(c.OCSPResponse) == 0 {
return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
}
if len(c.SignedCertificateTimestamps) == 0 {
return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
}
return checkFields(c, called)
}
},
},
{
name: "InsecureSkipVerify",
configureServer: func(config *Config, called *int) {
config.ClientAuth = RequireAnyClientCert
config.InsecureSkipVerify = true
config.VerifyConnection = func(c ConnectionState) error {
*called++
if l := len(c.PeerCertificates); l != 1 {
return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
}
if c.VerifiedChains != nil {
return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
}
return checkFields(c, called)
}
},
configureClient: func(config *Config, called *int) {
config.InsecureSkipVerify = true
config.VerifyConnection = func(c ConnectionState) error {
*called++
if l := len(c.PeerCertificates); l != 1 {
return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
}
if c.VerifiedChains != nil {
return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
}
if c.DidResume {
return nil
// The SCTs and OCSP Responce are dropped on resumption.
// See http://golang.org/issue/39075.
}
if len(c.OCSPResponse) == 0 {
return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
}
if len(c.SignedCertificateTimestamps) == 0 {
return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
}
return checkFields(c, called)
}
},
},
{
name: "NoClientCert",
configureServer: func(config *Config, called *int) {
config.ClientAuth = NoClientCert
config.VerifyConnection = func(c ConnectionState) error {
*called++
return checkFields(c, called)
}
},
configureClient: func(config *Config, called *int) {
config.VerifyConnection = func(c ConnectionState) error {
*called++
return checkFields(c, called)
}
},
},
{
name: "RequestClientCert",
configureServer: func(config *Config, called *int) {
config.ClientAuth = RequestClientCert
config.VerifyConnection = func(c ConnectionState) error {
*called++
return checkFields(c, called)
}
},
configureClient: func(config *Config, called *int) {
config.Certificates = nil // clear the client cert
config.VerifyConnection = func(c ConnectionState) error {
*called++
if l := len(c.PeerCertificates); l != 1 {
return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
}
if len(c.VerifiedChains) == 0 {
return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
}
if c.DidResume {
return nil
// The SCTs and OCSP Responce are dropped on resumption.
// See http://golang.org/issue/39075.
}
if len(c.OCSPResponse) == 0 {
return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
}
if len(c.SignedCertificateTimestamps) == 0 {
return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
}
return checkFields(c, called)
}
},
},
}
for _, test := range tests {
issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
if err != nil {
panic(err)
}
rootCAs := x509.NewCertPool()
rootCAs.AddCert(issuer)
var serverCalled, clientCalled int
serverConfig := &Config{
MaxVersion: version,
Certificates: []Certificate{testConfig.Certificates[0]},
ClientCAs: rootCAs,
NextProtos: []string{"protocol1"},
}
serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
test.configureServer(serverConfig, &serverCalled)
clientConfig := &Config{
MaxVersion: version,
ClientSessionCache: NewLRUClientSessionCache(32),
RootCAs: rootCAs,
ServerName: "example.golang",
Certificates: []Certificate{testConfig.Certificates[0]},
NextProtos: []string{"protocol1"},
}
test.configureClient(clientConfig, &clientCalled)
testHandshakeState := func(name string, didResume bool) {
_, hs, err := testHandshake(t, clientConfig, serverConfig)
if err != nil {
t.Fatalf("%s: handshake failed: %s", name, err)
}
if hs.DidResume != didResume {
t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
}
wantCalled := 1
if didResume {
wantCalled = 2 // resumption would mean this is the second time it was called in this test
}
if clientCalled != wantCalled {
t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
}
if serverCalled != wantCalled {
t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
}
}
testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
}
}
func TestVerifyPeerCertificate(t *testing.T) {
t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })

View file

@ -425,6 +425,13 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
return err
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
hs.masterSecret = hs.sessionState.masterSecret
return nil
@ -548,14 +555,11 @@ func (hs *serverHandshakeState) doFullHandshake() error {
if err != nil {
return err
}
} else {
// Make sure the connection is still being verified whether or not
// the server requested a client certificate.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
@ -805,13 +809,6 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
}
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}

View file

@ -783,6 +783,13 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
return err
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
if len(certMsg.certificate.Certificate) != 0 {
msg, err = c.readHandshake()
if err != nil {