Add http support for sslocal

This commit is contained in:
世界 2022-04-07 20:49:20 +08:00
parent 542fa4f975
commit 26e13e7beb
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
30 changed files with 1355 additions and 194 deletions

3
.gitignore vendored
View file

@ -1,2 +1,3 @@
/.idea/
/sing_*
/sing_*
/*.json

View file

@ -1,16 +0,0 @@
package main
import (
"fmt"
"os"
"github.com/sagernet/sing/cli/sslocal"
)
func main() {
err := sslocal.MainCmd().Execute()
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}

243
cli/sslocal/sslocal.go Normal file
View file

@ -0,0 +1,243 @@
package main
import (
"context"
"encoding/base64"
"encoding/json"
"io"
"io/ioutil"
"net"
"net/netip"
"os"
"os/signal"
"strconv"
"syscall"
"github.com/sagernet/sing"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/socksaddr"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/transport/system"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
func main() {
err := MainCmd().Execute()
if err != nil {
logrus.Fatal(err)
}
}
type Flags struct {
Server string `json:"server"`
ServerPort uint16 `json:"server_port"`
LocalPort uint16 `json:"local_port"`
Password string `json:"password"`
Key string `json:"key"`
Method string `json:"method"`
TCPFastOpen bool `json:"fast_open"`
Verbose bool `json:"verbose"`
ConfigFile string
}
func MainCmd() *cobra.Command {
flags := new(Flags)
cmd := &cobra.Command{
Use: "sslocal",
Short: "shadowsocks client as socks5 proxy, sing port",
Version: sing.Version,
Run: func(cmd *cobra.Command, args []string) {
Run(flags)
},
}
cmd.Flags().StringVarP(&flags.Server, "server", "s", "", "Set the server’s hostname or IP.")
cmd.Flags().Uint16VarP(&flags.ServerPort, "server-port", "p", 0, "Set the server’s port number.")
cmd.Flags().Uint16VarP(&flags.LocalPort, "local-port", "l", 0, "Set the local port number.")
cmd.Flags().StringVarP(&flags.Password, "password", "k", "", "Set the password. The server and the client should use the same password.")
cmd.Flags().StringVar(&flags.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
cmd.Flags().StringVarP(&flags.Method, "encrypt-method", "m", "", `Set the cipher.
Supported ciphers:
none
aes-128-gcm
aes-192-gcm
aes-256-gcm
chacha20-ietf-poly1305
xchacha20-ietf-poly1305
The default cipher is chacha20-ietf-poly1305.`)
cmd.Flags().BoolVar(&flags.TCPFastOpen, "fast-open", false, `Enable TCP fast open.
Only available with Linux kernel > 3.7.0.`)
cmd.Flags().StringVarP(&flags.ConfigFile, "config", "c", "", "Use a configuration file.")
cmd.Flags().BoolVarP(&flags.Verbose, "verbose", "v", false, "Enable verbose mode.")
return cmd
}
type LocalClient struct {
*system.MixedListener
*shadowsocks.Client
}
func NewLocalClient(flags *Flags) (*LocalClient, error) {
if flags.ConfigFile != "" {
configFile, err := ioutil.ReadFile(flags.ConfigFile)
if err != nil {
return nil, exceptions.Cause(err, "read config file")
}
flagsNew := new(Flags)
err = json.Unmarshal(configFile, flagsNew)
if err != nil {
return nil, exceptions.Cause(err, "decode config file")
}
if flagsNew.Server != "" && flags.Server == "" {
flags.Server = flagsNew.Server
}
if flagsNew.ServerPort != 0 && flags.ServerPort == 0 {
flags.ServerPort = flagsNew.ServerPort
}
if flagsNew.LocalPort != 0 && flags.LocalPort == 0 {
flags.LocalPort = flagsNew.LocalPort
}
if flagsNew.Password != "" && flags.Password == "" {
flags.Password = flagsNew.Password
}
if flagsNew.Key != "" && flags.Key == "" {
flags.Key = flagsNew.Key
}
if flagsNew.Method != "" && flags.Method == "" {
flags.Method = flagsNew.Method
}
if flagsNew.TCPFastOpen {
flags.TCPFastOpen = true
}
if flagsNew.Verbose {
flags.Verbose = true
}
}
clientConfig := &shadowsocks.ClientConfig{
Server: flags.Server,
ServerPort: flags.ServerPort,
Method: flags.Method,
}
if flags.Key != "" {
key, err := base64.URLEncoding.DecodeString(flags.Key)
if err != nil {
return nil, exceptions.Cause(err, "decode key")
}
clientConfig.Key = key
} else if flags.Password != "" {
clientConfig.Password = []byte(flags.Password)
}
if flags.Verbose {
logrus.SetLevel(logrus.TraceLevel)
}
dialer := new(net.Dialer)
if flags.TCPFastOpen {
dialer.Control = func(network, address string, c syscall.RawConn) error {
var rawFd uintptr
err := c.Control(func(fd uintptr) {
rawFd = fd
})
if err != nil {
return err
}
return system.TCPFastOpen(rawFd)
}
}
shadowClient, err := shadowsocks.NewClient(dialer, clientConfig)
if err != nil {
return nil, exceptions.Cause(err, "create shadowsocks")
}
client := &LocalClient{
Client: shadowClient,
}
client.MixedListener = system.NewMixedListener(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), flags.LocalPort), &system.SocksConfig{}, client)
return client, nil
}
func (c *LocalClient) Start() error {
err := c.MixedListener.Start()
if err != nil {
return err
}
logrus.Info("mixed server started at ", c.MixedListener.TCPListener.Addr())
return nil
}
func (c *LocalClient) NewConnection(addr socksaddr.Addr, port uint16, conn net.Conn) error {
logrus.Info("TCP ", conn.RemoteAddr(), " ==> ", net.JoinHostPort(addr.String(), strconv.Itoa(int(port))))
ctx := context.Background()
serverConn, err := c.DialContextTCP(ctx, addr, port)
if err != nil {
return err
}
return task.Run(ctx, func() error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(serverConn)
return common.Error(io.Copy(serverConn, conn))
}, func() error {
defer rw.CloseRead(serverConn)
defer rw.CloseWrite(conn)
return common.Error(io.Copy(conn, serverConn))
})
}
func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, addr socksaddr.Addr, port uint16) error {
ctx := context.Background()
serverConn := c.DialContextUDP(ctx)
return task.Run(ctx, func() error {
var init bool
return socks.CopyPacketConn(serverConn, conn, func(size int) {
if !init {
init = true
logrus.Info("UDP ", conn.LocalAddr(), " ==> ", socksaddr.JoinHostPort(addr, port))
} else {
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", socksaddr.JoinHostPort(addr, port))
}
})
}, func() error {
return socks.CopyPacketConn(conn, serverConn, func(size int) {
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", socksaddr.JoinHostPort(addr, port))
})
})
}
func Run(flags *Flags) {
client, err := NewLocalClient(flags)
if err != nil {
logrus.Fatal(err)
}
err = client.Start()
if err != nil {
logrus.Fatal(err)
}
osSignals := make(chan os.Signal, 1)
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
<-osSignals
client.Close()
}
func (c *LocalClient) OnError(err error) {
if exceptions.IsClosed(err) {
return
}
logrus.Warn(err)
}

View file

@ -1,15 +1,18 @@
package sslocal
package main
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"io"
"io/ioutil"
"net"
"net/netip"
"os"
"os/signal"
"strings"
"syscall"
"time"
@ -27,13 +30,24 @@ import (
"github.com/spf13/cobra"
)
func main() {
err := MainCmd().Execute()
if err != nil {
logrus.Fatal(err)
}
}
type Flags struct {
Server string
ServerPort uint16
LocalPort uint16
Password string
Key string
Method string
Server string `json:"server"`
ServerPort uint16 `json:"server_port"`
LocalPort uint16 `json:"local_port"`
Password string `json:"password"`
Key string `json:"key"`
Method string `json:"method"`
Timeout uint16 `json:"timeout"`
TCPFastOpen bool `json:"fast_open"`
Verbose bool `json:"verbose"`
ConfigFile string
}
func MainCmd() *cobra.Command {
@ -65,13 +79,10 @@ chacha20-ietf-poly1305
xchacha20-ietf-poly1305
The default cipher is chacha20-ietf-poly1305.`)
// cmd.Flags().Uint16VarP(&flags.Timeout, "timeout", "t", 60, "Set the socket timeout in seconds.")
// cmd.Flags().StringVarP(&flags.ConfigFile, "config", "c", "", "Use a configuration file.")
// cmd.Flags().Uint16VarP(&flags.MaxFD, "max-open-files", "n", 0, `Specify max number of open files.
// Only available on Linux.`)
// cmd.Flags().StringVarP(&flags.Interface, "interface", "i", "", `Send traffic through specific network interface.
// For example, there are three interfaces in your device, which is lo (127.0.0.1), eth0 (192.168.0.1) and eth1 (192.168.0.2). Meanwhile, you configure ss-local to listen on 0.0.0.0:8388 and bind to eth1. That results the traffic go out through eth1, but not lo nor eth0. This option is useful to control traffic in multi-interface environment.`)
// cmd.Flags().StringVarP(&flags.LocalAddress, "local-address", "b", "", "Specify the local address to use while this client is making outbound connections to the server.")
cmd.Flags().BoolVar(&flags.TCPFastOpen, "fast-open", false, `Enable TCP fast open.
Only available with Linux kernel > 3.7.0.`)
cmd.Flags().StringVarP(&flags.ConfigFile, "config", "c", "", "Use a configuration file.")
cmd.Flags().BoolVarP(&flags.Verbose, "verbose", "v", false, "Enable verbose mode.")
return cmd
}
@ -81,10 +92,27 @@ type LocalClient struct {
serverAddr netip.AddrPort
cipher shadowsocks.Cipher
key []byte
dialer net.Dialer
}
func NewLocalClient(flags *Flags) (*LocalClient, error) {
client := new(LocalClient)
if flags.ConfigFile != "" {
configFile, err := os.Open(flags.ConfigFile)
if err != nil {
return nil, exceptions.CauseF(err, "unable to open config file ", flags.ConfigFile)
}
config, err := ioutil.ReadAll(configFile)
configFile.Close()
if err != nil {
return nil, err
}
err = json.Unmarshal(config, &flags)
if err != nil {
return nil, exceptions.Cause(err, "failed to decode config file")
}
}
client := &LocalClient{}
client.tcpIn = system.NewTCPListener(netip.AddrPortFrom(netip.IPv4Unspecified(), flags.LocalPort), client)
if flags.Server == "" {
@ -120,6 +148,32 @@ func NewLocalClient(flags *Flags) (*LocalClient, error) {
return nil, exceptions.New("password not specified")
}
if flags.Timeout > 0 {
client.dialer.Timeout = time.Duration(flags.Timeout) * time.Second
}
if flags.TCPFastOpen {
client.dialer.Control = func(network, address string, c syscall.RawConn) error {
if strings.HasPrefix(network, "tcp") {
var rawFd uintptr
if err = c.Control(func(fd uintptr) {
rawFd = fd
}); err != nil {
return err
}
err = system.TCPFastOpen(rawFd)
if err != nil {
return exceptions.Cause(err, "set tcp fast open")
}
}
return nil
}
}
if flags.Verbose {
logrus.SetLevel(logrus.TraceLevel)
}
return client, nil
}
@ -142,7 +196,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
authRequest, err := socks.ReadAuthRequest(conn)
if err != nil {
return err
return exceptions.Cause(err, "read socks auth request")
}
if !common.Contains(authRequest.Methods, socks.AuthTypeNotRequired) {
@ -151,7 +205,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
Method: socks.AuthTypeNoAcceptedMethods,
})
if err != nil {
return err
return exceptions.Cause(err, "write socks auth response")
}
}
@ -160,12 +214,12 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
Method: socks.AuthTypeNotRequired,
})
if err != nil {
return err
return exceptions.Cause(err, "write socks auth response")
}
request, err := socks.ReadRequest(conn)
if err != nil {
return err
return exceptions.Cause(err, "read socks request")
}
ctx := context.Background()
@ -181,7 +235,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
case socks.CommandConnect:
logrus.Info("CONNECT ", request.Addr, ":", request.Port)
serverConn, dialErr := system.Dial(ctx, "tcp", c.serverAddr.String())
serverConn, dialErr := c.dialer.DialContext(ctx, "tcp", c.serverAddr.String())
if dialErr != nil {
failure()
return exceptions.Cause(dialErr, "connect to server")
@ -196,12 +250,12 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
Writer: serverConn,
Buffer: saltBuffer,
}
writer, _ := c.cipher.CreateWriter(c.key, saltBuffer.Bytes(), serverWriter)
writer := c.cipher.CreateWriter(c.key, saltBuffer.Bytes(), serverWriter)
header := buf.New()
defer header.Release()
requestBuffer := buf.New()
defer requestBuffer.Release()
err = shadowsocks.AddressSerializer.WriteAddressAndPort(header, request.Addr, request.Port)
err = shadowsocks.AddressSerializer.WriteAddressAndPort(requestBuffer, request.Addr, request.Port)
if err != nil {
failure()
return err
@ -215,7 +269,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
BindPort: serverPort,
})
if err != nil {
return exceptions.Cause(err, "write response for ", request.Addr, "/", request.Port)
return exceptions.Cause(err, "write socks response")
}
return task.Run(ctx, func() error {
@ -226,7 +280,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
if err != nil {
return err
}
_, err = header.ReadFrom(conn)
_, err = requestBuffer.ReadFrom(conn)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
} else {
@ -237,7 +291,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
if err != nil {
return err
}
_, err = writer.Write(header.Bytes())
_, err = writer.Write(requestBuffer.Bytes())
if err != nil {
return err
}
@ -245,7 +299,8 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
if err != nil {
return exceptions.Cause(err, "flush request")
}
_, err = io.Copy(writer, conn)
requestBuffer.FullReset()
_, err = io.CopyBuffer(writer, conn, requestBuffer.FreeBytes())
if err != nil {
return exceptions.Cause(err, "upload")
}
@ -275,10 +330,10 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
return nil
})
case socks.CommandUDPAssociate:
serverConn, dialErr := system.Dial(ctx, "udp", c.serverAddr.String())
serverConn, dialErr := c.dialer.DialContext(ctx, "udp", c.serverAddr.String())
if dialErr != nil {
failure()
return exceptions.Cause(err, "connect to server")
return exceptions.Cause(dialErr, "connect to server")
}
handler := &udpHandler{
LocalClient: c,
@ -300,9 +355,12 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
}
go handler.loopInput()
return common.Error(io.Copy(io.Discard, conn))
default:
return socks.WriteResponse(conn, &socks.Response{
Version: request.Version,
ReplyCode: socks.ReplyCodeUnsupported,
})
}
return nil
}
type udpHandler struct {
@ -313,7 +371,7 @@ type udpHandler struct {
sourceAddr net.Addr
}
func (c *udpHandler) HandleUDP(listener *system.UDPListener, buffer *buf.Buffer, sourceAddr net.Addr) error {
func (c *udpHandler) HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error {
c.sourceAddr = sourceAddr
buffer.Advance(3)
if c.cipher.SaltSize() > 0 {
@ -369,5 +427,8 @@ func (c *udpHandler) Close() error {
}
func (c *LocalClient) OnError(err error) {
if exceptions.IsClosed(err) {
return
}
logrus.Warn(err)
}

44
common/buf/bufconn.go Normal file
View file

@ -0,0 +1,44 @@
package buf
import (
"bufio"
"io"
"net"
)
type BufferedConn struct {
r *bufio.Reader
net.Conn
}
func NewBufferedConn(c net.Conn) *BufferedConn {
return &BufferedConn{bufio.NewReader(c), c}
}
func (c *BufferedConn) Reader() *bufio.Reader {
return c.r
}
func (c *BufferedConn) Peek(n int) ([]byte, error) {
return c.r.Peek(n)
}
func (c *BufferedConn) Read(p []byte) (int, error) {
return c.r.Read(p)
}
func (c *BufferedConn) ReadByte() (byte, error) {
return c.r.ReadByte()
}
func (c *BufferedConn) UnreadByte() error {
return c.r.UnreadByte()
}
func (c *BufferedConn) Buffered() int {
return c.r.Buffered()
}
func (c *BufferedConn) WriteTo(w io.Writer) (n int64, err error) {
return c.r.WriteTo(w)
}

View file

@ -25,6 +25,20 @@ func New() *Buffer {
}
}
func NewSize(size int) *Buffer {
if size <= 128 || size > BufferSize {
return &Buffer{
data: make([]byte, size),
}
}
return &Buffer{
data: GetBytes(),
start: ReversedHeader,
end: ReversedHeader,
managed: true,
}
}
func FullNew() *Buffer {
return &Buffer{
data: GetBytes(),
@ -57,6 +71,20 @@ func As(data []byte) *Buffer {
}
}
func Or(data []byte, size int) *Buffer {
max := cap(data)
if size != max {
data = data[:max]
}
if cap(data) >= size {
return &Buffer{
data: data,
}
} else {
return NewSize(size)
}
}
func With(data []byte) *Buffer {
return &Buffer{
data: data,
@ -346,7 +374,7 @@ func (b Buffer) Copy() []byte {
func ReleaseMulti(mb *list.List[*Buffer]) {
for entry := mb.Front(); entry != nil; entry = entry.Next() {
// TODO: remove cast
var buffer *Buffer = entry.Value
buffer := entry.Value
buffer.Release()
}
}

View file

@ -2,6 +2,8 @@ package buf
import (
"io"
"github.com/sagernet/sing/common"
)
type BufferedReader struct {
@ -16,13 +18,18 @@ func (r *BufferedReader) Upstream() io.Reader {
return r.Reader
}
func (r *BufferedReader) Replaceable() bool {
return r.Buffer == nil
}
func (r *BufferedReader) Read(p []byte) (n int, err error) {
if r.Buffer != nil {
n, err = r.Buffer.Read(p)
if err == nil {
return
if r.Buffer.IsEmpty() {
r.Buffer.Release()
r.Buffer = nil
}
r.Buffer = nil
return
}
return r.Reader.Read(p)
}
@ -33,12 +40,13 @@ type BufferedWriter struct {
}
func (w *BufferedWriter) Upstream() io.Writer {
if w.Buffer != nil {
return nil
}
return w.Writer
}
func (w *BufferedWriter) Replaceable() bool {
return w.Buffer == nil
}
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
if w.Buffer == nil {
return w.Writer.Write(p)
@ -47,13 +55,7 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) {
if err == nil {
return
}
n, err = w.Writer.Write(w.Buffer.Bytes())
if err != nil {
return 0, err
}
w.Buffer.Release()
w.Buffer = nil
return w.Writer.Write(p)
return len(p), w.Flush()
}
func (w *BufferedWriter) Flush() error {
@ -66,6 +68,5 @@ func (w *BufferedWriter) Flush() error {
if buffer.IsEmpty() {
return nil
}
_, err := w.Writer.Write(buffer.Bytes())
return err
return common.Error(w.Writer.Write(buffer.Bytes()))
}

View file

@ -3,6 +3,9 @@ package exceptions
import (
"errors"
"fmt"
"io"
"net"
"syscall"
)
type Exception interface {
@ -10,11 +13,6 @@ type Exception interface {
Cause() error
}
type SuppressedException interface {
error
Suppressed() error
}
type exception struct {
message string
cause error
@ -31,10 +29,36 @@ func (e exception) Cause() error {
return e.cause
}
func (e exception) Unwrap() error {
return e.cause
}
func (e exception) Is(err error) bool {
return e == err || errors.Is(e.cause, err)
}
func New(message ...any) error {
return errors.New(fmt.Sprint(message...))
}
func Cause(cause error, message ...any) Exception {
return &exception{fmt.Sprint(message...), cause}
func Cause(cause error, message string) Exception {
return exception{message, cause}
}
func CauseF(cause error, message ...any) Exception {
return exception{fmt.Sprint(message), cause}
}
func IsClosed(err error) bool {
return IsTimeout(err) || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE)
}
func IsTimeout(err error) bool {
if unwrapErr := errors.Unwrap(err); unwrapErr != nil {
err = unwrapErr
}
if opErr, isOpErr := err.(*net.OpError); isOpErr {
return opErr.Timeout()
}
return false
}

View file

@ -1,12 +1,15 @@
package common
import "io"
import (
"io"
)
type Flusher interface {
Flush() error
}
func Flush(writer io.Writer) error {
writerBack := writer
for {
if f, ok := writer.(Flusher); ok {
err := f.Flush()
@ -15,6 +18,15 @@ func Flush(writer io.Writer) error {
}
}
if u, ok := writer.(WriterWithUpstream); ok {
if u.Replaceable() {
if writerBack == writer {
} else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter {
setter.SetWriter(writerBack)
writer = u.Upstream()
continue
}
}
writerBack = writer
writer = u.Upstream()
} else {
break
@ -22,3 +34,62 @@ func Flush(writer io.Writer) error {
}
return nil
}
func FlushVar(writerP *io.Writer) error {
writer := *writerP
writerBack := writer
for {
if f, ok := writer.(Flusher); ok {
err := f.Flush()
if err != nil {
return err
}
}
if u, ok := writer.(WriterWithUpstream); ok {
if u.Replaceable() {
if writerBack == writer {
writer = u.Upstream()
writerBack = writer
writerP = &writer
continue
} else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter {
setter.SetWriter(writerBack)
writer = u.Upstream()
continue
}
}
writerBack = writer
writer = u.Upstream()
} else {
break
}
}
return nil
}
type FlushOnceWriter struct {
io.Writer
flushed bool
}
func (w *FlushOnceWriter) Upstream() io.Writer {
return w.Writer
}
func (w *FlushOnceWriter) Replaceable() bool {
return w.flushed
}
func (w *FlushOnceWriter) Write(p []byte) (n int, err error) {
if w.flushed {
return w.Writer.Write(p)
}
n, err = w.Writer.Write(p)
if n > 0 {
err = FlushVar(&w.Writer)
}
if err == nil {
w.flushed = true
}
return
}

View file

@ -4,18 +4,40 @@ import (
"context"
"io"
"net"
"os"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/task"
)
func CopyConn(ctx context.Context, conn net.Conn, outConn net.Conn) error {
return task.Run(ctx, func() error {
return common.Error(io.Copy(conn, outConn))
}, func() error {
return common.Error(io.Copy(outConn, conn))
})
func ReadFromVar(writerVar *io.Writer, reader io.Reader) (int64, error) {
writer := *writerVar
writerBack := writer
for {
if w, ok := writer.(io.ReaderFrom); ok {
return w.ReadFrom(reader)
}
if f, ok := writer.(common.Flusher); ok {
err := f.Flush()
if err != nil {
return 0, err
}
}
if u, ok := writer.(common.WriterWithUpstream); ok {
if u.Replaceable() && writerBack == writer {
writer = u.Upstream()
writerBack = writer
writerVar = &writer
continue
}
writer = u.Upstream()
writerBack = writer
} else {
break
}
}
return 0, os.ErrInvalid
}
func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error {

View file

@ -3,6 +3,7 @@ package socksaddr
import (
"net"
"net/netip"
"strconv"
)
type Addr interface {
@ -12,6 +13,26 @@ type Addr interface {
String() string
}
func ParseAddr(address string) Addr {
addr, err := netip.ParseAddr(address)
if err == nil {
return AddrFromAddr(addr)
}
return AddrFromFqdn(address)
}
func ParseAddrPort(address string) (Addr, uint16, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, 0, err
}
portInt, err := strconv.Atoi(port)
if err != nil {
return nil, 0, err
}
return ParseAddr(host), uint16(portInt), nil
}
func AddrFromIP(ip net.IP) Addr {
addr, _ := netip.AddrFromSlice(ip)
if addr.Is4() {
@ -21,6 +42,14 @@ func AddrFromIP(ip net.IP) Addr {
}
}
func AddrFromAddr(addr netip.Addr) Addr {
if addr.Is4() {
return Addr4(addr.As4())
} else {
return Addr16(addr.As16())
}
}
func AddressFromNetAddr(netAddr net.Addr) (addr Addr, port uint16) {
var ip net.IP
switch addr := netAddr.(type) {
@ -38,6 +67,10 @@ func AddrFromFqdn(fqdn string) Addr {
return AddrFqdn(fqdn)
}
func JoinHostPort(addr Addr, port uint16) string {
return net.JoinHostPort(addr.String(), strconv.Itoa(int(port)))
}
type Addr4 [4]byte
func (a Addr4) Family() Family {

View file

@ -6,8 +6,18 @@ import (
type ReaderWithUpstream interface {
Upstream() io.Reader
Replaceable() bool
}
type UpstreamReaderSetter interface {
SetUpstream(reader io.Reader)
}
type WriterWithUpstream interface {
Upstream() io.Writer
Replaceable() bool
}
type UpstreamWriterSetter interface {
SetWriter(writer io.Writer)
}

1
go.mod
View file

@ -11,5 +11,6 @@ require (
require (
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/testify v1.7.1 // indirect
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect
)

7
go.sum
View file

@ -1,4 +1,5 @@
github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
@ -12,8 +13,10 @@ github.com/spf13/cobra v1.4.0 h1:y+wJpx64xcgO1V+RcnwW0LEHxTKRi2ZDPSBjWnrg88Q=
github.com/spf13/cobra v1.4.0/go.mod h1:Wo4iy3BUC+X2Fybo0PDqwJIv3dNRiZLHQymsfxlB84g=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 h1:tkVvjkPTB7pnW3jnid7kNyAMPVWllTNOf/qKDze4p9o=
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -21,3 +24,5 @@ golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0P
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -11,8 +11,8 @@ import (
type Cipher interface {
KeySize() int
SaltSize() int
CreateReader(key []byte, iv []byte, reader io.Reader) io.Reader
CreateWriter(key []byte, iv []byte, writer io.Writer) (io.Writer, int)
CreateReader(key []byte, salt []byte, reader io.Reader) io.Reader
CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer
EncodePacket(key []byte, buffer *buf.Buffer) error
DecodePacket(key []byte, buffer *buf.Buffer) error
}

View file

@ -92,9 +92,9 @@ func (c *AEADCipher) CreateReader(key []byte, salt []byte, reader io.Reader) io.
return NewAEADReader(reader, c.Constructor(Kdf(key, salt, c.KeyLength)))
}
func (c *AEADCipher) CreateWriter(key []byte, salt []byte, writer io.Writer) (io.Writer, int) {
func (c *AEADCipher) CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer {
protocolWriter := NewAEADWriter(writer, c.Constructor(Kdf(key, salt, c.KeyLength)))
return protocolWriter, protocolWriter.maxDataSize
return protocolWriter
}
func (c *AEADCipher) EncodePacket(key []byte, buffer *buf.Buffer) error {
@ -132,12 +132,6 @@ func (c *AEADConn) Write(p []byte) (n int, err error) {
return c.Writer.Write(p)
}
func (c *AEADConn) Close() error {
c.Reader.Close()
c.Writer.Close()
return c.Conn.Close()
}
type AEADReader struct {
upstream io.Reader
cipher cipher.AEAD
@ -151,7 +145,7 @@ func NewAEADReader(upstream io.Reader, cipher cipher.AEAD) *AEADReader {
return &AEADReader{
upstream: upstream,
cipher: cipher,
data: buf.GetBytes(),
data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
nonce: make([]byte, cipher.NonceSize()),
}
}
@ -160,6 +154,52 @@ func (r *AEADReader) Upstream() io.Reader {
return r.upstream
}
func (r *AEADReader) Replaceable() bool {
return false
}
func (r *AEADReader) SetUpstream(reader io.Reader) {
r.upstream = reader
}
func (r *AEADReader) WriteTo(writer io.Writer) (n int64, err error) {
if r.cached > 0 {
writeN, writeErr := writer.Write(r.data[r.index : r.index+r.cached])
if writeErr != nil {
return int64(writeN), writeErr
}
n += int64(writeN)
}
for {
start := PacketLengthBufferSize + r.cipher.Overhead()
_, err = io.ReadFull(r.upstream, r.data[:start])
if err != nil {
return
}
_, err = r.cipher.Open(r.data[:0], r.nonce, r.data[:start], nil)
if err != nil {
return
}
increaseNonce(r.nonce)
length := int(binary.BigEndian.Uint16(r.data[:PacketLengthBufferSize]))
end := length + r.cipher.Overhead()
_, err = io.ReadFull(r.upstream, r.data[:end])
if err != nil {
return
}
_, err = r.cipher.Open(r.data[:0], r.nonce, r.data[:end], nil)
if err != nil {
return
}
increaseNonce(r.nonce)
writeN, writeErr := writer.Write(r.data[:length])
if writeErr != nil {
return int64(writeN), writeErr
}
n += int64(writeN)
}
}
func (r *AEADReader) Read(b []byte) (n int, err error) {
if r.cached > 0 {
n = copy(b, r.data[r.index:r.index+r.cached])
@ -209,29 +249,19 @@ func (r *AEADReader) Read(b []byte) (n int, err error) {
}
}
func (r *AEADReader) Close() error {
if r.data != nil {
buf.PutBytes(r.data)
r.data = nil
}
return nil
}
type AEADWriter struct {
upstream io.Writer
cipher cipher.AEAD
data []byte
nonce []byte
maxDataSize int
upstream io.Writer
cipher cipher.AEAD
data []byte
nonce []byte
}
func NewAEADWriter(upstream io.Writer, cipher cipher.AEAD) *AEADWriter {
return &AEADWriter{
upstream: upstream,
cipher: cipher,
data: buf.GetBytes(),
nonce: make([]byte, cipher.NonceSize()),
maxDataSize: MaxPacketSize - PacketLengthBufferSize - cipher.Overhead()*2,
upstream: upstream,
cipher: cipher,
data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
nonce: make([]byte, cipher.NonceSize()),
}
}
@ -239,47 +269,47 @@ func (w *AEADWriter) Upstream() io.Writer {
return w.upstream
}
func (w *AEADWriter) Process(p []byte) (n int, buffer *buf.Buffer, flush bool, err error) {
if len(p) > w.maxDataSize {
n, err = w.Write(p)
err = &rw.DirectException{
Suppressed: err,
func (w *AEADWriter) Replaceable() bool {
return false
}
func (w *AEADWriter) SetWriter(writer io.Writer) {
w.upstream = writer
}
func (w *AEADWriter) ReadFrom(r io.Reader) (n int64, err error) {
for {
offset := w.cipher.Overhead() + PacketLengthBufferSize
readN, readErr := r.Read(w.data[offset : offset+MaxPacketSize])
if readErr != nil {
return 0, readErr
}
return
}
binary.BigEndian.PutUint16(w.data[:PacketLengthBufferSize], uint16(len(p)))
encryptedLength := w.cipher.Seal(w.data[:0], w.nonce, w.data[:PacketLengthBufferSize], nil)
increaseNonce(w.nonce)
start := len(encryptedLength)
/*
no usage
if cap(p) > len(p)+PacketLengthBufferSize+2*w.cipher.Overhead() {
packet := w.cipher.Seal(p[:start], w.nonce, p, nil)
increaseNonce(w.nonce)
copy(p[:start], encryptedLength)
n = start + len(packet)
binary.BigEndian.PutUint16(w.data[:PacketLengthBufferSize], uint16(readN))
w.cipher.Seal(w.data[:0], w.nonce, w.data[:PacketLengthBufferSize], nil)
increaseNonce(w.nonce)
packet := w.cipher.Seal(w.data[offset:offset], w.nonce, w.data[offset:offset+readN], nil)
increaseNonce(w.nonce)
_, err = w.upstream.Write(w.data[:offset+len(packet)])
if err != nil {
return
}
*/
packet := w.cipher.Seal(w.data[:start], w.nonce, p, nil)
increaseNonce(w.nonce)
return 0, buf.As(packet), false, err
err = common.FlushVar(&w.upstream)
if err != nil {
return
}
n += int64(readN)
}
}
func (w *AEADWriter) Write(p []byte) (n int, err error) {
for _, data := range buf.ForeachN(p, w.maxDataSize) {
for _, data := range buf.ForeachN(p, MaxPacketSize) {
binary.BigEndian.PutUint16(w.data[:PacketLengthBufferSize], uint16(len(data)))
w.cipher.Seal(w.data[:0], w.nonce, w.data[:PacketLengthBufferSize], nil)
increaseNonce(w.nonce)
start := w.cipher.Overhead() + PacketLengthBufferSize
packet := w.cipher.Seal(w.data[:start], w.nonce, data, nil)
offset := w.cipher.Overhead() + PacketLengthBufferSize
packet := w.cipher.Seal(w.data[offset:offset], w.nonce, data, nil)
increaseNonce(w.nonce)
_, err = w.upstream.Write(packet)
_, err = w.upstream.Write(w.data[:offset+len(packet)])
if err != nil {
return
}
@ -289,14 +319,6 @@ func (w *AEADWriter) Write(p []byte) (n int, err error) {
return
}
func (w *AEADWriter) Close() error {
if w.data != nil {
buf.PutBytes(w.data)
w.data = nil
}
return nil
}
func increaseNonce(nonce []byte) {
for i := range nonce {
nonce[i]++

View file

@ -26,8 +26,8 @@ func (c *NoneCipher) CreateReader(_ []byte, _ []byte, reader io.Reader) io.Reade
return reader
}
func (c *NoneCipher) CreateWriter(_ []byte, _ []byte, writer io.Writer) (io.Writer, int) {
return writer, 0
func (c *NoneCipher) CreateWriter(key []byte, iv []byte, writer io.Writer) io.Writer {
return writer
}
func (c *NoneCipher) EncodePacket([]byte, *buf.Buffer) error {

View file

@ -0,0 +1,168 @@
package shadowsocks
import (
"context"
"github.com/sagernet/sing/protocol/socks"
"io"
"net"
"strconv"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/socksaddr"
)
var (
ErrBadKey = exceptions.New("bad key")
ErrMissingPassword = exceptions.New("password not specified")
)
type ClientConfig struct {
Server string `json:"server"`
ServerPort uint16 `json:"server_port"`
Method string `json:"method"`
Password []byte `json:"password"`
Key []byte `json:"key"`
}
type Client struct {
dialer *net.Dialer
cipher Cipher
server string
key []byte
}
func NewClient(dialer *net.Dialer, config *ClientConfig) (*Client, error) {
cipher, err := CreateCipher(config.Method)
if err != nil {
return nil, err
}
client := &Client{
dialer: dialer,
cipher: cipher,
server: net.JoinHostPort(config.Server, strconv.Itoa(int(config.ServerPort))),
}
if keyLen := len(config.Key); keyLen > 0 {
if keyLen == cipher.KeySize() {
client.key = config.Key
} else {
return nil, ErrBadKey
}
} else if len(config.Password) > 0 {
client.key = Key(config.Password, cipher.KeySize())
} else {
return nil, ErrMissingPassword
}
return client, nil
}
func (c *Client) DialContextTCP(ctx context.Context, addr socksaddr.Addr, port uint16) (net.Conn, error) {
conn, err := c.dialer.DialContext(ctx, "tcp", c.server)
if err != nil {
return nil, exceptions.Cause(err, "connect to server")
}
return c.DialConn(conn, addr, port), nil
}
func (c *Client) DialConn(conn net.Conn, addr socksaddr.Addr, port uint16) net.Conn {
header := buf.New()
header.WriteRandom(c.cipher.SaltSize())
writer := &buf.BufferedWriter{
Writer: conn,
Buffer: header,
}
protocolWriter := c.cipher.CreateWriter(c.key, header.Bytes(), writer)
requestBuffer := buf.New()
contentWriter := &buf.BufferedWriter{
Writer: protocolWriter,
Buffer: requestBuffer,
}
common.Must(AddressSerializer.WriteAddressAndPort(contentWriter, addr, port))
return &shadowsocksConn{
Client: c,
Conn: conn,
Writer: &common.FlushOnceWriter{Writer: contentWriter},
}
}
type shadowsocksConn struct {
*Client
net.Conn
io.Writer
reader io.Reader
}
func (c *shadowsocksConn) Read(p []byte) (n int, err error) {
if c.reader == nil {
buffer := buf.Or(p, c.cipher.SaltSize())
defer buffer.Release()
_, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize())
if err != nil {
return
}
c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn)
}
return c.reader.Read(p)
}
func (c *shadowsocksConn) WriteTo(w io.Writer) (n int64, err error) {
if c.reader == nil {
buffer := buf.NewSize(c.cipher.SaltSize())
defer buffer.Release()
_, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize())
if err != nil {
return
}
c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn)
}
return c.reader.(io.WriterTo).WriteTo(w)
}
func (c *shadowsocksConn) Write(p []byte) (n int, err error) {
return c.Writer.Write(p)
}
func (c *shadowsocksConn) ReadFrom(r io.Reader) (n int64, err error) {
return rw.ReadFromVar(&c.Writer, r)
}
func (c *Client) DialContextUDP(ctx context.Context) socks.PacketConn {
conn, err := c.dialer.DialContext(ctx, "udp", c.server)
if err != nil {
return nil
}
return &shadowsocksPacketConn{c, conn}
}
type shadowsocksPacketConn struct {
*Client
net.Conn
}
func (c *shadowsocksPacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error {
defer buffer.Release()
header := buf.New()
header.WriteRandom(c.cipher.SaltSize())
common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port))
buffer = buffer.WriteBufferAtFirst(header)
err := c.cipher.EncodePacket(c.key, buffer)
if err != nil {
return err
}
return common.Error(c.Conn.Write(buffer.Bytes()))
}
func (c *shadowsocksPacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return nil, 0, err
}
buffer.Truncate(n)
err = c.cipher.DecodePacket(c.key, buffer)
if err != nil {
return nil, 0, err
}
return AddressSerializer.ReadAddressAndPort(buffer)
}

