diff --git a/cli/portal/portal-v2board/main.go b/cli/portal/portal-v2board/main.go index 38a0d22..9507d5e 100644 --- a/cli/portal/portal-v2board/main.go +++ b/cli/portal/portal-v2board/main.go @@ -4,8 +4,6 @@ import ( "context" "crypto/tls" "encoding/json" - "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" "io/ioutil" "net" "os" @@ -18,7 +16,9 @@ import ( "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/random" + "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/protocol/trojan" transTLS "github.com/sagernet/sing/transport/tls" diff --git a/cli/ss-local/main.go b/cli/ss-local/main.go index 63e2db4..68ca16f 100644 --- a/cli/ss-local/main.go +++ b/cli/ss-local/main.go @@ -325,7 +325,7 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me } _payload := buf.StackNew() payload := common.Dup(_payload) - err = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + err = conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) if err != nil { return err } diff --git a/cli/trojan-local/main.go b/cli/trojan-local/main.go new file mode 100644 index 0000000..9f9fad3 --- /dev/null +++ b/cli/trojan-local/main.go @@ -0,0 +1,338 @@ +package main + +import ( + "context" + cTLS "crypto/tls" + "encoding/json" + "io/ioutil" + "net" + "net/netip" + "os" + "os/signal" + "syscall" + "time" + + "github.com/refraction-networking/utls" + "github.com/sagernet/sing" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/redir" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/trojan" + "github.com/sagernet/sing/transport/mixed" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +const udpTimeout = 5 * 60 + +type flags struct { + Server string `json:"server"` + ServerPort uint16 `json:"server_port"` + ServerName string `json:"server_name"` + Bind string `json:"local_address"` + LocalPort uint16 `json:"local_port"` + Password string `json:"password"` + Verbose bool `json:"verbose"` + Insecure bool `json:"insecure"` + ConfigFile string +} + +func main() { + f := new(flags) + + command := &cobra.Command{ + Use: "trojan-local", + Short: "trojan client", + Version: sing.VersionStr, + Run: func(cmd *cobra.Command, args []string) { + run(cmd, f) + }, + } + + command.Flags().StringVarP(&f.Server, "server", "s", "", "Set the server’s hostname or IP.") + command.Flags().Uint16VarP(&f.ServerPort, "server-port", "p", 0, "Set the server’s port number.") + command.Flags().StringVarP(&f.Bind, "local-address", "b", "", "Set the local address.") + command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.") + command.Flags().StringVarP(&f.Password, "password", "k", "", "Set the password. The server and the client should use the same password.") + command.Flags().BoolVarP(&f.Insecure, "insecure", "i", false, "Set insecure.") + command.Flags().StringVarP(&f.ConfigFile, "config", "c", "", "Use a configuration file.") + command.Flags().BoolVarP(&f.Verbose, "verbose", "v", false, "Set verbose mode.") + + err := command.Execute() + if err != nil { + logrus.Fatal(err) + } +} + +func run(cmd *cobra.Command, f *flags) { + c, err := newClient(f) + if err != nil { + logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n") + cmd.Help() + os.Exit(1) + } + err = c.Start() + if err != nil { + logrus.Fatal(err) + } + + logrus.Info("mixed server started at ", c.TCPListener.Addr()) + + osSignals := make(chan os.Signal, 1) + signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM) + <-osSignals + + c.Close() +} + +type client struct { + *mixed.Listener + server string + key [trojan.KeyLength]byte + sni string + insecure bool + dialer net.Dialer +} + +func newClient(f *flags) (*client, error) { + if f.ConfigFile != "" { + configFile, err := ioutil.ReadFile(f.ConfigFile) + if err != nil { + return nil, E.Cause(err, "read config file") + } + flagsNew := new(flags) + err = json.Unmarshal(configFile, flagsNew) + if err != nil { + return nil, E.Cause(err, "decode config file") + } + if flagsNew.Server != "" && f.Server == "" { + f.Server = flagsNew.Server + } + if flagsNew.ServerPort != 0 && f.ServerPort == 0 { + f.ServerPort = flagsNew.ServerPort + } + if flagsNew.Bind != "" && f.Bind == "" { + f.Bind = flagsNew.Bind + } + if flagsNew.LocalPort != 0 && f.LocalPort == 0 { + f.LocalPort = flagsNew.LocalPort + } + if flagsNew.Password != "" && f.Password == "" { + f.Password = flagsNew.Password + } + if flagsNew.Insecure { + f.Insecure = true + } + if flagsNew.Verbose { + f.Verbose = true + } + } + + if f.Verbose { + logrus.SetLevel(logrus.TraceLevel) + } + + if f.Server == "" { + return nil, E.New("missing server address") + } else if f.ServerPort == 0 { + return nil, E.New("missing server port") + } + + c := &client{ + server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort).String(), + key: trojan.Key(f.Password), + sni: f.ServerName, + insecure: f.Insecure, + } + if c.sni == "" { + c.sni = f.Server + } + + var bind netip.Addr + if f.Bind != "" { + addr, err := netip.ParseAddr(f.Bind) + if err != nil { + return nil, E.Cause(err, "bad local address") + } + bind = addr + } else { + bind = netip.IPv6Unspecified() + } + + c.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, redir.ModeDisabled, udpTimeout, c) + return c, nil +} + +func (c *client) connect(ctx context.Context) (*cTLS.Conn, error) { + tcpConn, err := c.dialer.DialContext(ctx, "tcp", c.server) + if err != nil { + return nil, err + } + + tlsConn := cTLS.Client(tcpConn, &cTLS.Config{ + ServerName: c.sni, + InsecureSkipVerify: c.insecure, + }) + return tlsConn, nil +} + +func (c *client) connectUTLS(ctx context.Context) (*tls.UConn, error) { + tcpConn, err := c.dialer.DialContext(ctx, "tcp", c.server) + if err != nil { + return nil, err + } + + tlsConn := tls.UClient(tcpConn, &tls.Config{ + ServerName: c.sni, + InsecureSkipVerify: c.insecure, + }, tls.HelloCustom) + clientHelloSpec := tls.ClientHelloSpec{ + CipherSuites: []uint16{ + tls.GREASE_PLACEHOLDER, + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.DISABLED_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.DISABLED_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + tls.DISABLED_TLS_RSA_WITH_AES_256_CBC_SHA256, + tls.TLS_RSA_WITH_AES_128_CBC_SHA256, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + 0xc008, + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + }, + CompressionMethods: []byte{ + 0x00, // compressionNone + }, + Extensions: []tls.TLSExtension{ + &tls.UtlsGREASEExtension{}, + &tls.SNIExtension{}, + &tls.UtlsExtendedMasterSecretExtension{}, + &tls.RenegotiationInfoExtension{Renegotiation: tls.RenegotiateOnceAsClient}, + &tls.SupportedCurvesExtension{Curves: []tls.CurveID{ + tls.CurveID(tls.GREASE_PLACEHOLDER), + tls.X25519, + tls.CurveP256, + tls.CurveP384, + tls.CurveP521, + }}, + &tls.SupportedPointsExtension{SupportedPoints: []byte{ + 0x00, // pointFormatUncompressed + }}, + &tls.ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, + &tls.StatusRequestExtension{}, + &tls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []tls.SignatureScheme{ + tls.ECDSAWithP256AndSHA256, + tls.PSSWithSHA256, + tls.PKCS1WithSHA256, + tls.ECDSAWithP384AndSHA384, + tls.ECDSAWithSHA1, + tls.PSSWithSHA384, + tls.PSSWithSHA384, + tls.PKCS1WithSHA384, + tls.PSSWithSHA512, + tls.PKCS1WithSHA512, + tls.PKCS1WithSHA1, + }}, + &tls.SCTExtension{}, + &tls.KeyShareExtension{KeyShares: []tls.KeyShare{ + {Group: tls.CurveID(tls.GREASE_PLACEHOLDER), Data: []byte{0}}, + {Group: tls.X25519}, + }}, + &tls.PSKKeyExchangeModesExtension{Modes: []uint8{ + tls.PskModeDHE, + }}, + &tls.SupportedVersionsExtension{Versions: []uint16{ + tls.GREASE_PLACEHOLDER, + tls.VersionTLS13, + tls.VersionTLS12, + tls.VersionTLS11, + tls.VersionTLS10, + }}, + &tls.UtlsGREASEExtension{}, + &tls.UtlsPaddingExtension{GetPaddingLen: tls.BoringPaddingStyle}, + }, + } + err = tlsConn.ApplyPreset(&clientHelloSpec) + if err != nil { + tcpConn.Close() + return nil, err + } + return tlsConn, nil +} + +func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + logrus.Info("outbound ", metadata.Protocol, " TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination) + + tlsConn, err := c.connect(ctx) + if err != nil { + return err + } + clientConn := trojan.NewClientConn(tlsConn, c.key, metadata.Destination) + + err = conn.SetReadDeadline(time.Now().Add(300 * time.Millisecond)) + if err != nil { + return err + } + + _request := buf.StackNew() + request := common.Dup(_request) + _, err = request.ReadFrom(conn) + if err != nil && !E.IsTimeout(err) { + return E.Cause(err, "read payload") + } + + err = conn.SetReadDeadline(time.Time{}) + if err != nil { + return err + } + + _, err = clientConn.Write(request.Bytes()) + if err != nil { + return E.Cause(err, "client handshake") + } + + return rw.CopyConn(ctx, clientConn, conn) +} + +func (c *client) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error { + logrus.Info("outbound ", metadata.Protocol, " UDP ", metadata.Source, " ==> ", metadata.Destination) + + tlsConn, err := c.connect(ctx) + if err != nil { + return err + } + /*err = trojan.ClientHandshakeRaw(tlsConn, c.key, trojan.CommandUDP, metadata.Destination, nil) + if err != nil { + return err + } + return socks.CopyPacketConn(ctx, &trojan.PacketConn{Conn: tlsConn}, conn)*/ + clientConn := trojan.NewClientPacketConn(tlsConn, c.key) + return socks.CopyPacketConn(ctx, clientConn, conn) +} + +func (c *client) HandleError(err error) { + if E.IsClosed(err) { + return + } + logrus.Warn(err) +} diff --git a/cli/trojan-server/main.go b/cli/trojan-server/main.go new file mode 100644 index 0000000..e633f51 --- /dev/null +++ b/cli/trojan-server/main.go @@ -0,0 +1,192 @@ +package main + +import ( + "context" + "crypto/tls" + "encoding/json" + "io/ioutil" + "net" + "net/netip" + "os" + "os/signal" + "syscall" + + "github.com/sagernet/sing" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/random" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/trojan" + "github.com/sagernet/sing/transport/tcp" + transTLS "github.com/sagernet/sing/transport/tls" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +const udpTimeout = 5 * 60 + +type flags struct { + Server string `json:"server"` + ServerPort uint16 `json:"server_port"` + ServerName string `json:"server_name"` + Bind string `json:"local_address"` + LocalPort uint16 `json:"local_port"` + Password string `json:"password"` + Verbose bool `json:"verbose"` + Insecure bool `json:"insecure"` + ConfigFile string +} + +func main() { + f := new(flags) + + command := &cobra.Command{ + Use: "trojan-local", + Short: "trojan client", + Version: sing.VersionStr, + Run: func(cmd *cobra.Command, args []string) { + run(cmd, f) + }, + } + + command.Flags().StringVarP(&f.Server, "server", "s", "", "Set the server’s hostname or IP.") + command.Flags().Uint16VarP(&f.ServerPort, "server-port", "p", 0, "Set the server’s port number.") + command.Flags().StringVarP(&f.Bind, "local-address", "b", "", "Set the local address.") + command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.") + command.Flags().StringVarP(&f.Password, "password", "k", "", "Set the password. The server and the client should use the same password.") + command.Flags().BoolVarP(&f.Insecure, "insecure", "i", false, "Set insecure.") + command.Flags().StringVarP(&f.ConfigFile, "config", "c", "", "Use a configuration file.") + command.Flags().BoolVarP(&f.Verbose, "verbose", "v", false, "Set verbose mode.") + + err := command.Execute() + if err != nil { + logrus.Fatal(err) + } +} + +func run(cmd *cobra.Command, f *flags) { + c, err := newServer(f) + if err != nil { + logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n") + cmd.Help() + os.Exit(1) + } + err = c.tcpIn.Start() + if err != nil { + logrus.Fatal(err) + } + + logrus.Info("server started at ", c.tcpIn.Addr()) + + osSignals := make(chan os.Signal, 1) + signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM) + <-osSignals + + c.tcpIn.Close() +} + +type server struct { + tcpIn *tcp.Listener + service trojan.Service[int] +} + +func newServer(f *flags) (*server, error) { + s := new(server) + + if f.ConfigFile != "" { + configFile, err := ioutil.ReadFile(f.ConfigFile) + if err != nil { + return nil, E.Cause(err, "read config file") + } + flagsNew := new(flags) + err = json.Unmarshal(configFile, flagsNew) + if err != nil { + return nil, E.Cause(err, "decode config file") + } + if flagsNew.Server != "" && f.Server == "" { + f.Server = flagsNew.Server + } + if flagsNew.ServerPort != 0 && f.ServerPort == 0 { + f.ServerPort = flagsNew.ServerPort + } + if flagsNew.Bind != "" && f.Bind == "" { + f.Bind = flagsNew.Bind + } + if flagsNew.LocalPort != 0 && f.LocalPort == 0 { + f.LocalPort = flagsNew.LocalPort + } + if flagsNew.Password != "" && f.Password == "" { + f.Password = flagsNew.Password + } + if flagsNew.Insecure { + f.Insecure = true + } + if flagsNew.Verbose { + f.Verbose = true + } + } + + if f.Verbose { + logrus.SetLevel(logrus.TraceLevel) + } + + if f.Server == "" { + return nil, E.New("missing server address") + } else if f.ServerPort == 0 { + return nil, E.New("missing server port") + } + + var bind netip.Addr + if f.Server != "" { + addr, err := netip.ParseAddr(f.Server) + if err != nil { + return nil, E.Cause(err, "bad server address") + } + bind = addr + } else { + bind = netip.IPv6Unspecified() + } + s.service = trojan.NewService[int](s) + common.Must(s.service.AddUser(0, f.Password)) + s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.ServerPort), s) + return s, nil +} + +func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + if metadata.Protocol != "trojan" { + logrus.Trace("inbound raw TCP from ", metadata.Source) + tlsConn := tls.Server(conn, &tls.Config{ + Rand: random.Blake3KeyedHash(), + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + return transTLS.GenerateCertificate(info.ServerName) + }, + }) + return s.service.NewConnection(ctx, tlsConn, metadata) + } + destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination) + if err != nil { + return err + } + logrus.Info("inbound TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination) + return rw.CopyConn(ctx, conn, destConn) +} + +func (s *server) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error { + logrus.Info("inbound UDP ", metadata.Source, " ==> ", metadata.Destination) + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return err + } + return socks.CopyNetPacketConn(ctx, udpConn, conn) +} + +func (s *server) HandleError(err error) { + common.Close(err) + if E.IsClosed(err) { + return + } + logrus.Warn(err) +} diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 2553d60..b178b97 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -94,6 +94,9 @@ func (b *Buffer) Write(data []byte) (n int, err error) { if b.IsFull() { return 0, io.ErrShortBuffer } + if b.end+len(data) > b.Cap() { + panic("buffer overflow") + } n = copy(b.data[b.end:], data) b.end += n return @@ -102,7 +105,7 @@ func (b *Buffer) Write(data []byte) (n int, err error) { func (b *Buffer) ExtendHeader(size int) []byte { if b.start >= size { b.start -= size - return b.data[b.start-size : b.start] + return b.data[b.start : b.start+size] } else { /*offset := size - b.start end := b.end + size @@ -326,7 +329,7 @@ func (b Buffer) Len() int { } func (b Buffer) Cap() int { - return cap(b.data) + return len(b.data) } func (b Buffer) Bytes() []byte { diff --git a/common/metadata/serializer.go b/common/metadata/serializer.go index ef764e7..d0c8e25 100644 --- a/common/metadata/serializer.go +++ b/common/metadata/serializer.go @@ -62,7 +62,7 @@ func (s *Serializer) AddressLen(addr Addr) int { case AddressFamilyIPv6: return 17 default: - return 1 + len(addr.Fqdn()) + return 2 + len(addr.Fqdn()) } } diff --git a/common/rw/copy.go b/common/rw/copy.go index bcb2b48..bc0827c 100644 --- a/common/rw/copy.go +++ b/common/rw/copy.go @@ -44,17 +44,42 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error { err := task.Run(ctx, func() error { defer CloseRead(conn) defer CloseWrite(dest) - return common.Error(io.Copy(dest, conn)) + return common.Error(Copy(dest, conn)) }, func() error { defer CloseRead(dest) defer CloseWrite(conn) - return common.Error(io.Copy(conn, dest)) + return common.Error(Copy(conn, dest)) }) conn.Close() dest.Close() return err } +func Copy(dst io.Writer, src io.Reader) (n int64, err error) { + if wt, ok := src.(io.WriterTo); ok { + return wt.WriteTo(dst) + } + if rt, ok := dst.(io.ReaderFrom); ok { + return rt.ReadFrom(src) + } + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) + for { + buffer.FullReset() + _, err = buffer.ReadFrom(src) + if err != nil { + break + } + var cn int + cn, err = dst.Write(buffer.Bytes()) + if err != nil { + return + } + n += int64(cn) + } + return +} + func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error { return task.Run(ctx, func() error { _buffer := buf.With(make([]byte, buf.UDPBufferSize)) diff --git a/go.mod b/go.mod index 1aec491..f6993e8 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/oschwald/maxminddb-golang v1.9.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/refraction-networking/utls v1.1.0 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect diff --git a/go.sum b/go.sum index 61aebe6..c07378a 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/refraction-networking/utls v1.1.0 h1:dKXJwSqni/t5csYJ+aQcEgqB7AMWYi6EUc9u3bEmjX0= +github.com/refraction-networking/utls v1.1.0/go.mod h1:tz9gX959MEFfFN5whTIocCLUG57WiILqtdVxI8c6Wj0= github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg= github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3/go.mod h1:HgjTstvQsPGkxUsCd2KWxErBblirPizecHcpD3ffK+s= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= diff --git a/protocol/trojan/protocol.go b/protocol/trojan/protocol.go index 79c4279..1d564ca 100644 --- a/protocol/trojan/protocol.go +++ b/protocol/trojan/protocol.go @@ -113,16 +113,45 @@ func Key(password string) [KeyLength]byte { return key } -func ClientHandshake(conn net.Conn, key [56]byte, destination *M.AddrPort, payload []byte) error { - bufferLen := KeyLength + socks.AddressSerializer.AddrPortLen(destination) + 5 +func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination *M.AddrPort, payload []byte) error { + _, err := conn.Write(key[:]) + if err != nil { + return err + } + _, err = conn.Write(CRLF[:]) + if err != nil { + return err + } + _, err = conn.Write([]byte{command}) + if err != nil { + return err + } + err = socks.AddressSerializer.WriteAddrPort(conn, destination) + if err != nil { + return err + } + _, err = conn.Write(CRLF[:]) + if err != nil { + return err + } + if len(payload) > 0 { + _, err = conn.Write(payload) + if err != nil { + return err + } + } + return nil +} + +func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort, payload []byte) error { + headerLen := KeyLength + socks.AddressSerializer.AddrPortLen(destination) + 5 var header *buf.Buffer var writeHeader bool - if len(payload) > 0 && bufferLen+len(payload) < 65535 { - buffer := buf.Make(bufferLen + len(payload)) - copy(buffer[bufferLen:], payload) + if len(payload) > 0 && headerLen+len(payload) < 65535 { + buffer := buf.Make(headerLen + len(payload)) header = buf.With(common.Dup(buffer)) } else { - buffer := buf.Make(bufferLen) + buffer := buf.Make(headerLen) header = buf.With(common.Dup(buffer)) writeHeader = true } @@ -131,11 +160,13 @@ func ClientHandshake(conn net.Conn, key [56]byte, destination *M.AddrPort, paylo common.Must(header.WriteByte(CommandTCP)) common.Must(socks.AddressSerializer.WriteAddrPort(header, destination)) common.Must1(header.Write(CRLF)) + common.Must1(header.Write(payload)) _, err := conn.Write(header.Bytes()) if err != nil { return E.Cause(err, "write request") } + if writeHeader { _, err = conn.Write(payload) if err != nil { @@ -145,14 +176,15 @@ func ClientHandshake(conn net.Conn, key [56]byte, destination *M.AddrPort, paylo return nil } -func ClientHandshakePacket(conn net.Conn, key [56]byte, destination *M.AddrPort, payload *buf.Buffer) error { - bufferLen := KeyLength + 2*socks.AddressSerializer.AddrPortLen(destination) + 9 +func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort, payload *buf.Buffer) error { + headerLen := KeyLength + 2*socks.AddressSerializer.AddrPortLen(destination) + 9 + payloadLen := payload.Len() var header *buf.Buffer var writeHeader bool - if payload.Start() >= bufferLen { - header = buf.With(payload.ExtendHeader(bufferLen)) + if payload.Start() >= headerLen { + header = buf.With(payload.ExtendHeader(headerLen)) } else { - buffer := buf.Make(bufferLen) + buffer := buf.Make(headerLen) header = buf.With(common.Dup(buffer)) writeHeader = true } @@ -162,19 +194,20 @@ func ClientHandshakePacket(conn net.Conn, key [56]byte, destination *M.AddrPort, common.Must(socks.AddressSerializer.WriteAddrPort(header, destination)) common.Must1(header.Write(CRLF)) common.Must(socks.AddressSerializer.WriteAddrPort(header, destination)) - common.Must(binary.Write(header, binary.BigEndian, uint16(payload.Len()))) + common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen))) common.Must1(header.Write(CRLF)) - _, err := conn.Write(header.Bytes()) - if err != nil { - return E.Cause(err, "write request") - } if writeHeader { - _, err = conn.Write(payload.Bytes()) + _, err := conn.Write(header.Bytes()) if err != nil { - return E.Cause(err, "write payload") + return E.Cause(err, "write request") } } + + _, err := conn.Write(payload.Bytes()) + if err != nil { + return E.Cause(err, "write payload") + } return nil } @@ -207,6 +240,7 @@ func WritePacket(conn net.Conn, buffer *buf.Buffer, destination *M.AddrPort) err headerOverload := socks.AddressSerializer.AddrPortLen(destination) + 4 var header *buf.Buffer var writeHeader bool + bufferLen := buffer.Len() if buffer.Start() >= headerOverload { header = buf.With(buffer.ExtendHeader(headerOverload)) } else { @@ -214,25 +248,16 @@ func WritePacket(conn net.Conn, buffer *buf.Buffer, destination *M.AddrPort) err _buffer := buf.Make(headerOverload) header = buf.With(common.Dup(_buffer)) } - err := socks.AddressSerializer.WriteAddrPort(header, destination) - if err != nil { - return E.Cause(err, "write socks addr") - } - err = binary.Write(header, binary.BigEndian, uint16(buffer.Len())) - if err != nil { - return E.Cause(err, "write chunk length") - } - _, err = header.Write(CRLF) - if err != nil { - return E.Cause(err, "write crlf") - } + common.Must(socks.AddressSerializer.WriteAddrPort(header, destination)) + common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen))) + common.Must1(header.Write(CRLF)) if writeHeader { - _, err = conn.Write(header.Bytes()) + _, err := conn.Write(header.Bytes()) if err != nil { return E.Cause(err, "write packet header") } } - _, err = conn.Write(buffer.Bytes()) + _, err := conn.Write(buffer.Bytes()) if err != nil { return E.Cause(err, "write packet") } diff --git a/protocol/trojan/service.go b/protocol/trojan/service.go index fe64aad..3a47730 100644 --- a/protocol/trojan/service.go +++ b/protocol/trojan/service.go @@ -133,20 +133,20 @@ process: if command == CommandTCP { return s.handler.NewConnection(&userCtx, conn, metadata) } else { - return s.handler.NewPacketConnection(&userCtx, &packetConn{conn}, metadata) + return s.handler.NewPacketConnection(&userCtx, &PacketConn{conn}, metadata) } } -type packetConn struct { +type PacketConn struct { net.Conn } -func (c *packetConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { - return ReadPacket(c, buffer) +func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { + return ReadPacket(c.Conn, buffer) } -func (c *packetConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { - return WritePacket(c, buffer, destination) +func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { + return WritePacket(c.Conn, buffer, destination) } type Error struct { diff --git a/transport/tls/cert.go b/transport/tls/cert.go index c018c84..0f08ba6 100644 --- a/transport/tls/cert.go +++ b/transport/tls/cert.go @@ -43,8 +43,8 @@ func GenerateCertificate(hosts ...string) (*tls.Certificate, error) { Subject: pkix.Name{ Organization: []string{"Cloudflare, Inc."}, }, - NotBefore: endAt, - NotAfter: createAt, + NotBefore: createAt, + NotAfter: endAt, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},