Init commit

This commit is contained in:
世界 2022-05-25 14:00:04 +08:00
commit 48809b0a99
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
20 changed files with 3828 additions and 0 deletions

40
.github/workflows/debug.yml vendored Normal file
View file

@ -0,0 +1,40 @@
name: Debug build
on:
push:
branches:
- main
paths-ignore:
- '**.md'
- '.github/**'
- '!.github/workflows/debug.yml'
pull_request:
branches:
- main
jobs:
build:
name: Debug build
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: ${{ steps.version.outputs.go_version }}
- name: Build and test
run: |
version=`git rev-parse HEAD`
mkdir build
pushd build
go mod init build
go get -v github.com/sagernet/sing-shadowsocks@$version
popd
go test -v ./...

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
/.idea/
/vendor/

14
LICENSE Normal file
View file

@ -0,0 +1,14 @@
Copyright (C) 2022 by nekohasekai <contact-sagernet@sekai.icu>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.

3
README.md Normal file
View file

@ -0,0 +1,3 @@
# sing-shadowsocks
Lightweight and efficient shadowsocks implementation with sing.

6
format.go Normal file
View file

@ -0,0 +1,6 @@
package shadowsocks
//go:generate go install -v mvdan.cc/gofumpt@latest
//go:generate go install -v github.com/daixiang0/gci@latest
//go:generate gofumpt -l -w .
//go:generate gci write .

16
go.mod Normal file
View file

@ -0,0 +1,16 @@
module github.com/sagernet/sing-shadowsocks
go 1.18
require (
github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d
github.com/sagernet/sing v0.0.0-20220525123305-5b49366f6e34
golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898
golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d
lukechampine.com/blake3 v1.1.7
)
require (
github.com/klauspost/cpuid/v2 v2.0.12 // indirect
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
)

17
go.sum Normal file
View file

@ -0,0 +1,17 @@
github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d h1:CPqTNIigGweVPT4CYb+OO2E6XyRKFOmvTHwWRLgCAlE=
github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d/go.mod h1:QX5ZVULjAfZJux/W62Y91HvCh9hyW6enAwcrrv/sLj0=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE=
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/sagernet/sing v0.0.0-20220525062603-53c607b13ff2 h1:x7E53uloX7pU3rWOzb81IBCAmwMtE2u9x4ZJvJXaCnM=
github.com/sagernet/sing v0.0.0-20220525062603-53c607b13ff2/go.mod h1:w2HnJzXKHpD6F5Z/9XlSD4qbcpHY2RSZuQnFzqgELMg=
github.com/sagernet/sing v0.0.0-20220525123305-5b49366f6e34 h1:/FfHfteLZo5mOtZbYOx/9ymDEYxlwBuM5iHO9reVe/E=
github.com/sagernet/sing v0.0.0-20220525123305-5b49366f6e34/go.mod h1:w2HnJzXKHpD6F5Z/9XlSD4qbcpHY2RSZuQnFzqgELMg=
golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898 h1:SLP7Q4Di66FONjDJbCYrCRrh97focO6sLogHO7/g8F0=
golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d h1:q4JksJ2n0fmbXC0Aj0eOs6E0AcPqnKglxWXWFqGD6x0=
golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d/go.mod h1:bVQfyl2sCM/QIIGHpWbFGfHPuDvqnCNkT6MQLTCjO/U=
lukechampine.com/blake3 v1.1.7 h1:GgRMhmdsuK8+ii6UZFDL8Nb+VyMwadAgcJyfYHxG6n0=
lukechampine.com/blake3 v1.1.7/go.mod h1:tkKEOtDkNtklkXtLNEOGNq5tcV90tJiA1vAA12R78LA=

241
none.go Normal file
View file

@ -0,0 +1,241 @@
package shadowsocks
import (
"context"
"io"
"net"
"net/netip"
"runtime"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
)
const MethodNone = "none"
type NoneMethod struct{}
func NewNone() Method {
return &NoneMethod{}
}
func (m *NoneMethod) Name() string {
return MethodNone
}
func (m *NoneMethod) KeyLength() int {
return 0
}
func (m *NoneMethod) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &noneConn{
Conn: conn,
handshake: true,
destination: destination,
}
return shadowsocksConn, shadowsocksConn.clientHandshake()
}
func (m *NoneMethod) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &noneConn{
Conn: conn,
destination: destination,
}
}
func (m *NoneMethod) DialPacketConn(conn net.Conn) N.NetPacketConn {
return &nonePacketConn{conn}
}
type noneConn struct {
net.Conn
access sync.Mutex
handshake bool
destination M.Socksaddr
}
func (c *noneConn) clientHandshake() error {
err := M.SocksaddrSerializer.WriteAddrPort(c.Conn, c.destination)
if err != nil {
return err
}
c.handshake = true
return nil
}
func (c *noneConn) Write(b []byte) (n int, err error) {
if c.handshake {
goto direct
}
c.access.Lock()
defer c.access.Unlock()
if c.handshake {
goto direct
}
{
if len(b) == 0 {
return 0, c.clientHandshake()
}
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination)
if err != nil {
return
}
bufN, _ := buffer.Write(b)
_, err = c.Conn.Write(buffer.Bytes())
runtime.KeepAlive(_buffer)
if err != nil {
return
}
if bufN < len(b) {
_, err = c.Conn.Write(b[bufN:])
if err != nil {
return
}
}
n = len(b)
}
direct:
return c.Conn.Write(b)
}
func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) {
if !c.handshake {
return rw.ReadFrom0(c, r)
}
return c.Conn.(io.ReaderFrom).ReadFrom(r)
}
func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) {
return io.Copy(w, c.Conn)
}
func (c *noneConn) RemoteAddr() net.Addr {
return c.destination.TCPAddr()
}
type nonePacketConn struct {
net.Conn
}
func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
_, err := buffer.ReadFrom(c)
if err != nil {
return M.Socksaddr{}, err
}
return M.SocksaddrSerializer.ReadAddrPort(buffer)
}
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination)))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
return common.Error(buffer.WriteTo(c))
}
func (c *nonePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(p)
if err != nil {
return
}
buffer := buf.With(p[:n])
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
return
}
addr = destination.UDPAddr()
n = copy(p, buffer.Bytes())
return
}
func (c *nonePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
_buffer := buf.Make(M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
defer runtime.KeepAlive(_buffer)
buffer := buf.With(common.Dup(_buffer))
err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
if err != nil {
return
}
_, err = buffer.Write(p)
if err != nil {
return
}
return len(p), nil
}
type NoneService struct {
handler Handler
udp *udpnat.Service[netip.AddrPort]
}
func NewNoneService(udpTimeout int64, handler Handler) Service {
s := &NoneService{
handler: handler,
}
s.udp = udpnat.New[netip.AddrPort](udpTimeout, s)
return s
}
func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
if err != nil {
return err
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.handler.NewConnection(ctx, conn, metadata)
}
func (s *NoneService) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
return err
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
s.udp.NewPacket(ctx, metadata.Source.AddrPort(), func() N.PacketWriter {
return &nonePacketWriter{conn, metadata.Source}
}, buffer, metadata)
return nil
}
type nonePacketWriter struct {
N.PacketConn
sourceAddr M.Socksaddr
}
func (s *nonePacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination)))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
return s.PacketConn.WritePacket(buffer, s.sourceAddr)
}
func (s *NoneService) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(ctx, conn, metadata)
}
func (s *NoneService) HandleError(err error) {
s.handler.HandleError(err)
}

421
shadowaead/aead.go Normal file
View file

