diff --git a/handshake_client.go b/handshake_client.go index a4ca5d3..284a776 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -29,51 +29,40 @@ type clientHandshakeState struct { session *ClientSessionState } -// c.out.Mutex <= L; c.handshakeMutex <= L. -func (c *Conn) clientHandshake() error { - if c.config == nil { - c.config = defaultConfig() +func makeClientHello(config *Config) (*clientHelloMsg, error) { + if config == nil { + config = defaultConfig() } - - // This may be a renegotiation handshake, in which case some fields - // need to be reset. - c.didResume = false - - if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify { - return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") + if len(config.ServerName) == 0 && !config.InsecureSkipVerify { + return nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") } nextProtosLength := 0 - for _, proto := range c.config.NextProtos { + for _, proto := range config.NextProtos { if l := len(proto); l == 0 || l > 255 { - return errors.New("tls: invalid NextProtos value") + return nil, errors.New("tls: invalid NextProtos value") } else { nextProtosLength += 1 + l } } if nextProtosLength > 0xffff { - return errors.New("tls: NextProtos values too large") + return nil, errors.New("tls: NextProtos values too large") } hello := &clientHelloMsg{ - vers: c.config.maxVersion(), + vers: config.maxVersion(), compressionMethods: []uint8{compressionNone}, random: make([]byte, 32), ocspStapling: true, scts: true, - serverName: hostnameInSNI(c.config.ServerName), - supportedCurves: c.config.curvePreferences(), + serverName: hostnameInSNI(config.ServerName), + supportedCurves: config.curvePreferences(), supportedPoints: []uint8{pointFormatUncompressed}, - nextProtoNeg: len(c.config.NextProtos) > 0, + nextProtoNeg: len(config.NextProtos) > 0, secureRenegotiationSupported: true, - alpnProtocols: c.config.NextProtos, + alpnProtocols: config.NextProtos, } - - if c.handshakes > 0 { - hello.secureRenegotiation = c.clientFinished[:] - } - - possibleCipherSuites := c.config.cipherSuites() + possibleCipherSuites := config.cipherSuites() hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites)) NextCipherSuite: @@ -92,15 +81,36 @@ NextCipherSuite: } } - _, err := io.ReadFull(c.config.rand(), hello.random) + _, err := io.ReadFull(config.rand(), hello.random) if err != nil { - c.sendAlert(alertInternalError) - return errors.New("tls: short read from Rand: " + err.Error()) + return nil, errors.New("tls: short read from Rand: " + err.Error()) } if hello.vers >= VersionTLS12 { hello.signatureAndHashes = supportedSignatureAlgorithms } + return hello, nil + +} + +// c.out.Mutex <= L; c.handshakeMutex <= L. +func (c *Conn) clientHandshake() error { + if c.config == nil { + c.config = defaultConfig() + } + + // This may be a renegotiation handshake, in which case some fields + // need to be reset. + c.didResume = false + + hello, err := makeClientHello(c.config) + if err != nil { + return err + } + + if c.handshakes > 0 { + hello.secureRenegotiation = c.clientFinished[:] + } var session *ClientSessionState var cacheKey string @@ -147,54 +157,70 @@ NextCipherSuite: // (see RFC 5077). hello.sessionId = make([]byte, 16) if _, err := io.ReadFull(c.config.rand(), hello.sessionId); err != nil { - c.sendAlert(alertInternalError) return errors.New("tls: short read from Rand: " + err.Error()) } } - if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + hs := &clientHandshakeState{ + c: c, + hello: hello, + session: session, + } + + if err = hs.handshake(); err != nil { return err } - msg, err := c.readHandshake() + // If we had a successful handshake and hs.session is different from the one already cached - cache a new one + if sessionCache != nil && hs.session != nil && session != hs.session { + sessionCache.Put(cacheKey, hs.session) + } + return nil +} + +// Does the handshake, either a full one or resumes old session. +// Requires hs.c, hs.hello, and, optionally, hs.session to be set. +func (hs *clientHandshakeState) handshake() error { + if hs.c == nil { + return errors.New("tls: corrupted clientHandshakeState: hs.c is unset") + } + c := hs.c + + if hs.hello == nil { + return errors.New("tls: corrupted clientHandshakeState: hs.hello is unset") + } + + // send ClientHello + if _, err := hs.c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + msg, err := hs.c.readHandshake() if err != nil { return err } + serverHello, ok := msg.(*serverHelloMsg) if !ok { - c.sendAlert(alertUnexpectedMessage) + hs.c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(serverHello, msg) } + hs.serverHello = serverHello - vers, ok := c.config.mutualVersion(serverHello.vers) - if !ok || vers < VersionTLS10 { - // TLS 1.0 is the minimum version supported as a client. - c.sendAlert(alertProtocolVersion) - return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers) - } - c.vers = vers - c.haveVers = true - - suite := mutualCipherSuite(hello.cipherSuites, serverHello.cipherSuite) - if suite == nil { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: server chose an unconfigured cipher suite") - } - - hs := &clientHandshakeState{ - c: c, - serverHello: serverHello, - hello: hello, - suite: suite, - finishedHash: newFinishedHash(c.vers, suite), - session: session, - } - - isResume, err := hs.processServerHello() - if err != nil { + if err = hs.pickTLSVersion(); err != nil { return err } + if err = hs.pickCipherSuite(); err != nil { + return err + } + + var isResume bool + if isResume, err = hs.processServerHello(); err != nil { + return err + } + + hs.finishedHash = newFinishedHash(c.vers, hs.suite) // No signatures of the handshake are needed in a resumption. // Otherwise, in a full handshake, if we don't have any certificates // configured then we will never send a CertificateVerify message and @@ -246,13 +272,32 @@ NextCipherSuite: } } - if sessionCache != nil && hs.session != nil && session != hs.session { - sessionCache.Put(cacheKey, hs.session) - } - c.didResume = isResume c.handshakeComplete = true - c.cipherSuite = suite.id + + return nil +} + +func (hs *clientHandshakeState) pickTLSVersion() error { + vers, ok := hs.c.config.mutualVersion(hs.serverHello.vers) + if !ok || vers < VersionTLS10 { + // TLS 1.0 is the minimum version supported as a client. + hs.c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: server selected unsupported protocol version %x", hs.serverHello.vers) + } + hs.c.vers = vers + hs.c.haveVers = true + return nil +} + +func (hs *clientHandshakeState) pickCipherSuite() error { + suite := mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite) + if suite == nil { + hs.c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server chose an unconfigured cipher suite") + } + hs.suite = suite + hs.c.cipherSuite = suite.id return nil }