Update BBR and Hysteria congestion control

This commit is contained in:
世界 2023-10-08 12:17:25 +08:00
parent 1ea488a342
commit 459406a10f
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
8 changed files with 102 additions and 35 deletions

View file

@ -21,6 +21,8 @@ import (
//
const (
minBps = 65536 // 64 kbps
invalidPacketNumber = -1
initialCongestionWindowPackets = 32
@ -284,10 +286,7 @@ func newBbrSender(
maxCongestionWindowWithNetworkParametersAdjusted: initialMaxCongestionWindow,
maxDatagramSize: initialMaxDatagramSize,
}
b.pacer = NewPacer(func() congestion.ByteCount {
// Pacer wants bytes per second, but Bandwidth is in bits per second.
return congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond))
})
b.pacer = NewPacer(b.bandwidthForPacer)
/*
if b.tracer != nil {
@ -484,10 +483,19 @@ func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, even
b.calculateRecoveryWindow(bytesAcked, bytesLost)
// Cleanup internal state.
if len(lostPackets) != 0 {
lastLostPacket := lostPackets[len(lostPackets)-1].PacketNumber
b.sampler.RemoveObsoletePackets(lastLostPacket)
// This is where we clean up obsolete (acked or lost) packets from the bandwidth sampler.
// The "least unacked" should actually be FirstOutstanding, but since we are not passing
// that through OnCongestionEventEx, we will only do an estimate using acked/lost packets
// for now. Because of fast retransmission, they should differ by no more than 2 packets.
// (this is controlled by packetThreshold in quic-go's sentPacketHandler)
var leastUnacked congestion.PacketNumber
if len(ackedPackets) != 0 {
leastUnacked = ackedPackets[len(ackedPackets)-1].PacketNumber - 2
} else {
leastUnacked = lostPackets[len(lostPackets)-1].PacketNumber + 1
}
b.sampler.RemoveObsoletePackets(leastUnacked)
if isRoundStart {
b.numLossEventsInRound = 0
b.bytesLostInRound = 0
@ -537,6 +545,17 @@ func (b *bbrSender) bandwidthEstimate() Bandwidth {
return b.maxBandwidth.GetBest()
}
func (b *bbrSender) bandwidthForPacer() congestion.ByteCount {
bps := congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond))
if bps < minBps {
// We need to make sure that the bandwidth value for pacer is never zero,
// otherwise it will go into an edge case where HasPacingBudget = false
// but TimeUntilSend is before, causing the quic-go send loop to go crazy and get stuck.
return minBps
}
return bps
}
// Returns the current estimate of the RTT of the connection. Outside of the
// edge cases, this is minimum RTT.
func (b *bbrSender) getMinRtt() time.Duration {

View file

@ -43,6 +43,9 @@ func (p *Pacer) Budget(now time.Time) congestion.ByteCount {
return p.maxBurstSize()
}
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
if budget < 0 { // protect against overflows
budget = congestion.ByteCount(1<<62 - 1)
}
return Min(p.maxBurstSize(), budget)
}

2
go.mod
View file

@ -3,7 +3,7 @@ module github.com/sagernet/sing-quic
go 1.20
require (
github.com/sagernet/quic-go v0.0.0-20231001051131-0fc736a289bb
github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460
github.com/sagernet/sing v0.2.13
golang.org/x/crypto v0.14.0
golang.org/x/exp v0.0.0-20231005195138-3e424a577f31

4
go.sum
View file

@ -23,8 +23,8 @@ github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg=
github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/sagernet/quic-go v0.0.0-20231001051131-0fc736a289bb h1:jlrVCepGBoob4QsPChIbe1j0d/lZSJkyVj2ukX3D4PE=
github.com/sagernet/quic-go v0.0.0-20231001051131-0fc736a289bb/go.mod h1:uJGpmJCOcMQqMlHKc3P1Vz6uygmpz4bPeVIoOhdVQnM=
github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460 h1:dAe4OIJAtE0nHOzTHhAReQteh3+sa63rvXbuIpbeOTY=
github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460/go.mod h1:uJGpmJCOcMQqMlHKc3P1Vz6uygmpz4bPeVIoOhdVQnM=
github.com/sagernet/sing v0.2.13 h1:ohczGKWP+Yn3zlQXSvFn+6EKSELGggBi66D5rqpYRQ0=
github.com/sagernet/sing v0.2.13/go.mod h1:AhNEHu0GXrpqkuzvTwvC8+j2cQUU/dh+zLEmq4C99pg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

View file

@ -21,6 +21,7 @@ import (
"github.com/sagernet/sing/common/baderror"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
@ -37,6 +38,8 @@ const (
type ClientOptions struct {
Context context.Context
Dialer N.Dialer
Logger logger.Logger
BrutalDebug bool
ServerAddress M.Socksaddr
SendBPS uint64
ReceiveBPS uint64
@ -49,6 +52,8 @@ type ClientOptions struct {
type Client struct {
ctx context.Context
dialer N.Dialer
logger logger.Logger
brutalDebug bool
serverAddr M.Socksaddr
sendBPS uint64
receiveBPS uint64
@ -76,6 +81,8 @@ func NewClient(options ClientOptions) (*Client, error) {
return &Client{
ctx: options.Context,
dialer: options.Dialer,
logger: options.Logger,
brutalDebug: options.BrutalDebug,
serverAddr: options.ServerAddress,
sendBPS: options.SendBPS,
receiveBPS: options.ReceiveBPS,
@ -153,7 +160,7 @@ func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
actualTx = c.sendBPS
}
if !authResponse.RxAuto && actualTx > 0 {
quicConn.SetCongestionControl(hyCC.NewBrutalSender(actualTx))
quicConn.SetCongestionControl(hyCC.NewBrutalSender(actualTx, c.brutalDebug, c.logger))
} else {
timeFunc := ntp.TimeFuncFromContext(c.ctx)
if timeFunc == nil {

View file

@ -1,20 +1,24 @@
package congestion
import (
"fmt"
"time"
"github.com/sagernet/quic-go/congestion"
"github.com/sagernet/sing/common/logger"
)
const (
initMaxDatagramSize = 1252
pktInfoSlotCount = 4
pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample
minSampleCount = 50
minAckRate = 0.8
congestionWindowMultiplier = 2
debugPrintInterval = 2
)
var _ congestion.CongestionControl = &BrutalSender{}
var _ congestion.CongestionControlEx = &BrutalSender{}
type BrutalSender struct {
rttStats congestion.RTTStatsProvider
@ -24,6 +28,9 @@ type BrutalSender struct {
pktInfoSlots [pktInfoSlotCount]pktInfo
ackRate float64
debug bool
logger logger.Logger
lastAckPrintTimestamp int64
}
type pktInfo struct {
@ -32,11 +39,13 @@ type pktInfo struct {
LossCount uint64
}
func NewBrutalSender(bps uint64) *BrutalSender {
func NewBrutalSender(bps uint64, debug bool, logger logger.Logger) *BrutalSender {
bs := &BrutalSender{
bps: congestion.ByteCount(bps),
maxDatagramSize: initMaxDatagramSize,
ackRate: 1,
debug: debug,
logger: logger,
}
bs.pacer = newPacer(func() congestion.ByteCount {
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
@ -65,7 +74,7 @@ func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
if rtt <= 0 {
return 10240
}
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate)
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate)
}
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
@ -77,31 +86,26 @@ func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
priorInFlight congestion.ByteCount, eventTime time.Time,
) {
currentTimestamp := eventTime.Unix()
slot := currentTimestamp % pktInfoSlotCount
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
b.pktInfoSlots[slot].AckCount++
} else {
// uninitialized slot or too old, reset
b.pktInfoSlots[slot].Timestamp = currentTimestamp
b.pktInfoSlots[slot].AckCount = 1
b.pktInfoSlots[slot].LossCount = 0
}
b.updateAckRate(currentTimestamp)
// Stub
}
func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount,
priorInFlight congestion.ByteCount,
) {
currentTimestamp := time.Now().Unix()
// Stub
}
func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) {
currentTimestamp := eventTime.Unix()
slot := currentTimestamp % pktInfoSlotCount
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
b.pktInfoSlots[slot].LossCount++
b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets))
b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets))
} else {
// uninitialized slot or too old, reset
b.pktInfoSlots[slot].Timestamp = currentTimestamp
b.pktInfoSlots[slot].AckCount = 0
b.pktInfoSlots[slot].LossCount = 1
b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets))
b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets))
}
b.updateAckRate(currentTimestamp)
}
@ -109,6 +113,9 @@ func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostByt
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
b.maxDatagramSize = size
b.pacer.SetMaxDatagramSize(size)
if b.debug {
b.debugPrint("SetMaxDatagramSize: %d", size)
}
}
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
@ -123,12 +130,29 @@ func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
}
if ackCount+lossCount < minSampleCount {
b.ackRate = 1
if b.canPrintAckRate(currentTimestamp) {
b.lastAckPrintTimestamp = currentTimestamp
b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)",
ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
}
return
}
rate := float64(ackCount) / float64(ackCount+lossCount)
if rate < minAckRate {
b.ackRate = minAckRate
if b.canPrintAckRate(currentTimestamp) {
b.lastAckPrintTimestamp = currentTimestamp
b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
}
return
}
b.ackRate = rate
if b.canPrintAckRate(currentTimestamp) {
b.lastAckPrintTimestamp = currentTimestamp
b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
}
}
func (b *BrutalSender) InSlowStart() bool {
@ -143,6 +167,14 @@ func (b *BrutalSender) MaybeExitSlowStart() {}
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool {
return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval
}
func (b *BrutalSender) debugPrint(format string, a ...any) {
b.logger.Debug("[brutal] ", fmt.Sprintf(format, a...))
}
func maxDuration(a, b time.Duration) time.Duration {
if a > b {
return a

View file

@ -44,6 +44,9 @@ func (p *pacer) Budget(now time.Time) congestion.ByteCount {
return p.maxBurstSize()
}
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
if budget < 0 { // protect against overflows
budget = congestion.ByteCount(1<<62 - 1)
}
return minByteCount(p.maxBurstSize(), budget)
}

View file

@ -33,6 +33,7 @@ import (
type ServiceOptions struct {
Context context.Context
Logger logger.Logger
BrutalDebug bool
SendBPS uint64
ReceiveBPS uint64
IgnoreClientBandwidth bool
@ -51,6 +52,7 @@ type ServerHandler interface {
type Service[U comparable] struct {
ctx context.Context
logger logger.Logger
brutalDebug bool
sendBPS uint64
receiveBPS uint64
ignoreClientBandwidth bool
@ -82,6 +84,7 @@ func NewService[U comparable](options ServiceOptions) (*Service[U], error) {
return &Service[U]{
ctx: options.Context,
logger: options.Logger,
brutalDebug: options.BrutalDebug,
sendBPS: options.SendBPS,
receiveBPS: options.ReceiveBPS,
ignoreClientBandwidth: options.IgnoreClientBandwidth,
@ -199,7 +202,7 @@ func (s *serverSession[U]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
sendBps = request.Rx
}
format.ToString(1024 * 1024)
s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(sendBps))
s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(sendBps, s.brutalDebug, s.logger))
} else {
timeFunc := ntp.TimeFuncFromContext(s.ctx)
if timeFunc == nil {