sing/protocol/shadowsocks/shadowaead_2022/service.go
2022-05-02 11:18:46 +08:00

455 lines
11 KiB
Go

package shadowaead_2022
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"io"
"math"
"net"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/cache"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
wgReplay "golang.zx2c4.com/wireguard/replay"
)
type Service struct {
name string
secureRNG io.Reader
keyLength int
constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block
udpCipher cipher.AEAD
udpBlockCipher cipher.Block
psk [KeySaltSize]byte
replayFilter replay.Filter
handler shadowsocks.Handler
udpNat udpnat.Service[uint64]
sessions cache.LruCache[uint64, *serverUDPSession]
}
func NewService(method string, psk [KeySaltSize]byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{
name: method,
psk: psk,
secureRNG: secureRNG,
replayFilter: replay.NewCuckoo(60),
handler: handler,
sessions: cache.NewLRU[uint64, *serverUDPSession](udpTimeout, true),
}
switch method {
case "2022-blake3-aes-128-gcm":
s.keyLength = 16
s.constructor = newAESGCM
s.blockConstructor = newAES
s.udpBlockCipher = newAES(s.psk[:])
case "2022-blake3-aes-256-gcm":
s.keyLength = 32
s.constructor = newAESGCM
s.blockConstructor = newAES
s.udpBlockCipher = newAES(s.psk[:])
case "2022-blake3-chacha20-poly1305":
s.keyLength = 32
s.constructor = newChacha20Poly1305
s.udpCipher = newXChacha20Poly1305(s.psk[:])
}
s.udpNat = udpnat.New[uint64](udpTimeout, s)
return s, nil
}
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
err := s.newConnection(ctx, conn, metadata)
if err != nil {
err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
}
return err
}
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
if debug.Enabled {
logger.Trace("begin server handshake request")
}
requestSalt := make([]byte, KeySaltSize)
_, err := io.ReadFull(conn, requestSalt)
if err != nil {
return E.Cause(err, "read request salt")
}
if debug.Enabled {
logger.Trace("read salt ", buf.EncodeHexString(requestSalt))
}
if !s.replayFilter.Check(requestSalt) {
return E.New("salt not unique")
}
requestKey := Blake3DeriveKey(s.psk[:], requestSalt, s.keyLength)
reader := shadowaead.NewReader(
conn,
s.constructor(common.Dup(requestKey)),
MaxPacketSize,
)
headerType, err := rw.ReadByte(reader)
if err != nil {
return E.Cause(err, "read header")
}
if headerType != HeaderTypeClient {
return ErrBadHeaderType
}
var epoch uint64
err = binary.Read(reader, binary.BigEndian, &epoch)
if err != nil {
return E.Cause(err, "read timestamp")
}
if debug.Enabled {
logger.Trace("client timestamp ", time.Unix(int64(epoch), 0).String())
}
if math.Abs(float64(time.Now().Unix()-int64(epoch))) > 30 {
return ErrBadTimestamp
}
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
if err != nil {
return E.Cause(err, "read destination")
}
if debug.Enabled {
logger.Trace("destination ", destination)
}
var paddingLen uint16
err = binary.Read(reader, binary.BigEndian, &paddingLen)
if err != nil {
return E.Cause(err, "read padding length")
}
if debug.Enabled {
if paddingLen > 0 {
logger.Trace("padding ", paddingLen, "B")
} else {
logger.Trace("no padding")
}
}
if paddingLen > 0 {
err = reader.Discard(int(paddingLen))
if err != nil {
return E.Cause(err, "discard padding")
}
}
if debug.Enabled {
logger.Trace("end server handshake request")
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.handler.NewConnection(ctx, &serverConn{
Service: s,
Conn: conn,
uPSK: s.psk,
reader: reader,
requestSalt: requestSalt,
}, metadata)
}
type serverConn struct {
*Service
net.Conn
uPSK [KeySaltSize]byte
access sync.Mutex
reader *shadowaead.Reader
writer *shadowaead.Writer
requestSalt []byte
}
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
var _salt [KeySaltSize]byte
salt := common.Dup(_salt[:])
common.Must1(io.ReadFull(c.secureRNG, salt))
key := Blake3DeriveKey(c.uPSK[:], salt, c.keyLength)
writer := shadowaead.NewWriter(
c.Conn,
c.constructor(common.Dup(key)),
MaxPacketSize,
)
header := writer.Buffer()
header.Write(salt)
bufferedWriter := writer.BufferedWriter(header.Len())
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeServer))
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
common.Must1(bufferedWriter.Write(c.requestSalt[:]))
c.requestSalt = nil
if len(payload) > 0 {
_, err = bufferedWriter.Write(payload)
if err != nil {
return
}
}
err = bufferedWriter.Flush()
if err != nil {
return
}
c.writer = writer
n = len(payload)
return
}
func (c *serverConn) Write(p []byte) (n int, err error) {
if c.writer != nil {
return c.writer.Write(p)
}
c.access.Lock()
if c.writer != nil {
c.access.Unlock()
return c.writer.Write(p)
}
defer c.access.Unlock()
return c.writeResponse(p)
}
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
return rw.ReadFrom0(c, r)
}
return c.writer.ReadFrom(r)
}
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}
func (c *serverConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *serverConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *serverConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *serverConn) WriterReplaceable() bool {
return c.writer != nil
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
err := s.newPacket(conn, buffer, metadata)
if err != nil {
err = &shadowsocks.ServerPacketError{PacketConn: conn, Source: metadata.Source, Cause: err}
}
return err
}
func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
var packetHeader []byte
if s.udpCipher != nil {
_, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
if err != nil {
return E.Cause(err, "decrypt packet header")
}
buffer.Advance(PacketNonceSize)
} else {
packetHeader = buffer.To(aes.BlockSize)
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
}
var sessionId, packetId uint64
err := binary.Read(buffer, binary.BigEndian, &sessionId)
if err != nil {
return err
}
err = binary.Read(buffer, binary.BigEndian, &packetId)
if err != nil {
return err
}
session, loaded := s.sessions.LoadOrStore(sessionId, s.newUDPSession)
if !loaded {
session.remoteSessionId = sessionId
if packetHeader != nil {
key := Blake3DeriveKey(s.psk[:], packetHeader[:8], s.keyLength)
session.remoteCipher = s.constructor(common.Dup(key))
}
}
goto process
returnErr:
if !loaded {
s.sessions.Delete(sessionId)
}
return err
process:
if !session.filter.ValidateCounter(packetId, math.MaxUint64) {
err = ErrPacketIdNotUnique
goto returnErr
}
if packetHeader != nil {
_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
if err != nil {
err = E.Cause(err, "decrypt packet")
goto returnErr
}
}
var headerType byte
headerType, err = buffer.ReadByte()
if err != nil {
err = E.Cause(err, "decrypt packet")
goto returnErr
}
if headerType != HeaderTypeClient {
err = ErrBadHeaderType
goto returnErr
}
var epoch uint64
err = binary.Read(buffer, binary.BigEndian, &epoch)
if err != nil {
goto returnErr
}
if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 {
err = ErrBadTimestamp
goto returnErr
}
var paddingLength uint16
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
if err != nil {
err = E.Cause(err, "read padding length")
goto returnErr
}
buffer.Advance(int(paddingLength))
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
goto returnErr
}
metadata.Destination = destination
session.remoteAddr = metadata.Source
s.udpNat.NewPacket(sessionId, func() socks.PacketWriter {
return &serverPacketWriter{s, conn, session}
}, buffer, metadata)
return nil
}
type serverPacketWriter struct {
*Service
socks.PacketConn
session *serverUDPSession
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
_header := buf.StackNew()
header := common.Dup(_header)
var dataIndex int
if w.udpCipher != nil {
common.Must1(header.ReadFullFrom(w.secureRNG, PacketNonceSize))
dataIndex = buffer.Len()
} else {
dataIndex = aes.BlockSize
}
common.Must(
binary.Write(header, binary.BigEndian, w.session.sessionId),
binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
header.WriteByte(HeaderTypeServer),
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
err := socks.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
_, err = header.Write(buffer.Bytes())
if err != nil {
return err
}
if w.udpCipher != nil {
w.udpCipher.Seal(header.Index(dataIndex), header.To(dataIndex), header.From(dataIndex), nil)
header.Extend(w.udpCipher.Overhead())
} else {
packetHeader := header.To(aes.BlockSize)
w.session.cipher.Seal(header.Index(dataIndex), packetHeader[4:16], header.From(dataIndex), nil)
header.Extend(w.session.cipher.Overhead())
w.udpBlockCipher.Encrypt(packetHeader, packetHeader)
}
return w.PacketConn.WritePacket(header, w.session.remoteAddr)
}
type serverUDPSession struct {
sessionId uint64
remoteSessionId uint64
remoteAddr *M.AddrPort
packetId uint64
cipher cipher.AEAD
remoteCipher cipher.AEAD
filter wgReplay.Filter
}
func (s *serverUDPSession) nextPacketId() uint64 {
return atomic.AddUint64(&s.packetId, 1)
}
func (m *Service) newUDPSession() *serverUDPSession {
session := &serverUDPSession{}
common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId))
session.packetId--
if m.udpCipher == nil {
sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := Blake3DeriveKey(m.psk[:], sessionId, m.keyLength)
session.cipher = m.constructor(common.Dup(key))
}
return session
}
func (s *Service) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(ctx, conn, metadata)
}
func (s *Service) HandleError(err error) {
s.handler.HandleError(err)
}