From 78840814285b232653ae96b13244c66dd0c4ee54 Mon Sep 17 00:00:00 2001 From: tobyxdd Date: Sat, 4 Feb 2023 11:56:34 -0800 Subject: [PATCH] feat: auth conn limiter --- app/auth/limiter.go | 41 +++++++++++++++++++++++++++++++++++++++++ app/cmd/config.go | 5 +++-- app/cmd/server.go | 23 +++++++++++++++-------- 3 files changed, 59 insertions(+), 10 deletions(-) create mode 100644 app/auth/limiter.go diff --git a/app/auth/limiter.go b/app/auth/limiter.go new file mode 100644 index 0000000..8c3cb4d --- /dev/null +++ b/app/auth/limiter.go @@ -0,0 +1,41 @@ +package auth + +import "sync" + +type ConnLimiter struct { + MaxConn int // <= 0 means no limit + + connMap map[string]int + mutex sync.RWMutex +} + +func (l *ConnLimiter) Connect(auth []byte) bool { + if l.MaxConn <= 0 { + return true + } + l.mutex.Lock() + defer l.mutex.Unlock() + if l.connMap == nil { + l.connMap = make(map[string]int) + } + authStr := string(auth) + if l.connMap[authStr] >= l.MaxConn { + return false + } + l.connMap[authStr]++ + return true +} + +func (l *ConnLimiter) Disconnect(auth []byte) { + if l.MaxConn <= 0 { + return + } + l.mutex.Lock() + defer l.mutex.Unlock() + authStr := string(auth) + if l.connMap[authStr] > 1 { + l.connMap[authStr]-- + } else { + delete(l.connMap, authStr) + } +} diff --git a/app/cmd/config.go b/app/cmd/config.go index 65f7e4e..c09a29c 100644 --- a/app/cmd/config.go +++ b/app/cmd/config.go @@ -54,8 +54,9 @@ type serverConfig struct { MMDB string `json:"mmdb"` Obfs string `json:"obfs"` Auth struct { - Mode string `json:"mode"` - Config json5.RawMessage `json:"config"` + Mode string `json:"mode"` + Config json5.RawMessage `json:"config"` + ConnLimit int `json:"conn_limit"` } `json:"auth"` ALPN string `json:"alpn"` PrometheusListen string `json:"prometheus_listen"` diff --git a/app/cmd/server.go b/app/cmd/server.go index c6909d3..ef40c90 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -123,6 +123,7 @@ func server(config *serverConfig) { default: logrus.WithField("mode", config.Auth.Mode).Fatal("Unsupported authentication mode") } + connLimiter := auth.ConnLimiter{MaxConn: config.Auth.ConnLimit} connectFunc := func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) { ok, msg := authFunc(addr, auth, sSend, sRecv) if !ok { @@ -130,13 +131,26 @@ func server(config *serverConfig) { "src": defaultIPMasker.Mask(addr.String()), "msg": msg, }).Info("Authentication failed, client rejected") - } else { + } else if connLimiter.Connect(auth) { logrus.WithFields(logrus.Fields{ "src": defaultIPMasker.Mask(addr.String()), }).Info("Client connected") + } else { + logrus.WithFields(logrus.Fields{ + "src": defaultIPMasker.Mask(addr.String()), + }).Info("Client rejected due to connection limit") + ok = false + msg = "too many connections" } return ok, msg } + disconnectFunc := func(addr net.Addr, auth []byte, err error) { + connLimiter.Disconnect(auth) + logrus.WithFields(logrus.Fields{ + "src": defaultIPMasker.Mask(addr.String()), + "error": err, + }).Info("Client disconnected") + } // Resolve preference if len(config.ResolvePreference) > 0 { pref, err := transport.ResolvePreferenceFromString(config.ResolvePreference) @@ -230,13 +244,6 @@ func server(config *serverConfig) { logrus.WithField("error", err).Fatal("Server shutdown") } -func disconnectFunc(addr net.Addr, auth []byte, err error) { - logrus.WithFields(logrus.Fields{ - "src": defaultIPMasker.Mask(addr.String()), - "error": err, - }).Info("Client disconnected") -} - func tcpRequestFunc(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) { logrus.WithFields(logrus.Fields{ "src": defaultIPMasker.Mask(addr.String()),