mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 04:37:36 +03:00
use a synchronous API for the crypto setup (#3939)
This commit is contained in:
parent
2c0e7e02b0
commit
469a6153b6
18 changed files with 696 additions and 1032 deletions
|
@ -6,7 +6,6 @@ import (
|
|||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -30,16 +29,15 @@ type cryptoSetup struct {
|
|||
tlsConf *tls.Config
|
||||
conn *qtls.QUICConn
|
||||
|
||||
events []Event
|
||||
|
||||
version protocol.VersionNumber
|
||||
|
||||
ourParams *wire.TransportParameters
|
||||
peerParams *wire.TransportParameters
|
||||
|
||||
runner handshakeRunner
|
||||
|
||||
zeroRTTParameters *wire.TransportParameters
|
||||
zeroRTTParametersChan chan<- *wire.TransportParameters
|
||||
allow0RTT bool
|
||||
zeroRTTParameters *wire.TransportParameters
|
||||
allow0RTT bool
|
||||
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
|
@ -55,17 +53,14 @@ type cryptoSetup struct {
|
|||
zeroRTTOpener LongHeaderOpener // only set for the server
|
||||
zeroRTTSealer LongHeaderSealer // only set for the client
|
||||
|
||||
initialStream io.Writer
|
||||
initialOpener LongHeaderOpener
|
||||
initialSealer LongHeaderSealer
|
||||
|
||||
handshakeStream io.Writer
|
||||
handshakeOpener LongHeaderOpener
|
||||
handshakeSealer LongHeaderSealer
|
||||
|
||||
used0RTT atomic.Bool
|
||||
|
||||
oneRTTStream io.Writer
|
||||
aead *updatableAEAD
|
||||
has1RTTSealer bool
|
||||
has1RTTOpener bool
|
||||
|
@ -75,24 +70,18 @@ var _ CryptoSetup = &cryptoSetup{}
|
|||
|
||||
// NewCryptoSetupClient creates a new crypto setup for the client
|
||||
func NewCryptoSetupClient(
|
||||
initialStream, handshakeStream, oneRTTStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
tp *wire.TransportParameters,
|
||||
runner handshakeRunner,
|
||||
tlsConf *tls.Config,
|
||||
enable0RTT bool,
|
||||
rttStats *utils.RTTStats,
|
||||
tracer logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
|
||||
cs, clientHelloWritten := newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
oneRTTStream,
|
||||
) CryptoSetup {
|
||||
cs := newCryptoSetup(
|
||||
connID,
|
||||
tp,
|
||||
runner,
|
||||
rttStats,
|
||||
tracer,
|
||||
logger,
|
||||
|
@ -109,15 +98,13 @@ func NewCryptoSetupClient(
|
|||
cs.conn = qtls.QUICClient(quicConf)
|
||||
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
|
||||
|
||||
return cs, clientHelloWritten
|
||||
return cs
|
||||
}
|
||||
|
||||
// NewCryptoSetupServer creates a new crypto setup for the server
|
||||
func NewCryptoSetupServer(
|
||||
initialStream, handshakeStream, oneRTTStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
tp *wire.TransportParameters,
|
||||
runner handshakeRunner,
|
||||
tlsConf *tls.Config,
|
||||
allow0RTT bool,
|
||||
rttStats *utils.RTTStats,
|
||||
|
@ -125,13 +112,9 @@ func NewCryptoSetupServer(
|
|||
logger utils.Logger,
|
||||
version protocol.VersionNumber,
|
||||
) CryptoSetup {
|
||||
cs, _ := newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
oneRTTStream,
|
||||
cs := newCryptoSetup(
|
||||
connID,
|
||||
tp,
|
||||
runner,
|
||||
rttStats,
|
||||
tracer,
|
||||
logger,
|
||||
|
@ -150,38 +133,31 @@ func NewCryptoSetupServer(
|
|||
}
|
||||
|
||||
func newCryptoSetup(
|
||||
initialStream, handshakeStream, oneRTTStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
tp *wire.TransportParameters,
|
||||
runner handshakeRunner,
|
||||
rttStats *utils.RTTStats,
|
||||
tracer logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
version protocol.VersionNumber,
|
||||
) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
|
||||
) *cryptoSetup {
|
||||
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
|
||||
if tracer != nil {
|
||||
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
|
||||
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
|
||||
}
|
||||
zeroRTTParametersChan := make(chan *wire.TransportParameters, 1)
|
||||
return &cryptoSetup{
|
||||
initialStream: initialStream,
|
||||
initialSealer: initialSealer,
|
||||
initialOpener: initialOpener,
|
||||
handshakeStream: handshakeStream,
|
||||
oneRTTStream: oneRTTStream,
|
||||
aead: newUpdatableAEAD(rttStats, tracer, logger, version),
|
||||
runner: runner,
|
||||
ourParams: tp,
|
||||
rttStats: rttStats,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
perspective: perspective,
|
||||
zeroRTTParametersChan: zeroRTTParametersChan,
|
||||
version: version,
|
||||
}, zeroRTTParametersChan
|
||||
initialSealer: initialSealer,
|
||||
initialOpener: initialOpener,
|
||||
aead: newUpdatableAEAD(rttStats, tracer, logger, version),
|
||||
events: make([]Event, 0, 16),
|
||||
ourParams: tp,
|
||||
rttStats: rttStats,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
perspective: perspective,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
|
||||
|
@ -216,10 +192,9 @@ func (h *cryptoSetup) StartHandshake() error {
|
|||
if h.perspective == protocol.PerspectiveClient {
|
||||
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
|
||||
h.logger.Debugf("Doing 0-RTT.")
|
||||
h.zeroRTTParametersChan <- h.zeroRTTParameters
|
||||
h.events = append(h.events, Event{Kind: EventRestoredTransportParameters, TransportParameters: h.zeroRTTParameters})
|
||||
} else {
|
||||
h.logger.Debugf("Not doing 0-RTT. Has sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil)
|
||||
h.zeroRTTParametersChan <- nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -275,7 +250,8 @@ func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
|
|||
h.rejected0RTT()
|
||||
return false, nil
|
||||
case qtls.QUICWriteData:
|
||||
return false, h.WriteRecord(ev.Level, ev.Data)
|
||||
h.WriteRecord(ev.Level, ev.Data)
|
||||
return false, nil
|
||||
case qtls.QUICHandshakeDone:
|
||||
h.handshakeComplete()
|
||||
return false, nil
|
||||
|
@ -284,13 +260,22 @@ func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) NextEvent() Event {
|
||||
if len(h.events) == 0 {
|
||||
return Event{Kind: EventNoEvent}
|
||||
}
|
||||
ev := h.events[0]
|
||||
h.events = h.events[1:]
|
||||
return ev
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleTransportParameters(data []byte) error {
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
|
||||
return err
|
||||
}
|
||||
h.peerParams = &tp
|
||||
h.runner.OnReceivedParams(h.peerParams)
|
||||
h.events = append(h.events, Event{Kind: EventReceivedTransportParameters, TransportParameters: h.peerParams})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -392,7 +377,7 @@ func (h *cryptoSetup) rejected0RTT() {
|
|||
h.mutex.Unlock()
|
||||
|
||||
if had0RTTKeys {
|
||||
h.runner.DropKeys(protocol.Encryption0RTT)
|
||||
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -414,11 +399,9 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr
|
|||
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
case qtls.QUICEncryptionLevelHandshake:
|
||||
h.handshakeOpener = newHandshakeOpener(
|
||||
h.handshakeOpener = newLongHeaderOpener(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
h.dropInitialKeys,
|
||||
h.perspective,
|
||||
)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
|
@ -433,7 +416,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr
|
|||
panic("unexpected read encryption level")
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
h.runner.OnReceivedReadKeys()
|
||||
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
|
||||
if h.tracer != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
|
||||
}
|
||||
|
@ -462,11 +445,9 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t
|
|||
// don't set used0RTT here. 0-RTT might still get rejected.
|
||||
return
|
||||
case qtls.QUICEncryptionLevelHandshake:
|
||||
h.handshakeSealer = newHandshakeSealer(
|
||||
h.handshakeSealer = newLongHeaderSealer(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
h.dropInitialKeys,
|
||||
h.perspective,
|
||||
)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
|
@ -496,40 +477,34 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t
|
|||
}
|
||||
|
||||
// WriteRecord is called when TLS writes data
|
||||
func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) error {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
var str io.Writer
|
||||
func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) {
|
||||
//nolint:exhaustive // handshake records can only be written for Initial and Handshake.
|
||||
switch encLevel {
|
||||
case qtls.QUICEncryptionLevelInitial:
|
||||
// assume that the first WriteRecord call contains the ClientHello
|
||||
str = h.initialStream
|
||||
h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p})
|
||||
case qtls.QUICEncryptionLevelHandshake:
|
||||
str = h.handshakeStream
|
||||
h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p})
|
||||
case qtls.QUICEncryptionLevelApplication:
|
||||
str = h.oneRTTStream
|
||||
panic("unexpected write")
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel))
|
||||
}
|
||||
_, err := str.Write(p)
|
||||
return err
|
||||
}
|
||||
|
||||
// used a callback in the handshakeSealer and handshakeOpener
|
||||
func (h *cryptoSetup) dropInitialKeys() {
|
||||
func (h *cryptoSetup) DiscardInitialKeys() {
|
||||
h.mutex.Lock()
|
||||
dropped := h.initialOpener != nil
|
||||
h.initialOpener = nil
|
||||
h.initialSealer = nil
|
||||
h.mutex.Unlock()
|
||||
h.runner.DropKeys(protocol.EncryptionInitial)
|
||||
h.logger.Debugf("Dropping Initial keys.")
|
||||
if dropped {
|
||||
h.logger.Debugf("Dropping Initial keys.")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handshakeComplete() {
|
||||
h.handshakeCompleteTime = time.Now()
|
||||
h.runner.OnHandshakeComplete()
|
||||
h.events = append(h.events, Event{Kind: EventHandshakeComplete})
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetHandshakeConfirmed() {
|
||||
|
@ -544,7 +519,6 @@ func (h *cryptoSetup) SetHandshakeConfirmed() {
|
|||
}
|
||||
h.mutex.Unlock()
|
||||
if dropped {
|
||||
h.runner.DropKeys(protocol.EncryptionHandshake)
|
||||
h.logger.Debugf("Dropping Handshake keys.")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue