mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 20:07:38 +03:00
Refine buffer
This commit is contained in:
parent
31d4b88581
commit
f16dd7a336
30 changed files with 993 additions and 209 deletions
|
@ -62,7 +62,7 @@ func main() {
|
|||
Short: "shadowsocks client",
|
||||
Version: sing.VersionStr,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
Run(cmd, f)
|
||||
run(cmd, f)
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -95,7 +95,7 @@ Only available with Linux kernel > 3.7.0.`)
|
|||
}
|
||||
}
|
||||
|
||||
type LocalClient struct {
|
||||
type client struct {
|
||||
*mixed.Listener
|
||||
*geosite.Matcher
|
||||
server *M.AddrPort
|
||||
|
@ -104,7 +104,7 @@ type LocalClient struct {
|
|||
bypass string
|
||||
}
|
||||
|
||||
func NewLocalClient(f *flags) (*LocalClient, error) {
|
||||
func newClient(f *flags) (*client, error) {
|
||||
if f.ConfigFile != "" {
|
||||
configFile, err := ioutil.ReadFile(f.ConfigFile)
|
||||
if err != nil {
|
||||
|
@ -159,13 +159,13 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
|
|||
return nil, E.New("missing method")
|
||||
}
|
||||
|
||||
client := &LocalClient{
|
||||
c := &client{
|
||||
server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort),
|
||||
bypass: f.Bypass,
|
||||
}
|
||||
|
||||
if f.Method == shadowsocks.MethodNone {
|
||||
client.method = shadowsocks.NewNone()
|
||||
c.method = shadowsocks.NewNone()
|
||||
} else {
|
||||
var pskList [][]byte
|
||||
if f.Key != "" {
|
||||
|
@ -183,7 +183,7 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
|
|||
if f.UseSystemRNG {
|
||||
rng = random.System
|
||||
} else {
|
||||
rng = random.Blake3KeyedHash()
|
||||
rng = random.System
|
||||
}
|
||||
if f.ReducedSaltEntropy {
|
||||
rng = &shadowsocks.ReducedEntropyReader{Reader: rng}
|
||||
|
@ -200,17 +200,17 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.method = method
|
||||
c.method = method
|
||||
} else if common.Contains(shadowaead_2022.List, f.Method) {
|
||||
method, err := shadowaead_2022.New(f.Method, pskList, rng)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.method = method
|
||||
c.method = method
|
||||
}
|
||||
}
|
||||
|
||||
client.dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
c.dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
var rawFd uintptr
|
||||
err := c.Control(func(fd uintptr) {
|
||||
rawFd = fd
|
||||
|
@ -256,7 +256,7 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
|
|||
bind = netip.IPv6Unspecified()
|
||||
}
|
||||
|
||||
client.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, transproxyMode, client)
|
||||
c.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, transproxyMode, c)
|
||||
|
||||
if f.Bypass != "" {
|
||||
err := geoip.LoadMMDB("Country.mmdb")
|
||||
|
@ -278,11 +278,11 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.Matcher = geositeMatcher
|
||||
c.Matcher = geositeMatcher
|
||||
}
|
||||
debug.FreeOSMemory()
|
||||
|
||||
return client, nil
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func bypass(conn net.Conn, destination *M.AddrPort) error {
|
||||
|
@ -302,7 +302,7 @@ func bypass(conn net.Conn, destination *M.AddrPort) error {
|
|||
})
|
||||
}
|
||||
|
||||
func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
func (c *client) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
if c.bypass != "" {
|
||||
if metadata.Destination.Addr.Family().IsFqdn() {
|
||||
if c.Match(metadata.Destination.Addr.Fqdn()) {
|
||||
|
@ -315,7 +315,7 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
|||
}
|
||||
}
|
||||
|
||||
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
||||
logrus.Info("outbound ", metadata.Protocol, " TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
||||
ctx := context.Background()
|
||||
|
||||
serverConn, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
|
||||
|
@ -339,7 +339,6 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
|||
}
|
||||
serverConn = c.method.DialEarlyConn(serverConn, metadata.Destination)
|
||||
_, err = serverConn.Write(payload.Bytes())
|
||||
payload.Release()
|
||||
if err != nil {
|
||||
return E.Cause(err, "client handshake")
|
||||
}
|
||||
|
@ -347,7 +346,7 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
|||
return rw.CopyConn(ctx, serverConn, conn)
|
||||
}
|
||||
|
||||
func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error {
|
||||
func (c *client) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error {
|
||||
ctx := context.Background()
|
||||
udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String())
|
||||
if err != nil {
|
||||
|
@ -371,32 +370,28 @@ func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) e
|
|||
})
|
||||
}
|
||||
|
||||
func Run(cmd *cobra.Command, flags *flags) {
|
||||
client, err := NewLocalClient(flags)
|
||||
func run(cmd *cobra.Command, flags *flags) {
|
||||
c, err := newClient(flags)
|
||||
if err != nil {
|
||||
logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n")
|
||||
cmd.Help()
|
||||
os.Exit(1)
|
||||
}
|
||||
err = client.Listener.Start()
|
||||
err = c.Start()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
|
||||
logrus.Info("mixed server started at ", client.Listener.TCPListener.Addr())
|
||||
logrus.Info("mixed server started at ", c.TCPListener.Addr())
|
||||
|
||||
osSignals := make(chan os.Signal, 1)
|
||||
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
|
||||
<-osSignals
|
||||
|
||||
client.Listener.Close()
|
||||
c.Close()
|
||||
}
|
||||
|
||||
func (c *LocalClient) HandleError(err error) {
|
||||
func (c *client) HandleError(err error) {
|
||||
common.Close(err)
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
|
|
151
cli/ss-server/main.go
Normal file
151
cli/ss-server/main.go
Normal file
|
@ -0,0 +1,151 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/random"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
|
||||
"github.com/sagernet/sing/transport/tcp"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type flags struct {
|
||||
Bind string `json:"local_address"`
|
||||
LocalPort uint16 `json:"local_port"`
|
||||
// Password string `json:"password"`
|
||||
Key string `json:"key"`
|
||||
Method string `json:"method"`
|
||||
Verbose bool `json:"verbose"`
|
||||
ConfigFile string
|
||||
}
|
||||
|
||||
func main() {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
|
||||
f := new(flags)
|
||||
|
||||
command := &cobra.Command{
|
||||
Use: "ss-local",
|
||||
Short: "shadowsocks client",
|
||||
Version: sing.VersionStr,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
run(cmd, f)
|
||||
},
|
||||
}
|
||||
|
||||
command.Flags().StringVarP(&f.Bind, "local-address", "b", "", "Set the local address.")
|
||||
command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.")
|
||||
command.Flags().StringVarP(&f.Key, "key", "k", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
|
||||
|
||||
var supportedCiphers []string
|
||||
supportedCiphers = append(supportedCiphers, shadowsocks.MethodNone)
|
||||
supportedCiphers = append(supportedCiphers, shadowaead_2022.List...)
|
||||
|
||||
command.Flags().StringVarP(&f.Method, "encrypt-method", "m", "", "Set the cipher.\n\nSupported ciphers:\n\n"+strings.Join(supportedCiphers, "\n"))
|
||||
command.Flags().StringVarP(&f.ConfigFile, "config", "c", "", "Use a configuration file.")
|
||||
command.Flags().BoolVarP(&f.Verbose, "verbose", "v", true, "Enable verbose mode.")
|
||||
|
||||
err := command.Execute()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func run(cmd *cobra.Command, f *flags) {
|
||||
s, err := newServer(f)
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
err = s.tcpIn.Start()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
logrus.Info("server started at ", s.tcpIn.TCPListener.Addr())
|
||||
|
||||
osSignals := make(chan os.Signal, 1)
|
||||
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
|
||||
<-osSignals
|
||||
|
||||
s.tcpIn.Close()
|
||||
}
|
||||
|
||||
type server struct {
|
||||
tcpIn *tcp.Listener
|
||||
service shadowsocks.Service
|
||||
}
|
||||
|
||||
func newServer(f *flags) (*server, error) {
|
||||
s := new(server)
|
||||
|
||||
if f.Method == shadowsocks.MethodNone {
|
||||
s.service = shadowsocks.NewNoneService(s)
|
||||
} else if common.Contains(shadowaead_2022.List, f.Method) {
|
||||
var pskList [][]byte
|
||||
if f.Key != "" {
|
||||
keyStrList := strings.Split(f.Key, ":")
|
||||
pskList = make([][]byte, len(keyStrList))
|
||||
for i, keyStr := range keyStrList {
|
||||
key, err := base64.StdEncoding.DecodeString(keyStr)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode key")
|
||||
}
|
||||
pskList[i] = key
|
||||
}
|
||||
}
|
||||
rng := random.System
|
||||
service, err := shadowaead_2022.NewService(f.Method, pskList[0], rng, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.service = service
|
||||
} else {
|
||||
return nil, E.New("unsupported method " + f.Method)
|
||||
}
|
||||
|
||||
var bind netip.Addr
|
||||
if f.Bind != "" {
|
||||
addr, err := netip.ParseAddr(f.Bind)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "bad local address")
|
||||
}
|
||||
bind = addr
|
||||
} else {
|
||||
bind = netip.IPv6Unspecified()
|
||||
}
|
||||
s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.LocalPort), s)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *server) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
if metadata.Protocol != "shadowsocks" {
|
||||
return s.service.NewConnection(conn, metadata)
|
||||
}
|
||||
logrus.Info("inbound TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
||||
destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return rw.CopyConn(context.Background(), conn, destConn)
|
||||
}
|
||||
|
||||
func (s *server) HandleError(err error) {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
logrus.Warn(err)
|
||||
}
|
|
@ -31,6 +31,12 @@ func StackNew() *Buffer {
|
|||
}
|
||||
}
|
||||
|
||||
func StackNewMax() *Buffer {
|
||||
return &Buffer{
|
||||
data: make([]byte, 65535),
|
||||
}
|
||||
}
|
||||
|
||||
func StackNewSize(size int) *Buffer {
|
||||
return &Buffer{
|
||||
data: Make(size),
|
||||
|
@ -291,6 +297,14 @@ func (b *Buffer) Release() {
|
|||
*b = Buffer{}
|
||||
}
|
||||
|
||||
func (b *Buffer) Cut(start int, end int) *Buffer {
|
||||
b.start += start
|
||||
b.end = len(b.data) - end
|
||||
return &Buffer{
|
||||
data: b.data[b.start:b.end],
|
||||
}
|
||||
}
|
||||
|
||||
func (b Buffer) Len() int {
|
||||
return b.end - b.start
|
||||
}
|
||||
|
|
|
@ -1,81 +0,0 @@
|
|||
package buf
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
type BufferedReader struct {
|
||||
Reader io.Reader
|
||||
Buffer *Buffer
|
||||
}
|
||||
|
||||
func (r *BufferedReader) Upstream() io.Reader {
|
||||
if r.Buffer != nil {
|
||||
return nil
|
||||
}
|
||||
return r.Reader
|
||||
}
|
||||
|
||||
func (r *BufferedReader) Replaceable() bool {
|
||||
return r.Buffer == nil
|
||||
}
|
||||
|
||||
func (r *BufferedReader) Read(p []byte) (n int, err error) {
|
||||
if r.Buffer != nil {
|
||||
n, err = r.Buffer.Read(p)
|
||||
if r.Buffer.IsEmpty() {
|
||||
r.Buffer.Release()
|
||||
r.Buffer = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
return r.Reader.Read(p)
|
||||
}
|
||||
|
||||
type BufferedWriter struct {
|
||||
Writer io.Writer
|
||||
Buffer *Buffer
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Upstream() io.Writer {
|
||||
return w.Writer
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Replaceable() bool {
|
||||
return w.Buffer == nil
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
||||
if w.Buffer == nil {
|
||||
return w.Writer.Write(p)
|
||||
}
|
||||
n, err = w.Buffer.Write(p)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
return len(p), w.Flush()
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Flush() error {
|
||||
if w.Buffer == nil {
|
||||
return nil
|
||||
}
|
||||
buffer := w.Buffer
|
||||
w.Buffer = nil
|
||||
defer buffer.Release()
|
||||
if buffer.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
return common.Error(w.Writer.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Close() error {
|
||||
buffer := w.Buffer
|
||||
if buffer != nil {
|
||||
w.Buffer = nil
|
||||
buffer.Release()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -29,11 +29,11 @@ func (r *readWriteConn) Close() error {
|
|||
}
|
||||
|
||||
func (r *readWriteConn) LocalAddr() net.Addr {
|
||||
return new(DummyAddr)
|
||||
return &DummyAddr{}
|
||||
}
|
||||
|
||||
func (r *readWriteConn) RemoteAddr() net.Addr {
|
||||
return new(DummyAddr)
|
||||
return &DummyAddr{}
|
||||
}
|
||||
|
||||
func (r *readWriteConn) SetDeadline(t time.Time) error {
|
||||
|
@ -53,7 +53,7 @@ type readConn struct {
|
|||
}
|
||||
|
||||
func (r *readConn) Write(b []byte) (n int, err error) {
|
||||
return 0, new(ReadOnlyException)
|
||||
return 0, &ReadOnlyException{}
|
||||
}
|
||||
|
||||
type writeConn struct {
|
||||
|
@ -62,23 +62,23 @@ type writeConn struct {
|
|||
}
|
||||
|
||||
func (w *writeConn) Read(p []byte) (n int, err error) {
|
||||
return 0, new(WriteOnlyException)
|
||||
return 0, &WriteOnlyException{}
|
||||
}
|
||||
|
||||
func NewReadConn(reader io.Reader) net.Conn {
|
||||
c := new(readConn)
|
||||
c := &readConn{}
|
||||
c.Reader = reader
|
||||
return c
|
||||
}
|
||||
|
||||
func NewWritConn(writer io.Writer) net.Conn {
|
||||
c := new(writeConn)
|
||||
c := &writeConn{}
|
||||
c.Writer = writer
|
||||
return c
|
||||
}
|
||||
|
||||
func NewReadWriteConn(reader io.Reader, writer io.Writer) net.Conn {
|
||||
c := new(readWriteConn)
|
||||
c := &readConn{}
|
||||
c.Reader = reader
|
||||
c.Writer = writer
|
||||
return c
|
||||
|
|
|
@ -52,8 +52,8 @@ func FlushVar(writerP *io.Writer) error {
|
|||
writerBack = writer
|
||||
*writerP = writer
|
||||
continue
|
||||
} else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter {
|
||||
setter.SetWriter(writerBack)
|
||||
} else if setter, hasSetter := writerBack.(UpstreamWriterSetter); hasSetter {
|
||||
setter.SetWriter(u.Upstream())
|
||||
writer = u.Upstream()
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -2,6 +2,4 @@
|
|||
|
||||
package lowmem
|
||||
|
||||
func init() {
|
||||
Enabled = true
|
||||
}
|
||||
const Enabled = true
|
||||
|
|
|
@ -4,8 +4,6 @@ import (
|
|||
"runtime/debug"
|
||||
)
|
||||
|
||||
var Enabled = false
|
||||
|
||||
func Free() {
|
||||
if Enabled {
|
||||
debug.FreeOSMemory()
|
||||
|
|
5
common/lowmem/stub.go
Normal file
5
common/lowmem/stub.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
//go:build !debug
|
||||
|
||||
package lowmem
|
||||
|
||||
const Enabled = false
|
|
@ -7,6 +7,7 @@ import (
|
|||
)
|
||||
|
||||
type Metadata struct {
|
||||
Protocol string
|
||||
Source *AddrPort
|
||||
Destination *AddrPort
|
||||
}
|
||||
|
|
|
@ -1,6 +1,16 @@
|
|||
package common
|
||||
|
||||
import "syscall"
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func TryFileDescriptor(conn any) (uintptr, error) {
|
||||
if rawConn, isRaw := conn.(syscall.Conn); isRaw {
|
||||
return GetFileDescriptor(rawConn)
|
||||
}
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func GetFileDescriptor(conn syscall.Conn) (uintptr, error) {
|
||||
rawConn, err := conn.SyscallConn()
|
||||
|
|
|
@ -3,10 +3,20 @@ package network
|
|||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type ContextDialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error)
|
||||
}
|
||||
|
||||
var SystemDialer ContextDialer = &net.Dialer{}
|
||||
var SystemDialer ContextDialer = &DefaultDialer{}
|
||||
|
||||
type DefaultDialer struct {
|
||||
net.Dialer
|
||||
}
|
||||
|
||||
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error) {
|
||||
return d.Dialer.DialContext(ctx, network, address.String())
|
||||
}
|
||||
|
|
|
@ -3,9 +3,10 @@
|
|||
package redir
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"net"
|
||||
)
|
||||
|
||||
func TProxy(fd uintptr, isIPv6 bool) error {
|
||||
|
|
133
common/rw/buffer.go
Normal file
133
common/rw/buffer.go
Normal file
|
@ -0,0 +1,133 @@
|
|||
package rw
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
type BufferedWriter struct {
|
||||
Writer io.Writer
|
||||
Buffer *buf.Buffer
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Upstream() io.Writer {
|
||||
return w.Writer
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Replaceable() bool {
|
||||
return w.Buffer == nil
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
||||
if w.Buffer == nil {
|
||||
return w.Writer.Write(p)
|
||||
}
|
||||
n, err = w.Buffer.Write(p)
|
||||
if n == len(p) {
|
||||
return
|
||||
}
|
||||
fd, err := common.TryFileDescriptor(w.Writer)
|
||||
if err == nil {
|
||||
_, err = WriteV(fd, w.Buffer.Bytes(), p[n:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.Buffer.Release()
|
||||
w.Buffer = nil
|
||||
return len(p), nil
|
||||
}
|
||||
_, err = w.Writer.Write(w.Buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = w.Flush()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = w.Writer.Write(p[n:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Flush() error {
|
||||
if w.Buffer == nil {
|
||||
return nil
|
||||
}
|
||||
if w.Buffer.IsEmpty() {
|
||||
w.Buffer.Release()
|
||||
w.Buffer = nil
|
||||
return nil
|
||||
}
|
||||
_, err := w.Writer.Write(w.Buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Buffer.Release()
|
||||
w.Buffer = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Close() error {
|
||||
buffer := w.Buffer
|
||||
if buffer != nil {
|
||||
w.Buffer = nil
|
||||
buffer.Release()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type HeaderWriter struct {
|
||||
Writer io.Writer
|
||||
Header *buf.Buffer
|
||||
}
|
||||
|
||||
func (w *HeaderWriter) Upstream() io.Writer {
|
||||
return w.Writer
|
||||
}
|
||||
|
||||
func (w *HeaderWriter) Replaceable() bool {
|
||||
return w.Header == nil
|
||||
}
|
||||
|
||||
func (w *HeaderWriter) Write(p []byte) (n int, err error) {
|
||||
if w.Header == nil {
|
||||
return w.Writer.Write(p)
|
||||
}
|
||||
fd, err := common.TryFileDescriptor(w.Writer)
|
||||
if err == nil {
|
||||
_, err = WriteV(fd, w.Header.Bytes(), p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.Header.Release()
|
||||
w.Header = nil
|
||||
return len(p), nil
|
||||
}
|
||||
cachedN, _ := w.Header.Write(p)
|
||||
_, err = w.Writer.Write(w.Header.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.Header.Release()
|
||||
w.Header = nil
|
||||
if cachedN < len(p) {
|
||||
_, err = w.Writer.Write(p[cachedN:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *HeaderWriter) Close() error {
|
||||
buffer := w.Header
|
||||
if buffer != nil {
|
||||
w.Header = nil
|
||||
buffer.Release()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
func Skip(reader io.Reader) error {
|
||||
|
@ -15,6 +16,9 @@ func SkipN(reader io.Reader, size int) error {
|
|||
}
|
||||
|
||||
func ReadByte(reader io.Reader) (byte, error) {
|
||||
if br, isBr := reader.(io.ByteReader); isBr {
|
||||
return br.ReadByte()
|
||||
}
|
||||
var b [1]byte
|
||||
if err := common.Error(io.ReadFull(reader, b[:])); err != nil {
|
||||
return 0, err
|
||||
|
@ -37,3 +41,33 @@ func ReadString(reader io.Reader, size int) (string, error) {
|
|||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
type ReaderFromWriter interface {
|
||||
io.ReaderFrom
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func ReadFrom0(readerFrom ReaderFromWriter, reader io.Reader) (n int64, err error) {
|
||||
n, err = CopyOnce(readerFrom, reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var rn int64
|
||||
rn, err = readerFrom.ReadFrom(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += rn
|
||||
return
|
||||
}
|
||||
|
||||
func CopyOnce(dest io.Writer, src io.Reader) (n int64, err error) {
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
n, err = buffer.ReadFrom(src)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = dest.Write(buffer.Bytes())
|
||||
return
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ type stubByteReader struct {
|
|||
}
|
||||
|
||||
func (r stubByteReader) ReadByte() (byte, error) {
|
||||
return ReadByte(r)
|
||||
return ReadByte(r.Reader)
|
||||
}
|
||||
|
||||
func ToByteReader(reader io.Reader) io.ByteReader {
|
||||
|
|
|
@ -19,7 +19,7 @@ func After(task func() error, after func() error) func() error {
|
|||
|
||||
func Run(ctx context.Context, tasks ...func() error) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
wg := new(sync.WaitGroup)
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(len(tasks))
|
||||
var retErr error
|
||||
for _, task := range tasks {
|
||||
|
|
|
@ -15,7 +15,7 @@ func TestServerConn(t *testing.T) {
|
|||
serverConn := NewServerConn(udpConn)
|
||||
defer serverConn.Close()
|
||||
clientConn := NewClientConn(serverConn)
|
||||
message := new(dnsmessage.Message)
|
||||
message := &dnsmessage.Message{}
|
||||
message.Header.ID = 1
|
||||
message.Header.RecursionDesired = true
|
||||
message.Questions = append(message.Questions, dnsmessage.Question{
|
||||
|
@ -30,6 +30,7 @@ func TestServerConn(t *testing.T) {
|
|||
Port: 53,
|
||||
}))
|
||||
_buffer := buf.StackNew()
|
||||
common.Use(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
common.Must2(buffer.ReadPacketFrom(clientConn))
|
||||
common.Must(message.Unpack(buffer.Bytes()))
|
||||
|
|
|
@ -95,6 +95,7 @@ func HandleRequest(request *http.Request, conn net.Conn, authenticator auth.Auth
|
|||
left, right := net.Pipe()
|
||||
go func() {
|
||||
metadata.Destination = destination
|
||||
metadata.Protocol = "http"
|
||||
err = handler.NewConnection(right, metadata)
|
||||
if err != nil {
|
||||
handler.HandleError(&tcp.Error{Conn: right, Cause: err})
|
||||
|
|
|
@ -117,7 +117,8 @@ func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
}
|
||||
|
||||
func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.Conn.(io.WriterTo).WriteTo(w)
|
||||
return io.Copy(w, c.Conn)
|
||||
// return c.Conn.(io.WriterTo).WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *noneConn) RemoteAddr() net.Addr {
|
||||
|
@ -138,7 +139,7 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
|||
|
||||
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
_header := buf.StackNew()
|
||||
_header := buf.StackNewMax()
|
||||
header := common.Dup(_header)
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
|
||||
if err != nil {
|
||||
|
|
42
protocol/shadowsocks/service.go
Normal file
42
protocol/shadowsocks/service.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package shadowsocks
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
M.TCPConnectionHandler
|
||||
}
|
||||
|
||||
type MultiUserService interface {
|
||||
Service
|
||||
AddUser(key []byte)
|
||||
RemoveUser(key []byte)
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
M.TCPConnectionHandler
|
||||
}
|
||||
|
||||
type NoneService struct {
|
||||
handler Handler
|
||||
}
|
||||
|
||||
func NewNoneService(handler Handler) Service {
|
||||
return &NoneService{
|
||||
handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NoneService) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
destination, err := socks.AddressSerializer.ReadAddrPort(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
return s.handler.NewConnection(conn, metadata)
|
||||
}
|
|
@ -5,7 +5,6 @@ import (
|
|||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
|
@ -92,6 +91,46 @@ func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *Reader) readInternal() (err error) {
|
||||
start := PacketLengthBufferSize + r.cipher.Overhead()
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:start])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
|
||||
end := length + r.cipher.Overhead()
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:end])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
r.cached = length
|
||||
r.index = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) ReadByte() (byte, error) {
|
||||
if r.cached == 0 {
|
||||
err := r.readInternal()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
index := r.index
|
||||
r.index++
|
||||
r.cached--
|
||||
return r.buffer[index], nil
|
||||
}
|
||||
|
||||
func (r *Reader) Read(b []byte) (n int, err error) {
|
||||
if r.cached > 0 {
|
||||
n = copy(b, r.buffer[r.index:r.index+r.cached])
|
||||
|
@ -141,6 +180,24 @@ func (r *Reader) Read(b []byte) (n int, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *Reader) Discard(n int) error {
|
||||
for {
|
||||
if r.cached >= n {
|
||||
r.cached -= n
|
||||
r.index += n
|
||||
return nil
|
||||
} else if r.cached > 0 {
|
||||
n -= r.cached
|
||||
r.cached = 0
|
||||
r.index = 0
|
||||
}
|
||||
err := r.readInternal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Writer struct {
|
||||
upstream io.Writer
|
||||
cipher cipher.AEAD
|
||||
|
@ -197,10 +254,6 @@ func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = common.FlushVar(&w.upstream)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(readN)
|
||||
}
|
||||
}
|
||||
|
@ -227,6 +280,70 @@ func (w *Writer) Write(p []byte) (n int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (w *Writer) Buffer() *buf.Buffer {
|
||||
return buf.With(w.buffer)
|
||||
}
|
||||
|
||||
func (w *Writer) BufferedWriter(reversed int) *BufferedWriter {
|
||||
return &BufferedWriter{
|
||||
upstream: w,
|
||||
reversed: reversed,
|
||||
data: w.buffer[PacketLengthBufferSize+w.cipher.Overhead() : len(w.buffer)-w.cipher.Overhead()],
|
||||
}
|
||||
}
|
||||
|
||||
type BufferedWriter struct {
|
||||
upstream *Writer
|
||||
data []byte
|
||||
reversed int
|
||||
index int
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Upstream() io.Writer {
|
||||
return w.upstream
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Replaceable() bool {
|
||||
return w.index == 0
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
||||
var index int
|
||||
for {
|
||||
cachedN := copy(w.data[w.reversed+w.index:], p[index:])
|
||||
if cachedN == len(p[index:]) {
|
||||
w.index += cachedN
|
||||
return cachedN, nil
|
||||
}
|
||||
err = w.Flush()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
index += cachedN
|
||||
}
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Flush() error {
|
||||
if w.index == 0 {
|
||||
if w.reversed > 0 {
|
||||
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed])
|
||||
w.reversed = 0
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
buffer := w.upstream.buffer[w.reversed:]
|
||||
binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index))
|
||||
w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil)
|
||||
increaseNonce(w.upstream.nonce)
|
||||
offset := w.upstream.cipher.Overhead() + PacketLengthBufferSize
|
||||
packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil)
|
||||
increaseNonce(w.upstream.nonce)
|
||||
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)])
|
||||
w.reversed = 0
|
||||
return err
|
||||
}
|
||||
|
||||
func increaseNonce(nonce []byte) {
|
||||
for i := range nonce {
|
||||
nonce[i]++
|
||||
|
|
|
@ -189,56 +189,47 @@ type clientConn struct {
|
|||
destination *M.AddrPort
|
||||
|
||||
access sync.Mutex
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
reader *Reader
|
||||
writer *Writer
|
||||
}
|
||||
|
||||
func (c *clientConn) writeRequest(payload []byte) error {
|
||||
_request := buf.StackNew()
|
||||
request := common.Dup(_request)
|
||||
_salt := make([]byte, c.method.keySaltLength)
|
||||
salt := common.Dup(_salt)
|
||||
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
||||
|
||||
common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
|
||||
var writer io.Writer = c.Conn
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: request,
|
||||
}
|
||||
writer = NewWriter(
|
||||
writer,
|
||||
c.method.constructor(Kdf(c.method.key, request.Bytes(), c.method.keySaltLength)),
|
||||
key := Kdf(c.method.key, salt, c.method.keySaltLength)
|
||||
writer := NewWriter(
|
||||
c.Conn,
|
||||
c.method.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
|
||||
if len(payload) > 0 {
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: header,
|
||||
}
|
||||
|
||||
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
||||
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = writer.Write(payload)
|
||||
_, err = bufferedWriter.Write(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
||||
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := common.FlushVar(&writer)
|
||||
err := bufferedWriter.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.writer = writer
|
||||
return nil
|
||||
}
|
||||
|
@ -278,7 +269,7 @@ func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
|
|||
if err = c.readResponse(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.(io.WriterTo).WriteTo(w)
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||
|
@ -302,9 +293,9 @@ func (c *clientConn) Write(p []byte) (n int, err error) {
|
|||
|
||||
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer == nil {
|
||||
panic("missing handshake")
|
||||
return rw.ReadFrom0(c, r)
|
||||
}
|
||||
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
||||
return c.writer.ReadFrom(r)
|
||||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
|
|
164
protocol/shadowsocks/shadowaead/service.go
Normal file
164
protocol/shadowsocks/shadowaead/service.go
Normal file
|
@ -0,0 +1,164 @@
|
|||
package shadowaead
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
name string
|
||||
keySaltLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
key []byte
|
||||
secureRNG io.Reader
|
||||
replayFilter replay.Filter
|
||||
handler shadowsocks.Handler
|
||||
}
|
||||
|
||||
func NewService(method string, key []byte, password []byte, secureRNG io.Reader, replayFilter bool, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||
s := &Service{
|
||||
name: method,
|
||||
secureRNG: secureRNG,
|
||||
handler: handler,
|
||||
}
|
||||
if replayFilter {
|
||||
s.replayFilter = replay.NewBloomRing()
|
||||
}
|
||||
switch method {
|
||||
case "aes-128-gcm":
|
||||
s.keySaltLength = 16
|
||||
s.constructor = newAESGCM
|
||||
case "aes-192-gcm":
|
||||
s.keySaltLength = 24
|
||||
s.constructor = newAESGCM
|
||||
case "aes-256-gcm":
|
||||
s.keySaltLength = 32
|
||||
s.constructor = newAESGCM
|
||||
case "chacha20-ietf-poly1305":
|
||||
s.keySaltLength = 32
|
||||
s.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
case "xchacha20-ietf-poly1305":
|
||||
s.keySaltLength = 32
|
||||
s.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.NewX(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
}
|
||||
if len(key) == s.keySaltLength {
|
||||
s.key = key
|
||||
} else if len(key) > 0 {
|
||||
return nil, ErrBadKey
|
||||
} else if len(password) > 0 {
|
||||
s.key = shadowsocks.Key(password, s.keySaltLength)
|
||||
} else {
|
||||
return nil, ErrMissingPassword
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Service) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
_salt := buf.Make(s.keySaltLength)
|
||||
salt := common.Dup(_salt)
|
||||
|
||||
_, err := io.ReadFull(conn, salt)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read salt")
|
||||
}
|
||||
|
||||
key := Kdf(s.key, salt, s.keySaltLength)
|
||||
reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize)
|
||||
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
|
||||
return s.handler.NewConnection(&serverConn{
|
||||
Service: s,
|
||||
Conn: conn,
|
||||
reader: reader,
|
||||
}, metadata)
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
*Service
|
||||
net.Conn
|
||||
access sync.Mutex
|
||||
reader *Reader
|
||||
writer *Writer
|
||||
}
|
||||
|
||||
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
||||
_salt := buf.Make(c.keySaltLength)
|
||||
salt := common.Dup(_salt)
|
||||
common.Must1(io.ReadFull(c.secureRNG, salt))
|
||||
|
||||
key := Kdf(c.key, salt, c.keySaltLength)
|
||||
writer := NewWriter(
|
||||
c.Conn,
|
||||
c.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
if len(payload) > 0 {
|
||||
_, err = bufferedWriter.Write(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = bufferedWriter.Flush()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.writer = writer
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(p []byte) (n int, err error) {
|
||||
if c.writer != nil {
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
c.access.Lock()
|
||||
if c.writer != nil {
|
||||
c.access.Unlock()
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.writeResponse(p)
|
||||
}
|
||||
|
||||
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer != nil {
|
||||
return rw.ReadFrom0(c, r)
|
||||
}
|
||||
return c.writer.ReadFrom(r)
|
||||
}
|
||||
|
||||
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
|
@ -197,8 +197,8 @@ type clientConn struct {
|
|||
|
||||
requestSalt []byte
|
||||
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
reader *shadowaead.Reader
|
||||
writer *shadowaead.Writer
|
||||
}
|
||||
|
||||
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
|
||||
|
@ -222,68 +222,56 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
|
|||
}
|
||||
|
||||
func (c *clientConn) writeRequest(payload []byte) error {
|
||||
_request := buf.StackNew()
|
||||
request := common.Dup(_request)
|
||||
|
||||
salt := make([]byte, KeySaltSize)
|
||||
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
||||
common.Must1(request.Write(salt))
|
||||
c.method.writeExtendedIdentityHeaders(request, salt)
|
||||
|
||||
var writer io.Writer
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: c.Conn,
|
||||
Buffer: request,
|
||||
}
|
||||
|
||||
key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength)
|
||||
writer = shadowaead.NewWriter(
|
||||
writer,
|
||||
writer := shadowaead.NewWriter(
|
||||
c.Conn,
|
||||
c.method.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
c.method.writeExtendedIdentityHeaders(header, salt)
|
||||
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: header,
|
||||
}
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
|
||||
common.Must(rw.WriteByte(writer, HeaderTypeClient))
|
||||
common.Must(binary.Write(writer, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeClient))
|
||||
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
|
||||
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
||||
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
|
||||
if err != nil {
|
||||
return E.Cause(err, "write destination")
|
||||
}
|
||||
|
||||
if len(payload) > 0 {
|
||||
err = binary.Write(writer, binary.BigEndian, uint16(0))
|
||||
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(0))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding length")
|
||||
}
|
||||
_, err = writer.Write(payload)
|
||||
_, err = bufferedWriter.Write(payload)
|
||||
if err != nil {
|
||||
return E.Cause(err, "write payload")
|
||||
}
|
||||
} else {
|
||||
pLen := rand.Intn(MaxPaddingLength + 1)
|
||||
err = binary.Write(writer, binary.BigEndian, uint16(pLen))
|
||||
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(pLen))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding length")
|
||||
}
|
||||
_, err = io.CopyN(writer, c.method.secureRNG, int64(pLen))
|
||||
_, err = io.CopyN(bufferedWriter, c.method.secureRNG, int64(pLen))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding")
|
||||
}
|
||||
}
|
||||
|
||||
err = common.FlushVar(&writer)
|
||||
err = bufferedWriter.Flush()
|
||||
if err != nil {
|
||||
return E.Cause(err, "client handshake")
|
||||
}
|
||||
|
||||
c.requestSalt = salt
|
||||
c.writer = writer
|
||||
return nil
|
||||
|
@ -363,7 +351,7 @@ func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
|
|||
if err = c.readResponse(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.(io.WriterTo).WriteTo(w)
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||
|
@ -389,10 +377,10 @@ func (c *clientConn) Write(p []byte) (n int, err error) {
|
|||
|
||||
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer == nil {
|
||||
panic("missing client handshake")
|
||||
return rw.ReadFrom0(c, r)
|
||||
}
|
||||
|
||||
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
||||
return c.writer.ReadFrom(r)
|
||||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
|
@ -540,7 +528,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
|||
c.session.lastFilter = c.session.filter
|
||||
c.session.lastRemoteSeen = time.Now().Unix()
|
||||
c.session.lastRemoteCipher = c.session.remoteCipher
|
||||
c.session.filter = new(wgReplay.Filter)
|
||||
c.session.filter = wgReplay.Filter{}
|
||||
}
|
||||
}
|
||||
c.session.remoteSessionId = sessionId
|
||||
|
@ -577,8 +565,8 @@ type udpSession struct {
|
|||
cipher cipher.AEAD
|
||||
remoteCipher cipher.AEAD
|
||||
lastRemoteCipher cipher.AEAD
|
||||
filter *wgReplay.Filter
|
||||
lastFilter *wgReplay.Filter
|
||||
filter wgReplay.Filter
|
||||
lastFilter wgReplay.Filter
|
||||
}
|
||||
|
||||
func (s *udpSession) nextPacketId() uint64 {
|
||||
|
@ -588,7 +576,6 @@ func (s *udpSession) nextPacketId() uint64 {
|
|||
func (m *Method) newUDPSession() *udpSession {
|
||||
session := &udpSession{
|
||||
sessionId: rand.Uint64(),
|
||||
filter: new(wgReplay.Filter),
|
||||
}
|
||||
if m.udpCipher == nil {
|
||||
sessionId := make([]byte, 8)
|
||||
|
|
195
protocol/shadowsocks/shadowaead_2022/service.go
Normal file
195
protocol/shadowsocks/shadowaead_2022/service.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package shadowaead_2022
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
name string
|
||||
secureRNG io.Reader
|
||||
keyLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
psk []byte
|
||||
replayFilter replay.Filter
|
||||
handler shadowsocks.Handler
|
||||
}
|
||||
|
||||
func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||
s := &Service{
|
||||
name: method,
|
||||
psk: psk,
|
||||
secureRNG: secureRNG,
|
||||
replayFilter: replay.NewCuckoo(60),
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
if len(psk) != KeySaltSize {
|
||||
return nil, shadowaead.ErrBadKey
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "2022-blake3-aes-128-gcm":
|
||||
s.keyLength = 16
|
||||
s.constructor = newAESGCM
|
||||
// m.blockConstructor = newAES
|
||||
// m.udpBlockCipher = newAES(m.psk)
|
||||
case "2022-blake3-aes-256-gcm":
|
||||
s.keyLength = 32
|
||||
s.constructor = newAESGCM
|
||||
// m.blockConstructor = newAES
|
||||
// m.udpBlockCipher = newAES(m.psk)
|
||||
case "2022-blake3-chacha20-poly1305":
|
||||
s.keyLength = 32
|
||||
s.constructor = newChacha20Poly1305
|
||||
// m.udpCipher = newXChacha20Poly1305(m.psk)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Service) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
requestSalt := make([]byte, KeySaltSize)
|
||||
_, err := io.ReadFull(conn, requestSalt)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read request salt")
|
||||
}
|
||||
|
||||
if !s.replayFilter.Check(requestSalt) {
|
||||
return E.New("salt not unique")
|
||||
}
|
||||
|
||||
requestKey := Blake3DeriveKey(s.psk, requestSalt, s.keyLength)
|
||||
reader := shadowaead.NewReader(
|
||||
conn,
|
||||
s.constructor(common.Dup(requestKey)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
|
||||
headerType, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read header")
|
||||
}
|
||||
|
||||
if headerType != HeaderTypeClient {
|
||||
return ErrBadHeaderType
|
||||
}
|
||||
|
||||
var epoch uint64
|
||||
err = binary.Read(reader, binary.BigEndian, &epoch)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read timestamp")
|
||||
}
|
||||
if math.Abs(float64(time.Now().Unix()-int64(epoch))) > 30 {
|
||||
return ErrBadTimestamp
|
||||
}
|
||||
|
||||
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read destination")
|
||||
}
|
||||
|
||||
var paddingLen uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &paddingLen)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read padding length")
|
||||
}
|
||||
|
||||
if paddingLen > 0 {
|
||||
err = reader.Discard(int(paddingLen))
|
||||
if err != nil {
|
||||
return E.Cause(err, "discard padding")
|
||||
}
|
||||
}
|
||||
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
return s.handler.NewConnection(&serverConn{
|
||||
Service: s,
|
||||
Conn: conn,
|
||||
reader: reader,
|
||||
requestSalt: requestSalt,
|
||||
}, metadata)
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
*Service
|
||||
net.Conn
|
||||
access sync.Mutex
|
||||
reader *shadowaead.Reader
|
||||
writer *shadowaead.Writer
|
||||
requestSalt []byte
|
||||
}
|
||||
|
||||
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
||||
_salt := make([]byte, KeySaltSize)
|
||||
salt := common.Dup(_salt)
|
||||
common.Must1(io.ReadFull(c.secureRNG, salt))
|
||||
key := Blake3DeriveKey(c.psk, salt, c.keyLength)
|
||||
writer := shadowaead.NewWriter(
|
||||
c.Conn,
|
||||
c.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
|
||||
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeServer))
|
||||
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
common.Must1(bufferedWriter.Write(c.requestSalt))
|
||||
c.requestSalt = nil
|
||||
|
||||
if len(payload) > 0 {
|
||||
_, err = bufferedWriter.Write(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = bufferedWriter.Flush()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.writer = writer
|
||||
n = len(payload)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(p []byte) (n int, err error) {
|
||||
if c.writer != nil {
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
c.access.Lock()
|
||||
if c.writer != nil {
|
||||
c.access.Unlock()
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.writeResponse(p)
|
||||
}
|
||||
|
||||
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer != nil {
|
||||
return rw.ReadFrom0(c, r)
|
||||
}
|
||||
return c.writer.ReadFrom(r)
|
||||
}
|
||||
|
||||
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
|
@ -2,13 +2,13 @@ package socks
|
|||
|
||||
import (
|
||||
"context"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
)
|
||||
|
||||
type PacketConn interface {
|
||||
|
@ -47,26 +47,32 @@ func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
|
|||
|
||||
func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error {
|
||||
return task.Run(ctx, func() error {
|
||||
_buffer := buf.StackNew()
|
||||
_buffer := buf.StackNewMax()
|
||||
buffer := common.Dup(_buffer)
|
||||
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
|
||||
for {
|
||||
destination, err := conn.ReadPacket(buffer)
|
||||
data.FullReset()
|
||||
destination, err := conn.ReadPacket(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Truncate(data.Len())
|
||||
err = dest.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}, func() error {
|
||||
_buffer := buf.StackNew()
|
||||
_buffer := buf.StackNewMax()
|
||||
buffer := common.Dup(_buffer)
|
||||
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
|
||||
for {
|
||||
destination, err := dest.ReadPacket(buffer)
|
||||
data.FullReset()
|
||||
destination, err := dest.ReadPacket(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Truncate(data.Len())
|
||||
err = conn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -125,7 +131,8 @@ func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error
|
|||
|
||||
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
common.Must(header.WriteZeroN(3))
|
||||
common.Must(AddressSerializer.WriteAddrPort(header, addrPort))
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
|
|
|
@ -83,7 +83,7 @@ func HandleConnection(conn net.Conn, authenticator auth.Authenticator, bind neti
|
|||
if err != nil {
|
||||
return E.Cause(err, "read user auth request")
|
||||
}
|
||||
response := new(UsernamePasswordAuthResponse)
|
||||
response := &UsernamePasswordAuthResponse{}
|
||||
if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) {
|
||||
response.Status = UsernamePasswordStatusSuccess
|
||||
} else {
|
||||
|
@ -109,6 +109,7 @@ func HandleConnection(conn net.Conn, authenticator auth.Authenticator, bind neti
|
|||
if err != nil {
|
||||
return E.Cause(err, "write socks response")
|
||||
}
|
||||
metadata.Protocol = "socks"
|
||||
metadata.Destination = request.Destination
|
||||
return handler.NewConnection(conn, metadata)
|
||||
case CommandUDPAssociate:
|
||||
|
|
|
@ -92,16 +92,22 @@ func (l *Listener) loop() {
|
|||
}
|
||||
switch l.trans {
|
||||
case redir.ModeRedirect:
|
||||
metadata.Destination, _ = redir.GetOriginalDestination(tcpConn)
|
||||
destination, err := redir.GetOriginalDestination(tcpConn)
|
||||
if err == nil {
|
||||
metadata.Protocol = "redirect"
|
||||
metadata.Destination = destination
|
||||
}
|
||||
case redir.ModeTProxy:
|
||||
lAddr := tcpConn.LocalAddr().(*net.TCPAddr)
|
||||
rAddr := tcpConn.RemoteAddr().(*net.TCPAddr)
|
||||
|
||||
if lAddr.Port != l.lAddr.Port || !lAddr.IP.Equal(rAddr.IP) && !lAddr.IP.IsLoopback() && !lAddr.IP.IsPrivate() {
|
||||
metadata.Protocol = "tproxy"
|
||||
metadata.Destination = M.AddrPortFromNetAddr(lAddr)
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
metadata.Protocol = "tcp"
|
||||
hErr := l.handler.NewConnection(tcpConn, metadata)
|
||||
if hErr != nil {
|
||||
l.handler.HandleError(&Error{Conn: tcpConn, Cause: hErr})
|
||||
|
|
|
@ -80,7 +80,8 @@ func (l *Listener) loop() {
|
|||
}
|
||||
buffer.Truncate(n)
|
||||
err = l.handler.NewPacket(buffer, M.Metadata{
|
||||
Source: M.AddrPortFromNetAddr(addr),
|
||||
Protocol: "udp",
|
||||
Source: M.AddrPortFromNetAddr(addr),
|
||||
})
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
|
@ -104,6 +105,7 @@ func (l *Listener) loop() {
|
|||
}
|
||||
buffer.Truncate(n)
|
||||
err = l.handler.NewPacket(buffer, M.Metadata{
|
||||
Protocol: "tproxy",
|
||||
Source: M.AddrPortFromAddrPort(addr),
|
||||
Destination: destination,
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue