mirror of
https://github.com/SagerNet/sing-quic.git
synced 2025-04-04 20:37:41 +03:00
Update BBR and Hysteria congestion control
This commit is contained in:
parent
1ea488a342
commit
459406a10f
8 changed files with 102 additions and 35 deletions
|
@ -21,6 +21,8 @@ import (
|
|||
//
|
||||
|
||||
const (
|
||||
minBps = 65536 // 64 kbps
|
||||
|
||||
invalidPacketNumber = -1
|
||||
initialCongestionWindowPackets = 32
|
||||
|
||||
|
@ -284,10 +286,7 @@ func newBbrSender(
|
|||
maxCongestionWindowWithNetworkParametersAdjusted: initialMaxCongestionWindow,
|
||||
maxDatagramSize: initialMaxDatagramSize,
|
||||
}
|
||||
b.pacer = NewPacer(func() congestion.ByteCount {
|
||||
// Pacer wants bytes per second, but Bandwidth is in bits per second.
|
||||
return congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond))
|
||||
})
|
||||
b.pacer = NewPacer(b.bandwidthForPacer)
|
||||
|
||||
/*
|
||||
if b.tracer != nil {
|
||||
|
@ -484,10 +483,19 @@ func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, even
|
|||
b.calculateRecoveryWindow(bytesAcked, bytesLost)
|
||||
|
||||
// Cleanup internal state.
|
||||
if len(lostPackets) != 0 {
|
||||
lastLostPacket := lostPackets[len(lostPackets)-1].PacketNumber
|
||||
b.sampler.RemoveObsoletePackets(lastLostPacket)
|
||||
// This is where we clean up obsolete (acked or lost) packets from the bandwidth sampler.
|
||||
// The "least unacked" should actually be FirstOutstanding, but since we are not passing
|
||||
// that through OnCongestionEventEx, we will only do an estimate using acked/lost packets
|
||||
// for now. Because of fast retransmission, they should differ by no more than 2 packets.
|
||||
// (this is controlled by packetThreshold in quic-go's sentPacketHandler)
|
||||
var leastUnacked congestion.PacketNumber
|
||||
if len(ackedPackets) != 0 {
|
||||
leastUnacked = ackedPackets[len(ackedPackets)-1].PacketNumber - 2
|
||||
} else {
|
||||
leastUnacked = lostPackets[len(lostPackets)-1].PacketNumber + 1
|
||||
}
|
||||
b.sampler.RemoveObsoletePackets(leastUnacked)
|
||||
|
||||
if isRoundStart {
|
||||
b.numLossEventsInRound = 0
|
||||
b.bytesLostInRound = 0
|
||||
|
@ -537,6 +545,17 @@ func (b *bbrSender) bandwidthEstimate() Bandwidth {
|
|||
return b.maxBandwidth.GetBest()
|
||||
}
|
||||
|
||||
func (b *bbrSender) bandwidthForPacer() congestion.ByteCount {
|
||||
bps := congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond))
|
||||
if bps < minBps {
|
||||
// We need to make sure that the bandwidth value for pacer is never zero,
|
||||
// otherwise it will go into an edge case where HasPacingBudget = false
|
||||
// but TimeUntilSend is before, causing the quic-go send loop to go crazy and get stuck.
|
||||
return minBps
|
||||
}
|
||||
return bps
|
||||
}
|
||||
|
||||
// Returns the current estimate of the RTT of the connection. Outside of the
|
||||
// edge cases, this is minimum RTT.
|
||||
func (b *bbrSender) getMinRtt() time.Duration {
|
||||
|
|
|
@ -43,6 +43,9 @@ func (p *Pacer) Budget(now time.Time) congestion.ByteCount {
|
|||
return p.maxBurstSize()
|
||||
}
|
||||
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
||||
if budget < 0 { // protect against overflows
|
||||
budget = congestion.ByteCount(1<<62 - 1)
|
||||
}
|
||||
return Min(p.maxBurstSize(), budget)
|
||||
}
|
||||
|
||||
|
|
2
go.mod
2
go.mod
|
@ -3,7 +3,7 @@ module github.com/sagernet/sing-quic
|
|||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/sagernet/quic-go v0.0.0-20231001051131-0fc736a289bb
|
||||
github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460
|
||||
github.com/sagernet/sing v0.2.13
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/exp v0.0.0-20231005195138-3e424a577f31
|
||||
|
|
4
go.sum
4
go.sum
|
@ -23,8 +23,8 @@ github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
|||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
||||
github.com/sagernet/quic-go v0.0.0-20231001051131-0fc736a289bb h1:jlrVCepGBoob4QsPChIbe1j0d/lZSJkyVj2ukX3D4PE=
|
||||
github.com/sagernet/quic-go v0.0.0-20231001051131-0fc736a289bb/go.mod h1:uJGpmJCOcMQqMlHKc3P1Vz6uygmpz4bPeVIoOhdVQnM=
|
||||
github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460 h1:dAe4OIJAtE0nHOzTHhAReQteh3+sa63rvXbuIpbeOTY=
|
||||
github.com/sagernet/quic-go v0.0.0-20231008035953-32727fef9460/go.mod h1:uJGpmJCOcMQqMlHKc3P1Vz6uygmpz4bPeVIoOhdVQnM=
|
||||
github.com/sagernet/sing v0.2.13 h1:ohczGKWP+Yn3zlQXSvFn+6EKSELGggBi66D5rqpYRQ0=
|
||||
github.com/sagernet/sing v0.2.13/go.mod h1:AhNEHu0GXrpqkuzvTwvC8+j2cQUU/dh+zLEmq4C99pg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/sagernet/sing/common/baderror"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
|
@ -37,6 +38,8 @@ const (
|
|||
type ClientOptions struct {
|
||||
Context context.Context
|
||||
Dialer N.Dialer
|
||||
Logger logger.Logger
|
||||
BrutalDebug bool
|
||||
ServerAddress M.Socksaddr
|
||||
SendBPS uint64
|
||||
ReceiveBPS uint64
|
||||
|
@ -49,6 +52,8 @@ type ClientOptions struct {
|
|||
type Client struct {
|
||||
ctx context.Context
|
||||
dialer N.Dialer
|
||||
logger logger.Logger
|
||||
brutalDebug bool
|
||||
serverAddr M.Socksaddr
|
||||
sendBPS uint64
|
||||
receiveBPS uint64
|
||||
|
@ -76,6 +81,8 @@ func NewClient(options ClientOptions) (*Client, error) {
|
|||
return &Client{
|
||||
ctx: options.Context,
|
||||
dialer: options.Dialer,
|
||||
logger: options.Logger,
|
||||
brutalDebug: options.BrutalDebug,
|
||||
serverAddr: options.ServerAddress,
|
||||
sendBPS: options.SendBPS,
|
||||
receiveBPS: options.ReceiveBPS,
|
||||
|
@ -153,7 +160,7 @@ func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
|
|||
actualTx = c.sendBPS
|
||||
}
|
||||
if !authResponse.RxAuto && actualTx > 0 {
|
||||
quicConn.SetCongestionControl(hyCC.NewBrutalSender(actualTx))
|
||||
quicConn.SetCongestionControl(hyCC.NewBrutalSender(actualTx, c.brutalDebug, c.logger))
|
||||
} else {
|
||||
timeFunc := ntp.TimeFuncFromContext(c.ctx)
|
||||
if timeFunc == nil {
|
||||
|
|
|
@ -1,20 +1,24 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
initMaxDatagramSize = 1252
|
||||
|
||||
pktInfoSlotCount = 4
|
||||
pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample
|
||||
minSampleCount = 50
|
||||
minAckRate = 0.8
|
||||
congestionWindowMultiplier = 2
|
||||
debugPrintInterval = 2
|
||||
)
|
||||
|
||||
var _ congestion.CongestionControl = &BrutalSender{}
|
||||
var _ congestion.CongestionControlEx = &BrutalSender{}
|
||||
|
||||
type BrutalSender struct {
|
||||
rttStats congestion.RTTStatsProvider
|
||||
|
@ -24,6 +28,9 @@ type BrutalSender struct {
|
|||
|
||||
pktInfoSlots [pktInfoSlotCount]pktInfo
|
||||
ackRate float64
|
||||
debug bool
|
||||
logger logger.Logger
|
||||
lastAckPrintTimestamp int64
|
||||
}
|
||||
|
||||
type pktInfo struct {
|
||||
|
@ -32,11 +39,13 @@ type pktInfo struct {
|
|||
LossCount uint64
|
||||
}
|
||||
|
||||
func NewBrutalSender(bps uint64) *BrutalSender {
|
||||
func NewBrutalSender(bps uint64, debug bool, logger logger.Logger) *BrutalSender {
|
||||
bs := &BrutalSender{
|
||||
bps: congestion.ByteCount(bps),
|
||||
maxDatagramSize: initMaxDatagramSize,
|
||||
ackRate: 1,
|
||||
debug: debug,
|
||||
logger: logger,
|
||||
}
|
||||
bs.pacer = newPacer(func() congestion.ByteCount {
|
||||
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
|
||||
|
@ -65,7 +74,7 @@ func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
|
|||
if rtt <= 0 {
|
||||
return 10240
|
||||
}
|
||||
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate)
|
||||
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
|
||||
|
@ -77,31 +86,26 @@ func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion
|
|||
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
|
||||
priorInFlight congestion.ByteCount, eventTime time.Time,
|
||||
) {
|
||||
currentTimestamp := eventTime.Unix()
|
||||
slot := currentTimestamp % pktInfoSlotCount
|
||||
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
||||
b.pktInfoSlots[slot].AckCount++
|
||||
} else {
|
||||
// uninitialized slot or too old, reset
|
||||
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
||||
b.pktInfoSlots[slot].AckCount = 1
|
||||
b.pktInfoSlots[slot].LossCount = 0
|
||||
}
|
||||
b.updateAckRate(currentTimestamp)
|
||||
// Stub
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount,
|
||||
priorInFlight congestion.ByteCount,
|
||||
) {
|
||||
currentTimestamp := time.Now().Unix()
|
||||
// Stub
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) {
|
||||
currentTimestamp := eventTime.Unix()
|
||||
slot := currentTimestamp % pktInfoSlotCount
|
||||
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
||||
b.pktInfoSlots[slot].LossCount++
|
||||
b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets))
|
||||
b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets))
|
||||
} else {
|
||||
// uninitialized slot or too old, reset
|
||||
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
||||
b.pktInfoSlots[slot].AckCount = 0
|
||||
b.pktInfoSlots[slot].LossCount = 1
|
||||
b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets))
|
||||
b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets))
|
||||
}
|
||||
b.updateAckRate(currentTimestamp)
|
||||
}
|
||||
|
@ -109,6 +113,9 @@ func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostByt
|
|||
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
|
||||
b.maxDatagramSize = size
|
||||
b.pacer.SetMaxDatagramSize(size)
|
||||
if b.debug {
|
||||
b.debugPrint("SetMaxDatagramSize: %d", size)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
|
||||
|
@ -123,12 +130,29 @@ func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
|
|||
}
|
||||
if ackCount+lossCount < minSampleCount {
|
||||
b.ackRate = 1
|
||||
if b.canPrintAckRate(currentTimestamp) {
|
||||
b.lastAckPrintTimestamp = currentTimestamp
|
||||
b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)",
|
||||
ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
|
||||
}
|
||||
return
|
||||
}
|
||||
rate := float64(ackCount) / float64(ackCount+lossCount)
|
||||
if rate < minAckRate {
|
||||
b.ackRate = minAckRate
|
||||
if b.canPrintAckRate(currentTimestamp) {
|
||||
b.lastAckPrintTimestamp = currentTimestamp
|
||||
b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
|
||||
rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
|
||||
}
|
||||
return
|
||||
}
|
||||
b.ackRate = rate
|
||||
if b.canPrintAckRate(currentTimestamp) {
|
||||
b.lastAckPrintTimestamp = currentTimestamp
|
||||
b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
|
||||
rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BrutalSender) InSlowStart() bool {
|
||||
|
@ -143,6 +167,14 @@ func (b *BrutalSender) MaybeExitSlowStart() {}
|
|||
|
||||
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
|
||||
|
||||
func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool {
|
||||
return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval
|
||||
}
|
||||
|
||||
func (b *BrutalSender) debugPrint(format string, a ...any) {
|
||||
b.logger.Debug("[brutal] ", fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
func maxDuration(a, b time.Duration) time.Duration {
|
||||
if a > b {
|
||||
return a
|
||||
|
|
|
@ -44,6 +44,9 @@ func (p *pacer) Budget(now time.Time) congestion.ByteCount {
|
|||
return p.maxBurstSize()
|
||||
}
|
||||
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
||||
if budget < 0 { // protect against overflows
|
||||
budget = congestion.ByteCount(1<<62 - 1)
|
||||
}
|
||||
return minByteCount(p.maxBurstSize(), budget)
|
||||
}
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ import (
|
|||
type ServiceOptions struct {
|
||||
Context context.Context
|
||||
Logger logger.Logger
|
||||
BrutalDebug bool
|
||||
SendBPS uint64
|
||||
ReceiveBPS uint64
|
||||
IgnoreClientBandwidth bool
|
||||
|
@ -51,6 +52,7 @@ type ServerHandler interface {
|
|||
type Service[U comparable] struct {
|
||||
ctx context.Context
|
||||
logger logger.Logger
|
||||
brutalDebug bool
|
||||
sendBPS uint64
|
||||
receiveBPS uint64
|
||||
ignoreClientBandwidth bool
|
||||
|
@ -82,6 +84,7 @@ func NewService[U comparable](options ServiceOptions) (*Service[U], error) {
|
|||
return &Service[U]{
|
||||
ctx: options.Context,
|
||||
logger: options.Logger,
|
||||
brutalDebug: options.BrutalDebug,
|
||||
sendBPS: options.SendBPS,
|
||||
receiveBPS: options.ReceiveBPS,
|
||||
ignoreClientBandwidth: options.IgnoreClientBandwidth,
|
||||
|
@ -199,7 +202,7 @@ func (s *serverSession[U]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
sendBps = request.Rx
|
||||
}
|
||||
format.ToString(1024 * 1024)
|
||||
s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(sendBps))
|
||||
s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(sendBps, s.brutalDebug, s.logger))
|
||||
} else {
|
||||
timeFunc := ntp.TimeFuncFromContext(s.ctx)
|
||||
if timeFunc == nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue