Fix linux route

This commit is contained in:
世界 2022-07-15 16:49:34 +08:00
parent 9dc73c0bcc
commit 9968d2c8e9
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 489 additions and 103 deletions

52
monitor_android.go Normal file
View file

@ -0,0 +1,52 @@
package tun
import (
E "github.com/sagernet/sing/common/exceptions"
"github.com/vishvananda/netlink"
)
func (m *defaultInterfaceMonitor) checkUpdate() error {
ruleList, err := netlink.RuleList(netlink.FAMILY_ALL)
if err != nil {
return err
}
var defaultTableIndex int
for _, rule := range ruleList {
if rule.Mask == 0xFFFF {
defaultTableIndex = rule.Table
}
}
if defaultTableIndex == 0 {
return E.New("no route to internet")
}
routes, err := netlink.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{Table: defaultTableIndex}, netlink.RT_FILTER_TABLE)
if err != nil {
return err
}
for _, route := range routes {
var link netlink.Link
link, err = netlink.LinkByIndex(route.LinkIndex)
if err != nil {
return err
}
oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex
m.defaultInterfaceName = link.Attrs().Name
m.defaultInterfaceIndex = link.Attrs().Index
if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex {
return nil
}
m.callback()
return nil
}
return E.New("no route in the system table")
}

View file

