mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 12:27:37 +03:00
Add password support for shadowsocks 2022 ciphers
This commit is contained in:
parent
f1a5f8aaa3
commit
2aae93c5b8
11 changed files with 204 additions and 180 deletions
|
@ -2,7 +2,6 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -32,6 +31,7 @@ import (
|
||||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
||||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
|
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
|
||||||
|
"github.com/sagernet/sing/protocol/shadowsocks/shadowimpl"
|
||||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowstream"
|
"github.com/sagernet/sing/protocol/shadowsocks/shadowstream"
|
||||||
"github.com/sagernet/sing/transport/mixed"
|
"github.com/sagernet/sing/transport/mixed"
|
||||||
"github.com/sagernet/sing/transport/system"
|
"github.com/sagernet/sing/transport/system"
|
||||||
|
@ -182,53 +182,11 @@ func newClient(f *flags) (*client, error) {
|
||||||
if f.ReducedSaltEntropy {
|
if f.ReducedSaltEntropy {
|
||||||
rng = &shadowsocks.ReducedEntropyReader{Reader: rng}
|
rng = &shadowsocks.ReducedEntropyReader{Reader: rng}
|
||||||
}
|
}
|
||||||
if common.Contains(shadowstream.List, f.Method) {
|
method, err := shadowimpl.FetchMethod(f.Method, f.Key, f.Password, rng)
|
||||||
var key []byte
|
if err != nil {
|
||||||
if f.Key != "" {
|
return nil, err
|
||||||
kb, err := base64.StdEncoding.DecodeString(f.Key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, E.Cause(err, "decode key")
|
|
||||||
}
|
|
||||||
key = kb
|
|
||||||
}
|
|
||||||
method, err := shadowstream.New(f.Method, key, []byte(f.Password), rng)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.method = method
|
|
||||||
} else if common.Contains(shadowaead.List, f.Method) {
|
|
||||||
var key []byte
|
|
||||||
if f.Key != "" {
|
|
||||||
kb, err := base64.StdEncoding.DecodeString(f.Key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, E.Cause(err, "decode key")
|
|
||||||
}
|
|
||||||
key = kb
|
|
||||||
}
|
|
||||||
method, err := shadowaead.New(f.Method, key, []byte(f.Password), rng)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.method = method
|
|
||||||
} else if common.Contains(shadowaead_2022.List, f.Method) {
|
|
||||||
var pskList [][]byte
|
|
||||||
if f.Key != "" {
|
|
||||||
keyStrList := strings.Split(f.Key, ":")
|
|
||||||
pskList = make([][]byte, len(keyStrList))
|
|
||||||
for i, keyStr := range keyStrList {
|
|
||||||
kb, err := base64.StdEncoding.DecodeString(keyStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, E.Cause(err, "decode key")
|
|
||||||
}
|
|
||||||
pskList[i] = kb
|
|
||||||
}
|
|
||||||
}
|
|
||||||
method, err := shadowaead_2022.New(f.Method, pskList, rng)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.method = method
|
|
||||||
}
|
}
|
||||||
|
c.method = method
|
||||||
}
|
}
|
||||||
|
|
||||||
c.dialer.Control = func(network, address string, c syscall.RawConn) error {
|
c.dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||||
|
|
|
@ -143,7 +143,7 @@ func newServer(f *flags) (*server, error) {
|
||||||
}
|
}
|
||||||
key = kb
|
key = kb
|
||||||
}
|
}
|
||||||
service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Default, udpTimeout, s)
|
service, err := shadowaead.NewService(f.Method, key, f.Password, random.Default, udpTimeout, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -157,7 +157,7 @@ func newServer(f *flags) (*server, error) {
|
||||||
}
|
}
|
||||||
key = kb
|
key = kb
|
||||||
}
|
}
|
||||||
service, err := shadowaead_2022.NewService(f.Method, key, random.Default, udpTimeout, s)
|
service, err := shadowaead_2022.NewService(f.Method, key, f.Password, random.Default, udpTimeout, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,10 +7,16 @@ import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrBadKey = E.New("shadowsocks: bad key")
|
||||||
|
ErrMissingPassword = E.New("shadowsocks: missing password")
|
||||||
|
)
|
||||||
|
|
||||||
type Method interface {
|
type Method interface {
|
||||||
Name() string
|
Name() string
|
||||||
KeyLength() int
|
KeyLength() int
|
||||||
|
|
|
@ -27,12 +27,7 @@ var List = []string{
|
||||||
"xchacha20-ietf-poly1305",
|
"xchacha20-ietf-poly1305",
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
func New(method string, key []byte, password string, secureRNG io.Reader) (shadowsocks.Method, error) {
|
||||||
ErrBadKey = E.New("shadowsocks: bad key")
|
|
||||||
ErrMissingPassword = E.New("shadowsocks: missing password")
|
|
||||||
)
|
|
||||||
|
|
||||||
func New(method string, key []byte, password []byte, secureRNG io.Reader) (shadowsocks.Method, error) {
|
|
||||||
m := &Method{
|
m := &Method{
|
||||||
name: method,
|
name: method,
|
||||||
secureRNG: secureRNG,
|
secureRNG: secureRNG,
|
||||||
|
@ -65,11 +60,11 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader) (shado
|
||||||
if len(key) == m.keySaltLength {
|
if len(key) == m.keySaltLength {
|
||||||
m.key = key
|
m.key = key
|
||||||
} else if len(key) > 0 {
|
} else if len(key) > 0 {
|
||||||
return nil, ErrBadKey
|
return nil, shadowsocks.ErrBadKey
|
||||||
} else if len(password) > 0 {
|
} else if password == "" {
|
||||||
m.key = shadowsocks.Key(password, m.keySaltLength)
|
return nil, shadowsocks.ErrMissingPassword
|
||||||
} else {
|
} else {
|
||||||
return nil, ErrMissingPassword
|
m.key = shadowsocks.Key([]byte(password), m.keySaltLength)
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
@ -181,12 +176,10 @@ func (m *Method) DecodePacket(buffer *buf.Buffer) error {
|
||||||
|
|
||||||
type clientConn struct {
|
type clientConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
|
|
||||||
method *Method
|
method *Method
|
||||||
destination M.Socksaddr
|
destination M.Socksaddr
|
||||||
|
reader *Reader
|
||||||
reader *Reader
|
writer *Writer
|
||||||
writer *Writer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) writeRequest(payload []byte) error {
|
func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
|
|
|
@ -30,7 +30,7 @@ type Service struct {
|
||||||
udpNat *udpnat.Service[netip.AddrPort]
|
udpNat *udpnat.Service[netip.AddrPort]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(method string, key []byte, password []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
func NewService(method string, key []byte, password string, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||||
s := &Service{
|
s := &Service{
|
||||||
name: method,
|
name: method,
|
||||||
secureRNG: secureRNG,
|
secureRNG: secureRNG,
|
||||||
|
@ -65,11 +65,11 @@ func NewService(method string, key []byte, password []byte, secureRNG io.Reader,
|
||||||
if len(key) == s.keySaltLength {
|
if len(key) == s.keySaltLength {
|
||||||
s.key = key
|
s.key = key
|
||||||
} else if len(key) > 0 {
|
} else if len(key) > 0 {
|
||||||
return nil, ErrBadKey
|
return nil, shadowsocks.ErrBadKey
|
||||||
} else if len(password) > 0 {
|
} else if password != "" {
|
||||||
s.key = shadowsocks.Key(password, s.keySaltLength)
|
s.key = shadowsocks.Key([]byte(password), s.keySaltLength)
|
||||||
} else {
|
} else {
|
||||||
return nil, ErrMissingPassword
|
return nil, shadowsocks.ErrMissingPassword
|
||||||
}
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,11 +4,13 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/aes"
|
"crypto/aes"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
@ -31,7 +33,6 @@ const (
|
||||||
HeaderTypeClient = 0
|
HeaderTypeClient = 0
|
||||||
HeaderTypeServer = 1
|
HeaderTypeServer = 1
|
||||||
MaxPaddingLength = 900
|
MaxPaddingLength = 900
|
||||||
SaltSize = 32
|
|
||||||
PacketNonceSize = 24
|
PacketNonceSize = 24
|
||||||
MaxPacketSize = 65535
|
MaxPacketSize = 65535
|
||||||
)
|
)
|
||||||
|
@ -48,6 +49,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
ErrMissingPasswordPSK = E.New("shadowsocks: missing password or psk")
|
||||||
ErrBadHeaderType = E.New("shadowsocks: bad header type")
|
ErrBadHeaderType = E.New("shadowsocks: bad header type")
|
||||||
ErrBadTimestamp = E.New("shadowsocks: bad timestamp")
|
ErrBadTimestamp = E.New("shadowsocks: bad timestamp")
|
||||||
ErrBadRequestSalt = E.New("shadowsocks: bad request salt")
|
ErrBadRequestSalt = E.New("shadowsocks: bad request salt")
|
||||||
|
@ -62,7 +64,7 @@ var List = []string{
|
||||||
"2022-blake3-chacha20-poly1305",
|
"2022-blake3-chacha20-poly1305",
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Method, error) {
|
func New(method string, pskList [][]byte, password string, secureRNG io.Reader) (shadowsocks.Method, error) {
|
||||||
m := &Method{
|
m := &Method{
|
||||||
name: method,
|
name: method,
|
||||||
pskList: pskList,
|
pskList: pskList,
|
||||||
|
@ -72,27 +74,35 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
|
||||||
|
|
||||||
switch method {
|
switch method {
|
||||||
case "2022-blake3-aes-128-gcm":
|
case "2022-blake3-aes-128-gcm":
|
||||||
m.keyLength = 16
|
m.keySaltLength = 16
|
||||||
m.constructor = newAESGCM
|
m.constructor = newAESGCM
|
||||||
m.blockConstructor = newAES
|
m.blockConstructor = newAES
|
||||||
case "2022-blake3-aes-256-gcm":
|
case "2022-blake3-aes-256-gcm":
|
||||||
m.keyLength = 32
|
m.keySaltLength = 32
|
||||||
m.constructor = newAESGCM
|
m.constructor = newAESGCM
|
||||||
m.blockConstructor = newAES
|
m.blockConstructor = newAES
|
||||||
case "2022-blake3-chacha20-poly1305":
|
case "2022-blake3-chacha20-poly1305":
|
||||||
m.keyLength = 32
|
if len(pskList) > 1 {
|
||||||
|
return nil, os.ErrInvalid
|
||||||
|
}
|
||||||
|
m.keySaltLength = 32
|
||||||
m.constructor = newChacha20Poly1305
|
m.constructor = newChacha20Poly1305
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, psk := range pskList {
|
if len(pskList) == 0 {
|
||||||
if len(psk) < m.keyLength {
|
if password == "" {
|
||||||
return nil, shadowaead.ErrBadKey
|
return nil, ErrMissingPasswordPSK
|
||||||
} else if len(psk) > m.keyLength {
|
|
||||||
pskList[i] = DerivePSK(psk, m.keyLength)
|
|
||||||
}
|
}
|
||||||
|
pskList = [][]byte{Key([]byte(password), m.keySaltLength)}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.psk = pskList[len(pskList)-1]
|
for i, psk := range pskList {
|
||||||
|
if len(psk) < m.keySaltLength {
|
||||||
|
return nil, shadowsocks.ErrBadKey
|
||||||
|
} else if len(psk) > m.keySaltLength {
|
||||||
|
pskList[i] = Key(psk, m.keySaltLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(pskList) > 1 {
|
if len(pskList) > 1 {
|
||||||
pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize)
|
pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize)
|
||||||
|
@ -112,19 +122,18 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
|
||||||
case "2022-blake3-aes-256-gcm":
|
case "2022-blake3-aes-256-gcm":
|
||||||
m.udpBlockCipher = newAES(pskList[0])
|
m.udpBlockCipher = newAES(pskList[0])
|
||||||
case "2022-blake3-chacha20-poly1305":
|
case "2022-blake3-chacha20-poly1305":
|
||||||
m.udpCipher = newXChacha20Poly1305(m.psk)
|
m.udpCipher = newXChacha20Poly1305(pskList[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DerivePSK(key []byte, keyLength int) []byte {
|
func Key(key []byte, keyLength int) []byte {
|
||||||
outKey := buf.Make(keyLength)
|
psk := sha256.Sum256(key)
|
||||||
blake3.DeriveKey(outKey, "shadowsocks 2022 pre shared key", key)
|
return psk[:keyLength]
|
||||||
return outKey
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeriveSessionKey(psk []byte, salt []byte, keyLength int) []byte {
|
func SessionKey(psk []byte, salt []byte, keyLength int) []byte {
|
||||||
sessionKey := buf.Make(len(psk) + len(salt))
|
sessionKey := buf.Make(len(psk) + len(salt))
|
||||||
copy(sessionKey, psk)
|
copy(sessionKey, psk)
|
||||||
copy(sessionKey[len(psk):], salt)
|
copy(sessionKey[len(psk):], salt)
|
||||||
|
@ -161,12 +170,11 @@ func newXChacha20Poly1305(key []byte) cipher.AEAD {
|
||||||
|
|
||||||
type Method struct {
|
type Method struct {
|
||||||
name string
|
name string
|
||||||
keyLength int
|
keySaltLength int
|
||||||
constructor func(key []byte) cipher.AEAD
|
constructor func(key []byte) cipher.AEAD
|
||||||
blockConstructor func(key []byte) cipher.Block
|
blockConstructor func(key []byte) cipher.Block
|
||||||
udpCipher cipher.AEAD
|
udpCipher cipher.AEAD
|
||||||
udpBlockCipher cipher.Block
|
udpBlockCipher cipher.Block
|
||||||
psk []byte
|
|
||||||
pskList [][]byte
|
pskList [][]byte
|
||||||
pskHash []byte
|
pskHash []byte
|
||||||
secureRNG io.Reader
|
secureRNG io.Reader
|
||||||
|
@ -178,13 +186,13 @@ func (m *Method) Name() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) KeyLength() int {
|
func (m *Method) KeyLength() int {
|
||||||
return m.keyLength
|
return m.keySaltLength
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
|
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
|
||||||
shadowsocksConn := &clientConn{
|
shadowsocksConn := &clientConn{
|
||||||
|
Method: m,
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
method: m,
|
|
||||||
destination: destination,
|
destination: destination,
|
||||||
}
|
}
|
||||||
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
|
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
|
||||||
|
@ -192,26 +200,23 @@ func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, err
|
||||||
|
|
||||||
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
|
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
|
||||||
return &clientConn{
|
return &clientConn{
|
||||||
|
Method: m,
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
method: m,
|
|
||||||
destination: destination,
|
destination: destination,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
|
func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
|
||||||
return &clientPacketConn{conn, m, m.newUDPSession()}
|
return &clientPacketConn{m, conn, m.newUDPSession()}
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientConn struct {
|
type clientConn struct {
|
||||||
|
*Method
|
||||||
net.Conn
|
net.Conn
|
||||||
|
|
||||||
method *Method
|
|
||||||
destination M.Socksaddr
|
destination M.Socksaddr
|
||||||
|
|
||||||
requestSalt []byte
|
requestSalt []byte
|
||||||
|
reader *shadowaead.Reader
|
||||||
reader *shadowaead.Reader
|
writer *shadowaead.Writer
|
||||||
writer *shadowaead.Writer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
|
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
|
||||||
|
@ -220,10 +225,10 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for i, psk := range m.pskList {
|
for i, psk := range m.pskList {
|
||||||
keyMaterial := buf.Make(m.keyLength + SaltSize)
|
keyMaterial := buf.Make(m.keySaltLength * 2)
|
||||||
copy(keyMaterial, psk)
|
copy(keyMaterial, psk)
|
||||||
copy(keyMaterial[m.keyLength:], salt)
|
copy(keyMaterial[m.keySaltLength:], salt)
|
||||||
_identitySubkey := buf.Make(m.keyLength)
|
_identitySubkey := buf.Make(m.keySaltLength)
|
||||||
identitySubkey := common.Dup(_identitySubkey)
|
identitySubkey := common.Dup(_identitySubkey)
|
||||||
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
|
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
|
||||||
|
|
||||||
|
@ -239,20 +244,20 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) writeRequest(payload []byte) error {
|
func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
salt := make([]byte, SaltSize)
|
salt := buf.Make(c.keySaltLength)
|
||||||
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
common.Must1(io.ReadFull(c.secureRNG, salt))
|
||||||
|
|
||||||
key := DeriveSessionKey(c.method.psk, salt, c.method.keyLength)
|
key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength)
|
||||||
writer := shadowaead.NewWriter(
|
writer := shadowaead.NewWriter(
|
||||||
c.Conn,
|
c.Conn,
|
||||||
c.method.constructor(common.Dup(key)),
|
c.constructor(common.Dup(key)),
|
||||||
MaxPacketSize,
|
MaxPacketSize,
|
||||||
)
|
)
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
|
|
||||||
header := writer.Buffer()
|
header := writer.Buffer()
|
||||||
header.Write(salt)
|
header.Write(salt)
|
||||||
c.method.writeExtendedIdentityHeaders(header, salt)
|
c.writeExtendedIdentityHeaders(header, salt)
|
||||||
|
|
||||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||||
|
|
||||||
|
@ -279,7 +284,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "write padding length")
|
return E.Cause(err, "write padding length")
|
||||||
}
|
}
|
||||||
_, err = io.CopyN(bufferedWriter, c.method.secureRNG, int64(pLen))
|
_, err = io.CopyN(bufferedWriter, c.secureRNG, int64(pLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "write padding")
|
return E.Cause(err, "write padding")
|
||||||
}
|
}
|
||||||
|
@ -300,22 +305,22 @@ func (c *clientConn) readResponse() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
_salt := make([]byte, SaltSize)
|
_salt := buf.Make(c.keySaltLength)
|
||||||
salt := common.Dup(_salt)
|
salt := common.Dup(_salt)
|
||||||
_, err := io.ReadFull(c.Conn, salt)
|
_, err := io.ReadFull(c.Conn, salt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !c.method.replayFilter.Check(salt) {
|
if !c.replayFilter.Check(salt) {
|
||||||
return E.New("salt not unique")
|
return E.New("salt not unique")
|
||||||
}
|
}
|
||||||
|
|
||||||
key := DeriveSessionKey(c.method.psk, salt, c.method.keyLength)
|
key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength)
|
||||||
runtime.KeepAlive(_salt)
|
runtime.KeepAlive(_salt)
|
||||||
reader := shadowaead.NewReader(
|
reader := shadowaead.NewReader(
|
||||||
c.Conn,
|
c.Conn,
|
||||||
c.method.constructor(common.Dup(key)),
|
c.constructor(common.Dup(key)),
|
||||||
MaxPacketSize,
|
MaxPacketSize,
|
||||||
)
|
)
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
|
@ -339,7 +344,7 @@ func (c *clientConn) readResponse() error {
|
||||||
return ErrBadTimestamp
|
return ErrBadTimestamp
|
||||||
}
|
}
|
||||||
|
|
||||||
_requestSalt := make([]byte, SaltSize)
|
_requestSalt := buf.Make(c.keySaltLength)
|
||||||
requestSalt := common.Dup(_requestSalt)
|
requestSalt := common.Dup(_requestSalt)
|
||||||
_, err = io.ReadFull(reader, requestSalt)
|
_, err = io.ReadFull(reader, requestSalt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -412,19 +417,19 @@ func (c *clientConn) WriterReplaceable() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientPacketConn struct {
|
type clientPacketConn struct {
|
||||||
|
*Method
|
||||||
net.Conn
|
net.Conn
|
||||||
method *Method
|
|
||||||
session *udpSession
|
session *udpSession
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||||
var hdrLen int
|
var hdrLen int
|
||||||
if c.method.udpCipher != nil {
|
if c.udpCipher != nil {
|
||||||
hdrLen = PacketNonceSize
|
hdrLen = PacketNonceSize
|
||||||
}
|
}
|
||||||
hdrLen += 16 // packet header
|
hdrLen += 16 // packet header
|
||||||
pskLen := len(c.method.pskList)
|
pskLen := len(c.pskList)
|
||||||
if c.method.udpCipher == nil && pskLen > 1 {
|
if c.udpCipher == nil && pskLen > 1 {
|
||||||
hdrLen += (pskLen - 1) * aes.BlockSize
|
hdrLen += (pskLen - 1) * aes.BlockSize
|
||||||
}
|
}
|
||||||
hdrLen += 1 // header type
|
hdrLen += 1 // header type
|
||||||
|
@ -434,8 +439,8 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
|
||||||
header := buf.With(buffer.ExtendHeader(hdrLen))
|
header := buf.With(buffer.ExtendHeader(hdrLen))
|
||||||
|
|
||||||
var dataIndex int
|
var dataIndex int
|
||||||
if c.method.udpCipher != nil {
|
if c.udpCipher != nil {
|
||||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize))
|
common.Must1(header.ReadFullFrom(c.secureRNG, PacketNonceSize))
|
||||||
if pskLen > 1 {
|
if pskLen > 1 {
|
||||||
panic("unsupported chacha extended header")
|
panic("unsupported chacha extended header")
|
||||||
}
|
}
|
||||||
|
@ -449,16 +454,16 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
|
||||||
binary.Write(header, binary.BigEndian, c.session.nextPacketId()),
|
binary.Write(header, binary.BigEndian, c.session.nextPacketId()),
|
||||||
)
|
)
|
||||||
|
|
||||||
if c.method.udpCipher == nil && pskLen > 1 {
|
if c.udpCipher == nil && pskLen > 1 {
|
||||||
for i, psk := range c.method.pskList {
|
for i, psk := range c.pskList {
|
||||||
dataIndex += aes.BlockSize
|
dataIndex += aes.BlockSize
|
||||||
pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
||||||
|
|
||||||
identityHeader := header.Extend(aes.BlockSize)
|
identityHeader := header.Extend(aes.BlockSize)
|
||||||
for textI := 0; textI < aes.BlockSize; textI++ {
|
for textI := 0; textI < aes.BlockSize; textI++ {
|
||||||
identityHeader[textI] = pskHash[textI] ^ header.Byte(textI)
|
identityHeader[textI] = pskHash[textI] ^ header.Byte(textI)
|
||||||
}
|
}
|
||||||
c.method.blockConstructor(psk).Encrypt(identityHeader, identityHeader)
|
c.blockConstructor(psk).Encrypt(identityHeader, identityHeader)
|
||||||
|
|
||||||
if i == pskLen-2 {
|
if i == pskLen-2 {
|
||||||
break
|
break
|
||||||
|
@ -477,14 +482,14 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if c.method.udpCipher != nil {
|
if c.udpCipher != nil {
|
||||||
c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
|
c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
|
||||||
buffer.Extend(c.method.udpCipher.Overhead())
|
buffer.Extend(c.udpCipher.Overhead())
|
||||||
} else {
|
} else {
|
||||||
packetHeader := buffer.To(aes.BlockSize)
|
packetHeader := buffer.To(aes.BlockSize)
|
||||||
c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
|
c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
|
||||||
buffer.Extend(c.session.cipher.Overhead())
|
buffer.Extend(c.session.cipher.Overhead())
|
||||||
c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
c.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
||||||
}
|
}
|
||||||
return common.Error(c.Write(buffer.Bytes()))
|
return common.Error(c.Write(buffer.Bytes()))
|
||||||
}
|
}
|
||||||
|
@ -497,16 +502,16 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||||
buffer.Truncate(n)
|
buffer.Truncate(n)
|
||||||
|
|
||||||
var packetHeader []byte
|
var packetHeader []byte
|
||||||
if c.method.udpCipher != nil {
|
if c.udpCipher != nil {
|
||||||
_, err = c.method.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
_, err = c.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return M.Socksaddr{}, E.Cause(err, "decrypt packet")
|
return M.Socksaddr{}, E.Cause(err, "decrypt packet")
|
||||||
}
|
}
|
||||||
buffer.Advance(PacketNonceSize)
|
buffer.Advance(PacketNonceSize)
|
||||||
buffer.Truncate(buffer.Len() - c.method.udpCipher.Overhead())
|
buffer.Truncate(buffer.Len() - c.udpCipher.Overhead())
|
||||||
} else {
|
} else {
|
||||||
packetHeader = buffer.To(aes.BlockSize)
|
packetHeader = buffer.To(aes.BlockSize)
|
||||||
c.method.udpBlockCipher.Decrypt(packetHeader, packetHeader)
|
c.udpBlockCipher.Decrypt(packetHeader, packetHeader)
|
||||||
}
|
}
|
||||||
|
|
||||||
var sessionId, packetId uint64
|
var sessionId, packetId uint64
|
||||||
|
@ -526,8 +531,8 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||||
} else if sessionId == c.session.lastRemoteSessionId {
|
} else if sessionId == c.session.lastRemoteSessionId {
|
||||||
remoteCipher = c.session.lastRemoteCipher
|
remoteCipher = c.session.lastRemoteCipher
|
||||||
} else {
|
} else {
|
||||||
key := DeriveSessionKey(c.method.psk, packetHeader[:8], c.method.keyLength)
|
key := SessionKey(c.pskList[len(c.pskList)-1], packetHeader[:8], c.keySaltLength)
|
||||||
remoteCipher = c.method.constructor(common.Dup(key))
|
remoteCipher = c.constructor(common.Dup(key))
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
}
|
}
|
||||||
_, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
|
_, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
|
||||||
|
@ -622,14 +627,14 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
|
||||||
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
destination := M.SocksaddrFromNet(addr)
|
destination := M.SocksaddrFromNet(addr)
|
||||||
var overHead int
|
var overHead int
|
||||||
if c.method.udpCipher != nil {
|
if c.udpCipher != nil {
|
||||||
overHead = PacketNonceSize + c.method.udpCipher.Overhead()
|
overHead = PacketNonceSize + c.udpCipher.Overhead()
|
||||||
} else {
|
} else {
|
||||||
overHead = c.session.cipher.Overhead()
|
overHead = c.session.cipher.Overhead()
|
||||||
}
|
}
|
||||||
overHead += 16 // packet header
|
overHead += 16 // packet header
|
||||||
pskLen := len(c.method.pskList)
|
pskLen := len(c.pskList)
|
||||||
if c.method.udpCipher == nil && pskLen > 1 {
|
if c.udpCipher == nil && pskLen > 1 {
|
||||||
overHead += (pskLen - 1) * aes.BlockSize
|
overHead += (pskLen - 1) * aes.BlockSize
|
||||||
}
|
}
|
||||||
overHead += 1 // header type
|
overHead += 1 // header type
|
||||||
|
@ -642,8 +647,8 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
buffer := buf.With(common.Dup(_buffer))
|
buffer := buf.With(common.Dup(_buffer))
|
||||||
|
|
||||||
var dataIndex int
|
var dataIndex int
|
||||||
if c.method.udpCipher != nil {
|
if c.udpCipher != nil {
|
||||||
common.Must1(buffer.ReadFullFrom(c.method.secureRNG, PacketNonceSize))
|
common.Must1(buffer.ReadFullFrom(c.secureRNG, PacketNonceSize))
|
||||||
if pskLen > 1 {
|
if pskLen > 1 {
|
||||||
panic("unsupported chacha extended header")
|
panic("unsupported chacha extended header")
|
||||||
}
|
}
|
||||||
|
@ -657,16 +662,16 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()),
|
binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()),
|
||||||
)
|
)
|
||||||
|
|
||||||
if c.method.udpCipher == nil && pskLen > 1 {
|
if c.udpCipher == nil && pskLen > 1 {
|
||||||
for i, psk := range c.method.pskList {
|
for i, psk := range c.pskList {
|
||||||
dataIndex += aes.BlockSize
|
dataIndex += aes.BlockSize
|
||||||
pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
||||||
|
|
||||||
identityHeader := buffer.Extend(aes.BlockSize)
|
identityHeader := buffer.Extend(aes.BlockSize)
|
||||||
for textI := 0; textI < aes.BlockSize; textI++ {
|
for textI := 0; textI < aes.BlockSize; textI++ {
|
||||||
identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI)
|
identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI)
|
||||||
}
|
}
|
||||||
c.method.blockConstructor(psk).Encrypt(identityHeader, identityHeader)
|
c.blockConstructor(psk).Encrypt(identityHeader, identityHeader)
|
||||||
|
|
||||||
if i == pskLen-2 {
|
if i == pskLen-2 {
|
||||||
break
|
break
|
||||||
|
@ -685,14 +690,14 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if c.method.udpCipher != nil {
|
if c.udpCipher != nil {
|
||||||
c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
|
c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
|
||||||
buffer.Extend(c.method.udpCipher.Overhead())
|
buffer.Extend(c.udpCipher.Overhead())
|
||||||
} else {
|
} else {
|
||||||
packetHeader := buffer.To(aes.BlockSize)
|
packetHeader := buffer.To(aes.BlockSize)
|
||||||
c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
|
c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
|
||||||
buffer.Extend(c.session.cipher.Overhead())
|
buffer.Extend(c.session.cipher.Overhead())
|
||||||
c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
c.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
||||||
}
|
}
|
||||||
err = common.Error(c.Write(buffer.Bytes()))
|
err = common.Error(c.Write(buffer.Bytes()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -726,7 +731,7 @@ func (m *Method) newUDPSession() *udpSession {
|
||||||
if m.udpCipher == nil {
|
if m.udpCipher == nil {
|
||||||
sessionId := make([]byte, 8)
|
sessionId := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
||||||
key := DeriveSessionKey(m.psk, sessionId, m.keyLength)
|
key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength)
|
||||||
session.cipher = m.constructor(common.Dup(key))
|
session.cipher = m.constructor(common.Dup(key))
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ import (
|
||||||
type Service struct {
|
type Service struct {
|
||||||
name string
|
name string
|
||||||
secureRNG io.Reader
|
secureRNG io.Reader
|
||||||
keyLength int
|
keySaltLength int
|
||||||
constructor func(key []byte) cipher.AEAD
|
constructor func(key []byte) cipher.AEAD
|
||||||
blockConstructor func(key []byte) cipher.Block
|
blockConstructor func(key []byte) cipher.Block
|
||||||
udpCipher cipher.AEAD
|
udpCipher cipher.AEAD
|
||||||
|
@ -42,7 +42,7 @@ type Service struct {
|
||||||
sessions *cache.LruCache[uint64, *serverUDPSession]
|
sessions *cache.LruCache[uint64, *serverUDPSession]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
func NewService(method string, psk []byte, password string, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||||
s := &Service{
|
s := &Service{
|
||||||
name: method,
|
name: method,
|
||||||
secureRNG: secureRNG,
|
secureRNG: secureRNG,
|
||||||
|
@ -57,25 +57,31 @@ func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64
|
||||||
|
|
||||||
switch method {
|
switch method {
|
||||||
case "2022-blake3-aes-128-gcm":
|
case "2022-blake3-aes-128-gcm":
|
||||||
s.keyLength = 16
|
s.keySaltLength = 16
|
||||||
s.constructor = newAESGCM
|
s.constructor = newAESGCM
|
||||||
s.blockConstructor = newAES
|
s.blockConstructor = newAES
|
||||||
case "2022-blake3-aes-256-gcm":
|
case "2022-blake3-aes-256-gcm":
|
||||||
s.keyLength = 32
|
s.keySaltLength = 32
|
||||||
s.constructor = newAESGCM
|
s.constructor = newAESGCM
|
||||||
s.blockConstructor = newAES
|
s.blockConstructor = newAES
|
||||||
case "2022-blake3-chacha20-poly1305":
|
case "2022-blake3-chacha20-poly1305":
|
||||||
s.keyLength = 32
|
s.keySaltLength = 32
|
||||||
s.constructor = newChacha20Poly1305
|
s.constructor = newChacha20Poly1305
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(psk) < s.keyLength {
|
if len(psk) == s.keySaltLength {
|
||||||
return nil, shadowaead.ErrBadKey
|
s.psk = psk
|
||||||
} else if len(psk) > s.keyLength {
|
} else if len(psk) != 0 {
|
||||||
psk = DerivePSK(psk, s.keyLength)
|
if len(psk) < s.keySaltLength {
|
||||||
|
return nil, shadowsocks.ErrBadKey
|
||||||
|
}
|
||||||
|
s.psk = Key(psk, s.keySaltLength)
|
||||||
|
} else if password == "" {
|
||||||
|
return nil, ErrMissingPasswordPSK
|
||||||
|
} else {
|
||||||
|
s.psk = Key([]byte(password), s.keySaltLength)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.psk = psk
|
|
||||||
switch method {
|
switch method {
|
||||||
case "2022-blake3-aes-128-gcm":
|
case "2022-blake3-aes-128-gcm":
|
||||||
s.udpBlockCipher = newAES(psk)
|
s.udpBlockCipher = newAES(psk)
|
||||||
|
@ -97,7 +103,7 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||||
requestSalt := make([]byte, SaltSize)
|
requestSalt := buf.Make(s.keySaltLength)
|
||||||
_, err := io.ReadFull(conn, requestSalt)
|
_, err := io.ReadFull(conn, requestSalt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "read request salt")
|
return E.Cause(err, "read request salt")
|
||||||
|
@ -107,7 +113,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
||||||
return E.New("salt not unique")
|
return E.New("salt not unique")
|
||||||
}
|
}
|
||||||
|
|
||||||
requestKey := DeriveSessionKey(s.psk, requestSalt, s.keyLength)
|
requestKey := SessionKey(s.psk, requestSalt, s.keySaltLength)
|
||||||
reader := shadowaead.NewReader(
|
reader := shadowaead.NewReader(
|
||||||
conn,
|
conn,
|
||||||
s.constructor(common.Dup(requestKey)),
|
s.constructor(common.Dup(requestKey)),
|
||||||
|
@ -175,10 +181,10 @@ type serverConn struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
||||||
var _salt [SaltSize]byte
|
_salt := buf.Make(c.keySaltLength)
|
||||||
salt := common.Dup(_salt[:])
|
salt := common.Dup(_salt[:])
|
||||||
common.Must1(io.ReadFull(c.secureRNG, salt))
|
common.Must1(io.ReadFull(c.secureRNG, salt))
|
||||||
key := DeriveSessionKey(c.uPSK, salt, c.keyLength)
|
key := SessionKey(c.uPSK, salt, c.keySaltLength)
|
||||||
runtime.KeepAlive(_salt)
|
runtime.KeepAlive(_salt)
|
||||||
writer := shadowaead.NewWriter(
|
writer := shadowaead.NewWriter(
|
||||||
c.Conn,
|
c.Conn,
|
||||||
|
@ -294,7 +300,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me
|
||||||
if !loaded {
|
if !loaded {
|
||||||
session.remoteSessionId = sessionId
|
session.remoteSessionId = sessionId
|
||||||
if packetHeader != nil {
|
if packetHeader != nil {
|
||||||
key := DeriveSessionKey(s.psk, packetHeader[:8], s.keyLength)
|
key := SessionKey(s.psk, packetHeader[:8], s.keySaltLength)
|
||||||
session.remoteCipher = s.constructor(common.Dup(key))
|
session.remoteCipher = s.constructor(common.Dup(key))
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
}
|
}
|
||||||
|
@ -439,7 +445,7 @@ func (m *Service) newUDPSession() *serverUDPSession {
|
||||||
if m.udpCipher == nil {
|
if m.udpCipher == nil {
|
||||||
sessionId := make([]byte, 8)
|
sessionId := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
||||||
key := DeriveSessionKey(m.psk, sessionId, m.keyLength)
|
key := SessionKey(m.psk, sessionId, m.keySaltLength)
|
||||||
session.cipher = m.constructor(common.Dup(key))
|
session.cipher = m.constructor(common.Dup(key))
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,10 +30,10 @@ type MultiService[U comparable] struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MultiService[U]) AddUser(user U, key []byte) error {
|
func (s *MultiService[U]) AddUser(user U, key []byte) error {
|
||||||
if len(key) < s.keyLength {
|
if len(key) < s.keySaltLength {
|
||||||
return shadowaead.ErrBadKey
|
return shadowsocks.ErrBadKey
|
||||||
} else if len(key) > s.keyLength {
|
} else if len(key) > s.keySaltLength {
|
||||||
key = DerivePSK(key, s.keyLength)
|
key = Key(key, s.keySaltLength)
|
||||||
}
|
}
|
||||||
|
|
||||||
var uPSKHash [aes.BlockSize]byte
|
var uPSKHash [aes.BlockSize]byte
|
||||||
|
@ -67,7 +67,7 @@ func NewMultiService[U comparable](method string, iPSK []byte, secureRNG io.Read
|
||||||
return nil, E.New("unsupported method ", method)
|
return nil, E.New("unsupported method ", method)
|
||||||
}
|
}
|
||||||
|
|
||||||
ss, err := NewService(method, iPSK, secureRNG, udpTimeout, handler)
|
ss, err := NewService(method, iPSK, "", secureRNG, udpTimeout, handler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -91,7 +91,7 @@ func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, meta
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||||
requestSalt := make([]byte, SaltSize)
|
requestSalt := make([]byte, s.keySaltLength)
|
||||||
_, err := io.ReadFull(conn, requestSalt)
|
_, err := io.ReadFull(conn, requestSalt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "read request salt")
|
return E.Cause(err, "read request salt")
|
||||||
|
@ -108,10 +108,10 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
|
||||||
return E.Cause(err, "read extended identity header")
|
return E.Cause(err, "read extended identity header")
|
||||||
}
|
}
|
||||||
|
|
||||||
keyMaterial := buf.Make(s.keyLength + SaltSize)
|
keyMaterial := buf.Make(s.keySaltLength * 2)
|
||||||
copy(keyMaterial, s.psk)
|
copy(keyMaterial, s.psk)
|
||||||
copy(keyMaterial[s.keyLength:], requestSalt)
|
copy(keyMaterial[s.keySaltLength:], requestSalt)
|
||||||
_identitySubkey := buf.Make(s.keyLength)
|
_identitySubkey := buf.Make(s.keySaltLength)
|
||||||
identitySubkey := common.Dup(_identitySubkey)
|
identitySubkey := common.Dup(_identitySubkey)
|
||||||
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
|
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
|
||||||
s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader)
|
s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader)
|
||||||
|
@ -126,7 +126,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
|
||||||
return E.New("invalid request")
|
return E.New("invalid request")
|
||||||
}
|
}
|
||||||
|
|
||||||
requestKey := DeriveSessionKey(uPSK, requestSalt, s.keyLength)
|
requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength)
|
||||||
reader := shadowaead.NewReader(
|
reader := shadowaead.NewReader(
|
||||||
conn,
|
conn,
|
||||||
s.constructor(common.Dup(requestKey)),
|
s.constructor(common.Dup(requestKey)),
|
||||||
|
@ -230,7 +230,7 @@ func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metad
|
||||||
})
|
})
|
||||||
if !loaded {
|
if !loaded {
|
||||||
session.remoteSessionId = sessionId
|
session.remoteSessionId = sessionId
|
||||||
key := DeriveSessionKey(uPSK, packetHeader[:8], s.keyLength)
|
key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength)
|
||||||
session.remoteCipher = s.constructor(common.Dup(key))
|
session.remoteCipher = s.constructor(common.Dup(key))
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
}
|
}
|
||||||
|
@ -312,7 +312,7 @@ func (m *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession {
|
||||||
session.packetId--
|
session.packetId--
|
||||||
sessionId := make([]byte, 8)
|
sessionId := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
||||||
key := DeriveSessionKey(uPSK, sessionId, m.keyLength)
|
key := SessionKey(uPSK, sessionId, m.keySaltLength)
|
||||||
session.cipher = m.constructor(common.Dup(key))
|
session.cipher = m.constructor(common.Dup(key))
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
return session
|
return session
|
||||||
|
|
|
@ -29,7 +29,7 @@ func TestMultiService(t *testing.T) {
|
||||||
random.Default.Read(uPSK[:])
|
random.Default.Read(uPSK[:])
|
||||||
multiService.AddUser("my user", uPSK[:])
|
multiService.AddUser("my user", uPSK[:])
|
||||||
|
|
||||||
client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}, random.Default)
|
client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}, "", random.Default)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
56
protocol/shadowsocks/shadowimpl/fetcher.go
Normal file
56
protocol/shadowsocks/shadowimpl/fetcher.go
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
package shadowimpl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||||
|
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
||||||
|
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
|
||||||
|
"github.com/sagernet/sing/protocol/shadowsocks/shadowstream"
|
||||||
|
)
|
||||||
|
|
||||||
|
func FetchMethod(method string, key string, password string, secureRNG io.Reader) (shadowsocks.Method, error) {
|
||||||
|
if method == "none" {
|
||||||
|
return shadowsocks.NewNone(), nil
|
||||||
|
} else if common.Contains(shadowstream.List, method) {
|
||||||
|
var keyBytes []byte
|
||||||
|
if key != "" {
|
||||||
|
kb, err := base64.StdEncoding.DecodeString(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "decode key")
|
||||||
|
}
|
||||||
|
keyBytes = kb
|
||||||
|
}
|
||||||
|
return shadowstream.New(method, keyBytes, password, secureRNG)
|
||||||
|
} else if common.Contains(shadowaead.List, method) {
|
||||||
|
var keyBytes []byte
|
||||||
|
if key != "" {
|
||||||
|
kb, err := base64.StdEncoding.DecodeString(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "decode key")
|
||||||
|
}
|
||||||
|
keyBytes = kb
|
||||||
|
}
|
||||||
|
return shadowaead.New(method, keyBytes, password, secureRNG)
|
||||||
|
} else if common.Contains(shadowaead_2022.List, method) {
|
||||||
|
var pskList [][]byte
|
||||||
|
if key != "" {
|
||||||
|
keyStrList := strings.Split(key, ":")
|
||||||
|
pskList = make([][]byte, len(keyStrList))
|
||||||
|
for i, keyStr := range keyStrList {
|
||||||
|
kb, err := base64.StdEncoding.DecodeString(keyStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "decode key")
|
||||||
|
}
|
||||||
|
pskList[i] = kb
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return shadowaead_2022.New(method, pskList, password, secureRNG)
|
||||||
|
} else {
|
||||||
|
return nil, E.New("shadowsocks: unsupported method ", method)
|
||||||
|
}
|
||||||
|
}
|
|
@ -53,7 +53,7 @@ type Method struct {
|
||||||
secureRNG io.Reader
|
secureRNG io.Reader
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(method string, key []byte, password []byte, secureRNG io.Reader) (shadowsocks.Method, error) {
|
func New(method string, key []byte, password string, secureRNG io.Reader) (shadowsocks.Method, error) {
|
||||||
m := &Method{
|
m := &Method{
|
||||||
name: method,
|
name: method,
|
||||||
secureRNG: secureRNG,
|
secureRNG: secureRNG,
|
||||||
|
@ -167,11 +167,11 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader) (shado
|
||||||
if len(key) == m.keyLength {
|
if len(key) == m.keyLength {
|
||||||
m.key = key
|
m.key = key
|
||||||
} else if len(key) > 0 {
|
} else if len(key) > 0 {
|
||||||
return nil, shadowaead.ErrBadKey
|
return nil, shadowsocks.ErrBadKey
|
||||||
} else if len(password) > 0 {
|
} else if password != "" {
|
||||||
m.key = shadowsocks.Key(password, m.keyLength)
|
m.key = shadowsocks.Key([]byte(password), m.keyLength)
|
||||||
} else {
|
} else {
|
||||||
return nil, shadowaead.ErrMissingPassword
|
return nil, shadowsocks.ErrMissingPassword
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue