Fix linux route

This commit is contained in:
世界 2022-07-15 16:49:34 +08:00
parent 9dc73c0bcc
commit 9968d2c8e9
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 489 additions and 103 deletions

213
netlink_linux.go Normal file
View file

@ -0,0 +1,213 @@
package tun
import (
"bytes"
"fmt"
"net/netip"
_ "unsafe"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netlink/nl"
"golang.org/x/sys/unix"
)
type Rule struct {
Priority int
Family int
Table int
Mark int
Mask int
TunID uint
Goto int
Src netip.Prefix
Dst netip.Prefix
Flow int
IifName string
OifName string
SuppressIfgroup int
SuppressPrefixLength int
Invert bool
IPProtocol int
SrcPortRange *RulePortRange
DstPortRange *RulePortRange
UIDRange *RuleUIDRange
}
func NewRule() *Rule {
return &Rule{
SuppressIfgroup: -1,
SuppressPrefixLength: -1,
Priority: -1,
Mark: -1,
Mask: -1,
Goto: -1,
Flow: -1,
IPProtocol: -1,
}
}
//go:linkname pkgHandle github.com/vishvananda/netlink.pkgHandle
var pkgHandle *netlink.Handle
//go:linkname newNetlinkRequest github.com/vishvananda/netlink.(*Handle).newNetlinkRequest
func newNetlinkRequest(h *netlink.Handle, proto, flags int) *nl.NetlinkRequest
func RuleAdd(rule *Rule) error {
req := newNetlinkRequest(pkgHandle, unix.RTM_NEWRULE, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
return ruleHandle(rule, req)
}
func RuleDel(rule *Rule) error {
req := newNetlinkRequest(pkgHandle, unix.RTM_DELRULE, unix.NLM_F_ACK)
return ruleHandle(rule, req)
}
type RulePortRange struct {
Start uint16
End uint16
}
func (pr *RulePortRange) toRtAttrData() []byte {
native := nl.NativeEndian()
b := [][]byte{make([]byte, 2), make([]byte, 2)}
native.PutUint16(b[0], pr.Start)
native.PutUint16(b[1], pr.End)
return bytes.Join(b, []byte{})
}
type RuleUIDRange struct {
Start uint32
End uint32
}
func (pr *RuleUIDRange) toRtAttrData() []byte {
native := nl.NativeEndian()
b := [][]byte{make([]byte, 4), make([]byte, 4)}
native.PutUint32(b[0], pr.Start)
native.PutUint32(b[1], pr.End)
return bytes.Join(b, []byte{})
}
func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
msg := nl.NewRtMsg()
msg.Family = unix.AF_INET
msg.Protocol = unix.RTPROT_BOOT
msg.Scope = unix.RT_SCOPE_UNIVERSE
msg.Table = unix.RT_TABLE_UNSPEC
msg.Type = unix.RTN_UNSPEC
if rule.Table >= 256 {
msg.Type = unix.FR_ACT_TO_TBL
} else if rule.Goto >= 0 {
msg.Type = unix.FR_ACT_GOTO
} else if req.NlMsghdr.Flags&unix.NLM_F_CREATE > 0 {
msg.Type = unix.FR_ACT_NOP
}
if rule.Invert {
msg.Flags |= netlink.FibRuleInvert
}
if rule.Family != 0 {
msg.Family = uint8(rule.Family)
}
if rule.Table >= 0 && rule.Table < 256 {
msg.Table = uint8(rule.Table)
}
var dstFamily uint8
var rtAttrs []*nl.RtAttr
if rule.Dst.IsValid() {
msg.Dst_len = uint8(rule.Dst.Bits())
msg.Family = uint8(nl.GetIPFamily(rule.Dst.Addr().AsSlice()))
dstFamily = msg.Family
rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, rule.Dst.Addr().AsSlice()))
}
if rule.Src.IsValid() {
msg.Src_len = uint8(rule.Src.Bits())
msg.Family = uint8(nl.GetIPFamily(rule.Src.Addr().AsSlice()))
if dstFamily != 0 && dstFamily != msg.Family {
return fmt.Errorf("source and destination ip are not the same IP family")
}
dstFamily = msg.Family
rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_SRC, rule.Src.Addr().AsSlice()))
}
req.AddData(msg)
for i := range rtAttrs {
req.AddData(rtAttrs[i])
}
native := nl.NativeEndian()
if rule.Priority >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Priority))
req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
}
if rule.Mark >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Mark))
req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
}
if rule.Mask >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Mask))
req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
}
if rule.Flow >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Flow))
req.AddData(nl.NewRtAttr(nl.FRA_FLOW, b))
}
if rule.TunID > 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.TunID))
req.AddData(nl.NewRtAttr(nl.FRA_TUN_ID, b))
}
if rule.Table >= 256 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Table))
req.AddData(nl.NewRtAttr(nl.FRA_TABLE, b))
}
if msg.Table > 0 {
if rule.SuppressPrefixLength >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.SuppressPrefixLength))
req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_PREFIXLEN, b))
}
if rule.SuppressIfgroup >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.SuppressIfgroup))
req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_IFGROUP, b))
}
}
if rule.IifName != "" {
req.AddData(nl.NewRtAttr(nl.FRA_IIFNAME, []byte(rule.IifName)))
}
if rule.OifName != "" {
req.AddData(nl.NewRtAttr(nl.FRA_OIFNAME, []byte(rule.OifName)))
}
if rule.Goto >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Goto))
req.AddData(nl.NewRtAttr(nl.FRA_GOTO, b))
}
if rule.IPProtocol >= 0 {
req.AddData(nl.NewRtAttr(unix.FRA_IP_PROTO, []byte{byte(rule.IPProtocol)}))
}
if rule.SrcPortRange != nil {
b := rule.SrcPortRange.toRtAttrData()
req.AddData(nl.NewRtAttr(unix.FRA_SPORT_RANGE, b))
}
if rule.DstPortRange != nil {
b := rule.DstPortRange.toRtAttrData()
req.AddData(nl.NewRtAttr(unix.FRA_DPORT_RANGE, b))
}
if rule.UIDRange != nil {
b := rule.UIDRange.toRtAttrData()
req.AddData(nl.NewRtAttr(unix.FRA_UID_RANGE, b))
}
_, err := req.Execute(unix.NETLINK_ROUTE, 0)
return err
}