mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 03:47:38 +03:00
Refactor shadowsocks
This commit is contained in:
parent
3f23b25edf
commit
00cd0d4b8f
75 changed files with 3169 additions and 1318 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -2,4 +2,5 @@
|
|||
/sing_*
|
||||
/*.json
|
||||
/Country.mmdb
|
||||
/geosite.dat
|
||||
/geosite.dat
|
||||
/vendor/
|
11
README.md
11
README.md
|
@ -1,3 +1,12 @@
|
|||
# sing
|
||||
|
||||
Do you hear the people sing?
|
||||
Do you hear the people sing?
|
||||
|
||||
```shell
|
||||
# geo resources
|
||||
go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/get-geoip
|
||||
go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/gen-geosite
|
||||
|
||||
# ss-local
|
||||
go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/ss-local
|
||||
```
|
|
@ -2,12 +2,12 @@ package main
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/ulikunitz/xz"
|
||||
"github.com/v2fly/v2ray-core/v5/app/router/routercommon"
|
|
@ -1,10 +1,11 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func main() {
|
374
cli/ss-local/main.go
Normal file
374
cli/ss-local/main.go
Normal file
|
@ -0,0 +1,374 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime/debug"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/geoip"
|
||||
"github.com/sagernet/sing/common/geosite"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/random"
|
||||
"github.com/sagernet/sing/common/redir"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"github.com/sagernet/sing/transport/mixed"
|
||||
"github.com/sagernet/sing/transport/system"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type flags struct {
|
||||
Server string `json:"server"`
|
||||
ServerPort uint16 `json:"server_port"`
|
||||
LocalPort uint16 `json:"local_port"`
|
||||
Password string `json:"password"`
|
||||
Key string `json:"key"`
|
||||
Method string `json:"method"`
|
||||
TCPFastOpen bool `json:"fast_open"`
|
||||
Verbose bool `json:"verbose"`
|
||||
Transproxy string `json:"transproxy"`
|
||||
FWMark int `json:"fwmark"`
|
||||
Bypass string `json:"bypass"`
|
||||
UseSystemRNG bool `json:"use_system_rng"`
|
||||
ReducedSaltEntropy bool `json:"reduced_salt_entropy"`
|
||||
ConfigFile string
|
||||
}
|
||||
|
||||
func main() {
|
||||
f := new(flags)
|
||||
|
||||
command := &cobra.Command{
|
||||
Use: "ss-local",
|
||||
Short: "shadowsocks client",
|
||||
Version: sing.VersionStr,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
Run(cmd, f)
|
||||
},
|
||||
}
|
||||
|
||||
command.Flags().StringVarP(&f.Server, "server", "s", "", "Set the server’s hostname or IP.")
|
||||
command.Flags().Uint16VarP(&f.ServerPort, "server-port", "p", 0, "Set the server’s port number.")
|
||||
command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.")
|
||||
command.Flags().StringVarP(&f.Password, "password", "k", "", "Set the password. The server and the client should use the same password.")
|
||||
command.Flags().StringVar(&f.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
|
||||
command.Flags().StringVarP(&f.Method, "encrypt-method", "m", "", `Set the cipher.
|
||||
|
||||
Supported ciphers:
|
||||
|
||||
none
|
||||
aes-128-gcm
|
||||
aes-192-gcm
|
||||
aes-256-gcm
|
||||
chacha20-ietf-poly1305
|
||||
xchacha20-ietf-poly1305`)
|
||||
command.Flags().BoolVar(&f.TCPFastOpen, "fast-open", false, `Enable TCP fast open.
|
||||
Only available with Linux kernel > 3.7.0.`)
|
||||
command.Flags().StringVarP(&f.Transproxy, "transproxy", "t", "", "Enable transparent proxy support. [possible values: redirect, tproxy]")
|
||||
command.Flags().IntVar(&f.FWMark, "fwmark", 0, "Set outbound socket mark.")
|
||||
command.Flags().StringVar(&f.Bypass, "bypass", "", "Set bypass country.")
|
||||
command.Flags().StringVarP(&f.ConfigFile, "config", "c", "", "Use a configuration file.")
|
||||
command.Flags().BoolVarP(&f.Verbose, "verbose", "v", false, "Enable verbose mode.")
|
||||
command.Flags().BoolVar(&f.UseSystemRNG, "use-system-rng", false, "Use system random number generator.")
|
||||
command.Flags().BoolVar(&f.ReducedSaltEntropy, "reduced-salt-entropy", false, "Remapping salt to printable chars.")
|
||||
|
||||
err := command.Execute()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
type LocalClient struct {
|
||||
*mixed.Listener
|
||||
*geosite.Matcher
|
||||
server *M.AddrPort
|
||||
method shadowsocks.Method
|
||||
session shadowsocks.Session
|
||||
dialer net.Dialer
|
||||
bypass string
|
||||
}
|
||||
|
||||
func NewLocalClient(f *flags) (*LocalClient, error) {
|
||||
if f.ConfigFile != "" {
|
||||
configFile, err := ioutil.ReadFile(f.ConfigFile)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read config file")
|
||||
}
|
||||
flagsNew := new(flags)
|
||||
err = json.Unmarshal(configFile, flagsNew)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode config file")
|
||||
}
|
||||
if flagsNew.Server != "" && f.Server == "" {
|
||||
f.Server = flagsNew.Server
|
||||
}
|
||||
if flagsNew.ServerPort != 0 && f.ServerPort == 0 {
|
||||
f.ServerPort = flagsNew.ServerPort
|
||||
}
|
||||
if flagsNew.LocalPort != 0 && f.LocalPort == 0 {
|
||||
f.LocalPort = flagsNew.LocalPort
|
||||
}
|
||||
if flagsNew.Password != "" && f.Password == "" {
|
||||
f.Password = flagsNew.Password
|
||||
}
|
||||
if flagsNew.Key != "" && f.Key == "" {
|
||||
f.Key = flagsNew.Key
|
||||
}
|
||||
if flagsNew.Method != "" && f.Method == "" {
|
||||
f.Method = flagsNew.Method
|
||||
}
|
||||
if flagsNew.Transproxy != "" && f.Transproxy == "" {
|
||||
f.Transproxy = flagsNew.Transproxy
|
||||
}
|
||||
if flagsNew.TCPFastOpen {
|
||||
f.TCPFastOpen = true
|
||||
}
|
||||
if flagsNew.Verbose {
|
||||
f.Verbose = true
|
||||
}
|
||||
}
|
||||
|
||||
if f.Verbose {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
}
|
||||
|
||||
if f.Server == "" {
|
||||
return nil, E.New("missing server address")
|
||||
} else if f.ServerPort == 0 {
|
||||
return nil, E.New("missing server port")
|
||||
} else if f.Method == "" {
|
||||
return nil, E.New("missing method")
|
||||
}
|
||||
|
||||
client := &LocalClient{
|
||||
server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort),
|
||||
bypass: f.Bypass,
|
||||
}
|
||||
|
||||
if f.Method == shadowsocks.MethodNone {
|
||||
client.method = shadowsocks.NewNone()
|
||||
} else if common.Contains(shadowaead.List, f.Method) {
|
||||
var key []byte
|
||||
var rng io.Reader
|
||||
if f.UseSystemRNG {
|
||||
rng = random.System
|
||||
} else {
|
||||
rng = random.Blake3KeyedHash()
|
||||
}
|
||||
if f.ReducedSaltEntropy {
|
||||
rng = &shadowsocks.ReducedEntropyReader{Reader: rng}
|
||||
}
|
||||
client.method = shadowaead.New(f.Method, rng)
|
||||
keyLength := client.method.KeyLength()
|
||||
|
||||
if f.Key != "" {
|
||||
decoded, err := base64.URLEncoding.DecodeString(f.Key)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode key")
|
||||
}
|
||||
if len(decoded) != keyLength {
|
||||
return nil, E.Cause(err, "bad key")
|
||||
}
|
||||
key = decoded
|
||||
} else if f.Password != "" {
|
||||
key = shadowsocks.Key([]byte(f.Password), keyLength)
|
||||
} else {
|
||||
return nil, E.New("missing password")
|
||||
}
|
||||
client.session = shadowaead.NewSession(key, false)
|
||||
}
|
||||
|
||||
client.dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
var rawFd uintptr
|
||||
err := c.Control(func(fd uintptr) {
|
||||
rawFd = fd
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f.FWMark > 0 {
|
||||
err = syscall.SetsockoptInt(int(rawFd), syscall.SOL_SOCKET, syscall.SO_MARK, f.FWMark)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if f.TCPFastOpen {
|
||||
err = system.TCPFastOpen(rawFd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var transproxyMode redir.TransproxyMode
|
||||
switch f.Transproxy {
|
||||
case "redirect":
|
||||
transproxyMode = redir.ModeRedirect
|
||||
case "tproxy":
|
||||
transproxyMode = redir.ModeTProxy
|
||||
case "":
|
||||
transproxyMode = redir.ModeDisabled
|
||||
default:
|
||||
return nil, E.New("unknown transproxy mode ", f.Transproxy)
|
||||
}
|
||||
|
||||
client.Listener = mixed.NewListener(netip.AddrPortFrom(netip.IPv6Unspecified(), f.LocalPort), nil, transproxyMode, client)
|
||||
|
||||
if f.Bypass != "" {
|
||||
err := geoip.LoadMMDB("Country.mmdb")
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "load Country.mmdb")
|
||||
}
|
||||
|
||||
geodata, err := os.Open("geosite.dat")
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "geosite.dat not found")
|
||||
}
|
||||
|
||||
geositeMatcher, err := geosite.LoadGeositeMatcher(geodata, f.Bypass)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.Matcher = geositeMatcher
|
||||
debug.FreeOSMemory()
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func bypass(conn net.Conn, destination *M.AddrPort) error {
|
||||
logrus.Info("BYPASS ", conn.RemoteAddr(), " ==> ", destination)
|
||||
serverConn, err := net.Dial("tcp", destination.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return task.Run(context.Background(), func() error {
|
||||
defer rw.CloseRead(conn)
|
||||
defer rw.CloseWrite(serverConn)
|
||||
return common.Error(io.Copy(serverConn, conn))
|
||||
}, func() error {
|
||||
defer rw.CloseRead(serverConn)
|
||||
defer rw.CloseWrite(conn)
|
||||
return common.Error(io.Copy(conn, serverConn))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
if c.bypass != "" {
|
||||
if metadata.Destination.Addr.Family().IsFqdn() {
|
||||
if c.Match(metadata.Destination.Addr.Fqdn()) {
|
||||
return bypass(conn, metadata.Destination)
|
||||
}
|
||||
} else {
|
||||
if geoip.Match(c.bypass, metadata.Destination.Addr.Addr().AsSlice()) {
|
||||
return bypass(conn, metadata.Destination)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
||||
ctx := context.Background()
|
||||
|
||||
var serverConn net.Conn
|
||||
payload := buf.New()
|
||||
err := task.Run(ctx, func() error {
|
||||
sc, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
|
||||
serverConn = sc
|
||||
if err != nil {
|
||||
return E.Cause(err, "connect to server")
|
||||
}
|
||||
return nil
|
||||
}, func() error {
|
||||
err := conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = payload.ReadFrom(conn)
|
||||
if err != nil && !E.IsTimeout(err) {
|
||||
return E.Cause(err, "read payload")
|
||||
}
|
||||
err = conn.SetReadDeadline(time.Time{})
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
payload.Release()
|
||||
return err
|
||||
}
|
||||
serverConn = c.method.DialEarlyConn(c.session, serverConn, metadata.Destination)
|
||||
_, err = serverConn.Write(payload.Bytes())
|
||||
payload.Release()
|
||||
if err != nil {
|
||||
return E.Cause(err, "client handshake")
|
||||
}
|
||||
|
||||
return rw.CopyConn(ctx, serverConn, conn)
|
||||
}
|
||||
|
||||
func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error {
|
||||
ctx := context.Background()
|
||||
udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverConn := c.method.DialPacketConn(c.session, udpConn)
|
||||
return task.Run(ctx, func() error {
|
||||
var init bool
|
||||
return socks.CopyPacketConn(serverConn, conn, func(destination *M.AddrPort, n int) {
|
||||
if !init {
|
||||
init = true
|
||||
logrus.Info("UDP ", conn.LocalAddr(), " ==> ", destination)
|
||||
} else {
|
||||
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination)
|
||||
}
|
||||
})
|
||||
}, func() error {
|
||||
return socks.CopyPacketConn(conn, serverConn, func(destination *M.AddrPort, n int) {
|
||||
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func Run(cmd *cobra.Command, flags *flags) {
|
||||
client, err := NewLocalClient(flags)
|
||||
if err != nil {
|
||||
logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n")
|
||||
cmd.Help()
|
||||
os.Exit(1)
|
||||
}
|
||||
err = client.Listener.Start()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
|
||||
logrus.Info("mixed server started at ", client.Listener.TCPListener.Addr())
|
||||
|
||||
osSignals := make(chan os.Signal, 1)
|
||||
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
|
||||
<-osSignals
|
||||
|
||||
client.Listener.Close()
|
||||
}
|
||||
|
||||
func (c *LocalClient) HandleError(err error) {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
logrus.Warn(err)
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
//go:build debug
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
func init() {
|
||||
go http.ListenAndServe("127.0.0.1:8964", nil)
|
||||
}
|
|
@ -1,324 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"github.com/sagernet/sing/common/geoip"
|
||||
"github.com/sagernet/sing/common/geosite"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"github.com/sagernet/sing/transport/system"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := MainCmd().Execute()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
type Flags struct {
|
||||
Server string `json:"server"`
|
||||
ServerPort uint16 `json:"server_port"`
|
||||
LocalPort uint16 `json:"local_port"`
|
||||
Password string `json:"password"`
|
||||
Key string `json:"key"`
|
||||
Method string `json:"method"`
|
||||
TCPFastOpen bool `json:"fast_open"`
|
||||
Verbose bool `json:"verbose"`
|
||||
Redirect string `json:"redir"`
|
||||
FWMark int `json:"fwmark"`
|
||||
Bypass string `json:"bypass"`
|
||||
ConfigFile string
|
||||
}
|
||||
|
||||
func MainCmd() *cobra.Command {
|
||||
flags := new(Flags)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "sslocal",
|
||||
Short: "shadowsocks client as socks5 proxy, sing port",
|
||||
Version: sing.Version,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
Run(cmd, flags)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVarP(&flags.Server, "server", "s", "", "Set the server’s hostname or IP.")
|
||||
cmd.Flags().Uint16VarP(&flags.ServerPort, "server-port", "p", 0, "Set the server’s port number.")
|
||||
cmd.Flags().Uint16VarP(&flags.LocalPort, "local-port", "l", 0, "Set the local port number.")
|
||||
cmd.Flags().StringVarP(&flags.Password, "password", "k", "", "Set the password. The server and the client should use the same password.")
|
||||
cmd.Flags().StringVar(&flags.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
|
||||
cmd.Flags().StringVarP(&flags.Method, "encrypt-method", "m", "", `Set the cipher.
|
||||
|
||||
Supported ciphers:
|
||||
|
||||
none
|
||||
aes-128-gcm
|
||||
aes-192-gcm
|
||||
aes-256-gcm
|
||||
chacha20-ietf-poly1305
|
||||
xchacha20-ietf-poly1305
|
||||
|
||||
The default cipher is chacha20-ietf-poly1305.`)
|
||||
cmd.Flags().BoolVar(&flags.TCPFastOpen, "fast-open", false, `Enable TCP fast open.
|
||||
Only available with Linux kernel > 3.7.0.`)
|
||||
cmd.Flags().StringVar(&flags.Redirect, "redir", "", "Enable transparent proxy support. [possible values: redirect, tproxy]")
|
||||
cmd.Flags().IntVar(&flags.FWMark, "fwmark", 0, "Set outbound socket mark.")
|
||||
cmd.Flags().StringVar(&flags.Bypass, "bypass", "", "Set bypass country.")
|
||||
cmd.Flags().StringVarP(&flags.ConfigFile, "config", "c", "", "Use a configuration file.")
|
||||
cmd.Flags().BoolVarP(&flags.Verbose, "verbose", "v", false, "Enable verbose mode.")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
type LocalClient struct {
|
||||
*system.MixedListener
|
||||
*shadowsocks.Client
|
||||
*geosite.Matcher
|
||||
redirect bool
|
||||
bypass string
|
||||
}
|
||||
|
||||
func NewLocalClient(flags *Flags) (*LocalClient, error) {
|
||||
if flags.ConfigFile != "" {
|
||||
configFile, err := ioutil.ReadFile(flags.ConfigFile)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "read config file")
|
||||
}
|
||||
flagsNew := new(Flags)
|
||||
err = json.Unmarshal(configFile, flagsNew)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "decode config file")
|
||||
}
|
||||
if flagsNew.Server != "" && flags.Server == "" {
|
||||
flags.Server = flagsNew.Server
|
||||
}
|
||||
if flagsNew.ServerPort != 0 && flags.ServerPort == 0 {
|
||||
flags.ServerPort = flagsNew.ServerPort
|
||||
}
|
||||
if flagsNew.LocalPort != 0 && flags.LocalPort == 0 {
|
||||
flags.LocalPort = flagsNew.LocalPort
|
||||
}
|
||||
if flagsNew.Password != "" && flags.Password == "" {
|
||||
flags.Password = flagsNew.Password
|
||||
}
|
||||
if flagsNew.Key != "" && flags.Key == "" {
|
||||
flags.Key = flagsNew.Key
|
||||
}
|
||||
if flagsNew.Method != "" && flags.Method == "" {
|
||||
flags.Method = flagsNew.Method
|
||||
}
|
||||
if flagsNew.Redirect != "" && flags.Redirect == "" {
|
||||
flags.Redirect = flagsNew.Redirect
|
||||
}
|
||||
if flagsNew.TCPFastOpen {
|
||||
flags.TCPFastOpen = true
|
||||
}
|
||||
if flagsNew.Verbose {
|
||||
flags.Verbose = true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
clientConfig := &shadowsocks.ClientConfig{
|
||||
Server: flags.Server,
|
||||
ServerPort: flags.ServerPort,
|
||||
Method: flags.Method,
|
||||
}
|
||||
|
||||
if flags.Key != "" {
|
||||
key, err := base64.URLEncoding.DecodeString(flags.Key)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "decode key")
|
||||
}
|
||||
clientConfig.Key = key
|
||||
} else if flags.Password != "" {
|
||||
clientConfig.Password = []byte(flags.Password)
|
||||
}
|
||||
|
||||
if flags.Verbose {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
}
|
||||
|
||||
dialer := new(net.Dialer)
|
||||
|
||||
dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
var rawFd uintptr
|
||||
err := c.Control(func(fd uintptr) {
|
||||
rawFd = fd
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if flags.FWMark > 0 {
|
||||
err = syscall.SetsockoptInt(int(rawFd), syscall.SOL_SOCKET, syscall.SO_MARK, flags.FWMark)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if flags.TCPFastOpen {
|
||||
err = system.TCPFastOpen(rawFd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
shadowClient, err := shadowsocks.NewClient(dialer, clientConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &LocalClient{
|
||||
Client: shadowClient,
|
||||
}
|
||||
client.MixedListener = system.NewMixedListener(netip.AddrPortFrom(netip.IPv6Unspecified(), flags.LocalPort), &system.MixedConfig{
|
||||
Redirect: flags.Redirect == "redirect",
|
||||
TProxy: flags.Redirect == "tproxy",
|
||||
}, client)
|
||||
|
||||
if flags.Bypass != "" {
|
||||
client.bypass = flags.Bypass
|
||||
|
||||
err = geoip.LoadMMDB("Country.mmdb")
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "load Country.mmdb")
|
||||
}
|
||||
|
||||
geodata, err := os.Open("geosite.dat")
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "geosite.dat not found")
|
||||
}
|
||||
|
||||
geositeMatcher, err := geosite.LoadGeositeMatcher(geodata, flags.Bypass)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.Matcher = geositeMatcher
|
||||
debug.FreeOSMemory()
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *LocalClient) Start() error {
|
||||
err := c.MixedListener.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logrus.Info("mixed server started at ", c.MixedListener.TCPListener.Addr())
|
||||
return nil
|
||||
}
|
||||
|
||||
func bypass(addr socksaddr.Addr, port uint16, conn net.Conn) error {
|
||||
logrus.Info("BYPASS ", conn.RemoteAddr(), " ==> ", net.JoinHostPort(addr.String(), strconv.Itoa(int(port))))
|
||||
serverConn, err := net.Dial("tcp", socksaddr.JoinHostPort(addr, port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return task.Run(context.Background(), func() error {
|
||||
defer rw.CloseRead(conn)
|
||||
defer rw.CloseWrite(serverConn)
|
||||
return common.Error(io.Copy(serverConn, conn))
|
||||
}, func() error {
|
||||
defer rw.CloseRead(serverConn)
|
||||
defer rw.CloseWrite(conn)
|
||||
return common.Error(io.Copy(conn, serverConn))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *LocalClient) NewConnection(addr socksaddr.Addr, port uint16, conn net.Conn) error {
|
||||
if c.bypass != "" {
|
||||
if addr.Family().IsFqdn() {
|
||||
if c.Match(addr.Fqdn()) {
|
||||
return bypass(addr, port, conn)
|
||||
}
|
||||
} else {
|
||||
if geoip.Match(c.bypass, addr.Addr().AsSlice()) {
|
||||
return bypass(addr, port, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", net.JoinHostPort(addr.String(), strconv.Itoa(int(port))))
|
||||
|
||||
ctx := context.Background()
|
||||
serverConn, err := c.DialContextTCP(ctx, addr, port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return task.Run(ctx, func() error {
|
||||
defer rw.CloseRead(conn)
|
||||
defer rw.CloseWrite(serverConn)
|
||||
return common.Error(io.Copy(serverConn, conn))
|
||||
}, func() error {
|
||||
defer rw.CloseRead(serverConn)
|
||||
defer rw.CloseWrite(conn)
|
||||
return common.Error(io.Copy(conn, serverConn))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, addr socksaddr.Addr, port uint16) error {
|
||||
ctx := context.Background()
|
||||
serverConn := c.DialContextUDP(ctx)
|
||||
return task.Run(ctx, func() error {
|
||||
var init bool
|
||||
return socks.CopyPacketConn(serverConn, conn, func(size int) {
|
||||
if !init {
|
||||
init = true
|
||||
logrus.Info("UDP ", conn.LocalAddr(), " ==> ", socksaddr.JoinHostPort(addr, port))
|
||||
} else {
|
||||
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", socksaddr.JoinHostPort(addr, port))
|
||||
}
|
||||
})
|
||||
}, func() error {
|
||||
return socks.CopyPacketConn(conn, serverConn, func(size int) {
|
||||
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", socksaddr.JoinHostPort(addr, port))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func Run(cmd *cobra.Command, flags *Flags) {
|
||||
client, err := NewLocalClient(flags)
|
||||
if err != nil {
|
||||
logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n")
|
||||
cmd.Help()
|
||||
os.Exit(1)
|
||||
}
|
||||
err = client.Start()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
osSignals := make(chan os.Signal, 1)
|
||||
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
|
||||
<-osSignals
|
||||
client.Close()
|
||||
}
|
||||
|
||||
func (c *LocalClient) OnError(err error) {
|
||||
if exceptions.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
logrus.Warn(err)
|
||||
}
|
140
cli/uot-local/main.go
Normal file
140
cli/uot-local/main.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/redir"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
"github.com/sagernet/sing/common/uot"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"github.com/sagernet/sing/transport/mixed"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
verbose bool
|
||||
transproxy string
|
||||
)
|
||||
|
||||
func main() {
|
||||
command := cobra.Command{
|
||||
Use: "uot-local <bind> <upstream>",
|
||||
Short: "SUoT client.",
|
||||
Long: "SUoT client. \n\nconverts a normal socks server to a SUoT mixed server.",
|
||||
Example: "uot-local 0.0.0.0:2080 127.0.0.1:1080",
|
||||
Version: sing.VersionStr,
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: run,
|
||||
}
|
||||
command.Flags().BoolVarP(&verbose, "verbose", "v", false, "Enable verbose mode")
|
||||
command.Flags().StringVarP(&transproxy, "transproxy", "t", "", "Enable transparent proxy support [possible values: redirect, tproxy]")
|
||||
err := command.Execute()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func run(cmd *cobra.Command, args []string) {
|
||||
if verbose {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
}
|
||||
|
||||
bind, err := netip.ParseAddrPort(args[0])
|
||||
if err != nil {
|
||||
logrus.Fatal("bad bind address: ", err)
|
||||
}
|
||||
|
||||
_, err = netip.ParseAddrPort(args[1])
|
||||
if err != nil {
|
||||
logrus.Fatal("bad upstream address: ", err)
|
||||
}
|
||||
|
||||
var transproxyMode redir.TransproxyMode
|
||||
switch transproxy {
|
||||
case "redirect":
|
||||
transproxyMode = redir.ModeRedirect
|
||||
case "tproxy":
|
||||
transproxyMode = redir.ModeTProxy
|
||||
case "":
|
||||
transproxyMode = redir.ModeDisabled
|
||||
default:
|
||||
logrus.Fatal("unknown transproxy mode ", transproxy)
|
||||
}
|
||||
|
||||
client := &localClient{upstream: args[1]}
|
||||
client.Listener = mixed.NewListener(bind, nil, transproxyMode, client)
|
||||
|
||||
err = client.Start()
|
||||
if err != nil {
|
||||
logrus.Fatal("start mixed server: ", err)
|
||||
}
|
||||
|
||||
logrus.Info("mixed server started at ", client.TCPListener.Addr())
|
||||
|
||||
osSignals := make(chan os.Signal, 1)
|
||||
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
|
||||
<-osSignals
|
||||
|
||||
client.Close()
|
||||
}
|
||||
|
||||
type localClient struct {
|
||||
*mixed.Listener
|
||||
upstream string
|
||||
}
|
||||
|
||||
func (c *localClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
||||
|
||||
upstream, err := net.Dial("tcp", c.upstream)
|
||||
if err != nil {
|
||||
return E.Cause(err, "connect to upstream")
|
||||
}
|
||||
|
||||
_, err = socks.ClientHandshake(upstream, socks.Version5, socks.CommandConnect, metadata.Destination, "", "")
|
||||
if err != nil {
|
||||
return E.Cause(err, "upstream handshake failed")
|
||||
}
|
||||
|
||||
return rw.CopyConn(context.Background(), upstream, conn)
|
||||
}
|
||||
|
||||
func (c *localClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error {
|
||||
upstream, err := net.Dial("tcp", c.upstream)
|
||||
if err != nil {
|
||||
return E.Cause(err, "connect to upstream")
|
||||
}
|
||||
|
||||
_, err = socks.ClientHandshake(upstream, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn(uot.UOTMagicAddress), 443), "", "")
|
||||
if err != nil {
|
||||
return E.Cause(err, "upstream handshake failed")
|
||||
}
|
||||
|
||||
client := uot.NewClientConn(upstream)
|
||||
return task.Run(context.Background(), func() error {
|
||||
return socks.CopyPacketConn(client, conn, func(destination *M.AddrPort, n int) {
|
||||
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination)
|
||||
})
|
||||
}, func() error {
|
||||
return socks.CopyPacketConn(conn, client, func(destination *M.AddrPort, n int) {
|
||||
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (c *localClient) OnError(err error) {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
logrus.Warn(err)
|
||||
}
|
46
common/auth/auth.go
Normal file
46
common/auth/auth.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Authenticator interface {
|
||||
Verify(user string, pass string) bool
|
||||
Users() []string
|
||||
}
|
||||
|
||||
type AuthUser struct {
|
||||
User string
|
||||
Pass string
|
||||
}
|
||||
|
||||
type inMemoryAuthenticator struct {
|
||||
storage *sync.Map
|
||||
usernames []string
|
||||
}
|
||||
|
||||
func (au *inMemoryAuthenticator) Verify(user string, pass string) bool {
|
||||
realPass, ok := au.storage.Load(user)
|
||||
return ok && realPass == pass
|
||||
}
|
||||
|
||||
func (au *inMemoryAuthenticator) Users() []string { return au.usernames }
|
||||
|
||||
func NewAuthenticator(users []AuthUser) Authenticator {
|
||||
if len(users) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
au := &inMemoryAuthenticator{storage: &sync.Map{}}
|
||||
for _, user := range users {
|
||||
au.storage.Store(user.User, user.Pass)
|
||||
}
|
||||
usernames := make([]string, 0, len(users))
|
||||
au.storage.Range(func(key, value interface{}) bool {
|
||||
usernames = append(usernames, key.(string))
|
||||
return true
|
||||
})
|
||||
au.usernames = usernames
|
||||
|
||||
return au
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/list"
|
||||
|
@ -188,6 +189,18 @@ func (b *Buffer) ReadFrom(r io.Reader) (int64, error) {
|
|||
return int64(n), nil
|
||||
}
|
||||
|
||||
func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) {
|
||||
if b.IsFull() {
|
||||
return 0, nil, io.ErrShortBuffer
|
||||
}
|
||||
n, addr, err := r.ReadFrom(b.FreeBytes())
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
b.end += n
|
||||
return int64(n), addr, nil
|
||||
}
|
||||
|
||||
func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
|
||||
if min <= 0 {
|
||||
return b.ReadFrom(r)
|
||||
|
|
|
@ -70,3 +70,12 @@ func (w *BufferedWriter) Flush() error {
|
|||
}
|
||||
return common.Error(w.Writer.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Close() error {
|
||||
buffer := w.Buffer
|
||||
if buffer != nil {
|
||||
w.Buffer = nil
|
||||
buffer.Release()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
106
common/cache/cache.go
vendored
Normal file
106
common/cache/cache.go
vendored
Normal file
|
@ -0,0 +1,106 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cache store element with a expired time
|
||||
type Cache struct {
|
||||
*cache
|
||||
}
|
||||
|
||||
type cache struct {
|
||||
mapping sync.Map
|
||||
janitor *janitor
|
||||
}
|
||||
|
||||
type element struct {
|
||||
Expired time.Time
|
||||
Payload interface{}
|
||||
}
|
||||
|
||||
// Put element in Cache with its ttl
|
||||
func (c *cache) Put(key interface{}, payload interface{}, ttl time.Duration) {
|
||||
c.mapping.Store(key, &element{
|
||||
Payload: payload,
|
||||
Expired: time.Now().Add(ttl),
|
||||
})
|
||||
}
|
||||
|
||||
// Get element in Cache, and drop when it expired
|
||||
func (c *cache) Get(key interface{}) interface{} {
|
||||
item, exist := c.mapping.Load(key)
|
||||
if !exist {
|
||||
return nil
|
||||
}
|
||||
elm := item.(*element)
|
||||
// expired
|
||||
if time.Since(elm.Expired) > 0 {
|
||||
c.mapping.Delete(key)
|
||||
return nil
|
||||
}
|
||||
return elm.Payload
|
||||
}
|
||||
|
||||
// GetWithExpire element in Cache with Expire Time
|
||||
func (c *cache) GetWithExpire(key interface{}) (payload interface{}, expired time.Time) {
|
||||
item, exist := c.mapping.Load(key)
|
||||
if !exist {
|
||||
return
|
||||
}
|
||||
elm := item.(*element)
|
||||
// expired
|
||||
if time.Since(elm.Expired) > 0 {
|
||||
c.mapping.Delete(key)
|
||||
return
|
||||
}
|
||||
return elm.Payload, elm.Expired
|
||||
}
|
||||
|
||||
func (c *cache) cleanup() {
|
||||
c.mapping.Range(func(k, v interface{}) bool {
|
||||
key := k.(string)
|
||||
elm := v.(*element)
|
||||
if time.Since(elm.Expired) > 0 {
|
||||
c.mapping.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
type janitor struct {
|
||||
interval time.Duration
|
||||
stop chan struct{}
|
||||
}
|
||||
|
||||
func (j *janitor) process(c *cache) {
|
||||
ticker := time.NewTicker(j.interval)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.cleanup()
|
||||
case <-j.stop:
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stopJanitor(c *Cache) {
|
||||
c.janitor.stop <- struct{}{}
|
||||
}
|
||||
|
||||
// New return *Cache
|
||||
func New(interval time.Duration) *Cache {
|
||||
j := &janitor{
|
||||
interval: interval,
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
c := &cache{janitor: j}
|
||||
go j.process(c)
|
||||
C := &Cache{c}
|
||||
runtime.SetFinalizer(C, stopJanitor)
|
||||
return C
|
||||
}
|
70
common/cache/cache_test.go
vendored
Normal file
70
common/cache/cache_test.go
vendored
Normal file
|
@ -0,0 +1,70 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCache_Basic(t *testing.T) {
|
||||
interval := 200 * time.Millisecond
|
||||
ttl := 20 * time.Millisecond
|
||||
c := New(interval)
|
||||
c.Put("int", 1, ttl)
|
||||
c.Put("string", "a", ttl)
|
||||
|
||||
i := c.Get("int")
|
||||
assert.Equal(t, i.(int), 1, "should recv 1")
|
||||
|
||||
s := c.Get("string")
|
||||
assert.Equal(t, s.(string), "a", "should recv 'a'")
|
||||
}
|
||||
|
||||
func TestCache_TTL(t *testing.T) {
|
||||
interval := 200 * time.Millisecond
|
||||
ttl := 20 * time.Millisecond
|
||||
now := time.Now()
|
||||
c := New(interval)
|
||||
c.Put("int", 1, ttl)
|
||||
c.Put("int2", 2, ttl)
|
||||
|
||||
i := c.Get("int")
|
||||
_, expired := c.GetWithExpire("int2")
|
||||
assert.Equal(t, i.(int), 1, "should recv 1")
|
||||
assert.True(t, now.Before(expired))
|
||||
|
||||
time.Sleep(ttl * 2)
|
||||
i = c.Get("int")
|
||||
j, _ := c.GetWithExpire("int2")
|
||||
assert.Nil(t, i, "should recv nil")
|
||||
assert.Nil(t, j, "should recv nil")
|
||||
}
|
||||
|
||||
func TestCache_AutoCleanup(t *testing.T) {
|
||||
interval := 10 * time.Millisecond
|
||||
ttl := 15 * time.Millisecond
|
||||
c := New(interval)
|
||||
c.Put("int", 1, ttl)
|
||||
|
||||
time.Sleep(ttl * 2)
|
||||
i := c.Get("int")
|
||||
j, _ := c.GetWithExpire("int")
|
||||
assert.Nil(t, i, "should recv nil")
|
||||
assert.Nil(t, j, "should recv nil")
|
||||
}
|
||||
|
||||
func TestCache_AutoGC(t *testing.T) {
|
||||
sign := make(chan struct{})
|
||||
go func() {
|
||||
interval := 10 * time.Millisecond
|
||||
ttl := 15 * time.Millisecond
|
||||
c := New(interval)
|
||||
c.Put("int", 1, ttl)
|
||||
sign <- struct{}{}
|
||||
}()
|
||||
|
||||
<-sign
|
||||
runtime.GC()
|
||||
}
|
223
common/cache/lrucache.go
vendored
Normal file
223
common/cache/lrucache.go
vendored
Normal file
|
@ -0,0 +1,223 @@
|
|||
package cache
|
||||
|
||||
// Modified by https://github.com/die-net/lrucache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Option is part of Functional Options Pattern
|
||||
type Option func(*LruCache)
|
||||
|
||||
// EvictCallback is used to get a callback when a cache entry is evicted
|
||||
type EvictCallback = func(key interface{}, value interface{})
|
||||
|
||||
// WithEvict set the evict callback
|
||||
func WithEvict(cb EvictCallback) Option {
|
||||
return func(l *LruCache) {
|
||||
l.onEvict = cb
|
||||
}
|
||||
}
|
||||
|
||||
// WithUpdateAgeOnGet update expires when Get element
|
||||
func WithUpdateAgeOnGet() Option {
|
||||
return func(l *LruCache) {
|
||||
l.updateAgeOnGet = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithAge defined element max age (second)
|
||||
func WithAge(maxAge int64) Option {
|
||||
return func(l *LruCache) {
|
||||
l.maxAge = maxAge
|
||||
}
|
||||
}
|
||||
|
||||
// WithSize defined max length of LruCache
|
||||
func WithSize(maxSize int) Option {
|
||||
return func(l *LruCache) {
|
||||
l.maxSize = maxSize
|
||||
}
|
||||
}
|
||||
|
||||
// WithStale decide whether Stale return is enabled.
|
||||
// If this feature is enabled, element will not get Evicted according to `WithAge`.
|
||||
func WithStale(stale bool) Option {
|
||||
return func(l *LruCache) {
|
||||
l.staleReturn = stale
|
||||
}
|
||||
}
|
||||
|
||||
// LruCache is a thread-safe, in-memory lru-cache that evicts the
|
||||
// least recently used entries from memory when (if set) the entries are
|
||||
// older than maxAge (in seconds). Use the New constructor to create one.
|
||||
type LruCache struct {
|
||||
maxAge int64
|
||||
maxSize int
|
||||
mu sync.Mutex
|
||||
cache map[interface{}]*list.Element
|
||||
lru *list.List // Front is least-recent
|
||||
updateAgeOnGet bool
|
||||
staleReturn bool
|
||||
onEvict EvictCallback
|
||||
}
|
||||
|
||||
// NewLRUCache creates an LruCache
|
||||
func NewLRUCache(options ...Option) *LruCache {
|
||||
lc := &LruCache{
|
||||
lru: list.New(),
|
||||
cache: make(map[interface{}]*list.Element),
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
option(lc)
|
||||
}
|
||||
|
||||
return lc
|
||||
}
|
||||
|
||||
// Get returns the interface{} representation of a cached response and a bool
|
||||
// set to true if the key was found.
|
||||
func (c *LruCache) Get(key interface{}) (interface{}, bool) {
|
||||
entry := c.get(key)
|
||||
if entry == nil {
|
||||
return nil, false
|
||||
}
|
||||
value := entry.value
|
||||
|
||||
return value, true
|
||||
}
|
||||
|
||||
// GetWithExpire returns the interface{} representation of a cached response,
|
||||
// a time.Time Give expected expires,
|
||||
// and a bool set to true if the key was found.
|
||||
// This method will NOT check the maxAge of element and will NOT update the expires.
|
||||
func (c *LruCache) GetWithExpire(key interface{}) (interface{}, time.Time, bool) {
|
||||
entry := c.get(key)
|
||||
if entry == nil {
|
||||
return nil, time.Time{}, false
|
||||
}
|
||||
|
||||
return entry.value, time.Unix(entry.expires, 0), true
|
||||
}
|
||||
|
||||
// Exist returns if key exist in cache but not put item to the head of linked list
|
||||
func (c *LruCache) Exist(key interface{}) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
_, ok := c.cache[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Set stores the interface{} representation of a response for a given key.
|
||||
func (c *LruCache) Set(key interface{}, value interface{}) {
|
||||
expires := int64(0)
|
||||
if c.maxAge > 0 {
|
||||
expires = time.Now().Unix() + c.maxAge
|
||||
}
|
||||
c.SetWithExpire(key, value, time.Unix(expires, 0))
|
||||
}
|
||||
|
||||
// SetWithExpire stores the interface{} representation of a response for a given key and given expires.
|
||||
// The expires time will round to second.
|
||||
func (c *LruCache) SetWithExpire(key interface{}, value interface{}, expires time.Time) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if le, ok := c.cache[key]; ok {
|
||||
c.lru.MoveToBack(le)
|
||||
e := le.Value.(*entry)
|
||||
e.value = value
|
||||
e.expires = expires.Unix()
|
||||
} else {
|
||||
e := &entry{key: key, value: value, expires: expires.Unix()}
|
||||
c.cache[key] = c.lru.PushBack(e)
|
||||
|
||||
if c.maxSize > 0 {
|
||||
if len := c.lru.Len(); len > c.maxSize {
|
||||
c.deleteElement(c.lru.Front())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.maybeDeleteOldest()
|
||||
}
|
||||
|
||||
// CloneTo clone and overwrite elements to another LruCache
|
||||
func (c *LruCache) CloneTo(n *LruCache) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
|
||||
n.lru = list.New()
|
||||
n.cache = make(map[interface{}]*list.Element)
|
||||
|
||||
for e := c.lru.Front(); e != nil; e = e.Next() {
|
||||
elm := e.Value.(*entry)
|
||||
n.cache[elm.key] = n.lru.PushBack(elm)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LruCache) get(key interface{}) *entry {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
le, ok := c.cache[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() {
|
||||
c.deleteElement(le)
|
||||
c.maybeDeleteOldest()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
c.lru.MoveToBack(le)
|
||||
entry := le.Value.(*entry)
|
||||
if c.maxAge > 0 && c.updateAgeOnGet {
|
||||
entry.expires = time.Now().Unix() + c.maxAge
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
// Delete removes the value associated with a key.
|
||||
func (c *LruCache) Delete(key interface{}) {
|
||||
c.mu.Lock()
|
||||
|
||||
if le, ok := c.cache[key]; ok {
|
||||
c.deleteElement(le)
|
||||
}
|
||||
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *LruCache) maybeDeleteOldest() {
|
||||
if !c.staleReturn && c.maxAge > 0 {
|
||||
now := time.Now().Unix()
|
||||
for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() {
|
||||
c.deleteElement(le)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LruCache) deleteElement(le *list.Element) {
|
||||
c.lru.Remove(le)
|
||||
e := le.Value.(*entry)
|
||||
delete(c.cache, e.key)
|
||||
if c.onEvict != nil {
|
||||
c.onEvict(e.key, e.value)
|
||||
}
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
key interface{}
|
||||
value interface{}
|
||||
expires int64
|
||||
}
|
183
common/cache/lrucache_test.go
vendored
Normal file
183
common/cache/lrucache_test.go
vendored
Normal file
|
@ -0,0 +1,183 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var entries = []struct {
|
||||
key string
|
||||
value string
|
||||
}{
|
||||
{"1", "one"},
|
||||
{"2", "two"},
|
||||
{"3", "three"},
|
||||
{"4", "four"},
|
||||
{"5", "five"},
|
||||
}
|
||||
|
||||
func TestLRUCache(t *testing.T) {
|
||||
c := NewLRUCache()
|
||||
|
||||
for _, e := range entries {
|
||||
c.Set(e.key, e.value)
|
||||
}
|
||||
|
||||
c.Delete("missing")
|
||||
_, ok := c.Get("missing")
|
||||
assert.False(t, ok)
|
||||
|
||||
for _, e := range entries {
|
||||
value, ok := c.Get(e.key)
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, e.value, value.(string))
|
||||
}
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
c.Delete(e.key)
|
||||
|
||||
_, ok := c.Get(e.key)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRUMaxAge(t *testing.T) {
|
||||
c := NewLRUCache(WithAge(86400))
|
||||
|
||||
now := time.Now().Unix()
|
||||
expected := now + 86400
|
||||
|
||||
// Add one expired entry
|
||||
c.Set("foo", "bar")
|
||||
c.lru.Back().Value.(*entry).expires = now
|
||||
|
||||
// Reset
|
||||
c.Set("foo", "bar")
|
||||
e := c.lru.Back().Value.(*entry)
|
||||
assert.True(t, e.expires >= now)
|
||||
c.lru.Back().Value.(*entry).expires = now
|
||||
|
||||
// Set a few and verify expiration times
|
||||
for _, s := range entries {
|
||||
c.Set(s.key, s.value)
|
||||
e := c.lru.Back().Value.(*entry)
|
||||
assert.True(t, e.expires >= expected && e.expires <= expected+10)
|
||||
}
|
||||
|
||||
// Make sure we can get them all
|
||||
for _, s := range entries {
|
||||
_, ok := c.Get(s.key)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
// Expire all entries
|
||||
for _, s := range entries {
|
||||
le, ok := c.cache[s.key]
|
||||
if assert.True(t, ok) {
|
||||
le.Value.(*entry).expires = now
|
||||
}
|
||||
}
|
||||
|
||||
// Get one expired entry, which should clear all expired entries
|
||||
_, ok := c.Get("3")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, c.lru.Len(), 0)
|
||||
}
|
||||
|
||||
func TestLRUpdateOnGet(t *testing.T) {
|
||||
c := NewLRUCache(WithAge(86400), WithUpdateAgeOnGet())
|
||||
|
||||
now := time.Now().Unix()
|
||||
expires := now + 86400/2
|
||||
|
||||
// Add one expired entry
|
||||
c.Set("foo", "bar")
|
||||
c.lru.Back().Value.(*entry).expires = expires
|
||||
|
||||
_, ok := c.Get("foo")
|
||||
assert.True(t, ok)
|
||||
assert.True(t, c.lru.Back().Value.(*entry).expires > expires)
|
||||
}
|
||||
|
||||
func TestMaxSize(t *testing.T) {
|
||||
c := NewLRUCache(WithSize(2))
|
||||
// Add one expired entry
|
||||
c.Set("foo", "bar")
|
||||
_, ok := c.Get("foo")
|
||||
assert.True(t, ok)
|
||||
|
||||
c.Set("bar", "foo")
|
||||
c.Set("baz", "foo")
|
||||
|
||||
_, ok = c.Get("foo")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestExist(t *testing.T) {
|
||||
c := NewLRUCache(WithSize(1))
|
||||
c.Set(1, 2)
|
||||
assert.True(t, c.Exist(1))
|
||||
c.Set(2, 3)
|
||||
assert.False(t, c.Exist(1))
|
||||
}
|
||||
|
||||
func TestEvict(t *testing.T) {
|
||||
temp := 0
|
||||
evict := func(key interface{}, value interface{}) {
|
||||
temp = key.(int) + value.(int)
|
||||
}
|
||||
|
||||
c := NewLRUCache(WithEvict(evict), WithSize(1))
|
||||
c.Set(1, 2)
|
||||
c.Set(2, 3)
|
||||
|
||||
assert.Equal(t, temp, 3)
|
||||
}
|
||||
|
||||
func TestSetWithExpire(t *testing.T) {
|
||||
c := NewLRUCache(WithAge(1))
|
||||
now := time.Now().Unix()
|
||||
|
||||
tenSecBefore := time.Unix(now-10, 0)
|
||||
c.SetWithExpire(1, 2, tenSecBefore)
|
||||
|
||||
// res is expected not to exist, and expires should be empty time.Time
|
||||
res, expires, exist := c.GetWithExpire(1)
|
||||
assert.Equal(t, nil, res)
|
||||
assert.Equal(t, time.Time{}, expires)
|
||||
assert.Equal(t, false, exist)
|
||||
}
|
||||
|
||||
func TestStale(t *testing.T) {
|
||||
c := NewLRUCache(WithAge(1), WithStale(true))
|
||||
now := time.Now().Unix()
|
||||
|
||||
tenSecBefore := time.Unix(now-10, 0)
|
||||
c.SetWithExpire(1, 2, tenSecBefore)
|
||||
|
||||
res, expires, exist := c.GetWithExpire(1)
|
||||
assert.Equal(t, 2, res)
|
||||
assert.Equal(t, tenSecBefore, expires)
|
||||
assert.Equal(t, true, exist)
|
||||
}
|
||||
|
||||
func TestCloneTo(t *testing.T) {
|
||||
o := NewLRUCache(WithSize(10))
|
||||
o.Set("1", 1)
|
||||
o.Set("2", 2)
|
||||
|
||||
n := NewLRUCache(WithSize(2))
|
||||
n.Set("3", 3)
|
||||
n.Set("4", 4)
|
||||
|
||||
o.CloneTo(n)
|
||||
|
||||
assert.False(t, n.Exist("3"))
|
||||
assert.True(t, n.Exist("1"))
|
||||
|
||||
n.Set("5", 5)
|
||||
assert.False(t, n.Exist("1"))
|
||||
}
|
|
@ -43,6 +43,16 @@ func Filter[T any](arr []T, block func(it T) bool) []T {
|
|||
return retArr
|
||||
}
|
||||
|
||||
func FilterIsInstance[T any, N any](arr []T, block func(it T) (N, bool)) []N {
|
||||
var retArr []N
|
||||
for _, it := range arr {
|
||||
if n, isN := block(it); isN {
|
||||
retArr = append(retArr, n)
|
||||
}
|
||||
}
|
||||
return retArr
|
||||
}
|
||||
|
||||
func Done(ctx context.Context) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
|
|
@ -62,3 +62,7 @@ func IsTimeout(err error) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
HandleError(err error)
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ func FlushVar(writerP *io.Writer) error {
|
|||
if writerBack == writer {
|
||||
writer = u.Upstream()
|
||||
writerBack = writer
|
||||
writerP = &writer
|
||||
*writerP = writer
|
||||
continue
|
||||
} else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter {
|
||||
setter.SetWriter(writerBack)
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
package geoip
|
||||
|
||||
import (
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -3,12 +3,13 @@ package geosite
|
|||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/trieset"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/trieset"
|
||||
)
|
||||
|
||||
type Matcher struct {
|
||||
|
@ -25,7 +26,7 @@ func LoadGeositeMatcher(reader io.Reader, code string) (*Matcher, error) {
|
|||
return nil, err
|
||||
}
|
||||
if version != 0 {
|
||||
return nil, exceptions.New("bad geosite data")
|
||||
return nil, E.New("bad geosite data")
|
||||
}
|
||||
decoder, err := zstd.NewReader(reader, zstd.WithDecoderLowmem(true), zstd.WithDecoderConcurrency(1))
|
||||
if err != nil {
|
||||
|
@ -72,5 +73,5 @@ func LoadGeositeMatcher(reader io.Reader, code string) (*Matcher, error) {
|
|||
}
|
||||
}
|
||||
}
|
||||
return nil, exceptions.New(code, " not found in geosite")
|
||||
return nil, E.New(code, " not found in geosite")
|
||||
}
|
||||
|
|
|
@ -201,7 +201,7 @@ func (e *entry[T]) storeLocked(i *T) {
|
|||
// LoadOrStore returns the existing value for the key if present.
|
||||
// Otherwise, it stores and returns the given value.
|
||||
// The loaded result is true if the value was loaded, false if stored.
|
||||
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
|
||||
func (m *Map[K, V]) LoadOrStore(key K, value func() V) (actual V, loaded bool) {
|
||||
// Avoid locking if it's a clean hit.
|
||||
read, _ := m.read.Load().(readOnly[K, V])
|
||||
if e, ok := read.m[key]; ok {
|
||||
|
@ -228,8 +228,9 @@ func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
|
|||
m.dirtyLocked()
|
||||
m.read.Store(readOnly[K, V]{m: read.m, amended: true})
|
||||
}
|
||||
m.dirty[key] = newEntry(value)
|
||||
actual, loaded = value, false
|
||||
v := value()
|
||||
m.dirty[key] = newEntry(v)
|
||||
actual, loaded = v, false
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
|
@ -241,7 +242,7 @@ func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
|
|||
//
|
||||
// If the entry is expunged, tryLoadOrStore leaves the entry unchanged and
|
||||
// returns with ok==false.
|
||||
func (e *entry[T]) tryLoadOrStore(i T) (actual T, loaded, ok bool) {
|
||||
func (e *entry[T]) tryLoadOrStore(i func() T) (actual T, loaded, ok bool) {
|
||||
p := atomic.LoadPointer(&e.p)
|
||||
if p == expunged {
|
||||
var defaultValue T
|
||||
|
@ -254,10 +255,10 @@ func (e *entry[T]) tryLoadOrStore(i T) (actual T, loaded, ok bool) {
|
|||
// Copy the interface after the first load to make this method more amenable
|
||||
// to escape analysis: if we hit the "load" path or the entry is expunged, we
|
||||
// shouldn't bother heap-allocating.
|
||||
ic := i
|
||||
ic := i()
|
||||
for {
|
||||
if atomic.CompareAndSwapPointer(&e.p, nil, unsafe.Pointer(&ic)) {
|
||||
return i, false, true
|
||||
return ic, false, true
|
||||
}
|
||||
p = atomic.LoadPointer(&e.p)
|
||||
if p == expunged {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package socksaddr
|
||||
package metadata
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
@ -13,6 +13,35 @@ type Addr interface {
|
|||
String() string
|
||||
}
|
||||
|
||||
type AddrPort struct {
|
||||
Addr Addr
|
||||
Port uint16
|
||||
}
|
||||
|
||||
func (ap AddrPort) IPAddr() *net.IPAddr {
|
||||
return &net.IPAddr{
|
||||
IP: ap.Addr.Addr().AsSlice(),
|
||||
}
|
||||
}
|
||||
|
||||
func (ap AddrPort) TCPAddr() *net.TCPAddr {
|
||||
return &net.TCPAddr{
|
||||
IP: ap.Addr.Addr().AsSlice(),
|
||||
Port: int(ap.Port),
|
||||
}
|
||||
}
|
||||
|
||||
func (ap AddrPort) UDPAddr() *net.UDPAddr {
|
||||
return &net.UDPAddr{
|
||||
IP: ap.Addr.Addr().AsSlice(),
|
||||
Port: int(ap.Port),
|
||||
}
|
||||
}
|
||||
|
||||
func (ap AddrPort) String() string {
|
||||
return net.JoinHostPort(ap.Addr.String(), strconv.Itoa(int(ap.Port)))
|
||||
}
|
||||
|
||||
func ParseAddr(address string) Addr {
|
||||
addr, err := netip.ParseAddr(address)
|
||||
if err == nil {
|
||||
|
@ -21,16 +50,44 @@ func ParseAddr(address string) Addr {
|
|||
return AddrFromFqdn(address)
|
||||
}
|
||||
|
||||
func ParseAddrPort(address string) (Addr, uint16, error) {
|
||||
func AddrPortFrom(addr Addr, port uint16) *AddrPort {
|
||||
return &AddrPort{addr, port}
|
||||
}
|
||||
|
||||
func ParseAddress(address string) (*AddrPort, error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
return ParseAddr(host), uint16(portInt), nil
|
||||
return AddrPortFrom(ParseAddr(host), uint16(portInt)), nil
|
||||
}
|
||||
|
||||
func ParseAddrPort(address string, port string) (*AddrPort, error) {
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return AddrPortFrom(ParseAddr(address), uint16(portInt)), nil
|
||||
}
|
||||
|
||||
func AddrPortFromNetAddr(netAddr net.Addr) *AddrPort {
|
||||
var ip net.IP
|
||||
var port uint16
|
||||
switch addr := netAddr.(type) {
|
||||
case *net.TCPAddr:
|
||||
ip = addr.IP
|
||||
port = uint16(addr.Port)
|
||||
case *net.UDPAddr:
|
||||
ip = addr.IP
|
||||
port = uint16(addr.Port)
|
||||
case *net.IPAddr:
|
||||
ip = addr.IP
|
||||
}
|
||||
return AddrPortFrom(AddrFromIP(ip), port)
|
||||
}
|
||||
|
||||
func AddrFromIP(ip net.IP) Addr {
|
||||
|
@ -50,27 +107,14 @@ func AddrFromAddr(addr netip.Addr) Addr {
|
|||
}
|
||||
}
|
||||
|
||||
func AddrFromNetAddr(netAddr net.Addr) (addr Addr, port uint16) {
|
||||
var ip net.IP
|
||||
switch addr := netAddr.(type) {
|
||||
case *net.TCPAddr:
|
||||
ip = addr.IP
|
||||
port = uint16(addr.Port)
|
||||
case *net.UDPAddr:
|
||||
ip = addr.IP
|
||||
port = uint16(addr.Port)
|
||||
}
|
||||
return AddrFromIP(ip), port
|
||||
func AddrPortFromAddrPort(addrPort netip.AddrPort) *AddrPort {
|
||||
return AddrPortFrom(AddrFromAddr(addrPort.Addr()), addrPort.Port())
|
||||
}
|
||||
|
||||
func AddrFromFqdn(fqdn string) Addr {
|
||||
return AddrFqdn(fqdn)
|
||||
}
|
||||
|
||||
func JoinHostPort(addr Addr, port uint16) string {
|
||||
return net.JoinHostPort(addr.String(), strconv.Itoa(int(port)))
|
||||
}
|
||||
|
||||
type Addr4 [4]byte
|
||||
|
||||
func (a Addr4) Family() Family {
|
|
@ -1,4 +1,4 @@
|
|||
package socksaddr
|
||||
package metadata
|
||||
|
||||
import "fmt"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package socksaddr
|
||||
package metadata
|
||||
|
||||
type Family byte
|
||||
|
20
common/metadata/metadata.go
Normal file
20
common/metadata/metadata.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package metadata
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
type Metadata struct {
|
||||
Source *AddrPort
|
||||
Destination *AddrPort
|
||||
}
|
||||
|
||||
type TCPConnectionHandler interface {
|
||||
NewConnection(conn net.Conn, metadata Metadata) error
|
||||
}
|
||||
|
||||
type UDPHandler interface {
|
||||
NewPacket(packet *buf.Buffer, metadata Metadata) error
|
||||
}
|
|
@ -1,11 +1,11 @@
|
|||
package socksaddr
|
||||
package metadata
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
)
|
||||
|
||||
|
@ -24,16 +24,9 @@ func PortThenAddress() SerializerOption {
|
|||
}
|
||||
}
|
||||
|
||||
func WithFamilyParser(fp FamilyParser) SerializerOption {
|
||||
return func(s *Serializer) {
|
||||
s.familyParser = fp
|
||||
}
|
||||
}
|
||||
|
||||
type Serializer struct {
|
||||
familyMap map[byte]Family
|
||||
familyByteMap map[Family]byte
|
||||
familyParser FamilyParser
|
||||
portFirst bool
|
||||
}
|
||||
|
||||
|
@ -66,20 +59,20 @@ func (s *Serializer) WritePort(writer io.Writer, port uint16) error {
|
|||
return binary.Write(writer, binary.BigEndian, port)
|
||||
}
|
||||
|
||||
func (s *Serializer) WriteAddressAndPort(writer io.Writer, addr Addr, port uint16) error {
|
||||
func (s *Serializer) WriteAddrPort(writer io.Writer, addrPort *AddrPort) error {
|
||||
var err error
|
||||
if !s.portFirst {
|
||||
err = s.WriteAddress(writer, addr)
|
||||
err = s.WriteAddress(writer, addrPort.Addr)
|
||||
} else {
|
||||
err = s.WritePort(writer, port)
|
||||
err = s.WritePort(writer, addrPort.Port)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.portFirst {
|
||||
err = s.WriteAddress(writer, addr)
|
||||
err = s.WriteAddress(writer, addrPort.Addr)
|
||||
} else {
|
||||
err = s.WritePort(writer, port)
|
||||
err = s.WritePort(writer, addrPort.Port)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -89,15 +82,12 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.familyParser != nil {
|
||||
af = s.familyParser(af)
|
||||
}
|
||||
family := s.familyMap[af]
|
||||
switch family {
|
||||
case AddressFamilyFqdn:
|
||||
fqdn, err := ReadString(reader)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "read fqdn")
|
||||
return nil, E.Cause(err, "read fqdn")
|
||||
}
|
||||
return AddrFqdn(fqdn), nil
|
||||
default:
|
||||
|
@ -106,18 +96,18 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) {
|
|||
var addr [4]byte
|
||||
err = common.Error(reader.Read(addr[:]))
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "read ipv4 address")
|
||||
return nil, E.Cause(err, "read ipv4 address")
|
||||
}
|
||||
return Addr4(addr), nil
|
||||
case AddressFamilyIPv6:
|
||||
var addr [16]byte
|
||||
err = common.Error(reader.Read(addr[:]))
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "read ipv6 address")
|
||||
return nil, E.Cause(err, "read ipv6 address")
|
||||
}
|
||||
return Addr16(addr), nil
|
||||
default:
|
||||
return nil, exceptions.New("unknown address family: ", af)
|
||||
return nil, E.New("unknown address family: ", af)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -125,12 +115,14 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) {
|
|||
func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) {
|
||||
port, err := rw.ReadBytes(reader, 2)
|
||||
if err != nil {
|
||||
return 0, exceptions.Cause(err, "read port")
|
||||
return 0, E.Cause(err, "read port")
|
||||
}
|
||||
return binary.BigEndian.Uint16(port), nil
|
||||
}
|
||||
|
||||
func (s *Serializer) ReadAddressAndPort(reader io.Reader) (addr Addr, port uint16, err error) {
|
||||
func (s *Serializer) ReadAddrPort(reader io.Reader) (addrPort *AddrPort, err error) {
|
||||
var addr Addr
|
||||
var port uint16
|
||||
if !s.portFirst {
|
||||
addr, err = s.ReadAddress(reader)
|
||||
} else {
|
||||
|
@ -144,7 +136,10 @@ func (s *Serializer) ReadAddressAndPort(reader io.Reader) (addr Addr, port uint1
|
|||
} else {
|
||||
port, err = s.ReadPort(reader)
|
||||
}
|
||||
return
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return AddrPortFrom(addr, port), nil
|
||||
}
|
||||
|
||||
func ReadString(reader io.Reader) (string, error) {
|
15
common/net.go
Normal file
15
common/net.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package common
|
||||
|
||||
import "syscall"
|
||||
|
||||
func GetFileDescriptor(conn syscall.Conn) (uintptr, error) {
|
||||
rawConn, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var rawFd uintptr
|
||||
err = rawConn.Control(func(fd uintptr) {
|
||||
rawFd = fd
|
||||
})
|
||||
return rawFd, err
|
||||
}
|
18
common/random/rng.go
Normal file
18
common/random/rng.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package random
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"lukechampine.com/blake3"
|
||||
)
|
||||
|
||||
var System = rand.Reader
|
||||
|
||||
func Blake3KeyedHash() io.Reader {
|
||||
key := make([]byte, 32)
|
||||
common.Must1(io.ReadFull(System, key))
|
||||
h := blake3.New(1024, key)
|
||||
return h.XOF()
|
||||
}
|
9
common/redir/mode.go
Normal file
9
common/redir/mode.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package redir
|
||||
|
||||
type TransproxyMode uint8
|
||||
|
||||
const (
|
||||
ModeDisabled TransproxyMode = iota
|
||||
ModeRedirect
|
||||
ModeTProxy
|
||||
)
|
|
@ -2,11 +2,12 @@ package redir
|
|||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err error) {
|
||||
func GetOriginalDestination(conn net.Conn) (destination *M.AddrPort, err error) {
|
||||
rawConn, err := conn.(syscall.Conn).SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -22,15 +23,14 @@ func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err erro
|
|||
if conn.RemoteAddr().(*net.TCPAddr).IP.To4() != nil {
|
||||
raw, err := syscall.GetsockoptIPv6Mreq(int(rawFd), syscall.IPPROTO_IP, SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
return nil, err
|
||||
}
|
||||
addr, _ := netip.AddrFromSlice(raw.Multiaddr[4:8])
|
||||
return netip.AddrPortFrom(addr, uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
|
||||
|
||||
return M.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
|
||||
} else {
|
||||
raw, err := syscall.GetsockoptIPv6MTUInfo(int(rawFd), syscall.IPPROTO_IPV6, SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return M.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), raw.Addr.Port), nil
|
||||
}
|
||||
raw, err := syscall.GetsockoptIPv6MTUInfo(int(rawFd), syscall.IPPROTO_IPV6, SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return netip.AddrPortFrom(netip.AddrFrom16(raw.Addr.Addr), raw.Addr.Port), nil
|
||||
}
|
||||
|
|
|
@ -5,10 +5,10 @@ package redir
|
|||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err error) {
|
||||
err = errors.New("unsupported platform")
|
||||
return
|
||||
func GetOriginalDestination(conn net.Conn) (destination *M.AddrPort, err error) {
|
||||
return nil, errors.New("unsupported platform")
|
||||
}
|
||||
|
|
131
common/redir/tproxy_linux.go
Normal file
131
common/redir/tproxy_linux.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
package redir
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"syscall"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func TProxy(fd uintptr, isIPv6 bool) error {
|
||||
err := syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isIPv6 {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func TProxyUDP(fd uintptr, isIPv6 bool) error {
|
||||
err := syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
|
||||
}
|
||||
|
||||
func GetOriginalDestinationFromOOB(oob []byte) (*M.AddrPort, error) {
|
||||
controlMessages, err := unix.ParseSocketControlMessage(oob)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, message := range controlMessages {
|
||||
if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR {
|
||||
return M.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil
|
||||
} else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR {
|
||||
return M.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil
|
||||
}
|
||||
}
|
||||
return nil, E.New("not found")
|
||||
}
|
||||
|
||||
func DialUDP(network string, lAddr *net.UDPAddr, rAddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
rSockAddr, err := udpAddrToSockAddr(rAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lSockAddr, err := udpAddrToSockAddr(lAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fd, err := syscall.Socket(udpAddrFamily(network, lAddr, rAddr), syscall.SOCK_DGRAM, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
|
||||
syscall.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = syscall.SetsockoptInt(fd, syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil {
|
||||
syscall.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = syscall.Bind(fd, lSockAddr); err != nil {
|
||||
syscall.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = syscall.Connect(fd, rSockAddr); err != nil {
|
||||
syscall.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fdFile := os.NewFile(uintptr(fd), fmt.Sprintf("net-udp-dial-%s", rAddr.String()))
|
||||
defer fdFile.Close()
|
||||
|
||||
c, err := net.FileConn(fdFile)
|
||||
if err != nil {
|
||||
syscall.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.(*net.UDPConn), nil
|
||||
}
|
||||
|
||||
func udpAddrToSockAddr(addr *net.UDPAddr) (syscall.Sockaddr, error) {
|
||||
switch {
|
||||
case addr.IP.To4() != nil:
|
||||
ip := [4]byte{}
|
||||
copy(ip[:], addr.IP.To4())
|
||||
|
||||
return &syscall.SockaddrInet4{Addr: ip, Port: addr.Port}, nil
|
||||
|
||||
default:
|
||||
ip := [16]byte{}
|
||||
copy(ip[:], addr.IP.To16())
|
||||
|
||||
zoneID, err := strconv.ParseUint(addr.Zone, 10, 32)
|
||||
if err != nil {
|
||||
zoneID = 0
|
||||
}
|
||||
|
||||
return &syscall.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func udpAddrFamily(net string, lAddr, rAddr *net.UDPAddr) int {
|
||||
switch net[len(net)-1] {
|
||||
case '4':
|
||||
return syscall.AF_INET
|
||||
case '6':
|
||||
return syscall.AF_INET6
|
||||
}
|
||||
|
||||
if (lAddr == nil || lAddr.IP.To4() != nil) && (rAddr == nil || lAddr.IP.To4() != nil) {
|
||||
return syscall.AF_INET
|
||||
}
|
||||
return syscall.AF_INET6
|
||||
}
|
20
common/redir/tproxy_other.go
Normal file
20
common/redir/tproxy_other.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
//go:build !linux
|
||||
|
||||
package redir
|
||||
|
||||
import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func TProxy(fd uintptr, isIPv6 bool) error {
|
||||
return E.New("only available on linux")
|
||||
}
|
||||
|
||||
func TProxyUDP(fd uintptr, isIPv6 bool) error {
|
||||
return E.New("only available on linux")
|
||||
}
|
||||
|
||||
func GetOriginalDestinationFromOOB(oob []byte) (*M.AddrPort, error) {
|
||||
return nil, E.New("only available on linux")
|
||||
}
|
30
common/replay/bloomring.go
Normal file
30
common/replay/bloomring.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package replay
|
||||
|
||||
import (
|
||||
"github.com/v2fly/ss-bloomring"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func NewBloomRing() Filter {
|
||||
const (
|
||||
DefaultSFCapacity = 1e6
|
||||
DefaultSFFPR = 1e-6
|
||||
DefaultSFSlot = 10
|
||||
)
|
||||
return &bloomRingFilter{BloomRing: ss_bloomring.NewBloomRing(DefaultSFSlot, DefaultSFCapacity, DefaultSFFPR)}
|
||||
}
|
||||
|
||||
type bloomRingFilter struct {
|
||||
sync.Mutex
|
||||
*ss_bloomring.BloomRing
|
||||
}
|
||||
|
||||
func (f *bloomRingFilter) Check(sum []byte) bool {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
if f.Test(sum) {
|
||||
return false
|
||||
}
|
||||
f.Add(sum)
|
||||
return true
|
||||
}
|
50
common/replay/cuckoo.go
Normal file
50
common/replay/cuckoo.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package replay
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/seiflotfy/cuckoofilter"
|
||||
)
|
||||
|
||||
func NewCuckoo(interval int64) Filter {
|
||||
filter := &cuckooFilter{}
|
||||
filter.interval = interval
|
||||
return filter
|
||||
}
|
||||
|
||||
type cuckooFilter struct {
|
||||
lock sync.Mutex
|
||||
poolA *cuckoo.Filter
|
||||
poolB *cuckoo.Filter
|
||||
poolSwap bool
|
||||
lastSwap int64
|
||||
interval int64
|
||||
}
|
||||
|
||||
func (filter *cuckooFilter) Check(sum []byte) bool {
|
||||
const defaultCapacity = 100000
|
||||
|
||||
filter.lock.Lock()
|
||||
defer filter.lock.Unlock()
|
||||
|
||||
now := time.Now().Unix()
|
||||
if filter.lastSwap == 0 {
|
||||
filter.lastSwap = now
|
||||
filter.poolA = cuckoo.NewFilter(defaultCapacity)
|
||||
filter.poolB = cuckoo.NewFilter(defaultCapacity)
|
||||
}
|
||||
|
||||
elapsed := now - filter.lastSwap
|
||||
if elapsed >= filter.interval {
|
||||
if filter.poolSwap {
|
||||
filter.poolA.Reset()
|
||||
} else {
|
||||
filter.poolB.Reset()
|
||||
}
|
||||
filter.poolSwap = !filter.poolSwap
|
||||
filter.lastSwap = now
|
||||
}
|
||||
|
||||
return filter.poolA.InsertUnique(sum) && filter.poolB.InsertUnique(sum)
|
||||
}
|
5
common/replay/filter.go
Normal file
5
common/replay/filter.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package replay
|
||||
|
||||
type Filter interface {
|
||||
Check(sum []byte) bool
|
||||
}
|
|
@ -40,6 +40,18 @@ func ReadFromVar(writerVar *io.Writer, reader io.Reader) (int64, error) {
|
|||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
||||
return task.Run(context.Background(), func() error {
|
||||
defer CloseRead(conn)
|
||||
defer CloseWrite(dest)
|
||||
return common.Error(io.Copy(dest, conn))
|
||||
}, func() error {
|
||||
defer CloseRead(dest)
|
||||
defer CloseWrite(conn)
|
||||
return common.Error(io.Copy(conn, dest))
|
||||
})
|
||||
}
|
||||
|
||||
func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error {
|
||||
return task.Run(ctx, func() error {
|
||||
buffer := buf.FullNew()
|
||||
|
|
|
@ -2,8 +2,9 @@ package rw
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"github.com/sagernet/sing/common"
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
type InputStream interface {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"strconv"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type Network int
|
||||
|
@ -20,8 +20,8 @@ type InstanceContext struct{}
|
|||
type Context struct {
|
||||
InstanceContext
|
||||
Network Network
|
||||
Source socksaddr.Addr
|
||||
Destination socksaddr.Addr
|
||||
Source M.Addr
|
||||
Destination M.Addr
|
||||
SourcePort uint16
|
||||
DestinationPort uint16
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ func (c Context) DestinationNetAddr() string {
|
|||
return net.JoinHostPort(c.Destination.String(), strconv.Itoa(int(c.DestinationPort)))
|
||||
}
|
||||
|
||||
func AddressFromNetAddr(netAddr net.Addr) (addr socksaddr.Addr, port uint16) {
|
||||
func AddressFromNetAddr(netAddr net.Addr) (addr M.Addr, port uint16) {
|
||||
var ip net.IP
|
||||
switch addr := netAddr.(type) {
|
||||
case *net.TCPAddr:
|
||||
|
@ -40,7 +40,7 @@ func AddressFromNetAddr(netAddr net.Addr) (addr socksaddr.Addr, port uint16) {
|
|||
ip = addr.IP
|
||||
port = uint16(addr.Port)
|
||||
}
|
||||
return socksaddr.AddrFromIP(ip), port
|
||||
return M.AddrFromIP(ip), port
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
|
|
108
common/udpnat/server.go
Normal file
108
common/udpnat/server.go
Normal file
|
@ -0,0 +1,108 @@
|
|||
package udpnat
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/gsync"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/redir"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
socks.UDPConnectionHandler
|
||||
E.Handler
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
udpNat gsync.Map[string, *packetConn]
|
||||
handler Handler
|
||||
}
|
||||
|
||||
func NewServer(handler Handler) *Server {
|
||||
return &Server{handler: handler}
|
||||
}
|
||||
|
||||
func (s *Server) HandleUDP(buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
conn, loaded := s.udpNat.LoadOrStore(metadata.Source.String(), func() *packetConn {
|
||||
return &packetConn{source: metadata.Source.UDPAddr(), in: make(chan *udpPacket)}
|
||||
})
|
||||
if !loaded {
|
||||
go func() {
|
||||
err := s.handler.NewPacketConnection(conn, metadata)
|
||||
if err != nil {
|
||||
s.handler.HandleError(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
conn.in <- &udpPacket{
|
||||
buffer: buffer,
|
||||
destination: metadata.Destination,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) OnError(err error) {
|
||||
s.handler.HandleError(err)
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
s.udpNat.Range(func(key string, conn *packetConn) bool {
|
||||
conn.Close()
|
||||
return true
|
||||
})
|
||||
s.udpNat = gsync.Map[string, *packetConn]{}
|
||||
return nil
|
||||
}
|
||||
|
||||
type packetConn struct {
|
||||
socks.PacketConnStub
|
||||
source *net.UDPAddr
|
||||
in chan *udpPacket
|
||||
}
|
||||
|
||||
type udpPacket struct {
|
||||
buffer *buf.Buffer
|
||||
destination *M.AddrPort
|
||||
}
|
||||
|
||||
func (c *packetConn) LocalAddr() net.Addr {
|
||||
return c.source
|
||||
}
|
||||
|
||||
func (c *packetConn) Close() error {
|
||||
select {
|
||||
case <-c.in:
|
||||
return io.ErrClosedPipe
|
||||
default:
|
||||
close(c.in)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *packetConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
select {
|
||||
case packet, ok := <-c.in:
|
||||
if !ok {
|
||||
return nil, io.ErrClosedPipe
|
||||
}
|
||||
defer packet.buffer.Release()
|
||||
if buffer.FreeLen() < packet.buffer.Len() {
|
||||
return nil, io.ErrShortBuffer
|
||||
}
|
||||
return packet.destination, common.Error(buffer.Write(packet.buffer.Bytes()))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *packetConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), c.source)
|
||||
if err != nil {
|
||||
return E.Cause(err, "tproxy udp write back")
|
||||
}
|
||||
defer udpConn.Close()
|
||||
return common.Error(udpConn.Write(buffer.Bytes()))
|
||||
}
|
80
common/uot/client.go
Normal file
80
common/uot/client.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package uot
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type ClientConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func NewClientConn(conn net.Conn) *ClientConn {
|
||||
return &ClientConn{conn}
|
||||
}
|
||||
|
||||
func (c *ClientConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
destination, err := AddrParser.ReadAddrPort(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if buffer.FreeLen() < int(length) {
|
||||
return nil, io.ErrShortBuffer
|
||||
}
|
||||
return destination, common.Error(buffer.ReadFullFrom(c, int(length)))
|
||||
}
|
||||
|
||||
func (c *ClientConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
err := AddrParser.WriteAddrPort(c, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Write(c, binary.BigEndian, uint16(buffer.Len()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(c.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (c *ClientConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
addrPort, err := AddrParser.ReadAddrPort(c)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if len(p) < int(length) {
|
||||
return 0, nil, io.ErrShortBuffer
|
||||
}
|
||||
n, err = io.ReadAtLeast(c, p, int(length))
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
addr = addrPort.UDPAddr()
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
err = AddrParser.WriteAddrPort(c, M.AddrPortFromNetAddr(addr))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Write(c, binary.BigEndian, uint16(len(p)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return c.Write(p)
|
||||
}
|
21
common/uot/resolver.go
Normal file
21
common/uot/resolver.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package uot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
var LookupAddress func(domain string) (net.IP, error)
|
||||
|
||||
func init() {
|
||||
LookupAddress = func(domain string) (net.IP, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", domain)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ips[0], nil
|
||||
}
|
||||
}
|
108
common/uot/server.go
Normal file
108
common/uot/server.go
Normal file
|
@ -0,0 +1,108 @@
|
|||
package uot
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type ServerConn struct {
|
||||
net.PacketConn
|
||||
inputReader, outputReader *io.PipeReader
|
||||
inputWriter, outputWriter *io.PipeWriter
|
||||
}
|
||||
|
||||
func NewServerConn(packetConn net.PacketConn) net.Conn {
|
||||
c := &ServerConn{
|
||||
PacketConn: packetConn,
|
||||
}
|
||||
c.inputReader, c.inputWriter = io.Pipe()
|
||||
c.outputReader, c.outputWriter = io.Pipe()
|
||||
go c.loopInput()
|
||||
go c.loopOutput()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *ServerConn) Read(b []byte) (n int, err error) {
|
||||
return c.outputReader.Read(b)
|
||||
}
|
||||
|
||||
func (c *ServerConn) Write(b []byte) (n int, err error) {
|
||||
return c.inputWriter.Write(b)
|
||||
}
|
||||
|
||||
func (c *ServerConn) RemoteAddr() net.Addr {
|
||||
return &common.DummyAddr{}
|
||||
}
|
||||
|
||||
func (c *ServerConn) loopInput() {
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
for {
|
||||
destination, err := AddrParser.ReadAddrPort(c.inputReader)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if destination.Addr.Family().IsFqdn() {
|
||||
ip, err := LookupAddress(destination.Addr.Fqdn())
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
destination.Addr = M.AddrFromIP(ip)
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.inputReader, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
buffer.FullReset()
|
||||
_, err = buffer.ReadFullFrom(c.inputReader, int(length))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
_, err = c.WriteTo(buffer.Bytes(), destination.UDPAddr())
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
|
||||
func (c *ServerConn) loopOutput() {
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
for {
|
||||
buffer.FullReset()
|
||||
n, addr, err := buffer.ReadPacketFrom(c)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
destination := M.AddrPortFromNetAddr(addr)
|
||||
err = AddrParser.WriteAddrPort(c.outputWriter, destination)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = binary.Write(c.outputWriter, binary.BigEndian, uint16(n))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
_, err = buffer.WriteTo(c.outputWriter)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
|
||||
func (c *ServerConn) Close() error {
|
||||
c.inputReader.Close()
|
||||
c.inputWriter.Close()
|
||||
c.outputReader.Close()
|
||||
c.outputWriter.Close()
|
||||
c.PacketConn.Close()
|
||||
return nil
|
||||
}
|
13
common/uot/uot.go
Normal file
13
common/uot/uot.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package uot
|
||||
|
||||
import (
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
const UOTMagicAddress = "sp.udp-over-tcp.arpa"
|
||||
|
||||
var AddrParser = M.NewSerializer(
|
||||
M.AddressFamilyByte(0x00, M.AddressFamilyIPv4),
|
||||
M.AddressFamilyByte(0x01, M.AddressFamilyIPv6),
|
||||
M.AddressFamilyByte(0x02, M.AddressFamilyFqdn),
|
||||
)
|
39
common/uot/uot_test.go
Normal file
39
common/uot/uot_test.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package uot
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/uot/common"
|
||||
"github.com/sagernet/uot/common/buf"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
)
|
||||
|
||||
func TestServerConn(t *testing.T) {
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
common.Must(err)
|
||||
serverConn := NewServerConn(udpConn)
|
||||
defer serverConn.Close()
|
||||
clientConn := NewClientConn(serverConn)
|
||||
message := new(dnsmessage.Message)
|
||||
message.Header.ID = 1
|
||||
message.Header.RecursionDesired = true
|
||||
message.Questions = append(message.Questions, dnsmessage.Question{
|
||||
Name: dnsmessage.MustNewName("google.com."),
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
})
|
||||
packet, err := message.Pack()
|
||||
common.Must(err)
|
||||
common.Must1(clientConn.WriteTo(packet, &net.UDPAddr{
|
||||
IP: net.IPv4(8, 8, 8, 8),
|
||||
Port: 53,
|
||||
}))
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
common.Must2(buffer.ReadPacketFrom(clientConn))
|
||||
common.Must(message.Unpack(buffer.Bytes()))
|
||||
for _, answer := range message.Answers {
|
||||
t.Log("got answer :", answer.Body)
|
||||
}
|
||||
}
|
5
core.go
5
core.go
|
@ -1,3 +1,6 @@
|
|||
package sing
|
||||
|
||||
const Version = "v0.0.0-alpha.1"
|
||||
const (
|
||||
Version = "v0.0.0-alpha.1"
|
||||
VersionStr = "sing " + Version
|
||||
)
|
||||
|
|
16
go.mod
16
go.mod
|
@ -6,20 +6,32 @@ require (
|
|||
github.com/klauspost/compress v1.15.1
|
||||
github.com/openacid/low v0.1.21
|
||||
github.com/oschwald/geoip2-golang v1.7.0
|
||||
github.com/sagernet/uot v0.0.0-20220403125237-bf82029ad617
|
||||
github.com/samber/lo v1.11.0
|
||||
github.com/seiflotfy/cuckoofilter v0.0.0-20201222105146-bc6005554a0c
|
||||
github.com/sirupsen/logrus v1.8.1
|
||||
github.com/spf13/cobra v1.4.0
|
||||
github.com/stretchr/testify v1.7.1
|
||||
github.com/ulikunitz/xz v0.5.10
|
||||
github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e
|
||||
github.com/v2fly/v2ray-core/v5 v5.0.3
|
||||
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29
|
||||
golang.org/x/crypto v0.0.0-20220408190544-5352b0902921
|
||||
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3
|
||||
golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f
|
||||
google.golang.org/protobuf v1.28.0
|
||||
lukechampine.com/blake3 v1.1.7
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165 // indirect
|
||||
github.com/golang/protobuf v1.5.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.0.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.0.9 // indirect
|
||||
github.com/oschwald/maxminddb-golang v1.9.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
|
||||
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
|
||||
)
|
||||
|
|
30
go.sum
30
go.sum
|
@ -3,6 +3,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t
|
|||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165 h1:BS21ZUJ/B5X2UVUbczfmdWH7GapPWAhxcMsDnjJTU1E=
|
||||
github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
|
||||
|
@ -13,6 +15,10 @@ github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NH
|
|||
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
||||
github.com/klauspost/compress v1.15.1 h1:y9FcTHGyrebwfP0ZZqFiaxTaiDnUrGkJkI+f583BL1A=
|
||||
github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
||||
github.com/openacid/errors v0.8.1/go.mod h1:GUQEJJOJE3W9skHm8E8Y4phdl2LLEN8iD7c5gcGgdx0=
|
||||
github.com/openacid/low v0.1.21 h1:Tr2GNu4N/+rGRYdOsEHOE89cxUIaDViZbVmKz29uKGo=
|
||||
github.com/openacid/low v0.1.21/go.mod h1:q+MsKI6Pz2xsCkzV4BLj7NR5M4EX0sGz5AqotpZDVh0=
|
||||
|
@ -25,9 +31,15 @@ github.com/oschwald/maxminddb-golang v1.9.0/go.mod h1:TK+s/Z2oZq0rSl4PSeAEoP0bgm
|
|||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg=
|
||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3/go.mod h1:HgjTstvQsPGkxUsCd2KWxErBblirPizecHcpD3ffK+s=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sagernet/uot v0.0.0-20220403125237-bf82029ad617 h1:h46Ocvf7zWpatqOHcR4kw+k2GbGcMM7EzGjYG7wiGfM=
|
||||
github.com/sagernet/uot v0.0.0-20220403125237-bf82029ad617/go.mod h1:T2LhXiIIvaoeKii21x1GONCee9u7N9Nnrqz5bY3SWsM=
|
||||
github.com/samber/lo v1.11.0 h1:JfeYozXL1xfkhRUFOfH13ociyeiLSC/GRJjGKI668xM=
|
||||
github.com/samber/lo v1.11.0/go.mod h1:2I7tgIv8Q1SG2xEIkRq0F2i2zgxVpnyPOP0d3Gj2r+A=
|
||||
github.com/seiflotfy/cuckoofilter v0.0.0-20201222105146-bc6005554a0c h1:pqy40B3MQWYrza7YZXOXgl0Nf0QGFqrOC0BKae1UNAA=
|
||||
github.com/seiflotfy/cuckoofilter v0.0.0-20201222105146-bc6005554a0c/go.mod h1:bR6DqgcAl1zTcOX8/pE2Qkj9XO00eCNqmKb7lXP8EAg=
|
||||
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
|
||||
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
|
||||
github.com/spf13/cobra v1.4.0 h1:y+wJpx64xcgO1V+RcnwW0LEHxTKRi2ZDPSBjWnrg88Q=
|
||||
|
@ -40,18 +52,23 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
|
|||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M=
|
||||
github.com/ulikunitz/xz v0.5.10 h1:t92gobL9l3HE202wg3rlk19F6X+JOxl9BBrCCMYEYd8=
|
||||
github.com/ulikunitz/xz v0.5.10/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e h1:5QefA066A1tF8gHIiADmOVOV5LS43gt3ONnlEl3xkwI=
|
||||
github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e/go.mod h1:5t19P9LBIrNamL6AcMQOncg/r10y3Pc01AbHeMhwlpU=
|
||||
github.com/v2fly/v2ray-core/v5 v5.0.3 h1:2rnJ9vZbBQ7V4upWsoUVYGoqZl4grrx8SxOReKx+jjc=
|
||||
github.com/v2fly/v2ray-core/v5 v5.0.3/go.mod h1:zhDdsUJcNE8LcLRA3l7fEQ6QLuveD4/OLbQM2CceSHM=
|
||||
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 h1:tkVvjkPTB7pnW3jnid7kNyAMPVWllTNOf/qKDze4p9o=
|
||||
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220408190544-5352b0902921 h1:iU7T1X1J6yxDr0rda54sWGkHgOp5XJrqm79gcNlC2VM=
|
||||
golang.org/x/crypto v0.0.0-20220408190544-5352b0902921/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
|
||||
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 h1:EN5+DfgmRMvRUrMGERW2gQl3Vc+Z7ZMnI/xdEpPSf0c=
|
||||
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0Pp5Fn9Qga0mrgxyERQM=
|
||||
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f h1:8w7RhxzTVgUzw/AH/9mUV5q0vMgy40SQRursCcfmkCw=
|
||||
golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
|
@ -59,6 +76,11 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ
|
|||
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
lukechampine.com/blake3 v1.1.7 h1:GgRMhmdsuK8+ii6UZFDL8Nb+VyMwadAgcJyfYHxG6n0=
|
||||
lukechampine.com/blake3 v1.1.7/go.mod h1:tkKEOtDkNtklkXtLNEOGNq5tcV90tJiA1vAA12R78LA=
|
||||
|
|
|
@ -1,92 +1,42 @@
|
|||
package system
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common/auth"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/redir"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/transport/tcp"
|
||||
)
|
||||
|
||||
type MixedListener struct {
|
||||
*SocksListener
|
||||
UDPListener *UDPListener
|
||||
mcfg *MixedConfig
|
||||
laddr *net.TCPAddr
|
||||
type Handler interface {
|
||||
tcp.Handler
|
||||
}
|
||||
|
||||
type MixedConfig struct {
|
||||
SocksConfig
|
||||
Redirect bool
|
||||
TProxy bool
|
||||
}
|
||||
|
||||
func NewMixedListener(bind netip.AddrPort, config *MixedConfig, handler SocksHandler) *MixedListener {
|
||||
listener := &MixedListener{SocksListener: NewSocksListener(bind, &config.SocksConfig, handler), mcfg: config}
|
||||
listener.TCPListener.Handler = listener
|
||||
if config.TProxy {
|
||||
listener.UDPListener = NewUDPListener(bind, listener)
|
||||
}
|
||||
return listener
|
||||
}
|
||||
|
||||
func (l *MixedListener) HandleTCP(conn net.Conn) error {
|
||||
if l.mcfg.Redirect {
|
||||
var destination netip.AddrPort
|
||||
destination, err := redir.GetOriginalDestination(conn)
|
||||
if err == nil {
|
||||
return l.Handler.NewConnection(socksaddr.AddrFromAddr(destination.Addr()), destination.Port(), conn)
|
||||
}
|
||||
} else if l.mcfg.TProxy {
|
||||
lAddr := conn.LocalAddr().(*net.TCPAddr)
|
||||
rAddr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
|
||||
if lAddr.Port != l.laddr.Port || !lAddr.IP.Equal(rAddr.IP) && !lAddr.IP.IsLoopback() && !lAddr.IP.IsPrivate() {
|
||||
addr, port := socksaddr.AddrFromNetAddr(lAddr)
|
||||
return l.Handler.NewConnection(addr, port, conn)
|
||||
}
|
||||
}
|
||||
|
||||
bufConn := buf.NewBufferedConn(conn)
|
||||
hdr, err := bufConn.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = bufConn.UnreadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hdr == socks.Version4 || hdr == socks.Version5 {
|
||||
return l.SocksListener.HandleTCP(bufConn)
|
||||
}
|
||||
|
||||
func HandleConnection(conn *buf.BufferedConn, authenticator auth.Authenticator, handler Handler) error {
|
||||
var httpClient *http.Client
|
||||
for {
|
||||
request, err := readRequest(bufConn.Reader())
|
||||
request, err := readRequest(conn.Reader())
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "read http request")
|
||||
return E.Cause(err, "read http request")
|
||||
}
|
||||
|
||||
if l.Username != "" {
|
||||
if authenticator != nil {
|
||||
var authOk bool
|
||||
authorization := request.Header.Get("Proxy-Authorization")
|
||||
if strings.HasPrefix(authorization, "BASIC ") {
|
||||
userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
|
||||
if string(userPassword) == l.Username+":"+l.Password {
|
||||
authOk = true
|
||||
}
|
||||
userPswdArr := strings.SplitN(string(userPassword), ":", 2)
|
||||
authOk = authenticator.Verify(userPswdArr[0], userPswdArr[1])
|
||||
}
|
||||
if !authOk {
|
||||
err = responseWith(request, http.StatusProxyAuthRequired).Write(conn)
|
||||
|
@ -97,23 +47,23 @@ func (l *MixedListener) HandleTCP(conn net.Conn) error {
|
|||
}
|
||||
|
||||
if request.Method == "CONNECT" {
|
||||
host := request.URL.Hostname()
|
||||
portStr := request.URL.Port()
|
||||
if portStr == "" {
|
||||
portStr = "80"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
destination, err := M.ParseAddrPort(request.URL.Hostname(), portStr)
|
||||
if err != nil {
|
||||
err = responseWith(request, http.StatusBadRequest).Write(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err = fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established")
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write http response")
|
||||
return E.Cause(err, "write http response")
|
||||
}
|
||||
return l.Handler.NewConnection(socksaddr.ParseAddr(host), uint16(port), bufConn)
|
||||
return handler.NewConnection(conn, M.Metadata{
|
||||
Destination: destination,
|
||||
})
|
||||
}
|
||||
|
||||
keepAlive := strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
|
||||
|
@ -141,19 +91,21 @@ func (l *MixedListener) HandleTCP(conn net.Conn) error {
|
|||
ExpectContinueTimeout: 1 * time.Second,
|
||||
DialContext: func(context context.Context, network, address string) (net.Conn, error) {
|
||||
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
||||
return nil, exceptions.New("unsupported network ", network)
|
||||
return nil, E.New("unsupported network ", network)
|
||||
}
|
||||
|
||||
addr, port, err := socksaddr.ParseAddrPort(address)
|
||||
destination, err := M.ParseAddress(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
left, right := net.Pipe()
|
||||
go func() {
|
||||
err = l.Handler.NewConnection(addr, port, right)
|
||||
err = handler.NewConnection(right, M.Metadata{
|
||||
Destination: destination,
|
||||
})
|
||||
if err != nil {
|
||||
l.OnError(err)
|
||||
handler.HandleError(err)
|
||||
}
|
||||
}()
|
||||
return left, nil
|
||||
|
@ -167,7 +119,7 @@ func (l *MixedListener) HandleTCP(conn net.Conn) error {
|
|||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
l.OnError(exceptions.Cause(err, "http proxy"))
|
||||
handler.HandleError(err)
|
||||
return responseWith(request, http.StatusBadGateway).Write(conn)
|
||||
}
|
||||
|
||||
|
@ -183,17 +135,14 @@ func (l *MixedListener) HandleTCP(conn net.Conn) error {
|
|||
|
||||
err = response.Write(conn)
|
||||
if err != nil {
|
||||
l.OnError(exceptions.Cause(err, "http proxy"))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *MixedListener) HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error {
|
||||
return nil
|
||||
}
|
||||
//go:linkname readRequest net/http.ReadRequest
|
||||
func readRequest(b *bufio.Reader) (req *http.Request, err error)
|
||||
|
||||
// removeHopByHopHeaders remove hop-by-hop header
|
||||
func removeHopByHopHeaders(header http.Header) {
|
||||
// Strip hop-by-hop header based on RFC:
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1
|
||||
|
@ -217,8 +166,6 @@ func removeHopByHopHeaders(header http.Header) {
|
|||
}
|
||||
}
|
||||
|
||||
// removeExtraHTTPHostPort remove extra host port (example.com:80 --> example.com)
|
||||
// It resolves the behavior of some HTTP servers that do not handle host:80 (e.g. baidu.com)
|
||||
func removeExtraHTTPHostPort(req *http.Request) {
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
|
@ -243,39 +190,3 @@ func responseWith(request *http.Request, statusCode int) *http.Response {
|
|||
Header: http.Header{},
|
||||
}
|
||||
}
|
||||
|
||||
func (l *MixedListener) Start() error {
|
||||
err := l.TCPListener.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if l.mcfg.TProxy {
|
||||
rawConn, err := l.TCPListener.TCPListener.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var rawFd uintptr
|
||||
err = rawConn.Control(func(fd uintptr) {
|
||||
rawFd = fd
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = syscall.SetsockoptInt(int(rawFd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "failed to configure TCP tproxy")
|
||||
}
|
||||
}
|
||||
if l.mcfg.TProxy {
|
||||
l.laddr = l.TCPListener.Addr().(*net.TCPAddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *MixedListener) Close() error {
|
||||
return l.TCPListener.Close()
|
||||
}
|
||||
|
||||
func (l *MixedListener) OnError(err error) {
|
||||
l.Handler.OnError(exceptions.Cause(err, "mixed server"))
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
package shadowsocks
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/list"
|
||||
)
|
||||
|
||||
type Cipher interface {
|
||||
KeySize() int
|
||||
SaltSize() int
|
||||
CreateReader(key []byte, salt []byte, reader io.Reader) io.Reader
|
||||
CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer
|
||||
EncodePacket(key []byte, buffer *buf.Buffer) error
|
||||
DecodePacket(key []byte, buffer *buf.Buffer) error
|
||||
}
|
||||
|
||||
type CipherCreator func() Cipher
|
||||
|
||||
var (
|
||||
cipherList *list.List[string]
|
||||
cipherMap map[string]CipherCreator
|
||||
)
|
||||
|
||||
func init() {
|
||||
cipherList = new(list.List[string])
|
||||
cipherMap = make(map[string]CipherCreator)
|
||||
}
|
||||
|
||||
func RegisterCipher(method string, creator CipherCreator) {
|
||||
cipherList.PushBack(method)
|
||||
cipherMap[method] = creator
|
||||
}
|
||||
|
||||
func CreateCipher(method string) (Cipher, error) {
|
||||
creator := cipherMap[method]
|
||||
if creator != nil {
|
||||
return creator(), nil
|
||||
}
|
||||
return nil, exceptions.New("unsupported method: ", method)
|
||||
}
|
||||
|
||||
func ListCiphers() []string {
|
||||
return cipherList.Array()
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
package shadowsocks
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterCipher("none", func() Cipher {
|
||||
return (*NoneCipher)(nil)
|
||||
})
|
||||
}
|
||||
|
||||
type NoneCipher struct{}
|
||||
|
||||
func (c *NoneCipher) KeySize() int {
|
||||
return 16
|
||||
}
|
||||
|
||||
func (c *NoneCipher) SaltSize() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *NoneCipher) CreateReader(_ []byte, _ []byte, reader io.Reader) io.Reader {
|
||||
return reader
|
||||
}
|
||||
|
||||
func (c *NoneCipher) CreateWriter(key []byte, iv []byte, writer io.Writer) io.Writer {
|
||||
return writer
|
||||
}
|
||||
|
||||
func (c *NoneCipher) EncodePacket([]byte, *buf.Buffer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NoneCipher) DecodePacket([]byte, *buf.Buffer) error {
|
||||
return nil
|
||||
}
|
|
@ -1,178 +0,0 @@
|
|||
package shadowsocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBadKey = exceptions.New("bad key")
|
||||
ErrMissingPassword = exceptions.New("password not specified")
|
||||
)
|
||||
|
||||
type ClientConfig struct {
|
||||
Server string `json:"server"`
|
||||
ServerPort uint16 `json:"server_port"`
|
||||
Method string `json:"method"`
|
||||
Password []byte `json:"password"`
|
||||
Key []byte `json:"key"`
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
dialer *net.Dialer
|
||||
cipher Cipher
|
||||
server string
|
||||
key []byte
|
||||
}
|
||||
|
||||
func NewClient(dialer *net.Dialer, config *ClientConfig) (*Client, error) {
|
||||
if config.Server == "" {
|
||||
return nil, exceptions.New("missing server address")
|
||||
}
|
||||
if config.ServerPort == 0 {
|
||||
return nil, exceptions.New("missing server port")
|
||||
}
|
||||
if config.Method == "" {
|
||||
return nil, exceptions.New("missing server method")
|
||||
}
|
||||
|
||||
cipher, err := CreateCipher(config.Method)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := &Client{
|
||||
dialer: dialer,
|
||||
cipher: cipher,
|
||||
server: net.JoinHostPort(config.Server, strconv.Itoa(int(config.ServerPort))),
|
||||
}
|
||||
if keyLen := len(config.Key); keyLen > 0 {
|
||||
if keyLen == cipher.KeySize() {
|
||||
client.key = config.Key
|
||||
} else {
|
||||
return nil, ErrBadKey
|
||||
}
|
||||
} else if len(config.Password) > 0 {
|
||||
client.key = Key(config.Password, cipher.KeySize())
|
||||
} else {
|
||||
return nil, ErrMissingPassword
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) DialContextTCP(ctx context.Context, addr socksaddr.Addr, port uint16) (net.Conn, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, "tcp", c.server)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "connect to server")
|
||||
}
|
||||
return c.DialConn(conn, addr, port), nil
|
||||
}
|
||||
|
||||
func (c *Client) DialConn(conn net.Conn, addr socksaddr.Addr, port uint16) net.Conn {
|
||||
header := buf.New()
|
||||
header.WriteRandom(c.cipher.SaltSize())
|
||||
writer := &buf.BufferedWriter{
|
||||
Writer: conn,
|
||||
Buffer: header,
|
||||
}
|
||||
protocolWriter := c.cipher.CreateWriter(c.key, header.Bytes(), writer)
|
||||
requestBuffer := buf.New()
|
||||
contentWriter := &buf.BufferedWriter{
|
||||
Writer: protocolWriter,
|
||||
Buffer: requestBuffer,
|
||||
}
|
||||
common.Must(AddressSerializer.WriteAddressAndPort(contentWriter, addr, port))
|
||||
return &shadowsocksConn{
|
||||
Client: c,
|
||||
Conn: conn,
|
||||
Writer: &common.FlushOnceWriter{Writer: contentWriter},
|
||||
}
|
||||
}
|
||||
|
||||
type shadowsocksConn struct {
|
||||
*Client
|
||||
net.Conn
|
||||
io.Writer
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (c *shadowsocksConn) Read(p []byte) (n int, err error) {
|
||||
if c.reader == nil {
|
||||
buffer := buf.Or(p, c.cipher.SaltSize())
|
||||
defer buffer.Release()
|
||||
_, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn)
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *shadowsocksConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if c.reader == nil {
|
||||
buffer := buf.NewSize(c.cipher.SaltSize())
|
||||
defer buffer.Release()
|
||||
_, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn)
|
||||
}
|
||||
return c.reader.(io.WriterTo).WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *shadowsocksConn) Write(p []byte) (n int, err error) {
|
||||
return c.Writer.Write(p)
|
||||
}
|
||||
|
||||
func (c *shadowsocksConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
return rw.ReadFromVar(&c.Writer, r)
|
||||
}
|
||||
|
||||
func (c *Client) DialContextUDP(ctx context.Context) socks.PacketConn {
|
||||
conn, err := c.dialer.DialContext(ctx, "udp", c.server)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &shadowsocksPacketConn{c, conn}
|
||||
}
|
||||
|
||||
type shadowsocksPacketConn struct {
|
||||
*Client
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *shadowsocksPacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
header.WriteRandom(c.cipher.SaltSize())
|
||||
common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port))
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
err := c.cipher.EncodePacket(c.key, buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(c.Conn.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (c *shadowsocksPacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) {
|
||||
n, err := c.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = c.cipher.DecodePacket(c.key, buffer)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return AddressSerializer.ReadAddressAndPort(buffer)
|
||||
}
|
153
protocol/shadowsocks/none.go
Normal file
153
protocol/shadowsocks/none.go
Normal file
|
@ -0,0 +1,153 @@
|
|||
package shadowsocks
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
const MethodNone = "none"
|
||||
|
||||
type NoneMethod struct{}
|
||||
|
||||
func NewNone() Method {
|
||||
return &NoneMethod{}
|
||||
}
|
||||
|
||||
func (m *NoneMethod) Name() string {
|
||||
return MethodNone
|
||||
}
|
||||
|
||||
func (m *NoneMethod) KeyLength() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *NoneMethod) NewSession(key []byte) Session {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *NoneMethod) DialConn(_ Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||
shadowsocksConn := &noneConn{
|
||||
Conn: conn,
|
||||
handshake: true,
|
||||
destination: destination,
|
||||
}
|
||||
return shadowsocksConn, shadowsocksConn.clientHandshake()
|
||||
}
|
||||
|
||||
func (m *NoneMethod) DialEarlyConn(_ Session, conn net.Conn, destination *M.AddrPort) net.Conn {
|
||||
return &noneConn{
|
||||
Conn: conn,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *NoneMethod) DialPacketConn(_ Session, conn net.Conn) socks.PacketConn {
|
||||
return &nonePacketConn{conn}
|
||||
}
|
||||
|
||||
type noneConn struct {
|
||||
net.Conn
|
||||
|
||||
access sync.Mutex
|
||||
handshake bool
|
||||
destination *M.AddrPort
|
||||
}
|
||||
|
||||
func (c *noneConn) clientHandshake() error {
|
||||
err := socks.AddressSerializer.WriteAddrPort(c.Conn, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.handshake = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *noneConn) Write(b []byte) (n int, err error) {
|
||||
if c.handshake {
|
||||
goto direct
|
||||
}
|
||||
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
if c.handshake {
|
||||
goto direct
|
||||
}
|
||||
|
||||
{
|
||||
if len(b) == 0 {
|
||||
return 0, c.clientHandshake()
|
||||
}
|
||||
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
|
||||
err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
bufN, _ := buffer.Write(b)
|
||||
_, err = c.Conn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if bufN < len(b) {
|
||||
_, err = c.Conn.Write(b[bufN:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
n = len(b)
|
||||
}
|
||||
|
||||
direct:
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if !c.handshake {
|
||||
panic("missing client handshake")
|
||||
}
|
||||
return c.Conn.(io.ReaderFrom).ReadFrom(r)
|
||||
}
|
||||
|
||||
func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.Conn.(io.WriterTo).WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *noneConn) RemoteAddr() net.Addr {
|
||||
return c.destination.TCPAddr()
|
||||
}
|
||||
|
||||
type nonePacketConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
_, err := buffer.ReadFrom(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return socks.AddressSerializer.ReadAddrPort(buffer)
|
||||
}
|
||||
|
||||
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
|
||||
if err != nil {
|
||||
header.Release()
|
||||
return err
|
||||
}
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
return common.Error(buffer.WriteTo(c))
|
||||
}
|
|
@ -2,23 +2,26 @@ package shadowsocks
|
|||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"math/rand"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
"net"
|
||||
)
|
||||
|
||||
const MaxPacketSize = 16*1024 - 1
|
||||
type Session interface {
|
||||
Key() []byte
|
||||
ReplayFilter() replay.Filter
|
||||
}
|
||||
|
||||
func Kdf(key, iv []byte, keyLength int) []byte {
|
||||
subKey := make([]byte, keyLength)
|
||||
kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
|
||||
common.Must1(io.ReadFull(kdf, subKey))
|
||||
return subKey
|
||||
type Method interface {
|
||||
Name() string
|
||||
KeyLength() int
|
||||
DialConn(session Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error)
|
||||
DialEarlyConn(session Session, conn net.Conn, destination *M.AddrPort) net.Conn
|
||||
DialPacketConn(session Session, conn net.Conn) socks.PacketConn
|
||||
}
|
||||
|
||||
func Key(password []byte, keySize int) []byte {
|
||||
|
@ -43,19 +46,18 @@ func Key(password []byte, keySize int) []byte {
|
|||
return m[:keySize]
|
||||
}
|
||||
|
||||
func RemapToPrintable(input []byte) {
|
||||
const charSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!#$%&()*+,./:;<=>?@[]^_`{|}~\\\""
|
||||
seed := rand.New(rand.NewSource(int64(crc32.ChecksumIEEE(input))))
|
||||
for i := range input {
|
||||
input[i] = charSet[seed.Intn(len(charSet))]
|
||||
}
|
||||
type ReducedEntropyReader struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
var AddressSerializer = socksaddr.NewSerializer(
|
||||
socksaddr.AddressFamilyByte(0x01, socksaddr.AddressFamilyIPv4),
|
||||
socksaddr.AddressFamilyByte(0x04, socksaddr.AddressFamilyIPv6),
|
||||
socksaddr.AddressFamilyByte(0x03, socksaddr.AddressFamilyFqdn),
|
||||
socksaddr.WithFamilyParser(func(b byte) byte {
|
||||
return b & 0x0F
|
||||
}),
|
||||
)
|
||||
func (r *ReducedEntropyReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.Reader.Read(p)
|
||||
if n > 6 {
|
||||
const charSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!#$%&()*+,./:;<=>?@[]^_`{|}~\\\""
|
||||
seed := rand.New(rand.NewSource(int64(crc32.ChecksumIEEE(p[:6]))))
|
||||
for i := range p[:6] {
|
||||
p[i] = charSet[seed.Intn(len(charSet))]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,138 +1,20 @@
|
|||
package shadowsocks
|
||||
package shadowaead
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
const PacketLengthBufferSize = 2
|
||||
const (
|
||||
MaxPacketSize = 16*1024 - 1
|
||||
PacketLengthBufferSize = 2
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterCipher("aes-128-gcm", func() Cipher {
|
||||
return &AEADCipher{
|
||||
KeyLength: 16,
|
||||
SaltLength: 16,
|
||||
Constructor: aesGcm,
|
||||
}
|
||||
})
|
||||
RegisterCipher("aes-192-gcm", func() Cipher {
|
||||
return &AEADCipher{
|
||||
KeyLength: 24,
|
||||
SaltLength: 24,
|
||||
Constructor: aesGcm,
|
||||
}
|
||||
})
|
||||
RegisterCipher("aes-256-gcm", func() Cipher {
|
||||
return &AEADCipher{
|
||||
KeyLength: 32,
|
||||
SaltLength: 32,
|
||||
Constructor: aesGcm,
|
||||
}
|
||||
})
|
||||
RegisterCipher("chacha20-ietf-poly1305", func() Cipher {
|
||||
return &AEADCipher{
|
||||
KeyLength: 32,
|
||||
SaltLength: 32,
|
||||
Constructor: chacha20Poly1305,
|
||||
}
|
||||
})
|
||||
RegisterCipher("xchacha20-ietf-poly1305", func() Cipher {
|
||||
return &AEADCipher{
|
||||
KeyLength: 32,
|
||||
SaltLength: 32,
|
||||
Constructor: xchacha20Poly1305,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func aesGcm(key []byte) cipher.AEAD {
|
||||
block, err := aes.NewCipher(key)
|
||||
common.Must(err)
|
||||
aead, err := cipher.NewGCM(block)
|
||||
common.Must(err)
|
||||
return aead
|
||||
}
|
||||
|
||||
func chacha20Poly1305(key []byte) cipher.AEAD {
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
common.Must(err)
|
||||
return aead
|
||||
}
|
||||
|
||||
func xchacha20Poly1305(key []byte) cipher.AEAD {
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
common.Must(err)
|
||||
return aead
|
||||
}
|
||||
|
||||
type AEADCipher struct {
|
||||
KeyLength int
|
||||
SaltLength int
|
||||
Constructor func(key []byte) cipher.AEAD
|
||||
}
|
||||
|
||||
func (c *AEADCipher) KeySize() int {
|
||||
return c.KeyLength
|
||||
}
|
||||
|
||||
func (c *AEADCipher) SaltSize() int {
|
||||
return c.SaltLength
|
||||
}
|
||||
|
||||
func (c *AEADCipher) CreateReader(key []byte, salt []byte, reader io.Reader) io.Reader {
|
||||
return NewAEADReader(reader, c.Constructor(Kdf(key, salt, c.KeyLength)))
|
||||
}
|
||||
|
||||
func (c *AEADCipher) CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer {
|
||||
protocolWriter := NewAEADWriter(writer, c.Constructor(Kdf(key, salt, c.KeyLength)))
|
||||
return protocolWriter
|
||||
}
|
||||
|
||||
func (c *AEADCipher) EncodePacket(key []byte, buffer *buf.Buffer) error {
|
||||
aead := c.Constructor(Kdf(key, buffer.To(c.SaltLength), c.KeyLength))
|
||||
aead.Seal(buffer.From(c.SaltLength)[:0], rw.ZeroBytes[:aead.NonceSize()], buffer.From(c.SaltLength), nil)
|
||||
buffer.Extend(aead.Overhead())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AEADCipher) DecodePacket(key []byte, buffer *buf.Buffer) error {
|
||||
if buffer.Len() < c.SaltLength {
|
||||
return exceptions.New("bad packet")
|
||||
}
|
||||
aead := c.Constructor(Kdf(key, buffer.To(c.SaltLength), c.KeyLength))
|
||||
packet, err := aead.Open(buffer.Index(c.SaltLength), rw.ZeroBytes[:aead.NonceSize()], buffer.From(c.SaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Advance(c.SaltLength)
|
||||
buffer.Truncate(len(packet))
|
||||
return nil
|
||||
}
|
||||
|
||||
type AEADConn struct {
|
||||
net.Conn
|
||||
Reader *AEADReader
|
||||
Writer *AEADWriter
|
||||
}
|
||||
|
||||
func (c *AEADConn) Read(p []byte) (n int, err error) {
|
||||
return c.Reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *AEADConn) Write(p []byte) (n int, err error) {
|
||||
return c.Writer.Write(p)
|
||||
}
|
||||
|
||||
type AEADReader struct {
|
||||
type Reader struct {
|
||||
upstream io.Reader
|
||||
cipher cipher.AEAD
|
||||
data []byte
|
||||
|
@ -141,8 +23,8 @@ type AEADReader struct {
|
|||
cached int
|
||||
}
|
||||
|
||||
func NewAEADReader(upstream io.Reader, cipher cipher.AEAD) *AEADReader {
|
||||
return &AEADReader{
|
||||
func NewReader(upstream io.Reader, cipher cipher.AEAD) *Reader {
|
||||
return &Reader{
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
|
||||
|
@ -150,19 +32,19 @@ func NewAEADReader(upstream io.Reader, cipher cipher.AEAD) *AEADReader {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *AEADReader) Upstream() io.Reader {
|
||||
func (r *Reader) Upstream() io.Reader {
|
||||
return r.upstream
|
||||
}
|
||||
|
||||
func (r *AEADReader) Replaceable() bool {
|
||||
func (r *Reader) Replaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *AEADReader) SetUpstream(reader io.Reader) {
|
||||
func (r *Reader) SetUpstream(reader io.Reader) {
|
||||
r.upstream = reader
|
||||
}
|
||||
|
||||
func (r *AEADReader) WriteTo(writer io.Writer) (n int64, err error) {
|
||||
func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
|
||||
if r.cached > 0 {
|
||||
writeN, writeErr := writer.Write(r.data[r.index : r.index+r.cached])
|
||||
if writeErr != nil {
|
||||
|
@ -200,7 +82,7 @@ func (r *AEADReader) WriteTo(writer io.Writer) (n int64, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *AEADReader) Read(b []byte) (n int, err error) {
|
||||
func (r *Reader) Read(b []byte) (n int, err error) {
|
||||
if r.cached > 0 {
|
||||
n = copy(b, r.data[r.index:r.index+r.cached])
|
||||
r.cached -= n
|
332
protocol/shadowsocks/shadowaead/method.go
Normal file
332
protocol/shadowsocks/shadowaead/method.go
Normal file
|
@ -0,0 +1,332 @@
|
|||
package shadowaead
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha1"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
var List = []string{
|
||||
"aes-128-gcm",
|
||||
"aes-192-gcm",
|
||||
"aes-256-gcm",
|
||||
"chacha20-ietf-poly1305",
|
||||
"xchacha20-ietf-poly1305",
|
||||
}
|
||||
|
||||
func New(method string, secureRNG io.Reader) shadowsocks.Method {
|
||||
m := &Method{
|
||||
name: method,
|
||||
secureRNG: secureRNG,
|
||||
}
|
||||
switch method {
|
||||
case "aes-128-gcm":
|
||||
m.keySaltLength = 16
|
||||
m.constructor = newAESGCM
|
||||
case "aes-192-gcm":
|
||||
m.keySaltLength = 24
|
||||
m.constructor = newAESGCM
|
||||
case "aes-256-gcm":
|
||||
m.keySaltLength = 32
|
||||
m.constructor = newAESGCM
|
||||
case "chacha20-ietf-poly1305":
|
||||
m.keySaltLength = 32
|
||||
m.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
case "xchacha20-ietf-poly1305":
|
||||
m.keySaltLength = 32
|
||||
m.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.NewX(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func NewSession(key []byte, replayFilter bool) shadowsocks.Session {
|
||||
var filter replay.Filter
|
||||
if replayFilter {
|
||||
filter = replay.NewBloomRing()
|
||||
}
|
||||
return &session{key, filter}
|
||||
}
|
||||
|
||||
func Kdf(key, iv []byte, keyLength int) []byte {
|
||||
subKey := make([]byte, keyLength)
|
||||
kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
|
||||
common.Must1(io.ReadFull(kdf, subKey))
|
||||
return subKey
|
||||
}
|
||||
|
||||
func newAESGCM(key []byte) cipher.AEAD {
|
||||
block, err := aes.NewCipher(key)
|
||||
common.Must(err)
|
||||
aead, err := cipher.NewGCM(block)
|
||||
common.Must(err)
|
||||
return aead
|
||||
}
|
||||
|
||||
type Method struct {
|
||||
name string
|
||||
keySaltLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
secureRNG io.Reader
|
||||
}
|
||||
|
||||
func (m *Method) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Method) KeyLength() int {
|
||||
return m.keySaltLength
|
||||
}
|
||||
|
||||
func (m *Method) DialConn(account shadowsocks.Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||
shadowsocksConn := &aeadConn{
|
||||
Conn: conn,
|
||||
method: m,
|
||||
key: account.Key(),
|
||||
replayFilter: account.ReplayFilter(),
|
||||
destination: destination,
|
||||
}
|
||||
return shadowsocksConn, shadowsocksConn.clientHandshake()
|
||||
}
|
||||
|
||||
func (m *Method) DialEarlyConn(account shadowsocks.Session, conn net.Conn, destination *M.AddrPort) net.Conn {
|
||||
return &aeadConn{
|
||||
Conn: conn,
|
||||
method: m,
|
||||
key: account.Key(),
|
||||
replayFilter: account.ReplayFilter(),
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Method) DialPacketConn(account shadowsocks.Session, conn net.Conn) socks.PacketConn {
|
||||
return &aeadPacketConn{conn, account.Key(), m}
|
||||
}
|
||||
|
||||
func (m *Method) EncodePacket(key []byte, buffer *buf.Buffer) error {
|
||||
cipher := m.constructor(Kdf(key, buffer.To(m.keySaltLength), m.keySaltLength))
|
||||
cipher.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:cipher.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
buffer.Extend(cipher.Overhead())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Method) DecodePacket(key []byte, buffer *buf.Buffer) error {
|
||||
if buffer.Len() < m.keySaltLength {
|
||||
return E.New("bad packet")
|
||||
}
|
||||
aead := m.constructor(Kdf(key, buffer.To(m.keySaltLength), m.keySaltLength))
|
||||
packet, err := aead.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:aead.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Advance(m.keySaltLength)
|
||||
buffer.Truncate(len(packet))
|
||||
return nil
|
||||
}
|
||||
|
||||
type session struct {
|
||||
key []byte
|
||||
replayFilter replay.Filter
|
||||
}
|
||||
|
||||
func (a *session) Key() []byte {
|
||||
return a.key
|
||||
}
|
||||
|
||||
func (a *session) ReplayFilter() replay.Filter {
|
||||
return a.replayFilter
|
||||
}
|
||||
|
||||
type aeadConn struct {
|
||||
net.Conn
|
||||
|
||||
method *Method
|
||||
key []byte
|
||||
destination *M.AddrPort
|
||||
|
||||
access sync.Mutex
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
replayFilter replay.Filter
|
||||
}
|
||||
|
||||
func (c *aeadConn) clientHandshake() error {
|
||||
header := buf.New()
|
||||
defer header.Release()
|
||||
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
if c.replayFilter != nil {
|
||||
c.replayFilter.Check(header.Bytes())
|
||||
}
|
||||
|
||||
c.writer = NewAEADWriter(
|
||||
&buf.BufferedWriter{
|
||||
Writer: c.Conn,
|
||||
Buffer: header,
|
||||
},
|
||||
c.method.constructor(Kdf(c.key, header.Bytes(), c.method.keySaltLength)),
|
||||
)
|
||||
|
||||
err := socks.AddressSerializer.WriteAddrPort(c.writer, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return common.FlushVar(&c.writer)
|
||||
}
|
||||
|
||||
func (c *aeadConn) serverHandshake() error {
|
||||
if c.reader == nil {
|
||||
salt := make([]byte, c.method.keySaltLength)
|
||||
_, err := io.ReadFull(c.Conn, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.replayFilter != nil {
|
||||
if !c.replayFilter.Check(salt) {
|
||||
return E.New("salt is not unique")
|
||||
}
|
||||
}
|
||||
c.reader = NewReader(c.Conn, c.method.constructor(Kdf(c.key, salt, c.method.keySaltLength)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *aeadConn) Read(p []byte) (n int, err error) {
|
||||
if err = c.serverHandshake(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *aeadConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if err = c.serverHandshake(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.(io.WriterTo).WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *aeadConn) Write(p []byte) (n int, err error) {
|
||||
if c.writer != nil {
|
||||
goto direct
|
||||
}
|
||||
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
if c.writer != nil {
|
||||
goto direct
|
||||
}
|
||||
|
||||
// client handshake
|
||||
|
||||
{
|
||||
header := buf.New()
|
||||
defer header.Release()
|
||||
|
||||
request := buf.New()
|
||||
defer request.Release()
|
||||
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
if c.replayFilter != nil {
|
||||
c.replayFilter.Check(header.Bytes())
|
||||
}
|
||||
|
||||
var writer io.Writer = c.Conn
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: header,
|
||||
}
|
||||
writer = NewAEADWriter(writer, c.method.constructor(Kdf(c.key, header.Bytes(), c.method.keySaltLength)))
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: request,
|
||||
}
|
||||
|
||||
err = socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(p) > 0 {
|
||||
_, err = writer.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = common.FlushVar(&writer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.writer = writer
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
direct:
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
|
||||
func (c *aeadConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer == nil {
|
||||
panic("missing client handshake")
|
||||
}
|
||||
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
||||
}
|
||||
|
||||
type aeadPacketConn struct {
|
||||
net.Conn
|
||||
key []byte
|
||||
method *Method
|
||||
}
|
||||
|
||||
func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
err = c.method.EncodePacket(c.key, buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(c.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (c *aeadPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
n, err := c.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = c.method.DecodePacket(c.key, buffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return socks.AddressSerializer.ReadAddrPort(buffer)
|
||||
}
|
|
@ -6,12 +6,12 @@ import (
|
|||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type PacketConn interface {
|
||||
ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error)
|
||||
WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error
|
||||
ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error)
|
||||
WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error
|
||||
|
||||
Close() error
|
||||
LocalAddr() net.Addr
|
||||
|
@ -21,23 +21,45 @@ type PacketConn interface {
|
|||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
func CopyPacketConn(dest PacketConn, conn PacketConn, onAction func(size int)) error {
|
||||
type UDPConnectionHandler interface {
|
||||
NewPacketConnection(conn PacketConn, metadata M.Metadata) error
|
||||
}
|
||||
|
||||
type PacketConnStub struct{}
|
||||
|
||||
func (s *PacketConnStub) RemoteAddr() net.Addr {
|
||||
return &common.DummyAddr{}
|
||||
}
|
||||
|
||||
func (s *PacketConnStub) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PacketConnStub) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func CopyPacketConn(dest PacketConn, conn PacketConn, onAction func(destination *M.AddrPort, n int)) error {
|
||||
for {
|
||||
buffer := buf.New()
|
||||
addr, port, err := conn.ReadPacket(buffer)
|
||||
destination, err := conn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return err
|
||||
}
|
||||
size := buffer.Len()
|
||||
err = dest.WritePacket(buffer, addr, port)
|
||||
err = dest.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return err
|
||||
}
|
||||
if onAction != nil {
|
||||
onAction(size)
|
||||
onAction(destination, size)
|
||||
}
|
||||
buffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,22 +80,22 @@ func (c *associatePacketConn) RemoteAddr() net.Addr {
|
|||
return c.addr
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) {
|
||||
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
c.addr = addr
|
||||
buffer.Truncate(n)
|
||||
buffer.Advance(3)
|
||||
return AddressSerializer.ReadAddressAndPort(buffer)
|
||||
return AddressSerializer.ReadAddrPort(buffer)
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error {
|
||||
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
common.Must(header.WriteZeroN(3))
|
||||
common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port))
|
||||
common.Must(AddressSerializer.WriteAddrPort(header, addrPort))
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr))
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ package socks
|
|||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -69,8 +69,8 @@ func (code ReplyCode) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
var AddressSerializer = socksaddr.NewSerializer(
|
||||
socksaddr.AddressFamilyByte(0x01, socksaddr.AddressFamilyIPv4),
|
||||
socksaddr.AddressFamilyByte(0x04, socksaddr.AddressFamilyIPv6),
|
||||
socksaddr.AddressFamilyByte(0x03, socksaddr.AddressFamilyFqdn),
|
||||
var AddressSerializer = M.NewSerializer(
|
||||
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
|
||||
M.AddressFamilyByte(0x04, M.AddressFamilyIPv6),
|
||||
M.AddressFamilyByte(0x03, M.AddressFamilyFqdn),
|
||||
)
|
||||
|
|
|
@ -4,11 +4,11 @@ import (
|
|||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksaddr.Addr, port uint16, username string, password string) (*Response, error) {
|
||||
func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination *M.AddrPort, username string, password string) (*Response, error) {
|
||||
var method byte
|
||||
if common.IsBlank(username) {
|
||||
method = AuthTypeNotRequired
|
||||
|
@ -27,7 +27,7 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksa
|
|||
return nil, err
|
||||
}
|
||||
if authResponse.Method != method {
|
||||
return nil, exceptions.New("not requested method, request ", method, ", return ", method)
|
||||
return nil, E.New("not requested method, request ", method, ", return ", method)
|
||||
}
|
||||
if method == AuthTypeUsernamePassword {
|
||||
err = WriteUsernamePasswordAuthRequest(conn, &UsernamePasswordAuthRequest{
|
||||
|
@ -46,10 +46,9 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksa
|
|||
}
|
||||
}
|
||||
err = WriteRequest(conn, &Request{
|
||||
Version: version,
|
||||
Command: command,
|
||||
Addr: addr,
|
||||
Port: port,
|
||||
Version: version,
|
||||
Command: command,
|
||||
Destination: destination,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -57,7 +56,7 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksa
|
|||
return ReadResponse(conn)
|
||||
}
|
||||
|
||||
func ClientFastHandshake(writer io.Writer, version byte, command byte, addr socksaddr.Addr, port uint16, username string, password string) error {
|
||||
func ClientFastHandshake(writer io.Writer, version byte, command byte, destination *M.AddrPort, username string, password string) error {
|
||||
var method byte
|
||||
if common.IsBlank(username) {
|
||||
method = AuthTypeNotRequired
|
||||
|
@ -81,10 +80,9 @@ func ClientFastHandshake(writer io.Writer, version byte, command byte, addr sock
|
|||
}
|
||||
}
|
||||
return WriteRequest(writer, &Request{
|
||||
Version: version,
|
||||
Command: command,
|
||||
Addr: addr,
|
||||
Port: port,
|
||||
Version: version,
|
||||
Command: command,
|
||||
Destination: destination,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
148
protocol/socks/listener.go
Normal file
148
protocol/socks/listener.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
package socks
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/auth"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/transport/tcp"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
tcp.Handler
|
||||
UDPConnectionHandler
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
tcpListener *tcp.Listener
|
||||
authenticator auth.Authenticator
|
||||
handler Handler
|
||||
}
|
||||
|
||||
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler Handler) *Listener {
|
||||
listener := &Listener{
|
||||
handler: handler,
|
||||
authenticator: authenticator,
|
||||
}
|
||||
listener.tcpListener = tcp.NewTCPListener(bind, listener)
|
||||
return listener
|
||||
}
|
||||
|
||||
func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
return HandleConnection(conn, l.authenticator, l.handler)
|
||||
}
|
||||
|
||||
func (l *Listener) Start() error {
|
||||
return l.tcpListener.Start()
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
return l.tcpListener.Close()
|
||||
}
|
||||
|
||||
func (l *Listener) HandleError(err error) {
|
||||
l.handler.HandleError(err)
|
||||
}
|
||||
|
||||
func HandleConnection(conn net.Conn, authenticator auth.Authenticator, handler Handler) error {
|
||||
authRequest, err := ReadAuthRequest(conn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read socks auth request")
|
||||
}
|
||||
var authMethod byte
|
||||
if authenticator == nil {
|
||||
authMethod = AuthTypeNotRequired
|
||||
} else {
|
||||
authMethod = AuthTypeUsernamePassword
|
||||
}
|
||||
if !common.Contains(authRequest.Methods, authMethod) {
|
||||
err = WriteAuthResponse(conn, &AuthResponse{
|
||||
Version: authRequest.Version,
|
||||
Method: AuthTypeNoAcceptedMethods,
|
||||
})
|
||||
if err != nil {
|
||||
return E.Cause(err, "write socks auth response")
|
||||
}
|
||||
}
|
||||
err = WriteAuthResponse(conn, &AuthResponse{
|
||||
Version: authRequest.Version,
|
||||
Method: AuthTypeNotRequired,
|
||||
})
|
||||
if err != nil {
|
||||
return E.Cause(err, "write socks auth response")
|
||||
}
|
||||
|
||||
if authMethod == AuthTypeUsernamePassword {
|
||||
usernamePasswordAuthRequest, err := ReadUsernamePasswordAuthRequest(conn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read user auth request")
|
||||
}
|
||||
response := new(UsernamePasswordAuthResponse)
|
||||
if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) {
|
||||
response.Status = UsernamePasswordStatusSuccess
|
||||
} else {
|
||||
response.Status = UsernamePasswordStatusFailure
|
||||
}
|
||||
err = WriteUsernamePasswordAuthResponse(conn, response)
|
||||
if err != nil {
|
||||
return E.Cause(err, "write user auth response")
|
||||
}
|
||||
}
|
||||
|
||||
request, err := ReadRequest(conn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read socks request")
|
||||
}
|
||||
switch request.Command {
|
||||
case CommandConnect:
|
||||
err = WriteResponse(conn, &Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: ReplyCodeSuccess,
|
||||
Bind: M.AddrPortFromNetAddr(conn.LocalAddr()),
|
||||
})
|
||||
if err != nil {
|
||||
return E.Cause(err, "write socks response")
|
||||
}
|
||||
return handler.NewConnection(conn, M.Metadata{
|
||||
Destination: request.Destination,
|
||||
})
|
||||
case CommandUDPAssociate:
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer udpConn.Close()
|
||||
err = WriteResponse(conn, &Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: ReplyCodeSuccess,
|
||||
Bind: M.AddrPortFromNetAddr(udpConn.LocalAddr()),
|
||||
})
|
||||
if err != nil {
|
||||
return E.Cause(err, "write socks response")
|
||||
}
|
||||
go func() {
|
||||
err := handler.NewPacketConnection(NewPacketConn(conn, udpConn), M.Metadata{
|
||||
Source: M.AddrPortFromNetAddr(conn.RemoteAddr()),
|
||||
Destination: request.Destination,
|
||||
})
|
||||
if err != nil {
|
||||
handler.HandleError(err)
|
||||
}
|
||||
conn.Close()
|
||||
}()
|
||||
return common.Error(io.Copy(io.Discard, conn))
|
||||
default:
|
||||
err = WriteResponse(conn, &Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: ReplyCodeUnsupported,
|
||||
})
|
||||
if err != nil {
|
||||
return E.Cause(err, "write response")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -7,9 +7,9 @@ import (
|
|||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
)
|
||||
|
||||
//+----+----------+----------+
|
||||
|
@ -45,11 +45,11 @@ func ReadAuthRequest(reader io.Reader) (*AuthRequest, error) {
|
|||
}
|
||||
methodLen, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "read socks auth methods length")
|
||||
return nil, E.Cause(err, "read socks auth methods length")
|
||||
}
|
||||
methods, err := rw.ReadBytes(reader, int(methodLen))
|
||||
if err != nil {
|
||||
return nil, exceptions.CauseF(err, "read socks auth methods, length ", methodLen)
|
||||
return nil, E.CauseF(err, "read socks auth methods, length ", methodLen)
|
||||
}
|
||||
request := &AuthRequest{
|
||||
version,
|
||||
|
@ -112,11 +112,11 @@ func WriteUsernamePasswordAuthRequest(writer io.Writer, request *UsernamePasswor
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = socksaddr.WriteString(writer, "username", request.Username)
|
||||
err = M.WriteString(writer, "username", request.Username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return socksaddr.WriteString(writer, "password", request.Password)
|
||||
return M.WriteString(writer, "password", request.Password)
|
||||
}
|
||||
|
||||
func ReadUsernamePasswordAuthRequest(reader io.Reader) (*UsernamePasswordAuthRequest, error) {
|
||||
|
@ -127,11 +127,11 @@ func ReadUsernamePasswordAuthRequest(reader io.Reader) (*UsernamePasswordAuthReq
|
|||
if version != UsernamePasswordVersion1 {
|
||||
return nil, &UnsupportedVersionException{version}
|
||||
}
|
||||
username, err := socksaddr.ReadString(reader)
|
||||
username, err := M.ReadString(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
password, err := socksaddr.ReadString(reader)
|
||||
password, err := M.ReadString(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -185,10 +185,9 @@ func ReadUsernamePasswordAuthResponse(reader io.Reader) (*UsernamePasswordAuthRe
|
|||
//+----+-----+-------+------+----------+----------+
|
||||
|
||||
type Request struct {
|
||||
Version byte
|
||||
Command byte
|
||||
Addr socksaddr.Addr
|
||||
Port uint16
|
||||
Version byte
|
||||
Command byte
|
||||
Destination *M.AddrPort
|
||||
}
|
||||
|
||||
func WriteRequest(writer io.Writer, request *Request) error {
|
||||
|
@ -204,7 +203,7 @@ func WriteRequest(writer io.Writer, request *Request) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return AddressSerializer.WriteAddressAndPort(writer, request.Addr, request.Port)
|
||||
return AddressSerializer.WriteAddrPort(writer, request.Destination)
|
||||
}
|
||||
|
||||
func ReadRequest(reader io.Reader) (*Request, error) {
|
||||
|
@ -226,15 +225,14 @@ func ReadRequest(reader io.Reader) (*Request, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addr, port, err := AddressSerializer.ReadAddressAndPort(reader)
|
||||
addrPort, err := AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request := &Request{
|
||||
Version: version,
|
||||
Command: command,
|
||||
Addr: addr,
|
||||
Port: port,
|
||||
Version: version,
|
||||
Command: command,
|
||||
Destination: addrPort,
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
@ -248,8 +246,7 @@ func ReadRequest(reader io.Reader) (*Request, error) {
|
|||
type Response struct {
|
||||
Version byte
|
||||
ReplyCode ReplyCode
|
||||
BindAddr socksaddr.Addr
|
||||
BindPort uint16
|
||||
Bind *M.AddrPort
|
||||
}
|
||||
|
||||
func WriteResponse(writer io.Writer, response *Response) error {
|
||||
|
@ -265,10 +262,10 @@ func WriteResponse(writer io.Writer, response *Response) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if response.BindAddr == nil {
|
||||
return AddressSerializer.WriteAddressAndPort(writer, socksaddr.AddrFromIP(net.IPv4zero), response.BindPort)
|
||||
if response.Bind == nil {
|
||||
return AddressSerializer.WriteAddrPort(writer, M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0))
|
||||
}
|
||||
return AddressSerializer.WriteAddressAndPort(writer, response.BindAddr, response.BindPort)
|
||||
return AddressSerializer.WriteAddrPort(writer, response.Bind)
|
||||
}
|
||||
|
||||
func ReadResponse(reader io.Reader) (*Response, error) {
|
||||
|
@ -287,15 +284,14 @@ func ReadResponse(reader io.Reader) (*Response, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addr, port, err := AddressSerializer.ReadAddressAndPort(reader)
|
||||
addrPort, err := AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response := &Response{
|
||||
Version: version,
|
||||
ReplyCode: ReplyCode(replyCode),
|
||||
BindAddr: addr,
|
||||
BindPort: port,
|
||||
Bind: addrPort,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
@ -307,15 +303,14 @@ func ReadResponse(reader io.Reader) (*Response, error) {
|
|||
//+----+------+------+----------+----------+----------+
|
||||
|
||||
type AssociatePacket struct {
|
||||
Fragment byte
|
||||
Addr socksaddr.Addr
|
||||
Port uint16
|
||||
Data []byte
|
||||
Fragment byte
|
||||
Destination *M.AddrPort
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func DecodeAssociatePacket(buffer *buf.Buffer) (*AssociatePacket, error) {
|
||||
if buffer.Len() < 5 {
|
||||
return nil, exceptions.New("insufficient length")
|
||||
return nil, E.New("insufficient length")
|
||||
}
|
||||
fragment := buffer.Byte(2)
|
||||
reader := bytes.NewReader(buffer.Bytes())
|
||||
|
@ -323,16 +318,15 @@ func DecodeAssociatePacket(buffer *buf.Buffer) (*AssociatePacket, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addr, port, err := AddressSerializer.ReadAddressAndPort(reader)
|
||||
addrPort, err := AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.Advance(reader.Len())
|
||||
packet := &AssociatePacket{
|
||||
Fragment: fragment,
|
||||
Addr: addr,
|
||||
Port: port,
|
||||
Data: buffer.Bytes(),
|
||||
Fragment: fragment,
|
||||
Destination: addrPort,
|
||||
Data: buffer.Bytes(),
|
||||
}
|
||||
return packet, nil
|
||||
}
|
||||
|
@ -346,7 +340,7 @@ func EncodeAssociatePacket(packet *AssociatePacket, buffer *buf.Buffer) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = AddressSerializer.WriteAddressAndPort(buffer, packet.Addr, packet.Port)
|
||||
err = AddressSerializer.WriteAddrPort(buffer, packet.Destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
|
@ -20,7 +20,7 @@ func TestHandshake(t *testing.T) {
|
|||
method := socks.AuthTypeUsernamePassword
|
||||
|
||||
go func() {
|
||||
response, err := socks.ClientHandshake(client, socks.Version5, socks.CommandConnect, socksaddr.AddrFromFqdn("test"), 80, "user", "pswd")
|
||||
response, err := socks.ClientHandshake(client, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn("test"), 80), "user", "pswd")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -60,14 +60,13 @@ func TestHandshake(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if request.Version != socks.Version5 || request.Command != socks.CommandConnect || request.Addr.Fqdn() != "test" || request.Port != 80 {
|
||||
if request.Version != socks.Version5 || request.Command != socks.CommandConnect || request.Destination.Addr.Fqdn() != "test" || request.Destination.Port != 80 {
|
||||
t.Fatal(request)
|
||||
}
|
||||
err = socks.WriteResponse(server, &socks.Response{
|
||||
Version: socks.Version5,
|
||||
ReplyCode: socks.ReplyCodeSuccess,
|
||||
BindAddr: socksaddr.AddrFromIP(net.IPv4zero),
|
||||
BindPort: 0,
|
||||
Bind: M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
86
transport/mixed/listener.go
Normal file
86
transport/mixed/listener.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package mixed
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing/common/auth"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/redir"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
"github.com/sagernet/sing/protocol/http"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"github.com/sagernet/sing/transport/tcp"
|
||||
"github.com/sagernet/sing/transport/udp"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
socks.Handler
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
TCPListener *tcp.Listener
|
||||
UDPListener *udp.Listener
|
||||
handler Handler
|
||||
authenticator auth.Authenticator
|
||||
udpNat *udpnat.Server
|
||||
}
|
||||
|
||||
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, handler Handler) *Listener {
|
||||
listener := &Listener{
|
||||
handler: handler,
|
||||
authenticator: authenticator,
|
||||
}
|
||||
|
||||
listener.TCPListener = tcp.NewTCPListener(bind, listener, tcp.WithTransproxyMode(transproxy))
|
||||
if transproxy == redir.ModeTProxy {
|
||||
listener.UDPListener = udp.NewUDPListener(bind, listener, udp.WithTransproxyMode(transproxy))
|
||||
listener.udpNat = udpnat.NewServer(handler)
|
||||
}
|
||||
return listener
|
||||
}
|
||||
|
||||
func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
if metadata.Destination != nil {
|
||||
return l.handler.NewConnection(conn, metadata)
|
||||
}
|
||||
bufConn := buf.NewBufferedConn(conn)
|
||||
header, err := bufConn.Peek(1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch header[0] {
|
||||
case socks.Version4, socks.Version5:
|
||||
return socks.HandleConnection(bufConn, l.authenticator, l.handler)
|
||||
default:
|
||||
return http.HandleConnection(bufConn, l.authenticator, l.handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) NewPacket(packet *buf.Buffer, metadata M.Metadata) error {
|
||||
return l.udpNat.HandleUDP(packet, metadata)
|
||||
}
|
||||
|
||||
func (l *Listener) HandleError(err error) {
|
||||
l.handler.HandleError(err)
|
||||
}
|
||||
|
||||
func (l *Listener) Start() error {
|
||||
err := l.TCPListener.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if l.UDPListener != nil {
|
||||
err = l.UDPListener.Start()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
l.TCPListener.Close()
|
||||
if l.UDPListener != nil {
|
||||
l.UDPListener.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
package system
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net/http"
|
||||
_ "unsafe"
|
||||
)
|
||||
|
||||
//go:linkname readRequest net/http.readRequest
|
||||
func readRequest(b *bufio.Reader) (req *http.Request, err error)
|
|
@ -2,6 +2,8 @@ package system
|
|||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -12,3 +14,22 @@ const (
|
|||
func TCPFastOpen(fd uintptr) error {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_TCP, TCP_FASTOPEN_CONNECT, 1)
|
||||
}
|
||||
|
||||
func TProxy(fd uintptr, isIPv6 bool) error {
|
||||
err := syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isIPv6 {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func TProxyUDP(fd uintptr, isIPv6 bool) error {
|
||||
err := syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
|
||||
}
|
||||
|
|
|
@ -7,3 +7,7 @@ import "github.com/sagernet/sing/common/exceptions"
|
|||
func TCPFastOpen(fd uintptr) error {
|
||||
return exceptions.New("only available on linux")
|
||||
}
|
||||
|
||||
func TProxy(fd uintptr, isIPv6 bool) error {
|
||||
return exceptions.New("only available on linux")
|
||||
}
|
||||
|
|
|
@ -1,149 +0,0 @@
|
|||
package system
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type SocksHandler interface {
|
||||
NewConnection(addr socksaddr.Addr, port uint16, conn net.Conn) error
|
||||
NewPacketConnection(conn socks.PacketConn, addr socksaddr.Addr, port uint16) error
|
||||
OnError(err error)
|
||||
}
|
||||
|
||||
type SocksConfig struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
type SocksListener struct {
|
||||
Handler SocksHandler
|
||||
*TCPListener
|
||||
*SocksConfig
|
||||
}
|
||||
|
||||
func NewSocksListener(bind netip.AddrPort, config *SocksConfig, handler SocksHandler) *SocksListener {
|
||||
listener := &SocksListener{
|
||||
SocksConfig: config,
|
||||
Handler: handler,
|
||||
}
|
||||
listener.TCPListener = NewTCPListener(bind, listener)
|
||||
return listener
|
||||
}
|
||||
|
||||
func (l *SocksListener) HandleTCP(conn net.Conn) error {
|
||||
authRequest, err := socks.ReadAuthRequest(conn)
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "read socks auth request")
|
||||
}
|
||||
var authMethod byte
|
||||
if l.Username == "" {
|
||||
authMethod = socks.AuthTypeNotRequired
|
||||
} else {
|
||||
authMethod = socks.AuthTypeUsernamePassword
|
||||
}
|
||||
if !common.Contains(authRequest.Methods, authMethod) {
|
||||
err = socks.WriteAuthResponse(conn, &socks.AuthResponse{
|
||||
Version: authRequest.Version,
|
||||
Method: socks.AuthTypeNoAcceptedMethods,
|
||||
})
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write socks auth response")
|
||||
}
|
||||
}
|
||||
err = socks.WriteAuthResponse(conn, &socks.AuthResponse{
|
||||
Version: authRequest.Version,
|
||||
Method: socks.AuthTypeNotRequired,
|
||||
})
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write socks auth response")
|
||||
}
|
||||
|
||||
if authMethod == socks.AuthTypeUsernamePassword {
|
||||
usernamePasswordAuthRequest, err := socks.ReadUsernamePasswordAuthRequest(conn)
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "read user auth request")
|
||||
}
|
||||
response := socks.UsernamePasswordAuthResponse{}
|
||||
if usernamePasswordAuthRequest.Username != l.Username {
|
||||
response.Status = socks.UsernamePasswordStatusFailure
|
||||
} else if usernamePasswordAuthRequest.Password != l.Password {
|
||||
response.Status = socks.UsernamePasswordStatusFailure
|
||||
} else {
|
||||
response.Status = socks.UsernamePasswordStatusSuccess
|
||||
}
|
||||
err = socks.WriteUsernamePasswordAuthResponse(conn, &response)
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write user auth response")
|
||||
}
|
||||
}
|
||||
|
||||
request, err := socks.ReadRequest(conn)
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "read socks request")
|
||||
}
|
||||
switch request.Command {
|
||||
case socks.CommandConnect:
|
||||
localAddr, localPort := socksaddr.AddrFromNetAddr(l.TCPListener.TCPListener.Addr())
|
||||
err = socks.WriteResponse(conn, &socks.Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: socks.ReplyCodeSuccess,
|
||||
BindAddr: localAddr,
|
||||
BindPort: localPort,
|
||||
})
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write socks response")
|
||||
}
|
||||
return l.Handler.NewConnection(request.Addr, request.Port, conn)
|
||||
case socks.CommandUDPAssociate:
|
||||
udpConn, err := net.ListenUDP("udp4", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer udpConn.Close()
|
||||
localAddr, localPort := socksaddr.AddrFromNetAddr(udpConn.LocalAddr())
|
||||
err = socks.WriteResponse(conn, &socks.Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: socks.ReplyCodeSuccess,
|
||||
BindAddr: localAddr,
|
||||
BindPort: localPort,
|
||||
})
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write socks response")
|
||||
}
|
||||
go func() {
|
||||
err := l.Handler.NewPacketConnection(socks.NewPacketConn(conn, udpConn), request.Addr, request.Port)
|
||||
if err != nil {
|
||||
l.OnError(err)
|
||||
}
|
||||
}()
|
||||
return common.Error(io.Copy(io.Discard, conn))
|
||||
default:
|
||||
err = socks.WriteResponse(conn, &socks.Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: socks.ReplyCodeUnsupported,
|
||||
})
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write response")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *SocksListener) Start() error {
|
||||
return l.TCPListener.Start()
|
||||
}
|
||||
|
||||
func (l *SocksListener) Close() error {
|
||||
return l.TCPListener.Close()
|
||||
}
|
||||
|
||||
func (l *SocksListener) OnError(err error) {
|
||||
l.Handler.OnError(exceptions.Cause(err, "socks server"))
|
||||
}
|
|
@ -1,57 +0,0 @@
|
|||
package system
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type TCPHandler interface {
|
||||
HandleTCP(conn net.Conn) error
|
||||
OnError(err error)
|
||||
}
|
||||
|
||||
type TCPListener struct {
|
||||
Listen netip.AddrPort
|
||||
Handler TCPHandler
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func NewTCPListener(listen netip.AddrPort, handler TCPHandler) *TCPListener {
|
||||
return &TCPListener{
|
||||
Listen: listen,
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *TCPListener) Start() error {
|
||||
tcpListener, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(l.Listen))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.TCPListener = tcpListener
|
||||
go l.loop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *TCPListener) Close() error {
|
||||
if l == nil || l.TCPListener == nil {
|
||||
return nil
|
||||
}
|
||||
return l.TCPListener.Close()
|
||||
}
|
||||
|
||||
func (l *TCPListener) loop() {
|
||||
for {
|
||||
tcpConn, err := l.Accept()
|
||||
if err != nil {
|
||||
l.Close()
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
err := l.Handler.HandleTCP(tcpConn)
|
||||
if err != nil {
|
||||
l.Handler.OnError(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
|
@ -1,62 +0,0 @@
|
|||
package system
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
type UDPHandler interface {
|
||||
HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error
|
||||
OnError(err error)
|
||||
}
|
||||
|
||||
type UDPListener struct {
|
||||
Listen netip.AddrPort
|
||||
Handler UDPHandler
|
||||
*net.UDPConn
|
||||
}
|
||||
|
||||
func NewUDPListener(listen netip.AddrPort, handler UDPHandler) *UDPListener {
|
||||
return &UDPListener{
|
||||
Listen: listen,
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *UDPListener) Start() error {
|
||||
udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(l.Listen))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.UDPConn = udpConn
|
||||
go l.loop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *UDPListener) Close() error {
|
||||
if l == nil || l.UDPConn == nil {
|
||||
return nil
|
||||
}
|
||||
return l.UDPConn.Close()
|
||||
}
|
||||
|
||||
func (l *UDPListener) loop() {
|
||||
for {
|
||||
buffer := buf.New()
|
||||
n, addr, err := l.ReadFromUDP(buffer.Extend(buf.UDPBufferSize))
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
go func() {
|
||||
err := l.Handler.HandleUDP(buffer, addr)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
l.Handler.OnError(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
91
transport/tcp/handler.go
Normal file
91
transport/tcp/handler.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/redir"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
M.TCPConnectionHandler
|
||||
E.Handler
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
bind netip.AddrPort
|
||||
handler Handler
|
||||
trans redir.TransproxyMode
|
||||
lAddr *net.TCPAddr
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func NewTCPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener {
|
||||
listener := &Listener{
|
||||
bind: listen,
|
||||
handler: handler,
|
||||
}
|
||||
for _, option := range options {
|
||||
option(listener)
|
||||
}
|
||||
return listener
|
||||
}
|
||||
|
||||
func (l *Listener) Start() error {
|
||||
tcpListener, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(l.bind))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if l.trans == redir.ModeTProxy {
|
||||
l.lAddr = tcpListener.Addr().(*net.TCPAddr)
|
||||
fd, err := common.GetFileDescriptor(tcpListener)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = redir.TProxy(fd, l.bind.Addr().Is6())
|
||||
if err != nil {
|
||||
return E.Cause(err, "configure tproxy")
|
||||
}
|
||||
}
|
||||
l.TCPListener = tcpListener
|
||||
go l.loop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
if l == nil || l.TCPListener == nil {
|
||||
return nil
|
||||
}
|
||||
return l.TCPListener.Close()
|
||||
}
|
||||
|
||||
func (l *Listener) loop() {
|
||||
for {
|
||||
tcpConn, err := l.Accept()
|
||||
if err != nil {
|
||||
l.Close()
|
||||
return
|
||||
}
|
||||
var metadata M.Metadata
|
||||
switch l.trans {
|
||||
case redir.ModeRedirect:
|
||||
metadata.Destination, _ = redir.GetOriginalDestination(tcpConn)
|
||||
case redir.ModeTProxy:
|
||||
lAddr := tcpConn.LocalAddr().(*net.TCPAddr)
|
||||
rAddr := tcpConn.RemoteAddr().(*net.TCPAddr)
|
||||
|
||||
if lAddr.Port != l.lAddr.Port || !lAddr.IP.Equal(rAddr.IP) && !lAddr.IP.IsLoopback() && !lAddr.IP.IsPrivate() {
|
||||
metadata.Destination = M.AddrPortFromNetAddr(lAddr)
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
err := l.handler.NewConnection(tcpConn, metadata)
|
||||
if err != nil {
|
||||
l.handler.HandleError(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
11
transport/tcp/options.go
Normal file
11
transport/tcp/options.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package tcp
|
||||
|
||||
import "github.com/sagernet/sing/common/redir"
|
||||
|
||||
type Option func(*Listener)
|
||||
|
||||
func WithTransproxyMode(mode redir.TransproxyMode) Option {
|
||||
return func(listener *Listener) {
|
||||
listener.trans = mode
|
||||
}
|
||||
}
|
11
transport/udp/options.go
Normal file
11
transport/udp/options.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package udp
|
||||
|
||||
import "github.com/sagernet/sing/common/redir"
|
||||
|
||||
type Option func(*Listener)
|
||||
|
||||
func WithTransproxyMode(mode redir.TransproxyMode) Option {
|
||||
return func(listener *Listener) {
|
||||
listener.tproxy = mode == redir.ModeTProxy
|
||||
}
|
||||
}
|
116
transport/udp/udp.go
Normal file
116
transport/udp/udp.go
Normal file
|
@ -0,0 +1,116 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/redir"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
M.UDPHandler
|
||||
E.Handler
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
*net.UDPConn
|
||||
handler Handler
|
||||
bind *net.UDPAddr
|
||||
tproxy bool
|
||||
}
|
||||
|
||||
func NewUDPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener {
|
||||
listener := &Listener{
|
||||
handler: handler,
|
||||
bind: net.UDPAddrFromAddrPort(listen),
|
||||
}
|
||||
for _, option := range options {
|
||||
option(listener)
|
||||
}
|
||||
return listener
|
||||
}
|
||||
|
||||
func (l *Listener) Start() error {
|
||||
udpConn, err := net.ListenUDP("udp", l.bind)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if l.tproxy {
|
||||
fd, err := common.GetFileDescriptor(udpConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = redir.TProxy(fd, l.bind.AddrPort().Addr().Is6())
|
||||
if err != nil {
|
||||
return E.Cause(err, "configure tproxy")
|
||||
}
|
||||
err = redir.TProxyUDP(fd, l.bind.AddrPort().Addr().Is6())
|
||||
if err != nil {
|
||||
return E.Cause(err, "configure tproxy")
|
||||
}
|
||||
}
|
||||
|
||||
l.UDPConn = udpConn
|
||||
go l.loop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
if l == nil || l.UDPConn == nil {
|
||||
return nil
|
||||
}
|
||||
return l.UDPConn.Close()
|
||||
}
|
||||
|
||||
func (l *Listener) loop() {
|
||||
if !l.tproxy {
|
||||
for {
|
||||
buffer := buf.New()
|
||||
n, addr, err := l.ReadFromUDP(buffer.Extend(buf.UDPBufferSize))
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
l.handler.HandleError(err)
|
||||
return
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = l.handler.NewPacket(buffer, M.Metadata{
|
||||
Source: M.AddrPortFromNetAddr(addr),
|
||||
})
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
l.handler.HandleError(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
oob := make([]byte, 1024)
|
||||
for {
|
||||
buffer := buf.New()
|
||||
n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(buffer.FreeBytes(), oob)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
l.handler.HandleError(err)
|
||||
return
|
||||
}
|
||||
destination, err := redir.GetOriginalDestinationFromOOB(oob[:oobN])
|
||||
if err != nil {
|
||||
l.handler.HandleError(E.Cause(err, "get original destination"))
|
||||
return
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = l.handler.NewPacket(buffer, M.Metadata{
|
||||
Source: M.AddrPortFromAddrPort(addr),
|
||||
Destination: destination,
|
||||
})
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
l.handler.HandleError(err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue