Update BBR and Hysteria congestion control

This commit is contained in:
世界 2023-10-08 12:17:25 +08:00
parent 1ea488a342
commit 459406a10f
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
8 changed files with 102 additions and 35 deletions

View file

@ -21,6 +21,8 @@ import (
// //
const ( const (
minBps = 65536 // 64 kbps
invalidPacketNumber = -1 invalidPacketNumber = -1
initialCongestionWindowPackets = 32 initialCongestionWindowPackets = 32
@ -284,10 +286,7 @@ func newBbrSender(
maxCongestionWindowWithNetworkParametersAdjusted: initialMaxCongestionWindow, maxCongestionWindowWithNetworkParametersAdjusted: initialMaxCongestionWindow,
maxDatagramSize: initialMaxDatagramSize, maxDatagramSize: initialMaxDatagramSize,
} }
b.pacer = NewPacer(func() congestion.ByteCount { b.pacer = NewPacer(b.bandwidthForPacer)
// Pacer wants bytes per second, but Bandwidth is in bits per second.
return congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond))
})
/* /*
if b.tracer != nil { if b.tracer != nil {
@ -484,10 +483,19 @@ func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, even
b.calculateRecoveryWindow(bytesAcked, bytesLost) b.calculateRecoveryWindow(bytesAcked, bytesLost)
// Cleanup internal state. // Cleanup internal state.
if len(lostPackets) != 0 { // This is where we clean up obsolete (acked or lost) packets from the bandwidth sampler.
lastLostPacket := lostPackets[len(lostPackets)-1].PacketNumber // The "least unacked" should actually be FirstOutstanding, but since we are not passing
b.sampler.RemoveObsoletePackets(lastLostPacket) // that through OnCongestionEventEx, we will only do an estimate using acked/lost packets
// for now. Because of fast retransmission, they should differ by no more than 2 packets.
// (this is controlled by packetThreshold in quic-go's sentPacketHandler)
var leastUnacked congestion.PacketNumber
if len(ackedPackets) != 0 {
leastUnacked = ackedPackets[len(ackedPackets)-1].PacketNumber - 2
} else {
leastUnacked = lostPackets[len(lostPackets)-1].PacketNumber + 1
} }
b.sampler.RemoveObsoletePackets(leastUnacked)
if isRoundStart { if isRoundStart {
b.numLossEventsInRound = 0 b.numLossEventsInRound = 0
b.bytesLostInRound = 0 b.bytesLostInRound = 0
@ -537,6 +545,17 @@ func (b *bbrSender) bandwidthEstimate() Bandwidth {
return b.maxBandwidth.GetBest() return b.maxBandwidth.GetBest()
} }
func (b *bbrSender) bandwidthForPacer() congestion.ByteCount {
bps := congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond))
if bps < minBps {
// We need to make sure that the bandwidth value for pacer is never zero,
// otherwise it will go into an edge case where HasPacingBudget = false
// but TimeUntilSend is before, causing the quic-go send loop to go crazy and get stuck.
return minBps
}
return bps
}
// Returns the current estimate of the RTT of the connection. Outside of the // Returns the current estimate of the RTT of the connection. Outside of the
// edge cases, this is minimum RTT. // edge cases, this is minimum RTT.
func (b *bbrSender) getMinRtt() time.Duration { func (b *bbrSender) getMinRtt() time.Duration {

View file

@ -43,6 +43,9 @@ func (p *Pacer) Budget(now time.Time) congestion.ByteCount {
return p.maxBurstSize() return p.maxBurstSize()
} }
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
if budget < 0 { // protect against overflows
budget = congestion.ByteCount(1<<62 - 1)
}
return Min(p.maxBurstSize(), budget) return Min(p.maxBurstSize(), budget)
} }

2
go.mod
View file

@ -3,7 +3,7 @@ module github.com/sagernet/sing-quic
go 1.20 go 1.20
require ( require (
github.com/sagernet/quic-go v0.0.0-20231001051131-0fc736a289bb github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460
github.com/sagernet/sing v0.2.13 github.com/sagernet/sing v0.2.13
golang.org/x/crypto v0.14.0 golang.org/x/crypto v0.14.0
golang.org/x/exp v0.0.0-20231005195138-3e424a577f31 golang.org/x/exp v0.0.0-20231005195138-3e424a577f31

4
go.sum
View file

@ -23,8 +23,8 @@ github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg= github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg=
github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/sagernet/quic-go v0.0.0-20231001051131-0fc736a289bb h1:jlrVCepGBoob4QsPChIbe1j0d/lZSJkyVj2ukX3D4PE= github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460 h1:dAe4OIJAtE0nHOzTHhAReQteh3+sa63rvXbuIpbeOTY=
github.com/sagernet/quic-go v0.0.0-20231001051131-0fc736a289bb/go.mod h1:uJGpmJCOcMQqMlHKc3P1Vz6uygmpz4bPeVIoOhdVQnM= github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460/go.mod h1:uJGpmJCOcMQqMlHKc3P1Vz6uygmpz4bPeVIoOhdVQnM=
github.com/sagernet/sing v0.2.13 h1:ohczGKWP+Yn3zlQXSvFn+6EKSELGggBi66D5rqpYRQ0= github.com/sagernet/sing v0.2.13 h1:ohczGKWP+Yn3zlQXSvFn+6EKSELGggBi66D5rqpYRQ0=
github.com/sagernet/sing v0.2.13/go.mod h1:AhNEHu0GXrpqkuzvTwvC8+j2cQUU/dh+zLEmq4C99pg= github.com/sagernet/sing v0.2.13/go.mod h1:AhNEHu0GXrpqkuzvTwvC8+j2cQUU/dh+zLEmq4C99pg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

View file

@ -21,6 +21,7 @@ import (
"github.com/sagernet/sing/common/baderror" "github.com/sagernet/sing/common/baderror"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/common/ntp"
@ -37,6 +38,8 @@ const (
type ClientOptions struct { type ClientOptions struct {
Context context.Context Context context.Context
Dialer N.Dialer Dialer N.Dialer
Logger logger.Logger
BrutalDebug bool
ServerAddress M.Socksaddr ServerAddress M.Socksaddr
SendBPS uint64 SendBPS uint64
ReceiveBPS uint64 ReceiveBPS uint64
@ -49,6 +52,8 @@ type ClientOptions struct {
type Client struct { type Client struct {
ctx context.Context ctx context.Context
dialer N.Dialer dialer N.Dialer
logger logger.Logger
brutalDebug bool
serverAddr M.Socksaddr serverAddr M.Socksaddr
sendBPS uint64 sendBPS uint64
receiveBPS uint64 receiveBPS uint64
@ -76,6 +81,8 @@ func NewClient(options ClientOptions) (*Client, error) {
return &Client{ return &Client{
ctx: options.Context, ctx: options.Context,
dialer: options.Dialer, dialer: options.Dialer,
logger: options.Logger,
brutalDebug: options.BrutalDebug,
serverAddr: options.ServerAddress, serverAddr: options.ServerAddress,
sendBPS: options.SendBPS, sendBPS: options.SendBPS,
receiveBPS: options.ReceiveBPS, receiveBPS: options.ReceiveBPS,
@ -153,7 +160,7 @@ func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
actualTx = c.sendBPS actualTx = c.sendBPS
} }
if !authResponse.RxAuto && actualTx > 0 { if !authResponse.RxAuto && actualTx > 0 {
quicConn.SetCongestionControl(hyCC.NewBrutalSender(actualTx)) quicConn.SetCongestionControl(hyCC.NewBrutalSender(actualTx, c.brutalDebug, c.logger))
} else { } else {
timeFunc := ntp.TimeFuncFromContext(c.ctx) timeFunc := ntp.TimeFuncFromContext(c.ctx)
if timeFunc == nil { if timeFunc == nil {

View file

@ -1,20 +1,24 @@
package congestion package congestion
import ( import (
"fmt"
"time" "time"
"github.com/sagernet/quic-go/congestion" "github.com/sagernet/quic-go/congestion"
"github.com/sagernet/sing/common/logger"
) )
const ( const (
initMaxDatagramSize = 1252 initMaxDatagramSize = 1252
pktInfoSlotCount = 4 pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample
minSampleCount = 50 minSampleCount = 50
minAckRate = 0.8 minAckRate = 0.8
congestionWindowMultiplier = 2
debugPrintInterval = 2
) )
var _ congestion.CongestionControl = &BrutalSender{} var _ congestion.CongestionControlEx = &BrutalSender{}
type BrutalSender struct { type BrutalSender struct {
rttStats congestion.RTTStatsProvider rttStats congestion.RTTStatsProvider
@ -22,8 +26,11 @@ type BrutalSender struct {
maxDatagramSize congestion.ByteCount maxDatagramSize congestion.ByteCount
pacer *pacer pacer *pacer
pktInfoSlots [pktInfoSlotCount]pktInfo pktInfoSlots [pktInfoSlotCount]pktInfo
ackRate float64 ackRate float64
debug bool
logger logger.Logger
lastAckPrintTimestamp int64
} }
type pktInfo struct { type pktInfo struct {
@ -32,11 +39,13 @@ type pktInfo struct {
LossCount uint64 LossCount uint64
} }
func NewBrutalSender(bps uint64) *BrutalSender { func NewBrutalSender(bps uint64, debug bool, logger logger.Logger) *BrutalSender {
bs := &BrutalSender{ bs := &BrutalSender{
bps: congestion.ByteCount(bps), bps: congestion.ByteCount(bps),
maxDatagramSize: initMaxDatagramSize, maxDatagramSize: initMaxDatagramSize,
ackRate: 1, ackRate: 1,
debug: debug,
logger: logger,
} }
bs.pacer = newPacer(func() congestion.ByteCount { bs.pacer = newPacer(func() congestion.ByteCount {
return congestion.ByteCount(float64(bs.bps) / bs.ackRate) return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
@ -65,7 +74,7 @@ func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
if rtt <= 0 { if rtt <= 0 {
return 10240 return 10240
} }
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate) return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate)
} }
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
@ -77,31 +86,26 @@ func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
priorInFlight congestion.ByteCount, eventTime time.Time, priorInFlight congestion.ByteCount, eventTime time.Time,
) { ) {
currentTimestamp := eventTime.Unix() // Stub
slot := currentTimestamp % pktInfoSlotCount
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
b.pktInfoSlots[slot].AckCount++
} else {
// uninitialized slot or too old, reset
b.pktInfoSlots[slot].Timestamp = currentTimestamp
b.pktInfoSlots[slot].AckCount = 1
b.pktInfoSlots[slot].LossCount = 0
}
b.updateAckRate(currentTimestamp)
} }
func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount, func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount,
priorInFlight congestion.ByteCount, priorInFlight congestion.ByteCount,
) { ) {
currentTimestamp := time.Now().Unix() // Stub
}
func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) {
currentTimestamp := eventTime.Unix()
slot := currentTimestamp % pktInfoSlotCount slot := currentTimestamp % pktInfoSlotCount
if b.pktInfoSlots[slot].Timestamp == currentTimestamp { if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
b.pktInfoSlots[slot].LossCount++ b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets))
b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets))
} else { } else {
// uninitialized slot or too old, reset // uninitialized slot or too old, reset
b.pktInfoSlots[slot].Timestamp = currentTimestamp b.pktInfoSlots[slot].Timestamp = currentTimestamp
b.pktInfoSlots[slot].AckCount = 0 b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets))
b.pktInfoSlots[slot].LossCount = 1 b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets))
} }
b.updateAckRate(currentTimestamp) b.updateAckRate(currentTimestamp)
} }
@ -109,6 +113,9 @@ func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostByt
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) { func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
b.maxDatagramSize = size b.maxDatagramSize = size
b.pacer.SetMaxDatagramSize(size) b.pacer.SetMaxDatagramSize(size)
if b.debug {
b.debugPrint("SetMaxDatagramSize: %d", size)
}
} }
func (b *BrutalSender) updateAckRate(currentTimestamp int64) { func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
@ -123,12 +130,29 @@ func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
} }
if ackCount+lossCount < minSampleCount { if ackCount+lossCount < minSampleCount {
b.ackRate = 1 b.ackRate = 1
if b.canPrintAckRate(currentTimestamp) {
b.lastAckPrintTimestamp = currentTimestamp
b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)",
ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
}
return
} }
rate := float64(ackCount) / float64(ackCount+lossCount) rate := float64(ackCount) / float64(ackCount+lossCount)
if rate < minAckRate { if rate < minAckRate {
b.ackRate = minAckRate b.ackRate = minAckRate
if b.canPrintAckRate(currentTimestamp) {
b.lastAckPrintTimestamp = currentTimestamp
b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
}
return
} }
b.ackRate = rate b.ackRate = rate
if b.canPrintAckRate(currentTimestamp) {
b.lastAckPrintTimestamp = currentTimestamp
b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
}
} }
func (b *BrutalSender) InSlowStart() bool { func (b *BrutalSender) InSlowStart() bool {
@ -143,6 +167,14 @@ func (b *BrutalSender) MaybeExitSlowStart() {}
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {} func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool {
return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval
}
func (b *BrutalSender) debugPrint(format string, a ...any) {
b.logger.Debug("[brutal] ", fmt.Sprintf(format, a...))
}
func maxDuration(a, b time.Duration) time.Duration { func maxDuration(a, b time.Duration) time.Duration {
if a > b { if a > b {
return a return a

View file

@ -44,6 +44,9 @@ func (p *pacer) Budget(now time.Time) congestion.ByteCount {
return p.maxBurstSize() return p.maxBurstSize()
} }
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
if budget < 0 { // protect against overflows
budget = congestion.ByteCount(1<<62 - 1)
}
return minByteCount(p.maxBurstSize(), budget) return minByteCount(p.maxBurstSize(), budget)
} }

View file

@ -33,6 +33,7 @@ import (
type ServiceOptions struct { type ServiceOptions struct {
Context context.Context Context context.Context
Logger logger.Logger Logger logger.Logger
BrutalDebug bool
SendBPS uint64 SendBPS uint64
ReceiveBPS uint64 ReceiveBPS uint64
IgnoreClientBandwidth bool IgnoreClientBandwidth bool
@ -51,6 +52,7 @@ type ServerHandler interface {
type Service[U comparable] struct { type Service[U comparable] struct {
ctx context.Context ctx context.Context
logger logger.Logger logger logger.Logger
brutalDebug bool
sendBPS uint64 sendBPS uint64
receiveBPS uint64 receiveBPS uint64
ignoreClientBandwidth bool ignoreClientBandwidth bool
@ -82,6 +84,7 @@ func NewService[U comparable](options ServiceOptions) (*Service[U], error) {
return &Service[U]{ return &Service[U]{
ctx: options.Context, ctx: options.Context,
logger: options.Logger, logger: options.Logger,
brutalDebug: options.BrutalDebug,
sendBPS: options.SendBPS, sendBPS: options.SendBPS,
receiveBPS: options.ReceiveBPS, receiveBPS: options.ReceiveBPS,
ignoreClientBandwidth: options.IgnoreClientBandwidth, ignoreClientBandwidth: options.IgnoreClientBandwidth,
@ -199,7 +202,7 @@ func (s *serverSession[U]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
sendBps = request.Rx sendBps = request.Rx
} }
format.ToString(1024 * 1024) format.ToString(1024 * 1024)
s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(sendBps)) s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(sendBps, s.brutalDebug, s.logger))
} else { } else {
timeFunc := ntp.TimeFuncFromContext(s.ctx) timeFunc := ntp.TimeFuncFromContext(s.ctx)
if timeFunc == nil { if timeFunc == nil {