@ -0,0 +1,421 @@
package shadowaead
import (
"crypto/cipher"
"encoding/binary"
"io"
"github.com/sagernet/sing/common/buf"
)
// https://shadowsocks.org/en/wiki/AEAD-Ciphers.html
const (
MaxPacketSize = 16*1024 - 1
PacketLengthBufferSize = 2
)
const (
// NonceSize
// crypto/cipher.gcmStandardNonceSize
// golang.org/x/crypto/chacha20poly1305.NonceSize
NonceSize = 12
// Overhead
// crypto/cipher.gcmTagSize
// golang.org/x/crypto/chacha20poly1305.Overhead
Overhead = 16
)
type Reader struct {
upstream io.Reader
cipher cipher.AEAD
buffer []byte
nonce []byte
index int
cached int
}
func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reader {
return &Reader{
upstream: upstream,
cipher: cipher,
buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2),
nonce: make([]byte, NonceSize),
}
}
func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce []byte) *Reader {
return &Reader{
upstream: upstream,
cipher: cipher,
buffer: buffer,
nonce: nonce,
}
}
func (r *Reader) Upstream() any {
return r.upstream
}
func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
if r.cached > 0 {
writeN, writeErr := writer.Write(r.buffer[r.index : r.index+r.cached])
if writeErr != nil {
return int64(writeN), writeErr
}
n += int64(writeN)
}
for {
start := PacketLengthBufferSize + Overhead
_, err = io.ReadFull(r.upstream, r.buffer[:start])
if err != nil {
return
}
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
if err != nil {
return
}
increaseNonce(r.nonce)
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
end := length + Overhead
_, err = io.ReadFull(r.upstream, r.buffer[:end])
if err != nil {
return
}
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
if err != nil {
return
}
increaseNonce(r.nonce)
writeN, writeErr := writer.Write(r.buffer[:length])
if writeErr != nil {
return int64(writeN), writeErr
}
n += int64(writeN)
}
}
func (r *Reader) readInternal() (err error) {
start := PacketLengthBufferSize + Overhead
_, err = io.ReadFull(r.upstream, r.buffer[:start])
if err != nil {
return err
}
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
if err != nil {
return err
}
increaseNonce(r.nonce)
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
end := length + Overhead
_, err = io.ReadFull(r.upstream, r.buffer[:end])
if err != nil {
return err
}
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
if err != nil {
return err
}
increaseNonce(r.nonce)
r.cached = length
r.index = 0
return nil
}
func (r *Reader) ReadByte() (byte, error) {
if r.cached == 0 {
err := r.readInternal()
if err != nil {
return 0, err
}
}
index := r.index
r.index++
r.cached--
return r.buffer[index], nil
}
func (r *Reader) Read(b []byte) (n int, err error) {
if r.cached > 0 {
n = copy(b, r.buffer[r.index:r.index+r.cached])
r.cached -= n
r.index += n
return
}
start := PacketLengthBufferSize + Overhead
_, err = io.ReadFull(r.upstream, r.buffer[:start])
if err != nil {
return 0, err
}
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
if err != nil {
return 0, err
}
increaseNonce(r.nonce)
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
end := length + Overhead
if len(b) >= end {
data := b[:end]
_, err = io.ReadFull(r.upstream, data)
if err != nil {
return 0, err
}
_, err = r.cipher.Open(b[:0], r.nonce, data, nil)
if err != nil {
return 0, err
}
increaseNonce(r.nonce)
return length, nil
} else {
_, err = io.ReadFull(r.upstream, r.buffer[:end])
if err != nil {
return 0, err
}
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
if err != nil {
return 0, err
}
increaseNonce(r.nonce)
n = copy(b, r.buffer[:length])
r.cached = length - n
r.index = n
return
}
}
func (r *Reader) Discard(n int) error {
for {
if r.cached >= n {
r.cached -= n
r.index += n
return nil
} else if r.cached > 0 {
n -= r.cached
r.cached = 0
r.index = 0
}
err := r.readInternal()
if err != nil {
return err
}
}
}
func (r *Reader) Cached() int {
return r.cached
}
func (r *Reader) CachedSlice() []byte {
return r.buffer[r.index : r.index+r.cached]
}
func (r *Reader) ReadWithLengthChunk(lengthChunk []byte) error {
_, err := r.cipher.Open(r.buffer[:0], r.nonce, lengthChunk, nil)
if err != nil {
return err
}
increaseNonce(r.nonce)
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
end := length + Overhead
_, err = io.ReadFull(r.upstream, r.buffer[:end])
if err != nil {
return err
}
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
if err != nil {
return err
}
increaseNonce(r.nonce)
r.cached = length
r.index = 0
return nil
}
func (r *Reader) ReadWithLength(length uint16) error {
end := length + Overhead
_, err := io.ReadFull(r.upstream, r.buffer[:end])
if err != nil {
return err
}
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
if err != nil {
return err
}
increaseNonce(r.nonce)
r.cached = int(length)
r.index = 0
return nil
}
func (r *Reader) ReadChunk(chunk []byte) error {
bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil)
if err != nil {
return err
}
increaseNonce(r.nonce)
r.cached = len(bb)
r.index = 0
return nil
}
type Writer struct {
upstream io.Writer
cipher cipher.AEAD
maxPacketSize int
buffer []byte
nonce []byte
}
func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *Writer {
return &Writer{
upstream: upstream,
cipher: cipher,
buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2),
nonce: make([]byte, cipher.NonceSize()),
maxPacketSize: maxPacketSize,
}
}
func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buffer []byte, nonce []byte) *Writer {
return &Writer{
upstream: upstream,
cipher: cipher,
maxPacketSize: maxPacketSize,
buffer: buffer,
nonce: nonce,
}
}
func (w *Writer) Upstream() any {
return w.upstream
}
func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
for {
offset := Overhead + PacketLengthBufferSize
readN, readErr := r.Read(w.buffer[offset : offset+w.maxPacketSize])
if readErr != nil {
return 0, readErr
}
binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(readN))
w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil)
increaseNonce(w.nonce)
packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, w.buffer[offset:offset+readN], nil)
increaseNonce(w.nonce)
_, err = w.upstream.Write(w.buffer[:offset+len(packet)])
if err != nil {
return
}
n += int64(readN)
}
}
func (w *Writer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return
}
for pLen := len(p); pLen > 0; {
var data []byte
if pLen > w.maxPacketSize {
data = p[:w.maxPacketSize]
p = p[w.maxPacketSize:]
} else {
data = p
p = nil
}
binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data)))
w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil)
increaseNonce(w.nonce)
offset := Overhead + PacketLengthBufferSize
packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, data, nil)
increaseNonce(w.nonce)
_, err = w.upstream.Write(w.buffer[:offset+len(packet)])
if err != nil {
return
}
n += len(data)
}
return
}
func (w *Writer) Buffer() *buf.Buffer {
return buf.With(w.buffer)
}
func (w *Writer) WriteChunk(buffer *buf.Buffer, chunk []byte) {
bb := w.cipher.Seal(buffer.Index(buffer.Len()), w.nonce, chunk, nil)
buffer.Extend(len(bb))
increaseNonce(w.nonce)
}
func (w *Writer) BufferedWriter(reversed int) *BufferedWriter {
return &BufferedWriter{
upstream: w,
reversed: reversed,
data: w.buffer[PacketLengthBufferSize+Overhead : len(w.buffer)-Overhead],
}
}
type BufferedWriter struct {
upstream *Writer
data []byte
reversed int
index int
}
func (w *BufferedWriter) UpstreamWriter() io.Writer {
return w.upstream
}
func (w *BufferedWriter) WriterReplaceable() bool {
return w.index == 0
}
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
var index int
for {
cachedN := copy(w.data[w.reversed+w.index:], p[index:])
if cachedN == len(p[index:]) {
w.index += cachedN
return cachedN, nil
}
err = w.Flush()
if err != nil {
return
}
index += cachedN
}
}
func (w *BufferedWriter) Flush() error {
if w.index == 0 {
if w.reversed > 0 {
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed])
w.reversed = 0
return err
}
return nil
}
buffer := w.upstream.buffer[w.reversed:]
binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index))
w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil)
increaseNonce(w.upstream.nonce)
offset := Overhead + PacketLengthBufferSize
packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil)
increaseNonce(w.upstream.nonce)
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)])
w.reversed = 0
return err
}
func increaseNonce(nonce []byte) {
for i := range nonce {
nonce[i]++
if nonce[i] != 0 {
return
}
}
}

361
shadowaead/protocol.go Normal file
View file

@ -0,0 +1,361 @@
package shadowaead
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha1"
"io"
"net"
"runtime"
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/hkdf"
)
var List = []string{
"aes-128-gcm",
"aes-192-gcm",
"aes-256-gcm",
"chacha20-ietf-poly1305",
"xchacha20-ietf-poly1305",
}
func New(method string, key []byte, password string) (shadowsocks.Method, error) {
m := &Method{
name: method,
}
switch method {
case "aes-128-gcm":
m.keySaltLength = 16
m.constructor = newAESGCM
case "aes-192-gcm":
m.keySaltLength = 24
m.constructor = newAESGCM
case "aes-256-gcm":
m.keySaltLength = 32
m.constructor = newAESGCM
case "chacha20-ietf-poly1305":
m.keySaltLength = 32
m.constructor = func(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.New(key)
common.Must(err)
return cipher
}
case "xchacha20-ietf-poly1305":
m.keySaltLength = 32
m.constructor = func(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.NewX(key)
common.Must(err)
return cipher
}
}
if len(key) == m.keySaltLength {
m.key = key
} else if len(key) > 0 {
return nil, shadowsocks.ErrBadKey
} else if password == "" {
return nil, shadowsocks.ErrMissingPassword
} else {
m.key = shadowsocks.Key([]byte(password), m.keySaltLength)
}
return m, nil
}
func Kdf(key, iv []byte, keyLength int) []byte {
info := []byte("ss-subkey")
subKey := buf.Make(keyLength)
kdf := hkdf.New(sha1.New, key, iv, common.Dup(info))
runtime.KeepAlive(info)
common.Must1(io.ReadFull(kdf, common.Dup(subKey)))
return subKey
}
func newAESGCM(key []byte) cipher.AEAD {
block, err := aes.NewCipher(key)
common.Must(err)
aead, err := cipher.NewGCM(block)
common.Must(err)
return aead
}
type Method struct {
name string
keySaltLength int
constructor func(key []byte) cipher.AEAD
key []byte
}
func (m *Method) Name() string {
return m.name
}
func (m *Method) KeyLength() int {
return m.keySaltLength
}
func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
_salt := buf.Make(m.keySaltLength)
defer runtime.KeepAlive(_salt)
salt := common.Dup(_salt)
_, err := io.ReadFull(upstream, salt)
if err != nil {
return nil, E.Cause(err, "read salt")
}
key := Kdf(m.key, salt, m.keySaltLength)
defer runtime.KeepAlive(key)
return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
}
func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) {
_salt := buf.Make(m.keySaltLength)
defer runtime.KeepAlive(_salt)
salt := common.Dup(_salt)
common.Must1(io.ReadFull(rand.Reader, salt))
_, err := upstream.Write(salt)
if err != nil {
return nil, err
}
key := Kdf(m.key, salt, m.keySaltLength)
return NewWriter(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
}
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &clientConn{
Conn: conn,
method: m,
destination: destination,
}
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
}
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &clientConn{
Conn: conn,
method: m,
destination: destination,
}
}
func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
return &clientPacketConn{m, conn}
}
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key))
runtime.KeepAlive(key)
c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
buffer.Extend(Overhead)
return nil
}
func (m *Method) DecodePacket(buffer *buf.Buffer) error {
if buffer.Len() < m.keySaltLength {
return E.New("bad packet")
}
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key))
runtime.KeepAlive(key)
packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
if err != nil {
return err
}
buffer.Advance(m.keySaltLength)
buffer.Truncate(len(packet))
return nil
}
type clientConn struct {
net.Conn
method *Method
destination M.Socksaddr
reader *Reader
writer *Writer
}
func (c *clientConn) writeRequest(payload []byte) error {
_salt := make([]byte, c.method.keySaltLength)
salt := common.Dup(_salt)
common.Must1(io.ReadFull(rand.Reader, salt))
key := Kdf(c.method.key, salt, c.method.keySaltLength)
runtime.KeepAlive(_salt)
writer := NewWriter(
c.Conn,
c.method.constructor(common.Dup(key)),
MaxPacketSize,
)
runtime.KeepAlive(key)
header := writer.Buffer()
header.Write(salt)
bufferedWriter := writer.BufferedWriter(header.Len())
if len(payload) > 0 {
err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return err
}
_, err = bufferedWriter.Write(payload)
if err != nil {
return err
}
} else {
err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return err
}
}
err := bufferedWriter.Flush()
if err != nil {
return err
}
c.writer = writer
return nil
}
func (c *clientConn) readResponse() error {
if c.reader != nil {
return nil
}
_salt := buf.Make(c.method.keySaltLength)
defer runtime.KeepAlive(_salt)
salt := common.Dup(_salt)
_, err := io.ReadFull(c.Conn, salt)
if err != nil {
return err
}
key := Kdf(c.method.key, salt, c.method.keySaltLength)
defer runtime.KeepAlive(key)
c.reader = NewReader(
c.Conn,
c.method.constructor(common.Dup(key)),
MaxPacketSize,
)
return nil
}
func (c *clientConn) Read(p []byte) (n int, err error) {
if err = c.readResponse(); err != nil {
return
}
return c.reader.Read(p)
}
func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
if err = c.readResponse(); err != nil {
return
}
return c.reader.WriteTo(w)
}
func (c *clientConn) Write(p []byte) (n int, err error) {
if c.writer != nil {
return c.writer.Write(p)
}
err = c.writeRequest(p)
if err != nil {
return
}
return len(p), nil
}
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
return rw.ReadFrom0(c, r)
}
return c.writer.ReadFrom(r)
}
func (c *clientConn) Upstream() any {
return c.Conn
}
type clientPacketConn struct {
*Method
net.Conn
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buffer.ExtendHeader(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(rand.Reader, header[:c.keySaltLength]))
err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
if err != nil {
return err
}
err = c.EncodePacket(buffer)
if err != nil {
return err
}
return common.Error(c.Write(buffer.Bytes()))
}
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return M.Socksaddr{}, err
}
buffer.Truncate(n)
err = c.DecodePacket(buffer)
if err != nil {
return M.Socksaddr{}, err
}
return M.SocksaddrSerializer.ReadAddrPort(buffer)
}
func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(p)
if err != nil {
return
}
b := buf.With(p[:n])
err = c.DecodePacket(b)
if err != nil {
return
}
destination, err := M.SocksaddrSerializer.ReadAddrPort(b)
if err != nil {
return
}
addr = destination.UDPAddr()
n = copy(p, b.Bytes())
return
}
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil {
return
}
_, err = buffer.Write(p)
if err != nil {
return
}
err = c.EncodePacket(buffer)
if err != nil {
return
}
_, err = c.Write(buffer.Bytes())
if err != nil {
return
}
return len(p), nil
}
func (c *clientPacketConn) Upstream() any {
return c.Conn
}

