handshake: remove unneeded mutex in cryptoSetup (#4227)

This commit is contained in:
Marten Seemann 2024-01-02 14:52:08 +07:00 committed by GitHub
parent 22b7f7744e
commit 1083d1fb8f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -8,7 +8,6 @@ import (
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -48,8 +47,6 @@ type cryptoSetup struct {
perspective protocol.Perspective perspective protocol.Perspective
mutex sync.Mutex // protects all members below
handshakeCompleteTime time.Time handshakeCompleteTime time.Time
zeroRTTOpener LongHeaderOpener // only set for the server zeroRTTOpener LongHeaderOpener // only set for the server
@ -434,10 +431,8 @@ func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bo
func (h *cryptoSetup) rejected0RTT() { func (h *cryptoSetup) rejected0RTT() {
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.") h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
h.mutex.Lock()
had0RTTKeys := h.zeroRTTSealer != nil had0RTTKeys := h.zeroRTTSealer != nil
h.zeroRTTSealer = nil h.zeroRTTSealer = nil
h.mutex.Unlock()
if had0RTTKeys { if had0RTTKeys {
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys}) h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
@ -446,7 +441,6 @@ func (h *cryptoSetup) rejected0RTT() {
func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID) suite := getCipherSuite(suiteID)
h.mutex.Lock()
//nolint:exhaustive // The TLS stack doesn't export Initial keys. //nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el { switch el {
case tls.QUICEncryptionLevelEarly: case tls.QUICEncryptionLevelEarly:
@ -478,7 +472,6 @@ func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra
default: default:
panic("unexpected read encryption level") panic("unexpected read encryption level")
} }
h.mutex.Unlock()
h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
@ -487,7 +480,6 @@ func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra
func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID) suite := getCipherSuite(suiteID)
h.mutex.Lock()
//nolint:exhaustive // The TLS stack doesn't export Initial keys. //nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el { switch el {
case tls.QUICEncryptionLevelEarly: case tls.QUICEncryptionLevelEarly:
@ -498,7 +490,6 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
createAEAD(suite, trafficSecret, h.version), createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version), newHeaderProtector(suite, trafficSecret, true, h.version),
) )
h.mutex.Unlock()
if h.logger.Debug() { if h.logger.Debug() {
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
} }
@ -533,7 +524,6 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
default: default:
panic("unexpected write encryption level") panic("unexpected write encryption level")
} }
h.mutex.Unlock()
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
} }
@ -555,11 +545,9 @@ func (h *cryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) {
} }
func (h *cryptoSetup) DiscardInitialKeys() { func (h *cryptoSetup) DiscardInitialKeys() {
h.mutex.Lock()
dropped := h.initialOpener != nil dropped := h.initialOpener != nil
h.initialOpener = nil h.initialOpener = nil
h.initialSealer = nil h.initialSealer = nil
h.mutex.Unlock()
if dropped { if dropped {
h.logger.Debugf("Dropping Initial keys.") h.logger.Debugf("Dropping Initial keys.")
} }
@ -574,22 +562,17 @@ func (h *cryptoSetup) SetHandshakeConfirmed() {
h.aead.SetHandshakeConfirmed() h.aead.SetHandshakeConfirmed()
// drop Handshake keys // drop Handshake keys
var dropped bool var dropped bool
h.mutex.Lock()
if h.handshakeOpener != nil { if h.handshakeOpener != nil {
h.handshakeOpener = nil h.handshakeOpener = nil
h.handshakeSealer = nil h.handshakeSealer = nil
dropped = true dropped = true
} }
h.mutex.Unlock()
if dropped { if dropped {
h.logger.Debugf("Dropping Handshake keys.") h.logger.Debugf("Dropping Handshake keys.")
} }
} }
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.initialSealer == nil { if h.initialSealer == nil {
return nil, ErrKeysDropped return nil, ErrKeysDropped
} }
@ -597,9 +580,6 @@ func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
} }
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTSealer == nil { if h.zeroRTTSealer == nil {
return nil, ErrKeysDropped return nil, ErrKeysDropped
} }
@ -607,9 +587,6 @@ func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
} }
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.handshakeSealer == nil { if h.handshakeSealer == nil {
if h.initialSealer == nil { if h.initialSealer == nil {
return nil, ErrKeysDropped return nil, ErrKeysDropped
@ -620,9 +597,6 @@ func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
} }
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if !h.has1RTTSealer { if !h.has1RTTSealer {
return nil, ErrKeysNotYetAvailable return nil, ErrKeysNotYetAvailable
} }
@ -630,9 +604,6 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
} }
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.initialOpener == nil { if h.initialOpener == nil {
return nil, ErrKeysDropped return nil, ErrKeysDropped
} }
@ -640,9 +611,6 @@ func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
} }
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener == nil { if h.zeroRTTOpener == nil {
if h.initialOpener != nil { if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable return nil, ErrKeysNotYetAvailable
@ -654,9 +622,6 @@ func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
} }
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.handshakeOpener == nil { if h.handshakeOpener == nil {
if h.initialOpener != nil { if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable return nil, ErrKeysNotYetAvailable
@ -668,9 +633,6 @@ func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
} }
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
h.zeroRTTOpener = nil h.zeroRTTOpener = nil
h.logger.Debugf("Dropping 0-RTT keys.") h.logger.Debugf("Dropping 0-RTT keys.")