package trafficlogger
import (
"cmp"
"encoding/json"
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/apernet/hysteria/core/v2/server"
"github.com/apernet/quic-go"
)
const (
indexHTML = `
Hysteria Traffic Stats API Server This is a Hysteria Traffic Stats API server.
Check the documentation for usage.
`
)
// TrafficStatsServer implements both server.TrafficLogger and http.Handler
// to provide a simple HTTP API to get the traffic stats per user.
type TrafficStatsServer interface {
server.TrafficLogger
http.Handler
}
func NewTrafficStatsServer(secret string) TrafficStatsServer {
return &trafficStatsServerImpl{
StatsMap: make(map[string]*trafficStatsEntry),
KickMap: make(map[string]struct{}),
OnlineMap: make(map[string]int),
StreamMap: make(map[quic.Stream]*server.StreamStats),
Secret: secret,
}
}
type trafficStatsServerImpl struct {
Mutex sync.RWMutex
StatsMap map[string]*trafficStatsEntry
OnlineMap map[string]int
StreamMap map[quic.Stream]*server.StreamStats
KickMap map[string]struct{}
Secret string
}
type trafficStatsEntry struct {
Tx uint64 `json:"tx"`
Rx uint64 `json:"rx"`
}
func (s *trafficStatsServerImpl) LogTraffic(id string, tx, rx uint64) (ok bool) {
s.Mutex.Lock()
defer s.Mutex.Unlock()
_, ok = s.KickMap[id]
if ok {
delete(s.KickMap, id)
return false
}
entry, ok := s.StatsMap[id]
if !ok {
entry = &trafficStatsEntry{}
s.StatsMap[id] = entry
}
entry.Tx += tx
entry.Rx += rx
return true
}
// LogOnlineStateChanged updates the online state to the online map.
func (s *trafficStatsServerImpl) LogOnlineState(id string, online bool) {
s.Mutex.Lock()
defer s.Mutex.Unlock()
if online {
s.OnlineMap[id]++
} else {
s.OnlineMap[id]--
if s.OnlineMap[id] <= 0 {
delete(s.OnlineMap, id)
}
}
}
func (s *trafficStatsServerImpl) TraceStream(stream quic.Stream, stats *server.StreamStats) {
s.Mutex.Lock()
defer s.Mutex.Unlock()
s.StreamMap[stream] = stats
}
func (s *trafficStatsServerImpl) UntraceStream(stream quic.Stream) {
s.Mutex.Lock()
defer s.Mutex.Unlock()
delete(s.StreamMap, stream)
}
func (s *trafficStatsServerImpl) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.Secret != "" && r.Header.Get("Authorization") != s.Secret {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Method == http.MethodGet && r.URL.Path == "/" {
_, _ = w.Write([]byte(indexHTML))
return
}
if r.Method == http.MethodGet && r.URL.Path == "/traffic" {
s.getTraffic(w, r)
return
}
if r.Method == http.MethodPost && r.URL.Path == "/kick" {
s.kick(w, r)
return
}
if r.Method == http.MethodGet && r.URL.Path == "/online" {
s.getOnline(w, r)
return
}
if r.Method == http.MethodGet && r.URL.Path == "/dump/streams" {
s.getDumpStreams(w, r)
return
}
http.NotFound(w, r)
}
func (s *trafficStatsServerImpl) getTraffic(w http.ResponseWriter, r *http.Request) {
bClear, _ := strconv.ParseBool(r.URL.Query().Get("clear"))
var jb []byte
var err error
if bClear {
s.Mutex.Lock()
jb, err = json.Marshal(s.StatsMap)
s.StatsMap = make(map[string]*trafficStatsEntry)
s.Mutex.Unlock()
} else {
s.Mutex.RLock()
jb, err = json.Marshal(s.StatsMap)
s.Mutex.RUnlock()
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
_, _ = w.Write(jb)
}
func (s *trafficStatsServerImpl) getOnline(w http.ResponseWriter, r *http.Request) {
s.Mutex.RLock()
defer s.Mutex.RUnlock()
jb, err := json.Marshal(s.OnlineMap)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
_, _ = w.Write(jb)
}
type dumpStreamEntry struct {
State string `json:"state"`
Auth string `json:"auth"`
Connection uint32 `json:"connection"`
Stream uint64 `json:"stream"`
ReqAddr string `json:"req_addr"`
HookedReqAddr string `json:"hooked_req_addr"`
Tx uint64 `json:"tx"`
Rx uint64 `json:"rx"`
InitialAt string `json:"initial_at"`
LastActiveAt string `json:"last_active_at"`
// for text/plain output
initialTime time.Time
lastActiveTime time.Time
}
func (e *dumpStreamEntry) fromStreamStats(stream quic.Stream, s *server.StreamStats) {
e.State = s.State.Load().String()
e.Auth = s.AuthID
e.Connection = s.ConnID
e.Stream = uint64(stream.StreamID())
e.ReqAddr = s.ReqAddr.Load()
e.HookedReqAddr = s.HookedReqAddr.Load()
e.Tx = s.Tx.Load()
e.Rx = s.Rx.Load()
e.initialTime = s.InitialTime
e.lastActiveTime = s.LastActiveTime.Load()
e.InitialAt = e.initialTime.Format(time.RFC3339Nano)
e.LastActiveAt = e.lastActiveTime.Format(time.RFC3339Nano)
}
func formatDumpStreamLine(state, auth, connection, stream, reqAddr, hookedReqAddr, tx, rx, lifetime, lastActive string) string {
return fmt.Sprintf("%-8s %-12s %12s %8s %12s %12s %12s %12s %-16s %s", state, auth, connection, stream, tx, rx, lifetime, lastActive, reqAddr, hookedReqAddr)
}
func (e *dumpStreamEntry) String() string {
stateText := strings.ToUpper(e.State)
connectionText := fmt.Sprintf("%08X", e.Connection)
streamText := strconv.FormatUint(e.Stream, 10)
reqAddrText := e.ReqAddr
if reqAddrText == "" {
reqAddrText = "-"
}
hookedReqAddrText := e.HookedReqAddr
if hookedReqAddrText == "" {
hookedReqAddrText = "-"
}
txText := strconv.FormatUint(e.Tx, 10)
rxText := strconv.FormatUint(e.Rx, 10)
lifetime := time.Now().Sub(e.initialTime)
if lifetime < 10*time.Minute {
lifetime = lifetime.Round(time.Millisecond)
} else {
lifetime = lifetime.Round(time.Second)
}
lastActive := time.Now().Sub(e.lastActiveTime)
if lastActive < 10*time.Minute {
lastActive = lastActive.Round(time.Millisecond)
} else {
lastActive = lastActive.Round(time.Second)
}
return formatDumpStreamLine(stateText, e.Auth, connectionText, streamText, reqAddrText, hookedReqAddrText, txText, rxText, lifetime.String(), lastActive.String())
}
func (s *trafficStatsServerImpl) getDumpStreams(w http.ResponseWriter, r *http.Request) {
var entries []dumpStreamEntry
s.Mutex.RLock()
entries = make([]dumpStreamEntry, len(s.StreamMap))
index := 0
for stream, stats := range s.StreamMap {
entries[index].fromStreamStats(stream, stats)
index++
}
s.Mutex.RUnlock()
slices.SortFunc(entries, func(lhs, rhs dumpStreamEntry) int {
if ret := cmp.Compare(lhs.Auth, rhs.Auth); ret != 0 {
return ret
}
if ret := cmp.Compare(lhs.Connection, rhs.Connection); ret != 0 {
return ret
}
if ret := cmp.Compare(lhs.Stream, rhs.Stream); ret != 0 {
return ret
}
return 0
})
accept := r.Header.Get("Accept")
if strings.Contains(accept, "text/plain") {
// Generate netstat-like output for humans
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
// Print table header
_, _ = fmt.Fprintln(w, formatDumpStreamLine("State", "Auth", "Connection", "Stream", "Req-Addr", "Hooked-Req-Addr", "TX-Bytes", "RX-Bytes", "Lifetime", "Last-Active"))
for _, entry := range entries {
_, _ = fmt.Fprintln(w, entry.String())
}
return
}
// Response with json by default
wrapper := struct {
Streams []dumpStreamEntry `json:"streams"`
}{entries}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
err := json.NewEncoder(w).Encode(&wrapper)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func (s *trafficStatsServerImpl) kick(w http.ResponseWriter, r *http.Request) {
var ids []string
err := json.NewDecoder(r.Body).Decode(&ids)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
s.Mutex.Lock()
for _, id := range ids {
s.KickMap[id] = struct{}{}
}
s.Mutex.Unlock()
w.WriteHeader(http.StatusOK)
}