246
shadowaead/service.go Normal file
View file

@ -0,0 +1,246 @@
package shadowaead
import (
"context"
"crypto/cipher"
"crypto/rand"
"io"
"net"
"net/netip"
"runtime"
"sync"
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"golang.org/x/crypto/chacha20poly1305"
)
var ErrBadHeader = E.New("bad header")
type Service struct {
name string
keySaltLength int
constructor func(key []byte) cipher.AEAD
key []byte
handler shadowsocks.Handler
udpNat *udpnat.Service[netip.AddrPort]
}
func NewService(method string, key []byte, password string, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{
name: method,
handler: handler,
udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
}
switch method {
case "aes-128-gcm":
s.keySaltLength = 16
s.constructor = newAESGCM
case "aes-192-gcm":
s.keySaltLength = 24
s.constructor = newAESGCM
case "aes-256-gcm":
s.keySaltLength = 32
s.constructor = newAESGCM
case "chacha20-ietf-poly1305":
s.keySaltLength = 32
s.constructor = func(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.New(key)
common.Must(err)
return cipher
}
case "xchacha20-ietf-poly1305":
s.keySaltLength = 32
s.constructor = func(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.NewX(key)
common.Must(err)
return cipher
}
}
if len(key) == s.keySaltLength {
s.key = key
} else if len(key) > 0 {
return nil, shadowsocks.ErrBadKey
} else if password != "" {
s.key = shadowsocks.Key([]byte(password), s.keySaltLength)
} else {
return nil, shadowsocks.ErrMissingPassword
}
return s, nil
}
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
err := s.newConnection(ctx, conn, metadata)
if err != nil {
err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
}
return err
}
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
_header := buf.Make(s.keySaltLength + PacketLengthBufferSize + Overhead)
defer runtime.KeepAlive(_header)
header := common.Dup(_header)
n, err := conn.Read(header)
if err != nil {
return E.Cause(err, "read header")
} else if n < len(header) {
return ErrBadHeader
}
key := Kdf(s.key, header[:s.keySaltLength], s.keySaltLength)
reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize)
err = reader.ReadWithLengthChunk(header[s.keySaltLength:])
if err != nil {
return err
}
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return err
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.handler.NewConnection(ctx, &serverConn{
Service: s,
Conn: conn,
reader: reader,
}, metadata)
}
type serverConn struct {
*Service
net.Conn
access sync.Mutex
reader *Reader
writer *Writer
}
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
_salt := buf.Make(c.keySaltLength)
salt := common.Dup(_salt)
common.Must1(io.ReadFull(rand.Reader, salt))
key := Kdf(c.key, salt, c.keySaltLength)
runtime.KeepAlive(_salt)
writer := NewWriter(
c.Conn,
c.constructor(common.Dup(key)),
MaxPacketSize,
)
runtime.KeepAlive(key)
header := writer.Buffer()
header.Write(salt)
bufferedWriter := writer.BufferedWriter(header.Len())
if len(payload) > 0 {
_, err = bufferedWriter.Write(payload)
if err != nil {
return
}
}
err = bufferedWriter.Flush()
if err != nil {
return
}
c.writer = writer
return
}
func (c *serverConn) Write(p []byte) (n int, err error) {
if c.writer != nil {
return c.writer.Write(p)
}
c.access.Lock()
if c.writer != nil {
c.access.Unlock()
return c.writer.Write(p)
}
defer c.access.Unlock()
return c.writeResponse(p)
}
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
return rw.ReadFrom0(c, r)
}
return c.writer.ReadFrom(r)
}
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}
func (c *serverConn) Upstream() any {
return c.Conn
}
func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
err := s.newPacket(ctx, conn, buffer, metadata)
if err != nil {
err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
}
return err
}
func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
if buffer.Len() < s.keySaltLength {
return E.New("bad packet")
}
key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength)
c := s.constructor(common.Dup(key))
runtime.KeepAlive(key)
packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil {
return err
}
buffer.Advance(s.keySaltLength)
buffer.Truncate(len(packet))
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
return err
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
s.udpNat.NewPacket(ctx, metadata.Source.AddrPort(), func() N.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source}
}, buffer, metadata)
return nil
}
type serverPacketWriter struct {
*Service
N.PacketConn
source M.Socksaddr
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buffer.ExtendHeader(w.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(rand.Reader, header[:w.keySaltLength]))
err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
if err != nil {
return err
}
key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength)
c := w.constructor(common.Dup(key))
runtime.KeepAlive(key)
c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil)
buffer.Extend(Overhead)
return w.PacketConn.WritePacket(buffer, w.source)
}

745
shadowaead_2022/protocol.go Normal file
View file

