diff --git a/handshake_client_test.go b/handshake_client_test.go index de93e1b..88c974f 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -1464,6 +1464,225 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { } } +func TestVerifyConnection(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) }) +} + +func testVerifyConnection(t *testing.T, version uint16) { + checkFields := func(c ConnectionState, called *int) error { + if c.Version != version { + return fmt.Errorf("got Version %v, want %v", c.Version, version) + } + if c.HandshakeComplete { + return fmt.Errorf("got HandshakeComplete, want false") + } + if c.ServerName != "example.golang" { + return fmt.Errorf("got ServerName %s, want %s", c.ServerName, "example.golang") + } + if c.NegotiatedProtocol != "protocol1" { + return fmt.Errorf("got NegotiatedProtocol %s, want %s", c.NegotiatedProtocol, "protocol1") + } + wantDidResume := false + if *called == 2 { // if this is the second time, then it should be a resumption + wantDidResume = true + } + if c.DidResume != wantDidResume { + return fmt.Errorf("got DidResume %t, want %t", c.DidResume, wantDidResume) + } + return nil + } + + tests := []struct { + name string + configureServer func(*Config, *int) + configureClient func(*Config, *int) + }{ + { + name: "RequireAndVerifyClientCert", + configureServer: func(config *Config, called *int) { + config.ClientAuth = RequireAndVerifyClientCert + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero") + } + return checkFields(c, called) + } + }, + configureClient: func(config *Config, called *int) { + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero") + } + if c.DidResume { + return nil + // The SCTs and OCSP Responce are dropped on resumption. + // See http://golang.org/issue/39075. + } + if len(c.OCSPResponse) == 0 { + return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") + } + if len(c.SignedCertificateTimestamps) == 0 { + return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") + } + return checkFields(c, called) + } + }, + }, + { + name: "InsecureSkipVerify", + configureServer: func(config *Config, called *int) { + config.ClientAuth = RequireAnyClientCert + config.InsecureSkipVerify = true + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l) + } + if c.VerifiedChains != nil { + return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains) + } + return checkFields(c, called) + } + }, + configureClient: func(config *Config, called *int) { + config.InsecureSkipVerify = true + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) + } + if c.VerifiedChains != nil { + return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains) + } + if c.DidResume { + return nil + // The SCTs and OCSP Responce are dropped on resumption. + // See http://golang.org/issue/39075. + } + if len(c.OCSPResponse) == 0 { + return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") + } + if len(c.SignedCertificateTimestamps) == 0 { + return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") + } + return checkFields(c, called) + } + }, + }, + { + name: "NoClientCert", + configureServer: func(config *Config, called *int) { + config.ClientAuth = NoClientCert + config.VerifyConnection = func(c ConnectionState) error { + *called++ + return checkFields(c, called) + } + }, + configureClient: func(config *Config, called *int) { + config.VerifyConnection = func(c ConnectionState) error { + *called++ + return checkFields(c, called) + } + }, + }, + { + name: "RequestClientCert", + configureServer: func(config *Config, called *int) { + config.ClientAuth = RequestClientCert + config.VerifyConnection = func(c ConnectionState) error { + *called++ + return checkFields(c, called) + } + }, + configureClient: func(config *Config, called *int) { + config.Certificates = nil // clear the client cert + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero") + } + if c.DidResume { + return nil + // The SCTs and OCSP Responce are dropped on resumption. + // See http://golang.org/issue/39075. + } + if len(c.OCSPResponse) == 0 { + return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") + } + if len(c.SignedCertificateTimestamps) == 0 { + return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") + } + return checkFields(c, called) + } + }, + }, + } + for _, test := range tests { + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + rootCAs := x509.NewCertPool() + rootCAs.AddCert(issuer) + + var serverCalled, clientCalled int + + serverConfig := &Config{ + MaxVersion: version, + Certificates: []Certificate{testConfig.Certificates[0]}, + ClientCAs: rootCAs, + NextProtos: []string{"protocol1"}, + } + serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} + serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp") + test.configureServer(serverConfig, &serverCalled) + + clientConfig := &Config{ + MaxVersion: version, + ClientSessionCache: NewLRUClientSessionCache(32), + RootCAs: rootCAs, + ServerName: "example.golang", + Certificates: []Certificate{testConfig.Certificates[0]}, + NextProtos: []string{"protocol1"}, + } + test.configureClient(clientConfig, &clientCalled) + + testHandshakeState := func(name string, didResume bool) { + _, hs, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("%s: handshake failed: %s", name, err) + } + if hs.DidResume != didResume { + t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume) + } + wantCalled := 1 + if didResume { + wantCalled = 2 // resumption would mean this is the second time it was called in this test + } + if clientCalled != wantCalled { + t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled) + } + if serverCalled != wantCalled { + t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled) + } + } + testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false) + testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true) + } +} + func TestVerifyPeerCertificate(t *testing.T) { t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) }) t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) }) diff --git a/handshake_server.go b/handshake_server.go index 57fba10..2c2f0a4 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -425,6 +425,13 @@ func (hs *serverHandshakeState) doResumeHandshake() error { return err } + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + hs.masterSecret = hs.sessionState.masterSecret return nil @@ -548,14 +555,11 @@ func (hs *serverHandshakeState) doFullHandshake() error { if err != nil { return err } - } else { - // Make sure the connection is still being verified whether or not - // the server requested a client certificate. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } + } + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err } } @@ -805,13 +809,6 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { } } - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - return nil } diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index fb7f871..92d55e0 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -783,6 +783,13 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { return err } + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + if len(certMsg.certificate.Certificate) != 0 { msg, err = c.readHandshake() if err != nil {