From 00cd0d4b8f41657587d2489753ca911b53b407f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 10 Apr 2022 22:51:29 +0800 Subject: [PATCH] Refactor shadowsocks --- .gitignore | 3 +- README.md | 11 +- cli/{geosite_gen => gen-geosite}/main.go | 4 +- cli/{get_geoip => get-geoip}/main.go | 3 +- cli/ss-local/main.go | 374 ++++++++++++++++++ cli/sslocal/debug.go | 12 - cli/sslocal/sslocal.go | 324 --------------- cli/uot-local/main.go | 140 +++++++ common/auth/auth.go | 46 +++ common/buf/buffer.go | 13 + common/buf/io.go | 9 + common/cache/cache.go | 106 +++++ common/cache/cache_test.go | 70 ++++ common/cache/lrucache.go | 223 +++++++++++ common/cache/lrucache_test.go | 183 +++++++++ common/cond.go | 10 + common/exceptions/error.go | 4 + common/flush.go | 2 +- common/geoip/matcher.go | 3 +- common/geosite/matcher.go | 13 +- common/gsync/map.go | 13 +- common/{socksaddr => metadata}/addr.go | 84 +++- common/{socksaddr => metadata}/exception.go | 2 +- common/{socksaddr => metadata}/family.go | 2 +- common/metadata/metadata.go | 20 + common/{socksaddr => metadata}/serializer.go | 43 +- common/net.go | 15 + common/random/rng.go | 18 + common/redir/mode.go | 9 + common/redir/redir_linux.go | 22 +- common/redir/redir_other.go | 8 +- common/redir/tproxy_linux.go | 131 ++++++ common/redir/tproxy_other.go | 20 + common/replay/bloomring.go | 30 ++ common/replay/cuckoo.go | 50 +++ common/replay/filter.go | 5 + common/rw/copy.go | 12 + common/rw/varinat.go | 3 +- common/session/context.go | 10 +- common/udpnat/server.go | 108 +++++ common/uot/client.go | 80 ++++ common/uot/resolver.go | 21 + common/uot/server.go | 108 +++++ common/uot/uot.go | 13 + common/uot/uot_test.go | 39 ++ core.go | 5 +- go.mod | 16 +- go.sum | 30 +- .../mixed.go => protocol/http/listener.go | 147 ++----- {transport/system => protocol/http}/stub.s | 0 protocol/shadowsocks/cipher.go | 47 --- protocol/shadowsocks/cipher_none.go | 39 -- protocol/shadowsocks/client.go | 178 --------- protocol/shadowsocks/{config.go => method.go} | 0 protocol/shadowsocks/none.go | 153 +++++++ protocol/shadowsocks/protocol.go | 52 +-- .../{cipher_aead.go => shadowaead/aead.go} | 144 +------ protocol/shadowsocks/shadowaead/method.go | 332 ++++++++++++++++ protocol/socks/conn.go | 48 ++- protocol/socks/constant.go | 10 +- protocol/socks/handshake.go | 24 +- protocol/socks/listener.go | 148 +++++++ protocol/socks/protocol.go | 68 ++-- protocol/socks/protocol_test.go | 9 +- transport/mixed/listener.go | 86 ++++ transport/system/http.go | 10 - transport/system/sockopt_linux.go | 21 + transport/system/sockopt_other.go | 4 + transport/system/socks.go | 149 ------- transport/system/tcp.go | 57 --- transport/system/udp.go | 62 --- transport/tcp/handler.go | 91 +++++ transport/tcp/options.go | 11 + transport/udp/options.go | 11 + transport/udp/udp.go | 116 ++++++ 75 files changed, 3169 insertions(+), 1318 deletions(-) rename cli/{geosite_gen => gen-geosite}/main.go (100%) rename cli/{get_geoip => get-geoip}/main.go (99%) create mode 100644 cli/ss-local/main.go delete mode 100644 cli/sslocal/debug.go delete mode 100644 cli/sslocal/sslocal.go create mode 100644 cli/uot-local/main.go create mode 100644 common/auth/auth.go create mode 100644 common/cache/cache.go create mode 100644 common/cache/cache_test.go create mode 100644 common/cache/lrucache.go create mode 100644 common/cache/lrucache_test.go rename common/{socksaddr => metadata}/addr.go (57%) rename common/{socksaddr => metadata}/exception.go (91%) rename common/{socksaddr => metadata}/family.go (95%) create mode 100644 common/metadata/metadata.go rename common/{socksaddr => metadata}/serializer.go (76%) create mode 100644 common/net.go create mode 100644 common/random/rng.go create mode 100644 common/redir/mode.go create mode 100644 common/redir/tproxy_linux.go create mode 100644 common/redir/tproxy_other.go create mode 100644 common/replay/bloomring.go create mode 100644 common/replay/cuckoo.go create mode 100644 common/replay/filter.go create mode 100644 common/udpnat/server.go create mode 100644 common/uot/client.go create mode 100644 common/uot/resolver.go create mode 100644 common/uot/server.go create mode 100644 common/uot/uot.go create mode 100644 common/uot/uot_test.go rename transport/system/mixed.go => protocol/http/listener.go (50%) rename {transport/system => protocol/http}/stub.s (100%) delete mode 100644 protocol/shadowsocks/cipher.go delete mode 100644 protocol/shadowsocks/cipher_none.go delete mode 100644 protocol/shadowsocks/client.go rename protocol/shadowsocks/{config.go => method.go} (100%) create mode 100644 protocol/shadowsocks/none.go rename protocol/shadowsocks/{cipher_aead.go => shadowaead/aead.go} (56%) create mode 100644 protocol/shadowsocks/shadowaead/method.go create mode 100644 protocol/socks/listener.go create mode 100644 transport/mixed/listener.go delete mode 100644 transport/system/http.go delete mode 100644 transport/system/socks.go delete mode 100644 transport/system/tcp.go delete mode 100644 transport/system/udp.go create mode 100644 transport/tcp/handler.go create mode 100644 transport/tcp/options.go create mode 100644 transport/udp/options.go create mode 100644 transport/udp/udp.go diff --git a/.gitignore b/.gitignore index 9cf5737..968e0d9 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ /sing_* /*.json /Country.mmdb -/geosite.dat \ No newline at end of file +/geosite.dat +/vendor/ \ No newline at end of file diff --git a/README.md b/README.md index 95c7df5..a6be57f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,12 @@ # sing -Do you hear the people sing? \ No newline at end of file +Do you hear the people sing? + +```shell +# geo resources +go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/get-geoip +go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/gen-geosite + +# ss-local +go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/ss-local +``` \ No newline at end of file diff --git a/cli/geosite_gen/main.go b/cli/gen-geosite/main.go similarity index 100% rename from cli/geosite_gen/main.go rename to cli/gen-geosite/main.go index 8da6a12..b35d236 100644 --- a/cli/geosite_gen/main.go +++ b/cli/gen-geosite/main.go @@ -2,12 +2,12 @@ package main import ( "encoding/binary" - "github.com/klauspost/compress/zstd" - "github.com/sagernet/sing/common/rw" "io" "net/http" "os" + "github.com/klauspost/compress/zstd" + "github.com/sagernet/sing/common/rw" "github.com/sirupsen/logrus" "github.com/ulikunitz/xz" "github.com/v2fly/v2ray-core/v5/app/router/routercommon" diff --git a/cli/get_geoip/main.go b/cli/get-geoip/main.go similarity index 99% rename from cli/get_geoip/main.go rename to cli/get-geoip/main.go index 56618e7..7897d41 100644 --- a/cli/get_geoip/main.go +++ b/cli/get-geoip/main.go @@ -1,10 +1,11 @@ package main import ( - "github.com/sirupsen/logrus" "io" "net/http" "os" + + "github.com/sirupsen/logrus" ) func main() { diff --git a/cli/ss-local/main.go b/cli/ss-local/main.go new file mode 100644 index 0000000..cf2464e --- /dev/null +++ b/cli/ss-local/main.go @@ -0,0 +1,374 @@ +package main + +import ( + "context" + "encoding/base64" + "encoding/json" + "io" + "io/ioutil" + "net" + "net/netip" + "os" + "os/signal" + "runtime/debug" + "syscall" + "time" + + "github.com/sagernet/sing" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/geoip" + "github.com/sagernet/sing/common/geosite" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/random" + "github.com/sagernet/sing/common/redir" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/task" + "github.com/sagernet/sing/protocol/shadowsocks" + "github.com/sagernet/sing/protocol/shadowsocks/shadowaead" + "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/transport/mixed" + "github.com/sagernet/sing/transport/system" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +type flags struct { + Server string `json:"server"` + ServerPort uint16 `json:"server_port"` + LocalPort uint16 `json:"local_port"` + Password string `json:"password"` + Key string `json:"key"` + Method string `json:"method"` + TCPFastOpen bool `json:"fast_open"` + Verbose bool `json:"verbose"` + Transproxy string `json:"transproxy"` + FWMark int `json:"fwmark"` + Bypass string `json:"bypass"` + UseSystemRNG bool `json:"use_system_rng"` + ReducedSaltEntropy bool `json:"reduced_salt_entropy"` + ConfigFile string +} + +func main() { + f := new(flags) + + command := &cobra.Command{ + Use: "ss-local", + Short: "shadowsocks client", + Version: sing.VersionStr, + Run: func(cmd *cobra.Command, args []string) { + Run(cmd, f) + }, + } + + command.Flags().StringVarP(&f.Server, "server", "s", "", "Set the server’s hostname or IP.") + command.Flags().Uint16VarP(&f.ServerPort, "server-port", "p", 0, "Set the server’s port number.") + command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.") + command.Flags().StringVarP(&f.Password, "password", "k", "", "Set the password. The server and the client should use the same password.") + command.Flags().StringVar(&f.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.") + command.Flags().StringVarP(&f.Method, "encrypt-method", "m", "", `Set the cipher. + +Supported ciphers: + +none +aes-128-gcm +aes-192-gcm +aes-256-gcm +chacha20-ietf-poly1305 +xchacha20-ietf-poly1305`) + command.Flags().BoolVar(&f.TCPFastOpen, "fast-open", false, `Enable TCP fast open. +Only available with Linux kernel > 3.7.0.`) + command.Flags().StringVarP(&f.Transproxy, "transproxy", "t", "", "Enable transparent proxy support. [possible values: redirect, tproxy]") + command.Flags().IntVar(&f.FWMark, "fwmark", 0, "Set outbound socket mark.") + command.Flags().StringVar(&f.Bypass, "bypass", "", "Set bypass country.") + command.Flags().StringVarP(&f.ConfigFile, "config", "c", "", "Use a configuration file.") + command.Flags().BoolVarP(&f.Verbose, "verbose", "v", false, "Enable verbose mode.") + command.Flags().BoolVar(&f.UseSystemRNG, "use-system-rng", false, "Use system random number generator.") + command.Flags().BoolVar(&f.ReducedSaltEntropy, "reduced-salt-entropy", false, "Remapping salt to printable chars.") + + err := command.Execute() + if err != nil { + logrus.Fatal(err) + } +} + +type LocalClient struct { + *mixed.Listener + *geosite.Matcher + server *M.AddrPort + method shadowsocks.Method + session shadowsocks.Session + dialer net.Dialer + bypass string +} + +func NewLocalClient(f *flags) (*LocalClient, error) { + if f.ConfigFile != "" { + configFile, err := ioutil.ReadFile(f.ConfigFile) + if err != nil { + return nil, E.Cause(err, "read config file") + } + flagsNew := new(flags) + err = json.Unmarshal(configFile, flagsNew) + if err != nil { + return nil, E.Cause(err, "decode config file") + } + if flagsNew.Server != "" && f.Server == "" { + f.Server = flagsNew.Server + } + if flagsNew.ServerPort != 0 && f.ServerPort == 0 { + f.ServerPort = flagsNew.ServerPort + } + if flagsNew.LocalPort != 0 && f.LocalPort == 0 { + f.LocalPort = flagsNew.LocalPort + } + if flagsNew.Password != "" && f.Password == "" { + f.Password = flagsNew.Password + } + if flagsNew.Key != "" && f.Key == "" { + f.Key = flagsNew.Key + } + if flagsNew.Method != "" && f.Method == "" { + f.Method = flagsNew.Method + } + if flagsNew.Transproxy != "" && f.Transproxy == "" { + f.Transproxy = flagsNew.Transproxy + } + if flagsNew.TCPFastOpen { + f.TCPFastOpen = true + } + if flagsNew.Verbose { + f.Verbose = true + } + } + + if f.Verbose { + logrus.SetLevel(logrus.TraceLevel) + } + + if f.Server == "" { + return nil, E.New("missing server address") + } else if f.ServerPort == 0 { + return nil, E.New("missing server port") + } else if f.Method == "" { + return nil, E.New("missing method") + } + + client := &LocalClient{ + server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort), + bypass: f.Bypass, + } + + if f.Method == shadowsocks.MethodNone { + client.method = shadowsocks.NewNone() + } else if common.Contains(shadowaead.List, f.Method) { + var key []byte + var rng io.Reader + if f.UseSystemRNG { + rng = random.System + } else { + rng = random.Blake3KeyedHash() + } + if f.ReducedSaltEntropy { + rng = &shadowsocks.ReducedEntropyReader{Reader: rng} + } + client.method = shadowaead.New(f.Method, rng) + keyLength := client.method.KeyLength() + + if f.Key != "" { + decoded, err := base64.URLEncoding.DecodeString(f.Key) + if err != nil { + return nil, E.Cause(err, "decode key") + } + if len(decoded) != keyLength { + return nil, E.Cause(err, "bad key") + } + key = decoded + } else if f.Password != "" { + key = shadowsocks.Key([]byte(f.Password), keyLength) + } else { + return nil, E.New("missing password") + } + client.session = shadowaead.NewSession(key, false) + } + + client.dialer.Control = func(network, address string, c syscall.RawConn) error { + var rawFd uintptr + err := c.Control(func(fd uintptr) { + rawFd = fd + }) + if err != nil { + return err + } + if f.FWMark > 0 { + err = syscall.SetsockoptInt(int(rawFd), syscall.SOL_SOCKET, syscall.SO_MARK, f.FWMark) + if err != nil { + return err + } + } + if f.TCPFastOpen { + err = system.TCPFastOpen(rawFd) + if err != nil { + return err + } + } + return nil + } + + var transproxyMode redir.TransproxyMode + switch f.Transproxy { + case "redirect": + transproxyMode = redir.ModeRedirect + case "tproxy": + transproxyMode = redir.ModeTProxy + case "": + transproxyMode = redir.ModeDisabled + default: + return nil, E.New("unknown transproxy mode ", f.Transproxy) + } + + client.Listener = mixed.NewListener(netip.AddrPortFrom(netip.IPv6Unspecified(), f.LocalPort), nil, transproxyMode, client) + + if f.Bypass != "" { + err := geoip.LoadMMDB("Country.mmdb") + if err != nil { + return nil, E.Cause(err, "load Country.mmdb") + } + + geodata, err := os.Open("geosite.dat") + if err != nil { + return nil, E.Cause(err, "geosite.dat not found") + } + + geositeMatcher, err := geosite.LoadGeositeMatcher(geodata, f.Bypass) + if err != nil { + return nil, err + } + client.Matcher = geositeMatcher + debug.FreeOSMemory() + } + + return client, nil +} + +func bypass(conn net.Conn, destination *M.AddrPort) error { + logrus.Info("BYPASS ", conn.RemoteAddr(), " ==> ", destination) + serverConn, err := net.Dial("tcp", destination.String()) + if err != nil { + return err + } + return task.Run(context.Background(), func() error { + defer rw.CloseRead(conn) + defer rw.CloseWrite(serverConn) + return common.Error(io.Copy(serverConn, conn)) + }, func() error { + defer rw.CloseRead(serverConn) + defer rw.CloseWrite(conn) + return common.Error(io.Copy(conn, serverConn)) + }) +} + +func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error { + if c.bypass != "" { + if metadata.Destination.Addr.Family().IsFqdn() { + if c.Match(metadata.Destination.Addr.Fqdn()) { + return bypass(conn, metadata.Destination) + } + } else { + if geoip.Match(c.bypass, metadata.Destination.Addr.Addr().AsSlice()) { + return bypass(conn, metadata.Destination) + } + } + } + + logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination) + ctx := context.Background() + + var serverConn net.Conn + payload := buf.New() + err := task.Run(ctx, func() error { + sc, err := c.dialer.DialContext(ctx, "tcp", c.server.String()) + serverConn = sc + if err != nil { + return E.Cause(err, "connect to server") + } + return nil + }, func() error { + err := conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + if err != nil { + return err + } + _, err = payload.ReadFrom(conn) + if err != nil && !E.IsTimeout(err) { + return E.Cause(err, "read payload") + } + err = conn.SetReadDeadline(time.Time{}) + return err + }) + if err != nil { + payload.Release() + return err + } + serverConn = c.method.DialEarlyConn(c.session, serverConn, metadata.Destination) + _, err = serverConn.Write(payload.Bytes()) + payload.Release() + if err != nil { + return E.Cause(err, "client handshake") + } + + return rw.CopyConn(ctx, serverConn, conn) +} + +func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error { + ctx := context.Background() + udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String()) + if err != nil { + return err + } + serverConn := c.method.DialPacketConn(c.session, udpConn) + return task.Run(ctx, func() error { + var init bool + return socks.CopyPacketConn(serverConn, conn, func(destination *M.AddrPort, n int) { + if !init { + init = true + logrus.Info("UDP ", conn.LocalAddr(), " ==> ", destination) + } else { + logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination) + } + }) + }, func() error { + return socks.CopyPacketConn(conn, serverConn, func(destination *M.AddrPort, n int) { + logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination) + }) + }) +} + +func Run(cmd *cobra.Command, flags *flags) { + client, err := NewLocalClient(flags) + if err != nil { + logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n") + cmd.Help() + os.Exit(1) + } + err = client.Listener.Start() + if err != nil { + logrus.Fatal(err) + } + + logrus.Info("mixed server started at ", client.Listener.TCPListener.Addr()) + + osSignals := make(chan os.Signal, 1) + signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM) + <-osSignals + + client.Listener.Close() +} + +func (c *LocalClient) HandleError(err error) { + if E.IsClosed(err) { + return + } + logrus.Warn(err) +} diff --git a/cli/sslocal/debug.go b/cli/sslocal/debug.go deleted file mode 100644 index 4510862..0000000 --- a/cli/sslocal/debug.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build debug - -package main - -import ( - "net/http" - _ "net/http/pprof" -) - -func init() { - go http.ListenAndServe("127.0.0.1:8964", nil) -} diff --git a/cli/sslocal/sslocal.go b/cli/sslocal/sslocal.go deleted file mode 100644 index 5277659..0000000 --- a/cli/sslocal/sslocal.go +++ /dev/null @@ -1,324 +0,0 @@ -package main - -import ( - "context" - "encoding/base64" - "encoding/json" - "github.com/sagernet/sing/common/geoip" - "github.com/sagernet/sing/common/geosite" - "io" - "io/ioutil" - "net" - "net/netip" - "os" - "os/signal" - "runtime/debug" - "strconv" - "syscall" - - "github.com/sagernet/sing" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/common/socksaddr" - "github.com/sagernet/sing/common/task" - "github.com/sagernet/sing/protocol/shadowsocks" - "github.com/sagernet/sing/protocol/socks" - "github.com/sagernet/sing/transport/system" - "github.com/sirupsen/logrus" - "github.com/spf13/cobra" -) - -func main() { - err := MainCmd().Execute() - if err != nil { - logrus.Fatal(err) - } -} - -type Flags struct { - Server string `json:"server"` - ServerPort uint16 `json:"server_port"` - LocalPort uint16 `json:"local_port"` - Password string `json:"password"` - Key string `json:"key"` - Method string `json:"method"` - TCPFastOpen bool `json:"fast_open"` - Verbose bool `json:"verbose"` - Redirect string `json:"redir"` - FWMark int `json:"fwmark"` - Bypass string `json:"bypass"` - ConfigFile string -} - -func MainCmd() *cobra.Command { - flags := new(Flags) - - cmd := &cobra.Command{ - Use: "sslocal", - Short: "shadowsocks client as socks5 proxy, sing port", - Version: sing.Version, - Run: func(cmd *cobra.Command, args []string) { - Run(cmd, flags) - }, - } - - cmd.Flags().StringVarP(&flags.Server, "server", "s", "", "Set the server’s hostname or IP.") - cmd.Flags().Uint16VarP(&flags.ServerPort, "server-port", "p", 0, "Set the server’s port number.") - cmd.Flags().Uint16VarP(&flags.LocalPort, "local-port", "l", 0, "Set the local port number.") - cmd.Flags().StringVarP(&flags.Password, "password", "k", "", "Set the password. The server and the client should use the same password.") - cmd.Flags().StringVar(&flags.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.") - cmd.Flags().StringVarP(&flags.Method, "encrypt-method", "m", "", `Set the cipher. - -Supported ciphers: - -none -aes-128-gcm -aes-192-gcm -aes-256-gcm -chacha20-ietf-poly1305 -xchacha20-ietf-poly1305 - -The default cipher is chacha20-ietf-poly1305.`) - cmd.Flags().BoolVar(&flags.TCPFastOpen, "fast-open", false, `Enable TCP fast open. -Only available with Linux kernel > 3.7.0.`) - cmd.Flags().StringVar(&flags.Redirect, "redir", "", "Enable transparent proxy support. [possible values: redirect, tproxy]") - cmd.Flags().IntVar(&flags.FWMark, "fwmark", 0, "Set outbound socket mark.") - cmd.Flags().StringVar(&flags.Bypass, "bypass", "", "Set bypass country.") - cmd.Flags().StringVarP(&flags.ConfigFile, "config", "c", "", "Use a configuration file.") - cmd.Flags().BoolVarP(&flags.Verbose, "verbose", "v", false, "Enable verbose mode.") - - return cmd -} - -type LocalClient struct { - *system.MixedListener - *shadowsocks.Client - *geosite.Matcher - redirect bool - bypass string -} - -func NewLocalClient(flags *Flags) (*LocalClient, error) { - if flags.ConfigFile != "" { - configFile, err := ioutil.ReadFile(flags.ConfigFile) - if err != nil { - return nil, exceptions.Cause(err, "read config file") - } - flagsNew := new(Flags) - err = json.Unmarshal(configFile, flagsNew) - if err != nil { - return nil, exceptions.Cause(err, "decode config file") - } - if flagsNew.Server != "" && flags.Server == "" { - flags.Server = flagsNew.Server - } - if flagsNew.ServerPort != 0 && flags.ServerPort == 0 { - flags.ServerPort = flagsNew.ServerPort - } - if flagsNew.LocalPort != 0 && flags.LocalPort == 0 { - flags.LocalPort = flagsNew.LocalPort - } - if flagsNew.Password != "" && flags.Password == "" { - flags.Password = flagsNew.Password - } - if flagsNew.Key != "" && flags.Key == "" { - flags.Key = flagsNew.Key - } - if flagsNew.Method != "" && flags.Method == "" { - flags.Method = flagsNew.Method - } - if flagsNew.Redirect != "" && flags.Redirect == "" { - flags.Redirect = flagsNew.Redirect - } - if flagsNew.TCPFastOpen { - flags.TCPFastOpen = true - } - if flagsNew.Verbose { - flags.Verbose = true - } - - } - - clientConfig := &shadowsocks.ClientConfig{ - Server: flags.Server, - ServerPort: flags.ServerPort, - Method: flags.Method, - } - - if flags.Key != "" { - key, err := base64.URLEncoding.DecodeString(flags.Key) - if err != nil { - return nil, exceptions.Cause(err, "decode key") - } - clientConfig.Key = key - } else if flags.Password != "" { - clientConfig.Password = []byte(flags.Password) - } - - if flags.Verbose { - logrus.SetLevel(logrus.TraceLevel) - } - - dialer := new(net.Dialer) - - dialer.Control = func(network, address string, c syscall.RawConn) error { - var rawFd uintptr - err := c.Control(func(fd uintptr) { - rawFd = fd - }) - if err != nil { - return err - } - if flags.FWMark > 0 { - err = syscall.SetsockoptInt(int(rawFd), syscall.SOL_SOCKET, syscall.SO_MARK, flags.FWMark) - if err != nil { - return err - } - } - if flags.TCPFastOpen { - err = system.TCPFastOpen(rawFd) - if err != nil { - return err - } - } - return nil - } - - shadowClient, err := shadowsocks.NewClient(dialer, clientConfig) - if err != nil { - return nil, err - } - - client := &LocalClient{ - Client: shadowClient, - } - client.MixedListener = system.NewMixedListener(netip.AddrPortFrom(netip.IPv6Unspecified(), flags.LocalPort), &system.MixedConfig{ - Redirect: flags.Redirect == "redirect", - TProxy: flags.Redirect == "tproxy", - }, client) - - if flags.Bypass != "" { - client.bypass = flags.Bypass - - err = geoip.LoadMMDB("Country.mmdb") - if err != nil { - return nil, exceptions.Cause(err, "load Country.mmdb") - } - - geodata, err := os.Open("geosite.dat") - if err != nil { - return nil, exceptions.Cause(err, "geosite.dat not found") - } - - geositeMatcher, err := geosite.LoadGeositeMatcher(geodata, flags.Bypass) - if err != nil { - return nil, err - } - client.Matcher = geositeMatcher - debug.FreeOSMemory() - } - - return client, nil -} - -func (c *LocalClient) Start() error { - err := c.MixedListener.Start() - if err != nil { - return err - } - logrus.Info("mixed server started at ", c.MixedListener.TCPListener.Addr()) - return nil -} - -func bypass(addr socksaddr.Addr, port uint16, conn net.Conn) error { - logrus.Info("BYPASS ", conn.RemoteAddr(), " ==> ", net.JoinHostPort(addr.String(), strconv.Itoa(int(port)))) - serverConn, err := net.Dial("tcp", socksaddr.JoinHostPort(addr, port)) - if err != nil { - return err - } - return task.Run(context.Background(), func() error { - defer rw.CloseRead(conn) - defer rw.CloseWrite(serverConn) - return common.Error(io.Copy(serverConn, conn)) - }, func() error { - defer rw.CloseRead(serverConn) - defer rw.CloseWrite(conn) - return common.Error(io.Copy(conn, serverConn)) - }) -} - -func (c *LocalClient) NewConnection(addr socksaddr.Addr, port uint16, conn net.Conn) error { - if c.bypass != "" { - if addr.Family().IsFqdn() { - if c.Match(addr.Fqdn()) { - return bypass(addr, port, conn) - } - } else { - if geoip.Match(c.bypass, addr.Addr().AsSlice()) { - return bypass(addr, port, conn) - } - } - } - - logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", net.JoinHostPort(addr.String(), strconv.Itoa(int(port)))) - - ctx := context.Background() - serverConn, err := c.DialContextTCP(ctx, addr, port) - if err != nil { - return err - } - return task.Run(ctx, func() error { - defer rw.CloseRead(conn) - defer rw.CloseWrite(serverConn) - return common.Error(io.Copy(serverConn, conn)) - }, func() error { - defer rw.CloseRead(serverConn) - defer rw.CloseWrite(conn) - return common.Error(io.Copy(conn, serverConn)) - }) -} - -func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, addr socksaddr.Addr, port uint16) error { - ctx := context.Background() - serverConn := c.DialContextUDP(ctx) - return task.Run(ctx, func() error { - var init bool - return socks.CopyPacketConn(serverConn, conn, func(size int) { - if !init { - init = true - logrus.Info("UDP ", conn.LocalAddr(), " ==> ", socksaddr.JoinHostPort(addr, port)) - } else { - logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", socksaddr.JoinHostPort(addr, port)) - } - }) - }, func() error { - return socks.CopyPacketConn(conn, serverConn, func(size int) { - logrus.Trace("UDP ", conn.LocalAddr(), " <== ", socksaddr.JoinHostPort(addr, port)) - }) - }) -} - -func Run(cmd *cobra.Command, flags *Flags) { - client, err := NewLocalClient(flags) - if err != nil { - logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n") - cmd.Help() - os.Exit(1) - } - err = client.Start() - if err != nil { - logrus.Fatal(err) - } - osSignals := make(chan os.Signal, 1) - signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM) - <-osSignals - client.Close() -} - -func (c *LocalClient) OnError(err error) { - if exceptions.IsClosed(err) { - return - } - logrus.Warn(err) -} diff --git a/cli/uot-local/main.go b/cli/uot-local/main.go new file mode 100644 index 0000000..944a0e5 --- /dev/null +++ b/cli/uot-local/main.go @@ -0,0 +1,140 @@ +package main + +import ( + "context" + "net" + "net/netip" + "os" + "os/signal" + "syscall" + + "github.com/sagernet/sing" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/redir" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/task" + "github.com/sagernet/sing/common/uot" + "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/transport/mixed" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +var ( + verbose bool + transproxy string +) + +func main() { + command := cobra.Command{ + Use: "uot-local ", + Short: "SUoT client.", + Long: "SUoT client. \n\nconverts a normal socks server to a SUoT mixed server.", + Example: "uot-local 0.0.0.0:2080 127.0.0.1:1080", + Version: sing.VersionStr, + Args: cobra.ExactArgs(2), + Run: run, + } + command.Flags().BoolVarP(&verbose, "verbose", "v", false, "Enable verbose mode") + command.Flags().StringVarP(&transproxy, "transproxy", "t", "", "Enable transparent proxy support [possible values: redirect, tproxy]") + err := command.Execute() + if err != nil { + logrus.Fatal(err) + } +} + +func run(cmd *cobra.Command, args []string) { + if verbose { + logrus.SetLevel(logrus.TraceLevel) + } + + bind, err := netip.ParseAddrPort(args[0]) + if err != nil { + logrus.Fatal("bad bind address: ", err) + } + + _, err = netip.ParseAddrPort(args[1]) + if err != nil { + logrus.Fatal("bad upstream address: ", err) + } + + var transproxyMode redir.TransproxyMode + switch transproxy { + case "redirect": + transproxyMode = redir.ModeRedirect + case "tproxy": + transproxyMode = redir.ModeTProxy + case "": + transproxyMode = redir.ModeDisabled + default: + logrus.Fatal("unknown transproxy mode ", transproxy) + } + + client := &localClient{upstream: args[1]} + client.Listener = mixed.NewListener(bind, nil, transproxyMode, client) + + err = client.Start() + if err != nil { + logrus.Fatal("start mixed server: ", err) + } + + logrus.Info("mixed server started at ", client.TCPListener.Addr()) + + osSignals := make(chan os.Signal, 1) + signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM) + <-osSignals + + client.Close() +} + +type localClient struct { + *mixed.Listener + upstream string +} + +func (c *localClient) NewConnection(conn net.Conn, metadata M.Metadata) error { + logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination) + + upstream, err := net.Dial("tcp", c.upstream) + if err != nil { + return E.Cause(err, "connect to upstream") + } + + _, err = socks.ClientHandshake(upstream, socks.Version5, socks.CommandConnect, metadata.Destination, "", "") + if err != nil { + return E.Cause(err, "upstream handshake failed") + } + + return rw.CopyConn(context.Background(), upstream, conn) +} + +func (c *localClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error { + upstream, err := net.Dial("tcp", c.upstream) + if err != nil { + return E.Cause(err, "connect to upstream") + } + + _, err = socks.ClientHandshake(upstream, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn(uot.UOTMagicAddress), 443), "", "") + if err != nil { + return E.Cause(err, "upstream handshake failed") + } + + client := uot.NewClientConn(upstream) + return task.Run(context.Background(), func() error { + return socks.CopyPacketConn(client, conn, func(destination *M.AddrPort, n int) { + logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination) + }) + }, func() error { + return socks.CopyPacketConn(conn, client, func(destination *M.AddrPort, n int) { + logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination) + }) + }) +} + +func (c *localClient) OnError(err error) { + if E.IsClosed(err) { + return + } + logrus.Warn(err) +} diff --git a/common/auth/auth.go b/common/auth/auth.go new file mode 100644 index 0000000..98414d8 --- /dev/null +++ b/common/auth/auth.go @@ -0,0 +1,46 @@ +package auth + +import ( + "sync" +) + +type Authenticator interface { + Verify(user string, pass string) bool + Users() []string +} + +type AuthUser struct { + User string + Pass string +} + +type inMemoryAuthenticator struct { + storage *sync.Map + usernames []string +} + +func (au *inMemoryAuthenticator) Verify(user string, pass string) bool { + realPass, ok := au.storage.Load(user) + return ok && realPass == pass +} + +func (au *inMemoryAuthenticator) Users() []string { return au.usernames } + +func NewAuthenticator(users []AuthUser) Authenticator { + if len(users) == 0 { + return nil + } + + au := &inMemoryAuthenticator{storage: &sync.Map{}} + for _, user := range users { + au.storage.Store(user.User, user.Pass) + } + usernames := make([]string, 0, len(users)) + au.storage.Range(func(key, value interface{}) bool { + usernames = append(usernames, key.(string)) + return true + }) + au.usernames = usernames + + return au +} diff --git a/common/buf/buffer.go b/common/buf/buffer.go index bc3c0cd..c13a5fa 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "fmt" "io" + "net" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/list" @@ -188,6 +189,18 @@ func (b *Buffer) ReadFrom(r io.Reader) (int64, error) { return int64(n), nil } +func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) { + if b.IsFull() { + return 0, nil, io.ErrShortBuffer + } + n, addr, err := r.ReadFrom(b.FreeBytes()) + if err != nil { + return 0, nil, err + } + b.end += n + return int64(n), addr, nil +} + func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) { if min <= 0 { return b.ReadFrom(r) diff --git a/common/buf/io.go b/common/buf/io.go index 78dcaf8..12fc43f 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -70,3 +70,12 @@ func (w *BufferedWriter) Flush() error { } return common.Error(w.Writer.Write(buffer.Bytes())) } + +func (w *BufferedWriter) Close() error { + buffer := w.Buffer + if buffer != nil { + w.Buffer = nil + buffer.Release() + } + return nil +} diff --git a/common/cache/cache.go b/common/cache/cache.go new file mode 100644 index 0000000..6b252c2 --- /dev/null +++ b/common/cache/cache.go @@ -0,0 +1,106 @@ +package cache + +import ( + "runtime" + "sync" + "time" +) + +// Cache store element with a expired time +type Cache struct { + *cache +} + +type cache struct { + mapping sync.Map + janitor *janitor +} + +type element struct { + Expired time.Time + Payload interface{} +} + +// Put element in Cache with its ttl +func (c *cache) Put(key interface{}, payload interface{}, ttl time.Duration) { + c.mapping.Store(key, &element{ + Payload: payload, + Expired: time.Now().Add(ttl), + }) +} + +// Get element in Cache, and drop when it expired +func (c *cache) Get(key interface{}) interface{} { + item, exist := c.mapping.Load(key) + if !exist { + return nil + } + elm := item.(*element) + // expired + if time.Since(elm.Expired) > 0 { + c.mapping.Delete(key) + return nil + } + return elm.Payload +} + +// GetWithExpire element in Cache with Expire Time +func (c *cache) GetWithExpire(key interface{}) (payload interface{}, expired time.Time) { + item, exist := c.mapping.Load(key) + if !exist { + return + } + elm := item.(*element) + // expired + if time.Since(elm.Expired) > 0 { + c.mapping.Delete(key) + return + } + return elm.Payload, elm.Expired +} + +func (c *cache) cleanup() { + c.mapping.Range(func(k, v interface{}) bool { + key := k.(string) + elm := v.(*element) + if time.Since(elm.Expired) > 0 { + c.mapping.Delete(key) + } + return true + }) +} + +type janitor struct { + interval time.Duration + stop chan struct{} +} + +func (j *janitor) process(c *cache) { + ticker := time.NewTicker(j.interval) + for { + select { + case <-ticker.C: + c.cleanup() + case <-j.stop: + ticker.Stop() + return + } + } +} + +func stopJanitor(c *Cache) { + c.janitor.stop <- struct{}{} +} + +// New return *Cache +func New(interval time.Duration) *Cache { + j := &janitor{ + interval: interval, + stop: make(chan struct{}), + } + c := &cache{janitor: j} + go j.process(c) + C := &Cache{c} + runtime.SetFinalizer(C, stopJanitor) + return C +} diff --git a/common/cache/cache_test.go b/common/cache/cache_test.go new file mode 100644 index 0000000..cf4a391 --- /dev/null +++ b/common/cache/cache_test.go @@ -0,0 +1,70 @@ +package cache + +import ( + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCache_Basic(t *testing.T) { + interval := 200 * time.Millisecond + ttl := 20 * time.Millisecond + c := New(interval) + c.Put("int", 1, ttl) + c.Put("string", "a", ttl) + + i := c.Get("int") + assert.Equal(t, i.(int), 1, "should recv 1") + + s := c.Get("string") + assert.Equal(t, s.(string), "a", "should recv 'a'") +} + +func TestCache_TTL(t *testing.T) { + interval := 200 * time.Millisecond + ttl := 20 * time.Millisecond + now := time.Now() + c := New(interval) + c.Put("int", 1, ttl) + c.Put("int2", 2, ttl) + + i := c.Get("int") + _, expired := c.GetWithExpire("int2") + assert.Equal(t, i.(int), 1, "should recv 1") + assert.True(t, now.Before(expired)) + + time.Sleep(ttl * 2) + i = c.Get("int") + j, _ := c.GetWithExpire("int2") + assert.Nil(t, i, "should recv nil") + assert.Nil(t, j, "should recv nil") +} + +func TestCache_AutoCleanup(t *testing.T) { + interval := 10 * time.Millisecond + ttl := 15 * time.Millisecond + c := New(interval) + c.Put("int", 1, ttl) + + time.Sleep(ttl * 2) + i := c.Get("int") + j, _ := c.GetWithExpire("int") + assert.Nil(t, i, "should recv nil") + assert.Nil(t, j, "should recv nil") +} + +func TestCache_AutoGC(t *testing.T) { + sign := make(chan struct{}) + go func() { + interval := 10 * time.Millisecond + ttl := 15 * time.Millisecond + c := New(interval) + c.Put("int", 1, ttl) + sign <- struct{}{} + }() + + <-sign + runtime.GC() +} diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go new file mode 100644 index 0000000..4269d86 --- /dev/null +++ b/common/cache/lrucache.go @@ -0,0 +1,223 @@ +package cache + +// Modified by https://github.com/die-net/lrucache + +import ( + "container/list" + "sync" + "time" +) + +// Option is part of Functional Options Pattern +type Option func(*LruCache) + +// EvictCallback is used to get a callback when a cache entry is evicted +type EvictCallback = func(key interface{}, value interface{}) + +// WithEvict set the evict callback +func WithEvict(cb EvictCallback) Option { + return func(l *LruCache) { + l.onEvict = cb + } +} + +// WithUpdateAgeOnGet update expires when Get element +func WithUpdateAgeOnGet() Option { + return func(l *LruCache) { + l.updateAgeOnGet = true + } +} + +// WithAge defined element max age (second) +func WithAge(maxAge int64) Option { + return func(l *LruCache) { + l.maxAge = maxAge + } +} + +// WithSize defined max length of LruCache +func WithSize(maxSize int) Option { + return func(l *LruCache) { + l.maxSize = maxSize + } +} + +// WithStale decide whether Stale return is enabled. +// If this feature is enabled, element will not get Evicted according to `WithAge`. +func WithStale(stale bool) Option { + return func(l *LruCache) { + l.staleReturn = stale + } +} + +// LruCache is a thread-safe, in-memory lru-cache that evicts the +// least recently used entries from memory when (if set) the entries are +// older than maxAge (in seconds). Use the New constructor to create one. +type LruCache struct { + maxAge int64 + maxSize int + mu sync.Mutex + cache map[interface{}]*list.Element + lru *list.List // Front is least-recent + updateAgeOnGet bool + staleReturn bool + onEvict EvictCallback +} + +// NewLRUCache creates an LruCache +func NewLRUCache(options ...Option) *LruCache { + lc := &LruCache{ + lru: list.New(), + cache: make(map[interface{}]*list.Element), + } + + for _, option := range options { + option(lc) + } + + return lc +} + +// Get returns the interface{} representation of a cached response and a bool +// set to true if the key was found. +func (c *LruCache) Get(key interface{}) (interface{}, bool) { + entry := c.get(key) + if entry == nil { + return nil, false + } + value := entry.value + + return value, true +} + +// GetWithExpire returns the interface{} representation of a cached response, +// a time.Time Give expected expires, +// and a bool set to true if the key was found. +// This method will NOT check the maxAge of element and will NOT update the expires. +func (c *LruCache) GetWithExpire(key interface{}) (interface{}, time.Time, bool) { + entry := c.get(key) + if entry == nil { + return nil, time.Time{}, false + } + + return entry.value, time.Unix(entry.expires, 0), true +} + +// Exist returns if key exist in cache but not put item to the head of linked list +func (c *LruCache) Exist(key interface{}) bool { + c.mu.Lock() + defer c.mu.Unlock() + + _, ok := c.cache[key] + return ok +} + +// Set stores the interface{} representation of a response for a given key. +func (c *LruCache) Set(key interface{}, value interface{}) { + expires := int64(0) + if c.maxAge > 0 { + expires = time.Now().Unix() + c.maxAge + } + c.SetWithExpire(key, value, time.Unix(expires, 0)) +} + +// SetWithExpire stores the interface{} representation of a response for a given key and given expires. +// The expires time will round to second. +func (c *LruCache) SetWithExpire(key interface{}, value interface{}, expires time.Time) { + c.mu.Lock() + defer c.mu.Unlock() + + if le, ok := c.cache[key]; ok { + c.lru.MoveToBack(le) + e := le.Value.(*entry) + e.value = value + e.expires = expires.Unix() + } else { + e := &entry{key: key, value: value, expires: expires.Unix()} + c.cache[key] = c.lru.PushBack(e) + + if c.maxSize > 0 { + if len := c.lru.Len(); len > c.maxSize { + c.deleteElement(c.lru.Front()) + } + } + } + + c.maybeDeleteOldest() +} + +// CloneTo clone and overwrite elements to another LruCache +func (c *LruCache) CloneTo(n *LruCache) { + c.mu.Lock() + defer c.mu.Unlock() + + n.mu.Lock() + defer n.mu.Unlock() + + n.lru = list.New() + n.cache = make(map[interface{}]*list.Element) + + for e := c.lru.Front(); e != nil; e = e.Next() { + elm := e.Value.(*entry) + n.cache[elm.key] = n.lru.PushBack(elm) + } +} + +func (c *LruCache) get(key interface{}) *entry { + c.mu.Lock() + defer c.mu.Unlock() + + le, ok := c.cache[key] + if !ok { + return nil + } + + if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() { + c.deleteElement(le) + c.maybeDeleteOldest() + + return nil + } + + c.lru.MoveToBack(le) + entry := le.Value.(*entry) + if c.maxAge > 0 && c.updateAgeOnGet { + entry.expires = time.Now().Unix() + c.maxAge + } + return entry +} + +// Delete removes the value associated with a key. +func (c *LruCache) Delete(key interface{}) { + c.mu.Lock() + + if le, ok := c.cache[key]; ok { + c.deleteElement(le) + } + + c.mu.Unlock() +} + +func (c *LruCache) maybeDeleteOldest() { + if !c.staleReturn && c.maxAge > 0 { + now := time.Now().Unix() + for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() { + c.deleteElement(le) + } + } +} + +func (c *LruCache) deleteElement(le *list.Element) { + c.lru.Remove(le) + e := le.Value.(*entry) + delete(c.cache, e.key) + if c.onEvict != nil { + c.onEvict(e.key, e.value) + } +} + +type entry struct { + key interface{} + value interface{} + expires int64 +} diff --git a/common/cache/lrucache_test.go b/common/cache/lrucache_test.go new file mode 100644 index 0000000..8a04f74 --- /dev/null +++ b/common/cache/lrucache_test.go @@ -0,0 +1,183 @@ +package cache + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var entries = []struct { + key string + value string +}{ + {"1", "one"}, + {"2", "two"}, + {"3", "three"}, + {"4", "four"}, + {"5", "five"}, +} + +func TestLRUCache(t *testing.T) { + c := NewLRUCache() + + for _, e := range entries { + c.Set(e.key, e.value) + } + + c.Delete("missing") + _, ok := c.Get("missing") + assert.False(t, ok) + + for _, e := range entries { + value, ok := c.Get(e.key) + if assert.True(t, ok) { + assert.Equal(t, e.value, value.(string)) + } + } + + for _, e := range entries { + c.Delete(e.key) + + _, ok := c.Get(e.key) + assert.False(t, ok) + } +} + +func TestLRUMaxAge(t *testing.T) { + c := NewLRUCache(WithAge(86400)) + + now := time.Now().Unix() + expected := now + 86400 + + // Add one expired entry + c.Set("foo", "bar") + c.lru.Back().Value.(*entry).expires = now + + // Reset + c.Set("foo", "bar") + e := c.lru.Back().Value.(*entry) + assert.True(t, e.expires >= now) + c.lru.Back().Value.(*entry).expires = now + + // Set a few and verify expiration times + for _, s := range entries { + c.Set(s.key, s.value) + e := c.lru.Back().Value.(*entry) + assert.True(t, e.expires >= expected && e.expires <= expected+10) + } + + // Make sure we can get them all + for _, s := range entries { + _, ok := c.Get(s.key) + assert.True(t, ok) + } + + // Expire all entries + for _, s := range entries { + le, ok := c.cache[s.key] + if assert.True(t, ok) { + le.Value.(*entry).expires = now + } + } + + // Get one expired entry, which should clear all expired entries + _, ok := c.Get("3") + assert.False(t, ok) + assert.Equal(t, c.lru.Len(), 0) +} + +func TestLRUpdateOnGet(t *testing.T) { + c := NewLRUCache(WithAge(86400), WithUpdateAgeOnGet()) + + now := time.Now().Unix() + expires := now + 86400/2 + + // Add one expired entry + c.Set("foo", "bar") + c.lru.Back().Value.(*entry).expires = expires + + _, ok := c.Get("foo") + assert.True(t, ok) + assert.True(t, c.lru.Back().Value.(*entry).expires > expires) +} + +func TestMaxSize(t *testing.T) { + c := NewLRUCache(WithSize(2)) + // Add one expired entry + c.Set("foo", "bar") + _, ok := c.Get("foo") + assert.True(t, ok) + + c.Set("bar", "foo") + c.Set("baz", "foo") + + _, ok = c.Get("foo") + assert.False(t, ok) +} + +func TestExist(t *testing.T) { + c := NewLRUCache(WithSize(1)) + c.Set(1, 2) + assert.True(t, c.Exist(1)) + c.Set(2, 3) + assert.False(t, c.Exist(1)) +} + +func TestEvict(t *testing.T) { + temp := 0 + evict := func(key interface{}, value interface{}) { + temp = key.(int) + value.(int) + } + + c := NewLRUCache(WithEvict(evict), WithSize(1)) + c.Set(1, 2) + c.Set(2, 3) + + assert.Equal(t, temp, 3) +} + +func TestSetWithExpire(t *testing.T) { + c := NewLRUCache(WithAge(1)) + now := time.Now().Unix() + + tenSecBefore := time.Unix(now-10, 0) + c.SetWithExpire(1, 2, tenSecBefore) + + // res is expected not to exist, and expires should be empty time.Time + res, expires, exist := c.GetWithExpire(1) + assert.Equal(t, nil, res) + assert.Equal(t, time.Time{}, expires) + assert.Equal(t, false, exist) +} + +func TestStale(t *testing.T) { + c := NewLRUCache(WithAge(1), WithStale(true)) + now := time.Now().Unix() + + tenSecBefore := time.Unix(now-10, 0) + c.SetWithExpire(1, 2, tenSecBefore) + + res, expires, exist := c.GetWithExpire(1) + assert.Equal(t, 2, res) + assert.Equal(t, tenSecBefore, expires) + assert.Equal(t, true, exist) +} + +func TestCloneTo(t *testing.T) { + o := NewLRUCache(WithSize(10)) + o.Set("1", 1) + o.Set("2", 2) + + n := NewLRUCache(WithSize(2)) + n.Set("3", 3) + n.Set("4", 4) + + o.CloneTo(n) + + assert.False(t, n.Exist("3")) + assert.True(t, n.Exist("1")) + + n.Set("5", 5) + assert.False(t, n.Exist("1")) +} diff --git a/common/cond.go b/common/cond.go index 7d97ab9..39d8722 100644 --- a/common/cond.go +++ b/common/cond.go @@ -43,6 +43,16 @@ func Filter[T any](arr []T, block func(it T) bool) []T { return retArr } +func FilterIsInstance[T any, N any](arr []T, block func(it T) (N, bool)) []N { + var retArr []N + for _, it := range arr { + if n, isN := block(it); isN { + retArr = append(retArr, n) + } + } + return retArr +} + func Done(ctx context.Context) bool { select { case <-ctx.Done(): diff --git a/common/exceptions/error.go b/common/exceptions/error.go index 43caec4..510ba90 100644 --- a/common/exceptions/error.go +++ b/common/exceptions/error.go @@ -62,3 +62,7 @@ func IsTimeout(err error) bool { } return false } + +type Handler interface { + HandleError(err error) +} diff --git a/common/flush.go b/common/flush.go index aa235b1..16064cc 100644 --- a/common/flush.go +++ b/common/flush.go @@ -50,7 +50,7 @@ func FlushVar(writerP *io.Writer) error { if writerBack == writer { writer = u.Upstream() writerBack = writer - writerP = &writer + *writerP = writer continue } else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter { setter.SetWriter(writerBack) diff --git a/common/geoip/matcher.go b/common/geoip/matcher.go index a311031..cdd5452 100644 --- a/common/geoip/matcher.go +++ b/common/geoip/matcher.go @@ -1,10 +1,11 @@ package geoip import ( - "github.com/oschwald/geoip2-golang" "net" "strings" "sync" + + "github.com/oschwald/geoip2-golang" ) var ( diff --git a/common/geosite/matcher.go b/common/geosite/matcher.go index f071f51..8fb5cc4 100644 --- a/common/geosite/matcher.go +++ b/common/geosite/matcher.go @@ -3,12 +3,13 @@ package geosite import ( "bufio" "encoding/binary" - "github.com/klauspost/compress/zstd" - "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/common/trieset" "io" "strings" + + "github.com/klauspost/compress/zstd" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/trieset" ) type Matcher struct { @@ -25,7 +26,7 @@ func LoadGeositeMatcher(reader io.Reader, code string) (*Matcher, error) { return nil, err } if version != 0 { - return nil, exceptions.New("bad geosite data") + return nil, E.New("bad geosite data") } decoder, err := zstd.NewReader(reader, zstd.WithDecoderLowmem(true), zstd.WithDecoderConcurrency(1)) if err != nil { @@ -72,5 +73,5 @@ func LoadGeositeMatcher(reader io.Reader, code string) (*Matcher, error) { } } } - return nil, exceptions.New(code, " not found in geosite") + return nil, E.New(code, " not found in geosite") } diff --git a/common/gsync/map.go b/common/gsync/map.go index 61af105..c009c85 100644 --- a/common/gsync/map.go +++ b/common/gsync/map.go @@ -201,7 +201,7 @@ func (e *entry[T]) storeLocked(i *T) { // LoadOrStore returns the existing value for the key if present. // Otherwise, it stores and returns the given value. // The loaded result is true if the value was loaded, false if stored. -func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { +func (m *Map[K, V]) LoadOrStore(key K, value func() V) (actual V, loaded bool) { // Avoid locking if it's a clean hit. read, _ := m.read.Load().(readOnly[K, V]) if e, ok := read.m[key]; ok { @@ -228,8 +228,9 @@ func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { m.dirtyLocked() m.read.Store(readOnly[K, V]{m: read.m, amended: true}) } - m.dirty[key] = newEntry(value) - actual, loaded = value, false + v := value() + m.dirty[key] = newEntry(v) + actual, loaded = v, false } m.mu.Unlock() @@ -241,7 +242,7 @@ func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { // // If the entry is expunged, tryLoadOrStore leaves the entry unchanged and // returns with ok==false. -func (e *entry[T]) tryLoadOrStore(i T) (actual T, loaded, ok bool) { +func (e *entry[T]) tryLoadOrStore(i func() T) (actual T, loaded, ok bool) { p := atomic.LoadPointer(&e.p) if p == expunged { var defaultValue T @@ -254,10 +255,10 @@ func (e *entry[T]) tryLoadOrStore(i T) (actual T, loaded, ok bool) { // Copy the interface after the first load to make this method more amenable // to escape analysis: if we hit the "load" path or the entry is expunged, we // shouldn't bother heap-allocating. - ic := i + ic := i() for { if atomic.CompareAndSwapPointer(&e.p, nil, unsafe.Pointer(&ic)) { - return i, false, true + return ic, false, true } p = atomic.LoadPointer(&e.p) if p == expunged { diff --git a/common/socksaddr/addr.go b/common/metadata/addr.go similarity index 57% rename from common/socksaddr/addr.go rename to common/metadata/addr.go index f0473fb..0e1dd4c 100644 --- a/common/socksaddr/addr.go +++ b/common/metadata/addr.go @@ -1,4 +1,4 @@ -package socksaddr +package metadata import ( "net" @@ -13,6 +13,35 @@ type Addr interface { String() string } +type AddrPort struct { + Addr Addr + Port uint16 +} + +func (ap AddrPort) IPAddr() *net.IPAddr { + return &net.IPAddr{ + IP: ap.Addr.Addr().AsSlice(), + } +} + +func (ap AddrPort) TCPAddr() *net.TCPAddr { + return &net.TCPAddr{ + IP: ap.Addr.Addr().AsSlice(), + Port: int(ap.Port), + } +} + +func (ap AddrPort) UDPAddr() *net.UDPAddr { + return &net.UDPAddr{ + IP: ap.Addr.Addr().AsSlice(), + Port: int(ap.Port), + } +} + +func (ap AddrPort) String() string { + return net.JoinHostPort(ap.Addr.String(), strconv.Itoa(int(ap.Port))) +} + func ParseAddr(address string) Addr { addr, err := netip.ParseAddr(address) if err == nil { @@ -21,16 +50,44 @@ func ParseAddr(address string) Addr { return AddrFromFqdn(address) } -func ParseAddrPort(address string) (Addr, uint16, error) { +func AddrPortFrom(addr Addr, port uint16) *AddrPort { + return &AddrPort{addr, port} +} + +func ParseAddress(address string) (*AddrPort, error) { host, port, err := net.SplitHostPort(address) if err != nil { - return nil, 0, err + return nil, err } portInt, err := strconv.Atoi(port) if err != nil { - return nil, 0, err + return nil, err } - return ParseAddr(host), uint16(portInt), nil + return AddrPortFrom(ParseAddr(host), uint16(portInt)), nil +} + +func ParseAddrPort(address string, port string) (*AddrPort, error) { + portInt, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + return AddrPortFrom(ParseAddr(address), uint16(portInt)), nil +} + +func AddrPortFromNetAddr(netAddr net.Addr) *AddrPort { + var ip net.IP + var port uint16 + switch addr := netAddr.(type) { + case *net.TCPAddr: + ip = addr.IP + port = uint16(addr.Port) + case *net.UDPAddr: + ip = addr.IP + port = uint16(addr.Port) + case *net.IPAddr: + ip = addr.IP + } + return AddrPortFrom(AddrFromIP(ip), port) } func AddrFromIP(ip net.IP) Addr { @@ -50,27 +107,14 @@ func AddrFromAddr(addr netip.Addr) Addr { } } -func AddrFromNetAddr(netAddr net.Addr) (addr Addr, port uint16) { - var ip net.IP - switch addr := netAddr.(type) { - case *net.TCPAddr: - ip = addr.IP - port = uint16(addr.Port) - case *net.UDPAddr: - ip = addr.IP - port = uint16(addr.Port) - } - return AddrFromIP(ip), port +func AddrPortFromAddrPort(addrPort netip.AddrPort) *AddrPort { + return AddrPortFrom(AddrFromAddr(addrPort.Addr()), addrPort.Port()) } func AddrFromFqdn(fqdn string) Addr { return AddrFqdn(fqdn) } -func JoinHostPort(addr Addr, port uint16) string { - return net.JoinHostPort(addr.String(), strconv.Itoa(int(port))) -} - type Addr4 [4]byte func (a Addr4) Family() Family { diff --git a/common/socksaddr/exception.go b/common/metadata/exception.go similarity index 91% rename from common/socksaddr/exception.go rename to common/metadata/exception.go index 6e9def5..b69e52f 100644 --- a/common/socksaddr/exception.go +++ b/common/metadata/exception.go @@ -1,4 +1,4 @@ -package socksaddr +package metadata import "fmt" diff --git a/common/socksaddr/family.go b/common/metadata/family.go similarity index 95% rename from common/socksaddr/family.go rename to common/metadata/family.go index 2e3ced0..80011ee 100644 --- a/common/socksaddr/family.go +++ b/common/metadata/family.go @@ -1,4 +1,4 @@ -package socksaddr +package metadata type Family byte diff --git a/common/metadata/metadata.go b/common/metadata/metadata.go new file mode 100644 index 0000000..0b2854c --- /dev/null +++ b/common/metadata/metadata.go @@ -0,0 +1,20 @@ +package metadata + +import ( + "net" + + "github.com/sagernet/sing/common/buf" +) + +type Metadata struct { + Source *AddrPort + Destination *AddrPort +} + +type TCPConnectionHandler interface { + NewConnection(conn net.Conn, metadata Metadata) error +} + +type UDPHandler interface { + NewPacket(packet *buf.Buffer, metadata Metadata) error +} diff --git a/common/socksaddr/serializer.go b/common/metadata/serializer.go similarity index 76% rename from common/socksaddr/serializer.go rename to common/metadata/serializer.go index 60fad02..045966d 100644 --- a/common/socksaddr/serializer.go +++ b/common/metadata/serializer.go @@ -1,11 +1,11 @@ -package socksaddr +package metadata import ( "encoding/binary" "io" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/exceptions" + E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/rw" ) @@ -24,16 +24,9 @@ func PortThenAddress() SerializerOption { } } -func WithFamilyParser(fp FamilyParser) SerializerOption { - return func(s *Serializer) { - s.familyParser = fp - } -} - type Serializer struct { familyMap map[byte]Family familyByteMap map[Family]byte - familyParser FamilyParser portFirst bool } @@ -66,20 +59,20 @@ func (s *Serializer) WritePort(writer io.Writer, port uint16) error { return binary.Write(writer, binary.BigEndian, port) } -func (s *Serializer) WriteAddressAndPort(writer io.Writer, addr Addr, port uint16) error { +func (s *Serializer) WriteAddrPort(writer io.Writer, addrPort *AddrPort) error { var err error if !s.portFirst { - err = s.WriteAddress(writer, addr) + err = s.WriteAddress(writer, addrPort.Addr) } else { - err = s.WritePort(writer, port) + err = s.WritePort(writer, addrPort.Port) } if err != nil { return err } if s.portFirst { - err = s.WriteAddress(writer, addr) + err = s.WriteAddress(writer, addrPort.Addr) } else { - err = s.WritePort(writer, port) + err = s.WritePort(writer, addrPort.Port) } return err } @@ -89,15 +82,12 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) { if err != nil { return nil, err } - if s.familyParser != nil { - af = s.familyParser(af) - } family := s.familyMap[af] switch family { case AddressFamilyFqdn: fqdn, err := ReadString(reader) if err != nil { - return nil, exceptions.Cause(err, "read fqdn") + return nil, E.Cause(err, "read fqdn") } return AddrFqdn(fqdn), nil default: @@ -106,18 +96,18 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) { var addr [4]byte err = common.Error(reader.Read(addr[:])) if err != nil { - return nil, exceptions.Cause(err, "read ipv4 address") + return nil, E.Cause(err, "read ipv4 address") } return Addr4(addr), nil case AddressFamilyIPv6: var addr [16]byte err = common.Error(reader.Read(addr[:])) if err != nil { - return nil, exceptions.Cause(err, "read ipv6 address") + return nil, E.Cause(err, "read ipv6 address") } return Addr16(addr), nil default: - return nil, exceptions.New("unknown address family: ", af) + return nil, E.New("unknown address family: ", af) } } } @@ -125,12 +115,14 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) { func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) { port, err := rw.ReadBytes(reader, 2) if err != nil { - return 0, exceptions.Cause(err, "read port") + return 0, E.Cause(err, "read port") } return binary.BigEndian.Uint16(port), nil } -func (s *Serializer) ReadAddressAndPort(reader io.Reader) (addr Addr, port uint16, err error) { +func (s *Serializer) ReadAddrPort(reader io.Reader) (addrPort *AddrPort, err error) { + var addr Addr + var port uint16 if !s.portFirst { addr, err = s.ReadAddress(reader) } else { @@ -144,7 +136,10 @@ func (s *Serializer) ReadAddressAndPort(reader io.Reader) (addr Addr, port uint1 } else { port, err = s.ReadPort(reader) } - return + if err != nil { + return + } + return AddrPortFrom(addr, port), nil } func ReadString(reader io.Reader) (string, error) { diff --git a/common/net.go b/common/net.go new file mode 100644 index 0000000..7a14da1 --- /dev/null +++ b/common/net.go @@ -0,0 +1,15 @@ +package common + +import "syscall" + +func GetFileDescriptor(conn syscall.Conn) (uintptr, error) { + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, err + } + var rawFd uintptr + err = rawConn.Control(func(fd uintptr) { + rawFd = fd + }) + return rawFd, err +} diff --git a/common/random/rng.go b/common/random/rng.go new file mode 100644 index 0000000..97e4db8 --- /dev/null +++ b/common/random/rng.go @@ -0,0 +1,18 @@ +package random + +import ( + "crypto/rand" + "io" + + "github.com/sagernet/sing/common" + "lukechampine.com/blake3" +) + +var System = rand.Reader + +func Blake3KeyedHash() io.Reader { + key := make([]byte, 32) + common.Must1(io.ReadFull(System, key)) + h := blake3.New(1024, key) + return h.XOF() +} diff --git a/common/redir/mode.go b/common/redir/mode.go new file mode 100644 index 0000000..a265ab7 --- /dev/null +++ b/common/redir/mode.go @@ -0,0 +1,9 @@ +package redir + +type TransproxyMode uint8 + +const ( + ModeDisabled TransproxyMode = iota + ModeRedirect + ModeTProxy +) diff --git a/common/redir/redir_linux.go b/common/redir/redir_linux.go index 3b8e451..61ea559 100644 --- a/common/redir/redir_linux.go +++ b/common/redir/redir_linux.go @@ -2,11 +2,12 @@ package redir import ( "net" - "net/netip" "syscall" + + M "github.com/sagernet/sing/common/metadata" ) -func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err error) { +func GetOriginalDestination(conn net.Conn) (destination *M.AddrPort, err error) { rawConn, err := conn.(syscall.Conn).SyscallConn() if err != nil { return @@ -22,15 +23,14 @@ func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err erro if conn.RemoteAddr().(*net.TCPAddr).IP.To4() != nil { raw, err := syscall.GetsockoptIPv6Mreq(int(rawFd), syscall.IPPROTO_IP, SO_ORIGINAL_DST) if err != nil { - return netip.AddrPort{}, err + return nil, err } - addr, _ := netip.AddrFromSlice(raw.Multiaddr[4:8]) - return netip.AddrPortFrom(addr, uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil - + return M.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil + } else { + raw, err := syscall.GetsockoptIPv6MTUInfo(int(rawFd), syscall.IPPROTO_IPV6, SO_ORIGINAL_DST) + if err != nil { + return nil, err + } + return M.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), raw.Addr.Port), nil } - raw, err := syscall.GetsockoptIPv6MTUInfo(int(rawFd), syscall.IPPROTO_IPV6, SO_ORIGINAL_DST) - if err != nil { - return - } - return netip.AddrPortFrom(netip.AddrFrom16(raw.Addr.Addr), raw.Addr.Port), nil } diff --git a/common/redir/redir_other.go b/common/redir/redir_other.go index 76bbdb6..379bc98 100644 --- a/common/redir/redir_other.go +++ b/common/redir/redir_other.go @@ -5,10 +5,10 @@ package redir import ( "errors" "net" - "net/netip" + + M "github.com/sagernet/sing/common/metadata" ) -func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err error) { - err = errors.New("unsupported platform") - return +func GetOriginalDestination(conn net.Conn) (destination *M.AddrPort, err error) { + return nil, errors.New("unsupported platform") } diff --git a/common/redir/tproxy_linux.go b/common/redir/tproxy_linux.go new file mode 100644 index 0000000..f5af7b7 --- /dev/null +++ b/common/redir/tproxy_linux.go @@ -0,0 +1,131 @@ +package redir + +import ( + "encoding/binary" + "fmt" + "net" + "os" + "strconv" + "syscall" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "golang.org/x/sys/unix" +) + +func TProxy(fd uintptr, isIPv6 bool) error { + err := syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1) + if err != nil { + return err + } + if isIPv6 { + err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1) + } + return err +} + +func TProxyUDP(fd uintptr, isIPv6 bool) error { + err := syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1) + if err != nil { + return err + } + return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1) +} + +func GetOriginalDestinationFromOOB(oob []byte) (*M.AddrPort, error) { + controlMessages, err := unix.ParseSocketControlMessage(oob) + if err != nil { + return nil, err + } + for _, message := range controlMessages { + if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR { + return M.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil + } else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR { + return M.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil + } + } + return nil, E.New("not found") +} + +func DialUDP(network string, lAddr *net.UDPAddr, rAddr *net.UDPAddr) (*net.UDPConn, error) { + rSockAddr, err := udpAddrToSockAddr(rAddr) + if err != nil { + return nil, err + } + + lSockAddr, err := udpAddrToSockAddr(lAddr) + if err != nil { + return nil, err + } + + fd, err := syscall.Socket(udpAddrFamily(network, lAddr, rAddr), syscall.SOCK_DGRAM, 0) + if err != nil { + return nil, err + } + + if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { + syscall.Close(fd) + return nil, err + } + + if err = syscall.SetsockoptInt(fd, syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil { + syscall.Close(fd) + return nil, err + } + + if err = syscall.Bind(fd, lSockAddr); err != nil { + syscall.Close(fd) + return nil, err + } + + if err = syscall.Connect(fd, rSockAddr); err != nil { + syscall.Close(fd) + return nil, err + } + + fdFile := os.NewFile(uintptr(fd), fmt.Sprintf("net-udp-dial-%s", rAddr.String())) + defer fdFile.Close() + + c, err := net.FileConn(fdFile) + if err != nil { + syscall.Close(fd) + return nil, err + } + + return c.(*net.UDPConn), nil +} + +func udpAddrToSockAddr(addr *net.UDPAddr) (syscall.Sockaddr, error) { + switch { + case addr.IP.To4() != nil: + ip := [4]byte{} + copy(ip[:], addr.IP.To4()) + + return &syscall.SockaddrInet4{Addr: ip, Port: addr.Port}, nil + + default: + ip := [16]byte{} + copy(ip[:], addr.IP.To16()) + + zoneID, err := strconv.ParseUint(addr.Zone, 10, 32) + if err != nil { + zoneID = 0 + } + + return &syscall.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, nil + } +} + +func udpAddrFamily(net string, lAddr, rAddr *net.UDPAddr) int { + switch net[len(net)-1] { + case '4': + return syscall.AF_INET + case '6': + return syscall.AF_INET6 + } + + if (lAddr == nil || lAddr.IP.To4() != nil) && (rAddr == nil || lAddr.IP.To4() != nil) { + return syscall.AF_INET + } + return syscall.AF_INET6 +} diff --git a/common/redir/tproxy_other.go b/common/redir/tproxy_other.go new file mode 100644 index 0000000..5bf1c0b --- /dev/null +++ b/common/redir/tproxy_other.go @@ -0,0 +1,20 @@ +//go:build !linux + +package redir + +import ( + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" +) + +func TProxy(fd uintptr, isIPv6 bool) error { + return E.New("only available on linux") +} + +func TProxyUDP(fd uintptr, isIPv6 bool) error { + return E.New("only available on linux") +} + +func GetOriginalDestinationFromOOB(oob []byte) (*M.AddrPort, error) { + return nil, E.New("only available on linux") +} diff --git a/common/replay/bloomring.go b/common/replay/bloomring.go new file mode 100644 index 0000000..6d347d9 --- /dev/null +++ b/common/replay/bloomring.go @@ -0,0 +1,30 @@ +package replay + +import ( + "github.com/v2fly/ss-bloomring" + "sync" +) + +func NewBloomRing() Filter { + const ( + DefaultSFCapacity = 1e6 + DefaultSFFPR = 1e-6 + DefaultSFSlot = 10 + ) + return &bloomRingFilter{BloomRing: ss_bloomring.NewBloomRing(DefaultSFSlot, DefaultSFCapacity, DefaultSFFPR)} +} + +type bloomRingFilter struct { + sync.Mutex + *ss_bloomring.BloomRing +} + +func (f *bloomRingFilter) Check(sum []byte) bool { + f.Lock() + defer f.Unlock() + if f.Test(sum) { + return false + } + f.Add(sum) + return true +} diff --git a/common/replay/cuckoo.go b/common/replay/cuckoo.go new file mode 100644 index 0000000..117e27b --- /dev/null +++ b/common/replay/cuckoo.go @@ -0,0 +1,50 @@ +package replay + +import ( + "sync" + "time" + + "github.com/seiflotfy/cuckoofilter" +) + +func NewCuckoo(interval int64) Filter { + filter := &cuckooFilter{} + filter.interval = interval + return filter +} + +type cuckooFilter struct { + lock sync.Mutex + poolA *cuckoo.Filter + poolB *cuckoo.Filter + poolSwap bool + lastSwap int64 + interval int64 +} + +func (filter *cuckooFilter) Check(sum []byte) bool { + const defaultCapacity = 100000 + + filter.lock.Lock() + defer filter.lock.Unlock() + + now := time.Now().Unix() + if filter.lastSwap == 0 { + filter.lastSwap = now + filter.poolA = cuckoo.NewFilter(defaultCapacity) + filter.poolB = cuckoo.NewFilter(defaultCapacity) + } + + elapsed := now - filter.lastSwap + if elapsed >= filter.interval { + if filter.poolSwap { + filter.poolA.Reset() + } else { + filter.poolB.Reset() + } + filter.poolSwap = !filter.poolSwap + filter.lastSwap = now + } + + return filter.poolA.InsertUnique(sum) && filter.poolB.InsertUnique(sum) +} diff --git a/common/replay/filter.go b/common/replay/filter.go new file mode 100644 index 0000000..2c82ae9 --- /dev/null +++ b/common/replay/filter.go @@ -0,0 +1,5 @@ +package replay + +type Filter interface { + Check(sum []byte) bool +} diff --git a/common/rw/copy.go b/common/rw/copy.go index 0ae26b4..8d101cc 100644 --- a/common/rw/copy.go +++ b/common/rw/copy.go @@ -40,6 +40,18 @@ func ReadFromVar(writerVar *io.Writer, reader io.Reader) (int64, error) { return 0, os.ErrInvalid } +func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error { + return task.Run(context.Background(), func() error { + defer CloseRead(conn) + defer CloseWrite(dest) + return common.Error(io.Copy(dest, conn)) + }, func() error { + defer CloseRead(dest) + defer CloseWrite(conn) + return common.Error(io.Copy(conn, dest)) + }) +} + func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error { return task.Run(ctx, func() error { buffer := buf.FullNew() diff --git a/common/rw/varinat.go b/common/rw/varinat.go index efb7468..d4823e6 100644 --- a/common/rw/varinat.go +++ b/common/rw/varinat.go @@ -2,8 +2,9 @@ package rw import ( "encoding/binary" - "github.com/sagernet/sing/common" "io" + + "github.com/sagernet/sing/common" ) type InputStream interface { diff --git a/common/session/context.go b/common/session/context.go index 675fea5..2e85fab 100644 --- a/common/session/context.go +++ b/common/session/context.go @@ -5,7 +5,7 @@ import ( "strconv" "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/socksaddr" + M "github.com/sagernet/sing/common/metadata" ) type Network int @@ -20,8 +20,8 @@ type InstanceContext struct{} type Context struct { InstanceContext Network Network - Source socksaddr.Addr - Destination socksaddr.Addr + Source M.Addr + Destination M.Addr SourcePort uint16 DestinationPort uint16 } @@ -30,7 +30,7 @@ func (c Context) DestinationNetAddr() string { return net.JoinHostPort(c.Destination.String(), strconv.Itoa(int(c.DestinationPort))) } -func AddressFromNetAddr(netAddr net.Addr) (addr socksaddr.Addr, port uint16) { +func AddressFromNetAddr(netAddr net.Addr) (addr M.Addr, port uint16) { var ip net.IP switch addr := netAddr.(type) { case *net.TCPAddr: @@ -40,7 +40,7 @@ func AddressFromNetAddr(netAddr net.Addr) (addr socksaddr.Addr, port uint16) { ip = addr.IP port = uint16(addr.Port) } - return socksaddr.AddrFromIP(ip), port + return M.AddrFromIP(ip), port } type Conn struct { diff --git a/common/udpnat/server.go b/common/udpnat/server.go new file mode 100644 index 0000000..41938b6 --- /dev/null +++ b/common/udpnat/server.go @@ -0,0 +1,108 @@ +package udpnat + +import ( + "io" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/gsync" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/redir" + "github.com/sagernet/sing/protocol/socks" +) + +type Handler interface { + socks.UDPConnectionHandler + E.Handler +} + +type Server struct { + udpNat gsync.Map[string, *packetConn] + handler Handler +} + +func NewServer(handler Handler) *Server { + return &Server{handler: handler} +} + +func (s *Server) HandleUDP(buffer *buf.Buffer, metadata M.Metadata) error { + conn, loaded := s.udpNat.LoadOrStore(metadata.Source.String(), func() *packetConn { + return &packetConn{source: metadata.Source.UDPAddr(), in: make(chan *udpPacket)} + }) + if !loaded { + go func() { + err := s.handler.NewPacketConnection(conn, metadata) + if err != nil { + s.handler.HandleError(err) + } + }() + } + conn.in <- &udpPacket{ + buffer: buffer, + destination: metadata.Destination, + } + return nil +} + +func (s *Server) OnError(err error) { + s.handler.HandleError(err) +} + +func (s *Server) Close() error { + s.udpNat.Range(func(key string, conn *packetConn) bool { + conn.Close() + return true + }) + s.udpNat = gsync.Map[string, *packetConn]{} + return nil +} + +type packetConn struct { + socks.PacketConnStub + source *net.UDPAddr + in chan *udpPacket +} + +type udpPacket struct { + buffer *buf.Buffer + destination *M.AddrPort +} + +func (c *packetConn) LocalAddr() net.Addr { + return c.source +} + +func (c *packetConn) Close() error { + select { + case <-c.in: + return io.ErrClosedPipe + default: + close(c.in) + } + return nil +} + +func (c *packetConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { + select { + case packet, ok := <-c.in: + if !ok { + return nil, io.ErrClosedPipe + } + defer packet.buffer.Release() + if buffer.FreeLen() < packet.buffer.Len() { + return nil, io.ErrShortBuffer + } + return packet.destination, common.Error(buffer.Write(packet.buffer.Bytes())) + } +} + +func (c *packetConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { + udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), c.source) + if err != nil { + return E.Cause(err, "tproxy udp write back") + } + defer udpConn.Close() + return common.Error(udpConn.Write(buffer.Bytes())) +} diff --git a/common/uot/client.go b/common/uot/client.go new file mode 100644 index 0000000..d04d748 --- /dev/null +++ b/common/uot/client.go @@ -0,0 +1,80 @@ +package uot + +import ( + "encoding/binary" + "io" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" +) + +type ClientConn struct { + net.Conn +} + +func NewClientConn(conn net.Conn) *ClientConn { + return &ClientConn{conn} +} + +func (c *ClientConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { + destination, err := AddrParser.ReadAddrPort(c) + if err != nil { + return nil, err + } + var length uint16 + err = binary.Read(c, binary.BigEndian, &length) + if err != nil { + return nil, err + } + if buffer.FreeLen() < int(length) { + return nil, io.ErrShortBuffer + } + return destination, common.Error(buffer.ReadFullFrom(c, int(length))) +} + +func (c *ClientConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { + err := AddrParser.WriteAddrPort(c, destination) + if err != nil { + return err + } + err = binary.Write(c, binary.BigEndian, uint16(buffer.Len())) + if err != nil { + return err + } + return common.Error(c.Write(buffer.Bytes())) +} + +func (c *ClientConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + addrPort, err := AddrParser.ReadAddrPort(c) + if err != nil { + return 0, nil, err + } + var length uint16 + err = binary.Read(c, binary.BigEndian, &length) + if err != nil { + return 0, nil, err + } + if len(p) < int(length) { + return 0, nil, io.ErrShortBuffer + } + n, err = io.ReadAtLeast(c, p, int(length)) + if err != nil { + return 0, nil, err + } + addr = addrPort.UDPAddr() + return +} + +func (c *ClientConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + err = AddrParser.WriteAddrPort(c, M.AddrPortFromNetAddr(addr)) + if err != nil { + return + } + err = binary.Write(c, binary.BigEndian, uint16(len(p))) + if err != nil { + return + } + return c.Write(p) +} diff --git a/common/uot/resolver.go b/common/uot/resolver.go new file mode 100644 index 0000000..6c3a7fb --- /dev/null +++ b/common/uot/resolver.go @@ -0,0 +1,21 @@ +package uot + +import ( + "context" + "net" + "time" +) + +var LookupAddress func(domain string) (net.IP, error) + +func init() { + LookupAddress = func(domain string) (net.IP, error) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ips, err := net.DefaultResolver.LookupIP(ctx, "ip", domain) + cancel() + if err != nil { + return nil, err + } + return ips[0], nil + } +} diff --git a/common/uot/server.go b/common/uot/server.go new file mode 100644 index 0000000..f9b9b46 --- /dev/null +++ b/common/uot/server.go @@ -0,0 +1,108 @@ +package uot + +import ( + "encoding/binary" + "io" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" +) + +type ServerConn struct { + net.PacketConn + inputReader, outputReader *io.PipeReader + inputWriter, outputWriter *io.PipeWriter +} + +func NewServerConn(packetConn net.PacketConn) net.Conn { + c := &ServerConn{ + PacketConn: packetConn, + } + c.inputReader, c.inputWriter = io.Pipe() + c.outputReader, c.outputWriter = io.Pipe() + go c.loopInput() + go c.loopOutput() + return c +} + +func (c *ServerConn) Read(b []byte) (n int, err error) { + return c.outputReader.Read(b) +} + +func (c *ServerConn) Write(b []byte) (n int, err error) { + return c.inputWriter.Write(b) +} + +func (c *ServerConn) RemoteAddr() net.Addr { + return &common.DummyAddr{} +} + +func (c *ServerConn) loopInput() { + buffer := buf.New() + defer buffer.Release() + for { + destination, err := AddrParser.ReadAddrPort(c.inputReader) + if err != nil { + break + } + if destination.Addr.Family().IsFqdn() { + ip, err := LookupAddress(destination.Addr.Fqdn()) + if err != nil { + break + } + destination.Addr = M.AddrFromIP(ip) + } + var length uint16 + err = binary.Read(c.inputReader, binary.BigEndian, &length) + if err != nil { + break + } + buffer.FullReset() + _, err = buffer.ReadFullFrom(c.inputReader, int(length)) + if err != nil { + break + } + _, err = c.WriteTo(buffer.Bytes(), destination.UDPAddr()) + if err != nil { + break + } + } + c.Close() +} + +func (c *ServerConn) loopOutput() { + buffer := buf.New() + defer buffer.Release() + for { + buffer.FullReset() + n, addr, err := buffer.ReadPacketFrom(c) + if err != nil { + break + } + destination := M.AddrPortFromNetAddr(addr) + err = AddrParser.WriteAddrPort(c.outputWriter, destination) + if err != nil { + break + } + err = binary.Write(c.outputWriter, binary.BigEndian, uint16(n)) + if err != nil { + break + } + _, err = buffer.WriteTo(c.outputWriter) + if err != nil { + break + } + } + c.Close() +} + +func (c *ServerConn) Close() error { + c.inputReader.Close() + c.inputWriter.Close() + c.outputReader.Close() + c.outputWriter.Close() + c.PacketConn.Close() + return nil +} diff --git a/common/uot/uot.go b/common/uot/uot.go new file mode 100644 index 0000000..dc98fda --- /dev/null +++ b/common/uot/uot.go @@ -0,0 +1,13 @@ +package uot + +import ( + M "github.com/sagernet/sing/common/metadata" +) + +const UOTMagicAddress = "sp.udp-over-tcp.arpa" + +var AddrParser = M.NewSerializer( + M.AddressFamilyByte(0x00, M.AddressFamilyIPv4), + M.AddressFamilyByte(0x01, M.AddressFamilyIPv6), + M.AddressFamilyByte(0x02, M.AddressFamilyFqdn), +) diff --git a/common/uot/uot_test.go b/common/uot/uot_test.go new file mode 100644 index 0000000..753bfaa --- /dev/null +++ b/common/uot/uot_test.go @@ -0,0 +1,39 @@ +package uot + +import ( + "net" + "testing" + + "github.com/sagernet/uot/common" + "github.com/sagernet/uot/common/buf" + "golang.org/x/net/dns/dnsmessage" +) + +func TestServerConn(t *testing.T) { + udpConn, err := net.ListenUDP("udp", nil) + common.Must(err) + serverConn := NewServerConn(udpConn) + defer serverConn.Close() + clientConn := NewClientConn(serverConn) + message := new(dnsmessage.Message) + message.Header.ID = 1 + message.Header.RecursionDesired = true + message.Questions = append(message.Questions, dnsmessage.Question{ + Name: dnsmessage.MustNewName("google.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }) + packet, err := message.Pack() + common.Must(err) + common.Must1(clientConn.WriteTo(packet, &net.UDPAddr{ + IP: net.IPv4(8, 8, 8, 8), + Port: 53, + })) + buffer := buf.New() + defer buffer.Release() + common.Must2(buffer.ReadPacketFrom(clientConn)) + common.Must(message.Unpack(buffer.Bytes())) + for _, answer := range message.Answers { + t.Log("got answer :", answer.Body) + } +} diff --git a/core.go b/core.go index 6f59e87..7a44bf3 100644 --- a/core.go +++ b/core.go @@ -1,3 +1,6 @@ package sing -const Version = "v0.0.0-alpha.1" +const ( + Version = "v0.0.0-alpha.1" + VersionStr = "sing " + Version +) diff --git a/go.mod b/go.mod index 9f2cedc..d57c196 100644 --- a/go.mod +++ b/go.mod @@ -6,20 +6,32 @@ require ( github.com/klauspost/compress v1.15.1 github.com/openacid/low v0.1.21 github.com/oschwald/geoip2-golang v1.7.0 + github.com/sagernet/uot v0.0.0-20220403125237-bf82029ad617 github.com/samber/lo v1.11.0 + github.com/seiflotfy/cuckoofilter v0.0.0-20201222105146-bc6005554a0c github.com/sirupsen/logrus v1.8.1 github.com/spf13/cobra v1.4.0 + github.com/stretchr/testify v1.7.1 github.com/ulikunitz/xz v0.5.10 + github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e github.com/v2fly/v2ray-core/v5 v5.0.3 - golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 + golang.org/x/crypto v0.0.0-20220408190544-5352b0902921 + golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 + golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f google.golang.org/protobuf v1.28.0 + lukechampine.com/blake3 v1.1.7 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/oschwald/maxminddb-golang v1.9.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/spf13/pflag v1.0.5 // indirect golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect - golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index 249812e..c5db731 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165 h1:BS21ZUJ/B5X2UVUbczfmdWH7GapPWAhxcMsDnjJTU1E= +github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= @@ -13,6 +15,10 @@ github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NH github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/klauspost/compress v1.15.1 h1:y9FcTHGyrebwfP0ZZqFiaxTaiDnUrGkJkI+f583BL1A= github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/openacid/errors v0.8.1/go.mod h1:GUQEJJOJE3W9skHm8E8Y4phdl2LLEN8iD7c5gcGgdx0= github.com/openacid/low v0.1.21 h1:Tr2GNu4N/+rGRYdOsEHOE89cxUIaDViZbVmKz29uKGo= github.com/openacid/low v0.1.21/go.mod h1:q+MsKI6Pz2xsCkzV4BLj7NR5M4EX0sGz5AqotpZDVh0= @@ -25,9 +31,15 @@ github.com/oschwald/maxminddb-golang v1.9.0/go.mod h1:TK+s/Z2oZq0rSl4PSeAEoP0bgm github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg= +github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3/go.mod h1:HgjTstvQsPGkxUsCd2KWxErBblirPizecHcpD3ffK+s= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagernet/uot v0.0.0-20220403125237-bf82029ad617 h1:h46Ocvf7zWpatqOHcR4kw+k2GbGcMM7EzGjYG7wiGfM= +github.com/sagernet/uot v0.0.0-20220403125237-bf82029ad617/go.mod h1:T2LhXiIIvaoeKii21x1GONCee9u7N9Nnrqz5bY3SWsM= github.com/samber/lo v1.11.0 h1:JfeYozXL1xfkhRUFOfH13ociyeiLSC/GRJjGKI668xM= github.com/samber/lo v1.11.0/go.mod h1:2I7tgIv8Q1SG2xEIkRq0F2i2zgxVpnyPOP0d3Gj2r+A= +github.com/seiflotfy/cuckoofilter v0.0.0-20201222105146-bc6005554a0c h1:pqy40B3MQWYrza7YZXOXgl0Nf0QGFqrOC0BKae1UNAA= +github.com/seiflotfy/cuckoofilter v0.0.0-20201222105146-bc6005554a0c/go.mod h1:bR6DqgcAl1zTcOX8/pE2Qkj9XO00eCNqmKb7lXP8EAg= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/spf13/cobra v1.4.0 h1:y+wJpx64xcgO1V+RcnwW0LEHxTKRi2ZDPSBjWnrg88Q= @@ -40,18 +52,23 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M= github.com/ulikunitz/xz v0.5.10 h1:t92gobL9l3HE202wg3rlk19F6X+JOxl9BBrCCMYEYd8= github.com/ulikunitz/xz v0.5.10/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= +github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e h1:5QefA066A1tF8gHIiADmOVOV5LS43gt3ONnlEl3xkwI= +github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e/go.mod h1:5t19P9LBIrNamL6AcMQOncg/r10y3Pc01AbHeMhwlpU= github.com/v2fly/v2ray-core/v5 v5.0.3 h1:2rnJ9vZbBQ7V4upWsoUVYGoqZl4grrx8SxOReKx+jjc= github.com/v2fly/v2ray-core/v5 v5.0.3/go.mod h1:zhDdsUJcNE8LcLRA3l7fEQ6QLuveD4/OLbQM2CceSHM= -golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 h1:tkVvjkPTB7pnW3jnid7kNyAMPVWllTNOf/qKDze4p9o= -golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220408190544-5352b0902921 h1:iU7T1X1J6yxDr0rda54sWGkHgOp5XJrqm79gcNlC2VM= +golang.org/x/crypto v0.0.0-20220408190544-5352b0902921/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= +golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 h1:EN5+DfgmRMvRUrMGERW2gQl3Vc+Z7ZMnI/xdEpPSf0c= +golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0Pp5Fn9Qga0mrgxyERQM= -golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f h1:8w7RhxzTVgUzw/AH/9mUV5q0vMgy40SQRursCcfmkCw= +golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= @@ -59,6 +76,11 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lukechampine.com/blake3 v1.1.7 h1:GgRMhmdsuK8+ii6UZFDL8Nb+VyMwadAgcJyfYHxG6n0= +lukechampine.com/blake3 v1.1.7/go.mod h1:tkKEOtDkNtklkXtLNEOGNq5tcV90tJiA1vAA12R78LA= diff --git a/transport/system/mixed.go b/protocol/http/listener.go similarity index 50% rename from transport/system/mixed.go rename to protocol/http/listener.go index 880efdd..874af97 100644 --- a/transport/system/mixed.go +++ b/protocol/http/listener.go @@ -1,92 +1,42 @@ -package system +package http import ( + "bufio" "context" "encoding/base64" "fmt" "net" "net/http" - "net/netip" - "strconv" "strings" - "syscall" "time" + _ "unsafe" + "github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/redir" - "github.com/sagernet/sing/common/socksaddr" - "github.com/sagernet/sing/protocol/socks" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/transport/tcp" ) -type MixedListener struct { - *SocksListener - UDPListener *UDPListener - mcfg *MixedConfig - laddr *net.TCPAddr +type Handler interface { + tcp.Handler } -type MixedConfig struct { - SocksConfig - Redirect bool - TProxy bool -} - -func NewMixedListener(bind netip.AddrPort, config *MixedConfig, handler SocksHandler) *MixedListener { - listener := &MixedListener{SocksListener: NewSocksListener(bind, &config.SocksConfig, handler), mcfg: config} - listener.TCPListener.Handler = listener - if config.TProxy { - listener.UDPListener = NewUDPListener(bind, listener) - } - return listener -} - -func (l *MixedListener) HandleTCP(conn net.Conn) error { - if l.mcfg.Redirect { - var destination netip.AddrPort - destination, err := redir.GetOriginalDestination(conn) - if err == nil { - return l.Handler.NewConnection(socksaddr.AddrFromAddr(destination.Addr()), destination.Port(), conn) - } - } else if l.mcfg.TProxy { - lAddr := conn.LocalAddr().(*net.TCPAddr) - rAddr := conn.RemoteAddr().(*net.TCPAddr) - - if lAddr.Port != l.laddr.Port || !lAddr.IP.Equal(rAddr.IP) && !lAddr.IP.IsLoopback() && !lAddr.IP.IsPrivate() { - addr, port := socksaddr.AddrFromNetAddr(lAddr) - return l.Handler.NewConnection(addr, port, conn) - } - } - - bufConn := buf.NewBufferedConn(conn) - hdr, err := bufConn.ReadByte() - if err != nil { - return err - } - err = bufConn.UnreadByte() - if err != nil { - return err - } - - if hdr == socks.Version4 || hdr == socks.Version5 { - return l.SocksListener.HandleTCP(bufConn) - } - +func HandleConnection(conn *buf.BufferedConn, authenticator auth.Authenticator, handler Handler) error { var httpClient *http.Client for { - request, err := readRequest(bufConn.Reader()) + request, err := readRequest(conn.Reader()) if err != nil { - return exceptions.Cause(err, "read http request") + return E.Cause(err, "read http request") } - if l.Username != "" { + if authenticator != nil { var authOk bool authorization := request.Header.Get("Proxy-Authorization") if strings.HasPrefix(authorization, "BASIC ") { userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:]) - if string(userPassword) == l.Username+":"+l.Password { - authOk = true - } + userPswdArr := strings.SplitN(string(userPassword), ":", 2) + authOk = authenticator.Verify(userPswdArr[0], userPswdArr[1]) } if !authOk { err = responseWith(request, http.StatusProxyAuthRequired).Write(conn) @@ -97,23 +47,23 @@ func (l *MixedListener) HandleTCP(conn net.Conn) error { } if request.Method == "CONNECT" { - host := request.URL.Hostname() portStr := request.URL.Port() if portStr == "" { portStr = "80" } - port, err := strconv.Atoi(portStr) + destination, err := M.ParseAddrPort(request.URL.Hostname(), portStr) if err != nil { - err = responseWith(request, http.StatusBadRequest).Write(conn) if err != nil { return err } } _, err = fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established") if err != nil { - return exceptions.Cause(err, "write http response") + return E.Cause(err, "write http response") } - return l.Handler.NewConnection(socksaddr.ParseAddr(host), uint16(port), bufConn) + return handler.NewConnection(conn, M.Metadata{ + Destination: destination, + }) } keepAlive := strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" @@ -141,19 +91,21 @@ func (l *MixedListener) HandleTCP(conn net.Conn) error { ExpectContinueTimeout: 1 * time.Second, DialContext: func(context context.Context, network, address string) (net.Conn, error) { if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, exceptions.New("unsupported network ", network) + return nil, E.New("unsupported network ", network) } - addr, port, err := socksaddr.ParseAddrPort(address) + destination, err := M.ParseAddress(address) if err != nil { return nil, err } left, right := net.Pipe() go func() { - err = l.Handler.NewConnection(addr, port, right) + err = handler.NewConnection(right, M.Metadata{ + Destination: destination, + }) if err != nil { - l.OnError(err) + handler.HandleError(err) } }() return left, nil @@ -167,7 +119,7 @@ func (l *MixedListener) HandleTCP(conn net.Conn) error { response, err := httpClient.Do(request) if err != nil { - l.OnError(exceptions.Cause(err, "http proxy")) + handler.HandleError(err) return responseWith(request, http.StatusBadGateway).Write(conn) } @@ -183,17 +135,14 @@ func (l *MixedListener) HandleTCP(conn net.Conn) error { err = response.Write(conn) if err != nil { - l.OnError(exceptions.Cause(err, "http proxy")) return err } } } -func (l *MixedListener) HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error { - return nil -} +//go:linkname readRequest net/http.ReadRequest +func readRequest(b *bufio.Reader) (req *http.Request, err error) -// removeHopByHopHeaders remove hop-by-hop header func removeHopByHopHeaders(header http.Header) { // Strip hop-by-hop header based on RFC: // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1 @@ -217,8 +166,6 @@ func removeHopByHopHeaders(header http.Header) { } } -// removeExtraHTTPHostPort remove extra host port (example.com:80 --> example.com) -// It resolves the behavior of some HTTP servers that do not handle host:80 (e.g. baidu.com) func removeExtraHTTPHostPort(req *http.Request) { host := req.Host if host == "" { @@ -243,39 +190,3 @@ func responseWith(request *http.Request, statusCode int) *http.Response { Header: http.Header{}, } } - -func (l *MixedListener) Start() error { - err := l.TCPListener.Start() - if err != nil { - return err - } - if l.mcfg.TProxy { - rawConn, err := l.TCPListener.TCPListener.SyscallConn() - if err != nil { - return err - } - var rawFd uintptr - err = rawConn.Control(func(fd uintptr) { - rawFd = fd - }) - if err != nil { - return err - } - err = syscall.SetsockoptInt(int(rawFd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1) - if err != nil { - return exceptions.Cause(err, "failed to configure TCP tproxy") - } - } - if l.mcfg.TProxy { - l.laddr = l.TCPListener.Addr().(*net.TCPAddr) - } - return nil -} - -func (l *MixedListener) Close() error { - return l.TCPListener.Close() -} - -func (l *MixedListener) OnError(err error) { - l.Handler.OnError(exceptions.Cause(err, "mixed server")) -} diff --git a/transport/system/stub.s b/protocol/http/stub.s similarity index 100% rename from transport/system/stub.s rename to protocol/http/stub.s diff --git a/protocol/shadowsocks/cipher.go b/protocol/shadowsocks/cipher.go deleted file mode 100644 index 22e3441..0000000 --- a/protocol/shadowsocks/cipher.go +++ /dev/null @@ -1,47 +0,0 @@ -package shadowsocks - -import ( - "io" - - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/list" -) - -type Cipher interface { - KeySize() int - SaltSize() int - CreateReader(key []byte, salt []byte, reader io.Reader) io.Reader - CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer - EncodePacket(key []byte, buffer *buf.Buffer) error - DecodePacket(key []byte, buffer *buf.Buffer) error -} - -type CipherCreator func() Cipher - -var ( - cipherList *list.List[string] - cipherMap map[string]CipherCreator -) - -func init() { - cipherList = new(list.List[string]) - cipherMap = make(map[string]CipherCreator) -} - -func RegisterCipher(method string, creator CipherCreator) { - cipherList.PushBack(method) - cipherMap[method] = creator -} - -func CreateCipher(method string) (Cipher, error) { - creator := cipherMap[method] - if creator != nil { - return creator(), nil - } - return nil, exceptions.New("unsupported method: ", method) -} - -func ListCiphers() []string { - return cipherList.Array() -} diff --git a/protocol/shadowsocks/cipher_none.go b/protocol/shadowsocks/cipher_none.go deleted file mode 100644 index 5636039..0000000 --- a/protocol/shadowsocks/cipher_none.go +++ /dev/null @@ -1,39 +0,0 @@ -package shadowsocks - -import ( - "io" - - "github.com/sagernet/sing/common/buf" -) - -func init() { - RegisterCipher("none", func() Cipher { - return (*NoneCipher)(nil) - }) -} - -type NoneCipher struct{} - -func (c *NoneCipher) KeySize() int { - return 16 -} - -func (c *NoneCipher) SaltSize() int { - return 0 -} - -func (c *NoneCipher) CreateReader(_ []byte, _ []byte, reader io.Reader) io.Reader { - return reader -} - -func (c *NoneCipher) CreateWriter(key []byte, iv []byte, writer io.Writer) io.Writer { - return writer -} - -func (c *NoneCipher) EncodePacket([]byte, *buf.Buffer) error { - return nil -} - -func (c *NoneCipher) DecodePacket([]byte, *buf.Buffer) error { - return nil -} diff --git a/protocol/shadowsocks/client.go b/protocol/shadowsocks/client.go deleted file mode 100644 index 0b37f02..0000000 --- a/protocol/shadowsocks/client.go +++ /dev/null @@ -1,178 +0,0 @@ -package shadowsocks - -import ( - "context" - "io" - "net" - "strconv" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/common/socksaddr" - "github.com/sagernet/sing/protocol/socks" -) - -var ( - ErrBadKey = exceptions.New("bad key") - ErrMissingPassword = exceptions.New("password not specified") -) - -type ClientConfig struct { - Server string `json:"server"` - ServerPort uint16 `json:"server_port"` - Method string `json:"method"` - Password []byte `json:"password"` - Key []byte `json:"key"` -} - -type Client struct { - dialer *net.Dialer - cipher Cipher - server string - key []byte -} - -func NewClient(dialer *net.Dialer, config *ClientConfig) (*Client, error) { - if config.Server == "" { - return nil, exceptions.New("missing server address") - } - if config.ServerPort == 0 { - return nil, exceptions.New("missing server port") - } - if config.Method == "" { - return nil, exceptions.New("missing server method") - } - - cipher, err := CreateCipher(config.Method) - if err != nil { - return nil, err - } - client := &Client{ - dialer: dialer, - cipher: cipher, - server: net.JoinHostPort(config.Server, strconv.Itoa(int(config.ServerPort))), - } - if keyLen := len(config.Key); keyLen > 0 { - if keyLen == cipher.KeySize() { - client.key = config.Key - } else { - return nil, ErrBadKey - } - } else if len(config.Password) > 0 { - client.key = Key(config.Password, cipher.KeySize()) - } else { - return nil, ErrMissingPassword - } - return client, nil -} - -func (c *Client) DialContextTCP(ctx context.Context, addr socksaddr.Addr, port uint16) (net.Conn, error) { - conn, err := c.dialer.DialContext(ctx, "tcp", c.server) - if err != nil { - return nil, exceptions.Cause(err, "connect to server") - } - return c.DialConn(conn, addr, port), nil -} - -func (c *Client) DialConn(conn net.Conn, addr socksaddr.Addr, port uint16) net.Conn { - header := buf.New() - header.WriteRandom(c.cipher.SaltSize()) - writer := &buf.BufferedWriter{ - Writer: conn, - Buffer: header, - } - protocolWriter := c.cipher.CreateWriter(c.key, header.Bytes(), writer) - requestBuffer := buf.New() - contentWriter := &buf.BufferedWriter{ - Writer: protocolWriter, - Buffer: requestBuffer, - } - common.Must(AddressSerializer.WriteAddressAndPort(contentWriter, addr, port)) - return &shadowsocksConn{ - Client: c, - Conn: conn, - Writer: &common.FlushOnceWriter{Writer: contentWriter}, - } -} - -type shadowsocksConn struct { - *Client - net.Conn - io.Writer - reader io.Reader -} - -func (c *shadowsocksConn) Read(p []byte) (n int, err error) { - if c.reader == nil { - buffer := buf.Or(p, c.cipher.SaltSize()) - defer buffer.Release() - _, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize()) - if err != nil { - return - } - c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn) - } - return c.reader.Read(p) -} - -func (c *shadowsocksConn) WriteTo(w io.Writer) (n int64, err error) { - if c.reader == nil { - buffer := buf.NewSize(c.cipher.SaltSize()) - defer buffer.Release() - _, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize()) - if err != nil { - return - } - c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn) - } - return c.reader.(io.WriterTo).WriteTo(w) -} - -func (c *shadowsocksConn) Write(p []byte) (n int, err error) { - return c.Writer.Write(p) -} - -func (c *shadowsocksConn) ReadFrom(r io.Reader) (n int64, err error) { - return rw.ReadFromVar(&c.Writer, r) -} - -func (c *Client) DialContextUDP(ctx context.Context) socks.PacketConn { - conn, err := c.dialer.DialContext(ctx, "udp", c.server) - if err != nil { - return nil - } - return &shadowsocksPacketConn{c, conn} -} - -type shadowsocksPacketConn struct { - *Client - net.Conn -} - -func (c *shadowsocksPacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error { - defer buffer.Release() - header := buf.New() - header.WriteRandom(c.cipher.SaltSize()) - common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port)) - buffer = buffer.WriteBufferAtFirst(header) - err := c.cipher.EncodePacket(c.key, buffer) - if err != nil { - return err - } - return common.Error(c.Conn.Write(buffer.Bytes())) -} - -func (c *shadowsocksPacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) { - n, err := c.Read(buffer.FreeBytes()) - if err != nil { - return nil, 0, err - } - buffer.Truncate(n) - err = c.cipher.DecodePacket(c.key, buffer) - if err != nil { - return nil, 0, err - } - return AddressSerializer.ReadAddressAndPort(buffer) -} diff --git a/protocol/shadowsocks/config.go b/protocol/shadowsocks/method.go similarity index 100% rename from protocol/shadowsocks/config.go rename to protocol/shadowsocks/method.go diff --git a/protocol/shadowsocks/none.go b/protocol/shadowsocks/none.go new file mode 100644 index 0000000..0ded5f9 --- /dev/null +++ b/protocol/shadowsocks/none.go @@ -0,0 +1,153 @@ +package shadowsocks + +import ( + "io" + "net" + "sync" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/protocol/socks" +) + +const MethodNone = "none" + +type NoneMethod struct{} + +func NewNone() Method { + return &NoneMethod{} +} + +func (m *NoneMethod) Name() string { + return MethodNone +} + +func (m *NoneMethod) KeyLength() int { + return 0 +} + +func (m *NoneMethod) NewSession(key []byte) Session { + return nil +} + +func (m *NoneMethod) DialConn(_ Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) { + shadowsocksConn := &noneConn{ + Conn: conn, + handshake: true, + destination: destination, + } + return shadowsocksConn, shadowsocksConn.clientHandshake() +} + +func (m *NoneMethod) DialEarlyConn(_ Session, conn net.Conn, destination *M.AddrPort) net.Conn { + return &noneConn{ + Conn: conn, + destination: destination, + } +} + +func (m *NoneMethod) DialPacketConn(_ Session, conn net.Conn) socks.PacketConn { + return &nonePacketConn{conn} +} + +type noneConn struct { + net.Conn + + access sync.Mutex + handshake bool + destination *M.AddrPort +} + +func (c *noneConn) clientHandshake() error { + err := socks.AddressSerializer.WriteAddrPort(c.Conn, c.destination) + if err != nil { + return err + } + c.handshake = true + return nil +} + +func (c *noneConn) Write(b []byte) (n int, err error) { + if c.handshake { + goto direct + } + + c.access.Lock() + defer c.access.Unlock() + + if c.handshake { + goto direct + } + + { + if len(b) == 0 { + return 0, c.clientHandshake() + } + + buffer := buf.New() + defer buffer.Release() + + err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination) + if err != nil { + return + } + + bufN, _ := buffer.Write(b) + _, err = c.Conn.Write(buffer.Bytes()) + if err != nil { + return + } + + if bufN < len(b) { + _, err = c.Conn.Write(b[bufN:]) + if err != nil { + return + } + } + + n = len(b) + } + +direct: + return c.Conn.Write(b) +} + +func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) { + if !c.handshake { + panic("missing client handshake") + } + return c.Conn.(io.ReaderFrom).ReadFrom(r) +} + +func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) { + return c.Conn.(io.WriterTo).WriteTo(w) +} + +func (c *noneConn) RemoteAddr() net.Addr { + return c.destination.TCPAddr() +} + +type nonePacketConn struct { + net.Conn +} + +func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { + _, err := buffer.ReadFrom(c) + if err != nil { + return nil, err + } + return socks.AddressSerializer.ReadAddrPort(buffer) +} + +func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error { + defer buffer.Release() + header := buf.New() + err := socks.AddressSerializer.WriteAddrPort(header, addrPort) + if err != nil { + header.Release() + return err + } + buffer = buffer.WriteBufferAtFirst(header) + return common.Error(buffer.WriteTo(c)) +} diff --git a/protocol/shadowsocks/protocol.go b/protocol/shadowsocks/protocol.go index 9ebc227..2ffb059 100644 --- a/protocol/shadowsocks/protocol.go +++ b/protocol/shadowsocks/protocol.go @@ -2,23 +2,26 @@ package shadowsocks import ( "crypto/md5" - "crypto/sha1" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/replay" + "github.com/sagernet/sing/protocol/socks" "hash/crc32" "io" "math/rand" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/socksaddr" - "golang.org/x/crypto/hkdf" + "net" ) -const MaxPacketSize = 16*1024 - 1 +type Session interface { + Key() []byte + ReplayFilter() replay.Filter +} -func Kdf(key, iv []byte, keyLength int) []byte { - subKey := make([]byte, keyLength) - kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey")) - common.Must1(io.ReadFull(kdf, subKey)) - return subKey +type Method interface { + Name() string + KeyLength() int + DialConn(session Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) + DialEarlyConn(session Session, conn net.Conn, destination *M.AddrPort) net.Conn + DialPacketConn(session Session, conn net.Conn) socks.PacketConn } func Key(password []byte, keySize int) []byte { @@ -43,19 +46,18 @@ func Key(password []byte, keySize int) []byte { return m[:keySize] } -func RemapToPrintable(input []byte) { - const charSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!#$%&()*+,./:;<=>?@[]^_`{|}~\\\"" - seed := rand.New(rand.NewSource(int64(crc32.ChecksumIEEE(input)))) - for i := range input { - input[i] = charSet[seed.Intn(len(charSet))] - } +type ReducedEntropyReader struct { + io.Reader } -var AddressSerializer = socksaddr.NewSerializer( - socksaddr.AddressFamilyByte(0x01, socksaddr.AddressFamilyIPv4), - socksaddr.AddressFamilyByte(0x04, socksaddr.AddressFamilyIPv6), - socksaddr.AddressFamilyByte(0x03, socksaddr.AddressFamilyFqdn), - socksaddr.WithFamilyParser(func(b byte) byte { - return b & 0x0F - }), -) +func (r *ReducedEntropyReader) Read(p []byte) (n int, err error) { + n, err = r.Reader.Read(p) + if n > 6 { + const charSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!#$%&()*+,./:;<=>?@[]^_`{|}~\\\"" + seed := rand.New(rand.NewSource(int64(crc32.ChecksumIEEE(p[:6])))) + for i := range p[:6] { + p[i] = charSet[seed.Intn(len(charSet))] + } + } + return +} diff --git a/protocol/shadowsocks/cipher_aead.go b/protocol/shadowsocks/shadowaead/aead.go similarity index 56% rename from protocol/shadowsocks/cipher_aead.go rename to protocol/shadowsocks/shadowaead/aead.go index 6ff89d0..529eede 100644 --- a/protocol/shadowsocks/cipher_aead.go +++ b/protocol/shadowsocks/shadowaead/aead.go @@ -1,138 +1,20 @@ -package shadowsocks +package shadowaead import ( - "crypto/aes" "crypto/cipher" "encoding/binary" "io" - "net" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" - "golang.org/x/crypto/chacha20poly1305" ) -const PacketLengthBufferSize = 2 +const ( + MaxPacketSize = 16*1024 - 1 + PacketLengthBufferSize = 2 +) -func init() { - RegisterCipher("aes-128-gcm", func() Cipher { - return &AEADCipher{ - KeyLength: 16, - SaltLength: 16, - Constructor: aesGcm, - } - }) - RegisterCipher("aes-192-gcm", func() Cipher { - return &AEADCipher{ - KeyLength: 24, - SaltLength: 24, - Constructor: aesGcm, - } - }) - RegisterCipher("aes-256-gcm", func() Cipher { - return &AEADCipher{ - KeyLength: 32, - SaltLength: 32, - Constructor: aesGcm, - } - }) - RegisterCipher("chacha20-ietf-poly1305", func() Cipher { - return &AEADCipher{ - KeyLength: 32, - SaltLength: 32, - Constructor: chacha20Poly1305, - } - }) - RegisterCipher("xchacha20-ietf-poly1305", func() Cipher { - return &AEADCipher{ - KeyLength: 32, - SaltLength: 32, - Constructor: xchacha20Poly1305, - } - }) -} - -func aesGcm(key []byte) cipher.AEAD { - block, err := aes.NewCipher(key) - common.Must(err) - aead, err := cipher.NewGCM(block) - common.Must(err) - return aead -} - -func chacha20Poly1305(key []byte) cipher.AEAD { - aead, err := chacha20poly1305.New(key) - common.Must(err) - return aead -} - -func xchacha20Poly1305(key []byte) cipher.AEAD { - aead, err := chacha20poly1305.NewX(key) - common.Must(err) - return aead -} - -type AEADCipher struct { - KeyLength int - SaltLength int - Constructor func(key []byte) cipher.AEAD -} - -func (c *AEADCipher) KeySize() int { - return c.KeyLength -} - -func (c *AEADCipher) SaltSize() int { - return c.SaltLength -} - -func (c *AEADCipher) CreateReader(key []byte, salt []byte, reader io.Reader) io.Reader { - return NewAEADReader(reader, c.Constructor(Kdf(key, salt, c.KeyLength))) -} - -func (c *AEADCipher) CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer { - protocolWriter := NewAEADWriter(writer, c.Constructor(Kdf(key, salt, c.KeyLength))) - return protocolWriter -} - -func (c *AEADCipher) EncodePacket(key []byte, buffer *buf.Buffer) error { - aead := c.Constructor(Kdf(key, buffer.To(c.SaltLength), c.KeyLength)) - aead.Seal(buffer.From(c.SaltLength)[:0], rw.ZeroBytes[:aead.NonceSize()], buffer.From(c.SaltLength), nil) - buffer.Extend(aead.Overhead()) - return nil -} - -func (c *AEADCipher) DecodePacket(key []byte, buffer *buf.Buffer) error { - if buffer.Len() < c.SaltLength { - return exceptions.New("bad packet") - } - aead := c.Constructor(Kdf(key, buffer.To(c.SaltLength), c.KeyLength)) - packet, err := aead.Open(buffer.Index(c.SaltLength), rw.ZeroBytes[:aead.NonceSize()], buffer.From(c.SaltLength), nil) - if err != nil { - return err - } - buffer.Advance(c.SaltLength) - buffer.Truncate(len(packet)) - return nil -} - -type AEADConn struct { - net.Conn - Reader *AEADReader - Writer *AEADWriter -} - -func (c *AEADConn) Read(p []byte) (n int, err error) { - return c.Reader.Read(p) -} - -func (c *AEADConn) Write(p []byte) (n int, err error) { - return c.Writer.Write(p) -} - -type AEADReader struct { +type Reader struct { upstream io.Reader cipher cipher.AEAD data []byte @@ -141,8 +23,8 @@ type AEADReader struct { cached int } -func NewAEADReader(upstream io.Reader, cipher cipher.AEAD) *AEADReader { - return &AEADReader{ +func NewReader(upstream io.Reader, cipher cipher.AEAD) *Reader { + return &Reader{ upstream: upstream, cipher: cipher, data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2), @@ -150,19 +32,19 @@ func NewAEADReader(upstream io.Reader, cipher cipher.AEAD) *AEADReader { } } -func (r *AEADReader) Upstream() io.Reader { +func (r *Reader) Upstream() io.Reader { return r.upstream } -func (r *AEADReader) Replaceable() bool { +func (r *Reader) Replaceable() bool { return false } -func (r *AEADReader) SetUpstream(reader io.Reader) { +func (r *Reader) SetUpstream(reader io.Reader) { r.upstream = reader } -func (r *AEADReader) WriteTo(writer io.Writer) (n int64, err error) { +func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) { if r.cached > 0 { writeN, writeErr := writer.Write(r.data[r.index : r.index+r.cached]) if writeErr != nil { @@ -200,7 +82,7 @@ func (r *AEADReader) WriteTo(writer io.Writer) (n int64, err error) { } } -func (r *AEADReader) Read(b []byte) (n int, err error) { +func (r *Reader) Read(b []byte) (n int, err error) { if r.cached > 0 { n = copy(b, r.data[r.index:r.index+r.cached]) r.cached -= n diff --git a/protocol/shadowsocks/shadowaead/method.go b/protocol/shadowsocks/shadowaead/method.go new file mode 100644 index 0000000..4862fa9 --- /dev/null +++ b/protocol/shadowsocks/shadowaead/method.go @@ -0,0 +1,332 @@ +package shadowaead + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "github.com/sagernet/sing/common/replay" + "io" + "net" + "sync" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/protocol/shadowsocks" + "github.com/sagernet/sing/protocol/socks" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/hkdf" +) + +var List = []string{ + "aes-128-gcm", + "aes-192-gcm", + "aes-256-gcm", + "chacha20-ietf-poly1305", + "xchacha20-ietf-poly1305", +} + +func New(method string, secureRNG io.Reader) shadowsocks.Method { + m := &Method{ + name: method, + secureRNG: secureRNG, + } + switch method { + case "aes-128-gcm": + m.keySaltLength = 16 + m.constructor = newAESGCM + case "aes-192-gcm": + m.keySaltLength = 24 + m.constructor = newAESGCM + case "aes-256-gcm": + m.keySaltLength = 32 + m.constructor = newAESGCM + case "chacha20-ietf-poly1305": + m.keySaltLength = 32 + m.constructor = func(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.New(key) + common.Must(err) + return cipher + } + case "xchacha20-ietf-poly1305": + m.keySaltLength = 32 + m.constructor = func(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.NewX(key) + common.Must(err) + return cipher + } + } + return m +} + +func NewSession(key []byte, replayFilter bool) shadowsocks.Session { + var filter replay.Filter + if replayFilter { + filter = replay.NewBloomRing() + } + return &session{key, filter} +} + +func Kdf(key, iv []byte, keyLength int) []byte { + subKey := make([]byte, keyLength) + kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey")) + common.Must1(io.ReadFull(kdf, subKey)) + return subKey +} + +func newAESGCM(key []byte) cipher.AEAD { + block, err := aes.NewCipher(key) + common.Must(err) + aead, err := cipher.NewGCM(block) + common.Must(err) + return aead +} + +type Method struct { + name string + keySaltLength int + constructor func(key []byte) cipher.AEAD + secureRNG io.Reader +} + +func (m *Method) Name() string { + return m.name +} + +func (m *Method) KeyLength() int { + return m.keySaltLength +} + +func (m *Method) DialConn(account shadowsocks.Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) { + shadowsocksConn := &aeadConn{ + Conn: conn, + method: m, + key: account.Key(), + replayFilter: account.ReplayFilter(), + destination: destination, + } + return shadowsocksConn, shadowsocksConn.clientHandshake() +} + +func (m *Method) DialEarlyConn(account shadowsocks.Session, conn net.Conn, destination *M.AddrPort) net.Conn { + return &aeadConn{ + Conn: conn, + method: m, + key: account.Key(), + replayFilter: account.ReplayFilter(), + destination: destination, + } +} + +func (m *Method) DialPacketConn(account shadowsocks.Session, conn net.Conn) socks.PacketConn { + return &aeadPacketConn{conn, account.Key(), m} +} + +func (m *Method) EncodePacket(key []byte, buffer *buf.Buffer) error { + cipher := m.constructor(Kdf(key, buffer.To(m.keySaltLength), m.keySaltLength)) + cipher.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:cipher.NonceSize()], buffer.From(m.keySaltLength), nil) + buffer.Extend(cipher.Overhead()) + return nil +} + +func (m *Method) DecodePacket(key []byte, buffer *buf.Buffer) error { + if buffer.Len() < m.keySaltLength { + return E.New("bad packet") + } + aead := m.constructor(Kdf(key, buffer.To(m.keySaltLength), m.keySaltLength)) + packet, err := aead.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:aead.NonceSize()], buffer.From(m.keySaltLength), nil) + if err != nil { + return err + } + buffer.Advance(m.keySaltLength) + buffer.Truncate(len(packet)) + return nil +} + +type session struct { + key []byte + replayFilter replay.Filter +} + +func (a *session) Key() []byte { + return a.key +} + +func (a *session) ReplayFilter() replay.Filter { + return a.replayFilter +} + +type aeadConn struct { + net.Conn + + method *Method + key []byte + destination *M.AddrPort + + access sync.Mutex + reader io.Reader + writer io.Writer + replayFilter replay.Filter +} + +func (c *aeadConn) clientHandshake() error { + header := buf.New() + defer header.Release() + + common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength)) + if c.replayFilter != nil { + c.replayFilter.Check(header.Bytes()) + } + + c.writer = NewAEADWriter( + &buf.BufferedWriter{ + Writer: c.Conn, + Buffer: header, + }, + c.method.constructor(Kdf(c.key, header.Bytes(), c.method.keySaltLength)), + ) + + err := socks.AddressSerializer.WriteAddrPort(c.writer, c.destination) + if err != nil { + return err + } + + return common.FlushVar(&c.writer) +} + +func (c *aeadConn) serverHandshake() error { + if c.reader == nil { + salt := make([]byte, c.method.keySaltLength) + _, err := io.ReadFull(c.Conn, salt) + if err != nil { + return err + } + if c.replayFilter != nil { + if !c.replayFilter.Check(salt) { + return E.New("salt is not unique") + } + } + c.reader = NewReader(c.Conn, c.method.constructor(Kdf(c.key, salt, c.method.keySaltLength))) + } + return nil +} + +func (c *aeadConn) Read(p []byte) (n int, err error) { + if err = c.serverHandshake(); err != nil { + return + } + return c.reader.Read(p) +} + +func (c *aeadConn) WriteTo(w io.Writer) (n int64, err error) { + if err = c.serverHandshake(); err != nil { + return + } + return c.reader.(io.WriterTo).WriteTo(w) +} + +func (c *aeadConn) Write(p []byte) (n int, err error) { + if c.writer != nil { + goto direct + } + + c.access.Lock() + defer c.access.Unlock() + + if c.writer != nil { + goto direct + } + + // client handshake + + { + header := buf.New() + defer header.Release() + + request := buf.New() + defer request.Release() + + common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength)) + if c.replayFilter != nil { + c.replayFilter.Check(header.Bytes()) + } + + var writer io.Writer = c.Conn + writer = &buf.BufferedWriter{ + Writer: writer, + Buffer: header, + } + writer = NewAEADWriter(writer, c.method.constructor(Kdf(c.key, header.Bytes(), c.method.keySaltLength))) + writer = &buf.BufferedWriter{ + Writer: writer, + Buffer: request, + } + + err = socks.AddressSerializer.WriteAddrPort(writer, c.destination) + if err != nil { + return + } + + if len(p) > 0 { + _, err = writer.Write(p) + if err != nil { + return + } + } + + err = common.FlushVar(&writer) + if err != nil { + return + } + + c.writer = writer + return len(p), nil + } + +direct: + return c.writer.Write(p) +} + +func (c *aeadConn) ReadFrom(r io.Reader) (n int64, err error) { + if c.writer == nil { + panic("missing client handshake") + } + return c.writer.(io.ReaderFrom).ReadFrom(r) +} + +type aeadPacketConn struct { + net.Conn + key []byte + method *Method +} + +func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { + defer buffer.Release() + header := buf.New() + common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength)) + err := socks.AddressSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + buffer = buffer.WriteBufferAtFirst(header) + err = c.method.EncodePacket(c.key, buffer) + if err != nil { + return err + } + return common.Error(c.Write(buffer.Bytes())) +} + +func (c *aeadPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { + n, err := c.Read(buffer.FreeBytes()) + if err != nil { + return nil, err + } + buffer.Truncate(n) + err = c.method.DecodePacket(c.key, buffer) + if err != nil { + return nil, err + } + return socks.AddressSerializer.ReadAddrPort(buffer) +} diff --git a/protocol/socks/conn.go b/protocol/socks/conn.go index 3001272..9f78ec9 100644 --- a/protocol/socks/conn.go +++ b/protocol/socks/conn.go @@ -6,12 +6,12 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/socksaddr" + M "github.com/sagernet/sing/common/metadata" ) type PacketConn interface { - ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) - WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error + ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) + WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error Close() error LocalAddr() net.Addr @@ -21,23 +21,45 @@ type PacketConn interface { SetWriteDeadline(t time.Time) error } -func CopyPacketConn(dest PacketConn, conn PacketConn, onAction func(size int)) error { +type UDPConnectionHandler interface { + NewPacketConnection(conn PacketConn, metadata M.Metadata) error +} + +type PacketConnStub struct{} + +func (s *PacketConnStub) RemoteAddr() net.Addr { + return &common.DummyAddr{} +} + +func (s *PacketConnStub) SetDeadline(t time.Time) error { + return nil +} + +func (s *PacketConnStub) SetReadDeadline(t time.Time) error { + return nil +} + +func (s *PacketConnStub) SetWriteDeadline(t time.Time) error { + return nil +} + +func CopyPacketConn(dest PacketConn, conn PacketConn, onAction func(destination *M.AddrPort, n int)) error { for { buffer := buf.New() - addr, port, err := conn.ReadPacket(buffer) + destination, err := conn.ReadPacket(buffer) if err != nil { buffer.Release() return err } size := buffer.Len() - err = dest.WritePacket(buffer, addr, port) + err = dest.WritePacket(buffer, destination) if err != nil { + buffer.Release() return err } if onAction != nil { - onAction(size) + onAction(destination, size) } - buffer.Reset() } } @@ -58,22 +80,22 @@ func (c *associatePacketConn) RemoteAddr() net.Addr { return c.addr } -func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) { +func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes()) if err != nil { - return nil, 0, err + return nil, err } c.addr = addr buffer.Truncate(n) buffer.Advance(3) - return AddressSerializer.ReadAddressAndPort(buffer) + return AddressSerializer.ReadAddrPort(buffer) } -func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error { +func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error { defer buffer.Release() header := buf.New() common.Must(header.WriteZeroN(3)) - common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port)) + common.Must(AddressSerializer.WriteAddrPort(header, addrPort)) buffer = buffer.WriteBufferAtFirst(header) return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr)) } diff --git a/protocol/socks/constant.go b/protocol/socks/constant.go index 64e81fb..b9b75c5 100644 --- a/protocol/socks/constant.go +++ b/protocol/socks/constant.go @@ -3,7 +3,7 @@ package socks import ( "strconv" - "github.com/sagernet/sing/common/socksaddr" + M "github.com/sagernet/sing/common/metadata" ) const ( @@ -69,8 +69,8 @@ func (code ReplyCode) String() string { } } -var AddressSerializer = socksaddr.NewSerializer( - socksaddr.AddressFamilyByte(0x01, socksaddr.AddressFamilyIPv4), - socksaddr.AddressFamilyByte(0x04, socksaddr.AddressFamilyIPv6), - socksaddr.AddressFamilyByte(0x03, socksaddr.AddressFamilyFqdn), +var AddressSerializer = M.NewSerializer( + M.AddressFamilyByte(0x01, M.AddressFamilyIPv4), + M.AddressFamilyByte(0x04, M.AddressFamilyIPv6), + M.AddressFamilyByte(0x03, M.AddressFamilyFqdn), ) diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 47c0c22..a7a119c 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -4,11 +4,11 @@ import ( "io" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/socksaddr" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" ) -func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksaddr.Addr, port uint16, username string, password string) (*Response, error) { +func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination *M.AddrPort, username string, password string) (*Response, error) { var method byte if common.IsBlank(username) { method = AuthTypeNotRequired @@ -27,7 +27,7 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksa return nil, err } if authResponse.Method != method { - return nil, exceptions.New("not requested method, request ", method, ", return ", method) + return nil, E.New("not requested method, request ", method, ", return ", method) } if method == AuthTypeUsernamePassword { err = WriteUsernamePasswordAuthRequest(conn, &UsernamePasswordAuthRequest{ @@ -46,10 +46,9 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksa } } err = WriteRequest(conn, &Request{ - Version: version, - Command: command, - Addr: addr, - Port: port, + Version: version, + Command: command, + Destination: destination, }) if err != nil { return nil, err @@ -57,7 +56,7 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksa return ReadResponse(conn) } -func ClientFastHandshake(writer io.Writer, version byte, command byte, addr socksaddr.Addr, port uint16, username string, password string) error { +func ClientFastHandshake(writer io.Writer, version byte, command byte, destination *M.AddrPort, username string, password string) error { var method byte if common.IsBlank(username) { method = AuthTypeNotRequired @@ -81,10 +80,9 @@ func ClientFastHandshake(writer io.Writer, version byte, command byte, addr sock } } return WriteRequest(writer, &Request{ - Version: version, - Command: command, - Addr: addr, - Port: port, + Version: version, + Command: command, + Destination: destination, }) } diff --git a/protocol/socks/listener.go b/protocol/socks/listener.go new file mode 100644 index 0000000..69134c4 --- /dev/null +++ b/protocol/socks/listener.go @@ -0,0 +1,148 @@ +package socks + +import ( + "io" + "net" + "net/netip" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/transport/tcp" +) + +type Handler interface { + tcp.Handler + UDPConnectionHandler +} + +type Listener struct { + tcpListener *tcp.Listener + authenticator auth.Authenticator + handler Handler +} + +func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler Handler) *Listener { + listener := &Listener{ + handler: handler, + authenticator: authenticator, + } + listener.tcpListener = tcp.NewTCPListener(bind, listener) + return listener +} + +func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error { + return HandleConnection(conn, l.authenticator, l.handler) +} + +func (l *Listener) Start() error { + return l.tcpListener.Start() +} + +func (l *Listener) Close() error { + return l.tcpListener.Close() +} + +func (l *Listener) HandleError(err error) { + l.handler.HandleError(err) +} + +func HandleConnection(conn net.Conn, authenticator auth.Authenticator, handler Handler) error { + authRequest, err := ReadAuthRequest(conn) + if err != nil { + return E.Cause(err, "read socks auth request") + } + var authMethod byte + if authenticator == nil { + authMethod = AuthTypeNotRequired + } else { + authMethod = AuthTypeUsernamePassword + } + if !common.Contains(authRequest.Methods, authMethod) { + err = WriteAuthResponse(conn, &AuthResponse{ + Version: authRequest.Version, + Method: AuthTypeNoAcceptedMethods, + }) + if err != nil { + return E.Cause(err, "write socks auth response") + } + } + err = WriteAuthResponse(conn, &AuthResponse{ + Version: authRequest.Version, + Method: AuthTypeNotRequired, + }) + if err != nil { + return E.Cause(err, "write socks auth response") + } + + if authMethod == AuthTypeUsernamePassword { + usernamePasswordAuthRequest, err := ReadUsernamePasswordAuthRequest(conn) + if err != nil { + return E.Cause(err, "read user auth request") + } + response := new(UsernamePasswordAuthResponse) + if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) { + response.Status = UsernamePasswordStatusSuccess + } else { + response.Status = UsernamePasswordStatusFailure + } + err = WriteUsernamePasswordAuthResponse(conn, response) + if err != nil { + return E.Cause(err, "write user auth response") + } + } + + request, err := ReadRequest(conn) + if err != nil { + return E.Cause(err, "read socks request") + } + switch request.Command { + case CommandConnect: + err = WriteResponse(conn, &Response{ + Version: request.Version, + ReplyCode: ReplyCodeSuccess, + Bind: M.AddrPortFromNetAddr(conn.LocalAddr()), + }) + if err != nil { + return E.Cause(err, "write socks response") + } + return handler.NewConnection(conn, M.Metadata{ + Destination: request.Destination, + }) + case CommandUDPAssociate: + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return err + } + defer udpConn.Close() + err = WriteResponse(conn, &Response{ + Version: request.Version, + ReplyCode: ReplyCodeSuccess, + Bind: M.AddrPortFromNetAddr(udpConn.LocalAddr()), + }) + if err != nil { + return E.Cause(err, "write socks response") + } + go func() { + err := handler.NewPacketConnection(NewPacketConn(conn, udpConn), M.Metadata{ + Source: M.AddrPortFromNetAddr(conn.RemoteAddr()), + Destination: request.Destination, + }) + if err != nil { + handler.HandleError(err) + } + conn.Close() + }() + return common.Error(io.Copy(io.Discard, conn)) + default: + err = WriteResponse(conn, &Response{ + Version: request.Version, + ReplyCode: ReplyCodeUnsupported, + }) + if err != nil { + return E.Cause(err, "write response") + } + } + return nil +} diff --git a/protocol/socks/protocol.go b/protocol/socks/protocol.go index 2270c98..acd163d 100644 --- a/protocol/socks/protocol.go +++ b/protocol/socks/protocol.go @@ -7,9 +7,9 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/exceptions" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/common/socksaddr" ) //+----+----------+----------+ @@ -45,11 +45,11 @@ func ReadAuthRequest(reader io.Reader) (*AuthRequest, error) { } methodLen, err := rw.ReadByte(reader) if err != nil { - return nil, exceptions.Cause(err, "read socks auth methods length") + return nil, E.Cause(err, "read socks auth methods length") } methods, err := rw.ReadBytes(reader, int(methodLen)) if err != nil { - return nil, exceptions.CauseF(err, "read socks auth methods, length ", methodLen) + return nil, E.CauseF(err, "read socks auth methods, length ", methodLen) } request := &AuthRequest{ version, @@ -112,11 +112,11 @@ func WriteUsernamePasswordAuthRequest(writer io.Writer, request *UsernamePasswor if err != nil { return err } - err = socksaddr.WriteString(writer, "username", request.Username) + err = M.WriteString(writer, "username", request.Username) if err != nil { return err } - return socksaddr.WriteString(writer, "password", request.Password) + return M.WriteString(writer, "password", request.Password) } func ReadUsernamePasswordAuthRequest(reader io.Reader) (*UsernamePasswordAuthRequest, error) { @@ -127,11 +127,11 @@ func ReadUsernamePasswordAuthRequest(reader io.Reader) (*UsernamePasswordAuthReq if version != UsernamePasswordVersion1 { return nil, &UnsupportedVersionException{version} } - username, err := socksaddr.ReadString(reader) + username, err := M.ReadString(reader) if err != nil { return nil, err } - password, err := socksaddr.ReadString(reader) + password, err := M.ReadString(reader) if err != nil { return nil, err } @@ -185,10 +185,9 @@ func ReadUsernamePasswordAuthResponse(reader io.Reader) (*UsernamePasswordAuthRe //+----+-----+-------+------+----------+----------+ type Request struct { - Version byte - Command byte - Addr socksaddr.Addr - Port uint16 + Version byte + Command byte + Destination *M.AddrPort } func WriteRequest(writer io.Writer, request *Request) error { @@ -204,7 +203,7 @@ func WriteRequest(writer io.Writer, request *Request) error { if err != nil { return err } - return AddressSerializer.WriteAddressAndPort(writer, request.Addr, request.Port) + return AddressSerializer.WriteAddrPort(writer, request.Destination) } func ReadRequest(reader io.Reader) (*Request, error) { @@ -226,15 +225,14 @@ func ReadRequest(reader io.Reader) (*Request, error) { if err != nil { return nil, err } - addr, port, err := AddressSerializer.ReadAddressAndPort(reader) + addrPort, err := AddressSerializer.ReadAddrPort(reader) if err != nil { return nil, err } request := &Request{ - Version: version, - Command: command, - Addr: addr, - Port: port, + Version: version, + Command: command, + Destination: addrPort, } return request, nil } @@ -248,8 +246,7 @@ func ReadRequest(reader io.Reader) (*Request, error) { type Response struct { Version byte ReplyCode ReplyCode - BindAddr socksaddr.Addr - BindPort uint16 + Bind *M.AddrPort } func WriteResponse(writer io.Writer, response *Response) error { @@ -265,10 +262,10 @@ func WriteResponse(writer io.Writer, response *Response) error { if err != nil { return err } - if response.BindAddr == nil { - return AddressSerializer.WriteAddressAndPort(writer, socksaddr.AddrFromIP(net.IPv4zero), response.BindPort) + if response.Bind == nil { + return AddressSerializer.WriteAddrPort(writer, M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0)) } - return AddressSerializer.WriteAddressAndPort(writer, response.BindAddr, response.BindPort) + return AddressSerializer.WriteAddrPort(writer, response.Bind) } func ReadResponse(reader io.Reader) (*Response, error) { @@ -287,15 +284,14 @@ func ReadResponse(reader io.Reader) (*Response, error) { if err != nil { return nil, err } - addr, port, err := AddressSerializer.ReadAddressAndPort(reader) + addrPort, err := AddressSerializer.ReadAddrPort(reader) if err != nil { return nil, err } response := &Response{ Version: version, ReplyCode: ReplyCode(replyCode), - BindAddr: addr, - BindPort: port, + Bind: addrPort, } return response, nil } @@ -307,15 +303,14 @@ func ReadResponse(reader io.Reader) (*Response, error) { //+----+------+------+----------+----------+----------+ type AssociatePacket struct { - Fragment byte - Addr socksaddr.Addr - Port uint16 - Data []byte + Fragment byte + Destination *M.AddrPort + Data []byte } func DecodeAssociatePacket(buffer *buf.Buffer) (*AssociatePacket, error) { if buffer.Len() < 5 { - return nil, exceptions.New("insufficient length") + return nil, E.New("insufficient length") } fragment := buffer.Byte(2) reader := bytes.NewReader(buffer.Bytes()) @@ -323,16 +318,15 @@ func DecodeAssociatePacket(buffer *buf.Buffer) (*AssociatePacket, error) { if err != nil { return nil, err } - addr, port, err := AddressSerializer.ReadAddressAndPort(reader) + addrPort, err := AddressSerializer.ReadAddrPort(reader) if err != nil { return nil, err } buffer.Advance(reader.Len()) packet := &AssociatePacket{ - Fragment: fragment, - Addr: addr, - Port: port, - Data: buffer.Bytes(), + Fragment: fragment, + Destination: addrPort, + Data: buffer.Bytes(), } return packet, nil } @@ -346,7 +340,7 @@ func EncodeAssociatePacket(packet *AssociatePacket, buffer *buf.Buffer) error { if err != nil { return err } - err = AddressSerializer.WriteAddressAndPort(buffer, packet.Addr, packet.Port) + err = AddressSerializer.WriteAddrPort(buffer, packet.Destination) if err != nil { return err } diff --git a/protocol/socks/protocol_test.go b/protocol/socks/protocol_test.go index fda38b3..0588cc2 100644 --- a/protocol/socks/protocol_test.go +++ b/protocol/socks/protocol_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/sagernet/sing/common/socksaddr" + M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/protocol/socks" ) @@ -20,7 +20,7 @@ func TestHandshake(t *testing.T) { method := socks.AuthTypeUsernamePassword go func() { - response, err := socks.ClientHandshake(client, socks.Version5, socks.CommandConnect, socksaddr.AddrFromFqdn("test"), 80, "user", "pswd") + response, err := socks.ClientHandshake(client, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn("test"), 80), "user", "pswd") if err != nil { t.Fatal(err) } @@ -60,14 +60,13 @@ func TestHandshake(t *testing.T) { if err != nil { t.Fatal(err) } - if request.Version != socks.Version5 || request.Command != socks.CommandConnect || request.Addr.Fqdn() != "test" || request.Port != 80 { + if request.Version != socks.Version5 || request.Command != socks.CommandConnect || request.Destination.Addr.Fqdn() != "test" || request.Destination.Port != 80 { t.Fatal(request) } err = socks.WriteResponse(server, &socks.Response{ Version: socks.Version5, ReplyCode: socks.ReplyCodeSuccess, - BindAddr: socksaddr.AddrFromIP(net.IPv4zero), - BindPort: 0, + Bind: M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0), }) if err != nil { t.Fatal(err) diff --git a/transport/mixed/listener.go b/transport/mixed/listener.go new file mode 100644 index 0000000..3be4432 --- /dev/null +++ b/transport/mixed/listener.go @@ -0,0 +1,86 @@ +package mixed + +import ( + "net" + "net/netip" + + "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/redir" + "github.com/sagernet/sing/common/udpnat" + "github.com/sagernet/sing/protocol/http" + "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/transport/tcp" + "github.com/sagernet/sing/transport/udp" +) + +type Handler interface { + socks.Handler +} + +type Listener struct { + TCPListener *tcp.Listener + UDPListener *udp.Listener + handler Handler + authenticator auth.Authenticator + udpNat *udpnat.Server +} + +func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, handler Handler) *Listener { + listener := &Listener{ + handler: handler, + authenticator: authenticator, + } + + listener.TCPListener = tcp.NewTCPListener(bind, listener, tcp.WithTransproxyMode(transproxy)) + if transproxy == redir.ModeTProxy { + listener.UDPListener = udp.NewUDPListener(bind, listener, udp.WithTransproxyMode(transproxy)) + listener.udpNat = udpnat.NewServer(handler) + } + return listener +} + +func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error { + if metadata.Destination != nil { + return l.handler.NewConnection(conn, metadata) + } + bufConn := buf.NewBufferedConn(conn) + header, err := bufConn.Peek(1) + if err != nil { + return err + } + switch header[0] { + case socks.Version4, socks.Version5: + return socks.HandleConnection(bufConn, l.authenticator, l.handler) + default: + return http.HandleConnection(bufConn, l.authenticator, l.handler) + } +} + +func (l *Listener) NewPacket(packet *buf.Buffer, metadata M.Metadata) error { + return l.udpNat.HandleUDP(packet, metadata) +} + +func (l *Listener) HandleError(err error) { + l.handler.HandleError(err) +} + +func (l *Listener) Start() error { + err := l.TCPListener.Start() + if err != nil { + return err + } + if l.UDPListener != nil { + err = l.UDPListener.Start() + } + return err +} + +func (l *Listener) Close() error { + l.TCPListener.Close() + if l.UDPListener != nil { + l.UDPListener.Close() + } + return nil +} diff --git a/transport/system/http.go b/transport/system/http.go deleted file mode 100644 index 5bfa2e7..0000000 --- a/transport/system/http.go +++ /dev/null @@ -1,10 +0,0 @@ -package system - -import ( - "bufio" - "net/http" - _ "unsafe" -) - -//go:linkname readRequest net/http.readRequest -func readRequest(b *bufio.Reader) (req *http.Request, err error) diff --git a/transport/system/sockopt_linux.go b/transport/system/sockopt_linux.go index 2f7335a..c7326f8 100644 --- a/transport/system/sockopt_linux.go +++ b/transport/system/sockopt_linux.go @@ -2,6 +2,8 @@ package system import ( "syscall" + + "golang.org/x/sys/unix" ) const ( @@ -12,3 +14,22 @@ const ( func TCPFastOpen(fd uintptr) error { return syscall.SetsockoptInt(int(fd), syscall.SOL_TCP, TCP_FASTOPEN_CONNECT, 1) } + +func TProxy(fd uintptr, isIPv6 bool) error { + err := syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1) + if err != nil { + return err + } + if isIPv6 { + err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1) + } + return err +} + +func TProxyUDP(fd uintptr, isIPv6 bool) error { + err := syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1) + if err != nil { + return err + } + return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1) +} diff --git a/transport/system/sockopt_other.go b/transport/system/sockopt_other.go index 2eca8e0..b4b8891 100644 --- a/transport/system/sockopt_other.go +++ b/transport/system/sockopt_other.go @@ -7,3 +7,7 @@ import "github.com/sagernet/sing/common/exceptions" func TCPFastOpen(fd uintptr) error { return exceptions.New("only available on linux") } + +func TProxy(fd uintptr, isIPv6 bool) error { + return exceptions.New("only available on linux") +} diff --git a/transport/system/socks.go b/transport/system/socks.go deleted file mode 100644 index 57538b4..0000000 --- a/transport/system/socks.go +++ /dev/null @@ -1,149 +0,0 @@ -package system - -import ( - "io" - "net" - "net/netip" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/socksaddr" - "github.com/sagernet/sing/protocol/socks" -) - -type SocksHandler interface { - NewConnection(addr socksaddr.Addr, port uint16, conn net.Conn) error - NewPacketConnection(conn socks.PacketConn, addr socksaddr.Addr, port uint16) error - OnError(err error) -} - -type SocksConfig struct { - Username string - Password string -} - -type SocksListener struct { - Handler SocksHandler - *TCPListener - *SocksConfig -} - -func NewSocksListener(bind netip.AddrPort, config *SocksConfig, handler SocksHandler) *SocksListener { - listener := &SocksListener{ - SocksConfig: config, - Handler: handler, - } - listener.TCPListener = NewTCPListener(bind, listener) - return listener -} - -func (l *SocksListener) HandleTCP(conn net.Conn) error { - authRequest, err := socks.ReadAuthRequest(conn) - if err != nil { - return exceptions.Cause(err, "read socks auth request") - } - var authMethod byte - if l.Username == "" { - authMethod = socks.AuthTypeNotRequired - } else { - authMethod = socks.AuthTypeUsernamePassword - } - if !common.Contains(authRequest.Methods, authMethod) { - err = socks.WriteAuthResponse(conn, &socks.AuthResponse{ - Version: authRequest.Version, - Method: socks.AuthTypeNoAcceptedMethods, - }) - if err != nil { - return exceptions.Cause(err, "write socks auth response") - } - } - err = socks.WriteAuthResponse(conn, &socks.AuthResponse{ - Version: authRequest.Version, - Method: socks.AuthTypeNotRequired, - }) - if err != nil { - return exceptions.Cause(err, "write socks auth response") - } - - if authMethod == socks.AuthTypeUsernamePassword { - usernamePasswordAuthRequest, err := socks.ReadUsernamePasswordAuthRequest(conn) - if err != nil { - return exceptions.Cause(err, "read user auth request") - } - response := socks.UsernamePasswordAuthResponse{} - if usernamePasswordAuthRequest.Username != l.Username { - response.Status = socks.UsernamePasswordStatusFailure - } else if usernamePasswordAuthRequest.Password != l.Password { - response.Status = socks.UsernamePasswordStatusFailure - } else { - response.Status = socks.UsernamePasswordStatusSuccess - } - err = socks.WriteUsernamePasswordAuthResponse(conn, &response) - if err != nil { - return exceptions.Cause(err, "write user auth response") - } - } - - request, err := socks.ReadRequest(conn) - if err != nil { - return exceptions.Cause(err, "read socks request") - } - switch request.Command { - case socks.CommandConnect: - localAddr, localPort := socksaddr.AddrFromNetAddr(l.TCPListener.TCPListener.Addr()) - err = socks.WriteResponse(conn, &socks.Response{ - Version: request.Version, - ReplyCode: socks.ReplyCodeSuccess, - BindAddr: localAddr, - BindPort: localPort, - }) - if err != nil { - return exceptions.Cause(err, "write socks response") - } - return l.Handler.NewConnection(request.Addr, request.Port, conn) - case socks.CommandUDPAssociate: - udpConn, err := net.ListenUDP("udp4", nil) - if err != nil { - return err - } - defer udpConn.Close() - localAddr, localPort := socksaddr.AddrFromNetAddr(udpConn.LocalAddr()) - err = socks.WriteResponse(conn, &socks.Response{ - Version: request.Version, - ReplyCode: socks.ReplyCodeSuccess, - BindAddr: localAddr, - BindPort: localPort, - }) - if err != nil { - return exceptions.Cause(err, "write socks response") - } - go func() { - err := l.Handler.NewPacketConnection(socks.NewPacketConn(conn, udpConn), request.Addr, request.Port) - if err != nil { - l.OnError(err) - } - }() - return common.Error(io.Copy(io.Discard, conn)) - default: - err = socks.WriteResponse(conn, &socks.Response{ - Version: request.Version, - ReplyCode: socks.ReplyCodeUnsupported, - }) - if err != nil { - return exceptions.Cause(err, "write response") - } - } - return nil -} - -func (l *SocksListener) Start() error { - return l.TCPListener.Start() -} - -func (l *SocksListener) Close() error { - return l.TCPListener.Close() -} - -func (l *SocksListener) OnError(err error) { - l.Handler.OnError(exceptions.Cause(err, "socks server")) -} diff --git a/transport/system/tcp.go b/transport/system/tcp.go deleted file mode 100644 index fcd335e..0000000 --- a/transport/system/tcp.go +++ /dev/null @@ -1,57 +0,0 @@ -package system - -import ( - "net" - "net/netip" -) - -type TCPHandler interface { - HandleTCP(conn net.Conn) error - OnError(err error) -} - -type TCPListener struct { - Listen netip.AddrPort - Handler TCPHandler - *net.TCPListener -} - -func NewTCPListener(listen netip.AddrPort, handler TCPHandler) *TCPListener { - return &TCPListener{ - Listen: listen, - Handler: handler, - } -} - -func (l *TCPListener) Start() error { - tcpListener, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(l.Listen)) - if err != nil { - return err - } - l.TCPListener = tcpListener - go l.loop() - return nil -} - -func (l *TCPListener) Close() error { - if l == nil || l.TCPListener == nil { - return nil - } - return l.TCPListener.Close() -} - -func (l *TCPListener) loop() { - for { - tcpConn, err := l.Accept() - if err != nil { - l.Close() - return - } - go func() { - err := l.Handler.HandleTCP(tcpConn) - if err != nil { - l.Handler.OnError(err) - } - }() - } -} diff --git a/transport/system/udp.go b/transport/system/udp.go deleted file mode 100644 index 8cb6ab6..0000000 --- a/transport/system/udp.go +++ /dev/null @@ -1,62 +0,0 @@ -package system - -import ( - "net" - "net/netip" - - "github.com/sagernet/sing/common/buf" -) - -type UDPHandler interface { - HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error - OnError(err error) -} - -type UDPListener struct { - Listen netip.AddrPort - Handler UDPHandler - *net.UDPConn -} - -func NewUDPListener(listen netip.AddrPort, handler UDPHandler) *UDPListener { - return &UDPListener{ - Listen: listen, - Handler: handler, - } -} - -func (l *UDPListener) Start() error { - udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(l.Listen)) - if err != nil { - return err - } - l.UDPConn = udpConn - go l.loop() - return nil -} - -func (l *UDPListener) Close() error { - if l == nil || l.UDPConn == nil { - return nil - } - return l.UDPConn.Close() -} - -func (l *UDPListener) loop() { - for { - buffer := buf.New() - n, addr, err := l.ReadFromUDP(buffer.Extend(buf.UDPBufferSize)) - if err != nil { - buffer.Release() - return - } - buffer.Truncate(n) - go func() { - err := l.Handler.HandleUDP(buffer, addr) - if err != nil { - buffer.Release() - l.Handler.OnError(err) - } - }() - } -} diff --git a/transport/tcp/handler.go b/transport/tcp/handler.go new file mode 100644 index 0000000..0e17ec1 --- /dev/null +++ b/transport/tcp/handler.go @@ -0,0 +1,91 @@ +package tcp + +import ( + "net" + "net/netip" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/redir" +) + +type Handler interface { + M.TCPConnectionHandler + E.Handler +} + +type Listener struct { + bind netip.AddrPort + handler Handler + trans redir.TransproxyMode + lAddr *net.TCPAddr + *net.TCPListener +} + +func NewTCPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener { + listener := &Listener{ + bind: listen, + handler: handler, + } + for _, option := range options { + option(listener) + } + return listener +} + +func (l *Listener) Start() error { + tcpListener, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(l.bind)) + if err != nil { + return err + } + if l.trans == redir.ModeTProxy { + l.lAddr = tcpListener.Addr().(*net.TCPAddr) + fd, err := common.GetFileDescriptor(tcpListener) + if err != nil { + return err + } + err = redir.TProxy(fd, l.bind.Addr().Is6()) + if err != nil { + return E.Cause(err, "configure tproxy") + } + } + l.TCPListener = tcpListener + go l.loop() + return nil +} + +func (l *Listener) Close() error { + if l == nil || l.TCPListener == nil { + return nil + } + return l.TCPListener.Close() +} + +func (l *Listener) loop() { + for { + tcpConn, err := l.Accept() + if err != nil { + l.Close() + return + } + var metadata M.Metadata + switch l.trans { + case redir.ModeRedirect: + metadata.Destination, _ = redir.GetOriginalDestination(tcpConn) + case redir.ModeTProxy: + lAddr := tcpConn.LocalAddr().(*net.TCPAddr) + rAddr := tcpConn.RemoteAddr().(*net.TCPAddr) + + if lAddr.Port != l.lAddr.Port || !lAddr.IP.Equal(rAddr.IP) && !lAddr.IP.IsLoopback() && !lAddr.IP.IsPrivate() { + metadata.Destination = M.AddrPortFromNetAddr(lAddr) + } + } + go func() { + err := l.handler.NewConnection(tcpConn, metadata) + if err != nil { + l.handler.HandleError(err) + } + }() + } +} diff --git a/transport/tcp/options.go b/transport/tcp/options.go new file mode 100644 index 0000000..7d12a18 --- /dev/null +++ b/transport/tcp/options.go @@ -0,0 +1,11 @@ +package tcp + +import "github.com/sagernet/sing/common/redir" + +type Option func(*Listener) + +func WithTransproxyMode(mode redir.TransproxyMode) Option { + return func(listener *Listener) { + listener.trans = mode + } +} diff --git a/transport/udp/options.go b/transport/udp/options.go new file mode 100644 index 0000000..2dab220 --- /dev/null +++ b/transport/udp/options.go @@ -0,0 +1,11 @@ +package udp + +import "github.com/sagernet/sing/common/redir" + +type Option func(*Listener) + +func WithTransproxyMode(mode redir.TransproxyMode) Option { + return func(listener *Listener) { + listener.tproxy = mode == redir.ModeTProxy + } +} diff --git a/transport/udp/udp.go b/transport/udp/udp.go new file mode 100644 index 0000000..2f0b0e0 --- /dev/null +++ b/transport/udp/udp.go @@ -0,0 +1,116 @@ +package udp + +import ( + "net" + "net/netip" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/redir" +) + +type Handler interface { + M.UDPHandler + E.Handler +} + +type Listener struct { + *net.UDPConn + handler Handler + bind *net.UDPAddr + tproxy bool +} + +func NewUDPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener { + listener := &Listener{ + handler: handler, + bind: net.UDPAddrFromAddrPort(listen), + } + for _, option := range options { + option(listener) + } + return listener +} + +func (l *Listener) Start() error { + udpConn, err := net.ListenUDP("udp", l.bind) + if err != nil { + return err + } + + if l.tproxy { + fd, err := common.GetFileDescriptor(udpConn) + if err != nil { + return err + } + err = redir.TProxy(fd, l.bind.AddrPort().Addr().Is6()) + if err != nil { + return E.Cause(err, "configure tproxy") + } + err = redir.TProxyUDP(fd, l.bind.AddrPort().Addr().Is6()) + if err != nil { + return E.Cause(err, "configure tproxy") + } + } + + l.UDPConn = udpConn + go l.loop() + return nil +} + +func (l *Listener) Close() error { + if l == nil || l.UDPConn == nil { + return nil + } + return l.UDPConn.Close() +} + +func (l *Listener) loop() { + if !l.tproxy { + for { + buffer := buf.New() + n, addr, err := l.ReadFromUDP(buffer.Extend(buf.UDPBufferSize)) + if err != nil { + buffer.Release() + l.handler.HandleError(err) + return + } + buffer.Truncate(n) + err = l.handler.NewPacket(buffer, M.Metadata{ + Source: M.AddrPortFromNetAddr(addr), + }) + if err != nil { + buffer.Release() + l.handler.HandleError(err) + } + } + } else { + oob := make([]byte, 1024) + for { + buffer := buf.New() + n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(buffer.FreeBytes(), oob) + if err != nil { + buffer.Release() + l.handler.HandleError(err) + return + } + destination, err := redir.GetOriginalDestinationFromOOB(oob[:oobN]) + if err != nil { + l.handler.HandleError(E.Cause(err, "get original destination")) + return + } + buffer.Truncate(n) + err = l.handler.NewPacket(buffer, M.Metadata{ + Source: M.AddrPortFromAddrPort(addr), + Destination: destination, + }) + if err != nil { + buffer.Release() + l.handler.HandleError(err) + } + } + + } +}