@ -0,0 +1,745 @@
package shadowaead_2022
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"io"
"math"
mRand "math/rand"
"net"
"os"
"runtime"
"strings"
"sync/atomic"
"time"
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"golang.org/x/crypto/chacha20poly1305"
wgReplay "golang.zx2c4.com/wireguard/replay"
"lukechampine.com/blake3"
)
const (
HeaderTypeClient = 0
HeaderTypeServer = 1
MaxPaddingLength = 900
PacketNonceSize = 24
MaxPacketSize = 65535
RequestHeaderFixedChunkLength = 1 + 8 + 2
)
var (
ErrMissingPSK = E.New("missing psk")
ErrBadHeaderType = E.New("bad header type")
ErrBadTimestamp = E.New("bad timestamp")
ErrBadRequestSalt = E.New("bad request salt")
ErrBadClientSessionId = E.New("bad client session id")
ErrPacketIdNotUnique = E.New("packet id not unique")
ErrTooManyServerSessions = E.New("server session changed more than once during the last minute")
)
var List = []string{
"2022-blake3-aes-128-gcm",
"2022-blake3-aes-256-gcm",
"2022-blake3-chacha20-poly1305",
}
func NewWithPassword(method string, password string) (shadowsocks.Method, error) {
var pskList [][]byte
if password == "" {
return nil, ErrMissingPSK
}
keyStrList := strings.Split(password, ":")
pskList = make([][]byte, len(keyStrList))
for i, keyStr := range keyStrList {
kb, err := base64.StdEncoding.DecodeString(keyStr)
if err != nil {
return nil, E.Cause(err, "decode key")
}
pskList[i] = kb
}
return New(method, pskList)
}
func New(method string, pskList [][]byte) (shadowsocks.Method, error) {
m := &Method{
name: method,
replayFilter: replay.NewSimple(60 * time.Second),
}
switch method {
case "2022-blake3-aes-128-gcm":
m.keySaltLength = 16
m.constructor = newAESGCM
m.blockConstructor = newAES
case "2022-blake3-aes-256-gcm":
m.keySaltLength = 32
m.constructor = newAESGCM
m.blockConstructor = newAES
case "2022-blake3-chacha20-poly1305":
if len(pskList) > 1 {
return nil, os.ErrInvalid
}
m.keySaltLength = 32
m.constructor = newChacha20Poly1305
}
if len(pskList) == 0 {
return nil, ErrMissingPSK
}
for i, psk := range pskList {
if len(psk) < m.keySaltLength {
return nil, shadowsocks.ErrBadKey
} else if len(psk) > m.keySaltLength {
pskList[i] = Key(psk, m.keySaltLength)
}
}
if len(pskList) > 1 {
pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize)
for i, psk := range pskList {
if i == 0 {
continue
}
hash := blake3.Sum512(psk)
copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], hash[:aes.BlockSize])
}
m.pskHash = pskHash
}
switch method {
case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm":
m.udpBlockCipher = newAES(pskList[0])
case "2022-blake3-chacha20-poly1305":
m.udpCipher = newXChacha20Poly1305(pskList[0])
}
m.pskList = pskList
return m, nil
}
func Key(key []byte, keyLength int) []byte {
psk := sha256.Sum256(key)
return psk[:keyLength]
}
func SessionKey(psk []byte, salt []byte, keyLength int) []byte {
sessionKey := buf.Make(len(psk) + len(salt))
copy(sessionKey, psk)
copy(sessionKey[len(psk):], salt)
outKey := buf.Make(keyLength)
blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey)
return outKey
}
func newAES(key []byte) cipher.Block {
block, err := aes.NewCipher(key)
common.Must(err)
return block
}
func newAESGCM(key []byte) cipher.AEAD {
block, err := aes.NewCipher(key)
common.Must(err)
aead, err := cipher.NewGCM(block)
common.Must(err)
return aead
}
func newChacha20Poly1305(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.New(key)
common.Must(err)
return cipher
}
func newXChacha20Poly1305(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.NewX(key)
common.Must(err)
return cipher
}
type Method struct {
name string
keySaltLength int
constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block
udpCipher cipher.AEAD
udpBlockCipher cipher.Block
pskList [][]byte
pskHash []byte
replayFilter replay.Filter
}
func (m *Method) Name() string {
return m.name
}
func (m *Method) KeyLength() int {
return m.keySaltLength
}
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &clientConn{
Method: m,
Conn: conn,
destination: destination,
}
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
}
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &clientConn{
Method: m,
Conn: conn,
destination: destination,
}
}
func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
return &clientPacketConn{m, conn, m.newUDPSession()}
}
type clientConn struct {
*Method
net.Conn
destination M.Socksaddr
requestSalt []byte
reader *shadowaead.Reader
writer *shadowaead.Writer
}
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
pskLen := len(m.pskList)
if pskLen < 2 {
return
}
for i, psk := range m.pskList {
keyMaterial := buf.Make(m.keySaltLength * 2)
copy(keyMaterial, psk)
copy(keyMaterial[m.keySaltLength:], salt)
_identitySubkey := buf.Make(m.keySaltLength)
identitySubkey := common.Dup(_identitySubkey)
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
header := request.Extend(16)
m.blockConstructor(identitySubkey).Encrypt(header, pskHash)
runtime.KeepAlive(_identitySubkey)
if i == pskLen-2 {
break
}
}
}
func (c *clientConn) writeRequest(payload []byte) error {
salt := buf.Make(c.keySaltLength)
common.Must1(io.ReadFull(rand.Reader, salt))
key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength)
writer := shadowaead.NewWriter(
c.Conn,
c.constructor(common.Dup(key)),
MaxPacketSize,
)
runtime.KeepAlive(key)
header := writer.Buffer()
header.Write(salt)
c.writeExtendedIdentityHeaders(header, salt)
var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte
fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:]))
common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient))
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(time.Now().Unix())))
var paddingLen int
if len(payload) == 0 {
paddingLen = mRand.Intn(MaxPaddingLength + 1)
}
variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen + len(payload)
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen)))
writer.WriteChunk(header, fixedLengthBuffer.Slice())
runtime.KeepAlive(_fixedLengthBuffer)
_variableLengthBuffer := buf.Make(variableLengthHeaderLen)
variableLengthBuffer := buf.With(common.Dup(_variableLengthBuffer))
common.Must(M.SocksaddrSerializer.WriteAddrPort(variableLengthBuffer, c.destination))
common.Must(binary.Write(variableLengthBuffer, binary.BigEndian, uint16(paddingLen)))
if paddingLen > 0 {
variableLengthBuffer.Extend(paddingLen)
} else {
common.Must1(variableLengthBuffer.Write(payload))
}
writer.WriteChunk(header, variableLengthBuffer.Slice())
runtime.KeepAlive(_variableLengthBuffer)
err := writer.BufferedWriter(header.Len()).Flush()
if err != nil {
return E.Cause(err, "client handshake")
}
c.requestSalt = salt
c.writer = writer
return nil
}
func (c *clientConn) readResponse() error {
if c.reader != nil {
return nil
}
_salt := buf.Make(c.keySaltLength)
salt := common.Dup(_salt)
_, err := io.ReadFull(c.Conn, salt)
if err != nil {
return err
}
if !c.replayFilter.Check(salt) {
return ErrSaltNotUnique
}
key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength)
runtime.KeepAlive(_salt)
reader := shadowaead.NewReader(
c.Conn,
c.constructor(common.Dup(key)),
MaxPacketSize,
)
runtime.KeepAlive(key)
err = reader.ReadWithLength(uint16(1 + 8 + c.keySaltLength + 2))
if err != nil {
return E.Cause(err, "read response fixed length chunk")
}
headerType, err := rw.ReadByte(reader)
if err != nil {
return err
}
if headerType != HeaderTypeServer {
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
}
var epoch uint64
err = binary.Read(reader, binary.BigEndian, &epoch)
if err != nil {
return err
}
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
if diff > 30 {
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
}
_requestSalt := buf.Make(c.keySaltLength)
requestSalt := common.Dup(_requestSalt)
_, err = io.ReadFull(reader, requestSalt)
if err != nil {
return err
}
if bytes.Compare(requestSalt, c.requestSalt) > 0 {
return ErrBadRequestSalt
}
runtime.KeepAlive(_requestSalt)
var length uint16
err = binary.Read(reader, binary.BigEndian, &length)
if err != nil {
return err
}
err = reader.ReadWithLength(length)
if err != nil {
return err
}
c.requestSalt = nil
c.reader = reader
return nil
}
func (c *clientConn) Read(p []byte) (n int, err error) {
if err = c.readResponse(); err != nil {
return
}
return c.reader.Read(p)
}
func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
if err = c.readResponse(); err != nil {
return
}
return c.reader.WriteTo(w)
}
func (c *clientConn) Write(p []byte) (n int, err error) {
if c.writer == nil {
err = c.writeRequest(p)
if err == nil {
n = len(p)
}
return
}
return c.writer.Write(p)
}
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
return rw.ReadFrom0(c, r)
}
return c.writer.ReadFrom(r)
}
func (c *clientConn) Upstream() any {
return c.Conn
}
type clientPacketConn struct {
*Method
net.Conn
session *udpSession
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
var hdrLen int
if c.udpCipher != nil {
hdrLen = PacketNonceSize
}
hdrLen += 16 // packet header
pskLen := len(c.pskList)
if c.udpCipher == nil && pskLen > 1 {
hdrLen += (pskLen - 1) * aes.BlockSize
}
hdrLen += 1 // header type
hdrLen += 8 // timestamp
hdrLen += 2 // padding length
hdrLen += M.SocksaddrSerializer.AddrPortLen(destination)
header := buf.With(buffer.ExtendHeader(hdrLen))
var dataIndex int
if c.udpCipher != nil {
common.Must1(header.ReadFullFrom(c.session.rng, PacketNonceSize))
if pskLen > 1 {
panic("unsupported chacha extended header")
}
dataIndex = PacketNonceSize
} else {
dataIndex = aes.BlockSize
}
common.Must(
binary.Write(header, binary.BigEndian, c.session.sessionId),
binary.Write(header, binary.BigEndian, c.session.nextPacketId()),
)
if c.udpCipher == nil && pskLen > 1 {
for i, psk := range c.pskList {
dataIndex += aes.BlockSize
pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
identityHeader := header.Extend(aes.BlockSize)
for textI := 0; textI < aes.BlockSize; textI++ {
identityHeader[textI] = pskHash[textI] ^ header.Byte(textI)
}
c.blockConstructor(psk).Encrypt(identityHeader, identityHeader)
if i == pskLen-2 {
break
}
}
}
common.Must(
header.WriteByte(HeaderTypeClient),
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
if c.udpCipher != nil {
c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
buffer.Extend(shadowaead.Overhead)
} else {
packetHeader := buffer.To(aes.BlockSize)
c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
buffer.Extend(shadowaead.Overhead)
c.udpBlockCipher.Encrypt(packetHeader, packetHeader)
}
return common.Error(c.Write(buffer.Bytes()))
}
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return M.Socksaddr{}, err
}
buffer.Truncate(n)
var packetHeader []byte
if c.udpCipher != nil {
_, err = c.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
if err != nil {
return M.Socksaddr{}, E.Cause(err, "decrypt packet")
}
buffer.Advance(PacketNonceSize)
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
} else {
packetHeader = buffer.To(aes.BlockSize)
c.udpBlockCipher.Decrypt(packetHeader, packetHeader)
}
var sessionId, packetId uint64
err = binary.Read(buffer, binary.BigEndian, &sessionId)
if err != nil {
return M.Socksaddr{}, err
}
err = binary.Read(buffer, binary.BigEndian, &packetId)
if err != nil {
return M.Socksaddr{}, err
}
var remoteCipher cipher.AEAD
if packetHeader != nil {
if sessionId == c.session.remoteSessionId {
remoteCipher = c.session.remoteCipher
} else if sessionId == c.session.lastRemoteSessionId {
remoteCipher = c.session.lastRemoteCipher
} else {
key := SessionKey(c.pskList[len(c.pskList)-1], packetHeader[:8], c.keySaltLength)
remoteCipher = c.constructor(common.Dup(key))
runtime.KeepAlive(key)
}
_, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
if err != nil {
return M.Socksaddr{}, E.Cause(err, "decrypt packet")
}
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
}
var headerType byte
headerType, err = buffer.ReadByte()
if err != nil {
return M.Socksaddr{}, err
}
if headerType != HeaderTypeServer {
return M.Socksaddr{}, E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
}
var epoch uint64
err = binary.Read(buffer, binary.BigEndian, &epoch)
if err != nil {
return M.Socksaddr{}, err
}
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
if diff > 30 {
return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
}
if sessionId == c.session.remoteSessionId {
if !c.session.filter.ValidateCounter(packetId, math.MaxUint64) {
return M.Socksaddr{}, ErrPacketIdNotUnique
}
} else if sessionId == c.session.lastRemoteSessionId {
if !c.session.lastFilter.ValidateCounter(packetId, math.MaxUint64) {
return M.Socksaddr{}, ErrPacketIdNotUnique
}
remoteCipher = c.session.lastRemoteCipher
c.session.lastRemoteSeen = time.Now().Unix()
} else {
if c.session.remoteSessionId != 0 {
if time.Now().Unix()-c.session.lastRemoteSeen < 60 {
return M.Socksaddr{}, ErrTooManyServerSessions
} else {
c.session.lastRemoteSessionId = c.session.remoteSessionId
c.session.lastFilter = c.session.filter
c.session.lastRemoteSeen = time.Now().Unix()
c.session.lastRemoteCipher = c.session.remoteCipher
c.session.filter = wgReplay.Filter{}
}
}
c.session.remoteSessionId = sessionId
c.session.remoteCipher = remoteCipher
c.session.filter.ValidateCounter(packetId, math.MaxUint64)
}
var clientSessionId uint64
err = binary.Read(buffer, binary.BigEndian, &clientSessionId)
if err != nil {
return M.Socksaddr{}, err
}
if clientSessionId != c.session.sessionId {
return M.Socksaddr{}, ErrBadClientSessionId
}
var paddingLength uint16
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
if err != nil {
return M.Socksaddr{}, E.Cause(err, "read padding length")
}
buffer.Advance(int(paddingLength))
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
return M.Socksaddr{}, err
}
return destination, nil
}
func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
buffer := buf.With(p)
destination, err := c.ReadPacket(buffer)
if err != nil {
return
}
addr = destination.UDPAddr()
n = copy(p, buffer.Bytes())
return
}
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
var overHead int
if c.udpCipher != nil {
overHead = PacketNonceSize + shadowaead.Overhead
} else {
overHead = shadowaead.Overhead
}
overHead += 16 // packet header
pskLen := len(c.pskList)
if c.udpCipher == nil && pskLen > 1 {
overHead += (pskLen - 1) * aes.BlockSize
}
overHead += 1 // header type
overHead += 8 // timestamp
overHead += 2 // padding length
overHead += M.SocksaddrSerializer.AddrPortLen(destination)
_buffer := buf.Make(overHead + len(p))
defer runtime.KeepAlive(_buffer)
buffer := buf.With(common.Dup(_buffer))
var dataIndex int
if c.udpCipher != nil {
common.Must1(buffer.ReadFullFrom(c.session.rng, PacketNonceSize))
if pskLen > 1 {
panic("unsupported chacha extended header")
}
dataIndex = PacketNonceSize
} else {
dataIndex = aes.BlockSize
}
common.Must(
binary.Write(buffer, binary.BigEndian, c.session.sessionId),
binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()),
)
if c.udpCipher == nil && pskLen > 1 {
for i, psk := range c.pskList {
dataIndex += aes.BlockSize
pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
identityHeader := buffer.Extend(aes.BlockSize)
for textI := 0; textI < aes.BlockSize; textI++ {
identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI)
}
c.blockConstructor(psk).Encrypt(identityHeader, identityHeader)
if i == pskLen-2 {
break
}
}
}
common.Must(
buffer.WriteByte(HeaderTypeClient),
binary.Write(buffer, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(buffer, binary.BigEndian, uint16(0)), // padding length
)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
if err != nil {
return
}
if c.udpCipher != nil {
c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
buffer.Extend(shadowaead.Overhead)
} else {
packetHeader := buffer.To(aes.BlockSize)
c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
buffer.Extend(shadowaead.Overhead)
c.udpBlockCipher.Encrypt(packetHeader, packetHeader)
}
err = common.Error(c.Write(buffer.Bytes()))
if err != nil {
return
}
return len(p), nil
}
type udpSession struct {
headerType byte
sessionId uint64
packetId uint64
remoteSessionId uint64
lastRemoteSessionId uint64
lastRemoteSeen int64
cipher cipher.AEAD
remoteCipher cipher.AEAD
lastRemoteCipher cipher.AEAD
filter wgReplay.Filter
lastFilter wgReplay.Filter
rng io.Reader
}
func (s *udpSession) nextPacketId() uint64 {
return atomic.AddUint64(&s.packetId, 1)
}
func (m *Method) newUDPSession() *udpSession {
session := &udpSession{}
if m.udpCipher != nil {
session.rng = Blake3KeyedHash(rand.Reader)
common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
} else {
common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
}
session.packetId--
if m.udpCipher == nil {
sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength)
session.cipher = m.constructor(common.Dup(key))
runtime.KeepAlive(key)
}
return session
}
func (c *clientPacketConn) Upstream() any {
return c.Conn
}
func Blake3KeyedHash(reader io.Reader) io.Reader {
key := make([]byte, 32)
common.Must1(io.ReadFull(reader, key))
h := blake3.New(1024, key)
return h.XOF()
}

