Fix shadowsocks blake3-aes-128-gcm implementation

This commit is contained in:
世界 2022-05-12 13:23:34 +08:00
parent 1adec71629
commit 69d801e651
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 81 additions and 59 deletions

View file

@ -28,8 +28,8 @@ var List = []string{
} }
var ( var (
ErrBadKey = E.New("bad key") ErrBadKey = E.New("shadowsocks: bad key")
ErrMissingPassword = E.New("missing password") ErrMissingPassword = E.New("shadowsocks: missing password")
) )
func New(method string, key []byte, password []byte, secureRNG io.Reader) (shadowsocks.Method, error) { func New(method string, key []byte, password []byte, secureRNG io.Reader) (shadowsocks.Method, error) {

View file

@ -31,7 +31,7 @@ const (
HeaderTypeClient = 0 HeaderTypeClient = 0
HeaderTypeServer = 1 HeaderTypeServer = 1
MaxPaddingLength = 900 MaxPaddingLength = 900
KeySaltSize = 32 SaltSize = 32
PacketNonceSize = 24 PacketNonceSize = 24
MaxPacketSize = 65535 MaxPacketSize = 65535
) )
@ -48,12 +48,12 @@ const (
) )
var ( var (
ErrBadHeaderType = E.New("bad header type") ErrBadHeaderType = E.New("shadowsocks: bad header type")
ErrBadTimestamp = E.New("bad timestamp") ErrBadTimestamp = E.New("shadowsocks: bad timestamp")
ErrBadRequestSalt = E.New("bad request salt") ErrBadRequestSalt = E.New("shadowsocks: bad request salt")
ErrBadClientSessionId = E.New("bad client session id") ErrBadClientSessionId = E.New("shadowsocks: bad client session id")
ErrPacketIdNotUnique = E.New("packet id not unique") ErrPacketIdNotUnique = E.New("shadowsocks: packet id not unique")
ErrTooManyServerSessions = E.New("server session changed more than once during the last minute") ErrTooManyServerSessions = E.New("shadowsocks: server session changed more than once during the last minute")
) )
var List = []string{ var List = []string{
@ -62,7 +62,7 @@ var List = []string{
"2022-blake3-chacha20-poly1305", "2022-blake3-chacha20-poly1305",
} }
func New(method string, pskList [][KeySaltSize]byte, secureRNG io.Reader) (shadowsocks.Method, error) { func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Method, error) {
m := &Method{ m := &Method{
name: method, name: method,
psk: pskList[len(pskList)-1], psk: pskList[len(pskList)-1],
@ -71,13 +71,33 @@ func New(method string, pskList [][KeySaltSize]byte, secureRNG io.Reader) (shado
replayFilter: replay.NewCuckoo(60), replayFilter: replay.NewCuckoo(60),
} }
switch method {
case "2022-blake3-aes-128-gcm":
m.keyLength = 16
m.constructor = newAESGCM
m.blockConstructor = newAES
case "2022-blake3-aes-256-gcm":
m.keyLength = 32
m.constructor = newAESGCM
m.blockConstructor = newAES
case "2022-blake3-chacha20-poly1305":
m.keyLength = 32
m.constructor = newChacha20Poly1305
}
for _, psk := range pskList {
if len(psk) != m.keyLength {
return nil, shadowaead.ErrBadKey
}
}
if len(pskList) > 1 { if len(pskList) > 1 {
pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize) pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize)
for i, psk := range pskList { for i, psk := range pskList {
if i == 0 { if i == 0 {
continue continue
} }
hash := blake3.Sum512(psk[:]) hash := blake3.Sum512(psk)
copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], hash[:aes.BlockSize]) copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], hash[:aes.BlockSize])
} }
m.pskHash = pskHash m.pskHash = pskHash
@ -85,26 +105,18 @@ func New(method string, pskList [][KeySaltSize]byte, secureRNG io.Reader) (shado
switch method { switch method {
case "2022-blake3-aes-128-gcm": case "2022-blake3-aes-128-gcm":
m.keyLength = 16 m.udpBlockCipher = newAES(pskList[0])
m.constructor = newAESGCM
m.blockConstructor = newAES
m.udpBlockCipher = newAES(pskList[0][:])
case "2022-blake3-aes-256-gcm": case "2022-blake3-aes-256-gcm":
m.keyLength = 32 m.udpBlockCipher = newAES(pskList[0])
m.constructor = newAESGCM
m.blockConstructor = newAES
m.udpBlockCipher = newAES(pskList[0][:])
case "2022-blake3-chacha20-poly1305": case "2022-blake3-chacha20-poly1305":
m.keyLength = 32 m.udpCipher = newXChacha20Poly1305(m.psk)
m.constructor = newChacha20Poly1305
m.udpCipher = newXChacha20Poly1305(m.psk[:])
} }
return m, nil return m, nil
} }
func Blake3DeriveKey(psk []byte, salt []byte, keyLength int) []byte { func Blake3DeriveKey(psk []byte, salt []byte, keyLength int) []byte {
sessionKey := buf.Make(len(psk) + len(salt)) sessionKey := buf.Make(len(psk) + len(salt))
copy(sessionKey, psk[:]) copy(sessionKey, psk)
copy(sessionKey[len(psk):], salt) copy(sessionKey[len(psk):], salt)
outKey := buf.Make(keyLength) outKey := buf.Make(keyLength)
blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey) blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey)
@ -144,8 +156,8 @@ type Method struct {
blockConstructor func(key []byte) cipher.Block blockConstructor func(key []byte) cipher.Block
udpCipher cipher.AEAD udpCipher cipher.AEAD
udpBlockCipher cipher.Block udpBlockCipher cipher.Block
psk [KeySaltSize]byte psk []byte
pskList [][KeySaltSize]byte pskList [][]byte
pskHash []byte pskHash []byte
secureRNG io.Reader secureRNG io.Reader
replayFilter replay.Filter replayFilter replay.Filter
@ -198,9 +210,9 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
return return
} }
for i, psk := range m.pskList { for i, psk := range m.pskList {
keyMaterial := make([]byte, 2*KeySaltSize) keyMaterial := buf.Make(m.keyLength + SaltSize)
copy(keyMaterial, psk[:]) copy(keyMaterial, psk)
copy(keyMaterial[KeySaltSize:], salt) copy(keyMaterial[m.keyLength:], salt)
_identitySubkey := buf.Make(m.keyLength) _identitySubkey := buf.Make(m.keyLength)
identitySubkey := common.Dup(_identitySubkey) identitySubkey := common.Dup(_identitySubkey)
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
@ -217,10 +229,10 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
} }
func (c *clientConn) writeRequest(payload []byte) error { func (c *clientConn) writeRequest(payload []byte) error {
salt := make([]byte, KeySaltSize) salt := make([]byte, SaltSize)
common.Must1(io.ReadFull(c.method.secureRNG, salt)) common.Must1(io.ReadFull(c.method.secureRNG, salt))
key := Blake3DeriveKey(c.method.psk[:], salt, c.method.keyLength) key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength)
writer := shadowaead.NewWriter( writer := shadowaead.NewWriter(
c.Conn, c.Conn,
c.method.constructor(common.Dup(key)), c.method.constructor(common.Dup(key)),
@ -278,7 +290,7 @@ func (c *clientConn) readResponse() error {
return nil return nil
} }
_salt := make([]byte, KeySaltSize) _salt := make([]byte, SaltSize)
salt := common.Dup(_salt) salt := common.Dup(_salt)
_, err := io.ReadFull(c.Conn, salt) _, err := io.ReadFull(c.Conn, salt)
if err != nil { if err != nil {
@ -289,7 +301,7 @@ func (c *clientConn) readResponse() error {
return E.New("salt not unique") return E.New("salt not unique")
} }
key := Blake3DeriveKey(c.method.psk[:], salt, c.method.keyLength) key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength)
runtime.KeepAlive(_salt) runtime.KeepAlive(_salt)
reader := shadowaead.NewReader( reader := shadowaead.NewReader(
c.Conn, c.Conn,
@ -317,7 +329,7 @@ func (c *clientConn) readResponse() error {
return ErrBadTimestamp return ErrBadTimestamp
} }
_requestSalt := make([]byte, KeySaltSize) _requestSalt := make([]byte, SaltSize)
requestSalt := common.Dup(_requestSalt) requestSalt := common.Dup(_requestSalt)
_, err = io.ReadFull(reader, requestSalt) _, err = io.ReadFull(reader, requestSalt)
if err != nil { if err != nil {
@ -436,7 +448,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
for textI := 0; textI < aes.BlockSize; textI++ { for textI := 0; textI < aes.BlockSize; textI++ {
identityHeader[textI] = pskHash[textI] ^ header.Byte(textI) identityHeader[textI] = pskHash[textI] ^ header.Byte(textI)
} }
c.method.blockConstructor(psk[:]).Encrypt(identityHeader, identityHeader) c.method.blockConstructor(psk).Encrypt(identityHeader, identityHeader)
if i == pskLen-2 { if i == pskLen-2 {
break break
@ -504,7 +516,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
} else if sessionId == c.session.lastRemoteSessionId { } else if sessionId == c.session.lastRemoteSessionId {
remoteCipher = c.session.lastRemoteCipher remoteCipher = c.session.lastRemoteCipher
} else { } else {
key := Blake3DeriveKey(c.method.psk[:], packetHeader[:8], c.method.keyLength) key := Blake3DeriveKey(c.method.psk, packetHeader[:8], c.method.keyLength)
remoteCipher = c.method.constructor(common.Dup(key)) remoteCipher = c.method.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
} }
@ -644,7 +656,7 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
for textI := 0; textI < aes.BlockSize; textI++ { for textI := 0; textI < aes.BlockSize; textI++ {
identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI) identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI)
} }
c.method.blockConstructor(psk[:]).Encrypt(identityHeader, identityHeader) c.method.blockConstructor(psk).Encrypt(identityHeader, identityHeader)
if i == pskLen-2 { if i == pskLen-2 {
break break
@ -704,7 +716,7 @@ func (m *Method) newUDPSession() *udpSession {
if m.udpCipher == nil { if m.udpCipher == nil {
sessionId := make([]byte, 8) sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId) binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := Blake3DeriveKey(m.psk[:], sessionId, m.keyLength) key := Blake3DeriveKey(m.psk, sessionId, m.keyLength)
session.cipher = m.constructor(common.Dup(key)) session.cipher = m.constructor(common.Dup(key))
runtime.KeepAlive(key) runtime.KeepAlive(key)
} }

View file

@ -35,14 +35,14 @@ type Service struct {
blockConstructor func(key []byte) cipher.Block blockConstructor func(key []byte) cipher.Block
udpCipher cipher.AEAD udpCipher cipher.AEAD
udpBlockCipher cipher.Block udpBlockCipher cipher.Block
psk [KeySaltSize]byte psk []byte
replayFilter replay.Filter replayFilter replay.Filter
handler shadowsocks.Handler handler shadowsocks.Handler
udpNat *udpnat.Service[uint64] udpNat *udpnat.Service[uint64]
sessions *cache.LruCache[uint64, *serverUDPSession] sessions *cache.LruCache[uint64, *serverUDPSession]
} }
func NewService(method string, psk [KeySaltSize]byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{ s := &Service{
name: method, name: method,
psk: psk, psk: psk,
@ -61,15 +61,25 @@ func NewService(method string, psk [KeySaltSize]byte, secureRNG io.Reader, udpTi
s.keyLength = 16 s.keyLength = 16
s.constructor = newAESGCM s.constructor = newAESGCM
s.blockConstructor = newAES s.blockConstructor = newAES
s.udpBlockCipher = newAES(s.psk[:])
case "2022-blake3-aes-256-gcm": case "2022-blake3-aes-256-gcm":
s.keyLength = 32 s.keyLength = 32
s.constructor = newAESGCM s.constructor = newAESGCM
s.blockConstructor = newAES s.blockConstructor = newAES
s.udpBlockCipher = newAES(s.psk[:])
case "2022-blake3-chacha20-poly1305": case "2022-blake3-chacha20-poly1305":
s.keyLength = 32 s.keyLength = 32
s.constructor = newChacha20Poly1305 s.constructor = newChacha20Poly1305
}
if len(psk) != s.keyLength {
return nil, shadowaead.ErrBadKey
}
switch method {
case "2022-blake3-aes-128-gcm":
s.udpBlockCipher = newAES(s.psk[:])
case "2022-blake3-aes-256-gcm":
s.udpBlockCipher = newAES(s.psk[:])
case "2022-blake3-chacha20-poly1305":
s.udpCipher = newXChacha20Poly1305(s.psk[:]) s.udpCipher = newXChacha20Poly1305(s.psk[:])
} }
@ -85,7 +95,7 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
} }
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
requestSalt := make([]byte, KeySaltSize) requestSalt := make([]byte, SaltSize)
_, err := io.ReadFull(conn, requestSalt) _, err := io.ReadFull(conn, requestSalt)
if err != nil { if err != nil {
return E.Cause(err, "read request salt") return E.Cause(err, "read request salt")
@ -155,7 +165,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
type serverConn struct { type serverConn struct {
*Service *Service
net.Conn net.Conn
uPSK [KeySaltSize]byte uPSK []byte
access sync.Mutex access sync.Mutex
reader *shadowaead.Reader reader *shadowaead.Reader
writer *shadowaead.Writer writer *shadowaead.Writer
@ -163,7 +173,7 @@ type serverConn struct {
} }
func (c *serverConn) writeResponse(payload []byte) (n int, err error) { func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
var _salt [KeySaltSize]byte var _salt [SaltSize]byte
salt := common.Dup(_salt[:]) salt := common.Dup(_salt[:])
common.Must1(io.ReadFull(c.secureRNG, salt)) common.Must1(io.ReadFull(c.secureRNG, salt))
key := Blake3DeriveKey(c.uPSK[:], salt, c.keyLength) key := Blake3DeriveKey(c.uPSK[:], salt, c.keyLength)

View file

@ -24,12 +24,12 @@ import (
type MultiService[U comparable] struct { type MultiService[U comparable] struct {
*Service *Service
uPSK map[U][KeySaltSize]byte uPSK map[U][]byte
uPSKHash map[U][aes.BlockSize]byte uPSKHash map[U][aes.BlockSize]byte
uPSKHashR map[[aes.BlockSize]byte]U uPSKHashR map[[aes.BlockSize]byte]U
} }
func (s *MultiService[U]) AddUser(user U, key [KeySaltSize]byte) { func (s *MultiService[U]) AddUser(user U, key []byte) {
var uPSKHash [aes.BlockSize]byte var uPSKHash [aes.BlockSize]byte
hash512 := blake3.Sum512(key[:]) hash512 := blake3.Sum512(key[:])
copy(uPSKHash[:], hash512[:]) copy(uPSKHash[:], hash512[:])
@ -52,7 +52,7 @@ func (s *MultiService[U]) RemoveUser(user U) {
delete(s.uPSKHash, user) delete(s.uPSKHash, user)
} }
func NewMultiService[U comparable](method string, iPSK [KeySaltSize]byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) { func NewMultiService[U comparable](method string, iPSK []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) {
switch method { switch method {
case "2022-blake3-aes-128-gcm": case "2022-blake3-aes-128-gcm":
case "2022-blake3-aes-256-gcm": case "2022-blake3-aes-256-gcm":
@ -68,7 +68,7 @@ func NewMultiService[U comparable](method string, iPSK [KeySaltSize]byte, secure
s := &MultiService[U]{ s := &MultiService[U]{
Service: ss.(*Service), Service: ss.(*Service),
uPSK: make(map[U][KeySaltSize]byte), uPSK: make(map[U][]byte),
uPSKHash: make(map[U][aes.BlockSize]byte), uPSKHash: make(map[U][aes.BlockSize]byte),
uPSKHashR: make(map[[aes.BlockSize]byte]U), uPSKHashR: make(map[[aes.BlockSize]byte]U),
} }
@ -84,7 +84,7 @@ func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, meta
} }
func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
requestSalt := make([]byte, KeySaltSize) requestSalt := make([]byte, SaltSize)
_, err := io.ReadFull(conn, requestSalt) _, err := io.ReadFull(conn, requestSalt)
if err != nil { if err != nil {
return E.Cause(err, "read request salt") return E.Cause(err, "read request salt")
@ -101,9 +101,9 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
return E.Cause(err, "read extended identity header") return E.Cause(err, "read extended identity header")
} }
keyMaterial := make([]byte, 2*KeySaltSize) keyMaterial := buf.Make(s.keyLength + SaltSize)
copy(keyMaterial, s.psk[:]) copy(keyMaterial, s.psk[:])
copy(keyMaterial[KeySaltSize:], requestSalt) copy(keyMaterial[s.keyLength:], requestSalt)
_identitySubkey := buf.Make(s.keyLength) _identitySubkey := buf.Make(s.keyLength)
identitySubkey := common.Dup(_identitySubkey) identitySubkey := common.Dup(_identitySubkey)
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
@ -111,7 +111,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
runtime.KeepAlive(_identitySubkey) runtime.KeepAlive(_identitySubkey)
var user U var user U
var uPSK [KeySaltSize]byte var uPSK []byte
if u, loaded := s.uPSKHashR[_eiHeader]; loaded { if u, loaded := s.uPSKHashR[_eiHeader]; loaded {
user = u user = u
uPSK = s.uPSK[u] uPSK = s.uPSK[u]
@ -200,7 +200,7 @@ func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metad
} }
var user U var user U
var uPSK [KeySaltSize]byte var uPSK []byte
if u, loaded := s.uPSKHashR[_eiHeader]; loaded { if u, loaded := s.uPSKHashR[_eiHeader]; loaded {
user = u user = u
uPSK = s.uPSK[u] uPSK = s.uPSK[u]
@ -299,7 +299,7 @@ process:
return nil return nil
} }
func (m *MultiService[U]) newUDPSession(uPSK [KeySaltSize]byte) *serverUDPSession { func (m *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession {
session := &serverUDPSession{} session := &serverUDPSession{}
common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId)) common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId))
session.packetId-- session.packetId--

View file

@ -15,21 +15,21 @@ import (
func TestMultiService(t *testing.T) { func TestMultiService(t *testing.T) {
method := "2022-blake3-aes-128-gcm" method := "2022-blake3-aes-128-gcm"
var iPSK [shadowaead_2022.KeySaltSize]byte var iPSK [16]byte
random.Default.Read(iPSK[:]) random.Default.Read(iPSK[:])
var wg sync.WaitGroup var wg sync.WaitGroup
multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK, random.Default, 500, &multiHandler{t, &wg}) multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], random.Default, 500, &multiHandler{t, &wg})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var uPSK [shadowaead_2022.KeySaltSize]byte var uPSK [16]byte
random.Default.Read(uPSK[:]) random.Default.Read(uPSK[:])
multiService.AddUser("my user", uPSK) multiService.AddUser("my user", uPSK[:])
client, err := shadowaead_2022.New(method, [][shadowaead_2022.KeySaltSize]byte{iPSK, uPSK}, random.Default) client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}, random.Default)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }