This commit is contained in:
世界 2024-11-23 22:34:02 +08:00
parent 68781387fe
commit 0c66888691
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
32 changed files with 263 additions and 267 deletions

View file

@ -4,17 +4,19 @@ import (
std_bufio "bufio"
"context"
"net"
"os"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/smux"
)
func HandleMuxConnection(ctx context.Context, conn net.Conn, metadata M.Metadata, handler Handler, logger logger.ContextLogger) error {
func HandleMuxConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, handler Handler, logger logger.ContextLogger, onClose N.CloseHandlerFunc) error {
session, err := smux.Server(conn, smuxConfig())
if err != nil {
return err
@ -27,29 +29,32 @@ func HandleMuxConnection(ctx context.Context, conn net.Conn, metadata M.Metadata
if err != nil {
return err
}
go newMuxConnection(ctx, stream, metadata, handler, logger)
go newMuxConnection(ctx, stream, source, handler, logger)
}
})
group.Cleanup(func() {
session.Close()
if onClose != nil {
onClose(os.ErrClosed)
}
})
return group.Run(ctx)
}
func newMuxConnection(ctx context.Context, conn net.Conn, metadata M.Metadata, handler Handler, logger logger.ContextLogger) {
err := newMuxConnection0(ctx, conn, metadata, handler)
func newMuxConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, handler Handler, logger logger.ContextLogger) {
err := newMuxConnection0(ctx, conn, source, handler)
if err != nil {
logger.ErrorContext(ctx, E.Cause(err, "process trojan-go multiplex connection"))
}
}
func newMuxConnection0(ctx context.Context, conn net.Conn, metadata M.Metadata, handler Handler) error {
func newMuxConnection0(ctx context.Context, conn net.Conn, source M.Socksaddr, handler Handler) error {
reader := std_bufio.NewReader(conn)
command, err := reader.ReadByte()
if err != nil {
return E.Cause(err, "read command")
}
metadata.Destination, err = M.SocksaddrSerializer.ReadAddrPort(reader)
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return E.Cause(err, "read destination")
}
@ -63,12 +68,13 @@ func newMuxConnection0(ctx context.Context, conn net.Conn, metadata M.Metadata,
}
switch command {
case CommandTCP:
return handler.NewConnection(ctx, conn, metadata)
handler.NewConnectionEx(ctx, conn, source, destination, nil)
case CommandUDP:
return handler.NewPacketConnection(ctx, &PacketConn{Conn: conn}, metadata)
handler.NewPacketConnectionEx(ctx, &PacketConn{Conn: conn}, source, destination, nil)
default:
return E.New("unknown command ", command)
}
return nil
}
func smuxConfig() *smux.Config {

View file

@ -16,19 +16,19 @@ import (
)
type Handler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
N.TCPConnectionHandlerEx
N.UDPConnectionHandlerEx
}
type Service[K comparable] struct {
users map[K][56]byte
keys map[[56]byte]K
handler Handler
fallbackHandler N.TCPConnectionHandler
fallbackHandler N.TCPConnectionHandlerEx
logger logger.ContextLogger
}
func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandler, logger logger.ContextLogger) *Service[K] {
func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandlerEx, logger logger.ContextLogger) *Service[K] {
return &Service[K]{
users: make(map[K][56]byte),
keys: make(map[[56]byte]K),
@ -59,19 +59,19 @@ func (s *Service[K]) UpdateUsers(userList []K, passwordList []string) error {
return nil
}
func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc) error {
var key [KeyLength]byte
n, err := conn.Read(key[:])
if err != nil {
return err
} else if n != KeyLength {
return s.fallback(ctx, conn, metadata, key[:n], E.New("bad request size"))
return s.fallback(ctx, conn, source, key[:n], E.New("bad request size"), onClose)
}
if user, loaded := s.keys[key]; loaded {
ctx = auth.ContextWithUser(ctx, user)
} else {
return s.fallback(ctx, conn, metadata, key[:], E.New("bad request"))
return s.fallback(ctx, conn, source, key[:], E.New("bad request"), onClose)
}
err = rw.SkipN(conn, 2)
@ -102,26 +102,25 @@ func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata
return E.Cause(err, "skip crlf")
}
metadata.Protocol = "trojan"
metadata.Destination = destination
switch command {
case CommandTCP:
return s.handler.NewConnection(ctx, conn, metadata)
s.handler.NewConnectionEx(ctx, conn, source, destination, onClose)
case CommandUDP:
return s.handler.NewPacketConnection(ctx, &PacketConn{Conn: conn}, metadata)
s.handler.NewPacketConnectionEx(ctx, &PacketConn{Conn: conn}, source, destination, onClose)
// case CommandMux:
default:
return HandleMuxConnection(ctx, conn, metadata, s.handler, s.logger)
return HandleMuxConnection(ctx, conn, source, s.handler, s.logger, onClose)
}
return nil
}
func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Metadata, header []byte, err error) error {
func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, source M.Socksaddr, header []byte, err error, onClose N.CloseHandlerFunc) error {
if s.fallbackHandler == nil {
return E.Extend(err, "fallback disabled")
}
conn = bufio.NewCachedConn(conn, buf.As(header).ToOwned())
return s.fallbackHandler.NewConnection(ctx, conn, metadata)
s.fallbackHandler.NewConnectionEx(ctx, conn, source, M.Socksaddr{}, onClose)
return nil
}
type PacketConn struct {