233
shadowaead_2022/relay.go Normal file
View file

@ -0,0 +1,233 @@
package shadowaead_2022
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"io"
"net"
"os"
"runtime"
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
"lukechampine.com/blake3"
)
type Relay[U comparable] struct {
name string
secureRNG io.Reader
keySaltLength int
handler shadowsocks.Handler
constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block
udpBlockCipher cipher.Block
iPSK []byte
uPSKHash map[U][aes.BlockSize]byte
uPSKHashR map[[aes.BlockSize]byte]U
uDestination map[U]M.Socksaddr
uCipher map[U]cipher.Block
udpNat *udpnat.Service[uint64]
udpSessions *cache.LruCache[uint64, *relayUDPSession]
}
func (s *Relay[U]) AddUser(user U, key []byte, destination M.Socksaddr) error {
if len(key) < s.keySaltLength {
return shadowsocks.ErrBadKey
} else if len(key) > s.keySaltLength {
key = Key(key, s.keySaltLength)
}
var uPSKHash [aes.BlockSize]byte
hash512 := blake3.Sum512(key)
copy(uPSKHash[:], hash512[:])
if oldHash, loaded := s.uPSKHash[user]; loaded {
delete(s.uPSKHashR, oldHash)
}
s.uPSKHash[user] = uPSKHash
s.uPSKHashR[uPSKHash] = user
s.uDestination[user] = destination
s.uCipher[user] = s.blockConstructor(key)
return nil
}
func (s *Relay[U]) RemoveUser(user U) {
if hash, loaded := s.uPSKHash[user]; loaded {
delete(s.uPSKHashR, hash)
}
delete(s.uPSKHash, user)
delete(s.uCipher, user)
}
func NewRelay[U comparable](method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (*Relay[U], error) {
s := &Relay[U]{
name: method,
secureRNG: secureRNG,
handler: handler,
uPSKHash: make(map[U][aes.BlockSize]byte),
uPSKHashR: make(map[[aes.BlockSize]byte]U),
uDestination: make(map[U]M.Socksaddr),
uCipher: make(map[U]cipher.Block),
udpNat: udpnat.New[uint64](udpTimeout, handler),
udpSessions: cache.New(
cache.WithAge[uint64, *relayUDPSession](udpTimeout),
cache.WithUpdateAgeOnGet[uint64, *relayUDPSession](),
),
}
switch method {
case "2022-blake3-aes-128-gcm":
s.keySaltLength = 16
s.constructor = newAESGCM
s.blockConstructor = newAES
case "2022-blake3-aes-256-gcm":
s.keySaltLength = 32
s.constructor = newAESGCM
s.blockConstructor = newAES
default:
return nil, os.ErrInvalid
}
if len(psk) != s.keySaltLength {
if len(psk) < s.keySaltLength {
return nil, shadowsocks.ErrBadKey
} else {
psk = Key(psk, s.keySaltLength)
}
}
s.udpBlockCipher = s.blockConstructor(psk)
return s, nil
}
func (s *Relay[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
err := s.newConnection(ctx, conn, metadata)
if err != nil {
err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
}
return err
}
func (s *Relay[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
_requestHeader := buf.StackNew()
defer runtime.KeepAlive(_requestHeader)
requestHeader := common.Dup(_requestHeader)
n, err := requestHeader.ReadFrom(conn)
if err != nil {
return err
} else if int(n) < s.keySaltLength+aes.BlockSize {
return shadowaead.ErrBadHeader
}
requestSalt := requestHeader.To(s.keySaltLength)
var _eiHeader [aes.BlockSize]byte
eiHeader := common.Dup(_eiHeader[:])
copy(eiHeader, requestHeader.Range(s.keySaltLength, s.keySaltLength+aes.BlockSize))
keyMaterial := buf.Make(s.keySaltLength * 2)
copy(keyMaterial, s.iPSK)
copy(keyMaterial[s.keySaltLength:], requestSalt)
_identitySubkey := buf.Make(s.keySaltLength)
identitySubkey := common.Dup(_identitySubkey)
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader)
runtime.KeepAlive(_identitySubkey)
var user U
if u, loaded := s.uPSKHashR[_eiHeader]; loaded {
user = u
} else {
return E.New("invalid request")
}
runtime.KeepAlive(_eiHeader)
copy(requestHeader.Range(aes.BlockSize, aes.BlockSize+s.keySaltLength), requestHeader.To(s.keySaltLength))
requestHeader.Advance(aes.BlockSize)
ctx = shadowsocks.UserContext[U]{
ctx,
user,
}
metadata.Protocol = "shadowsocks-relay"
metadata.Destination = s.uDestination[user]
conn = &bufio.BufferedConn{
Conn: conn,
Buffer: requestHeader,
}
return s.handler.NewConnection(ctx, conn, metadata)
}
func (s *Relay[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
err := s.newPacket(ctx, conn, buffer, metadata)
if err != nil {
err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
}
return err
}
func (s *Relay[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
packetHeader := buffer.To(aes.BlockSize)
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
sessionId := binary.BigEndian.Uint64(packetHeader)
var _eiHeader [aes.BlockSize]byte
eiHeader := common.Dup(_eiHeader[:])
s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize))
for i := range eiHeader {
eiHeader[i] = eiHeader[i] ^ packetHeader[i]
}
var user U
if u, loaded := s.uPSKHashR[_eiHeader]; loaded {
user = u
} else {
return E.New("invalid request")
}
session, _ := s.udpSessions.LoadOrStore(sessionId, func() *relayUDPSession {
return new(relayUDPSession)
})
session.sourceAddr = metadata.Source
s.uCipher[user].Encrypt(packetHeader, packetHeader)
copy(buffer.Range(aes.BlockSize, 2*aes.BlockSize), packetHeader)
buffer.Advance(aes.BlockSize)
metadata.Protocol = "shadowsocks-relay"
metadata.Destination = s.uDestination[user]
s.udpNat.NewContextPacket(ctx, sessionId, func() (context.Context, N.PacketWriter) {
return &shadowsocks.UserContext[U]{
ctx,
user,
}, &relayPacketWriter[U]{conn, session}
}, buffer, metadata)
return nil
}
type relayUDPSession struct {
sourceAddr M.Socksaddr
}
type relayPacketWriter[U comparable] struct {
N.PacketConn
session *relayUDPSession
}
func (w *relayPacketWriter[U]) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
return w.PacketConn.WritePacket(buffer, w.session.sourceAddr)
}

489
shadowaead_2022/service.go Normal file
View file

@ -0,0 +1,489 @@
package shadowaead_2022
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"io"
"math"
"net"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
wgReplay "golang.zx2c4.com/wireguard/replay"
)
var (
ErrSaltNotUnique = E.New("bad request: salt not unique")
ErrNoPadding = E.New("bad request: missing payload or padding")
ErrBadPadding = E.New("bad request: damaged padding")
)
type Service struct {
name string
keySaltLength int
handler shadowsocks.Handler
constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block
udpCipher cipher.AEAD
udpBlockCipher cipher.Block
psk []byte
replayFilter replay.Filter
udpNat *udpnat.Service[uint64]
udpSessions *cache.LruCache[uint64, *serverUDPSession]
}
func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
if password == "" {
return nil, ErrMissingPSK
}
psk, err := base64.StdEncoding.DecodeString(password)
if err != nil {
return nil, E.Cause(err, "decode psk")
}
return NewService(method, psk, udpTimeout, handler)
}
func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{
name: method,
handler: handler,
replayFilter: replay.NewSimple(60 * time.Second),
udpNat: udpnat.New[uint64](udpTimeout, handler),
udpSessions: cache.New[uint64, *serverUDPSession](
cache.WithAge[uint64, *serverUDPSession](udpTimeout),
cache.WithUpdateAgeOnGet[uint64, *serverUDPSession](),
),
}
switch method {
case "2022-blake3-aes-128-gcm":
s.keySaltLength = 16
s.constructor = newAESGCM
s.blockConstructor = newAES
case "2022-blake3-aes-256-gcm":
s.keySaltLength = 32
s.constructor = newAESGCM
s.blockConstructor = newAES
case "2022-blake3-chacha20-poly1305":
s.keySaltLength = 32
s.constructor = newChacha20Poly1305
default:
return nil, os.ErrInvalid
}
if len(psk) != s.keySaltLength {
if len(psk) < s.keySaltLength {
return nil, shadowsocks.ErrBadKey
} else if len(psk) > s.keySaltLength {
psk = Key(psk, s.keySaltLength)
} else {
return nil, ErrMissingPSK
}
}
switch method {
case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm":
s.udpBlockCipher = newAES(psk)
case "2022-blake3-chacha20-poly1305":
s.udpCipher = newXChacha20Poly1305(psk)
}
s.psk = psk
return s, nil
}
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
err := s.newConnection(ctx, conn, metadata)
if err != nil {
err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
}
return err
}
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
header := buf.Make(s.keySaltLength + shadowaead.Overhead + RequestHeaderFixedChunkLength)
n, err := conn.Read(header)
if err != nil {
return E.Cause(err, "read header")
} else if n < len(header) {
return shadowaead.ErrBadHeader
}
requestSalt := header[:s.keySaltLength]
if !s.replayFilter.Check(requestSalt) {
return ErrSaltNotUnique
}
requestKey := SessionKey(s.psk, requestSalt, s.keySaltLength)
reader := shadowaead.NewReader(
conn,
s.constructor(common.Dup(requestKey)),
MaxPacketSize,
)
runtime.KeepAlive(requestKey)
err = reader.ReadChunk(header[s.keySaltLength:])
if err != nil {
return err
}
headerType, err := reader.ReadByte()
if err != nil {
return E.Cause(err, "read header")
}
if headerType != HeaderTypeClient {
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
}
var epoch uint64
err = binary.Read(reader, binary.BigEndian, &epoch)
if err != nil {
return err
}
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
if diff > 30 {
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
}
var length uint16
err = binary.Read(reader, binary.BigEndian, &length)
if err != nil {
return err
}
err = reader.ReadWithLength(length)
if err != nil {
return err
}
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return err
}
var paddingLen uint16
err = binary.Read(reader, binary.BigEndian, &paddingLen)
if err != nil {
return err
}
if uint16(reader.Cached()) < paddingLen {
return ErrNoPadding
}
if paddingLen > 0 {
err = reader.Discard(int(paddingLen))
if err != nil {
return E.Cause(err, "discard padding")
}
} else if reader.Cached() == 0 {
return ErrNoPadding
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.handler.NewConnection(ctx, &serverConn{
Service: s,
Conn: conn,
uPSK: s.psk,
reader: reader,
requestSalt: requestSalt,
}, metadata)
}
type serverConn struct {
*Service
net.Conn
uPSK []byte
access sync.Mutex
reader *shadowaead.Reader
writer *shadowaead.Writer
requestSalt []byte
}
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
_salt := buf.Make(c.keySaltLength)
salt := common.Dup(_salt[:])
common.Must1(io.ReadFull(rand.Reader, salt))
key := SessionKey(c.uPSK, salt, c.keySaltLength)
runtime.KeepAlive(_salt)
writer := shadowaead.NewWriter(
c.Conn,
c.constructor(common.Dup(key)),
MaxPacketSize,
)
runtime.KeepAlive(key)
header := writer.Buffer()
header.Write(salt)
_headerFixedChunk := buf.Make(1 + 8 + c.keySaltLength + 2)
headerFixedChunk := buf.With(common.Dup(_headerFixedChunk))
common.Must(headerFixedChunk.WriteByte(HeaderTypeServer))
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(time.Now().Unix())))
common.Must1(headerFixedChunk.Write(c.requestSalt))
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(len(payload))))
writer.WriteChunk(header, headerFixedChunk.Slice())
runtime.KeepAlive(_headerFixedChunk)
c.requestSalt = nil
if len(payload) > 0 {
writer.WriteChunk(header, payload)
}
err = writer.BufferedWriter(header.Len()).Flush()
if err != nil {
return
}
c.writer = writer
n = len(payload)
return
}
func (c *serverConn) Write(p []byte) (n int, err error) {
if c.writer != nil {
return c.writer.Write(p)
}
c.access.Lock()
if c.writer != nil {
c.access.Unlock()
return c.writer.Write(p)
}
defer c.access.Unlock()
return c.writeResponse(p)
}
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
return rw.ReadFrom0(c, r)
}
return c.writer.ReadFrom(r)
}
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}
func (c *serverConn) Upstream() any {
return c.Conn
}
func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
err := s.newPacket(ctx, conn, buffer, metadata)
if err != nil {
err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
}
return err
}
func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
var packetHeader []byte
if s.udpCipher != nil {
_, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
if err != nil {
return E.Cause(err, "decrypt packet header")
}
buffer.Advance(PacketNonceSize)
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
} else {
packetHeader = buffer.To(aes.BlockSize)
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
}
var sessionId, packetId uint64
err := binary.Read(buffer, binary.BigEndian, &sessionId)
if err != nil {
return err
}
err = binary.Read(buffer, binary.BigEndian, &packetId)
if err != nil {
return err
}
session, loaded := s.udpSessions.LoadOrStore(sessionId, s.newUDPSession)
if !loaded {
session.remoteSessionId = sessionId
if packetHeader != nil {
key := SessionKey(s.psk, packetHeader[:8], s.keySaltLength)
session.remoteCipher = s.constructor(common.Dup(key))
runtime.KeepAlive(key)
}
}
goto process
returnErr:
if !loaded {
s.udpSessions.Delete(sessionId)
}
return err
process:
if !session.filter.ValidateCounter(packetId, math.MaxUint64) {
err = ErrPacketIdNotUnique
goto returnErr
}
if packetHeader != nil {
_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
if err != nil {
err = E.Cause(err, "decrypt packet")
goto returnErr
}
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
}
var headerType byte
headerType, err = buffer.ReadByte()
if err != nil {
err = E.Cause(err, "decrypt packet")
goto returnErr
}
if headerType != HeaderTypeClient {
err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
goto returnErr
}
var epoch uint64
err = binary.Read(buffer, binary.BigEndian, &epoch)
if err != nil {
goto returnErr
}
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
if diff > 30 {
err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
goto returnErr
}
var paddingLength uint16
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
if err != nil {
err = E.Cause(err, "read padding length")
goto returnErr
}
buffer.Advance(int(paddingLength))
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
goto returnErr
}
metadata.Destination = destination
session.remoteAddr = metadata.Source
s.udpNat.NewPacket(ctx, sessionId, func() N.PacketWriter {
return &serverPacketWriter{s, conn, session}
}, buffer, metadata)
return nil
}
type serverPacketWriter struct {
*Service
N.PacketConn
session *serverUDPSession
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
var hdrLen int
if w.udpCipher != nil {
hdrLen = PacketNonceSize
}
hdrLen += 16 // packet header
hdrLen += 1 // header type
hdrLen += 8 // timestamp
hdrLen += 8 // remote session id
hdrLen += 2 // padding length
hdrLen += M.SocksaddrSerializer.AddrPortLen(destination)
header := buf.With(buffer.ExtendHeader(hdrLen))
var dataIndex int
if w.udpCipher != nil {
common.Must1(header.ReadFullFrom(w.session.rng, PacketNonceSize))
dataIndex = PacketNonceSize
} else {
dataIndex = aes.BlockSize
}
common.Must(
binary.Write(header, binary.BigEndian, w.session.sessionId),
binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
header.WriteByte(HeaderTypeServer),
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
if w.udpCipher != nil {
w.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
buffer.Extend(shadowaead.Overhead)
} else {
packetHeader := buffer.To(aes.BlockSize)
w.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
buffer.Extend(shadowaead.Overhead)
w.udpBlockCipher.Encrypt(packetHeader, packetHeader)
}
return w.PacketConn.WritePacket(buffer, w.session.remoteAddr)
}
type serverUDPSession struct {
sessionId uint64
remoteSessionId uint64
remoteAddr M.Socksaddr
packetId uint64
cipher cipher.AEAD
remoteCipher cipher.AEAD
filter wgReplay.Filter
rng io.Reader
}
func (s *serverUDPSession) nextPacketId() uint64 {
return atomic.AddUint64(&s.packetId, 1)
}
func (m *Service) newUDPSession() *serverUDPSession {
session := &serverUDPSession{}
if m.udpCipher != nil {
session.rng = Blake3KeyedHash(rand.Reader)
common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
} else {
common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
}
session.packetId--
if m.udpCipher == nil {
sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := SessionKey(m.psk, sessionId, m.keySaltLength)
session.cipher = m.constructor(common.Dup(key))
runtime.KeepAlive(key)
}
return session
}

View file

@ -0,0 +1,365 @@
package shadowaead_2022
import (
"context"
"crypto/aes"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"math"
"net"
"os"
"runtime"
"time"
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"lukechampine.com/blake3"
)
type MultiService[U comparable] struct {
*Service
uPSK map[U][]byte
uPSKHash map[U][aes.BlockSize]byte
uPSKHashR map[[aes.BlockSize]byte]U
}
func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) {
if password == "" {
return nil, ErrMissingPSK
}
iPSK, err := base64.StdEncoding.DecodeString(password)
if err != nil {
return nil, E.Cause(err, "decode psk")
}
return NewMultiService[U](method, iPSK, udpTimeout, handler)
}
func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) {
switch method {
case "2022-blake3-aes-128-gcm":
case "2022-blake3-aes-256-gcm":
default:
return nil, os.ErrInvalid
}
ss, err := NewService(method, iPSK, udpTimeout, handler)
if err != nil {
return nil, err
}
s := &MultiService[U]{
Service: ss.(*Service),
uPSK: make(map[U][]byte),
uPSKHash: make(map[U][aes.BlockSize]byte),
uPSKHashR: make(map[[aes.BlockSize]byte]U),
}
return s, nil
}
func (s *MultiService[U]) AddUser(user U, key []byte) error {
if len(key) < s.keySaltLength {
return shadowsocks.ErrBadKey
} else if len(key) > s.keySaltLength {
key = Key(key, s.keySaltLength)
}
var uPSKHash [aes.BlockSize]byte
hash512 := blake3.Sum512(key)
copy(uPSKHash[:], hash512[:])
if oldHash, loaded := s.uPSKHash[user]; loaded {
delete(s.uPSKHashR, oldHash)
}
s.uPSKHash[user] = uPSKHash
s.uPSKHashR[uPSKHash] = user
s.uPSK[user] = key
return nil
}
func (s *MultiService[U]) AddUserWithPassword(user U, password string) error {
if password == "" {
return shadowsocks.ErrMissingPassword
}
psk, err := base64.StdEncoding.DecodeString(password)
if err != nil {
return E.Cause(err, "decode psk")
}
return s.AddUser(user, psk)
}
func (s *MultiService[U]) RemoveUser(user U) {
if hash, loaded := s.uPSKHash[user]; loaded {
delete(s.uPSKHashR, hash)
}
delete(s.uPSK, user)
delete(s.uPSKHash, user)
}
func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
err := s.newConnection(ctx, conn, metadata)
if err != nil {
err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
}
return err
}
func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
requestHeader := make([]byte, s.keySaltLength+aes.BlockSize+shadowaead.Overhead+RequestHeaderFixedChunkLength)
n, err := conn.Read(requestHeader)
if err != nil {
return err
} else if n < len(requestHeader) {
return shadowaead.ErrBadHeader
}
requestSalt := requestHeader[:s.keySaltLength]
if !s.replayFilter.Check(requestSalt) {
return ErrSaltNotUnique
}
var _eiHeader [aes.BlockSize]byte
eiHeader := common.Dup(_eiHeader[:])
copy(eiHeader, requestHeader[s.keySaltLength:s.keySaltLength+aes.BlockSize])
keyMaterial := buf.Make(s.keySaltLength * 2)
copy(keyMaterial, s.psk)
copy(keyMaterial[s.keySaltLength:], requestSalt)
_identitySubkey := buf.Make(s.keySaltLength)
identitySubkey := common.Dup(_identitySubkey)
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader)
runtime.KeepAlive(_identitySubkey)
var user U
var uPSK []byte
if u, loaded := s.uPSKHashR[_eiHeader]; loaded {
user = u
uPSK = s.uPSK[u]
} else {
return E.New("invalid request")
}
runtime.KeepAlive(_eiHeader)
requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength)
reader := shadowaead.NewReader(
conn,
s.constructor(common.Dup(requestKey)),
MaxPacketSize,
)
err = reader.ReadChunk(requestHeader[s.keySaltLength+aes.BlockSize:])
if err != nil {
return err
}
headerType, err := rw.ReadByte(reader)
if err != nil {
return E.Cause(err, "read header")
}
if headerType != HeaderTypeClient {
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
}
var epoch uint64
err = binary.Read(reader, binary.BigEndian, &epoch)
if err != nil {
return E.Cause(err, "read timestamp")
}
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
if diff > 30 {
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
}
var length uint16
err = binary.Read(reader, binary.BigEndian, &length)
if err != nil {
return E.Cause(err, "read length")
}
err = reader.ReadWithLength(length)
if err != nil {
return err
}
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return E.Cause(err, "read destination")
}
var paddingLen uint16
err = binary.Read(reader, binary.BigEndian, &paddingLen)
if err != nil {
return E.Cause(err, "read padding length")
}
if reader.Cached() < int(paddingLen) {
return ErrBadPadding
} else if paddingLen > 0 {
err = reader.Discard(int(paddingLen))
if err != nil {
return E.Cause(err, "discard padding")
}
} else if reader.Cached() == 0 {
return ErrNoPadding
}
var userCtx shadowsocks.UserContext[U]
userCtx.Context = ctx
userCtx.User = user
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.handler.NewConnection(&userCtx, &serverConn{
Service: s.Service,
Conn: conn,
uPSK: uPSK,
reader: reader,
requestSalt: requestSalt,
}, metadata)
}
func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
err := s.newPacket(ctx, conn, buffer, metadata)
if err != nil {
err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
}
return err
}
func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
packetHeader := buffer.To(aes.BlockSize)
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
var _eiHeader [aes.BlockSize]byte
eiHeader := common.Dup(_eiHeader[:])
s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize))
for i := range eiHeader {
eiHeader[i] = eiHeader[i] ^ packetHeader[i]
}
var user U
var uPSK []byte
if u, loaded := s.uPSKHashR[_eiHeader]; loaded {
user = u
uPSK = s.uPSK[u]
} else {
return E.New("invalid request")
}
var sessionId, packetId uint64
err := binary.Read(buffer, binary.BigEndian, &sessionId)
if err != nil {
return err
}
err = binary.Read(buffer, binary.BigEndian, &packetId)
if err != nil {
return err
}
session, loaded := s.udpSessions.LoadOrStore(sessionId, func() *serverUDPSession {
return s.newUDPSession(uPSK)
})
if !loaded {
session.remoteSessionId = sessionId
key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength)
session.remoteCipher = s.constructor(common.Dup(key))
runtime.KeepAlive(key)
}
goto process
returnErr:
if !loaded {
s.udpSessions.Delete(sessionId)
}
return err
process:
if !session.filter.ValidateCounter(packetId, math.MaxUint64) {
err = ErrPacketIdNotUnique
goto returnErr
}
if packetHeader != nil {
_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
if err != nil {
err = E.Cause(err, "decrypt packet")
goto returnErr
}
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
}
var headerType byte
headerType, err = buffer.ReadByte()
if err != nil {
err = E.Cause(err, "decrypt packet")
goto returnErr
}
if headerType != HeaderTypeClient {
err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
goto returnErr
}
var epoch uint64
err = binary.Read(buffer, binary.BigEndian, &epoch)
if err != nil {
goto returnErr
}
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
if diff > 30 {
err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
goto returnErr
}
var paddingLength uint16
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
if err != nil {
err = E.Cause(err, "read padding length")
goto returnErr
}
buffer.Advance(int(paddingLength))
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
goto returnErr
}
metadata.Destination = destination
session.remoteAddr = metadata.Source
s.udpNat.NewContextPacket(ctx, sessionId, func() (context.Context, N.PacketWriter) {
return &shadowsocks.UserContext[U]{
ctx,
user,
}, &serverPacketWriter{s.Service, conn, session}
}, buffer, metadata)
return nil
}
func (s *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession {
session := &serverUDPSession{}
if s.udpCipher != nil {
session.rng = Blake3KeyedHash(rand.Reader)
common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
} else {
common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
}
session.packetId--
sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := SessionKey(uPSK, sessionId, s.keySaltLength)
session.cipher = s.constructor(common.Dup(key))
runtime.KeepAlive(key)
return session
}

View file

@ -0,0 +1,75 @@
package shadowaead_2022_test
import (
"context"
"crypto/rand"
"net"
"sync"
"testing"
"github.com/sagernet/sing-shadowsocks/shadowaead_2022"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func TestMultiService(t *testing.T) {
method := "2022-blake3-aes-128-gcm"
var iPSK [16]byte
rand.Reader.Read(iPSK[:])
var wg sync.WaitGroup
multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], 500, &multiHandler{t, &wg})
if err != nil {
t.Fatal(err)
}
var uPSK [16]byte
rand.Reader.Read(uPSK[:])
multiService.AddUser("my user", uPSK[:])
client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]})
if err != nil {
t.Fatal(err)
}
wg.Add(1)
serverConn, clientConn := net.Pipe()
defer common.Close(serverConn, clientConn)
go func() {
err := multiService.NewConnection(context.Background(), serverConn, M.Metadata{})
if err != nil {
serverConn.Close()
t.Error(E.Cause(err, "server"))
return
}
}()
_, err = client.DialConn(clientConn, M.ParseSocksaddr("test.com:443"))
if err != nil {
t.Fatal(err)
}
wg.Wait()
}
type multiHandler struct {
t *testing.T
wg *sync.WaitGroup
}
func (h *multiHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
if metadata.Destination.String() != "test.com:443" {
h.t.Error("bad destination")
}
h.wg.Done()
return nil
}
func (h *multiHandler) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
return nil
}
func (h *multiHandler) HandleError(err error) {
h.t.Error(err)
}

