Refactor shadowsocks

This commit is contained in:
世界 2022-04-10 22:51:29 +08:00
parent 3f23b25edf
commit 00cd0d4b8f
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
75 changed files with 3169 additions and 1318 deletions

3
.gitignore vendored
View file

@ -2,4 +2,5 @@
/sing_*
/*.json
/Country.mmdb
/geosite.dat
/geosite.dat
/vendor/

View file

@ -1,3 +1,12 @@
# sing
Do you hear the people sing?
Do you hear the people sing?
```shell
# geo resources
go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/get-geoip
go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/gen-geosite
# ss-local
go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/ss-local
```

View file

@ -2,12 +2,12 @@ package main
import (
"encoding/binary"
"github.com/klauspost/compress/zstd"
"github.com/sagernet/sing/common/rw"
"io"
"net/http"
"os"
"github.com/klauspost/compress/zstd"
"github.com/sagernet/sing/common/rw"
"github.com/sirupsen/logrus"
"github.com/ulikunitz/xz"
"github.com/v2fly/v2ray-core/v5/app/router/routercommon"

View file

@ -1,10 +1,11 @@
package main
import (
"github.com/sirupsen/logrus"
"io"
"net/http"
"os"
"github.com/sirupsen/logrus"
)
func main() {

374
cli/ss-local/main.go Normal file
View file

@ -0,0 +1,374 @@
package main
import (
"context"
"encoding/base64"
"encoding/json"
"io"
"io/ioutil"
"net"
"net/netip"
"os"
"os/signal"
"runtime/debug"
"syscall"
"time"
"github.com/sagernet/sing"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/geoip"
"github.com/sagernet/sing/common/geosite"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/random"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/transport/mixed"
"github.com/sagernet/sing/transport/system"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
type flags struct {
Server string `json:"server"`
ServerPort uint16 `json:"server_port"`
LocalPort uint16 `json:"local_port"`
Password string `json:"password"`
Key string `json:"key"`
Method string `json:"method"`
TCPFastOpen bool `json:"fast_open"`
Verbose bool `json:"verbose"`
Transproxy string `json:"transproxy"`
FWMark int `json:"fwmark"`
Bypass string `json:"bypass"`
UseSystemRNG bool `json:"use_system_rng"`
ReducedSaltEntropy bool `json:"reduced_salt_entropy"`
ConfigFile string
}
func main() {
f := new(flags)
command := &cobra.Command{
Use: "ss-local",
Short: "shadowsocks client",
Version: sing.VersionStr,
Run: func(cmd *cobra.Command, args []string) {
Run(cmd, f)
},
}
command.Flags().StringVarP(&f.Server, "server", "s", "", "Set the server’s hostname or IP.")
command.Flags().Uint16VarP(&f.ServerPort, "server-port", "p", 0, "Set the server’s port number.")
command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.")
command.Flags().StringVarP(&f.Password, "password", "k", "", "Set the password. The server and the client should use the same password.")
command.Flags().StringVar(&f.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
command.Flags().StringVarP(&f.Method, "encrypt-method", "m", "", `Set the cipher.
Supported ciphers:
none
aes-128-gcm
aes-192-gcm
aes-256-gcm
chacha20-ietf-poly1305
xchacha20-ietf-poly1305`)
command.Flags().BoolVar(&f.TCPFastOpen, "fast-open", false, `Enable TCP fast open.
Only available with Linux kernel > 3.7.0.`)
command.Flags().StringVarP(&f.Transproxy, "transproxy", "t", "", "Enable transparent proxy support. [possible values: redirect, tproxy]")
command.Flags().IntVar(&f.FWMark, "fwmark", 0, "Set outbound socket mark.")
command.Flags().StringVar(&f.Bypass, "bypass", "", "Set bypass country.")
command.Flags().StringVarP(&f.ConfigFile, "config", "c", "", "Use a configuration file.")
command.Flags().BoolVarP(&f.Verbose, "verbose", "v", false, "Enable verbose mode.")
command.Flags().BoolVar(&f.UseSystemRNG, "use-system-rng", false, "Use system random number generator.")
command.Flags().BoolVar(&f.ReducedSaltEntropy, "reduced-salt-entropy", false, "Remapping salt to printable chars.")
err := command.Execute()
if err != nil {
logrus.Fatal(err)
}
}
type LocalClient struct {
*mixed.Listener
*geosite.Matcher
server *M.AddrPort
method shadowsocks.Method
session shadowsocks.Session
dialer net.Dialer
bypass string
}
func NewLocalClient(f *flags) (*LocalClient, error) {
if f.ConfigFile != "" {
configFile, err := ioutil.ReadFile(f.ConfigFile)
if err != nil {
return nil, E.Cause(err, "read config file")
}
flagsNew := new(flags)
err = json.Unmarshal(configFile, flagsNew)
if err != nil {
return nil, E.Cause(err, "decode config file")
}
if flagsNew.Server != "" && f.Server == "" {
f.Server = flagsNew.Server
}
if flagsNew.ServerPort != 0 && f.ServerPort == 0 {
f.ServerPort = flagsNew.ServerPort
}
if flagsNew.LocalPort != 0 && f.LocalPort == 0 {
f.LocalPort = flagsNew.LocalPort
}
if flagsNew.Password != "" && f.Password == "" {
f.Password = flagsNew.Password
}
if flagsNew.Key != "" && f.Key == "" {
f.Key = flagsNew.Key
}
if flagsNew.Method != "" && f.Method == "" {
f.Method = flagsNew.Method
}
if flagsNew.Transproxy != "" && f.Transproxy == "" {
f.Transproxy = flagsNew.Transproxy
}
if flagsNew.TCPFastOpen {
f.TCPFastOpen = true
}
if flagsNew.Verbose {
f.Verbose = true
}
}
if f.Verbose {
logrus.SetLevel(logrus.TraceLevel)
}
if f.Server == "" {
return nil, E.New("missing server address")
} else if f.ServerPort == 0 {
return nil, E.New("missing server port")
} else if f.Method == "" {
return nil, E.New("missing method")
}
client := &LocalClient{
server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort),
bypass: f.Bypass,
}
if f.Method == shadowsocks.MethodNone {
client.method = shadowsocks.NewNone()
} else if common.Contains(shadowaead.List, f.Method) {
var key []byte
var rng io.Reader
if f.UseSystemRNG {
rng = random.System
} else {
rng = random.Blake3KeyedHash()
}
if f.ReducedSaltEntropy {
rng = &shadowsocks.ReducedEntropyReader{Reader: rng}
}
client.method = shadowaead.New(f.Method, rng)
keyLength := client.method.KeyLength()
if f.Key != "" {
decoded, err := base64.URLEncoding.DecodeString(f.Key)
if err != nil {
return nil, E.Cause(err, "decode key")
}
if len(decoded) != keyLength {
return nil, E.Cause(err, "bad key")
}
key = decoded
} else if f.Password != "" {
key = shadowsocks.Key([]byte(f.Password), keyLength)
} else {
return nil, E.New("missing password")
}
client.session = shadowaead.NewSession(key, false)
}
client.dialer.Control = func(network, address string, c syscall.RawConn) error {
var rawFd uintptr
err := c.Control(func(fd uintptr) {
rawFd = fd
})
if err != nil {
return err
}
if f.FWMark > 0 {
err = syscall.SetsockoptInt(int(rawFd), syscall.SOL_SOCKET, syscall.SO_MARK, f.FWMark)
if err != nil {
return err
}
}
if f.TCPFastOpen {
err = system.TCPFastOpen(rawFd)
if err != nil {
return err
}
}
return nil
}
var transproxyMode redir.TransproxyMode
switch f.Transproxy {
case "redirect":
transproxyMode = redir.ModeRedirect
case "tproxy":
transproxyMode = redir.ModeTProxy
case "":
transproxyMode = redir.ModeDisabled
default:
return nil, E.New("unknown transproxy mode ", f.Transproxy)
}
client.Listener = mixed.NewListener(netip.AddrPortFrom(netip.IPv6Unspecified(), f.LocalPort), nil, transproxyMode, client)
if f.Bypass != "" {
err := geoip.LoadMMDB("Country.mmdb")
if err != nil {
return nil, E.Cause(err, "load Country.mmdb")
}
geodata, err := os.Open("geosite.dat")
if err != nil {
return nil, E.Cause(err, "geosite.dat not found")
}
geositeMatcher, err := geosite.LoadGeositeMatcher(geodata, f.Bypass)
if err != nil {
return nil, err
}
client.Matcher = geositeMatcher
debug.FreeOSMemory()
}
return client, nil
}
func bypass(conn net.Conn, destination *M.AddrPort) error {
logrus.Info("BYPASS ", conn.RemoteAddr(), " ==> ", destination)
serverConn, err := net.Dial("tcp", destination.String())
if err != nil {
return err
}
return task.Run(context.Background(), func() error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(serverConn)
return common.Error(io.Copy(serverConn, conn))
}, func() error {
defer rw.CloseRead(serverConn)
defer rw.CloseWrite(conn)
return common.Error(io.Copy(conn, serverConn))
})
}
func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
if c.bypass != "" {
if metadata.Destination.Addr.Family().IsFqdn() {
if c.Match(metadata.Destination.Addr.Fqdn()) {
return bypass(conn, metadata.Destination)
}
} else {
if geoip.Match(c.bypass, metadata.Destination.Addr.Addr().AsSlice()) {
return bypass(conn, metadata.Destination)
}
}
}
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
ctx := context.Background()
var serverConn net.Conn
payload := buf.New()
err := task.Run(ctx, func() error {
sc, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
serverConn = sc
if err != nil {
return E.Cause(err, "connect to server")
}
return nil
}, func() error {
err := conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
if err != nil {
return err
}
_, err = payload.ReadFrom(conn)
if err != nil && !E.IsTimeout(err) {
return E.Cause(err, "read payload")
}
err = conn.SetReadDeadline(time.Time{})
return err
})
if err != nil {
payload.Release()
return err
}
serverConn = c.method.DialEarlyConn(c.session, serverConn, metadata.Destination)
_, err = serverConn.Write(payload.Bytes())
payload.Release()
if err != nil {
return E.Cause(err, "client handshake")
}
return rw.CopyConn(ctx, serverConn, conn)
}
func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error {
ctx := context.Background()
udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String())
if err != nil {
return err
}
serverConn := c.method.DialPacketConn(c.session, udpConn)
return task.Run(ctx, func() error {
var init bool
return socks.CopyPacketConn(serverConn, conn, func(destination *M.AddrPort, n int) {
if !init {
init = true
logrus.Info("UDP ", conn.LocalAddr(), " ==> ", destination)
} else {
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination)
}
})
}, func() error {
return socks.CopyPacketConn(conn, serverConn, func(destination *M.AddrPort, n int) {
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination)
})
})
}
func Run(cmd *cobra.Command, flags *flags) {
client, err := NewLocalClient(flags)
if err != nil {
logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n")
cmd.Help()
os.Exit(1)
}
err = client.Listener.Start()
if err != nil {
logrus.Fatal(err)
}
logrus.Info("mixed server started at ", client.Listener.TCPListener.Addr())
osSignals := make(chan os.Signal, 1)
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
<-osSignals
client.Listener.Close()
}
func (c *LocalClient) HandleError(err error) {
if E.IsClosed(err) {
return
}
logrus.Warn(err)
}

View file

@ -1,12 +0,0 @@
//go:build debug
package main
import (
"net/http"
_ "net/http/pprof"
)
func init() {
go http.ListenAndServe("127.0.0.1:8964", nil)
}

View file

@ -1,324 +0,0 @@
package main
import (
"context"
"encoding/base64"
"encoding/json"
"github.com/sagernet/sing/common/geoip"
"github.com/sagernet/sing/common/geosite"
"io"
"io/ioutil"
"net"
"net/netip"
"os"
"os/signal"
"runtime/debug"
"strconv"
"syscall"
"github.com/sagernet/sing"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/socksaddr"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/transport/system"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
func main() {
err := MainCmd().Execute()
if err != nil {
logrus.Fatal(err)
}
}
type Flags struct {
Server string `json:"server"`
ServerPort uint16 `json:"server_port"`
LocalPort uint16 `json:"local_port"`
Password string `json:"password"`
Key string `json:"key"`
Method string `json:"method"`
TCPFastOpen bool `json:"fast_open"`
Verbose bool `json:"verbose"`
Redirect string `json:"redir"`
FWMark int `json:"fwmark"`
Bypass string `json:"bypass"`
ConfigFile string
}
func MainCmd() *cobra.Command {
flags := new(Flags)
cmd := &cobra.Command{
Use: "sslocal",
Short: "shadowsocks client as socks5 proxy, sing port",
Version: sing.Version,
Run: func(cmd *cobra.Command, args []string) {
Run(cmd, flags)
},
}
cmd.Flags().StringVarP(&flags.Server, "server", "s", "", "Set the server’s hostname or IP.")
cmd.Flags().Uint16VarP(&flags.ServerPort, "server-port", "p", 0, "Set the server’s port number.")
cmd.Flags().Uint16VarP(&flags.LocalPort, "local-port", "l", 0, "Set the local port number.")
cmd.Flags().StringVarP(&flags.Password, "password", "k", "", "Set the password. The server and the client should use the same password.")
cmd.Flags().StringVar(&flags.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
cmd.Flags().StringVarP(&flags.Method, "encrypt-method", "m", "", `Set the cipher.
Supported ciphers:
none
aes-128-gcm
aes-192-gcm
aes-256-gcm
chacha20-ietf-poly1305
xchacha20-ietf-poly1305
The default cipher is chacha20-ietf-poly1305.`)
cmd.Flags().BoolVar(&flags.TCPFastOpen, "fast-open", false, `Enable TCP fast open.
Only available with Linux kernel > 3.7.0.`)
cmd.Flags().StringVar(&flags.Redirect, "redir", "", "Enable transparent proxy support. [possible values: redirect, tproxy]")
cmd.Flags().IntVar(&flags.FWMark, "fwmark", 0, "Set outbound socket mark.")
cmd.Flags().StringVar(&flags.Bypass, "bypass", "", "Set bypass country.")
cmd.Flags().StringVarP(&flags.ConfigFile, "config", "c", "", "Use a configuration file.")
cmd.Flags().BoolVarP(&flags.Verbose, "verbose", "v", false, "Enable verbose mode.")
return cmd
}
type LocalClient struct {
*system.MixedListener
*shadowsocks.Client
*geosite.Matcher
redirect bool
bypass string
}
func NewLocalClient(flags *Flags) (*LocalClient, error) {
if flags.ConfigFile != "" {
configFile, err := ioutil.ReadFile(flags.ConfigFile)
if err != nil {
return nil, exceptions.Cause(err, "read config file")
}
flagsNew := new(Flags)
err = json.Unmarshal(configFile, flagsNew)
if err != nil {
return nil, exceptions.Cause(err, "decode config file")
}
if flagsNew.Server != "" && flags.Server == "" {
flags.Server = flagsNew.Server
}
if flagsNew.ServerPort != 0 && flags.ServerPort == 0 {
flags.ServerPort = flagsNew.ServerPort
}
if flagsNew.LocalPort != 0 && flags.LocalPort == 0 {
flags.LocalPort = flagsNew.LocalPort
}
if flagsNew.Password != "" && flags.Password == "" {
flags.Password = flagsNew.Password
}
if flagsNew.Key != "" && flags.Key == "" {
flags.Key = flagsNew.Key
}
if flagsNew.Method != "" && flags.Method == "" {
flags.Method = flagsNew.Method
}
if flagsNew.Redirect != "" && flags.Redirect == "" {
flags.Redirect = flagsNew.Redirect
}
if flagsNew.TCPFastOpen {
flags.TCPFastOpen = true
}
if flagsNew.Verbose {
flags.Verbose = true
}
}
clientConfig := &shadowsocks.ClientConfig{
Server: flags.Server,
ServerPort: flags.ServerPort,
Method: flags.Method,
}
if flags.Key != "" {
key, err := base64.URLEncoding.DecodeString(flags.Key)
if err != nil {
return nil, exceptions.Cause(err, "decode key")
}
clientConfig.Key = key
} else if flags.Password != "" {
clientConfig.Password = []byte(flags.Password)
}
if flags.Verbose {
logrus.SetLevel(logrus.TraceLevel)
}
dialer := new(net.Dialer)
dialer.Control = func(network, address string, c syscall.RawConn) error {
var rawFd uintptr
err := c.Control(func(fd uintptr) {
rawFd = fd
})
if err != nil {
return err
}
if flags.FWMark > 0 {
err = syscall.SetsockoptInt(int(rawFd), syscall.SOL_SOCKET, syscall.SO_MARK, flags.FWMark)
if err != nil {
return err
}
}
if flags.TCPFastOpen {
err = system.TCPFastOpen(rawFd)
if err != nil {
return err
}
}
return nil
}
shadowClient, err := shadowsocks.NewClient(dialer, clientConfig)
if err != nil {
return nil, err
}
client := &LocalClient{
Client: shadowClient,
}
client.MixedListener = system.NewMixedListener(netip.AddrPortFrom(netip.IPv6Unspecified(), flags.LocalPort), &system.MixedConfig{
Redirect: flags.Redirect == "redirect",
TProxy: flags.Redirect == "tproxy",
}, client)
if flags.Bypass != "" {
client.bypass = flags.Bypass
err = geoip.LoadMMDB("Country.mmdb")
if err != nil {
return nil, exceptions.Cause(err, "load Country.mmdb")
}
geodata, err := os.Open("geosite.dat")
if err != nil {
return nil, exceptions.Cause(err, "geosite.dat not found")
}
geositeMatcher, err := geosite.LoadGeositeMatcher(geodata, flags.Bypass)
if err != nil {
return nil, err
}
client.Matcher = geositeMatcher
debug.FreeOSMemory()
}
return client, nil
}
func (c *LocalClient) Start() error {
err := c.MixedListener.Start()
if err != nil {
return err
}
logrus.Info("mixed server started at ", c.MixedListener.TCPListener.Addr())
return nil
}
func bypass(addr socksaddr.Addr, port uint16, conn net.Conn) error {
logrus.Info("BYPASS ", conn.RemoteAddr(), " ==> ", net.JoinHostPort(addr.String(), strconv.Itoa(int(port))))
serverConn, err := net.Dial("tcp", socksaddr.JoinHostPort(addr, port))
if err != nil {
return err
}
return task.Run(context.Background(), func() error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(serverConn)
return common.Error(io.Copy(serverConn, conn))
}, func() error {
defer rw.CloseRead(serverConn)
defer rw.CloseWrite(conn)
return common.Error(io.Copy(conn, serverConn))
})
}
func (c *LocalClient) NewConnection(addr socksaddr.Addr, port uint16, conn net.Conn) error {
if c.bypass != "" {
if addr.Family().IsFqdn() {
if c.Match(addr.Fqdn()) {
return bypass(addr, port, conn)
}
} else {
if geoip.Match(c.bypass, addr.Addr().AsSlice()) {
return bypass(addr, port, conn)
}
}
}
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", net.JoinHostPort(addr.String(), strconv.Itoa(int(port))))
ctx := context.Background()
serverConn, err := c.DialContextTCP(ctx, addr, port)
if err != nil {
return err
}
return task.Run(ctx, func() error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(serverConn)
return common.Error(io.Copy(serverConn, conn))
}, func() error {
defer rw.CloseRead(serverConn)
defer rw.CloseWrite(conn)
return common.Error(io.Copy(conn, serverConn))
})
}
func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, addr socksaddr.Addr, port uint16) error {
ctx := context.Background()
serverConn := c.DialContextUDP(ctx)
return task.Run(ctx, func() error {
var init bool
return socks.CopyPacketConn(serverConn, conn, func(size int) {
if !init {
init = true
logrus.Info("UDP ", conn.LocalAddr(), " ==> ", socksaddr.JoinHostPort(addr, port))
} else {
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", socksaddr.JoinHostPort(addr, port))
}
})
}, func() error {
return socks.CopyPacketConn(conn, serverConn, func(size int) {
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", socksaddr.JoinHostPort(addr, port))
})
})
}
func Run(cmd *cobra.Command, flags *Flags) {
client, err := NewLocalClient(flags)
if err != nil {
logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n")
cmd.Help()
os.Exit(1)
}
err = client.Start()
if err != nil {
logrus.Fatal(err)
}
osSignals := make(chan os.Signal, 1)
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
<-osSignals
client.Close()
}
func (c *LocalClient) OnError(err error) {
if exceptions.IsClosed(err) {
return
}
logrus.Warn(err)
}

140
cli/uot-local/main.go Normal file
View file

@ -0,0 +1,140 @@
package main
import (
"context"
"net"
"net/netip"
"os"
"os/signal"
"syscall"
"github.com/sagernet/sing"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/transport/mixed"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
var (
verbose bool
transproxy string
)
func main() {
command := cobra.Command{
Use: "uot-local <bind> <upstream>",
Short: "SUoT client.",
Long: "SUoT client. \n\nconverts a normal socks server to a SUoT mixed server.",
Example: "uot-local 0.0.0.0:2080 127.0.0.1:1080",
Version: sing.VersionStr,
Args: cobra.ExactArgs(2),
Run: run,
}
command.Flags().BoolVarP(&verbose, "verbose", "v", false, "Enable verbose mode")
command.Flags().StringVarP(&transproxy, "transproxy", "t", "", "Enable transparent proxy support [possible values: redirect, tproxy]")
err := command.Execute()
if err != nil {
logrus.Fatal(err)
}
}
func run(cmd *cobra.Command, args []string) {
if verbose {
logrus.SetLevel(logrus.TraceLevel)
}
bind, err := netip.ParseAddrPort(args[0])
if err != nil {
logrus.Fatal("bad bind address: ", err)
}
_, err = netip.ParseAddrPort(args[1])
if err != nil {
logrus.Fatal("bad upstream address: ", err)
}
var transproxyMode redir.TransproxyMode
switch transproxy {
case "redirect":
transproxyMode = redir.ModeRedirect
case "tproxy":
transproxyMode = redir.ModeTProxy
case "":
transproxyMode = redir.ModeDisabled
default:
logrus.Fatal("unknown transproxy mode ", transproxy)
}
client := &localClient{upstream: args[1]}
client.Listener = mixed.NewListener(bind, nil, transproxyMode, client)
err = client.Start()
if err != nil {
logrus.Fatal("start mixed server: ", err)
}
logrus.Info("mixed server started at ", client.TCPListener.Addr())
osSignals := make(chan os.Signal, 1)
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
<-osSignals
client.Close()
}
type localClient struct {
*mixed.Listener
upstream string
}
func (c *localClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
upstream, err := net.Dial("tcp", c.upstream)
if err != nil {
return E.Cause(err, "connect to upstream")
}
_, err = socks.ClientHandshake(upstream, socks.Version5, socks.CommandConnect, metadata.Destination, "", "")
if err != nil {
return E.Cause(err, "upstream handshake failed")
}
return rw.CopyConn(context.Background(), upstream, conn)
}
func (c *localClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error {
upstream, err := net.Dial("tcp", c.upstream)
if err != nil {
return E.Cause(err, "connect to upstream")
}
_, err = socks.ClientHandshake(upstream, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn(uot.UOTMagicAddress), 443), "", "")
if err != nil {
return E.Cause(err, "upstream handshake failed")
}
client := uot.NewClientConn(upstream)
return task.Run(context.Background(), func() error {
return socks.CopyPacketConn(client, conn, func(destination *M.AddrPort, n int) {
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination)
})
}, func() error {
return socks.CopyPacketConn(conn, client, func(destination *M.AddrPort, n int) {
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination)
})
})
}
func (c *localClient) OnError(err error) {
if E.IsClosed(err) {
return
}
logrus.Warn(err)
}

46
common/auth/auth.go Normal file
View file

@ -0,0 +1,46 @@
package auth
import (
"sync"
)
type Authenticator interface {
Verify(user string, pass string) bool
Users() []string
}
type AuthUser struct {
User string
Pass string
}
type inMemoryAuthenticator struct {
storage *sync.Map
usernames []string
}
func (au *inMemoryAuthenticator) Verify(user string, pass string) bool {
realPass, ok := au.storage.Load(user)
return ok && realPass == pass
}
func (au *inMemoryAuthenticator) Users() []string { return au.usernames }
func NewAuthenticator(users []AuthUser) Authenticator {
if len(users) == 0 {
return nil
}
au := &inMemoryAuthenticator{storage: &sync.Map{}}
for _, user := range users {
au.storage.Store(user.User, user.Pass)
}
usernames := make([]string, 0, len(users))
au.storage.Range(func(key, value interface{}) bool {
usernames = append(usernames, key.(string))
return true
})
au.usernames = usernames
return au
}

View file

@ -4,6 +4,7 @@ import (
"crypto/rand"
"fmt"
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/list"
@ -188,6 +189,18 @@ func (b *Buffer) ReadFrom(r io.Reader) (int64, error) {
return int64(n), nil
}
func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) {
if b.IsFull() {
return 0, nil, io.ErrShortBuffer
}
n, addr, err := r.ReadFrom(b.FreeBytes())
if err != nil {
return 0, nil, err
}
b.end += n
return int64(n), addr, nil
}
func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
if min <= 0 {
return b.ReadFrom(r)

View file

@ -70,3 +70,12 @@ func (w *BufferedWriter) Flush() error {
}
return common.Error(w.Writer.Write(buffer.Bytes()))
}
func (w *BufferedWriter) Close() error {
buffer := w.Buffer
if buffer != nil {
w.Buffer = nil
buffer.Release()
}
return nil
}

106
common/cache/cache.go vendored Normal file
View file

@ -0,0 +1,106 @@
package cache
import (
"runtime"
"sync"
"time"
)
// Cache store element with a expired time
type Cache struct {
*cache
}
type cache struct {
mapping sync.Map
janitor *janitor
}
type element struct {
Expired time.Time
Payload interface{}
}
// Put element in Cache with its ttl
func (c *cache) Put(key interface{}, payload interface{}, ttl time.Duration) {
c.mapping.Store(key, &element{
Payload: payload,
Expired: time.Now().Add(ttl),
})
}
// Get element in Cache, and drop when it expired
func (c *cache) Get(key interface{}) interface{} {
item, exist := c.mapping.Load(key)
if !exist {
return nil
}
elm := item.(*element)
// expired
if time.Since(elm.Expired) > 0 {
c.mapping.Delete(key)
return nil
}
return elm.Payload
}
// GetWithExpire element in Cache with Expire Time
func (c *cache) GetWithExpire(key interface{}) (payload interface{}, expired time.Time) {
item, exist := c.mapping.Load(key)
if !exist {
return
}
elm := item.(*element)
// expired
if time.Since(elm.Expired) > 0 {
c.mapping.Delete(key)
return
}
return elm.Payload, elm.Expired
}
func (c *cache) cleanup() {
c.mapping.Range(func(k, v interface{}) bool {
key := k.(string)
elm := v.(*element)
if time.Since(elm.Expired) > 0 {
c.mapping.Delete(key)
}
return true
})
}
type janitor struct {
interval time.Duration
stop chan struct{}
}
func (j *janitor) process(c *cache) {
ticker := time.NewTicker(j.interval)
for {
select {
case <-ticker.C:
c.cleanup()
case <-j.stop:
ticker.Stop()
return
}
}
}
func stopJanitor(c *Cache) {
c.janitor.stop <- struct{}{}
}
// New return *Cache
func New(interval time.Duration) *Cache {
j := &janitor{
interval: interval,
stop: make(chan struct{}),
}
c := &cache{janitor: j}
go j.process(c)
C := &Cache{c}
runtime.SetFinalizer(C, stopJanitor)
return C
}

70
common/cache/cache_test.go vendored Normal file
View file

@ -0,0 +1,70 @@
package cache
import (
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCache_Basic(t *testing.T) {
interval := 200 * time.Millisecond
ttl := 20 * time.Millisecond
c := New(interval)
c.Put("int", 1, ttl)
c.Put("string", "a", ttl)
i := c.Get("int")
assert.Equal(t, i.(int), 1, "should recv 1")
s := c.Get("string")
assert.Equal(t, s.(string), "a", "should recv 'a'")
}
func TestCache_TTL(t *testing.T) {
interval := 200 * time.Millisecond
ttl := 20 * time.Millisecond
now := time.Now()
c := New(interval)
c.Put("int", 1, ttl)
c.Put("int2", 2, ttl)
i := c.Get("int")
_, expired := c.GetWithExpire("int2")
assert.Equal(t, i.(int), 1, "should recv 1")
assert.True(t, now.Before(expired))
time.Sleep(ttl * 2)
i = c.Get("int")
j, _ := c.GetWithExpire("int2")
assert.Nil(t, i, "should recv nil")
assert.Nil(t, j, "should recv nil")
}
func TestCache_AutoCleanup(t *testing.T) {
interval := 10 * time.Millisecond
ttl := 15 * time.Millisecond
c := New(interval)
c.Put("int", 1, ttl)
time.Sleep(ttl * 2)
i := c.Get("int")
j, _ := c.GetWithExpire("int")
assert.Nil(t, i, "should recv nil")
assert.Nil(t, j, "should recv nil")
}
func TestCache_AutoGC(t *testing.T) {
sign := make(chan struct{})
go func() {
interval := 10 * time.Millisecond
ttl := 15 * time.Millisecond
c := New(interval)
c.Put("int", 1, ttl)
sign <- struct{}{}
}()
<-sign
runtime.GC()
}

223
common/cache/lrucache.go vendored Normal file
View file

@ -0,0 +1,223 @@
package cache
// Modified by https://github.com/die-net/lrucache
import (
"container/list"
"sync"
"time"
)
// Option is part of Functional Options Pattern
type Option func(*LruCache)
// EvictCallback is used to get a callback when a cache entry is evicted
type EvictCallback = func(key interface{}, value interface{})
// WithEvict set the evict callback
func WithEvict(cb EvictCallback) Option {
return func(l *LruCache) {
l.onEvict = cb
}
}
// WithUpdateAgeOnGet update expires when Get element
func WithUpdateAgeOnGet() Option {
return func(l *LruCache) {
l.updateAgeOnGet = true
}
}
// WithAge defined element max age (second)
func WithAge(maxAge int64) Option {
return func(l *LruCache) {
l.maxAge = maxAge
}
}
// WithSize defined max length of LruCache
func WithSize(maxSize int) Option {
return func(l *LruCache) {
l.maxSize = maxSize
}
}
// WithStale decide whether Stale return is enabled.
// If this feature is enabled, element will not get Evicted according to `WithAge`.
func WithStale(stale bool) Option {
return func(l *LruCache) {
l.staleReturn = stale
}
}
// LruCache is a thread-safe, in-memory lru-cache that evicts the
// least recently used entries from memory when (if set) the entries are
// older than maxAge (in seconds). Use the New constructor to create one.
type LruCache struct {
maxAge int64
maxSize int
mu sync.Mutex
cache map[interface{}]*list.Element
lru *list.List // Front is least-recent
updateAgeOnGet bool
staleReturn bool
onEvict EvictCallback
}
// NewLRUCache creates an LruCache
func NewLRUCache(options ...Option) *LruCache {
lc := &LruCache{
lru: list.New(),
cache: make(map[interface{}]*list.Element),
}
for _, option := range options {
option(lc)
}
return lc
}
// Get returns the interface{} representation of a cached response and a bool
// set to true if the key was found.
func (c *LruCache) Get(key interface{}) (interface{}, bool) {
entry := c.get(key)
if entry == nil {
return nil, false
}
value := entry.value
return value, true
}
// GetWithExpire returns the interface{} representation of a cached response,
// a time.Time Give expected expires,
// and a bool set to true if the key was found.
// This method will NOT check the maxAge of element and will NOT update the expires.
func (c *LruCache) GetWithExpire(key interface{}) (interface{}, time.Time, bool) {
entry := c.get(key)
if entry == nil {
return nil, time.Time{}, false
}
return entry.value, time.Unix(entry.expires, 0), true
}
// Exist returns if key exist in cache but not put item to the head of linked list
func (c *LruCache) Exist(key interface{}) bool {
c.mu.Lock()
defer c.mu.Unlock()
_, ok := c.cache[key]
return ok
}
// Set stores the interface{} representation of a response for a given key.
func (c *LruCache) Set(key interface{}, value interface{}) {
expires := int64(0)
if c.maxAge > 0 {
expires = time.Now().Unix() + c.maxAge
}
c.SetWithExpire(key, value, time.Unix(expires, 0))
}
// SetWithExpire stores the interface{} representation of a response for a given key and given expires.
// The expires time will round to second.
func (c *LruCache) SetWithExpire(key interface{}, value interface{}, expires time.Time) {
c.mu.Lock()
defer c.mu.Unlock()
if le, ok := c.cache[key]; ok {
c.lru.MoveToBack(le)
e := le.Value.(*entry)
e.value = value
e.expires = expires.Unix()
} else {
e := &entry{key: key, value: value, expires: expires.Unix()}
c.cache[key] = c.lru.PushBack(e)
if c.maxSize > 0 {
if len := c.lru.Len(); len > c.maxSize {
c.deleteElement(c.lru.Front())
}
}
}
c.maybeDeleteOldest()
}
// CloneTo clone and overwrite elements to another LruCache
func (c *LruCache) CloneTo(n *LruCache) {
c.mu.Lock()
defer c.mu.Unlock()
n.mu.Lock()
defer n.mu.Unlock()
n.lru = list.New()
n.cache = make(map[interface{}]*list.Element)
for e := c.lru.Front(); e != nil; e = e.Next() {
elm := e.Value.(*entry)
n.cache[elm.key] = n.lru.PushBack(elm)
}
}
func (c *LruCache) get(key interface{}) *entry {
c.mu.Lock()
defer c.mu.Unlock()
le, ok := c.cache[key]
if !ok {
return nil
}
if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() {
c.deleteElement(le)
c.maybeDeleteOldest()
return nil
}
c.lru.MoveToBack(le)
entry := le.Value.(*entry)
if c.maxAge > 0 && c.updateAgeOnGet {
entry.expires = time.Now().Unix() + c.maxAge
}
return entry
}
// Delete removes the value associated with a key.
func (c *LruCache) Delete(key interface{}) {
c.mu.Lock()
if le, ok := c.cache[key]; ok {
c.deleteElement(le)
}
c.mu.Unlock()
}
func (c *LruCache) maybeDeleteOldest() {
if !c.staleReturn && c.maxAge > 0 {
now := time.Now().Unix()
for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() {
c.deleteElement(le)
}
}
}
func (c *LruCache) deleteElement(le *list.Element) {
c.lru.Remove(le)
e := le.Value.(*entry)
delete(c.cache, e.key)
if c.onEvict != nil {
c.onEvict(e.key, e.value)
}
}
type entry struct {
key interface{}
value interface{}
expires int64
}

183
common/cache/lrucache_test.go vendored Normal file
View file

@ -0,0 +1,183 @@
package cache
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var entries = []struct {
key string
value string
}{
{"1", "one"},
{"2", "two"},
{"3", "three"},
{"4", "four"},
{"5", "five"},
}
func TestLRUCache(t *testing.T) {
c := NewLRUCache()
for _, e := range entries {
c.Set(e.key, e.value)
}
c.Delete("missing")
_, ok := c.Get("missing")
assert.False(t, ok)
for _, e := range entries {
value, ok := c.Get(e.key)
if assert.True(t, ok) {
assert.Equal(t, e.value, value.(string))
}
}
for _, e := range entries {
c.Delete(e.key)
_, ok := c.Get(e.key)
assert.False(t, ok)
}
}
func TestLRUMaxAge(t *testing.T) {
c := NewLRUCache(WithAge(86400))
now := time.Now().Unix()
expected := now + 86400
// Add one expired entry
c.Set("foo", "bar")
c.lru.Back().Value.(*entry).expires = now
// Reset
c.Set("foo", "bar")
e := c.lru.Back().Value.(*entry)
assert.True(t, e.expires >= now)
c.lru.Back().Value.(*entry).expires = now
// Set a few and verify expiration times
for _, s := range entries {
c.Set(s.key, s.value)
e := c.lru.Back().Value.(*entry)
assert.True(t, e.expires >= expected && e.expires <= expected+10)
}
// Make sure we can get them all
for _, s := range entries {
_, ok := c.Get(s.key)
assert.True(t, ok)
}
// Expire all entries
for _, s := range entries {
le, ok := c.cache[s.key]
if assert.True(t, ok) {
le.Value.(*entry).expires = now
}
}
// Get one expired entry, which should clear all expired entries
_, ok := c.Get("3")
assert.False(t, ok)
assert.Equal(t, c.lru.Len(), 0)
}
func TestLRUpdateOnGet(t *testing.T) {
c := NewLRUCache(WithAge(86400), WithUpdateAgeOnGet())
now := time.Now().Unix()
expires := now + 86400/2
// Add one expired entry
c.Set("foo", "bar")
c.lru.Back().Value.(*entry).expires = expires
_, ok := c.Get("foo")
assert.True(t, ok)
assert.True(t, c.lru.Back().Value.(*entry).expires > expires)
}
func TestMaxSize(t *testing.T) {
c := NewLRUCache(WithSize(2))
// Add one expired entry
c.Set("foo", "bar")
_, ok := c.Get("foo")
assert.True(t, ok)
c.Set("bar", "foo")
c.Set("baz", "foo")
_, ok = c.Get("foo")
assert.False(t, ok)
}
func TestExist(t *testing.T) {
c := NewLRUCache(WithSize(1))
c.Set(1, 2)
assert.True(t, c.Exist(1))
c.Set(2, 3)
assert.False(t, c.Exist(1))
}
func TestEvict(t *testing.T) {
temp := 0
evict := func(key interface{}, value interface{}) {
temp = key.(int) + value.(int)
}
c := NewLRUCache(WithEvict(evict), WithSize(1))
c.Set(1, 2)
c.Set(2, 3)
assert.Equal(t, temp, 3)
}
func TestSetWithExpire(t *testing.T) {
c := NewLRUCache(WithAge(1))
now := time.Now().Unix()
tenSecBefore := time.Unix(now-10, 0)
c.SetWithExpire(1, 2, tenSecBefore)
// res is expected not to exist, and expires should be empty time.Time
res, expires, exist := c.GetWithExpire(1)
assert.Equal(t, nil, res)
assert.Equal(t, time.Time{}, expires)
assert.Equal(t, false, exist)
}
func TestStale(t *testing.T) {
c := NewLRUCache(WithAge(1), WithStale(true))
now := time.Now().Unix()
tenSecBefore := time.Unix(now-10, 0)
c.SetWithExpire(1, 2, tenSecBefore)
res, expires, exist := c.GetWithExpire(1)
assert.Equal(t, 2, res)
assert.Equal(t, tenSecBefore, expires)
assert.Equal(t, true, exist)
}
func TestCloneTo(t *testing.T) {
o := NewLRUCache(WithSize(10))
o.Set("1", 1)
o.Set("2", 2)
n := NewLRUCache(WithSize(2))
n.Set("3", 3)
n.Set("4", 4)
o.CloneTo(n)
assert.False(t, n.Exist("3"))
assert.True(t, n.Exist("1"))
n.Set("5", 5)
assert.False(t, n.Exist("1"))
}

View file

@ -43,6 +43,16 @@ func Filter[T any](arr []T, block func(it T) bool) []T {
return retArr
}
func FilterIsInstance[T any, N any](arr []T, block func(it T) (N, bool)) []N {
var retArr []N
for _, it := range arr {
if n, isN := block(it); isN {
retArr = append(retArr, n)
}
}
return retArr
}
func Done(ctx context.Context) bool {
select {
case <-ctx.Done():

View file

@ -62,3 +62,7 @@ func IsTimeout(err error) bool {
}
return false
}
type Handler interface {
HandleError(err error)
}

View file

@ -50,7 +50,7 @@ func FlushVar(writerP *io.Writer) error {
if writerBack == writer {
writer = u.Upstream()
writerBack = writer
writerP = &writer
*writerP = writer
continue
} else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter {
setter.SetWriter(writerBack)

View file

@ -1,10 +1,11 @@
package geoip
import (
"github.com/oschwald/geoip2-golang"
"net"
"strings"
"sync"
"github.com/oschwald/geoip2-golang"
)
var (

View file

@ -3,12 +3,13 @@ package geosite
import (
"bufio"
"encoding/binary"
"github.com/klauspost/compress/zstd"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/trieset"
"io"
"strings"
"github.com/klauspost/compress/zstd"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/trieset"
)
type Matcher struct {
@ -25,7 +26,7 @@ func LoadGeositeMatcher(reader io.Reader, code string) (*Matcher, error) {
return nil, err
}
if version != 0 {
return nil, exceptions.New("bad geosite data")
return nil, E.New("bad geosite data")
}
decoder, err := zstd.NewReader(reader, zstd.WithDecoderLowmem(true), zstd.WithDecoderConcurrency(1))
if err != nil {
@ -72,5 +73,5 @@ func LoadGeositeMatcher(reader io.Reader, code string) (*Matcher, error) {
}
}
}
return nil, exceptions.New(code, " not found in geosite")
return nil, E.New(code, " not found in geosite")
}

View file

@ -201,7 +201,7 @@ func (e *entry[T]) storeLocked(i *T) {
// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
func (m *Map[K, V]) LoadOrStore(key K, value func() V) (actual V, loaded bool) {
// Avoid locking if it's a clean hit.
read, _ := m.read.Load().(readOnly[K, V])
if e, ok := read.m[key]; ok {
@ -228,8 +228,9 @@ func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
m.dirtyLocked()
m.read.Store(readOnly[K, V]{m: read.m, amended: true})
}
m.dirty[key] = newEntry(value)
actual, loaded = value, false
v := value()
m.dirty[key] = newEntry(v)
actual, loaded = v, false
}
m.mu.Unlock()
@ -241,7 +242,7 @@ func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
//
// If the entry is expunged, tryLoadOrStore leaves the entry unchanged and
// returns with ok==false.
func (e *entry[T]) tryLoadOrStore(i T) (actual T, loaded, ok bool) {
func (e *entry[T]) tryLoadOrStore(i func() T) (actual T, loaded, ok bool) {
p := atomic.LoadPointer(&e.p)
if p == expunged {
var defaultValue T
@ -254,10 +255,10 @@ func (e *entry[T]) tryLoadOrStore(i T) (actual T, loaded, ok bool) {
// Copy the interface after the first load to make this method more amenable
// to escape analysis: if we hit the "load" path or the entry is expunged, we
// shouldn't bother heap-allocating.
ic := i
ic := i()
for {
if atomic.CompareAndSwapPointer(&e.p, nil, unsafe.Pointer(&ic)) {
return i, false, true
return ic, false, true
}
p = atomic.LoadPointer(&e.p)
if p == expunged {

View file

@ -1,4 +1,4 @@
package socksaddr
package metadata
import (
"net"
@ -13,6 +13,35 @@ type Addr interface {
String() string
}
type AddrPort struct {
Addr Addr
Port uint16
}
func (ap AddrPort) IPAddr() *net.IPAddr {
return &net.IPAddr{
IP: ap.Addr.Addr().AsSlice(),
}
}
func (ap AddrPort) TCPAddr() *net.TCPAddr {
return &net.TCPAddr{
IP: ap.Addr.Addr().AsSlice(),
Port: int(ap.Port),
}
}
func (ap AddrPort) UDPAddr() *net.UDPAddr {
return &net.UDPAddr{
IP: ap.Addr.Addr().AsSlice(),
Port: int(ap.Port),
}
}
func (ap AddrPort) String() string {
return net.JoinHostPort(ap.Addr.String(), strconv.Itoa(int(ap.Port)))
}
func ParseAddr(address string) Addr {
addr, err := netip.ParseAddr(address)
if err == nil {
@ -21,16 +50,44 @@ func ParseAddr(address string) Addr {
return AddrFromFqdn(address)
}
func ParseAddrPort(address string) (Addr, uint16, error) {
func AddrPortFrom(addr Addr, port uint16) *AddrPort {
return &AddrPort{addr, port}
}
func ParseAddress(address string) (*AddrPort, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, 0, err
return nil, err
}
portInt, err := strconv.Atoi(port)
if err != nil {
return nil, 0, err
return nil, err
}
return ParseAddr(host), uint16(portInt), nil
return AddrPortFrom(ParseAddr(host), uint16(portInt)), nil
}
func ParseAddrPort(address string, port string) (*AddrPort, error) {
portInt, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
return AddrPortFrom(ParseAddr(address), uint16(portInt)), nil
}
func AddrPortFromNetAddr(netAddr net.Addr) *AddrPort {
var ip net.IP
var port uint16
switch addr := netAddr.(type) {
case *net.TCPAddr:
ip = addr.IP
port = uint16(addr.Port)
case *net.UDPAddr:
ip = addr.IP
port = uint16(addr.Port)
case *net.IPAddr:
ip = addr.IP
}
return AddrPortFrom(AddrFromIP(ip), port)
}
func AddrFromIP(ip net.IP) Addr {
@ -50,27 +107,14 @@ func AddrFromAddr(addr netip.Addr) Addr {
}
}
func AddrFromNetAddr(netAddr net.Addr) (addr Addr, port uint16) {
var ip net.IP
switch addr := netAddr.(type) {
case *net.TCPAddr:
ip = addr.IP
port = uint16(addr.Port)
case *net.UDPAddr:
ip = addr.IP
port = uint16(addr.Port)
}
return AddrFromIP(ip), port
func AddrPortFromAddrPort(addrPort netip.AddrPort) *AddrPort {
return AddrPortFrom(AddrFromAddr(addrPort.Addr()), addrPort.Port())
}
func AddrFromFqdn(fqdn string) Addr {
return AddrFqdn(fqdn)
}
func JoinHostPort(addr Addr, port uint16) string {
return net.JoinHostPort(addr.String(), strconv.Itoa(int(port)))
}
type Addr4 [4]byte
func (a Addr4) Family() Family {

View file

@ -1,4 +1,4 @@
package socksaddr
package metadata
import "fmt"

View file

@ -1,4 +1,4 @@
package socksaddr
package metadata
type Family byte

View file

@ -0,0 +1,20 @@
package metadata
import (
"net"
"github.com/sagernet/sing/common/buf"
)
type Metadata struct {
Source *AddrPort
Destination *AddrPort
}
type TCPConnectionHandler interface {
NewConnection(conn net.Conn, metadata Metadata) error
}
type UDPHandler interface {
NewPacket(packet *buf.Buffer, metadata Metadata) error
}

View file

@ -1,11 +1,11 @@
package socksaddr
package metadata
import (
"encoding/binary"
"io"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/exceptions"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
)
@ -24,16 +24,9 @@ func PortThenAddress() SerializerOption {
}
}
func WithFamilyParser(fp FamilyParser) SerializerOption {
return func(s *Serializer) {
s.familyParser = fp
}
}
type Serializer struct {
familyMap map[byte]Family
familyByteMap map[Family]byte
familyParser FamilyParser
portFirst bool
}
@ -66,20 +59,20 @@ func (s *Serializer) WritePort(writer io.Writer, port uint16) error {
return binary.Write(writer, binary.BigEndian, port)
}
func (s *Serializer) WriteAddressAndPort(writer io.Writer, addr Addr, port uint16) error {
func (s *Serializer) WriteAddrPort(writer io.Writer, addrPort *AddrPort) error {
var err error
if !s.portFirst {
err = s.WriteAddress(writer, addr)
err = s.WriteAddress(writer, addrPort.Addr)
} else {
err = s.WritePort(writer, port)
err = s.WritePort(writer, addrPort.Port)
}
if err != nil {
return err
}
if s.portFirst {
err = s.WriteAddress(writer, addr)
err = s.WriteAddress(writer, addrPort.Addr)
} else {
err = s.WritePort(writer, port)
err = s.WritePort(writer, addrPort.Port)
}
return err
}
@ -89,15 +82,12 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) {
if err != nil {
return nil, err
}
if s.familyParser != nil {
af = s.familyParser(af)
}
family := s.familyMap[af]
switch family {
case AddressFamilyFqdn:
fqdn, err := ReadString(reader)
if err != nil {
return nil, exceptions.Cause(err, "read fqdn")
return nil, E.Cause(err, "read fqdn")
}
return AddrFqdn(fqdn), nil
default:
@ -106,18 +96,18 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) {
var addr [4]byte
err = common.Error(reader.Read(addr[:]))
if err != nil {
return nil, exceptions.Cause(err, "read ipv4 address")
return nil, E.Cause(err, "read ipv4 address")
}
return Addr4(addr), nil
case AddressFamilyIPv6:
var addr [16]byte
err = common.Error(reader.Read(addr[:]))
if err != nil {
return nil, exceptions.Cause(err, "read ipv6 address")
return nil, E.Cause(err, "read ipv6 address")
}
return Addr16(addr), nil
default:
return nil, exceptions.New("unknown address family: ", af)
return nil, E.New("unknown address family: ", af)
}
}
}
@ -125,12 +115,14 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) {
func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) {
port, err := rw.ReadBytes(reader, 2)
if err != nil {
return 0, exceptions.Cause(err, "read port")
return 0, E.Cause(err, "read port")
}
return binary.BigEndian.Uint16(port), nil
}
func (s *Serializer) ReadAddressAndPort(reader io.Reader) (addr Addr, port uint16, err error) {
func (s *Serializer) ReadAddrPort(reader io.Reader) (addrPort *AddrPort, err error) {
var addr Addr
var port uint16
if !s.portFirst {
addr, err = s.ReadAddress(reader)
} else {
@ -144,7 +136,10 @@ func (s *Serializer) ReadAddressAndPort(reader io.Reader) (addr Addr, port uint1
} else {
port, err = s.ReadPort(reader)
}
return
if err != nil {
return
}
return AddrPortFrom(addr, port), nil
}
func ReadString(reader io.Reader) (string, error) {

15
common/net.go Normal file
View file

@ -0,0 +1,15 @@
package common
import "syscall"
func GetFileDescriptor(conn syscall.Conn) (uintptr, error) {
rawConn, err := conn.SyscallConn()
if err != nil {
return 0, err
}
var rawFd uintptr
err = rawConn.Control(func(fd uintptr) {
rawFd = fd
})
return rawFd, err
}

18
common/random/rng.go Normal file
View file

@ -0,0 +1,18 @@
package random
import (
"crypto/rand"
"io"
"github.com/sagernet/sing/common"
"lukechampine.com/blake3"
)
var System = rand.Reader
func Blake3KeyedHash() io.Reader {
key := make([]byte, 32)
common.Must1(io.ReadFull(System, key))
h := blake3.New(1024, key)
return h.XOF()
}

9
common/redir/mode.go Normal file
View file

@ -0,0 +1,9 @@
package redir
type TransproxyMode uint8
const (
ModeDisabled TransproxyMode = iota
ModeRedirect
ModeTProxy
)

View file

@ -2,11 +2,12 @@ package redir
import (
"net"
"net/netip"
"syscall"
M "github.com/sagernet/sing/common/metadata"
)
func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err error) {
func GetOriginalDestination(conn net.Conn) (destination *M.AddrPort, err error) {
rawConn, err := conn.(syscall.Conn).SyscallConn()
if err != nil {
return
@ -22,15 +23,14 @@ func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err erro
if conn.RemoteAddr().(*net.TCPAddr).IP.To4() != nil {
raw, err := syscall.GetsockoptIPv6Mreq(int(rawFd), syscall.IPPROTO_IP, SO_ORIGINAL_DST)
if err != nil {
return netip.AddrPort{}, err
return nil, err
}
addr, _ := netip.AddrFromSlice(raw.Multiaddr[4:8])
return netip.AddrPortFrom(addr, uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
return M.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
} else {
raw, err := syscall.GetsockoptIPv6MTUInfo(int(rawFd), syscall.IPPROTO_IPV6, SO_ORIGINAL_DST)
if err != nil {
return nil, err
}
return M.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), raw.Addr.Port), nil
}
raw, err := syscall.GetsockoptIPv6MTUInfo(int(rawFd), syscall.IPPROTO_IPV6, SO_ORIGINAL_DST)
if err != nil {
return
}
return netip.AddrPortFrom(netip.AddrFrom16(raw.Addr.Addr), raw.Addr.Port), nil
}

View file

@ -5,10 +5,10 @@ package redir
import (
"errors"
"net"
"net/netip"
M "github.com/sagernet/sing/common/metadata"
)
func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err error) {
err = errors.New("unsupported platform")
return
func GetOriginalDestination(conn net.Conn) (destination *M.AddrPort, err error) {
return nil, errors.New("unsupported platform")
}

View file

@ -0,0 +1,131 @@
package redir
import (
"encoding/binary"
"fmt"
"net"
"os"
"strconv"
"syscall"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
func TProxy(fd uintptr, isIPv6 bool) error {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
if err != nil {
return err
}
if isIPv6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
}
return err
}
func TProxyUDP(fd uintptr, isIPv6 bool) error {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
if err != nil {
return err
}
return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
}
func GetOriginalDestinationFromOOB(oob []byte) (*M.AddrPort, error) {
controlMessages, err := unix.ParseSocketControlMessage(oob)
if err != nil {
return nil, err
}
for _, message := range controlMessages {
if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR {
return M.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil
} else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR {
return M.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil
}
}
return nil, E.New("not found")
}
func DialUDP(network string, lAddr *net.UDPAddr, rAddr *net.UDPAddr) (*net.UDPConn, error) {
rSockAddr, err := udpAddrToSockAddr(rAddr)
if err != nil {
return nil, err
}
lSockAddr, err := udpAddrToSockAddr(lAddr)
if err != nil {
return nil, err
}
fd, err := syscall.Socket(udpAddrFamily(network, lAddr, rAddr), syscall.SOCK_DGRAM, 0)
if err != nil {
return nil, err
}
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
syscall.Close(fd)
return nil, err
}
if err = syscall.SetsockoptInt(fd, syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil {
syscall.Close(fd)
return nil, err
}
if err = syscall.Bind(fd, lSockAddr); err != nil {
syscall.Close(fd)
return nil, err
}
if err = syscall.Connect(fd, rSockAddr); err != nil {
syscall.Close(fd)
return nil, err
}
fdFile := os.NewFile(uintptr(fd), fmt.Sprintf("net-udp-dial-%s", rAddr.String()))
defer fdFile.Close()
c, err := net.FileConn(fdFile)
if err != nil {
syscall.Close(fd)
return nil, err
}
return c.(*net.UDPConn), nil
}
func udpAddrToSockAddr(addr *net.UDPAddr) (syscall.Sockaddr, error) {
switch {
case addr.IP.To4() != nil:
ip := [4]byte{}
copy(ip[:], addr.IP.To4())
return &syscall.SockaddrInet4{Addr: ip, Port: addr.Port}, nil
default:
ip := [16]byte{}
copy(ip[:], addr.IP.To16())
zoneID, err := strconv.ParseUint(addr.Zone, 10, 32)
if err != nil {
zoneID = 0
}
return &syscall.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, nil
}
}
func udpAddrFamily(net string, lAddr, rAddr *net.UDPAddr) int {
switch net[len(net)-1] {
case '4':
return syscall.AF_INET
case '6':
return syscall.AF_INET6
}
if (lAddr == nil || lAddr.IP.To4() != nil) && (rAddr == nil || lAddr.IP.To4() != nil) {
return syscall.AF_INET
}
return syscall.AF_INET6
}

View file

@ -0,0 +1,20 @@
//go:build !linux
package redir
import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
func TProxy(fd uintptr, isIPv6 bool) error {
return E.New("only available on linux")
}
func TProxyUDP(fd uintptr, isIPv6 bool) error {
return E.New("only available on linux")
}
func GetOriginalDestinationFromOOB(oob []byte) (*M.AddrPort, error) {
return nil, E.New("only available on linux")
}

View file

@ -0,0 +1,30 @@
package replay
import (
"github.com/v2fly/ss-bloomring"
"sync"
)
func NewBloomRing() Filter {
const (
DefaultSFCapacity = 1e6
DefaultSFFPR = 1e-6
DefaultSFSlot = 10
)
return &bloomRingFilter{BloomRing: ss_bloomring.NewBloomRing(DefaultSFSlot, DefaultSFCapacity, DefaultSFFPR)}
}
type bloomRingFilter struct {
sync.Mutex
*ss_bloomring.BloomRing
}
func (f *bloomRingFilter) Check(sum []byte) bool {
f.Lock()
defer f.Unlock()
if f.Test(sum) {
return false
}
f.Add(sum)
return true
}

50
common/replay/cuckoo.go Normal file
View file

@ -0,0 +1,50 @@
package replay
import (
"sync"
"time"
"github.com/seiflotfy/cuckoofilter"
)
func NewCuckoo(interval int64) Filter {
filter := &cuckooFilter{}
filter.interval = interval
return filter
}
type cuckooFilter struct {
lock sync.Mutex
poolA *cuckoo.Filter
poolB *cuckoo.Filter
poolSwap bool
lastSwap int64
interval int64
}
func (filter *cuckooFilter) Check(sum []byte) bool {
const defaultCapacity = 100000
filter.lock.Lock()
defer filter.lock.Unlock()
now := time.Now().Unix()
if filter.lastSwap == 0 {
filter.lastSwap = now
filter.poolA = cuckoo.NewFilter(defaultCapacity)
filter.poolB = cuckoo.NewFilter(defaultCapacity)
}
elapsed := now - filter.lastSwap
if elapsed >= filter.interval {
if filter.poolSwap {
filter.poolA.Reset()
} else {
filter.poolB.Reset()
}
filter.poolSwap = !filter.poolSwap
filter.lastSwap = now
}
return filter.poolA.InsertUnique(sum) && filter.poolB.InsertUnique(sum)
}

5
common/replay/filter.go Normal file
View file

@ -0,0 +1,5 @@
package replay
type Filter interface {
Check(sum []byte) bool
}

View file

@ -40,6 +40,18 @@ func ReadFromVar(writerVar *io.Writer, reader io.Reader) (int64, error) {
return 0, os.ErrInvalid
}
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
return task.Run(context.Background(), func() error {
defer CloseRead(conn)
defer CloseWrite(dest)
return common.Error(io.Copy(dest, conn))
}, func() error {
defer CloseRead(dest)
defer CloseWrite(conn)
return common.Error(io.Copy(conn, dest))
})
}
func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error {
return task.Run(ctx, func() error {
buffer := buf.FullNew()

View file

@ -2,8 +2,9 @@ package rw
import (
"encoding/binary"
"github.com/sagernet/sing/common"
"io"
"github.com/sagernet/sing/common"
)
type InputStream interface {

View file

@ -5,7 +5,7 @@ import (
"strconv"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/socksaddr"
M "github.com/sagernet/sing/common/metadata"
)
type Network int
@ -20,8 +20,8 @@ type InstanceContext struct{}
type Context struct {
InstanceContext
Network Network
Source socksaddr.Addr
Destination socksaddr.Addr
Source M.Addr
Destination M.Addr
SourcePort uint16
DestinationPort uint16
}
@ -30,7 +30,7 @@ func (c Context) DestinationNetAddr() string {
return net.JoinHostPort(c.Destination.String(), strconv.Itoa(int(c.DestinationPort)))
}
func AddressFromNetAddr(netAddr net.Addr) (addr socksaddr.Addr, port uint16) {
func AddressFromNetAddr(netAddr net.Addr) (addr M.Addr, port uint16) {
var ip net.IP
switch addr := netAddr.(type) {
case *net.TCPAddr:
@ -40,7 +40,7 @@ func AddressFromNetAddr(netAddr net.Addr) (addr socksaddr.Addr, port uint16) {
ip = addr.IP
port = uint16(addr.Port)
}
return socksaddr.AddrFromIP(ip), port
return M.AddrFromIP(ip), port
}
type Conn struct {

108
common/udpnat/server.go Normal file
View file

@ -0,0 +1,108 @@
package udpnat
import (
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/protocol/socks"
)
type Handler interface {
socks.UDPConnectionHandler
E.Handler
}
type Server struct {
udpNat gsync.Map[string, *packetConn]
handler Handler
}
func NewServer(handler Handler) *Server {
return &Server{handler: handler}
}
func (s *Server) HandleUDP(buffer *buf.Buffer, metadata M.Metadata) error {
conn, loaded := s.udpNat.LoadOrStore(metadata.Source.String(), func() *packetConn {
return &packetConn{source: metadata.Source.UDPAddr(), in: make(chan *udpPacket)}
})
if !loaded {
go func() {
err := s.handler.NewPacketConnection(conn, metadata)
if err != nil {
s.handler.HandleError(err)
}
}()
}
conn.in <- &udpPacket{
buffer: buffer,
destination: metadata.Destination,
}
return nil
}
func (s *Server) OnError(err error) {
s.handler.HandleError(err)
}
func (s *Server) Close() error {
s.udpNat.Range(func(key string, conn *packetConn) bool {
conn.Close()
return true
})
s.udpNat = gsync.Map[string, *packetConn]{}
return nil
}
type packetConn struct {
socks.PacketConnStub
source *net.UDPAddr
in chan *udpPacket
}
type udpPacket struct {
buffer *buf.Buffer
destination *M.AddrPort
}
func (c *packetConn) LocalAddr() net.Addr {
return c.source
}
func (c *packetConn) Close() error {
select {
case <-c.in:
return io.ErrClosedPipe
default:
close(c.in)
}
return nil
}
func (c *packetConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
select {
case packet, ok := <-c.in:
if !ok {
return nil, io.ErrClosedPipe
}
defer packet.buffer.Release()
if buffer.FreeLen() < packet.buffer.Len() {
return nil, io.ErrShortBuffer
}
return packet.destination, common.Error(buffer.Write(packet.buffer.Bytes()))
}
}
func (c *packetConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), c.source)
if err != nil {
return E.Cause(err, "tproxy udp write back")
}
defer udpConn.Close()
return common.Error(udpConn.Write(buffer.Bytes()))
}

80
common/uot/client.go Normal file
View file

@ -0,0 +1,80 @@
package uot
import (
"encoding/binary"
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
type ClientConn struct {
net.Conn
}
func NewClientConn(conn net.Conn) *ClientConn {
return &ClientConn{conn}
}
func (c *ClientConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
destination, err := AddrParser.ReadAddrPort(c)
if err != nil {
return nil, err
}
var length uint16
err = binary.Read(c, binary.BigEndian, &length)
if err != nil {
return nil, err
}
if buffer.FreeLen() < int(length) {
return nil, io.ErrShortBuffer
}
return destination, common.Error(buffer.ReadFullFrom(c, int(length)))
}
func (c *ClientConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
err := AddrParser.WriteAddrPort(c, destination)
if err != nil {
return err
}
err = binary.Write(c, binary.BigEndian, uint16(buffer.Len()))
if err != nil {
return err
}
return common.Error(c.Write(buffer.Bytes()))
}
func (c *ClientConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
addrPort, err := AddrParser.ReadAddrPort(c)
if err != nil {
return 0, nil, err
}
var length uint16
err = binary.Read(c, binary.BigEndian, &length)
if err != nil {
return 0, nil, err
}
if len(p) < int(length) {
return 0, nil, io.ErrShortBuffer
}
n, err = io.ReadAtLeast(c, p, int(length))
if err != nil {
return 0, nil, err
}
addr = addrPort.UDPAddr()
return
}
func (c *ClientConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
err = AddrParser.WriteAddrPort(c, M.AddrPortFromNetAddr(addr))
if err != nil {
return
}
err = binary.Write(c, binary.BigEndian, uint16(len(p)))
if err != nil {
return
}
return c.Write(p)
}

21
common/uot/resolver.go Normal file
View file

@ -0,0 +1,21 @@
package uot
import (
"context"
"net"
"time"
)
var LookupAddress func(domain string) (net.IP, error)
func init() {
LookupAddress = func(domain string) (net.IP, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", domain)
cancel()
if err != nil {
return nil, err
}
return ips[0], nil
}
}

108
common/uot/server.go Normal file
View file

@ -0,0 +1,108 @@
package uot
import (
"encoding/binary"
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
type ServerConn struct {
net.PacketConn
inputReader, outputReader *io.PipeReader
inputWriter, outputWriter *io.PipeWriter
}
func NewServerConn(packetConn net.PacketConn) net.Conn {
c := &ServerConn{
PacketConn: packetConn,
}
c.inputReader, c.inputWriter = io.Pipe()
c.outputReader, c.outputWriter = io.Pipe()
go c.loopInput()
go c.loopOutput()
return c
}
func (c *ServerConn) Read(b []byte) (n int, err error) {
return c.outputReader.Read(b)
}
func (c *ServerConn) Write(b []byte) (n int, err error) {
return c.inputWriter.Write(b)
}
func (c *ServerConn) RemoteAddr() net.Addr {
return &common.DummyAddr{}
}
func (c *ServerConn) loopInput() {
buffer := buf.New()
defer buffer.Release()
for {
destination, err := AddrParser.ReadAddrPort(c.inputReader)
if err != nil {
break
}
if destination.Addr.Family().IsFqdn() {
ip, err := LookupAddress(destination.Addr.Fqdn())
if err != nil {
break
}
destination.Addr = M.AddrFromIP(ip)
}
var length uint16
err = binary.Read(c.inputReader, binary.BigEndian, &length)
if err != nil {
break
}
buffer.FullReset()
_, err = buffer.ReadFullFrom(c.inputReader, int(length))
if err != nil {
break
}
_, err = c.WriteTo(buffer.Bytes(), destination.UDPAddr())
if err != nil {
break
}
}
c.Close()
}
func (c *ServerConn) loopOutput() {
buffer := buf.New()
defer buffer.Release()
for {
buffer.FullReset()
n, addr, err := buffer.ReadPacketFrom(c)
if err != nil {
break
}
destination := M.AddrPortFromNetAddr(addr)
err = AddrParser.WriteAddrPort(c.outputWriter, destination)
if err != nil {
break
}
err = binary.Write(c.outputWriter, binary.BigEndian, uint16(n))
if err != nil {
break
}
_, err = buffer.WriteTo(c.outputWriter)
if err != nil {
break
}
}
c.Close()
}
func (c *ServerConn) Close() error {
c.inputReader.Close()
c.inputWriter.Close()
c.outputReader.Close()
c.outputWriter.Close()
c.PacketConn.Close()
return nil
}

13
common/uot/uot.go Normal file
View file

@ -0,0 +1,13 @@
package uot
import (
M "github.com/sagernet/sing/common/metadata"
)
const UOTMagicAddress = "sp.udp-over-tcp.arpa"
var AddrParser = M.NewSerializer(
M.AddressFamilyByte(0x00, M.AddressFamilyIPv4),
M.AddressFamilyByte(0x01, M.AddressFamilyIPv6),
M.AddressFamilyByte(0x02, M.AddressFamilyFqdn),
)

39
common/uot/uot_test.go Normal file
View file

@ -0,0 +1,39 @@
package uot
import (
"net"
"testing"
"github.com/sagernet/uot/common"
"github.com/sagernet/uot/common/buf"
"golang.org/x/net/dns/dnsmessage"
)
func TestServerConn(t *testing.T) {
udpConn, err := net.ListenUDP("udp", nil)
common.Must(err)
serverConn := NewServerConn(udpConn)
defer serverConn.Close()
clientConn := NewClientConn(serverConn)
message := new(dnsmessage.Message)
message.Header.ID = 1
message.Header.RecursionDesired = true
message.Questions = append(message.Questions, dnsmessage.Question{
Name: dnsmessage.MustNewName("google.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
})
packet, err := message.Pack()
common.Must(err)
common.Must1(clientConn.WriteTo(packet, &net.UDPAddr{
IP: net.IPv4(8, 8, 8, 8),
Port: 53,
}))
buffer := buf.New()
defer buffer.Release()
common.Must2(buffer.ReadPacketFrom(clientConn))
common.Must(message.Unpack(buffer.Bytes()))
for _, answer := range message.Answers {
t.Log("got answer :", answer.Body)
}
}

View file

@ -1,3 +1,6 @@
package sing
const Version = "v0.0.0-alpha.1"
const (
Version = "v0.0.0-alpha.1"
VersionStr = "sing " + Version
)

16
go.mod
View file

@ -6,20 +6,32 @@ require (
github.com/klauspost/compress v1.15.1
github.com/openacid/low v0.1.21
github.com/oschwald/geoip2-golang v1.7.0
github.com/sagernet/uot v0.0.0-20220403125237-bf82029ad617
github.com/samber/lo v1.11.0
github.com/seiflotfy/cuckoofilter v0.0.0-20201222105146-bc6005554a0c
github.com/sirupsen/logrus v1.8.1
github.com/spf13/cobra v1.4.0
github.com/stretchr/testify v1.7.1
github.com/ulikunitz/xz v0.5.10
github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e
github.com/v2fly/v2ray-core/v5 v5.0.3
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29
golang.org/x/crypto v0.0.0-20220408190544-5352b0902921
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3
golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f
google.golang.org/protobuf v1.28.0
lukechampine.com/blake3 v1.1.7
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/klauspost/cpuid/v2 v2.0.9 // indirect
github.com/oschwald/maxminddb-golang v1.9.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)

30
go.sum
View file

@ -3,6 +3,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165 h1:BS21ZUJ/B5X2UVUbczfmdWH7GapPWAhxcMsDnjJTU1E=
github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
@ -13,6 +15,10 @@ github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NH
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/klauspost/compress v1.15.1 h1:y9FcTHGyrebwfP0ZZqFiaxTaiDnUrGkJkI+f583BL1A=
github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/openacid/errors v0.8.1/go.mod h1:GUQEJJOJE3W9skHm8E8Y4phdl2LLEN8iD7c5gcGgdx0=
github.com/openacid/low v0.1.21 h1:Tr2GNu4N/+rGRYdOsEHOE89cxUIaDViZbVmKz29uKGo=
github.com/openacid/low v0.1.21/go.mod h1:q+MsKI6Pz2xsCkzV4BLj7NR5M4EX0sGz5AqotpZDVh0=
@ -25,9 +31,15 @@ github.com/oschwald/maxminddb-golang v1.9.0/go.mod h1:TK+s/Z2oZq0rSl4PSeAEoP0bgm
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg=
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3/go.mod h1:HgjTstvQsPGkxUsCd2KWxErBblirPizecHcpD3ffK+s=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagernet/uot v0.0.0-20220403125237-bf82029ad617 h1:h46Ocvf7zWpatqOHcR4kw+k2GbGcMM7EzGjYG7wiGfM=
github.com/sagernet/uot v0.0.0-20220403125237-bf82029ad617/go.mod h1:T2LhXiIIvaoeKii21x1GONCee9u7N9Nnrqz5bY3SWsM=
github.com/samber/lo v1.11.0 h1:JfeYozXL1xfkhRUFOfH13ociyeiLSC/GRJjGKI668xM=
github.com/samber/lo v1.11.0/go.mod h1:2I7tgIv8Q1SG2xEIkRq0F2i2zgxVpnyPOP0d3Gj2r+A=
github.com/seiflotfy/cuckoofilter v0.0.0-20201222105146-bc6005554a0c h1:pqy40B3MQWYrza7YZXOXgl0Nf0QGFqrOC0BKae1UNAA=
github.com/seiflotfy/cuckoofilter v0.0.0-20201222105146-bc6005554a0c/go.mod h1:bR6DqgcAl1zTcOX8/pE2Qkj9XO00eCNqmKb7lXP8EAg=
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/spf13/cobra v1.4.0 h1:y+wJpx64xcgO1V+RcnwW0LEHxTKRi2ZDPSBjWnrg88Q=
@ -40,18 +52,23 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M=
github.com/ulikunitz/xz v0.5.10 h1:t92gobL9l3HE202wg3rlk19F6X+JOxl9BBrCCMYEYd8=
github.com/ulikunitz/xz v0.5.10/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e h1:5QefA066A1tF8gHIiADmOVOV5LS43gt3ONnlEl3xkwI=
github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e/go.mod h1:5t19P9LBIrNamL6AcMQOncg/r10y3Pc01AbHeMhwlpU=
github.com/v2fly/v2ray-core/v5 v5.0.3 h1:2rnJ9vZbBQ7V4upWsoUVYGoqZl4grrx8SxOReKx+jjc=
github.com/v2fly/v2ray-core/v5 v5.0.3/go.mod h1:zhDdsUJcNE8LcLRA3l7fEQ6QLuveD4/OLbQM2CceSHM=
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 h1:tkVvjkPTB7pnW3jnid7kNyAMPVWllTNOf/qKDze4p9o=
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220408190544-5352b0902921 h1:iU7T1X1J6yxDr0rda54sWGkHgOp5XJrqm79gcNlC2VM=
golang.org/x/crypto v0.0.0-20220408190544-5352b0902921/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 h1:EN5+DfgmRMvRUrMGERW2gQl3Vc+Z7ZMnI/xdEpPSf0c=
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0Pp5Fn9Qga0mrgxyERQM=
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f h1:8w7RhxzTVgUzw/AH/9mUV5q0vMgy40SQRursCcfmkCw=
golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
@ -59,6 +76,11 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
lukechampine.com/blake3 v1.1.7 h1:GgRMhmdsuK8+ii6UZFDL8Nb+VyMwadAgcJyfYHxG6n0=
lukechampine.com/blake3 v1.1.7/go.mod h1:tkKEOtDkNtklkXtLNEOGNq5tcV90tJiA1vAA12R78LA=

View file

@ -1,92 +1,42 @@
package system
package http
import (
"bufio"
"context"
"encoding/base64"
"fmt"
"net"
"net/http"
"net/netip"
"strconv"
"strings"
"syscall"
"time"
_ "unsafe"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/common/socksaddr"
"github.com/sagernet/sing/protocol/socks"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/transport/tcp"
)
type MixedListener struct {
*SocksListener
UDPListener *UDPListener
mcfg *MixedConfig
laddr *net.TCPAddr
type Handler interface {
tcp.Handler
}
type MixedConfig struct {
SocksConfig
Redirect bool
TProxy bool
}
func NewMixedListener(bind netip.AddrPort, config *MixedConfig, handler SocksHandler) *MixedListener {
listener := &MixedListener{SocksListener: NewSocksListener(bind, &config.SocksConfig, handler), mcfg: config}
listener.TCPListener.Handler = listener
if config.TProxy {
listener.UDPListener = NewUDPListener(bind, listener)
}
return listener
}
func (l *MixedListener) HandleTCP(conn net.Conn) error {
if l.mcfg.Redirect {
var destination netip.AddrPort
destination, err := redir.GetOriginalDestination(conn)
if err == nil {
return l.Handler.NewConnection(socksaddr.AddrFromAddr(destination.Addr()), destination.Port(), conn)
}
} else if l.mcfg.TProxy {
lAddr := conn.LocalAddr().(*net.TCPAddr)
rAddr := conn.RemoteAddr().(*net.TCPAddr)
if lAddr.Port != l.laddr.Port || !lAddr.IP.Equal(rAddr.IP) && !lAddr.IP.IsLoopback() && !lAddr.IP.IsPrivate() {
addr, port := socksaddr.AddrFromNetAddr(lAddr)
return l.Handler.NewConnection(addr, port, conn)
}
}
bufConn := buf.NewBufferedConn(conn)
hdr, err := bufConn.ReadByte()
if err != nil {
return err
}
err = bufConn.UnreadByte()
if err != nil {
return err
}
if hdr == socks.Version4 || hdr == socks.Version5 {
return l.SocksListener.HandleTCP(bufConn)
}
func HandleConnection(conn *buf.BufferedConn, authenticator auth.Authenticator, handler Handler) error {
var httpClient *http.Client
for {
request, err := readRequest(bufConn.Reader())
request, err := readRequest(conn.Reader())
if err != nil {
return exceptions.Cause(err, "read http request")
return E.Cause(err, "read http request")
}
if l.Username != "" {
if authenticator != nil {
var authOk bool
authorization := request.Header.Get("Proxy-Authorization")
if strings.HasPrefix(authorization, "BASIC ") {
userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
if string(userPassword) == l.Username+":"+l.Password {
authOk = true
}
userPswdArr := strings.SplitN(string(userPassword), ":", 2)
authOk = authenticator.Verify(userPswdArr[0], userPswdArr[1])
}
if !authOk {
err = responseWith(request, http.StatusProxyAuthRequired).Write(conn)
@ -97,23 +47,23 @@ func (l *MixedListener) HandleTCP(conn net.Conn) error {
}
if request.Method == "CONNECT" {
host := request.URL.Hostname()
portStr := request.URL.Port()
if portStr == "" {
portStr = "80"
}
port, err := strconv.Atoi(portStr)
destination, err := M.ParseAddrPort(request.URL.Hostname(), portStr)
if err != nil {
err = responseWith(request, http.StatusBadRequest).Write(conn)
if err != nil {
return err
}
}
_, err = fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established")
if err != nil {
return exceptions.Cause(err, "write http response")
return E.Cause(err, "write http response")
}
return l.Handler.NewConnection(socksaddr.ParseAddr(host), uint16(port), bufConn)
return handler.NewConnection(conn, M.Metadata{
Destination: destination,
})
}
keepAlive := strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
@ -141,19 +91,21 @@ func (l *MixedListener) HandleTCP(conn net.Conn) error {
ExpectContinueTimeout: 1 * time.Second,
DialContext: func(context context.Context, network, address string) (net.Conn, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, exceptions.New("unsupported network ", network)
return nil, E.New("unsupported network ", network)
}
addr, port, err := socksaddr.ParseAddrPort(address)
destination, err := M.ParseAddress(address)
if err != nil {
return nil, err
}
left, right := net.Pipe()
go func() {
err = l.Handler.NewConnection(addr, port, right)
err = handler.NewConnection(right, M.Metadata{
Destination: destination,
})
if err != nil {
l.OnError(err)
handler.HandleError(err)
}
}()
return left, nil
@ -167,7 +119,7 @@ func (l *MixedListener) HandleTCP(conn net.Conn) error {
response, err := httpClient.Do(request)
if err != nil {
l.OnError(exceptions.Cause(err, "http proxy"))
handler.HandleError(err)
return responseWith(request, http.StatusBadGateway).Write(conn)
}
@ -183,17 +135,14 @@ func (l *MixedListener) HandleTCP(conn net.Conn) error {
err = response.Write(conn)
if err != nil {
l.OnError(exceptions.Cause(err, "http proxy"))
return err
}
}
}
func (l *MixedListener) HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error {
return nil
}
//go:linkname readRequest net/http.ReadRequest
func readRequest(b *bufio.Reader) (req *http.Request, err error)
// removeHopByHopHeaders remove hop-by-hop header
func removeHopByHopHeaders(header http.Header) {
// Strip hop-by-hop header based on RFC:
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1
@ -217,8 +166,6 @@ func removeHopByHopHeaders(header http.Header) {
}
}
// removeExtraHTTPHostPort remove extra host port (example.com:80 --> example.com)
// It resolves the behavior of some HTTP servers that do not handle host:80 (e.g. baidu.com)
func removeExtraHTTPHostPort(req *http.Request) {
host := req.Host
if host == "" {
@ -243,39 +190,3 @@ func responseWith(request *http.Request, statusCode int) *http.Response {
Header: http.Header{},
}
}
func (l *MixedListener) Start() error {
err := l.TCPListener.Start()
if err != nil {
return err
}
if l.mcfg.TProxy {
rawConn, err := l.TCPListener.TCPListener.SyscallConn()
if err != nil {
return err
}
var rawFd uintptr
err = rawConn.Control(func(fd uintptr) {
rawFd = fd
})
if err != nil {
return err
}
err = syscall.SetsockoptInt(int(rawFd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
if err != nil {
return exceptions.Cause(err, "failed to configure TCP tproxy")
}
}
if l.mcfg.TProxy {
l.laddr = l.TCPListener.Addr().(*net.TCPAddr)
}
return nil
}
func (l *MixedListener) Close() error {
return l.TCPListener.Close()
}
func (l *MixedListener) OnError(err error) {
l.Handler.OnError(exceptions.Cause(err, "mixed server"))
}

View file

@ -1,47 +0,0 @@
package shadowsocks
import (
"io"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/list"
)
type Cipher interface {
KeySize() int
SaltSize() int
CreateReader(key []byte, salt []byte, reader io.Reader) io.Reader
CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer
EncodePacket(key []byte, buffer *buf.Buffer) error
DecodePacket(key []byte, buffer *buf.Buffer) error
}
type CipherCreator func() Cipher
var (
cipherList *list.List[string]
cipherMap map[string]CipherCreator
)
func init() {
cipherList = new(list.List[string])
cipherMap = make(map[string]CipherCreator)
}
func RegisterCipher(method string, creator CipherCreator) {
cipherList.PushBack(method)
cipherMap[method] = creator
}
func CreateCipher(method string) (Cipher, error) {
creator := cipherMap[method]
if creator != nil {
return creator(), nil
}
return nil, exceptions.New("unsupported method: ", method)
}
func ListCiphers() []string {
return cipherList.Array()
}

View file

@ -1,39 +0,0 @@
package shadowsocks
import (
"io"
"github.com/sagernet/sing/common/buf"
)
func init() {
RegisterCipher("none", func() Cipher {
return (*NoneCipher)(nil)
})
}
type NoneCipher struct{}
func (c *NoneCipher) KeySize() int {
return 16
}
func (c *NoneCipher) SaltSize() int {
return 0
}
func (c *NoneCipher) CreateReader(_ []byte, _ []byte, reader io.Reader) io.Reader {
return reader
}
func (c *NoneCipher) CreateWriter(key []byte, iv []byte, writer io.Writer) io.Writer {
return writer
}
func (c *NoneCipher) EncodePacket([]byte, *buf.Buffer) error {
return nil
}
func (c *NoneCipher) DecodePacket([]byte, *buf.Buffer) error {
return nil
}

View file

@ -1,178 +0,0 @@
package shadowsocks
import (
"context"
"io"
"net"
"strconv"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/socksaddr"
"github.com/sagernet/sing/protocol/socks"
)
var (
ErrBadKey = exceptions.New("bad key")
ErrMissingPassword = exceptions.New("password not specified")
)
type ClientConfig struct {
Server string `json:"server"`
ServerPort uint16 `json:"server_port"`
Method string `json:"method"`
Password []byte `json:"password"`
Key []byte `json:"key"`
}
type Client struct {
dialer *net.Dialer
cipher Cipher
server string
key []byte
}
func NewClient(dialer *net.Dialer, config *ClientConfig) (*Client, error) {
if config.Server == "" {
return nil, exceptions.New("missing server address")
}
if config.ServerPort == 0 {
return nil, exceptions.New("missing server port")
}
if config.Method == "" {
return nil, exceptions.New("missing server method")
}
cipher, err := CreateCipher(config.Method)
if err != nil {
return nil, err
}
client := &Client{
dialer: dialer,
cipher: cipher,
server: net.JoinHostPort(config.Server, strconv.Itoa(int(config.ServerPort))),
}
if keyLen := len(config.Key); keyLen > 0 {
if keyLen == cipher.KeySize() {
client.key = config.Key
} else {
return nil, ErrBadKey
}
} else if len(config.Password) > 0 {
client.key = Key(config.Password, cipher.KeySize())
} else {
return nil, ErrMissingPassword
}
return client, nil
}
func (c *Client) DialContextTCP(ctx context.Context, addr socksaddr.Addr, port uint16) (net.Conn, error) {
conn, err := c.dialer.DialContext(ctx, "tcp", c.server)
if err != nil {
return nil, exceptions.Cause(err, "connect to server")
}
return c.DialConn(conn, addr, port), nil
}
func (c *Client) DialConn(conn net.Conn, addr socksaddr.Addr, port uint16) net.Conn {
header := buf.New()
header.WriteRandom(c.cipher.SaltSize())
writer := &buf.BufferedWriter{
Writer: conn,
Buffer: header,
}
protocolWriter := c.cipher.CreateWriter(c.key, header.Bytes(), writer)
requestBuffer := buf.New()
contentWriter := &buf.BufferedWriter{
Writer: protocolWriter,
Buffer: requestBuffer,
}
common.Must(AddressSerializer.WriteAddressAndPort(contentWriter, addr, port))
return &shadowsocksConn{
Client: c,
Conn: conn,
Writer: &common.FlushOnceWriter{Writer: contentWriter},
}
}
type shadowsocksConn struct {
*Client
net.Conn
io.Writer
reader io.Reader
}
func (c *shadowsocksConn) Read(p []byte) (n int, err error) {
if c.reader == nil {
buffer := buf.Or(p, c.cipher.SaltSize())
defer buffer.Release()
_, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize())
if err != nil {
return
}
c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn)
}
return c.reader.Read(p)
}
func (c *shadowsocksConn) WriteTo(w io.Writer) (n int64, err error) {
if c.reader == nil {
buffer := buf.NewSize(c.cipher.SaltSize())
defer buffer.Release()
_, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize())
if err != nil {
return
}
c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn)
}
return c.reader.(io.WriterTo).WriteTo(w)
}
func (c *shadowsocksConn) Write(p []byte) (n int, err error) {
return c.Writer.Write(p)
}
func (c *shadowsocksConn) ReadFrom(r io.Reader) (n int64, err error) {
return rw.ReadFromVar(&c.Writer, r)
}
func (c *Client) DialContextUDP(ctx context.Context) socks.PacketConn {
conn, err := c.dialer.DialContext(ctx, "udp", c.server)
if err != nil {
return nil
}
return &shadowsocksPacketConn{c, conn}
}
type shadowsocksPacketConn struct {
*Client
net.Conn
}
func (c *shadowsocksPacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error {
defer buffer.Release()
header := buf.New()
header.WriteRandom(c.cipher.SaltSize())
common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port))
buffer = buffer.WriteBufferAtFirst(header)
err := c.cipher.EncodePacket(c.key, buffer)
if err != nil {
return err
}
return common.Error(c.Conn.Write(buffer.Bytes()))
}
func (c *shadowsocksPacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return nil, 0, err
}
buffer.Truncate(n)
err = c.cipher.DecodePacket(c.key, buffer)
if err != nil {
return nil, 0, err
}
return AddressSerializer.ReadAddressAndPort(buffer)
}

View file

@ -0,0 +1,153 @@
package shadowsocks
import (
"io"
"net"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
)
const MethodNone = "none"
type NoneMethod struct{}
func NewNone() Method {
return &NoneMethod{}
}
func (m *NoneMethod) Name() string {
return MethodNone
}
func (m *NoneMethod) KeyLength() int {
return 0
}
func (m *NoneMethod) NewSession(key []byte) Session {
return nil
}
func (m *NoneMethod) DialConn(_ Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
shadowsocksConn := &noneConn{
Conn: conn,
handshake: true,
destination: destination,
}
return shadowsocksConn, shadowsocksConn.clientHandshake()
}
func (m *NoneMethod) DialEarlyConn(_ Session, conn net.Conn, destination *M.AddrPort) net.Conn {
return &noneConn{
Conn: conn,
destination: destination,
}
}
func (m *NoneMethod) DialPacketConn(_ Session, conn net.Conn) socks.PacketConn {
return &nonePacketConn{conn}
}
type noneConn struct {
net.Conn
access sync.Mutex
handshake bool
destination *M.AddrPort
}
func (c *noneConn) clientHandshake() error {
err := socks.AddressSerializer.WriteAddrPort(c.Conn, c.destination)
if err != nil {
return err
}
c.handshake = true
return nil
}
func (c *noneConn) Write(b []byte) (n int, err error) {
if c.handshake {
goto direct
}
c.access.Lock()
defer c.access.Unlock()
if c.handshake {
goto direct
}
{
if len(b) == 0 {
return 0, c.clientHandshake()
}
buffer := buf.New()
defer buffer.Release()
err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination)
if err != nil {
return
}
bufN, _ := buffer.Write(b)
_, err = c.Conn.Write(buffer.Bytes())
if err != nil {
return
}
if bufN < len(b) {
_, err = c.Conn.Write(b[bufN:])
if err != nil {
return
}
}
n = len(b)
}
direct:
return c.Conn.Write(b)
}
func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) {
if !c.handshake {
panic("missing client handshake")
}
return c.Conn.(io.ReaderFrom).ReadFrom(r)
}
func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) {
return c.Conn.(io.WriterTo).WriteTo(w)
}
func (c *noneConn) RemoteAddr() net.Addr {
return c.destination.TCPAddr()
}
type nonePacketConn struct {
net.Conn
}
func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
_, err := buffer.ReadFrom(c)
if err != nil {
return nil, err
}
return socks.AddressSerializer.ReadAddrPort(buffer)
}
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
defer buffer.Release()
header := buf.New()
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
if err != nil {
header.Release()
return err
}
buffer = buffer.WriteBufferAtFirst(header)
return common.Error(buffer.WriteTo(c))
}

View file

@ -2,23 +2,26 @@ package shadowsocks
import (
"crypto/md5"
"crypto/sha1"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/protocol/socks"
"hash/crc32"
"io"
"math/rand"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/socksaddr"
"golang.org/x/crypto/hkdf"
"net"
)
const MaxPacketSize = 16*1024 - 1
type Session interface {
Key() []byte
ReplayFilter() replay.Filter
}
func Kdf(key, iv []byte, keyLength int) []byte {
subKey := make([]byte, keyLength)
kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
common.Must1(io.ReadFull(kdf, subKey))
return subKey
type Method interface {
Name() string
KeyLength() int
DialConn(session Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error)
DialEarlyConn(session Session, conn net.Conn, destination *M.AddrPort) net.Conn
DialPacketConn(session Session, conn net.Conn) socks.PacketConn
}
func Key(password []byte, keySize int) []byte {
@ -43,19 +46,18 @@ func Key(password []byte, keySize int) []byte {
return m[:keySize]
}
func RemapToPrintable(input []byte) {
const charSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!#$%&()*+,./:;<=>?@[]^_`{|}~\\\""
seed := rand.New(rand.NewSource(int64(crc32.ChecksumIEEE(input))))
for i := range input {
input[i] = charSet[seed.Intn(len(charSet))]
}
type ReducedEntropyReader struct {
io.Reader
}
var AddressSerializer = socksaddr.NewSerializer(
socksaddr.AddressFamilyByte(0x01, socksaddr.AddressFamilyIPv4),
socksaddr.AddressFamilyByte(0x04, socksaddr.AddressFamilyIPv6),
socksaddr.AddressFamilyByte(0x03, socksaddr.AddressFamilyFqdn),
socksaddr.WithFamilyParser(func(b byte) byte {
return b & 0x0F
}),
)
func (r *ReducedEntropyReader) Read(p []byte) (n int, err error) {
n, err = r.Reader.Read(p)
if n > 6 {
const charSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!#$%&()*+,./:;<=>?@[]^_`{|}~\\\""
seed := rand.New(rand.NewSource(int64(crc32.ChecksumIEEE(p[:6]))))
for i := range p[:6] {
p[i] = charSet[seed.Intn(len(charSet))]
}
}
return
}

View file

@ -1,138 +1,20 @@
package shadowsocks
package shadowaead
import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
"golang.org/x/crypto/chacha20poly1305"
)
const PacketLengthBufferSize = 2
const (
MaxPacketSize = 16*1024 - 1
PacketLengthBufferSize = 2
)
func init() {
RegisterCipher("aes-128-gcm", func() Cipher {
return &AEADCipher{
KeyLength: 16,
SaltLength: 16,
Constructor: aesGcm,
}
})
RegisterCipher("aes-192-gcm", func() Cipher {
return &AEADCipher{
KeyLength: 24,
SaltLength: 24,
Constructor: aesGcm,
}
})
RegisterCipher("aes-256-gcm", func() Cipher {
return &AEADCipher{
KeyLength: 32,
SaltLength: 32,
Constructor: aesGcm,
}
})
RegisterCipher("chacha20-ietf-poly1305", func() Cipher {
return &AEADCipher{
KeyLength: 32,
SaltLength: 32,
Constructor: chacha20Poly1305,
}
})
RegisterCipher("xchacha20-ietf-poly1305", func() Cipher {
return &AEADCipher{
KeyLength: 32,
SaltLength: 32,
Constructor: xchacha20Poly1305,
}
})
}
func aesGcm(key []byte) cipher.AEAD {
block, err := aes.NewCipher(key)
common.Must(err)
aead, err := cipher.NewGCM(block)
common.Must(err)
return aead
}
func chacha20Poly1305(key []byte) cipher.AEAD {
aead, err := chacha20poly1305.New(key)
common.Must(err)
return aead
}
func xchacha20Poly1305(key []byte) cipher.AEAD {
aead, err := chacha20poly1305.NewX(key)
common.Must(err)
return aead
}
type AEADCipher struct {
KeyLength int
SaltLength int
Constructor func(key []byte) cipher.AEAD
}
func (c *AEADCipher) KeySize() int {
return c.KeyLength
}
func (c *AEADCipher) SaltSize() int {
return c.SaltLength
}
func (c *AEADCipher) CreateReader(key []byte, salt []byte, reader io.Reader) io.Reader {
return NewAEADReader(reader, c.Constructor(Kdf(key, salt, c.KeyLength)))
}
func (c *AEADCipher) CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer {
protocolWriter := NewAEADWriter(writer, c.Constructor(Kdf(key, salt, c.KeyLength)))
return protocolWriter
}
func (c *AEADCipher) EncodePacket(key []byte, buffer *buf.Buffer) error {
aead := c.Constructor(Kdf(key, buffer.To(c.SaltLength), c.KeyLength))
aead.Seal(buffer.From(c.SaltLength)[:0], rw.ZeroBytes[:aead.NonceSize()], buffer.From(c.SaltLength), nil)
buffer.Extend(aead.Overhead())
return nil
}
func (c *AEADCipher) DecodePacket(key []byte, buffer *buf.Buffer) error {
if buffer.Len() < c.SaltLength {
return exceptions.New("bad packet")
}
aead := c.Constructor(Kdf(key, buffer.To(c.SaltLength), c.KeyLength))
packet, err := aead.Open(buffer.Index(c.SaltLength), rw.ZeroBytes[:aead.NonceSize()], buffer.From(c.SaltLength), nil)
if err != nil {
return err
}
buffer.Advance(c.SaltLength)
buffer.Truncate(len(packet))
return nil
}
type AEADConn struct {
net.Conn
Reader *AEADReader
Writer *AEADWriter
}
func (c *AEADConn) Read(p []byte) (n int, err error) {
return c.Reader.Read(p)
}
func (c *AEADConn) Write(p []byte) (n int, err error) {
return c.Writer.Write(p)
}
type AEADReader struct {
type Reader struct {
upstream io.Reader
cipher cipher.AEAD
data []byte
@ -141,8 +23,8 @@ type AEADReader struct {
cached int
}
func NewAEADReader(upstream io.Reader, cipher cipher.AEAD) *AEADReader {
return &AEADReader{
func NewReader(upstream io.Reader, cipher cipher.AEAD) *Reader {
return &Reader{
upstream: upstream,
cipher: cipher,
data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
@ -150,19 +32,19 @@ func NewAEADReader(upstream io.Reader, cipher cipher.AEAD) *AEADReader {
}
}
func (r *AEADReader) Upstream() io.Reader {
func (r *Reader) Upstream() io.Reader {
return r.upstream
}
func (r *AEADReader) Replaceable() bool {
func (r *Reader) Replaceable() bool {
return false
}
func (r *AEADReader) SetUpstream(reader io.Reader) {
func (r *Reader) SetUpstream(reader io.Reader) {
r.upstream = reader
}
func (r *AEADReader) WriteTo(writer io.Writer) (n int64, err error) {
func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
if r.cached > 0 {
writeN, writeErr := writer.Write(r.data[r.index : r.index+r.cached])
if writeErr != nil {
@ -200,7 +82,7 @@ func (r *AEADReader) WriteTo(writer io.Writer) (n int64, err error) {
}
}
func (r *AEADReader) Read(b []byte) (n int, err error) {
func (r *Reader) Read(b []byte) (n int, err error) {
if r.cached > 0 {
n = copy(b, r.data[r.index:r.index+r.cached])
r.cached -= n

View file

@ -0,0 +1,332 @@
package shadowaead
import (
"crypto/aes"
"crypto/cipher"
"crypto/sha1"
"github.com/sagernet/sing/common/replay"
"io"
"net"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/hkdf"
)
var List = []string{
"aes-128-gcm",
"aes-192-gcm",
"aes-256-gcm",
"chacha20-ietf-poly1305",
"xchacha20-ietf-poly1305",
}
func New(method string, secureRNG io.Reader) shadowsocks.Method {
m := &Method{
name: method,
secureRNG: secureRNG,
}
switch method {
case "aes-128-gcm":
m.keySaltLength = 16
m.constructor = newAESGCM
case "aes-192-gcm":
m.keySaltLength = 24
m.constructor = newAESGCM
case "aes-256-gcm":
m.keySaltLength = 32
m.constructor = newAESGCM
case "chacha20-ietf-poly1305":
m.keySaltLength = 32
m.constructor = func(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.New(key)
common.Must(err)
return cipher
}
case "xchacha20-ietf-poly1305":
m.keySaltLength = 32
m.constructor = func(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.NewX(key)
common.Must(err)
return cipher
}
}
return m
}
func NewSession(key []byte, replayFilter bool) shadowsocks.Session {
var filter replay.Filter
if replayFilter {
filter = replay.NewBloomRing()
}
return &session{key, filter}
}
func Kdf(key, iv []byte, keyLength int) []byte {
subKey := make([]byte, keyLength)
kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
common.Must1(io.ReadFull(kdf, subKey))
return subKey
}
func newAESGCM(key []byte) cipher.AEAD {
block, err := aes.NewCipher(key)
common.Must(err)
aead, err := cipher.NewGCM(block)
common.Must(err)
return aead
}
type Method struct {
name string
keySaltLength int
constructor func(key []byte) cipher.AEAD
secureRNG io.Reader
}
func (m *Method) Name() string {
return m.name
}
func (m *Method) KeyLength() int {
return m.keySaltLength
}
func (m *Method) DialConn(account shadowsocks.Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
shadowsocksConn := &aeadConn{
Conn: conn,
method: m,
key: account.Key(),
replayFilter: account.ReplayFilter(),
destination: destination,
}
return shadowsocksConn, shadowsocksConn.clientHandshake()
}
func (m *Method) DialEarlyConn(account shadowsocks.Session, conn net.Conn, destination *M.AddrPort) net.Conn {
return &aeadConn{
Conn: conn,
method: m,
key: account.Key(),
replayFilter: account.ReplayFilter(),
destination: destination,
}
}
func (m *Method) DialPacketConn(account shadowsocks.Session, conn net.Conn) socks.PacketConn {
return &aeadPacketConn{conn, account.Key(), m}
}
func (m *Method) EncodePacket(key []byte, buffer *buf.Buffer) error {
cipher := m.constructor(Kdf(key, buffer.To(m.keySaltLength), m.keySaltLength))
cipher.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:cipher.NonceSize()], buffer.From(m.keySaltLength), nil)
buffer.Extend(cipher.Overhead())
return nil
}
func (m *Method) DecodePacket(key []byte, buffer *buf.Buffer) error {
if buffer.Len() < m.keySaltLength {
return E.New("bad packet")
}
aead := m.constructor(Kdf(key, buffer.To(m.keySaltLength), m.keySaltLength))
packet, err := aead.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:aead.NonceSize()], buffer.From(m.keySaltLength), nil)
if err != nil {
return err
}
buffer.Advance(m.keySaltLength)
buffer.Truncate(len(packet))
return nil
}
type session struct {
key []byte
replayFilter replay.Filter
}
func (a *session) Key() []byte {
return a.key
}
func (a *session) ReplayFilter() replay.Filter {
return a.replayFilter
}
type aeadConn struct {
net.Conn
method *Method
key []byte
destination *M.AddrPort
access sync.Mutex
reader io.Reader
writer io.Writer
replayFilter replay.Filter
}
func (c *aeadConn) clientHandshake() error {
header := buf.New()
defer header.Release()
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
if c.replayFilter != nil {
c.replayFilter.Check(header.Bytes())
}
c.writer = NewAEADWriter(
&buf.BufferedWriter{
Writer: c.Conn,
Buffer: header,
},
c.method.constructor(Kdf(c.key, header.Bytes(), c.method.keySaltLength)),
)
err := socks.AddressSerializer.WriteAddrPort(c.writer, c.destination)
if err != nil {
return err
}
return common.FlushVar(&c.writer)
}
func (c *aeadConn) serverHandshake() error {
if c.reader == nil {
salt := make([]byte, c.method.keySaltLength)
_, err := io.ReadFull(c.Conn, salt)
if err != nil {
return err
}
if c.replayFilter != nil {
if !c.replayFilter.Check(salt) {
return E.New("salt is not unique")
}
}
c.reader = NewReader(c.Conn, c.method.constructor(Kdf(c.key, salt, c.method.keySaltLength)))
}
return nil
}
func (c *aeadConn) Read(p []byte) (n int, err error) {
if err = c.serverHandshake(); err != nil {
return
}
return c.reader.Read(p)
}
func (c *aeadConn) WriteTo(w io.Writer) (n int64, err error) {
if err = c.serverHandshake(); err != nil {
return
}
return c.reader.(io.WriterTo).WriteTo(w)
}
func (c *aeadConn) Write(p []byte) (n int, err error) {
if c.writer != nil {
goto direct
}
c.access.Lock()
defer c.access.Unlock()
if c.writer != nil {
goto direct
}
// client handshake
{
header := buf.New()
defer header.Release()
request := buf.New()
defer request.Release()
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
if c.replayFilter != nil {
c.replayFilter.Check(header.Bytes())
}
var writer io.Writer = c.Conn
writer = &buf.BufferedWriter{
Writer: writer,
Buffer: header,
}
writer = NewAEADWriter(writer, c.method.constructor(Kdf(c.key, header.Bytes(), c.method.keySaltLength)))
writer = &buf.BufferedWriter{
Writer: writer,
Buffer: request,
}
err = socks.AddressSerializer.WriteAddrPort(writer, c.destination)
if err != nil {
return
}
if len(p) > 0 {
_, err = writer.Write(p)
if err != nil {
return
}
}
err = common.FlushVar(&writer)
if err != nil {
return
}
c.writer = writer
return len(p), nil
}
direct:
return c.writer.Write(p)
}
func (c *aeadConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
panic("missing client handshake")
}
return c.writer.(io.ReaderFrom).ReadFrom(r)
}
type aeadPacketConn struct {
net.Conn
key []byte
method *Method
}
func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
header := buf.New()
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
err := socks.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
buffer = buffer.WriteBufferAtFirst(header)
err = c.method.EncodePacket(c.key, buffer)
if err != nil {
return err
}
return common.Error(c.Write(buffer.Bytes()))
}
func (c *aeadPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return nil, err
}
buffer.Truncate(n)
err = c.method.DecodePacket(c.key, buffer)
if err != nil {
return nil, err
}
return socks.AddressSerializer.ReadAddrPort(buffer)
}

View file

@ -6,12 +6,12 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/socksaddr"
M "github.com/sagernet/sing/common/metadata"
)
type PacketConn interface {
ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error)
WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error
ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error)
WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error
Close() error
LocalAddr() net.Addr
@ -21,23 +21,45 @@ type PacketConn interface {
SetWriteDeadline(t time.Time) error
}
func CopyPacketConn(dest PacketConn, conn PacketConn, onAction func(size int)) error {
type UDPConnectionHandler interface {
NewPacketConnection(conn PacketConn, metadata M.Metadata) error
}
type PacketConnStub struct{}
func (s *PacketConnStub) RemoteAddr() net.Addr {
return &common.DummyAddr{}
}
func (s *PacketConnStub) SetDeadline(t time.Time) error {
return nil
}
func (s *PacketConnStub) SetReadDeadline(t time.Time) error {
return nil
}
func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
return nil
}
func CopyPacketConn(dest PacketConn, conn PacketConn, onAction func(destination *M.AddrPort, n int)) error {
for {
buffer := buf.New()
addr, port, err := conn.ReadPacket(buffer)
destination, err := conn.ReadPacket(buffer)
if err != nil {
buffer.Release()
return err
}
size := buffer.Len()
err = dest.WritePacket(buffer, addr, port)
err = dest.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return err
}
if onAction != nil {
onAction(size)
onAction(destination, size)
}
buffer.Reset()
}
}
@ -58,22 +80,22 @@ func (c *associatePacketConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) {
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes())
if err != nil {
return nil, 0, err
return nil, err
}
c.addr = addr
buffer.Truncate(n)
buffer.Advance(3)
return AddressSerializer.ReadAddressAndPort(buffer)
return AddressSerializer.ReadAddrPort(buffer)
}
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error {
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
defer buffer.Release()
header := buf.New()
common.Must(header.WriteZeroN(3))
common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port))
common.Must(AddressSerializer.WriteAddrPort(header, addrPort))
buffer = buffer.WriteBufferAtFirst(header)
return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr))
}

View file

@ -3,7 +3,7 @@ package socks
import (
"strconv"
"github.com/sagernet/sing/common/socksaddr"
M "github.com/sagernet/sing/common/metadata"
)
const (
@ -69,8 +69,8 @@ func (code ReplyCode) String() string {
}
}
var AddressSerializer = socksaddr.NewSerializer(
socksaddr.AddressFamilyByte(0x01, socksaddr.AddressFamilyIPv4),
socksaddr.AddressFamilyByte(0x04, socksaddr.AddressFamilyIPv6),
socksaddr.AddressFamilyByte(0x03, socksaddr.AddressFamilyFqdn),
var AddressSerializer = M.NewSerializer(
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
M.AddressFamilyByte(0x04, M.AddressFamilyIPv6),
M.AddressFamilyByte(0x03, M.AddressFamilyFqdn),
)

View file

@ -4,11 +4,11 @@ import (
"io"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/socksaddr"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksaddr.Addr, port uint16, username string, password string) (*Response, error) {
func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination *M.AddrPort, username string, password string) (*Response, error) {
var method byte
if common.IsBlank(username) {
method = AuthTypeNotRequired
@ -27,7 +27,7 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksa
return nil, err
}
if authResponse.Method != method {
return nil, exceptions.New("not requested method, request ", method, ", return ", method)
return nil, E.New("not requested method, request ", method, ", return ", method)
}
if method == AuthTypeUsernamePassword {
err = WriteUsernamePasswordAuthRequest(conn, &UsernamePasswordAuthRequest{
@ -46,10 +46,9 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksa
}
}
err = WriteRequest(conn, &Request{
Version: version,
Command: command,
Addr: addr,
Port: port,
Version: version,
Command: command,
Destination: destination,
})
if err != nil {
return nil, err
@ -57,7 +56,7 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksa
return ReadResponse(conn)
}
func ClientFastHandshake(writer io.Writer, version byte, command byte, addr socksaddr.Addr, port uint16, username string, password string) error {
func ClientFastHandshake(writer io.Writer, version byte, command byte, destination *M.AddrPort, username string, password string) error {
var method byte
if common.IsBlank(username) {
method = AuthTypeNotRequired
@ -81,10 +80,9 @@ func ClientFastHandshake(writer io.Writer, version byte, command byte, addr sock
}
}
return WriteRequest(writer, &Request{
Version: version,
Command: command,
Addr: addr,
Port: port,
Version: version,
Command: command,
Destination: destination,
})
}

148
protocol/socks/listener.go Normal file
View file

@ -0,0 +1,148 @@
package socks
import (
"io"
"net"
"net/netip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/transport/tcp"
)
type Handler interface {
tcp.Handler
UDPConnectionHandler
}
type Listener struct {
tcpListener *tcp.Listener
authenticator auth.Authenticator
handler Handler
}
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler Handler) *Listener {
listener := &Listener{
handler: handler,
authenticator: authenticator,
}
listener.tcpListener = tcp.NewTCPListener(bind, listener)
return listener
}
func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error {
return HandleConnection(conn, l.authenticator, l.handler)
}
func (l *Listener) Start() error {
return l.tcpListener.Start()
}
func (l *Listener) Close() error {
return l.tcpListener.Close()
}
func (l *Listener) HandleError(err error) {
l.handler.HandleError(err)
}
func HandleConnection(conn net.Conn, authenticator auth.Authenticator, handler Handler) error {
authRequest, err := ReadAuthRequest(conn)
if err != nil {
return E.Cause(err, "read socks auth request")
}
var authMethod byte
if authenticator == nil {
authMethod = AuthTypeNotRequired
} else {
authMethod = AuthTypeUsernamePassword
}
if !common.Contains(authRequest.Methods, authMethod) {
err = WriteAuthResponse(conn, &AuthResponse{
Version: authRequest.Version,
Method: AuthTypeNoAcceptedMethods,
})
if err != nil {
return E.Cause(err, "write socks auth response")
}
}
err = WriteAuthResponse(conn, &AuthResponse{
Version: authRequest.Version,
Method: AuthTypeNotRequired,
})
if err != nil {
return E.Cause(err, "write socks auth response")
}
if authMethod == AuthTypeUsernamePassword {
usernamePasswordAuthRequest, err := ReadUsernamePasswordAuthRequest(conn)
if err != nil {
return E.Cause(err, "read user auth request")
}
response := new(UsernamePasswordAuthResponse)
if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) {
response.Status = UsernamePasswordStatusSuccess
} else {
response.Status = UsernamePasswordStatusFailure
}
err = WriteUsernamePasswordAuthResponse(conn, response)
if err != nil {
return E.Cause(err, "write user auth response")
}
}
request, err := ReadRequest(conn)
if err != nil {
return E.Cause(err, "read socks request")
}
switch request.Command {
case CommandConnect:
err = WriteResponse(conn, &Response{
Version: request.Version,
ReplyCode: ReplyCodeSuccess,
Bind: M.AddrPortFromNetAddr(conn.LocalAddr()),
})
if err != nil {
return E.Cause(err, "write socks response")
}
return handler.NewConnection(conn, M.Metadata{
Destination: request.Destination,
})
case CommandUDPAssociate:
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return err
}
defer udpConn.Close()
err = WriteResponse(conn, &Response{
Version: request.Version,
ReplyCode: ReplyCodeSuccess,
Bind: M.AddrPortFromNetAddr(udpConn.LocalAddr()),
})
if err != nil {
return E.Cause(err, "write socks response")
}
go func() {
err := handler.NewPacketConnection(NewPacketConn(conn, udpConn), M.Metadata{
Source: M.AddrPortFromNetAddr(conn.RemoteAddr()),
Destination: request.Destination,
})
if err != nil {
handler.HandleError(err)
}
conn.Close()
}()
return common.Error(io.Copy(io.Discard, conn))
default:
err = WriteResponse(conn, &Response{
Version: request.Version,
ReplyCode: ReplyCodeUnsupported,
})
if err != nil {
return E.Cause(err, "write response")
}
}
return nil
}

View file

@ -7,9 +7,9 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/exceptions"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/socksaddr"
)
//+----+----------+----------+
@ -45,11 +45,11 @@ func ReadAuthRequest(reader io.Reader) (*AuthRequest, error) {
}
methodLen, err := rw.ReadByte(reader)
if err != nil {
return nil, exceptions.Cause(err, "read socks auth methods length")
return nil, E.Cause(err, "read socks auth methods length")
}
methods, err := rw.ReadBytes(reader, int(methodLen))
if err != nil {
return nil, exceptions.CauseF(err, "read socks auth methods, length ", methodLen)
return nil, E.CauseF(err, "read socks auth methods, length ", methodLen)
}
request := &AuthRequest{
version,
@ -112,11 +112,11 @@ func WriteUsernamePasswordAuthRequest(writer io.Writer, request *UsernamePasswor
if err != nil {
return err
}
err = socksaddr.WriteString(writer, "username", request.Username)
err = M.WriteString(writer, "username", request.Username)
if err != nil {
return err
}
return socksaddr.WriteString(writer, "password", request.Password)
return M.WriteString(writer, "password", request.Password)
}
func ReadUsernamePasswordAuthRequest(reader io.Reader) (*UsernamePasswordAuthRequest, error) {
@ -127,11 +127,11 @@ func ReadUsernamePasswordAuthRequest(reader io.Reader) (*UsernamePasswordAuthReq
if version != UsernamePasswordVersion1 {
return nil, &UnsupportedVersionException{version}
}
username, err := socksaddr.ReadString(reader)
username, err := M.ReadString(reader)
if err != nil {
return nil, err
}
password, err := socksaddr.ReadString(reader)
password, err := M.ReadString(reader)
if err != nil {
return nil, err
}
@ -185,10 +185,9 @@ func ReadUsernamePasswordAuthResponse(reader io.Reader) (*UsernamePasswordAuthRe
//+----+-----+-------+------+----------+----------+
type Request struct {
Version byte
Command byte
Addr socksaddr.Addr
Port uint16
Version byte
Command byte
Destination *M.AddrPort
}
func WriteRequest(writer io.Writer, request *Request) error {
@ -204,7 +203,7 @@ func WriteRequest(writer io.Writer, request *Request) error {
if err != nil {
return err
}
return AddressSerializer.WriteAddressAndPort(writer, request.Addr, request.Port)
return AddressSerializer.WriteAddrPort(writer, request.Destination)
}
func ReadRequest(reader io.Reader) (*Request, error) {
@ -226,15 +225,14 @@ func ReadRequest(reader io.Reader) (*Request, error) {
if err != nil {
return nil, err
}
addr, port, err := AddressSerializer.ReadAddressAndPort(reader)
addrPort, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
request := &Request{
Version: version,
Command: command,
Addr: addr,
Port: port,
Version: version,
Command: command,
Destination: addrPort,
}
return request, nil
}
@ -248,8 +246,7 @@ func ReadRequest(reader io.Reader) (*Request, error) {
type Response struct {
Version byte
ReplyCode ReplyCode
BindAddr socksaddr.Addr
BindPort uint16
Bind *M.AddrPort
}
func WriteResponse(writer io.Writer, response *Response) error {
@ -265,10 +262,10 @@ func WriteResponse(writer io.Writer, response *Response) error {
if err != nil {
return err
}
if response.BindAddr == nil {
return AddressSerializer.WriteAddressAndPort(writer, socksaddr.AddrFromIP(net.IPv4zero), response.BindPort)
if response.Bind == nil {
return AddressSerializer.WriteAddrPort(writer, M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0))
}
return AddressSerializer.WriteAddressAndPort(writer, response.BindAddr, response.BindPort)
return AddressSerializer.WriteAddrPort(writer, response.Bind)
}
func ReadResponse(reader io.Reader) (*Response, error) {
@ -287,15 +284,14 @@ func ReadResponse(reader io.Reader) (*Response, error) {
if err != nil {
return nil, err
}
addr, port, err := AddressSerializer.ReadAddressAndPort(reader)
addrPort, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
response := &Response{
Version: version,
ReplyCode: ReplyCode(replyCode),
BindAddr: addr,
BindPort: port,
Bind: addrPort,
}
return response, nil
}
@ -307,15 +303,14 @@ func ReadResponse(reader io.Reader) (*Response, error) {
//+----+------+------+----------+----------+----------+
type AssociatePacket struct {
Fragment byte
Addr socksaddr.Addr
Port uint16
Data []byte
Fragment byte
Destination *M.AddrPort
Data []byte
}
func DecodeAssociatePacket(buffer *buf.Buffer) (*AssociatePacket, error) {
if buffer.Len() < 5 {
return nil, exceptions.New("insufficient length")
return nil, E.New("insufficient length")
}
fragment := buffer.Byte(2)
reader := bytes.NewReader(buffer.Bytes())
@ -323,16 +318,15 @@ func DecodeAssociatePacket(buffer *buf.Buffer) (*AssociatePacket, error) {
if err != nil {
return nil, err
}
addr, port, err := AddressSerializer.ReadAddressAndPort(reader)
addrPort, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
buffer.Advance(reader.Len())
packet := &AssociatePacket{
Fragment: fragment,
Addr: addr,
Port: port,
Data: buffer.Bytes(),
Fragment: fragment,
Destination: addrPort,
Data: buffer.Bytes(),
}
return packet, nil
}
@ -346,7 +340,7 @@ func EncodeAssociatePacket(packet *AssociatePacket, buffer *buf.Buffer) error {
if err != nil {
return err
}
err = AddressSerializer.WriteAddressAndPort(buffer, packet.Addr, packet.Port)
err = AddressSerializer.WriteAddrPort(buffer, packet.Destination)
if err != nil {
return err
}

View file

@ -5,7 +5,7 @@ import (
"sync"
"testing"
"github.com/sagernet/sing/common/socksaddr"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
)
@ -20,7 +20,7 @@ func TestHandshake(t *testing.T) {
method := socks.AuthTypeUsernamePassword
go func() {
response, err := socks.ClientHandshake(client, socks.Version5, socks.CommandConnect, socksaddr.AddrFromFqdn("test"), 80, "user", "pswd")
response, err := socks.ClientHandshake(client, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn("test"), 80), "user", "pswd")
if err != nil {
t.Fatal(err)
}
@ -60,14 +60,13 @@ func TestHandshake(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if request.Version != socks.Version5 || request.Command != socks.CommandConnect || request.Addr.Fqdn() != "test" || request.Port != 80 {
if request.Version != socks.Version5 || request.Command != socks.CommandConnect || request.Destination.Addr.Fqdn() != "test" || request.Destination.Port != 80 {
t.Fatal(request)
}
err = socks.WriteResponse(server, &socks.Response{
Version: socks.Version5,
ReplyCode: socks.ReplyCodeSuccess,
BindAddr: socksaddr.AddrFromIP(net.IPv4zero),
BindPort: 0,
Bind: M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0),
})
if err != nil {
t.Fatal(err)

View file

@ -0,0 +1,86 @@
package mixed
import (
"net"
"net/netip"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/http"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/transport/tcp"
"github.com/sagernet/sing/transport/udp"
)
type Handler interface {
socks.Handler
}
type Listener struct {
TCPListener *tcp.Listener
UDPListener *udp.Listener
handler Handler
authenticator auth.Authenticator
udpNat *udpnat.Server
}
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, handler Handler) *Listener {
listener := &Listener{
handler: handler,
authenticator: authenticator,
}
listener.TCPListener = tcp.NewTCPListener(bind, listener, tcp.WithTransproxyMode(transproxy))
if transproxy == redir.ModeTProxy {
listener.UDPListener = udp.NewUDPListener(bind, listener, udp.WithTransproxyMode(transproxy))
listener.udpNat = udpnat.NewServer(handler)
}
return listener
}
func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error {
if metadata.Destination != nil {
return l.handler.NewConnection(conn, metadata)
}
bufConn := buf.NewBufferedConn(conn)
header, err := bufConn.Peek(1)
if err != nil {
return err
}
switch header[0] {
case socks.Version4, socks.Version5:
return socks.HandleConnection(bufConn, l.authenticator, l.handler)
default:
return http.HandleConnection(bufConn, l.authenticator, l.handler)
}
}
func (l *Listener) NewPacket(packet *buf.Buffer, metadata M.Metadata) error {
return l.udpNat.HandleUDP(packet, metadata)
}
func (l *Listener) HandleError(err error) {
l.handler.HandleError(err)
}
func (l *Listener) Start() error {
err := l.TCPListener.Start()
if err != nil {
return err
}
if l.UDPListener != nil {
err = l.UDPListener.Start()
}
return err
}
func (l *Listener) Close() error {
l.TCPListener.Close()
if l.UDPListener != nil {
l.UDPListener.Close()
}
return nil
}

View file

@ -1,10 +0,0 @@
package system
import (
"bufio"
"net/http"
_ "unsafe"
)
//go:linkname readRequest net/http.readRequest
func readRequest(b *bufio.Reader) (req *http.Request, err error)

View file

@ -2,6 +2,8 @@ package system
import (
"syscall"
"golang.org/x/sys/unix"
)
const (
@ -12,3 +14,22 @@ const (
func TCPFastOpen(fd uintptr) error {
return syscall.SetsockoptInt(int(fd), syscall.SOL_TCP, TCP_FASTOPEN_CONNECT, 1)
}
func TProxy(fd uintptr, isIPv6 bool) error {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
if err != nil {
return err
}
if isIPv6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
}
return err
}
func TProxyUDP(fd uintptr, isIPv6 bool) error {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
if err != nil {
return err
}
return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
}

View file

@ -7,3 +7,7 @@ import "github.com/sagernet/sing/common/exceptions"
func TCPFastOpen(fd uintptr) error {
return exceptions.New("only available on linux")
}
func TProxy(fd uintptr, isIPv6 bool) error {
return exceptions.New("only available on linux")
}

View file

@ -1,149 +0,0 @@
package system
import (
"io"
"net"
"net/netip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/socksaddr"
"github.com/sagernet/sing/protocol/socks"
)
type SocksHandler interface {
NewConnection(addr socksaddr.Addr, port uint16, conn net.Conn) error
NewPacketConnection(conn socks.PacketConn, addr socksaddr.Addr, port uint16) error
OnError(err error)
}
type SocksConfig struct {
Username string
Password string
}
type SocksListener struct {
Handler SocksHandler
*TCPListener
*SocksConfig
}
func NewSocksListener(bind netip.AddrPort, config *SocksConfig, handler SocksHandler) *SocksListener {
listener := &SocksListener{
SocksConfig: config,
Handler: handler,
}
listener.TCPListener = NewTCPListener(bind, listener)
return listener
}
func (l *SocksListener) HandleTCP(conn net.Conn) error {
authRequest, err := socks.ReadAuthRequest(conn)
if err != nil {
return exceptions.Cause(err, "read socks auth request")
}
var authMethod byte
if l.Username == "" {
authMethod = socks.AuthTypeNotRequired
} else {
authMethod = socks.AuthTypeUsernamePassword
}
if !common.Contains(authRequest.Methods, authMethod) {
err = socks.WriteAuthResponse(conn, &socks.AuthResponse{
Version: authRequest.Version,
Method: socks.AuthTypeNoAcceptedMethods,
})
if err != nil {
return exceptions.Cause(err, "write socks auth response")
}
}
err = socks.WriteAuthResponse(conn, &socks.AuthResponse{
Version: authRequest.Version,
Method: socks.AuthTypeNotRequired,
})
if err != nil {
return exceptions.Cause(err, "write socks auth response")
}
if authMethod == socks.AuthTypeUsernamePassword {
usernamePasswordAuthRequest, err := socks.ReadUsernamePasswordAuthRequest(conn)
if err != nil {
return exceptions.Cause(err, "read user auth request")
}
response := socks.UsernamePasswordAuthResponse{}
if usernamePasswordAuthRequest.Username != l.Username {
response.Status = socks.UsernamePasswordStatusFailure
} else if usernamePasswordAuthRequest.Password != l.Password {
response.Status = socks.UsernamePasswordStatusFailure
} else {
response.Status = socks.UsernamePasswordStatusSuccess
}
err = socks.WriteUsernamePasswordAuthResponse(conn, &response)
if err != nil {
return exceptions.Cause(err, "write user auth response")
}
}
request, err := socks.ReadRequest(conn)
if err != nil {
return exceptions.Cause(err, "read socks request")
}
switch request.Command {
case socks.CommandConnect:
localAddr, localPort := socksaddr.AddrFromNetAddr(l.TCPListener.TCPListener.Addr())
err = socks.WriteResponse(conn, &socks.Response{
Version: request.Version,
ReplyCode: socks.ReplyCodeSuccess,
BindAddr: localAddr,
BindPort: localPort,
})
if err != nil {
return exceptions.Cause(err, "write socks response")
}
return l.Handler.NewConnection(request.Addr, request.Port, conn)
case socks.CommandUDPAssociate:
udpConn, err := net.ListenUDP("udp4", nil)
if err != nil {
return err
}
defer udpConn.Close()
localAddr, localPort := socksaddr.AddrFromNetAddr(udpConn.LocalAddr())
err = socks.WriteResponse(conn, &socks.Response{
Version: request.Version,
ReplyCode: socks.ReplyCodeSuccess,
BindAddr: localAddr,
BindPort: localPort,
})
if err != nil {
return exceptions.Cause(err, "write socks response")
}
go func() {
err := l.Handler.NewPacketConnection(socks.NewPacketConn(conn, udpConn), request.Addr, request.Port)
if err != nil {
l.OnError(err)
}
}()
return common.Error(io.Copy(io.Discard, conn))
default:
err = socks.WriteResponse(conn, &socks.Response{
Version: request.Version,
ReplyCode: socks.ReplyCodeUnsupported,
})
if err != nil {
return exceptions.Cause(err, "write response")
}
}
return nil
}
func (l *SocksListener) Start() error {
return l.TCPListener.Start()
}
func (l *SocksListener) Close() error {
return l.TCPListener.Close()
}
func (l *SocksListener) OnError(err error) {
l.Handler.OnError(exceptions.Cause(err, "socks server"))
}

View file

@ -1,57 +0,0 @@
package system
import (
"net"
"net/netip"
)
type TCPHandler interface {
HandleTCP(conn net.Conn) error
OnError(err error)
}
type TCPListener struct {
Listen netip.AddrPort
Handler TCPHandler
*net.TCPListener
}
func NewTCPListener(listen netip.AddrPort, handler TCPHandler) *TCPListener {
return &TCPListener{
Listen: listen,
Handler: handler,
}
}
func (l *TCPListener) Start() error {
tcpListener, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(l.Listen))
if err != nil {
return err
}
l.TCPListener = tcpListener
go l.loop()
return nil
}
func (l *TCPListener) Close() error {
if l == nil || l.TCPListener == nil {
return nil
}
return l.TCPListener.Close()
}
func (l *TCPListener) loop() {
for {
tcpConn, err := l.Accept()
if err != nil {
l.Close()
return
}
go func() {
err := l.Handler.HandleTCP(tcpConn)
if err != nil {
l.Handler.OnError(err)
}
}()
}
}

View file

@ -1,62 +0,0 @@
package system
import (
"net"
"net/netip"
"github.com/sagernet/sing/common/buf"
)
type UDPHandler interface {
HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error
OnError(err error)
}
type UDPListener struct {
Listen netip.AddrPort
Handler UDPHandler
*net.UDPConn
}
func NewUDPListener(listen netip.AddrPort, handler UDPHandler) *UDPListener {
return &UDPListener{
Listen: listen,
Handler: handler,
}
}
func (l *UDPListener) Start() error {
udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(l.Listen))
if err != nil {
return err
}
l.UDPConn = udpConn
go l.loop()
return nil
}
func (l *UDPListener) Close() error {
if l == nil || l.UDPConn == nil {
return nil
}
return l.UDPConn.Close()
}
func (l *UDPListener) loop() {
for {
buffer := buf.New()
n, addr, err := l.ReadFromUDP(buffer.Extend(buf.UDPBufferSize))
if err != nil {
buffer.Release()
return
}
buffer.Truncate(n)
go func() {
err := l.Handler.HandleUDP(buffer, addr)
if err != nil {
buffer.Release()
l.Handler.OnError(err)
}
}()
}
}

91
transport/tcp/handler.go Normal file
View file

@ -0,0 +1,91 @@
package tcp
import (
"net"
"net/netip"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/redir"
)
type Handler interface {
M.TCPConnectionHandler
E.Handler
}
type Listener struct {
bind netip.AddrPort
handler Handler
trans redir.TransproxyMode
lAddr *net.TCPAddr
*net.TCPListener
}
func NewTCPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener {
listener := &Listener{
bind: listen,
handler: handler,
}
for _, option := range options {
option(listener)
}
return listener
}
func (l *Listener) Start() error {
tcpListener, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(l.bind))
if err != nil {
return err
}
if l.trans == redir.ModeTProxy {
l.lAddr = tcpListener.Addr().(*net.TCPAddr)
fd, err := common.GetFileDescriptor(tcpListener)
if err != nil {
return err
}
err = redir.TProxy(fd, l.bind.Addr().Is6())
if err != nil {
return E.Cause(err, "configure tproxy")
}
}
l.TCPListener = tcpListener
go l.loop()
return nil
}
func (l *Listener) Close() error {
if l == nil || l.TCPListener == nil {
return nil
}
return l.TCPListener.Close()
}
func (l *Listener) loop() {
for {
tcpConn, err := l.Accept()
if err != nil {
l.Close()
return
}
var metadata M.Metadata
switch l.trans {
case redir.ModeRedirect:
metadata.Destination, _ = redir.GetOriginalDestination(tcpConn)
case redir.ModeTProxy:
lAddr := tcpConn.LocalAddr().(*net.TCPAddr)
rAddr := tcpConn.RemoteAddr().(*net.TCPAddr)
if lAddr.Port != l.lAddr.Port || !lAddr.IP.Equal(rAddr.IP) && !lAddr.IP.IsLoopback() && !lAddr.IP.IsPrivate() {
metadata.Destination = M.AddrPortFromNetAddr(lAddr)
}
}
go func() {
err := l.handler.NewConnection(tcpConn, metadata)
if err != nil {
l.handler.HandleError(err)
}
}()
}
}

11
transport/tcp/options.go Normal file
View file

@ -0,0 +1,11 @@
package tcp
import "github.com/sagernet/sing/common/redir"
type Option func(*Listener)
func WithTransproxyMode(mode redir.TransproxyMode) Option {
return func(listener *Listener) {
listener.trans = mode
}
}

11
transport/udp/options.go Normal file
View file

@ -0,0 +1,11 @@
package udp
import "github.com/sagernet/sing/common/redir"
type Option func(*Listener)
func WithTransproxyMode(mode redir.TransproxyMode) Option {
return func(listener *Listener) {
listener.tproxy = mode == redir.ModeTProxy
}
}

116
transport/udp/udp.go Normal file
View file

@ -0,0 +1,116 @@
package udp
import (
"net"
"net/netip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/redir"
)
type Handler interface {
M.UDPHandler
E.Handler
}
type Listener struct {
*net.UDPConn
handler Handler
bind *net.UDPAddr
tproxy bool
}
func NewUDPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener {
listener := &Listener{
handler: handler,
bind: net.UDPAddrFromAddrPort(listen),
}
for _, option := range options {
option(listener)
}
return listener
}
func (l *Listener) Start() error {
udpConn, err := net.ListenUDP("udp", l.bind)
if err != nil {
return err
}
if l.tproxy {
fd, err := common.GetFileDescriptor(udpConn)
if err != nil {
return err
}
err = redir.TProxy(fd, l.bind.AddrPort().Addr().Is6())
if err != nil {
return E.Cause(err, "configure tproxy")
}
err = redir.TProxyUDP(fd, l.bind.AddrPort().Addr().Is6())
if err != nil {
return E.Cause(err, "configure tproxy")
}
}
l.UDPConn = udpConn
go l.loop()
return nil
}
func (l *Listener) Close() error {
if l == nil || l.UDPConn == nil {
return nil
}
return l.UDPConn.Close()
}
func (l *Listener) loop() {
if !l.tproxy {
for {
buffer := buf.New()
n, addr, err := l.ReadFromUDP(buffer.Extend(buf.UDPBufferSize))
if err != nil {
buffer.Release()
l.handler.HandleError(err)
return
}
buffer.Truncate(n)
err = l.handler.NewPacket(buffer, M.Metadata{
Source: M.AddrPortFromNetAddr(addr),
})
if err != nil {
buffer.Release()
l.handler.HandleError(err)
}
}
} else {
oob := make([]byte, 1024)
for {
buffer := buf.New()
n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(buffer.FreeBytes(), oob)
if err != nil {
buffer.Release()
l.handler.HandleError(err)
return
}
destination, err := redir.GetOriginalDestinationFromOOB(oob[:oobN])
if err != nil {
l.handler.HandleError(E.Cause(err, "get original destination"))
return
}
buffer.Truncate(n)
err = l.handler.NewPacket(buffer, M.Metadata{
Source: M.AddrPortFromAddrPort(addr),
Destination: destination,
})
if err != nil {
buffer.Release()
l.handler.HandleError(err)
}
}
}
}