mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-05 12:57:38 +03:00
Improve user context
This commit is contained in:
parent
bd79d31e3b
commit
6795d518e1
6 changed files with 23 additions and 52 deletions
14
common/auth/context.go
Normal file
14
common/auth/context.go
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type userKey struct{}
|
||||||
|
|
||||||
|
func ContextWithUser[T any](ctx context.Context, user T) context.Context {
|
||||||
|
return context.WithValue(ctx, (*userKey)(nil), user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UserFromContext[T any](ctx context.Context) (T, bool) {
|
||||||
|
user, loaded := ctx.Value((*userKey)(nil)).(T)
|
||||||
|
return user, loaded
|
||||||
|
}
|
|
@ -36,6 +36,9 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
|
||||||
userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
|
userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
|
||||||
userPswdArr := strings.SplitN(string(userPassword), ":", 2)
|
userPswdArr := strings.SplitN(string(userPassword), ":", 2)
|
||||||
authOk = authenticator.Verify(userPswdArr[0], userPswdArr[1])
|
authOk = authenticator.Verify(userPswdArr[0], userPswdArr[1])
|
||||||
|
if authOk {
|
||||||
|
ctx = auth.ContextWithUser(ctx, userPswdArr[0])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if !authOk {
|
if !authOk {
|
||||||
err = responseWith(request, http.StatusProxyAuthRequired).Write(conn)
|
err = responseWith(request, http.StatusProxyAuthRequired).Write(conn)
|
||||||
|
|
|
@ -123,11 +123,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
|
||||||
}
|
}
|
||||||
metadata.Protocol = "socks4"
|
metadata.Protocol = "socks4"
|
||||||
metadata.Destination = request.Destination
|
metadata.Destination = request.Destination
|
||||||
ctx = &socks4.UserContext{
|
return handler.NewConnection(auth.ContextWithUser(ctx, request.Username), conn, metadata)
|
||||||
Context: ctx,
|
|
||||||
Username: request.Username,
|
|
||||||
}
|
|
||||||
return handler.NewConnection(ctx, conn, metadata)
|
|
||||||
default:
|
default:
|
||||||
err = socks4.WriteResponse(conn, socks4.Response{
|
err = socks4.WriteResponse(conn, socks4.Response{
|
||||||
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
|
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
|
||||||
|
@ -163,16 +159,12 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
userCtx := &socks5.UserContext{
|
|
||||||
Context: ctx,
|
|
||||||
}
|
|
||||||
if authMethod == socks5.AuthTypeUsernamePassword {
|
if authMethod == socks5.AuthTypeUsernamePassword {
|
||||||
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(conn)
|
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
userCtx.Username = usernamePasswordAuthRequest.Username
|
ctx = auth.ContextWithUser(ctx, usernamePasswordAuthRequest.Username)
|
||||||
userCtx.Password = usernamePasswordAuthRequest.Password
|
|
||||||
response := socks5.UsernamePasswordAuthResponse{}
|
response := socks5.UsernamePasswordAuthResponse{}
|
||||||
if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) {
|
if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) {
|
||||||
response.Status = socks5.UsernamePasswordStatusSuccess
|
response.Status = socks5.UsernamePasswordStatusSuccess
|
||||||
|
@ -184,7 +176,6 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx = userCtx
|
|
||||||
request, err := socks5.ReadRequest(conn)
|
request, err := socks5.ReadRequest(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -1,12 +0,0 @@
|
||||||
package socks4
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
type UserContext struct {
|
|
||||||
context.Context
|
|
||||||
Username string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *UserContext) Upstream() any {
|
|
||||||
return c.Context
|
|
||||||
}
|
|
|
@ -1,13 +0,0 @@
|
||||||
package socks5
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
type UserContext struct {
|
|
||||||
context.Context
|
|
||||||
Username string
|
|
||||||
Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *UserContext) Upstream() any {
|
|
||||||
return c.Context
|
|
||||||
}
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/auth"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
F "github.com/sagernet/sing/common/format"
|
F "github.com/sagernet/sing/common/format"
|
||||||
|
@ -19,16 +20,6 @@ type Handler interface {
|
||||||
N.UDPConnectionHandler
|
N.UDPConnectionHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
type Context[K comparable] struct {
|
|
||||||
context.Context
|
|
||||||
User K
|
|
||||||
Key [KeyLength]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *Context[K]) Upstream() any {
|
|
||||||
return ctx.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
type Service[K comparable] struct {
|
type Service[K comparable] struct {
|
||||||
handler Handler
|
handler Handler
|
||||||
keys map[[56]byte]K
|
keys map[[56]byte]K
|
||||||
|
@ -91,11 +82,8 @@ returnErr:
|
||||||
|
|
||||||
process:
|
process:
|
||||||
|
|
||||||
var userCtx Context[K]
|
|
||||||
userCtx.Context = ctx
|
|
||||||
if user, loaded := s.keys[key]; loaded {
|
if user, loaded := s.keys[key]; loaded {
|
||||||
userCtx.User = user
|
ctx = auth.ContextWithUser(ctx, user)
|
||||||
userCtx.Key = key
|
|
||||||
} else {
|
} else {
|
||||||
err = E.New("bad request")
|
err = E.New("bad request")
|
||||||
goto returnErr
|
goto returnErr
|
||||||
|
@ -134,9 +122,9 @@ process:
|
||||||
metadata.Destination = destination
|
metadata.Destination = destination
|
||||||
|
|
||||||
if command == CommandTCP {
|
if command == CommandTCP {
|
||||||
return s.handler.NewConnection(&userCtx, conn, metadata)
|
return s.handler.NewConnection(ctx, conn, metadata)
|
||||||
} else {
|
} else {
|
||||||
return s.handler.NewPacketConnection(&userCtx, &PacketConn{conn}, metadata)
|
return s.handler.NewPacketConnection(ctx, &PacketConn{conn}, metadata)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue