diff --git a/internal/ackhandler/u_ackhandler.go b/internal/ackhandler/u_ackhandler.go index 00a2a8c0..8a4bdeb6 100644 --- a/internal/ackhandler/u_ackhandler.go +++ b/internal/ackhandler/u_ackhandler.go @@ -20,5 +20,5 @@ func NewUAckHandler( sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger) return &uSentPacketHandler{ sentPacketHandler: sph, - }, newReceivedPacketHandler(sph, rttStats, logger) + }, newReceivedPacketHandler(sph, logger) } diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 4a250467..a7cb5f18 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -276,7 +276,6 @@ func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) { return false, h.handleTransportParameters(ev.Data) case tls.QUICTransportParametersRequired: h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective)) - // [UQUIC] doesn't expect this and may fail return false, nil case tls.QUICRejectedEarlyData: h.rejected0RTT() diff --git a/internal/handshake/u_crypto_setup.go b/internal/handshake/u_crypto_setup.go index d6dc16b1..8c7cd5b9 100644 --- a/internal/handshake/u_crypto_setup.go +++ b/internal/handshake/u_crypto_setup.go @@ -1,11 +1,9 @@ package handshake import ( - "bytes" "context" "fmt" "strings" - "sync" "sync/atomic" "time" @@ -20,11 +18,11 @@ import ( type uCryptoSetup struct { tlsConf *tls.Config - conn *qtls.UQUICConn + conn *tls.UQUICConn events []Event - version protocol.VersionNumber + version protocol.Version ourParams *wire.TransportParameters peerParams *wire.TransportParameters @@ -39,8 +37,6 @@ type uCryptoSetup struct { perspective protocol.Perspective - mutex sync.Mutex // protects all members below - handshakeCompleteTime time.Time zeroRTTOpener LongHeaderOpener // only set for the server @@ -71,7 +67,7 @@ func NewUCryptoSetupClient( rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, - version protocol.VersionNumber, + version protocol.Version, chs *tls.ClientHelloSpec, ) CryptoSetup { cs := newUCryptoSetup( @@ -86,11 +82,16 @@ func NewUCryptoSetupClient( tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} + quicConf := &tls.QUICConfig{TLSConfig: tlsConf} qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState) cs.tlsConf = tlsConf - cs.conn = qtls.UQUICClient(quicConf, chs) + // [UQUIC] + cs.conn = tls.UQUICClient(quicConf, tls.HelloCustom) + if err := cs.conn.ApplyPreset(chs); err != nil { + panic(err) + } + // cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) // [UQUIC] doesn't require this return cs @@ -103,7 +104,7 @@ func newUCryptoSetup( tracer *logging.ConnectionTracer, logger utils.Logger, perspective protocol.Perspective, - version protocol.VersionNumber, + version protocol.Version, ) *uCryptoSetup { initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) if tracer != nil { @@ -195,29 +196,29 @@ func (h *uCryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLe } } -func (h *uCryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) { +func (h *uCryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) { switch ev.Kind { - case qtls.QUICNoEvent: + case tls.QUICNoEvent: return true, nil - case qtls.QUICSetReadSecret: - h.SetReadKey(ev.Level, ev.Suite, ev.Data) + case tls.QUICSetReadSecret: + h.setReadKey(ev.Level, ev.Suite, ev.Data) return false, nil - case qtls.QUICSetWriteSecret: - h.SetWriteKey(ev.Level, ev.Suite, ev.Data) + case tls.QUICSetWriteSecret: + h.setWriteKey(ev.Level, ev.Suite, ev.Data) return false, nil - case qtls.QUICTransportParameters: + case tls.QUICTransportParameters: return false, h.handleTransportParameters(ev.Data) - case qtls.QUICTransportParametersRequired: + case tls.QUICTransportParametersRequired: h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective)) // [UQUIC] doesn't expect this and may fail return false, nil - case qtls.QUICRejectedEarlyData: + case tls.QUICRejectedEarlyData: h.rejected0RTT() return false, nil - case qtls.QUICWriteData: - h.WriteRecord(ev.Level, ev.Data) + case tls.QUICWriteData: + h.writeRecord(ev.Level, ev.Data) return false, nil - case qtls.QUICHandshakeDone: + case tls.QUICHandshakeDone: h.handshakeComplete() return false, nil default: @@ -245,48 +246,41 @@ func (h *uCryptoSetup) handleTransportParameters(data []byte) error { } // must be called after receiving the transport parameters -func (h *uCryptoSetup) marshalDataForSessionState() []byte { +func (h *uCryptoSetup) marshalDataForSessionState(earlyData bool) []byte { b := make([]byte, 0, 256) b = quicvarint.Append(b, clientSessionStateRevision) b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds())) - return h.peerParams.MarshalForSessionTicket(b) + if earlyData { + // only save the transport parameters for 0-RTT enabled session tickets + return h.peerParams.MarshalForSessionTicket(b) + } + return b } -func (h *uCryptoSetup) handleDataFromSessionState(data []byte) { - tp, err := h.handleDataFromSessionStateImpl(data) +func (h *uCryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (allowEarlyData bool) { + rtt, tp, err := decodeDataFromSessionState(data, earlyData) if err != nil { h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) return } - h.zeroRTTParameters = tp -} - -func (h *uCryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) { - r := bytes.NewReader(data) - ver, err := quicvarint.Read(r) - if err != nil { - return nil, err + h.rttStats.SetInitialRTT(rtt) + // The session ticket might have been saved from a connection that allowed 0-RTT, + // and therefore contain transport parameters. + // Only use them if 0-RTT is actually used on the new connection. + if tp != nil && h.allow0RTT { + h.zeroRTTParameters = tp + return true } - if ver != clientSessionStateRevision { - return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) - } - rtt, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) - var tp wire.TransportParameters - if err := tp.UnmarshalFromSessionTicket(r); err != nil { - return nil, err - } - return &tp, nil + return false } // GetSessionTicket generates a new session ticket. // Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection. // It is only valid for the server. func (h *uCryptoSetup) GetSessionTicket() ([]byte, error) { - if err := qtls.SendSessionTicket(h.conn, h.allow0RTT); err != nil { + if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{ + EarlyData: h.allow0RTT, + }); err != nil { // Session tickets might be disabled by tls.Config.SessionTicketsDisabled. // We can't check h.tlsConfig here, since the actual config might have been obtained from // the GetConfigForClient callback. @@ -298,11 +292,11 @@ func (h *uCryptoSetup) GetSessionTicket() ([]byte, error) { return nil, err } ev := h.conn.NextEvent() - if ev.Kind != qtls.QUICWriteData || ev.Level != qtls.QUICEncryptionLevelApplication { + if ev.Kind != tls.QUICWriteData || ev.Level != tls.QUICEncryptionLevelApplication { panic("crypto/tls bug: where's my session ticket?") } ticket := ev.Data - if ev := h.conn.NextEvent(); ev.Kind != qtls.QUICNoEvent { + if ev := h.conn.NextEvent(); ev.Kind != tls.QUICNoEvent { panic("crypto/tls bug: why more than one ticket?") } return ticket, nil @@ -312,22 +306,19 @@ func (h *uCryptoSetup) GetSessionTicket() ([]byte, error) { func (h *uCryptoSetup) rejected0RTT() { h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.") - h.mutex.Lock() had0RTTKeys := h.zeroRTTSealer != nil h.zeroRTTSealer = nil - h.mutex.Unlock() if had0RTTKeys { h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys}) } } -func (h *uCryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { +func (h *uCryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { suite := getCipherSuite(suiteID) - h.mutex.Lock() //nolint:exhaustive // The TLS stack doesn't export Initial keys. switch el { - case qtls.QUICEncryptionLevelEarly: + case tls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveClient { panic("Received 0-RTT read key for the client") } @@ -339,7 +330,7 @@ func (h *uCryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, t if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.QUICEncryptionLevelHandshake: + case tls.QUICEncryptionLevelHandshake: h.handshakeOpener = newLongHeaderOpener( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), @@ -347,7 +338,7 @@ func (h *uCryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, t if h.logger.Debug() { h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.QUICEncryptionLevelApplication: + case tls.QUICEncryptionLevelApplication: h.aead.SetReadKey(suite, trafficSecret) h.has1RTTOpener = true if h.logger.Debug() { @@ -356,19 +347,17 @@ func (h *uCryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, t default: panic("unexpected read encryption level") } - h.mutex.Unlock() h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) } } -func (h *uCryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { +func (h *uCryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { suite := getCipherSuite(suiteID) - h.mutex.Lock() //nolint:exhaustive // The TLS stack doesn't export Initial keys. switch el { - case qtls.QUICEncryptionLevelEarly: + case tls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveServer { panic("Received 0-RTT write key for the server") } @@ -376,16 +365,15 @@ func (h *uCryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), ) - h.mutex.Unlock() if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) } // don't set used0RTT here. 0-RTT might still get rejected. return - case qtls.QUICEncryptionLevelHandshake: + case tls.QUICEncryptionLevelHandshake: h.handshakeSealer = newLongHeaderSealer( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), @@ -393,7 +381,7 @@ func (h *uCryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, if h.logger.Debug() { h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.QUICEncryptionLevelApplication: + case tls.QUICEncryptionLevelApplication: h.aead.SetWriteKey(suite, trafficSecret) h.has1RTTSealer = true if h.logger.Debug() { @@ -404,28 +392,27 @@ func (h *uCryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, h.used0RTT.Store(true) h.zeroRTTSealer = nil h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil { h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) } } default: panic("unexpected write encryption level") } - h.mutex.Unlock() - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) } } -// WriteRecord is called when TLS writes data -func (h *uCryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) { +// writeRecord is called when TLS writes data +func (h *uCryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) { //nolint:exhaustive // handshake records can only be written for Initial and Handshake. switch encLevel { - case qtls.QUICEncryptionLevelInitial: + case tls.QUICEncryptionLevelInitial: h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p}) - case qtls.QUICEncryptionLevelHandshake: + case tls.QUICEncryptionLevelHandshake: h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p}) - case qtls.QUICEncryptionLevelApplication: + case tls.QUICEncryptionLevelApplication: panic("unexpected write") default: panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel)) @@ -433,11 +420,9 @@ func (h *uCryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) } func (h *uCryptoSetup) DiscardInitialKeys() { - h.mutex.Lock() dropped := h.initialOpener != nil h.initialOpener = nil h.initialSealer = nil - h.mutex.Unlock() if dropped { h.logger.Debugf("Dropping Initial keys.") } @@ -452,22 +437,17 @@ func (h *uCryptoSetup) SetHandshakeConfirmed() { h.aead.SetHandshakeConfirmed() // drop Handshake keys var dropped bool - h.mutex.Lock() if h.handshakeOpener != nil { h.handshakeOpener = nil h.handshakeSealer = nil dropped = true } - h.mutex.Unlock() if dropped { h.logger.Debugf("Dropping Handshake keys.") } } func (h *uCryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.initialSealer == nil { return nil, ErrKeysDropped } @@ -475,9 +455,6 @@ func (h *uCryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { } func (h *uCryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.zeroRTTSealer == nil { return nil, ErrKeysDropped } @@ -485,9 +462,6 @@ func (h *uCryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { } func (h *uCryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.handshakeSealer == nil { if h.initialSealer == nil { return nil, ErrKeysDropped @@ -498,9 +472,6 @@ func (h *uCryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { } func (h *uCryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if !h.has1RTTSealer { return nil, ErrKeysNotYetAvailable } @@ -508,9 +479,6 @@ func (h *uCryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { } func (h *uCryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.initialOpener == nil { return nil, ErrKeysDropped } @@ -518,9 +486,6 @@ func (h *uCryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { } func (h *uCryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.zeroRTTOpener == nil { if h.initialOpener != nil { return nil, ErrKeysNotYetAvailable @@ -532,9 +497,6 @@ func (h *uCryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { } func (h *uCryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.handshakeOpener == nil { if h.initialOpener != nil { return nil, ErrKeysNotYetAvailable @@ -546,9 +508,6 @@ func (h *uCryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { } func (h *uCryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { h.zeroRTTOpener = nil h.logger.Debugf("Dropping 0-RTT keys.") diff --git a/internal/qtls/cipher_suite_test.go b/internal/qtls/cipher_suite_test.go index 716d8217..c76763f5 100644 --- a/internal/qtls/cipher_suite_test.go +++ b/internal/qtls/cipher_suite_test.go @@ -1,11 +1,11 @@ package qtls import ( - "crypto/tls" "fmt" "net" - "github.com/quic-go/quic-go/internal/testdata" + "github.com/refraction-networking/uquic/internal/testdata" + tls "github.com/refraction-networking/utls" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" diff --git a/internal/qtls/qtls.go b/internal/qtls/qtls.go index 3642973d..0425b874 100644 --- a/internal/qtls/qtls.go +++ b/internal/qtls/qtls.go @@ -9,15 +9,6 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" ) -// [UQUIC] -func UQUICClient(config *tls.QUICConfig, clientHelloSpec *tls.ClientHelloSpec) *UQUICConn { - uqc := tls.UQUICClient(config, tls.HelloCustom) - if err := uqc.ApplyPreset(clientHelloSpec); err != nil { - panic(err) - } - return uqc -} - func SetupConfigForServer(qconf *tls.QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) { conf := qconf.TLSConfig diff --git a/u_client.go b/u_client.go index 19b049aa..f6822e07 100644 --- a/u_client.go +++ b/u_client.go @@ -131,8 +131,8 @@ func (c *uClient) dial(ctx context.Context) error { select { case <-ctx.Done(): - c.conn.shutdown() - return ctx.Err() + c.conn.destroy(nil) + return context.Cause(ctx) case err := <-errorChan: return err case recreateErr := <-recreateChan: diff --git a/u_connection.go b/u_connection.go index 7c11749c..e85275c4 100644 --- a/u_connection.go +++ b/u_connection.go @@ -27,11 +27,7 @@ var newUClientConnection = func( tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, - v protocol.VersionNumber, - // chs *tls.ClientHelloSpec, - // initPktNbrLen PacketNumberLen, - // qfs QUICFrames, - // udpDatagramMinSize int, + v protocol.Version, uSpec *QUICSpec, // [UQUIC] ) quicConn { s := &connection{ @@ -130,7 +126,7 @@ var newUClientConnection = func( InitialSourceConnectionID: srcConnID, } if s.config.EnableDatagrams { - params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + params.MaxDatagramFrameSize = wire.MaxDatagramSize } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } diff --git a/u_packet_packer.go b/u_packet_packer.go index d4b85991..be23c3b1 100644 --- a/u_packet_packer.go +++ b/u_packet_packer.go @@ -35,7 +35,7 @@ func newUPacketPacker( // PackCoalescedPacket packs a new packet. // It packs an Initial / Handshake if there is data to send in these packet number spaces. // It should only be called before the handshake is confirmed. -func (p *uPacketPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *uPacketPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { var ( initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload @@ -165,7 +165,7 @@ func (p *uPacketPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol } // [UQUIC] -func (p *uPacketPacker) appendInitialPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) { +func (p *uPacketPacker) appendInitialPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.Version) (*longHeaderPacket, error) { // Shouldn't need this? // if p.uSpec.InitialPacketSpec.InitPacketNumberLength > 0 { // header.PacketNumberLen = p.uSpec.InitialPacketSpec.InitPacketNumberLength @@ -220,7 +220,7 @@ func (p *uPacketPacker) appendInitialPacket(buffer *packetBuffer, header *wire.E }, nil } -func (p *uPacketPacker) MarshalInitialPacketPayload(pl payload, v protocol.VersionNumber) ([]byte, error) { +func (p *uPacketPacker) MarshalInitialPacketPayload(pl payload, v protocol.Version) ([]byte, error) { var originalFrameBytes []byte for _, f := range pl.frames {