View file

@ -0,0 +1,49 @@
package shadowaead_2022_test
import (
"context"
"crypto/rand"
"net"
"sync"
"testing"
"github.com/sagernet/sing-shadowsocks/shadowaead_2022"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
func TestService(t *testing.T) {
method := "2022-blake3-aes-128-gcm"
var psk [16]byte
rand.Reader.Read(psk[:])
var wg sync.WaitGroup
service, err := shadowaead_2022.NewService(method, psk[:], 500, &multiHandler{t, &wg})
if err != nil {
t.Fatal(err)
}
client, err := shadowaead_2022.New(method, [][]byte{psk[:]})
if err != nil {
t.Fatal(err)
}
wg.Add(1)
serverConn, clientConn := net.Pipe()
defer common.Close(serverConn, clientConn)
go func() {
err := service.NewConnection(context.Background(), serverConn, M.Metadata{})
if err != nil {
serverConn.Close()
t.Error(E.Cause(err, "server"))
return
}
}()
_, err = client.DialConn(clientConn, M.ParseSocksaddr("test.com:443"))
if err != nil {
t.Fatal(err)
}
wg.Wait()
}

24
shadowimpl/fetcher.go Normal file
View file

@ -0,0 +1,24 @@
package shadowimpl
import (
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing-shadowsocks/shadowaead_2022"
"github.com/sagernet/sing-shadowsocks/shadowstream"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
func FetchMethod(method string, password string) (shadowsocks.Method, error) {
if method == "none" {
return shadowsocks.NewNone(), nil
} else if common.Contains(shadowstream.List, method) {
return shadowstream.New(method, nil, password)
} else if common.Contains(shadowaead.List, method) {
return shadowaead.New(method, nil, password)
} else if common.Contains(shadowaead_2022.List, method) {
return shadowaead_2022.NewWithPassword(method, password)
} else {
return nil, E.New("shadowsocks: unsupported method ", method)
}
}

89
shadowsocks.go Normal file
View file

@ -0,0 +1,89 @@
package shadowsocks
import (
"context"
"crypto/md5"
"fmt"
"net"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var (
ErrBadKey = E.New("bad key")
ErrMissingPassword = E.New("missing password")
)
type Method interface {
Name() string
KeyLength() int
DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error)
DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn
DialPacketConn(conn net.Conn) N.NetPacketConn
}
type Service interface {
N.TCPConnectionHandler
N.UDPHandler
}
type Handler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
E.Handler
}
type UserContext[U comparable] struct {
context.Context
User U
}
type ServerConnError struct {
net.Conn
Source M.Socksaddr
Cause error
}
func (e *ServerConnError) Close() error {
if conn, ok := common.Cast[*net.TCPConn](e.Conn); ok {
conn.SetLinger(0)
}
return e.Conn.Close()
}
func (e *ServerConnError) Unwrap() error {
return e.Cause
}
func (e *ServerConnError) Error() string {
return fmt.Sprint("shadowsocks: serve TCP from ", e.Source, ": ", e.Cause)
}
type ServerPacketError struct {
Source M.Socksaddr
Cause error
}
func (e *ServerPacketError) Unwrap() error {
return e.Cause
}
func (e *ServerPacketError) Error() string {
return fmt.Sprint("shadowsocks: serve UDP from ", e.Source, ": ", e.Cause)
}
func Key(password []byte, keySize int) []byte {
var b, prev []byte
h := md5.New()
for len(b) < keySize {
h.Write(prev)
h.Write([]byte(password))
b = h.Sum(b)
prev = b[len(b)-h.Size():]
h.Reset()
}
return b[:keySize]
}

392
shadowstream/protocol.go Normal file
View file

@ -0,0 +1,392 @@
package shadowstream
import (
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/md5"
"crypto/rand"
"crypto/rc4"
"io"
"net"
"os"
"runtime"
"github.com/dgryski/go-camellia"
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.org/x/crypto/blowfish"
"golang.org/x/crypto/cast5"
"golang.org/x/crypto/chacha20"
)
var List = []string{
"aes-128-ctr",
"aes-192-ctr",
"aes-256-ctr",
"aes-128-cfb",
"aes-192-cfb",
"aes-256-cfb",
"camellia-128-cfb",
"camellia-192-cfb",
"camellia-256-cfb",
"bf-cfb",
"cast5-cfb",
"des-cfb",
"rc4",
"rc4-md5",
"chacha20",
"chacha20-ietf",
"xchacha20",
}
type Method struct {
name string
keyLength int
saltLength int
encryptConstructor func(key []byte, salt []byte) (cipher.Stream, error)
decryptConstructor func(key []byte, salt []byte) (cipher.Stream, error)
key []byte
}
func New(method string, key []byte, password string) (shadowsocks.Method, error) {
m := &Method{
name: method,
}
switch method {
case "aes-128-ctr":
m.keyLength = 16
m.saltLength = aes.BlockSize
m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
case "aes-192-ctr":
m.keyLength = 24
m.saltLength = aes.BlockSize
m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
case "aes-256-ctr":
m.keyLength = 32
m.saltLength = aes.BlockSize
m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
case "aes-128-cfb":
m.keyLength = 16
m.saltLength = aes.BlockSize
m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter)
m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
case "aes-192-cfb":
m.keyLength = 24
m.saltLength = aes.BlockSize
m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter)
m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
case "aes-256-cfb":
m.keyLength = 32
m.saltLength = aes.BlockSize
m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter)
m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
case "camellia-128-cfb":
m.keyLength = 16
m.saltLength = camellia.BlockSize
m.encryptConstructor = blockStream(camellia.New, cipher.NewCFBEncrypter)
m.decryptConstructor = blockStream(camellia.New, cipher.NewCFBDecrypter)
case "camellia-192-cfb":
m.keyLength = 24
m.saltLength = camellia.BlockSize
m.encryptConstructor = blockStream(camellia.New, cipher.NewCFBEncrypter)
m.decryptConstructor = blockStream(camellia.New, cipher.NewCFBDecrypter)
case "camellia-256-cfb":
m.keyLength = 32
m.saltLength = camellia.BlockSize
m.encryptConstructor = blockStream(camellia.New, cipher.NewCFBEncrypter)
m.decryptConstructor = blockStream(camellia.New, cipher.NewCFBDecrypter)
case "bf-cfb":
m.keyLength = 16
m.saltLength = blowfish.BlockSize
m.encryptConstructor = blockStream(func(key []byte) (cipher.Block, error) { return blowfish.NewCipher(key) }, cipher.NewCFBEncrypter)
m.decryptConstructor = blockStream(func(key []byte) (cipher.Block, error) { return blowfish.NewCipher(key) }, cipher.NewCFBDecrypter)
case "cast5-cfb":
m.keyLength = 16
m.saltLength = cast5.BlockSize
m.encryptConstructor = blockStream(func(key []byte) (cipher.Block, error) { return cast5.NewCipher(key) }, cipher.NewCFBEncrypter)
m.decryptConstructor = blockStream(func(key []byte) (cipher.Block, error) { return cast5.NewCipher(key) }, cipher.NewCFBDecrypter)
case "des-cfb":
m.keyLength = 8
m.saltLength = des.BlockSize
m.encryptConstructor = blockStream(des.NewCipher, cipher.NewCFBEncrypter)
m.decryptConstructor = blockStream(des.NewCipher, cipher.NewCFBDecrypter)
case "rc4":
m.keyLength = 16
m.saltLength = 0
m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
return rc4.NewCipher(key)
}
m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
return rc4.NewCipher(key)
}
case "rc4-md5":
m.keyLength = 16
m.saltLength = 0
m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
h := md5.New()
h.Write(key)
h.Write(salt)
return rc4.NewCipher(h.Sum(nil))
}
m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
h := md5.New()
h.Write(key)
h.Write(salt)
return rc4.NewCipher(h.Sum(nil))
}
case "chacha20", "chacha20-ietf":
m.keyLength = chacha20.KeySize
m.saltLength = chacha20.NonceSize
m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
return chacha20.NewUnauthenticatedCipher(key, salt)
}
m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
return chacha20.NewUnauthenticatedCipher(key, salt)
}
case "xchacha20":
m.keyLength = chacha20.KeySize
m.saltLength = chacha20.NonceSizeX
m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
return chacha20.NewUnauthenticatedCipher(key, salt)
}
m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
return chacha20.NewUnauthenticatedCipher(key, salt)
}
default:
return nil, os.ErrInvalid
}
if len(key) == m.keyLength {
m.key = key
} else if len(key) > 0 {
return nil, shadowsocks.ErrBadKey
} else if password != "" {
m.key = shadowsocks.Key([]byte(password), m.keyLength)
} else {
return nil, shadowsocks.ErrMissingPassword
}
return m, nil
}
func blockStream(blockCreator func(key []byte) (cipher.Block, error), streamCreator func(block cipher.Block, iv []byte) cipher.Stream) func([]byte, []byte) (cipher.Stream, error) {
return func(key []byte, iv []byte) (cipher.Stream, error) {
block, err := blockCreator(key)
if err != nil {
return nil, err
}
return streamCreator(block, iv), err
}
}
func (m *Method) Name() string {
return m.name
}
func (m *Method) KeyLength() int {
return m.keyLength
}
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &clientConn{
Conn: conn,
method: m,
destination: destination,
}
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
}
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &clientConn{
Conn: conn,
method: m,
destination: destination,
}
}
func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
return &clientPacketConn{m, conn}
}
type clientConn struct {
net.Conn
method *Method
destination M.Socksaddr
readStream cipher.Stream
writeStream cipher.Stream
}
func (c *clientConn) writeRequest(payload []byte) error {
_buffer := buf.Make(c.method.keyLength + M.SocksaddrSerializer.AddrPortLen(c.destination) + len(payload))
defer runtime.KeepAlive(_buffer)
buffer := buf.With(common.Dup(_buffer))
salt := buffer.Extend(c.method.keyLength)
common.Must1(io.ReadFull(rand.Reader, salt))
key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength)
writer, err := c.method.encryptConstructor(c.method.key, salt)
if err != nil {
return err
}
runtime.KeepAlive(key)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination)
if err != nil {
return err
}
_, err = buffer.Write(payload)
if err != nil {
return err
}
_, err = c.Conn.Write(buffer.Bytes())
if err != nil {
return err
}
c.writeStream = writer
return nil
}
func (c *clientConn) readResponse() error {
if c.readStream != nil {
return nil
}
_salt := buf.Make(c.method.keyLength)
defer runtime.KeepAlive(_salt)
salt := common.Dup(_salt)
_, err := io.ReadFull(c.Conn, salt)
if err != nil {
return err
}
key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength)
defer runtime.KeepAlive(key)
c.readStream, err = c.method.decryptConstructor(common.Dup(key), salt)
if err != nil {
return err
}
return nil
}
func (c *clientConn) Read(p []byte) (n int, err error) {
if err = c.readResponse(); err != nil {
return
}
n, err = c.Conn.Read(p)
if err != nil {
return 0, err
}
c.readStream.XORKeyStream(p[:n], p[:n])
return
}
func (c *clientConn) Write(p []byte) (n int, err error) {
if c.writeStream == nil {
err = c.writeRequest(p)
if err == nil {
n = len(p)
}
return
}
c.writeStream.XORKeyStream(p, p)
return c.Conn.Write(p)
}
func (c *clientConn) Upstream() any {
return c.Conn
}
type clientPacketConn struct {
*Method
net.Conn
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buf.With(buffer.ExtendHeader(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination)))
common.Must1(header.ReadFullFrom(rand.Reader, c.keyLength))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength))
if err != nil {
return err
}
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
return common.Error(c.Write(buffer.Bytes()))
}
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return M.Socksaddr{}, err
}
buffer.Truncate(n)
stream, err := c.decryptConstructor(c.key, buffer.To(c.keyLength))
if err != nil {
return M.Socksaddr{}, err
}
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
buffer.Advance(c.keyLength)
return M.SocksaddrSerializer.ReadAddrPort(buffer)
}
func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(p)
if err != nil {
return
}
stream, err := c.decryptConstructor(c.key, p[:c.keyLength])
if err != nil {
return
}
buffer := buf.With(p[c.keyLength:n])
stream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
return
}
addr = destination.UDPAddr()
n = copy(p, buffer.Bytes())
return
}
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
_buffer := buf.Make(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
defer runtime.KeepAlive(_buffer)
buffer := buf.With(common.Dup(_buffer))
common.Must1(buffer.ReadFullFrom(rand.Reader, c.keyLength))
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil {
return
}
_, err = buffer.Write(p)
if err != nil {
return
}
stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength))
if err != nil {
return
}
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
_, err = c.Write(buffer.Bytes())
if err != nil {
return
}
return len(p), nil
}
func (c *clientPacketConn) Upstream() any {
return c.Conn
}