mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 20:37:40 +03:00
Allow over length for shadowsocks 2022 psk
This commit is contained in:
parent
69d801e651
commit
f1a5f8aaa3
5 changed files with 48 additions and 37 deletions
|
@ -211,21 +211,16 @@ func newClient(f *flags) (*client, error) {
|
||||||
}
|
}
|
||||||
c.method = method
|
c.method = method
|
||||||
} else if common.Contains(shadowaead_2022.List, f.Method) {
|
} else if common.Contains(shadowaead_2022.List, f.Method) {
|
||||||
var pskList [][shadowaead_2022.KeySaltSize]byte
|
var pskList [][]byte
|
||||||
if f.Key != "" {
|
if f.Key != "" {
|
||||||
keyStrList := strings.Split(f.Key, ":")
|
keyStrList := strings.Split(f.Key, ":")
|
||||||
pskList = make([][shadowaead_2022.KeySaltSize]byte, len(keyStrList))
|
pskList = make([][]byte, len(keyStrList))
|
||||||
for i, keyStr := range keyStrList {
|
for i, keyStr := range keyStrList {
|
||||||
var psk [shadowaead_2022.KeySaltSize]byte
|
|
||||||
kb, err := base64.StdEncoding.DecodeString(keyStr)
|
kb, err := base64.StdEncoding.DecodeString(keyStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "decode key")
|
return nil, E.Cause(err, "decode key")
|
||||||
}
|
}
|
||||||
if len(kb) != shadowaead_2022.KeySaltSize {
|
pskList[i] = kb
|
||||||
return nil, shadowaead.ErrBadKey
|
|
||||||
}
|
|
||||||
copy(psk[:], kb)
|
|
||||||
pskList[i] = psk
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
method, err := shadowaead_2022.New(f.Method, pskList, rng)
|
method, err := shadowaead_2022.New(f.Method, pskList, rng)
|
||||||
|
|
|
@ -149,16 +149,13 @@ func newServer(f *flags) (*server, error) {
|
||||||
}
|
}
|
||||||
s.service = service
|
s.service = service
|
||||||
} else if common.Contains(shadowaead_2022.List, f.Method) {
|
} else if common.Contains(shadowaead_2022.List, f.Method) {
|
||||||
var key [shadowaead_2022.KeySaltSize]byte
|
var key []byte
|
||||||
if f.Key != "" {
|
if f.Key != "" {
|
||||||
kb, err := base64.StdEncoding.DecodeString(f.Key)
|
kb, err := base64.StdEncoding.DecodeString(f.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "decode key")
|
return nil, E.Cause(err, "decode key")
|
||||||
}
|
}
|
||||||
if len(kb) != shadowaead_2022.KeySaltSize {
|
key = kb
|
||||||
return nil, shadowaead.ErrBadKey
|
|
||||||
}
|
|
||||||
copy(key[:], kb)
|
|
||||||
}
|
}
|
||||||
service, err := shadowaead_2022.NewService(f.Method, key, random.Default, udpTimeout, s)
|
service, err := shadowaead_2022.NewService(f.Method, key, random.Default, udpTimeout, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -65,7 +65,6 @@ var List = []string{
|
||||||
func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Method, error) {
|
func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Method, error) {
|
||||||
m := &Method{
|
m := &Method{
|
||||||
name: method,
|
name: method,
|
||||||
psk: pskList[len(pskList)-1],
|
|
||||||
pskList: pskList,
|
pskList: pskList,
|
||||||
secureRNG: secureRNG,
|
secureRNG: secureRNG,
|
||||||
replayFilter: replay.NewCuckoo(60),
|
replayFilter: replay.NewCuckoo(60),
|
||||||
|
@ -85,12 +84,16 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
|
||||||
m.constructor = newChacha20Poly1305
|
m.constructor = newChacha20Poly1305
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, psk := range pskList {
|
for i, psk := range pskList {
|
||||||
if len(psk) != m.keyLength {
|
if len(psk) < m.keyLength {
|
||||||
return nil, shadowaead.ErrBadKey
|
return nil, shadowaead.ErrBadKey
|
||||||
|
} else if len(psk) > m.keyLength {
|
||||||
|
pskList[i] = DerivePSK(psk, m.keyLength)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.psk = pskList[len(pskList)-1]
|
||||||
|
|
||||||
if len(pskList) > 1 {
|
if len(pskList) > 1 {
|
||||||
pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize)
|
pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize)
|
||||||
for i, psk := range pskList {
|
for i, psk := range pskList {
|
||||||
|
@ -111,10 +114,17 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
|
||||||
case "2022-blake3-chacha20-poly1305":
|
case "2022-blake3-chacha20-poly1305":
|
||||||
m.udpCipher = newXChacha20Poly1305(m.psk)
|
m.udpCipher = newXChacha20Poly1305(m.psk)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Blake3DeriveKey(psk []byte, salt []byte, keyLength int) []byte {
|
func DerivePSK(key []byte, keyLength int) []byte {
|
||||||
|
outKey := buf.Make(keyLength)
|
||||||
|
blake3.DeriveKey(outKey, "shadowsocks 2022 pre shared key", key)
|
||||||
|
return outKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeriveSessionKey(psk []byte, salt []byte, keyLength int) []byte {
|
||||||
sessionKey := buf.Make(len(psk) + len(salt))
|
sessionKey := buf.Make(len(psk) + len(salt))
|
||||||
copy(sessionKey, psk)
|
copy(sessionKey, psk)
|
||||||
copy(sessionKey[len(psk):], salt)
|
copy(sessionKey[len(psk):], salt)
|
||||||
|
@ -232,7 +242,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
salt := make([]byte, SaltSize)
|
salt := make([]byte, SaltSize)
|
||||||
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
||||||
|
|
||||||
key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength)
|
key := DeriveSessionKey(c.method.psk, salt, c.method.keyLength)
|
||||||
writer := shadowaead.NewWriter(
|
writer := shadowaead.NewWriter(
|
||||||
c.Conn,
|
c.Conn,
|
||||||
c.method.constructor(common.Dup(key)),
|
c.method.constructor(common.Dup(key)),
|
||||||
|
@ -301,7 +311,7 @@ func (c *clientConn) readResponse() error {
|
||||||
return E.New("salt not unique")
|
return E.New("salt not unique")
|
||||||
}
|
}
|
||||||
|
|
||||||
key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength)
|
key := DeriveSessionKey(c.method.psk, salt, c.method.keyLength)
|
||||||
runtime.KeepAlive(_salt)
|
runtime.KeepAlive(_salt)
|
||||||
reader := shadowaead.NewReader(
|
reader := shadowaead.NewReader(
|
||||||
c.Conn,
|
c.Conn,
|
||||||
|
@ -516,7 +526,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||||
} else if sessionId == c.session.lastRemoteSessionId {
|
} else if sessionId == c.session.lastRemoteSessionId {
|
||||||
remoteCipher = c.session.lastRemoteCipher
|
remoteCipher = c.session.lastRemoteCipher
|
||||||
} else {
|
} else {
|
||||||
key := Blake3DeriveKey(c.method.psk, packetHeader[:8], c.method.keyLength)
|
key := DeriveSessionKey(c.method.psk, packetHeader[:8], c.method.keyLength)
|
||||||
remoteCipher = c.method.constructor(common.Dup(key))
|
remoteCipher = c.method.constructor(common.Dup(key))
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
}
|
}
|
||||||
|
@ -716,7 +726,7 @@ func (m *Method) newUDPSession() *udpSession {
|
||||||
if m.udpCipher == nil {
|
if m.udpCipher == nil {
|
||||||
sessionId := make([]byte, 8)
|
sessionId := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
||||||
key := Blake3DeriveKey(m.psk, sessionId, m.keyLength)
|
key := DeriveSessionKey(m.psk, sessionId, m.keyLength)
|
||||||
session.cipher = m.constructor(common.Dup(key))
|
session.cipher = m.constructor(common.Dup(key))
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,6 @@ type Service struct {
|
||||||
func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||||
s := &Service{
|
s := &Service{
|
||||||
name: method,
|
name: method,
|
||||||
psk: psk,
|
|
||||||
secureRNG: secureRNG,
|
secureRNG: secureRNG,
|
||||||
replayFilter: replay.NewCuckoo(60),
|
replayFilter: replay.NewCuckoo(60),
|
||||||
handler: handler,
|
handler: handler,
|
||||||
|
@ -70,17 +69,20 @@ func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64
|
||||||
s.constructor = newChacha20Poly1305
|
s.constructor = newChacha20Poly1305
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(psk) != s.keyLength {
|
if len(psk) < s.keyLength {
|
||||||
return nil, shadowaead.ErrBadKey
|
return nil, shadowaead.ErrBadKey
|
||||||
|
} else if len(psk) > s.keyLength {
|
||||||
|
psk = DerivePSK(psk, s.keyLength)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.psk = psk
|
||||||
switch method {
|
switch method {
|
||||||
case "2022-blake3-aes-128-gcm":
|
case "2022-blake3-aes-128-gcm":
|
||||||
s.udpBlockCipher = newAES(s.psk[:])
|
s.udpBlockCipher = newAES(psk)
|
||||||
case "2022-blake3-aes-256-gcm":
|
case "2022-blake3-aes-256-gcm":
|
||||||
s.udpBlockCipher = newAES(s.psk[:])
|
s.udpBlockCipher = newAES(psk)
|
||||||
case "2022-blake3-chacha20-poly1305":
|
case "2022-blake3-chacha20-poly1305":
|
||||||
s.udpCipher = newXChacha20Poly1305(s.psk[:])
|
s.udpCipher = newXChacha20Poly1305(psk)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
|
@ -105,7 +107,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
||||||
return E.New("salt not unique")
|
return E.New("salt not unique")
|
||||||
}
|
}
|
||||||
|
|
||||||
requestKey := Blake3DeriveKey(s.psk[:], requestSalt, s.keyLength)
|
requestKey := DeriveSessionKey(s.psk, requestSalt, s.keyLength)
|
||||||
reader := shadowaead.NewReader(
|
reader := shadowaead.NewReader(
|
||||||
conn,
|
conn,
|
||||||
s.constructor(common.Dup(requestKey)),
|
s.constructor(common.Dup(requestKey)),
|
||||||
|
@ -176,7 +178,7 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
||||||
var _salt [SaltSize]byte
|
var _salt [SaltSize]byte
|
||||||
salt := common.Dup(_salt[:])
|
salt := common.Dup(_salt[:])
|
||||||
common.Must1(io.ReadFull(c.secureRNG, salt))
|
common.Must1(io.ReadFull(c.secureRNG, salt))
|
||||||
key := Blake3DeriveKey(c.uPSK[:], salt, c.keyLength)
|
key := DeriveSessionKey(c.uPSK, salt, c.keyLength)
|
||||||
runtime.KeepAlive(_salt)
|
runtime.KeepAlive(_salt)
|
||||||
writer := shadowaead.NewWriter(
|
writer := shadowaead.NewWriter(
|
||||||
c.Conn,
|
c.Conn,
|
||||||
|
@ -292,7 +294,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me
|
||||||
if !loaded {
|
if !loaded {
|
||||||
session.remoteSessionId = sessionId
|
session.remoteSessionId = sessionId
|
||||||
if packetHeader != nil {
|
if packetHeader != nil {
|
||||||
key := Blake3DeriveKey(s.psk[:], packetHeader[:8], s.keyLength)
|
key := DeriveSessionKey(s.psk, packetHeader[:8], s.keyLength)
|
||||||
session.remoteCipher = s.constructor(common.Dup(key))
|
session.remoteCipher = s.constructor(common.Dup(key))
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
}
|
}
|
||||||
|
@ -437,7 +439,7 @@ func (m *Service) newUDPSession() *serverUDPSession {
|
||||||
if m.udpCipher == nil {
|
if m.udpCipher == nil {
|
||||||
sessionId := make([]byte, 8)
|
sessionId := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
||||||
key := Blake3DeriveKey(m.psk[:], sessionId, m.keyLength)
|
key := DeriveSessionKey(m.psk, sessionId, m.keyLength)
|
||||||
session.cipher = m.constructor(common.Dup(key))
|
session.cipher = m.constructor(common.Dup(key))
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,9 +29,15 @@ type MultiService[U comparable] struct {
|
||||||
uPSKHashR map[[aes.BlockSize]byte]U
|
uPSKHashR map[[aes.BlockSize]byte]U
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MultiService[U]) AddUser(user U, key []byte) {
|
func (s *MultiService[U]) AddUser(user U, key []byte) error {
|
||||||
|
if len(key) < s.keyLength {
|
||||||
|
return shadowaead.ErrBadKey
|
||||||
|
} else if len(key) > s.keyLength {
|
||||||
|
key = DerivePSK(key, s.keyLength)
|
||||||
|
}
|
||||||
|
|
||||||
var uPSKHash [aes.BlockSize]byte
|
var uPSKHash [aes.BlockSize]byte
|
||||||
hash512 := blake3.Sum512(key[:])
|
hash512 := blake3.Sum512(key)
|
||||||
copy(uPSKHash[:], hash512[:])
|
copy(uPSKHash[:], hash512[:])
|
||||||
|
|
||||||
if oldHash, loaded := s.uPSKHash[user]; loaded {
|
if oldHash, loaded := s.uPSKHash[user]; loaded {
|
||||||
|
@ -40,8 +46,9 @@ func (s *MultiService[U]) AddUser(user U, key []byte) {
|
||||||
|
|
||||||
s.uPSKHash[user] = uPSKHash
|
s.uPSKHash[user] = uPSKHash
|
||||||
s.uPSKHashR[uPSKHash] = user
|
s.uPSKHashR[uPSKHash] = user
|
||||||
|
|
||||||
s.uPSK[user] = key
|
s.uPSK[user] = key
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MultiService[U]) RemoveUser(user U) {
|
func (s *MultiService[U]) RemoveUser(user U) {
|
||||||
|
@ -102,7 +109,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
|
||||||
}
|
}
|
||||||
|
|
||||||
keyMaterial := buf.Make(s.keyLength + SaltSize)
|
keyMaterial := buf.Make(s.keyLength + SaltSize)
|
||||||
copy(keyMaterial, s.psk[:])
|
copy(keyMaterial, s.psk)
|
||||||
copy(keyMaterial[s.keyLength:], requestSalt)
|
copy(keyMaterial[s.keyLength:], requestSalt)
|
||||||
_identitySubkey := buf.Make(s.keyLength)
|
_identitySubkey := buf.Make(s.keyLength)
|
||||||
identitySubkey := common.Dup(_identitySubkey)
|
identitySubkey := common.Dup(_identitySubkey)
|
||||||
|
@ -119,7 +126,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
|
||||||
return E.New("invalid request")
|
return E.New("invalid request")
|
||||||
}
|
}
|
||||||
|
|
||||||
requestKey := Blake3DeriveKey(uPSK[:], requestSalt, s.keyLength)
|
requestKey := DeriveSessionKey(uPSK, requestSalt, s.keyLength)
|
||||||
reader := shadowaead.NewReader(
|
reader := shadowaead.NewReader(
|
||||||
conn,
|
conn,
|
||||||
s.constructor(common.Dup(requestKey)),
|
s.constructor(common.Dup(requestKey)),
|
||||||
|
@ -223,7 +230,7 @@ func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metad
|
||||||
})
|
})
|
||||||
if !loaded {
|
if !loaded {
|
||||||
session.remoteSessionId = sessionId
|
session.remoteSessionId = sessionId
|
||||||
key := Blake3DeriveKey(uPSK[:], packetHeader[:8], s.keyLength)
|
key := DeriveSessionKey(uPSK, packetHeader[:8], s.keyLength)
|
||||||
session.remoteCipher = s.constructor(common.Dup(key))
|
session.remoteCipher = s.constructor(common.Dup(key))
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
}
|
}
|
||||||
|
@ -305,7 +312,7 @@ func (m *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession {
|
||||||
session.packetId--
|
session.packetId--
|
||||||
sessionId := make([]byte, 8)
|
sessionId := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
||||||
key := Blake3DeriveKey(uPSK[:], sessionId, m.keyLength)
|
key := DeriveSessionKey(uPSK, sessionId, m.keyLength)
|
||||||
session.cipher = m.constructor(common.Dup(key))
|
session.cipher = m.constructor(common.Dup(key))
|
||||||
runtime.KeepAlive(key)
|
runtime.KeepAlive(key)
|
||||||
return session
|
return session
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue