Using netipx.IPSet safely

This commit is contained in:
wwqgtxx 2024-06-19 10:03:15 +08:00
parent ef83d1643c
commit 625ac412bb

View file

@ -4,7 +4,6 @@ package tun
import ( import (
"net/netip" "net/netip"
"unsafe"
"github.com/sagernet/nftables" "github.com/sagernet/nftables"
"github.com/sagernet/nftables/expr" "github.com/sagernet/nftables/expr"
@ -77,42 +76,34 @@ func nftablesCreateIPSet(
id uint32, name string, family nftables.TableFamily, id uint32, name string, family nftables.TableFamily,
setList []*netipx.IPSet, prefixList []netip.Prefix, appendDefault bool, update bool, setList []*netipx.IPSet, prefixList []netip.Prefix, appendDefault bool, update bool,
) (*nftables.Set, error) { ) (*nftables.Set, error) {
if len(prefixList) > 0 { var builder netipx.IPSetBuilder
var builder netipx.IPSetBuilder for _, prefix := range prefixList {
for _, prefix := range prefixList { builder.AddPrefix(prefix)
builder.AddPrefix(prefix)
}
ipSet, err := builder.IPSet()
if err != nil {
return nil, err
}
setList = append(setList, ipSet)
} }
ipSets := make([]*myIPSet, 0, len(setList))
var rangeLen int
for _, set := range setList { for _, set := range setList {
mySet := (*myIPSet)(unsafe.Pointer(set)) builder.AddSet(set)
ipSets = append(ipSets, mySet)
rangeLen += len(mySet.rr)
} }
setElements := make([]nftables.SetElement, 0, rangeLen) ipSet, err := builder.IPSet()
for _, mySet := range ipSets { if err != nil {
for _, rr := range mySet.rr { return nil, err
if (family == nftables.TableFamilyIPv4) != rr.from.Is4() { }
continue ipRanges := ipSet.Ranges()
} setElements := make([]nftables.SetElement, 0, len(ipRanges))
endAddr := rr.to.Next() for _, rr := range ipRanges {
if !endAddr.IsValid() { if (family == nftables.TableFamilyIPv4) != rr.From().Is4() {
endAddr = rr.from continue
}
setElements = append(setElements, nftables.SetElement{
Key: rr.from.AsSlice(),
})
setElements = append(setElements, nftables.SetElement{
Key: endAddr.AsSlice(),
IntervalEnd: true,
})
} }
endAddr := rr.To().Next()
if !endAddr.IsValid() {
endAddr = rr.From()
}
setElements = append(setElements, nftables.SetElement{
Key: rr.From().AsSlice(),
})
setElements = append(setElements, nftables.SetElement{
Key: endAddr.AsSlice(),
IntervalEnd: true,
})
} }
if len(prefixList) == 0 && appendDefault { if len(prefixList) == 0 && appendDefault {
if family == nftables.TableFamilyIPv4 { if family == nftables.TableFamilyIPv4 {
@ -179,12 +170,3 @@ func nftablesCreateIPSet(
} }
return mySet, nil return mySet, nil
} }
type myIPSet struct {
rr []myIPRange
}
type myIPRange struct {
from netip.Addr
to netip.Addr
}