mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 12:27:37 +03:00
Refine buffer
This commit is contained in:
parent
31d4b88581
commit
f16dd7a336
30 changed files with 993 additions and 209 deletions
|
@ -62,7 +62,7 @@ func main() {
|
||||||
Short: "shadowsocks client",
|
Short: "shadowsocks client",
|
||||||
Version: sing.VersionStr,
|
Version: sing.VersionStr,
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
Run(cmd, f)
|
run(cmd, f)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ Only available with Linux kernel > 3.7.0.`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type LocalClient struct {
|
type client struct {
|
||||||
*mixed.Listener
|
*mixed.Listener
|
||||||
*geosite.Matcher
|
*geosite.Matcher
|
||||||
server *M.AddrPort
|
server *M.AddrPort
|
||||||
|
@ -104,7 +104,7 @@ type LocalClient struct {
|
||||||
bypass string
|
bypass string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLocalClient(f *flags) (*LocalClient, error) {
|
func newClient(f *flags) (*client, error) {
|
||||||
if f.ConfigFile != "" {
|
if f.ConfigFile != "" {
|
||||||
configFile, err := ioutil.ReadFile(f.ConfigFile)
|
configFile, err := ioutil.ReadFile(f.ConfigFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -159,13 +159,13 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
|
||||||
return nil, E.New("missing method")
|
return nil, E.New("missing method")
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &LocalClient{
|
c := &client{
|
||||||
server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort),
|
server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort),
|
||||||
bypass: f.Bypass,
|
bypass: f.Bypass,
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.Method == shadowsocks.MethodNone {
|
if f.Method == shadowsocks.MethodNone {
|
||||||
client.method = shadowsocks.NewNone()
|
c.method = shadowsocks.NewNone()
|
||||||
} else {
|
} else {
|
||||||
var pskList [][]byte
|
var pskList [][]byte
|
||||||
if f.Key != "" {
|
if f.Key != "" {
|
||||||
|
@ -183,7 +183,7 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
|
||||||
if f.UseSystemRNG {
|
if f.UseSystemRNG {
|
||||||
rng = random.System
|
rng = random.System
|
||||||
} else {
|
} else {
|
||||||
rng = random.Blake3KeyedHash()
|
rng = random.System
|
||||||
}
|
}
|
||||||
if f.ReducedSaltEntropy {
|
if f.ReducedSaltEntropy {
|
||||||
rng = &shadowsocks.ReducedEntropyReader{Reader: rng}
|
rng = &shadowsocks.ReducedEntropyReader{Reader: rng}
|
||||||
|
@ -200,17 +200,17 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
client.method = method
|
c.method = method
|
||||||
} else if common.Contains(shadowaead_2022.List, f.Method) {
|
} else if common.Contains(shadowaead_2022.List, f.Method) {
|
||||||
method, err := shadowaead_2022.New(f.Method, pskList, rng)
|
method, err := shadowaead_2022.New(f.Method, pskList, rng)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
client.method = method
|
c.method = method
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
client.dialer.Control = func(network, address string, c syscall.RawConn) error {
|
c.dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||||
var rawFd uintptr
|
var rawFd uintptr
|
||||||
err := c.Control(func(fd uintptr) {
|
err := c.Control(func(fd uintptr) {
|
||||||
rawFd = fd
|
rawFd = fd
|
||||||
|
@ -256,7 +256,7 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
|
||||||
bind = netip.IPv6Unspecified()
|
bind = netip.IPv6Unspecified()
|
||||||
}
|
}
|
||||||
|
|
||||||
client.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, transproxyMode, client)
|
c.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, transproxyMode, c)
|
||||||
|
|
||||||
if f.Bypass != "" {
|
if f.Bypass != "" {
|
||||||
err := geoip.LoadMMDB("Country.mmdb")
|
err := geoip.LoadMMDB("Country.mmdb")
|
||||||
|
@ -278,11 +278,11 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
client.Matcher = geositeMatcher
|
c.Matcher = geositeMatcher
|
||||||
}
|
}
|
||||||
debug.FreeOSMemory()
|
debug.FreeOSMemory()
|
||||||
|
|
||||||
return client, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func bypass(conn net.Conn, destination *M.AddrPort) error {
|
func bypass(conn net.Conn, destination *M.AddrPort) error {
|
||||||
|
@ -302,7 +302,7 @@ func bypass(conn net.Conn, destination *M.AddrPort) error {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
func (c *client) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||||
if c.bypass != "" {
|
if c.bypass != "" {
|
||||||
if metadata.Destination.Addr.Family().IsFqdn() {
|
if metadata.Destination.Addr.Family().IsFqdn() {
|
||||||
if c.Match(metadata.Destination.Addr.Fqdn()) {
|
if c.Match(metadata.Destination.Addr.Fqdn()) {
|
||||||
|
@ -315,7 +315,7 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
logrus.Info("outbound ", metadata.Protocol, " TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
serverConn, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
|
serverConn, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
|
||||||
|
@ -339,7 +339,6 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||||
}
|
}
|
||||||
serverConn = c.method.DialEarlyConn(serverConn, metadata.Destination)
|
serverConn = c.method.DialEarlyConn(serverConn, metadata.Destination)
|
||||||
_, err = serverConn.Write(payload.Bytes())
|
_, err = serverConn.Write(payload.Bytes())
|
||||||
payload.Release()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "client handshake")
|
return E.Cause(err, "client handshake")
|
||||||
}
|
}
|
||||||
|
@ -347,7 +346,7 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||||
return rw.CopyConn(ctx, serverConn, conn)
|
return rw.CopyConn(ctx, serverConn, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error {
|
func (c *client) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String())
|
udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -371,32 +370,28 @@ func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) e
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func Run(cmd *cobra.Command, flags *flags) {
|
func run(cmd *cobra.Command, flags *flags) {
|
||||||
client, err := NewLocalClient(flags)
|
c, err := newClient(flags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n")
|
logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n")
|
||||||
cmd.Help()
|
cmd.Help()
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
err = client.Listener.Start()
|
err = c.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatal(err)
|
logrus.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
logrus.Info("mixed server started at ", c.TCPListener.Addr())
|
||||||
logrus.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logrus.Info("mixed server started at ", client.Listener.TCPListener.Addr())
|
|
||||||
|
|
||||||
osSignals := make(chan os.Signal, 1)
|
osSignals := make(chan os.Signal, 1)
|
||||||
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
|
||||||
<-osSignals
|
<-osSignals
|
||||||
|
|
||||||
client.Listener.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *LocalClient) HandleError(err error) {
|
func (c *client) HandleError(err error) {
|
||||||
common.Close(err)
|
common.Close(err)
|
||||||
if E.IsClosed(err) {
|
if E.IsClosed(err) {
|
||||||
return
|
return
|
||||||
|
|
151
cli/ss-server/main.go
Normal file
151
cli/ss-server/main.go
Normal file
|
@ -0,0 +1,151 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
"github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/common/random"
|
||||||
|
"github.com/sagernet/sing/common/rw"
|
||||||
|
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||||
|
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
|
||||||
|
"github.com/sagernet/sing/transport/tcp"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
type flags struct {
|
||||||
|
Bind string `json:"local_address"`
|
||||||
|
LocalPort uint16 `json:"local_port"`
|
||||||
|
// Password string `json:"password"`
|
||||||
|
Key string `json:"key"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
Verbose bool `json:"verbose"`
|
||||||
|
ConfigFile string
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
logrus.SetLevel(logrus.TraceLevel)
|
||||||
|
|
||||||
|
f := new(flags)
|
||||||
|
|
||||||
|
command := &cobra.Command{
|
||||||
|
Use: "ss-local",
|
||||||
|
Short: "shadowsocks client",
|
||||||
|
Version: sing.VersionStr,
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
run(cmd, f)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
command.Flags().StringVarP(&f.Bind, "local-address", "b", "", "Set the local address.")
|
||||||
|
command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.")
|
||||||
|
command.Flags().StringVarP(&f.Key, "key", "k", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
|
||||||
|
|
||||||
|
var supportedCiphers []string
|
||||||
|
supportedCiphers = append(supportedCiphers, shadowsocks.MethodNone)
|
||||||
|
supportedCiphers = append(supportedCiphers, shadowaead_2022.List...)
|
||||||
|
|
||||||
|
command.Flags().StringVarP(&f.Method, "encrypt-method", "m", "", "Set the cipher.\n\nSupported ciphers:\n\n"+strings.Join(supportedCiphers, "\n"))
|
||||||
|
command.Flags().StringVarP(&f.ConfigFile, "config", "c", "", "Use a configuration file.")
|
||||||
|
command.Flags().BoolVarP(&f.Verbose, "verbose", "v", true, "Enable verbose mode.")
|
||||||
|
|
||||||
|
err := command.Execute()
|
||||||
|
if err != nil {
|
||||||
|
logrus.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func run(cmd *cobra.Command, f *flags) {
|
||||||
|
s, err := newServer(f)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Fatal(err)
|
||||||
|
}
|
||||||
|
err = s.tcpIn.Start()
|
||||||
|
if err != nil {
|
||||||
|
logrus.Fatal(err)
|
||||||
|
}
|
||||||
|
logrus.Info("server started at ", s.tcpIn.TCPListener.Addr())
|
||||||
|
|
||||||
|
osSignals := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
|
||||||
|
<-osSignals
|
||||||
|
|
||||||
|
s.tcpIn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type server struct {
|
||||||
|
tcpIn *tcp.Listener
|
||||||
|
service shadowsocks.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServer(f *flags) (*server, error) {
|
||||||
|
s := new(server)
|
||||||
|
|
||||||
|
if f.Method == shadowsocks.MethodNone {
|
||||||
|
s.service = shadowsocks.NewNoneService(s)
|
||||||
|
} else if common.Contains(shadowaead_2022.List, f.Method) {
|
||||||
|
var pskList [][]byte
|
||||||
|
if f.Key != "" {
|
||||||
|
keyStrList := strings.Split(f.Key, ":")
|
||||||
|
pskList = make([][]byte, len(keyStrList))
|
||||||
|
for i, keyStr := range keyStrList {
|
||||||
|
key, err := base64.StdEncoding.DecodeString(keyStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "decode key")
|
||||||
|
}
|
||||||
|
pskList[i] = key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rng := random.System
|
||||||
|
service, err := shadowaead_2022.NewService(f.Method, pskList[0], rng, s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.service = service
|
||||||
|
} else {
|
||||||
|
return nil, E.New("unsupported method " + f.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
var bind netip.Addr
|
||||||
|
if f.Bind != "" {
|
||||||
|
addr, err := netip.ParseAddr(f.Bind)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "bad local address")
|
||||||
|
}
|
||||||
|
bind = addr
|
||||||
|
} else {
|
||||||
|
bind = netip.IPv6Unspecified()
|
||||||
|
}
|
||||||
|
s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.LocalPort), s)
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||||
|
if metadata.Protocol != "shadowsocks" {
|
||||||
|
return s.service.NewConnection(conn, metadata)
|
||||||
|
}
|
||||||
|
logrus.Info("inbound TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
||||||
|
destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return rw.CopyConn(context.Background(), conn, destConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) HandleError(err error) {
|
||||||
|
if E.IsClosed(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logrus.Warn(err)
|
||||||
|
}
|
|
@ -31,6 +31,12 @@ func StackNew() *Buffer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func StackNewMax() *Buffer {
|
||||||
|
return &Buffer{
|
||||||
|
data: make([]byte, 65535),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func StackNewSize(size int) *Buffer {
|
func StackNewSize(size int) *Buffer {
|
||||||
return &Buffer{
|
return &Buffer{
|
||||||
data: Make(size),
|
data: Make(size),
|
||||||
|
@ -291,6 +297,14 @@ func (b *Buffer) Release() {
|
||||||
*b = Buffer{}
|
*b = Buffer{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) Cut(start int, end int) *Buffer {
|
||||||
|
b.start += start
|
||||||
|
b.end = len(b.data) - end
|
||||||
|
return &Buffer{
|
||||||
|
data: b.data[b.start:b.end],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (b Buffer) Len() int {
|
func (b Buffer) Len() int {
|
||||||
return b.end - b.start
|
return b.end - b.start
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,81 +0,0 @@
|
||||||
package buf
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BufferedReader struct {
|
|
||||||
Reader io.Reader
|
|
||||||
Buffer *Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *BufferedReader) Upstream() io.Reader {
|
|
||||||
if r.Buffer != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return r.Reader
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *BufferedReader) Replaceable() bool {
|
|
||||||
return r.Buffer == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *BufferedReader) Read(p []byte) (n int, err error) {
|
|
||||||
if r.Buffer != nil {
|
|
||||||
n, err = r.Buffer.Read(p)
|
|
||||||
if r.Buffer.IsEmpty() {
|
|
||||||
r.Buffer.Release()
|
|
||||||
r.Buffer = nil
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return r.Reader.Read(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
type BufferedWriter struct {
|
|
||||||
Writer io.Writer
|
|
||||||
Buffer *Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *BufferedWriter) Upstream() io.Writer {
|
|
||||||
return w.Writer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *BufferedWriter) Replaceable() bool {
|
|
||||||
return w.Buffer == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
|
||||||
if w.Buffer == nil {
|
|
||||||
return w.Writer.Write(p)
|
|
||||||
}
|
|
||||||
n, err = w.Buffer.Write(p)
|
|
||||||
if err == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return len(p), w.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *BufferedWriter) Flush() error {
|
|
||||||
if w.Buffer == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
buffer := w.Buffer
|
|
||||||
w.Buffer = nil
|
|
||||||
defer buffer.Release()
|
|
||||||
if buffer.IsEmpty() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return common.Error(w.Writer.Write(buffer.Bytes()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *BufferedWriter) Close() error {
|
|
||||||
buffer := w.Buffer
|
|
||||||
if buffer != nil {
|
|
||||||
w.Buffer = nil
|
|
||||||
buffer.Release()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -29,11 +29,11 @@ func (r *readWriteConn) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *readWriteConn) LocalAddr() net.Addr {
|
func (r *readWriteConn) LocalAddr() net.Addr {
|
||||||
return new(DummyAddr)
|
return &DummyAddr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *readWriteConn) RemoteAddr() net.Addr {
|
func (r *readWriteConn) RemoteAddr() net.Addr {
|
||||||
return new(DummyAddr)
|
return &DummyAddr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *readWriteConn) SetDeadline(t time.Time) error {
|
func (r *readWriteConn) SetDeadline(t time.Time) error {
|
||||||
|
@ -53,7 +53,7 @@ type readConn struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *readConn) Write(b []byte) (n int, err error) {
|
func (r *readConn) Write(b []byte) (n int, err error) {
|
||||||
return 0, new(ReadOnlyException)
|
return 0, &ReadOnlyException{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type writeConn struct {
|
type writeConn struct {
|
||||||
|
@ -62,23 +62,23 @@ type writeConn struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *writeConn) Read(p []byte) (n int, err error) {
|
func (w *writeConn) Read(p []byte) (n int, err error) {
|
||||||
return 0, new(WriteOnlyException)
|
return 0, &WriteOnlyException{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewReadConn(reader io.Reader) net.Conn {
|
func NewReadConn(reader io.Reader) net.Conn {
|
||||||
c := new(readConn)
|
c := &readConn{}
|
||||||
c.Reader = reader
|
c.Reader = reader
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWritConn(writer io.Writer) net.Conn {
|
func NewWritConn(writer io.Writer) net.Conn {
|
||||||
c := new(writeConn)
|
c := &writeConn{}
|
||||||
c.Writer = writer
|
c.Writer = writer
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewReadWriteConn(reader io.Reader, writer io.Writer) net.Conn {
|
func NewReadWriteConn(reader io.Reader, writer io.Writer) net.Conn {
|
||||||
c := new(readWriteConn)
|
c := &readConn{}
|
||||||
c.Reader = reader
|
c.Reader = reader
|
||||||
c.Writer = writer
|
c.Writer = writer
|
||||||
return c
|
return c
|
||||||
|
|
|
@ -52,8 +52,8 @@ func FlushVar(writerP *io.Writer) error {
|
||||||
writerBack = writer
|
writerBack = writer
|
||||||
*writerP = writer
|
*writerP = writer
|
||||||
continue
|
continue
|
||||||
} else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter {
|
} else if setter, hasSetter := writerBack.(UpstreamWriterSetter); hasSetter {
|
||||||
setter.SetWriter(writerBack)
|
setter.SetWriter(u.Upstream())
|
||||||
writer = u.Upstream()
|
writer = u.Upstream()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,4 @@
|
||||||
|
|
||||||
package lowmem
|
package lowmem
|
||||||
|
|
||||||
func init() {
|
const Enabled = true
|
||||||
Enabled = true
|
|
||||||
}
|
|
||||||
|
|
|
@ -4,8 +4,6 @@ import (
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Enabled = false
|
|
||||||
|
|
||||||
func Free() {
|
func Free() {
|
||||||
if Enabled {
|
if Enabled {
|
||||||
debug.FreeOSMemory()
|
debug.FreeOSMemory()
|
||||||
|
|
5
common/lowmem/stub.go
Normal file
5
common/lowmem/stub.go
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
//go:build !debug
|
||||||
|
|
||||||
|
package lowmem
|
||||||
|
|
||||||
|
const Enabled = false
|
|
@ -7,6 +7,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Metadata struct {
|
type Metadata struct {
|
||||||
|
Protocol string
|
||||||
Source *AddrPort
|
Source *AddrPort
|
||||||
Destination *AddrPort
|
Destination *AddrPort
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,16 @@
|
||||||
package common
|
package common
|
||||||
|
|
||||||
import "syscall"
|
import (
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TryFileDescriptor(conn any) (uintptr, error) {
|
||||||
|
if rawConn, isRaw := conn.(syscall.Conn); isRaw {
|
||||||
|
return GetFileDescriptor(rawConn)
|
||||||
|
}
|
||||||
|
return 0, os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
func GetFileDescriptor(conn syscall.Conn) (uintptr, error) {
|
func GetFileDescriptor(conn syscall.Conn) (uintptr, error) {
|
||||||
rawConn, err := conn.SyscallConn()
|
rawConn, err := conn.SyscallConn()
|
||||||
|
|
|
@ -3,10 +3,20 @@ package network
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ContextDialer interface {
|
type ContextDialer interface {
|
||||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
var SystemDialer ContextDialer = &net.Dialer{}
|
var SystemDialer ContextDialer = &DefaultDialer{}
|
||||||
|
|
||||||
|
type DefaultDialer struct {
|
||||||
|
net.Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error) {
|
||||||
|
return d.Dialer.DialContext(ctx, network, address.String())
|
||||||
|
}
|
||||||
|
|
|
@ -3,9 +3,10 @@
|
||||||
package redir
|
package redir
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
"net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TProxy(fd uintptr, isIPv6 bool) error {
|
func TProxy(fd uintptr, isIPv6 bool) error {
|
||||||
|
|
133
common/rw/buffer.go
Normal file
133
common/rw/buffer.go
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
package rw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BufferedWriter struct {
|
||||||
|
Writer io.Writer
|
||||||
|
Buffer *buf.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BufferedWriter) Upstream() io.Writer {
|
||||||
|
return w.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BufferedWriter) Replaceable() bool {
|
||||||
|
return w.Buffer == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
||||||
|
if w.Buffer == nil {
|
||||||
|
return w.Writer.Write(p)
|
||||||
|
}
|
||||||
|
n, err = w.Buffer.Write(p)
|
||||||
|
if n == len(p) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fd, err := common.TryFileDescriptor(w.Writer)
|
||||||
|
if err == nil {
|
||||||
|
_, err = WriteV(fd, w.Buffer.Bytes(), p[n:])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Buffer.Release()
|
||||||
|
w.Buffer = nil
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
_, err = w.Writer.Write(w.Buffer.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = w.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = w.Writer.Write(p[n:])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BufferedWriter) Flush() error {
|
||||||
|
if w.Buffer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if w.Buffer.IsEmpty() {
|
||||||
|
w.Buffer.Release()
|
||||||
|
w.Buffer = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, err := w.Writer.Write(w.Buffer.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.Buffer.Release()
|
||||||
|
w.Buffer = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BufferedWriter) Close() error {
|
||||||
|
buffer := w.Buffer
|
||||||
|
if buffer != nil {
|
||||||
|
w.Buffer = nil
|
||||||
|
buffer.Release()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type HeaderWriter struct {
|
||||||
|
Writer io.Writer
|
||||||
|
Header *buf.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *HeaderWriter) Upstream() io.Writer {
|
||||||
|
return w.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *HeaderWriter) Replaceable() bool {
|
||||||
|
return w.Header == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *HeaderWriter) Write(p []byte) (n int, err error) {
|
||||||
|
if w.Header == nil {
|
||||||
|
return w.Writer.Write(p)
|
||||||
|
}
|
||||||
|
fd, err := common.TryFileDescriptor(w.Writer)
|
||||||
|
if err == nil {
|
||||||
|
_, err = WriteV(fd, w.Header.Bytes(), p)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header.Release()
|
||||||
|
w.Header = nil
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
cachedN, _ := w.Header.Write(p)
|
||||||
|
_, err = w.Writer.Write(w.Header.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header.Release()
|
||||||
|
w.Header = nil
|
||||||
|
if cachedN < len(p) {
|
||||||
|
_, err = w.Writer.Write(p[cachedN:])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *HeaderWriter) Close() error {
|
||||||
|
buffer := w.Header
|
||||||
|
if buffer != nil {
|
||||||
|
w.Header = nil
|
||||||
|
buffer.Release()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Skip(reader io.Reader) error {
|
func Skip(reader io.Reader) error {
|
||||||
|
@ -15,6 +16,9 @@ func SkipN(reader io.Reader, size int) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadByte(reader io.Reader) (byte, error) {
|
func ReadByte(reader io.Reader) (byte, error) {
|
||||||
|
if br, isBr := reader.(io.ByteReader); isBr {
|
||||||
|
return br.ReadByte()
|
||||||
|
}
|
||||||
var b [1]byte
|
var b [1]byte
|
||||||
if err := common.Error(io.ReadFull(reader, b[:])); err != nil {
|
if err := common.Error(io.ReadFull(reader, b[:])); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
@ -37,3 +41,33 @@ func ReadString(reader io.Reader, size int) (string, error) {
|
||||||
}
|
}
|
||||||
return string(b), nil
|
return string(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ReaderFromWriter interface {
|
||||||
|
io.ReaderFrom
|
||||||
|
io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadFrom0(readerFrom ReaderFromWriter, reader io.Reader) (n int64, err error) {
|
||||||
|
n, err = CopyOnce(readerFrom, reader)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var rn int64
|
||||||
|
rn, err = readerFrom.ReadFrom(reader)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n += rn
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func CopyOnce(dest io.Writer, src io.Reader) (n int64, err error) {
|
||||||
|
_buffer := buf.StackNew()
|
||||||
|
buffer := common.Dup(_buffer)
|
||||||
|
n, err = buffer.ReadFrom(src)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = dest.Write(buffer.Bytes())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ type stubByteReader struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r stubByteReader) ReadByte() (byte, error) {
|
func (r stubByteReader) ReadByte() (byte, error) {
|
||||||
return ReadByte(r)
|
return ReadByte(r.Reader)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToByteReader(reader io.Reader) io.ByteReader {
|
func ToByteReader(reader io.Reader) io.ByteReader {
|
||||||
|
|
|
@ -19,7 +19,7 @@ func After(task func() error, after func() error) func() error {
|
||||||
|
|
||||||
func Run(ctx context.Context, tasks ...func() error) error {
|
func Run(ctx context.Context, tasks ...func() error) error {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
wg := new(sync.WaitGroup)
|
wg := &sync.WaitGroup{}
|
||||||
wg.Add(len(tasks))
|
wg.Add(len(tasks))
|
||||||
var retErr error
|
var retErr error
|
||||||
for _, task := range tasks {
|
for _, task := range tasks {
|
||||||
|
|
|
@ -15,7 +15,7 @@ func TestServerConn(t *testing.T) {
|
||||||
serverConn := NewServerConn(udpConn)
|
serverConn := NewServerConn(udpConn)
|
||||||
defer serverConn.Close()
|
defer serverConn.Close()
|
||||||
clientConn := NewClientConn(serverConn)
|
clientConn := NewClientConn(serverConn)
|
||||||
message := new(dnsmessage.Message)
|
message := &dnsmessage.Message{}
|
||||||
message.Header.ID = 1
|
message.Header.ID = 1
|
||||||
message.Header.RecursionDesired = true
|
message.Header.RecursionDesired = true
|
||||||
message.Questions = append(message.Questions, dnsmessage.Question{
|
message.Questions = append(message.Questions, dnsmessage.Question{
|
||||||
|
@ -30,6 +30,7 @@ func TestServerConn(t *testing.T) {
|
||||||
Port: 53,
|
Port: 53,
|
||||||
}))
|
}))
|
||||||
_buffer := buf.StackNew()
|
_buffer := buf.StackNew()
|
||||||
|
common.Use(_buffer)
|
||||||
buffer := common.Dup(_buffer)
|
buffer := common.Dup(_buffer)
|
||||||
common.Must2(buffer.ReadPacketFrom(clientConn))
|
common.Must2(buffer.ReadPacketFrom(clientConn))
|
||||||
common.Must(message.Unpack(buffer.Bytes()))
|
common.Must(message.Unpack(buffer.Bytes()))
|
||||||
|
|
|
@ -95,6 +95,7 @@ func HandleRequest(request *http.Request, conn net.Conn, authenticator auth.Auth
|
||||||
left, right := net.Pipe()
|
left, right := net.Pipe()
|
||||||
go func() {
|
go func() {
|
||||||
metadata.Destination = destination
|
metadata.Destination = destination
|
||||||
|
metadata.Protocol = "http"
|
||||||
err = handler.NewConnection(right, metadata)
|
err = handler.NewConnection(right, metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handler.HandleError(&tcp.Error{Conn: right, Cause: err})
|
handler.HandleError(&tcp.Error{Conn: right, Cause: err})
|
||||||
|
|
|
@ -117,7 +117,8 @@ func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) {
|
func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
return c.Conn.(io.WriterTo).WriteTo(w)
|
return io.Copy(w, c.Conn)
|
||||||
|
// return c.Conn.(io.WriterTo).WriteTo(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *noneConn) RemoteAddr() net.Addr {
|
func (c *noneConn) RemoteAddr() net.Addr {
|
||||||
|
@ -138,7 +139,7 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||||
|
|
||||||
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||||
defer buffer.Release()
|
defer buffer.Release()
|
||||||
_header := buf.StackNew()
|
_header := buf.StackNewMax()
|
||||||
header := common.Dup(_header)
|
header := common.Dup(_header)
|
||||||
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
|
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
42
protocol/shadowsocks/service.go
Normal file
42
protocol/shadowsocks/service.go
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
package shadowsocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
"github.com/sagernet/sing/protocol/socks"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Service interface {
|
||||||
|
M.TCPConnectionHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
type MultiUserService interface {
|
||||||
|
Service
|
||||||
|
AddUser(key []byte)
|
||||||
|
RemoveUser(key []byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Handler interface {
|
||||||
|
M.TCPConnectionHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
type NoneService struct {
|
||||||
|
handler Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNoneService(handler Handler) Service {
|
||||||
|
return &NoneService{
|
||||||
|
handler: handler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *NoneService) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||||
|
destination, err := socks.AddressSerializer.ReadAddrPort(conn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
metadata.Protocol = "shadowsocks"
|
||||||
|
metadata.Destination = destination
|
||||||
|
return s.handler.NewConnection(conn, metadata)
|
||||||
|
}
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -92,6 +91,46 @@ func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Reader) readInternal() (err error) {
|
||||||
|
start := PacketLengthBufferSize + r.cipher.Overhead()
|
||||||
|
_, err = io.ReadFull(r.upstream, r.buffer[:start])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
increaseNonce(r.nonce)
|
||||||
|
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
|
||||||
|
end := length + r.cipher.Overhead()
|
||||||
|
_, err = io.ReadFull(r.upstream, r.buffer[:end])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
increaseNonce(r.nonce)
|
||||||
|
r.cached = length
|
||||||
|
r.index = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reader) ReadByte() (byte, error) {
|
||||||
|
if r.cached == 0 {
|
||||||
|
err := r.readInternal()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
index := r.index
|
||||||
|
r.index++
|
||||||
|
r.cached--
|
||||||
|
return r.buffer[index], nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Reader) Read(b []byte) (n int, err error) {
|
func (r *Reader) Read(b []byte) (n int, err error) {
|
||||||
if r.cached > 0 {
|
if r.cached > 0 {
|
||||||
n = copy(b, r.buffer[r.index:r.index+r.cached])
|
n = copy(b, r.buffer[r.index:r.index+r.cached])
|
||||||
|
@ -141,6 +180,24 @@ func (r *Reader) Read(b []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Reader) Discard(n int) error {
|
||||||
|
for {
|
||||||
|
if r.cached >= n {
|
||||||
|
r.cached -= n
|
||||||
|
r.index += n
|
||||||
|
return nil
|
||||||
|
} else if r.cached > 0 {
|
||||||
|
n -= r.cached
|
||||||
|
r.cached = 0
|
||||||
|
r.index = 0
|
||||||
|
}
|
||||||
|
err := r.readInternal()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Writer struct {
|
type Writer struct {
|
||||||
upstream io.Writer
|
upstream io.Writer
|
||||||
cipher cipher.AEAD
|
cipher cipher.AEAD
|
||||||
|
@ -197,10 +254,6 @@ func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = common.FlushVar(&w.upstream)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n += int64(readN)
|
n += int64(readN)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -227,6 +280,70 @@ func (w *Writer) Write(p []byte) (n int, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Writer) Buffer() *buf.Buffer {
|
||||||
|
return buf.With(w.buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Writer) BufferedWriter(reversed int) *BufferedWriter {
|
||||||
|
return &BufferedWriter{
|
||||||
|
upstream: w,
|
||||||
|
reversed: reversed,
|
||||||
|
data: w.buffer[PacketLengthBufferSize+w.cipher.Overhead() : len(w.buffer)-w.cipher.Overhead()],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type BufferedWriter struct {
|
||||||
|
upstream *Writer
|
||||||
|
data []byte
|
||||||
|
reversed int
|
||||||
|
index int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BufferedWriter) Upstream() io.Writer {
|
||||||
|
return w.upstream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BufferedWriter) Replaceable() bool {
|
||||||
|
return w.index == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
||||||
|
var index int
|
||||||
|
for {
|
||||||
|
cachedN := copy(w.data[w.reversed+w.index:], p[index:])
|
||||||
|
if cachedN == len(p[index:]) {
|
||||||
|
w.index += cachedN
|
||||||
|
return cachedN, nil
|
||||||
|
}
|
||||||
|
err = w.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
index += cachedN
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BufferedWriter) Flush() error {
|
||||||
|
if w.index == 0 {
|
||||||
|
if w.reversed > 0 {
|
||||||
|
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed])
|
||||||
|
w.reversed = 0
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
buffer := w.upstream.buffer[w.reversed:]
|
||||||
|
binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index))
|
||||||
|
w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil)
|
||||||
|
increaseNonce(w.upstream.nonce)
|
||||||
|
offset := w.upstream.cipher.Overhead() + PacketLengthBufferSize
|
||||||
|
packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil)
|
||||||
|
increaseNonce(w.upstream.nonce)
|
||||||
|
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)])
|
||||||
|
w.reversed = 0
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func increaseNonce(nonce []byte) {
|
func increaseNonce(nonce []byte) {
|
||||||
for i := range nonce {
|
for i := range nonce {
|
||||||
nonce[i]++
|
nonce[i]++
|
||||||
|
|
|
@ -189,56 +189,47 @@ type clientConn struct {
|
||||||
destination *M.AddrPort
|
destination *M.AddrPort
|
||||||
|
|
||||||
access sync.Mutex
|
access sync.Mutex
|
||||||
reader io.Reader
|
reader *Reader
|
||||||
writer io.Writer
|
writer *Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) writeRequest(payload []byte) error {
|
func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
_request := buf.StackNew()
|
_salt := make([]byte, c.method.keySaltLength)
|
||||||
request := common.Dup(_request)
|
salt := common.Dup(_salt)
|
||||||
|
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
||||||
|
|
||||||
common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
key := Kdf(c.method.key, salt, c.method.keySaltLength)
|
||||||
|
writer := NewWriter(
|
||||||
var writer io.Writer = c.Conn
|
c.Conn,
|
||||||
writer = &buf.BufferedWriter{
|
c.method.constructor(common.Dup(key)),
|
||||||
Writer: writer,
|
|
||||||
Buffer: request,
|
|
||||||
}
|
|
||||||
writer = NewWriter(
|
|
||||||
writer,
|
|
||||||
c.method.constructor(Kdf(c.method.key, request.Bytes(), c.method.keySaltLength)),
|
|
||||||
MaxPacketSize,
|
MaxPacketSize,
|
||||||
)
|
)
|
||||||
|
header := writer.Buffer()
|
||||||
|
header.Write(salt)
|
||||||
|
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||||
|
|
||||||
if len(payload) > 0 {
|
if len(payload) > 0 {
|
||||||
_header := buf.StackNew()
|
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
|
||||||
header := common.Dup(_header)
|
|
||||||
|
|
||||||
writer = &buf.BufferedWriter{
|
|
||||||
Writer: writer,
|
|
||||||
Buffer: header,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = writer.Write(payload)
|
_, err = bufferedWriter.Write(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := common.FlushVar(&writer)
|
err := bufferedWriter.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.writer = writer
|
c.writer = writer
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -278,7 +269,7 @@ func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
if err = c.readResponse(); err != nil {
|
if err = c.readResponse(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return c.reader.(io.WriterTo).WriteTo(w)
|
return c.reader.WriteTo(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||||
|
@ -302,9 +293,9 @@ func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||||
|
|
||||||
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
if c.writer == nil {
|
if c.writer == nil {
|
||||||
panic("missing handshake")
|
return rw.ReadFrom0(c, r)
|
||||||
}
|
}
|
||||||
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
return c.writer.ReadFrom(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientPacketConn struct {
|
type clientPacketConn struct {
|
||||||
|
|
164
protocol/shadowsocks/shadowaead/service.go
Normal file
164
protocol/shadowsocks/shadowaead/service.go
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
package shadowaead
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/cipher"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
"github.com/sagernet/sing/common/replay"
|
||||||
|
"github.com/sagernet/sing/common/rw"
|
||||||
|
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||||
|
"github.com/sagernet/sing/protocol/socks"
|
||||||
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
name string
|
||||||
|
keySaltLength int
|
||||||
|
constructor func(key []byte) cipher.AEAD
|
||||||
|
key []byte
|
||||||
|
secureRNG io.Reader
|
||||||
|
replayFilter replay.Filter
|
||||||
|
handler shadowsocks.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(method string, key []byte, password []byte, secureRNG io.Reader, replayFilter bool, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||||
|
s := &Service{
|
||||||
|
name: method,
|
||||||
|
secureRNG: secureRNG,
|
||||||
|
handler: handler,
|
||||||
|
}
|
||||||
|
if replayFilter {
|
||||||
|
s.replayFilter = replay.NewBloomRing()
|
||||||
|
}
|
||||||
|
switch method {
|
||||||
|
case "aes-128-gcm":
|
||||||
|
s.keySaltLength = 16
|
||||||
|
s.constructor = newAESGCM
|
||||||
|
case "aes-192-gcm":
|
||||||
|
s.keySaltLength = 24
|
||||||
|
s.constructor = newAESGCM
|
||||||
|
case "aes-256-gcm":
|
||||||
|
s.keySaltLength = 32
|
||||||
|
s.constructor = newAESGCM
|
||||||
|
case "chacha20-ietf-poly1305":
|
||||||
|
s.keySaltLength = 32
|
||||||
|
s.constructor = func(key []byte) cipher.AEAD {
|
||||||
|
cipher, err := chacha20poly1305.New(key)
|
||||||
|
common.Must(err)
|
||||||
|
return cipher
|
||||||
|
}
|
||||||
|
case "xchacha20-ietf-poly1305":
|
||||||
|
s.keySaltLength = 32
|
||||||
|
s.constructor = func(key []byte) cipher.AEAD {
|
||||||
|
cipher, err := chacha20poly1305.NewX(key)
|
||||||
|
common.Must(err)
|
||||||
|
return cipher
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(key) == s.keySaltLength {
|
||||||
|
s.key = key
|
||||||
|
} else if len(key) > 0 {
|
||||||
|
return nil, ErrBadKey
|
||||||
|
} else if len(password) > 0 {
|
||||||
|
s.key = shadowsocks.Key(password, s.keySaltLength)
|
||||||
|
} else {
|
||||||
|
return nil, ErrMissingPassword
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||||
|
_salt := buf.Make(s.keySaltLength)
|
||||||
|
salt := common.Dup(_salt)
|
||||||
|
|
||||||
|
_, err := io.ReadFull(conn, salt)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "read salt")
|
||||||
|
}
|
||||||
|
|
||||||
|
key := Kdf(s.key, salt, s.keySaltLength)
|
||||||
|
reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize)
|
||||||
|
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata.Protocol = "shadowsocks"
|
||||||
|
metadata.Destination = destination
|
||||||
|
|
||||||
|
return s.handler.NewConnection(&serverConn{
|
||||||
|
Service: s,
|
||||||
|
Conn: conn,
|
||||||
|
reader: reader,
|
||||||
|
}, metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverConn struct {
|
||||||
|
*Service
|
||||||
|
net.Conn
|
||||||
|
access sync.Mutex
|
||||||
|
reader *Reader
|
||||||
|
writer *Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
||||||
|
_salt := buf.Make(c.keySaltLength)
|
||||||
|
salt := common.Dup(_salt)
|
||||||
|
common.Must1(io.ReadFull(c.secureRNG, salt))
|
||||||
|
|
||||||
|
key := Kdf(c.key, salt, c.keySaltLength)
|
||||||
|
writer := NewWriter(
|
||||||
|
c.Conn,
|
||||||
|
c.constructor(common.Dup(key)),
|
||||||
|
MaxPacketSize,
|
||||||
|
)
|
||||||
|
|
||||||
|
header := writer.Buffer()
|
||||||
|
header.Write(salt)
|
||||||
|
|
||||||
|
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||||
|
if len(payload) > 0 {
|
||||||
|
_, err = bufferedWriter.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bufferedWriter.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.writer = writer
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) Write(p []byte) (n int, err error) {
|
||||||
|
if c.writer != nil {
|
||||||
|
return c.writer.Write(p)
|
||||||
|
}
|
||||||
|
c.access.Lock()
|
||||||
|
if c.writer != nil {
|
||||||
|
c.access.Unlock()
|
||||||
|
return c.writer.Write(p)
|
||||||
|
}
|
||||||
|
defer c.access.Unlock()
|
||||||
|
return c.writeResponse(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
|
if c.writer != nil {
|
||||||
|
return rw.ReadFrom0(c, r)
|
||||||
|
}
|
||||||
|
return c.writer.ReadFrom(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
|
return c.reader.WriteTo(w)
|
||||||
|
}
|
|
@ -197,8 +197,8 @@ type clientConn struct {
|
||||||
|
|
||||||
requestSalt []byte
|
requestSalt []byte
|
||||||
|
|
||||||
reader io.Reader
|
reader *shadowaead.Reader
|
||||||
writer io.Writer
|
writer *shadowaead.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
|
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
|
||||||
|
@ -222,68 +222,56 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) writeRequest(payload []byte) error {
|
func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
_request := buf.StackNew()
|
|
||||||
request := common.Dup(_request)
|
|
||||||
|
|
||||||
salt := make([]byte, KeySaltSize)
|
salt := make([]byte, KeySaltSize)
|
||||||
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
||||||
common.Must1(request.Write(salt))
|
|
||||||
c.method.writeExtendedIdentityHeaders(request, salt)
|
|
||||||
|
|
||||||
var writer io.Writer
|
|
||||||
writer = &buf.BufferedWriter{
|
|
||||||
Writer: c.Conn,
|
|
||||||
Buffer: request,
|
|
||||||
}
|
|
||||||
|
|
||||||
key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength)
|
key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength)
|
||||||
writer = shadowaead.NewWriter(
|
writer := shadowaead.NewWriter(
|
||||||
writer,
|
c.Conn,
|
||||||
c.method.constructor(common.Dup(key)),
|
c.method.constructor(common.Dup(key)),
|
||||||
MaxPacketSize,
|
MaxPacketSize,
|
||||||
)
|
)
|
||||||
|
|
||||||
_header := buf.StackNew()
|
header := writer.Buffer()
|
||||||
header := common.Dup(_header)
|
header.Write(salt)
|
||||||
|
c.method.writeExtendedIdentityHeaders(header, salt)
|
||||||
|
|
||||||
writer = &buf.BufferedWriter{
|
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||||
Writer: writer,
|
|
||||||
Buffer: header,
|
|
||||||
}
|
|
||||||
|
|
||||||
common.Must(rw.WriteByte(writer, HeaderTypeClient))
|
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeClient))
|
||||||
common.Must(binary.Write(writer, binary.BigEndian, uint64(time.Now().Unix())))
|
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
|
||||||
|
|
||||||
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "write destination")
|
return E.Cause(err, "write destination")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(payload) > 0 {
|
if len(payload) > 0 {
|
||||||
err = binary.Write(writer, binary.BigEndian, uint16(0))
|
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "write padding length")
|
return E.Cause(err, "write padding length")
|
||||||
}
|
}
|
||||||
_, err = writer.Write(payload)
|
_, err = bufferedWriter.Write(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "write payload")
|
return E.Cause(err, "write payload")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
pLen := rand.Intn(MaxPaddingLength + 1)
|
pLen := rand.Intn(MaxPaddingLength + 1)
|
||||||
err = binary.Write(writer, binary.BigEndian, uint16(pLen))
|
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(pLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "write padding length")
|
return E.Cause(err, "write padding length")
|
||||||
}
|
}
|
||||||
_, err = io.CopyN(writer, c.method.secureRNG, int64(pLen))
|
_, err = io.CopyN(bufferedWriter, c.method.secureRNG, int64(pLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "write padding")
|
return E.Cause(err, "write padding")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = common.FlushVar(&writer)
|
err = bufferedWriter.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "client handshake")
|
return E.Cause(err, "client handshake")
|
||||||
}
|
}
|
||||||
|
|
||||||
c.requestSalt = salt
|
c.requestSalt = salt
|
||||||
c.writer = writer
|
c.writer = writer
|
||||||
return nil
|
return nil
|
||||||
|
@ -363,7 +351,7 @@ func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
if err = c.readResponse(); err != nil {
|
if err = c.readResponse(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return c.reader.(io.WriterTo).WriteTo(w)
|
return c.reader.WriteTo(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||||
|
@ -389,10 +377,10 @@ func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||||
|
|
||||||
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
if c.writer == nil {
|
if c.writer == nil {
|
||||||
panic("missing client handshake")
|
return rw.ReadFrom0(c, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
return c.writer.ReadFrom(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientPacketConn struct {
|
type clientPacketConn struct {
|
||||||
|
@ -540,7 +528,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||||
c.session.lastFilter = c.session.filter
|
c.session.lastFilter = c.session.filter
|
||||||
c.session.lastRemoteSeen = time.Now().Unix()
|
c.session.lastRemoteSeen = time.Now().Unix()
|
||||||
c.session.lastRemoteCipher = c.session.remoteCipher
|
c.session.lastRemoteCipher = c.session.remoteCipher
|
||||||
c.session.filter = new(wgReplay.Filter)
|
c.session.filter = wgReplay.Filter{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.session.remoteSessionId = sessionId
|
c.session.remoteSessionId = sessionId
|
||||||
|
@ -577,8 +565,8 @@ type udpSession struct {
|
||||||
cipher cipher.AEAD
|
cipher cipher.AEAD
|
||||||
remoteCipher cipher.AEAD
|
remoteCipher cipher.AEAD
|
||||||
lastRemoteCipher cipher.AEAD
|
lastRemoteCipher cipher.AEAD
|
||||||
filter *wgReplay.Filter
|
filter wgReplay.Filter
|
||||||
lastFilter *wgReplay.Filter
|
lastFilter wgReplay.Filter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *udpSession) nextPacketId() uint64 {
|
func (s *udpSession) nextPacketId() uint64 {
|
||||||
|
@ -588,7 +576,6 @@ func (s *udpSession) nextPacketId() uint64 {
|
||||||
func (m *Method) newUDPSession() *udpSession {
|
func (m *Method) newUDPSession() *udpSession {
|
||||||
session := &udpSession{
|
session := &udpSession{
|
||||||
sessionId: rand.Uint64(),
|
sessionId: rand.Uint64(),
|
||||||
filter: new(wgReplay.Filter),
|
|
||||||
}
|
}
|
||||||
if m.udpCipher == nil {
|
if m.udpCipher == nil {
|
||||||
sessionId := make([]byte, 8)
|
sessionId := make([]byte, 8)
|
||||||
|
|
195
protocol/shadowsocks/shadowaead_2022/service.go
Normal file
195
protocol/shadowsocks/shadowaead_2022/service.go
Normal file
|
@ -0,0 +1,195 @@
|
||||||
|
package shadowaead_2022
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/cipher"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
"github.com/sagernet/sing/common/replay"
|
||||||
|
"github.com/sagernet/sing/common/rw"
|
||||||
|
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||||
|
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
||||||
|
"github.com/sagernet/sing/protocol/socks"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
name string
|
||||||
|
secureRNG io.Reader
|
||||||
|
keyLength int
|
||||||
|
constructor func(key []byte) cipher.AEAD
|
||||||
|
psk []byte
|
||||||
|
replayFilter replay.Filter
|
||||||
|
handler shadowsocks.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||||
|
s := &Service{
|
||||||
|
name: method,
|
||||||
|
psk: psk,
|
||||||
|
secureRNG: secureRNG,
|
||||||
|
replayFilter: replay.NewCuckoo(60),
|
||||||
|
handler: handler,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(psk) != KeySaltSize {
|
||||||
|
return nil, shadowaead.ErrBadKey
|
||||||
|
}
|
||||||
|
|
||||||
|
switch method {
|
||||||
|
case "2022-blake3-aes-128-gcm":
|
||||||
|
s.keyLength = 16
|
||||||
|
s.constructor = newAESGCM
|
||||||
|
// m.blockConstructor = newAES
|
||||||
|
// m.udpBlockCipher = newAES(m.psk)
|
||||||
|
case "2022-blake3-aes-256-gcm":
|
||||||
|
s.keyLength = 32
|
||||||
|
s.constructor = newAESGCM
|
||||||
|
// m.blockConstructor = newAES
|
||||||
|
// m.udpBlockCipher = newAES(m.psk)
|
||||||
|
case "2022-blake3-chacha20-poly1305":
|
||||||
|
s.keyLength = 32
|
||||||
|
s.constructor = newChacha20Poly1305
|
||||||
|
// m.udpCipher = newXChacha20Poly1305(m.psk)
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||||
|
requestSalt := make([]byte, KeySaltSize)
|
||||||
|
_, err := io.ReadFull(conn, requestSalt)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "read request salt")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.replayFilter.Check(requestSalt) {
|
||||||
|
return E.New("salt not unique")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestKey := Blake3DeriveKey(s.psk, requestSalt, s.keyLength)
|
||||||
|
reader := shadowaead.NewReader(
|
||||||
|
conn,
|
||||||
|
s.constructor(common.Dup(requestKey)),
|
||||||
|
MaxPacketSize,
|
||||||
|
)
|
||||||
|
|
||||||
|
headerType, err := rw.ReadByte(reader)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "read header")
|
||||||
|
}
|
||||||
|
|
||||||
|
if headerType != HeaderTypeClient {
|
||||||
|
return ErrBadHeaderType
|
||||||
|
}
|
||||||
|
|
||||||
|
var epoch uint64
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &epoch)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "read timestamp")
|
||||||
|
}
|
||||||
|
if math.Abs(float64(time.Now().Unix()-int64(epoch))) > 30 {
|
||||||
|
return ErrBadTimestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "read destination")
|
||||||
|
}
|
||||||
|
|
||||||
|
var paddingLen uint16
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &paddingLen)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "read padding length")
|
||||||
|
}
|
||||||
|
|
||||||
|
if paddingLen > 0 {
|
||||||
|
err = reader.Discard(int(paddingLen))
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "discard padding")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata.Protocol = "shadowsocks"
|
||||||
|
metadata.Destination = destination
|
||||||
|
return s.handler.NewConnection(&serverConn{
|
||||||
|
Service: s,
|
||||||
|
Conn: conn,
|
||||||
|
reader: reader,
|
||||||
|
requestSalt: requestSalt,
|
||||||
|
}, metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverConn struct {
|
||||||
|
*Service
|
||||||
|
net.Conn
|
||||||
|
access sync.Mutex
|
||||||
|
reader *shadowaead.Reader
|
||||||
|
writer *shadowaead.Writer
|
||||||
|
requestSalt []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
||||||
|
_salt := make([]byte, KeySaltSize)
|
||||||
|
salt := common.Dup(_salt)
|
||||||
|
common.Must1(io.ReadFull(c.secureRNG, salt))
|
||||||
|
key := Blake3DeriveKey(c.psk, salt, c.keyLength)
|
||||||
|
writer := shadowaead.NewWriter(
|
||||||
|
c.Conn,
|
||||||
|
c.constructor(common.Dup(key)),
|
||||||
|
MaxPacketSize,
|
||||||
|
)
|
||||||
|
header := writer.Buffer()
|
||||||
|
header.Write(salt)
|
||||||
|
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||||
|
|
||||||
|
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeServer))
|
||||||
|
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
|
||||||
|
common.Must1(bufferedWriter.Write(c.requestSalt))
|
||||||
|
c.requestSalt = nil
|
||||||
|
|
||||||
|
if len(payload) > 0 {
|
||||||
|
_, err = bufferedWriter.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bufferedWriter.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.writer = writer
|
||||||
|
n = len(payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) Write(p []byte) (n int, err error) {
|
||||||
|
if c.writer != nil {
|
||||||
|
return c.writer.Write(p)
|
||||||
|
}
|
||||||
|
c.access.Lock()
|
||||||
|
if c.writer != nil {
|
||||||
|
c.access.Unlock()
|
||||||
|
return c.writer.Write(p)
|
||||||
|
}
|
||||||
|
defer c.access.Unlock()
|
||||||
|
return c.writeResponse(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
|
if c.writer != nil {
|
||||||
|
return rw.ReadFrom0(c, r)
|
||||||
|
}
|
||||||
|
return c.writer.ReadFrom(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
|
return c.reader.WriteTo(w)
|
||||||
|
}
|
|
@ -2,13 +2,13 @@ package socks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/sagernet/sing/common/task"
|
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
"github.com/sagernet/sing/common/task"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PacketConn interface {
|
type PacketConn interface {
|
||||||
|
@ -47,26 +47,32 @@ func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
|
||||||
|
|
||||||
func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error {
|
func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error {
|
||||||
return task.Run(ctx, func() error {
|
return task.Run(ctx, func() error {
|
||||||
_buffer := buf.StackNew()
|
_buffer := buf.StackNewMax()
|
||||||
buffer := common.Dup(_buffer)
|
buffer := common.Dup(_buffer)
|
||||||
|
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
|
||||||
for {
|
for {
|
||||||
destination, err := conn.ReadPacket(buffer)
|
data.FullReset()
|
||||||
|
destination, err := conn.ReadPacket(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
buffer.Truncate(data.Len())
|
||||||
err = dest.WritePacket(buffer, destination)
|
err = dest.WritePacket(buffer, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, func() error {
|
}, func() error {
|
||||||
_buffer := buf.StackNew()
|
_buffer := buf.StackNewMax()
|
||||||
buffer := common.Dup(_buffer)
|
buffer := common.Dup(_buffer)
|
||||||
|
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
|
||||||
for {
|
for {
|
||||||
destination, err := dest.ReadPacket(buffer)
|
data.FullReset()
|
||||||
|
destination, err := dest.ReadPacket(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
buffer.Truncate(data.Len())
|
||||||
err = conn.WritePacket(buffer, destination)
|
err = conn.WritePacket(buffer, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -125,7 +131,8 @@ func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error
|
||||||
|
|
||||||
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||||
defer buffer.Release()
|
defer buffer.Release()
|
||||||
header := buf.New()
|
_header := buf.StackNew()
|
||||||
|
header := common.Dup(_header)
|
||||||
common.Must(header.WriteZeroN(3))
|
common.Must(header.WriteZeroN(3))
|
||||||
common.Must(AddressSerializer.WriteAddrPort(header, addrPort))
|
common.Must(AddressSerializer.WriteAddrPort(header, addrPort))
|
||||||
buffer = buffer.WriteBufferAtFirst(header)
|
buffer = buffer.WriteBufferAtFirst(header)
|
||||||
|
|
|
@ -83,7 +83,7 @@ func HandleConnection(conn net.Conn, authenticator auth.Authenticator, bind neti
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "read user auth request")
|
return E.Cause(err, "read user auth request")
|
||||||
}
|
}
|
||||||
response := new(UsernamePasswordAuthResponse)
|
response := &UsernamePasswordAuthResponse{}
|
||||||
if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) {
|
if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) {
|
||||||
response.Status = UsernamePasswordStatusSuccess
|
response.Status = UsernamePasswordStatusSuccess
|
||||||
} else {
|
} else {
|
||||||
|
@ -109,6 +109,7 @@ func HandleConnection(conn net.Conn, authenticator auth.Authenticator, bind neti
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "write socks response")
|
return E.Cause(err, "write socks response")
|
||||||
}
|
}
|
||||||
|
metadata.Protocol = "socks"
|
||||||
metadata.Destination = request.Destination
|
metadata.Destination = request.Destination
|
||||||
return handler.NewConnection(conn, metadata)
|
return handler.NewConnection(conn, metadata)
|
||||||
case CommandUDPAssociate:
|
case CommandUDPAssociate:
|
||||||
|
|
|
@ -92,16 +92,22 @@ func (l *Listener) loop() {
|
||||||
}
|
}
|
||||||
switch l.trans {
|
switch l.trans {
|
||||||
case redir.ModeRedirect:
|
case redir.ModeRedirect:
|
||||||
metadata.Destination, _ = redir.GetOriginalDestination(tcpConn)
|
destination, err := redir.GetOriginalDestination(tcpConn)
|
||||||
|
if err == nil {
|
||||||
|
metadata.Protocol = "redirect"
|
||||||
|
metadata.Destination = destination
|
||||||
|
}
|
||||||
case redir.ModeTProxy:
|
case redir.ModeTProxy:
|
||||||
lAddr := tcpConn.LocalAddr().(*net.TCPAddr)
|
lAddr := tcpConn.LocalAddr().(*net.TCPAddr)
|
||||||
rAddr := tcpConn.RemoteAddr().(*net.TCPAddr)
|
rAddr := tcpConn.RemoteAddr().(*net.TCPAddr)
|
||||||
|
|
||||||
if lAddr.Port != l.lAddr.Port || !lAddr.IP.Equal(rAddr.IP) && !lAddr.IP.IsLoopback() && !lAddr.IP.IsPrivate() {
|
if lAddr.Port != l.lAddr.Port || !lAddr.IP.Equal(rAddr.IP) && !lAddr.IP.IsLoopback() && !lAddr.IP.IsPrivate() {
|
||||||
|
metadata.Protocol = "tproxy"
|
||||||
metadata.Destination = M.AddrPortFromNetAddr(lAddr)
|
metadata.Destination = M.AddrPortFromNetAddr(lAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
|
metadata.Protocol = "tcp"
|
||||||
hErr := l.handler.NewConnection(tcpConn, metadata)
|
hErr := l.handler.NewConnection(tcpConn, metadata)
|
||||||
if hErr != nil {
|
if hErr != nil {
|
||||||
l.handler.HandleError(&Error{Conn: tcpConn, Cause: hErr})
|
l.handler.HandleError(&Error{Conn: tcpConn, Cause: hErr})
|
||||||
|
|
|
@ -80,7 +80,8 @@ func (l *Listener) loop() {
|
||||||
}
|
}
|
||||||
buffer.Truncate(n)
|
buffer.Truncate(n)
|
||||||
err = l.handler.NewPacket(buffer, M.Metadata{
|
err = l.handler.NewPacket(buffer, M.Metadata{
|
||||||
Source: M.AddrPortFromNetAddr(addr),
|
Protocol: "udp",
|
||||||
|
Source: M.AddrPortFromNetAddr(addr),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
|
@ -104,6 +105,7 @@ func (l *Listener) loop() {
|
||||||
}
|
}
|
||||||
buffer.Truncate(n)
|
buffer.Truncate(n)
|
||||||
err = l.handler.NewPacket(buffer, M.Metadata{
|
err = l.handler.NewPacket(buffer, M.Metadata{
|
||||||
|
Protocol: "tproxy",
|
||||||
Source: M.AddrPortFromAddrPort(addr),
|
Source: M.AddrPortFromAddrPort(addr),
|
||||||
Destination: destination,
|
Destination: destination,
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue