mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 04:27:39 +03:00
185 lines
5.3 KiB
Go
185 lines
5.3 KiB
Go
package brutal
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/apernet/hysteria/core/v2/internal/congestion/common"
|
|
|
|
"github.com/apernet/quic-go/congestion"
|
|
)
|
|
|
|
const (
|
|
pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample
|
|
minSampleCount = 50
|
|
minAckRate = 0.8
|
|
congestionWindowMultiplier = 2
|
|
|
|
debugEnv = "HYSTERIA_BRUTAL_DEBUG"
|
|
debugPrintInterval = 2
|
|
)
|
|
|
|
var _ congestion.CongestionControl = &BrutalSender{}
|
|
|
|
type BrutalSender struct {
|
|
rttStats congestion.RTTStatsProvider
|
|
bps congestion.ByteCount
|
|
maxDatagramSize congestion.ByteCount
|
|
pacer *common.Pacer
|
|
|
|
pktInfoSlots [pktInfoSlotCount]pktInfo
|
|
ackRate float64
|
|
|
|
debug bool
|
|
lastAckPrintTimestamp int64
|
|
}
|
|
|
|
type pktInfo struct {
|
|
Timestamp int64
|
|
AckCount uint64
|
|
LossCount uint64
|
|
}
|
|
|
|
func NewBrutalSender(bps uint64) *BrutalSender {
|
|
debug, _ := strconv.ParseBool(os.Getenv(debugEnv))
|
|
bs := &BrutalSender{
|
|
bps: congestion.ByteCount(bps),
|
|
maxDatagramSize: congestion.InitialPacketSizeIPv4,
|
|
ackRate: 1,
|
|
debug: debug,
|
|
}
|
|
bs.pacer = common.NewPacer(func() congestion.ByteCount {
|
|
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
|
|
})
|
|
return bs
|
|
}
|
|
|
|
func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
|
|
b.rttStats = rttStats
|
|
}
|
|
|
|
func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
|
|
return b.pacer.TimeUntilSend()
|
|
}
|
|
|
|
func (b *BrutalSender) HasPacingBudget(now time.Time) bool {
|
|
return b.pacer.Budget(now) >= b.maxDatagramSize
|
|
}
|
|
|
|
func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
|
return bytesInFlight <= b.GetCongestionWindow()
|
|
}
|
|
|
|
func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
|
|
rtt := b.rttStats.SmoothedRTT()
|
|
if rtt <= 0 {
|
|
return 10240
|
|
}
|
|
cwnd := congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate)
|
|
if cwnd < b.maxDatagramSize {
|
|
cwnd = b.maxDatagramSize
|
|
}
|
|
return cwnd
|
|
}
|
|
|
|
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
|
|
packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool,
|
|
) {
|
|
b.pacer.SentPacket(sentTime, bytes)
|
|
}
|
|
|
|
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
|
|
priorInFlight congestion.ByteCount, eventTime time.Time,
|
|
) {
|
|
// Stub
|
|
}
|
|
|
|
func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount,
|
|
priorInFlight congestion.ByteCount,
|
|
) {
|
|
// Stub
|
|
}
|
|
|
|
func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) {
|
|
currentTimestamp := eventTime.Unix()
|
|
slot := currentTimestamp % pktInfoSlotCount
|
|
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
|
b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets))
|
|
b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets))
|
|
} else {
|
|
// uninitialized slot or too old, reset
|
|
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
|
b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets))
|
|
b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets))
|
|
}
|
|
b.updateAckRate(currentTimestamp)
|
|
}
|
|
|
|
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
|
|
b.maxDatagramSize = size
|
|
b.pacer.SetMaxDatagramSize(size)
|
|
if b.debug {
|
|
b.debugPrint("SetMaxDatagramSize: %d", size)
|
|
}
|
|
}
|
|
|
|
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
|
|
minTimestamp := currentTimestamp - pktInfoSlotCount
|
|
var ackCount, lossCount uint64
|
|
for _, info := range b.pktInfoSlots {
|
|
if info.Timestamp < minTimestamp {
|
|
continue
|
|
}
|
|
ackCount += info.AckCount
|
|
lossCount += info.LossCount
|
|
}
|
|
if ackCount+lossCount < minSampleCount {
|
|
b.ackRate = 1
|
|
if b.canPrintAckRate(currentTimestamp) {
|
|
b.lastAckPrintTimestamp = currentTimestamp
|
|
b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)",
|
|
ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
|
|
}
|
|
return
|
|
}
|
|
rate := float64(ackCount) / float64(ackCount+lossCount)
|
|
if rate < minAckRate {
|
|
b.ackRate = minAckRate
|
|
if b.canPrintAckRate(currentTimestamp) {
|
|
b.lastAckPrintTimestamp = currentTimestamp
|
|
b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
|
|
rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
|
|
}
|
|
return
|
|
}
|
|
b.ackRate = rate
|
|
if b.canPrintAckRate(currentTimestamp) {
|
|
b.lastAckPrintTimestamp = currentTimestamp
|
|
b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
|
|
rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
|
|
}
|
|
}
|
|
|
|
func (b *BrutalSender) InSlowStart() bool {
|
|
return false
|
|
}
|
|
|
|
func (b *BrutalSender) InRecovery() bool {
|
|
return false
|
|
}
|
|
|
|
func (b *BrutalSender) MaybeExitSlowStart() {}
|
|
|
|
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
|
|
|
|
func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool {
|
|
return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval
|
|
}
|
|
|
|
func (b *BrutalSender) debugPrint(format string, a ...any) {
|
|
fmt.Printf("[BrutalSender] [%s] %s\n",
|
|
time.Now().Format("15:04:05"),
|
|
fmt.Sprintf(format, a...))
|
|
}
|