@ -2,7 +2,6 @@ package tun
import (
"context"
"net"
"os"
"sync"
@ -122,41 +121,6 @@ func (m *defaultInterfaceMonitor) Close() error {
return nil
}
func (m *defaultInterfaceMonitor) checkUpdate() error {
routes, err := netlink.RouteList(nil, netlink.FAMILY_V4)
if err != nil {
return err
}
for _, route := range routes {
var link netlink.Link
link, err = netlink.LinkByIndex(route.LinkIndex)
if err != nil {
return err
}
if link.Attrs().Flags&net.FlagUp == 0 {
continue
}
if link.Type() == "tuntap" {
continue
}
oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex
m.defaultInterfaceName = link.Attrs().Name
m.defaultInterfaceIndex = link.Attrs().Index
if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex {
return nil
}
m.callback()
return nil
}
return E.New("no route to internet")
}
func (m *defaultInterfaceMonitor) DefaultInterfaceName() string {
return m.defaultInterfaceName
}

41
monitor_linux_default.go Normal file
View file

@ -0,0 +1,41 @@
//go:build linux && !android
package tun
import (
E "github.com/sagernet/sing/common/exceptions"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
func (m *defaultInterfaceMonitor) checkUpdate() error {
routes, err := netlink.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{Table: unix.RT_TABLE_MAIN}, netlink.RT_FILTER_TABLE)
if err != nil {
return err
}
for _, route := range routes {
if route.Dst != nil {
continue
}
var link netlink.Link
link, err = netlink.LinkByIndex(route.LinkIndex)
if err != nil {
return err
}
oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex
m.defaultInterfaceName = link.Attrs().Name
m.defaultInterfaceIndex = link.Attrs().Index
if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex {
return nil
}
m.callback()
return nil
}
return E.New("no route to internet")
}

213
netlink_linux.go Normal file
View file

@ -0,0 +1,213 @@
package tun
import (
"bytes"
"fmt"
"net/netip"
_ "unsafe"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netlink/nl"
"golang.org/x/sys/unix"
)
type Rule struct {
Priority int
Family int
Table int
Mark int
Mask int
TunID uint
Goto int
Src netip.Prefix
Dst netip.Prefix
Flow int
IifName string
OifName string
SuppressIfgroup int
SuppressPrefixLength int
Invert bool
IPProtocol int
SrcPortRange *RulePortRange
DstPortRange *RulePortRange
UIDRange *RuleUIDRange
}
func NewRule() *Rule {
return &Rule{
SuppressIfgroup: -1,
SuppressPrefixLength: -1,
Priority: -1,
Mark: -1,
Mask: -1,
Goto: -1,
Flow: -1,
IPProtocol: -1,
}
}
//go:linkname pkgHandle github.com/vishvananda/netlink.pkgHandle
var pkgHandle *netlink.Handle
//go:linkname newNetlinkRequest github.com/vishvananda/netlink.(*Handle).newNetlinkRequest
func newNetlinkRequest(h *netlink.Handle, proto, flags int) *nl.NetlinkRequest
func RuleAdd(rule *Rule) error {
req := newNetlinkRequest(pkgHandle, unix.RTM_NEWRULE, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
return ruleHandle(rule, req)
}
func RuleDel(rule *Rule) error {
req := newNetlinkRequest(pkgHandle, unix.RTM_DELRULE, unix.NLM_F_ACK)
return ruleHandle(rule, req)
}
type RulePortRange struct {
Start uint16
End uint16
}
func (pr *RulePortRange) toRtAttrData() []byte {
native := nl.NativeEndian()
b := [][]byte{make([]byte, 2), make([]byte, 2)}
native.PutUint16(b[0], pr.Start)
native.PutUint16(b[1], pr.End)
return bytes.Join(b, []byte{})
}
type RuleUIDRange struct {
Start uint32
End uint32
}
func (pr *RuleUIDRange) toRtAttrData() []byte {
native := nl.NativeEndian()
b := [][]byte{make([]byte, 4), make([]byte, 4)}
native.PutUint32(b[0], pr.Start)
native.PutUint32(b[1], pr.End)
return bytes.Join(b, []byte{})
}
func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
msg := nl.NewRtMsg()
msg.Family = unix.AF_INET
msg.Protocol = unix.RTPROT_BOOT
msg.Scope = unix.RT_SCOPE_UNIVERSE
msg.Table = unix.RT_TABLE_UNSPEC
msg.Type = unix.RTN_UNSPEC
if rule.Table >= 256 {
msg.Type = unix.FR_ACT_TO_TBL
} else if rule.Goto >= 0 {
msg.Type = unix.FR_ACT_GOTO
} else if req.NlMsghdr.Flags&unix.NLM_F_CREATE > 0 {
msg.Type = unix.FR_ACT_NOP
}
if rule.Invert {
msg.Flags |= netlink.FibRuleInvert
}
if rule.Family != 0 {
msg.Family = uint8(rule.Family)
}
if rule.Table >= 0 && rule.Table < 256 {
msg.Table = uint8(rule.Table)
}
var dstFamily uint8
var rtAttrs []*nl.RtAttr
if rule.Dst.IsValid() {
msg.Dst_len = uint8(rule.Dst.Bits())
msg.Family = uint8(nl.GetIPFamily(rule.Dst.Addr().AsSlice()))
dstFamily = msg.Family
rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, rule.Dst.Addr().AsSlice()))
}
if rule.Src.IsValid() {
msg.Src_len = uint8(rule.Src.Bits())
msg.Family = uint8(nl.GetIPFamily(rule.Src.Addr().AsSlice()))
if dstFamily != 0 && dstFamily != msg.Family {
return fmt.Errorf("source and destination ip are not the same IP family")
}
dstFamily = msg.Family
rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_SRC, rule.Src.Addr().AsSlice()))
}
req.AddData(msg)
for i := range rtAttrs {
req.AddData(rtAttrs[i])
}
native := nl.NativeEndian()
if rule.Priority >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Priority))
req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
}
if rule.Mark >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Mark))
req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
}
if rule.Mask >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Mask))
req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
}
if rule.Flow >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Flow))
req.AddData(nl.NewRtAttr(nl.FRA_FLOW, b))
}
if rule.TunID > 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.TunID))
req.AddData(nl.NewRtAttr(nl.FRA_TUN_ID, b))
}
if rule.Table >= 256 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Table))
req.AddData(nl.NewRtAttr(nl.FRA_TABLE, b))
}
if msg.Table > 0 {
if rule.SuppressPrefixLength >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.SuppressPrefixLength))
req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_PREFIXLEN, b))
}
if rule.SuppressIfgroup >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.SuppressIfgroup))
req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_IFGROUP, b))
}
}
if rule.IifName != "" {
req.AddData(nl.NewRtAttr(nl.FRA_IIFNAME, []byte(rule.IifName)))
}
if rule.OifName != "" {
req.AddData(nl.NewRtAttr(nl.FRA_OIFNAME, []byte(rule.OifName)))
}
if rule.Goto >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Goto))
req.AddData(nl.NewRtAttr(nl.FRA_GOTO, b))
}
if rule.IPProtocol >= 0 {
req.AddData(nl.NewRtAttr(unix.FRA_IP_PROTO, []byte{byte(rule.IPProtocol)}))
}
if rule.SrcPortRange != nil {
b := rule.SrcPortRange.toRtAttrData()
req.AddData(nl.NewRtAttr(unix.FRA_SPORT_RANGE, b))
}
if rule.DstPortRange != nil {
b := rule.DstPortRange.toRtAttrData()
req.AddData(nl.NewRtAttr(unix.FRA_DPORT_RANGE, b))
}
if rule.UIDRange != nil {
b := rule.UIDRange.toRtAttrData()
req.AddData(nl.NewRtAttr(unix.FRA_UID_RANGE, b))
}
_, err := req.Execute(unix.NETLINK_ROUTE, 0)
return err
}

