From 459406a10fd809f25cd292cb0fbf6e7c6ad398cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 8 Oct 2023 12:17:25 +0800 Subject: [PATCH] Update BBR and Hysteria congestion control --- congestion_meta2/bbr_sender.go | 33 +++++++++++--- congestion_meta2/pacer.go | 3 ++ go.mod | 2 +- go.sum | 4 +- hysteria2/client.go | 9 +++- hysteria2/congestion/brutal.go | 78 ++++++++++++++++++++++++---------- hysteria2/congestion/pacer.go | 3 ++ hysteria2/service.go | 5 ++- 8 files changed, 102 insertions(+), 35 deletions(-) diff --git a/congestion_meta2/bbr_sender.go b/congestion_meta2/bbr_sender.go index d888082..c377877 100644 --- a/congestion_meta2/bbr_sender.go +++ b/congestion_meta2/bbr_sender.go @@ -21,6 +21,8 @@ import ( // const ( + minBps = 65536 // 64 kbps + invalidPacketNumber = -1 initialCongestionWindowPackets = 32 @@ -284,10 +286,7 @@ func newBbrSender( maxCongestionWindowWithNetworkParametersAdjusted: initialMaxCongestionWindow, maxDatagramSize: initialMaxDatagramSize, } - b.pacer = NewPacer(func() congestion.ByteCount { - // Pacer wants bytes per second, but Bandwidth is in bits per second. - return congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond)) - }) + b.pacer = NewPacer(b.bandwidthForPacer) /* if b.tracer != nil { @@ -484,10 +483,19 @@ func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, even b.calculateRecoveryWindow(bytesAcked, bytesLost) // Cleanup internal state. - if len(lostPackets) != 0 { - lastLostPacket := lostPackets[len(lostPackets)-1].PacketNumber - b.sampler.RemoveObsoletePackets(lastLostPacket) + // This is where we clean up obsolete (acked or lost) packets from the bandwidth sampler. + // The "least unacked" should actually be FirstOutstanding, but since we are not passing + // that through OnCongestionEventEx, we will only do an estimate using acked/lost packets + // for now. Because of fast retransmission, they should differ by no more than 2 packets. + // (this is controlled by packetThreshold in quic-go's sentPacketHandler) + var leastUnacked congestion.PacketNumber + if len(ackedPackets) != 0 { + leastUnacked = ackedPackets[len(ackedPackets)-1].PacketNumber - 2 + } else { + leastUnacked = lostPackets[len(lostPackets)-1].PacketNumber + 1 } + b.sampler.RemoveObsoletePackets(leastUnacked) + if isRoundStart { b.numLossEventsInRound = 0 b.bytesLostInRound = 0 @@ -537,6 +545,17 @@ func (b *bbrSender) bandwidthEstimate() Bandwidth { return b.maxBandwidth.GetBest() } +func (b *bbrSender) bandwidthForPacer() congestion.ByteCount { + bps := congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond)) + if bps < minBps { + // We need to make sure that the bandwidth value for pacer is never zero, + // otherwise it will go into an edge case where HasPacingBudget = false + // but TimeUntilSend is before, causing the quic-go send loop to go crazy and get stuck. + return minBps + } + return bps +} + // Returns the current estimate of the RTT of the connection. Outside of the // edge cases, this is minimum RTT. func (b *bbrSender) getMinRtt() time.Duration { diff --git a/congestion_meta2/pacer.go b/congestion_meta2/pacer.go index 2d05d82..5fae021 100644 --- a/congestion_meta2/pacer.go +++ b/congestion_meta2/pacer.go @@ -43,6 +43,9 @@ func (p *Pacer) Budget(now time.Time) congestion.ByteCount { return p.maxBurstSize() } budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + if budget < 0 { // protect against overflows + budget = congestion.ByteCount(1<<62 - 1) + } return Min(p.maxBurstSize(), budget) } diff --git a/go.mod b/go.mod index bf80010..f6a8553 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/sagernet/sing-quic go 1.20 require ( - github.com/sagernet/quic-go v0.0.0-20231001051131-0fc736a289bb + github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460 github.com/sagernet/sing v0.2.13 golang.org/x/crypto v0.14.0 golang.org/x/exp v0.0.0-20231005195138-3e424a577f31 diff --git a/go.sum b/go.sum index 6807be7..c98a266 100644 --- a/go.sum +++ b/go.sum @@ -23,8 +23,8 @@ github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg= github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= -github.com/sagernet/quic-go v0.0.0-20231001051131-0fc736a289bb h1:jlrVCepGBoob4QsPChIbe1j0d/lZSJkyVj2ukX3D4PE= -github.com/sagernet/quic-go v0.0.0-20231001051131-0fc736a289bb/go.mod h1:uJGpmJCOcMQqMlHKc3P1Vz6uygmpz4bPeVIoOhdVQnM= +github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460 h1:dAe4OIJAtE0nHOzTHhAReQteh3+sa63rvXbuIpbeOTY= +github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460/go.mod h1:uJGpmJCOcMQqMlHKc3P1Vz6uygmpz4bPeVIoOhdVQnM= github.com/sagernet/sing v0.2.13 h1:ohczGKWP+Yn3zlQXSvFn+6EKSELGggBi66D5rqpYRQ0= github.com/sagernet/sing v0.2.13/go.mod h1:AhNEHu0GXrpqkuzvTwvC8+j2cQUU/dh+zLEmq4C99pg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/hysteria2/client.go b/hysteria2/client.go index 8cb43ad..29b05d2 100644 --- a/hysteria2/client.go +++ b/hysteria2/client.go @@ -21,6 +21,7 @@ import ( "github.com/sagernet/sing/common/baderror" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" @@ -37,6 +38,8 @@ const ( type ClientOptions struct { Context context.Context Dialer N.Dialer + Logger logger.Logger + BrutalDebug bool ServerAddress M.Socksaddr SendBPS uint64 ReceiveBPS uint64 @@ -49,6 +52,8 @@ type ClientOptions struct { type Client struct { ctx context.Context dialer N.Dialer + logger logger.Logger + brutalDebug bool serverAddr M.Socksaddr sendBPS uint64 receiveBPS uint64 @@ -76,6 +81,8 @@ func NewClient(options ClientOptions) (*Client, error) { return &Client{ ctx: options.Context, dialer: options.Dialer, + logger: options.Logger, + brutalDebug: options.BrutalDebug, serverAddr: options.ServerAddress, sendBPS: options.SendBPS, receiveBPS: options.ReceiveBPS, @@ -153,7 +160,7 @@ func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { actualTx = c.sendBPS } if !authResponse.RxAuto && actualTx > 0 { - quicConn.SetCongestionControl(hyCC.NewBrutalSender(actualTx)) + quicConn.SetCongestionControl(hyCC.NewBrutalSender(actualTx, c.brutalDebug, c.logger)) } else { timeFunc := ntp.TimeFuncFromContext(c.ctx) if timeFunc == nil { diff --git a/hysteria2/congestion/brutal.go b/hysteria2/congestion/brutal.go index fc4cf55..27f2411 100644 --- a/hysteria2/congestion/brutal.go +++ b/hysteria2/congestion/brutal.go @@ -1,20 +1,24 @@ package congestion import ( + "fmt" "time" "github.com/sagernet/quic-go/congestion" + "github.com/sagernet/sing/common/logger" ) const ( initMaxDatagramSize = 1252 - pktInfoSlotCount = 4 - minSampleCount = 50 - minAckRate = 0.8 + pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample + minSampleCount = 50 + minAckRate = 0.8 + congestionWindowMultiplier = 2 + debugPrintInterval = 2 ) -var _ congestion.CongestionControl = &BrutalSender{} +var _ congestion.CongestionControlEx = &BrutalSender{} type BrutalSender struct { rttStats congestion.RTTStatsProvider @@ -22,8 +26,11 @@ type BrutalSender struct { maxDatagramSize congestion.ByteCount pacer *pacer - pktInfoSlots [pktInfoSlotCount]pktInfo - ackRate float64 + pktInfoSlots [pktInfoSlotCount]pktInfo + ackRate float64 + debug bool + logger logger.Logger + lastAckPrintTimestamp int64 } type pktInfo struct { @@ -32,11 +39,13 @@ type pktInfo struct { LossCount uint64 } -func NewBrutalSender(bps uint64) *BrutalSender { +func NewBrutalSender(bps uint64, debug bool, logger logger.Logger) *BrutalSender { bs := &BrutalSender{ bps: congestion.ByteCount(bps), maxDatagramSize: initMaxDatagramSize, ackRate: 1, + debug: debug, + logger: logger, } bs.pacer = newPacer(func() congestion.ByteCount { return congestion.ByteCount(float64(bs.bps) / bs.ackRate) @@ -65,7 +74,7 @@ func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount { if rtt <= 0 { return 10240 } - return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate) + return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate) } func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, @@ -77,31 +86,26 @@ func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, priorInFlight congestion.ByteCount, eventTime time.Time, ) { - currentTimestamp := eventTime.Unix() - slot := currentTimestamp % pktInfoSlotCount - if b.pktInfoSlots[slot].Timestamp == currentTimestamp { - b.pktInfoSlots[slot].AckCount++ - } else { - // uninitialized slot or too old, reset - b.pktInfoSlots[slot].Timestamp = currentTimestamp - b.pktInfoSlots[slot].AckCount = 1 - b.pktInfoSlots[slot].LossCount = 0 - } - b.updateAckRate(currentTimestamp) + // Stub } func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount, priorInFlight congestion.ByteCount, ) { - currentTimestamp := time.Now().Unix() + // Stub +} + +func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { + currentTimestamp := eventTime.Unix() slot := currentTimestamp % pktInfoSlotCount if b.pktInfoSlots[slot].Timestamp == currentTimestamp { - b.pktInfoSlots[slot].LossCount++ + b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets)) + b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets)) } else { // uninitialized slot or too old, reset b.pktInfoSlots[slot].Timestamp = currentTimestamp - b.pktInfoSlots[slot].AckCount = 0 - b.pktInfoSlots[slot].LossCount = 1 + b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets)) + b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets)) } b.updateAckRate(currentTimestamp) } @@ -109,6 +113,9 @@ func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostByt func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) { b.maxDatagramSize = size b.pacer.SetMaxDatagramSize(size) + if b.debug { + b.debugPrint("SetMaxDatagramSize: %d", size) + } } func (b *BrutalSender) updateAckRate(currentTimestamp int64) { @@ -123,12 +130,29 @@ func (b *BrutalSender) updateAckRate(currentTimestamp int64) { } if ackCount+lossCount < minSampleCount { b.ackRate = 1 + if b.canPrintAckRate(currentTimestamp) { + b.lastAckPrintTimestamp = currentTimestamp + b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)", + ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) + } + return } rate := float64(ackCount) / float64(ackCount+lossCount) if rate < minAckRate { b.ackRate = minAckRate + if b.canPrintAckRate(currentTimestamp) { + b.lastAckPrintTimestamp = currentTimestamp + b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)", + rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) + } + return } b.ackRate = rate + if b.canPrintAckRate(currentTimestamp) { + b.lastAckPrintTimestamp = currentTimestamp + b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)", + rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) + } } func (b *BrutalSender) InSlowStart() bool { @@ -143,6 +167,14 @@ func (b *BrutalSender) MaybeExitSlowStart() {} func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {} +func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool { + return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval +} + +func (b *BrutalSender) debugPrint(format string, a ...any) { + b.logger.Debug("[brutal] ", fmt.Sprintf(format, a...)) +} + func maxDuration(a, b time.Duration) time.Duration { if a > b { return a diff --git a/hysteria2/congestion/pacer.go b/hysteria2/congestion/pacer.go index 878985e..a648f98 100644 --- a/hysteria2/congestion/pacer.go +++ b/hysteria2/congestion/pacer.go @@ -44,6 +44,9 @@ func (p *pacer) Budget(now time.Time) congestion.ByteCount { return p.maxBurstSize() } budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + if budget < 0 { // protect against overflows + budget = congestion.ByteCount(1<<62 - 1) + } return minByteCount(p.maxBurstSize(), budget) } diff --git a/hysteria2/service.go b/hysteria2/service.go index 8f07c56..7c3866d 100644 --- a/hysteria2/service.go +++ b/hysteria2/service.go @@ -33,6 +33,7 @@ import ( type ServiceOptions struct { Context context.Context Logger logger.Logger + BrutalDebug bool SendBPS uint64 ReceiveBPS uint64 IgnoreClientBandwidth bool @@ -51,6 +52,7 @@ type ServerHandler interface { type Service[U comparable] struct { ctx context.Context logger logger.Logger + brutalDebug bool sendBPS uint64 receiveBPS uint64 ignoreClientBandwidth bool @@ -82,6 +84,7 @@ func NewService[U comparable](options ServiceOptions) (*Service[U], error) { return &Service[U]{ ctx: options.Context, logger: options.Logger, + brutalDebug: options.BrutalDebug, sendBPS: options.SendBPS, receiveBPS: options.ReceiveBPS, ignoreClientBandwidth: options.IgnoreClientBandwidth, @@ -199,7 +202,7 @@ func (s *serverSession[U]) ServeHTTP(w http.ResponseWriter, r *http.Request) { sendBps = request.Rx } format.ToString(1024 * 1024) - s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(sendBps)) + s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(sendBps, s.brutalDebug, s.logger)) } else { timeFunc := ntp.TimeFuncFromContext(s.ctx) if timeFunc == nil {