View file

@ -0,0 +1 @@
package shadowsocks

78
protocol/socks/conn.go Normal file
View file

@ -0,0 +1,78 @@
package socks
import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/socksaddr"
"net"
"time"
)
type PacketConn interface {
ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error)
WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
SetDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
}
func CopyPacketConn(dest PacketConn, conn PacketConn, onAction func(size int)) error {
for {
buffer := buf.New()
addr, port, err := conn.ReadPacket(buffer)
if err != nil {
buffer.Release()
return err
}
size := buffer.Len()
err = dest.WritePacket(buffer, addr, port)
if err != nil {
return err
}
if onAction != nil {
onAction(size)
}
buffer.Reset()
}
}
type associatePacketConn struct {
net.PacketConn
conn net.Conn
addr net.Addr
}
func NewPacketConn(conn net.Conn, packetConn net.PacketConn) PacketConn {
return &associatePacketConn{
PacketConn: packetConn,
conn: conn,
}
}
func (c *associatePacketConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) {
n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes())
if err != nil {
return nil, 0, err
}
c.addr = addr
buffer.Truncate(n)
buffer.Advance(3)
return AddressSerializer.ReadAddressAndPort(buffer)
}
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error {
defer buffer.Release()
header := buf.New()
common.Must(header.WriteZeroN(3))
common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port))
buffer = buffer.WriteBufferAtFirst(header)
return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr))
}

