From 26e13e7bebb1b6ecbff50f90e74478e42f446245 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 7 Apr 2022 20:49:20 +0800 Subject: [PATCH] Add http support for sslocal --- .gitignore | 3 +- cli/sslocal/main/wrap.go | 16 -- cli/sslocal/sslocal.go | 243 ++++++++++++++++++ .../cmd.go => sslocal_raw/sslocal.go} | 127 ++++++--- common/buf/bufconn.go | 44 ++++ common/buf/buffer.go | 30 ++- common/buf/{conn.go => io.go} | 31 +-- common/exceptions/error.go | 38 ++- common/flush.go | 73 +++++- common/rw/copy.go | 34 ++- common/socksaddr/addr.go | 33 +++ common/upstream.go | 10 + go.mod | 1 + go.sum | 7 +- protocol/shadowsocks/cipher.go | 4 +- protocol/shadowsocks/cipher_aead.go | 154 ++++++----- protocol/shadowsocks/cipher_none.go | 4 +- protocol/shadowsocks/client.go | 168 ++++++++++++ protocol/shadowsocks/config.go | 1 + protocol/socks/conn.go | 78 ++++++ protocol/socks/protocol.go | 2 +- transport/system/control.go | 30 --- transport/system/dial.go | 10 - transport/system/http.go | 10 + transport/system/mixed.go | 223 ++++++++++++++++ transport/system/sockopt_linux.go | 14 + transport/system/sockopt_other.go | 9 + transport/system/socks.go | 148 +++++++++++ transport/system/stub.s | 0 transport/system/udp.go | 4 +- 30 files changed, 1355 insertions(+), 194 deletions(-) delete mode 100644 cli/sslocal/main/wrap.go create mode 100644 cli/sslocal/sslocal.go rename cli/{sslocal/cmd.go => sslocal_raw/sslocal.go} (73%) create mode 100644 common/buf/bufconn.go rename common/buf/{conn.go => io.go} (71%) create mode 100644 protocol/shadowsocks/client.go create mode 100644 protocol/shadowsocks/config.go create mode 100644 protocol/socks/conn.go delete mode 100644 transport/system/control.go delete mode 100644 transport/system/dial.go create mode 100644 transport/system/http.go create mode 100644 transport/system/mixed.go create mode 100644 transport/system/sockopt_linux.go create mode 100644 transport/system/sockopt_other.go create mode 100644 transport/system/socks.go create mode 100644 transport/system/stub.s diff --git a/.gitignore b/.gitignore index b865525..989d036 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /.idea/ -/sing_* \ No newline at end of file +/sing_* +/*.json \ No newline at end of file diff --git a/cli/sslocal/main/wrap.go b/cli/sslocal/main/wrap.go deleted file mode 100644 index 8b92c8d..0000000 --- a/cli/sslocal/main/wrap.go +++ /dev/null @@ -1,16 +0,0 @@ -package main - -import ( - "fmt" - "os" - - "github.com/sagernet/sing/cli/sslocal" -) - -func main() { - err := sslocal.MainCmd().Execute() - if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } -} diff --git a/cli/sslocal/sslocal.go b/cli/sslocal/sslocal.go new file mode 100644 index 0000000..7f22975 --- /dev/null +++ b/cli/sslocal/sslocal.go @@ -0,0 +1,243 @@ +package main + +import ( + "context" + "encoding/base64" + "encoding/json" + "io" + "io/ioutil" + "net" + "net/netip" + "os" + "os/signal" + "strconv" + "syscall" + + "github.com/sagernet/sing" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/socksaddr" + "github.com/sagernet/sing/common/task" + "github.com/sagernet/sing/protocol/shadowsocks" + "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/transport/system" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +func main() { + err := MainCmd().Execute() + if err != nil { + logrus.Fatal(err) + } +} + +type Flags struct { + Server string `json:"server"` + ServerPort uint16 `json:"server_port"` + LocalPort uint16 `json:"local_port"` + Password string `json:"password"` + Key string `json:"key"` + Method string `json:"method"` + TCPFastOpen bool `json:"fast_open"` + Verbose bool `json:"verbose"` + ConfigFile string +} + +func MainCmd() *cobra.Command { + flags := new(Flags) + + cmd := &cobra.Command{ + Use: "sslocal", + Short: "shadowsocks client as socks5 proxy, sing port", + Version: sing.Version, + Run: func(cmd *cobra.Command, args []string) { + Run(flags) + }, + } + + cmd.Flags().StringVarP(&flags.Server, "server", "s", "", "Set the server’s hostname or IP.") + cmd.Flags().Uint16VarP(&flags.ServerPort, "server-port", "p", 0, "Set the server’s port number.") + cmd.Flags().Uint16VarP(&flags.LocalPort, "local-port", "l", 0, "Set the local port number.") + cmd.Flags().StringVarP(&flags.Password, "password", "k", "", "Set the password. The server and the client should use the same password.") + cmd.Flags().StringVar(&flags.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.") + cmd.Flags().StringVarP(&flags.Method, "encrypt-method", "m", "", `Set the cipher. + +Supported ciphers: + +none +aes-128-gcm +aes-192-gcm +aes-256-gcm +chacha20-ietf-poly1305 +xchacha20-ietf-poly1305 + +The default cipher is chacha20-ietf-poly1305.`) + cmd.Flags().BoolVar(&flags.TCPFastOpen, "fast-open", false, `Enable TCP fast open. +Only available with Linux kernel > 3.7.0.`) + cmd.Flags().StringVarP(&flags.ConfigFile, "config", "c", "", "Use a configuration file.") + cmd.Flags().BoolVarP(&flags.Verbose, "verbose", "v", false, "Enable verbose mode.") + + return cmd +} + +type LocalClient struct { + *system.MixedListener + *shadowsocks.Client +} + +func NewLocalClient(flags *Flags) (*LocalClient, error) { + if flags.ConfigFile != "" { + configFile, err := ioutil.ReadFile(flags.ConfigFile) + if err != nil { + return nil, exceptions.Cause(err, "read config file") + } + flagsNew := new(Flags) + err = json.Unmarshal(configFile, flagsNew) + if err != nil { + return nil, exceptions.Cause(err, "decode config file") + } + if flagsNew.Server != "" && flags.Server == "" { + flags.Server = flagsNew.Server + } + if flagsNew.ServerPort != 0 && flags.ServerPort == 0 { + flags.ServerPort = flagsNew.ServerPort + } + if flagsNew.LocalPort != 0 && flags.LocalPort == 0 { + flags.LocalPort = flagsNew.LocalPort + } + if flagsNew.Password != "" && flags.Password == "" { + flags.Password = flagsNew.Password + } + if flagsNew.Key != "" && flags.Key == "" { + flags.Key = flagsNew.Key + } + if flagsNew.Method != "" && flags.Method == "" { + flags.Method = flagsNew.Method + } + if flagsNew.TCPFastOpen { + flags.TCPFastOpen = true + } + if flagsNew.Verbose { + flags.Verbose = true + } + + } + + clientConfig := &shadowsocks.ClientConfig{ + Server: flags.Server, + ServerPort: flags.ServerPort, + Method: flags.Method, + } + + if flags.Key != "" { + key, err := base64.URLEncoding.DecodeString(flags.Key) + if err != nil { + return nil, exceptions.Cause(err, "decode key") + } + clientConfig.Key = key + } else if flags.Password != "" { + clientConfig.Password = []byte(flags.Password) + } + + if flags.Verbose { + logrus.SetLevel(logrus.TraceLevel) + } + + dialer := new(net.Dialer) + + if flags.TCPFastOpen { + dialer.Control = func(network, address string, c syscall.RawConn) error { + var rawFd uintptr + err := c.Control(func(fd uintptr) { + rawFd = fd + }) + if err != nil { + return err + } + return system.TCPFastOpen(rawFd) + } + } + + shadowClient, err := shadowsocks.NewClient(dialer, clientConfig) + if err != nil { + return nil, exceptions.Cause(err, "create shadowsocks") + } + + client := &LocalClient{ + Client: shadowClient, + } + client.MixedListener = system.NewMixedListener(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), flags.LocalPort), &system.SocksConfig{}, client) + return client, nil +} + +func (c *LocalClient) Start() error { + err := c.MixedListener.Start() + if err != nil { + return err + } + logrus.Info("mixed server started at ", c.MixedListener.TCPListener.Addr()) + return nil +} + +func (c *LocalClient) NewConnection(addr socksaddr.Addr, port uint16, conn net.Conn) error { + logrus.Info("TCP ", conn.RemoteAddr(), " ==> ", net.JoinHostPort(addr.String(), strconv.Itoa(int(port)))) + + ctx := context.Background() + serverConn, err := c.DialContextTCP(ctx, addr, port) + if err != nil { + return err + } + return task.Run(ctx, func() error { + defer rw.CloseRead(conn) + defer rw.CloseWrite(serverConn) + return common.Error(io.Copy(serverConn, conn)) + }, func() error { + defer rw.CloseRead(serverConn) + defer rw.CloseWrite(conn) + return common.Error(io.Copy(conn, serverConn)) + }) +} + +func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, addr socksaddr.Addr, port uint16) error { + ctx := context.Background() + serverConn := c.DialContextUDP(ctx) + return task.Run(ctx, func() error { + var init bool + return socks.CopyPacketConn(serverConn, conn, func(size int) { + if !init { + init = true + logrus.Info("UDP ", conn.LocalAddr(), " ==> ", socksaddr.JoinHostPort(addr, port)) + } else { + logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", socksaddr.JoinHostPort(addr, port)) + } + }) + }, func() error { + return socks.CopyPacketConn(conn, serverConn, func(size int) { + logrus.Trace("UDP ", conn.LocalAddr(), " <== ", socksaddr.JoinHostPort(addr, port)) + }) + }) +} + +func Run(flags *Flags) { + client, err := NewLocalClient(flags) + if err != nil { + logrus.Fatal(err) + } + err = client.Start() + if err != nil { + logrus.Fatal(err) + } + osSignals := make(chan os.Signal, 1) + signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM) + <-osSignals + client.Close() +} + +func (c *LocalClient) OnError(err error) { + if exceptions.IsClosed(err) { + return + } + logrus.Warn(err) +} diff --git a/cli/sslocal/cmd.go b/cli/sslocal_raw/sslocal.go similarity index 73% rename from cli/sslocal/cmd.go rename to cli/sslocal_raw/sslocal.go index c60e129..1d79672 100644 --- a/cli/sslocal/cmd.go +++ b/cli/sslocal_raw/sslocal.go @@ -1,15 +1,18 @@ -package sslocal +package main import ( "context" "crypto/rand" "encoding/base64" + "encoding/json" "errors" "io" + "io/ioutil" "net" "net/netip" "os" "os/signal" + "strings" "syscall" "time" @@ -27,13 +30,24 @@ import ( "github.com/spf13/cobra" ) +func main() { + err := MainCmd().Execute() + if err != nil { + logrus.Fatal(err) + } +} + type Flags struct { - Server string - ServerPort uint16 - LocalPort uint16 - Password string - Key string - Method string + Server string `json:"server"` + ServerPort uint16 `json:"server_port"` + LocalPort uint16 `json:"local_port"` + Password string `json:"password"` + Key string `json:"key"` + Method string `json:"method"` + Timeout uint16 `json:"timeout"` + TCPFastOpen bool `json:"fast_open"` + Verbose bool `json:"verbose"` + ConfigFile string } func MainCmd() *cobra.Command { @@ -65,13 +79,10 @@ chacha20-ietf-poly1305 xchacha20-ietf-poly1305 The default cipher is chacha20-ietf-poly1305.`) - // cmd.Flags().Uint16VarP(&flags.Timeout, "timeout", "t", 60, "Set the socket timeout in seconds.") - // cmd.Flags().StringVarP(&flags.ConfigFile, "config", "c", "", "Use a configuration file.") - // cmd.Flags().Uint16VarP(&flags.MaxFD, "max-open-files", "n", 0, `Specify max number of open files. - // Only available on Linux.`) - // cmd.Flags().StringVarP(&flags.Interface, "interface", "i", "", `Send traffic through specific network interface. - // For example, there are three interfaces in your device, which is lo (127.0.0.1), eth0 (192.168.0.1) and eth1 (192.168.0.2). Meanwhile, you configure ss-local to listen on 0.0.0.0:8388 and bind to eth1. That results the traffic go out through eth1, but not lo nor eth0. This option is useful to control traffic in multi-interface environment.`) - // cmd.Flags().StringVarP(&flags.LocalAddress, "local-address", "b", "", "Specify the local address to use while this client is making outbound connections to the server.") + cmd.Flags().BoolVar(&flags.TCPFastOpen, "fast-open", false, `Enable TCP fast open. +Only available with Linux kernel > 3.7.0.`) + cmd.Flags().StringVarP(&flags.ConfigFile, "config", "c", "", "Use a configuration file.") + cmd.Flags().BoolVarP(&flags.Verbose, "verbose", "v", false, "Enable verbose mode.") return cmd } @@ -81,10 +92,27 @@ type LocalClient struct { serverAddr netip.AddrPort cipher shadowsocks.Cipher key []byte + dialer net.Dialer } func NewLocalClient(flags *Flags) (*LocalClient, error) { - client := new(LocalClient) + if flags.ConfigFile != "" { + configFile, err := os.Open(flags.ConfigFile) + if err != nil { + return nil, exceptions.CauseF(err, "unable to open config file ", flags.ConfigFile) + } + config, err := ioutil.ReadAll(configFile) + configFile.Close() + if err != nil { + return nil, err + } + err = json.Unmarshal(config, &flags) + if err != nil { + return nil, exceptions.Cause(err, "failed to decode config file") + } + } + + client := &LocalClient{} client.tcpIn = system.NewTCPListener(netip.AddrPortFrom(netip.IPv4Unspecified(), flags.LocalPort), client) if flags.Server == "" { @@ -120,6 +148,32 @@ func NewLocalClient(flags *Flags) (*LocalClient, error) { return nil, exceptions.New("password not specified") } + if flags.Timeout > 0 { + client.dialer.Timeout = time.Duration(flags.Timeout) * time.Second + } + + if flags.TCPFastOpen { + client.dialer.Control = func(network, address string, c syscall.RawConn) error { + if strings.HasPrefix(network, "tcp") { + var rawFd uintptr + if err = c.Control(func(fd uintptr) { + rawFd = fd + }); err != nil { + return err + } + err = system.TCPFastOpen(rawFd) + if err != nil { + return exceptions.Cause(err, "set tcp fast open") + } + } + return nil + } + } + + if flags.Verbose { + logrus.SetLevel(logrus.TraceLevel) + } + return client, nil } @@ -142,7 +196,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error { authRequest, err := socks.ReadAuthRequest(conn) if err != nil { - return err + return exceptions.Cause(err, "read socks auth request") } if !common.Contains(authRequest.Methods, socks.AuthTypeNotRequired) { @@ -151,7 +205,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error { Method: socks.AuthTypeNoAcceptedMethods, }) if err != nil { - return err + return exceptions.Cause(err, "write socks auth response") } } @@ -160,12 +214,12 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error { Method: socks.AuthTypeNotRequired, }) if err != nil { - return err + return exceptions.Cause(err, "write socks auth response") } request, err := socks.ReadRequest(conn) if err != nil { - return err + return exceptions.Cause(err, "read socks request") } ctx := context.Background() @@ -181,7 +235,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error { case socks.CommandConnect: logrus.Info("CONNECT ", request.Addr, ":", request.Port) - serverConn, dialErr := system.Dial(ctx, "tcp", c.serverAddr.String()) + serverConn, dialErr := c.dialer.DialContext(ctx, "tcp", c.serverAddr.String()) if dialErr != nil { failure() return exceptions.Cause(dialErr, "connect to server") @@ -196,12 +250,12 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error { Writer: serverConn, Buffer: saltBuffer, } - writer, _ := c.cipher.CreateWriter(c.key, saltBuffer.Bytes(), serverWriter) + writer := c.cipher.CreateWriter(c.key, saltBuffer.Bytes(), serverWriter) - header := buf.New() - defer header.Release() + requestBuffer := buf.New() + defer requestBuffer.Release() - err = shadowsocks.AddressSerializer.WriteAddressAndPort(header, request.Addr, request.Port) + err = shadowsocks.AddressSerializer.WriteAddressAndPort(requestBuffer, request.Addr, request.Port) if err != nil { failure() return err @@ -215,7 +269,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error { BindPort: serverPort, }) if err != nil { - return exceptions.Cause(err, "write response for ", request.Addr, "/", request.Port) + return exceptions.Cause(err, "write socks response") } return task.Run(ctx, func() error { @@ -226,7 +280,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error { if err != nil { return err } - _, err = header.ReadFrom(conn) + _, err = requestBuffer.ReadFrom(conn) if err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { } else { @@ -237,7 +291,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error { if err != nil { return err } - _, err = writer.Write(header.Bytes()) + _, err = writer.Write(requestBuffer.Bytes()) if err != nil { return err } @@ -245,7 +299,8 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error { if err != nil { return exceptions.Cause(err, "flush request") } - _, err = io.Copy(writer, conn) + requestBuffer.FullReset() + _, err = io.CopyBuffer(writer, conn, requestBuffer.FreeBytes()) if err != nil { return exceptions.Cause(err, "upload") } @@ -275,10 +330,10 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error { return nil }) case socks.CommandUDPAssociate: - serverConn, dialErr := system.Dial(ctx, "udp", c.serverAddr.String()) + serverConn, dialErr := c.dialer.DialContext(ctx, "udp", c.serverAddr.String()) if dialErr != nil { failure() - return exceptions.Cause(err, "connect to server") + return exceptions.Cause(dialErr, "connect to server") } handler := &udpHandler{ LocalClient: c, @@ -300,9 +355,12 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error { } go handler.loopInput() return common.Error(io.Copy(io.Discard, conn)) + default: + return socks.WriteResponse(conn, &socks.Response{ + Version: request.Version, + ReplyCode: socks.ReplyCodeUnsupported, + }) } - - return nil } type udpHandler struct { @@ -313,7 +371,7 @@ type udpHandler struct { sourceAddr net.Addr } -func (c *udpHandler) HandleUDP(listener *system.UDPListener, buffer *buf.Buffer, sourceAddr net.Addr) error { +func (c *udpHandler) HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error { c.sourceAddr = sourceAddr buffer.Advance(3) if c.cipher.SaltSize() > 0 { @@ -369,5 +427,8 @@ func (c *udpHandler) Close() error { } func (c *LocalClient) OnError(err error) { + if exceptions.IsClosed(err) { + return + } logrus.Warn(err) } diff --git a/common/buf/bufconn.go b/common/buf/bufconn.go new file mode 100644 index 0000000..f39bbd6 --- /dev/null +++ b/common/buf/bufconn.go @@ -0,0 +1,44 @@ +package buf + +import ( + "bufio" + "io" + "net" +) + +type BufferedConn struct { + r *bufio.Reader + net.Conn +} + +func NewBufferedConn(c net.Conn) *BufferedConn { + return &BufferedConn{bufio.NewReader(c), c} +} + +func (c *BufferedConn) Reader() *bufio.Reader { + return c.r +} + +func (c *BufferedConn) Peek(n int) ([]byte, error) { + return c.r.Peek(n) +} + +func (c *BufferedConn) Read(p []byte) (int, error) { + return c.r.Read(p) +} + +func (c *BufferedConn) ReadByte() (byte, error) { + return c.r.ReadByte() +} + +func (c *BufferedConn) UnreadByte() error { + return c.r.UnreadByte() +} + +func (c *BufferedConn) Buffered() int { + return c.r.Buffered() +} + +func (c *BufferedConn) WriteTo(w io.Writer) (n int64, err error) { + return c.r.WriteTo(w) +} diff --git a/common/buf/buffer.go b/common/buf/buffer.go index a34608e..bc3c0cd 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -25,6 +25,20 @@ func New() *Buffer { } } +func NewSize(size int) *Buffer { + if size <= 128 || size > BufferSize { + return &Buffer{ + data: make([]byte, size), + } + } + return &Buffer{ + data: GetBytes(), + start: ReversedHeader, + end: ReversedHeader, + managed: true, + } +} + func FullNew() *Buffer { return &Buffer{ data: GetBytes(), @@ -57,6 +71,20 @@ func As(data []byte) *Buffer { } } +func Or(data []byte, size int) *Buffer { + max := cap(data) + if size != max { + data = data[:max] + } + if cap(data) >= size { + return &Buffer{ + data: data, + } + } else { + return NewSize(size) + } +} + func With(data []byte) *Buffer { return &Buffer{ data: data, @@ -346,7 +374,7 @@ func (b Buffer) Copy() []byte { func ReleaseMulti(mb *list.List[*Buffer]) { for entry := mb.Front(); entry != nil; entry = entry.Next() { // TODO: remove cast - var buffer *Buffer = entry.Value + buffer := entry.Value buffer.Release() } } diff --git a/common/buf/conn.go b/common/buf/io.go similarity index 71% rename from common/buf/conn.go rename to common/buf/io.go index 2e5d55c..78dcaf8 100644 --- a/common/buf/conn.go +++ b/common/buf/io.go @@ -2,6 +2,8 @@ package buf import ( "io" + + "github.com/sagernet/sing/common" ) type BufferedReader struct { @@ -16,13 +18,18 @@ func (r *BufferedReader) Upstream() io.Reader { return r.Reader } +func (r *BufferedReader) Replaceable() bool { + return r.Buffer == nil +} + func (r *BufferedReader) Read(p []byte) (n int, err error) { if r.Buffer != nil { n, err = r.Buffer.Read(p) - if err == nil { - return + if r.Buffer.IsEmpty() { + r.Buffer.Release() + r.Buffer = nil } - r.Buffer = nil + return } return r.Reader.Read(p) } @@ -33,12 +40,13 @@ type BufferedWriter struct { } func (w *BufferedWriter) Upstream() io.Writer { - if w.Buffer != nil { - return nil - } return w.Writer } +func (w *BufferedWriter) Replaceable() bool { + return w.Buffer == nil +} + func (w *BufferedWriter) Write(p []byte) (n int, err error) { if w.Buffer == nil { return w.Writer.Write(p) @@ -47,13 +55,7 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) { if err == nil { return } - n, err = w.Writer.Write(w.Buffer.Bytes()) - if err != nil { - return 0, err - } - w.Buffer.Release() - w.Buffer = nil - return w.Writer.Write(p) + return len(p), w.Flush() } func (w *BufferedWriter) Flush() error { @@ -66,6 +68,5 @@ func (w *BufferedWriter) Flush() error { if buffer.IsEmpty() { return nil } - _, err := w.Writer.Write(buffer.Bytes()) - return err + return common.Error(w.Writer.Write(buffer.Bytes())) } diff --git a/common/exceptions/error.go b/common/exceptions/error.go index 05f2e59..43caec4 100644 --- a/common/exceptions/error.go +++ b/common/exceptions/error.go @@ -3,6 +3,9 @@ package exceptions import ( "errors" "fmt" + "io" + "net" + "syscall" ) type Exception interface { @@ -10,11 +13,6 @@ type Exception interface { Cause() error } -type SuppressedException interface { - error - Suppressed() error -} - type exception struct { message string cause error @@ -31,10 +29,36 @@ func (e exception) Cause() error { return e.cause } +func (e exception) Unwrap() error { + return e.cause +} + +func (e exception) Is(err error) bool { + return e == err || errors.Is(e.cause, err) +} + func New(message ...any) error { return errors.New(fmt.Sprint(message...)) } -func Cause(cause error, message ...any) Exception { - return &exception{fmt.Sprint(message...), cause} +func Cause(cause error, message string) Exception { + return exception{message, cause} +} + +func CauseF(cause error, message ...any) Exception { + return exception{fmt.Sprint(message), cause} +} + +func IsClosed(err error) bool { + return IsTimeout(err) || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) +} + +func IsTimeout(err error) bool { + if unwrapErr := errors.Unwrap(err); unwrapErr != nil { + err = unwrapErr + } + if opErr, isOpErr := err.(*net.OpError); isOpErr { + return opErr.Timeout() + } + return false } diff --git a/common/flush.go b/common/flush.go index 91dabf9..aa235b1 100644 --- a/common/flush.go +++ b/common/flush.go @@ -1,12 +1,15 @@ package common -import "io" +import ( + "io" +) type Flusher interface { Flush() error } func Flush(writer io.Writer) error { + writerBack := writer for { if f, ok := writer.(Flusher); ok { err := f.Flush() @@ -15,6 +18,15 @@ func Flush(writer io.Writer) error { } } if u, ok := writer.(WriterWithUpstream); ok { + if u.Replaceable() { + if writerBack == writer { + } else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter { + setter.SetWriter(writerBack) + writer = u.Upstream() + continue + } + } + writerBack = writer writer = u.Upstream() } else { break @@ -22,3 +34,62 @@ func Flush(writer io.Writer) error { } return nil } + +func FlushVar(writerP *io.Writer) error { + writer := *writerP + writerBack := writer + for { + if f, ok := writer.(Flusher); ok { + err := f.Flush() + if err != nil { + return err + } + } + if u, ok := writer.(WriterWithUpstream); ok { + if u.Replaceable() { + if writerBack == writer { + writer = u.Upstream() + writerBack = writer + writerP = &writer + continue + } else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter { + setter.SetWriter(writerBack) + writer = u.Upstream() + continue + } + } + writerBack = writer + writer = u.Upstream() + } else { + break + } + } + return nil +} + +type FlushOnceWriter struct { + io.Writer + flushed bool +} + +func (w *FlushOnceWriter) Upstream() io.Writer { + return w.Writer +} + +func (w *FlushOnceWriter) Replaceable() bool { + return w.flushed +} + +func (w *FlushOnceWriter) Write(p []byte) (n int, err error) { + if w.flushed { + return w.Writer.Write(p) + } + n, err = w.Writer.Write(p) + if n > 0 { + err = FlushVar(&w.Writer) + } + if err == nil { + w.flushed = true + } + return +} diff --git a/common/rw/copy.go b/common/rw/copy.go index 60fcabb..0ae26b4 100644 --- a/common/rw/copy.go +++ b/common/rw/copy.go @@ -4,18 +4,40 @@ import ( "context" "io" "net" + "os" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/task" ) -func CopyConn(ctx context.Context, conn net.Conn, outConn net.Conn) error { - return task.Run(ctx, func() error { - return common.Error(io.Copy(conn, outConn)) - }, func() error { - return common.Error(io.Copy(outConn, conn)) - }) +func ReadFromVar(writerVar *io.Writer, reader io.Reader) (int64, error) { + writer := *writerVar + writerBack := writer + for { + if w, ok := writer.(io.ReaderFrom); ok { + return w.ReadFrom(reader) + } + if f, ok := writer.(common.Flusher); ok { + err := f.Flush() + if err != nil { + return 0, err + } + } + if u, ok := writer.(common.WriterWithUpstream); ok { + if u.Replaceable() && writerBack == writer { + writer = u.Upstream() + writerBack = writer + writerVar = &writer + continue + } + writer = u.Upstream() + writerBack = writer + } else { + break + } + } + return 0, os.ErrInvalid } func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error { diff --git a/common/socksaddr/addr.go b/common/socksaddr/addr.go index d575685..737234c 100644 --- a/common/socksaddr/addr.go +++ b/common/socksaddr/addr.go @@ -3,6 +3,7 @@ package socksaddr import ( "net" "net/netip" + "strconv" ) type Addr interface { @@ -12,6 +13,26 @@ type Addr interface { String() string } +func ParseAddr(address string) Addr { + addr, err := netip.ParseAddr(address) + if err == nil { + return AddrFromAddr(addr) + } + return AddrFromFqdn(address) +} + +func ParseAddrPort(address string) (Addr, uint16, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, 0, err + } + portInt, err := strconv.Atoi(port) + if err != nil { + return nil, 0, err + } + return ParseAddr(host), uint16(portInt), nil +} + func AddrFromIP(ip net.IP) Addr { addr, _ := netip.AddrFromSlice(ip) if addr.Is4() { @@ -21,6 +42,14 @@ func AddrFromIP(ip net.IP) Addr { } } +func AddrFromAddr(addr netip.Addr) Addr { + if addr.Is4() { + return Addr4(addr.As4()) + } else { + return Addr16(addr.As16()) + } +} + func AddressFromNetAddr(netAddr net.Addr) (addr Addr, port uint16) { var ip net.IP switch addr := netAddr.(type) { @@ -38,6 +67,10 @@ func AddrFromFqdn(fqdn string) Addr { return AddrFqdn(fqdn) } +func JoinHostPort(addr Addr, port uint16) string { + return net.JoinHostPort(addr.String(), strconv.Itoa(int(port))) +} + type Addr4 [4]byte func (a Addr4) Family() Family { diff --git a/common/upstream.go b/common/upstream.go index 640a865..7bad06f 100644 --- a/common/upstream.go +++ b/common/upstream.go @@ -6,8 +6,18 @@ import ( type ReaderWithUpstream interface { Upstream() io.Reader + Replaceable() bool +} + +type UpstreamReaderSetter interface { + SetUpstream(reader io.Reader) } type WriterWithUpstream interface { Upstream() io.Writer + Replaceable() bool +} + +type UpstreamWriterSetter interface { + SetWriter(writer io.Writer) } diff --git a/go.mod b/go.mod index 64bc8b6..067d004 100644 --- a/go.mod +++ b/go.mod @@ -11,5 +11,6 @@ require ( require ( github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/testify v1.7.1 // indirect golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect ) diff --git a/go.sum b/go.sum index c326a99..23009a3 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,5 @@ github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= @@ -12,8 +13,10 @@ github.com/spf13/cobra v1.4.0 h1:y+wJpx64xcgO1V+RcnwW0LEHxTKRi2ZDPSBjWnrg88Q= github.com/spf13/cobra v1.4.0/go.mod h1:Wo4iy3BUC+X2Fybo0PDqwJIv3dNRiZLHQymsfxlB84g= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 h1:tkVvjkPTB7pnW3jnid7kNyAMPVWllTNOf/qKDze4p9o= golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -21,3 +24,5 @@ golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0P golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/protocol/shadowsocks/cipher.go b/protocol/shadowsocks/cipher.go index ccc03a1..22e3441 100644 --- a/protocol/shadowsocks/cipher.go +++ b/protocol/shadowsocks/cipher.go @@ -11,8 +11,8 @@ import ( type Cipher interface { KeySize() int SaltSize() int - CreateReader(key []byte, iv []byte, reader io.Reader) io.Reader - CreateWriter(key []byte, iv []byte, writer io.Writer) (io.Writer, int) + CreateReader(key []byte, salt []byte, reader io.Reader) io.Reader + CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer EncodePacket(key []byte, buffer *buf.Buffer) error DecodePacket(key []byte, buffer *buf.Buffer) error } diff --git a/protocol/shadowsocks/cipher_aead.go b/protocol/shadowsocks/cipher_aead.go index 9db8773..6ff89d0 100644 --- a/protocol/shadowsocks/cipher_aead.go +++ b/protocol/shadowsocks/cipher_aead.go @@ -92,9 +92,9 @@ func (c *AEADCipher) CreateReader(key []byte, salt []byte, reader io.Reader) io. return NewAEADReader(reader, c.Constructor(Kdf(key, salt, c.KeyLength))) } -func (c *AEADCipher) CreateWriter(key []byte, salt []byte, writer io.Writer) (io.Writer, int) { +func (c *AEADCipher) CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer { protocolWriter := NewAEADWriter(writer, c.Constructor(Kdf(key, salt, c.KeyLength))) - return protocolWriter, protocolWriter.maxDataSize + return protocolWriter } func (c *AEADCipher) EncodePacket(key []byte, buffer *buf.Buffer) error { @@ -132,12 +132,6 @@ func (c *AEADConn) Write(p []byte) (n int, err error) { return c.Writer.Write(p) } -func (c *AEADConn) Close() error { - c.Reader.Close() - c.Writer.Close() - return c.Conn.Close() -} - type AEADReader struct { upstream io.Reader cipher cipher.AEAD @@ -151,7 +145,7 @@ func NewAEADReader(upstream io.Reader, cipher cipher.AEAD) *AEADReader { return &AEADReader{ upstream: upstream, cipher: cipher, - data: buf.GetBytes(), + data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2), nonce: make([]byte, cipher.NonceSize()), } } @@ -160,6 +154,52 @@ func (r *AEADReader) Upstream() io.Reader { return r.upstream } +func (r *AEADReader) Replaceable() bool { + return false +} + +func (r *AEADReader) SetUpstream(reader io.Reader) { + r.upstream = reader +} + +func (r *AEADReader) WriteTo(writer io.Writer) (n int64, err error) { + if r.cached > 0 { + writeN, writeErr := writer.Write(r.data[r.index : r.index+r.cached]) + if writeErr != nil { + return int64(writeN), writeErr + } + n += int64(writeN) + } + for { + start := PacketLengthBufferSize + r.cipher.Overhead() + _, err = io.ReadFull(r.upstream, r.data[:start]) + if err != nil { + return + } + _, err = r.cipher.Open(r.data[:0], r.nonce, r.data[:start], nil) + if err != nil { + return + } + increaseNonce(r.nonce) + length := int(binary.BigEndian.Uint16(r.data[:PacketLengthBufferSize])) + end := length + r.cipher.Overhead() + _, err = io.ReadFull(r.upstream, r.data[:end]) + if err != nil { + return + } + _, err = r.cipher.Open(r.data[:0], r.nonce, r.data[:end], nil) + if err != nil { + return + } + increaseNonce(r.nonce) + writeN, writeErr := writer.Write(r.data[:length]) + if writeErr != nil { + return int64(writeN), writeErr + } + n += int64(writeN) + } +} + func (r *AEADReader) Read(b []byte) (n int, err error) { if r.cached > 0 { n = copy(b, r.data[r.index:r.index+r.cached]) @@ -209,29 +249,19 @@ func (r *AEADReader) Read(b []byte) (n int, err error) { } } -func (r *AEADReader) Close() error { - if r.data != nil { - buf.PutBytes(r.data) - r.data = nil - } - return nil -} - type AEADWriter struct { - upstream io.Writer - cipher cipher.AEAD - data []byte - nonce []byte - maxDataSize int + upstream io.Writer + cipher cipher.AEAD + data []byte + nonce []byte } func NewAEADWriter(upstream io.Writer, cipher cipher.AEAD) *AEADWriter { return &AEADWriter{ - upstream: upstream, - cipher: cipher, - data: buf.GetBytes(), - nonce: make([]byte, cipher.NonceSize()), - maxDataSize: MaxPacketSize - PacketLengthBufferSize - cipher.Overhead()*2, + upstream: upstream, + cipher: cipher, + data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2), + nonce: make([]byte, cipher.NonceSize()), } } @@ -239,47 +269,47 @@ func (w *AEADWriter) Upstream() io.Writer { return w.upstream } -func (w *AEADWriter) Process(p []byte) (n int, buffer *buf.Buffer, flush bool, err error) { - if len(p) > w.maxDataSize { - n, err = w.Write(p) - err = &rw.DirectException{ - Suppressed: err, +func (w *AEADWriter) Replaceable() bool { + return false +} + +func (w *AEADWriter) SetWriter(writer io.Writer) { + w.upstream = writer +} + +func (w *AEADWriter) ReadFrom(r io.Reader) (n int64, err error) { + for { + offset := w.cipher.Overhead() + PacketLengthBufferSize + readN, readErr := r.Read(w.data[offset : offset+MaxPacketSize]) + if readErr != nil { + return 0, readErr } - return - } - - binary.BigEndian.PutUint16(w.data[:PacketLengthBufferSize], uint16(len(p))) - encryptedLength := w.cipher.Seal(w.data[:0], w.nonce, w.data[:PacketLengthBufferSize], nil) - increaseNonce(w.nonce) - start := len(encryptedLength) - - /* - no usage - if cap(p) > len(p)+PacketLengthBufferSize+2*w.cipher.Overhead() { - packet := w.cipher.Seal(p[:start], w.nonce, p, nil) - increaseNonce(w.nonce) - copy(p[:start], encryptedLength) - n = start + len(packet) + binary.BigEndian.PutUint16(w.data[:PacketLengthBufferSize], uint16(readN)) + w.cipher.Seal(w.data[:0], w.nonce, w.data[:PacketLengthBufferSize], nil) + increaseNonce(w.nonce) + packet := w.cipher.Seal(w.data[offset:offset], w.nonce, w.data[offset:offset+readN], nil) + increaseNonce(w.nonce) + _, err = w.upstream.Write(w.data[:offset+len(packet)]) + if err != nil { return } - */ - - packet := w.cipher.Seal(w.data[:start], w.nonce, p, nil) - increaseNonce(w.nonce) - return 0, buf.As(packet), false, err + err = common.FlushVar(&w.upstream) + if err != nil { + return + } + n += int64(readN) + } } func (w *AEADWriter) Write(p []byte) (n int, err error) { - for _, data := range buf.ForeachN(p, w.maxDataSize) { + for _, data := range buf.ForeachN(p, MaxPacketSize) { binary.BigEndian.PutUint16(w.data[:PacketLengthBufferSize], uint16(len(data))) w.cipher.Seal(w.data[:0], w.nonce, w.data[:PacketLengthBufferSize], nil) increaseNonce(w.nonce) - - start := w.cipher.Overhead() + PacketLengthBufferSize - packet := w.cipher.Seal(w.data[:start], w.nonce, data, nil) + offset := w.cipher.Overhead() + PacketLengthBufferSize + packet := w.cipher.Seal(w.data[offset:offset], w.nonce, data, nil) increaseNonce(w.nonce) - - _, err = w.upstream.Write(packet) + _, err = w.upstream.Write(w.data[:offset+len(packet)]) if err != nil { return } @@ -289,14 +319,6 @@ func (w *AEADWriter) Write(p []byte) (n int, err error) { return } -func (w *AEADWriter) Close() error { - if w.data != nil { - buf.PutBytes(w.data) - w.data = nil - } - return nil -} - func increaseNonce(nonce []byte) { for i := range nonce { nonce[i]++ diff --git a/protocol/shadowsocks/cipher_none.go b/protocol/shadowsocks/cipher_none.go index 842d036..5636039 100644 --- a/protocol/shadowsocks/cipher_none.go +++ b/protocol/shadowsocks/cipher_none.go @@ -26,8 +26,8 @@ func (c *NoneCipher) CreateReader(_ []byte, _ []byte, reader io.Reader) io.Reade return reader } -func (c *NoneCipher) CreateWriter(_ []byte, _ []byte, writer io.Writer) (io.Writer, int) { - return writer, 0 +func (c *NoneCipher) CreateWriter(key []byte, iv []byte, writer io.Writer) io.Writer { + return writer } func (c *NoneCipher) EncodePacket([]byte, *buf.Buffer) error { diff --git a/protocol/shadowsocks/client.go b/protocol/shadowsocks/client.go new file mode 100644 index 0000000..3df974d --- /dev/null +++ b/protocol/shadowsocks/client.go @@ -0,0 +1,168 @@ +package shadowsocks + +import ( + "context" + "github.com/sagernet/sing/protocol/socks" + "io" + "net" + "strconv" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/socksaddr" +) + +var ( + ErrBadKey = exceptions.New("bad key") + ErrMissingPassword = exceptions.New("password not specified") +) + +type ClientConfig struct { + Server string `json:"server"` + ServerPort uint16 `json:"server_port"` + Method string `json:"method"` + Password []byte `json:"password"` + Key []byte `json:"key"` +} + +type Client struct { + dialer *net.Dialer + cipher Cipher + server string + key []byte +} + +func NewClient(dialer *net.Dialer, config *ClientConfig) (*Client, error) { + cipher, err := CreateCipher(config.Method) + if err != nil { + return nil, err + } + client := &Client{ + dialer: dialer, + cipher: cipher, + server: net.JoinHostPort(config.Server, strconv.Itoa(int(config.ServerPort))), + } + if keyLen := len(config.Key); keyLen > 0 { + if keyLen == cipher.KeySize() { + client.key = config.Key + } else { + return nil, ErrBadKey + } + } else if len(config.Password) > 0 { + client.key = Key(config.Password, cipher.KeySize()) + } else { + return nil, ErrMissingPassword + } + return client, nil +} + +func (c *Client) DialContextTCP(ctx context.Context, addr socksaddr.Addr, port uint16) (net.Conn, error) { + conn, err := c.dialer.DialContext(ctx, "tcp", c.server) + if err != nil { + return nil, exceptions.Cause(err, "connect to server") + } + return c.DialConn(conn, addr, port), nil +} + +func (c *Client) DialConn(conn net.Conn, addr socksaddr.Addr, port uint16) net.Conn { + header := buf.New() + header.WriteRandom(c.cipher.SaltSize()) + writer := &buf.BufferedWriter{ + Writer: conn, + Buffer: header, + } + protocolWriter := c.cipher.CreateWriter(c.key, header.Bytes(), writer) + requestBuffer := buf.New() + contentWriter := &buf.BufferedWriter{ + Writer: protocolWriter, + Buffer: requestBuffer, + } + common.Must(AddressSerializer.WriteAddressAndPort(contentWriter, addr, port)) + return &shadowsocksConn{ + Client: c, + Conn: conn, + Writer: &common.FlushOnceWriter{Writer: contentWriter}, + } +} + +type shadowsocksConn struct { + *Client + net.Conn + io.Writer + reader io.Reader +} + +func (c *shadowsocksConn) Read(p []byte) (n int, err error) { + if c.reader == nil { + buffer := buf.Or(p, c.cipher.SaltSize()) + defer buffer.Release() + _, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize()) + if err != nil { + return + } + c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn) + } + return c.reader.Read(p) +} + +func (c *shadowsocksConn) WriteTo(w io.Writer) (n int64, err error) { + if c.reader == nil { + buffer := buf.NewSize(c.cipher.SaltSize()) + defer buffer.Release() + _, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize()) + if err != nil { + return + } + c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn) + } + return c.reader.(io.WriterTo).WriteTo(w) +} + +func (c *shadowsocksConn) Write(p []byte) (n int, err error) { + return c.Writer.Write(p) +} + +func (c *shadowsocksConn) ReadFrom(r io.Reader) (n int64, err error) { + return rw.ReadFromVar(&c.Writer, r) +} + +func (c *Client) DialContextUDP(ctx context.Context) socks.PacketConn { + conn, err := c.dialer.DialContext(ctx, "udp", c.server) + if err != nil { + return nil + } + return &shadowsocksPacketConn{c, conn} +} + +type shadowsocksPacketConn struct { + *Client + net.Conn +} + +func (c *shadowsocksPacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error { + defer buffer.Release() + header := buf.New() + header.WriteRandom(c.cipher.SaltSize()) + common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port)) + buffer = buffer.WriteBufferAtFirst(header) + err := c.cipher.EncodePacket(c.key, buffer) + if err != nil { + return err + } + return common.Error(c.Conn.Write(buffer.Bytes())) +} + +func (c *shadowsocksPacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) { + n, err := c.Read(buffer.FreeBytes()) + if err != nil { + return nil, 0, err + } + buffer.Truncate(n) + err = c.cipher.DecodePacket(c.key, buffer) + if err != nil { + return nil, 0, err + } + return AddressSerializer.ReadAddressAndPort(buffer) +} diff --git a/protocol/shadowsocks/config.go b/protocol/shadowsocks/config.go new file mode 100644 index 0000000..4b35346 --- /dev/null +++ b/protocol/shadowsocks/config.go @@ -0,0 +1 @@ +package shadowsocks diff --git a/protocol/socks/conn.go b/protocol/socks/conn.go new file mode 100644 index 0000000..3b576e6 --- /dev/null +++ b/protocol/socks/conn.go @@ -0,0 +1,78 @@ +package socks + +import ( + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/socksaddr" + "net" + "time" +) + +type PacketConn interface { + ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) + WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error + + Close() error + LocalAddr() net.Addr + RemoteAddr() net.Addr + SetDeadline(t time.Time) error + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error +} + +func CopyPacketConn(dest PacketConn, conn PacketConn, onAction func(size int)) error { + for { + buffer := buf.New() + addr, port, err := conn.ReadPacket(buffer) + if err != nil { + buffer.Release() + return err + } + size := buffer.Len() + err = dest.WritePacket(buffer, addr, port) + if err != nil { + return err + } + if onAction != nil { + onAction(size) + } + buffer.Reset() + } +} + +type associatePacketConn struct { + net.PacketConn + conn net.Conn + addr net.Addr +} + +func NewPacketConn(conn net.Conn, packetConn net.PacketConn) PacketConn { + return &associatePacketConn{ + PacketConn: packetConn, + conn: conn, + } +} + +func (c *associatePacketConn) RemoteAddr() net.Addr { + return c.addr +} + +func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) { + n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes()) + if err != nil { + return nil, 0, err + } + c.addr = addr + buffer.Truncate(n) + buffer.Advance(3) + return AddressSerializer.ReadAddressAndPort(buffer) +} + +func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error { + defer buffer.Release() + header := buf.New() + common.Must(header.WriteZeroN(3)) + common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port)) + buffer = buffer.WriteBufferAtFirst(header) + return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr)) +} diff --git a/protocol/socks/protocol.go b/protocol/socks/protocol.go index 15c8427..2270c98 100644 --- a/protocol/socks/protocol.go +++ b/protocol/socks/protocol.go @@ -49,7 +49,7 @@ func ReadAuthRequest(reader io.Reader) (*AuthRequest, error) { } methods, err := rw.ReadBytes(reader, int(methodLen)) if err != nil { - return nil, exceptions.Cause(err, "read socks auth methods, length ", methodLen) + return nil, exceptions.CauseF(err, "read socks auth methods, length ", methodLen) } request := &AuthRequest{ version, diff --git a/transport/system/control.go b/transport/system/control.go deleted file mode 100644 index 20a222f..0000000 --- a/transport/system/control.go +++ /dev/null @@ -1,30 +0,0 @@ -package system - -import "syscall" - -var ControlFunc func(fd uintptr) error - -func Control(conn syscall.Conn) error { - if ControlFunc == nil { - return nil - } - rawConn, err := conn.SyscallConn() - if err != nil { - return err - } - return ControlRaw(rawConn) -} - -func ControlRaw(conn syscall.RawConn) error { - if ControlFunc == nil { - return nil - } - var rawFd uintptr - err := conn.Control(func(fd uintptr) { - rawFd = fd - }) - if err != nil { - return err - } - return ControlFunc(rawFd) -} diff --git a/transport/system/dial.go b/transport/system/dial.go deleted file mode 100644 index 312ad73..0000000 --- a/transport/system/dial.go +++ /dev/null @@ -1,10 +0,0 @@ -package system - -import ( - "context" - "net" -) - -type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) - -var Dial DialFunc = new(net.Dialer).DialContext diff --git a/transport/system/http.go b/transport/system/http.go new file mode 100644 index 0000000..5bfa2e7 --- /dev/null +++ b/transport/system/http.go @@ -0,0 +1,10 @@ +package system + +import ( + "bufio" + "net/http" + _ "unsafe" +) + +//go:linkname readRequest net/http.readRequest +func readRequest(b *bufio.Reader) (req *http.Request, err error) diff --git a/transport/system/mixed.go b/transport/system/mixed.go new file mode 100644 index 0000000..1861dc1 --- /dev/null +++ b/transport/system/mixed.go @@ -0,0 +1,223 @@ +package system + +import ( + "context" + "encoding/base64" + "fmt" + "net" + "net/http" + "net/netip" + "strconv" + "strings" + "time" + + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/socksaddr" + "github.com/sagernet/sing/protocol/socks" +) + +type MixedListener struct { + *SocksListener +} + +func NewMixedListener(bind netip.AddrPort, config *SocksConfig, handler SocksHandler) *MixedListener { + listener := &MixedListener{NewSocksListener(bind, config, handler)} + listener.TCPListener.Handler = listener + return listener +} + +func (l *MixedListener) HandleTCP(conn net.Conn) error { + bufConn := buf.NewBufferedConn(conn) + hdr, err := bufConn.ReadByte() + if err != nil { + return err + } + err = bufConn.UnreadByte() + if err != nil { + return err + } + + if hdr == socks.Version4 || hdr == socks.Version5 { + return l.SocksListener.HandleTCP(bufConn) + } + + var httpClient *http.Client + for { + request, err := readRequest(bufConn.Reader()) + if err != nil { + return exceptions.Cause(err, "read http request") + } + + if l.Username != "" { + var authOk bool + authorization := request.Header.Get("Proxy-Authorization") + if strings.HasPrefix(authorization, "BASIC ") { + userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:]) + if string(userPassword) == l.Username+":"+l.Password { + authOk = true + } + } + if !authOk { + err = responseWith(request, http.StatusProxyAuthRequired).Write(conn) + if err != nil { + return err + } + } + } + + if request.Method == "CONNECT" { + host := request.URL.Hostname() + portStr := request.URL.Port() + if portStr == "" { + portStr = "80" + } + port, err := strconv.Atoi(portStr) + if err != nil { + err = responseWith(request, http.StatusBadRequest).Write(conn) + if err != nil { + return err + } + } + _, err = fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established") + if err != nil { + return exceptions.Cause(err, "write http response") + } + return l.Handler.NewConnection(socksaddr.ParseAddr(host), uint16(port), bufConn) + } + + keepAlive := strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" + + host := request.Header.Get("Host") + if host != "" { + request.Host = host + } + + request.RequestURI = "" + + removeHopByHopHeaders(request.Header) + removeExtraHTTPHostPort(request) + + if request.URL.Scheme == "" || request.URL.Host == "" { + return responseWith(request, http.StatusBadRequest).Write(conn) + } + + if httpClient == nil { + httpClient = &http.Client{ + Transport: &http.Transport{ + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DialContext: func(context context.Context, network, address string) (net.Conn, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, exceptions.New("unsupported network ", network) + } + + addr, port, err := socksaddr.ParseAddrPort(address) + if err != nil { + return nil, err + } + + left, right := net.Pipe() + go func() { + err = l.Handler.NewConnection(addr, port, right) + if err != nil { + l.OnError(err) + } + }() + return left, nil + }, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + } + + response, err := httpClient.Do(request) + if err != nil { + l.OnError(exceptions.Cause(err, "http proxy")) + return responseWith(request, http.StatusBadGateway).Write(conn) + } + + removeHopByHopHeaders(response.Header) + + if keepAlive { + response.Header.Set("Proxy-Connection", "keep-alive") + response.Header.Set("Connection", "keep-alive") + response.Header.Set("Keep-Alive", "timeout=4") + } + + response.Close = !keepAlive + + err = response.Write(conn) + if err != nil { + l.OnError(exceptions.Cause(err, "http proxy")) + return err + } + } +} + +// removeHopByHopHeaders remove hop-by-hop header +func removeHopByHopHeaders(header http.Header) { + // Strip hop-by-hop header based on RFC: + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1 + // https://www.mnot.net/blog/2011/07/11/what_proxies_must_do + + header.Del("Proxy-Connection") + header.Del("Proxy-Authenticate") + header.Del("Proxy-Authorization") + header.Del("TE") + header.Del("Trailers") + header.Del("Transfer-Encoding") + header.Del("Upgrade") + + connections := header.Get("Connection") + header.Del("Connection") + if len(connections) == 0 { + return + } + for _, h := range strings.Split(connections, ",") { + header.Del(strings.TrimSpace(h)) + } +} + +// removeExtraHTTPHostPort remove extra host port (example.com:80 --> example.com) +// It resolves the behavior of some HTTP servers that do not handle host:80 (e.g. baidu.com) +func removeExtraHTTPHostPort(req *http.Request) { + host := req.Host + if host == "" { + host = req.URL.Host + } + + if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" { + host = pHost + } + + req.Host = host + req.URL.Host = host +} + +func responseWith(request *http.Request, statusCode int) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Status: http.StatusText(statusCode), + Proto: request.Proto, + ProtoMajor: request.ProtoMajor, + ProtoMinor: request.ProtoMinor, + Header: http.Header{}, + } +} + +func (l *MixedListener) Start() error { + return l.TCPListener.Start() +} + +func (l *MixedListener) Close() error { + return l.TCPListener.Close() +} + +func (l *MixedListener) OnError(err error) { + l.Handler.OnError(exceptions.Cause(err, "mixed server")) +} diff --git a/transport/system/sockopt_linux.go b/transport/system/sockopt_linux.go new file mode 100644 index 0000000..2f7335a --- /dev/null +++ b/transport/system/sockopt_linux.go @@ -0,0 +1,14 @@ +package system + +import ( + "syscall" +) + +const ( + TCP_FASTOPEN = 23 + TCP_FASTOPEN_CONNECT = 30 +) + +func TCPFastOpen(fd uintptr) error { + return syscall.SetsockoptInt(int(fd), syscall.SOL_TCP, TCP_FASTOPEN_CONNECT, 1) +} diff --git a/transport/system/sockopt_other.go b/transport/system/sockopt_other.go new file mode 100644 index 0000000..2eca8e0 --- /dev/null +++ b/transport/system/sockopt_other.go @@ -0,0 +1,9 @@ +//go:build !linux + +package main + +import "github.com/sagernet/sing/common/exceptions" + +func TCPFastOpen(fd uintptr) error { + return exceptions.New("only available on linux") +} diff --git a/transport/system/socks.go b/transport/system/socks.go new file mode 100644 index 0000000..8b93ac4 --- /dev/null +++ b/transport/system/socks.go @@ -0,0 +1,148 @@ +package system + +import ( + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/socksaddr" + "github.com/sagernet/sing/protocol/socks" + "io" + "net" + "net/netip" +) + +type SocksHandler interface { + NewConnection(addr socksaddr.Addr, port uint16, conn net.Conn) error + NewPacketConnection(conn socks.PacketConn, addr socksaddr.Addr, port uint16) error + OnError(err error) +} + +type SocksConfig struct { + Username string + Password string +} + +type SocksListener struct { + Handler SocksHandler + *TCPListener + *SocksConfig +} + +func NewSocksListener(bind netip.AddrPort, config *SocksConfig, handler SocksHandler) *SocksListener { + listener := &SocksListener{ + SocksConfig: config, + Handler: handler, + } + listener.TCPListener = NewTCPListener(bind, listener) + return listener +} + +func (l *SocksListener) HandleTCP(conn net.Conn) error { + authRequest, err := socks.ReadAuthRequest(conn) + if err != nil { + return exceptions.Cause(err, "read socks auth request") + } + var authMethod byte + if l.Username == "" { + authMethod = socks.AuthTypeNotRequired + } else { + authMethod = socks.AuthTypeUsernamePassword + } + if !common.Contains(authRequest.Methods, authMethod) { + err = socks.WriteAuthResponse(conn, &socks.AuthResponse{ + Version: authRequest.Version, + Method: socks.AuthTypeNoAcceptedMethods, + }) + if err != nil { + return exceptions.Cause(err, "write socks auth response") + } + } + err = socks.WriteAuthResponse(conn, &socks.AuthResponse{ + Version: authRequest.Version, + Method: socks.AuthTypeNotRequired, + }) + if err != nil { + return exceptions.Cause(err, "write socks auth response") + } + + if authMethod == socks.AuthTypeUsernamePassword { + usernamePasswordAuthRequest, err := socks.ReadUsernamePasswordAuthRequest(conn) + if err != nil { + return exceptions.Cause(err, "read user auth request") + } + response := socks.UsernamePasswordAuthResponse{} + if usernamePasswordAuthRequest.Username != l.Username { + response.Status = socks.UsernamePasswordStatusFailure + } else if usernamePasswordAuthRequest.Password != l.Password { + response.Status = socks.UsernamePasswordStatusFailure + } else { + response.Status = socks.UsernamePasswordStatusSuccess + } + err = socks.WriteUsernamePasswordAuthResponse(conn, &response) + if err != nil { + return exceptions.Cause(err, "write user auth response") + } + } + + request, err := socks.ReadRequest(conn) + if err != nil { + return exceptions.Cause(err, "read socks request") + } + switch request.Command { + case socks.CommandConnect: + localAddr, localPort := socksaddr.AddressFromNetAddr(l.TCPListener.TCPListener.Addr()) + err = socks.WriteResponse(conn, &socks.Response{ + Version: request.Version, + ReplyCode: socks.ReplyCodeSuccess, + BindAddr: localAddr, + BindPort: localPort, + }) + if err != nil { + return exceptions.Cause(err, "write socks response") + } + return l.Handler.NewConnection(request.Addr, request.Port, conn) + case socks.CommandUDPAssociate: + udpConn, err := net.ListenUDP("udp4", nil) + if err != nil { + return err + } + defer udpConn.Close() + localAddr, localPort := socksaddr.AddressFromNetAddr(udpConn.LocalAddr()) + err = socks.WriteResponse(conn, &socks.Response{ + Version: request.Version, + ReplyCode: socks.ReplyCodeSuccess, + BindAddr: localAddr, + BindPort: localPort, + }) + if err != nil { + return exceptions.Cause(err, "write socks response") + } + go func() { + err := l.Handler.NewPacketConnection(socks.NewPacketConn(conn, udpConn), request.Addr, request.Port) + if err != nil { + l.OnError(err) + } + }() + return common.Error(io.Copy(io.Discard, conn)) + default: + err = socks.WriteResponse(conn, &socks.Response{ + Version: request.Version, + ReplyCode: socks.ReplyCodeUnsupported, + }) + if err != nil { + return exceptions.Cause(err, "write response") + } + } + return nil +} + +func (l *SocksListener) Start() error { + return l.TCPListener.Start() +} + +func (l *SocksListener) Close() error { + return l.TCPListener.Close() +} + +func (l *SocksListener) OnError(err error) { + l.Handler.OnError(exceptions.Cause(err, "socks server")) +} diff --git a/transport/system/stub.s b/transport/system/stub.s new file mode 100644 index 0000000..e69de29 diff --git a/transport/system/udp.go b/transport/system/udp.go index 3d90bd8..8cb6ab6 100644 --- a/transport/system/udp.go +++ b/transport/system/udp.go @@ -8,7 +8,7 @@ import ( ) type UDPHandler interface { - HandleUDP(listener *UDPListener, buffer *buf.Buffer, sourceAddr net.Addr) error + HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error OnError(err error) } @@ -52,7 +52,7 @@ func (l *UDPListener) loop() { } buffer.Truncate(n) go func() { - err := l.Handler.HandleUDP(l, buffer, addr) + err := l.Handler.HandleUDP(buffer, addr) if err != nil { buffer.Release() l.Handler.OnError(err)