Add password support for shadowsocks 2022 ciphers

This commit is contained in:
世界 2022-05-12 16:46:38 +08:00
parent f1a5f8aaa3
commit 2aae93c5b8
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
11 changed files with 204 additions and 180 deletions

View file

@ -2,7 +2,6 @@ package main
import ( import (
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"io" "io"
"io/ioutil" "io/ioutil"
@ -32,6 +31,7 @@ import (
"github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
"github.com/sagernet/sing/protocol/shadowsocks/shadowimpl"
"github.com/sagernet/sing/protocol/shadowsocks/shadowstream" "github.com/sagernet/sing/protocol/shadowsocks/shadowstream"
"github.com/sagernet/sing/transport/mixed" "github.com/sagernet/sing/transport/mixed"
"github.com/sagernet/sing/transport/system" "github.com/sagernet/sing/transport/system"
@ -182,53 +182,11 @@ func newClient(f *flags) (*client, error) {
if f.ReducedSaltEntropy { if f.ReducedSaltEntropy {
rng = &shadowsocks.ReducedEntropyReader{Reader: rng} rng = &shadowsocks.ReducedEntropyReader{Reader: rng}
} }
if common.Contains(shadowstream.List, f.Method) { method, err := shadowimpl.FetchMethod(f.Method, f.Key, f.Password, rng)
var key []byte if err != nil {
if f.Key != "" { return nil, err
kb, err := base64.StdEncoding.DecodeString(f.Key)
if err != nil {
return nil, E.Cause(err, "decode key")
}
key = kb
}
method, err := shadowstream.New(f.Method, key, []byte(f.Password), rng)
if err != nil {
return nil, err
}
c.method = method
} else if common.Contains(shadowaead.List, f.Method) {
var key []byte
if f.Key != "" {
kb, err := base64.StdEncoding.DecodeString(f.Key)
if err != nil {
return nil, E.Cause(err, "decode key")
}
key = kb
}
method, err := shadowaead.New(f.Method, key, []byte(f.Password), rng)
if err != nil {
return nil, err
}
c.method = method
} else if common.Contains(shadowaead_2022.List, f.Method) {
var pskList [][]byte
if f.Key != "" {
keyStrList := strings.Split(f.Key, ":")
pskList = make([][]byte, len(keyStrList))
for i, keyStr := range keyStrList {
kb, err := base64.StdEncoding.DecodeString(keyStr)
if err != nil {
return nil, E.Cause(err, "decode key")
}
pskList[i] = kb
}
}
method, err := shadowaead_2022.New(f.Method, pskList, rng)
if err != nil {
return nil, err
}
c.method = method
} }
c.method = method
} }
c.dialer.Control = func(network, address string, c syscall.RawConn) error { c.dialer.Control = func(network, address string, c syscall.RawConn) error {

View file

@ -143,7 +143,7 @@ func newServer(f *flags) (*server, error) {
} }
key = kb key = kb
} }
service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Default, udpTimeout, s) service, err := shadowaead.NewService(f.Method, key, f.Password, random.Default, udpTimeout, s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -157,7 +157,7 @@ func newServer(f *flags) (*server, error) {
} }
key = kb key = kb
} }
service, err := shadowaead_2022.NewService(f.Method, key, random.Default, udpTimeout, s) service, err := shadowaead_2022.NewService(f.Method, key, f.Password, random.Default, udpTimeout, s)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -7,10 +7,16 @@ import (
"math/rand" "math/rand"
"net" "net"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
var (
ErrBadKey = E.New("shadowsocks: bad key")
ErrMissingPassword = E.New("shadowsocks: missing password")
)
type Method interface { type Method interface {
Name() string Name() string
KeyLength() int KeyLength() int

View file

@ -27,12 +27,7 @@ var List = []string{
"xchacha20-ietf-poly1305", "xchacha20-ietf-poly1305",
} }
var ( func New(method string, key []byte, password string, secureRNG io.Reader) (shadowsocks.Method, error) {
ErrBadKey = E.New("shadowsocks: bad key")
ErrMissingPassword = E.New("shadowsocks: missing password")
)
func New(method string, key []byte, password []byte, secureRNG io.Reader) (shadowsocks.Method, error) {
m := &Method{ m := &Method{
name: method, name: method,
secureRNG: secureRNG, secureRNG: secureRNG,
@ -65,11 +60,11 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader) (shado
if len(key) == m.keySaltLength { if len(key) == m.keySaltLength {
m.key = key m.key = key
} else if len(key) > 0 { } else if len(key) > 0 {
return nil, ErrBadKey return nil, shadowsocks.ErrBadKey
} else if len(password) > 0 { } else if password == "" {
m.key = shadowsocks.Key(password, m.keySaltLength) return nil, shadowsocks.ErrMissingPassword
} else { } else {
return nil, ErrMissingPassword m.key = shadowsocks.Key([]byte(password), m.keySaltLength)
} }
return m, nil return m, nil
} }
@ -181,12 +176,10 @@ func (m *Method) DecodePacket(buffer *buf.Buffer) error {
type clientConn struct { type clientConn struct {
net.Conn net.Conn
method *Method method *Method
destination M.Socksaddr destination M.Socksaddr
reader *Reader
reader *Reader writer *Writer
writer *Writer
} }
func (c *clientConn) writeRequest(payload []byte) error { func (c *clientConn) writeRequest(payload []byte) error {

View file

@ -30,7 +30,7 @@ type Service struct {
udpNat *udpnat.Service[netip.AddrPort] udpNat *udpnat.Service[netip.AddrPort]
} }
func NewService(method string, key []byte, password []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { func NewService(method string, key []byte, password string, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{ s := &Service{
name: method, name: method,
secureRNG: secureRNG, secureRNG: secureRNG,
@ -65,11 +65,11 @@ func NewService(method string, key []byte, password []byte, secureRNG io.Reader,
if len(key) == s.keySaltLength { if len(key) == s.keySaltLength {
s.key = key s.key = key
} else if len(key) > 0 { } else if len(key) > 0 {
return nil, ErrBadKey return nil, shadowsocks.ErrBadKey
} else if len(password) > 0 { } else if password != "" {
s.key = shadowsocks.Key(password, s.keySaltLength) s.key = shadowsocks.Key([]byte(password), s.keySaltLength)
} else { } else {
return nil, ErrMissingPassword return nil, shadowsocks.ErrMissingPassword
} }
return s, nil return s, nil
} }

View file

@ -4,11 +4,13 @@ import (
"bytes" "bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/sha256"
"encoding/binary" "encoding/binary"
"io" "io"
"math" "math"
"math/rand" "math/rand"
"net" "net"
"os"
"runtime" "runtime"
"sync/atomic" "sync/atomic"
"time" "time"
@ -31,7 +33,6 @@ const (
HeaderTypeClient = 0 HeaderTypeClient = 0
HeaderTypeServer = 1 HeaderTypeServer = 1
MaxPaddingLength = 900 MaxPaddingLength = 900
SaltSize = 32
PacketNonceSize = 24 PacketNonceSize = 24
MaxPacketSize = 65535 MaxPacketSize = 65535
) )
@ -48,6 +49,7 @@ const (
) )
var ( var (
ErrMissingPasswordPSK = E.New("shadowsocks: missing password or psk")
ErrBadHeaderType = E.New("shadowsocks: bad header type") ErrBadHeaderType = E.New("shadowsocks: bad header type")
ErrBadTimestamp = E.New("shadowsocks: bad timestamp") ErrBadTimestamp = E.New("shadowsocks: bad timestamp")
ErrBadRequestSalt = E.New("shadowsocks: bad request salt") ErrBadRequestSalt = E.New("shadowsocks: bad request salt")
@ -62,7 +64,7 @@ var List = []string{
"2022-blake3-chacha20-poly1305", "2022-blake3-chacha20-poly1305",
} }
func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Method, error) { func New(method string, pskList [][]byte, password string, secureRNG io.Reader) (shadowsocks.Method, error) {
m := &Method{ m := &Method{
name: method, name: method,
pskList: pskList, pskList: pskList,
@ -72,27 +74,35 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
switch method { switch method {
case "2022-blake3-aes-128-gcm": case "2022-blake3-aes-128-gcm":
m.keyLength = 16 m.keySaltLength = 16
m.constructor = newAESGCM m.constructor = newAESGCM
m.blockConstructor = newAES m.blockConstructor = newAES
case "2022-blake3-aes-256-gcm": case "2022-blake3-aes-256-gcm":
m.keyLength = 32 m.keySaltLength = 32
m.constructor = newAESGCM m.constructor = newAESGCM
m.blockConstructor = newAES m.blockConstructor = newAES
case "2022-blake3-chacha20-poly1305": case "2022-blake3-chacha20-poly1305":
m.keyLength = 32 if len(pskList) > 1 {
return nil, os.ErrInvalid
}
m.keySaltLength = 32
m.constructor = newChacha20Poly1305 m.constructor = newChacha20Poly1305
} }
for i, psk := range pskList { if len(pskList) == 0 {
if len(psk) < m.keyLength { if password == "" {
return nil, shadowaead.ErrBadKey return nil, ErrMissingPasswordPSK
} else if len(psk) > m.keyLength {
pskList[i] = DerivePSK(psk, m.keyLength)
} }
pskList = [][]byte{Key([]byte(password), m.keySaltLength)}
} }
m.psk = pskList[len(pskList)-1] for i, psk := range pskList {
if len(psk) < m.keySaltLength {
return nil, shadowsocks.ErrBadKey
} else if len(psk) > m.keySaltLength {
pskList[i] = Key(psk, m.keySaltLength)
}
}
if len(pskList) > 1 { if len(pskList) > 1 {
pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize) pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize)
@ -112,19 +122,18 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
case "2022-blake3-aes-256-gcm": case "2022-blake3-aes-256-gcm":
m.udpBlockCipher = newAES(pskList[0]) m.udpBlockCipher = newAES(pskList[0])
case "2022-blake3-chacha20-poly1305": case "2022-blake3-chacha20-poly1305":
m.udpCipher = newXChacha20Poly1305(m.psk) m.udpCipher = newXChacha20Poly1305(pskList[0])
} }
return m, nil return m, nil
} }
func DerivePSK(key []byte, keyLength int) []byte { func Key(key []byte, keyLength int) []byte {
outKey := buf.Make(keyLength) psk := sha256.Sum256(key)
blake3.DeriveKey(outKey, "shadowsocks 2022 pre shared key", key) return psk[:keyLength]
return outKey
} }
func DeriveSessionKey(psk []byte, salt []byte, keyLength int) []byte { func SessionKey(psk []byte, salt []byte, keyLength int) []byte {
sessionKey := buf.Make(len(psk) + len(salt)) sessionKey := buf.Make(len(psk) + len(salt))
copy(sessionKey, psk) copy(sessionKey, psk)
copy(sessionKey[len(psk):], salt) copy(sessionKey[len(psk):], salt)
@ -161,12 +170,11 @@ func newXChacha20Poly1305(key []byte) cipher.AEAD {
type Method struct { type Method struct {
name string name string
keyLength int keySaltLength int
constructor func(key []byte) cipher.AEAD constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block blockConstructor func(key []byte) cipher.Block
udpCipher cipher.AEAD udpCipher cipher.AEAD
udpBlockCipher cipher.Block udpBlockCipher cipher.Block
psk []byte
pskList [][]byte pskList [][]byte
pskHash []byte pskHash []byte
secureRNG io.Reader secureRNG io.Reader
@ -178,13 +186,13 @@ func (m *Method) Name() string {
} }
func (m *Method) KeyLength() int { func (m *Method) KeyLength() int {
return m.keyLength return m.keySaltLength
} }
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &clientConn{ shadowsocksConn := &clientConn{
Method: m,
Conn: conn, Conn: conn,
method: m,
destination: destination, destination: destination,
} }
return shadowsocksConn, shadowsocksConn.writeRequest(nil) return shadowsocksConn, shadowsocksConn.writeRequest(nil)
@ -192,26 +200,23 @@ func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, err
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &clientConn{ return &clientConn{
Method: m,
Conn: conn, Conn: conn,
method: m,
destination: destination, destination: destination,
} }
} }
func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
return &clientPacketConn{conn, m, m.newUDPSession()} return &clientPacketConn{m, conn, m.newUDPSession()}
} }
type clientConn struct { type clientConn struct {
*Method
net.Conn net.Conn
method *Method
destination M.Socksaddr destination M.Socksaddr
requestSalt []byte requestSalt []byte
reader *shadowaead.Reader
reader *shadowaead.Reader writer *shadowaead.Writer
writer *shadowaead.Writer
} }
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) { func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
@ -220,10 +225,10 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
return return
} }
for i, psk := range m.pskList { for i, psk := range m.pskList {
keyMaterial := buf.Make(m.keyLength + SaltSize) keyMaterial := buf.Make(m.keySaltLength * 2)
copy(keyMaterial, psk) copy(keyMaterial, psk)
copy(keyMaterial[m.keyLength:], salt) copy(keyMaterial[m.keySaltLength:], salt)
_identitySubkey := buf.Make(m.keyLength) _identitySubkey := buf.Make(m.keySaltLength)
identitySubkey := common.Dup(_identitySubkey) identitySubkey := common.Dup(_identitySubkey)
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
@ -239,20 +244,20 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
} }
func (c *clientConn) writeRequest(payload []byte) error { func (c *clientConn) writeRequest(payload []byte) error {
salt := make([]byte, SaltSize) salt := buf.Make(c.keySaltLength)
common.Must1(io.ReadFull(c.method.secureRNG, salt)) common.Must1(io.ReadFull(c.secureRNG, salt))
key := DeriveSessionKey(c.method.psk, salt, c.method.keyLength) key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength)
writer := shadowaead.NewWriter( writer := shadowaead.NewWriter(
c.Conn, c.Conn,
c.method.constructor(common.Dup(key)), c.constructor(common.Dup(key)),
MaxPacketSize, MaxPacketSize,
) )
runtime.KeepAlive(key) runtime.KeepAlive(key)
header := writer.Buffer() header := writer.Buffer()
header.Write(salt) header.Write(salt)
c.method.writeExtendedIdentityHeaders(header, salt) c.writeExtendedIdentityHeaders(header, salt)
bufferedWriter := writer.BufferedWriter(header.Len()) bufferedWriter := writer.BufferedWriter(header.Len())
@ -279,7 +284,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
if err != nil { if err != nil {
return E.Cause(err, "write padding length") return E.Cause(err, "write padding length")
} }
_, err = io.CopyN(bufferedWriter, c.method.secureRNG, int64(pLen)) _, err = io.CopyN(bufferedWriter, c.secureRNG, int64(pLen))
if err != nil { if err != nil {
return E.Cause(err, "write padding") return E.Cause(err, "write padding")
} }
@ -300,22 +305,22 @@ func (c *clientConn) readResponse() error {
return nil return nil
} }
_salt := make([]byte, SaltSize) _salt := buf.Make(c.keySaltLength)
salt := common.Dup(_salt) salt := common.Dup(_salt)
_, err := io.ReadFull(c.Conn, salt) _, err := io.ReadFull(c.Conn, salt)
if err != nil { if err != nil {
return err return err
} }
if !c.method.replayFilter.Check(salt) { if !c.replayFilter.Check(salt) {
return E.New("salt not unique") return E.New("salt not unique")
} }
key := DeriveSessionKey(c.method.psk, salt, c.method.keyLength) key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength)
runtime.KeepAlive(_salt) runtime.KeepAlive(_salt)
reader := shadowaead.NewReader( reader := shadowaead.NewReader(
c.Conn, c.Conn,
c.method.constructor(common.Dup(key)), c.constructor(common.Dup(key)),
MaxPacketSize, MaxPacketSize,
) )
runtime.KeepAlive(key) runtime.KeepAlive(key)
@ -339,7 +344,7 @@ func (c *clientConn) readResponse() error {
return ErrBadTimestamp return ErrBadTimestamp
} }
_requestSalt := make([]byte, SaltSize) _requestSalt := buf.Make(c.keySaltLength)
requestSalt := common.Dup(_requestSalt) requestSalt := common.Dup(_requestSalt)
_, err = io.ReadFull(reader, requestSalt) _, err = io.ReadFull(reader, requestSalt)
if err != nil { if err != nil {
@ -412,19 +417,19 @@ func (c *clientConn) WriterReplaceable() bool {
} }
type clientPacketConn struct { type clientPacketConn struct {
*Method
net.Conn net.Conn
method *Method
session *udpSession session *udpSession
} }
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
var hdrLen int var hdrLen int
if c.method.udpCipher != nil { if c.udpCipher != nil {
hdrLen = PacketNonceSize hdrLen = PacketNonceSize
} }
hdrLen += 16 // packet header hdrLen += 16 // packet header
pskLen := len(c.method.pskList) pskLen := len(c.pskList)
if c.method.udpCipher == nil && pskLen > 1 { if c.udpCipher == nil && pskLen > 1 {
hdrLen += (pskLen - 1) * aes.BlockSize hdrLen += (pskLen - 1) * aes.BlockSize
} }
hdrLen += 1 // header type hdrLen += 1 // header type
@ -434,8 +439,8 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
header := buf.With(buffer.ExtendHeader(hdrLen)) header := buf.With(buffer.ExtendHeader(hdrLen))
var dataIndex int var dataIndex int
if c.method.udpCipher != nil { if c.udpCipher != nil {
common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize)) common.Must1(header.ReadFullFrom(c.secureRNG, PacketNonceSize))
if pskLen > 1 { if pskLen > 1 {
panic("unsupported chacha extended header") panic("unsupported chacha extended header")
} }
@ -449,16 +454,16 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
binary.Write(header, binary.BigEndian, c.session.nextPacketId()), binary.Write(header, binary.BigEndian, c.session.nextPacketId()),
) )
if c.method.udpCipher == nil && pskLen > 1 { if c.udpCipher == nil && pskLen > 1 {
for i, psk := range c.method.pskList { for i, psk := range c.pskList {
dataIndex += aes.BlockSize dataIndex += aes.BlockSize
pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
identityHeader := header.Extend(aes.BlockSize) identityHeader := header.Extend(aes.BlockSize)
for textI := 0; textI < aes.BlockSize; textI++ { for textI := 0; textI < aes.BlockSize; textI++ {
identityHeader[textI] = pskHash[textI] ^ header.Byte(textI) identityHeader[textI] = pskHash[textI] ^ header.Byte(textI)
} }
c.method.blockConstructor(psk).Encrypt(identityHeader, identityHeader) c.blockConstructor(psk).Encrypt(identityHeader, identityHeader)
if i == pskLen-2 { if i == pskLen-2 {
break break
@ -477,14 +482,14 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
if err != nil { if err != nil {
return err return err
} }
if c.method.udpCipher != nil { if c.udpCipher != nil {
c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
buffer.Extend(c.method.udpCipher.Overhead()) buffer.Extend(c.udpCipher.Overhead())
} else { } else {
packetHeader := buffer.To(aes.BlockSize) packetHeader := buffer.To(aes.BlockSize)
c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
buffer.Extend(c.session.cipher.Overhead()) buffer.Extend(c.session.cipher.Overhead())
c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader) c.udpBlockCipher.Encrypt(packetHeader, packetHeader)
} }
return common.Error(c.Write(buffer.Bytes())) return common.Error(c.Write(buffer.Bytes()))
} }
@ -497,16 +502,16 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
buffer.Truncate(n) buffer.Truncate(n)
var packetHeader []byte var packetHeader []byte
if c.method.udpCipher != nil { if c.udpCipher != nil {
_, err = c.method.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) _, err = c.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
if err != nil { if err != nil {
return M.Socksaddr{}, E.Cause(err, "decrypt packet") return M.Socksaddr{}, E.Cause(err, "decrypt packet")
} }
buffer.Advance(PacketNonceSize) buffer.Advance(PacketNonceSize)
buffer.Truncate(buffer.Len() - c.method.udpCipher.Overhead()) buffer.Truncate(buffer.Len() - c.udpCipher.Overhead())
} else { } else {
packetHeader = buffer.To(aes.BlockSize) packetHeader = buffer.To(aes.BlockSize)
c.method.udpBlockCipher.Decrypt(packetHeader, packetHeader) c.udpBlockCipher.Decrypt(packetHeader, packetHeader)
} }
var sessionId, packetId uint64 var sessionId, packetId uint64
@ -526,8 +531,8 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
} else if sessionId == c.session.lastRemoteSessionId { } else if sessionId == c.session.lastRemoteSessionId {
remoteCipher = c.session.lastRemoteCipher remoteCipher = c.session.lastRemoteCipher
} else { } else {
key := DeriveSessionKey(c.method.psk, packetHeader[:8], c.method.keyLength) key := SessionKey(c.pskList[len(c.pskList)-1], packetHeader[:8], c.keySaltLength)
remoteCipher = c.method.constructor(common.Dup(key)) remoteCipher = c.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
} }
_, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) _, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
@ -622,14 +627,14 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr) destination := M.SocksaddrFromNet(addr)
var overHead int var overHead int
if c.method.udpCipher != nil { if c.udpCipher != nil {
overHead = PacketNonceSize + c.method.udpCipher.Overhead() overHead = PacketNonceSize + c.udpCipher.Overhead()
} else { } else {
overHead = c.session.cipher.Overhead() overHead = c.session.cipher.Overhead()
} }
overHead += 16 // packet header overHead += 16 // packet header
pskLen := len(c.method.pskList) pskLen := len(c.pskList)
if c.method.udpCipher == nil && pskLen > 1 { if c.udpCipher == nil && pskLen > 1 {
overHead += (pskLen - 1) * aes.BlockSize overHead += (pskLen - 1) * aes.BlockSize
} }
overHead += 1 // header type overHead += 1 // header type
@ -642,8 +647,8 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
buffer := buf.With(common.Dup(_buffer)) buffer := buf.With(common.Dup(_buffer))
var dataIndex int var dataIndex int
if c.method.udpCipher != nil { if c.udpCipher != nil {
common.Must1(buffer.ReadFullFrom(c.method.secureRNG, PacketNonceSize)) common.Must1(buffer.ReadFullFrom(c.secureRNG, PacketNonceSize))
if pskLen > 1 { if pskLen > 1 {
panic("unsupported chacha extended header") panic("unsupported chacha extended header")
} }
@ -657,16 +662,16 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()), binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()),
) )
if c.method.udpCipher == nil && pskLen > 1 { if c.udpCipher == nil && pskLen > 1 {
for i, psk := range c.method.pskList { for i, psk := range c.pskList {
dataIndex += aes.BlockSize dataIndex += aes.BlockSize
pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
identityHeader := buffer.Extend(aes.BlockSize) identityHeader := buffer.Extend(aes.BlockSize)
for textI := 0; textI < aes.BlockSize; textI++ { for textI := 0; textI < aes.BlockSize; textI++ {
identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI) identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI)
} }
c.method.blockConstructor(psk).Encrypt(identityHeader, identityHeader) c.blockConstructor(psk).Encrypt(identityHeader, identityHeader)
if i == pskLen-2 { if i == pskLen-2 {
break break
@ -685,14 +690,14 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if err != nil { if err != nil {
return return
} }
if c.method.udpCipher != nil { if c.udpCipher != nil {
c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
buffer.Extend(c.method.udpCipher.Overhead()) buffer.Extend(c.udpCipher.Overhead())
} else { } else {
packetHeader := buffer.To(aes.BlockSize) packetHeader := buffer.To(aes.BlockSize)
c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
buffer.Extend(c.session.cipher.Overhead()) buffer.Extend(c.session.cipher.Overhead())
c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader) c.udpBlockCipher.Encrypt(packetHeader, packetHeader)
} }
err = common.Error(c.Write(buffer.Bytes())) err = common.Error(c.Write(buffer.Bytes()))
if err != nil { if err != nil {
@ -726,7 +731,7 @@ func (m *Method) newUDPSession() *udpSession {
if m.udpCipher == nil { if m.udpCipher == nil {
sessionId := make([]byte, 8) sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId) binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := DeriveSessionKey(m.psk, sessionId, m.keyLength) key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength)
session.cipher = m.constructor(common.Dup(key)) session.cipher = m.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
} }

View file

@ -30,7 +30,7 @@ import (
type Service struct { type Service struct {
name string name string
secureRNG io.Reader secureRNG io.Reader
keyLength int keySaltLength int
constructor func(key []byte) cipher.AEAD constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block blockConstructor func(key []byte) cipher.Block
udpCipher cipher.AEAD udpCipher cipher.AEAD
@ -42,7 +42,7 @@ type Service struct {
sessions *cache.LruCache[uint64, *serverUDPSession] sessions *cache.LruCache[uint64, *serverUDPSession]
} }
func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { func NewService(method string, psk []byte, password string, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{ s := &Service{
name: method, name: method,
secureRNG: secureRNG, secureRNG: secureRNG,
@ -57,25 +57,31 @@ func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64
switch method { switch method {
case "2022-blake3-aes-128-gcm": case "2022-blake3-aes-128-gcm":
s.keyLength = 16 s.keySaltLength = 16
s.constructor = newAESGCM s.constructor = newAESGCM
s.blockConstructor = newAES s.blockConstructor = newAES
case "2022-blake3-aes-256-gcm": case "2022-blake3-aes-256-gcm":
s.keyLength = 32 s.keySaltLength = 32
s.constructor = newAESGCM s.constructor = newAESGCM
s.blockConstructor = newAES s.blockConstructor = newAES
case "2022-blake3-chacha20-poly1305": case "2022-blake3-chacha20-poly1305":
s.keyLength = 32 s.keySaltLength = 32
s.constructor = newChacha20Poly1305 s.constructor = newChacha20Poly1305
} }
if len(psk) < s.keyLength { if len(psk) == s.keySaltLength {
return nil, shadowaead.ErrBadKey s.psk = psk
} else if len(psk) > s.keyLength { } else if len(psk) != 0 {
psk = DerivePSK(psk, s.keyLength) if len(psk) < s.keySaltLength {
return nil, shadowsocks.ErrBadKey
}
s.psk = Key(psk, s.keySaltLength)
} else if password == "" {
return nil, ErrMissingPasswordPSK
} else {
s.psk = Key([]byte(password), s.keySaltLength)
} }
s.psk = psk
switch method { switch method {
case "2022-blake3-aes-128-gcm": case "2022-blake3-aes-128-gcm":
s.udpBlockCipher = newAES(psk) s.udpBlockCipher = newAES(psk)
@ -97,7 +103,7 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
} }
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
requestSalt := make([]byte, SaltSize) requestSalt := buf.Make(s.keySaltLength)
_, err := io.ReadFull(conn, requestSalt) _, err := io.ReadFull(conn, requestSalt)
if err != nil { if err != nil {
return E.Cause(err, "read request salt") return E.Cause(err, "read request salt")
@ -107,7 +113,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
return E.New("salt not unique") return E.New("salt not unique")
} }
requestKey := DeriveSessionKey(s.psk, requestSalt, s.keyLength) requestKey := SessionKey(s.psk, requestSalt, s.keySaltLength)
reader := shadowaead.NewReader( reader := shadowaead.NewReader(
conn, conn,
s.constructor(common.Dup(requestKey)), s.constructor(common.Dup(requestKey)),
@ -175,10 +181,10 @@ type serverConn struct {
} }
func (c *serverConn) writeResponse(payload []byte) (n int, err error) { func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
var _salt [SaltSize]byte _salt := buf.Make(c.keySaltLength)
salt := common.Dup(_salt[:]) salt := common.Dup(_salt[:])
common.Must1(io.ReadFull(c.secureRNG, salt)) common.Must1(io.ReadFull(c.secureRNG, salt))
key := DeriveSessionKey(c.uPSK, salt, c.keyLength) key := SessionKey(c.uPSK, salt, c.keySaltLength)
runtime.KeepAlive(_salt) runtime.KeepAlive(_salt)
writer := shadowaead.NewWriter( writer := shadowaead.NewWriter(
c.Conn, c.Conn,
@ -294,7 +300,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me
if !loaded { if !loaded {
session.remoteSessionId = sessionId session.remoteSessionId = sessionId
if packetHeader != nil { if packetHeader != nil {
key := DeriveSessionKey(s.psk, packetHeader[:8], s.keyLength) key := SessionKey(s.psk, packetHeader[:8], s.keySaltLength)
session.remoteCipher = s.constructor(common.Dup(key)) session.remoteCipher = s.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
} }
@ -439,7 +445,7 @@ func (m *Service) newUDPSession() *serverUDPSession {
if m.udpCipher == nil { if m.udpCipher == nil {
sessionId := make([]byte, 8) sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId) binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := DeriveSessionKey(m.psk, sessionId, m.keyLength) key := SessionKey(m.psk, sessionId, m.keySaltLength)
session.cipher = m.constructor(common.Dup(key)) session.cipher = m.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
} }

View file

@ -30,10 +30,10 @@ type MultiService[U comparable] struct {
} }
func (s *MultiService[U]) AddUser(user U, key []byte) error { func (s *MultiService[U]) AddUser(user U, key []byte) error {
if len(key) < s.keyLength { if len(key) < s.keySaltLength {
return shadowaead.ErrBadKey return shadowsocks.ErrBadKey
} else if len(key) > s.keyLength { } else if len(key) > s.keySaltLength {
key = DerivePSK(key, s.keyLength) key = Key(key, s.keySaltLength)
} }
var uPSKHash [aes.BlockSize]byte var uPSKHash [aes.BlockSize]byte
@ -67,7 +67,7 @@ func NewMultiService[U comparable](method string, iPSK []byte, secureRNG io.Read
return nil, E.New("unsupported method ", method) return nil, E.New("unsupported method ", method)
} }
ss, err := NewService(method, iPSK, secureRNG, udpTimeout, handler) ss, err := NewService(method, iPSK, "", secureRNG, udpTimeout, handler)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -91,7 +91,7 @@ func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, meta
} }
func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
requestSalt := make([]byte, SaltSize) requestSalt := make([]byte, s.keySaltLength)
_, err := io.ReadFull(conn, requestSalt) _, err := io.ReadFull(conn, requestSalt)
if err != nil { if err != nil {
return E.Cause(err, "read request salt") return E.Cause(err, "read request salt")
@ -108,10 +108,10 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
return E.Cause(err, "read extended identity header") return E.Cause(err, "read extended identity header")
} }
keyMaterial := buf.Make(s.keyLength + SaltSize) keyMaterial := buf.Make(s.keySaltLength * 2)
copy(keyMaterial, s.psk) copy(keyMaterial, s.psk)
copy(keyMaterial[s.keyLength:], requestSalt) copy(keyMaterial[s.keySaltLength:], requestSalt)
_identitySubkey := buf.Make(s.keyLength) _identitySubkey := buf.Make(s.keySaltLength)
identitySubkey := common.Dup(_identitySubkey) identitySubkey := common.Dup(_identitySubkey)
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader) s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader)
@ -126,7 +126,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
return E.New("invalid request") return E.New("invalid request")
} }
requestKey := DeriveSessionKey(uPSK, requestSalt, s.keyLength) requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength)
reader := shadowaead.NewReader( reader := shadowaead.NewReader(
conn, conn,
s.constructor(common.Dup(requestKey)), s.constructor(common.Dup(requestKey)),
@ -230,7 +230,7 @@ func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metad
}) })
if !loaded { if !loaded {
session.remoteSessionId = sessionId session.remoteSessionId = sessionId
key := DeriveSessionKey(uPSK, packetHeader[:8], s.keyLength) key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength)
session.remoteCipher = s.constructor(common.Dup(key)) session.remoteCipher = s.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
} }
@ -312,7 +312,7 @@ func (m *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession {
session.packetId-- session.packetId--
sessionId := make([]byte, 8) sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId) binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := DeriveSessionKey(uPSK, sessionId, m.keyLength) key := SessionKey(uPSK, sessionId, m.keySaltLength)
session.cipher = m.constructor(common.Dup(key)) session.cipher = m.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
return session return session

View file

@ -29,7 +29,7 @@ func TestMultiService(t *testing.T) {
random.Default.Read(uPSK[:]) random.Default.Read(uPSK[:])
multiService.AddUser("my user", uPSK[:]) multiService.AddUser("my user", uPSK[:])
client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}, random.Default) client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}, "", random.Default)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -0,0 +1,56 @@
package shadowimpl
import (
"encoding/base64"
"io"
"strings"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
"github.com/sagernet/sing/protocol/shadowsocks/shadowstream"
)
func FetchMethod(method string, key string, password string, secureRNG io.Reader) (shadowsocks.Method, error) {
if method == "none" {
return shadowsocks.NewNone(), nil
} else if common.Contains(shadowstream.List, method) {
var keyBytes []byte
if key != "" {
kb, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, E.Cause(err, "decode key")
}
keyBytes = kb
}
return shadowstream.New(method, keyBytes, password, secureRNG)
} else if common.Contains(shadowaead.List, method) {
var keyBytes []byte
if key != "" {
kb, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, E.Cause(err, "decode key")
}
keyBytes = kb
}
return shadowaead.New(method, keyBytes, password, secureRNG)
} else if common.Contains(shadowaead_2022.List, method) {
var pskList [][]byte
if key != "" {
keyStrList := strings.Split(key, ":")
pskList = make([][]byte, len(keyStrList))
for i, keyStr := range keyStrList {
kb, err := base64.StdEncoding.DecodeString(keyStr)
if err != nil {
return nil, E.Cause(err, "decode key")
}
pskList[i] = kb
}
}
return shadowaead_2022.New(method, pskList, password, secureRNG)
} else {
return nil, E.New("shadowsocks: unsupported method ", method)
}
}

View file

@ -53,7 +53,7 @@ type Method struct {
secureRNG io.Reader secureRNG io.Reader
} }
func New(method string, key []byte, password []byte, secureRNG io.Reader) (shadowsocks.Method, error) { func New(method string, key []byte, password string, secureRNG io.Reader) (shadowsocks.Method, error) {
m := &Method{ m := &Method{
name: method, name: method,
secureRNG: secureRNG, secureRNG: secureRNG,
@ -167,11 +167,11 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader) (shado
if len(key) == m.keyLength { if len(key) == m.keyLength {
m.key = key m.key = key
} else if len(key) > 0 { } else if len(key) > 0 {
return nil, shadowaead.ErrBadKey return nil, shadowsocks.ErrBadKey
} else if len(password) > 0 { } else if password != "" {
m.key = shadowsocks.Key(password, m.keyLength) m.key = shadowsocks.Key([]byte(password), m.keyLength)
} else { } else {
return nil, shadowaead.ErrMissingPassword return nil, shadowsocks.ErrMissingPassword
} }
return m, nil return m, nil
} }