diff --git a/protocol/trojan/protocol.go b/protocol/trojan/protocol.go deleted file mode 100644 index be45cae..0000000 --- a/protocol/trojan/protocol.go +++ /dev/null @@ -1,312 +0,0 @@ -package trojan - -import ( - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "io" - "net" - "os" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" -) - -const ( - KeyLength = 56 - CommandTCP = 1 - CommandUDP = 3 -) - -var CRLF = []byte{'\r', '\n'} - -type ClientConn struct { - N.ExtendedConn - key [KeyLength]byte - destination M.Socksaddr - headerWritten bool -} - -func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn { - return &ClientConn{ - ExtendedConn: bufio.NewExtendedConn(conn), - key: key, - destination: destination, - } -} - -func (c *ClientConn) Write(p []byte) (n int, err error) { - if c.headerWritten { - return c.ExtendedConn.Write(p) - } - err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p) - if err != nil { - return - } - n = len(p) - c.headerWritten = true - return -} - -func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error { - if c.headerWritten { - return c.ExtendedConn.WriteBuffer(buffer) - } - err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer) - if err != nil { - return err - } - c.headerWritten = true - return nil -} - -func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) { - if !c.headerWritten { - return bufio.ReadFrom0(c, r) - } - return bufio.Copy(c.ExtendedConn, r) -} - -func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) { - return bufio.Copy(w, c.ExtendedConn) -} - -func (c *ClientConn) FrontHeadroom() int { - if !c.headerWritten { - return KeyLength + 5 + M.MaxSocksaddrLength - } - return 0 -} - -func (c *ClientConn) Upstream() any { - return c.ExtendedConn -} - -type ClientPacketConn struct { - net.Conn - key [KeyLength]byte - headerWritten bool -} - -func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn { - return &ClientPacketConn{ - Conn: conn, - key: key, - } -} - -func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { - return ReadPacket(c.Conn, buffer) -} - -func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - if !c.headerWritten { - err := ClientHandshakePacket(c.Conn, c.key, destination, buffer) - c.headerWritten = true - return err - } - return WritePacket(c.Conn, buffer, destination) -} - -func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - buffer := buf.With(p) - destination, err := c.ReadPacket(buffer) - if err != nil { - return - } - n = buffer.Len() - addr = destination.UDPAddr() - return -} - -func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - return bufio.WritePacket(c, p, addr) -} - -func (c *ClientPacketConn) Read(p []byte) (n int, err error) { - n, _, err = c.ReadFrom(p) - return -} - -func (c *ClientPacketConn) Write(p []byte) (n int, err error) { - return 0, os.ErrInvalid -} - -func (c *ClientPacketConn) FrontHeadroom() int { - if !c.headerWritten { - return KeyLength + 2*M.MaxSocksaddrLength + 9 - } - return M.MaxSocksaddrLength + 4 -} - -func (c *ClientPacketConn) Upstream() any { - return c.Conn -} - -func Key(password string) [KeyLength]byte { - var key [KeyLength]byte - hash := sha256.New224() - common.Must1(hash.Write([]byte(password))) - hex.Encode(key[:], hash.Sum(nil)) - return key -} - -func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error { - _, err := conn.Write(key[:]) - if err != nil { - return err - } - _, err = conn.Write(CRLF) - if err != nil { - return err - } - _, err = conn.Write([]byte{command}) - if err != nil { - return err - } - err = M.SocksaddrSerializer.WriteAddrPort(conn, destination) - if err != nil { - return err - } - _, err = conn.Write(CRLF) - if err != nil { - return err - } - if len(payload) > 0 { - _, err = conn.Write(payload) - if err != nil { - return err - } - } - return nil -} - -func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error { - headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5 - var header *buf.Buffer - defer header.Release() - var writeHeader bool - if len(payload) > 0 && headerLen+len(payload) < 65535 { - buffer := buf.StackNewSize(headerLen + len(payload)) - defer common.KeepAlive(buffer) - header = common.Dup(buffer) - } else { - buffer := buf.StackNewSize(headerLen) - defer common.KeepAlive(buffer) - header = common.Dup(buffer) - writeHeader = true - } - common.Must1(header.Write(key[:])) - common.Must1(header.Write(CRLF)) - common.Must(header.WriteByte(CommandTCP)) - common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) - common.Must1(header.Write(CRLF)) - if !writeHeader { - common.Must1(header.Write(payload)) - } - - _, err := conn.Write(header.Bytes()) - if err != nil { - return E.Cause(err, "write request") - } - - if writeHeader { - _, err = conn.Write(payload) - if err != nil { - return E.Cause(err, "write payload") - } - } - return nil -} - -func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error { - header := buf.With(payload.ExtendHeader(KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5)) - common.Must1(header.Write(key[:])) - common.Must1(header.Write(CRLF)) - common.Must(header.WriteByte(CommandTCP)) - common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) - common.Must1(header.Write(CRLF)) - - _, err := conn.Write(payload.Bytes()) - if err != nil { - return E.Cause(err, "write request") - } - return nil -} - -func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error { - headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9 - payloadLen := payload.Len() - var header *buf.Buffer - defer header.Release() - var writeHeader bool - if payload.Start() >= headerLen { - header = buf.With(payload.ExtendHeader(headerLen)) - } else { - buffer := buf.StackNewSize(headerLen) - defer common.KeepAlive(buffer) - header = common.Dup(buffer) - writeHeader = true - } - common.Must1(header.Write(key[:])) - common.Must1(header.Write(CRLF)) - common.Must(header.WriteByte(CommandUDP)) - common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) - common.Must1(header.Write(CRLF)) - common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) - common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen))) - common.Must1(header.Write(CRLF)) - - if writeHeader { - _, err := conn.Write(header.Bytes()) - if err != nil { - return E.Cause(err, "write request") - } - } - - _, err := conn.Write(payload.Bytes()) - if err != nil { - return E.Cause(err, "write payload") - } - return nil -} - -func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) { - destination, err := M.SocksaddrSerializer.ReadAddrPort(conn) - if err != nil { - return M.Socksaddr{}, E.Cause(err, "read destination") - } - - var length uint16 - err = binary.Read(conn, binary.BigEndian, &length) - if err != nil { - return M.Socksaddr{}, E.Cause(err, "read chunk length") - } - - err = rw.SkipN(conn, 2) - if err != nil { - return M.Socksaddr{}, E.Cause(err, "skip crlf") - } - - _, err = buffer.ReadFullFrom(conn, int(length)) - return destination, err -} - -func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error { - defer buffer.Release() - bufferLen := buffer.Len() - header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4)) - common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) - common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen))) - common.Must1(header.Write(CRLF)) - _, err := conn.Write(buffer.Bytes()) - if err != nil { - return E.Cause(err, "write packet") - } - return nil -} diff --git a/protocol/trojan/service.go b/protocol/trojan/service.go deleted file mode 100644 index 8916487..0000000 --- a/protocol/trojan/service.go +++ /dev/null @@ -1,130 +0,0 @@ -package trojan - -import ( - "context" - "net" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/auth" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" -) - -type Handler interface { - N.TCPConnectionHandler - N.UDPConnectionHandler -} - -type Service[K comparable] struct { - users map[K][56]byte - keys map[[56]byte]K - handler Handler - fallbackHandler N.TCPConnectionHandler -} - -func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandler) *Service[K] { - return &Service[K]{ - users: make(map[K][56]byte), - keys: make(map[[56]byte]K), - handler: handler, - fallbackHandler: fallbackHandler, - } -} - -var ErrUserExists = E.New("user already exists") - -func (s *Service[K]) UpdateUsers(userList []K, passwordList []string) error { - users := make(map[K][56]byte) - keys := make(map[[56]byte]K) - for i, user := range userList { - if _, loaded := users[user]; loaded { - return ErrUserExists - } - key := Key(passwordList[i]) - if oldUser, loaded := keys[key]; loaded { - return E.Extend(ErrUserExists, "password used by ", oldUser) - } - users[user] = key - keys[key] = user - } - s.users = users - s.keys = keys - return nil -} - -func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - var key [KeyLength]byte - n, err := conn.Read(common.Dup(key[:])) - if err != nil { - return err - } else if n != KeyLength { - return s.fallback(ctx, conn, metadata, key[:n], E.New("bad request size")) - } - - if user, loaded := s.keys[key]; loaded { - ctx = auth.ContextWithUser(ctx, user) - } else { - return s.fallback(ctx, conn, metadata, key[:], E.New("bad request")) - } - - err = rw.SkipN(conn, 2) - if err != nil { - return E.Cause(err, "skip crlf") - } - - command, err := rw.ReadByte(conn) - if err != nil { - return E.Cause(err, "read command") - } - - if command != CommandTCP && command != CommandUDP { - return E.New("unknown command ", command) - } - - destination, err := M.SocksaddrSerializer.ReadAddrPort(conn) - if err != nil { - return E.Cause(err, "read destination") - } - - err = rw.SkipN(conn, 2) - if err != nil { - return E.Cause(err, "skip crlf") - } - - metadata.Protocol = "trojan" - metadata.Destination = destination - - if command == CommandTCP { - return s.handler.NewConnection(ctx, conn, metadata) - } else { - return s.handler.NewPacketConnection(ctx, &PacketConn{conn}, metadata) - } -} - -func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Metadata, header []byte, err error) error { - if s.fallbackHandler == nil { - return E.Extend(err, "fallback disabled") - } - conn = bufio.NewCachedConn(conn, buf.As(header).ToOwned()) - return s.fallbackHandler.NewConnection(ctx, conn, metadata) -} - -type PacketConn struct { - net.Conn -} - -func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { - return ReadPacket(c.Conn, buffer) -} - -func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - return WritePacket(c.Conn, buffer, destination) -} - -func (c *PacketConn) FrontHeadroom() int { - return M.MaxSocksaddrLength + 4 -}