Allow over length for shadowsocks 2022 psk

This commit is contained in:
世界 2022-05-12 13:48:54 +08:00
parent 69d801e651
commit f1a5f8aaa3
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 48 additions and 37 deletions

View file

@ -211,21 +211,16 @@ func newClient(f *flags) (*client, error) {
} }
c.method = method c.method = method
} else if common.Contains(shadowaead_2022.List, f.Method) { } else if common.Contains(shadowaead_2022.List, f.Method) {
var pskList [][shadowaead_2022.KeySaltSize]byte var pskList [][]byte
if f.Key != "" { if f.Key != "" {
keyStrList := strings.Split(f.Key, ":") keyStrList := strings.Split(f.Key, ":")
pskList = make([][shadowaead_2022.KeySaltSize]byte, len(keyStrList)) pskList = make([][]byte, len(keyStrList))
for i, keyStr := range keyStrList { for i, keyStr := range keyStrList {
var psk [shadowaead_2022.KeySaltSize]byte
kb, err := base64.StdEncoding.DecodeString(keyStr) kb, err := base64.StdEncoding.DecodeString(keyStr)
if err != nil { if err != nil {
return nil, E.Cause(err, "decode key") return nil, E.Cause(err, "decode key")
} }
if len(kb) != shadowaead_2022.KeySaltSize { pskList[i] = kb
return nil, shadowaead.ErrBadKey
}
copy(psk[:], kb)
pskList[i] = psk
} }
} }
method, err := shadowaead_2022.New(f.Method, pskList, rng) method, err := shadowaead_2022.New(f.Method, pskList, rng)

View file

@ -149,16 +149,13 @@ func newServer(f *flags) (*server, error) {
} }
s.service = service s.service = service
} else if common.Contains(shadowaead_2022.List, f.Method) { } else if common.Contains(shadowaead_2022.List, f.Method) {
var key [shadowaead_2022.KeySaltSize]byte var key []byte
if f.Key != "" { if f.Key != "" {
kb, err := base64.StdEncoding.DecodeString(f.Key) kb, err := base64.StdEncoding.DecodeString(f.Key)
if err != nil { if err != nil {
return nil, E.Cause(err, "decode key") return nil, E.Cause(err, "decode key")
} }
if len(kb) != shadowaead_2022.KeySaltSize { key = kb
return nil, shadowaead.ErrBadKey
}
copy(key[:], kb)
} }
service, err := shadowaead_2022.NewService(f.Method, key, random.Default, udpTimeout, s) service, err := shadowaead_2022.NewService(f.Method, key, random.Default, udpTimeout, s)
if err != nil { if err != nil {

View file

@ -65,7 +65,6 @@ var List = []string{
func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Method, error) { func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Method, error) {
m := &Method{ m := &Method{
name: method, name: method,
psk: pskList[len(pskList)-1],
pskList: pskList, pskList: pskList,
secureRNG: secureRNG, secureRNG: secureRNG,
replayFilter: replay.NewCuckoo(60), replayFilter: replay.NewCuckoo(60),
@ -85,12 +84,16 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
m.constructor = newChacha20Poly1305 m.constructor = newChacha20Poly1305
} }
for _, psk := range pskList { for i, psk := range pskList {
if len(psk) != m.keyLength { if len(psk) < m.keyLength {
return nil, shadowaead.ErrBadKey return nil, shadowaead.ErrBadKey
} else if len(psk) > m.keyLength {
pskList[i] = DerivePSK(psk, m.keyLength)
} }
} }
m.psk = pskList[len(pskList)-1]
if len(pskList) > 1 { if len(pskList) > 1 {
pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize) pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize)
for i, psk := range pskList { for i, psk := range pskList {
@ -111,10 +114,17 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
case "2022-blake3-chacha20-poly1305": case "2022-blake3-chacha20-poly1305":
m.udpCipher = newXChacha20Poly1305(m.psk) m.udpCipher = newXChacha20Poly1305(m.psk)
} }
return m, nil return m, nil
} }
func Blake3DeriveKey(psk []byte, salt []byte, keyLength int) []byte { func DerivePSK(key []byte, keyLength int) []byte {
outKey := buf.Make(keyLength)
blake3.DeriveKey(outKey, "shadowsocks 2022 pre shared key", key)
return outKey
}
func DeriveSessionKey(psk []byte, salt []byte, keyLength int) []byte {
sessionKey := buf.Make(len(psk) + len(salt)) sessionKey := buf.Make(len(psk) + len(salt))
copy(sessionKey, psk) copy(sessionKey, psk)
copy(sessionKey[len(psk):], salt) copy(sessionKey[len(psk):], salt)
@ -232,7 +242,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
salt := make([]byte, SaltSize) salt := make([]byte, SaltSize)
common.Must1(io.ReadFull(c.method.secureRNG, salt)) common.Must1(io.ReadFull(c.method.secureRNG, salt))
key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength) key := DeriveSessionKey(c.method.psk, salt, c.method.keyLength)
writer := shadowaead.NewWriter( writer := shadowaead.NewWriter(
c.Conn, c.Conn,
c.method.constructor(common.Dup(key)), c.method.constructor(common.Dup(key)),
@ -301,7 +311,7 @@ func (c *clientConn) readResponse() error {
return E.New("salt not unique") return E.New("salt not unique")
} }
key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength) key := DeriveSessionKey(c.method.psk, salt, c.method.keyLength)
runtime.KeepAlive(_salt) runtime.KeepAlive(_salt)
reader := shadowaead.NewReader( reader := shadowaead.NewReader(
c.Conn, c.Conn,
@ -516,7 +526,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
} else if sessionId == c.session.lastRemoteSessionId { } else if sessionId == c.session.lastRemoteSessionId {
remoteCipher = c.session.lastRemoteCipher remoteCipher = c.session.lastRemoteCipher
} else { } else {
key := Blake3DeriveKey(c.method.psk, packetHeader[:8], c.method.keyLength) key := DeriveSessionKey(c.method.psk, packetHeader[:8], c.method.keyLength)
remoteCipher = c.method.constructor(common.Dup(key)) remoteCipher = c.method.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
} }
@ -716,7 +726,7 @@ func (m *Method) newUDPSession() *udpSession {
if m.udpCipher == nil { if m.udpCipher == nil {
sessionId := make([]byte, 8) sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId) binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := Blake3DeriveKey(m.psk, sessionId, m.keyLength) key := DeriveSessionKey(m.psk, sessionId, m.keyLength)
session.cipher = m.constructor(common.Dup(key)) session.cipher = m.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
} }

View file

@ -45,7 +45,6 @@ type Service struct {
func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{ s := &Service{
name: method, name: method,
psk: psk,
secureRNG: secureRNG, secureRNG: secureRNG,
replayFilter: replay.NewCuckoo(60), replayFilter: replay.NewCuckoo(60),
handler: handler, handler: handler,
@ -70,17 +69,20 @@ func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64
s.constructor = newChacha20Poly1305 s.constructor = newChacha20Poly1305
} }
if len(psk) != s.keyLength { if len(psk) < s.keyLength {
return nil, shadowaead.ErrBadKey return nil, shadowaead.ErrBadKey
} else if len(psk) > s.keyLength {
psk = DerivePSK(psk, s.keyLength)
} }
s.psk = psk
switch method { switch method {
case "2022-blake3-aes-128-gcm": case "2022-blake3-aes-128-gcm":
s.udpBlockCipher = newAES(s.psk[:]) s.udpBlockCipher = newAES(psk)
case "2022-blake3-aes-256-gcm": case "2022-blake3-aes-256-gcm":
s.udpBlockCipher = newAES(s.psk[:]) s.udpBlockCipher = newAES(psk)
case "2022-blake3-chacha20-poly1305": case "2022-blake3-chacha20-poly1305":
s.udpCipher = newXChacha20Poly1305(s.psk[:]) s.udpCipher = newXChacha20Poly1305(psk)
} }
return s, nil return s, nil
@ -105,7 +107,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
return E.New("salt not unique") return E.New("salt not unique")
} }
requestKey := Blake3DeriveKey(s.psk[:], requestSalt, s.keyLength) requestKey := DeriveSessionKey(s.psk, requestSalt, s.keyLength)
reader := shadowaead.NewReader( reader := shadowaead.NewReader(
conn, conn,
s.constructor(common.Dup(requestKey)), s.constructor(common.Dup(requestKey)),
@ -176,7 +178,7 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
var _salt [SaltSize]byte var _salt [SaltSize]byte
salt := common.Dup(_salt[:]) salt := common.Dup(_salt[:])
common.Must1(io.ReadFull(c.secureRNG, salt)) common.Must1(io.ReadFull(c.secureRNG, salt))
key := Blake3DeriveKey(c.uPSK[:], salt, c.keyLength) key := DeriveSessionKey(c.uPSK, salt, c.keyLength)
runtime.KeepAlive(_salt) runtime.KeepAlive(_salt)
writer := shadowaead.NewWriter( writer := shadowaead.NewWriter(
c.Conn, c.Conn,
@ -292,7 +294,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me
if !loaded { if !loaded {
session.remoteSessionId = sessionId session.remoteSessionId = sessionId
if packetHeader != nil { if packetHeader != nil {
key := Blake3DeriveKey(s.psk[:], packetHeader[:8], s.keyLength) key := DeriveSessionKey(s.psk, packetHeader[:8], s.keyLength)
session.remoteCipher = s.constructor(common.Dup(key)) session.remoteCipher = s.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
} }
@ -437,7 +439,7 @@ func (m *Service) newUDPSession() *serverUDPSession {
if m.udpCipher == nil { if m.udpCipher == nil {
sessionId := make([]byte, 8) sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId) binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := Blake3DeriveKey(m.psk[:], sessionId, m.keyLength) key := DeriveSessionKey(m.psk, sessionId, m.keyLength)
session.cipher = m.constructor(common.Dup(key)) session.cipher = m.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
} }

View file

@ -29,9 +29,15 @@ type MultiService[U comparable] struct {
uPSKHashR map[[aes.BlockSize]byte]U uPSKHashR map[[aes.BlockSize]byte]U
} }
func (s *MultiService[U]) AddUser(user U, key []byte) { func (s *MultiService[U]) AddUser(user U, key []byte) error {
if len(key) < s.keyLength {
return shadowaead.ErrBadKey
} else if len(key) > s.keyLength {
key = DerivePSK(key, s.keyLength)
}
var uPSKHash [aes.BlockSize]byte var uPSKHash [aes.BlockSize]byte
hash512 := blake3.Sum512(key[:]) hash512 := blake3.Sum512(key)
copy(uPSKHash[:], hash512[:]) copy(uPSKHash[:], hash512[:])
if oldHash, loaded := s.uPSKHash[user]; loaded { if oldHash, loaded := s.uPSKHash[user]; loaded {
@ -40,8 +46,9 @@ func (s *MultiService[U]) AddUser(user U, key []byte) {
s.uPSKHash[user] = uPSKHash s.uPSKHash[user] = uPSKHash
s.uPSKHashR[uPSKHash] = user s.uPSKHashR[uPSKHash] = user
s.uPSK[user] = key s.uPSK[user] = key
return nil
} }
func (s *MultiService[U]) RemoveUser(user U) { func (s *MultiService[U]) RemoveUser(user U) {
@ -102,7 +109,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
} }
keyMaterial := buf.Make(s.keyLength + SaltSize) keyMaterial := buf.Make(s.keyLength + SaltSize)
copy(keyMaterial, s.psk[:]) copy(keyMaterial, s.psk)
copy(keyMaterial[s.keyLength:], requestSalt) copy(keyMaterial[s.keyLength:], requestSalt)
_identitySubkey := buf.Make(s.keyLength) _identitySubkey := buf.Make(s.keyLength)
identitySubkey := common.Dup(_identitySubkey) identitySubkey := common.Dup(_identitySubkey)
@ -119,7 +126,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
return E.New("invalid request") return E.New("invalid request")
} }
requestKey := Blake3DeriveKey(uPSK[:], requestSalt, s.keyLength) requestKey := DeriveSessionKey(uPSK, requestSalt, s.keyLength)
reader := shadowaead.NewReader( reader := shadowaead.NewReader(
conn, conn,
s.constructor(common.Dup(requestKey)), s.constructor(common.Dup(requestKey)),
@ -223,7 +230,7 @@ func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metad
}) })
if !loaded { if !loaded {
session.remoteSessionId = sessionId session.remoteSessionId = sessionId
key := Blake3DeriveKey(uPSK[:], packetHeader[:8], s.keyLength) key := DeriveSessionKey(uPSK, packetHeader[:8], s.keyLength)
session.remoteCipher = s.constructor(common.Dup(key)) session.remoteCipher = s.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
} }
@ -305,7 +312,7 @@ func (m *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession {
session.packetId-- session.packetId--
sessionId := make([]byte, 8) sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId) binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := Blake3DeriveKey(uPSK[:], sessionId, m.keyLength) key := DeriveSessionKey(uPSK, sessionId, m.keyLength)
session.cipher = m.constructor(common.Dup(key)) session.cipher = m.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
return session return session