From 39baf9c17e7d3efcaa5669757b664ead978768d2 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 14 Apr 2023 20:44:20 +0800 Subject: [PATCH] Add multi-user service for no-2022 aead --- shadowaead/service.go | 72 ++++------- shadowaead/service_multi.go | 208 +++++++++++++++++++++++++++++++ shadowaead_2022/service_multi.go | 2 + shadowsocks.go | 10 ++ 4 files changed, 245 insertions(+), 47 deletions(-) create mode 100644 shadowaead/service_multi.go diff --git a/shadowaead/service.go b/shadowaead/service.go index f63bb6f..117e0bc 100644 --- a/shadowaead/service.go +++ b/shadowaead/service.go @@ -2,8 +2,6 @@ package shadowaead import ( "context" - "crypto/aes" - "crypto/cipher" "crypto/rand" "io" "net" @@ -20,8 +18,6 @@ import ( N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/udpnat" - - "golang.org/x/crypto/chacha20poly1305" ) var ErrBadHeader = E.New("bad header") @@ -29,47 +25,22 @@ var ErrBadHeader = E.New("bad header") var _ shadowsocks.Service = (*Service)(nil) type Service struct { - name string - keySaltLength int - constructor func(key []byte) (cipher.AEAD, error) - key []byte - password string - handler shadowsocks.Handler - udpNat *udpnat.Service[netip.AddrPort] + *Method + password string + handler shadowsocks.Handler + udpNat *udpnat.Service[netip.AddrPort] } func NewService(method string, key []byte, password string, udpTimeout int64, handler shadowsocks.Handler) (*Service, error) { + m, err := New(method, key, password) + if err != nil { + return nil, err + } s := &Service{ - name: method, + Method: m, handler: handler, udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler), } - switch method { - case "aes-128-gcm": - s.keySaltLength = 16 - s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) - case "aes-192-gcm": - s.keySaltLength = 24 - s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) - case "aes-256-gcm": - s.keySaltLength = 32 - s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) - case "chacha20-ietf-poly1305": - s.keySaltLength = 32 - s.constructor = chacha20poly1305.New - case "xchacha20-ietf-poly1305": - s.keySaltLength = 32 - s.constructor = chacha20poly1305.NewX - } - if len(key) == s.keySaltLength { - s.key = key - } else if len(key) > 0 { - return nil, shadowsocks.ErrBadKey - } else if password != "" { - s.key = shadowsocks.Key([]byte(password), s.keySaltLength) - } else { - return nil, shadowsocks.ErrMissingPassword - } return s, nil } @@ -95,7 +66,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M header := common.Dup(_header) defer header.Release() - _, err := header.ReadOnceFrom(conn) + _, err := header.ReadFullFrom(conn, header.FreeLen()) if err != nil { return E.Cause(err, "read header") } else if !header.IsFull() { @@ -127,9 +98,9 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M metadata.Destination = destination return s.handler.NewConnection(ctx, deadline.NewConn(&serverConn{ - Service: s, - Conn: conn, - reader: reader, + Method: s.Method, + Conn: conn, + reader: reader, }), metadata) } @@ -138,7 +109,7 @@ func (s *Service) NewError(ctx context.Context, err error) { } type serverConn struct { - *Service + *Method net.Conn access sync.Mutex reader *Reader @@ -218,11 +189,11 @@ func (c *serverConn) Upstream() any { return c.Conn } -func (s *Service) ReaderMTU() int { +func (c *serverConn) ReaderMTU() int { return MaxPacketSize } -func (s *Service) WriteIsThreadUnsafe() { +func (c *serverConn) WriteIsThreadUnsafe() { } func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { @@ -261,13 +232,13 @@ func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf. metadata.Protocol = "shadowsocks" metadata.Destination = destination s.udpNat.NewPacket(ctx, metadata.Source.AddrPort(), buffer, metadata, func(natConn N.PacketConn) N.PacketWriter { - return &serverPacketWriter{s, conn, natConn} + return &serverPacketWriter{s.Method, conn, natConn} }) return nil } type serverPacketWriter struct { - *Service + *Method source N.PacketConn nat N.PacketConn } @@ -309,3 +280,10 @@ func (w *serverPacketWriter) WriterMTU() int { func (w *serverPacketWriter) Upstream() any { return w.source } + +func (w *serverPacketWriter) ReaderMTU() int { + return MaxPacketSize +} + +func (w *serverPacketWriter) WriteIsThreadUnsafe() { +} diff --git a/shadowaead/service_multi.go b/shadowaead/service_multi.go new file mode 100644 index 0000000..df94782 --- /dev/null +++ b/shadowaead/service_multi.go @@ -0,0 +1,208 @@ +package shadowaead + +import ( + "context" + "crypto/cipher" + "io" + "net" + "net/netip" + + "github.com/sagernet/sing-shadowsocks" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio/deadline" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/udpnat" +) + +var _ shadowsocks.MultiService[int] = (*MultiService[int])(nil) + +type MultiService[U comparable] struct { + name string + methodMap map[U]*Method + handler shadowsocks.Handler + udpNat *udpnat.Service[netip.AddrPort] +} + +func NewMultiService[U comparable](method string, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) { + s := &MultiService[U]{ + name: method, + handler: handler, + udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler), + } + return s, nil +} + +func (s *MultiService[U]) Name() string { + return s.name +} + +func (s *MultiService[U]) UpdateUsers(userList []U, keyList [][]byte) error { + s.methodMap = make(map[U]*Method) + for i, user := range userList { + key := keyList[i] + method, err := New(s.name, key, "") + if err != nil { + return err + } + s.methodMap[user] = method + } + return nil +} + +func (s *MultiService[U]) UpdateUsersWithPasswords(userList []U, passwordList []string) error { + s.methodMap = make(map[U]*Method) + for i, user := range userList { + password := passwordList[i] + method, err := New(s.name, nil, password) + if err != nil { + return err + } + s.methodMap[user] = method + } + return nil +} + +func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + err := s.newConnection(ctx, conn, metadata) + if err != nil { + err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} + } + return err +} + +func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + var user U + var method *Method + for u, m := range s.methodMap { + user, method = u, m + break + } + if method == nil { + return shadowsocks.ErrNoUsers + } + _header := buf.StackNewSize(method.keySaltLength + PacketLengthBufferSize + Overhead) + defer common.KeepAlive(_header) + header := common.Dup(_header) + defer header.Release() + + _, err := header.ReadFullFrom(conn, header.FreeLen()) + if err != nil { + return E.Cause(err, "read header") + } else if !header.IsFull() { + return ErrBadHeader + } + + var reader *Reader + var readCipher cipher.AEAD + for u, m := range s.methodMap { + _key := buf.StackNewSize(method.keySaltLength) + key := common.Dup(_key) + Kdf(m.key, header.To(m.keySaltLength), key) + readCipher, err = m.constructor(key.Bytes()) + key.Release() + common.KeepAlive(_key) + if err != nil { + return err + } + reader = NewReader(conn, readCipher, MaxPacketSize) + + err = reader.ReadWithLengthChunk(header.From(method.keySaltLength)) + if err != nil { + continue + } + + user, method = u, m + break + } + if err != nil { + return err + } + + destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) + if err != nil { + return err + } + + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + + return s.handler.NewConnection(auth.ContextWithUser(ctx, user), deadline.NewConn(&serverConn{ + Method: method, + Conn: conn, + reader: reader, + }), metadata) +} + +func (s *MultiService[U]) WriteIsThreadUnsafe() { +} + +func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + err := s.newPacket(ctx, conn, buffer, metadata) + if err != nil { + err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} + } + return err +} + +func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + var user U + var method *Method + for u, m := range s.methodMap { + user, method = u, m + break + } + if method == nil { + return shadowsocks.ErrNoUsers + } + if buffer.Len() < method.keySaltLength { + return io.ErrShortBuffer + } + var readCipher cipher.AEAD + var err error + for u, m := range s.methodMap { + _key := buf.StackNewSize(m.keySaltLength) + key := common.Dup(_key) + Kdf(m.key, buffer.To(m.keySaltLength), key) + readCipher, err = m.constructor(key.Bytes()) + key.Release() + common.KeepAlive(_key) + if err != nil { + return err + } + var packet []byte + packet, err = readCipher.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(m.keySaltLength), nil) + if err != nil { + continue + } + + buffer.Advance(m.keySaltLength) + buffer.Truncate(len(packet)) + + user, method = u, m + break + } + if err != nil { + return err + } + + destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return err + } + + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + s.udpNat.NewPacket(auth.ContextWithUser(ctx, user), metadata.Source.AddrPort(), buffer, metadata, func(natConn N.PacketConn) N.PacketWriter { + return &serverPacketWriter{method, conn, natConn} + }) + return nil +} + +func (s *MultiService[U]) NewError(ctx context.Context, err error) { + s.handler.NewError(ctx, err) +} diff --git a/shadowaead_2022/service_multi.go b/shadowaead_2022/service_multi.go index c60eca1..65f2b9c 100644 --- a/shadowaead_2022/service_multi.go +++ b/shadowaead_2022/service_multi.go @@ -26,6 +26,8 @@ import ( "lukechampine.com/blake3" ) +var _ shadowsocks.MultiService[int] = (*MultiService[int])(nil) + type MultiService[U comparable] struct { *Service diff --git a/shadowsocks.go b/shadowsocks.go index 84e4bb0..46479c7 100644 --- a/shadowsocks.go +++ b/shadowsocks.go @@ -14,6 +14,7 @@ import ( var ( ErrBadKey = E.New("bad key") ErrMissingPassword = E.New("missing password") + ErrNoUsers = E.New("no users") ) type Method interface { @@ -31,6 +32,15 @@ type Service interface { E.Handler } +type MultiService[U comparable] interface { + Name() string + UpdateUsers(userList []U, keyList [][]byte) error + UpdateUsersWithPasswords(userList []U, passwordList []string) error + N.TCPConnectionHandler + N.UDPHandler + E.Handler +} + type Handler interface { N.TCPConnectionHandler N.UDPConnectionHandler