diff --git a/cli/ss-local/main.go b/cli/ss-local/main.go index a7d9928..680f419 100644 --- a/cli/ss-local/main.go +++ b/cli/ss-local/main.go @@ -211,21 +211,16 @@ func newClient(f *flags) (*client, error) { } c.method = method } else if common.Contains(shadowaead_2022.List, f.Method) { - var pskList [][shadowaead_2022.KeySaltSize]byte + var pskList [][]byte if f.Key != "" { keyStrList := strings.Split(f.Key, ":") - pskList = make([][shadowaead_2022.KeySaltSize]byte, len(keyStrList)) + pskList = make([][]byte, len(keyStrList)) for i, keyStr := range keyStrList { - var psk [shadowaead_2022.KeySaltSize]byte kb, err := base64.StdEncoding.DecodeString(keyStr) if err != nil { return nil, E.Cause(err, "decode key") } - if len(kb) != shadowaead_2022.KeySaltSize { - return nil, shadowaead.ErrBadKey - } - copy(psk[:], kb) - pskList[i] = psk + pskList[i] = kb } } method, err := shadowaead_2022.New(f.Method, pskList, rng) diff --git a/cli/ss-server/main.go b/cli/ss-server/main.go index 4301163..210cc1c 100644 --- a/cli/ss-server/main.go +++ b/cli/ss-server/main.go @@ -149,16 +149,13 @@ func newServer(f *flags) (*server, error) { } s.service = service } else if common.Contains(shadowaead_2022.List, f.Method) { - var key [shadowaead_2022.KeySaltSize]byte + var key []byte if f.Key != "" { kb, err := base64.StdEncoding.DecodeString(f.Key) if err != nil { return nil, E.Cause(err, "decode key") } - if len(kb) != shadowaead_2022.KeySaltSize { - return nil, shadowaead.ErrBadKey - } - copy(key[:], kb) + key = kb } service, err := shadowaead_2022.NewService(f.Method, key, random.Default, udpTimeout, s) if err != nil { diff --git a/protocol/shadowsocks/shadowaead_2022/protocol.go b/protocol/shadowsocks/shadowaead_2022/protocol.go index afc1e0c..842b8a8 100644 --- a/protocol/shadowsocks/shadowaead_2022/protocol.go +++ b/protocol/shadowsocks/shadowaead_2022/protocol.go @@ -65,7 +65,6 @@ var List = []string{ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Method, error) { m := &Method{ name: method, - psk: pskList[len(pskList)-1], pskList: pskList, secureRNG: secureRNG, replayFilter: replay.NewCuckoo(60), @@ -85,12 +84,16 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth m.constructor = newChacha20Poly1305 } - for _, psk := range pskList { - if len(psk) != m.keyLength { + for i, psk := range pskList { + if len(psk) < m.keyLength { return nil, shadowaead.ErrBadKey + } else if len(psk) > m.keyLength { + pskList[i] = DerivePSK(psk, m.keyLength) } } + m.psk = pskList[len(pskList)-1] + if len(pskList) > 1 { pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize) for i, psk := range pskList { @@ -111,10 +114,17 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth case "2022-blake3-chacha20-poly1305": m.udpCipher = newXChacha20Poly1305(m.psk) } + return m, nil } -func Blake3DeriveKey(psk []byte, salt []byte, keyLength int) []byte { +func DerivePSK(key []byte, keyLength int) []byte { + outKey := buf.Make(keyLength) + blake3.DeriveKey(outKey, "shadowsocks 2022 pre shared key", key) + return outKey +} + +func DeriveSessionKey(psk []byte, salt []byte, keyLength int) []byte { sessionKey := buf.Make(len(psk) + len(salt)) copy(sessionKey, psk) copy(sessionKey[len(psk):], salt) @@ -232,7 +242,7 @@ func (c *clientConn) writeRequest(payload []byte) error { salt := make([]byte, SaltSize) common.Must1(io.ReadFull(c.method.secureRNG, salt)) - key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength) + key := DeriveSessionKey(c.method.psk, salt, c.method.keyLength) writer := shadowaead.NewWriter( c.Conn, c.method.constructor(common.Dup(key)), @@ -301,7 +311,7 @@ func (c *clientConn) readResponse() error { return E.New("salt not unique") } - key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength) + key := DeriveSessionKey(c.method.psk, salt, c.method.keyLength) runtime.KeepAlive(_salt) reader := shadowaead.NewReader( c.Conn, @@ -516,7 +526,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { } else if sessionId == c.session.lastRemoteSessionId { remoteCipher = c.session.lastRemoteCipher } else { - key := Blake3DeriveKey(c.method.psk, packetHeader[:8], c.method.keyLength) + key := DeriveSessionKey(c.method.psk, packetHeader[:8], c.method.keyLength) remoteCipher = c.method.constructor(common.Dup(key)) runtime.KeepAlive(key) } @@ -716,7 +726,7 @@ func (m *Method) newUDPSession() *udpSession { if m.udpCipher == nil { sessionId := make([]byte, 8) binary.BigEndian.PutUint64(sessionId, session.sessionId) - key := Blake3DeriveKey(m.psk, sessionId, m.keyLength) + key := DeriveSessionKey(m.psk, sessionId, m.keyLength) session.cipher = m.constructor(common.Dup(key)) runtime.KeepAlive(key) } diff --git a/protocol/shadowsocks/shadowaead_2022/service.go b/protocol/shadowsocks/shadowaead_2022/service.go index b7d6055..e1fcf9b 100644 --- a/protocol/shadowsocks/shadowaead_2022/service.go +++ b/protocol/shadowsocks/shadowaead_2022/service.go @@ -45,7 +45,6 @@ type Service struct { func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { s := &Service{ name: method, - psk: psk, secureRNG: secureRNG, replayFilter: replay.NewCuckoo(60), handler: handler, @@ -70,17 +69,20 @@ func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64 s.constructor = newChacha20Poly1305 } - if len(psk) != s.keyLength { + if len(psk) < s.keyLength { return nil, shadowaead.ErrBadKey + } else if len(psk) > s.keyLength { + psk = DerivePSK(psk, s.keyLength) } + s.psk = psk switch method { case "2022-blake3-aes-128-gcm": - s.udpBlockCipher = newAES(s.psk[:]) + s.udpBlockCipher = newAES(psk) case "2022-blake3-aes-256-gcm": - s.udpBlockCipher = newAES(s.psk[:]) + s.udpBlockCipher = newAES(psk) case "2022-blake3-chacha20-poly1305": - s.udpCipher = newXChacha20Poly1305(s.psk[:]) + s.udpCipher = newXChacha20Poly1305(psk) } return s, nil @@ -105,7 +107,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M return E.New("salt not unique") } - requestKey := Blake3DeriveKey(s.psk[:], requestSalt, s.keyLength) + requestKey := DeriveSessionKey(s.psk, requestSalt, s.keyLength) reader := shadowaead.NewReader( conn, s.constructor(common.Dup(requestKey)), @@ -176,7 +178,7 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) { var _salt [SaltSize]byte salt := common.Dup(_salt[:]) common.Must1(io.ReadFull(c.secureRNG, salt)) - key := Blake3DeriveKey(c.uPSK[:], salt, c.keyLength) + key := DeriveSessionKey(c.uPSK, salt, c.keyLength) runtime.KeepAlive(_salt) writer := shadowaead.NewWriter( c.Conn, @@ -292,7 +294,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me if !loaded { session.remoteSessionId = sessionId if packetHeader != nil { - key := Blake3DeriveKey(s.psk[:], packetHeader[:8], s.keyLength) + key := DeriveSessionKey(s.psk, packetHeader[:8], s.keyLength) session.remoteCipher = s.constructor(common.Dup(key)) runtime.KeepAlive(key) } @@ -437,7 +439,7 @@ func (m *Service) newUDPSession() *serverUDPSession { if m.udpCipher == nil { sessionId := make([]byte, 8) binary.BigEndian.PutUint64(sessionId, session.sessionId) - key := Blake3DeriveKey(m.psk[:], sessionId, m.keyLength) + key := DeriveSessionKey(m.psk, sessionId, m.keyLength) session.cipher = m.constructor(common.Dup(key)) runtime.KeepAlive(key) } diff --git a/protocol/shadowsocks/shadowaead_2022/service_multi.go b/protocol/shadowsocks/shadowaead_2022/service_multi.go index c55d4d4..9e2880d 100644 --- a/protocol/shadowsocks/shadowaead_2022/service_multi.go +++ b/protocol/shadowsocks/shadowaead_2022/service_multi.go @@ -29,9 +29,15 @@ type MultiService[U comparable] struct { uPSKHashR map[[aes.BlockSize]byte]U } -func (s *MultiService[U]) AddUser(user U, key []byte) { +func (s *MultiService[U]) AddUser(user U, key []byte) error { + if len(key) < s.keyLength { + return shadowaead.ErrBadKey + } else if len(key) > s.keyLength { + key = DerivePSK(key, s.keyLength) + } + var uPSKHash [aes.BlockSize]byte - hash512 := blake3.Sum512(key[:]) + hash512 := blake3.Sum512(key) copy(uPSKHash[:], hash512[:]) if oldHash, loaded := s.uPSKHash[user]; loaded { @@ -40,8 +46,9 @@ func (s *MultiService[U]) AddUser(user U, key []byte) { s.uPSKHash[user] = uPSKHash s.uPSKHashR[uPSKHash] = user - s.uPSK[user] = key + + return nil } func (s *MultiService[U]) RemoveUser(user U) { @@ -102,7 +109,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta } keyMaterial := buf.Make(s.keyLength + SaltSize) - copy(keyMaterial, s.psk[:]) + copy(keyMaterial, s.psk) copy(keyMaterial[s.keyLength:], requestSalt) _identitySubkey := buf.Make(s.keyLength) identitySubkey := common.Dup(_identitySubkey) @@ -119,7 +126,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta return E.New("invalid request") } - requestKey := Blake3DeriveKey(uPSK[:], requestSalt, s.keyLength) + requestKey := DeriveSessionKey(uPSK, requestSalt, s.keyLength) reader := shadowaead.NewReader( conn, s.constructor(common.Dup(requestKey)), @@ -223,7 +230,7 @@ func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metad }) if !loaded { session.remoteSessionId = sessionId - key := Blake3DeriveKey(uPSK[:], packetHeader[:8], s.keyLength) + key := DeriveSessionKey(uPSK, packetHeader[:8], s.keyLength) session.remoteCipher = s.constructor(common.Dup(key)) runtime.KeepAlive(key) } @@ -305,7 +312,7 @@ func (m *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession { session.packetId-- sessionId := make([]byte, 8) binary.BigEndian.PutUint64(sessionId, session.sessionId) - key := Blake3DeriveKey(uPSK[:], sessionId, m.keyLength) + key := DeriveSessionKey(uPSK, sessionId, m.keyLength) session.cipher = m.constructor(common.Dup(key)) runtime.KeepAlive(key) return session