crypto/tls: change Conn.handshakeStatus type to atomic.Bool

Change the type of Conn.handshakeStatus from an atomically
accessed uint32 to an atomic.Bool. Change its name to
Conn.isHandshakeComplete to indicate it is a boolean value.
Eliminate the handshakeComplete() helper function, which checks
for equality with 1, in favor of the simpler
c.isHandshakeComplete.Load().

Change-Id: I084c83956fff266e2145847e8645372bef6ae9df
Reviewed-on: https://go-review.googlesource.com/c/go/+/422296
Auto-Submit: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Than McIntosh <thanm@google.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Ludi Rehak 2022-08-09 09:36:17 -07:00 committed by Gopher Robot
parent 057db2c48b
commit 8011ffeccb
5 changed files with 18 additions and 27 deletions

33
conn.go
View file

@ -30,11 +30,10 @@ type Conn struct {
isClient bool
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
// handshakeStatus is 1 if the connection is currently transferring
// isHandshakeComplete is true if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
// handshakeStatus == 1 implies handshakeErr == nil.
// This field is only to be accessed with sync/atomic.
handshakeStatus uint32
// isHandshakeComplete is true implies handshakeErr == nil.
isHandshakeComplete atomic.Bool
// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex
handshakeErr error // error resulting from handshake
@ -604,7 +603,7 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
if c.in.err != nil {
return c.in.err
}
handshakeComplete := c.handshakeComplete()
handshakeComplete := c.isHandshakeComplete.Load()
// This function modifies c.rawInput, which owns the c.input memory.
if c.input.Len() != 0 {
@ -1130,7 +1129,7 @@ func (c *Conn) Write(b []byte) (int, error) {
return 0, err
}
if !c.handshakeComplete() {
if !c.isHandshakeComplete.Load() {
return 0, alertInternalError
}
@ -1200,7 +1199,7 @@ func (c *Conn) handleRenegotiation() error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
atomic.StoreUint32(&c.handshakeStatus, 0)
c.isHandshakeComplete.Store(false)
if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
c.handshakes++
}
@ -1337,7 +1336,7 @@ func (c *Conn) Close() error {
}
var alertErr error
if c.handshakeComplete() {
if c.isHandshakeComplete.Load() {
if err := c.closeNotify(); err != nil {
alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
}
@ -1355,7 +1354,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
// called once the handshake has completed and does not call CloseWrite on the
// underlying connection. Most callers should just use Close.
func (c *Conn) CloseWrite() error {
if !c.handshakeComplete() {
if !c.isHandshakeComplete.Load() {
return errEarlyCloseWrite
}
@ -1409,7 +1408,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
// Fast sync/atomic-based exit if there is no handshake in flight and the
// last one succeeded without an error. Avoids the expensive context setup
// and mutex for most Read and Write calls.
if c.handshakeComplete() {
if c.isHandshakeComplete.Load() {
return nil
}
@ -1452,7 +1451,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
if err := c.handshakeErr; err != nil {
return err
}
if c.handshakeComplete() {
if c.isHandshakeComplete.Load() {
return nil
}
@ -1468,10 +1467,10 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
c.flush()
}
if c.handshakeErr == nil && !c.handshakeComplete() {
if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
}
if c.handshakeErr != nil && c.handshakeComplete() {
if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
panic("tls: internal error: handshake returned an error but is marked successful")
}
@ -1487,7 +1486,7 @@ func (c *Conn) ConnectionState() ConnectionState {
func (c *Conn) connectionStateLocked() ConnectionState {
var state ConnectionState
state.HandshakeComplete = c.handshakeComplete()
state.HandshakeComplete = c.isHandshakeComplete.Load()
state.Version = c.vers
state.NegotiatedProtocol = c.clientProtocol
state.DidResume = c.didResume
@ -1531,7 +1530,7 @@ func (c *Conn) VerifyHostname(host string) error {
if !c.isClient {
return errors.New("tls: VerifyHostname called on TLS server connection")
}
if !c.handshakeComplete() {
if !c.isHandshakeComplete.Load() {
return errors.New("tls: handshake has not yet been performed")
}
if len(c.verifiedChains) == 0 {
@ -1539,7 +1538,3 @@ func (c *Conn) VerifyHostname(host string) error {
}
return c.peerCertificates[0].VerifyHostname(host)
}
func (c *Conn) handshakeComplete() bool {
return atomic.LoadUint32(&c.handshakeStatus) == 1
}

View file

@ -19,7 +19,6 @@ import (
"io"
"net"
"strings"
"sync/atomic"
"time"
)
@ -455,7 +454,7 @@ func (hs *clientHandshakeState) handshake() error {
}
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
atomic.StoreUint32(&c.handshakeStatus, 1)
c.isHandshakeComplete.Store(true)
return nil
}

View file

@ -12,7 +12,6 @@ import (
"crypto/rsa"
"errors"
"hash"
"sync/atomic"
"time"
)
@ -104,7 +103,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
return err
}
atomic.StoreUint32(&c.handshakeStatus, 1)
c.isHandshakeComplete.Store(true)
return nil
}

View file

@ -16,7 +16,6 @@ import (
"fmt"
"hash"
"io"
"sync/atomic"
"time"
)
@ -122,7 +121,7 @@ func (hs *serverHandshakeState) handshake() error {
}
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
atomic.StoreUint32(&c.handshakeStatus, 1)
c.isHandshakeComplete.Store(true)
return nil
}

View file

@ -14,7 +14,6 @@ import (
"errors"
"hash"
"io"
"sync/atomic"
"time"
)
@ -82,7 +81,7 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
return err
}
atomic.StoreUint32(&c.handshakeStatus, 1)
c.isHandshakeComplete.Store(true)
return nil
}