View file

@ -49,7 +49,7 @@ func ReadAuthRequest(reader io.Reader) (*AuthRequest, error) {
}
methods, err := rw.ReadBytes(reader, int(methodLen))
if err != nil {
return nil, exceptions.Cause(err, "read socks auth methods, length ", methodLen)
return nil, exceptions.CauseF(err, "read socks auth methods, length ", methodLen)
}
request := &AuthRequest{
version,

View file

@ -1,30 +0,0 @@
package system
import "syscall"
var ControlFunc func(fd uintptr) error
func Control(conn syscall.Conn) error {
if ControlFunc == nil {
return nil
}
rawConn, err := conn.SyscallConn()
if err != nil {
return err
}
return ControlRaw(rawConn)
}
func ControlRaw(conn syscall.RawConn) error {
if ControlFunc == nil {
return nil
}
var rawFd uintptr
err := conn.Control(func(fd uintptr) {
rawFd = fd
})
if err != nil {
return err
}
return ControlFunc(rawFd)
}

View file

@ -1,10 +0,0 @@
package system
import (
"context"
"net"
)
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
var Dial DialFunc = new(net.Dialer).DialContext

10
transport/system/http.go Normal file
View file

@ -0,0 +1,10 @@
package system
import (
"bufio"
"net/http"
_ "unsafe"
)
//go:linkname readRequest net/http.readRequest
func readRequest(b *bufio.Reader) (req *http.Request, err error)

223
transport/system/mixed.go Normal file
View file

@ -0,0 +1,223 @@
package system
import (
"context"
"encoding/base64"
"fmt"
"net"
"net/http"
"net/netip"
"strconv"
"strings"
"time"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/socksaddr"
"github.com/sagernet/sing/protocol/socks"
)
type MixedListener struct {
*SocksListener
}
func NewMixedListener(bind netip.AddrPort, config *SocksConfig, handler SocksHandler) *MixedListener {
listener := &MixedListener{NewSocksListener(bind, config, handler)}
listener.TCPListener.Handler = listener
return listener
}
func (l *MixedListener) HandleTCP(conn net.Conn) error {
bufConn := buf.NewBufferedConn(conn)
hdr, err := bufConn.ReadByte()
if err != nil {
return err
}
err = bufConn.UnreadByte()
if err != nil {
return err
}
if hdr == socks.Version4 || hdr == socks.Version5 {
return l.SocksListener.HandleTCP(bufConn)
}
var httpClient *http.Client
for {
request, err := readRequest(bufConn.Reader())
if err != nil {
return exceptions.Cause(err, "read http request")
}
if l.Username != "" {
var authOk bool
authorization := request.Header.Get("Proxy-Authorization")
if strings.HasPrefix(authorization, "BASIC ") {
userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
if string(userPassword) == l.Username+":"+l.Password {
authOk = true
}
}
if !authOk {
err = responseWith(request, http.StatusProxyAuthRequired).Write(conn)
if err != nil {
return err
}
}
}
if request.Method == "CONNECT" {
host := request.URL.Hostname()
portStr := request.URL.Port()
if portStr == "" {
portStr = "80"
}
port, err := strconv.Atoi(portStr)
if err != nil {
err = responseWith(request, http.StatusBadRequest).Write(conn)
if err != nil {
return err
}
}
_, err = fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established")
if err != nil {
return exceptions.Cause(err, "write http response")
}
return l.Handler.NewConnection(socksaddr.ParseAddr(host), uint16(port), bufConn)
}
keepAlive := strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
host := request.Header.Get("Host")
if host != "" {
request.Host = host
}
request.RequestURI = ""
removeHopByHopHeaders(request.Header)
removeExtraHTTPHostPort(request)
if request.URL.Scheme == "" || request.URL.Host == "" {
return responseWith(request, http.StatusBadRequest).Write(conn)
}
if httpClient == nil {
httpClient = &http.Client{
Transport: &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: func(context context.Context, network, address string) (net.Conn, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, exceptions.New("unsupported network ", network)
}
addr, port, err := socksaddr.ParseAddrPort(address)
if err != nil {
return nil, err
}
left, right := net.Pipe()
go func() {
err = l.Handler.NewConnection(addr, port, right)
if err != nil {
l.OnError(err)
}
}()
return left, nil
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}
response, err := httpClient.Do(request)
if err != nil {
l.OnError(exceptions.Cause(err, "http proxy"))
return responseWith(request, http.StatusBadGateway).Write(conn)
}
removeHopByHopHeaders(response.Header)
if keepAlive {
response.Header.Set("Proxy-Connection", "keep-alive")
response.Header.Set("Connection", "keep-alive")
response.Header.Set("Keep-Alive", "timeout=4")
}
response.Close = !keepAlive
err = response.Write(conn)
if err != nil {
l.OnError(exceptions.Cause(err, "http proxy"))
return err
}
}
}
// removeHopByHopHeaders remove hop-by-hop header
func removeHopByHopHeaders(header http.Header) {
// Strip hop-by-hop header based on RFC:
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1
// https://www.mnot.net/blog/2011/07/11/what_proxies_must_do
header.Del("Proxy-Connection")
header.Del("Proxy-Authenticate")
header.Del("Proxy-Authorization")
header.Del("TE")
header.Del("Trailers")
header.Del("Transfer-Encoding")
header.Del("Upgrade")
connections := header.Get("Connection")
header.Del("Connection")
if len(connections) == 0 {
return
}
for _, h := range strings.Split(connections, ",") {
header.Del(strings.TrimSpace(h))
}
}
// removeExtraHTTPHostPort remove extra host port (example.com:80 --> example.com)
// It resolves the behavior of some HTTP servers that do not handle host:80 (e.g. baidu.com)
func removeExtraHTTPHostPort(req *http.Request) {
host := req.Host
if host == "" {
host = req.URL.Host
}
if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" {
host = pHost
}
req.Host = host
req.URL.Host = host
}
func responseWith(request *http.Request, statusCode int) *http.Response {
return &http.Response{
StatusCode: statusCode,
Status: http.StatusText(statusCode),
Proto: request.Proto,
ProtoMajor: request.ProtoMajor,
ProtoMinor: request.ProtoMinor,
Header: http.Header{},
}
}
func (l *MixedListener) Start() error {
return l.TCPListener.Start()
}
func (l *MixedListener) Close() error {
return l.TCPListener.Close()
}
func (l *MixedListener) OnError(err error) {
l.Handler.OnError(exceptions.Cause(err, "mixed server"))
}

View file

@ -0,0 +1,14 @@
package system
import (
"syscall"
)
const (
TCP_FASTOPEN = 23
TCP_FASTOPEN_CONNECT = 30
)
func TCPFastOpen(fd uintptr) error {
return syscall.SetsockoptInt(int(fd), syscall.SOL_TCP, TCP_FASTOPEN_CONNECT, 1)
}

View file

@ -0,0 +1,9 @@
//go:build !linux
package main
import "github.com/sagernet/sing/common/exceptions"
func TCPFastOpen(fd uintptr) error {
return exceptions.New("only available on linux")
}

148
transport/system/socks.go Normal file
View file

@ -0,0 +1,148 @@
package system
import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/socksaddr"
"github.com/sagernet/sing/protocol/socks"
"io"
"net"
"net/netip"
)
type SocksHandler interface {
NewConnection(addr socksaddr.Addr, port uint16, conn net.Conn) error
NewPacketConnection(conn socks.PacketConn, addr socksaddr.Addr, port uint16) error
OnError(err error)
}
type SocksConfig struct {
Username string
Password string
}
type SocksListener struct {
Handler SocksHandler
*TCPListener
*SocksConfig
}
func NewSocksListener(bind netip.AddrPort, config *SocksConfig, handler SocksHandler) *SocksListener {
listener := &SocksListener{
SocksConfig: config,
Handler: handler,
}
listener.TCPListener = NewTCPListener(bind, listener)
return listener
}
func (l *SocksListener) HandleTCP(conn net.Conn) error {
authRequest, err := socks.ReadAuthRequest(conn)
if err != nil {
return exceptions.Cause(err, "read socks auth request")
}
var authMethod byte
if l.Username == "" {
authMethod = socks.AuthTypeNotRequired
} else {
authMethod = socks.AuthTypeUsernamePassword
}
if !common.Contains(authRequest.Methods, authMethod) {
err = socks.WriteAuthResponse(conn, &socks.AuthResponse{
Version: authRequest.Version,
Method: socks.AuthTypeNoAcceptedMethods,
})
if err != nil {
return exceptions.Cause(err, "write socks auth response")
}
}
err = socks.WriteAuthResponse(conn, &socks.AuthResponse{
Version: authRequest.Version,
Method: socks.AuthTypeNotRequired,
})
if err != nil {
return exceptions.Cause(err, "write socks auth response")
}
if authMethod == socks.AuthTypeUsernamePassword {
usernamePasswordAuthRequest, err := socks.ReadUsernamePasswordAuthRequest(conn)
if err != nil {
return exceptions.Cause(err, "read user auth request")
}
response := socks.UsernamePasswordAuthResponse{}
if usernamePasswordAuthRequest.Username != l.Username {
response.Status = socks.UsernamePasswordStatusFailure
} else if usernamePasswordAuthRequest.Password != l.Password {
response.Status = socks.UsernamePasswordStatusFailure
} else {
response.Status = socks.UsernamePasswordStatusSuccess
}
err = socks.WriteUsernamePasswordAuthResponse(conn, &response)
if err != nil {
return exceptions.Cause(err, "write user auth response")
}
}
request, err := socks.ReadRequest(conn)
if err != nil {
return exceptions.Cause(err, "read socks request")
}
switch request.Command {
case socks.CommandConnect:
localAddr, localPort := socksaddr.AddressFromNetAddr(l.TCPListener.TCPListener.Addr())
err = socks.WriteResponse(conn, &socks.Response{
Version: request.Version,
ReplyCode: socks.ReplyCodeSuccess,
BindAddr: localAddr,
BindPort: localPort,
})
if err != nil {
return exceptions.Cause(err, "write socks response")
}
return l.Handler.NewConnection(request.Addr, request.Port, conn)
case socks.CommandUDPAssociate:
udpConn, err := net.ListenUDP("udp4", nil)
if err != nil {
return err
}
defer udpConn.Close()
localAddr, localPort := socksaddr.AddressFromNetAddr(udpConn.LocalAddr())
err = socks.WriteResponse(conn, &socks.Response{
Version: request.Version,
ReplyCode: socks.ReplyCodeSuccess,
BindAddr: localAddr,
BindPort: localPort,
})
if err != nil {
return exceptions.Cause(err, "write socks response")
}
go func() {
err := l.Handler.NewPacketConnection(socks.NewPacketConn(conn, udpConn), request.Addr, request.Port)
if err != nil {
l.OnError(err)
}
}()
return common.Error(io.Copy(io.Discard, conn))
default:
err = socks.WriteResponse(conn, &socks.Response{
Version: request.Version,
ReplyCode: socks.ReplyCodeUnsupported,
})
if err != nil {
return exceptions.Cause(err, "write response")
}
}
return nil
}
func (l *SocksListener) Start() error {
return l.TCPListener.Start()
}
func (l *SocksListener) Close() error {
return l.TCPListener.Close()
}
func (l *SocksListener) OnError(err error) {
l.Handler.OnError(exceptions.Cause(err, "socks server"))
}

0
transport/system/stub.s Normal file
View file

View file

@ -8,7 +8,7 @@ import (
)
type UDPHandler interface {
HandleUDP(listener *UDPListener, buffer *buf.Buffer, sourceAddr net.Addr) error
HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error
OnError(err error)
}
@ -52,7 +52,7 @@ func (l *UDPListener) loop() {
}
buffer.Truncate(n)
go func() {
err := l.Handler.HandleUDP(l, buffer, addr)
err := l.Handler.HandleUDP(buffer, addr)
if err != nil {
buffer.Release()
l.Handler.OnError(err)