diff --git a/go.mod b/go.mod index 8dc33a6..bab7de6 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.18 require ( github.com/eycorsican/go-tun2socks v1.16.11 + github.com/fsnotify/fsnotify v1.5.4 github.com/sagernet/netlink v0.0.0-20220803045538-bdac49abf805 github.com/sagernet/sing v0.0.0-20220814164830-4f2b872a8cbf golang.org/x/net v0.0.0-20220805013720-a33c5aa5df48 diff --git a/go.sum b/go.sum index 6840960..bf6f5a8 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/eycorsican/go-tun2socks v1.16.11 h1:+hJDNgisrYaGEqoSxhdikMgMJ4Ilfwm/IZDrWRrbaH8= github.com/eycorsican/go-tun2socks v1.16.11/go.mod h1:wgB2BFT8ZaPKyKOQ/5dljMG/YIow+AIXyq4KBwJ5sGQ= +github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= +github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/sagernet/netlink v0.0.0-20220803045538-bdac49abf805 h1:hE+vtsjBCCPmxkRz9jZA+CicHgVkDT6H+Av5ZzskVxs= @@ -16,6 +18,7 @@ golang.org/x/net v0.0.0-20220805013720-a33c5aa5df48/go.mod h1:YDH+HFinaLZZlnHAfS golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220804214406-8e32c043e418 h1:9vYwv7OjYaky/tlAeD7C4oC9EsPTlaFl1H2jS++V+ME= golang.org/x/sys v0.0.0-20220804214406-8e32c043e418/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/packages.go b/packages.go new file mode 100644 index 0000000..816948b --- /dev/null +++ b/packages.go @@ -0,0 +1,17 @@ +package tun + +import E "github.com/sagernet/sing/common/exceptions" + +type PackageManager interface { + Start() error + Close() error + IDByPackage(packageName string) ([]uint32, bool) + IDBySharedPackage(sharedPackage string) (uint32, bool) + PackageByID(id uint32) (string, bool) + SharedPackageByID(id uint32) (string, bool) +} + +type PackageManagerCallback interface { + OnPackagesUpdated(packages int, sharedUsers int) + E.Handler +} diff --git a/packages_android.go b/packages_android.go new file mode 100644 index 0000000..94f4fe4 --- /dev/null +++ b/packages_android.go @@ -0,0 +1,171 @@ +package tun + +import ( + "context" + "encoding/xml" + "io" + "os" + "strconv" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/fsnotify/fsnotify" +) + +type packageManager struct { + callback PackageManagerCallback + watcher *fsnotify.Watcher + idByPackage map[string][]uint32 + sharedByPackage map[string]uint32 + packageById map[uint32]string + sharedById map[uint32]string +} + +func NewPackageManager(callback PackageManagerCallback) (PackageManager, error) { + return &packageManager{callback: callback}, nil +} + +func (m *packageManager) Start() error { + err := m.updatePackages() + if err != nil { + return E.Cause(err, "read packages list") + } + err = m.startWatcher() + if err != nil { + m.callback.NewError(context.Background(), E.Cause(err, "create fsnotify watcher")) + } + return nil +} + +func (m *packageManager) startWatcher() error { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + err = watcher.Add("/data/system/packages.xml") + if err != nil { + return err + } + m.watcher = watcher + go m.loopUpdate() + return nil +} + +func (m *packageManager) loopUpdate() { + for { + select { + case _, ok := <-m.watcher.Events: + if !ok { + return + } + err := m.updatePackages() + if err != nil { + m.callback.NewError(context.Background(), E.Cause(err, "update packages")) + } + case err, ok := <-m.watcher.Errors: + if !ok { + return + } + m.callback.NewError(context.Background(), E.Cause(err, "fsnotify error")) + } + } +} + +func (m *packageManager) Close() error { + return common.Close(common.PtrOrNil(m.watcher)) +} + +func (m *packageManager) IDByPackage(packageName string) ([]uint32, bool) { + id, loaded := m.idByPackage[packageName] + return id, loaded +} + +func (m *packageManager) IDBySharedPackage(sharedPackage string) (uint32, bool) { + id, loaded := m.sharedByPackage[sharedPackage] + return id, loaded +} + +func (m *packageManager) PackageByID(id uint32) (string, bool) { + packageName, loaded := m.packageById[id] + return packageName, loaded +} + +func (m *packageManager) SharedPackageByID(id uint32) (string, bool) { + sharedPackage, loaded := m.sharedById[id] + return sharedPackage, loaded +} + +func (m *packageManager) updatePackages() error { + idByPackage := make(map[string][]uint32) + sharedByPackage := make(map[string]uint32) + packageById := make(map[uint32]string) + sharedById := make(map[uint32]string) + packagesData, err := os.Open("/data/system/packages.xml") + if err != nil { + return err + } + decoder := xml.NewDecoder(packagesData) + var token xml.Token + for { + token, err = decoder.Token() + if err == io.EOF { + break + } else if err != nil { + return err + } + + element, isStart := token.(xml.StartElement) + if !isStart { + continue + } + + switch element.Name.Local { + case "package": + var name string + var userID uint64 + for _, attr := range element.Attr { + switch attr.Name.Local { + case "name": + name = attr.Value + case "userId", "sharedUserId": + userID, err = strconv.ParseUint(attr.Value, 10, 32) + if err != nil { + return err + } + } + } + if userID == 0 && name == "" { + continue + } + idByPackage[name] = append(idByPackage[name], uint32(userID)) + packageById[uint32(userID)] = name + case "shared-user": + var name string + var userID uint64 + for _, attr := range element.Attr { + switch attr.Name.Local { + case "name": + name = attr.Value + case "userId": + userID, err = strconv.ParseUint(attr.Value, 10, 32) + if err != nil { + return err + } + packageById[uint32(userID)] = name + } + } + if userID == 0 && name == "" { + continue + } + sharedByPackage[name] = uint32(userID) + sharedById[uint32(userID)] = name + } + } + m.idByPackage = idByPackage + m.sharedByPackage = sharedByPackage + m.packageById = packageById + m.sharedById = sharedById + m.callback.OnPackagesUpdated(len(packageById), len(sharedById)) + return nil +} diff --git a/packages_stub.go b/packages_stub.go new file mode 100644 index 0000000..104454c --- /dev/null +++ b/packages_stub.go @@ -0,0 +1,9 @@ +//go:build !android + +package tun + +import "os" + +func NewPackageManager(callback PackageManagerCallback) (PackageManager, error) { + return nil, os.ErrInvalid +} diff --git a/tun.go b/tun.go index c394471..73b189e 100644 --- a/tun.go +++ b/tun.go @@ -5,11 +5,9 @@ import ( "net" "net/netip" "runtime" - "sort" "strconv" "strings" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" N "github.com/sagernet/sing/common/network" @@ -39,41 +37,10 @@ type Options struct { MTU uint32 AutoRoute bool IncludeUID []ranges.Range[uint32] - IncludeAndroidUser []int ExcludeUID []ranges.Range[uint32] -} - -func (o Options) ExcludedRanges() (uidRanges []ranges.Range[uint32]) { - var includeAndroidUser []int - if runtime.GOOS == "android" { - includeAndroidUser = o.IncludeAndroidUser - } - return buildExcludedRanges(o.IncludeUID, o.ExcludeUID, includeAndroidUser) -} - -const ( - androidUserRange = 100000 - userEnd uint32 = 0xFFFFFFFF - 1 -) - -func buildExcludedRanges(includeRanges []ranges.Range[uint32], excludeRanges []ranges.Range[uint32], includeAndroidUser []int) (uidRanges []ranges.Range[uint32]) { - if len(includeRanges) > 0 { - uidRanges = includeRanges - } - if len(includeAndroidUser) > 0 { - includeAndroidUser = common.Uniq(includeAndroidUser) - sort.Ints(includeAndroidUser) - for _, androidUser := range includeAndroidUser { - uidRanges = append(uidRanges, ranges.New[uint32](uint32(androidUser)*androidUserRange, uint32(androidUser+1)*androidUserRange-1)) - } - } - if len(uidRanges) > 0 { - uidRanges = ranges.Exclude(uidRanges, excludeRanges) - uidRanges = ranges.Revert(0, userEnd, uidRanges) - } else { - uidRanges = excludeRanges - } - return ranges.Merge(uidRanges) + IncludeAndroidUser []int + IncludePackage []string + ExcludePackage []string } func DefaultInterfaceName() (tunName string) { diff --git a/tun_rules.go b/tun_rules.go new file mode 100644 index 0000000..19bb402 --- /dev/null +++ b/tun_rules.go @@ -0,0 +1,72 @@ +package tun + +import ( + "context" + "sort" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/ranges" +) + +const ( + androidUserRange = 100000 + userEnd uint32 = 0xFFFFFFFF - 1 +) + +func (o *Options) BuildAndroidRules(packageManager PackageManager, errorHandler E.Handler) { + if len(o.IncludeAndroidUser) > 0 { + o.IncludeAndroidUser = common.Uniq(o.IncludeAndroidUser) + sort.Ints(o.IncludeAndroidUser) + for _, androidUser := range o.IncludeAndroidUser { + o.IncludeUID = append(o.IncludeUID, ranges.New[uint32](uint32(androidUser)*androidUserRange, uint32(androidUser+1)*androidUserRange-1)) + } + } + if len(o.IncludePackage) > 0 { + o.IncludePackage = common.Uniq(o.IncludePackage) + for _, packageName := range o.IncludePackage { + if sharedId, loaded := packageManager.IDBySharedPackage(packageName); loaded { + o.IncludeUID = append(o.IncludeUID, ranges.NewSingle(sharedId)) + continue + } + if ids, loaded := packageManager.IDByPackage(packageName); loaded { + for _, id := range ids { + o.IncludeUID = append(o.IncludeUID, ranges.NewSingle(id)) + } + continue + } + errorHandler.NewError(context.Background(), E.New("package to include not found: ", packageName)) + } + } + if len(o.ExcludePackage) > 0 { + o.ExcludePackage = common.Uniq(o.ExcludePackage) + for _, packageName := range o.ExcludePackage { + if sharedId, loaded := packageManager.IDBySharedPackage(packageName); loaded { + o.ExcludeUID = append(o.ExcludeUID, ranges.NewSingle(sharedId)) + continue + } + if ids, loaded := packageManager.IDByPackage(packageName); loaded { + for _, id := range ids { + o.ExcludeUID = append(o.ExcludeUID, ranges.NewSingle(id)) + } + continue + } + errorHandler.NewError(context.Background(), E.New("package to exclude not found: ", packageName)) + } + } +} + +func (o *Options) ExcludedRanges() (uidRanges []ranges.Range[uint32]) { + return buildExcludedRanges(o.IncludeUID, o.ExcludeUID) +} + +func buildExcludedRanges(includeRanges []ranges.Range[uint32], excludeRanges []ranges.Range[uint32]) (uidRanges []ranges.Range[uint32]) { + uidRanges = includeRanges + if len(uidRanges) > 0 { + uidRanges = ranges.Exclude(uidRanges, excludeRanges) + uidRanges = ranges.Revert(0, userEnd, uidRanges) + } else { + uidRanges = excludeRanges + } + return ranges.Merge(uidRanges) +}