From 48809b0a994924106ed089fc186e160c46be265f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 25 May 2022 14:00:04 +0800 Subject: [PATCH] Init commit --- .github/workflows/debug.yml | 40 ++ .gitignore | 2 + LICENSE | 14 + README.md | 3 + format.go | 6 + go.mod | 16 + go.sum | 17 + none.go | 241 +++++++++ shadowaead/aead.go | 421 +++++++++++++++ shadowaead/protocol.go | 361 +++++++++++++ shadowaead/service.go | 246 +++++++++ shadowaead_2022/protocol.go | 745 ++++++++++++++++++++++++++ shadowaead_2022/relay.go | 233 ++++++++ shadowaead_2022/service.go | 489 +++++++++++++++++ shadowaead_2022/service_multi.go | 365 +++++++++++++ shadowaead_2022/service_multi_test.go | 75 +++ shadowaead_2022/service_test.go | 49 ++ shadowimpl/fetcher.go | 24 + shadowsocks.go | 89 +++ shadowstream/protocol.go | 392 ++++++++++++++ 20 files changed, 3828 insertions(+) create mode 100644 .github/workflows/debug.yml create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 format.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 none.go create mode 100644 shadowaead/aead.go create mode 100644 shadowaead/protocol.go create mode 100644 shadowaead/service.go create mode 100644 shadowaead_2022/protocol.go create mode 100644 shadowaead_2022/relay.go create mode 100644 shadowaead_2022/service.go create mode 100644 shadowaead_2022/service_multi.go create mode 100644 shadowaead_2022/service_multi_test.go create mode 100644 shadowaead_2022/service_test.go create mode 100644 shadowimpl/fetcher.go create mode 100644 shadowsocks.go create mode 100644 shadowstream/protocol.go diff --git a/.github/workflows/debug.yml b/.github/workflows/debug.yml new file mode 100644 index 0000000..f82dc4c --- /dev/null +++ b/.github/workflows/debug.yml @@ -0,0 +1,40 @@ +name: Debug build + +on: + push: + branches: + - main + paths-ignore: + - '**.md' + - '.github/**' + - '!.github/workflows/debug.yml' + pull_request: + branches: + - main + +jobs: + build: + name: Debug build + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Get latest go version + id: version + run: | + echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g') + - name: Setup Go + uses: actions/setup-go@v2 + with: + go-version: ${{ steps.version.outputs.go_version }} + - name: Build and test + run: | + version=`git rev-parse HEAD` + mkdir build + pushd build + go mod init build + go get -v github.com/sagernet/sing-shadowsocks@$version + popd + go test -v ./... \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f7f8ac3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/.idea/ +/vendor/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3e3e29e --- /dev/null +++ b/LICENSE @@ -0,0 +1,14 @@ +Copyright (C) 2022 by nekohasekai + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e7f3d0e --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# sing-shadowsocks + +Lightweight and efficient shadowsocks implementation with sing. \ No newline at end of file diff --git a/format.go b/format.go new file mode 100644 index 0000000..ccdf305 --- /dev/null +++ b/format.go @@ -0,0 +1,6 @@ +package shadowsocks + +//go:generate go install -v mvdan.cc/gofumpt@latest +//go:generate go install -v github.com/daixiang0/gci@latest +//go:generate gofumpt -l -w . +//go:generate gci write . diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a3c06a3 --- /dev/null +++ b/go.mod @@ -0,0 +1,16 @@ +module github.com/sagernet/sing-shadowsocks + +go 1.18 + +require ( + github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d + github.com/sagernet/sing v0.0.0-20220525123305-5b49366f6e34 + golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898 + golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d + lukechampine.com/blake3 v1.1.7 +) + +require ( + github.com/klauspost/cpuid/v2 v2.0.12 // indirect + golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..45dd320 --- /dev/null +++ b/go.sum @@ -0,0 +1,17 @@ +github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d h1:CPqTNIigGweVPT4CYb+OO2E6XyRKFOmvTHwWRLgCAlE= +github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d/go.mod h1:QX5ZVULjAfZJux/W62Y91HvCh9hyW6enAwcrrv/sLj0= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE= +github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= +github.com/sagernet/sing v0.0.0-20220525062603-53c607b13ff2 h1:x7E53uloX7pU3rWOzb81IBCAmwMtE2u9x4ZJvJXaCnM= +github.com/sagernet/sing v0.0.0-20220525062603-53c607b13ff2/go.mod h1:w2HnJzXKHpD6F5Z/9XlSD4qbcpHY2RSZuQnFzqgELMg= +github.com/sagernet/sing v0.0.0-20220525123305-5b49366f6e34 h1:/FfHfteLZo5mOtZbYOx/9ymDEYxlwBuM5iHO9reVe/E= +github.com/sagernet/sing v0.0.0-20220525123305-5b49366f6e34/go.mod h1:w2HnJzXKHpD6F5Z/9XlSD4qbcpHY2RSZuQnFzqgELMg= +golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898 h1:SLP7Q4Di66FONjDJbCYrCRrh97focO6sLogHO7/g8F0= +golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d h1:q4JksJ2n0fmbXC0Aj0eOs6E0AcPqnKglxWXWFqGD6x0= +golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d/go.mod h1:bVQfyl2sCM/QIIGHpWbFGfHPuDvqnCNkT6MQLTCjO/U= +lukechampine.com/blake3 v1.1.7 h1:GgRMhmdsuK8+ii6UZFDL8Nb+VyMwadAgcJyfYHxG6n0= +lukechampine.com/blake3 v1.1.7/go.mod h1:tkKEOtDkNtklkXtLNEOGNq5tcV90tJiA1vAA12R78LA= diff --git a/none.go b/none.go new file mode 100644 index 0000000..e286f1d --- /dev/null +++ b/none.go @@ -0,0 +1,241 @@ +package shadowsocks + +import ( + "context" + "io" + "net" + "net/netip" + "runtime" + "sync" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/udpnat" +) + +const MethodNone = "none" + +type NoneMethod struct{} + +func NewNone() Method { + return &NoneMethod{} +} + +func (m *NoneMethod) Name() string { + return MethodNone +} + +func (m *NoneMethod) KeyLength() int { + return 0 +} + +func (m *NoneMethod) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { + shadowsocksConn := &noneConn{ + Conn: conn, + handshake: true, + destination: destination, + } + return shadowsocksConn, shadowsocksConn.clientHandshake() +} + +func (m *NoneMethod) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { + return &noneConn{ + Conn: conn, + destination: destination, + } +} + +func (m *NoneMethod) DialPacketConn(conn net.Conn) N.NetPacketConn { + return &nonePacketConn{conn} +} + +type noneConn struct { + net.Conn + + access sync.Mutex + handshake bool + destination M.Socksaddr +} + +func (c *noneConn) clientHandshake() error { + err := M.SocksaddrSerializer.WriteAddrPort(c.Conn, c.destination) + if err != nil { + return err + } + c.handshake = true + return nil +} + +func (c *noneConn) Write(b []byte) (n int, err error) { + if c.handshake { + goto direct + } + + c.access.Lock() + defer c.access.Unlock() + + if c.handshake { + goto direct + } + + { + if len(b) == 0 { + return 0, c.clientHandshake() + } + + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) + + err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination) + if err != nil { + return + } + + bufN, _ := buffer.Write(b) + _, err = c.Conn.Write(buffer.Bytes()) + runtime.KeepAlive(_buffer) + if err != nil { + return + } + + if bufN < len(b) { + _, err = c.Conn.Write(b[bufN:]) + if err != nil { + return + } + } + + n = len(b) + } + +direct: + return c.Conn.Write(b) +} + +func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) { + if !c.handshake { + return rw.ReadFrom0(c, r) + } + return c.Conn.(io.ReaderFrom).ReadFrom(r) +} + +func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) { + return io.Copy(w, c.Conn) +} + +func (c *noneConn) RemoteAddr() net.Addr { + return c.destination.TCPAddr() +} + +type nonePacketConn struct { + net.Conn +} + +func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + _, err := buffer.ReadFrom(c) + if err != nil { + return M.Socksaddr{}, err + } + return M.SocksaddrSerializer.ReadAddrPort(buffer) +} + +func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + return common.Error(buffer.WriteTo(c)) +} + +func (c *nonePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(p) + if err != nil { + return + } + buffer := buf.With(p[:n]) + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return + } + addr = destination.UDPAddr() + n = copy(p, buffer.Bytes()) + return +} + +func (c *nonePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + destination := M.SocksaddrFromNet(addr) + _buffer := buf.Make(M.SocksaddrSerializer.AddrPortLen(destination) + len(p)) + defer runtime.KeepAlive(_buffer) + buffer := buf.With(common.Dup(_buffer)) + err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) + if err != nil { + return + } + _, err = buffer.Write(p) + if err != nil { + return + } + return len(p), nil +} + +type NoneService struct { + handler Handler + udp *udpnat.Service[netip.AddrPort] +} + +func NewNoneService(udpTimeout int64, handler Handler) Service { + s := &NoneService{ + handler: handler, + } + s.udp = udpnat.New[netip.AddrPort](udpTimeout, s) + return s +} + +func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + destination, err := M.SocksaddrSerializer.ReadAddrPort(conn) + if err != nil { + return err + } + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + return s.handler.NewConnection(ctx, conn, metadata) +} + +func (s *NoneService) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return err + } + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + s.udp.NewPacket(ctx, metadata.Source.AddrPort(), func() N.PacketWriter { + return &nonePacketWriter{conn, metadata.Source} + }, buffer, metadata) + return nil +} + +type nonePacketWriter struct { + N.PacketConn + sourceAddr M.Socksaddr +} + +func (s *nonePacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + return s.PacketConn.WritePacket(buffer, s.sourceAddr) +} + +func (s *NoneService) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { + return s.handler.NewPacketConnection(ctx, conn, metadata) +} + +func (s *NoneService) HandleError(err error) { + s.handler.HandleError(err) +} diff --git a/shadowaead/aead.go b/shadowaead/aead.go new file mode 100644 index 0000000..658cf61 --- /dev/null +++ b/shadowaead/aead.go @@ -0,0 +1,421 @@ +package shadowaead + +import ( + "crypto/cipher" + "encoding/binary" + "io" + + "github.com/sagernet/sing/common/buf" +) + +// https://shadowsocks.org/en/wiki/AEAD-Ciphers.html +const ( + MaxPacketSize = 16*1024 - 1 + PacketLengthBufferSize = 2 +) + +const ( + // NonceSize + // crypto/cipher.gcmStandardNonceSize + // golang.org/x/crypto/chacha20poly1305.NonceSize + NonceSize = 12 + + // Overhead + // crypto/cipher.gcmTagSize + // golang.org/x/crypto/chacha20poly1305.Overhead + Overhead = 16 +) + +type Reader struct { + upstream io.Reader + cipher cipher.AEAD + buffer []byte + nonce []byte + index int + cached int +} + +func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reader { + return &Reader{ + upstream: upstream, + cipher: cipher, + buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2), + nonce: make([]byte, NonceSize), + } +} + +func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce []byte) *Reader { + return &Reader{ + upstream: upstream, + cipher: cipher, + buffer: buffer, + nonce: nonce, + } +} + +func (r *Reader) Upstream() any { + return r.upstream +} + +func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) { + if r.cached > 0 { + writeN, writeErr := writer.Write(r.buffer[r.index : r.index+r.cached]) + if writeErr != nil { + return int64(writeN), writeErr + } + n += int64(writeN) + } + for { + start := PacketLengthBufferSize + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:start]) + if err != nil { + return + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil) + if err != nil { + return + } + increaseNonce(r.nonce) + length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) + end := length + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:end]) + if err != nil { + return + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) + if err != nil { + return + } + increaseNonce(r.nonce) + writeN, writeErr := writer.Write(r.buffer[:length]) + if writeErr != nil { + return int64(writeN), writeErr + } + n += int64(writeN) + } +} + +func (r *Reader) readInternal() (err error) { + start := PacketLengthBufferSize + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:start]) + if err != nil { + return err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) + end := length + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:end]) + if err != nil { + return err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + r.cached = length + r.index = 0 + return nil +} + +func (r *Reader) ReadByte() (byte, error) { + if r.cached == 0 { + err := r.readInternal() + if err != nil { + return 0, err + } + } + index := r.index + r.index++ + r.cached-- + return r.buffer[index], nil +} + +func (r *Reader) Read(b []byte) (n int, err error) { + if r.cached > 0 { + n = copy(b, r.buffer[r.index:r.index+r.cached]) + r.cached -= n + r.index += n + return + } + start := PacketLengthBufferSize + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:start]) + if err != nil { + return 0, err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil) + if err != nil { + return 0, err + } + increaseNonce(r.nonce) + length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) + end := length + Overhead + + if len(b) >= end { + data := b[:end] + _, err = io.ReadFull(r.upstream, data) + if err != nil { + return 0, err + } + _, err = r.cipher.Open(b[:0], r.nonce, data, nil) + if err != nil { + return 0, err + } + increaseNonce(r.nonce) + return length, nil + } else { + _, err = io.ReadFull(r.upstream, r.buffer[:end]) + if err != nil { + return 0, err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) + if err != nil { + return 0, err + } + increaseNonce(r.nonce) + n = copy(b, r.buffer[:length]) + r.cached = length - n + r.index = n + return + } +} + +func (r *Reader) Discard(n int) error { + for { + if r.cached >= n { + r.cached -= n + r.index += n + return nil + } else if r.cached > 0 { + n -= r.cached + r.cached = 0 + r.index = 0 + } + err := r.readInternal() + if err != nil { + return err + } + } +} + +func (r *Reader) Cached() int { + return r.cached +} + +func (r *Reader) CachedSlice() []byte { + return r.buffer[r.index : r.index+r.cached] +} + +func (r *Reader) ReadWithLengthChunk(lengthChunk []byte) error { + _, err := r.cipher.Open(r.buffer[:0], r.nonce, lengthChunk, nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) + end := length + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:end]) + if err != nil { + return err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + r.cached = length + r.index = 0 + return nil +} + +func (r *Reader) ReadWithLength(length uint16) error { + end := length + Overhead + _, err := io.ReadFull(r.upstream, r.buffer[:end]) + if err != nil { + return err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + r.cached = int(length) + r.index = 0 + return nil +} + +func (r *Reader) ReadChunk(chunk []byte) error { + bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + r.cached = len(bb) + r.index = 0 + return nil +} + +type Writer struct { + upstream io.Writer + cipher cipher.AEAD + maxPacketSize int + buffer []byte + nonce []byte +} + +func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *Writer { + return &Writer{ + upstream: upstream, + cipher: cipher, + buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2), + nonce: make([]byte, cipher.NonceSize()), + maxPacketSize: maxPacketSize, + } +} + +func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buffer []byte, nonce []byte) *Writer { + return &Writer{ + upstream: upstream, + cipher: cipher, + maxPacketSize: maxPacketSize, + buffer: buffer, + nonce: nonce, + } +} + +func (w *Writer) Upstream() any { + return w.upstream +} + +func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) { + for { + offset := Overhead + PacketLengthBufferSize + readN, readErr := r.Read(w.buffer[offset : offset+w.maxPacketSize]) + if readErr != nil { + return 0, readErr + } + binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(readN)) + w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil) + increaseNonce(w.nonce) + packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, w.buffer[offset:offset+readN], nil) + increaseNonce(w.nonce) + _, err = w.upstream.Write(w.buffer[:offset+len(packet)]) + if err != nil { + return + } + n += int64(readN) + } +} + +func (w *Writer) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return + } + + for pLen := len(p); pLen > 0; { + var data []byte + if pLen > w.maxPacketSize { + data = p[:w.maxPacketSize] + p = p[w.maxPacketSize:] + } else { + data = p + p = nil + } + binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data))) + w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil) + increaseNonce(w.nonce) + offset := Overhead + PacketLengthBufferSize + packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, data, nil) + increaseNonce(w.nonce) + _, err = w.upstream.Write(w.buffer[:offset+len(packet)]) + if err != nil { + return + } + n += len(data) + } + + return +} + +func (w *Writer) Buffer() *buf.Buffer { + return buf.With(w.buffer) +} + +func (w *Writer) WriteChunk(buffer *buf.Buffer, chunk []byte) { + bb := w.cipher.Seal(buffer.Index(buffer.Len()), w.nonce, chunk, nil) + buffer.Extend(len(bb)) + increaseNonce(w.nonce) +} + +func (w *Writer) BufferedWriter(reversed int) *BufferedWriter { + return &BufferedWriter{ + upstream: w, + reversed: reversed, + data: w.buffer[PacketLengthBufferSize+Overhead : len(w.buffer)-Overhead], + } +} + +type BufferedWriter struct { + upstream *Writer + data []byte + reversed int + index int +} + +func (w *BufferedWriter) UpstreamWriter() io.Writer { + return w.upstream +} + +func (w *BufferedWriter) WriterReplaceable() bool { + return w.index == 0 +} + +func (w *BufferedWriter) Write(p []byte) (n int, err error) { + var index int + for { + cachedN := copy(w.data[w.reversed+w.index:], p[index:]) + if cachedN == len(p[index:]) { + w.index += cachedN + return cachedN, nil + } + err = w.Flush() + if err != nil { + return + } + index += cachedN + } +} + +func (w *BufferedWriter) Flush() error { + if w.index == 0 { + if w.reversed > 0 { + _, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed]) + w.reversed = 0 + return err + } + return nil + } + buffer := w.upstream.buffer[w.reversed:] + binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index)) + w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil) + increaseNonce(w.upstream.nonce) + offset := Overhead + PacketLengthBufferSize + packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil) + increaseNonce(w.upstream.nonce) + _, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)]) + w.reversed = 0 + return err +} + +func increaseNonce(nonce []byte) { + for i := range nonce { + nonce[i]++ + if nonce[i] != 0 { + return + } + } +} diff --git a/shadowaead/protocol.go b/shadowaead/protocol.go new file mode 100644 index 0000000..1271952 --- /dev/null +++ b/shadowaead/protocol.go @@ -0,0 +1,361 @@ +package shadowaead + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha1" + "io" + "net" + "runtime" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/hkdf" +) + +var List = []string{ + "aes-128-gcm", + "aes-192-gcm", + "aes-256-gcm", + "chacha20-ietf-poly1305", + "xchacha20-ietf-poly1305", +} + +func New(method string, key []byte, password string) (shadowsocks.Method, error) { + m := &Method{ + name: method, + } + switch method { + case "aes-128-gcm": + m.keySaltLength = 16 + m.constructor = newAESGCM + case "aes-192-gcm": + m.keySaltLength = 24 + m.constructor = newAESGCM + case "aes-256-gcm": + m.keySaltLength = 32 + m.constructor = newAESGCM + case "chacha20-ietf-poly1305": + m.keySaltLength = 32 + m.constructor = func(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.New(key) + common.Must(err) + return cipher + } + case "xchacha20-ietf-poly1305": + m.keySaltLength = 32 + m.constructor = func(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.NewX(key) + common.Must(err) + return cipher + } + } + if len(key) == m.keySaltLength { + m.key = key + } else if len(key) > 0 { + return nil, shadowsocks.ErrBadKey + } else if password == "" { + return nil, shadowsocks.ErrMissingPassword + } else { + m.key = shadowsocks.Key([]byte(password), m.keySaltLength) + } + return m, nil +} + +func Kdf(key, iv []byte, keyLength int) []byte { + info := []byte("ss-subkey") + subKey := buf.Make(keyLength) + kdf := hkdf.New(sha1.New, key, iv, common.Dup(info)) + runtime.KeepAlive(info) + common.Must1(io.ReadFull(kdf, common.Dup(subKey))) + return subKey +} + +func newAESGCM(key []byte) cipher.AEAD { + block, err := aes.NewCipher(key) + common.Must(err) + aead, err := cipher.NewGCM(block) + common.Must(err) + return aead +} + +type Method struct { + name string + keySaltLength int + constructor func(key []byte) cipher.AEAD + key []byte +} + +func (m *Method) Name() string { + return m.name +} + +func (m *Method) KeyLength() int { + return m.keySaltLength +} + +func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) { + _salt := buf.Make(m.keySaltLength) + defer runtime.KeepAlive(_salt) + salt := common.Dup(_salt) + _, err := io.ReadFull(upstream, salt) + if err != nil { + return nil, E.Cause(err, "read salt") + } + key := Kdf(m.key, salt, m.keySaltLength) + defer runtime.KeepAlive(key) + return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil +} + +func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) { + _salt := buf.Make(m.keySaltLength) + defer runtime.KeepAlive(_salt) + salt := common.Dup(_salt) + common.Must1(io.ReadFull(rand.Reader, salt)) + _, err := upstream.Write(salt) + if err != nil { + return nil, err + } + key := Kdf(m.key, salt, m.keySaltLength) + return NewWriter(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil +} + +func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { + shadowsocksConn := &clientConn{ + Conn: conn, + method: m, + destination: destination, + } + return shadowsocksConn, shadowsocksConn.writeRequest(nil) +} + +func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { + return &clientConn{ + Conn: conn, + method: m, + destination: destination, + } +} + +func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { + return &clientPacketConn{m, conn} +} + +func (m *Method) EncodePacket(buffer *buf.Buffer) error { + key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength) + c := m.constructor(common.Dup(key)) + runtime.KeepAlive(key) + c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) + buffer.Extend(Overhead) + return nil +} + +func (m *Method) DecodePacket(buffer *buf.Buffer) error { + if buffer.Len() < m.keySaltLength { + return E.New("bad packet") + } + key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength) + c := m.constructor(common.Dup(key)) + runtime.KeepAlive(key) + packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) + if err != nil { + return err + } + buffer.Advance(m.keySaltLength) + buffer.Truncate(len(packet)) + return nil +} + +type clientConn struct { + net.Conn + method *Method + destination M.Socksaddr + reader *Reader + writer *Writer +} + +func (c *clientConn) writeRequest(payload []byte) error { + _salt := make([]byte, c.method.keySaltLength) + salt := common.Dup(_salt) + common.Must1(io.ReadFull(rand.Reader, salt)) + + key := Kdf(c.method.key, salt, c.method.keySaltLength) + runtime.KeepAlive(_salt) + writer := NewWriter( + c.Conn, + c.method.constructor(common.Dup(key)), + MaxPacketSize, + ) + runtime.KeepAlive(key) + header := writer.Buffer() + header.Write(salt) + bufferedWriter := writer.BufferedWriter(header.Len()) + + if len(payload) > 0 { + err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination) + if err != nil { + return err + } + + _, err = bufferedWriter.Write(payload) + if err != nil { + return err + } + } else { + err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination) + if err != nil { + return err + } + } + + err := bufferedWriter.Flush() + if err != nil { + return err + } + + c.writer = writer + return nil +} + +func (c *clientConn) readResponse() error { + if c.reader != nil { + return nil + } + _salt := buf.Make(c.method.keySaltLength) + defer runtime.KeepAlive(_salt) + salt := common.Dup(_salt) + _, err := io.ReadFull(c.Conn, salt) + if err != nil { + return err + } + key := Kdf(c.method.key, salt, c.method.keySaltLength) + defer runtime.KeepAlive(key) + c.reader = NewReader( + c.Conn, + c.method.constructor(common.Dup(key)), + MaxPacketSize, + ) + return nil +} + +func (c *clientConn) Read(p []byte) (n int, err error) { + if err = c.readResponse(); err != nil { + return + } + return c.reader.Read(p) +} + +func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) { + if err = c.readResponse(); err != nil { + return + } + return c.reader.WriteTo(w) +} + +func (c *clientConn) Write(p []byte) (n int, err error) { + if c.writer != nil { + return c.writer.Write(p) + } + + err = c.writeRequest(p) + if err != nil { + return + } + return len(p), nil +} + +func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) { + if c.writer == nil { + return rw.ReadFrom0(c, r) + } + return c.writer.ReadFrom(r) +} + +func (c *clientConn) Upstream() any { + return c.Conn +} + +type clientPacketConn struct { + *Method + net.Conn +} + +func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buffer.ExtendHeader(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination)) + common.Must1(io.ReadFull(rand.Reader, header[:c.keySaltLength])) + err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination) + if err != nil { + return err + } + err = c.EncodePacket(buffer) + if err != nil { + return err + } + return common.Error(c.Write(buffer.Bytes())) +} + +func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + n, err := c.Read(buffer.FreeBytes()) + if err != nil { + return M.Socksaddr{}, err + } + buffer.Truncate(n) + err = c.DecodePacket(buffer) + if err != nil { + return M.Socksaddr{}, err + } + return M.SocksaddrSerializer.ReadAddrPort(buffer) +} + +func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(p) + if err != nil { + return + } + b := buf.With(p[:n]) + err = c.DecodePacket(b) + if err != nil { + return + } + destination, err := M.SocksaddrSerializer.ReadAddrPort(b) + if err != nil { + return + } + addr = destination.UDPAddr() + n = copy(p, b.Bytes()) + return +} + +func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + _buffer := buf.StackNew() + defer runtime.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr)) + if err != nil { + return + } + _, err = buffer.Write(p) + if err != nil { + return + } + err = c.EncodePacket(buffer) + if err != nil { + return + } + _, err = c.Write(buffer.Bytes()) + if err != nil { + return + } + return len(p), nil +} + +func (c *clientPacketConn) Upstream() any { + return c.Conn +} diff --git a/shadowaead/service.go b/shadowaead/service.go new file mode 100644 index 0000000..5708e2f --- /dev/null +++ b/shadowaead/service.go @@ -0,0 +1,246 @@ +package shadowaead + +import ( + "context" + "crypto/cipher" + "crypto/rand" + "io" + "net" + "net/netip" + "runtime" + "sync" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/udpnat" + "golang.org/x/crypto/chacha20poly1305" +) + +var ErrBadHeader = E.New("bad header") + +type Service struct { + name string + keySaltLength int + constructor func(key []byte) cipher.AEAD + key []byte + handler shadowsocks.Handler + udpNat *udpnat.Service[netip.AddrPort] +} + +func NewService(method string, key []byte, password string, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { + s := &Service{ + name: method, + handler: handler, + udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler), + } + switch method { + case "aes-128-gcm": + s.keySaltLength = 16 + s.constructor = newAESGCM + case "aes-192-gcm": + s.keySaltLength = 24 + s.constructor = newAESGCM + case "aes-256-gcm": + s.keySaltLength = 32 + s.constructor = newAESGCM + case "chacha20-ietf-poly1305": + s.keySaltLength = 32 + s.constructor = func(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.New(key) + common.Must(err) + return cipher + } + case "xchacha20-ietf-poly1305": + s.keySaltLength = 32 + s.constructor = func(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.NewX(key) + common.Must(err) + return cipher + } + } + if len(key) == s.keySaltLength { + s.key = key + } else if len(key) > 0 { + return nil, shadowsocks.ErrBadKey + } else if password != "" { + s.key = shadowsocks.Key([]byte(password), s.keySaltLength) + } else { + return nil, shadowsocks.ErrMissingPassword + } + return s, nil +} + +func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + err := s.newConnection(ctx, conn, metadata) + if err != nil { + err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} + } + return err +} + +func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + _header := buf.Make(s.keySaltLength + PacketLengthBufferSize + Overhead) + defer runtime.KeepAlive(_header) + header := common.Dup(_header) + + n, err := conn.Read(header) + if err != nil { + return E.Cause(err, "read header") + } else if n < len(header) { + return ErrBadHeader + } + + key := Kdf(s.key, header[:s.keySaltLength], s.keySaltLength) + reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize) + + err = reader.ReadWithLengthChunk(header[s.keySaltLength:]) + if err != nil { + return err + } + + destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) + if err != nil { + return err + } + + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + + return s.handler.NewConnection(ctx, &serverConn{ + Service: s, + Conn: conn, + reader: reader, + }, metadata) +} + +type serverConn struct { + *Service + net.Conn + access sync.Mutex + reader *Reader + writer *Writer +} + +func (c *serverConn) writeResponse(payload []byte) (n int, err error) { + _salt := buf.Make(c.keySaltLength) + salt := common.Dup(_salt) + common.Must1(io.ReadFull(rand.Reader, salt)) + + key := Kdf(c.key, salt, c.keySaltLength) + runtime.KeepAlive(_salt) + + writer := NewWriter( + c.Conn, + c.constructor(common.Dup(key)), + MaxPacketSize, + ) + runtime.KeepAlive(key) + + header := writer.Buffer() + header.Write(salt) + + bufferedWriter := writer.BufferedWriter(header.Len()) + if len(payload) > 0 { + _, err = bufferedWriter.Write(payload) + if err != nil { + return + } + } + + err = bufferedWriter.Flush() + if err != nil { + return + } + + c.writer = writer + return +} + +func (c *serverConn) Write(p []byte) (n int, err error) { + if c.writer != nil { + return c.writer.Write(p) + } + c.access.Lock() + if c.writer != nil { + c.access.Unlock() + return c.writer.Write(p) + } + defer c.access.Unlock() + return c.writeResponse(p) +} + +func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) { + if c.writer == nil { + return rw.ReadFrom0(c, r) + } + return c.writer.ReadFrom(r) +} + +func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) { + return c.reader.WriteTo(w) +} + +func (c *serverConn) Upstream() any { + return c.Conn +} + +func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + err := s.newPacket(ctx, conn, buffer, metadata) + if err != nil { + err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} + } + return err +} + +func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + if buffer.Len() < s.keySaltLength { + return E.New("bad packet") + } + key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength) + c := s.constructor(common.Dup(key)) + runtime.KeepAlive(key) + packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil) + if err != nil { + return err + } + buffer.Advance(s.keySaltLength) + buffer.Truncate(len(packet)) + + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return err + } + + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + s.udpNat.NewPacket(ctx, metadata.Source.AddrPort(), func() N.PacketWriter { + return &serverPacketWriter{s, conn, metadata.Source} + }, buffer, metadata) + return nil +} + +type serverPacketWriter struct { + *Service + N.PacketConn + source M.Socksaddr +} + +func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buffer.ExtendHeader(w.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination)) + common.Must1(io.ReadFull(rand.Reader, header[:w.keySaltLength])) + err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination) + if err != nil { + return err + } + key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength) + c := w.constructor(common.Dup(key)) + runtime.KeepAlive(key) + c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil) + buffer.Extend(Overhead) + return w.PacketConn.WritePacket(buffer, w.source) +} diff --git a/shadowaead_2022/protocol.go b/shadowaead_2022/protocol.go new file mode 100644 index 0000000..f990d0b --- /dev/null +++ b/shadowaead_2022/protocol.go @@ -0,0 +1,745 @@ +package shadowaead_2022 + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "io" + "math" + mRand "math/rand" + "net" + "os" + "runtime" + "strings" + "sync/atomic" + "time" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/replay" + "github.com/sagernet/sing/common/rw" + "golang.org/x/crypto/chacha20poly1305" + wgReplay "golang.zx2c4.com/wireguard/replay" + "lukechampine.com/blake3" +) + +const ( + HeaderTypeClient = 0 + HeaderTypeServer = 1 + MaxPaddingLength = 900 + PacketNonceSize = 24 + MaxPacketSize = 65535 + RequestHeaderFixedChunkLength = 1 + 8 + 2 +) + +var ( + ErrMissingPSK = E.New("missing psk") + ErrBadHeaderType = E.New("bad header type") + ErrBadTimestamp = E.New("bad timestamp") + ErrBadRequestSalt = E.New("bad request salt") + ErrBadClientSessionId = E.New("bad client session id") + ErrPacketIdNotUnique = E.New("packet id not unique") + ErrTooManyServerSessions = E.New("server session changed more than once during the last minute") +) + +var List = []string{ + "2022-blake3-aes-128-gcm", + "2022-blake3-aes-256-gcm", + "2022-blake3-chacha20-poly1305", +} + +func NewWithPassword(method string, password string) (shadowsocks.Method, error) { + var pskList [][]byte + if password == "" { + return nil, ErrMissingPSK + } + keyStrList := strings.Split(password, ":") + pskList = make([][]byte, len(keyStrList)) + for i, keyStr := range keyStrList { + kb, err := base64.StdEncoding.DecodeString(keyStr) + if err != nil { + return nil, E.Cause(err, "decode key") + } + pskList[i] = kb + } + return New(method, pskList) +} + +func New(method string, pskList [][]byte) (shadowsocks.Method, error) { + m := &Method{ + name: method, + replayFilter: replay.NewSimple(60 * time.Second), + } + + switch method { + case "2022-blake3-aes-128-gcm": + m.keySaltLength = 16 + m.constructor = newAESGCM + m.blockConstructor = newAES + case "2022-blake3-aes-256-gcm": + m.keySaltLength = 32 + m.constructor = newAESGCM + m.blockConstructor = newAES + case "2022-blake3-chacha20-poly1305": + if len(pskList) > 1 { + return nil, os.ErrInvalid + } + m.keySaltLength = 32 + m.constructor = newChacha20Poly1305 + } + + if len(pskList) == 0 { + return nil, ErrMissingPSK + } + + for i, psk := range pskList { + if len(psk) < m.keySaltLength { + return nil, shadowsocks.ErrBadKey + } else if len(psk) > m.keySaltLength { + pskList[i] = Key(psk, m.keySaltLength) + } + } + + if len(pskList) > 1 { + pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize) + for i, psk := range pskList { + if i == 0 { + continue + } + hash := blake3.Sum512(psk) + copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], hash[:aes.BlockSize]) + } + m.pskHash = pskHash + } + + switch method { + case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm": + m.udpBlockCipher = newAES(pskList[0]) + case "2022-blake3-chacha20-poly1305": + m.udpCipher = newXChacha20Poly1305(pskList[0]) + } + + m.pskList = pskList + return m, nil +} + +func Key(key []byte, keyLength int) []byte { + psk := sha256.Sum256(key) + return psk[:keyLength] +} + +func SessionKey(psk []byte, salt []byte, keyLength int) []byte { + sessionKey := buf.Make(len(psk) + len(salt)) + copy(sessionKey, psk) + copy(sessionKey[len(psk):], salt) + outKey := buf.Make(keyLength) + blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey) + return outKey +} + +func newAES(key []byte) cipher.Block { + block, err := aes.NewCipher(key) + common.Must(err) + return block +} + +func newAESGCM(key []byte) cipher.AEAD { + block, err := aes.NewCipher(key) + common.Must(err) + aead, err := cipher.NewGCM(block) + common.Must(err) + return aead +} + +func newChacha20Poly1305(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.New(key) + common.Must(err) + return cipher +} + +func newXChacha20Poly1305(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.NewX(key) + common.Must(err) + return cipher +} + +type Method struct { + name string + keySaltLength int + constructor func(key []byte) cipher.AEAD + blockConstructor func(key []byte) cipher.Block + udpCipher cipher.AEAD + udpBlockCipher cipher.Block + pskList [][]byte + pskHash []byte + replayFilter replay.Filter +} + +func (m *Method) Name() string { + return m.name +} + +func (m *Method) KeyLength() int { + return m.keySaltLength +} + +func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { + shadowsocksConn := &clientConn{ + Method: m, + Conn: conn, + destination: destination, + } + return shadowsocksConn, shadowsocksConn.writeRequest(nil) +} + +func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { + return &clientConn{ + Method: m, + Conn: conn, + destination: destination, + } +} + +func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { + return &clientPacketConn{m, conn, m.newUDPSession()} +} + +type clientConn struct { + *Method + net.Conn + destination M.Socksaddr + requestSalt []byte + reader *shadowaead.Reader + writer *shadowaead.Writer +} + +func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) { + pskLen := len(m.pskList) + if pskLen < 2 { + return + } + for i, psk := range m.pskList { + keyMaterial := buf.Make(m.keySaltLength * 2) + copy(keyMaterial, psk) + copy(keyMaterial[m.keySaltLength:], salt) + _identitySubkey := buf.Make(m.keySaltLength) + identitySubkey := common.Dup(_identitySubkey) + blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) + + pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] + + header := request.Extend(16) + m.blockConstructor(identitySubkey).Encrypt(header, pskHash) + runtime.KeepAlive(_identitySubkey) + if i == pskLen-2 { + break + } + } +} + +func (c *clientConn) writeRequest(payload []byte) error { + salt := buf.Make(c.keySaltLength) + common.Must1(io.ReadFull(rand.Reader, salt)) + + key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength) + writer := shadowaead.NewWriter( + c.Conn, + c.constructor(common.Dup(key)), + MaxPacketSize, + ) + runtime.KeepAlive(key) + + header := writer.Buffer() + header.Write(salt) + c.writeExtendedIdentityHeaders(header, salt) + + var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte + fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:])) + common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient)) + common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(time.Now().Unix()))) + var paddingLen int + if len(payload) == 0 { + paddingLen = mRand.Intn(MaxPaddingLength + 1) + } + variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen + len(payload) + common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen))) + writer.WriteChunk(header, fixedLengthBuffer.Slice()) + runtime.KeepAlive(_fixedLengthBuffer) + + _variableLengthBuffer := buf.Make(variableLengthHeaderLen) + variableLengthBuffer := buf.With(common.Dup(_variableLengthBuffer)) + common.Must(M.SocksaddrSerializer.WriteAddrPort(variableLengthBuffer, c.destination)) + common.Must(binary.Write(variableLengthBuffer, binary.BigEndian, uint16(paddingLen))) + if paddingLen > 0 { + variableLengthBuffer.Extend(paddingLen) + } else { + common.Must1(variableLengthBuffer.Write(payload)) + } + writer.WriteChunk(header, variableLengthBuffer.Slice()) + runtime.KeepAlive(_variableLengthBuffer) + + err := writer.BufferedWriter(header.Len()).Flush() + if err != nil { + return E.Cause(err, "client handshake") + } + + c.requestSalt = salt + c.writer = writer + return nil +} + +func (c *clientConn) readResponse() error { + if c.reader != nil { + return nil + } + + _salt := buf.Make(c.keySaltLength) + salt := common.Dup(_salt) + _, err := io.ReadFull(c.Conn, salt) + if err != nil { + return err + } + + if !c.replayFilter.Check(salt) { + return ErrSaltNotUnique + } + + key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength) + runtime.KeepAlive(_salt) + reader := shadowaead.NewReader( + c.Conn, + c.constructor(common.Dup(key)), + MaxPacketSize, + ) + runtime.KeepAlive(key) + + err = reader.ReadWithLength(uint16(1 + 8 + c.keySaltLength + 2)) + if err != nil { + return E.Cause(err, "read response fixed length chunk") + } + + headerType, err := rw.ReadByte(reader) + if err != nil { + return err + } + if headerType != HeaderTypeServer { + return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType) + } + + var epoch uint64 + err = binary.Read(reader, binary.BigEndian, &epoch) + if err != nil { + return err + } + + diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + if diff > 30 { + return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") + } + + _requestSalt := buf.Make(c.keySaltLength) + requestSalt := common.Dup(_requestSalt) + _, err = io.ReadFull(reader, requestSalt) + if err != nil { + return err + } + + if bytes.Compare(requestSalt, c.requestSalt) > 0 { + return ErrBadRequestSalt + } + runtime.KeepAlive(_requestSalt) + + var length uint16 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return err + } + + err = reader.ReadWithLength(length) + if err != nil { + return err + } + + c.requestSalt = nil + c.reader = reader + + return nil +} + +func (c *clientConn) Read(p []byte) (n int, err error) { + if err = c.readResponse(); err != nil { + return + } + return c.reader.Read(p) +} + +func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) { + if err = c.readResponse(); err != nil { + return + } + return c.reader.WriteTo(w) +} + +func (c *clientConn) Write(p []byte) (n int, err error) { + if c.writer == nil { + err = c.writeRequest(p) + if err == nil { + n = len(p) + } + return + } + return c.writer.Write(p) +} + +func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) { + if c.writer == nil { + return rw.ReadFrom0(c, r) + } + return c.writer.ReadFrom(r) +} + +func (c *clientConn) Upstream() any { + return c.Conn +} + +type clientPacketConn struct { + *Method + net.Conn + session *udpSession +} + +func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + var hdrLen int + if c.udpCipher != nil { + hdrLen = PacketNonceSize + } + hdrLen += 16 // packet header + pskLen := len(c.pskList) + if c.udpCipher == nil && pskLen > 1 { + hdrLen += (pskLen - 1) * aes.BlockSize + } + hdrLen += 1 // header type + hdrLen += 8 // timestamp + hdrLen += 2 // padding length + hdrLen += M.SocksaddrSerializer.AddrPortLen(destination) + header := buf.With(buffer.ExtendHeader(hdrLen)) + + var dataIndex int + if c.udpCipher != nil { + common.Must1(header.ReadFullFrom(c.session.rng, PacketNonceSize)) + if pskLen > 1 { + panic("unsupported chacha extended header") + } + dataIndex = PacketNonceSize + } else { + dataIndex = aes.BlockSize + } + + common.Must( + binary.Write(header, binary.BigEndian, c.session.sessionId), + binary.Write(header, binary.BigEndian, c.session.nextPacketId()), + ) + + if c.udpCipher == nil && pskLen > 1 { + for i, psk := range c.pskList { + dataIndex += aes.BlockSize + pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] + + identityHeader := header.Extend(aes.BlockSize) + for textI := 0; textI < aes.BlockSize; textI++ { + identityHeader[textI] = pskHash[textI] ^ header.Byte(textI) + } + c.blockConstructor(psk).Encrypt(identityHeader, identityHeader) + + if i == pskLen-2 { + break + } + } + } + common.Must( + header.WriteByte(HeaderTypeClient), + binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())), + binary.Write(header, binary.BigEndian, uint16(0)), // padding length + ) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + if c.udpCipher != nil { + c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) + buffer.Extend(shadowaead.Overhead) + } else { + packetHeader := buffer.To(aes.BlockSize) + c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) + buffer.Extend(shadowaead.Overhead) + c.udpBlockCipher.Encrypt(packetHeader, packetHeader) + } + return common.Error(c.Write(buffer.Bytes())) +} + +func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + n, err := c.Read(buffer.FreeBytes()) + if err != nil { + return M.Socksaddr{}, err + } + buffer.Truncate(n) + + var packetHeader []byte + if c.udpCipher != nil { + _, err = c.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) + if err != nil { + return M.Socksaddr{}, E.Cause(err, "decrypt packet") + } + buffer.Advance(PacketNonceSize) + buffer.Truncate(buffer.Len() - shadowaead.Overhead) + } else { + packetHeader = buffer.To(aes.BlockSize) + c.udpBlockCipher.Decrypt(packetHeader, packetHeader) + } + + var sessionId, packetId uint64 + err = binary.Read(buffer, binary.BigEndian, &sessionId) + if err != nil { + return M.Socksaddr{}, err + } + err = binary.Read(buffer, binary.BigEndian, &packetId) + if err != nil { + return M.Socksaddr{}, err + } + + var remoteCipher cipher.AEAD + if packetHeader != nil { + if sessionId == c.session.remoteSessionId { + remoteCipher = c.session.remoteCipher + } else if sessionId == c.session.lastRemoteSessionId { + remoteCipher = c.session.lastRemoteCipher + } else { + key := SessionKey(c.pskList[len(c.pskList)-1], packetHeader[:8], c.keySaltLength) + remoteCipher = c.constructor(common.Dup(key)) + runtime.KeepAlive(key) + } + _, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) + if err != nil { + return M.Socksaddr{}, E.Cause(err, "decrypt packet") + } + buffer.Truncate(buffer.Len() - shadowaead.Overhead) + } + + var headerType byte + headerType, err = buffer.ReadByte() + if err != nil { + return M.Socksaddr{}, err + } + if headerType != HeaderTypeServer { + return M.Socksaddr{}, E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType) + } + + var epoch uint64 + err = binary.Read(buffer, binary.BigEndian, &epoch) + if err != nil { + return M.Socksaddr{}, err + } + + diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + if diff > 30 { + return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") + } + + if sessionId == c.session.remoteSessionId { + if !c.session.filter.ValidateCounter(packetId, math.MaxUint64) { + return M.Socksaddr{}, ErrPacketIdNotUnique + } + } else if sessionId == c.session.lastRemoteSessionId { + if !c.session.lastFilter.ValidateCounter(packetId, math.MaxUint64) { + return M.Socksaddr{}, ErrPacketIdNotUnique + } + remoteCipher = c.session.lastRemoteCipher + c.session.lastRemoteSeen = time.Now().Unix() + } else { + if c.session.remoteSessionId != 0 { + if time.Now().Unix()-c.session.lastRemoteSeen < 60 { + return M.Socksaddr{}, ErrTooManyServerSessions + } else { + c.session.lastRemoteSessionId = c.session.remoteSessionId + c.session.lastFilter = c.session.filter + c.session.lastRemoteSeen = time.Now().Unix() + c.session.lastRemoteCipher = c.session.remoteCipher + c.session.filter = wgReplay.Filter{} + } + } + c.session.remoteSessionId = sessionId + c.session.remoteCipher = remoteCipher + c.session.filter.ValidateCounter(packetId, math.MaxUint64) + } + + var clientSessionId uint64 + err = binary.Read(buffer, binary.BigEndian, &clientSessionId) + if err != nil { + return M.Socksaddr{}, err + } + + if clientSessionId != c.session.sessionId { + return M.Socksaddr{}, ErrBadClientSessionId + } + + var paddingLength uint16 + err = binary.Read(buffer, binary.BigEndian, &paddingLength) + if err != nil { + return M.Socksaddr{}, E.Cause(err, "read padding length") + } + buffer.Advance(int(paddingLength)) + + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return M.Socksaddr{}, err + } + return destination, nil +} + +func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + buffer := buf.With(p) + destination, err := c.ReadPacket(buffer) + if err != nil { + return + } + addr = destination.UDPAddr() + n = copy(p, buffer.Bytes()) + return +} + +func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + destination := M.SocksaddrFromNet(addr) + var overHead int + if c.udpCipher != nil { + overHead = PacketNonceSize + shadowaead.Overhead + } else { + overHead = shadowaead.Overhead + } + overHead += 16 // packet header + pskLen := len(c.pskList) + if c.udpCipher == nil && pskLen > 1 { + overHead += (pskLen - 1) * aes.BlockSize + } + overHead += 1 // header type + overHead += 8 // timestamp + overHead += 2 // padding length + overHead += M.SocksaddrSerializer.AddrPortLen(destination) + + _buffer := buf.Make(overHead + len(p)) + defer runtime.KeepAlive(_buffer) + buffer := buf.With(common.Dup(_buffer)) + + var dataIndex int + if c.udpCipher != nil { + common.Must1(buffer.ReadFullFrom(c.session.rng, PacketNonceSize)) + if pskLen > 1 { + panic("unsupported chacha extended header") + } + dataIndex = PacketNonceSize + } else { + dataIndex = aes.BlockSize + } + + common.Must( + binary.Write(buffer, binary.BigEndian, c.session.sessionId), + binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()), + ) + + if c.udpCipher == nil && pskLen > 1 { + for i, psk := range c.pskList { + dataIndex += aes.BlockSize + pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] + + identityHeader := buffer.Extend(aes.BlockSize) + for textI := 0; textI < aes.BlockSize; textI++ { + identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI) + } + c.blockConstructor(psk).Encrypt(identityHeader, identityHeader) + + if i == pskLen-2 { + break + } + } + } + common.Must( + buffer.WriteByte(HeaderTypeClient), + binary.Write(buffer, binary.BigEndian, uint64(time.Now().Unix())), + binary.Write(buffer, binary.BigEndian, uint16(0)), // padding length + ) + err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) + if err != nil { + return + } + if c.udpCipher != nil { + c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) + buffer.Extend(shadowaead.Overhead) + } else { + packetHeader := buffer.To(aes.BlockSize) + c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) + buffer.Extend(shadowaead.Overhead) + c.udpBlockCipher.Encrypt(packetHeader, packetHeader) + } + err = common.Error(c.Write(buffer.Bytes())) + if err != nil { + return + } + return len(p), nil +} + +type udpSession struct { + headerType byte + sessionId uint64 + packetId uint64 + remoteSessionId uint64 + lastRemoteSessionId uint64 + lastRemoteSeen int64 + cipher cipher.AEAD + remoteCipher cipher.AEAD + lastRemoteCipher cipher.AEAD + filter wgReplay.Filter + lastFilter wgReplay.Filter + rng io.Reader +} + +func (s *udpSession) nextPacketId() uint64 { + return atomic.AddUint64(&s.packetId, 1) +} + +func (m *Method) newUDPSession() *udpSession { + session := &udpSession{} + if m.udpCipher != nil { + session.rng = Blake3KeyedHash(rand.Reader) + common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) + } else { + common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) + } + session.packetId-- + if m.udpCipher == nil { + sessionId := make([]byte, 8) + binary.BigEndian.PutUint64(sessionId, session.sessionId) + key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength) + session.cipher = m.constructor(common.Dup(key)) + runtime.KeepAlive(key) + } + return session +} + +func (c *clientPacketConn) Upstream() any { + return c.Conn +} + +func Blake3KeyedHash(reader io.Reader) io.Reader { + key := make([]byte, 32) + common.Must1(io.ReadFull(reader, key)) + h := blake3.New(1024, key) + return h.XOF() +} diff --git a/shadowaead_2022/relay.go b/shadowaead_2022/relay.go new file mode 100644 index 0000000..61da950 --- /dev/null +++ b/shadowaead_2022/relay.go @@ -0,0 +1,233 @@ +package shadowaead_2022 + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "encoding/binary" + "io" + "net" + "os" + "runtime" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/cache" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/udpnat" + "lukechampine.com/blake3" +) + +type Relay[U comparable] struct { + name string + secureRNG io.Reader + keySaltLength int + handler shadowsocks.Handler + + constructor func(key []byte) cipher.AEAD + blockConstructor func(key []byte) cipher.Block + udpBlockCipher cipher.Block + + iPSK []byte + uPSKHash map[U][aes.BlockSize]byte + uPSKHashR map[[aes.BlockSize]byte]U + uDestination map[U]M.Socksaddr + uCipher map[U]cipher.Block + udpNat *udpnat.Service[uint64] + udpSessions *cache.LruCache[uint64, *relayUDPSession] +} + +func (s *Relay[U]) AddUser(user U, key []byte, destination M.Socksaddr) error { + if len(key) < s.keySaltLength { + return shadowsocks.ErrBadKey + } else if len(key) > s.keySaltLength { + key = Key(key, s.keySaltLength) + } + + var uPSKHash [aes.BlockSize]byte + hash512 := blake3.Sum512(key) + copy(uPSKHash[:], hash512[:]) + + if oldHash, loaded := s.uPSKHash[user]; loaded { + delete(s.uPSKHashR, oldHash) + } + + s.uPSKHash[user] = uPSKHash + s.uPSKHashR[uPSKHash] = user + s.uDestination[user] = destination + s.uCipher[user] = s.blockConstructor(key) + + return nil +} + +func (s *Relay[U]) RemoveUser(user U) { + if hash, loaded := s.uPSKHash[user]; loaded { + delete(s.uPSKHashR, hash) + } + delete(s.uPSKHash, user) + delete(s.uCipher, user) +} + +func NewRelay[U comparable](method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (*Relay[U], error) { + s := &Relay[U]{ + name: method, + secureRNG: secureRNG, + handler: handler, + + uPSKHash: make(map[U][aes.BlockSize]byte), + uPSKHashR: make(map[[aes.BlockSize]byte]U), + uDestination: make(map[U]M.Socksaddr), + uCipher: make(map[U]cipher.Block), + + udpNat: udpnat.New[uint64](udpTimeout, handler), + udpSessions: cache.New( + cache.WithAge[uint64, *relayUDPSession](udpTimeout), + cache.WithUpdateAgeOnGet[uint64, *relayUDPSession](), + ), + } + + switch method { + case "2022-blake3-aes-128-gcm": + s.keySaltLength = 16 + s.constructor = newAESGCM + s.blockConstructor = newAES + case "2022-blake3-aes-256-gcm": + s.keySaltLength = 32 + s.constructor = newAESGCM + s.blockConstructor = newAES + default: + return nil, os.ErrInvalid + } + if len(psk) != s.keySaltLength { + if len(psk) < s.keySaltLength { + return nil, shadowsocks.ErrBadKey + } else { + psk = Key(psk, s.keySaltLength) + } + } + s.udpBlockCipher = s.blockConstructor(psk) + return s, nil +} + +func (s *Relay[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + err := s.newConnection(ctx, conn, metadata) + if err != nil { + err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} + } + return err +} + +func (s *Relay[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + _requestHeader := buf.StackNew() + defer runtime.KeepAlive(_requestHeader) + requestHeader := common.Dup(_requestHeader) + n, err := requestHeader.ReadFrom(conn) + if err != nil { + return err + } else if int(n) < s.keySaltLength+aes.BlockSize { + return shadowaead.ErrBadHeader + } + requestSalt := requestHeader.To(s.keySaltLength) + var _eiHeader [aes.BlockSize]byte + eiHeader := common.Dup(_eiHeader[:]) + copy(eiHeader, requestHeader.Range(s.keySaltLength, s.keySaltLength+aes.BlockSize)) + + keyMaterial := buf.Make(s.keySaltLength * 2) + copy(keyMaterial, s.iPSK) + copy(keyMaterial[s.keySaltLength:], requestSalt) + _identitySubkey := buf.Make(s.keySaltLength) + identitySubkey := common.Dup(_identitySubkey) + blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) + s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader) + runtime.KeepAlive(_identitySubkey) + + var user U + if u, loaded := s.uPSKHashR[_eiHeader]; loaded { + user = u + } else { + return E.New("invalid request") + } + runtime.KeepAlive(_eiHeader) + + copy(requestHeader.Range(aes.BlockSize, aes.BlockSize+s.keySaltLength), requestHeader.To(s.keySaltLength)) + requestHeader.Advance(aes.BlockSize) + + ctx = shadowsocks.UserContext[U]{ + ctx, + user, + } + metadata.Protocol = "shadowsocks-relay" + metadata.Destination = s.uDestination[user] + conn = &bufio.BufferedConn{ + Conn: conn, + Buffer: requestHeader, + } + return s.handler.NewConnection(ctx, conn, metadata) +} + +func (s *Relay[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + err := s.newPacket(ctx, conn, buffer, metadata) + if err != nil { + err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} + } + return err +} + +func (s *Relay[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + packetHeader := buffer.To(aes.BlockSize) + s.udpBlockCipher.Decrypt(packetHeader, packetHeader) + + sessionId := binary.BigEndian.Uint64(packetHeader) + + var _eiHeader [aes.BlockSize]byte + eiHeader := common.Dup(_eiHeader[:]) + s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize)) + + for i := range eiHeader { + eiHeader[i] = eiHeader[i] ^ packetHeader[i] + } + + var user U + if u, loaded := s.uPSKHashR[_eiHeader]; loaded { + user = u + } else { + return E.New("invalid request") + } + + session, _ := s.udpSessions.LoadOrStore(sessionId, func() *relayUDPSession { + return new(relayUDPSession) + }) + session.sourceAddr = metadata.Source + + s.uCipher[user].Encrypt(packetHeader, packetHeader) + copy(buffer.Range(aes.BlockSize, 2*aes.BlockSize), packetHeader) + buffer.Advance(aes.BlockSize) + + metadata.Protocol = "shadowsocks-relay" + metadata.Destination = s.uDestination[user] + s.udpNat.NewContextPacket(ctx, sessionId, func() (context.Context, N.PacketWriter) { + return &shadowsocks.UserContext[U]{ + ctx, + user, + }, &relayPacketWriter[U]{conn, session} + }, buffer, metadata) + return nil +} + +type relayUDPSession struct { + sourceAddr M.Socksaddr +} + +type relayPacketWriter[U comparable] struct { + N.PacketConn + session *relayUDPSession +} + +func (w *relayPacketWriter[U]) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error { + return w.PacketConn.WritePacket(buffer, w.session.sourceAddr) +} diff --git a/shadowaead_2022/service.go b/shadowaead_2022/service.go new file mode 100644 index 0000000..6a4c750 --- /dev/null +++ b/shadowaead_2022/service.go @@ -0,0 +1,489 @@ +package shadowaead_2022 + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/binary" + "io" + "math" + "net" + "os" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/cache" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/replay" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/udpnat" + wgReplay "golang.zx2c4.com/wireguard/replay" +) + +var ( + ErrSaltNotUnique = E.New("bad request: salt not unique") + ErrNoPadding = E.New("bad request: missing payload or padding") + ErrBadPadding = E.New("bad request: damaged padding") +) + +type Service struct { + name string + keySaltLength int + handler shadowsocks.Handler + + constructor func(key []byte) cipher.AEAD + blockConstructor func(key []byte) cipher.Block + udpCipher cipher.AEAD + udpBlockCipher cipher.Block + psk []byte + + replayFilter replay.Filter + udpNat *udpnat.Service[uint64] + udpSessions *cache.LruCache[uint64, *serverUDPSession] +} + +func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { + if password == "" { + return nil, ErrMissingPSK + } + psk, err := base64.StdEncoding.DecodeString(password) + if err != nil { + return nil, E.Cause(err, "decode psk") + } + return NewService(method, psk, udpTimeout, handler) +} + +func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { + s := &Service{ + name: method, + handler: handler, + + replayFilter: replay.NewSimple(60 * time.Second), + udpNat: udpnat.New[uint64](udpTimeout, handler), + udpSessions: cache.New[uint64, *serverUDPSession]( + cache.WithAge[uint64, *serverUDPSession](udpTimeout), + cache.WithUpdateAgeOnGet[uint64, *serverUDPSession](), + ), + } + + switch method { + case "2022-blake3-aes-128-gcm": + s.keySaltLength = 16 + s.constructor = newAESGCM + s.blockConstructor = newAES + case "2022-blake3-aes-256-gcm": + s.keySaltLength = 32 + s.constructor = newAESGCM + s.blockConstructor = newAES + case "2022-blake3-chacha20-poly1305": + s.keySaltLength = 32 + s.constructor = newChacha20Poly1305 + default: + return nil, os.ErrInvalid + } + + if len(psk) != s.keySaltLength { + if len(psk) < s.keySaltLength { + return nil, shadowsocks.ErrBadKey + } else if len(psk) > s.keySaltLength { + psk = Key(psk, s.keySaltLength) + } else { + return nil, ErrMissingPSK + } + } + + switch method { + case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm": + s.udpBlockCipher = newAES(psk) + case "2022-blake3-chacha20-poly1305": + s.udpCipher = newXChacha20Poly1305(psk) + } + + s.psk = psk + return s, nil +} + +func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + err := s.newConnection(ctx, conn, metadata) + if err != nil { + err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} + } + return err +} + +func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + header := buf.Make(s.keySaltLength + shadowaead.Overhead + RequestHeaderFixedChunkLength) + + n, err := conn.Read(header) + if err != nil { + return E.Cause(err, "read header") + } else if n < len(header) { + return shadowaead.ErrBadHeader + } + + requestSalt := header[:s.keySaltLength] + + if !s.replayFilter.Check(requestSalt) { + return ErrSaltNotUnique + } + + requestKey := SessionKey(s.psk, requestSalt, s.keySaltLength) + reader := shadowaead.NewReader( + conn, + s.constructor(common.Dup(requestKey)), + MaxPacketSize, + ) + runtime.KeepAlive(requestKey) + + err = reader.ReadChunk(header[s.keySaltLength:]) + if err != nil { + return err + } + + headerType, err := reader.ReadByte() + if err != nil { + return E.Cause(err, "read header") + } + + if headerType != HeaderTypeClient { + return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) + } + + var epoch uint64 + err = binary.Read(reader, binary.BigEndian, &epoch) + if err != nil { + return err + } + + diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + if diff > 30 { + return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") + } + + var length uint16 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return err + } + + err = reader.ReadWithLength(length) + if err != nil { + return err + } + + destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) + if err != nil { + return err + } + + var paddingLen uint16 + err = binary.Read(reader, binary.BigEndian, &paddingLen) + if err != nil { + return err + } + + if uint16(reader.Cached()) < paddingLen { + return ErrNoPadding + } + + if paddingLen > 0 { + err = reader.Discard(int(paddingLen)) + if err != nil { + return E.Cause(err, "discard padding") + } + } else if reader.Cached() == 0 { + return ErrNoPadding + } + + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + return s.handler.NewConnection(ctx, &serverConn{ + Service: s, + Conn: conn, + uPSK: s.psk, + reader: reader, + requestSalt: requestSalt, + }, metadata) +} + +type serverConn struct { + *Service + net.Conn + uPSK []byte + access sync.Mutex + reader *shadowaead.Reader + writer *shadowaead.Writer + requestSalt []byte +} + +func (c *serverConn) writeResponse(payload []byte) (n int, err error) { + _salt := buf.Make(c.keySaltLength) + salt := common.Dup(_salt[:]) + common.Must1(io.ReadFull(rand.Reader, salt)) + key := SessionKey(c.uPSK, salt, c.keySaltLength) + runtime.KeepAlive(_salt) + writer := shadowaead.NewWriter( + c.Conn, + c.constructor(common.Dup(key)), + MaxPacketSize, + ) + runtime.KeepAlive(key) + header := writer.Buffer() + header.Write(salt) + + _headerFixedChunk := buf.Make(1 + 8 + c.keySaltLength + 2) + headerFixedChunk := buf.With(common.Dup(_headerFixedChunk)) + common.Must(headerFixedChunk.WriteByte(HeaderTypeServer)) + common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(time.Now().Unix()))) + common.Must1(headerFixedChunk.Write(c.requestSalt)) + common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(len(payload)))) + + writer.WriteChunk(header, headerFixedChunk.Slice()) + runtime.KeepAlive(_headerFixedChunk) + c.requestSalt = nil + + if len(payload) > 0 { + writer.WriteChunk(header, payload) + } + + err = writer.BufferedWriter(header.Len()).Flush() + if err != nil { + return + } + + c.writer = writer + n = len(payload) + return +} + +func (c *serverConn) Write(p []byte) (n int, err error) { + if c.writer != nil { + return c.writer.Write(p) + } + c.access.Lock() + if c.writer != nil { + c.access.Unlock() + return c.writer.Write(p) + } + defer c.access.Unlock() + return c.writeResponse(p) +} + +func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) { + if c.writer == nil { + return rw.ReadFrom0(c, r) + } + return c.writer.ReadFrom(r) +} + +func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) { + return c.reader.WriteTo(w) +} + +func (c *serverConn) Upstream() any { + return c.Conn +} + +func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + err := s.newPacket(ctx, conn, buffer, metadata) + if err != nil { + err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} + } + return err +} + +func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + var packetHeader []byte + if s.udpCipher != nil { + _, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) + if err != nil { + return E.Cause(err, "decrypt packet header") + } + buffer.Advance(PacketNonceSize) + buffer.Truncate(buffer.Len() - shadowaead.Overhead) + } else { + packetHeader = buffer.To(aes.BlockSize) + s.udpBlockCipher.Decrypt(packetHeader, packetHeader) + } + + var sessionId, packetId uint64 + err := binary.Read(buffer, binary.BigEndian, &sessionId) + if err != nil { + return err + } + err = binary.Read(buffer, binary.BigEndian, &packetId) + if err != nil { + return err + } + + session, loaded := s.udpSessions.LoadOrStore(sessionId, s.newUDPSession) + if !loaded { + session.remoteSessionId = sessionId + if packetHeader != nil { + key := SessionKey(s.psk, packetHeader[:8], s.keySaltLength) + session.remoteCipher = s.constructor(common.Dup(key)) + runtime.KeepAlive(key) + } + } + goto process + +returnErr: + if !loaded { + s.udpSessions.Delete(sessionId) + } + return err + +process: + if !session.filter.ValidateCounter(packetId, math.MaxUint64) { + err = ErrPacketIdNotUnique + goto returnErr + } + + if packetHeader != nil { + _, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) + if err != nil { + err = E.Cause(err, "decrypt packet") + goto returnErr + } + buffer.Truncate(buffer.Len() - shadowaead.Overhead) + } + + var headerType byte + headerType, err = buffer.ReadByte() + if err != nil { + err = E.Cause(err, "decrypt packet") + goto returnErr + } + if headerType != HeaderTypeClient { + err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) + goto returnErr + } + + var epoch uint64 + err = binary.Read(buffer, binary.BigEndian, &epoch) + if err != nil { + goto returnErr + } + diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + if diff > 30 { + err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") + goto returnErr + } + + var paddingLength uint16 + err = binary.Read(buffer, binary.BigEndian, &paddingLength) + if err != nil { + err = E.Cause(err, "read padding length") + goto returnErr + } + buffer.Advance(int(paddingLength)) + + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + goto returnErr + } + metadata.Destination = destination + + session.remoteAddr = metadata.Source + s.udpNat.NewPacket(ctx, sessionId, func() N.PacketWriter { + return &serverPacketWriter{s, conn, session} + }, buffer, metadata) + return nil +} + +type serverPacketWriter struct { + *Service + N.PacketConn + session *serverUDPSession +} + +func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + var hdrLen int + if w.udpCipher != nil { + hdrLen = PacketNonceSize + } + hdrLen += 16 // packet header + hdrLen += 1 // header type + hdrLen += 8 // timestamp + hdrLen += 8 // remote session id + hdrLen += 2 // padding length + hdrLen += M.SocksaddrSerializer.AddrPortLen(destination) + header := buf.With(buffer.ExtendHeader(hdrLen)) + + var dataIndex int + if w.udpCipher != nil { + common.Must1(header.ReadFullFrom(w.session.rng, PacketNonceSize)) + dataIndex = PacketNonceSize + } else { + dataIndex = aes.BlockSize + } + + common.Must( + binary.Write(header, binary.BigEndian, w.session.sessionId), + binary.Write(header, binary.BigEndian, w.session.nextPacketId()), + header.WriteByte(HeaderTypeServer), + binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())), + binary.Write(header, binary.BigEndian, w.session.remoteSessionId), + binary.Write(header, binary.BigEndian, uint16(0)), // padding length + ) + + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + + if w.udpCipher != nil { + w.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) + buffer.Extend(shadowaead.Overhead) + } else { + packetHeader := buffer.To(aes.BlockSize) + w.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) + buffer.Extend(shadowaead.Overhead) + w.udpBlockCipher.Encrypt(packetHeader, packetHeader) + } + return w.PacketConn.WritePacket(buffer, w.session.remoteAddr) +} + +type serverUDPSession struct { + sessionId uint64 + remoteSessionId uint64 + remoteAddr M.Socksaddr + packetId uint64 + cipher cipher.AEAD + remoteCipher cipher.AEAD + filter wgReplay.Filter + rng io.Reader +} + +func (s *serverUDPSession) nextPacketId() uint64 { + return atomic.AddUint64(&s.packetId, 1) +} + +func (m *Service) newUDPSession() *serverUDPSession { + session := &serverUDPSession{} + if m.udpCipher != nil { + session.rng = Blake3KeyedHash(rand.Reader) + common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) + } else { + common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) + } + session.packetId-- + if m.udpCipher == nil { + sessionId := make([]byte, 8) + binary.BigEndian.PutUint64(sessionId, session.sessionId) + key := SessionKey(m.psk, sessionId, m.keySaltLength) + session.cipher = m.constructor(common.Dup(key)) + runtime.KeepAlive(key) + } + return session +} diff --git a/shadowaead_2022/service_multi.go b/shadowaead_2022/service_multi.go new file mode 100644 index 0000000..c98c6b9 --- /dev/null +++ b/shadowaead_2022/service_multi.go @@ -0,0 +1,365 @@ +package shadowaead_2022 + +import ( + "context" + "crypto/aes" + "crypto/rand" + "encoding/base64" + "encoding/binary" + "math" + "net" + "os" + "runtime" + "time" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" + "lukechampine.com/blake3" +) + +type MultiService[U comparable] struct { + *Service + + uPSK map[U][]byte + uPSKHash map[U][aes.BlockSize]byte + uPSKHashR map[[aes.BlockSize]byte]U +} + +func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) { + if password == "" { + return nil, ErrMissingPSK + } + iPSK, err := base64.StdEncoding.DecodeString(password) + if err != nil { + return nil, E.Cause(err, "decode psk") + } + return NewMultiService[U](method, iPSK, udpTimeout, handler) +} + +func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) { + switch method { + case "2022-blake3-aes-128-gcm": + case "2022-blake3-aes-256-gcm": + default: + return nil, os.ErrInvalid + } + + ss, err := NewService(method, iPSK, udpTimeout, handler) + if err != nil { + return nil, err + } + + s := &MultiService[U]{ + Service: ss.(*Service), + + uPSK: make(map[U][]byte), + uPSKHash: make(map[U][aes.BlockSize]byte), + uPSKHashR: make(map[[aes.BlockSize]byte]U), + } + return s, nil +} + +func (s *MultiService[U]) AddUser(user U, key []byte) error { + if len(key) < s.keySaltLength { + return shadowsocks.ErrBadKey + } else if len(key) > s.keySaltLength { + key = Key(key, s.keySaltLength) + } + + var uPSKHash [aes.BlockSize]byte + hash512 := blake3.Sum512(key) + copy(uPSKHash[:], hash512[:]) + + if oldHash, loaded := s.uPSKHash[user]; loaded { + delete(s.uPSKHashR, oldHash) + } + + s.uPSKHash[user] = uPSKHash + s.uPSKHashR[uPSKHash] = user + s.uPSK[user] = key + + return nil +} + +func (s *MultiService[U]) AddUserWithPassword(user U, password string) error { + if password == "" { + return shadowsocks.ErrMissingPassword + } + psk, err := base64.StdEncoding.DecodeString(password) + if err != nil { + return E.Cause(err, "decode psk") + } + return s.AddUser(user, psk) +} + +func (s *MultiService[U]) RemoveUser(user U) { + if hash, loaded := s.uPSKHash[user]; loaded { + delete(s.uPSKHashR, hash) + } + delete(s.uPSK, user) + delete(s.uPSKHash, user) +} + +func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + err := s.newConnection(ctx, conn, metadata) + if err != nil { + err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} + } + return err +} + +func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + requestHeader := make([]byte, s.keySaltLength+aes.BlockSize+shadowaead.Overhead+RequestHeaderFixedChunkLength) + n, err := conn.Read(requestHeader) + if err != nil { + return err + } else if n < len(requestHeader) { + return shadowaead.ErrBadHeader + } + requestSalt := requestHeader[:s.keySaltLength] + if !s.replayFilter.Check(requestSalt) { + return ErrSaltNotUnique + } + + var _eiHeader [aes.BlockSize]byte + eiHeader := common.Dup(_eiHeader[:]) + copy(eiHeader, requestHeader[s.keySaltLength:s.keySaltLength+aes.BlockSize]) + + keyMaterial := buf.Make(s.keySaltLength * 2) + copy(keyMaterial, s.psk) + copy(keyMaterial[s.keySaltLength:], requestSalt) + _identitySubkey := buf.Make(s.keySaltLength) + identitySubkey := common.Dup(_identitySubkey) + blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) + s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader) + runtime.KeepAlive(_identitySubkey) + + var user U + var uPSK []byte + if u, loaded := s.uPSKHashR[_eiHeader]; loaded { + user = u + uPSK = s.uPSK[u] + } else { + return E.New("invalid request") + } + runtime.KeepAlive(_eiHeader) + + requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength) + reader := shadowaead.NewReader( + conn, + s.constructor(common.Dup(requestKey)), + MaxPacketSize, + ) + + err = reader.ReadChunk(requestHeader[s.keySaltLength+aes.BlockSize:]) + if err != nil { + return err + } + + headerType, err := rw.ReadByte(reader) + if err != nil { + return E.Cause(err, "read header") + } + + if headerType != HeaderTypeClient { + return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) + } + + var epoch uint64 + err = binary.Read(reader, binary.BigEndian, &epoch) + if err != nil { + return E.Cause(err, "read timestamp") + } + diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + if diff > 30 { + return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") + } + var length uint16 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return E.Cause(err, "read length") + } + + err = reader.ReadWithLength(length) + if err != nil { + return err + } + + destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) + if err != nil { + return E.Cause(err, "read destination") + } + + var paddingLen uint16 + err = binary.Read(reader, binary.BigEndian, &paddingLen) + if err != nil { + return E.Cause(err, "read padding length") + } + + if reader.Cached() < int(paddingLen) { + return ErrBadPadding + } else if paddingLen > 0 { + err = reader.Discard(int(paddingLen)) + if err != nil { + return E.Cause(err, "discard padding") + } + } else if reader.Cached() == 0 { + return ErrNoPadding + } + + var userCtx shadowsocks.UserContext[U] + userCtx.Context = ctx + userCtx.User = user + + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + return s.handler.NewConnection(&userCtx, &serverConn{ + Service: s.Service, + Conn: conn, + uPSK: uPSK, + reader: reader, + requestSalt: requestSalt, + }, metadata) +} + +func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + err := s.newPacket(ctx, conn, buffer, metadata) + if err != nil { + err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} + } + return err +} + +func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + packetHeader := buffer.To(aes.BlockSize) + s.udpBlockCipher.Decrypt(packetHeader, packetHeader) + + var _eiHeader [aes.BlockSize]byte + eiHeader := common.Dup(_eiHeader[:]) + s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize)) + + for i := range eiHeader { + eiHeader[i] = eiHeader[i] ^ packetHeader[i] + } + + var user U + var uPSK []byte + if u, loaded := s.uPSKHashR[_eiHeader]; loaded { + user = u + uPSK = s.uPSK[u] + } else { + return E.New("invalid request") + } + + var sessionId, packetId uint64 + err := binary.Read(buffer, binary.BigEndian, &sessionId) + if err != nil { + return err + } + err = binary.Read(buffer, binary.BigEndian, &packetId) + if err != nil { + return err + } + + session, loaded := s.udpSessions.LoadOrStore(sessionId, func() *serverUDPSession { + return s.newUDPSession(uPSK) + }) + if !loaded { + session.remoteSessionId = sessionId + key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength) + session.remoteCipher = s.constructor(common.Dup(key)) + runtime.KeepAlive(key) + } + + goto process + +returnErr: + if !loaded { + s.udpSessions.Delete(sessionId) + } + return err + +process: + if !session.filter.ValidateCounter(packetId, math.MaxUint64) { + err = ErrPacketIdNotUnique + goto returnErr + } + + if packetHeader != nil { + _, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) + if err != nil { + err = E.Cause(err, "decrypt packet") + goto returnErr + } + buffer.Truncate(buffer.Len() - shadowaead.Overhead) + } + + var headerType byte + headerType, err = buffer.ReadByte() + if err != nil { + err = E.Cause(err, "decrypt packet") + goto returnErr + } + if headerType != HeaderTypeClient { + err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) + goto returnErr + } + + var epoch uint64 + err = binary.Read(buffer, binary.BigEndian, &epoch) + if err != nil { + goto returnErr + } + diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + if diff > 30 { + err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") + goto returnErr + } + + var paddingLength uint16 + err = binary.Read(buffer, binary.BigEndian, &paddingLength) + if err != nil { + err = E.Cause(err, "read padding length") + goto returnErr + } + buffer.Advance(int(paddingLength)) + + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + goto returnErr + } + + metadata.Destination = destination + session.remoteAddr = metadata.Source + + s.udpNat.NewContextPacket(ctx, sessionId, func() (context.Context, N.PacketWriter) { + return &shadowsocks.UserContext[U]{ + ctx, + user, + }, &serverPacketWriter{s.Service, conn, session} + }, buffer, metadata) + return nil +} + +func (s *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession { + session := &serverUDPSession{} + if s.udpCipher != nil { + session.rng = Blake3KeyedHash(rand.Reader) + common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) + } else { + common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) + } + session.packetId-- + sessionId := make([]byte, 8) + binary.BigEndian.PutUint64(sessionId, session.sessionId) + key := SessionKey(uPSK, sessionId, s.keySaltLength) + session.cipher = s.constructor(common.Dup(key)) + runtime.KeepAlive(key) + return session +} diff --git a/shadowaead_2022/service_multi_test.go b/shadowaead_2022/service_multi_test.go new file mode 100644 index 0000000..6565cd2 --- /dev/null +++ b/shadowaead_2022/service_multi_test.go @@ -0,0 +1,75 @@ +package shadowaead_2022_test + +import ( + "context" + "crypto/rand" + "net" + "sync" + "testing" + + "github.com/sagernet/sing-shadowsocks/shadowaead_2022" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func TestMultiService(t *testing.T) { + method := "2022-blake3-aes-128-gcm" + var iPSK [16]byte + rand.Reader.Read(iPSK[:]) + + var wg sync.WaitGroup + + multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], 500, &multiHandler{t, &wg}) + if err != nil { + t.Fatal(err) + } + + var uPSK [16]byte + rand.Reader.Read(uPSK[:]) + multiService.AddUser("my user", uPSK[:]) + + client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}) + if err != nil { + t.Fatal(err) + } + wg.Add(1) + + serverConn, clientConn := net.Pipe() + defer common.Close(serverConn, clientConn) + go func() { + err := multiService.NewConnection(context.Background(), serverConn, M.Metadata{}) + if err != nil { + serverConn.Close() + t.Error(E.Cause(err, "server")) + return + } + }() + _, err = client.DialConn(clientConn, M.ParseSocksaddr("test.com:443")) + if err != nil { + t.Fatal(err) + } + wg.Wait() +} + +type multiHandler struct { + t *testing.T + wg *sync.WaitGroup +} + +func (h *multiHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + if metadata.Destination.String() != "test.com:443" { + h.t.Error("bad destination") + } + h.wg.Done() + return nil +} + +func (h *multiHandler) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { + return nil +} + +func (h *multiHandler) HandleError(err error) { + h.t.Error(err) +} diff --git a/shadowaead_2022/service_test.go b/shadowaead_2022/service_test.go new file mode 100644 index 0000000..d58c90a --- /dev/null +++ b/shadowaead_2022/service_test.go @@ -0,0 +1,49 @@ +package shadowaead_2022_test + +import ( + "context" + "crypto/rand" + "net" + "sync" + "testing" + + "github.com/sagernet/sing-shadowsocks/shadowaead_2022" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" +) + +func TestService(t *testing.T) { + method := "2022-blake3-aes-128-gcm" + var psk [16]byte + rand.Reader.Read(psk[:]) + + var wg sync.WaitGroup + + service, err := shadowaead_2022.NewService(method, psk[:], 500, &multiHandler{t, &wg}) + if err != nil { + t.Fatal(err) + } + + client, err := shadowaead_2022.New(method, [][]byte{psk[:]}) + if err != nil { + t.Fatal(err) + } + wg.Add(1) + + serverConn, clientConn := net.Pipe() + defer common.Close(serverConn, clientConn) + go func() { + err := service.NewConnection(context.Background(), serverConn, M.Metadata{}) + if err != nil { + serverConn.Close() + t.Error(E.Cause(err, "server")) + return + } + }() + _, err = client.DialConn(clientConn, M.ParseSocksaddr("test.com:443")) + if err != nil { + t.Fatal(err) + } + wg.Wait() +} diff --git a/shadowimpl/fetcher.go b/shadowimpl/fetcher.go new file mode 100644 index 0000000..0b15861 --- /dev/null +++ b/shadowimpl/fetcher.go @@ -0,0 +1,24 @@ +package shadowimpl + +import ( + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing-shadowsocks/shadowaead_2022" + "github.com/sagernet/sing-shadowsocks/shadowstream" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +func FetchMethod(method string, password string) (shadowsocks.Method, error) { + if method == "none" { + return shadowsocks.NewNone(), nil + } else if common.Contains(shadowstream.List, method) { + return shadowstream.New(method, nil, password) + } else if common.Contains(shadowaead.List, method) { + return shadowaead.New(method, nil, password) + } else if common.Contains(shadowaead_2022.List, method) { + return shadowaead_2022.NewWithPassword(method, password) + } else { + return nil, E.New("shadowsocks: unsupported method ", method) + } +} diff --git a/shadowsocks.go b/shadowsocks.go new file mode 100644 index 0000000..7e98679 --- /dev/null +++ b/shadowsocks.go @@ -0,0 +1,89 @@ +package shadowsocks + +import ( + "context" + "crypto/md5" + "fmt" + "net" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var ( + ErrBadKey = E.New("bad key") + ErrMissingPassword = E.New("missing password") +) + +type Method interface { + Name() string + KeyLength() int + DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) + DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn + DialPacketConn(conn net.Conn) N.NetPacketConn +} + +type Service interface { + N.TCPConnectionHandler + N.UDPHandler +} + +type Handler interface { + N.TCPConnectionHandler + N.UDPConnectionHandler + E.Handler +} + +type UserContext[U comparable] struct { + context.Context + User U +} + +type ServerConnError struct { + net.Conn + Source M.Socksaddr + Cause error +} + +func (e *ServerConnError) Close() error { + if conn, ok := common.Cast[*net.TCPConn](e.Conn); ok { + conn.SetLinger(0) + } + return e.Conn.Close() +} + +func (e *ServerConnError) Unwrap() error { + return e.Cause +} + +func (e *ServerConnError) Error() string { + return fmt.Sprint("shadowsocks: serve TCP from ", e.Source, ": ", e.Cause) +} + +type ServerPacketError struct { + Source M.Socksaddr + Cause error +} + +func (e *ServerPacketError) Unwrap() error { + return e.Cause +} + +func (e *ServerPacketError) Error() string { + return fmt.Sprint("shadowsocks: serve UDP from ", e.Source, ": ", e.Cause) +} + +func Key(password []byte, keySize int) []byte { + var b, prev []byte + h := md5.New() + for len(b) < keySize { + h.Write(prev) + h.Write([]byte(password)) + b = h.Sum(b) + prev = b[len(b)-h.Size():] + h.Reset() + } + return b[:keySize] +} diff --git a/shadowstream/protocol.go b/shadowstream/protocol.go new file mode 100644 index 0000000..1626de6 --- /dev/null +++ b/shadowstream/protocol.go @@ -0,0 +1,392 @@ +package shadowstream + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/md5" + "crypto/rand" + "crypto/rc4" + "io" + "net" + "os" + "runtime" + + "github.com/dgryski/go-camellia" + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "golang.org/x/crypto/blowfish" + "golang.org/x/crypto/cast5" + "golang.org/x/crypto/chacha20" +) + +var List = []string{ + "aes-128-ctr", + "aes-192-ctr", + "aes-256-ctr", + "aes-128-cfb", + "aes-192-cfb", + "aes-256-cfb", + "camellia-128-cfb", + "camellia-192-cfb", + "camellia-256-cfb", + "bf-cfb", + "cast5-cfb", + "des-cfb", + "rc4", + "rc4-md5", + "chacha20", + "chacha20-ietf", + "xchacha20", +} + +type Method struct { + name string + keyLength int + saltLength int + encryptConstructor func(key []byte, salt []byte) (cipher.Stream, error) + decryptConstructor func(key []byte, salt []byte) (cipher.Stream, error) + key []byte +} + +func New(method string, key []byte, password string) (shadowsocks.Method, error) { + m := &Method{ + name: method, + } + switch method { + case "aes-128-ctr": + m.keyLength = 16 + m.saltLength = aes.BlockSize + m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) + m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) + case "aes-192-ctr": + m.keyLength = 24 + m.saltLength = aes.BlockSize + m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) + m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) + case "aes-256-ctr": + m.keyLength = 32 + m.saltLength = aes.BlockSize + m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) + m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) + case "aes-128-cfb": + m.keyLength = 16 + m.saltLength = aes.BlockSize + m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) + case "aes-192-cfb": + m.keyLength = 24 + m.saltLength = aes.BlockSize + m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) + case "aes-256-cfb": + m.keyLength = 32 + m.saltLength = aes.BlockSize + m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) + case "camellia-128-cfb": + m.keyLength = 16 + m.saltLength = camellia.BlockSize + m.encryptConstructor = blockStream(camellia.New, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(camellia.New, cipher.NewCFBDecrypter) + case "camellia-192-cfb": + m.keyLength = 24 + m.saltLength = camellia.BlockSize + m.encryptConstructor = blockStream(camellia.New, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(camellia.New, cipher.NewCFBDecrypter) + case "camellia-256-cfb": + m.keyLength = 32 + m.saltLength = camellia.BlockSize + m.encryptConstructor = blockStream(camellia.New, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(camellia.New, cipher.NewCFBDecrypter) + case "bf-cfb": + m.keyLength = 16 + m.saltLength = blowfish.BlockSize + m.encryptConstructor = blockStream(func(key []byte) (cipher.Block, error) { return blowfish.NewCipher(key) }, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(func(key []byte) (cipher.Block, error) { return blowfish.NewCipher(key) }, cipher.NewCFBDecrypter) + case "cast5-cfb": + m.keyLength = 16 + m.saltLength = cast5.BlockSize + m.encryptConstructor = blockStream(func(key []byte) (cipher.Block, error) { return cast5.NewCipher(key) }, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(func(key []byte) (cipher.Block, error) { return cast5.NewCipher(key) }, cipher.NewCFBDecrypter) + case "des-cfb": + m.keyLength = 8 + m.saltLength = des.BlockSize + m.encryptConstructor = blockStream(des.NewCipher, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(des.NewCipher, cipher.NewCFBDecrypter) + case "rc4": + m.keyLength = 16 + m.saltLength = 0 + m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + return rc4.NewCipher(key) + } + m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + return rc4.NewCipher(key) + } + case "rc4-md5": + m.keyLength = 16 + m.saltLength = 0 + m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + h := md5.New() + h.Write(key) + h.Write(salt) + return rc4.NewCipher(h.Sum(nil)) + } + m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + h := md5.New() + h.Write(key) + h.Write(salt) + return rc4.NewCipher(h.Sum(nil)) + } + case "chacha20", "chacha20-ietf": + m.keyLength = chacha20.KeySize + m.saltLength = chacha20.NonceSize + m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + return chacha20.NewUnauthenticatedCipher(key, salt) + } + m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + return chacha20.NewUnauthenticatedCipher(key, salt) + } + case "xchacha20": + m.keyLength = chacha20.KeySize + m.saltLength = chacha20.NonceSizeX + m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + return chacha20.NewUnauthenticatedCipher(key, salt) + } + m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + return chacha20.NewUnauthenticatedCipher(key, salt) + } + default: + return nil, os.ErrInvalid + } + if len(key) == m.keyLength { + m.key = key + } else if len(key) > 0 { + return nil, shadowsocks.ErrBadKey + } else if password != "" { + m.key = shadowsocks.Key([]byte(password), m.keyLength) + } else { + return nil, shadowsocks.ErrMissingPassword + } + return m, nil +} + +func blockStream(blockCreator func(key []byte) (cipher.Block, error), streamCreator func(block cipher.Block, iv []byte) cipher.Stream) func([]byte, []byte) (cipher.Stream, error) { + return func(key []byte, iv []byte) (cipher.Stream, error) { + block, err := blockCreator(key) + if err != nil { + return nil, err + } + return streamCreator(block, iv), err + } +} + +func (m *Method) Name() string { + return m.name +} + +func (m *Method) KeyLength() int { + return m.keyLength +} + +func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { + shadowsocksConn := &clientConn{ + Conn: conn, + method: m, + destination: destination, + } + return shadowsocksConn, shadowsocksConn.writeRequest(nil) +} + +func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { + return &clientConn{ + Conn: conn, + method: m, + destination: destination, + } +} + +func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { + return &clientPacketConn{m, conn} +} + +type clientConn struct { + net.Conn + + method *Method + destination M.Socksaddr + + readStream cipher.Stream + writeStream cipher.Stream +} + +func (c *clientConn) writeRequest(payload []byte) error { + _buffer := buf.Make(c.method.keyLength + M.SocksaddrSerializer.AddrPortLen(c.destination) + len(payload)) + defer runtime.KeepAlive(_buffer) + buffer := buf.With(common.Dup(_buffer)) + + salt := buffer.Extend(c.method.keyLength) + common.Must1(io.ReadFull(rand.Reader, salt)) + + key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength) + writer, err := c.method.encryptConstructor(c.method.key, salt) + if err != nil { + return err + } + runtime.KeepAlive(key) + + err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination) + if err != nil { + return err + } + _, err = buffer.Write(payload) + if err != nil { + return err + } + + _, err = c.Conn.Write(buffer.Bytes()) + if err != nil { + return err + } + + c.writeStream = writer + return nil +} + +func (c *clientConn) readResponse() error { + if c.readStream != nil { + return nil + } + _salt := buf.Make(c.method.keyLength) + defer runtime.KeepAlive(_salt) + salt := common.Dup(_salt) + _, err := io.ReadFull(c.Conn, salt) + if err != nil { + return err + } + key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength) + defer runtime.KeepAlive(key) + c.readStream, err = c.method.decryptConstructor(common.Dup(key), salt) + if err != nil { + return err + } + return nil +} + +func (c *clientConn) Read(p []byte) (n int, err error) { + if err = c.readResponse(); err != nil { + return + } + n, err = c.Conn.Read(p) + if err != nil { + return 0, err + } + c.readStream.XORKeyStream(p[:n], p[:n]) + return +} + +func (c *clientConn) Write(p []byte) (n int, err error) { + if c.writeStream == nil { + err = c.writeRequest(p) + if err == nil { + n = len(p) + } + return + } + + c.writeStream.XORKeyStream(p, p) + return c.Conn.Write(p) +} + +func (c *clientConn) Upstream() any { + return c.Conn +} + +type clientPacketConn struct { + *Method + net.Conn +} + +func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buf.With(buffer.ExtendHeader(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination))) + common.Must1(header.ReadFullFrom(rand.Reader, c.keyLength)) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength)) + if err != nil { + return err + } + stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength)) + return common.Error(c.Write(buffer.Bytes())) +} + +func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + n, err := c.Read(buffer.FreeBytes()) + if err != nil { + return M.Socksaddr{}, err + } + buffer.Truncate(n) + stream, err := c.decryptConstructor(c.key, buffer.To(c.keyLength)) + if err != nil { + return M.Socksaddr{}, err + } + stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength)) + buffer.Advance(c.keyLength) + return M.SocksaddrSerializer.ReadAddrPort(buffer) +} + +func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(p) + if err != nil { + return + } + stream, err := c.decryptConstructor(c.key, p[:c.keyLength]) + if err != nil { + return + } + buffer := buf.With(p[c.keyLength:n]) + stream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return + } + addr = destination.UDPAddr() + n = copy(p, buffer.Bytes()) + return +} + +func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + destination := M.SocksaddrFromNet(addr) + _buffer := buf.Make(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p)) + defer runtime.KeepAlive(_buffer) + buffer := buf.With(common.Dup(_buffer)) + common.Must1(buffer.ReadFullFrom(rand.Reader, c.keyLength)) + err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr)) + if err != nil { + return + } + _, err = buffer.Write(p) + if err != nil { + return + } + stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength)) + if err != nil { + return + } + stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength)) + _, err = c.Write(buffer.Bytes()) + if err != nil { + return + } + return len(p), nil +} + +func (c *clientPacketConn) Upstream() any { + return c.Conn +}