mirror of
https://github.com/SagerNet/sing-shadowsocks.git
synced 2025-04-03 03:47:40 +03:00
635 lines
16 KiB
Go
635 lines
16 KiB
Go
package shadowaead_2022
|
|
|
|
import (
|
|
"context"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"io"
|
|
"math"
|
|
mRand "math/rand"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/sagernet/sing-shadowsocks"
|
|
"github.com/sagernet/sing-shadowsocks/shadowaead"
|
|
"github.com/sagernet/sing/common"
|
|
"github.com/sagernet/sing/common/buf"
|
|
"github.com/sagernet/sing/common/bufio"
|
|
"github.com/sagernet/sing/common/cache"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
N "github.com/sagernet/sing/common/network"
|
|
"github.com/sagernet/sing/common/replay"
|
|
"github.com/sagernet/sing/common/udpnat"
|
|
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
)
|
|
|
|
var (
|
|
ErrNoPadding = E.New("bad request: missing payload or padding")
|
|
ErrBadPadding = E.New("bad request: damaged padding")
|
|
)
|
|
|
|
var _ shadowsocks.Service = (*Service)(nil)
|
|
|
|
type Service struct {
|
|
name string
|
|
keySaltLength int
|
|
handler shadowsocks.Handler
|
|
timeFunc func() time.Time
|
|
|
|
constructor func(key []byte) (cipher.AEAD, error)
|
|
blockConstructor func(key []byte) (cipher.Block, error)
|
|
udpCipher cipher.AEAD
|
|
udpBlockCipher cipher.Block
|
|
psk []byte
|
|
|
|
replayFilter replay.Filter
|
|
udpNat *udpnat.Service[uint64]
|
|
udpSessions *cache.LruCache[uint64, *serverUDPSession]
|
|
}
|
|
|
|
func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) {
|
|
if password == "" {
|
|
return nil, ErrMissingPSK
|
|
}
|
|
psk, err := base64.StdEncoding.DecodeString(password)
|
|
if err != nil {
|
|
return nil, E.Cause(err, "decode psk")
|
|
}
|
|
return NewService(method, psk, udpTimeout, handler, timeFunc)
|
|
}
|
|
|
|
func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) {
|
|
s := &Service{
|
|
name: method,
|
|
handler: handler,
|
|
timeFunc: timeFunc,
|
|
|
|
replayFilter: replay.NewSimple(60 * time.Second),
|
|
udpNat: udpnat.New[uint64](udpTimeout, handler),
|
|
udpSessions: cache.New[uint64, *serverUDPSession](
|
|
cache.WithAge[uint64, *serverUDPSession](udpTimeout),
|
|
cache.WithUpdateAgeOnGet[uint64, *serverUDPSession](),
|
|
),
|
|
}
|
|
|
|
switch method {
|
|
case "2022-blake3-aes-128-gcm":
|
|
s.keySaltLength = 16
|
|
s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
|
|
s.blockConstructor = aes.NewCipher
|
|
case "2022-blake3-aes-256-gcm":
|
|
s.keySaltLength = 32
|
|
s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
|
|
s.blockConstructor = aes.NewCipher
|
|
case "2022-blake3-chacha20-poly1305":
|
|
s.keySaltLength = 32
|
|
s.constructor = chacha20poly1305.New
|
|
default:
|
|
return nil, os.ErrInvalid
|
|
}
|
|
|
|
if len(psk) != s.keySaltLength {
|
|
if len(psk) < s.keySaltLength {
|
|
return nil, shadowsocks.ErrBadKey
|
|
} else if len(psk) > s.keySaltLength {
|
|
psk = Key(psk, s.keySaltLength)
|
|
} else {
|
|
return nil, ErrMissingPSK
|
|
}
|
|
}
|
|
|
|
var err error
|
|
switch method {
|
|
case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm":
|
|
s.udpBlockCipher, err = aes.NewCipher(psk)
|
|
case "2022-blake3-chacha20-poly1305":
|
|
s.udpCipher, err = chacha20poly1305.NewX(psk)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.psk = psk
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Service) Name() string {
|
|
return s.name
|
|
}
|
|
|
|
func (s *Service) Password() string {
|
|
return base64.StdEncoding.EncodeToString(s.psk)
|
|
}
|
|
|
|
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
|
err := s.newConnection(ctx, conn, metadata)
|
|
if err != nil {
|
|
err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Service) time() time.Time {
|
|
if s.timeFunc != nil {
|
|
return s.timeFunc()
|
|
} else {
|
|
return time.Now()
|
|
}
|
|
}
|
|
|
|
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
|
header := make([]byte, s.keySaltLength+shadowaead.Overhead+RequestHeaderFixedChunkLength)
|
|
|
|
n, err := conn.Read(header)
|
|
if err != nil {
|
|
return E.Cause(err, "read header")
|
|
} else if n < len(header) {
|
|
return shadowaead.ErrBadHeader
|
|
}
|
|
|
|
requestSalt := header[:s.keySaltLength]
|
|
|
|
if !s.replayFilter.Check(requestSalt) {
|
|
return ErrSaltNotUnique
|
|
}
|
|
|
|
requestKey := SessionKey(s.psk, requestSalt, s.keySaltLength)
|
|
readCipher, err := s.constructor(common.Dup(requestKey))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
reader := shadowaead.NewReader(
|
|
conn,
|
|
readCipher,
|
|
MaxPacketSize,
|
|
)
|
|
common.KeepAlive(requestKey)
|
|
|
|
err = reader.ReadExternalChunk(header[s.keySaltLength:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
headerType, err := reader.ReadByte()
|
|
if err != nil {
|
|
return E.Cause(err, "read header")
|
|
}
|
|
|
|
if headerType != HeaderTypeClient {
|
|
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
|
|
}
|
|
|
|
var epoch uint64
|
|
err = binary.Read(reader, binary.BigEndian, &epoch)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
|
|
if diff > 30 {
|
|
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
|
|
}
|
|
|
|
var length uint16
|
|
err = binary.Read(reader, binary.BigEndian, &length)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = reader.ReadWithLength(length)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var paddingLen uint16
|
|
err = binary.Read(reader, binary.BigEndian, &paddingLen)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if uint16(reader.Cached()) < paddingLen {
|
|
return ErrNoPadding
|
|
}
|
|
|
|
if paddingLen > 0 {
|
|
err = reader.Discard(int(paddingLen))
|
|
if err != nil {
|
|
return E.Cause(err, "discard padding")
|
|
}
|
|
} else if reader.Cached() == 0 {
|
|
return ErrNoPadding
|
|
}
|
|
|
|
protocolConn := &serverConn{
|
|
Service: s,
|
|
Conn: conn,
|
|
uPSK: s.psk,
|
|
headerType: headerType,
|
|
requestSalt: requestSalt,
|
|
}
|
|
|
|
protocolConn.reader = reader
|
|
|
|
metadata.Protocol = "shadowsocks"
|
|
metadata.Destination = destination
|
|
return s.handler.NewConnection(ctx, protocolConn, metadata)
|
|
}
|
|
|
|
type serverConn struct {
|
|
*Service
|
|
net.Conn
|
|
uPSK []byte
|
|
access sync.Mutex
|
|
headerType byte
|
|
reader *shadowaead.Reader
|
|
writer *shadowaead.Writer
|
|
requestSalt []byte
|
|
}
|
|
|
|
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
|
_salt := buf.StackNewSize(c.keySaltLength)
|
|
salt := common.Dup(_salt)
|
|
salt.WriteRandom(salt.FreeLen())
|
|
|
|
key := SessionKey(c.uPSK, salt.Bytes(), c.keySaltLength)
|
|
common.KeepAlive(_salt)
|
|
writeCipher, err := c.constructor(common.Dup(key))
|
|
if err != nil {
|
|
salt.Release()
|
|
common.KeepAlive(_salt)
|
|
return
|
|
}
|
|
writer := shadowaead.NewWriter(
|
|
c.Conn,
|
|
writeCipher,
|
|
MaxPacketSize,
|
|
)
|
|
common.KeepAlive(key)
|
|
header := writer.Buffer()
|
|
header.Write(salt.Bytes())
|
|
|
|
salt.Release()
|
|
common.KeepAlive(_salt)
|
|
|
|
headerType := byte(HeaderTypeServer)
|
|
payloadLen := len(payload)
|
|
|
|
_headerFixedChunk := buf.StackNewSize(1 + 8 + c.keySaltLength + 2)
|
|
headerFixedChunk := common.Dup(_headerFixedChunk)
|
|
common.Must(headerFixedChunk.WriteByte(headerType))
|
|
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(c.time().Unix())))
|
|
common.Must1(headerFixedChunk.Write(c.requestSalt))
|
|
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen)))
|
|
|
|
writer.WriteChunk(header, headerFixedChunk.Slice())
|
|
headerFixedChunk.Release()
|
|
common.KeepAlive(_headerFixedChunk)
|
|
c.requestSalt = nil
|
|
|
|
if payloadLen > 0 {
|
|
writer.WriteChunk(header, payload[:payloadLen])
|
|
}
|
|
|
|
err = writer.BufferedWriter(header.Len()).Flush()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
switch headerType {
|
|
case HeaderTypeServer:
|
|
c.writer = writer
|
|
// case HeaderTypeServerEncrypted:
|
|
// encryptedWriter := NewTLSEncryptedStreamWriter(writer)
|
|
// if payloadLen < len(payload) {
|
|
// _, err = encryptedWriter.Write(payload[payloadLen:])
|
|
// if err != nil {
|
|
// return
|
|
// }
|
|
// }
|
|
// c.writer = encryptedWriter
|
|
}
|
|
|
|
n = len(payload)
|
|
return
|
|
}
|
|
|
|
func (c *serverConn) Read(b []byte) (n int, err error) {
|
|
return c.reader.Read(b)
|
|
}
|
|
|
|
func (c *serverConn) Write(p []byte) (n int, err error) {
|
|
if c.writer != nil {
|
|
return c.writer.Write(p)
|
|
}
|
|
c.access.Lock()
|
|
if c.writer != nil {
|
|
c.access.Unlock()
|
|
return c.writer.Write(p)
|
|
}
|
|
defer c.access.Unlock()
|
|
return c.writeResponse(p)
|
|
}
|
|
|
|
func (c *serverConn) WriteVectorised(buffers []*buf.Buffer) error {
|
|
if c.writer != nil {
|
|
return c.writer.WriteVectorised(buffers)
|
|
}
|
|
c.access.Lock()
|
|
if c.writer != nil {
|
|
c.access.Unlock()
|
|
return c.writer.WriteVectorised(buffers)
|
|
}
|
|
defer c.access.Unlock()
|
|
_, err := c.writeResponse(buffers[0].Bytes())
|
|
if err != nil {
|
|
buf.ReleaseMulti(buffers)
|
|
return err
|
|
}
|
|
buffers[0].Release()
|
|
return c.writer.WriteVectorised(buffers[1:])
|
|
}
|
|
|
|
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|
if c.writer == nil {
|
|
return bufio.ReadFrom0(c, r)
|
|
}
|
|
return bufio.Copy(c.writer, r)
|
|
}
|
|
|
|
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
|
return bufio.Copy(w, c.reader)
|
|
}
|
|
|
|
func (c *serverConn) Close() error {
|
|
return common.Close(
|
|
c.Conn,
|
|
common.PtrOrNil(c.reader),
|
|
common.PtrOrNil(c.writer),
|
|
)
|
|
}
|
|
|
|
func (c *serverConn) Upstream() any {
|
|
return c.Conn
|
|
}
|
|
|
|
func (s *Service) WriteIsThreadUnsafe() {
|
|
}
|
|
|
|
func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
|
err := s.newPacket(ctx, conn, buffer, metadata)
|
|
if err != nil {
|
|
err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
|
var packetHeader []byte
|
|
if s.udpCipher != nil {
|
|
if buffer.Len() < PacketNonceSize+PacketMinimalHeaderSize {
|
|
return ErrPacketTooShort
|
|
}
|
|
_, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
|
if err != nil {
|
|
return E.Cause(err, "decrypt packet header")
|
|
}
|
|
buffer.Advance(PacketNonceSize)
|
|
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
|
|
} else {
|
|
if buffer.Len() < PacketMinimalHeaderSize {
|
|
return ErrPacketTooShort
|
|
}
|
|
packetHeader = buffer.To(aes.BlockSize)
|
|
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
|
|
}
|
|
|
|
var sessionId, packetId uint64
|
|
err := binary.Read(buffer, binary.BigEndian, &sessionId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = binary.Read(buffer, binary.BigEndian, &packetId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
session, loaded := s.udpSessions.LoadOrStore(sessionId, s.newUDPSession)
|
|
if !loaded {
|
|
session.remoteSessionId = sessionId
|
|
if packetHeader != nil {
|
|
key := SessionKey(s.psk, packetHeader[:8], s.keySaltLength)
|
|
session.remoteCipher, err = s.constructor(common.Dup(key))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
common.KeepAlive(key)
|
|
}
|
|
}
|
|
goto process
|
|
|
|
returnErr:
|
|
if !loaded {
|
|
s.udpSessions.Delete(sessionId)
|
|
}
|
|
return err
|
|
|
|
process:
|
|
if !session.window.Check(packetId) {
|
|
err = ErrPacketIdNotUnique
|
|
goto returnErr
|
|
}
|
|
|
|
if packetHeader != nil {
|
|
_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
|
|
if err != nil {
|
|
err = E.Cause(err, "decrypt packet")
|
|
goto returnErr
|
|
}
|
|
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
|
|
}
|
|
|
|
session.window.Add(packetId)
|
|
|
|
var headerType byte
|
|
headerType, err = buffer.ReadByte()
|
|
if err != nil {
|
|
err = E.Cause(err, "decrypt packet")
|
|
goto returnErr
|
|
}
|
|
if headerType != HeaderTypeClient {
|
|
err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
|
|
goto returnErr
|
|
}
|
|
|
|
var epoch uint64
|
|
err = binary.Read(buffer, binary.BigEndian, &epoch)
|
|
if err != nil {
|
|
goto returnErr
|
|
}
|
|
diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
|
|
if diff > 30 {
|
|
err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
|
|
goto returnErr
|
|
}
|
|
|
|
var paddingLen uint16
|
|
err = binary.Read(buffer, binary.BigEndian, &paddingLen)
|
|
if err != nil {
|
|
err = E.Cause(err, "read padding length")
|
|
goto returnErr
|
|
}
|
|
buffer.Advance(int(paddingLen))
|
|
|
|
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
|
|
if err != nil {
|
|
goto returnErr
|
|
}
|
|
metadata.Protocol = "shadowsocks"
|
|
metadata.Destination = destination
|
|
s.udpNat.NewPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) N.PacketWriter {
|
|
return &serverPacketWriter{s, conn, natConn, session, s.udpBlockCipher}
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) NewError(ctx context.Context, err error) {
|
|
s.handler.NewError(ctx, err)
|
|
}
|
|
|
|
type serverPacketWriter struct {
|
|
*Service
|
|
source N.PacketConn
|
|
nat N.PacketConn
|
|
session *serverUDPSession
|
|
udpBlockCipher cipher.Block
|
|
}
|
|
|
|
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
|
var hdrLen int
|
|
if w.udpCipher != nil {
|
|
hdrLen = PacketNonceSize
|
|
}
|
|
|
|
var paddingLen int
|
|
if destination.Port == 53 && buffer.Len() < MaxPaddingLength {
|
|
paddingLen = mRand.Intn(MaxPaddingLength-buffer.Len()) + 1
|
|
}
|
|
|
|
hdrLen += 16 // packet header
|
|
hdrLen += 1 // header type
|
|
hdrLen += 8 // timestamp
|
|
hdrLen += 8 // remote session id
|
|
hdrLen += 2 // padding length
|
|
hdrLen += paddingLen
|
|
hdrLen += M.SocksaddrSerializer.AddrPortLen(destination)
|
|
header := buf.With(buffer.ExtendHeader(hdrLen))
|
|
|
|
var dataIndex int
|
|
if w.udpCipher != nil {
|
|
common.Must1(header.ReadFullFrom(w.session.rng, PacketNonceSize))
|
|
dataIndex = PacketNonceSize
|
|
} else {
|
|
dataIndex = aes.BlockSize
|
|
}
|
|
|
|
common.Must(
|
|
binary.Write(header, binary.BigEndian, w.session.sessionId),
|
|
binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
|
|
header.WriteByte(HeaderTypeServer),
|
|
binary.Write(header, binary.BigEndian, uint64(w.time().Unix())),
|
|
binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
|
|
binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
|
|
)
|
|
|
|
if paddingLen > 0 {
|
|
header.Extend(paddingLen)
|
|
}
|
|
|
|
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
|
|
if err != nil {
|
|
buffer.Release()
|
|
return err
|
|
}
|
|
|
|
if w.udpCipher != nil {
|
|
w.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
|
|
buffer.Extend(shadowaead.Overhead)
|
|
} else {
|
|
packetHeader := buffer.To(aes.BlockSize)
|
|
w.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
|
|
buffer.Extend(shadowaead.Overhead)
|
|
w.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
|
}
|
|
return w.source.WritePacket(buffer, M.SocksaddrFromNet(w.nat.LocalAddr()))
|
|
}
|
|
|
|
func (w *serverPacketWriter) FrontHeadroom() int {
|
|
var hdrLen int
|
|
if w.udpCipher != nil {
|
|
hdrLen = PacketNonceSize
|
|
}
|
|
hdrLen += 16 // packet header
|
|
hdrLen += 1 // header type
|
|
hdrLen += 8 // timestamp
|
|
hdrLen += 8 // remote session id
|
|
hdrLen += 2 // padding length
|
|
hdrLen += MaxPaddingLength
|
|
hdrLen += M.MaxSocksaddrLength
|
|
return hdrLen
|
|
}
|
|
|
|
func (w *serverPacketWriter) RearHeadroom() int {
|
|
return shadowaead.Overhead
|
|
}
|
|
|
|
func (w *serverPacketWriter) Upstream() any {
|
|
return w.source
|
|
}
|
|
|
|
type serverUDPSession struct {
|
|
sessionId uint64
|
|
remoteSessionId uint64
|
|
packetId uint64
|
|
cipher cipher.AEAD
|
|
remoteCipher cipher.AEAD
|
|
window SlidingWindow
|
|
rng io.Reader
|
|
}
|
|
|
|
func (s *serverUDPSession) nextPacketId() uint64 {
|
|
return atomic.AddUint64(&s.packetId, 1)
|
|
}
|
|
|
|
func (s *Service) newUDPSession() *serverUDPSession {
|
|
session := &serverUDPSession{}
|
|
if s.udpCipher != nil {
|
|
session.rng = Blake3KeyedHash(rand.Reader)
|
|
common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
|
|
} else {
|
|
common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
|
|
}
|
|
session.packetId--
|
|
if s.udpCipher == nil {
|
|
sessionId := make([]byte, 8)
|
|
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
|
key := SessionKey(s.psk, sessionId, s.keySaltLength)
|
|
var err error
|
|
session.cipher, err = s.constructor(common.Dup(key))
|
|
common.Must(err)
|
|
common.KeepAlive(key)
|
|
}
|
|
return session
|
|
}
|