mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 20:47:38 +03:00
feat: auth conn limiter
This commit is contained in:
parent
cb4daac18d
commit
7884081428
3 changed files with 59 additions and 10 deletions
41
app/auth/limiter.go
Normal file
41
app/auth/limiter.go
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
type ConnLimiter struct {
|
||||||
|
MaxConn int // <= 0 means no limit
|
||||||
|
|
||||||
|
connMap map[string]int
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *ConnLimiter) Connect(auth []byte) bool {
|
||||||
|
if l.MaxConn <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
l.mutex.Lock()
|
||||||
|
defer l.mutex.Unlock()
|
||||||
|
if l.connMap == nil {
|
||||||
|
l.connMap = make(map[string]int)
|
||||||
|
}
|
||||||
|
authStr := string(auth)
|
||||||
|
if l.connMap[authStr] >= l.MaxConn {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
l.connMap[authStr]++
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *ConnLimiter) Disconnect(auth []byte) {
|
||||||
|
if l.MaxConn <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l.mutex.Lock()
|
||||||
|
defer l.mutex.Unlock()
|
||||||
|
authStr := string(auth)
|
||||||
|
if l.connMap[authStr] > 1 {
|
||||||
|
l.connMap[authStr]--
|
||||||
|
} else {
|
||||||
|
delete(l.connMap, authStr)
|
||||||
|
}
|
||||||
|
}
|
|
@ -56,6 +56,7 @@ type serverConfig struct {
|
||||||
Auth struct {
|
Auth struct {
|
||||||
Mode string `json:"mode"`
|
Mode string `json:"mode"`
|
||||||
Config json5.RawMessage `json:"config"`
|
Config json5.RawMessage `json:"config"`
|
||||||
|
ConnLimit int `json:"conn_limit"`
|
||||||
} `json:"auth"`
|
} `json:"auth"`
|
||||||
ALPN string `json:"alpn"`
|
ALPN string `json:"alpn"`
|
||||||
PrometheusListen string `json:"prometheus_listen"`
|
PrometheusListen string `json:"prometheus_listen"`
|
||||||
|
|
|
@ -123,6 +123,7 @@ func server(config *serverConfig) {
|
||||||
default:
|
default:
|
||||||
logrus.WithField("mode", config.Auth.Mode).Fatal("Unsupported authentication mode")
|
logrus.WithField("mode", config.Auth.Mode).Fatal("Unsupported authentication mode")
|
||||||
}
|
}
|
||||||
|
connLimiter := auth.ConnLimiter{MaxConn: config.Auth.ConnLimit}
|
||||||
connectFunc := func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) {
|
connectFunc := func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) {
|
||||||
ok, msg := authFunc(addr, auth, sSend, sRecv)
|
ok, msg := authFunc(addr, auth, sSend, sRecv)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -130,13 +131,26 @@ func server(config *serverConfig) {
|
||||||
"src": defaultIPMasker.Mask(addr.String()),
|
"src": defaultIPMasker.Mask(addr.String()),
|
||||||
"msg": msg,
|
"msg": msg,
|
||||||
}).Info("Authentication failed, client rejected")
|
}).Info("Authentication failed, client rejected")
|
||||||
} else {
|
} else if connLimiter.Connect(auth) {
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"src": defaultIPMasker.Mask(addr.String()),
|
"src": defaultIPMasker.Mask(addr.String()),
|
||||||
}).Info("Client connected")
|
}).Info("Client connected")
|
||||||
|
} else {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"src": defaultIPMasker.Mask(addr.String()),
|
||||||
|
}).Info("Client rejected due to connection limit")
|
||||||
|
ok = false
|
||||||
|
msg = "too many connections"
|
||||||
}
|
}
|
||||||
return ok, msg
|
return ok, msg
|
||||||
}
|
}
|
||||||
|
disconnectFunc := func(addr net.Addr, auth []byte, err error) {
|
||||||
|
connLimiter.Disconnect(auth)
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"src": defaultIPMasker.Mask(addr.String()),
|
||||||
|
"error": err,
|
||||||
|
}).Info("Client disconnected")
|
||||||
|
}
|
||||||
// Resolve preference
|
// Resolve preference
|
||||||
if len(config.ResolvePreference) > 0 {
|
if len(config.ResolvePreference) > 0 {
|
||||||
pref, err := transport.ResolvePreferenceFromString(config.ResolvePreference)
|
pref, err := transport.ResolvePreferenceFromString(config.ResolvePreference)
|
||||||
|
@ -230,13 +244,6 @@ func server(config *serverConfig) {
|
||||||
logrus.WithField("error", err).Fatal("Server shutdown")
|
logrus.WithField("error", err).Fatal("Server shutdown")
|
||||||
}
|
}
|
||||||
|
|
||||||
func disconnectFunc(addr net.Addr, auth []byte, err error) {
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"src": defaultIPMasker.Mask(addr.String()),
|
|
||||||
"error": err,
|
|
||||||
}).Info("Client disconnected")
|
|
||||||
}
|
|
||||||
|
|
||||||
func tcpRequestFunc(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) {
|
func tcpRequestFunc(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) {
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"src": defaultIPMasker.Mask(addr.String()),
|
"src": defaultIPMasker.Mask(addr.String()),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue