diff --git a/cli/ss-local/main.go b/cli/ss-local/main.go index 680f419..afe930c 100644 --- a/cli/ss-local/main.go +++ b/cli/ss-local/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "encoding/base64" "encoding/json" "io" "io/ioutil" @@ -32,6 +31,7 @@ import ( "github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022" + "github.com/sagernet/sing/protocol/shadowsocks/shadowimpl" "github.com/sagernet/sing/protocol/shadowsocks/shadowstream" "github.com/sagernet/sing/transport/mixed" "github.com/sagernet/sing/transport/system" @@ -182,53 +182,11 @@ func newClient(f *flags) (*client, error) { if f.ReducedSaltEntropy { rng = &shadowsocks.ReducedEntropyReader{Reader: rng} } - if common.Contains(shadowstream.List, f.Method) { - var key []byte - if f.Key != "" { - kb, err := base64.StdEncoding.DecodeString(f.Key) - if err != nil { - return nil, E.Cause(err, "decode key") - } - key = kb - } - method, err := shadowstream.New(f.Method, key, []byte(f.Password), rng) - if err != nil { - return nil, err - } - c.method = method - } else if common.Contains(shadowaead.List, f.Method) { - var key []byte - if f.Key != "" { - kb, err := base64.StdEncoding.DecodeString(f.Key) - if err != nil { - return nil, E.Cause(err, "decode key") - } - key = kb - } - method, err := shadowaead.New(f.Method, key, []byte(f.Password), rng) - if err != nil { - return nil, err - } - c.method = method - } else if common.Contains(shadowaead_2022.List, f.Method) { - var pskList [][]byte - if f.Key != "" { - keyStrList := strings.Split(f.Key, ":") - pskList = make([][]byte, len(keyStrList)) - for i, keyStr := range keyStrList { - kb, err := base64.StdEncoding.DecodeString(keyStr) - if err != nil { - return nil, E.Cause(err, "decode key") - } - pskList[i] = kb - } - } - method, err := shadowaead_2022.New(f.Method, pskList, rng) - if err != nil { - return nil, err - } - c.method = method + method, err := shadowimpl.FetchMethod(f.Method, f.Key, f.Password, rng) + if err != nil { + return nil, err } + c.method = method } c.dialer.Control = func(network, address string, c syscall.RawConn) error { diff --git a/cli/ss-server/main.go b/cli/ss-server/main.go index 210cc1c..6a2bfe0 100644 --- a/cli/ss-server/main.go +++ b/cli/ss-server/main.go @@ -143,7 +143,7 @@ func newServer(f *flags) (*server, error) { } key = kb } - service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Default, udpTimeout, s) + service, err := shadowaead.NewService(f.Method, key, f.Password, random.Default, udpTimeout, s) if err != nil { return nil, err } @@ -157,7 +157,7 @@ func newServer(f *flags) (*server, error) { } key = kb } - service, err := shadowaead_2022.NewService(f.Method, key, random.Default, udpTimeout, s) + service, err := shadowaead_2022.NewService(f.Method, key, f.Password, random.Default, udpTimeout, s) if err != nil { return nil, err } diff --git a/protocol/shadowsocks/protocol.go b/protocol/shadowsocks/protocol.go index 6f40c4a..50a0d5a 100644 --- a/protocol/shadowsocks/protocol.go +++ b/protocol/shadowsocks/protocol.go @@ -7,10 +7,16 @@ import ( "math/rand" "net" + E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) +var ( + ErrBadKey = E.New("shadowsocks: bad key") + ErrMissingPassword = E.New("shadowsocks: missing password") +) + type Method interface { Name() string KeyLength() int diff --git a/protocol/shadowsocks/shadowaead/protocol.go b/protocol/shadowsocks/shadowaead/protocol.go index fd4e64d..8458f2a 100644 --- a/protocol/shadowsocks/shadowaead/protocol.go +++ b/protocol/shadowsocks/shadowaead/protocol.go @@ -27,12 +27,7 @@ var List = []string{ "xchacha20-ietf-poly1305", } -var ( - ErrBadKey = E.New("shadowsocks: bad key") - ErrMissingPassword = E.New("shadowsocks: missing password") -) - -func New(method string, key []byte, password []byte, secureRNG io.Reader) (shadowsocks.Method, error) { +func New(method string, key []byte, password string, secureRNG io.Reader) (shadowsocks.Method, error) { m := &Method{ name: method, secureRNG: secureRNG, @@ -65,11 +60,11 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader) (shado if len(key) == m.keySaltLength { m.key = key } else if len(key) > 0 { - return nil, ErrBadKey - } else if len(password) > 0 { - m.key = shadowsocks.Key(password, m.keySaltLength) + return nil, shadowsocks.ErrBadKey + } else if password == "" { + return nil, shadowsocks.ErrMissingPassword } else { - return nil, ErrMissingPassword + m.key = shadowsocks.Key([]byte(password), m.keySaltLength) } return m, nil } @@ -181,12 +176,10 @@ func (m *Method) DecodePacket(buffer *buf.Buffer) error { type clientConn struct { net.Conn - method *Method destination M.Socksaddr - - reader *Reader - writer *Writer + reader *Reader + writer *Writer } func (c *clientConn) writeRequest(payload []byte) error { diff --git a/protocol/shadowsocks/shadowaead/service.go b/protocol/shadowsocks/shadowaead/service.go index 14b7764..dd47ce9 100644 --- a/protocol/shadowsocks/shadowaead/service.go +++ b/protocol/shadowsocks/shadowaead/service.go @@ -30,7 +30,7 @@ type Service struct { udpNat *udpnat.Service[netip.AddrPort] } -func NewService(method string, key []byte, password []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { +func NewService(method string, key []byte, password string, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { s := &Service{ name: method, secureRNG: secureRNG, @@ -65,11 +65,11 @@ func NewService(method string, key []byte, password []byte, secureRNG io.Reader, if len(key) == s.keySaltLength { s.key = key } else if len(key) > 0 { - return nil, ErrBadKey - } else if len(password) > 0 { - s.key = shadowsocks.Key(password, s.keySaltLength) + return nil, shadowsocks.ErrBadKey + } else if password != "" { + s.key = shadowsocks.Key([]byte(password), s.keySaltLength) } else { - return nil, ErrMissingPassword + return nil, shadowsocks.ErrMissingPassword } return s, nil } diff --git a/protocol/shadowsocks/shadowaead_2022/protocol.go b/protocol/shadowsocks/shadowaead_2022/protocol.go index 842b8a8..b56af92 100644 --- a/protocol/shadowsocks/shadowaead_2022/protocol.go +++ b/protocol/shadowsocks/shadowaead_2022/protocol.go @@ -4,11 +4,13 @@ import ( "bytes" "crypto/aes" "crypto/cipher" + "crypto/sha256" "encoding/binary" "io" "math" "math/rand" "net" + "os" "runtime" "sync/atomic" "time" @@ -31,7 +33,6 @@ const ( HeaderTypeClient = 0 HeaderTypeServer = 1 MaxPaddingLength = 900 - SaltSize = 32 PacketNonceSize = 24 MaxPacketSize = 65535 ) @@ -48,6 +49,7 @@ const ( ) var ( + ErrMissingPasswordPSK = E.New("shadowsocks: missing password or psk") ErrBadHeaderType = E.New("shadowsocks: bad header type") ErrBadTimestamp = E.New("shadowsocks: bad timestamp") ErrBadRequestSalt = E.New("shadowsocks: bad request salt") @@ -62,7 +64,7 @@ var List = []string{ "2022-blake3-chacha20-poly1305", } -func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Method, error) { +func New(method string, pskList [][]byte, password string, secureRNG io.Reader) (shadowsocks.Method, error) { m := &Method{ name: method, pskList: pskList, @@ -72,27 +74,35 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth switch method { case "2022-blake3-aes-128-gcm": - m.keyLength = 16 + m.keySaltLength = 16 m.constructor = newAESGCM m.blockConstructor = newAES case "2022-blake3-aes-256-gcm": - m.keyLength = 32 + m.keySaltLength = 32 m.constructor = newAESGCM m.blockConstructor = newAES case "2022-blake3-chacha20-poly1305": - m.keyLength = 32 + if len(pskList) > 1 { + return nil, os.ErrInvalid + } + m.keySaltLength = 32 m.constructor = newChacha20Poly1305 } - for i, psk := range pskList { - if len(psk) < m.keyLength { - return nil, shadowaead.ErrBadKey - } else if len(psk) > m.keyLength { - pskList[i] = DerivePSK(psk, m.keyLength) + if len(pskList) == 0 { + if password == "" { + return nil, ErrMissingPasswordPSK } + pskList = [][]byte{Key([]byte(password), m.keySaltLength)} } - m.psk = pskList[len(pskList)-1] + for i, psk := range pskList { + if len(psk) < m.keySaltLength { + return nil, shadowsocks.ErrBadKey + } else if len(psk) > m.keySaltLength { + pskList[i] = Key(psk, m.keySaltLength) + } + } if len(pskList) > 1 { pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize) @@ -112,19 +122,18 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth case "2022-blake3-aes-256-gcm": m.udpBlockCipher = newAES(pskList[0]) case "2022-blake3-chacha20-poly1305": - m.udpCipher = newXChacha20Poly1305(m.psk) + m.udpCipher = newXChacha20Poly1305(pskList[0]) } return m, nil } -func DerivePSK(key []byte, keyLength int) []byte { - outKey := buf.Make(keyLength) - blake3.DeriveKey(outKey, "shadowsocks 2022 pre shared key", key) - return outKey +func Key(key []byte, keyLength int) []byte { + psk := sha256.Sum256(key) + return psk[:keyLength] } -func DeriveSessionKey(psk []byte, salt []byte, keyLength int) []byte { +func SessionKey(psk []byte, salt []byte, keyLength int) []byte { sessionKey := buf.Make(len(psk) + len(salt)) copy(sessionKey, psk) copy(sessionKey[len(psk):], salt) @@ -161,12 +170,11 @@ func newXChacha20Poly1305(key []byte) cipher.AEAD { type Method struct { name string - keyLength int + keySaltLength int constructor func(key []byte) cipher.AEAD blockConstructor func(key []byte) cipher.Block udpCipher cipher.AEAD udpBlockCipher cipher.Block - psk []byte pskList [][]byte pskHash []byte secureRNG io.Reader @@ -178,13 +186,13 @@ func (m *Method) Name() string { } func (m *Method) KeyLength() int { - return m.keyLength + return m.keySaltLength } func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { shadowsocksConn := &clientConn{ + Method: m, Conn: conn, - method: m, destination: destination, } return shadowsocksConn, shadowsocksConn.writeRequest(nil) @@ -192,26 +200,23 @@ func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, err func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { return &clientConn{ + Method: m, Conn: conn, - method: m, destination: destination, } } func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { - return &clientPacketConn{conn, m, m.newUDPSession()} + return &clientPacketConn{m, conn, m.newUDPSession()} } type clientConn struct { + *Method net.Conn - - method *Method destination M.Socksaddr - requestSalt []byte - - reader *shadowaead.Reader - writer *shadowaead.Writer + reader *shadowaead.Reader + writer *shadowaead.Writer } func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) { @@ -220,10 +225,10 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) return } for i, psk := range m.pskList { - keyMaterial := buf.Make(m.keyLength + SaltSize) + keyMaterial := buf.Make(m.keySaltLength * 2) copy(keyMaterial, psk) - copy(keyMaterial[m.keyLength:], salt) - _identitySubkey := buf.Make(m.keyLength) + copy(keyMaterial[m.keySaltLength:], salt) + _identitySubkey := buf.Make(m.keySaltLength) identitySubkey := common.Dup(_identitySubkey) blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) @@ -239,20 +244,20 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) } func (c *clientConn) writeRequest(payload []byte) error { - salt := make([]byte, SaltSize) - common.Must1(io.ReadFull(c.method.secureRNG, salt)) + salt := buf.Make(c.keySaltLength) + common.Must1(io.ReadFull(c.secureRNG, salt)) - key := DeriveSessionKey(c.method.psk, salt, c.method.keyLength) + key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength) writer := shadowaead.NewWriter( c.Conn, - c.method.constructor(common.Dup(key)), + c.constructor(common.Dup(key)), MaxPacketSize, ) runtime.KeepAlive(key) header := writer.Buffer() header.Write(salt) - c.method.writeExtendedIdentityHeaders(header, salt) + c.writeExtendedIdentityHeaders(header, salt) bufferedWriter := writer.BufferedWriter(header.Len()) @@ -279,7 +284,7 @@ func (c *clientConn) writeRequest(payload []byte) error { if err != nil { return E.Cause(err, "write padding length") } - _, err = io.CopyN(bufferedWriter, c.method.secureRNG, int64(pLen)) + _, err = io.CopyN(bufferedWriter, c.secureRNG, int64(pLen)) if err != nil { return E.Cause(err, "write padding") } @@ -300,22 +305,22 @@ func (c *clientConn) readResponse() error { return nil } - _salt := make([]byte, SaltSize) + _salt := buf.Make(c.keySaltLength) salt := common.Dup(_salt) _, err := io.ReadFull(c.Conn, salt) if err != nil { return err } - if !c.method.replayFilter.Check(salt) { + if !c.replayFilter.Check(salt) { return E.New("salt not unique") } - key := DeriveSessionKey(c.method.psk, salt, c.method.keyLength) + key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength) runtime.KeepAlive(_salt) reader := shadowaead.NewReader( c.Conn, - c.method.constructor(common.Dup(key)), + c.constructor(common.Dup(key)), MaxPacketSize, ) runtime.KeepAlive(key) @@ -339,7 +344,7 @@ func (c *clientConn) readResponse() error { return ErrBadTimestamp } - _requestSalt := make([]byte, SaltSize) + _requestSalt := buf.Make(c.keySaltLength) requestSalt := common.Dup(_requestSalt) _, err = io.ReadFull(reader, requestSalt) if err != nil { @@ -412,19 +417,19 @@ func (c *clientConn) WriterReplaceable() bool { } type clientPacketConn struct { + *Method net.Conn - method *Method session *udpSession } func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { var hdrLen int - if c.method.udpCipher != nil { + if c.udpCipher != nil { hdrLen = PacketNonceSize } hdrLen += 16 // packet header - pskLen := len(c.method.pskList) - if c.method.udpCipher == nil && pskLen > 1 { + pskLen := len(c.pskList) + if c.udpCipher == nil && pskLen > 1 { hdrLen += (pskLen - 1) * aes.BlockSize } hdrLen += 1 // header type @@ -434,8 +439,8 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad header := buf.With(buffer.ExtendHeader(hdrLen)) var dataIndex int - if c.method.udpCipher != nil { - common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize)) + if c.udpCipher != nil { + common.Must1(header.ReadFullFrom(c.secureRNG, PacketNonceSize)) if pskLen > 1 { panic("unsupported chacha extended header") } @@ -449,16 +454,16 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad binary.Write(header, binary.BigEndian, c.session.nextPacketId()), ) - if c.method.udpCipher == nil && pskLen > 1 { - for i, psk := range c.method.pskList { + if c.udpCipher == nil && pskLen > 1 { + for i, psk := range c.pskList { dataIndex += aes.BlockSize - pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] + pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] identityHeader := header.Extend(aes.BlockSize) for textI := 0; textI < aes.BlockSize; textI++ { identityHeader[textI] = pskHash[textI] ^ header.Byte(textI) } - c.method.blockConstructor(psk).Encrypt(identityHeader, identityHeader) + c.blockConstructor(psk).Encrypt(identityHeader, identityHeader) if i == pskLen-2 { break @@ -477,14 +482,14 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad if err != nil { return err } - if c.method.udpCipher != nil { - c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) - buffer.Extend(c.method.udpCipher.Overhead()) + if c.udpCipher != nil { + c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) + buffer.Extend(c.udpCipher.Overhead()) } else { packetHeader := buffer.To(aes.BlockSize) c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) buffer.Extend(c.session.cipher.Overhead()) - c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader) + c.udpBlockCipher.Encrypt(packetHeader, packetHeader) } return common.Error(c.Write(buffer.Bytes())) } @@ -497,16 +502,16 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { buffer.Truncate(n) var packetHeader []byte - if c.method.udpCipher != nil { - _, err = c.method.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) + if c.udpCipher != nil { + _, err = c.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) if err != nil { return M.Socksaddr{}, E.Cause(err, "decrypt packet") } buffer.Advance(PacketNonceSize) - buffer.Truncate(buffer.Len() - c.method.udpCipher.Overhead()) + buffer.Truncate(buffer.Len() - c.udpCipher.Overhead()) } else { packetHeader = buffer.To(aes.BlockSize) - c.method.udpBlockCipher.Decrypt(packetHeader, packetHeader) + c.udpBlockCipher.Decrypt(packetHeader, packetHeader) } var sessionId, packetId uint64 @@ -526,8 +531,8 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { } else if sessionId == c.session.lastRemoteSessionId { remoteCipher = c.session.lastRemoteCipher } else { - key := DeriveSessionKey(c.method.psk, packetHeader[:8], c.method.keyLength) - remoteCipher = c.method.constructor(common.Dup(key)) + key := SessionKey(c.pskList[len(c.pskList)-1], packetHeader[:8], c.keySaltLength) + remoteCipher = c.constructor(common.Dup(key)) runtime.KeepAlive(key) } _, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) @@ -622,14 +627,14 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { destination := M.SocksaddrFromNet(addr) var overHead int - if c.method.udpCipher != nil { - overHead = PacketNonceSize + c.method.udpCipher.Overhead() + if c.udpCipher != nil { + overHead = PacketNonceSize + c.udpCipher.Overhead() } else { overHead = c.session.cipher.Overhead() } overHead += 16 // packet header - pskLen := len(c.method.pskList) - if c.method.udpCipher == nil && pskLen > 1 { + pskLen := len(c.pskList) + if c.udpCipher == nil && pskLen > 1 { overHead += (pskLen - 1) * aes.BlockSize } overHead += 1 // header type @@ -642,8 +647,8 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { buffer := buf.With(common.Dup(_buffer)) var dataIndex int - if c.method.udpCipher != nil { - common.Must1(buffer.ReadFullFrom(c.method.secureRNG, PacketNonceSize)) + if c.udpCipher != nil { + common.Must1(buffer.ReadFullFrom(c.secureRNG, PacketNonceSize)) if pskLen > 1 { panic("unsupported chacha extended header") } @@ -657,16 +662,16 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()), ) - if c.method.udpCipher == nil && pskLen > 1 { - for i, psk := range c.method.pskList { + if c.udpCipher == nil && pskLen > 1 { + for i, psk := range c.pskList { dataIndex += aes.BlockSize - pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] + pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] identityHeader := buffer.Extend(aes.BlockSize) for textI := 0; textI < aes.BlockSize; textI++ { identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI) } - c.method.blockConstructor(psk).Encrypt(identityHeader, identityHeader) + c.blockConstructor(psk).Encrypt(identityHeader, identityHeader) if i == pskLen-2 { break @@ -685,14 +690,14 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if err != nil { return } - if c.method.udpCipher != nil { - c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) - buffer.Extend(c.method.udpCipher.Overhead()) + if c.udpCipher != nil { + c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) + buffer.Extend(c.udpCipher.Overhead()) } else { packetHeader := buffer.To(aes.BlockSize) c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) buffer.Extend(c.session.cipher.Overhead()) - c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader) + c.udpBlockCipher.Encrypt(packetHeader, packetHeader) } err = common.Error(c.Write(buffer.Bytes())) if err != nil { @@ -726,7 +731,7 @@ func (m *Method) newUDPSession() *udpSession { if m.udpCipher == nil { sessionId := make([]byte, 8) binary.BigEndian.PutUint64(sessionId, session.sessionId) - key := DeriveSessionKey(m.psk, sessionId, m.keyLength) + key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength) session.cipher = m.constructor(common.Dup(key)) runtime.KeepAlive(key) } diff --git a/protocol/shadowsocks/shadowaead_2022/service.go b/protocol/shadowsocks/shadowaead_2022/service.go index e1fcf9b..28bc642 100644 --- a/protocol/shadowsocks/shadowaead_2022/service.go +++ b/protocol/shadowsocks/shadowaead_2022/service.go @@ -30,7 +30,7 @@ import ( type Service struct { name string secureRNG io.Reader - keyLength int + keySaltLength int constructor func(key []byte) cipher.AEAD blockConstructor func(key []byte) cipher.Block udpCipher cipher.AEAD @@ -42,7 +42,7 @@ type Service struct { sessions *cache.LruCache[uint64, *serverUDPSession] } -func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { +func NewService(method string, psk []byte, password string, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { s := &Service{ name: method, secureRNG: secureRNG, @@ -57,25 +57,31 @@ func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64 switch method { case "2022-blake3-aes-128-gcm": - s.keyLength = 16 + s.keySaltLength = 16 s.constructor = newAESGCM s.blockConstructor = newAES case "2022-blake3-aes-256-gcm": - s.keyLength = 32 + s.keySaltLength = 32 s.constructor = newAESGCM s.blockConstructor = newAES case "2022-blake3-chacha20-poly1305": - s.keyLength = 32 + s.keySaltLength = 32 s.constructor = newChacha20Poly1305 } - if len(psk) < s.keyLength { - return nil, shadowaead.ErrBadKey - } else if len(psk) > s.keyLength { - psk = DerivePSK(psk, s.keyLength) + if len(psk) == s.keySaltLength { + s.psk = psk + } else if len(psk) != 0 { + if len(psk) < s.keySaltLength { + return nil, shadowsocks.ErrBadKey + } + s.psk = Key(psk, s.keySaltLength) + } else if password == "" { + return nil, ErrMissingPasswordPSK + } else { + s.psk = Key([]byte(password), s.keySaltLength) } - s.psk = psk switch method { case "2022-blake3-aes-128-gcm": s.udpBlockCipher = newAES(psk) @@ -97,7 +103,7 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M } func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - requestSalt := make([]byte, SaltSize) + requestSalt := buf.Make(s.keySaltLength) _, err := io.ReadFull(conn, requestSalt) if err != nil { return E.Cause(err, "read request salt") @@ -107,7 +113,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M return E.New("salt not unique") } - requestKey := DeriveSessionKey(s.psk, requestSalt, s.keyLength) + requestKey := SessionKey(s.psk, requestSalt, s.keySaltLength) reader := shadowaead.NewReader( conn, s.constructor(common.Dup(requestKey)), @@ -175,10 +181,10 @@ type serverConn struct { } func (c *serverConn) writeResponse(payload []byte) (n int, err error) { - var _salt [SaltSize]byte + _salt := buf.Make(c.keySaltLength) salt := common.Dup(_salt[:]) common.Must1(io.ReadFull(c.secureRNG, salt)) - key := DeriveSessionKey(c.uPSK, salt, c.keyLength) + key := SessionKey(c.uPSK, salt, c.keySaltLength) runtime.KeepAlive(_salt) writer := shadowaead.NewWriter( c.Conn, @@ -294,7 +300,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me if !loaded { session.remoteSessionId = sessionId if packetHeader != nil { - key := DeriveSessionKey(s.psk, packetHeader[:8], s.keyLength) + key := SessionKey(s.psk, packetHeader[:8], s.keySaltLength) session.remoteCipher = s.constructor(common.Dup(key)) runtime.KeepAlive(key) } @@ -439,7 +445,7 @@ func (m *Service) newUDPSession() *serverUDPSession { if m.udpCipher == nil { sessionId := make([]byte, 8) binary.BigEndian.PutUint64(sessionId, session.sessionId) - key := DeriveSessionKey(m.psk, sessionId, m.keyLength) + key := SessionKey(m.psk, sessionId, m.keySaltLength) session.cipher = m.constructor(common.Dup(key)) runtime.KeepAlive(key) } diff --git a/protocol/shadowsocks/shadowaead_2022/service_multi.go b/protocol/shadowsocks/shadowaead_2022/service_multi.go index 9e2880d..ac5a7f4 100644 --- a/protocol/shadowsocks/shadowaead_2022/service_multi.go +++ b/protocol/shadowsocks/shadowaead_2022/service_multi.go @@ -30,10 +30,10 @@ type MultiService[U comparable] struct { } func (s *MultiService[U]) AddUser(user U, key []byte) error { - if len(key) < s.keyLength { - return shadowaead.ErrBadKey - } else if len(key) > s.keyLength { - key = DerivePSK(key, s.keyLength) + if len(key) < s.keySaltLength { + return shadowsocks.ErrBadKey + } else if len(key) > s.keySaltLength { + key = Key(key, s.keySaltLength) } var uPSKHash [aes.BlockSize]byte @@ -67,7 +67,7 @@ func NewMultiService[U comparable](method string, iPSK []byte, secureRNG io.Read return nil, E.New("unsupported method ", method) } - ss, err := NewService(method, iPSK, secureRNG, udpTimeout, handler) + ss, err := NewService(method, iPSK, "", secureRNG, udpTimeout, handler) if err != nil { return nil, err } @@ -91,7 +91,7 @@ func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, meta } func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - requestSalt := make([]byte, SaltSize) + requestSalt := make([]byte, s.keySaltLength) _, err := io.ReadFull(conn, requestSalt) if err != nil { return E.Cause(err, "read request salt") @@ -108,10 +108,10 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta return E.Cause(err, "read extended identity header") } - keyMaterial := buf.Make(s.keyLength + SaltSize) + keyMaterial := buf.Make(s.keySaltLength * 2) copy(keyMaterial, s.psk) - copy(keyMaterial[s.keyLength:], requestSalt) - _identitySubkey := buf.Make(s.keyLength) + copy(keyMaterial[s.keySaltLength:], requestSalt) + _identitySubkey := buf.Make(s.keySaltLength) identitySubkey := common.Dup(_identitySubkey) blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader) @@ -126,7 +126,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta return E.New("invalid request") } - requestKey := DeriveSessionKey(uPSK, requestSalt, s.keyLength) + requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength) reader := shadowaead.NewReader( conn, s.constructor(common.Dup(requestKey)), @@ -230,7 +230,7 @@ func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metad }) if !loaded { session.remoteSessionId = sessionId - key := DeriveSessionKey(uPSK, packetHeader[:8], s.keyLength) + key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength) session.remoteCipher = s.constructor(common.Dup(key)) runtime.KeepAlive(key) } @@ -312,7 +312,7 @@ func (m *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession { session.packetId-- sessionId := make([]byte, 8) binary.BigEndian.PutUint64(sessionId, session.sessionId) - key := DeriveSessionKey(uPSK, sessionId, m.keyLength) + key := SessionKey(uPSK, sessionId, m.keySaltLength) session.cipher = m.constructor(common.Dup(key)) runtime.KeepAlive(key) return session diff --git a/protocol/shadowsocks/shadowaead_2022/service_multi_test.go b/protocol/shadowsocks/shadowaead_2022/service_multi_test.go index 16d6648..fb04dba 100644 --- a/protocol/shadowsocks/shadowaead_2022/service_multi_test.go +++ b/protocol/shadowsocks/shadowaead_2022/service_multi_test.go @@ -29,7 +29,7 @@ func TestMultiService(t *testing.T) { random.Default.Read(uPSK[:]) multiService.AddUser("my user", uPSK[:]) - client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}, random.Default) + client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}, "", random.Default) if err != nil { t.Fatal(err) } diff --git a/protocol/shadowsocks/shadowimpl/fetcher.go b/protocol/shadowsocks/shadowimpl/fetcher.go new file mode 100644 index 0000000..11753ba --- /dev/null +++ b/protocol/shadowsocks/shadowimpl/fetcher.go @@ -0,0 +1,56 @@ +package shadowimpl + +import ( + "encoding/base64" + "io" + "strings" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/protocol/shadowsocks" + "github.com/sagernet/sing/protocol/shadowsocks/shadowaead" + "github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022" + "github.com/sagernet/sing/protocol/shadowsocks/shadowstream" +) + +func FetchMethod(method string, key string, password string, secureRNG io.Reader) (shadowsocks.Method, error) { + if method == "none" { + return shadowsocks.NewNone(), nil + } else if common.Contains(shadowstream.List, method) { + var keyBytes []byte + if key != "" { + kb, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return nil, E.Cause(err, "decode key") + } + keyBytes = kb + } + return shadowstream.New(method, keyBytes, password, secureRNG) + } else if common.Contains(shadowaead.List, method) { + var keyBytes []byte + if key != "" { + kb, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return nil, E.Cause(err, "decode key") + } + keyBytes = kb + } + return shadowaead.New(method, keyBytes, password, secureRNG) + } else if common.Contains(shadowaead_2022.List, method) { + var pskList [][]byte + if key != "" { + keyStrList := strings.Split(key, ":") + pskList = make([][]byte, len(keyStrList)) + for i, keyStr := range keyStrList { + kb, err := base64.StdEncoding.DecodeString(keyStr) + if err != nil { + return nil, E.Cause(err, "decode key") + } + pskList[i] = kb + } + } + return shadowaead_2022.New(method, pskList, password, secureRNG) + } else { + return nil, E.New("shadowsocks: unsupported method ", method) + } +} diff --git a/protocol/shadowsocks/shadowstream/protocol.go b/protocol/shadowsocks/shadowstream/protocol.go index 35ef971..c683b03 100644 --- a/protocol/shadowsocks/shadowstream/protocol.go +++ b/protocol/shadowsocks/shadowstream/protocol.go @@ -53,7 +53,7 @@ type Method struct { secureRNG io.Reader } -func New(method string, key []byte, password []byte, secureRNG io.Reader) (shadowsocks.Method, error) { +func New(method string, key []byte, password string, secureRNG io.Reader) (shadowsocks.Method, error) { m := &Method{ name: method, secureRNG: secureRNG, @@ -167,11 +167,11 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader) (shado if len(key) == m.keyLength { m.key = key } else if len(key) > 0 { - return nil, shadowaead.ErrBadKey - } else if len(password) > 0 { - m.key = shadowsocks.Key(password, m.keyLength) + return nil, shadowsocks.ErrBadKey + } else if password != "" { + m.key = shadowsocks.Key([]byte(password), m.keyLength) } else { - return nil, shadowaead.ErrMissingPassword + return nil, shadowsocks.ErrMissingPassword } return m, nil }