commit 48809b0a994924106ed089fc186e160c46be265f Author: 世界 Date: Wed May 25 14:00:04 2022 +0800 Init commit diff --git a/.github/workflows/debug.yml b/.github/workflows/debug.yml new file mode 100644 index 0000000..f82dc4c --- /dev/null +++ b/.github/workflows/debug.yml @@ -0,0 +1,40 @@ +name: Debug build + +on: + push: + branches: + - main + paths-ignore: + - '**.md' + - '.github/**' + - '!.github/workflows/debug.yml' + pull_request: + branches: + - main + +jobs: + build: + name: Debug build + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Get latest go version + id: version + run: | + echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g') + - name: Setup Go + uses: actions/setup-go@v2 + with: + go-version: ${{ steps.version.outputs.go_version }} + - name: Build and test + run: | + version=`git rev-parse HEAD` + mkdir build + pushd build + go mod init build + go get -v github.com/sagernet/sing-shadowsocks@$version + popd + go test -v ./... \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f7f8ac3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/.idea/ +/vendor/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3e3e29e --- /dev/null +++ b/LICENSE @@ -0,0 +1,14 @@ +Copyright (C) 2022 by nekohasekai + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e7f3d0e --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# sing-shadowsocks + +Lightweight and efficient shadowsocks implementation with sing. \ No newline at end of file diff --git a/format.go b/format.go new file mode 100644 index 0000000..ccdf305 --- /dev/null +++ b/format.go @@ -0,0 +1,6 @@ +package shadowsocks + +//go:generate go install -v mvdan.cc/gofumpt@latest +//go:generate go install -v github.com/daixiang0/gci@latest +//go:generate gofumpt -l -w . +//go:generate gci write . diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a3c06a3 --- /dev/null +++ b/go.mod @@ -0,0 +1,16 @@ +module github.com/sagernet/sing-shadowsocks + +go 1.18 + +require ( + github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d + github.com/sagernet/sing v0.0.0-20220525123305-5b49366f6e34 + golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898 + golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d + lukechampine.com/blake3 v1.1.7 +) + +require ( + github.com/klauspost/cpuid/v2 v2.0.12 // indirect + golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..45dd320 --- /dev/null +++ b/go.sum @@ -0,0 +1,17 @@ +github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d h1:CPqTNIigGweVPT4CYb+OO2E6XyRKFOmvTHwWRLgCAlE= +github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d/go.mod h1:QX5ZVULjAfZJux/W62Y91HvCh9hyW6enAwcrrv/sLj0= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE= +github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= +github.com/sagernet/sing v0.0.0-20220525062603-53c607b13ff2 h1:x7E53uloX7pU3rWOzb81IBCAmwMtE2u9x4ZJvJXaCnM= +github.com/sagernet/sing v0.0.0-20220525062603-53c607b13ff2/go.mod h1:w2HnJzXKHpD6F5Z/9XlSD4qbcpHY2RSZuQnFzqgELMg= +github.com/sagernet/sing v0.0.0-20220525123305-5b49366f6e34 h1:/FfHfteLZo5mOtZbYOx/9ymDEYxlwBuM5iHO9reVe/E= +github.com/sagernet/sing v0.0.0-20220525123305-5b49366f6e34/go.mod h1:w2HnJzXKHpD6F5Z/9XlSD4qbcpHY2RSZuQnFzqgELMg= +golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898 h1:SLP7Q4Di66FONjDJbCYrCRrh97focO6sLogHO7/g8F0= +golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d h1:q4JksJ2n0fmbXC0Aj0eOs6E0AcPqnKglxWXWFqGD6x0= +golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d/go.mod h1:bVQfyl2sCM/QIIGHpWbFGfHPuDvqnCNkT6MQLTCjO/U= +lukechampine.com/blake3 v1.1.7 h1:GgRMhmdsuK8+ii6UZFDL8Nb+VyMwadAgcJyfYHxG6n0= +lukechampine.com/blake3 v1.1.7/go.mod h1:tkKEOtDkNtklkXtLNEOGNq5tcV90tJiA1vAA12R78LA= diff --git a/none.go b/none.go new file mode 100644 index 0000000..e286f1d --- /dev/null +++ b/none.go @@ -0,0 +1,241 @@ +package shadowsocks + +import ( + "context" + "io" + "net" + "net/netip" + "runtime" + "sync" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/udpnat" +) + +const MethodNone = "none" + +type NoneMethod struct{} + +func NewNone() Method { + return &NoneMethod{} +} + +func (m *NoneMethod) Name() string { + return MethodNone +} + +func (m *NoneMethod) KeyLength() int { + return 0 +} + +func (m *NoneMethod) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { + shadowsocksConn := &noneConn{ + Conn: conn, + handshake: true, + destination: destination, + } + return shadowsocksConn, shadowsocksConn.clientHandshake() +} + +func (m *NoneMethod) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { + return &noneConn{ + Conn: conn, + destination: destination, + } +} + +func (m *NoneMethod) DialPacketConn(conn net.Conn) N.NetPacketConn { + return &nonePacketConn{conn} +} + +type noneConn struct { + net.Conn + + access sync.Mutex + handshake bool + destination M.Socksaddr +} + +func (c *noneConn) clientHandshake() error { + err := M.SocksaddrSerializer.WriteAddrPort(c.Conn, c.destination) + if err != nil { + return err + } + c.handshake = true + return nil +} + +func (c *noneConn) Write(b []byte) (n int, err error) { + if c.handshake { + goto direct + } + + c.access.Lock() + defer c.access.Unlock() + + if c.handshake { + goto direct + } + + { + if len(b) == 0 { + return 0, c.clientHandshake() + } + + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) + + err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination) + if err != nil { + return + } + + bufN, _ := buffer.Write(b) + _, err = c.Conn.Write(buffer.Bytes()) + runtime.KeepAlive(_buffer) + if err != nil { + return + } + + if bufN < len(b) { + _, err = c.Conn.Write(b[bufN:]) + if err != nil { + return + } + } + + n = len(b) + } + +direct: + return c.Conn.Write(b) +} + +func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) { + if !c.handshake { + return rw.ReadFrom0(c, r) + } + return c.Conn.(io.ReaderFrom).ReadFrom(r) +} + +func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) { + return io.Copy(w, c.Conn) +} + +func (c *noneConn) RemoteAddr() net.Addr { + return c.destination.TCPAddr() +} + +type nonePacketConn struct { + net.Conn +} + +func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + _, err := buffer.ReadFrom(c) + if err != nil { + return M.Socksaddr{}, err + } + return M.SocksaddrSerializer.ReadAddrPort(buffer) +} + +func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + return common.Error(buffer.WriteTo(c)) +} + +func (c *nonePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(p) + if err != nil { + return + } + buffer := buf.With(p[:n]) + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return + } + addr = destination.UDPAddr() + n = copy(p, buffer.Bytes()) + return +} + +func (c *nonePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + destination := M.SocksaddrFromNet(addr) + _buffer := buf.Make(M.SocksaddrSerializer.AddrPortLen(destination) + len(p)) + defer runtime.KeepAlive(_buffer) + buffer := buf.With(common.Dup(_buffer)) + err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) + if err != nil { + return + } + _, err = buffer.Write(p) + if err != nil { + return + } + return len(p), nil +} + +type NoneService struct { + handler Handler + udp *udpnat.Service[netip.AddrPort] +} + +func NewNoneService(udpTimeout int64, handler Handler) Service { + s := &NoneService{ + handler: handler, + } + s.udp = udpnat.New[netip.AddrPort](udpTimeout, s) + return s +} + +func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + destination, err := M.SocksaddrSerializer.ReadAddrPort(conn) + if err != nil { + return err + } + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + return s.handler.NewConnection(ctx, conn, metadata) +} + +func (s *NoneService) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return err + } + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + s.udp.NewPacket(ctx, metadata.Source.AddrPort(), func() N.PacketWriter { + return &nonePacketWriter{conn, metadata.Source} + }, buffer, metadata) + return nil +} + +type nonePacketWriter struct { + N.PacketConn + sourceAddr M.Socksaddr +} + +func (s *nonePacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + return s.PacketConn.WritePacket(buffer, s.sourceAddr) +} + +func (s *NoneService) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { + return s.handler.NewPacketConnection(ctx, conn, metadata) +} + +func (s *NoneService) HandleError(err error) { + s.handler.HandleError(err) +} diff --git a/shadowaead/aead.go b/shadowaead/aead.go new file mode 100644 index 0000000..658cf61 --- /dev/null +++ b/shadowaead/aead.go @@ -0,0 +1,421 @@ +package shadowaead + +import ( + "crypto/cipher" + "encoding/binary" + "io" + + "github.com/sagernet/sing/common/buf" +) + +// https://shadowsocks.org/en/wiki/AEAD-Ciphers.html +const ( + MaxPacketSize = 16*1024 - 1 + PacketLengthBufferSize = 2 +) + +const ( + // NonceSize + // crypto/cipher.gcmStandardNonceSize + // golang.org/x/crypto/chacha20poly1305.NonceSize + NonceSize = 12 + + // Overhead + // crypto/cipher.gcmTagSize + // golang.org/x/crypto/chacha20poly1305.Overhead + Overhead = 16 +) + +type Reader struct { + upstream io.Reader + cipher cipher.AEAD + buffer []byte + nonce []byte + index int + cached int +} + +func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reader { + return &Reader{ + upstream: upstream, + cipher: cipher, + buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2), + nonce: make([]byte, NonceSize), + } +} + +func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce []byte) *Reader { + return &Reader{ + upstream: upstream, + cipher: cipher, + buffer: buffer, + nonce: nonce, + } +} + +func (r *Reader) Upstream() any { + return r.upstream +} + +func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) { + if r.cached > 0 { + writeN, writeErr := writer.Write(r.buffer[r.index : r.index+r.cached]) + if writeErr != nil { + return int64(writeN), writeErr + } + n += int64(writeN) + } + for { + start := PacketLengthBufferSize + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:start]) + if err != nil { + return + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil) + if err != nil { + return + } + increaseNonce(r.nonce) + length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) + end := length + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:end]) + if err != nil { + return + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) + if err != nil { + return + } + increaseNonce(r.nonce) + writeN, writeErr := writer.Write(r.buffer[:length]) + if writeErr != nil { + return int64(writeN), writeErr + } + n += int64(writeN) + } +} + +func (r *Reader) readInternal() (err error) { + start := PacketLengthBufferSize + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:start]) + if err != nil { + return err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) + end := length + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:end]) + if err != nil { + return err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + r.cached = length + r.index = 0 + return nil +} + +func (r *Reader) ReadByte() (byte, error) { + if r.cached == 0 { + err := r.readInternal() + if err != nil { + return 0, err + } + } + index := r.index + r.index++ + r.cached-- + return r.buffer[index], nil +} + +func (r *Reader) Read(b []byte) (n int, err error) { + if r.cached > 0 { + n = copy(b, r.buffer[r.index:r.index+r.cached]) + r.cached -= n + r.index += n + return + } + start := PacketLengthBufferSize + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:start]) + if err != nil { + return 0, err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil) + if err != nil { + return 0, err + } + increaseNonce(r.nonce) + length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) + end := length + Overhead + + if len(b) >= end { + data := b[:end] + _, err = io.ReadFull(r.upstream, data) + if err != nil { + return 0, err + } + _, err = r.cipher.Open(b[:0], r.nonce, data, nil) + if err != nil { + return 0, err + } + increaseNonce(r.nonce) + return length, nil + } else { + _, err = io.ReadFull(r.upstream, r.buffer[:end]) + if err != nil { + return 0, err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) + if err != nil { + return 0, err + } + increaseNonce(r.nonce) + n = copy(b, r.buffer[:length]) + r.cached = length - n + r.index = n + return + } +} + +func (r *Reader) Discard(n int) error { + for { + if r.cached >= n { + r.cached -= n + r.index += n + return nil + } else if r.cached > 0 { + n -= r.cached + r.cached = 0 + r.index = 0 + } + err := r.readInternal() + if err != nil { + return err + } + } +} + +func (r *Reader) Cached() int { + return r.cached +} + +func (r *Reader) CachedSlice() []byte { + return r.buffer[r.index : r.index+r.cached] +} + +func (r *Reader) ReadWithLengthChunk(lengthChunk []byte) error { + _, err := r.cipher.Open(r.buffer[:0], r.nonce, lengthChunk, nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) + end := length + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:end]) + if err != nil { + return err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + r.cached = length + r.index = 0 + return nil +} + +func (r *Reader) ReadWithLength(length uint16) error { + end := length + Overhead + _, err := io.ReadFull(r.upstream, r.buffer[:end]) + if err != nil { + return err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + r.cached = int(length) + r.index = 0 + return nil +} + +func (r *Reader) ReadChunk(chunk []byte) error { + bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + r.cached = len(bb) + r.index = 0 + return nil +} + +type Writer struct { + upstream io.Writer + cipher cipher.AEAD + maxPacketSize int + buffer []byte + nonce []byte +} + +func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *Writer { + return &Writer{ + upstream: upstream, + cipher: cipher, + buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2), + nonce: make([]byte, cipher.NonceSize()), + maxPacketSize: maxPacketSize, + } +} + +func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buffer []byte, nonce []byte) *Writer { + return &Writer{ + upstream: upstream, + cipher: cipher, + maxPacketSize: maxPacketSize, + buffer: buffer, + nonce: nonce, + } +} + +func (w *Writer) Upstream() any { + return w.upstream +} + +func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) { + for { + offset := Overhead + PacketLengthBufferSize + readN, readErr := r.Read(w.buffer[offset : offset+w.maxPacketSize]) + if readErr != nil { + return 0, readErr + } + binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(readN)) + w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil) + increaseNonce(w.nonce) + packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, w.buffer[offset:offset+readN], nil) + increaseNonce(w.nonce) + _, err = w.upstream.Write(w.buffer[:offset+len(packet)]) + if err != nil { + return + } + n += int64(readN) + } +} + +func (w *Writer) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return + } + + for pLen := len(p); pLen > 0; { + var data []byte + if pLen > w.maxPacketSize { + data = p[:w.maxPacketSize] + p = p[w.maxPacketSize:] + } else { + data = p + p = nil + } + binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data))) + w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil) + increaseNonce(w.nonce) + offset := Overhead + PacketLengthBufferSize + packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, data, nil) + increaseNonce(w.nonce) + _, err = w.upstream.Write(w.buffer[:offset+len(packet)]) + if err != nil { + return + } + n += len(data) + } + + return +} + +func (w *Writer) Buffer() *buf.Buffer { + return buf.With(w.buffer) +} + +func (w *Writer) WriteChunk(buffer *buf.Buffer, chunk []byte) { + bb := w.cipher.Seal(buffer.Index(buffer.Len()), w.nonce, chunk, nil) + buffer.Extend(len(bb)) + increaseNonce(w.nonce) +} + +func (w *Writer) BufferedWriter(reversed int) *BufferedWriter { + return &BufferedWriter{ + upstream: w, + reversed: reversed, + data: w.buffer[PacketLengthBufferSize+Overhead : len(w.buffer)-Overhead], + } +} + +type BufferedWriter struct { + upstream *Writer + data []byte + reversed int + index int +} + +func (w *BufferedWriter) UpstreamWriter() io.Writer { + return w.upstream +} + +func (w *BufferedWriter) WriterReplaceable() bool { + return w.index == 0 +} + +func (w *BufferedWriter) Write(p []byte) (n int, err error) { + var index int + for { + cachedN := copy(w.data[w.reversed+w.index:], p[index:]) + if cachedN == len(p[index:]) { + w.index += cachedN + return cachedN, nil + } + err = w.Flush() + if err != nil { + return + } + index += cachedN + } +} + +func (w *BufferedWriter) Flush() error { + if w.index == 0 { + if w.reversed > 0 { + _, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed]) + w.reversed = 0 + return err + } + return nil + } + buffer := w.upstream.buffer[w.reversed:] + binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index)) + w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil) + increaseNonce(w.upstream.nonce) + offset := Overhead + PacketLengthBufferSize + packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil) + increaseNonce(w.upstream.nonce) + _, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)]) + w.reversed = 0 + return err +} + +func increaseNonce(nonce []byte) { + for i := range nonce { + nonce[i]++ + if nonce[i] != 0 { + return + } + } +} diff --git a/shadowaead/protocol.go b/shadowaead/protocol.go new file mode 100644 index 0000000..1271952 --- /dev/null +++ b/shadowaead/protocol.go @@ -0,0 +1,361 @@ +package shadowaead + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha1" + "io" + "net" + "runtime" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/hkdf" +) + +var List = []string{ + "aes-128-gcm", + "aes-192-gcm", + "aes-256-gcm", + "chacha20-ietf-poly1305", + "xchacha20-ietf-poly1305", +} + +func New(method string, key []byte, password string) (shadowsocks.Method, error) { + m := &Method{ + name: method, + } + switch method { + case "aes-128-gcm": + m.keySaltLength = 16 + m.constructor = newAESGCM + case "aes-192-gcm": + m.keySaltLength = 24 + m.constructor = newAESGCM + case "aes-256-gcm": + m.keySaltLength = 32 + m.constructor = newAESGCM + case "chacha20-ietf-poly1305": + m.keySaltLength = 32 + m.constructor = func(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.New(key) + common.Must(err) + return cipher + } + case "xchacha20-ietf-poly1305": + m.keySaltLength = 32 + m.constructor = func(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.NewX(key) + common.Must(err) + return cipher + } + } + if len(key) == m.keySaltLength { + m.key = key + } else if len(key) > 0 { + return nil, shadowsocks.ErrBadKey + } else if password == "" { + return nil, shadowsocks.ErrMissingPassword + } else { + m.key = shadowsocks.Key([]byte(password), m.keySaltLength) + } + return m, nil +} + +func Kdf(key, iv []byte, keyLength int) []byte { + info := []byte("ss-subkey") + subKey := buf.Make(keyLength) + kdf := hkdf.New(sha1.New, key, iv, common.Dup(info)) + runtime.KeepAlive(info) + common.Must1(io.ReadFull(kdf, common.Dup(subKey))) + return subKey +} + +func newAESGCM(key []byte) cipher.AEAD { + block, err := aes.NewCipher(key) + common.Must(err) + aead, err := cipher.NewGCM(block) + common.Must(err) + return aead +} + +type Method struct { + name string + keySaltLength int + constructor func(key []byte) cipher.AEAD + key []byte +} + +func (m *Method) Name() string { + return m.name +} + +func (m *Method) KeyLength() int { + return m.keySaltLength +} + +func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) { + _salt := buf.Make(m.keySaltLength) + defer runtime.KeepAlive(_salt) + salt := common.Dup(_salt) + _, err := io.ReadFull(upstream, salt) + if err != nil { + return nil, E.Cause(err, "read salt") + } + key := Kdf(m.key, salt, m.keySaltLength) + defer runtime.KeepAlive(key) + return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil +} + +func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) { + _salt := buf.Make(m.keySaltLength) + defer runtime.KeepAlive(_salt) + salt := common.Dup(_salt) + common.Must1(io.ReadFull(rand.Reader, salt)) + _, err := upstream.Write(salt) + if err != nil { + return nil, err + } + key := Kdf(m.key, salt, m.keySaltLength) + return NewWriter(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil +} + +func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { + shadowsocksConn := &clientConn{ + Conn: conn, + method: m, + destination: destination, + } + return shadowsocksConn, shadowsocksConn.writeRequest(nil) +} + +func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { + return &clientConn{ + Conn: conn, + method: m, + destination: destination, + } +} + +func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { + return &clientPacketConn{m, conn} +} + +func (m *Method) EncodePacket(buffer *buf.Buffer) error { + key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength) + c := m.constructor(common.Dup(key)) + runtime.KeepAlive(key) + c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) + buffer.Extend(Overhead) + return nil +} + +func (m *Method) DecodePacket(buffer *buf.Buffer) error { + if buffer.Len() < m.keySaltLength { + return E.New("bad packet") + } + key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength) + c := m.constructor(common.Dup(key)) + runtime.KeepAlive(key) + packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) + if err != nil { + return err + } + buffer.Advance(m.keySaltLength) + buffer.Truncate(len(packet)) + return nil +} + +type clientConn struct { + net.Conn + method *Method + destination M.Socksaddr + reader *Reader + writer *Writer +} + +func (c *clientConn) writeRequest(payload []byte) error { + _salt := make([]byte, c.method.keySaltLength) + salt := common.Dup(_salt) + common.Must1(io.ReadFull(rand.Reader, salt)) + + key := Kdf(c.method.key, salt, c.method.keySaltLength) + runtime.KeepAlive(_salt) + writer := NewWriter( + c.Conn, + c.method.constructor(common.Dup(key)), + MaxPacketSize, + ) + runtime.KeepAlive(key) + header := writer.Buffer() + header.Write(salt) + bufferedWriter := writer.BufferedWriter(header.Len()) + + if len(payload) > 0 { + err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination) + if err != nil { + return err + } + + _, err = bufferedWriter.Write(payload) + if err != nil { + return err + } + } else { + err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination) + if err != nil { + return err + } + } + + err := bufferedWriter.Flush() + if err != nil { + return err + } + + c.writer = writer + return nil +} + +func (c *clientConn) readResponse() error { + if c.reader != nil { + return nil + } + _salt := buf.Make(c.method.keySaltLength) + defer runtime.KeepAlive(_salt) + salt := common.Dup(_salt) + _, err := io.ReadFull(c.Conn, salt) + if err != nil { + return err + } + key := Kdf(c.method.key, salt, c.method.keySaltLength) + defer runtime.KeepAlive(key) + c.reader = NewReader( + c.Conn, + c.method.constructor(common.Dup(key)), + MaxPacketSize, + ) + return nil +} + +func (c *clientConn) Read(p []byte) (n int, err error) { + if err = c.readResponse(); err != nil { + return + } + return c.reader.Read(p) +} + +func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) { + if err = c.readResponse(); err != nil { + return + } + return c.reader.WriteTo(w) +} + +func (c *clientConn) Write(p []byte) (n int, err error) { + if c.writer != nil { + return c.writer.Write(p) + } + + err = c.writeRequest(p) + if err != nil { + return + } + return len(p), nil +} + +func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) { + if c.writer == nil { + return rw.ReadFrom0(c, r) + } + return c.writer.ReadFrom(r) +} + +func (c *clientConn) Upstream() any { + return c.Conn +} + +type clientPacketConn struct { + *Method + net.Conn +} + +func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buffer.ExtendHeader(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination)) + common.Must1(io.ReadFull(rand.Reader, header[:c.keySaltLength])) + err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination) + if err != nil { + return err + } + err = c.EncodePacket(buffer) + if err != nil { + return err + } + return common.Error(c.Write(buffer.Bytes())) +} + +func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + n, err := c.Read(buffer.FreeBytes()) + if err != nil { + return M.Socksaddr{}, err + } + buffer.Truncate(n) + err = c.DecodePacket(buffer) + if err != nil { + return M.Socksaddr{}, err + } + return M.SocksaddrSerializer.ReadAddrPort(buffer) +} + +func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(p) + if err != nil { + return + } + b := buf.With(p[:n]) + err = c.DecodePacket(b) + if err != nil { + return + } + destination, err := M.SocksaddrSerializer.ReadAddrPort(b) + if err != nil { + return + } + addr = destination.UDPAddr() + n = copy(p, b.Bytes()) + return +} + +func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + _buffer := buf.StackNew() + defer runtime.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr)) + if err != nil { + return + } + _, err = buffer.Write(p) + if err != nil { + return + } + err = c.EncodePacket(buffer) + if err != nil { + return + } + _, err = c.Write(buffer.Bytes()) + if err != nil { + return + } + return len(p), nil +} + +func (c *clientPacketConn) Upstream() any { + return c.Conn +} diff --git a/shadowaead/service.go b/shadowaead/service.go new file mode 100644 index 0000000..5708e2f --- /dev/null +++ b/shadowaead/service.go @@ -0,0 +1,246 @@ +package shadowaead + +import ( + "context" + "crypto/cipher" + "crypto/rand" + "io" + "net" + "net/netip" + "runtime" + "sync" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/udpnat" + "golang.org/x/crypto/chacha20poly1305" +) + +var ErrBadHeader = E.New("bad header") + +type Service struct { + name string + keySaltLength int + constructor func(key []byte) cipher.AEAD + key []byte + handler shadowsocks.Handler + udpNat *udpnat.Service[netip.AddrPort] +} + +func NewService(method string, key []byte, password string, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { + s := &Service{ + name: method, + handler: handler, + udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler), + } + switch method { + case "aes-128-gcm": + s.keySaltLength = 16 + s.constructor = newAESGCM + case "aes-192-gcm": + s.keySaltLength = 24 + s.constructor = newAESGCM + case "aes-256-gcm": + s.keySaltLength = 32 + s.constructor = newAESGCM + case "chacha20-ietf-poly1305": + s.keySaltLength = 32 + s.constructor = func(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.New(key) + common.Must(err) + return cipher + } + case "xchacha20-ietf-poly1305": + s.keySaltLength = 32 + s.constructor = func(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.NewX(key) + common.Must(err) + return cipher + } + } + if len(key) == s.keySaltLength { + s.key = key + } else if len(key) > 0 { + return nil, shadowsocks.ErrBadKey + } else if password != "" { + s.key = shadowsocks.Key([]byte(password), s.keySaltLength) + } else { + return nil, shadowsocks.ErrMissingPassword + } + return s, nil +} + +func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + err := s.newConnection(ctx, conn, metadata) + if err != nil { + err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} + } + return err +} + +func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + _header := buf.Make(s.keySaltLength + PacketLengthBufferSize + Overhead) + defer runtime.KeepAlive(_header) + header := common.Dup(_header) + + n, err := conn.Read(header) + if err != nil { + return E.Cause(err, "read header") + } else if n < len(header) { + return ErrBadHeader + } + + key := Kdf(s.key, header[:s.keySaltLength], s.keySaltLength) + reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize) + + err = reader.ReadWithLengthChunk(header[s.keySaltLength:]) + if err != nil { + return err + } + + destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) + if err != nil { + return err + } + + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + + return s.handler.NewConnection(ctx, &serverConn{ + Service: s, + Conn: conn, + reader: reader, + }, metadata) +} + +type serverConn struct { + *Service + net.Conn + access sync.Mutex + reader *Reader + writer *Writer +} + +func (c *serverConn) writeResponse(payload []byte) (n int, err error) { + _salt := buf.Make(c.keySaltLength) + salt := common.Dup(_salt) + common.Must1(io.ReadFull(rand.Reader, salt)) + + key := Kdf(c.key, salt, c.keySaltLength) + runtime.KeepAlive(_salt) + + writer := NewWriter( + c.Conn, + c.constructor(common.Dup(key)), + MaxPacketSize, + ) + runtime.KeepAlive(key) + + header := writer.Buffer() + header.Write(salt) + + bufferedWriter := writer.BufferedWriter(header.Len()) + if len(payload) > 0 { + _, err = bufferedWriter.Write(payload) + if err != nil { + return + } + } + + err = bufferedWriter.Flush() + if err != nil { + return + } + + c.writer = writer + return +} + +func (c *serverConn) Write(p []byte) (n int, err error) { + if c.writer != nil { + return c.writer.Write(p) + } + c.access.Lock() + if c.writer != nil { + c.access.Unlock() + return c.writer.Write(p) + } + defer c.access.Unlock() + return c.writeResponse(p) +} + +func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) { + if c.writer == nil { + return rw.ReadFrom0(c, r) + } + return c.writer.ReadFrom(r) +} + +func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) { + return c.reader.WriteTo(w) +} + +func (c *serverConn) Upstream() any { + return c.Conn +} + +func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + err := s.newPacket(ctx, conn, buffer, metadata) + if err != nil { + err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} + } + return err +} + +func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + if buffer.Len() < s.keySaltLength { + return E.New("bad packet") + } + key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength) + c := s.constructor(common.Dup(key)) + runtime.KeepAlive(key) + packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil) + if err != nil { + return err + } + buffer.Advance(s.keySaltLength) + buffer.Truncate(len(packet)) + + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return err + } + + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + s.udpNat.NewPacket(ctx, metadata.Source.AddrPort(), func() N.PacketWriter { + return &serverPacketWriter{s, conn, metadata.Source} + }, buffer, metadata) + return nil +} + +type serverPacketWriter struct { + *Service + N.PacketConn + source M.Socksaddr +} + +func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buffer.ExtendHeader(w.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination)) + common.Must1(io.ReadFull(rand.Reader, header[:w.keySaltLength])) + err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination) + if err != nil { + return err + } + key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength) + c := w.constructor(common.Dup(key)) + runtime.KeepAlive(key) + c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil) + buffer.Extend(Overhead) + return w.PacketConn.WritePacket(buffer, w.source) +} diff --git a/shadowaead_2022/protocol.go b/shadowaead_2022/protocol.go new file mode 100644 index 0000000..f990d0b --- /dev/null +++ b/shadowaead_2022/protocol.go @@ -0,0 +1,745 @@ +package shadowaead_2022 + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "io" + "math" + mRand "math/rand" + "net" + "os" + "runtime" + "strings" + "sync/atomic" + "time" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/replay" + "github.com/sagernet/sing/common/rw" + "golang.org/x/crypto/chacha20poly1305" + wgReplay "golang.zx2c4.com/wireguard/replay" + "lukechampine.com/blake3" +) + +const ( + HeaderTypeClient = 0 + HeaderTypeServer = 1 + MaxPaddingLength = 900 + PacketNonceSize = 24 + MaxPacketSize = 65535 + RequestHeaderFixedChunkLength = 1 + 8 + 2 +) + +var ( + ErrMissingPSK = E.New("missing psk") + ErrBadHeaderType = E.New("bad header type") + ErrBadTimestamp = E.New("bad timestamp") + ErrBadRequestSalt = E.New("bad request salt") + ErrBadClientSessionId = E.New("bad client session id") + ErrPacketIdNotUnique = E.New("packet id not unique") + ErrTooManyServerSessions = E.New("server session changed more than once during the last minute") +) + +var List = []string{ + "2022-blake3-aes-128-gcm", + "2022-blake3-aes-256-gcm", + "2022-blake3-chacha20-poly1305", +} + +func NewWithPassword(method string, password string) (shadowsocks.Method, error) { + var pskList [][]byte + if password == "" { + return nil, ErrMissingPSK + } + keyStrList := strings.Split(password, ":") + pskList = make([][]byte, len(keyStrList)) + for i, keyStr := range keyStrList { + kb, err := base64.StdEncoding.DecodeString(keyStr) + if err != nil { + return nil, E.Cause(err, "decode key") + } + pskList[i] = kb + } + return New(method, pskList) +} + +func New(method string, pskList [][]byte) (shadowsocks.Method, error) { + m := &Method{ + name: method, + replayFilter: replay.NewSimple(60 * time.Second), + } + + switch method { + case "2022-blake3-aes-128-gcm": + m.keySaltLength = 16 + m.constructor = newAESGCM + m.blockConstructor = newAES + case "2022-blake3-aes-256-gcm": + m.keySaltLength = 32 + m.constructor = newAESGCM + m.blockConstructor = newAES + case "2022-blake3-chacha20-poly1305": + if len(pskList) > 1 { + return nil, os.ErrInvalid + } + m.keySaltLength = 32 + m.constructor = newChacha20Poly1305 + } + + if len(pskList) == 0 { + return nil, ErrMissingPSK + } + + for i, psk := range pskList { + if len(psk) < m.keySaltLength { + return nil, shadowsocks.ErrBadKey + } else if len(psk) > m.keySaltLength { + pskList[i] = Key(psk, m.keySaltLength) + } + } + + if len(pskList) > 1 { + pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize) + for i, psk := range pskList { + if i == 0 { + continue + } + hash := blake3.Sum512(psk) + copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], hash[:aes.BlockSize]) + } + m.pskHash = pskHash + } + + switch method { + case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm": + m.udpBlockCipher = newAES(pskList[0]) + case "2022-blake3-chacha20-poly1305": + m.udpCipher = newXChacha20Poly1305(pskList[0]) + } + + m.pskList = pskList + return m, nil +} + +func Key(key []byte, keyLength int) []byte { + psk := sha256.Sum256(key) + return psk[:keyLength] +} + +func SessionKey(psk []byte, salt []byte, keyLength int) []byte { + sessionKey := buf.Make(len(psk) + len(salt)) + copy(sessionKey, psk) + copy(sessionKey[len(psk):], salt) + outKey := buf.Make(keyLength) + blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey) + return outKey +} + +func newAES(key []byte) cipher.Block { + block, err := aes.NewCipher(key) + common.Must(err) + return block +} + +func newAESGCM(key []byte) cipher.AEAD { + block, err := aes.NewCipher(key) + common.Must(err) + aead, err := cipher.NewGCM(block) + common.Must(err) + return aead +} + +func newChacha20Poly1305(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.New(key) + common.Must(err) + return cipher +} + +func newXChacha20Poly1305(key []byte) cipher.AEAD { + cipher, err := chacha20poly1305.NewX(key) + common.Must(err) + return cipher +} + +type Method struct { + name string + keySaltLength int + constructor func(key []byte) cipher.AEAD + blockConstructor func(key []byte) cipher.Block + udpCipher cipher.AEAD + udpBlockCipher cipher.Block + pskList [][]byte + pskHash []byte + replayFilter replay.Filter +} + +func (m *Method) Name() string { + return m.name +} + +func (m *Method) KeyLength() int { + return m.keySaltLength +} + +func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { + shadowsocksConn := &clientConn{ + Method: m, + Conn: conn, + destination: destination, + } + return shadowsocksConn, shadowsocksConn.writeRequest(nil) +} + +func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { + return &clientConn{ + Method: m, + Conn: conn, + destination: destination, + } +} + +func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { + return &clientPacketConn{m, conn, m.newUDPSession()} +} + +type clientConn struct { + *Method + net.Conn + destination M.Socksaddr + requestSalt []byte + reader *shadowaead.Reader + writer *shadowaead.Writer +} + +func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) { + pskLen := len(m.pskList) + if pskLen < 2 { + return + } + for i, psk := range m.pskList { + keyMaterial := buf.Make(m.keySaltLength * 2) + copy(keyMaterial, psk) + copy(keyMaterial[m.keySaltLength:], salt) + _identitySubkey := buf.Make(m.keySaltLength) + identitySubkey := common.Dup(_identitySubkey) + blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) + + pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] + + header := request.Extend(16) + m.blockConstructor(identitySubkey).Encrypt(header, pskHash) + runtime.KeepAlive(_identitySubkey) + if i == pskLen-2 { + break + } + } +} + +func (c *clientConn) writeRequest(payload []byte) error { + salt := buf.Make(c.keySaltLength) + common.Must1(io.ReadFull(rand.Reader, salt)) + + key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength) + writer := shadowaead.NewWriter( + c.Conn, + c.constructor(common.Dup(key)), + MaxPacketSize, + ) + runtime.KeepAlive(key) + + header := writer.Buffer() + header.Write(salt) + c.writeExtendedIdentityHeaders(header, salt) + + var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte + fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:])) + common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient)) + common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(time.Now().Unix()))) + var paddingLen int + if len(payload) == 0 { + paddingLen = mRand.Intn(MaxPaddingLength + 1) + } + variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen + len(payload) + common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen))) + writer.WriteChunk(header, fixedLengthBuffer.Slice()) + runtime.KeepAlive(_fixedLengthBuffer) + + _variableLengthBuffer := buf.Make(variableLengthHeaderLen) + variableLengthBuffer := buf.With(common.Dup(_variableLengthBuffer)) + common.Must(M.SocksaddrSerializer.WriteAddrPort(variableLengthBuffer, c.destination)) + common.Must(binary.Write(variableLengthBuffer, binary.BigEndian, uint16(paddingLen))) + if paddingLen > 0 { + variableLengthBuffer.Extend(paddingLen) + } else { + common.Must1(variableLengthBuffer.Write(payload)) + } + writer.WriteChunk(header, variableLengthBuffer.Slice()) + runtime.KeepAlive(_variableLengthBuffer) + + err := writer.BufferedWriter(header.Len()).Flush() + if err != nil { + return E.Cause(err, "client handshake") + } + + c.requestSalt = salt + c.writer = writer + return nil +} + +func (c *clientConn) readResponse() error { + if c.reader != nil { + return nil + } + + _salt := buf.Make(c.keySaltLength) + salt := common.Dup(_salt) + _, err := io.ReadFull(c.Conn, salt) + if err != nil { + return err + } + + if !c.replayFilter.Check(salt) { + return ErrSaltNotUnique + } + + key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength) + runtime.KeepAlive(_salt) + reader := shadowaead.NewReader( + c.Conn, + c.constructor(common.Dup(key)), + MaxPacketSize, + ) + runtime.KeepAlive(key) + + err = reader.ReadWithLength(uint16(1 + 8 + c.keySaltLength + 2)) + if err != nil { + return E.Cause(err, "read response fixed length chunk") + } + + headerType, err := rw.ReadByte(reader) + if err != nil { + return err + } + if headerType != HeaderTypeServer { + return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType) + } + + var epoch uint64 + err = binary.Read(reader, binary.BigEndian, &epoch) + if err != nil { + return err + } + + diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + if diff > 30 { + return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") + } + + _requestSalt := buf.Make(c.keySaltLength) + requestSalt := common.Dup(_requestSalt) + _, err = io.ReadFull(reader, requestSalt) + if err != nil { + return err + } + + if bytes.Compare(requestSalt, c.requestSalt) > 0 { + return ErrBadRequestSalt + } + runtime.KeepAlive(_requestSalt) + + var length uint16 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return err + } + + err = reader.ReadWithLength(length) + if err != nil { + return err + } + + c.requestSalt = nil + c.reader = reader + + return nil +} + +func (c *clientConn) Read(p []byte) (n int, err error) { + if err = c.readResponse(); err != nil { + return + } + return c.reader.Read(p) +} + +func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) { + if err = c.readResponse(); err != nil { + return + } + return c.reader.WriteTo(w) +} + +func (c *clientConn) Write(p []byte) (n int, err error) { + if c.writer == nil { + err = c.writeRequest(p) + if err == nil { + n = len(p) + } + return + } + return c.writer.Write(p) +} + +func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) { + if c.writer == nil { + return rw.ReadFrom0(c, r) + } + return c.writer.ReadFrom(r) +} + +func (c *clientConn) Upstream() any { + return c.Conn +} + +type clientPacketConn struct { + *Method + net.Conn + session *udpSession +} + +func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + var hdrLen int + if c.udpCipher != nil { + hdrLen = PacketNonceSize + } + hdrLen += 16 // packet header + pskLen := len(c.pskList) + if c.udpCipher == nil && pskLen > 1 { + hdrLen += (pskLen - 1) * aes.BlockSize + } + hdrLen += 1 // header type + hdrLen += 8 // timestamp + hdrLen += 2 // padding length + hdrLen += M.SocksaddrSerializer.AddrPortLen(destination) + header := buf.With(buffer.ExtendHeader(hdrLen)) + + var dataIndex int + if c.udpCipher != nil { + common.Must1(header.ReadFullFrom(c.session.rng, PacketNonceSize)) + if pskLen > 1 { + panic("unsupported chacha extended header") + } + dataIndex = PacketNonceSize + } else { + dataIndex = aes.BlockSize + } + + common.Must( + binary.Write(header, binary.BigEndian, c.session.sessionId), + binary.Write(header, binary.BigEndian, c.session.nextPacketId()), + ) + + if c.udpCipher == nil && pskLen > 1 { + for i, psk := range c.pskList { + dataIndex += aes.BlockSize + pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] + + identityHeader := header.Extend(aes.BlockSize) + for textI := 0; textI < aes.BlockSize; textI++ { + identityHeader[textI] = pskHash[textI] ^ header.Byte(textI) + } + c.blockConstructor(psk).Encrypt(identityHeader, identityHeader) + + if i == pskLen-2 { + break + } + } + } + common.Must( + header.WriteByte(HeaderTypeClient), + binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())), + binary.Write(header, binary.BigEndian, uint16(0)), // padding length + ) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + if c.udpCipher != nil { + c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) + buffer.Extend(shadowaead.Overhead) + } else { + packetHeader := buffer.To(aes.BlockSize) + c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) + buffer.Extend(shadowaead.Overhead) + c.udpBlockCipher.Encrypt(packetHeader, packetHeader) + } + return common.Error(c.Write(buffer.Bytes())) +} + +func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + n, err := c.Read(buffer.FreeBytes()) + if err != nil { + return M.Socksaddr{}, err + } + buffer.Truncate(n) + + var packetHeader []byte + if c.udpCipher != nil { + _, err = c.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) + if err != nil { + return M.Socksaddr{}, E.Cause(err, "decrypt packet") + } + buffer.Advance(PacketNonceSize) + buffer.Truncate(buffer.Len() - shadowaead.Overhead) + } else { + packetHeader = buffer.To(aes.BlockSize) + c.udpBlockCipher.Decrypt(packetHeader, packetHeader) + } + + var sessionId, packetId uint64 + err = binary.Read(buffer, binary.BigEndian, &sessionId) + if err != nil { + return M.Socksaddr{}, err + } + err = binary.Read(buffer, binary.BigEndian, &packetId) + if err != nil { + return M.Socksaddr{}, err + } + + var remoteCipher cipher.AEAD + if packetHeader != nil { + if sessionId == c.session.remoteSessionId { + remoteCipher = c.session.remoteCipher + } else if sessionId == c.session.lastRemoteSessionId { + remoteCipher = c.session.lastRemoteCipher + } else { + key := SessionKey(c.pskList[len(c.pskList)-1], packetHeader[:8], c.keySaltLength) + remoteCipher = c.constructor(common.Dup(key)) + runtime.KeepAlive(key) + } + _, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) + if err != nil { + return M.Socksaddr{}, E.Cause(err, "decrypt packet") + } + buffer.Truncate(buffer.Len() - shadowaead.Overhead) + } + + var headerType byte + headerType, err = buffer.ReadByte() + if err != nil { + return M.Socksaddr{}, err + } + if headerType != HeaderTypeServer { + return M.Socksaddr{}, E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType) + } + + var epoch uint64 + err = binary.Read(buffer, binary.BigEndian, &epoch) + if err != nil { + return M.Socksaddr{}, err + } + + diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + if diff > 30 { + return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") + } + + if sessionId == c.session.remoteSessionId { + if !c.session.filter.ValidateCounter(packetId, math.MaxUint64) { + return M.Socksaddr{}, ErrPacketIdNotUnique + } + } else if sessionId == c.session.lastRemoteSessionId { + if !c.session.lastFilter.ValidateCounter(packetId, math.MaxUint64) { + return M.Socksaddr{}, ErrPacketIdNotUnique + } + remoteCipher = c.session.lastRemoteCipher + c.session.lastRemoteSeen = time.Now().Unix() + } else { + if c.session.remoteSessionId != 0 { + if time.Now().Unix()-c.session.lastRemoteSeen < 60 { + return M.Socksaddr{}, ErrTooManyServerSessions + } else { + c.session.lastRemoteSessionId = c.session.remoteSessionId + c.session.lastFilter = c.session.filter + c.session.lastRemoteSeen = time.Now().Unix() + c.session.lastRemoteCipher = c.session.remoteCipher + c.session.filter = wgReplay.Filter{} + } + } + c.session.remoteSessionId = sessionId + c.session.remoteCipher = remoteCipher + c.session.filter.ValidateCounter(packetId, math.MaxUint64) + } + + var clientSessionId uint64 + err = binary.Read(buffer, binary.BigEndian, &clientSessionId) + if err != nil { + return M.Socksaddr{}, err + } + + if clientSessionId != c.session.sessionId { + return M.Socksaddr{}, ErrBadClientSessionId + } + + var paddingLength uint16 + err = binary.Read(buffer, binary.BigEndian, &paddingLength) + if err != nil { + return M.Socksaddr{}, E.Cause(err, "read padding length") + } + buffer.Advance(int(paddingLength)) + + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return M.Socksaddr{}, err + } + return destination, nil +} + +func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + buffer := buf.With(p) + destination, err := c.ReadPacket(buffer) + if err != nil { + return + } + addr = destination.UDPAddr() + n = copy(p, buffer.Bytes()) + return +} + +func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + destination := M.SocksaddrFromNet(addr) + var overHead int + if c.udpCipher != nil { + overHead = PacketNonceSize + shadowaead.Overhead + } else { + overHead = shadowaead.Overhead + } + overHead += 16 // packet header + pskLen := len(c.pskList) + if c.udpCipher == nil && pskLen > 1 { + overHead += (pskLen - 1) * aes.BlockSize + } + overHead += 1 // header type + overHead += 8 // timestamp + overHead += 2 // padding length + overHead += M.SocksaddrSerializer.AddrPortLen(destination) + + _buffer := buf.Make(overHead + len(p)) + defer runtime.KeepAlive(_buffer) + buffer := buf.With(common.Dup(_buffer)) + + var dataIndex int + if c.udpCipher != nil { + common.Must1(buffer.ReadFullFrom(c.session.rng, PacketNonceSize)) + if pskLen > 1 { + panic("unsupported chacha extended header") + } + dataIndex = PacketNonceSize + } else { + dataIndex = aes.BlockSize + } + + common.Must( + binary.Write(buffer, binary.BigEndian, c.session.sessionId), + binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()), + ) + + if c.udpCipher == nil && pskLen > 1 { + for i, psk := range c.pskList { + dataIndex += aes.BlockSize + pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] + + identityHeader := buffer.Extend(aes.BlockSize) + for textI := 0; textI < aes.BlockSize; textI++ { + identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI) + } + c.blockConstructor(psk).Encrypt(identityHeader, identityHeader) + + if i == pskLen-2 { + break + } + } + } + common.Must( + buffer.WriteByte(HeaderTypeClient), + binary.Write(buffer, binary.BigEndian, uint64(time.Now().Unix())), + binary.Write(buffer, binary.BigEndian, uint16(0)), // padding length + ) + err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) + if err != nil { + return + } + if c.udpCipher != nil { + c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) + buffer.Extend(shadowaead.Overhead) + } else { + packetHeader := buffer.To(aes.BlockSize) + c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) + buffer.Extend(shadowaead.Overhead) + c.udpBlockCipher.Encrypt(packetHeader, packetHeader) + } + err = common.Error(c.Write(buffer.Bytes())) + if err != nil { + return + } + return len(p), nil +} + +type udpSession struct { + headerType byte + sessionId uint64 + packetId uint64 + remoteSessionId uint64 + lastRemoteSessionId uint64 + lastRemoteSeen int64 + cipher cipher.AEAD + remoteCipher cipher.AEAD + lastRemoteCipher cipher.AEAD + filter wgReplay.Filter + lastFilter wgReplay.Filter + rng io.Reader +} + +func (s *udpSession) nextPacketId() uint64 { + return atomic.AddUint64(&s.packetId, 1) +} + +func (m *Method) newUDPSession() *udpSession { + session := &udpSession{} + if m.udpCipher != nil { + session.rng = Blake3KeyedHash(rand.Reader) + common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) + } else { + common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) + } + session.packetId-- + if m.udpCipher == nil { + sessionId := make([]byte, 8) + binary.BigEndian.PutUint64(sessionId, session.sessionId) + key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength) + session.cipher = m.constructor(common.Dup(key)) + runtime.KeepAlive(key) + } + return session +} + +func (c *clientPacketConn) Upstream() any { + return c.Conn +} + +func Blake3KeyedHash(reader io.Reader) io.Reader { + key := make([]byte, 32) + common.Must1(io.ReadFull(reader, key)) + h := blake3.New(1024, key) + return h.XOF() +} diff --git a/shadowaead_2022/relay.go b/shadowaead_2022/relay.go new file mode 100644 index 0000000..61da950 --- /dev/null +++ b/shadowaead_2022/relay.go @@ -0,0 +1,233 @@ +package shadowaead_2022 + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "encoding/binary" + "io" + "net" + "os" + "runtime" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/cache" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/udpnat" + "lukechampine.com/blake3" +) + +type Relay[U comparable] struct { + name string + secureRNG io.Reader + keySaltLength int + handler shadowsocks.Handler + + constructor func(key []byte) cipher.AEAD + blockConstructor func(key []byte) cipher.Block + udpBlockCipher cipher.Block + + iPSK []byte + uPSKHash map[U][aes.BlockSize]byte + uPSKHashR map[[aes.BlockSize]byte]U + uDestination map[U]M.Socksaddr + uCipher map[U]cipher.Block + udpNat *udpnat.Service[uint64] + udpSessions *cache.LruCache[uint64, *relayUDPSession] +} + +func (s *Relay[U]) AddUser(user U, key []byte, destination M.Socksaddr) error { + if len(key) < s.keySaltLength { + return shadowsocks.ErrBadKey + } else if len(key) > s.keySaltLength { + key = Key(key, s.keySaltLength) + } + + var uPSKHash [aes.BlockSize]byte + hash512 := blake3.Sum512(key) + copy(uPSKHash[:], hash512[:]) + + if oldHash, loaded := s.uPSKHash[user]; loaded { + delete(s.uPSKHashR, oldHash) + } + + s.uPSKHash[user] = uPSKHash + s.uPSKHashR[uPSKHash] = user + s.uDestination[user] = destination + s.uCipher[user] = s.blockConstructor(key) + + return nil +} + +func (s *Relay[U]) RemoveUser(user U) { + if hash, loaded := s.uPSKHash[user]; loaded { + delete(s.uPSKHashR, hash) + } + delete(s.uPSKHash, user) + delete(s.uCipher, user) +} + +func NewRelay[U comparable](method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (*Relay[U], error) { + s := &Relay[U]{ + name: method, + secureRNG: secureRNG, + handler: handler, + + uPSKHash: make(map[U][aes.BlockSize]byte), + uPSKHashR: make(map[[aes.BlockSize]byte]U), + uDestination: make(map[U]M.Socksaddr), + uCipher: make(map[U]cipher.Block), + + udpNat: udpnat.New[uint64](udpTimeout, handler), + udpSessions: cache.New( + cache.WithAge[uint64, *relayUDPSession](udpTimeout), + cache.WithUpdateAgeOnGet[uint64, *relayUDPSession](), + ), + } + + switch method { + case "2022-blake3-aes-128-gcm": + s.keySaltLength = 16 + s.constructor = newAESGCM + s.blockConstructor = newAES + case "2022-blake3-aes-256-gcm": + s.keySaltLength = 32 + s.constructor = newAESGCM + s.blockConstructor = newAES + default: + return nil, os.ErrInvalid + } + if len(psk) != s.keySaltLength { + if len(psk) < s.keySaltLength { + return nil, shadowsocks.ErrBadKey + } else { + psk = Key(psk, s.keySaltLength) + } + } + s.udpBlockCipher = s.blockConstructor(psk) + return s, nil +} + +func (s *Relay[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + err := s.newConnection(ctx, conn, metadata) + if err != nil { + err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} + } + return err +} + +func (s *Relay[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + _requestHeader := buf.StackNew() + defer runtime.KeepAlive(_requestHeader) + requestHeader := common.Dup(_requestHeader) + n, err := requestHeader.ReadFrom(conn) + if err != nil { + return err + } else if int(n) < s.keySaltLength+aes.BlockSize { + return shadowaead.ErrBadHeader + } + requestSalt := requestHeader.To(s.keySaltLength) + var _eiHeader [aes.BlockSize]byte + eiHeader := common.Dup(_eiHeader[:]) + copy(eiHeader, requestHeader.Range(s.keySaltLength, s.keySaltLength+aes.BlockSize)) + + keyMaterial := buf.Make(s.keySaltLength * 2) + copy(keyMaterial, s.iPSK) + copy(keyMaterial[s.keySaltLength:], requestSalt) + _identitySubkey := buf.Make(s.keySaltLength) + identitySubkey := common.Dup(_identitySubkey) + blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) + s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader) + runtime.KeepAlive(_identitySubkey) + + var user U + if u, loaded := s.uPSKHashR[_eiHeader]; loaded { + user = u + } else { + return E.New("invalid request") + } + runtime.KeepAlive(_eiHeader) + + copy(requestHeader.Range(aes.BlockSize, aes.BlockSize+s.keySaltLength), requestHeader.To(s.keySaltLength)) + requestHeader.Advance(aes.BlockSize) + + ctx = shadowsocks.UserContext[U]{ + ctx, + user, + } + metadata.Protocol = "shadowsocks-relay" + metadata.Destination = s.uDestination[user] + conn = &bufio.BufferedConn{ + Conn: conn, + Buffer: requestHeader, + } + return s.handler.NewConnection(ctx, conn, metadata) +} + +func (s *Relay[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + err := s.newPacket(ctx, conn, buffer, metadata) + if err != nil { + err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} + } + return err +} + +func (s *Relay[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + packetHeader := buffer.To(aes.BlockSize) + s.udpBlockCipher.Decrypt(packetHeader, packetHeader) + + sessionId := binary.BigEndian.Uint64(packetHeader) + + var _eiHeader [aes.BlockSize]byte + eiHeader := common.Dup(_eiHeader[:]) + s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize)) + + for i := range eiHeader { + eiHeader[i] = eiHeader[i] ^ packetHeader[i] + } + + var user U + if u, loaded := s.uPSKHashR[_eiHeader]; loaded { + user = u + } else { + return E.New("invalid request") + } + + session, _ := s.udpSessions.LoadOrStore(sessionId, func() *relayUDPSession { + return new(relayUDPSession) + }) + session.sourceAddr = metadata.Source + + s.uCipher[user].Encrypt(packetHeader, packetHeader) + copy(buffer.Range(aes.BlockSize, 2*aes.BlockSize), packetHeader) + buffer.Advance(aes.BlockSize) + + metadata.Protocol = "shadowsocks-relay" + metadata.Destination = s.uDestination[user] + s.udpNat.NewContextPacket(ctx, sessionId, func() (context.Context, N.PacketWriter) { + return &shadowsocks.UserContext[U]{ + ctx, + user, + }, &relayPacketWriter[U]{conn, session} + }, buffer, metadata) + return nil +} + +type relayUDPSession struct { + sourceAddr M.Socksaddr +} + +type relayPacketWriter[U comparable] struct { + N.PacketConn + session *relayUDPSession +} + +func (w *relayPacketWriter[U]) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error { + return w.PacketConn.WritePacket(buffer, w.session.sourceAddr) +} diff --git a/shadowaead_2022/service.go b/shadowaead_2022/service.go new file mode 100644 index 0000000..6a4c750 --- /dev/null +++ b/shadowaead_2022/service.go @@ -0,0 +1,489 @@ +package shadowaead_2022 + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/binary" + "io" + "math" + "net" + "os" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/cache" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/replay" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/udpnat" + wgReplay "golang.zx2c4.com/wireguard/replay" +) + +var ( + ErrSaltNotUnique = E.New("bad request: salt not unique") + ErrNoPadding = E.New("bad request: missing payload or padding") + ErrBadPadding = E.New("bad request: damaged padding") +) + +type Service struct { + name string + keySaltLength int + handler shadowsocks.Handler + + constructor func(key []byte) cipher.AEAD + blockConstructor func(key []byte) cipher.Block + udpCipher cipher.AEAD + udpBlockCipher cipher.Block + psk []byte + + replayFilter replay.Filter + udpNat *udpnat.Service[uint64] + udpSessions *cache.LruCache[uint64, *serverUDPSession] +} + +func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { + if password == "" { + return nil, ErrMissingPSK + } + psk, err := base64.StdEncoding.DecodeString(password) + if err != nil { + return nil, E.Cause(err, "decode psk") + } + return NewService(method, psk, udpTimeout, handler) +} + +func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { + s := &Service{ + name: method, + handler: handler, + + replayFilter: replay.NewSimple(60 * time.Second), + udpNat: udpnat.New[uint64](udpTimeout, handler), + udpSessions: cache.New[uint64, *serverUDPSession]( + cache.WithAge[uint64, *serverUDPSession](udpTimeout), + cache.WithUpdateAgeOnGet[uint64, *serverUDPSession](), + ), + } + + switch method { + case "2022-blake3-aes-128-gcm": + s.keySaltLength = 16 + s.constructor = newAESGCM + s.blockConstructor = newAES + case "2022-blake3-aes-256-gcm": + s.keySaltLength = 32 + s.constructor = newAESGCM + s.blockConstructor = newAES + case "2022-blake3-chacha20-poly1305": + s.keySaltLength = 32 + s.constructor = newChacha20Poly1305 + default: + return nil, os.ErrInvalid + } + + if len(psk) != s.keySaltLength { + if len(psk) < s.keySaltLength { + return nil, shadowsocks.ErrBadKey + } else if len(psk) > s.keySaltLength { + psk = Key(psk, s.keySaltLength) + } else { + return nil, ErrMissingPSK + } + } + + switch method { + case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm": + s.udpBlockCipher = newAES(psk) + case "2022-blake3-chacha20-poly1305": + s.udpCipher = newXChacha20Poly1305(psk) + } + + s.psk = psk + return s, nil +} + +func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + err := s.newConnection(ctx, conn, metadata) + if err != nil { + err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} + } + return err +} + +func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + header := buf.Make(s.keySaltLength + shadowaead.Overhead + RequestHeaderFixedChunkLength) + + n, err := conn.Read(header) + if err != nil { + return E.Cause(err, "read header") + } else if n < len(header) { + return shadowaead.ErrBadHeader + } + + requestSalt := header[:s.keySaltLength] + + if !s.replayFilter.Check(requestSalt) { + return ErrSaltNotUnique + } + + requestKey := SessionKey(s.psk, requestSalt, s.keySaltLength) + reader := shadowaead.NewReader( + conn, + s.constructor(common.Dup(requestKey)), + MaxPacketSize, + ) + runtime.KeepAlive(requestKey) + + err = reader.ReadChunk(header[s.keySaltLength:]) + if err != nil { + return err + } + + headerType, err := reader.ReadByte() + if err != nil { + return E.Cause(err, "read header") + } + + if headerType != HeaderTypeClient { + return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) + } + + var epoch uint64 + err = binary.Read(reader, binary.BigEndian, &epoch) + if err != nil { + return err + } + + diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + if diff > 30 { + return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") + } + + var length uint16 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return err + } + + err = reader.ReadWithLength(length) + if err != nil { + return err + } + + destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) + if err != nil { + return err + } + + var paddingLen uint16 + err = binary.Read(reader, binary.BigEndian, &paddingLen) + if err != nil { + return err + } + + if uint16(reader.Cached()) < paddingLen { + return ErrNoPadding + } + + if paddingLen > 0 { + err = reader.Discard(int(paddingLen)) + if err != nil { + return E.Cause(err, "discard padding") + } + } else if reader.Cached() == 0 { + return ErrNoPadding + } + + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + return s.handler.NewConnection(ctx, &serverConn{ + Service: s, + Conn: conn, + uPSK: s.psk, + reader: reader, + requestSalt: requestSalt, + }, metadata) +} + +type serverConn struct { + *Service + net.Conn + uPSK []byte + access sync.Mutex + reader *shadowaead.Reader + writer *shadowaead.Writer + requestSalt []byte +} + +func (c *serverConn) writeResponse(payload []byte) (n int, err error) { + _salt := buf.Make(c.keySaltLength) + salt := common.Dup(_salt[:]) + common.Must1(io.ReadFull(rand.Reader, salt)) + key := SessionKey(c.uPSK, salt, c.keySaltLength) + runtime.KeepAlive(_salt) + writer := shadowaead.NewWriter( + c.Conn, + c.constructor(common.Dup(key)), + MaxPacketSize, + ) + runtime.KeepAlive(key) + header := writer.Buffer() + header.Write(salt) + + _headerFixedChunk := buf.Make(1 + 8 + c.keySaltLength + 2) + headerFixedChunk := buf.With(common.Dup(_headerFixedChunk)) + common.Must(headerFixedChunk.WriteByte(HeaderTypeServer)) + common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(time.Now().Unix()))) + common.Must1(headerFixedChunk.Write(c.requestSalt)) + common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(len(payload)))) + + writer.WriteChunk(header, headerFixedChunk.Slice()) + runtime.KeepAlive(_headerFixedChunk) + c.requestSalt = nil + + if len(payload) > 0 { + writer.WriteChunk(header, payload) + } + + err = writer.BufferedWriter(header.Len()).Flush() + if err != nil { + return + } + + c.writer = writer + n = len(payload) + return +} + +func (c *serverConn) Write(p []byte) (n int, err error) { + if c.writer != nil { + return c.writer.Write(p) + } + c.access.Lock() + if c.writer != nil { + c.access.Unlock() + return c.writer.Write(p) + } + defer c.access.Unlock() + return c.writeResponse(p) +} + +func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) { + if c.writer == nil { + return rw.ReadFrom0(c, r) + } + return c.writer.ReadFrom(r) +} + +func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) { + return c.reader.WriteTo(w) +} + +func (c *serverConn) Upstream() any { + return c.Conn +} + +func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + err := s.newPacket(ctx, conn, buffer, metadata) + if err != nil { + err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} + } + return err +} + +func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + var packetHeader []byte + if s.udpCipher != nil { + _, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) + if err != nil { + return E.Cause(err, "decrypt packet header") + } + buffer.Advance(PacketNonceSize) + buffer.Truncate(buffer.Len() - shadowaead.Overhead) + } else { + packetHeader = buffer.To(aes.BlockSize) + s.udpBlockCipher.Decrypt(packetHeader, packetHeader) + } + + var sessionId, packetId uint64 + err := binary.Read(buffer, binary.BigEndian, &sessionId) + if err != nil { + return err + } + err = binary.Read(buffer, binary.BigEndian, &packetId) + if err != nil { + return err + } + + session, loaded := s.udpSessions.LoadOrStore(sessionId, s.newUDPSession) + if !loaded { + session.remoteSessionId = sessionId + if packetHeader != nil { + key := SessionKey(s.psk, packetHeader[:8], s.keySaltLength) + session.remoteCipher = s.constructor(common.Dup(key)) + runtime.KeepAlive(key) + } + } + goto process + +returnErr: + if !loaded { + s.udpSessions.Delete(sessionId) + } + return err + +process: + if !session.filter.ValidateCounter(packetId, math.MaxUint64) { + err = ErrPacketIdNotUnique + goto returnErr + } + + if packetHeader != nil { + _, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) + if err != nil { + err = E.Cause(err, "decrypt packet") + goto returnErr + } + buffer.Truncate(buffer.Len() - shadowaead.Overhead) + } + + var headerType byte + headerType, err = buffer.ReadByte() + if err != nil { + err = E.Cause(err, "decrypt packet") + goto returnErr + } + if headerType != HeaderTypeClient { + err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) + goto returnErr + } + + var epoch uint64 + err = binary.Read(buffer, binary.BigEndian, &epoch) + if err != nil { + goto returnErr + } + diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + if diff > 30 { + err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") + goto returnErr + } + + var paddingLength uint16 + err = binary.Read(buffer, binary.BigEndian, &paddingLength) + if err != nil { + err = E.Cause(err, "read padding length") + goto returnErr + } + buffer.Advance(int(paddingLength)) + + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + goto returnErr + } + metadata.Destination = destination + + session.remoteAddr = metadata.Source + s.udpNat.NewPacket(ctx, sessionId, func() N.PacketWriter { + return &serverPacketWriter{s, conn, session} + }, buffer, metadata) + return nil +} + +type serverPacketWriter struct { + *Service + N.PacketConn + session *serverUDPSession +} + +func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + var hdrLen int + if w.udpCipher != nil { + hdrLen = PacketNonceSize + } + hdrLen += 16 // packet header + hdrLen += 1 // header type + hdrLen += 8 // timestamp + hdrLen += 8 // remote session id + hdrLen += 2 // padding length + hdrLen += M.SocksaddrSerializer.AddrPortLen(destination) + header := buf.With(buffer.ExtendHeader(hdrLen)) + + var dataIndex int + if w.udpCipher != nil { + common.Must1(header.ReadFullFrom(w.session.rng, PacketNonceSize)) + dataIndex = PacketNonceSize + } else { + dataIndex = aes.BlockSize + } + + common.Must( + binary.Write(header, binary.BigEndian, w.session.sessionId), + binary.Write(header, binary.BigEndian, w.session.nextPacketId()), + header.WriteByte(HeaderTypeServer), + binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())), + binary.Write(header, binary.BigEndian, w.session.remoteSessionId), + binary.Write(header, binary.BigEndian, uint16(0)), // padding length + ) + + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + + if w.udpCipher != nil { + w.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) + buffer.Extend(shadowaead.Overhead) + } else { + packetHeader := buffer.To(aes.BlockSize) + w.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) + buffer.Extend(shadowaead.Overhead) + w.udpBlockCipher.Encrypt(packetHeader, packetHeader) + } + return w.PacketConn.WritePacket(buffer, w.session.remoteAddr) +} + +type serverUDPSession struct { + sessionId uint64 + remoteSessionId uint64 + remoteAddr M.Socksaddr + packetId uint64 + cipher cipher.AEAD + remoteCipher cipher.AEAD + filter wgReplay.Filter + rng io.Reader +} + +func (s *serverUDPSession) nextPacketId() uint64 { + return atomic.AddUint64(&s.packetId, 1) +} + +func (m *Service) newUDPSession() *serverUDPSession { + session := &serverUDPSession{} + if m.udpCipher != nil { + session.rng = Blake3KeyedHash(rand.Reader) + common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) + } else { + common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) + } + session.packetId-- + if m.udpCipher == nil { + sessionId := make([]byte, 8) + binary.BigEndian.PutUint64(sessionId, session.sessionId) + key := SessionKey(m.psk, sessionId, m.keySaltLength) + session.cipher = m.constructor(common.Dup(key)) + runtime.KeepAlive(key) + } + return session +} diff --git a/shadowaead_2022/service_multi.go b/shadowaead_2022/service_multi.go new file mode 100644 index 0000000..c98c6b9 --- /dev/null +++ b/shadowaead_2022/service_multi.go @@ -0,0 +1,365 @@ +package shadowaead_2022 + +import ( + "context" + "crypto/aes" + "crypto/rand" + "encoding/base64" + "encoding/binary" + "math" + "net" + "os" + "runtime" + "time" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" + "lukechampine.com/blake3" +) + +type MultiService[U comparable] struct { + *Service + + uPSK map[U][]byte + uPSKHash map[U][aes.BlockSize]byte + uPSKHashR map[[aes.BlockSize]byte]U +} + +func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) { + if password == "" { + return nil, ErrMissingPSK + } + iPSK, err := base64.StdEncoding.DecodeString(password) + if err != nil { + return nil, E.Cause(err, "decode psk") + } + return NewMultiService[U](method, iPSK, udpTimeout, handler) +} + +func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) { + switch method { + case "2022-blake3-aes-128-gcm": + case "2022-blake3-aes-256-gcm": + default: + return nil, os.ErrInvalid + } + + ss, err := NewService(method, iPSK, udpTimeout, handler) + if err != nil { + return nil, err + } + + s := &MultiService[U]{ + Service: ss.(*Service), + + uPSK: make(map[U][]byte), + uPSKHash: make(map[U][aes.BlockSize]byte), + uPSKHashR: make(map[[aes.BlockSize]byte]U), + } + return s, nil +} + +func (s *MultiService[U]) AddUser(user U, key []byte) error { + if len(key) < s.keySaltLength { + return shadowsocks.ErrBadKey + } else if len(key) > s.keySaltLength { + key = Key(key, s.keySaltLength) + } + + var uPSKHash [aes.BlockSize]byte + hash512 := blake3.Sum512(key) + copy(uPSKHash[:], hash512[:]) + + if oldHash, loaded := s.uPSKHash[user]; loaded { + delete(s.uPSKHashR, oldHash) + } + + s.uPSKHash[user] = uPSKHash + s.uPSKHashR[uPSKHash] = user + s.uPSK[user] = key + + return nil +} + +func (s *MultiService[U]) AddUserWithPassword(user U, password string) error { + if password == "" { + return shadowsocks.ErrMissingPassword + } + psk, err := base64.StdEncoding.DecodeString(password) + if err != nil { + return E.Cause(err, "decode psk") + } + return s.AddUser(user, psk) +} + +func (s *MultiService[U]) RemoveUser(user U) { + if hash, loaded := s.uPSKHash[user]; loaded { + delete(s.uPSKHashR, hash) + } + delete(s.uPSK, user) + delete(s.uPSKHash, user) +} + +func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + err := s.newConnection(ctx, conn, metadata) + if err != nil { + err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} + } + return err +} + +func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + requestHeader := make([]byte, s.keySaltLength+aes.BlockSize+shadowaead.Overhead+RequestHeaderFixedChunkLength) + n, err := conn.Read(requestHeader) + if err != nil { + return err + } else if n < len(requestHeader) { + return shadowaead.ErrBadHeader + } + requestSalt := requestHeader[:s.keySaltLength] + if !s.replayFilter.Check(requestSalt) { + return ErrSaltNotUnique + } + + var _eiHeader [aes.BlockSize]byte + eiHeader := common.Dup(_eiHeader[:]) + copy(eiHeader, requestHeader[s.keySaltLength:s.keySaltLength+aes.BlockSize]) + + keyMaterial := buf.Make(s.keySaltLength * 2) + copy(keyMaterial, s.psk) + copy(keyMaterial[s.keySaltLength:], requestSalt) + _identitySubkey := buf.Make(s.keySaltLength) + identitySubkey := common.Dup(_identitySubkey) + blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) + s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader) + runtime.KeepAlive(_identitySubkey) + + var user U + var uPSK []byte + if u, loaded := s.uPSKHashR[_eiHeader]; loaded { + user = u + uPSK = s.uPSK[u] + } else { + return E.New("invalid request") + } + runtime.KeepAlive(_eiHeader) + + requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength) + reader := shadowaead.NewReader( + conn, + s.constructor(common.Dup(requestKey)), + MaxPacketSize, + ) + + err = reader.ReadChunk(requestHeader[s.keySaltLength+aes.BlockSize:]) + if err != nil { + return err + } + + headerType, err := rw.ReadByte(reader) + if err != nil { + return E.Cause(err, "read header") + } + + if headerType != HeaderTypeClient { + return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) + } + + var epoch uint64 + err = binary.Read(reader, binary.BigEndian, &epoch) + if err != nil { + return E.Cause(err, "read timestamp") + } + diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + if diff > 30 { + return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") + } + var length uint16 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return E.Cause(err, "read length") + } + + err = reader.ReadWithLength(length) + if err != nil { + return err + } + + destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) + if err != nil { + return E.Cause(err, "read destination") + } + + var paddingLen uint16 + err = binary.Read(reader, binary.BigEndian, &paddingLen) + if err != nil { + return E.Cause(err, "read padding length") + } + + if reader.Cached() < int(paddingLen) { + return ErrBadPadding + } else if paddingLen > 0 { + err = reader.Discard(int(paddingLen)) + if err != nil { + return E.Cause(err, "discard padding") + } + } else if reader.Cached() == 0 { + return ErrNoPadding + } + + var userCtx shadowsocks.UserContext[U] + userCtx.Context = ctx + userCtx.User = user + + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + return s.handler.NewConnection(&userCtx, &serverConn{ + Service: s.Service, + Conn: conn, + uPSK: uPSK, + reader: reader, + requestSalt: requestSalt, + }, metadata) +} + +func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + err := s.newPacket(ctx, conn, buffer, metadata) + if err != nil { + err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} + } + return err +} + +func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + packetHeader := buffer.To(aes.BlockSize) + s.udpBlockCipher.Decrypt(packetHeader, packetHeader) + + var _eiHeader [aes.BlockSize]byte + eiHeader := common.Dup(_eiHeader[:]) + s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize)) + + for i := range eiHeader { + eiHeader[i] = eiHeader[i] ^ packetHeader[i] + } + + var user U + var uPSK []byte + if u, loaded := s.uPSKHashR[_eiHeader]; loaded { + user = u + uPSK = s.uPSK[u] + } else { + return E.New("invalid request") + } + + var sessionId, packetId uint64 + err := binary.Read(buffer, binary.BigEndian, &sessionId) + if err != nil { + return err + } + err = binary.Read(buffer, binary.BigEndian, &packetId) + if err != nil { + return err + } + + session, loaded := s.udpSessions.LoadOrStore(sessionId, func() *serverUDPSession { + return s.newUDPSession(uPSK) + }) + if !loaded { + session.remoteSessionId = sessionId + key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength) + session.remoteCipher = s.constructor(common.Dup(key)) + runtime.KeepAlive(key) + } + + goto process + +returnErr: + if !loaded { + s.udpSessions.Delete(sessionId) + } + return err + +process: + if !session.filter.ValidateCounter(packetId, math.MaxUint64) { + err = ErrPacketIdNotUnique + goto returnErr + } + + if packetHeader != nil { + _, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) + if err != nil { + err = E.Cause(err, "decrypt packet") + goto returnErr + } + buffer.Truncate(buffer.Len() - shadowaead.Overhead) + } + + var headerType byte + headerType, err = buffer.ReadByte() + if err != nil { + err = E.Cause(err, "decrypt packet") + goto returnErr + } + if headerType != HeaderTypeClient { + err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) + goto returnErr + } + + var epoch uint64 + err = binary.Read(buffer, binary.BigEndian, &epoch) + if err != nil { + goto returnErr + } + diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + if diff > 30 { + err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") + goto returnErr + } + + var paddingLength uint16 + err = binary.Read(buffer, binary.BigEndian, &paddingLength) + if err != nil { + err = E.Cause(err, "read padding length") + goto returnErr + } + buffer.Advance(int(paddingLength)) + + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + goto returnErr + } + + metadata.Destination = destination + session.remoteAddr = metadata.Source + + s.udpNat.NewContextPacket(ctx, sessionId, func() (context.Context, N.PacketWriter) { + return &shadowsocks.UserContext[U]{ + ctx, + user, + }, &serverPacketWriter{s.Service, conn, session} + }, buffer, metadata) + return nil +} + +func (s *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession { + session := &serverUDPSession{} + if s.udpCipher != nil { + session.rng = Blake3KeyedHash(rand.Reader) + common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) + } else { + common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) + } + session.packetId-- + sessionId := make([]byte, 8) + binary.BigEndian.PutUint64(sessionId, session.sessionId) + key := SessionKey(uPSK, sessionId, s.keySaltLength) + session.cipher = s.constructor(common.Dup(key)) + runtime.KeepAlive(key) + return session +} diff --git a/shadowaead_2022/service_multi_test.go b/shadowaead_2022/service_multi_test.go new file mode 100644 index 0000000..6565cd2 --- /dev/null +++ b/shadowaead_2022/service_multi_test.go @@ -0,0 +1,75 @@ +package shadowaead_2022_test + +import ( + "context" + "crypto/rand" + "net" + "sync" + "testing" + + "github.com/sagernet/sing-shadowsocks/shadowaead_2022" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func TestMultiService(t *testing.T) { + method := "2022-blake3-aes-128-gcm" + var iPSK [16]byte + rand.Reader.Read(iPSK[:]) + + var wg sync.WaitGroup + + multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], 500, &multiHandler{t, &wg}) + if err != nil { + t.Fatal(err) + } + + var uPSK [16]byte + rand.Reader.Read(uPSK[:]) + multiService.AddUser("my user", uPSK[:]) + + client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}) + if err != nil { + t.Fatal(err) + } + wg.Add(1) + + serverConn, clientConn := net.Pipe() + defer common.Close(serverConn, clientConn) + go func() { + err := multiService.NewConnection(context.Background(), serverConn, M.Metadata{}) + if err != nil { + serverConn.Close() + t.Error(E.Cause(err, "server")) + return + } + }() + _, err = client.DialConn(clientConn, M.ParseSocksaddr("test.com:443")) + if err != nil { + t.Fatal(err) + } + wg.Wait() +} + +type multiHandler struct { + t *testing.T + wg *sync.WaitGroup +} + +func (h *multiHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + if metadata.Destination.String() != "test.com:443" { + h.t.Error("bad destination") + } + h.wg.Done() + return nil +} + +func (h *multiHandler) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { + return nil +} + +func (h *multiHandler) HandleError(err error) { + h.t.Error(err) +} diff --git a/shadowaead_2022/service_test.go b/shadowaead_2022/service_test.go new file mode 100644 index 0000000..d58c90a --- /dev/null +++ b/shadowaead_2022/service_test.go @@ -0,0 +1,49 @@ +package shadowaead_2022_test + +import ( + "context" + "crypto/rand" + "net" + "sync" + "testing" + + "github.com/sagernet/sing-shadowsocks/shadowaead_2022" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" +) + +func TestService(t *testing.T) { + method := "2022-blake3-aes-128-gcm" + var psk [16]byte + rand.Reader.Read(psk[:]) + + var wg sync.WaitGroup + + service, err := shadowaead_2022.NewService(method, psk[:], 500, &multiHandler{t, &wg}) + if err != nil { + t.Fatal(err) + } + + client, err := shadowaead_2022.New(method, [][]byte{psk[:]}) + if err != nil { + t.Fatal(err) + } + wg.Add(1) + + serverConn, clientConn := net.Pipe() + defer common.Close(serverConn, clientConn) + go func() { + err := service.NewConnection(context.Background(), serverConn, M.Metadata{}) + if err != nil { + serverConn.Close() + t.Error(E.Cause(err, "server")) + return + } + }() + _, err = client.DialConn(clientConn, M.ParseSocksaddr("test.com:443")) + if err != nil { + t.Fatal(err) + } + wg.Wait() +} diff --git a/shadowimpl/fetcher.go b/shadowimpl/fetcher.go new file mode 100644 index 0000000..0b15861 --- /dev/null +++ b/shadowimpl/fetcher.go @@ -0,0 +1,24 @@ +package shadowimpl + +import ( + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing-shadowsocks/shadowaead_2022" + "github.com/sagernet/sing-shadowsocks/shadowstream" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +func FetchMethod(method string, password string) (shadowsocks.Method, error) { + if method == "none" { + return shadowsocks.NewNone(), nil + } else if common.Contains(shadowstream.List, method) { + return shadowstream.New(method, nil, password) + } else if common.Contains(shadowaead.List, method) { + return shadowaead.New(method, nil, password) + } else if common.Contains(shadowaead_2022.List, method) { + return shadowaead_2022.NewWithPassword(method, password) + } else { + return nil, E.New("shadowsocks: unsupported method ", method) + } +} diff --git a/shadowsocks.go b/shadowsocks.go new file mode 100644 index 0000000..7e98679 --- /dev/null +++ b/shadowsocks.go @@ -0,0 +1,89 @@ +package shadowsocks + +import ( + "context" + "crypto/md5" + "fmt" + "net" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var ( + ErrBadKey = E.New("bad key") + ErrMissingPassword = E.New("missing password") +) + +type Method interface { + Name() string + KeyLength() int + DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) + DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn + DialPacketConn(conn net.Conn) N.NetPacketConn +} + +type Service interface { + N.TCPConnectionHandler + N.UDPHandler +} + +type Handler interface { + N.TCPConnectionHandler + N.UDPConnectionHandler + E.Handler +} + +type UserContext[U comparable] struct { + context.Context + User U +} + +type ServerConnError struct { + net.Conn + Source M.Socksaddr + Cause error +} + +func (e *ServerConnError) Close() error { + if conn, ok := common.Cast[*net.TCPConn](e.Conn); ok { + conn.SetLinger(0) + } + return e.Conn.Close() +} + +func (e *ServerConnError) Unwrap() error { + return e.Cause +} + +func (e *ServerConnError) Error() string { + return fmt.Sprint("shadowsocks: serve TCP from ", e.Source, ": ", e.Cause) +} + +type ServerPacketError struct { + Source M.Socksaddr + Cause error +} + +func (e *ServerPacketError) Unwrap() error { + return e.Cause +} + +func (e *ServerPacketError) Error() string { + return fmt.Sprint("shadowsocks: serve UDP from ", e.Source, ": ", e.Cause) +} + +func Key(password []byte, keySize int) []byte { + var b, prev []byte + h := md5.New() + for len(b) < keySize { + h.Write(prev) + h.Write([]byte(password)) + b = h.Sum(b) + prev = b[len(b)-h.Size():] + h.Reset() + } + return b[:keySize] +} diff --git a/shadowstream/protocol.go b/shadowstream/protocol.go new file mode 100644 index 0000000..1626de6 --- /dev/null +++ b/shadowstream/protocol.go @@ -0,0 +1,392 @@ +package shadowstream + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/md5" + "crypto/rand" + "crypto/rc4" + "io" + "net" + "os" + "runtime" + + "github.com/dgryski/go-camellia" + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "golang.org/x/crypto/blowfish" + "golang.org/x/crypto/cast5" + "golang.org/x/crypto/chacha20" +) + +var List = []string{ + "aes-128-ctr", + "aes-192-ctr", + "aes-256-ctr", + "aes-128-cfb", + "aes-192-cfb", + "aes-256-cfb", + "camellia-128-cfb", + "camellia-192-cfb", + "camellia-256-cfb", + "bf-cfb", + "cast5-cfb", + "des-cfb", + "rc4", + "rc4-md5", + "chacha20", + "chacha20-ietf", + "xchacha20", +} + +type Method struct { + name string + keyLength int + saltLength int + encryptConstructor func(key []byte, salt []byte) (cipher.Stream, error) + decryptConstructor func(key []byte, salt []byte) (cipher.Stream, error) + key []byte +} + +func New(method string, key []byte, password string) (shadowsocks.Method, error) { + m := &Method{ + name: method, + } + switch method { + case "aes-128-ctr": + m.keyLength = 16 + m.saltLength = aes.BlockSize + m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) + m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) + case "aes-192-ctr": + m.keyLength = 24 + m.saltLength = aes.BlockSize + m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) + m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) + case "aes-256-ctr": + m.keyLength = 32 + m.saltLength = aes.BlockSize + m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) + m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) + case "aes-128-cfb": + m.keyLength = 16 + m.saltLength = aes.BlockSize + m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) + case "aes-192-cfb": + m.keyLength = 24 + m.saltLength = aes.BlockSize + m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) + case "aes-256-cfb": + m.keyLength = 32 + m.saltLength = aes.BlockSize + m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) + case "camellia-128-cfb": + m.keyLength = 16 + m.saltLength = camellia.BlockSize + m.encryptConstructor = blockStream(camellia.New, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(camellia.New, cipher.NewCFBDecrypter) + case "camellia-192-cfb": + m.keyLength = 24 + m.saltLength = camellia.BlockSize + m.encryptConstructor = blockStream(camellia.New, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(camellia.New, cipher.NewCFBDecrypter) + case "camellia-256-cfb": + m.keyLength = 32 + m.saltLength = camellia.BlockSize + m.encryptConstructor = blockStream(camellia.New, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(camellia.New, cipher.NewCFBDecrypter) + case "bf-cfb": + m.keyLength = 16 + m.saltLength = blowfish.BlockSize + m.encryptConstructor = blockStream(func(key []byte) (cipher.Block, error) { return blowfish.NewCipher(key) }, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(func(key []byte) (cipher.Block, error) { return blowfish.NewCipher(key) }, cipher.NewCFBDecrypter) + case "cast5-cfb": + m.keyLength = 16 + m.saltLength = cast5.BlockSize + m.encryptConstructor = blockStream(func(key []byte) (cipher.Block, error) { return cast5.NewCipher(key) }, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(func(key []byte) (cipher.Block, error) { return cast5.NewCipher(key) }, cipher.NewCFBDecrypter) + case "des-cfb": + m.keyLength = 8 + m.saltLength = des.BlockSize + m.encryptConstructor = blockStream(des.NewCipher, cipher.NewCFBEncrypter) + m.decryptConstructor = blockStream(des.NewCipher, cipher.NewCFBDecrypter) + case "rc4": + m.keyLength = 16 + m.saltLength = 0 + m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + return rc4.NewCipher(key) + } + m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + return rc4.NewCipher(key) + } + case "rc4-md5": + m.keyLength = 16 + m.saltLength = 0 + m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + h := md5.New() + h.Write(key) + h.Write(salt) + return rc4.NewCipher(h.Sum(nil)) + } + m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + h := md5.New() + h.Write(key) + h.Write(salt) + return rc4.NewCipher(h.Sum(nil)) + } + case "chacha20", "chacha20-ietf": + m.keyLength = chacha20.KeySize + m.saltLength = chacha20.NonceSize + m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + return chacha20.NewUnauthenticatedCipher(key, salt) + } + m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + return chacha20.NewUnauthenticatedCipher(key, salt) + } + case "xchacha20": + m.keyLength = chacha20.KeySize + m.saltLength = chacha20.NonceSizeX + m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + return chacha20.NewUnauthenticatedCipher(key, salt) + } + m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { + return chacha20.NewUnauthenticatedCipher(key, salt) + } + default: + return nil, os.ErrInvalid + } + if len(key) == m.keyLength { + m.key = key + } else if len(key) > 0 { + return nil, shadowsocks.ErrBadKey + } else if password != "" { + m.key = shadowsocks.Key([]byte(password), m.keyLength) + } else { + return nil, shadowsocks.ErrMissingPassword + } + return m, nil +} + +func blockStream(blockCreator func(key []byte) (cipher.Block, error), streamCreator func(block cipher.Block, iv []byte) cipher.Stream) func([]byte, []byte) (cipher.Stream, error) { + return func(key []byte, iv []byte) (cipher.Stream, error) { + block, err := blockCreator(key) + if err != nil { + return nil, err + } + return streamCreator(block, iv), err + } +} + +func (m *Method) Name() string { + return m.name +} + +func (m *Method) KeyLength() int { + return m.keyLength +} + +func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { + shadowsocksConn := &clientConn{ + Conn: conn, + method: m, + destination: destination, + } + return shadowsocksConn, shadowsocksConn.writeRequest(nil) +} + +func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { + return &clientConn{ + Conn: conn, + method: m, + destination: destination, + } +} + +func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { + return &clientPacketConn{m, conn} +} + +type clientConn struct { + net.Conn + + method *Method + destination M.Socksaddr + + readStream cipher.Stream + writeStream cipher.Stream +} + +func (c *clientConn) writeRequest(payload []byte) error { + _buffer := buf.Make(c.method.keyLength + M.SocksaddrSerializer.AddrPortLen(c.destination) + len(payload)) + defer runtime.KeepAlive(_buffer) + buffer := buf.With(common.Dup(_buffer)) + + salt := buffer.Extend(c.method.keyLength) + common.Must1(io.ReadFull(rand.Reader, salt)) + + key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength) + writer, err := c.method.encryptConstructor(c.method.key, salt) + if err != nil { + return err + } + runtime.KeepAlive(key) + + err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination) + if err != nil { + return err + } + _, err = buffer.Write(payload) + if err != nil { + return err + } + + _, err = c.Conn.Write(buffer.Bytes()) + if err != nil { + return err + } + + c.writeStream = writer + return nil +} + +func (c *clientConn) readResponse() error { + if c.readStream != nil { + return nil + } + _salt := buf.Make(c.method.keyLength) + defer runtime.KeepAlive(_salt) + salt := common.Dup(_salt) + _, err := io.ReadFull(c.Conn, salt) + if err != nil { + return err + } + key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength) + defer runtime.KeepAlive(key) + c.readStream, err = c.method.decryptConstructor(common.Dup(key), salt) + if err != nil { + return err + } + return nil +} + +func (c *clientConn) Read(p []byte) (n int, err error) { + if err = c.readResponse(); err != nil { + return + } + n, err = c.Conn.Read(p) + if err != nil { + return 0, err + } + c.readStream.XORKeyStream(p[:n], p[:n]) + return +} + +func (c *clientConn) Write(p []byte) (n int, err error) { + if c.writeStream == nil { + err = c.writeRequest(p) + if err == nil { + n = len(p) + } + return + } + + c.writeStream.XORKeyStream(p, p) + return c.Conn.Write(p) +} + +func (c *clientConn) Upstream() any { + return c.Conn +} + +type clientPacketConn struct { + *Method + net.Conn +} + +func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buf.With(buffer.ExtendHeader(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination))) + common.Must1(header.ReadFullFrom(rand.Reader, c.keyLength)) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength)) + if err != nil { + return err + } + stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength)) + return common.Error(c.Write(buffer.Bytes())) +} + +func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + n, err := c.Read(buffer.FreeBytes()) + if err != nil { + return M.Socksaddr{}, err + } + buffer.Truncate(n) + stream, err := c.decryptConstructor(c.key, buffer.To(c.keyLength)) + if err != nil { + return M.Socksaddr{}, err + } + stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength)) + buffer.Advance(c.keyLength) + return M.SocksaddrSerializer.ReadAddrPort(buffer) +} + +func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(p) + if err != nil { + return + } + stream, err := c.decryptConstructor(c.key, p[:c.keyLength]) + if err != nil { + return + } + buffer := buf.With(p[c.keyLength:n]) + stream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return + } + addr = destination.UDPAddr() + n = copy(p, buffer.Bytes()) + return +} + +func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + destination := M.SocksaddrFromNet(addr) + _buffer := buf.Make(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p)) + defer runtime.KeepAlive(_buffer) + buffer := buf.With(common.Dup(_buffer)) + common.Must1(buffer.ReadFullFrom(rand.Reader, c.keyLength)) + err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr)) + if err != nil { + return + } + _, err = buffer.Write(p) + if err != nil { + return + } + stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength)) + if err != nil { + return + } + stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength)) + _, err = c.Write(buffer.Bytes()) + if err != nil { + return + } + return len(p), nil +} + +func (c *clientPacketConn) Upstream() any { + return c.Conn +}