View file

@ -1,16 +1,13 @@
package tun
import (
"math"
"net"
"net/netip"
"runtime"
"syscall"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/link/tun"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@ -18,11 +15,11 @@ import (
type NativeTun struct {
name string
fd int
inet4Address netip.Prefix
inet6Address netip.Prefix
mtu uint32
autoRoute bool
fdList []int
}
func Open(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, mtu uint32, autoRoute bool) (Tun, error) {
@ -30,49 +27,31 @@ func Open(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, mtu
if err != nil {
return nil, err
}
tunLink, err := netlink.LinkByName(name)
if err != nil {
return nil, E.Errors(err, unix.Close(tunFd))
}
nativeTun := &NativeTun{
name: name,
fdList: []int{tunFd},
fd: tunFd,
mtu: mtu,
inet4Address: inet4Address,
inet6Address: inet6Address,
autoRoute: autoRoute,
}
err = nativeTun.configure()
err = nativeTun.configure(tunLink)
if err != nil {
return nil, E.Errors(err, syscall.Close(tunFd))
return nil, E.Errors(err, unix.Close(tunFd))
}
return nativeTun, nil
}
func (t *NativeTun) routes(tunLink netlink.Link) []netlink.Route {
var routes []netlink.Route
if t.inet4Address.IsValid() {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: net.IPv4zero,
Mask: net.CIDRMask(0, 32),
},
LinkIndex: tunLink.Attrs().Index,
})
}
if t.inet6Address.IsValid() {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: net.IPv6zero,
Mask: net.CIDRMask(0, 128),
},
LinkIndex: tunLink.Attrs().Index,
})
}
return routes
}
func (t *NativeTun) configure() error {
tunLink, err := netlink.LinkByName(t.name)
func (t *NativeTun) configure(tunLink netlink.Link) error {
err := netlink.LinkSetMTU(tunLink, int(t.mtu))
if err != nil {
return err
}
if t.inet4Address.IsValid() {
addr4, _ := netlink.ParseAddr(t.inet4Address.String())
err = netlink.AddrAdd(tunLink, addr4)
@ -89,61 +68,198 @@ func (t *NativeTun) configure() error {
}
}
err = netlink.LinkSetMTU(tunLink, int(t.mtu))
if err != nil {
return err
}
err = netlink.LinkSetUp(tunLink)
if err != nil {
return err
}
if t.autoRoute {
for _, route := range t.routes(tunLink) {
err = netlink.RouteAdd(&route)
if err != nil {
return err
}
err = t.setRoute(tunLink)
if err != nil {
_ = t.unsetRoute0(tunLink)
return err
}
}
return nil
}
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
var packetDispatchMode fdbased.PacketDispatchMode
if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" {
packetDispatchMode = fdbased.PacketMMap
} else {
packetDispatchMode = fdbased.RecvMMsg
}
dupFdSize := int(math.Max(float64(runtime.NumCPU()/2), 1)) - 1
for i := 0; i < dupFdSize; i++ {
dupFd, err := syscall.Dup(t.fdList[0])
if err != nil {
return nil, err
}
t.fdList = append(t.fdList, dupFd)
}
return fdbased.New(&fdbased.Options{
FDs: t.fdList,
MTU: t.mtu,
PacketDispatchMode: packetDispatchMode,
FDs: []int{t.fd},
MTU: t.mtu,
})
}
func (t *NativeTun) Close() error {
var errors []error
if t.autoRoute {
errors = append(errors, t.unsetRoute())
}
errors = append(errors, unix.Close(t.fd))
return E.Errors(errors...)
}
const tunTableIndex = 2022
func (t *NativeTun) routes(tunLink netlink.Link) []netlink.Route {
var routes []netlink.Route
if t.inet4Address.IsValid() {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: net.IPv4zero,
Mask: net.CIDRMask(0, 32),
},
LinkIndex: tunLink.Attrs().Index,
Table: tunTableIndex,
})
}
if t.inet6Address.IsValid() {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: net.IPv6zero,
Mask: net.CIDRMask(0, 128),
},
LinkIndex: tunLink.Attrs().Index,
Table: tunTableIndex,
})
}
return routes
}
func (t *NativeTun) rules() []*Rule {
var rules []*Rule
priority := 9000
it := NewRule()
it.Priority = priority
it.Invert = true
it.UIDRange = &RuleUIDRange{Start: 0, End: 0xFFFFFFFF - 1}
it.Goto = 9100
rules = append(rules, it)
priority++
if t.inet4Address.IsValid() {
it = NewRule()
it.Priority = priority
it.Dst = t.inet4Address.Masked()
it.Table = tunTableIndex
rules = append(rules, it)
priority++
it = NewRule()
it.Priority = priority
it.IPProtocol = unix.IPPROTO_ICMP
it.Goto = 9100
rules = append(rules, it)
priority++
}
if t.inet6Address.IsValid() {
it = NewRule()
it.Priority = priority
it.Dst = t.inet6Address.Masked()
it.Table = tunTableIndex
rules = append(rules, it)
priority++
it = NewRule()
it.Priority = priority
it.IPProtocol = unix.IPPROTO_ICMPV6
it.Goto = 9100
rules = append(rules, it)
priority++
}
it = NewRule()
it.Priority = priority
it.Invert = true
it.DstPortRange = &RulePortRange{Start: 53, End: 53}
it.Table = unix.RT_TABLE_MAIN
it.SuppressPrefixLength = 0
rules = append(rules, it)
priority++
it = NewRule()
it.Priority = priority
it.Invert = true
it.IifName = "lo"
it.Table = tunTableIndex
rules = append(rules, it)
priority++
it = NewRule()
it.Priority = priority
it.IifName = "lo"
it.Src = netip.PrefixFrom(netip.IPv4Unspecified(), 32)
it.Table = tunTableIndex
rules = append(rules, it)
priority++
if t.inet4Address.IsValid() {
it = NewRule()
it.Priority = priority
it.IifName = "lo"
it.Src = t.inet4Address.Masked()
it.Table = tunTableIndex
rules = append(rules, it)
priority++
}
if t.inet6Address.IsValid() {
it = NewRule()
it.Priority = priority
it.IifName = "lo"
it.Src = t.inet6Address.Masked()
it.Table = tunTableIndex
rules = append(rules, it)
priority++
}
it = NewRule()
it.Priority = 9100
rules = append(rules, it)
return rules
}
func (t *NativeTun) setRoute(tunLink netlink.Link) error {
for i, route := range t.routes(tunLink) {
err := netlink.RouteAdd(&route)
if err != nil {
return E.Cause(err, "add route ", i)
}
}
for i, rule := range t.rules() {
err := RuleAdd(rule)
if err != nil {
return E.Cause(err, "add rule ", i, "/", len(t.rules()))
}
}
return nil
}
func (t *NativeTun) unsetRoute() error {
tunLink, err := netlink.LinkByName(t.name)
if err != nil {
return err
}
if t.autoRoute {
for _, route := range t.routes(tunLink) {
err = netlink.RouteDel(&route)
if err != nil {
return err
}
return t.unsetRoute0(tunLink)
}
func (t *NativeTun) unsetRoute0(tunLink netlink.Link) error {
var errors []error
for _, route := range t.routes(tunLink) {
err := netlink.RouteDel(&route)
if err != nil {
errors = append(errors, err)
}
}
return E.Errors(common.Map(t.fdList, syscall.Close)...)
for _, rule := range t.rules() {
err := RuleDel(rule)
if err != nil {
errors = append(errors, err)
}
}
return E.Errors(errors...)
}