mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-02 03:57:38 +03:00
feat: auth conn limiter
This commit is contained in:
parent
cb4daac18d
commit
7884081428
3 changed files with 59 additions and 10 deletions
41
app/auth/limiter.go
Normal file
41
app/auth/limiter.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package auth
|
||||
|
||||
import "sync"
|
||||
|
||||
type ConnLimiter struct {
|
||||
MaxConn int // <= 0 means no limit
|
||||
|
||||
connMap map[string]int
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (l *ConnLimiter) Connect(auth []byte) bool {
|
||||
if l.MaxConn <= 0 {
|
||||
return true
|
||||
}
|
||||
l.mutex.Lock()
|
||||
defer l.mutex.Unlock()
|
||||
if l.connMap == nil {
|
||||
l.connMap = make(map[string]int)
|
||||
}
|
||||
authStr := string(auth)
|
||||
if l.connMap[authStr] >= l.MaxConn {
|
||||
return false
|
||||
}
|
||||
l.connMap[authStr]++
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *ConnLimiter) Disconnect(auth []byte) {
|
||||
if l.MaxConn <= 0 {
|
||||
return
|
||||
}
|
||||
l.mutex.Lock()
|
||||
defer l.mutex.Unlock()
|
||||
authStr := string(auth)
|
||||
if l.connMap[authStr] > 1 {
|
||||
l.connMap[authStr]--
|
||||
} else {
|
||||
delete(l.connMap, authStr)
|
||||
}
|
||||
}
|
|
@ -54,8 +54,9 @@ type serverConfig struct {
|
|||
MMDB string `json:"mmdb"`
|
||||
Obfs string `json:"obfs"`
|
||||
Auth struct {
|
||||
Mode string `json:"mode"`
|
||||
Config json5.RawMessage `json:"config"`
|
||||
Mode string `json:"mode"`
|
||||
Config json5.RawMessage `json:"config"`
|
||||
ConnLimit int `json:"conn_limit"`
|
||||
} `json:"auth"`
|
||||
ALPN string `json:"alpn"`
|
||||
PrometheusListen string `json:"prometheus_listen"`
|
||||
|
|
|
@ -123,6 +123,7 @@ func server(config *serverConfig) {
|
|||
default:
|
||||
logrus.WithField("mode", config.Auth.Mode).Fatal("Unsupported authentication mode")
|
||||
}
|
||||
connLimiter := auth.ConnLimiter{MaxConn: config.Auth.ConnLimit}
|
||||
connectFunc := func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) {
|
||||
ok, msg := authFunc(addr, auth, sSend, sRecv)
|
||||
if !ok {
|
||||
|
@ -130,13 +131,26 @@ func server(config *serverConfig) {
|
|||
"src": defaultIPMasker.Mask(addr.String()),
|
||||
"msg": msg,
|
||||
}).Info("Authentication failed, client rejected")
|
||||
} else {
|
||||
} else if connLimiter.Connect(auth) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"src": defaultIPMasker.Mask(addr.String()),
|
||||
}).Info("Client connected")
|
||||
} else {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"src": defaultIPMasker.Mask(addr.String()),
|
||||
}).Info("Client rejected due to connection limit")
|
||||
ok = false
|
||||
msg = "too many connections"
|
||||
}
|
||||
return ok, msg
|
||||
}
|
||||
disconnectFunc := func(addr net.Addr, auth []byte, err error) {
|
||||
connLimiter.Disconnect(auth)
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"src": defaultIPMasker.Mask(addr.String()),
|
||||
"error": err,
|
||||
}).Info("Client disconnected")
|
||||
}
|
||||
// Resolve preference
|
||||
if len(config.ResolvePreference) > 0 {
|
||||
pref, err := transport.ResolvePreferenceFromString(config.ResolvePreference)
|
||||
|
@ -230,13 +244,6 @@ func server(config *serverConfig) {
|
|||
logrus.WithField("error", err).Fatal("Server shutdown")
|
||||
}
|
||||
|
||||
func disconnectFunc(addr net.Addr, auth []byte, err error) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"src": defaultIPMasker.Mask(addr.String()),
|
||||
"error": err,
|
||||
}).Info("Client disconnected")
|
||||
}
|
||||
|
||||
func tcpRequestFunc(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"src": defaultIPMasker.Mask(addr.String()),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue