diff --git a/conn.go b/conn.go index 1861a31..b1a7dcc 100644 --- a/conn.go +++ b/conn.go @@ -30,11 +30,10 @@ type Conn struct { isClient bool handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake - // handshakeStatus is 1 if the connection is currently transferring + // isHandshakeComplete is true if the connection is currently transferring // application data (i.e. is not currently processing a handshake). - // handshakeStatus == 1 implies handshakeErr == nil. - // This field is only to be accessed with sync/atomic. - handshakeStatus uint32 + // isHandshakeComplete is true implies handshakeErr == nil. + isHandshakeComplete atomic.Bool // constant after handshake; protected by handshakeMutex handshakeMutex sync.Mutex handshakeErr error // error resulting from handshake @@ -604,7 +603,7 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { if c.in.err != nil { return c.in.err } - handshakeComplete := c.handshakeComplete() + handshakeComplete := c.isHandshakeComplete.Load() // This function modifies c.rawInput, which owns the c.input memory. if c.input.Len() != 0 { @@ -1130,7 +1129,7 @@ func (c *Conn) Write(b []byte) (int, error) { return 0, err } - if !c.handshakeComplete() { + if !c.isHandshakeComplete.Load() { return 0, alertInternalError } @@ -1200,7 +1199,7 @@ func (c *Conn) handleRenegotiation() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() - atomic.StoreUint32(&c.handshakeStatus, 0) + c.isHandshakeComplete.Store(false) if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { c.handshakes++ } @@ -1337,7 +1336,7 @@ func (c *Conn) Close() error { } var alertErr error - if c.handshakeComplete() { + if c.isHandshakeComplete.Load() { if err := c.closeNotify(); err != nil { alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) } @@ -1355,7 +1354,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com // called once the handshake has completed and does not call CloseWrite on the // underlying connection. Most callers should just use Close. func (c *Conn) CloseWrite() error { - if !c.handshakeComplete() { + if !c.isHandshakeComplete.Load() { return errEarlyCloseWrite } @@ -1409,7 +1408,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) { // Fast sync/atomic-based exit if there is no handshake in flight and the // last one succeeded without an error. Avoids the expensive context setup // and mutex for most Read and Write calls. - if c.handshakeComplete() { + if c.isHandshakeComplete.Load() { return nil } @@ -1452,7 +1451,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) { if err := c.handshakeErr; err != nil { return err } - if c.handshakeComplete() { + if c.isHandshakeComplete.Load() { return nil } @@ -1468,10 +1467,10 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) { c.flush() } - if c.handshakeErr == nil && !c.handshakeComplete() { + if c.handshakeErr == nil && !c.isHandshakeComplete.Load() { c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") } - if c.handshakeErr != nil && c.handshakeComplete() { + if c.handshakeErr != nil && c.isHandshakeComplete.Load() { panic("tls: internal error: handshake returned an error but is marked successful") } @@ -1487,7 +1486,7 @@ func (c *Conn) ConnectionState() ConnectionState { func (c *Conn) connectionStateLocked() ConnectionState { var state ConnectionState - state.HandshakeComplete = c.handshakeComplete() + state.HandshakeComplete = c.isHandshakeComplete.Load() state.Version = c.vers state.NegotiatedProtocol = c.clientProtocol state.DidResume = c.didResume @@ -1531,7 +1530,7 @@ func (c *Conn) VerifyHostname(host string) error { if !c.isClient { return errors.New("tls: VerifyHostname called on TLS server connection") } - if !c.handshakeComplete() { + if !c.isHandshakeComplete.Load() { return errors.New("tls: handshake has not yet been performed") } if len(c.verifiedChains) == 0 { @@ -1539,7 +1538,3 @@ func (c *Conn) VerifyHostname(host string) error { } return c.peerCertificates[0].VerifyHostname(host) } - -func (c *Conn) handshakeComplete() bool { - return atomic.LoadUint32(&c.handshakeStatus) == 1 -} diff --git a/handshake_client.go b/handshake_client.go index e61e3eb..e07cf79 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -19,7 +19,6 @@ import ( "io" "net" "strings" - "sync/atomic" "time" ) @@ -455,7 +454,7 @@ func (hs *clientHandshakeState) handshake() error { } c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) - atomic.StoreUint32(&c.handshakeStatus, 1) + c.isHandshakeComplete.Store(true) return nil } diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index c798986..ac783af 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -12,7 +12,6 @@ import ( "crypto/rsa" "errors" "hash" - "sync/atomic" "time" ) @@ -104,7 +103,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { return err } - atomic.StoreUint32(&c.handshakeStatus, 1) + c.isHandshakeComplete.Store(true) return nil } diff --git a/handshake_server.go b/handshake_server.go index 7606305..844e887 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -16,7 +16,6 @@ import ( "fmt" "hash" "io" - "sync/atomic" "time" ) @@ -122,7 +121,7 @@ func (hs *serverHandshakeState) handshake() error { } c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) - atomic.StoreUint32(&c.handshakeStatus, 1) + c.isHandshakeComplete.Store(true) return nil } diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 03a477f..712f358 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -14,7 +14,6 @@ import ( "errors" "hash" "io" - "sync/atomic" "time" ) @@ -82,7 +81,7 @@ func (hs *serverHandshakeStateTLS13) handshake() error { return err } - atomic.StoreUint32(&c.handshakeStatus, 1) + c.isHandshakeComplete.Store(true) return nil }