mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-03 03:57:36 +03:00
crypto/tls: change Conn.handshakeStatus type to atomic.Bool
Change the type of Conn.handshakeStatus from an atomically accessed uint32 to an atomic.Bool. Change its name to Conn.isHandshakeComplete to indicate it is a boolean value. Eliminate the handshakeComplete() helper function, which checks for equality with 1, in favor of the simpler c.isHandshakeComplete.Load(). Change-Id: I084c83956fff266e2145847e8645372bef6ae9df Reviewed-on: https://go-review.googlesource.com/c/go/+/422296 Auto-Submit: Filippo Valsorda <filippo@golang.org> TryBot-Result: Gopher Robot <gobot@golang.org> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com> Reviewed-by: Than McIntosh <thanm@google.com> Reviewed-by: Filippo Valsorda <filippo@golang.org> Run-TryBot: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
parent
057db2c48b
commit
8011ffeccb
5 changed files with 18 additions and 27 deletions
33
conn.go
33
conn.go
|
@ -30,11 +30,10 @@ type Conn struct {
|
|||
isClient bool
|
||||
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
|
||||
|
||||
// handshakeStatus is 1 if the connection is currently transferring
|
||||
// isHandshakeComplete is true if the connection is currently transferring
|
||||
// application data (i.e. is not currently processing a handshake).
|
||||
// handshakeStatus == 1 implies handshakeErr == nil.
|
||||
// This field is only to be accessed with sync/atomic.
|
||||
handshakeStatus uint32
|
||||
// isHandshakeComplete is true implies handshakeErr == nil.
|
||||
isHandshakeComplete atomic.Bool
|
||||
// constant after handshake; protected by handshakeMutex
|
||||
handshakeMutex sync.Mutex
|
||||
handshakeErr error // error resulting from handshake
|
||||
|
@ -604,7 +603,7 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
|
|||
if c.in.err != nil {
|
||||
return c.in.err
|
||||
}
|
||||
handshakeComplete := c.handshakeComplete()
|
||||
handshakeComplete := c.isHandshakeComplete.Load()
|
||||
|
||||
// This function modifies c.rawInput, which owns the c.input memory.
|
||||
if c.input.Len() != 0 {
|
||||
|
@ -1130,7 +1129,7 @@ func (c *Conn) Write(b []byte) (int, error) {
|
|||
return 0, err
|
||||
}
|
||||
|
||||
if !c.handshakeComplete() {
|
||||
if !c.isHandshakeComplete.Load() {
|
||||
return 0, alertInternalError
|
||||
}
|
||||
|
||||
|
@ -1200,7 +1199,7 @@ func (c *Conn) handleRenegotiation() error {
|
|||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
atomic.StoreUint32(&c.handshakeStatus, 0)
|
||||
c.isHandshakeComplete.Store(false)
|
||||
if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
|
||||
c.handshakes++
|
||||
}
|
||||
|
@ -1337,7 +1336,7 @@ func (c *Conn) Close() error {
|
|||
}
|
||||
|
||||
var alertErr error
|
||||
if c.handshakeComplete() {
|
||||
if c.isHandshakeComplete.Load() {
|
||||
if err := c.closeNotify(); err != nil {
|
||||
alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
|
||||
}
|
||||
|
@ -1355,7 +1354,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
|
|||
// called once the handshake has completed and does not call CloseWrite on the
|
||||
// underlying connection. Most callers should just use Close.
|
||||
func (c *Conn) CloseWrite() error {
|
||||
if !c.handshakeComplete() {
|
||||
if !c.isHandshakeComplete.Load() {
|
||||
return errEarlyCloseWrite
|
||||
}
|
||||
|
||||
|
@ -1409,7 +1408,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
|
|||
// Fast sync/atomic-based exit if there is no handshake in flight and the
|
||||
// last one succeeded without an error. Avoids the expensive context setup
|
||||
// and mutex for most Read and Write calls.
|
||||
if c.handshakeComplete() {
|
||||
if c.isHandshakeComplete.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1452,7 +1451,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
|
|||
if err := c.handshakeErr; err != nil {
|
||||
return err
|
||||
}
|
||||
if c.handshakeComplete() {
|
||||
if c.isHandshakeComplete.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1468,10 +1467,10 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
|
|||
c.flush()
|
||||
}
|
||||
|
||||
if c.handshakeErr == nil && !c.handshakeComplete() {
|
||||
if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
|
||||
c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
|
||||
}
|
||||
if c.handshakeErr != nil && c.handshakeComplete() {
|
||||
if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
|
||||
panic("tls: internal error: handshake returned an error but is marked successful")
|
||||
}
|
||||
|
||||
|
@ -1487,7 +1486,7 @@ func (c *Conn) ConnectionState() ConnectionState {
|
|||
|
||||
func (c *Conn) connectionStateLocked() ConnectionState {
|
||||
var state ConnectionState
|
||||
state.HandshakeComplete = c.handshakeComplete()
|
||||
state.HandshakeComplete = c.isHandshakeComplete.Load()
|
||||
state.Version = c.vers
|
||||
state.NegotiatedProtocol = c.clientProtocol
|
||||
state.DidResume = c.didResume
|
||||
|
@ -1531,7 +1530,7 @@ func (c *Conn) VerifyHostname(host string) error {
|
|||
if !c.isClient {
|
||||
return errors.New("tls: VerifyHostname called on TLS server connection")
|
||||
}
|
||||
if !c.handshakeComplete() {
|
||||
if !c.isHandshakeComplete.Load() {
|
||||
return errors.New("tls: handshake has not yet been performed")
|
||||
}
|
||||
if len(c.verifiedChains) == 0 {
|
||||
|
@ -1539,7 +1538,3 @@ func (c *Conn) VerifyHostname(host string) error {
|
|||
}
|
||||
return c.peerCertificates[0].VerifyHostname(host)
|
||||
}
|
||||
|
||||
func (c *Conn) handshakeComplete() bool {
|
||||
return atomic.LoadUint32(&c.handshakeStatus) == 1
|
||||
}
|
||||
|
|
|
@ -19,7 +19,6 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -455,7 +454,7 @@ func (hs *clientHandshakeState) handshake() error {
|
|||
}
|
||||
|
||||
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
|
||||
atomic.StoreUint32(&c.handshakeStatus, 1)
|
||||
c.isHandshakeComplete.Store(true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"crypto/rsa"
|
||||
"errors"
|
||||
"hash"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -104,7 +103,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
|
|||
return err
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&c.handshakeStatus, 1)
|
||||
c.isHandshakeComplete.Store(true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -16,7 +16,6 @@ import (
|
|||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -122,7 +121,7 @@ func (hs *serverHandshakeState) handshake() error {
|
|||
}
|
||||
|
||||
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
|
||||
atomic.StoreUint32(&c.handshakeStatus, 1)
|
||||
c.isHandshakeComplete.Store(true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"errors"
|
||||
"hash"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -82,7 +81,7 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
|
|||
return err
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&c.handshakeStatus, 1)
|
||||
c.isHandshakeComplete.Store(true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue