Merge pull request #1075 from HynoR/feat/online

feat: Add getOnline feature
This commit is contained in:
Toby 2024-05-11 14:16:32 -07:00 committed by GitHub
commit a3c4cfa4b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 240 additions and 47 deletions

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.42.2. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package mocks

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.42.2. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package mocks

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.0. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package client
@ -24,6 +24,10 @@ func (_m *mockUDPIO) EXPECT() *mockUDPIO_Expecter {
func (_m *mockUDPIO) ReceiveMessage() (*protocol.UDPMessage, error) {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for ReceiveMessage")
}
var r0 *protocol.UDPMessage
var r1 error
if rf, ok := ret.Get(0).(func() (*protocol.UDPMessage, error)); ok {
@ -77,6 +81,10 @@ func (_c *mockUDPIO_ReceiveMessage_Call) RunAndReturn(run func() (*protocol.UDPM
func (_m *mockUDPIO) SendMessage(_a0 []byte, _a1 *protocol.UDPMessage) error {
ret := _m.Called(_a0, _a1)
if len(ret) == 0 {
panic("no return value specified for SendMessage")
}
var r0 error
if rf, ok := ret.Get(0).(func([]byte, *protocol.UDPMessage) error); ok {
r0 = rf(_a0, _a1)

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.0. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package mocks
@ -25,6 +25,10 @@ func (_m *MockAuthenticator) EXPECT() *MockAuthenticator_Expecter {
func (_m *MockAuthenticator) Authenticate(addr net.Addr, auth string, tx uint64) (bool, string) {
ret := _m.Called(addr, auth, tx)
if len(ret) == 0 {
panic("no return value specified for Authenticate")
}
var r0 bool
var r1 string
if rf, ok := ret.Get(0).(func(net.Addr, string, uint64) (bool, string)); ok {

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.0. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package mocks
@ -27,6 +27,10 @@ func (_m *MockConn) EXPECT() *MockConn_Expecter {
func (_m *MockConn) Close() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
@ -68,6 +72,10 @@ func (_c *MockConn_Close_Call) RunAndReturn(run func() error) *MockConn_Close_Ca
func (_m *MockConn) LocalAddr() net.Addr {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for LocalAddr")
}
var r0 net.Addr
if rf, ok := ret.Get(0).(func() net.Addr); ok {
r0 = rf()
@ -111,6 +119,10 @@ func (_c *MockConn_LocalAddr_Call) RunAndReturn(run func() net.Addr) *MockConn_L
func (_m *MockConn) Read(b []byte) (int, error) {
ret := _m.Called(b)
if len(ret) == 0 {
panic("no return value specified for Read")
}
var r0 int
var r1 error
if rf, ok := ret.Get(0).(func([]byte) (int, error)); ok {
@ -163,6 +175,10 @@ func (_c *MockConn_Read_Call) RunAndReturn(run func([]byte) (int, error)) *MockC
func (_m *MockConn) RemoteAddr() net.Addr {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for RemoteAddr")
}
var r0 net.Addr
if rf, ok := ret.Get(0).(func() net.Addr); ok {
r0 = rf()
@ -206,6 +222,10 @@ func (_c *MockConn_RemoteAddr_Call) RunAndReturn(run func() net.Addr) *MockConn_
func (_m *MockConn) SetDeadline(t time.Time) error {
ret := _m.Called(t)
if len(ret) == 0 {
panic("no return value specified for SetDeadline")
}
var r0 error
if rf, ok := ret.Get(0).(func(time.Time) error); ok {
r0 = rf(t)
@ -248,6 +268,10 @@ func (_c *MockConn_SetDeadline_Call) RunAndReturn(run func(time.Time) error) *Mo
func (_m *MockConn) SetReadDeadline(t time.Time) error {
ret := _m.Called(t)
if len(ret) == 0 {
panic("no return value specified for SetReadDeadline")
}
var r0 error
if rf, ok := ret.Get(0).(func(time.Time) error); ok {
r0 = rf(t)
@ -290,6 +314,10 @@ func (_c *MockConn_SetReadDeadline_Call) RunAndReturn(run func(time.Time) error)
func (_m *MockConn) SetWriteDeadline(t time.Time) error {
ret := _m.Called(t)
if len(ret) == 0 {
panic("no return value specified for SetWriteDeadline")
}
var r0 error
if rf, ok := ret.Get(0).(func(time.Time) error); ok {
r0 = rf(t)
@ -332,6 +360,10 @@ func (_c *MockConn_SetWriteDeadline_Call) RunAndReturn(run func(time.Time) error
func (_m *MockConn) Write(b []byte) (int, error) {
ret := _m.Called(b)
if len(ret) == 0 {
panic("no return value specified for Write")
}
var r0 int
var r1 error
if rf, ok := ret.Get(0).(func([]byte) (int, error)); ok {

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.0. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package mocks

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.0. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package mocks
@ -27,6 +27,10 @@ func (_m *MockOutbound) EXPECT() *MockOutbound_Expecter {
func (_m *MockOutbound) TCP(reqAddr string) (net.Conn, error) {
ret := _m.Called(reqAddr)
if len(ret) == 0 {
panic("no return value specified for TCP")
}
var r0 net.Conn
var r1 error
if rf, ok := ret.Get(0).(func(string) (net.Conn, error)); ok {
@ -81,6 +85,10 @@ func (_c *MockOutbound_TCP_Call) RunAndReturn(run func(string) (net.Conn, error)
func (_m *MockOutbound) UDP(reqAddr string) (server.UDPConn, error) {
ret := _m.Called(reqAddr)
if len(ret) == 0 {
panic("no return value specified for UDP")
}
var r0 server.UDPConn
var r1 error
if rf, ok := ret.Get(0).(func(string) (server.UDPConn, error)); ok {

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.0. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package mocks
@ -17,10 +17,48 @@ func (_m *MockTrafficLogger) EXPECT() *MockTrafficLogger_Expecter {
return &MockTrafficLogger_Expecter{mock: &_m.Mock}
}
// Log provides a mock function with given fields: id, tx, rx
func (_m *MockTrafficLogger) Log(id string, tx uint64, rx uint64) bool {
// LogOnlineState provides a mock function with given fields: id, online
func (_m *MockTrafficLogger) LogOnlineState(id string, online bool) {
_m.Called(id, online)
}
// MockTrafficLogger_LogOnlineState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LogOnlineState'
type MockTrafficLogger_LogOnlineState_Call struct {
*mock.Call
}
// LogOnlineState is a helper method to define mock.On call
// - id string
// - online bool
func (_e *MockTrafficLogger_Expecter) LogOnlineState(id interface{}, online interface{}) *MockTrafficLogger_LogOnlineState_Call {
return &MockTrafficLogger_LogOnlineState_Call{Call: _e.mock.On("LogOnlineState", id, online)}
}
func (_c *MockTrafficLogger_LogOnlineState_Call) Run(run func(id string, online bool)) *MockTrafficLogger_LogOnlineState_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(bool))
})
return _c
}
func (_c *MockTrafficLogger_LogOnlineState_Call) Return() *MockTrafficLogger_LogOnlineState_Call {
_c.Call.Return()
return _c
}
func (_c *MockTrafficLogger_LogOnlineState_Call) RunAndReturn(run func(string, bool)) *MockTrafficLogger_LogOnlineState_Call {
_c.Call.Return(run)
return _c
}
// LogTraffic provides a mock function with given fields: id, tx, rx
func (_m *MockTrafficLogger) LogTraffic(id string, tx uint64, rx uint64) bool {
ret := _m.Called(id, tx, rx)
if len(ret) == 0 {
panic("no return value specified for LogTraffic")
}
var r0 bool
if rf, ok := ret.Get(0).(func(string, uint64, uint64) bool); ok {
r0 = rf(id, tx, rx)
@ -31,32 +69,32 @@ func (_m *MockTrafficLogger) Log(id string, tx uint64, rx uint64) bool {
return r0
}
// MockTrafficLogger_Log_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Log'
type MockTrafficLogger_Log_Call struct {
// MockTrafficLogger_LogTraffic_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LogTraffic'
type MockTrafficLogger_LogTraffic_Call struct {
*mock.Call
}
// Log is a helper method to define mock.On call
// LogTraffic is a helper method to define mock.On call
// - id string
// - tx uint64
// - rx uint64
func (_e *MockTrafficLogger_Expecter) Log(id interface{}, tx interface{}, rx interface{}) *MockTrafficLogger_Log_Call {
return &MockTrafficLogger_Log_Call{Call: _e.mock.On("Log", id, tx, rx)}
func (_e *MockTrafficLogger_Expecter) LogTraffic(id interface{}, tx interface{}, rx interface{}) *MockTrafficLogger_LogTraffic_Call {
return &MockTrafficLogger_LogTraffic_Call{Call: _e.mock.On("LogTraffic", id, tx, rx)}
}
func (_c *MockTrafficLogger_Log_Call) Run(run func(id string, tx uint64, rx uint64)) *MockTrafficLogger_Log_Call {
func (_c *MockTrafficLogger_LogTraffic_Call) Run(run func(id string, tx uint64, rx uint64)) *MockTrafficLogger_LogTraffic_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(uint64), args[2].(uint64))
})
return _c
}
func (_c *MockTrafficLogger_Log_Call) Return(ok bool) *MockTrafficLogger_Log_Call {
func (_c *MockTrafficLogger_LogTraffic_Call) Return(ok bool) *MockTrafficLogger_LogTraffic_Call {
_c.Call.Return(ok)
return _c
}
func (_c *MockTrafficLogger_Log_Call) RunAndReturn(run func(string, uint64, uint64) bool) *MockTrafficLogger_Log_Call {
func (_c *MockTrafficLogger_LogTraffic_Call) RunAndReturn(run func(string, uint64, uint64) bool) *MockTrafficLogger_LogTraffic_Call {
_c.Call.Return(run)
return _c
}

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.0. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package mocks
@ -21,6 +21,10 @@ func (_m *MockUDPConn) EXPECT() *MockUDPConn_Expecter {
func (_m *MockUDPConn) Close() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
@ -62,6 +66,10 @@ func (_c *MockUDPConn_Close_Call) RunAndReturn(run func() error) *MockUDPConn_Cl
func (_m *MockUDPConn) ReadFrom(b []byte) (int, string, error) {
ret := _m.Called(b)
if len(ret) == 0 {
panic("no return value specified for ReadFrom")
}
var r0 int
var r1 string
var r2 error
@ -121,6 +129,10 @@ func (_c *MockUDPConn_ReadFrom_Call) RunAndReturn(run func([]byte) (int, string,
func (_m *MockUDPConn) WriteTo(b []byte, addr string) (int, error) {
ret := _m.Called(b, addr)
if len(ret) == 0 {
panic("no return value specified for WriteTo")
}
var r0 int
var r1 error
if rf, ok := ret.Get(0).(func([]byte, string) (int, error)); ok {

View file

@ -36,6 +36,7 @@ func TestClientServerTrafficLoggerTCP(t *testing.T) {
go s.Serve()
// Create client
trafficLogger.EXPECT().LogOnlineState("nobody", true).Return().Once()
c, _, err := client.NewClient(&client.Config{
ServerAddr: udpAddr,
TLSConfig: client.TLSConfig{InsecureSkipVerify: true},
@ -66,7 +67,7 @@ func TestClientServerTrafficLoggerTCP(t *testing.T) {
assert.NoError(t, err)
// Client reads from server
trafficLogger.EXPECT().Log("nobody", uint64(0), uint64(11)).Return(true).Once()
trafficLogger.EXPECT().LogTraffic("nobody", uint64(0), uint64(11)).Return(true).Once()
sobConnCh <- []byte("knock knock")
buf := make([]byte, 100)
n, err := conn.Read(buf)
@ -75,7 +76,7 @@ func TestClientServerTrafficLoggerTCP(t *testing.T) {
assert.Equal(t, "knock knock", string(buf[:n]))
// Client writes to server
trafficLogger.EXPECT().Log("nobody", uint64(12), uint64(0)).Return(true).Once()
trafficLogger.EXPECT().LogTraffic("nobody", uint64(12), uint64(0)).Return(true).Once()
sobConn.EXPECT().Write([]byte("who is there")).Return(12, nil).Once()
n, err = conn.Write([]byte("who is there"))
assert.NoError(t, err)
@ -83,7 +84,8 @@ func TestClientServerTrafficLoggerTCP(t *testing.T) {
time.Sleep(1 * time.Second) // Need some time for the server to receive the data
// Client reads from server again but blocked
trafficLogger.EXPECT().Log("nobody", uint64(0), uint64(4)).Return(false).Once()
trafficLogger.EXPECT().LogTraffic("nobody", uint64(0), uint64(4)).Return(false).Once()
trafficLogger.EXPECT().LogOnlineState("nobody", false).Return().Once()
sobConnCh <- []byte("nope")
n, err = conn.Read(buf)
assert.Zero(t, n)
@ -116,6 +118,7 @@ func TestClientServerTrafficLoggerUDP(t *testing.T) {
go s.Serve()
// Create client
trafficLogger.EXPECT().LogOnlineState("nobody", true).Return().Once()
c, _, err := client.NewClient(&client.Config{
ServerAddr: udpAddr,
TLSConfig: client.TLSConfig{InsecureSkipVerify: true},
@ -146,14 +149,14 @@ func TestClientServerTrafficLoggerUDP(t *testing.T) {
assert.NoError(t, err)
// Client writes to server
trafficLogger.EXPECT().Log("nobody", uint64(9), uint64(0)).Return(true).Once()
trafficLogger.EXPECT().LogTraffic("nobody", uint64(9), uint64(0)).Return(true).Once()
sobConn.EXPECT().WriteTo([]byte("small sad"), addr).Return(9, nil).Once()
err = conn.Send([]byte("small sad"), addr)
assert.NoError(t, err)
time.Sleep(1 * time.Second) // Need some time for the server to receive the data
// Client reads from server
trafficLogger.EXPECT().Log("nobody", uint64(0), uint64(7)).Return(true).Once()
trafficLogger.EXPECT().LogTraffic("nobody", uint64(0), uint64(7)).Return(true).Once()
sobConnCh <- []byte("big mad")
bs, rAddr, err := conn.Receive()
assert.NoError(t, err)
@ -161,7 +164,8 @@ func TestClientServerTrafficLoggerUDP(t *testing.T) {
assert.Equal(t, "big mad", string(bs))
// Client reads from server again but blocked
trafficLogger.EXPECT().Log("nobody", uint64(0), uint64(4)).Return(false).Once()
trafficLogger.EXPECT().LogTraffic("nobody", uint64(0), uint64(4)).Return(false).Once()
trafficLogger.EXPECT().LogOnlineState("nobody", false).Return().Once()
sobConnCh <- []byte("nope")
bs, rAddr, err = conn.Receive()
assert.Equal(t, err, io.EOF)

View file

@ -195,5 +195,6 @@ type EventLogger interface {
// bandwidth limits or post-connection authentication, for example.
// The implementation of this interface must be thread-safe.
type TrafficLogger interface {
Log(id string, tx, rx uint64) (ok bool)
LogTraffic(id string, tx, rx uint64) (ok bool)
LogOnlineState(id string, online bool)
}

View file

@ -35,12 +35,12 @@ func copyTwoWayWithLogger(id string, serverRw, remoteRw io.ReadWriter, l Traffic
errChan := make(chan error, 2)
go func() {
errChan <- copyBufferLog(serverRw, remoteRw, func(n uint64) bool {
return l.Log(id, 0, n)
return l.LogTraffic(id, 0, n)
})
}()
go func() {
errChan <- copyBufferLog(remoteRw, serverRw, func(n uint64) bool {
return l.Log(id, n, 0)
return l.LogTraffic(id, n, 0)
})
}()
// Block until one of the two goroutines returns

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.0. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package server
@ -21,6 +21,10 @@ func (_m *mockUDPConn) EXPECT() *mockUDPConn_Expecter {
func (_m *mockUDPConn) Close() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
@ -62,6 +66,10 @@ func (_c *mockUDPConn_Close_Call) RunAndReturn(run func() error) *mockUDPConn_Cl
func (_m *mockUDPConn) ReadFrom(b []byte) (int, string, error) {
ret := _m.Called(b)
if len(ret) == 0 {
panic("no return value specified for ReadFrom")
}
var r0 int
var r1 string
var r2 error
@ -121,6 +129,10 @@ func (_c *mockUDPConn_ReadFrom_Call) RunAndReturn(run func([]byte) (int, string,
func (_m *mockUDPConn) WriteTo(b []byte, addr string) (int, error) {
ret := _m.Called(b, addr)
if len(ret) == 0 {
panic("no return value specified for WriteTo")
}
var r0 int
var r1 error
if rf, ok := ret.Get(0).(func([]byte, string) (int, error)); ok {

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.0. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package server

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.0. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package server
@ -24,6 +24,10 @@ func (_m *mockUDPIO) EXPECT() *mockUDPIO_Expecter {
func (_m *mockUDPIO) ReceiveMessage() (*protocol.UDPMessage, error) {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for ReceiveMessage")
}
var r0 *protocol.UDPMessage
var r1 error
if rf, ok := ret.Get(0).(func() (*protocol.UDPMessage, error)); ok {
@ -77,6 +81,10 @@ func (_c *mockUDPIO_ReceiveMessage_Call) RunAndReturn(run func() (*protocol.UDPM
func (_m *mockUDPIO) SendMessage(_a0 []byte, _a1 *protocol.UDPMessage) error {
ret := _m.Called(_a0, _a1)
if len(ret) == 0 {
panic("no return value specified for SendMessage")
}
var r0 error
if rf, ok := ret.Get(0).(func([]byte, *protocol.UDPMessage) error); ok {
r0 = rf(_a0, _a1)
@ -120,6 +128,10 @@ func (_c *mockUDPIO_SendMessage_Call) RunAndReturn(run func([]byte, *protocol.UD
func (_m *mockUDPIO) UDP(reqAddr string) (UDPConn, error) {
ret := _m.Called(reqAddr)
if len(ret) == 0 {
panic("no return value specified for UDP")
}
var r0 UDPConn
var r1 error
if rf, ok := ret.Get(0).(func(string) (UDPConn, error)); ok {

View file

@ -82,8 +82,13 @@ func (s *serverImpl) handleClient(conn quic.Connection) {
}
err := h3s.ServeQUICConn(conn)
// If the client is authenticated, we need to log the disconnect event
if handler.authenticated && s.config.EventLogger != nil {
s.config.EventLogger.Disconnect(conn.RemoteAddr(), handler.authID, err)
if handler.authenticated {
if tl := s.config.TrafficLogger; tl != nil {
tl.LogOnlineState(handler.authID, false)
}
if el := s.config.EventLogger; el != nil {
el.Disconnect(conn.RemoteAddr(), handler.authID, err)
}
}
_ = conn.CloseWithError(closeErrCodeOK, "")
}
@ -153,8 +158,11 @@ func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
})
w.WriteHeader(protocol.StatusAuthOK)
// Call event logger
if h.config.EventLogger != nil {
h.config.EventLogger.Connect(h.conn.RemoteAddr(), id, actualTx)
if tl := h.config.TrafficLogger; tl != nil {
tl.LogOnlineState(id, true)
}
if el := h.config.EventLogger; el != nil {
el.Connect(h.conn.RemoteAddr(), id, actualTx)
}
// Initialize UDP session manager (if UDP is enabled)
// We use sync.Once to make sure that only one goroutine is started,
@ -268,7 +276,7 @@ func (io *udpIOImpl) ReceiveMessage() (*protocol.UDPMessage, error) {
continue
}
if io.TrafficLogger != nil {
ok := io.TrafficLogger.Log(io.AuthID, uint64(len(udpMsg.Data)), 0)
ok := io.TrafficLogger.LogTraffic(io.AuthID, uint64(len(udpMsg.Data)), 0)
if !ok {
// TrafficLogger requested to disconnect the client
_ = io.Conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
@ -281,7 +289,7 @@ func (io *udpIOImpl) ReceiveMessage() (*protocol.UDPMessage, error) {
func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
if io.TrafficLogger != nil {
ok := io.TrafficLogger.Log(io.AuthID, 0, uint64(len(msg.Data)))
ok := io.TrafficLogger.LogTraffic(io.AuthID, 0, uint64(len(msg.Data)))
if !ok {
// TrafficLogger requested to disconnect the client
_ = io.Conn.CloseWithError(closeErrCodeTrafficLimitReached, "")

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.0. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package outbounds
@ -25,6 +25,10 @@ func (_m *mockPluggableOutbound) EXPECT() *mockPluggableOutbound_Expecter {
func (_m *mockPluggableOutbound) TCP(reqAddr *AddrEx) (net.Conn, error) {
ret := _m.Called(reqAddr)
if len(ret) == 0 {
panic("no return value specified for TCP")
}
var r0 net.Conn
var r1 error
if rf, ok := ret.Get(0).(func(*AddrEx) (net.Conn, error)); ok {
@ -79,6 +83,10 @@ func (_c *mockPluggableOutbound_TCP_Call) RunAndReturn(run func(*AddrEx) (net.Co
func (_m *mockPluggableOutbound) UDP(reqAddr *AddrEx) (UDPConn, error) {
ret := _m.Called(reqAddr)
if len(ret) == 0 {
panic("no return value specified for UDP")
}
var r0 UDPConn
var r1 error
if rf, ok := ret.Get(0).(func(*AddrEx) (UDPConn, error)); ok {

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.0. DO NOT EDIT.
// Code generated by mockery v2.43.0. DO NOT EDIT.
package outbounds
@ -21,6 +21,10 @@ func (_m *mockUDPConn) EXPECT() *mockUDPConn_Expecter {
func (_m *mockUDPConn) Close() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
@ -62,6 +66,10 @@ func (_c *mockUDPConn_Close_Call) RunAndReturn(run func() error) *mockUDPConn_Cl
func (_m *mockUDPConn) ReadFrom(b []byte) (int, *AddrEx, error) {
ret := _m.Called(b)
if len(ret) == 0 {
panic("no return value specified for ReadFrom")
}
var r0 int
var r1 *AddrEx
var r2 error
@ -123,6 +131,10 @@ func (_c *mockUDPConn_ReadFrom_Call) RunAndReturn(run func([]byte) (int, *AddrEx
func (_m *mockUDPConn) WriteTo(b []byte, addr *AddrEx) (int, error) {
ret := _m.Called(b, addr)
if len(ret) == 0 {
panic("no return value specified for WriteTo")
}
var r0 int
var r1 error
if rf, ok := ret.Get(0).(func([]byte, *AddrEx) (int, error)); ok {

View file

@ -22,17 +22,19 @@ type TrafficStatsServer interface {
func NewTrafficStatsServer(secret string) TrafficStatsServer {
return &trafficStatsServerImpl{
StatsMap: make(map[string]*trafficStatsEntry),
KickMap: make(map[string]struct{}),
Secret: secret,
StatsMap: make(map[string]*trafficStatsEntry),
KickMap: make(map[string]struct{}),
OnlineMap: make(map[string]int),
Secret: secret,
}
}
type trafficStatsServerImpl struct {
Mutex sync.RWMutex
StatsMap map[string]*trafficStatsEntry
KickMap map[string]struct{}
Secret string
Mutex sync.RWMutex
StatsMap map[string]*trafficStatsEntry
OnlineMap map[string]int
KickMap map[string]struct{}
Secret string
}
type trafficStatsEntry struct {
@ -40,7 +42,7 @@ type trafficStatsEntry struct {
Rx uint64 `json:"rx"`
}
func (s *trafficStatsServerImpl) Log(id string, tx, rx uint64) (ok bool) {
func (s *trafficStatsServerImpl) LogTraffic(id string, tx, rx uint64) (ok bool) {
s.Mutex.Lock()
defer s.Mutex.Unlock()
@ -61,6 +63,21 @@ func (s *trafficStatsServerImpl) Log(id string, tx, rx uint64) (ok bool) {
return true
}
// LogOnlineStateChanged updates the online state to the online map.
func (s *trafficStatsServerImpl) LogOnlineState(id string, online bool) {
s.Mutex.Lock()
defer s.Mutex.Unlock()
if online {
s.OnlineMap[id]++
} else {
s.OnlineMap[id]--
if s.OnlineMap[id] <= 0 {
delete(s.OnlineMap, id)
}
}
}
func (s *trafficStatsServerImpl) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.Secret != "" && r.Header.Get("Authorization") != s.Secret {
http.Error(w, "unauthorized", http.StatusUnauthorized)
@ -78,6 +95,10 @@ func (s *trafficStatsServerImpl) ServeHTTP(w http.ResponseWriter, r *http.Reques
s.kick(w, r)
return
}
if r.Method == http.MethodGet && r.URL.Path == "/online" {
s.getOnline(w, r)
return
}
http.NotFound(w, r)
}
@ -103,6 +124,19 @@ func (s *trafficStatsServerImpl) getTraffic(w http.ResponseWriter, r *http.Reque
_, _ = w.Write(jb)
}
func (s *trafficStatsServerImpl) getOnline(w http.ResponseWriter, r *http.Request) {
s.Mutex.RLock()
defer s.Mutex.RUnlock()
jb, err := json.Marshal(s.OnlineMap)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
_, _ = w.Write(jb)
}
func (s *trafficStatsServerImpl) kick(w http.ResponseWriter, r *http.Request) {
var ids []string
err := json.NewDecoder(r.Body).Decode(&ids)