mirror of
https://github.com/SagerNet/sing-shadowsocks.git
synced 2025-04-03 03:47:40 +03:00
392 lines
10 KiB
Go
392 lines
10 KiB
Go
package shadowaead_2022
|
|
|
|
import (
|
|
"context"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"math"
|
|
"net"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/sagernet/sing-shadowsocks"
|
|
"github.com/sagernet/sing-shadowsocks/shadowaead"
|
|
"github.com/sagernet/sing/common"
|
|
"github.com/sagernet/sing/common/auth"
|
|
"github.com/sagernet/sing/common/buf"
|
|
"github.com/sagernet/sing/common/bufio/deadline"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
N "github.com/sagernet/sing/common/network"
|
|
"github.com/sagernet/sing/common/rw"
|
|
|
|
"lukechampine.com/blake3"
|
|
)
|
|
|
|
var _ shadowsocks.MultiService[int] = (*MultiService[int])(nil)
|
|
|
|
type MultiService[U comparable] struct {
|
|
*Service
|
|
|
|
uPSK map[U][]byte
|
|
uPSKHash map[[aes.BlockSize]byte]U
|
|
uCipher map[U]cipher.Block
|
|
}
|
|
|
|
func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
|
|
if password == "" {
|
|
return nil, ErrMissingPSK
|
|
}
|
|
iPSK, err := base64.StdEncoding.DecodeString(password)
|
|
if err != nil {
|
|
return nil, E.Cause(err, "decode psk")
|
|
}
|
|
return NewMultiService[U](method, iPSK, udpTimeout, handler, timeFunc)
|
|
}
|
|
|
|
func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
|
|
switch method {
|
|
case "2022-blake3-aes-128-gcm":
|
|
case "2022-blake3-aes-256-gcm":
|
|
default:
|
|
return nil, os.ErrInvalid
|
|
}
|
|
|
|
ss, err := NewService(method, iPSK, udpTimeout, handler, timeFunc)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s := &MultiService[U]{
|
|
Service: ss.(*Service),
|
|
|
|
uPSK: make(map[U][]byte),
|
|
uPSKHash: make(map[[aes.BlockSize]byte]U),
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (s *MultiService[U]) UpdateUsers(userList []U, keyList [][]byte) error {
|
|
uPSK := make(map[U][]byte)
|
|
uPSKHash := make(map[[aes.BlockSize]byte]U)
|
|
uCipher := make(map[U]cipher.Block)
|
|
for i, user := range userList {
|
|
key := keyList[i]
|
|
if len(key) < s.keySaltLength {
|
|
return shadowsocks.ErrBadKey
|
|
} else if len(key) > s.keySaltLength {
|
|
key = Key(key, s.keySaltLength)
|
|
}
|
|
|
|
var hash [aes.BlockSize]byte
|
|
hash512 := blake3.Sum512(key)
|
|
copy(hash[:], hash512[:])
|
|
|
|
uPSKHash[hash] = user
|
|
uPSK[user] = key
|
|
var err error
|
|
uCipher[user], err = s.blockConstructor(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
s.uPSK = uPSK
|
|
s.uPSKHash = uPSKHash
|
|
s.uCipher = uCipher
|
|
return nil
|
|
}
|
|
|
|
func (s *MultiService[U]) UpdateUsersWithPasswords(userList []U, passwordList []string) error {
|
|
keyList := make([][]byte, 0, len(passwordList))
|
|
for _, password := range passwordList {
|
|
if password == "" {
|
|
return shadowsocks.ErrMissingPassword
|
|
}
|
|
uPSK, err := base64.StdEncoding.DecodeString(password)
|
|
if err != nil {
|
|
return E.Cause(err, "decode psk")
|
|
}
|
|
keyList = append(keyList, uPSK)
|
|
}
|
|
return s.UpdateUsers(userList, keyList)
|
|
}
|
|
|
|
func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
|
err := s.newConnection(ctx, conn, metadata)
|
|
if err != nil {
|
|
err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
|
requestHeader := make([]byte, s.keySaltLength+aes.BlockSize+shadowaead.Overhead+RequestHeaderFixedChunkLength)
|
|
n, err := conn.Read(requestHeader)
|
|
if err != nil {
|
|
return err
|
|
} else if n < len(requestHeader) {
|
|
return shadowaead.ErrBadHeader
|
|
}
|
|
requestSalt := requestHeader[:s.keySaltLength]
|
|
if !s.replayFilter.Check(requestSalt) {
|
|
return ErrSaltNotUnique
|
|
}
|
|
|
|
var _eiHeader [aes.BlockSize]byte
|
|
eiHeader := common.Dup(_eiHeader[:])
|
|
copy(eiHeader, requestHeader[s.keySaltLength:s.keySaltLength+aes.BlockSize])
|
|
|
|
keyMaterial := buf.Make(s.keySaltLength * 2)
|
|
copy(keyMaterial, s.psk)
|
|
copy(keyMaterial[s.keySaltLength:], requestSalt)
|
|
_identitySubkey := buf.StackNewSize(s.keySaltLength)
|
|
identitySubkey := common.Dup(_identitySubkey)
|
|
identitySubkey.Extend(identitySubkey.FreeLen())
|
|
blake3.DeriveKey(identitySubkey.Bytes(), "shadowsocks 2022 identity subkey", keyMaterial)
|
|
b, err := s.blockConstructor(identitySubkey.Bytes())
|
|
identitySubkey.Release()
|
|
common.KeepAlive(_identitySubkey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
b.Decrypt(eiHeader, eiHeader)
|
|
|
|
var user U
|
|
var uPSK []byte
|
|
if u, loaded := s.uPSKHash[_eiHeader]; loaded {
|
|
user = u
|
|
uPSK = s.uPSK[u]
|
|
} else {
|
|
return E.New("invalid request")
|
|
}
|
|
common.KeepAlive(_eiHeader)
|
|
|
|
requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength)
|
|
readCipher, err := s.constructor(common.Dup(requestKey))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
reader := shadowaead.NewReader(
|
|
conn,
|
|
readCipher,
|
|
MaxPacketSize,
|
|
)
|
|
|
|
err = reader.ReadExternalChunk(requestHeader[s.keySaltLength+aes.BlockSize:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
headerType, err := rw.ReadByte(reader)
|
|
if err != nil {
|
|
return E.Cause(err, "read header")
|
|
}
|
|
|
|
if headerType != HeaderTypeClient {
|
|
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
|
|
}
|
|
|
|
var epoch uint64
|
|
err = binary.Read(reader, binary.BigEndian, &epoch)
|
|
if err != nil {
|
|
return E.Cause(err, "read timestamp")
|
|
}
|
|
diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
|
|
if diff > 30 {
|
|
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
|
|
}
|
|
var length uint16
|
|
err = binary.Read(reader, binary.BigEndian, &length)
|
|
if err != nil {
|
|
return E.Cause(err, "read length")
|
|
}
|
|
|
|
err = reader.ReadWithLength(length)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
|
|
if err != nil {
|
|
return E.Cause(err, "read destination")
|
|
}
|
|
|
|
var paddingLen uint16
|
|
err = binary.Read(reader, binary.BigEndian, &paddingLen)
|
|
if err != nil {
|
|
return E.Cause(err, "read padding length")
|
|
}
|
|
|
|
if reader.Cached() < int(paddingLen) {
|
|
return ErrBadPadding
|
|
} else if paddingLen > 0 {
|
|
err = reader.Discard(int(paddingLen))
|
|
if err != nil {
|
|
return E.Cause(err, "discard padding")
|
|
}
|
|
} else if reader.Cached() == 0 {
|
|
return ErrNoPadding
|
|
}
|
|
|
|
protocolConn := &serverConn{
|
|
Service: s.Service,
|
|
Conn: conn,
|
|
uPSK: uPSK,
|
|
headerType: headerType,
|
|
requestSalt: requestSalt,
|
|
}
|
|
|
|
protocolConn.reader = reader
|
|
metadata.Protocol = "shadowsocks"
|
|
metadata.Destination = destination
|
|
return s.handler.NewConnection(auth.ContextWithUser(ctx, user), deadline.NewConn(protocolConn), metadata)
|
|
}
|
|
|
|
func (s *MultiService[U]) WriteIsThreadUnsafe() {
|
|
}
|
|
|
|
func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
|
err := s.newPacket(ctx, conn, buffer, metadata)
|
|
if err != nil {
|
|
err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
|
if buffer.Len() < PacketMinimalHeaderSize {
|
|
return ErrPacketTooShort
|
|
}
|
|
|
|
packetHeader := buffer.To(aes.BlockSize)
|
|
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
|
|
|
|
var _eiHeader [aes.BlockSize]byte
|
|
eiHeader := common.Dup(_eiHeader[:])
|
|
s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize))
|
|
xorWords(eiHeader, eiHeader, packetHeader)
|
|
|
|
var user U
|
|
var uPSK []byte
|
|
if u, loaded := s.uPSKHash[_eiHeader]; loaded {
|
|
user = u
|
|
uPSK = s.uPSK[u]
|
|
} else {
|
|
return E.New("invalid request")
|
|
}
|
|
|
|
var sessionId, packetId uint64
|
|
err := binary.Read(buffer, binary.BigEndian, &sessionId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = binary.Read(buffer, binary.BigEndian, &packetId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
buffer.Advance(aes.BlockSize)
|
|
|
|
session, loaded := s.udpSessions.LoadOrStore(sessionId, func() *serverUDPSession {
|
|
return s.newUDPSession(uPSK)
|
|
})
|
|
if !loaded {
|
|
session.remoteSessionId = sessionId
|
|
key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength)
|
|
session.remoteCipher, err = s.constructor(common.Dup(key))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
common.KeepAlive(key)
|
|
}
|
|
|
|
goto process
|
|
|
|
returnErr:
|
|
if !loaded {
|
|
s.udpSessions.Delete(sessionId)
|
|
}
|
|
return err
|
|
|
|
process:
|
|
if !session.window.Check(packetId) {
|
|
err = ErrPacketIdNotUnique
|
|
goto returnErr
|
|
}
|
|
|
|
if packetHeader != nil {
|
|
_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
|
|
if err != nil {
|
|
err = E.Cause(err, "decrypt packet")
|
|
goto returnErr
|
|
}
|
|
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
|
|
}
|
|
|
|
session.window.Add(packetId)
|
|
|
|
var headerType byte
|
|
headerType, err = buffer.ReadByte()
|
|
if err != nil {
|
|
err = E.Cause(err, "decrypt packet")
|
|
goto returnErr
|
|
}
|
|
if headerType != HeaderTypeClient {
|
|
err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
|
|
goto returnErr
|
|
}
|
|
|
|
var epoch uint64
|
|
err = binary.Read(buffer, binary.BigEndian, &epoch)
|
|
if err != nil {
|
|
goto returnErr
|
|
}
|
|
diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
|
|
if diff > 30 {
|
|
err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
|
|
goto returnErr
|
|
}
|
|
|
|
var paddingLen uint16
|
|
err = binary.Read(buffer, binary.BigEndian, &paddingLen)
|
|
if err != nil {
|
|
err = E.Cause(err, "read padding length")
|
|
goto returnErr
|
|
}
|
|
buffer.Advance(int(paddingLen))
|
|
|
|
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
|
|
if err != nil {
|
|
goto returnErr
|
|
}
|
|
|
|
metadata.Protocol = "shadowsocks"
|
|
metadata.Destination = destination
|
|
s.udpNat.NewContextPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
|
|
return auth.ContextWithUser(ctx, user), &serverPacketWriter{s.Service, conn, natConn, session, s.uCipher[user]}
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (s *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession {
|
|
session := &serverUDPSession{}
|
|
if s.udpCipher != nil {
|
|
session.rng = Blake3KeyedHash(rand.Reader)
|
|
common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
|
|
} else {
|
|
common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
|
|
}
|
|
session.packetId--
|
|
sessionId := make([]byte, 8)
|
|
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
|
key := SessionKey(uPSK, sessionId, s.keySaltLength)
|
|
var err error
|
|
session.cipher, err = s.constructor(common.Dup(key))
|
|
common.Must(err)
|
|
common.KeepAlive(key)
|
|
return session
|
|
}
|