From 2ad14d1adfa4fe330422dd7269d5e0daee4884b9 Mon Sep 17 00:00:00 2001 From: Katie Hockman Date: Mon, 20 Apr 2020 17:55:37 -0400 Subject: [PATCH] crypto/tls: add Config.VerifyConnection callback Since the ConnectionState will now be available during verification, some code was moved around in order to initialize and make available as much of the fields on Conn as possible before the ConnectionState is verified. Fixes #36736 Change-Id: I0e3efa97565ead7de5c48bb8a87e3ea54fbde140 Reviewed-on: https://go-review.googlesource.com/c/go/+/229122 Run-TryBot: Katie Hockman Reviewed-by: Filippo Valsorda TryBot-Result: Gobot Gobot --- common.go | 13 +++- conn.go | 43 ++++++------ handshake_client.go | 71 ++++++++++++-------- handshake_client_test.go | 137 +++++++++++++++++++++++++++++++++++++- handshake_client_tls13.go | 9 +++ handshake_server.go | 41 ++++++++---- handshake_server_tls13.go | 10 ++- tls_test.go | 13 ++-- 8 files changed, 266 insertions(+), 71 deletions(-) diff --git a/common.go b/common.go index 90846a3..fd21ae8 100644 --- a/common.go +++ b/common.go @@ -219,7 +219,7 @@ type ConnectionState struct { CipherSuite uint16 // cipher suite in use (TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, ...) NegotiatedProtocol string // negotiated next protocol (not guaranteed to be from Config.NextProtos) NegotiatedProtocolIsMutual bool // negotiated protocol was advertised by server (client side only) - ServerName string // server name requested by client, if any (server side only) + ServerName string // server name requested by client, if any PeerCertificates []*x509.Certificate // certificate chain presented by remote peer VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates SignedCertificateTimestamps [][]byte // SCTs from the peer, if any @@ -520,6 +520,16 @@ type Config struct { // be considered but the verifiedChains argument will always be nil. VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + // VerifyConnection, if not nil, is called after normal certificate + // verification and after VerifyPeerCertificate by either a TLS client + // or server. If it returns a non-nil error, the handshake is aborted + // and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. This callback will run for all connections + // regardless of InsecureSkipVerify or ClientAuth settings. + VerifyConnection func(ConnectionState) error + // RootCAs defines the set of root certificate authorities // that clients use when verifying server certificates. // If RootCAs is nil, TLS uses the host's root CA set. @@ -685,6 +695,7 @@ func (c *Config) Clone() *Config { GetClientCertificate: c.GetClientCertificate, GetConfigForClient: c.GetConfigForClient, VerifyPeerCertificate: c.VerifyPeerCertificate, + VerifyConnection: c.VerifyConnection, RootCAs: c.RootCAs, NextProtos: c.NextProtos, ServerName: c.ServerName, diff --git a/conn.go b/conn.go index d759986..edcfecf 100644 --- a/conn.go +++ b/conn.go @@ -1379,35 +1379,34 @@ func (c *Conn) Handshake() error { func (c *Conn) ConnectionState() ConnectionState { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() + return c.connectionStateLocked() +} +func (c *Conn) connectionStateLocked() ConnectionState { var state ConnectionState state.HandshakeComplete = c.handshakeComplete() + state.Version = c.vers + state.NegotiatedProtocol = c.clientProtocol + state.DidResume = c.didResume + state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback state.ServerName = c.serverName - - if state.HandshakeComplete { - state.Version = c.vers - state.NegotiatedProtocol = c.clientProtocol - state.DidResume = c.didResume - state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback - state.CipherSuite = c.cipherSuite - state.PeerCertificates = c.peerCertificates - state.VerifiedChains = c.verifiedChains - state.SignedCertificateTimestamps = c.scts - state.OCSPResponse = c.ocspResponse - if !c.didResume && c.vers != VersionTLS13 { - if c.clientFinishedIsFirst { - state.TLSUnique = c.clientFinished[:] - } else { - state.TLSUnique = c.serverFinished[:] - } - } - if c.config.Renegotiation != RenegotiateNever { - state.ekm = noExportedKeyingMaterial + state.CipherSuite = c.cipherSuite + state.PeerCertificates = c.peerCertificates + state.VerifiedChains = c.verifiedChains + state.SignedCertificateTimestamps = c.scts + state.OCSPResponse = c.ocspResponse + if !c.didResume && c.vers != VersionTLS13 { + if c.clientFinishedIsFirst { + state.TLSUnique = c.clientFinished[:] } else { - state.ekm = c.ekm + state.TLSUnique = c.serverFinished[:] } } - + if c.config.Renegotiation != RenegotiateNever { + state.ekm = noExportedKeyingMaterial + } else { + state.ekm = c.ekm + } return state } diff --git a/handshake_client.go b/handshake_client.go index 210eece..40c8e02 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -146,6 +146,7 @@ func (c *Conn) clientHandshake() (err error) { if err != nil { return err } + c.serverName = hello.serverName cacheKey, session, earlySecret, binderKey := c.loadSession(hello) if cacheKey != "" && session != nil { @@ -388,6 +389,7 @@ func (hs *clientHandshakeState) handshake() error { hs.finishedHash.Write(hs.serverHello.marshal()) c.buffering = true + c.didResume = isResume if isResume { if err := hs.establishKeys(); err != nil { return err @@ -399,6 +401,15 @@ func (hs *clientHandshakeState) handshake() error { return err } c.clientFinishedIsFirst = false + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } if err := hs.sendFinished(c.clientFinished[:]); err != nil { return err } @@ -428,7 +439,6 @@ func (hs *clientHandshakeState) handshake() error { } c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) - c.didResume = isResume atomic.StoreUint32(&c.handshakeStatus, 1) return nil @@ -458,25 +468,6 @@ func (hs *clientHandshakeState) doFullHandshake() error { } hs.finishedHash.Write(certMsg.marshal()) - if c.handshakes == 0 { - // If this is the first handshake on a connection, process and - // (optionally) verify the server's certificates. - if err := c.verifyServerCertificate(certMsg.certificates); err != nil { - return err - } - } else { - // This is a renegotiation handshake. We require that the - // server's identity (i.e. leaf certificate) is unchanged and - // thus any previous trust decision is still valid. - // - // See https://mitls.org/pages/attacks/3SHAKE for the - // motivation behind this requirement. - if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { - c.sendAlert(alertBadCertificate) - return errors.New("tls: server's identity changed during renegotiation") - } - } - msg, err = c.readHandshake() if err != nil { return err @@ -505,6 +496,25 @@ func (hs *clientHandshakeState) doFullHandshake() error { } } + if c.handshakes == 0 { + // If this is the first handshake on a connection, process and + // (optionally) verify the server's certificates. + if err := c.verifyServerCertificate(certMsg.certificates); err != nil { + return err + } + } else { + // This is a renegotiation handshake. We require that the + // server's identity (i.e. leaf certificate) is unchanged and + // thus any previous trust decision is still valid. + // + // See https://mitls.org/pages/attacks/3SHAKE for the + // motivation behind this requirement. + if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: server's identity changed during renegotiation") + } + } + keyAgreement := hs.suite.ka(c.vers) skx, ok := msg.(*serverKeyExchangeMsg) @@ -831,13 +841,6 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { } } - if c.config.VerifyPeerCertificate != nil { - if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - switch certs[0].PublicKey.(type) { case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: break @@ -848,6 +851,20 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { c.peerCertificates = certs + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil } diff --git a/handshake_client_test.go b/handshake_client_test.go index cd387dc..313872c 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -907,6 +907,9 @@ func testResumption(t *testing.T, version uint16) { if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) { t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains) } + if got, want := hs.ServerName, clientConfig.ServerName; got != want { + t.Errorf("%s: server name %s, want %s", test, got, want) + } } getTicket := func() []byte { @@ -1458,7 +1461,7 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) { sentinelErr := errors.New("TestVerifyPeerCertificate") - verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { if l := len(rawCerts); l != 1 { return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) } @@ -1468,6 +1471,19 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) { *called = true return nil } + verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error { + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero") + } + if isClient && len(c.OCSPResponse) == 0 { + return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero") + } + *called = true + return nil + } tests := []struct { configureServer func(*Config, *bool) @@ -1478,13 +1494,13 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) { configureServer: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { - return verifyCallback(called, rawCerts, validatedChains) + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) } }, configureClient: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { - return verifyCallback(called, rawCerts, validatedChains) + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) } }, validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { @@ -1565,6 +1581,116 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) { } }, }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return verifyConnectionCallback(called, false, c) + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return verifyConnectionCallback(called, true, c) + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != nil { + t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) + } + if serverErr != nil { + t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) + } + if !clientCalled { + t.Errorf("test[%d]: client did not call callback", testNo) + } + if !serverCalled { + t.Errorf("test[%d]: server did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = nil + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if serverErr != sentinelErr { + t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = nil + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != sentinelErr { + t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) + } + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = nil + config.VerifyConnection = nil + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if serverErr != sentinelErr { + t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) + } + if !serverCalled { + t.Errorf("test[%d]: server did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = nil + config.VerifyConnection = nil + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) + } + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != sentinelErr { + t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) + } + if !clientCalled { + t.Errorf("test[%d]: client did not call callback", testNo) + } + }, + }, } for i, test := range tests { @@ -1580,6 +1706,11 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) { config.ClientCAs = rootCAs config.Time = now config.MaxVersion = version + config.Certificates = make([]Certificate, 1) + config.Certificates[0].Certificate = [][]byte{testRSACertificate} + config.Certificates[0].PrivateKey = testRSAPrivateKey + config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} + config.Certificates[0].OCSPStaple = []byte("dummy ocsp") test.configureServer(config, &serverCalled) err = Server(s, config).Handshake() diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index 97122bd..35a00f2 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -407,6 +407,15 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { // Either a PSK or a certificate is always used, but not both. // See RFC 8446, Section 4.1.1. if hs.usingPSK { + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } return nil } diff --git a/handshake_server.go b/handshake_server.go index 4885c69..6aacfa1 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -68,6 +68,7 @@ func (hs *serverHandshakeState) handshake() error { c.buffering = true if hs.checkForResumption() { // The client has included a session ticket and so we do an abbreviated handshake. + c.didResume = true if err := hs.doResumeHandshake(); err != nil { return err } @@ -92,7 +93,6 @@ func (hs *serverHandshakeState) handshake() error { if err := hs.readFinished(nil); err != nil { return err } - c.didResume = true } else { // The client didn't include a session ticket, or it wasn't // valid so we do a full handshake. @@ -553,6 +553,15 @@ func (hs *serverHandshakeState) doFullHandshake() error { if err != nil { return err } + } else { + // Make sure the connection is still being verified whether or not + // the server requested a client certificate. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } } // Get client key exchange @@ -771,6 +780,19 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { c.verifiedChains = chains } + c.peerCertificates = certs + c.ocspResponse = certificate.OCSPStaple + c.scts = certificate.SignedCertificateTimestamps + + if len(certs) > 0 { + switch certs[0].PublicKey.(type) { + case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) + } + } + if c.config.VerifyPeerCertificate != nil { if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { c.sendAlert(alertBadCertificate) @@ -778,20 +800,13 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { } } - if len(certs) == 0 { - return nil + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } } - switch certs[0].PublicKey.(type) { - case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: - default: - c.sendAlert(alertUnsupportedCertificate) - return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) - } - - c.peerCertificates = certs - c.ocspResponse = certificate.OCSPStaple - c.scts = certificate.SignedCertificateTimestamps return nil } diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 5432145..fb7f871 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -306,6 +306,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { return errors.New("tls: invalid PSK binder") } + c.didResume = true if err := c.processCertsFromClient(sessionState.certificate); err != nil { return err } @@ -313,7 +314,6 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { hs.hello.selectedIdentityPresent = true hs.hello.selectedIdentity = uint16(i) hs.usingPSK = true - c.didResume = true return nil } @@ -753,6 +753,14 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { c := hs.c if !hs.requestClientCert() { + // Make sure the connection is still being verified whether or not + // the server requested a client certificate. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } return nil } diff --git a/tls_test.go b/tls_test.go index 85005d4..9e34077 100644 --- a/tls_test.go +++ b/tls_test.go @@ -734,7 +734,7 @@ func TestWarningAlertFlood(t *testing.T) { } func TestCloneFuncFields(t *testing.T) { - const expectedCount = 5 + const expectedCount = 6 called := 0 c1 := Config{ @@ -758,6 +758,10 @@ func TestCloneFuncFields(t *testing.T) { called |= 1 << 4 return nil }, + VerifyConnection: func(ConnectionState) error { + called |= 1 << 5 + return nil + }, } c2 := c1.Clone() @@ -767,6 +771,7 @@ func TestCloneFuncFields(t *testing.T) { c2.GetClientCertificate(nil) c2.GetConfigForClient(nil) c2.VerifyPeerCertificate(nil, nil) + c2.VerifyConnection(ConnectionState{}) if called != (1<