diff --git a/common/auth/context.go b/common/auth/context.go new file mode 100644 index 0000000..c0899fe --- /dev/null +++ b/common/auth/context.go @@ -0,0 +1,14 @@ +package auth + +import "context" + +type userKey struct{} + +func ContextWithUser[T any](ctx context.Context, user T) context.Context { + return context.WithValue(ctx, (*userKey)(nil), user) +} + +func UserFromContext[T any](ctx context.Context) (T, bool) { + user, loaded := ctx.Value((*userKey)(nil)).(T) + return user, loaded +} diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 00087c0..0a90717 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -36,6 +36,9 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:]) userPswdArr := strings.SplitN(string(userPassword), ":", 2) authOk = authenticator.Verify(userPswdArr[0], userPswdArr[1]) + if authOk { + ctx = auth.ContextWithUser(ctx, userPswdArr[0]) + } } if !authOk { err = responseWith(request, http.StatusProxyAuthRequired).Write(conn) diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 98715a2..6cc891d 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -123,11 +123,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent } metadata.Protocol = "socks4" metadata.Destination = request.Destination - ctx = &socks4.UserContext{ - Context: ctx, - Username: request.Username, - } - return handler.NewConnection(ctx, conn, metadata) + return handler.NewConnection(auth.ContextWithUser(ctx, request.Username), conn, metadata) default: err = socks4.WriteResponse(conn, socks4.Response{ ReplyCode: socks4.ReplyCodeRejectedOrFailed, @@ -163,16 +159,12 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent if err != nil { return err } - userCtx := &socks5.UserContext{ - Context: ctx, - } if authMethod == socks5.AuthTypeUsernamePassword { usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(conn) if err != nil { return err } - userCtx.Username = usernamePasswordAuthRequest.Username - userCtx.Password = usernamePasswordAuthRequest.Password + ctx = auth.ContextWithUser(ctx, usernamePasswordAuthRequest.Username) response := socks5.UsernamePasswordAuthResponse{} if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) { response.Status = socks5.UsernamePasswordStatusSuccess @@ -184,7 +176,6 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent return err } } - ctx = userCtx request, err := socks5.ReadRequest(conn) if err != nil { return err diff --git a/protocol/socks/socks4/ctx.go b/protocol/socks/socks4/ctx.go deleted file mode 100644 index 554ff27..0000000 --- a/protocol/socks/socks4/ctx.go +++ /dev/null @@ -1,12 +0,0 @@ -package socks4 - -import "context" - -type UserContext struct { - context.Context - Username string -} - -func (c *UserContext) Upstream() any { - return c.Context -} diff --git a/protocol/socks/socks5/ctx.go b/protocol/socks/socks5/ctx.go deleted file mode 100644 index 81ba4b2..0000000 --- a/protocol/socks/socks5/ctx.go +++ /dev/null @@ -1,13 +0,0 @@ -package socks5 - -import "context" - -type UserContext struct { - context.Context - Username string - Password string -} - -func (c *UserContext) Upstream() any { - return c.Context -} diff --git a/protocol/trojan/service.go b/protocol/trojan/service.go index f8ef7a7..f0b51c5 100644 --- a/protocol/trojan/service.go +++ b/protocol/trojan/service.go @@ -6,6 +6,7 @@ import ( "net" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" @@ -19,16 +20,6 @@ type Handler interface { N.UDPConnectionHandler } -type Context[K comparable] struct { - context.Context - User K - Key [KeyLength]byte -} - -func (ctx *Context[K]) Upstream() any { - return ctx.Context -} - type Service[K comparable] struct { handler Handler keys map[[56]byte]K @@ -91,11 +82,8 @@ returnErr: process: - var userCtx Context[K] - userCtx.Context = ctx if user, loaded := s.keys[key]; loaded { - userCtx.User = user - userCtx.Key = key + ctx = auth.ContextWithUser(ctx, user) } else { err = E.New("bad request") goto returnErr @@ -134,9 +122,9 @@ process: metadata.Destination = destination if command == CommandTCP { - return s.handler.NewConnection(&userCtx, conn, metadata) + return s.handler.NewConnection(ctx, conn, metadata) } else { - return s.handler.NewPacketConnection(&userCtx, &PacketConn{conn}, metadata) + return s.handler.NewPacketConnection(ctx, &PacketConn{conn}, metadata) } }