crypto/tls: set CipherSuite for VerifyConnection

The ConnectionState's CipherSuite was not set prior
to the VerifyConnection callback in TLS 1.2 servers,
both for full handshakes and resumptions.

Change-Id: Iab91783eff84d1b42ca09c8df08e07861e18da30
Reviewed-on: https://go-review.googlesource.com/c/go/+/236558
Run-TryBot: Katie Hockman <katie@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Katie Hockman 2020-06-04 10:52:24 -04:00
parent 7950929c30
commit e90314e882
2 changed files with 19 additions and 15 deletions

View file

@ -1470,25 +1470,28 @@ func TestVerifyConnection(t *testing.T) {
} }
func testVerifyConnection(t *testing.T, version uint16) { func testVerifyConnection(t *testing.T, version uint16) {
checkFields := func(c ConnectionState, called *int) error { checkFields := func(c ConnectionState, called *int, errorType string) error {
if c.Version != version { if c.Version != version {
return fmt.Errorf("got Version %v, want %v", c.Version, version) return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version)
} }
if c.HandshakeComplete { if c.HandshakeComplete {
return fmt.Errorf("got HandshakeComplete, want false") return fmt.Errorf("%s: got HandshakeComplete, want false", errorType)
} }
if c.ServerName != "example.golang" { if c.ServerName != "example.golang" {
return fmt.Errorf("got ServerName %s, want %s", c.ServerName, "example.golang") return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang")
} }
if c.NegotiatedProtocol != "protocol1" { if c.NegotiatedProtocol != "protocol1" {
return fmt.Errorf("got NegotiatedProtocol %s, want %s", c.NegotiatedProtocol, "protocol1") return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1")
}
if c.CipherSuite == 0 {
return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType)
} }
wantDidResume := false wantDidResume := false
if *called == 2 { // if this is the second time, then it should be a resumption if *called == 2 { // if this is the second time, then it should be a resumption
wantDidResume = true wantDidResume = true
} }
if c.DidResume != wantDidResume { if c.DidResume != wantDidResume {
return fmt.Errorf("got DidResume %t, want %t", c.DidResume, wantDidResume) return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume)
} }
return nil return nil
} }
@ -1510,7 +1513,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
if len(c.VerifiedChains) == 0 { if len(c.VerifiedChains) == 0 {
return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero") return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
} }
return checkFields(c, called) return checkFields(c, called, "server")
} }
}, },
configureClient: func(config *Config, called *int) { configureClient: func(config *Config, called *int) {
@ -1533,7 +1536,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
if len(c.SignedCertificateTimestamps) == 0 { if len(c.SignedCertificateTimestamps) == 0 {
return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
} }
return checkFields(c, called) return checkFields(c, called, "client")
} }
}, },
}, },
@ -1550,7 +1553,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
if c.VerifiedChains != nil { if c.VerifiedChains != nil {
return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains) return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
} }
return checkFields(c, called) return checkFields(c, called, "server")
} }
}, },
configureClient: func(config *Config, called *int) { configureClient: func(config *Config, called *int) {
@ -1574,7 +1577,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
if len(c.SignedCertificateTimestamps) == 0 { if len(c.SignedCertificateTimestamps) == 0 {
return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
} }
return checkFields(c, called) return checkFields(c, called, "client")
} }
}, },
}, },
@ -1584,13 +1587,13 @@ func testVerifyConnection(t *testing.T, version uint16) {
config.ClientAuth = NoClientCert config.ClientAuth = NoClientCert
config.VerifyConnection = func(c ConnectionState) error { config.VerifyConnection = func(c ConnectionState) error {
*called++ *called++
return checkFields(c, called) return checkFields(c, called, "server")
} }
}, },
configureClient: func(config *Config, called *int) { configureClient: func(config *Config, called *int) {
config.VerifyConnection = func(c ConnectionState) error { config.VerifyConnection = func(c ConnectionState) error {
*called++ *called++
return checkFields(c, called) return checkFields(c, called, "client")
} }
}, },
}, },
@ -1600,7 +1603,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
config.ClientAuth = RequestClientCert config.ClientAuth = RequestClientCert
config.VerifyConnection = func(c ConnectionState) error { config.VerifyConnection = func(c ConnectionState) error {
*called++ *called++
return checkFields(c, called) return checkFields(c, called, "server")
} }
}, },
configureClient: func(config *Config, called *int) { configureClient: func(config *Config, called *int) {
@ -1624,7 +1627,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
if len(c.SignedCertificateTimestamps) == 0 { if len(c.SignedCertificateTimestamps) == 0 {
return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
} }
return checkFields(c, called) return checkFields(c, called, "client")
} }
}, },
}, },

View file

@ -308,6 +308,7 @@ func (hs *serverHandshakeState) pickCipherSuite() error {
c.sendAlert(alertHandshakeFailure) c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no cipher suite supported by both client and server") return errors.New("tls: no cipher suite supported by both client and server")
} }
c.cipherSuite = hs.suite.id
for _, id := range hs.clientHello.cipherSuites { for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV { if id == TLS_FALLBACK_SCSV {
@ -407,6 +408,7 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
c := hs.c c := hs.c
hs.hello.cipherSuite = hs.suite.id hs.hello.cipherSuite = hs.suite.id
c.cipherSuite = hs.suite.id
// We echo the client's session ID in the ServerHello to let it know // We echo the client's session ID in the ServerHello to let it know
// that we're doing a resumption. // that we're doing a resumption.
hs.hello.sessionId = hs.clientHello.sessionId hs.hello.sessionId = hs.clientHello.sessionId
@ -743,7 +745,6 @@ func (hs *serverHandshakeState) sendFinished(out []byte) error {
return err return err
} }
c.cipherSuite = hs.suite.id
copy(out, finished.verifyData) copy(out, finished.verifyData)
return nil return nil