mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 11:57:39 +03:00
Add http support for sslocal
This commit is contained in:
parent
542fa4f975
commit
26e13e7beb
30 changed files with 1355 additions and 194 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -1,2 +1,3 @@
|
|||
/.idea/
|
||||
/sing_*
|
||||
/sing_*
|
||||
/*.json
|
|
@ -1,16 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing/cli/sslocal"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := sslocal.MainCmd().Execute()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
243
cli/sslocal/sslocal.go
Normal file
243
cli/sslocal/sslocal.go
Normal file
|
@ -0,0 +1,243 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"github.com/sagernet/sing/transport/system"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := MainCmd().Execute()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
type Flags struct {
|
||||
Server string `json:"server"`
|
||||
ServerPort uint16 `json:"server_port"`
|
||||
LocalPort uint16 `json:"local_port"`
|
||||
Password string `json:"password"`
|
||||
Key string `json:"key"`
|
||||
Method string `json:"method"`
|
||||
TCPFastOpen bool `json:"fast_open"`
|
||||
Verbose bool `json:"verbose"`
|
||||
ConfigFile string
|
||||
}
|
||||
|
||||
func MainCmd() *cobra.Command {
|
||||
flags := new(Flags)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "sslocal",
|
||||
Short: "shadowsocks client as socks5 proxy, sing port",
|
||||
Version: sing.Version,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
Run(flags)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVarP(&flags.Server, "server", "s", "", "Set the server’s hostname or IP.")
|
||||
cmd.Flags().Uint16VarP(&flags.ServerPort, "server-port", "p", 0, "Set the server’s port number.")
|
||||
cmd.Flags().Uint16VarP(&flags.LocalPort, "local-port", "l", 0, "Set the local port number.")
|
||||
cmd.Flags().StringVarP(&flags.Password, "password", "k", "", "Set the password. The server and the client should use the same password.")
|
||||
cmd.Flags().StringVar(&flags.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
|
||||
cmd.Flags().StringVarP(&flags.Method, "encrypt-method", "m", "", `Set the cipher.
|
||||
|
||||
Supported ciphers:
|
||||
|
||||
none
|
||||
aes-128-gcm
|
||||
aes-192-gcm
|
||||
aes-256-gcm
|
||||
chacha20-ietf-poly1305
|
||||
xchacha20-ietf-poly1305
|
||||
|
||||
The default cipher is chacha20-ietf-poly1305.`)
|
||||
cmd.Flags().BoolVar(&flags.TCPFastOpen, "fast-open", false, `Enable TCP fast open.
|
||||
Only available with Linux kernel > 3.7.0.`)
|
||||
cmd.Flags().StringVarP(&flags.ConfigFile, "config", "c", "", "Use a configuration file.")
|
||||
cmd.Flags().BoolVarP(&flags.Verbose, "verbose", "v", false, "Enable verbose mode.")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
type LocalClient struct {
|
||||
*system.MixedListener
|
||||
*shadowsocks.Client
|
||||
}
|
||||
|
||||
func NewLocalClient(flags *Flags) (*LocalClient, error) {
|
||||
if flags.ConfigFile != "" {
|
||||
configFile, err := ioutil.ReadFile(flags.ConfigFile)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "read config file")
|
||||
}
|
||||
flagsNew := new(Flags)
|
||||
err = json.Unmarshal(configFile, flagsNew)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "decode config file")
|
||||
}
|
||||
if flagsNew.Server != "" && flags.Server == "" {
|
||||
flags.Server = flagsNew.Server
|
||||
}
|
||||
if flagsNew.ServerPort != 0 && flags.ServerPort == 0 {
|
||||
flags.ServerPort = flagsNew.ServerPort
|
||||
}
|
||||
if flagsNew.LocalPort != 0 && flags.LocalPort == 0 {
|
||||
flags.LocalPort = flagsNew.LocalPort
|
||||
}
|
||||
if flagsNew.Password != "" && flags.Password == "" {
|
||||
flags.Password = flagsNew.Password
|
||||
}
|
||||
if flagsNew.Key != "" && flags.Key == "" {
|
||||
flags.Key = flagsNew.Key
|
||||
}
|
||||
if flagsNew.Method != "" && flags.Method == "" {
|
||||
flags.Method = flagsNew.Method
|
||||
}
|
||||
if flagsNew.TCPFastOpen {
|
||||
flags.TCPFastOpen = true
|
||||
}
|
||||
if flagsNew.Verbose {
|
||||
flags.Verbose = true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
clientConfig := &shadowsocks.ClientConfig{
|
||||
Server: flags.Server,
|
||||
ServerPort: flags.ServerPort,
|
||||
Method: flags.Method,
|
||||
}
|
||||
|
||||
if flags.Key != "" {
|
||||
key, err := base64.URLEncoding.DecodeString(flags.Key)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "decode key")
|
||||
}
|
||||
clientConfig.Key = key
|
||||
} else if flags.Password != "" {
|
||||
clientConfig.Password = []byte(flags.Password)
|
||||
}
|
||||
|
||||
if flags.Verbose {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
}
|
||||
|
||||
dialer := new(net.Dialer)
|
||||
|
||||
if flags.TCPFastOpen {
|
||||
dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
var rawFd uintptr
|
||||
err := c.Control(func(fd uintptr) {
|
||||
rawFd = fd
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return system.TCPFastOpen(rawFd)
|
||||
}
|
||||
}
|
||||
|
||||
shadowClient, err := shadowsocks.NewClient(dialer, clientConfig)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "create shadowsocks")
|
||||
}
|
||||
|
||||
client := &LocalClient{
|
||||
Client: shadowClient,
|
||||
}
|
||||
client.MixedListener = system.NewMixedListener(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), flags.LocalPort), &system.SocksConfig{}, client)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *LocalClient) Start() error {
|
||||
err := c.MixedListener.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logrus.Info("mixed server started at ", c.MixedListener.TCPListener.Addr())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LocalClient) NewConnection(addr socksaddr.Addr, port uint16, conn net.Conn) error {
|
||||
logrus.Info("TCP ", conn.RemoteAddr(), " ==> ", net.JoinHostPort(addr.String(), strconv.Itoa(int(port))))
|
||||
|
||||
ctx := context.Background()
|
||||
serverConn, err := c.DialContextTCP(ctx, addr, port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return task.Run(ctx, func() error {
|
||||
defer rw.CloseRead(conn)
|
||||
defer rw.CloseWrite(serverConn)
|
||||
return common.Error(io.Copy(serverConn, conn))
|
||||
}, func() error {
|
||||
defer rw.CloseRead(serverConn)
|
||||
defer rw.CloseWrite(conn)
|
||||
return common.Error(io.Copy(conn, serverConn))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, addr socksaddr.Addr, port uint16) error {
|
||||
ctx := context.Background()
|
||||
serverConn := c.DialContextUDP(ctx)
|
||||
return task.Run(ctx, func() error {
|
||||
var init bool
|
||||
return socks.CopyPacketConn(serverConn, conn, func(size int) {
|
||||
if !init {
|
||||
init = true
|
||||
logrus.Info("UDP ", conn.LocalAddr(), " ==> ", socksaddr.JoinHostPort(addr, port))
|
||||
} else {
|
||||
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", socksaddr.JoinHostPort(addr, port))
|
||||
}
|
||||
})
|
||||
}, func() error {
|
||||
return socks.CopyPacketConn(conn, serverConn, func(size int) {
|
||||
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", socksaddr.JoinHostPort(addr, port))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func Run(flags *Flags) {
|
||||
client, err := NewLocalClient(flags)
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
err = client.Start()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
osSignals := make(chan os.Signal, 1)
|
||||
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
|
||||
<-osSignals
|
||||
client.Close()
|
||||
}
|
||||
|
||||
func (c *LocalClient) OnError(err error) {
|
||||
if exceptions.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
logrus.Warn(err)
|
||||
}
|
|
@ -1,15 +1,18 @@
|
|||
package sslocal
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
|
@ -27,13 +30,24 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := MainCmd().Execute()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
type Flags struct {
|
||||
Server string
|
||||
ServerPort uint16
|
||||
LocalPort uint16
|
||||
Password string
|
||||
Key string
|
||||
Method string
|
||||
Server string `json:"server"`
|
||||
ServerPort uint16 `json:"server_port"`
|
||||
LocalPort uint16 `json:"local_port"`
|
||||
Password string `json:"password"`
|
||||
Key string `json:"key"`
|
||||
Method string `json:"method"`
|
||||
Timeout uint16 `json:"timeout"`
|
||||
TCPFastOpen bool `json:"fast_open"`
|
||||
Verbose bool `json:"verbose"`
|
||||
ConfigFile string
|
||||
}
|
||||
|
||||
func MainCmd() *cobra.Command {
|
||||
|
@ -65,13 +79,10 @@ chacha20-ietf-poly1305
|
|||
xchacha20-ietf-poly1305
|
||||
|
||||
The default cipher is chacha20-ietf-poly1305.`)
|
||||
// cmd.Flags().Uint16VarP(&flags.Timeout, "timeout", "t", 60, "Set the socket timeout in seconds.")
|
||||
// cmd.Flags().StringVarP(&flags.ConfigFile, "config", "c", "", "Use a configuration file.")
|
||||
// cmd.Flags().Uint16VarP(&flags.MaxFD, "max-open-files", "n", 0, `Specify max number of open files.
|
||||
// Only available on Linux.`)
|
||||
// cmd.Flags().StringVarP(&flags.Interface, "interface", "i", "", `Send traffic through specific network interface.
|
||||
// For example, there are three interfaces in your device, which is lo (127.0.0.1), eth0 (192.168.0.1) and eth1 (192.168.0.2). Meanwhile, you configure ss-local to listen on 0.0.0.0:8388 and bind to eth1. That results the traffic go out through eth1, but not lo nor eth0. This option is useful to control traffic in multi-interface environment.`)
|
||||
// cmd.Flags().StringVarP(&flags.LocalAddress, "local-address", "b", "", "Specify the local address to use while this client is making outbound connections to the server.")
|
||||
cmd.Flags().BoolVar(&flags.TCPFastOpen, "fast-open", false, `Enable TCP fast open.
|
||||
Only available with Linux kernel > 3.7.0.`)
|
||||
cmd.Flags().StringVarP(&flags.ConfigFile, "config", "c", "", "Use a configuration file.")
|
||||
cmd.Flags().BoolVarP(&flags.Verbose, "verbose", "v", false, "Enable verbose mode.")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
@ -81,10 +92,27 @@ type LocalClient struct {
|
|||
serverAddr netip.AddrPort
|
||||
cipher shadowsocks.Cipher
|
||||
key []byte
|
||||
dialer net.Dialer
|
||||
}
|
||||
|
||||
func NewLocalClient(flags *Flags) (*LocalClient, error) {
|
||||
client := new(LocalClient)
|
||||
if flags.ConfigFile != "" {
|
||||
configFile, err := os.Open(flags.ConfigFile)
|
||||
if err != nil {
|
||||
return nil, exceptions.CauseF(err, "unable to open config file ", flags.ConfigFile)
|
||||
}
|
||||
config, err := ioutil.ReadAll(configFile)
|
||||
configFile.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = json.Unmarshal(config, &flags)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "failed to decode config file")
|
||||
}
|
||||
}
|
||||
|
||||
client := &LocalClient{}
|
||||
client.tcpIn = system.NewTCPListener(netip.AddrPortFrom(netip.IPv4Unspecified(), flags.LocalPort), client)
|
||||
|
||||
if flags.Server == "" {
|
||||
|
@ -120,6 +148,32 @@ func NewLocalClient(flags *Flags) (*LocalClient, error) {
|
|||
return nil, exceptions.New("password not specified")
|
||||
}
|
||||
|
||||
if flags.Timeout > 0 {
|
||||
client.dialer.Timeout = time.Duration(flags.Timeout) * time.Second
|
||||
}
|
||||
|
||||
if flags.TCPFastOpen {
|
||||
client.dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
if strings.HasPrefix(network, "tcp") {
|
||||
var rawFd uintptr
|
||||
if err = c.Control(func(fd uintptr) {
|
||||
rawFd = fd
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
err = system.TCPFastOpen(rawFd)
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "set tcp fast open")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if flags.Verbose {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
|
@ -142,7 +196,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
|
|||
|
||||
authRequest, err := socks.ReadAuthRequest(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
return exceptions.Cause(err, "read socks auth request")
|
||||
}
|
||||
|
||||
if !common.Contains(authRequest.Methods, socks.AuthTypeNotRequired) {
|
||||
|
@ -151,7 +205,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
|
|||
Method: socks.AuthTypeNoAcceptedMethods,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return exceptions.Cause(err, "write socks auth response")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -160,12 +214,12 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
|
|||
Method: socks.AuthTypeNotRequired,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return exceptions.Cause(err, "write socks auth response")
|
||||
}
|
||||
|
||||
request, err := socks.ReadRequest(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
return exceptions.Cause(err, "read socks request")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
@ -181,7 +235,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
|
|||
case socks.CommandConnect:
|
||||
logrus.Info("CONNECT ", request.Addr, ":", request.Port)
|
||||
|
||||
serverConn, dialErr := system.Dial(ctx, "tcp", c.serverAddr.String())
|
||||
serverConn, dialErr := c.dialer.DialContext(ctx, "tcp", c.serverAddr.String())
|
||||
if dialErr != nil {
|
||||
failure()
|
||||
return exceptions.Cause(dialErr, "connect to server")
|
||||
|
@ -196,12 +250,12 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
|
|||
Writer: serverConn,
|
||||
Buffer: saltBuffer,
|
||||
}
|
||||
writer, _ := c.cipher.CreateWriter(c.key, saltBuffer.Bytes(), serverWriter)
|
||||
writer := c.cipher.CreateWriter(c.key, saltBuffer.Bytes(), serverWriter)
|
||||
|
||||
header := buf.New()
|
||||
defer header.Release()
|
||||
requestBuffer := buf.New()
|
||||
defer requestBuffer.Release()
|
||||
|
||||
err = shadowsocks.AddressSerializer.WriteAddressAndPort(header, request.Addr, request.Port)
|
||||
err = shadowsocks.AddressSerializer.WriteAddressAndPort(requestBuffer, request.Addr, request.Port)
|
||||
if err != nil {
|
||||
failure()
|
||||
return err
|
||||
|
@ -215,7 +269,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
|
|||
BindPort: serverPort,
|
||||
})
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write response for ", request.Addr, "/", request.Port)
|
||||
return exceptions.Cause(err, "write socks response")
|
||||
}
|
||||
|
||||
return task.Run(ctx, func() error {
|
||||
|
@ -226,7 +280,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = header.ReadFrom(conn)
|
||||
_, err = requestBuffer.ReadFrom(conn)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
} else {
|
||||
|
@ -237,7 +291,7 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = writer.Write(header.Bytes())
|
||||
_, err = writer.Write(requestBuffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -245,7 +299,8 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
|
|||
if err != nil {
|
||||
return exceptions.Cause(err, "flush request")
|
||||
}
|
||||
_, err = io.Copy(writer, conn)
|
||||
requestBuffer.FullReset()
|
||||
_, err = io.CopyBuffer(writer, conn, requestBuffer.FreeBytes())
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "upload")
|
||||
}
|
||||
|
@ -275,10 +330,10 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
|
|||
return nil
|
||||
})
|
||||
case socks.CommandUDPAssociate:
|
||||
serverConn, dialErr := system.Dial(ctx, "udp", c.serverAddr.String())
|
||||
serverConn, dialErr := c.dialer.DialContext(ctx, "udp", c.serverAddr.String())
|
||||
if dialErr != nil {
|
||||
failure()
|
||||
return exceptions.Cause(err, "connect to server")
|
||||
return exceptions.Cause(dialErr, "connect to server")
|
||||
}
|
||||
handler := &udpHandler{
|
||||
LocalClient: c,
|
||||
|
@ -300,9 +355,12 @@ func (c *LocalClient) HandleTCP(conn net.Conn) error {
|
|||
}
|
||||
go handler.loopInput()
|
||||
return common.Error(io.Copy(io.Discard, conn))
|
||||
default:
|
||||
return socks.WriteResponse(conn, &socks.Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: socks.ReplyCodeUnsupported,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type udpHandler struct {
|
||||
|
@ -313,7 +371,7 @@ type udpHandler struct {
|
|||
sourceAddr net.Addr
|
||||
}
|
||||
|
||||
func (c *udpHandler) HandleUDP(listener *system.UDPListener, buffer *buf.Buffer, sourceAddr net.Addr) error {
|
||||
func (c *udpHandler) HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error {
|
||||
c.sourceAddr = sourceAddr
|
||||
buffer.Advance(3)
|
||||
if c.cipher.SaltSize() > 0 {
|
||||
|
@ -369,5 +427,8 @@ func (c *udpHandler) Close() error {
|
|||
}
|
||||
|
||||
func (c *LocalClient) OnError(err error) {
|
||||
if exceptions.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
logrus.Warn(err)
|
||||
}
|
44
common/buf/bufconn.go
Normal file
44
common/buf/bufconn.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package buf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type BufferedConn struct {
|
||||
r *bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func NewBufferedConn(c net.Conn) *BufferedConn {
|
||||
return &BufferedConn{bufio.NewReader(c), c}
|
||||
}
|
||||
|
||||
func (c *BufferedConn) Reader() *bufio.Reader {
|
||||
return c.r
|
||||
}
|
||||
|
||||
func (c *BufferedConn) Peek(n int) ([]byte, error) {
|
||||
return c.r.Peek(n)
|
||||
}
|
||||
|
||||
func (c *BufferedConn) Read(p []byte) (int, error) {
|
||||
return c.r.Read(p)
|
||||
}
|
||||
|
||||
func (c *BufferedConn) ReadByte() (byte, error) {
|
||||
return c.r.ReadByte()
|
||||
}
|
||||
|
||||
func (c *BufferedConn) UnreadByte() error {
|
||||
return c.r.UnreadByte()
|
||||
}
|
||||
|
||||
func (c *BufferedConn) Buffered() int {
|
||||
return c.r.Buffered()
|
||||
}
|
||||
|
||||
func (c *BufferedConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.r.WriteTo(w)
|
||||
}
|
|
@ -25,6 +25,20 @@ func New() *Buffer {
|
|||
}
|
||||
}
|
||||
|
||||
func NewSize(size int) *Buffer {
|
||||
if size <= 128 || size > BufferSize {
|
||||
return &Buffer{
|
||||
data: make([]byte, size),
|
||||
}
|
||||
}
|
||||
return &Buffer{
|
||||
data: GetBytes(),
|
||||
start: ReversedHeader,
|
||||
end: ReversedHeader,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
||||
func FullNew() *Buffer {
|
||||
return &Buffer{
|
||||
data: GetBytes(),
|
||||
|
@ -57,6 +71,20 @@ func As(data []byte) *Buffer {
|
|||
}
|
||||
}
|
||||
|
||||
func Or(data []byte, size int) *Buffer {
|
||||
max := cap(data)
|
||||
if size != max {
|
||||
data = data[:max]
|
||||
}
|
||||
if cap(data) >= size {
|
||||
return &Buffer{
|
||||
data: data,
|
||||
}
|
||||
} else {
|
||||
return NewSize(size)
|
||||
}
|
||||
}
|
||||
|
||||
func With(data []byte) *Buffer {
|
||||
return &Buffer{
|
||||
data: data,
|
||||
|
@ -346,7 +374,7 @@ func (b Buffer) Copy() []byte {
|
|||
func ReleaseMulti(mb *list.List[*Buffer]) {
|
||||
for entry := mb.Front(); entry != nil; entry = entry.Next() {
|
||||
// TODO: remove cast
|
||||
var buffer *Buffer = entry.Value
|
||||
buffer := entry.Value
|
||||
buffer.Release()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ package buf
|
|||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
type BufferedReader struct {
|
||||
|
@ -16,13 +18,18 @@ func (r *BufferedReader) Upstream() io.Reader {
|
|||
return r.Reader
|
||||
}
|
||||
|
||||
func (r *BufferedReader) Replaceable() bool {
|
||||
return r.Buffer == nil
|
||||
}
|
||||
|
||||
func (r *BufferedReader) Read(p []byte) (n int, err error) {
|
||||
if r.Buffer != nil {
|
||||
n, err = r.Buffer.Read(p)
|
||||
if err == nil {
|
||||
return
|
||||
if r.Buffer.IsEmpty() {
|
||||
r.Buffer.Release()
|
||||
r.Buffer = nil
|
||||
}
|
||||
r.Buffer = nil
|
||||
return
|
||||
}
|
||||
return r.Reader.Read(p)
|
||||
}
|
||||
|
@ -33,12 +40,13 @@ type BufferedWriter struct {
|
|||
}
|
||||
|
||||
func (w *BufferedWriter) Upstream() io.Writer {
|
||||
if w.Buffer != nil {
|
||||
return nil
|
||||
}
|
||||
return w.Writer
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Replaceable() bool {
|
||||
return w.Buffer == nil
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
||||
if w.Buffer == nil {
|
||||
return w.Writer.Write(p)
|
||||
|
@ -47,13 +55,7 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
|||
if err == nil {
|
||||
return
|
||||
}
|
||||
n, err = w.Writer.Write(w.Buffer.Bytes())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.Buffer.Release()
|
||||
w.Buffer = nil
|
||||
return w.Writer.Write(p)
|
||||
return len(p), w.Flush()
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Flush() error {
|
||||
|
@ -66,6 +68,5 @@ func (w *BufferedWriter) Flush() error {
|
|||
if buffer.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
_, err := w.Writer.Write(buffer.Bytes())
|
||||
return err
|
||||
return common.Error(w.Writer.Write(buffer.Bytes()))
|
||||
}
|
|
@ -3,6 +3,9 @@ package exceptions
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type Exception interface {
|
||||
|
@ -10,11 +13,6 @@ type Exception interface {
|
|||
Cause() error
|
||||
}
|
||||
|
||||
type SuppressedException interface {
|
||||
error
|
||||
Suppressed() error
|
||||
}
|
||||
|
||||
type exception struct {
|
||||
message string
|
||||
cause error
|
||||
|
@ -31,10 +29,36 @@ func (e exception) Cause() error {
|
|||
return e.cause
|
||||
}
|
||||
|
||||
func (e exception) Unwrap() error {
|
||||
return e.cause
|
||||
}
|
||||
|
||||
func (e exception) Is(err error) bool {
|
||||
return e == err || errors.Is(e.cause, err)
|
||||
}
|
||||
|
||||
func New(message ...any) error {
|
||||
return errors.New(fmt.Sprint(message...))
|
||||
}
|
||||
|
||||
func Cause(cause error, message ...any) Exception {
|
||||
return &exception{fmt.Sprint(message...), cause}
|
||||
func Cause(cause error, message string) Exception {
|
||||
return exception{message, cause}
|
||||
}
|
||||
|
||||
func CauseF(cause error, message ...any) Exception {
|
||||
return exception{fmt.Sprint(message), cause}
|
||||
}
|
||||
|
||||
func IsClosed(err error) bool {
|
||||
return IsTimeout(err) || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE)
|
||||
}
|
||||
|
||||
func IsTimeout(err error) bool {
|
||||
if unwrapErr := errors.Unwrap(err); unwrapErr != nil {
|
||||
err = unwrapErr
|
||||
}
|
||||
if opErr, isOpErr := err.(*net.OpError); isOpErr {
|
||||
return opErr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
package common
|
||||
|
||||
import "io"
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
type Flusher interface {
|
||||
Flush() error
|
||||
}
|
||||
|
||||
func Flush(writer io.Writer) error {
|
||||
writerBack := writer
|
||||
for {
|
||||
if f, ok := writer.(Flusher); ok {
|
||||
err := f.Flush()
|
||||
|
@ -15,6 +18,15 @@ func Flush(writer io.Writer) error {
|
|||
}
|
||||
}
|
||||
if u, ok := writer.(WriterWithUpstream); ok {
|
||||
if u.Replaceable() {
|
||||
if writerBack == writer {
|
||||
} else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter {
|
||||
setter.SetWriter(writerBack)
|
||||
writer = u.Upstream()
|
||||
continue
|
||||
}
|
||||
}
|
||||
writerBack = writer
|
||||
writer = u.Upstream()
|
||||
} else {
|
||||
break
|
||||
|
@ -22,3 +34,62 @@ func Flush(writer io.Writer) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func FlushVar(writerP *io.Writer) error {
|
||||
writer := *writerP
|
||||
writerBack := writer
|
||||
for {
|
||||
if f, ok := writer.(Flusher); ok {
|
||||
err := f.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if u, ok := writer.(WriterWithUpstream); ok {
|
||||
if u.Replaceable() {
|
||||
if writerBack == writer {
|
||||
writer = u.Upstream()
|
||||
writerBack = writer
|
||||
writerP = &writer
|
||||
continue
|
||||
} else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter {
|
||||
setter.SetWriter(writerBack)
|
||||
writer = u.Upstream()
|
||||
continue
|
||||
}
|
||||
}
|
||||
writerBack = writer
|
||||
writer = u.Upstream()
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type FlushOnceWriter struct {
|
||||
io.Writer
|
||||
flushed bool
|
||||
}
|
||||
|
||||
func (w *FlushOnceWriter) Upstream() io.Writer {
|
||||
return w.Writer
|
||||
}
|
||||
|
||||
func (w *FlushOnceWriter) Replaceable() bool {
|
||||
return w.flushed
|
||||
}
|
||||
|
||||
func (w *FlushOnceWriter) Write(p []byte) (n int, err error) {
|
||||
if w.flushed {
|
||||
return w.Writer.Write(p)
|
||||
}
|
||||
n, err = w.Writer.Write(p)
|
||||
if n > 0 {
|
||||
err = FlushVar(&w.Writer)
|
||||
}
|
||||
if err == nil {
|
||||
w.flushed = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -4,18 +4,40 @@ import (
|
|||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
)
|
||||
|
||||
func CopyConn(ctx context.Context, conn net.Conn, outConn net.Conn) error {
|
||||
return task.Run(ctx, func() error {
|
||||
return common.Error(io.Copy(conn, outConn))
|
||||
}, func() error {
|
||||
return common.Error(io.Copy(outConn, conn))
|
||||
})
|
||||
func ReadFromVar(writerVar *io.Writer, reader io.Reader) (int64, error) {
|
||||
writer := *writerVar
|
||||
writerBack := writer
|
||||
for {
|
||||
if w, ok := writer.(io.ReaderFrom); ok {
|
||||
return w.ReadFrom(reader)
|
||||
}
|
||||
if f, ok := writer.(common.Flusher); ok {
|
||||
err := f.Flush()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if u, ok := writer.(common.WriterWithUpstream); ok {
|
||||
if u.Replaceable() && writerBack == writer {
|
||||
writer = u.Upstream()
|
||||
writerBack = writer
|
||||
writerVar = &writer
|
||||
continue
|
||||
}
|
||||
writer = u.Upstream()
|
||||
writerBack = writer
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error {
|
||||
|
|
|
@ -3,6 +3,7 @@ package socksaddr
|
|||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Addr interface {
|
||||
|
@ -12,6 +13,26 @@ type Addr interface {
|
|||
String() string
|
||||
}
|
||||
|
||||
func ParseAddr(address string) Addr {
|
||||
addr, err := netip.ParseAddr(address)
|
||||
if err == nil {
|
||||
return AddrFromAddr(addr)
|
||||
}
|
||||
return AddrFromFqdn(address)
|
||||
}
|
||||
|
||||
func ParseAddrPort(address string) (Addr, uint16, error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return ParseAddr(host), uint16(portInt), nil
|
||||
}
|
||||
|
||||
func AddrFromIP(ip net.IP) Addr {
|
||||
addr, _ := netip.AddrFromSlice(ip)
|
||||
if addr.Is4() {
|
||||
|
@ -21,6 +42,14 @@ func AddrFromIP(ip net.IP) Addr {
|
|||
}
|
||||
}
|
||||
|
||||
func AddrFromAddr(addr netip.Addr) Addr {
|
||||
if addr.Is4() {
|
||||
return Addr4(addr.As4())
|
||||
} else {
|
||||
return Addr16(addr.As16())
|
||||
}
|
||||
}
|
||||
|
||||
func AddressFromNetAddr(netAddr net.Addr) (addr Addr, port uint16) {
|
||||
var ip net.IP
|
||||
switch addr := netAddr.(type) {
|
||||
|
@ -38,6 +67,10 @@ func AddrFromFqdn(fqdn string) Addr {
|
|||
return AddrFqdn(fqdn)
|
||||
}
|
||||
|
||||
func JoinHostPort(addr Addr, port uint16) string {
|
||||
return net.JoinHostPort(addr.String(), strconv.Itoa(int(port)))
|
||||
}
|
||||
|
||||
type Addr4 [4]byte
|
||||
|
||||
func (a Addr4) Family() Family {
|
||||
|
|
|
@ -6,8 +6,18 @@ import (
|
|||
|
||||
type ReaderWithUpstream interface {
|
||||
Upstream() io.Reader
|
||||
Replaceable() bool
|
||||
}
|
||||
|
||||
type UpstreamReaderSetter interface {
|
||||
SetUpstream(reader io.Reader)
|
||||
}
|
||||
|
||||
type WriterWithUpstream interface {
|
||||
Upstream() io.Writer
|
||||
Replaceable() bool
|
||||
}
|
||||
|
||||
type UpstreamWriterSetter interface {
|
||||
SetWriter(writer io.Writer)
|
||||
}
|
||||
|
|
1
go.mod
1
go.mod
|
@ -11,5 +11,6 @@ require (
|
|||
require (
|
||||
github.com/inconshreveable/mousetrap v1.0.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/stretchr/testify v1.7.1 // indirect
|
||||
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect
|
||||
)
|
||||
|
|
7
go.sum
7
go.sum
|
@ -1,4 +1,5 @@
|
|||
github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
|
||||
|
@ -12,8 +13,10 @@ github.com/spf13/cobra v1.4.0 h1:y+wJpx64xcgO1V+RcnwW0LEHxTKRi2ZDPSBjWnrg88Q=
|
|||
github.com/spf13/cobra v1.4.0/go.mod h1:Wo4iy3BUC+X2Fybo0PDqwJIv3dNRiZLHQymsfxlB84g=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 h1:tkVvjkPTB7pnW3jnid7kNyAMPVWllTNOf/qKDze4p9o=
|
||||
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
|
@ -21,3 +24,5 @@ golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0P
|
|||
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
|
@ -11,8 +11,8 @@ import (
|
|||
type Cipher interface {
|
||||
KeySize() int
|
||||
SaltSize() int
|
||||
CreateReader(key []byte, iv []byte, reader io.Reader) io.Reader
|
||||
CreateWriter(key []byte, iv []byte, writer io.Writer) (io.Writer, int)
|
||||
CreateReader(key []byte, salt []byte, reader io.Reader) io.Reader
|
||||
CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer
|
||||
EncodePacket(key []byte, buffer *buf.Buffer) error
|
||||
DecodePacket(key []byte, buffer *buf.Buffer) error
|
||||
}
|
||||
|
|
|
@ -92,9 +92,9 @@ func (c *AEADCipher) CreateReader(key []byte, salt []byte, reader io.Reader) io.
|
|||
return NewAEADReader(reader, c.Constructor(Kdf(key, salt, c.KeyLength)))
|
||||
}
|
||||
|
||||
func (c *AEADCipher) CreateWriter(key []byte, salt []byte, writer io.Writer) (io.Writer, int) {
|
||||
func (c *AEADCipher) CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer {
|
||||
protocolWriter := NewAEADWriter(writer, c.Constructor(Kdf(key, salt, c.KeyLength)))
|
||||
return protocolWriter, protocolWriter.maxDataSize
|
||||
return protocolWriter
|
||||
}
|
||||
|
||||
func (c *AEADCipher) EncodePacket(key []byte, buffer *buf.Buffer) error {
|
||||
|
@ -132,12 +132,6 @@ func (c *AEADConn) Write(p []byte) (n int, err error) {
|
|||
return c.Writer.Write(p)
|
||||
}
|
||||
|
||||
func (c *AEADConn) Close() error {
|
||||
c.Reader.Close()
|
||||
c.Writer.Close()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
type AEADReader struct {
|
||||
upstream io.Reader
|
||||
cipher cipher.AEAD
|
||||
|
@ -151,7 +145,7 @@ func NewAEADReader(upstream io.Reader, cipher cipher.AEAD) *AEADReader {
|
|||
return &AEADReader{
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
data: buf.GetBytes(),
|
||||
data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
|
||||
nonce: make([]byte, cipher.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
@ -160,6 +154,52 @@ func (r *AEADReader) Upstream() io.Reader {
|
|||
return r.upstream
|
||||
}
|
||||
|
||||
func (r *AEADReader) Replaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *AEADReader) SetUpstream(reader io.Reader) {
|
||||
r.upstream = reader
|
||||
}
|
||||
|
||||
func (r *AEADReader) WriteTo(writer io.Writer) (n int64, err error) {
|
||||
if r.cached > 0 {
|
||||
writeN, writeErr := writer.Write(r.data[r.index : r.index+r.cached])
|
||||
if writeErr != nil {
|
||||
return int64(writeN), writeErr
|
||||
}
|
||||
n += int64(writeN)
|
||||
}
|
||||
for {
|
||||
start := PacketLengthBufferSize + r.cipher.Overhead()
|
||||
_, err = io.ReadFull(r.upstream, r.data[:start])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = r.cipher.Open(r.data[:0], r.nonce, r.data[:start], nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
length := int(binary.BigEndian.Uint16(r.data[:PacketLengthBufferSize]))
|
||||
end := length + r.cipher.Overhead()
|
||||
_, err = io.ReadFull(r.upstream, r.data[:end])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = r.cipher.Open(r.data[:0], r.nonce, r.data[:end], nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
writeN, writeErr := writer.Write(r.data[:length])
|
||||
if writeErr != nil {
|
||||
return int64(writeN), writeErr
|
||||
}
|
||||
n += int64(writeN)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *AEADReader) Read(b []byte) (n int, err error) {
|
||||
if r.cached > 0 {
|
||||
n = copy(b, r.data[r.index:r.index+r.cached])
|
||||
|
@ -209,29 +249,19 @@ func (r *AEADReader) Read(b []byte) (n int, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *AEADReader) Close() error {
|
||||
if r.data != nil {
|
||||
buf.PutBytes(r.data)
|
||||
r.data = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type AEADWriter struct {
|
||||
upstream io.Writer
|
||||
cipher cipher.AEAD
|
||||
data []byte
|
||||
nonce []byte
|
||||
maxDataSize int
|
||||
upstream io.Writer
|
||||
cipher cipher.AEAD
|
||||
data []byte
|
||||
nonce []byte
|
||||
}
|
||||
|
||||
func NewAEADWriter(upstream io.Writer, cipher cipher.AEAD) *AEADWriter {
|
||||
return &AEADWriter{
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
data: buf.GetBytes(),
|
||||
nonce: make([]byte, cipher.NonceSize()),
|
||||
maxDataSize: MaxPacketSize - PacketLengthBufferSize - cipher.Overhead()*2,
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
|
||||
nonce: make([]byte, cipher.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -239,47 +269,47 @@ func (w *AEADWriter) Upstream() io.Writer {
|
|||
return w.upstream
|
||||
}
|
||||
|
||||
func (w *AEADWriter) Process(p []byte) (n int, buffer *buf.Buffer, flush bool, err error) {
|
||||
if len(p) > w.maxDataSize {
|
||||
n, err = w.Write(p)
|
||||
err = &rw.DirectException{
|
||||
Suppressed: err,
|
||||
func (w *AEADWriter) Replaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *AEADWriter) SetWriter(writer io.Writer) {
|
||||
w.upstream = writer
|
||||
}
|
||||
|
||||
func (w *AEADWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
for {
|
||||
offset := w.cipher.Overhead() + PacketLengthBufferSize
|
||||
readN, readErr := r.Read(w.data[offset : offset+MaxPacketSize])
|
||||
if readErr != nil {
|
||||
return 0, readErr
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(w.data[:PacketLengthBufferSize], uint16(len(p)))
|
||||
encryptedLength := w.cipher.Seal(w.data[:0], w.nonce, w.data[:PacketLengthBufferSize], nil)
|
||||
increaseNonce(w.nonce)
|
||||
start := len(encryptedLength)
|
||||
|
||||
/*
|
||||
no usage
|
||||
if cap(p) > len(p)+PacketLengthBufferSize+2*w.cipher.Overhead() {
|
||||
packet := w.cipher.Seal(p[:start], w.nonce, p, nil)
|
||||
increaseNonce(w.nonce)
|
||||
copy(p[:start], encryptedLength)
|
||||
n = start + len(packet)
|
||||
binary.BigEndian.PutUint16(w.data[:PacketLengthBufferSize], uint16(readN))
|
||||
w.cipher.Seal(w.data[:0], w.nonce, w.data[:PacketLengthBufferSize], nil)
|
||||
increaseNonce(w.nonce)
|
||||
packet := w.cipher.Seal(w.data[offset:offset], w.nonce, w.data[offset:offset+readN], nil)
|
||||
increaseNonce(w.nonce)
|
||||
_, err = w.upstream.Write(w.data[:offset+len(packet)])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
*/
|
||||
|
||||
packet := w.cipher.Seal(w.data[:start], w.nonce, p, nil)
|
||||
increaseNonce(w.nonce)
|
||||
return 0, buf.As(packet), false, err
|
||||
err = common.FlushVar(&w.upstream)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(readN)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *AEADWriter) Write(p []byte) (n int, err error) {
|
||||
for _, data := range buf.ForeachN(p, w.maxDataSize) {
|
||||
for _, data := range buf.ForeachN(p, MaxPacketSize) {
|
||||
binary.BigEndian.PutUint16(w.data[:PacketLengthBufferSize], uint16(len(data)))
|
||||
w.cipher.Seal(w.data[:0], w.nonce, w.data[:PacketLengthBufferSize], nil)
|
||||
increaseNonce(w.nonce)
|
||||
|
||||
start := w.cipher.Overhead() + PacketLengthBufferSize
|
||||
packet := w.cipher.Seal(w.data[:start], w.nonce, data, nil)
|
||||
offset := w.cipher.Overhead() + PacketLengthBufferSize
|
||||
packet := w.cipher.Seal(w.data[offset:offset], w.nonce, data, nil)
|
||||
increaseNonce(w.nonce)
|
||||
|
||||
_, err = w.upstream.Write(packet)
|
||||
_, err = w.upstream.Write(w.data[:offset+len(packet)])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -289,14 +319,6 @@ func (w *AEADWriter) Write(p []byte) (n int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (w *AEADWriter) Close() error {
|
||||
if w.data != nil {
|
||||
buf.PutBytes(w.data)
|
||||
w.data = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func increaseNonce(nonce []byte) {
|
||||
for i := range nonce {
|
||||
nonce[i]++
|
||||
|
|
|
@ -26,8 +26,8 @@ func (c *NoneCipher) CreateReader(_ []byte, _ []byte, reader io.Reader) io.Reade
|
|||
return reader
|
||||
}
|
||||
|
||||
func (c *NoneCipher) CreateWriter(_ []byte, _ []byte, writer io.Writer) (io.Writer, int) {
|
||||
return writer, 0
|
||||
func (c *NoneCipher) CreateWriter(key []byte, iv []byte, writer io.Writer) io.Writer {
|
||||
return writer
|
||||
}
|
||||
|
||||
func (c *NoneCipher) EncodePacket([]byte, *buf.Buffer) error {
|
||||
|
|
168
protocol/shadowsocks/client.go
Normal file
168
protocol/shadowsocks/client.go
Normal file
|
@ -0,0 +1,168 @@
|
|||
package shadowsocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBadKey = exceptions.New("bad key")
|
||||
ErrMissingPassword = exceptions.New("password not specified")
|
||||
)
|
||||
|
||||
type ClientConfig struct {
|
||||
Server string `json:"server"`
|
||||
ServerPort uint16 `json:"server_port"`
|
||||
Method string `json:"method"`
|
||||
Password []byte `json:"password"`
|
||||
Key []byte `json:"key"`
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
dialer *net.Dialer
|
||||
cipher Cipher
|
||||
server string
|
||||
key []byte
|
||||
}
|
||||
|
||||
func NewClient(dialer *net.Dialer, config *ClientConfig) (*Client, error) {
|
||||
cipher, err := CreateCipher(config.Method)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := &Client{
|
||||
dialer: dialer,
|
||||
cipher: cipher,
|
||||
server: net.JoinHostPort(config.Server, strconv.Itoa(int(config.ServerPort))),
|
||||
}
|
||||
if keyLen := len(config.Key); keyLen > 0 {
|
||||
if keyLen == cipher.KeySize() {
|
||||
client.key = config.Key
|
||||
} else {
|
||||
return nil, ErrBadKey
|
||||
}
|
||||
} else if len(config.Password) > 0 {
|
||||
client.key = Key(config.Password, cipher.KeySize())
|
||||
} else {
|
||||
return nil, ErrMissingPassword
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) DialContextTCP(ctx context.Context, addr socksaddr.Addr, port uint16) (net.Conn, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, "tcp", c.server)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "connect to server")
|
||||
}
|
||||
return c.DialConn(conn, addr, port), nil
|
||||
}
|
||||
|
||||
func (c *Client) DialConn(conn net.Conn, addr socksaddr.Addr, port uint16) net.Conn {
|
||||
header := buf.New()
|
||||
header.WriteRandom(c.cipher.SaltSize())
|
||||
writer := &buf.BufferedWriter{
|
||||
Writer: conn,
|
||||
Buffer: header,
|
||||
}
|
||||
protocolWriter := c.cipher.CreateWriter(c.key, header.Bytes(), writer)
|
||||
requestBuffer := buf.New()
|
||||
contentWriter := &buf.BufferedWriter{
|
||||
Writer: protocolWriter,
|
||||
Buffer: requestBuffer,
|
||||
}
|
||||
common.Must(AddressSerializer.WriteAddressAndPort(contentWriter, addr, port))
|
||||
return &shadowsocksConn{
|
||||
Client: c,
|
||||
Conn: conn,
|
||||
Writer: &common.FlushOnceWriter{Writer: contentWriter},
|
||||
}
|
||||
}
|
||||
|
||||
type shadowsocksConn struct {
|
||||
*Client
|
||||
net.Conn
|
||||
io.Writer
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (c *shadowsocksConn) Read(p []byte) (n int, err error) {
|
||||
if c.reader == nil {
|
||||
buffer := buf.Or(p, c.cipher.SaltSize())
|
||||
defer buffer.Release()
|
||||
_, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn)
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *shadowsocksConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if c.reader == nil {
|
||||
buffer := buf.NewSize(c.cipher.SaltSize())
|
||||
defer buffer.Release()
|
||||
_, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn)
|
||||
}
|
||||
return c.reader.(io.WriterTo).WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *shadowsocksConn) Write(p []byte) (n int, err error) {
|
||||
return c.Writer.Write(p)
|
||||
}
|
||||
|
||||
func (c *shadowsocksConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
return rw.ReadFromVar(&c.Writer, r)
|
||||
}
|
||||
|
||||
func (c *Client) DialContextUDP(ctx context.Context) socks.PacketConn {
|
||||
conn, err := c.dialer.DialContext(ctx, "udp", c.server)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &shadowsocksPacketConn{c, conn}
|
||||
}
|
||||
|
||||
type shadowsocksPacketConn struct {
|
||||
*Client
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *shadowsocksPacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
header.WriteRandom(c.cipher.SaltSize())
|
||||
common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port))
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
err := c.cipher.EncodePacket(c.key, buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(c.Conn.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (c *shadowsocksPacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) {
|
||||
n, err := c.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = c.cipher.DecodePacket(c.key, buffer)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return AddressSerializer.ReadAddressAndPort(buffer)
|
||||
}
|
1
protocol/shadowsocks/config.go
Normal file
1
protocol/shadowsocks/config.go
Normal file
|
@ -0,0 +1 @@
|
|||
package shadowsocks
|
78
protocol/socks/conn.go
Normal file
78
protocol/socks/conn.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package socks
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type PacketConn interface {
|
||||
ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error)
|
||||
WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error
|
||||
|
||||
Close() error
|
||||
LocalAddr() net.Addr
|
||||
RemoteAddr() net.Addr
|
||||
SetDeadline(t time.Time) error
|
||||
SetReadDeadline(t time.Time) error
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
func CopyPacketConn(dest PacketConn, conn PacketConn, onAction func(size int)) error {
|
||||
for {
|
||||
buffer := buf.New()
|
||||
addr, port, err := conn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return err
|
||||
}
|
||||
size := buffer.Len()
|
||||
err = dest.WritePacket(buffer, addr, port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if onAction != nil {
|
||||
onAction(size)
|
||||
}
|
||||
buffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
type associatePacketConn struct {
|
||||
net.PacketConn
|
||||
conn net.Conn
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func NewPacketConn(conn net.Conn, packetConn net.PacketConn) PacketConn {
|
||||
return &associatePacketConn{
|
||||
PacketConn: packetConn,
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) RemoteAddr() net.Addr {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) {
|
||||
n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
c.addr = addr
|
||||
buffer.Truncate(n)
|
||||
buffer.Advance(3)
|
||||
return AddressSerializer.ReadAddressAndPort(buffer)
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
common.Must(header.WriteZeroN(3))
|
||||
common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port))
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr))
|
||||
}
|
|
@ -49,7 +49,7 @@ func ReadAuthRequest(reader io.Reader) (*AuthRequest, error) {
|
|||
}
|
||||
methods, err := rw.ReadBytes(reader, int(methodLen))
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "read socks auth methods, length ", methodLen)
|
||||
return nil, exceptions.CauseF(err, "read socks auth methods, length ", methodLen)
|
||||
}
|
||||
request := &AuthRequest{
|
||||
version,
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
package system
|
||||
|
||||
import "syscall"
|
||||
|
||||
var ControlFunc func(fd uintptr) error
|
||||
|
||||
func Control(conn syscall.Conn) error {
|
||||
if ControlFunc == nil {
|
||||
return nil
|
||||
}
|
||||
rawConn, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ControlRaw(rawConn)
|
||||
}
|
||||
|
||||
func ControlRaw(conn syscall.RawConn) error {
|
||||
if ControlFunc == nil {
|
||||
return nil
|
||||
}
|
||||
var rawFd uintptr
|
||||
err := conn.Control(func(fd uintptr) {
|
||||
rawFd = fd
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ControlFunc(rawFd)
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
var Dial DialFunc = new(net.Dialer).DialContext
|
10
transport/system/http.go
Normal file
10
transport/system/http.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package system
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net/http"
|
||||
_ "unsafe"
|
||||
)
|
||||
|
||||
//go:linkname readRequest net/http.readRequest
|
||||
func readRequest(b *bufio.Reader) (req *http.Request, err error)
|
223
transport/system/mixed.go
Normal file
223
transport/system/mixed.go
Normal file
|
@ -0,0 +1,223 @@
|
|||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type MixedListener struct {
|
||||
*SocksListener
|
||||
}
|
||||
|
||||
func NewMixedListener(bind netip.AddrPort, config *SocksConfig, handler SocksHandler) *MixedListener {
|
||||
listener := &MixedListener{NewSocksListener(bind, config, handler)}
|
||||
listener.TCPListener.Handler = listener
|
||||
return listener
|
||||
}
|
||||
|
||||
func (l *MixedListener) HandleTCP(conn net.Conn) error {
|
||||
bufConn := buf.NewBufferedConn(conn)
|
||||
hdr, err := bufConn.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = bufConn.UnreadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hdr == socks.Version4 || hdr == socks.Version5 {
|
||||
return l.SocksListener.HandleTCP(bufConn)
|
||||
}
|
||||
|
||||
var httpClient *http.Client
|
||||
for {
|
||||
request, err := readRequest(bufConn.Reader())
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "read http request")
|
||||
}
|
||||
|
||||
if l.Username != "" {
|
||||
var authOk bool
|
||||
authorization := request.Header.Get("Proxy-Authorization")
|
||||
if strings.HasPrefix(authorization, "BASIC ") {
|
||||
userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
|
||||
if string(userPassword) == l.Username+":"+l.Password {
|
||||
authOk = true
|
||||
}
|
||||
}
|
||||
if !authOk {
|
||||
err = responseWith(request, http.StatusProxyAuthRequired).Write(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if request.Method == "CONNECT" {
|
||||
host := request.URL.Hostname()
|
||||
portStr := request.URL.Port()
|
||||
if portStr == "" {
|
||||
portStr = "80"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
err = responseWith(request, http.StatusBadRequest).Write(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err = fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established")
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write http response")
|
||||
}
|
||||
return l.Handler.NewConnection(socksaddr.ParseAddr(host), uint16(port), bufConn)
|
||||
}
|
||||
|
||||
keepAlive := strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
|
||||
|
||||
host := request.Header.Get("Host")
|
||||
if host != "" {
|
||||
request.Host = host
|
||||
}
|
||||
|
||||
request.RequestURI = ""
|
||||
|
||||
removeHopByHopHeaders(request.Header)
|
||||
removeExtraHTTPHostPort(request)
|
||||
|
||||
if request.URL.Scheme == "" || request.URL.Host == "" {
|
||||
return responseWith(request, http.StatusBadRequest).Write(conn)
|
||||
}
|
||||
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
DialContext: func(context context.Context, network, address string) (net.Conn, error) {
|
||||
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
||||
return nil, exceptions.New("unsupported network ", network)
|
||||
}
|
||||
|
||||
addr, port, err := socksaddr.ParseAddrPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
left, right := net.Pipe()
|
||||
go func() {
|
||||
err = l.Handler.NewConnection(addr, port, right)
|
||||
if err != nil {
|
||||
l.OnError(err)
|
||||
}
|
||||
}()
|
||||
return left, nil
|
||||
},
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
l.OnError(exceptions.Cause(err, "http proxy"))
|
||||
return responseWith(request, http.StatusBadGateway).Write(conn)
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(response.Header)
|
||||
|
||||
if keepAlive {
|
||||
response.Header.Set("Proxy-Connection", "keep-alive")
|
||||
response.Header.Set("Connection", "keep-alive")
|
||||
response.Header.Set("Keep-Alive", "timeout=4")
|
||||
}
|
||||
|
||||
response.Close = !keepAlive
|
||||
|
||||
err = response.Write(conn)
|
||||
if err != nil {
|
||||
l.OnError(exceptions.Cause(err, "http proxy"))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// removeHopByHopHeaders remove hop-by-hop header
|
||||
func removeHopByHopHeaders(header http.Header) {
|
||||
// Strip hop-by-hop header based on RFC:
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1
|
||||
// https://www.mnot.net/blog/2011/07/11/what_proxies_must_do
|
||||
|
||||
header.Del("Proxy-Connection")
|
||||
header.Del("Proxy-Authenticate")
|
||||
header.Del("Proxy-Authorization")
|
||||
header.Del("TE")
|
||||
header.Del("Trailers")
|
||||
header.Del("Transfer-Encoding")
|
||||
header.Del("Upgrade")
|
||||
|
||||
connections := header.Get("Connection")
|
||||
header.Del("Connection")
|
||||
if len(connections) == 0 {
|
||||
return
|
||||
}
|
||||
for _, h := range strings.Split(connections, ",") {
|
||||
header.Del(strings.TrimSpace(h))
|
||||
}
|
||||
}
|
||||
|
||||
// removeExtraHTTPHostPort remove extra host port (example.com:80 --> example.com)
|
||||
// It resolves the behavior of some HTTP servers that do not handle host:80 (e.g. baidu.com)
|
||||
func removeExtraHTTPHostPort(req *http.Request) {
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
|
||||
if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" {
|
||||
host = pHost
|
||||
}
|
||||
|
||||
req.Host = host
|
||||
req.URL.Host = host
|
||||
}
|
||||
|
||||
func responseWith(request *http.Request, statusCode int) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Status: http.StatusText(statusCode),
|
||||
Proto: request.Proto,
|
||||
ProtoMajor: request.ProtoMajor,
|
||||
ProtoMinor: request.ProtoMinor,
|
||||
Header: http.Header{},
|
||||
}
|
||||
}
|
||||
|
||||
func (l *MixedListener) Start() error {
|
||||
return l.TCPListener.Start()
|
||||
}
|
||||
|
||||
func (l *MixedListener) Close() error {
|
||||
return l.TCPListener.Close()
|
||||
}
|
||||
|
||||
func (l *MixedListener) OnError(err error) {
|
||||
l.Handler.OnError(exceptions.Cause(err, "mixed server"))
|
||||
}
|
14
transport/system/sockopt_linux.go
Normal file
14
transport/system/sockopt_linux.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package system
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const (
|
||||
TCP_FASTOPEN = 23
|
||||
TCP_FASTOPEN_CONNECT = 30
|
||||
)
|
||||
|
||||
func TCPFastOpen(fd uintptr) error {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_TCP, TCP_FASTOPEN_CONNECT, 1)
|
||||
}
|
9
transport/system/sockopt_other.go
Normal file
9
transport/system/sockopt_other.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
//go:build !linux
|
||||
|
||||
package main
|
||||
|
||||
import "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
func TCPFastOpen(fd uintptr) error {
|
||||
return exceptions.New("only available on linux")
|
||||
}
|
148
transport/system/socks.go
Normal file
148
transport/system/socks.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
package system
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type SocksHandler interface {
|
||||
NewConnection(addr socksaddr.Addr, port uint16, conn net.Conn) error
|
||||
NewPacketConnection(conn socks.PacketConn, addr socksaddr.Addr, port uint16) error
|
||||
OnError(err error)
|
||||
}
|
||||
|
||||
type SocksConfig struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
type SocksListener struct {
|
||||
Handler SocksHandler
|
||||
*TCPListener
|
||||
*SocksConfig
|
||||
}
|
||||
|
||||
func NewSocksListener(bind netip.AddrPort, config *SocksConfig, handler SocksHandler) *SocksListener {
|
||||
listener := &SocksListener{
|
||||
SocksConfig: config,
|
||||
Handler: handler,
|
||||
}
|
||||
listener.TCPListener = NewTCPListener(bind, listener)
|
||||
return listener
|
||||
}
|
||||
|
||||
func (l *SocksListener) HandleTCP(conn net.Conn) error {
|
||||
authRequest, err := socks.ReadAuthRequest(conn)
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "read socks auth request")
|
||||
}
|
||||
var authMethod byte
|
||||
if l.Username == "" {
|
||||
authMethod = socks.AuthTypeNotRequired
|
||||
} else {
|
||||
authMethod = socks.AuthTypeUsernamePassword
|
||||
}
|
||||
if !common.Contains(authRequest.Methods, authMethod) {
|
||||
err = socks.WriteAuthResponse(conn, &socks.AuthResponse{
|
||||
Version: authRequest.Version,
|
||||
Method: socks.AuthTypeNoAcceptedMethods,
|
||||
})
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write socks auth response")
|
||||
}
|
||||
}
|
||||
err = socks.WriteAuthResponse(conn, &socks.AuthResponse{
|
||||
Version: authRequest.Version,
|
||||
Method: socks.AuthTypeNotRequired,
|
||||
})
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write socks auth response")
|
||||
}
|
||||
|
||||
if authMethod == socks.AuthTypeUsernamePassword {
|
||||
usernamePasswordAuthRequest, err := socks.ReadUsernamePasswordAuthRequest(conn)
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "read user auth request")
|
||||
}
|
||||
response := socks.UsernamePasswordAuthResponse{}
|
||||
if usernamePasswordAuthRequest.Username != l.Username {
|
||||
response.Status = socks.UsernamePasswordStatusFailure
|
||||
} else if usernamePasswordAuthRequest.Password != l.Password {
|
||||
response.Status = socks.UsernamePasswordStatusFailure
|
||||
} else {
|
||||
response.Status = socks.UsernamePasswordStatusSuccess
|
||||
}
|
||||
err = socks.WriteUsernamePasswordAuthResponse(conn, &response)
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write user auth response")
|
||||
}
|
||||
}
|
||||
|
||||
request, err := socks.ReadRequest(conn)
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "read socks request")
|
||||
}
|
||||
switch request.Command {
|
||||
case socks.CommandConnect:
|
||||
localAddr, localPort := socksaddr.AddressFromNetAddr(l.TCPListener.TCPListener.Addr())
|
||||
err = socks.WriteResponse(conn, &socks.Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: socks.ReplyCodeSuccess,
|
||||
BindAddr: localAddr,
|
||||
BindPort: localPort,
|
||||
})
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write socks response")
|
||||
}
|
||||
return l.Handler.NewConnection(request.Addr, request.Port, conn)
|
||||
case socks.CommandUDPAssociate:
|
||||
udpConn, err := net.ListenUDP("udp4", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer udpConn.Close()
|
||||
localAddr, localPort := socksaddr.AddressFromNetAddr(udpConn.LocalAddr())
|
||||
err = socks.WriteResponse(conn, &socks.Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: socks.ReplyCodeSuccess,
|
||||
BindAddr: localAddr,
|
||||
BindPort: localPort,
|
||||
})
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write socks response")
|
||||
}
|
||||
go func() {
|
||||
err := l.Handler.NewPacketConnection(socks.NewPacketConn(conn, udpConn), request.Addr, request.Port)
|
||||
if err != nil {
|
||||
l.OnError(err)
|
||||
}
|
||||
}()
|
||||
return common.Error(io.Copy(io.Discard, conn))
|
||||
default:
|
||||
err = socks.WriteResponse(conn, &socks.Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: socks.ReplyCodeUnsupported,
|
||||
})
|
||||
if err != nil {
|
||||
return exceptions.Cause(err, "write response")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *SocksListener) Start() error {
|
||||
return l.TCPListener.Start()
|
||||
}
|
||||
|
||||
func (l *SocksListener) Close() error {
|
||||
return l.TCPListener.Close()
|
||||
}
|
||||
|
||||
func (l *SocksListener) OnError(err error) {
|
||||
l.Handler.OnError(exceptions.Cause(err, "socks server"))
|
||||
}
|
0
transport/system/stub.s
Normal file
0
transport/system/stub.s
Normal file
|
@ -8,7 +8,7 @@ import (
|
|||
)
|
||||
|
||||
type UDPHandler interface {
|
||||
HandleUDP(listener *UDPListener, buffer *buf.Buffer, sourceAddr net.Addr) error
|
||||
HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error
|
||||
OnError(err error)
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,7 @@ func (l *UDPListener) loop() {
|
|||
}
|
||||
buffer.Truncate(n)
|
||||
go func() {
|
||||
err := l.Handler.HandleUDP(l, buffer, addr)
|
||||
err := l.Handler.HandleUDP(buffer, addr)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
l.Handler.OnError(err)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue