diff --git a/protocol/shadowsocks/shadowaead/protocol.go b/protocol/shadowsocks/shadowaead/protocol.go index 21c5a08..fd4e64d 100644 --- a/protocol/shadowsocks/shadowaead/protocol.go +++ b/protocol/shadowsocks/shadowaead/protocol.go @@ -28,8 +28,8 @@ var List = []string{ } var ( - ErrBadKey = E.New("bad key") - ErrMissingPassword = E.New("missing password") + ErrBadKey = E.New("shadowsocks: bad key") + ErrMissingPassword = E.New("shadowsocks: missing password") ) func New(method string, key []byte, password []byte, secureRNG io.Reader) (shadowsocks.Method, error) { diff --git a/protocol/shadowsocks/shadowaead_2022/protocol.go b/protocol/shadowsocks/shadowaead_2022/protocol.go index 8c5881f..afc1e0c 100644 --- a/protocol/shadowsocks/shadowaead_2022/protocol.go +++ b/protocol/shadowsocks/shadowaead_2022/protocol.go @@ -31,7 +31,7 @@ const ( HeaderTypeClient = 0 HeaderTypeServer = 1 MaxPaddingLength = 900 - KeySaltSize = 32 + SaltSize = 32 PacketNonceSize = 24 MaxPacketSize = 65535 ) @@ -48,12 +48,12 @@ const ( ) var ( - ErrBadHeaderType = E.New("bad header type") - ErrBadTimestamp = E.New("bad timestamp") - ErrBadRequestSalt = E.New("bad request salt") - ErrBadClientSessionId = E.New("bad client session id") - ErrPacketIdNotUnique = E.New("packet id not unique") - ErrTooManyServerSessions = E.New("server session changed more than once during the last minute") + ErrBadHeaderType = E.New("shadowsocks: bad header type") + ErrBadTimestamp = E.New("shadowsocks: bad timestamp") + ErrBadRequestSalt = E.New("shadowsocks: bad request salt") + ErrBadClientSessionId = E.New("shadowsocks: bad client session id") + ErrPacketIdNotUnique = E.New("shadowsocks: packet id not unique") + ErrTooManyServerSessions = E.New("shadowsocks: server session changed more than once during the last minute") ) var List = []string{ @@ -62,7 +62,7 @@ var List = []string{ "2022-blake3-chacha20-poly1305", } -func New(method string, pskList [][KeySaltSize]byte, secureRNG io.Reader) (shadowsocks.Method, error) { +func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Method, error) { m := &Method{ name: method, psk: pskList[len(pskList)-1], @@ -71,13 +71,33 @@ func New(method string, pskList [][KeySaltSize]byte, secureRNG io.Reader) (shado replayFilter: replay.NewCuckoo(60), } + switch method { + case "2022-blake3-aes-128-gcm": + m.keyLength = 16 + m.constructor = newAESGCM + m.blockConstructor = newAES + case "2022-blake3-aes-256-gcm": + m.keyLength = 32 + m.constructor = newAESGCM + m.blockConstructor = newAES + case "2022-blake3-chacha20-poly1305": + m.keyLength = 32 + m.constructor = newChacha20Poly1305 + } + + for _, psk := range pskList { + if len(psk) != m.keyLength { + return nil, shadowaead.ErrBadKey + } + } + if len(pskList) > 1 { pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize) for i, psk := range pskList { if i == 0 { continue } - hash := blake3.Sum512(psk[:]) + hash := blake3.Sum512(psk) copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], hash[:aes.BlockSize]) } m.pskHash = pskHash @@ -85,26 +105,18 @@ func New(method string, pskList [][KeySaltSize]byte, secureRNG io.Reader) (shado switch method { case "2022-blake3-aes-128-gcm": - m.keyLength = 16 - m.constructor = newAESGCM - m.blockConstructor = newAES - m.udpBlockCipher = newAES(pskList[0][:]) + m.udpBlockCipher = newAES(pskList[0]) case "2022-blake3-aes-256-gcm": - m.keyLength = 32 - m.constructor = newAESGCM - m.blockConstructor = newAES - m.udpBlockCipher = newAES(pskList[0][:]) + m.udpBlockCipher = newAES(pskList[0]) case "2022-blake3-chacha20-poly1305": - m.keyLength = 32 - m.constructor = newChacha20Poly1305 - m.udpCipher = newXChacha20Poly1305(m.psk[:]) + m.udpCipher = newXChacha20Poly1305(m.psk) } return m, nil } func Blake3DeriveKey(psk []byte, salt []byte, keyLength int) []byte { sessionKey := buf.Make(len(psk) + len(salt)) - copy(sessionKey, psk[:]) + copy(sessionKey, psk) copy(sessionKey[len(psk):], salt) outKey := buf.Make(keyLength) blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey) @@ -144,8 +156,8 @@ type Method struct { blockConstructor func(key []byte) cipher.Block udpCipher cipher.AEAD udpBlockCipher cipher.Block - psk [KeySaltSize]byte - pskList [][KeySaltSize]byte + psk []byte + pskList [][]byte pskHash []byte secureRNG io.Reader replayFilter replay.Filter @@ -198,9 +210,9 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) return } for i, psk := range m.pskList { - keyMaterial := make([]byte, 2*KeySaltSize) - copy(keyMaterial, psk[:]) - copy(keyMaterial[KeySaltSize:], salt) + keyMaterial := buf.Make(m.keyLength + SaltSize) + copy(keyMaterial, psk) + copy(keyMaterial[m.keyLength:], salt) _identitySubkey := buf.Make(m.keyLength) identitySubkey := common.Dup(_identitySubkey) blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) @@ -217,10 +229,10 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) } func (c *clientConn) writeRequest(payload []byte) error { - salt := make([]byte, KeySaltSize) + salt := make([]byte, SaltSize) common.Must1(io.ReadFull(c.method.secureRNG, salt)) - key := Blake3DeriveKey(c.method.psk[:], salt, c.method.keyLength) + key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength) writer := shadowaead.NewWriter( c.Conn, c.method.constructor(common.Dup(key)), @@ -278,7 +290,7 @@ func (c *clientConn) readResponse() error { return nil } - _salt := make([]byte, KeySaltSize) + _salt := make([]byte, SaltSize) salt := common.Dup(_salt) _, err := io.ReadFull(c.Conn, salt) if err != nil { @@ -289,7 +301,7 @@ func (c *clientConn) readResponse() error { return E.New("salt not unique") } - key := Blake3DeriveKey(c.method.psk[:], salt, c.method.keyLength) + key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength) runtime.KeepAlive(_salt) reader := shadowaead.NewReader( c.Conn, @@ -317,7 +329,7 @@ func (c *clientConn) readResponse() error { return ErrBadTimestamp } - _requestSalt := make([]byte, KeySaltSize) + _requestSalt := make([]byte, SaltSize) requestSalt := common.Dup(_requestSalt) _, err = io.ReadFull(reader, requestSalt) if err != nil { @@ -436,7 +448,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad for textI := 0; textI < aes.BlockSize; textI++ { identityHeader[textI] = pskHash[textI] ^ header.Byte(textI) } - c.method.blockConstructor(psk[:]).Encrypt(identityHeader, identityHeader) + c.method.blockConstructor(psk).Encrypt(identityHeader, identityHeader) if i == pskLen-2 { break @@ -504,7 +516,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { } else if sessionId == c.session.lastRemoteSessionId { remoteCipher = c.session.lastRemoteCipher } else { - key := Blake3DeriveKey(c.method.psk[:], packetHeader[:8], c.method.keyLength) + key := Blake3DeriveKey(c.method.psk, packetHeader[:8], c.method.keyLength) remoteCipher = c.method.constructor(common.Dup(key)) runtime.KeepAlive(key) } @@ -644,7 +656,7 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { for textI := 0; textI < aes.BlockSize; textI++ { identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI) } - c.method.blockConstructor(psk[:]).Encrypt(identityHeader, identityHeader) + c.method.blockConstructor(psk).Encrypt(identityHeader, identityHeader) if i == pskLen-2 { break @@ -704,7 +716,7 @@ func (m *Method) newUDPSession() *udpSession { if m.udpCipher == nil { sessionId := make([]byte, 8) binary.BigEndian.PutUint64(sessionId, session.sessionId) - key := Blake3DeriveKey(m.psk[:], sessionId, m.keyLength) + key := Blake3DeriveKey(m.psk, sessionId, m.keyLength) session.cipher = m.constructor(common.Dup(key)) runtime.KeepAlive(key) } diff --git a/protocol/shadowsocks/shadowaead_2022/service.go b/protocol/shadowsocks/shadowaead_2022/service.go index a0d1f08..b7d6055 100644 --- a/protocol/shadowsocks/shadowaead_2022/service.go +++ b/protocol/shadowsocks/shadowaead_2022/service.go @@ -35,14 +35,14 @@ type Service struct { blockConstructor func(key []byte) cipher.Block udpCipher cipher.AEAD udpBlockCipher cipher.Block - psk [KeySaltSize]byte + psk []byte replayFilter replay.Filter handler shadowsocks.Handler udpNat *udpnat.Service[uint64] sessions *cache.LruCache[uint64, *serverUDPSession] } -func NewService(method string, psk [KeySaltSize]byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { +func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { s := &Service{ name: method, psk: psk, @@ -61,15 +61,25 @@ func NewService(method string, psk [KeySaltSize]byte, secureRNG io.Reader, udpTi s.keyLength = 16 s.constructor = newAESGCM s.blockConstructor = newAES - s.udpBlockCipher = newAES(s.psk[:]) case "2022-blake3-aes-256-gcm": s.keyLength = 32 s.constructor = newAESGCM s.blockConstructor = newAES - s.udpBlockCipher = newAES(s.psk[:]) case "2022-blake3-chacha20-poly1305": s.keyLength = 32 s.constructor = newChacha20Poly1305 + } + + if len(psk) != s.keyLength { + return nil, shadowaead.ErrBadKey + } + + switch method { + case "2022-blake3-aes-128-gcm": + s.udpBlockCipher = newAES(s.psk[:]) + case "2022-blake3-aes-256-gcm": + s.udpBlockCipher = newAES(s.psk[:]) + case "2022-blake3-chacha20-poly1305": s.udpCipher = newXChacha20Poly1305(s.psk[:]) } @@ -85,7 +95,7 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M } func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - requestSalt := make([]byte, KeySaltSize) + requestSalt := make([]byte, SaltSize) _, err := io.ReadFull(conn, requestSalt) if err != nil { return E.Cause(err, "read request salt") @@ -155,7 +165,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M type serverConn struct { *Service net.Conn - uPSK [KeySaltSize]byte + uPSK []byte access sync.Mutex reader *shadowaead.Reader writer *shadowaead.Writer @@ -163,7 +173,7 @@ type serverConn struct { } func (c *serverConn) writeResponse(payload []byte) (n int, err error) { - var _salt [KeySaltSize]byte + var _salt [SaltSize]byte salt := common.Dup(_salt[:]) common.Must1(io.ReadFull(c.secureRNG, salt)) key := Blake3DeriveKey(c.uPSK[:], salt, c.keyLength) diff --git a/protocol/shadowsocks/shadowaead_2022/service_multi.go b/protocol/shadowsocks/shadowaead_2022/service_multi.go index 509e881..c55d4d4 100644 --- a/protocol/shadowsocks/shadowaead_2022/service_multi.go +++ b/protocol/shadowsocks/shadowaead_2022/service_multi.go @@ -24,12 +24,12 @@ import ( type MultiService[U comparable] struct { *Service - uPSK map[U][KeySaltSize]byte + uPSK map[U][]byte uPSKHash map[U][aes.BlockSize]byte uPSKHashR map[[aes.BlockSize]byte]U } -func (s *MultiService[U]) AddUser(user U, key [KeySaltSize]byte) { +func (s *MultiService[U]) AddUser(user U, key []byte) { var uPSKHash [aes.BlockSize]byte hash512 := blake3.Sum512(key[:]) copy(uPSKHash[:], hash512[:]) @@ -52,7 +52,7 @@ func (s *MultiService[U]) RemoveUser(user U) { delete(s.uPSKHash, user) } -func NewMultiService[U comparable](method string, iPSK [KeySaltSize]byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) { +func NewMultiService[U comparable](method string, iPSK []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) { switch method { case "2022-blake3-aes-128-gcm": case "2022-blake3-aes-256-gcm": @@ -68,7 +68,7 @@ func NewMultiService[U comparable](method string, iPSK [KeySaltSize]byte, secure s := &MultiService[U]{ Service: ss.(*Service), - uPSK: make(map[U][KeySaltSize]byte), + uPSK: make(map[U][]byte), uPSKHash: make(map[U][aes.BlockSize]byte), uPSKHashR: make(map[[aes.BlockSize]byte]U), } @@ -84,7 +84,7 @@ func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, meta } func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - requestSalt := make([]byte, KeySaltSize) + requestSalt := make([]byte, SaltSize) _, err := io.ReadFull(conn, requestSalt) if err != nil { return E.Cause(err, "read request salt") @@ -101,9 +101,9 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta return E.Cause(err, "read extended identity header") } - keyMaterial := make([]byte, 2*KeySaltSize) + keyMaterial := buf.Make(s.keyLength + SaltSize) copy(keyMaterial, s.psk[:]) - copy(keyMaterial[KeySaltSize:], requestSalt) + copy(keyMaterial[s.keyLength:], requestSalt) _identitySubkey := buf.Make(s.keyLength) identitySubkey := common.Dup(_identitySubkey) blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) @@ -111,7 +111,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta runtime.KeepAlive(_identitySubkey) var user U - var uPSK [KeySaltSize]byte + var uPSK []byte if u, loaded := s.uPSKHashR[_eiHeader]; loaded { user = u uPSK = s.uPSK[u] @@ -200,7 +200,7 @@ func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metad } var user U - var uPSK [KeySaltSize]byte + var uPSK []byte if u, loaded := s.uPSKHashR[_eiHeader]; loaded { user = u uPSK = s.uPSK[u] @@ -299,7 +299,7 @@ process: return nil } -func (m *MultiService[U]) newUDPSession(uPSK [KeySaltSize]byte) *serverUDPSession { +func (m *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession { session := &serverUDPSession{} common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId)) session.packetId-- diff --git a/protocol/shadowsocks/shadowaead_2022/service_multi_test.go b/protocol/shadowsocks/shadowaead_2022/service_multi_test.go index 64c41c3..16d6648 100644 --- a/protocol/shadowsocks/shadowaead_2022/service_multi_test.go +++ b/protocol/shadowsocks/shadowaead_2022/service_multi_test.go @@ -15,21 +15,21 @@ import ( func TestMultiService(t *testing.T) { method := "2022-blake3-aes-128-gcm" - var iPSK [shadowaead_2022.KeySaltSize]byte + var iPSK [16]byte random.Default.Read(iPSK[:]) var wg sync.WaitGroup - multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK, random.Default, 500, &multiHandler{t, &wg}) + multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], random.Default, 500, &multiHandler{t, &wg}) if err != nil { t.Fatal(err) } - var uPSK [shadowaead_2022.KeySaltSize]byte + var uPSK [16]byte random.Default.Read(uPSK[:]) - multiService.AddUser("my user", uPSK) + multiService.AddUser("my user", uPSK[:]) - client, err := shadowaead_2022.New(method, [][shadowaead_2022.KeySaltSize]byte{iPSK, uPSK}, random.Default) + client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}, random.Default) if err != nil { t.Fatal(err) }