diff --git a/monitor_android.go b/monitor_android.go new file mode 100644 index 0000000..1805ad4 --- /dev/null +++ b/monitor_android.go @@ -0,0 +1,52 @@ +package tun + +import ( + E "github.com/sagernet/sing/common/exceptions" + + "github.com/vishvananda/netlink" +) + +func (m *defaultInterfaceMonitor) checkUpdate() error { + ruleList, err := netlink.RuleList(netlink.FAMILY_ALL) + if err != nil { + return err + } + + var defaultTableIndex int + for _, rule := range ruleList { + if rule.Mask == 0xFFFF { + defaultTableIndex = rule.Table + } + } + + if defaultTableIndex == 0 { + return E.New("no route to internet") + } + + routes, err := netlink.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{Table: defaultTableIndex}, netlink.RT_FILTER_TABLE) + if err != nil { + return err + } + + for _, route := range routes { + var link netlink.Link + link, err = netlink.LinkByIndex(route.LinkIndex) + if err != nil { + return err + } + + oldInterface := m.defaultInterfaceName + oldIndex := m.defaultInterfaceIndex + + m.defaultInterfaceName = link.Attrs().Name + m.defaultInterfaceIndex = link.Attrs().Index + + if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex { + return nil + } + m.callback() + return nil + } + + return E.New("no route in the system table") +} diff --git a/monitor_linux.go b/monitor_linux.go index e2b6b55..be5d962 100644 --- a/monitor_linux.go +++ b/monitor_linux.go @@ -2,7 +2,6 @@ package tun import ( "context" - "net" "os" "sync" @@ -122,41 +121,6 @@ func (m *defaultInterfaceMonitor) Close() error { return nil } -func (m *defaultInterfaceMonitor) checkUpdate() error { - routes, err := netlink.RouteList(nil, netlink.FAMILY_V4) - if err != nil { - return err - } - for _, route := range routes { - var link netlink.Link - link, err = netlink.LinkByIndex(route.LinkIndex) - if err != nil { - return err - } - - if link.Attrs().Flags&net.FlagUp == 0 { - continue - } - - if link.Type() == "tuntap" { - continue - } - - oldInterface := m.defaultInterfaceName - oldIndex := m.defaultInterfaceIndex - - m.defaultInterfaceName = link.Attrs().Name - m.defaultInterfaceIndex = link.Attrs().Index - - if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex { - return nil - } - m.callback() - return nil - } - return E.New("no route to internet") -} - func (m *defaultInterfaceMonitor) DefaultInterfaceName() string { return m.defaultInterfaceName } diff --git a/monitor_linux_default.go b/monitor_linux_default.go new file mode 100644 index 0000000..27456be --- /dev/null +++ b/monitor_linux_default.go @@ -0,0 +1,41 @@ +//go:build linux && !android + +package tun + +import ( + E "github.com/sagernet/sing/common/exceptions" + + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" +) + +func (m *defaultInterfaceMonitor) checkUpdate() error { + routes, err := netlink.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{Table: unix.RT_TABLE_MAIN}, netlink.RT_FILTER_TABLE) + if err != nil { + return err + } + for _, route := range routes { + if route.Dst != nil { + continue + } + + var link netlink.Link + link, err = netlink.LinkByIndex(route.LinkIndex) + if err != nil { + return err + } + + oldInterface := m.defaultInterfaceName + oldIndex := m.defaultInterfaceIndex + + m.defaultInterfaceName = link.Attrs().Name + m.defaultInterfaceIndex = link.Attrs().Index + + if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex { + return nil + } + m.callback() + return nil + } + return E.New("no route to internet") +} diff --git a/netlink_linux.go b/netlink_linux.go new file mode 100644 index 0000000..752f400 --- /dev/null +++ b/netlink_linux.go @@ -0,0 +1,213 @@ +package tun + +import ( + "bytes" + "fmt" + "net/netip" + _ "unsafe" + + "github.com/vishvananda/netlink" + "github.com/vishvananda/netlink/nl" + "golang.org/x/sys/unix" +) + +type Rule struct { + Priority int + Family int + Table int + Mark int + Mask int + TunID uint + Goto int + Src netip.Prefix + Dst netip.Prefix + Flow int + IifName string + OifName string + SuppressIfgroup int + SuppressPrefixLength int + Invert bool + + IPProtocol int + SrcPortRange *RulePortRange + DstPortRange *RulePortRange + UIDRange *RuleUIDRange +} + +func NewRule() *Rule { + return &Rule{ + SuppressIfgroup: -1, + SuppressPrefixLength: -1, + Priority: -1, + Mark: -1, + Mask: -1, + Goto: -1, + Flow: -1, + IPProtocol: -1, + } +} + +//go:linkname pkgHandle github.com/vishvananda/netlink.pkgHandle +var pkgHandle *netlink.Handle + +//go:linkname newNetlinkRequest github.com/vishvananda/netlink.(*Handle).newNetlinkRequest +func newNetlinkRequest(h *netlink.Handle, proto, flags int) *nl.NetlinkRequest + +func RuleAdd(rule *Rule) error { + req := newNetlinkRequest(pkgHandle, unix.RTM_NEWRULE, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK) + return ruleHandle(rule, req) +} + +func RuleDel(rule *Rule) error { + req := newNetlinkRequest(pkgHandle, unix.RTM_DELRULE, unix.NLM_F_ACK) + return ruleHandle(rule, req) +} + +type RulePortRange struct { + Start uint16 + End uint16 +} + +func (pr *RulePortRange) toRtAttrData() []byte { + native := nl.NativeEndian() + b := [][]byte{make([]byte, 2), make([]byte, 2)} + native.PutUint16(b[0], pr.Start) + native.PutUint16(b[1], pr.End) + return bytes.Join(b, []byte{}) +} + +type RuleUIDRange struct { + Start uint32 + End uint32 +} + +func (pr *RuleUIDRange) toRtAttrData() []byte { + native := nl.NativeEndian() + b := [][]byte{make([]byte, 4), make([]byte, 4)} + native.PutUint32(b[0], pr.Start) + native.PutUint32(b[1], pr.End) + return bytes.Join(b, []byte{}) +} + +func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error { + msg := nl.NewRtMsg() + msg.Family = unix.AF_INET + msg.Protocol = unix.RTPROT_BOOT + msg.Scope = unix.RT_SCOPE_UNIVERSE + msg.Table = unix.RT_TABLE_UNSPEC + msg.Type = unix.RTN_UNSPEC + if rule.Table >= 256 { + msg.Type = unix.FR_ACT_TO_TBL + } else if rule.Goto >= 0 { + msg.Type = unix.FR_ACT_GOTO + } else if req.NlMsghdr.Flags&unix.NLM_F_CREATE > 0 { + msg.Type = unix.FR_ACT_NOP + } + if rule.Invert { + msg.Flags |= netlink.FibRuleInvert + } + if rule.Family != 0 { + msg.Family = uint8(rule.Family) + } + if rule.Table >= 0 && rule.Table < 256 { + msg.Table = uint8(rule.Table) + } + + var dstFamily uint8 + var rtAttrs []*nl.RtAttr + + if rule.Dst.IsValid() { + msg.Dst_len = uint8(rule.Dst.Bits()) + msg.Family = uint8(nl.GetIPFamily(rule.Dst.Addr().AsSlice())) + dstFamily = msg.Family + rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, rule.Dst.Addr().AsSlice())) + } + + if rule.Src.IsValid() { + msg.Src_len = uint8(rule.Src.Bits()) + msg.Family = uint8(nl.GetIPFamily(rule.Src.Addr().AsSlice())) + if dstFamily != 0 && dstFamily != msg.Family { + return fmt.Errorf("source and destination ip are not the same IP family") + } + dstFamily = msg.Family + rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_SRC, rule.Src.Addr().AsSlice())) + } + + req.AddData(msg) + for i := range rtAttrs { + req.AddData(rtAttrs[i]) + } + + native := nl.NativeEndian() + + if rule.Priority >= 0 { + b := make([]byte, 4) + native.PutUint32(b, uint32(rule.Priority)) + req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b)) + } + if rule.Mark >= 0 { + b := make([]byte, 4) + native.PutUint32(b, uint32(rule.Mark)) + req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b)) + } + if rule.Mask >= 0 { + b := make([]byte, 4) + native.PutUint32(b, uint32(rule.Mask)) + req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b)) + } + if rule.Flow >= 0 { + b := make([]byte, 4) + native.PutUint32(b, uint32(rule.Flow)) + req.AddData(nl.NewRtAttr(nl.FRA_FLOW, b)) + } + if rule.TunID > 0 { + b := make([]byte, 4) + native.PutUint32(b, uint32(rule.TunID)) + req.AddData(nl.NewRtAttr(nl.FRA_TUN_ID, b)) + } + if rule.Table >= 256 { + b := make([]byte, 4) + native.PutUint32(b, uint32(rule.Table)) + req.AddData(nl.NewRtAttr(nl.FRA_TABLE, b)) + } + if msg.Table > 0 { + if rule.SuppressPrefixLength >= 0 { + b := make([]byte, 4) + native.PutUint32(b, uint32(rule.SuppressPrefixLength)) + req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_PREFIXLEN, b)) + } + if rule.SuppressIfgroup >= 0 { + b := make([]byte, 4) + native.PutUint32(b, uint32(rule.SuppressIfgroup)) + req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_IFGROUP, b)) + } + } + if rule.IifName != "" { + req.AddData(nl.NewRtAttr(nl.FRA_IIFNAME, []byte(rule.IifName))) + } + if rule.OifName != "" { + req.AddData(nl.NewRtAttr(nl.FRA_OIFNAME, []byte(rule.OifName))) + } + if rule.Goto >= 0 { + b := make([]byte, 4) + native.PutUint32(b, uint32(rule.Goto)) + req.AddData(nl.NewRtAttr(nl.FRA_GOTO, b)) + } + if rule.IPProtocol >= 0 { + req.AddData(nl.NewRtAttr(unix.FRA_IP_PROTO, []byte{byte(rule.IPProtocol)})) + } + if rule.SrcPortRange != nil { + b := rule.SrcPortRange.toRtAttrData() + req.AddData(nl.NewRtAttr(unix.FRA_SPORT_RANGE, b)) + } + if rule.DstPortRange != nil { + b := rule.DstPortRange.toRtAttrData() + req.AddData(nl.NewRtAttr(unix.FRA_DPORT_RANGE, b)) + } + if rule.UIDRange != nil { + b := rule.UIDRange.toRtAttrData() + req.AddData(nl.NewRtAttr(unix.FRA_UID_RANGE, b)) + } + _, err := req.Execute(unix.NETLINK_ROUTE, 0) + return err +} diff --git a/tun_linux.go b/tun_linux.go index 168fe1d..c0f0b30 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -1,16 +1,13 @@ package tun import ( - "math" "net" "net/netip" - "runtime" - "syscall" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/tun" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -18,11 +15,11 @@ import ( type NativeTun struct { name string + fd int inet4Address netip.Prefix inet6Address netip.Prefix mtu uint32 autoRoute bool - fdList []int } func Open(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, mtu uint32, autoRoute bool) (Tun, error) { @@ -30,49 +27,31 @@ func Open(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, mtu if err != nil { return nil, err } + tunLink, err := netlink.LinkByName(name) + if err != nil { + return nil, E.Errors(err, unix.Close(tunFd)) + } nativeTun := &NativeTun{ name: name, - fdList: []int{tunFd}, + fd: tunFd, mtu: mtu, inet4Address: inet4Address, inet6Address: inet6Address, autoRoute: autoRoute, } - err = nativeTun.configure() + err = nativeTun.configure(tunLink) if err != nil { - return nil, E.Errors(err, syscall.Close(tunFd)) + return nil, E.Errors(err, unix.Close(tunFd)) } return nativeTun, nil } -func (t *NativeTun) routes(tunLink netlink.Link) []netlink.Route { - var routes []netlink.Route - if t.inet4Address.IsValid() { - routes = append(routes, netlink.Route{ - Dst: &net.IPNet{ - IP: net.IPv4zero, - Mask: net.CIDRMask(0, 32), - }, - LinkIndex: tunLink.Attrs().Index, - }) - } - if t.inet6Address.IsValid() { - routes = append(routes, netlink.Route{ - Dst: &net.IPNet{ - IP: net.IPv6zero, - Mask: net.CIDRMask(0, 128), - }, - LinkIndex: tunLink.Attrs().Index, - }) - } - return routes -} - -func (t *NativeTun) configure() error { - tunLink, err := netlink.LinkByName(t.name) +func (t *NativeTun) configure(tunLink netlink.Link) error { + err := netlink.LinkSetMTU(tunLink, int(t.mtu)) if err != nil { return err } + if t.inet4Address.IsValid() { addr4, _ := netlink.ParseAddr(t.inet4Address.String()) err = netlink.AddrAdd(tunLink, addr4) @@ -89,61 +68,198 @@ func (t *NativeTun) configure() error { } } - err = netlink.LinkSetMTU(tunLink, int(t.mtu)) - if err != nil { - return err - } - err = netlink.LinkSetUp(tunLink) if err != nil { return err } if t.autoRoute { - for _, route := range t.routes(tunLink) { - err = netlink.RouteAdd(&route) - if err != nil { - return err - } + err = t.setRoute(tunLink) + if err != nil { + _ = t.unsetRoute0(tunLink) + return err } } return nil } func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) { - var packetDispatchMode fdbased.PacketDispatchMode - if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" { - packetDispatchMode = fdbased.PacketMMap - } else { - packetDispatchMode = fdbased.RecvMMsg - } - dupFdSize := int(math.Max(float64(runtime.NumCPU()/2), 1)) - 1 - for i := 0; i < dupFdSize; i++ { - dupFd, err := syscall.Dup(t.fdList[0]) - if err != nil { - return nil, err - } - t.fdList = append(t.fdList, dupFd) - } return fdbased.New(&fdbased.Options{ - FDs: t.fdList, - MTU: t.mtu, - PacketDispatchMode: packetDispatchMode, + FDs: []int{t.fd}, + MTU: t.mtu, }) } func (t *NativeTun) Close() error { + var errors []error + if t.autoRoute { + errors = append(errors, t.unsetRoute()) + } + errors = append(errors, unix.Close(t.fd)) + return E.Errors(errors...) +} + +const tunTableIndex = 2022 + +func (t *NativeTun) routes(tunLink netlink.Link) []netlink.Route { + var routes []netlink.Route + if t.inet4Address.IsValid() { + routes = append(routes, netlink.Route{ + Dst: &net.IPNet{ + IP: net.IPv4zero, + Mask: net.CIDRMask(0, 32), + }, + LinkIndex: tunLink.Attrs().Index, + Table: tunTableIndex, + }) + } + if t.inet6Address.IsValid() { + routes = append(routes, netlink.Route{ + Dst: &net.IPNet{ + IP: net.IPv6zero, + Mask: net.CIDRMask(0, 128), + }, + LinkIndex: tunLink.Attrs().Index, + Table: tunTableIndex, + }) + } + return routes +} + +func (t *NativeTun) rules() []*Rule { + var rules []*Rule + + priority := 9000 + + it := NewRule() + it.Priority = priority + it.Invert = true + it.UIDRange = &RuleUIDRange{Start: 0, End: 0xFFFFFFFF - 1} + it.Goto = 9100 + rules = append(rules, it) + priority++ + + if t.inet4Address.IsValid() { + it = NewRule() + it.Priority = priority + it.Dst = t.inet4Address.Masked() + it.Table = tunTableIndex + rules = append(rules, it) + priority++ + + it = NewRule() + it.Priority = priority + it.IPProtocol = unix.IPPROTO_ICMP + it.Goto = 9100 + rules = append(rules, it) + priority++ + } + + if t.inet6Address.IsValid() { + it = NewRule() + it.Priority = priority + it.Dst = t.inet6Address.Masked() + it.Table = tunTableIndex + rules = append(rules, it) + priority++ + + it = NewRule() + it.Priority = priority + it.IPProtocol = unix.IPPROTO_ICMPV6 + it.Goto = 9100 + rules = append(rules, it) + priority++ + } + + it = NewRule() + it.Priority = priority + it.Invert = true + it.DstPortRange = &RulePortRange{Start: 53, End: 53} + it.Table = unix.RT_TABLE_MAIN + it.SuppressPrefixLength = 0 + rules = append(rules, it) + priority++ + + it = NewRule() + it.Priority = priority + it.Invert = true + it.IifName = "lo" + it.Table = tunTableIndex + rules = append(rules, it) + priority++ + + it = NewRule() + it.Priority = priority + it.IifName = "lo" + it.Src = netip.PrefixFrom(netip.IPv4Unspecified(), 32) + it.Table = tunTableIndex + rules = append(rules, it) + priority++ + + if t.inet4Address.IsValid() { + it = NewRule() + it.Priority = priority + it.IifName = "lo" + it.Src = t.inet4Address.Masked() + it.Table = tunTableIndex + rules = append(rules, it) + priority++ + } + + if t.inet6Address.IsValid() { + it = NewRule() + it.Priority = priority + it.IifName = "lo" + it.Src = t.inet6Address.Masked() + it.Table = tunTableIndex + rules = append(rules, it) + priority++ + } + + it = NewRule() + it.Priority = 9100 + rules = append(rules, it) + + return rules +} + +func (t *NativeTun) setRoute(tunLink netlink.Link) error { + for i, route := range t.routes(tunLink) { + err := netlink.RouteAdd(&route) + if err != nil { + return E.Cause(err, "add route ", i) + } + } + for i, rule := range t.rules() { + err := RuleAdd(rule) + if err != nil { + return E.Cause(err, "add rule ", i, "/", len(t.rules())) + } + } + return nil +} + +func (t *NativeTun) unsetRoute() error { tunLink, err := netlink.LinkByName(t.name) if err != nil { return err } - if t.autoRoute { - for _, route := range t.routes(tunLink) { - err = netlink.RouteDel(&route) - if err != nil { - return err - } + return t.unsetRoute0(tunLink) +} + +func (t *NativeTun) unsetRoute0(tunLink netlink.Link) error { + var errors []error + for _, route := range t.routes(tunLink) { + err := netlink.RouteDel(&route) + if err != nil { + errors = append(errors, err) } } - return E.Errors(common.Map(t.fdList, syscall.Close)...) + for _, rule := range t.rules() { + err := RuleDel(rule) + if err != nil { + errors = append(errors, err) + } + } + return E.Errors(errors...) }