mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-03 20:17:36 +03:00
crypto/tls: fix duplicate calls to VerifyConnection
Also add a test that could reproduce this error and ensure it doesn't occur in other configurations. Fixes #39012 Change-Id: If792b5131f312c269fd2c5f08c9ed5c00188d1af Reviewed-on: https://go-review.googlesource.com/c/go/+/233957 Run-TryBot: Katie Hockman <katie@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
parent
5e9ce971e5
commit
f505878117
3 changed files with 238 additions and 15 deletions
|
@ -1464,6 +1464,225 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestVerifyConnection(t *testing.T) {
|
||||
t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
|
||||
t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
|
||||
}
|
||||
|
||||
func testVerifyConnection(t *testing.T, version uint16) {
|
||||
checkFields := func(c ConnectionState, called *int) error {
|
||||
if c.Version != version {
|
||||
return fmt.Errorf("got Version %v, want %v", c.Version, version)
|
||||
}
|
||||
if c.HandshakeComplete {
|
||||
return fmt.Errorf("got HandshakeComplete, want false")
|
||||
}
|
||||
if c.ServerName != "example.golang" {
|
||||
return fmt.Errorf("got ServerName %s, want %s", c.ServerName, "example.golang")
|
||||
}
|
||||
if c.NegotiatedProtocol != "protocol1" {
|
||||
return fmt.Errorf("got NegotiatedProtocol %s, want %s", c.NegotiatedProtocol, "protocol1")
|
||||
}
|
||||
wantDidResume := false
|
||||
if *called == 2 { // if this is the second time, then it should be a resumption
|
||||
wantDidResume = true
|
||||
}
|
||||
if c.DidResume != wantDidResume {
|
||||
return fmt.Errorf("got DidResume %t, want %t", c.DidResume, wantDidResume)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
configureServer func(*Config, *int)
|
||||
configureClient func(*Config, *int)
|
||||
}{
|
||||
{
|
||||
name: "RequireAndVerifyClientCert",
|
||||
configureServer: func(config *Config, called *int) {
|
||||
config.ClientAuth = RequireAndVerifyClientCert
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
*called++
|
||||
if l := len(c.PeerCertificates); l != 1 {
|
||||
return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
|
||||
}
|
||||
if len(c.VerifiedChains) == 0 {
|
||||
return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
|
||||
}
|
||||
return checkFields(c, called)
|
||||
}
|
||||
},
|
||||
configureClient: func(config *Config, called *int) {
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
*called++
|
||||
if l := len(c.PeerCertificates); l != 1 {
|
||||
return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
|
||||
}
|
||||
if len(c.VerifiedChains) == 0 {
|
||||
return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
|
||||
}
|
||||
if c.DidResume {
|
||||
return nil
|
||||
// The SCTs and OCSP Responce are dropped on resumption.
|
||||
// See http://golang.org/issue/39075.
|
||||
}
|
||||
if len(c.OCSPResponse) == 0 {
|
||||
return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
|
||||
}
|
||||
if len(c.SignedCertificateTimestamps) == 0 {
|
||||
return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
|
||||
}
|
||||
return checkFields(c, called)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "InsecureSkipVerify",
|
||||
configureServer: func(config *Config, called *int) {
|
||||
config.ClientAuth = RequireAnyClientCert
|
||||
config.InsecureSkipVerify = true
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
*called++
|
||||
if l := len(c.PeerCertificates); l != 1 {
|
||||
return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
|
||||
}
|
||||
if c.VerifiedChains != nil {
|
||||
return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
|
||||
}
|
||||
return checkFields(c, called)
|
||||
}
|
||||
},
|
||||
configureClient: func(config *Config, called *int) {
|
||||
config.InsecureSkipVerify = true
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
*called++
|
||||
if l := len(c.PeerCertificates); l != 1 {
|
||||
return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
|
||||
}
|
||||
if c.VerifiedChains != nil {
|
||||
return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
|
||||
}
|
||||
if c.DidResume {
|
||||
return nil
|
||||
// The SCTs and OCSP Responce are dropped on resumption.
|
||||
// See http://golang.org/issue/39075.
|
||||
}
|
||||
if len(c.OCSPResponse) == 0 {
|
||||
return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
|
||||
}
|
||||
if len(c.SignedCertificateTimestamps) == 0 {
|
||||
return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
|
||||
}
|
||||
return checkFields(c, called)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NoClientCert",
|
||||
configureServer: func(config *Config, called *int) {
|
||||
config.ClientAuth = NoClientCert
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
*called++
|
||||
return checkFields(c, called)
|
||||
}
|
||||
},
|
||||
configureClient: func(config *Config, called *int) {
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
*called++
|
||||
return checkFields(c, called)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RequestClientCert",
|
||||
configureServer: func(config *Config, called *int) {
|
||||
config.ClientAuth = RequestClientCert
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
*called++
|
||||
return checkFields(c, called)
|
||||
}
|
||||
},
|
||||
configureClient: func(config *Config, called *int) {
|
||||
config.Certificates = nil // clear the client cert
|
||||
config.VerifyConnection = func(c ConnectionState) error {
|
||||
*called++
|
||||
if l := len(c.PeerCertificates); l != 1 {
|
||||
return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
|
||||
}
|
||||
if len(c.VerifiedChains) == 0 {
|
||||
return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
|
||||
}
|
||||
if c.DidResume {
|
||||
return nil
|
||||
// The SCTs and OCSP Responce are dropped on resumption.
|
||||
// See http://golang.org/issue/39075.
|
||||
}
|
||||
if len(c.OCSPResponse) == 0 {
|
||||
return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
|
||||
}
|
||||
if len(c.SignedCertificateTimestamps) == 0 {
|
||||
return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
|
||||
}
|
||||
return checkFields(c, called)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(issuer)
|
||||
|
||||
var serverCalled, clientCalled int
|
||||
|
||||
serverConfig := &Config{
|
||||
MaxVersion: version,
|
||||
Certificates: []Certificate{testConfig.Certificates[0]},
|
||||
ClientCAs: rootCAs,
|
||||
NextProtos: []string{"protocol1"},
|
||||
}
|
||||
serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
|
||||
serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
|
||||
test.configureServer(serverConfig, &serverCalled)
|
||||
|
||||
clientConfig := &Config{
|
||||
MaxVersion: version,
|
||||
ClientSessionCache: NewLRUClientSessionCache(32),
|
||||
RootCAs: rootCAs,
|
||||
ServerName: "example.golang",
|
||||
Certificates: []Certificate{testConfig.Certificates[0]},
|
||||
NextProtos: []string{"protocol1"},
|
||||
}
|
||||
test.configureClient(clientConfig, &clientCalled)
|
||||
|
||||
testHandshakeState := func(name string, didResume bool) {
|
||||
_, hs, err := testHandshake(t, clientConfig, serverConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("%s: handshake failed: %s", name, err)
|
||||
}
|
||||
if hs.DidResume != didResume {
|
||||
t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
|
||||
}
|
||||
wantCalled := 1
|
||||
if didResume {
|
||||
wantCalled = 2 // resumption would mean this is the second time it was called in this test
|
||||
}
|
||||
if clientCalled != wantCalled {
|
||||
t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
|
||||
}
|
||||
if serverCalled != wantCalled {
|
||||
t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
|
||||
}
|
||||
}
|
||||
testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
|
||||
testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPeerCertificate(t *testing.T) {
|
||||
t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
|
||||
t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })
|
||||
|
|
|
@ -425,6 +425,13 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
|
|||
return err
|
||||
}
|
||||
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
hs.masterSecret = hs.sessionState.masterSecret
|
||||
|
||||
return nil
|
||||
|
@ -548,14 +555,11 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Make sure the connection is still being verified whether or not
|
||||
// the server requested a client certificate.
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -805,13 +809,6 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
|
|||
}
|
||||
}
|
||||
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -783,6 +783,13 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
|
|||
return err
|
||||
}
|
||||
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(certMsg.certificate.Certificate) != 0 {
|
||||
msg, err = c.readHandshake()
|
||||
if err != nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue