sing-tun/netlink_linux.go
2022-07-15 22:39:49 +08:00

213 lines
5.4 KiB
Go

package tun
import (
"bytes"
"fmt"
"net/netip"
_ "unsafe"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netlink/nl"
"golang.org/x/sys/unix"
)
type Rule struct {
Priority int
Family int
Table int
Mark int
Mask int
TunID uint
Goto int
Src netip.Prefix
Dst netip.Prefix
Flow int
IifName string
OifName string
SuppressIfgroup int
SuppressPrefixLength int
Invert bool
IPProtocol int
SrcPortRange *RulePortRange
DstPortRange *RulePortRange
UIDRange *RuleUIDRange
}
func NewRule() *Rule {
return &Rule{
SuppressIfgroup: -1,
SuppressPrefixLength: -1,
Priority: -1,
Mark: -1,
Mask: -1,
Goto: -1,
Flow: -1,
IPProtocol: -1,
}
}
//go:linkname pkgHandle github.com/vishvananda/netlink.pkgHandle
var pkgHandle *netlink.Handle
//go:linkname newNetlinkRequest github.com/vishvananda/netlink.(*Handle).newNetlinkRequest
func newNetlinkRequest(h *netlink.Handle, proto, flags int) *nl.NetlinkRequest
func RuleAdd(rule *Rule) error {
req := newNetlinkRequest(pkgHandle, unix.RTM_NEWRULE, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
return ruleHandle(rule, req)
}
func RuleDel(rule *Rule) error {
req := newNetlinkRequest(pkgHandle, unix.RTM_DELRULE, unix.NLM_F_ACK)
return ruleHandle(rule, req)
}
type RulePortRange struct {
Start uint16
End uint16
}
func (pr *RulePortRange) toRtAttrData() []byte {
native := nl.NativeEndian()
b := [][]byte{make([]byte, 2), make([]byte, 2)}
native.PutUint16(b[0], pr.Start)
native.PutUint16(b[1], pr.End)
return bytes.Join(b, []byte{})
}
type RuleUIDRange struct {
Start uint32
End uint32
}
func (pr *RuleUIDRange) toRtAttrData() []byte {
native := nl.NativeEndian()
b := [][]byte{make([]byte, 4), make([]byte, 4)}
native.PutUint32(b[0], pr.Start)
native.PutUint32(b[1], pr.End)
return bytes.Join(b, []byte{})
}
func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
msg := nl.NewRtMsg()
msg.Family = unix.AF_INET
msg.Protocol = unix.RTPROT_BOOT
msg.Scope = unix.RT_SCOPE_UNIVERSE
msg.Table = unix.RT_TABLE_UNSPEC
msg.Type = unix.RTN_UNSPEC
if rule.Table >= 256 {
msg.Type = unix.FR_ACT_TO_TBL
} else if rule.Goto >= 0 {
msg.Type = unix.FR_ACT_GOTO
} else if req.NlMsghdr.Flags&unix.NLM_F_CREATE > 0 {
msg.Type = unix.FR_ACT_NOP
}
if rule.Invert {
msg.Flags |= netlink.FibRuleInvert
}
if rule.Family != 0 {
msg.Family = uint8(rule.Family)
}
if rule.Table >= 0 && rule.Table < 256 {
msg.Table = uint8(rule.Table)
}
var dstFamily uint8
var rtAttrs []*nl.RtAttr
if rule.Dst.IsValid() {
msg.Dst_len = uint8(rule.Dst.Bits())
msg.Family = uint8(nl.GetIPFamily(rule.Dst.Addr().AsSlice()))
dstFamily = msg.Family
rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, rule.Dst.Addr().AsSlice()))
}
if rule.Src.IsValid() {
msg.Src_len = uint8(rule.Src.Bits())
msg.Family = uint8(nl.GetIPFamily(rule.Src.Addr().AsSlice()))
if dstFamily != 0 && dstFamily != msg.Family {
return fmt.Errorf("source and destination ip are not the same IP family")
}
dstFamily = msg.Family
rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_SRC, rule.Src.Addr().AsSlice()))
}
req.AddData(msg)
for i := range rtAttrs {
req.AddData(rtAttrs[i])
}
native := nl.NativeEndian()
if rule.Priority >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Priority))
req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
}
if rule.Mark >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Mark))
req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
}
if rule.Mask >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Mask))
req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
}
if rule.Flow >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Flow))
req.AddData(nl.NewRtAttr(nl.FRA_FLOW, b))
}
if rule.TunID > 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.TunID))
req.AddData(nl.NewRtAttr(nl.FRA_TUN_ID, b))
}
if rule.Table >= 256 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Table))
req.AddData(nl.NewRtAttr(nl.FRA_TABLE, b))
}
if msg.Table > 0 {
if rule.SuppressPrefixLength >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.SuppressPrefixLength))
req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_PREFIXLEN, b))
}
if rule.SuppressIfgroup >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.SuppressIfgroup))
req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_IFGROUP, b))
}
}
if rule.IifName != "" {
req.AddData(nl.NewRtAttr(nl.FRA_IIFNAME, []byte(rule.IifName)))
}
if rule.OifName != "" {
req.AddData(nl.NewRtAttr(nl.FRA_OIFNAME, []byte(rule.OifName)))
}
if rule.Goto >= 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Goto))
req.AddData(nl.NewRtAttr(nl.FRA_GOTO, b))
}
if rule.IPProtocol >= 0 {
req.AddData(nl.NewRtAttr(unix.FRA_IP_PROTO, []byte{byte(rule.IPProtocol)}))
}
if rule.SrcPortRange != nil {
b := rule.SrcPortRange.toRtAttrData()
req.AddData(nl.NewRtAttr(unix.FRA_SPORT_RANGE, b))
}
if rule.DstPortRange != nil {
b := rule.DstPortRange.toRtAttrData()
req.AddData(nl.NewRtAttr(unix.FRA_DPORT_RANGE, b))
}
if rule.UIDRange != nil {
b := rule.UIDRange.toRtAttrData()
req.AddData(nl.NewRtAttr(unix.FRA_UID_RANGE, b))
}
_, err := req.Execute(unix.NETLINK_ROUTE, 0)
return err
}