Refine buffer

This commit is contained in:
世界 2022-04-28 07:08:50 +08:00
parent 31d4b88581
commit f16dd7a336
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
30 changed files with 993 additions and 209 deletions

View file

@ -62,7 +62,7 @@ func main() {
Short: "shadowsocks client",
Version: sing.VersionStr,
Run: func(cmd *cobra.Command, args []string) {
Run(cmd, f)
run(cmd, f)
},
}
@ -95,7 +95,7 @@ Only available with Linux kernel > 3.7.0.`)
}
}
type LocalClient struct {
type client struct {
*mixed.Listener
*geosite.Matcher
server *M.AddrPort
@ -104,7 +104,7 @@ type LocalClient struct {
bypass string
}
func NewLocalClient(f *flags) (*LocalClient, error) {
func newClient(f *flags) (*client, error) {
if f.ConfigFile != "" {
configFile, err := ioutil.ReadFile(f.ConfigFile)
if err != nil {
@ -159,13 +159,13 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
return nil, E.New("missing method")
}
client := &LocalClient{
c := &client{
server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort),
bypass: f.Bypass,
}
if f.Method == shadowsocks.MethodNone {
client.method = shadowsocks.NewNone()
c.method = shadowsocks.NewNone()
} else {
var pskList [][]byte
if f.Key != "" {
@ -183,7 +183,7 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
if f.UseSystemRNG {
rng = random.System
} else {
rng = random.Blake3KeyedHash()
rng = random.System
}
if f.ReducedSaltEntropy {
rng = &shadowsocks.ReducedEntropyReader{Reader: rng}
@ -200,17 +200,17 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
if err != nil {
return nil, err
}
client.method = method
c.method = method
} else if common.Contains(shadowaead_2022.List, f.Method) {
method, err := shadowaead_2022.New(f.Method, pskList, rng)
if err != nil {
return nil, err
}
client.method = method
c.method = method
}
}
client.dialer.Control = func(network, address string, c syscall.RawConn) error {
c.dialer.Control = func(network, address string, c syscall.RawConn) error {
var rawFd uintptr
err := c.Control(func(fd uintptr) {
rawFd = fd
@ -256,7 +256,7 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
bind = netip.IPv6Unspecified()
}
client.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, transproxyMode, client)
c.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, transproxyMode, c)
if f.Bypass != "" {
err := geoip.LoadMMDB("Country.mmdb")
@ -278,11 +278,11 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
if err != nil {
return nil, err
}
client.Matcher = geositeMatcher
c.Matcher = geositeMatcher
}
debug.FreeOSMemory()
return client, nil
return c, nil
}
func bypass(conn net.Conn, destination *M.AddrPort) error {
@ -302,7 +302,7 @@ func bypass(conn net.Conn, destination *M.AddrPort) error {
})
}
func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
func (c *client) NewConnection(conn net.Conn, metadata M.Metadata) error {
if c.bypass != "" {
if metadata.Destination.Addr.Family().IsFqdn() {
if c.Match(metadata.Destination.Addr.Fqdn()) {
@ -315,7 +315,7 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
}
}
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
logrus.Info("outbound ", metadata.Protocol, " TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination)
ctx := context.Background()
serverConn, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
@ -339,7 +339,6 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
}
serverConn = c.method.DialEarlyConn(serverConn, metadata.Destination)
_, err = serverConn.Write(payload.Bytes())
payload.Release()
if err != nil {
return E.Cause(err, "client handshake")
}
@ -347,7 +346,7 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
return rw.CopyConn(ctx, serverConn, conn)
}
func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error {
func (c *client) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error {
ctx := context.Background()
udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String())
if err != nil {
@ -371,32 +370,28 @@ func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) e
})
}
func Run(cmd *cobra.Command, flags *flags) {
client, err := NewLocalClient(flags)
func run(cmd *cobra.Command, flags *flags) {
c, err := newClient(flags)
if err != nil {
logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n")
cmd.Help()
os.Exit(1)
}
err = client.Listener.Start()
err = c.Start()
if err != nil {
logrus.Fatal(err)
}
if err != nil {
logrus.Fatal(err)
}
logrus.Info("mixed server started at ", client.Listener.TCPListener.Addr())
logrus.Info("mixed server started at ", c.TCPListener.Addr())
osSignals := make(chan os.Signal, 1)
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
<-osSignals
client.Listener.Close()
c.Close()
}
func (c *LocalClient) HandleError(err error) {
func (c *client) HandleError(err error) {
common.Close(err)
if E.IsClosed(err) {
return

151
cli/ss-server/main.go Normal file
View file

@ -0,0 +1,151 @@
package main
import (
"context"
"encoding/base64"
"net"
"net/netip"
"os"
"os/signal"
"strings"
"syscall"
"github.com/sagernet/sing"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/random"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
"github.com/sagernet/sing/transport/tcp"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
type flags struct {
Bind string `json:"local_address"`
LocalPort uint16 `json:"local_port"`
// Password string `json:"password"`
Key string `json:"key"`
Method string `json:"method"`
Verbose bool `json:"verbose"`
ConfigFile string
}
func main() {
logrus.SetLevel(logrus.TraceLevel)
f := new(flags)
command := &cobra.Command{
Use: "ss-local",
Short: "shadowsocks client",
Version: sing.VersionStr,
Run: func(cmd *cobra.Command, args []string) {
run(cmd, f)
},
}
command.Flags().StringVarP(&f.Bind, "local-address", "b", "", "Set the local address.")
command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.")
command.Flags().StringVarP(&f.Key, "key", "k", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
var supportedCiphers []string
supportedCiphers = append(supportedCiphers, shadowsocks.MethodNone)
supportedCiphers = append(supportedCiphers, shadowaead_2022.List...)
command.Flags().StringVarP(&f.Method, "encrypt-method", "m", "", "Set the cipher.\n\nSupported ciphers:\n\n"+strings.Join(supportedCiphers, "\n"))
command.Flags().StringVarP(&f.ConfigFile, "config", "c", "", "Use a configuration file.")
command.Flags().BoolVarP(&f.Verbose, "verbose", "v", true, "Enable verbose mode.")
err := command.Execute()
if err != nil {
logrus.Fatal(err)
}
}
func run(cmd *cobra.Command, f *flags) {
s, err := newServer(f)
if err != nil {
logrus.Fatal(err)
}
err = s.tcpIn.Start()
if err != nil {
logrus.Fatal(err)
}
logrus.Info("server started at ", s.tcpIn.TCPListener.Addr())
osSignals := make(chan os.Signal, 1)
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
<-osSignals
s.tcpIn.Close()
}
type server struct {
tcpIn *tcp.Listener
service shadowsocks.Service
}
func newServer(f *flags) (*server, error) {
s := new(server)
if f.Method == shadowsocks.MethodNone {
s.service = shadowsocks.NewNoneService(s)
} else if common.Contains(shadowaead_2022.List, f.Method) {
var pskList [][]byte
if f.Key != "" {
keyStrList := strings.Split(f.Key, ":")
pskList = make([][]byte, len(keyStrList))
for i, keyStr := range keyStrList {
key, err := base64.StdEncoding.DecodeString(keyStr)
if err != nil {
return nil, E.Cause(err, "decode key")
}
pskList[i] = key
}
}
rng := random.System
service, err := shadowaead_2022.NewService(f.Method, pskList[0], rng, s)
if err != nil {
return nil, err
}
s.service = service
} else {
return nil, E.New("unsupported method " + f.Method)
}
var bind netip.Addr
if f.Bind != "" {
addr, err := netip.ParseAddr(f.Bind)
if err != nil {
return nil, E.Cause(err, "bad local address")
}
bind = addr
} else {
bind = netip.IPv6Unspecified()
}
s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.LocalPort), s)
return s, nil
}
func (s *server) NewConnection(conn net.Conn, metadata M.Metadata) error {
if metadata.Protocol != "shadowsocks" {
return s.service.NewConnection(conn, metadata)
}
logrus.Info("inbound TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination)
destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination)
if err != nil {
return err
}
return rw.CopyConn(context.Background(), conn, destConn)
}
func (s *server) HandleError(err error) {
if E.IsClosed(err) {
return
}
logrus.Warn(err)
}

View file

@ -31,6 +31,12 @@ func StackNew() *Buffer {
}
}
func StackNewMax() *Buffer {
return &Buffer{
data: make([]byte, 65535),
}
}
func StackNewSize(size int) *Buffer {
return &Buffer{
data: Make(size),
@ -291,6 +297,14 @@ func (b *Buffer) Release() {
*b = Buffer{}
}
func (b *Buffer) Cut(start int, end int) *Buffer {
b.start += start
b.end = len(b.data) - end
return &Buffer{
data: b.data[b.start:b.end],
}
}
func (b Buffer) Len() int {
return b.end - b.start
}

View file

@ -1,81 +0,0 @@
package buf
import (
"io"
"github.com/sagernet/sing/common"
)
type BufferedReader struct {
Reader io.Reader
Buffer *Buffer
}
func (r *BufferedReader) Upstream() io.Reader {
if r.Buffer != nil {
return nil
}
return r.Reader
}
func (r *BufferedReader) Replaceable() bool {
return r.Buffer == nil
}
func (r *BufferedReader) Read(p []byte) (n int, err error) {
if r.Buffer != nil {
n, err = r.Buffer.Read(p)
if r.Buffer.IsEmpty() {
r.Buffer.Release()
r.Buffer = nil
}
return
}
return r.Reader.Read(p)
}
type BufferedWriter struct {
Writer io.Writer
Buffer *Buffer
}
func (w *BufferedWriter) Upstream() io.Writer {
return w.Writer
}
func (w *BufferedWriter) Replaceable() bool {
return w.Buffer == nil
}
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
if w.Buffer == nil {
return w.Writer.Write(p)
}
n, err = w.Buffer.Write(p)
if err == nil {
return
}
return len(p), w.Flush()
}
func (w *BufferedWriter) Flush() error {
if w.Buffer == nil {
return nil
}
buffer := w.Buffer
w.Buffer = nil
defer buffer.Release()
if buffer.IsEmpty() {
return nil
}
return common.Error(w.Writer.Write(buffer.Bytes()))
}
func (w *BufferedWriter) Close() error {
buffer := w.Buffer
if buffer != nil {
w.Buffer = nil
buffer.Release()
}
return nil
}

View file

@ -29,11 +29,11 @@ func (r *readWriteConn) Close() error {
}
func (r *readWriteConn) LocalAddr() net.Addr {
return new(DummyAddr)
return &DummyAddr{}
}
func (r *readWriteConn) RemoteAddr() net.Addr {
return new(DummyAddr)
return &DummyAddr{}
}
func (r *readWriteConn) SetDeadline(t time.Time) error {
@ -53,7 +53,7 @@ type readConn struct {
}
func (r *readConn) Write(b []byte) (n int, err error) {
return 0, new(ReadOnlyException)
return 0, &ReadOnlyException{}
}
type writeConn struct {
@ -62,23 +62,23 @@ type writeConn struct {
}
func (w *writeConn) Read(p []byte) (n int, err error) {
return 0, new(WriteOnlyException)
return 0, &WriteOnlyException{}
}
func NewReadConn(reader io.Reader) net.Conn {
c := new(readConn)
c := &readConn{}
c.Reader = reader
return c
}
func NewWritConn(writer io.Writer) net.Conn {
c := new(writeConn)
c := &writeConn{}
c.Writer = writer
return c
}
func NewReadWriteConn(reader io.Reader, writer io.Writer) net.Conn {
c := new(readWriteConn)
c := &readConn{}
c.Reader = reader
c.Writer = writer
return c

View file

@ -52,8 +52,8 @@ func FlushVar(writerP *io.Writer) error {
writerBack = writer
*writerP = writer
continue
} else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter {
setter.SetWriter(writerBack)
} else if setter, hasSetter := writerBack.(UpstreamWriterSetter); hasSetter {
setter.SetWriter(u.Upstream())
writer = u.Upstream()
continue
}

View file

@ -2,6 +2,4 @@
package lowmem
func init() {
Enabled = true
}
const Enabled = true

View file

@ -4,8 +4,6 @@ import (
"runtime/debug"
)
var Enabled = false
func Free() {
if Enabled {
debug.FreeOSMemory()

5
common/lowmem/stub.go Normal file
View file

@ -0,0 +1,5 @@
//go:build !debug
package lowmem
const Enabled = false

View file

@ -7,6 +7,7 @@ import (
)
type Metadata struct {
Protocol string
Source *AddrPort
Destination *AddrPort
}

View file

@ -1,6 +1,16 @@
package common
import "syscall"
import (
"os"
"syscall"
)
func TryFileDescriptor(conn any) (uintptr, error) {
if rawConn, isRaw := conn.(syscall.Conn); isRaw {
return GetFileDescriptor(rawConn)
}
return 0, os.ErrInvalid
}
func GetFileDescriptor(conn syscall.Conn) (uintptr, error) {
rawConn, err := conn.SyscallConn()

View file

@ -3,10 +3,20 @@ package network
import (
"context"
"net"
M "github.com/sagernet/sing/common/metadata"
)
type ContextDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error)
}
var SystemDialer ContextDialer = &net.Dialer{}
var SystemDialer ContextDialer = &DefaultDialer{}
type DefaultDialer struct {
net.Dialer
}
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error) {
return d.Dialer.DialContext(ctx, network, address.String())
}

View file

@ -3,9 +3,10 @@
package redir
import (
"net"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"net"
)
func TProxy(fd uintptr, isIPv6 bool) error {

133
common/rw/buffer.go Normal file
View file

@ -0,0 +1,133 @@
package rw
import (
"io"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
)
type BufferedWriter struct {
Writer io.Writer
Buffer *buf.Buffer
}
func (w *BufferedWriter) Upstream() io.Writer {
return w.Writer
}
func (w *BufferedWriter) Replaceable() bool {
return w.Buffer == nil
}
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
if w.Buffer == nil {
return w.Writer.Write(p)
}
n, err = w.Buffer.Write(p)
if n == len(p) {
return
}
fd, err := common.TryFileDescriptor(w.Writer)
if err == nil {
_, err = WriteV(fd, w.Buffer.Bytes(), p[n:])
if err != nil {
return
}
w.Buffer.Release()
w.Buffer = nil
return len(p), nil
}
_, err = w.Writer.Write(w.Buffer.Bytes())
if err != nil {
return
}
err = w.Flush()
if err != nil {
return
}
_, err = w.Writer.Write(p[n:])
if err != nil {
return
}
return len(p), nil
}
func (w *BufferedWriter) Flush() error {
if w.Buffer == nil {
return nil
}
if w.Buffer.IsEmpty() {
w.Buffer.Release()
w.Buffer = nil
return nil
}
_, err := w.Writer.Write(w.Buffer.Bytes())
if err != nil {
return err
}
w.Buffer.Release()
w.Buffer = nil
return nil
}
func (w *BufferedWriter) Close() error {
buffer := w.Buffer
if buffer != nil {
w.Buffer = nil
buffer.Release()
}
return nil
}
type HeaderWriter struct {
Writer io.Writer
Header *buf.Buffer
}
func (w *HeaderWriter) Upstream() io.Writer {
return w.Writer
}
func (w *HeaderWriter) Replaceable() bool {
return w.Header == nil
}
func (w *HeaderWriter) Write(p []byte) (n int, err error) {
if w.Header == nil {
return w.Writer.Write(p)
}
fd, err := common.TryFileDescriptor(w.Writer)
if err == nil {
_, err = WriteV(fd, w.Header.Bytes(), p)
if err != nil {
return
}
w.Header.Release()
w.Header = nil
return len(p), nil
}
cachedN, _ := w.Header.Write(p)
_, err = w.Writer.Write(w.Header.Bytes())
if err != nil {
return
}
w.Header.Release()
w.Header = nil
if cachedN < len(p) {
_, err = w.Writer.Write(p[cachedN:])
if err != nil {
return
}
}
return len(p), nil
}
func (w *HeaderWriter) Close() error {
buffer := w.Header
if buffer != nil {
w.Header = nil
buffer.Release()
}
return nil
}

View file

@ -4,6 +4,7 @@ import (
"io"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
)
func Skip(reader io.Reader) error {
@ -15,6 +16,9 @@ func SkipN(reader io.Reader, size int) error {
}
func ReadByte(reader io.Reader) (byte, error) {
if br, isBr := reader.(io.ByteReader); isBr {
return br.ReadByte()
}
var b [1]byte
if err := common.Error(io.ReadFull(reader, b[:])); err != nil {
return 0, err
@ -37,3 +41,33 @@ func ReadString(reader io.Reader, size int) (string, error) {
}
return string(b), nil
}
type ReaderFromWriter interface {
io.ReaderFrom
io.Writer
}
func ReadFrom0(readerFrom ReaderFromWriter, reader io.Reader) (n int64, err error) {
n, err = CopyOnce(readerFrom, reader)
if err != nil {
return
}
var rn int64
rn, err = readerFrom.ReadFrom(reader)
if err != nil {
return
}
n += rn
return
}
func CopyOnce(dest io.Writer, src io.Reader) (n int64, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
n, err = buffer.ReadFrom(src)
if err != nil {
return
}
_, err = dest.Write(buffer.Bytes())
return
}

View file

@ -12,7 +12,7 @@ type stubByteReader struct {
}
func (r stubByteReader) ReadByte() (byte, error) {
return ReadByte(r)
return ReadByte(r.Reader)
}
func ToByteReader(reader io.Reader) io.ByteReader {

View file

@ -19,7 +19,7 @@ func After(task func() error, after func() error) func() error {
func Run(ctx context.Context, tasks ...func() error) error {
ctx, cancel := context.WithCancel(ctx)
wg := new(sync.WaitGroup)
wg := &sync.WaitGroup{}
wg.Add(len(tasks))
var retErr error
for _, task := range tasks {

View file

@ -15,7 +15,7 @@ func TestServerConn(t *testing.T) {
serverConn := NewServerConn(udpConn)
defer serverConn.Close()
clientConn := NewClientConn(serverConn)
message := new(dnsmessage.Message)
message := &dnsmessage.Message{}
message.Header.ID = 1
message.Header.RecursionDesired = true
message.Questions = append(message.Questions, dnsmessage.Question{
@ -30,6 +30,7 @@ func TestServerConn(t *testing.T) {
Port: 53,
}))
_buffer := buf.StackNew()
common.Use(_buffer)
buffer := common.Dup(_buffer)
common.Must2(buffer.ReadPacketFrom(clientConn))
common.Must(message.Unpack(buffer.Bytes()))

View file

@ -95,6 +95,7 @@ func HandleRequest(request *http.Request, conn net.Conn, authenticator auth.Auth
left, right := net.Pipe()
go func() {
metadata.Destination = destination
metadata.Protocol = "http"
err = handler.NewConnection(right, metadata)
if err != nil {
handler.HandleError(&tcp.Error{Conn: right, Cause: err})

View file

@ -117,7 +117,8 @@ func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) {
}
func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) {
return c.Conn.(io.WriterTo).WriteTo(w)
return io.Copy(w, c.Conn)
// return c.Conn.(io.WriterTo).WriteTo(w)
}
func (c *noneConn) RemoteAddr() net.Addr {
@ -138,7 +139,7 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
defer buffer.Release()
_header := buf.StackNew()
_header := buf.StackNewMax()
header := common.Dup(_header)
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
if err != nil {

View file

@ -0,0 +1,42 @@
package shadowsocks
import (
"net"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
)
type Service interface {
M.TCPConnectionHandler
}
type MultiUserService interface {
Service
AddUser(key []byte)
RemoveUser(key []byte)
}
type Handler interface {
M.TCPConnectionHandler
}
type NoneService struct {
handler Handler
}
func NewNoneService(handler Handler) Service {
return &NoneService{
handler: handler,
}
}
func (s *NoneService) NewConnection(conn net.Conn, metadata M.Metadata) error {
destination, err := socks.AddressSerializer.ReadAddrPort(conn)
if err != nil {
return err
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.handler.NewConnection(conn, metadata)
}

View file

@ -5,7 +5,6 @@ import (
"encoding/binary"
"io"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
)
@ -92,6 +91,46 @@ func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
}
}
func (r *Reader) readInternal() (err error) {
start := PacketLengthBufferSize + r.cipher.Overhead()
_, err = io.ReadFull(r.upstream, r.buffer[:start])
if err != nil {
return err
}
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
if err != nil {
return err
}
increaseNonce(r.nonce)
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
end := length + r.cipher.Overhead()
_, err = io.ReadFull(r.upstream, r.buffer[:end])
if err != nil {
return err
}
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
if err != nil {
return err
}
increaseNonce(r.nonce)
r.cached = length
r.index = 0
return nil
}
func (r *Reader) ReadByte() (byte, error) {
if r.cached == 0 {
err := r.readInternal()
if err != nil {
return 0, err
}
}
index := r.index
r.index++
r.cached--
return r.buffer[index], nil
}
func (r *Reader) Read(b []byte) (n int, err error) {
if r.cached > 0 {
n = copy(b, r.buffer[r.index:r.index+r.cached])
@ -141,6 +180,24 @@ func (r *Reader) Read(b []byte) (n int, err error) {
}
}
func (r *Reader) Discard(n int) error {
for {
if r.cached >= n {
r.cached -= n
r.index += n
return nil
} else if r.cached > 0 {
n -= r.cached
r.cached = 0
r.index = 0
}
err := r.readInternal()
if err != nil {
return err
}
}
}
type Writer struct {
upstream io.Writer
cipher cipher.AEAD
@ -197,10 +254,6 @@ func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
if err != nil {
return
}
err = common.FlushVar(&w.upstream)
if err != nil {
return
}
n += int64(readN)
}
}
@ -227,6 +280,70 @@ func (w *Writer) Write(p []byte) (n int, err error) {
return
}
func (w *Writer) Buffer() *buf.Buffer {
return buf.With(w.buffer)
}
func (w *Writer) BufferedWriter(reversed int) *BufferedWriter {
return &BufferedWriter{
upstream: w,
reversed: reversed,
data: w.buffer[PacketLengthBufferSize+w.cipher.Overhead() : len(w.buffer)-w.cipher.Overhead()],
}
}
type BufferedWriter struct {
upstream *Writer
data []byte
reversed int
index int
}
func (w *BufferedWriter) Upstream() io.Writer {
return w.upstream
}
func (w *BufferedWriter) Replaceable() bool {
return w.index == 0
}
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
var index int
for {
cachedN := copy(w.data[w.reversed+w.index:], p[index:])
if cachedN == len(p[index:]) {
w.index += cachedN
return cachedN, nil
}
err = w.Flush()
if err != nil {
return
}
index += cachedN
}
}
func (w *BufferedWriter) Flush() error {
if w.index == 0 {
if w.reversed > 0 {
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed])
w.reversed = 0
return err
}
return nil
}
buffer := w.upstream.buffer[w.reversed:]
binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index))
w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil)
increaseNonce(w.upstream.nonce)
offset := w.upstream.cipher.Overhead() + PacketLengthBufferSize
packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil)
increaseNonce(w.upstream.nonce)
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)])
w.reversed = 0
return err
}
func increaseNonce(nonce []byte) {
for i := range nonce {
nonce[i]++

View file

@ -189,56 +189,47 @@ type clientConn struct {
destination *M.AddrPort
access sync.Mutex
reader io.Reader
writer io.Writer
reader *Reader
writer *Writer
}
func (c *clientConn) writeRequest(payload []byte) error {
_request := buf.StackNew()
request := common.Dup(_request)
_salt := make([]byte, c.method.keySaltLength)
salt := common.Dup(_salt)
common.Must1(io.ReadFull(c.method.secureRNG, salt))
common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
var writer io.Writer = c.Conn
writer = &buf.BufferedWriter{
Writer: writer,
Buffer: request,
}
writer = NewWriter(
writer,
c.method.constructor(Kdf(c.method.key, request.Bytes(), c.method.keySaltLength)),
key := Kdf(c.method.key, salt, c.method.keySaltLength)
writer := NewWriter(
c.Conn,
c.method.constructor(common.Dup(key)),
MaxPacketSize,
)
header := writer.Buffer()
header.Write(salt)
bufferedWriter := writer.BufferedWriter(header.Len())
if len(payload) > 0 {
_header := buf.StackNew()
header := common.Dup(_header)
writer = &buf.BufferedWriter{
Writer: writer,
Buffer: header,
}
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return err
}
_, err = writer.Write(payload)
_, err = bufferedWriter.Write(payload)
if err != nil {
return err
}
} else {
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return err
}
}
err := common.FlushVar(&writer)
err := bufferedWriter.Flush()
if err != nil {
return err
}
c.writer = writer
return nil
}
@ -278,7 +269,7 @@ func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
if err = c.readResponse(); err != nil {
return
}
return c.reader.(io.WriterTo).WriteTo(w)
return c.reader.WriteTo(w)
}
func (c *clientConn) Write(p []byte) (n int, err error) {
@ -302,9 +293,9 @@ func (c *clientConn) Write(p []byte) (n int, err error) {
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
panic("missing handshake")
return rw.ReadFrom0(c, r)
}
return c.writer.(io.ReaderFrom).ReadFrom(r)
return c.writer.ReadFrom(r)
}
type clientPacketConn struct {

View file

@ -0,0 +1,164 @@
package shadowaead
import (
"crypto/cipher"
"io"
"net"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks"
"golang.org/x/crypto/chacha20poly1305"
)
type Service struct {
name string
keySaltLength int
constructor func(key []byte) cipher.AEAD
key []byte
secureRNG io.Reader
replayFilter replay.Filter
handler shadowsocks.Handler
}
func NewService(method string, key []byte, password []byte, secureRNG io.Reader, replayFilter bool, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{
name: method,
secureRNG: secureRNG,
handler: handler,
}
if replayFilter {
s.replayFilter = replay.NewBloomRing()
}
switch method {
case "aes-128-gcm":
s.keySaltLength = 16
s.constructor = newAESGCM
case "aes-192-gcm":
s.keySaltLength = 24
s.constructor = newAESGCM
case "aes-256-gcm":
s.keySaltLength = 32
s.constructor = newAESGCM
case "chacha20-ietf-poly1305":
s.keySaltLength = 32
s.constructor = func(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.New(key)
common.Must(err)
return cipher
}
case "xchacha20-ietf-poly1305":
s.keySaltLength = 32
s.constructor = func(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.NewX(key)
common.Must(err)
return cipher
}
}
if len(key) == s.keySaltLength {
s.key = key
} else if len(key) > 0 {
return nil, ErrBadKey
} else if len(password) > 0 {
s.key = shadowsocks.Key(password, s.keySaltLength)
} else {
return nil, ErrMissingPassword
}
return s, nil
}
func (s *Service) NewConnection(conn net.Conn, metadata M.Metadata) error {
_salt := buf.Make(s.keySaltLength)
salt := common.Dup(_salt)
_, err := io.ReadFull(conn, salt)
if err != nil {
return E.Cause(err, "read salt")
}
key := Kdf(s.key, salt, s.keySaltLength)
reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize)
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
if err != nil {
return err
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.handler.NewConnection(&serverConn{
Service: s,
Conn: conn,
reader: reader,
}, metadata)
}
type serverConn struct {
*Service
net.Conn
access sync.Mutex
reader *Reader
writer *Writer
}
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
_salt := buf.Make(c.keySaltLength)
salt := common.Dup(_salt)
common.Must1(io.ReadFull(c.secureRNG, salt))
key := Kdf(c.key, salt, c.keySaltLength)
writer := NewWriter(
c.Conn,
c.constructor(common.Dup(key)),
MaxPacketSize,
)
header := writer.Buffer()
header.Write(salt)
bufferedWriter := writer.BufferedWriter(header.Len())
if len(payload) > 0 {
_, err = bufferedWriter.Write(payload)
if err != nil {
return
}
}
err = bufferedWriter.Flush()
if err != nil {
return
}
c.writer = writer
return
}
func (c *serverConn) Write(p []byte) (n int, err error) {
if c.writer != nil {
return c.writer.Write(p)
}
c.access.Lock()
if c.writer != nil {
c.access.Unlock()
return c.writer.Write(p)
}
defer c.access.Unlock()
return c.writeResponse(p)
}
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer != nil {
return rw.ReadFrom0(c, r)
}
return c.writer.ReadFrom(r)
}
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}

View file

@ -197,8 +197,8 @@ type clientConn struct {
requestSalt []byte
reader io.Reader
writer io.Writer
reader *shadowaead.Reader
writer *shadowaead.Writer
}
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
@ -222,68 +222,56 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
}
func (c *clientConn) writeRequest(payload []byte) error {
_request := buf.StackNew()
request := common.Dup(_request)
salt := make([]byte, KeySaltSize)
common.Must1(io.ReadFull(c.method.secureRNG, salt))
common.Must1(request.Write(salt))
c.method.writeExtendedIdentityHeaders(request, salt)
var writer io.Writer
writer = &buf.BufferedWriter{
Writer: c.Conn,
Buffer: request,
}
key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength)
writer = shadowaead.NewWriter(
writer,
writer := shadowaead.NewWriter(
c.Conn,
c.method.constructor(common.Dup(key)),
MaxPacketSize,
)
_header := buf.StackNew()
header := common.Dup(_header)
header := writer.Buffer()
header.Write(salt)
c.method.writeExtendedIdentityHeaders(header, salt)
writer = &buf.BufferedWriter{
Writer: writer,
Buffer: header,
}
bufferedWriter := writer.BufferedWriter(header.Len())
common.Must(rw.WriteByte(writer, HeaderTypeClient))
common.Must(binary.Write(writer, binary.BigEndian, uint64(time.Now().Unix())))
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeClient))
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return E.Cause(err, "write destination")
}
if len(payload) > 0 {
err = binary.Write(writer, binary.BigEndian, uint16(0))
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(0))
if err != nil {
return E.Cause(err, "write padding length")
}
_, err = writer.Write(payload)
_, err = bufferedWriter.Write(payload)
if err != nil {
return E.Cause(err, "write payload")
}
} else {
pLen := rand.Intn(MaxPaddingLength + 1)
err = binary.Write(writer, binary.BigEndian, uint16(pLen))
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(pLen))
if err != nil {
return E.Cause(err, "write padding length")
}
_, err = io.CopyN(writer, c.method.secureRNG, int64(pLen))
_, err = io.CopyN(bufferedWriter, c.method.secureRNG, int64(pLen))
if err != nil {
return E.Cause(err, "write padding")
}
}
err = common.FlushVar(&writer)
err = bufferedWriter.Flush()
if err != nil {
return E.Cause(err, "client handshake")
}
c.requestSalt = salt
c.writer = writer
return nil
@ -363,7 +351,7 @@ func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
if err = c.readResponse(); err != nil {
return
}
return c.reader.(io.WriterTo).WriteTo(w)
return c.reader.WriteTo(w)
}
func (c *clientConn) Write(p []byte) (n int, err error) {
@ -389,10 +377,10 @@ func (c *clientConn) Write(p []byte) (n int, err error) {
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
panic("missing client handshake")
return rw.ReadFrom0(c, r)
}
return c.writer.(io.ReaderFrom).ReadFrom(r)
return c.writer.ReadFrom(r)
}
type clientPacketConn struct {
@ -540,7 +528,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
c.session.lastFilter = c.session.filter
c.session.lastRemoteSeen = time.Now().Unix()
c.session.lastRemoteCipher = c.session.remoteCipher
c.session.filter = new(wgReplay.Filter)
c.session.filter = wgReplay.Filter{}
}
}
c.session.remoteSessionId = sessionId
@ -577,8 +565,8 @@ type udpSession struct {
cipher cipher.AEAD
remoteCipher cipher.AEAD
lastRemoteCipher cipher.AEAD
filter *wgReplay.Filter
lastFilter *wgReplay.Filter
filter wgReplay.Filter
lastFilter wgReplay.Filter
}
func (s *udpSession) nextPacketId() uint64 {
@ -588,7 +576,6 @@ func (s *udpSession) nextPacketId() uint64 {
func (m *Method) newUDPSession() *udpSession {
session := &udpSession{
sessionId: rand.Uint64(),
filter: new(wgReplay.Filter),
}
if m.udpCipher == nil {
sessionId := make([]byte, 8)

View file

@ -0,0 +1,195 @@
package shadowaead_2022
import (
"crypto/cipher"
"encoding/binary"
"io"
"math"
"net"
"sync"
"time"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
)
type Service struct {
name string
secureRNG io.Reader
keyLength int
constructor func(key []byte) cipher.AEAD
psk []byte
replayFilter replay.Filter
handler shadowsocks.Handler
}
func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{
name: method,
psk: psk,
secureRNG: secureRNG,
replayFilter: replay.NewCuckoo(60),
handler: handler,
}
if len(psk) != KeySaltSize {
return nil, shadowaead.ErrBadKey
}
switch method {
case "2022-blake3-aes-128-gcm":
s.keyLength = 16
s.constructor = newAESGCM
// m.blockConstructor = newAES
// m.udpBlockCipher = newAES(m.psk)
case "2022-blake3-aes-256-gcm":
s.keyLength = 32
s.constructor = newAESGCM
// m.blockConstructor = newAES
// m.udpBlockCipher = newAES(m.psk)
case "2022-blake3-chacha20-poly1305":
s.keyLength = 32
s.constructor = newChacha20Poly1305
// m.udpCipher = newXChacha20Poly1305(m.psk)
}
return s, nil
}
func (s *Service) NewConnection(conn net.Conn, metadata M.Metadata) error {
requestSalt := make([]byte, KeySaltSize)
_, err := io.ReadFull(conn, requestSalt)
if err != nil {
return E.Cause(err, "read request salt")
}
if !s.replayFilter.Check(requestSalt) {
return E.New("salt not unique")
}
requestKey := Blake3DeriveKey(s.psk, requestSalt, s.keyLength)
reader := shadowaead.NewReader(
conn,
s.constructor(common.Dup(requestKey)),
MaxPacketSize,
)
headerType, err := rw.ReadByte(reader)
if err != nil {
return E.Cause(err, "read header")
}
if headerType != HeaderTypeClient {
return ErrBadHeaderType
}
var epoch uint64
err = binary.Read(reader, binary.BigEndian, &epoch)
if err != nil {
return E.Cause(err, "read timestamp")
}
if math.Abs(float64(time.Now().Unix()-int64(epoch))) > 30 {
return ErrBadTimestamp
}
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
if err != nil {
return E.Cause(err, "read destination")
}
var paddingLen uint16
err = binary.Read(reader, binary.BigEndian, &paddingLen)
if err != nil {
return E.Cause(err, "read padding length")
}
if paddingLen > 0 {
err = reader.Discard(int(paddingLen))
if err != nil {
return E.Cause(err, "discard padding")
}
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.handler.NewConnection(&serverConn{
Service: s,
Conn: conn,
reader: reader,
requestSalt: requestSalt,
}, metadata)
}
type serverConn struct {
*Service
net.Conn
access sync.Mutex
reader *shadowaead.Reader
writer *shadowaead.Writer
requestSalt []byte
}
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
_salt := make([]byte, KeySaltSize)
salt := common.Dup(_salt)
common.Must1(io.ReadFull(c.secureRNG, salt))
key := Blake3DeriveKey(c.psk, salt, c.keyLength)
writer := shadowaead.NewWriter(
c.Conn,
c.constructor(common.Dup(key)),
MaxPacketSize,
)
header := writer.Buffer()
header.Write(salt)
bufferedWriter := writer.BufferedWriter(header.Len())
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeServer))
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
common.Must1(bufferedWriter.Write(c.requestSalt))
c.requestSalt = nil
if len(payload) > 0 {
_, err = bufferedWriter.Write(payload)
if err != nil {
return
}
}
err = bufferedWriter.Flush()
if err != nil {
return
}
c.writer = writer
n = len(payload)
return
}
func (c *serverConn) Write(p []byte) (n int, err error) {
if c.writer != nil {
return c.writer.Write(p)
}
c.access.Lock()
if c.writer != nil {
c.access.Unlock()
return c.writer.Write(p)
}
defer c.access.Unlock()
return c.writeResponse(p)
}
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer != nil {
return rw.ReadFrom0(c, r)
}
return c.writer.ReadFrom(r)
}
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}

View file

@ -2,13 +2,13 @@ package socks
import (
"context"
"github.com/sagernet/sing/common/task"
"net"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/task"
)
type PacketConn interface {
@ -47,26 +47,32 @@ func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error {
return task.Run(ctx, func() error {
_buffer := buf.StackNew()
_buffer := buf.StackNewMax()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
for {
destination, err := conn.ReadPacket(buffer)
data.FullReset()
destination, err := conn.ReadPacket(data)
if err != nil {
return err
}
buffer.Truncate(data.Len())
err = dest.WritePacket(buffer, destination)
if err != nil {
return err
}
}
}, func() error {
_buffer := buf.StackNew()
_buffer := buf.StackNewMax()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
for {
destination, err := dest.ReadPacket(buffer)
data.FullReset()
destination, err := dest.ReadPacket(data)
if err != nil {
return err
}
buffer.Truncate(data.Len())
err = conn.WritePacket(buffer, destination)
if err != nil {
return err
@ -125,7 +131,8 @@ func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
defer buffer.Release()
header := buf.New()
_header := buf.StackNew()
header := common.Dup(_header)
common.Must(header.WriteZeroN(3))
common.Must(AddressSerializer.WriteAddrPort(header, addrPort))
buffer = buffer.WriteBufferAtFirst(header)

View file

@ -83,7 +83,7 @@ func HandleConnection(conn net.Conn, authenticator auth.Authenticator, bind neti
if err != nil {
return E.Cause(err, "read user auth request")
}
response := new(UsernamePasswordAuthResponse)
response := &UsernamePasswordAuthResponse{}
if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) {
response.Status = UsernamePasswordStatusSuccess
} else {
@ -109,6 +109,7 @@ func HandleConnection(conn net.Conn, authenticator auth.Authenticator, bind neti
if err != nil {
return E.Cause(err, "write socks response")
}
metadata.Protocol = "socks"
metadata.Destination = request.Destination
return handler.NewConnection(conn, metadata)
case CommandUDPAssociate:

View file

@ -92,16 +92,22 @@ func (l *Listener) loop() {
}
switch l.trans {
case redir.ModeRedirect:
metadata.Destination, _ = redir.GetOriginalDestination(tcpConn)
destination, err := redir.GetOriginalDestination(tcpConn)
if err == nil {
metadata.Protocol = "redirect"
metadata.Destination = destination
}
case redir.ModeTProxy:
lAddr := tcpConn.LocalAddr().(*net.TCPAddr)
rAddr := tcpConn.RemoteAddr().(*net.TCPAddr)
if lAddr.Port != l.lAddr.Port || !lAddr.IP.Equal(rAddr.IP) && !lAddr.IP.IsLoopback() && !lAddr.IP.IsPrivate() {
metadata.Protocol = "tproxy"
metadata.Destination = M.AddrPortFromNetAddr(lAddr)
}
}
go func() {
metadata.Protocol = "tcp"
hErr := l.handler.NewConnection(tcpConn, metadata)
if hErr != nil {
l.handler.HandleError(&Error{Conn: tcpConn, Cause: hErr})

View file

@ -80,7 +80,8 @@ func (l *Listener) loop() {
}
buffer.Truncate(n)
err = l.handler.NewPacket(buffer, M.Metadata{
Source: M.AddrPortFromNetAddr(addr),
Protocol: "udp",
Source: M.AddrPortFromNetAddr(addr),
})
if err != nil {
buffer.Release()
@ -104,6 +105,7 @@ func (l *Listener) loop() {
}
buffer.Truncate(n)
err = l.handler.NewPacket(buffer, M.Metadata{
Protocol: "tproxy",
Source: M.AddrPortFromAddrPort(addr),
Destination: destination,
})