From a2f9fef93663efbf3a79efc2e3eb4a5b576debc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 26 Jul 2024 01:24:58 +0800 Subject: [PATCH] domain: Add adguard matcher --- common/domain/adguard_matcher_test.go | 67 ++++++++++ common/domain/adgurad_matcher.go | 172 ++++++++++++++++++++++++++ common/domain/matcher.go | 104 ++++++++-------- common/domain/matcher_test.go | 3 + common/domain/set.go | 75 +++++------ 5 files changed, 326 insertions(+), 95 deletions(-) create mode 100644 common/domain/adguard_matcher_test.go create mode 100644 common/domain/adgurad_matcher.go diff --git a/common/domain/adguard_matcher_test.go b/common/domain/adguard_matcher_test.go new file mode 100644 index 0000000..2a9ad70 --- /dev/null +++ b/common/domain/adguard_matcher_test.go @@ -0,0 +1,67 @@ +package domain_test + +import ( + "sort" + "testing" + + "github.com/sagernet/sing/common/domain" + + "github.com/stretchr/testify/require" +) + +func TestAdGuardMatcher(t *testing.T) { + t.Parallel() + ruleLines := []string{ + "||example.org^", + "|example.com^", + "example.net^", + "||example.edu", + "||example.edu.tw^", + "|example.gov", + "example.arpa", + } + matcher := domain.NewAdGuardMatcher(ruleLines) + require.NotNil(t, matcher) + matchDomain := []string{ + "example.org", + "www.example.org", + "example.com", + "example.net", + "isexample.net", + "www.example.net", + "example.edu", + "example.edu.cn", + "example.edu.tw", + "www.example.edu", + "www.example.edu.cn", + "example.gov", + "example.gov.cn", + "example.arpa", + "www.example.arpa", + "isexample.arpa", + "example.arpa.cn", + "www.example.arpa.cn", + "isexample.arpa.cn", + } + notMatchDomain := []string{ + "example.org.cn", + "notexample.org", + "example.com.cn", + "www.example.com.cn", + "example.net.cn", + "notexample.edu", + "notexample.edu.cn", + "www.example.gov", + "notexample.gov", + } + for _, domain := range matchDomain { + require.True(t, matcher.Match(domain), domain) + } + for _, domain := range notMatchDomain { + require.False(t, matcher.Match(domain), domain) + } + dLines := matcher.Dump() + sort.Strings(ruleLines) + sort.Strings(dLines) + require.Equal(t, ruleLines, dLines) +} diff --git a/common/domain/adgurad_matcher.go b/common/domain/adgurad_matcher.go new file mode 100644 index 0000000..e13a121 --- /dev/null +++ b/common/domain/adgurad_matcher.go @@ -0,0 +1,172 @@ +package domain + +import ( + "bytes" + "sort" + "strings" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/varbin" +) + +const ( + anyLabel = '*' + suffixLabel = '\b' +) + +type AdGuardMatcher struct { + set *succinctSet +} + +func NewAdGuardMatcher(ruleLines []string) *AdGuardMatcher { + ruleList := make([]string, 0, len(ruleLines)) + for _, ruleLine := range ruleLines { + var ( + isSuffix bool // || + hasStart bool // | + hasEnd bool // ^ + ) + if strings.HasPrefix(ruleLine, "||") { + ruleLine = ruleLine[2:] + isSuffix = true + } else if strings.HasPrefix(ruleLine, "|") { + ruleLine = ruleLine[1:] + hasStart = true + } + if strings.HasSuffix(ruleLine, "^") { + ruleLine = ruleLine[:len(ruleLine)-1] + hasEnd = true + } + if isSuffix { + ruleLine = string(rootLabel) + ruleLine + } else if !hasStart { + ruleLine = string(prefixLabel) + ruleLine + } + if !hasEnd { + if strings.HasSuffix(ruleLine, ".") { + ruleLine = ruleLine[:len(ruleLine)-1] + } + ruleLine += string(suffixLabel) + } + ruleList = append(ruleList, reverseDomain(ruleLine)) + } + ruleList = common.Uniq(ruleList) + sort.Strings(ruleList) + return &AdGuardMatcher{newSuccinctSet(ruleList)} +} + +func ReadAdGuardMatcher(reader varbin.Reader) (*AdGuardMatcher, error) { + set, err := readSuccinctSet(reader) + if err != nil { + return nil, err + } + return &AdGuardMatcher{set}, nil +} + +func (m *AdGuardMatcher) Write(writer varbin.Writer) error { + return m.set.Write(writer) +} + +func (m *AdGuardMatcher) Match(domain string) bool { + key := reverseDomain(domain) + if m.has([]byte(key), 0, 0) { + return true + } + for { + if m.has([]byte(string(suffixLabel)+key), 0, 0) { + return true + } + idx := strings.IndexByte(key, '.') + if idx == -1 { + return false + } + key = key[idx+1:] + } +} + +func (m *AdGuardMatcher) has(key []byte, nodeId, bmIdx int) bool { + for i := 0; i < len(key); i++ { + currentChar := key[i] + for ; ; bmIdx++ { + if getBit(m.set.labelBitmap, bmIdx) != 0 { + return false + } + nextLabel := m.set.labels[bmIdx-nodeId] + if nextLabel == prefixLabel { + return true + } + if nextLabel == rootLabel { + nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1) + hasNext := getBit(m.set.leaves, nextNodeId) != 0 + if currentChar == '.' && hasNext { + return true + } + } + if nextLabel == currentChar { + break + } + if nextLabel == anyLabel { + idx := bytes.IndexRune(key[i:], '.') + nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1) + if idx == -1 { + if getBit(m.set.leaves, nextNodeId) != 0 { + return true + } + idx = 0 + } + nextBmIdx := selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nextNodeId-1) + 1 + if m.has(key[i+idx:], nextNodeId, nextBmIdx) { + return true + } + } + } + nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1) + bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1 + } + if getBit(m.set.leaves, nodeId) != 0 { + return true + } + for ; ; bmIdx++ { + if getBit(m.set.labelBitmap, bmIdx) != 0 { + return false + } + nextLabel := m.set.labels[bmIdx-nodeId] + if nextLabel == prefixLabel || nextLabel == rootLabel { + return true + } + } +} + +func (m *AdGuardMatcher) Dump() (ruleLines []string) { + for _, key := range m.set.keys() { + key = reverseDomain(key) + var ( + isSuffix bool + hasStart bool + hasEnd bool + ) + if key[0] == prefixLabel { + key = key[1:] + } else if key[0] == rootLabel { + key = key[1:] + isSuffix = true + } else { + hasStart = true + } + if key[len(key)-1] == suffixLabel { + key = key[:len(key)-1] + } else { + hasEnd = true + } + if isSuffix { + key = "||" + key + } else if hasStart { + key = "|" + key + } + if hasEnd { + key += "^" + } + ruleLines = append(ruleLines, key) + } + return +} diff --git a/common/domain/matcher.go b/common/domain/matcher.go index 52c32a3..3407aa4 100644 --- a/common/domain/matcher.go +++ b/common/domain/matcher.go @@ -1,13 +1,17 @@ package domain import ( - "encoding/binary" "sort" "unicode/utf8" "github.com/sagernet/sing/common/varbin" ) +const ( + prefixLabel = '\r' + rootLabel = '\n' +) + type Matcher struct { set *succinctSet } @@ -21,16 +25,16 @@ func NewMatcher(domains []string, domainSuffix []string, generateLegacy bool) *M } seen[domain] = true if domain[0] == '.' { - domainList = append(domainList, reverseDomainSuffix(domain)) + domainList = append(domainList, reverseDomain(string(prefixLabel)+domain)) } else if generateLegacy { domainList = append(domainList, reverseDomain(domain)) suffixDomain := "." + domain if !seen[suffixDomain] { seen[suffixDomain] = true - domainList = append(domainList, reverseDomainSuffix(suffixDomain)) + domainList = append(domainList, reverseDomain(string(prefixLabel)+suffixDomain)) } } else { - domainList = append(domainList, reverseDomainRoot(domain)) + domainList = append(domainList, reverseDomain(string(rootLabel)+domain)) } } for _, domain := range domains { @@ -44,38 +48,60 @@ func NewMatcher(domains []string, domainSuffix []string, generateLegacy bool) *M return &Matcher{newSuccinctSet(domainList)} } -type matcherData struct { - Version uint8 - Leaves []uint64 - LabelBitmap []uint64 - Labels []byte -} - func ReadMatcher(reader varbin.Reader) (*Matcher, error) { - matcher, err := varbin.ReadValue[matcherData](reader, binary.BigEndian) + set, err := readSuccinctSet(reader) if err != nil { return nil, err } - set := &succinctSet{ - leaves: matcher.Leaves, - labelBitmap: matcher.LabelBitmap, - labels: matcher.Labels, - } - set.init() return &Matcher{set}, nil } -func (m *Matcher) Match(domain string) bool { - return m.set.Has(reverseDomain(domain)) +func (m *Matcher) Write(writer varbin.Writer) error { + return m.set.Write(writer) } -func (m *Matcher) Write(writer varbin.Writer) error { - return varbin.Write(writer, binary.BigEndian, matcherData{ - Version: 1, - Leaves: m.set.leaves, - LabelBitmap: m.set.labelBitmap, - Labels: m.set.labels, - }) +func (m *Matcher) Match(domain string) bool { + return m.has(reverseDomain(domain)) +} + +func (m *Matcher) has(key string) bool { + var nodeId, bmIdx int + for i := 0; i < len(key); i++ { + currentChar := key[i] + for ; ; bmIdx++ { + if getBit(m.set.labelBitmap, bmIdx) != 0 { + return false + } + nextLabel := m.set.labels[bmIdx-nodeId] + if nextLabel == prefixLabel { + return true + } + if nextLabel == rootLabel { + nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1) + hasNext := getBit(m.set.leaves, nextNodeId) != 0 + if currentChar == '.' && hasNext { + return true + } + } + if nextLabel == currentChar { + break + } + } + nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1) + bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1 + } + if getBit(m.set.leaves, nodeId) != 0 { + return true + } + for ; ; bmIdx++ { + if getBit(m.set.labelBitmap, bmIdx) != 0 { + return false + } + nextLabel := m.set.labels[bmIdx-nodeId] + if nextLabel == prefixLabel || nextLabel == rootLabel { + return true + } + } } func (m *Matcher) Dump() (domainList []string, prefixList []string) { @@ -119,27 +145,3 @@ func reverseDomain(domain string) string { } return string(b) } - -func reverseDomainSuffix(domain string) string { - l := len(domain) - b := make([]byte, l+1) - for i := 0; i < l; { - r, n := utf8.DecodeRuneInString(domain[i:]) - i += n - utf8.EncodeRune(b[l-i:], r) - } - b[l] = prefixLabel - return string(b) -} - -func reverseDomainRoot(domain string) string { - l := len(domain) - b := make([]byte, l+1) - for i := 0; i < l; { - r, n := utf8.DecodeRuneInString(domain[i:]) - i += n - utf8.EncodeRune(b[l-i:], r) - } - b[l] = rootLabel - return string(b) -} diff --git a/common/domain/matcher_test.go b/common/domain/matcher_test.go index 5d12e91..59b1368 100644 --- a/common/domain/matcher_test.go +++ b/common/domain/matcher_test.go @@ -12,6 +12,7 @@ import ( ) func TestMatcher(t *testing.T) { + t.Parallel() testDomain := []string{"example.com", "example.org"} testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"} matcher := domain.NewMatcher(testDomain, testDomainSuffix, false) @@ -31,6 +32,7 @@ func TestMatcher(t *testing.T) { } func TestMatcherLegacy(t *testing.T) { + t.Parallel() testDomain := []string{"example.com", "example.org"} testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"} matcher := domain.NewMatcher(testDomain, testDomainSuffix, true) @@ -57,6 +59,7 @@ type simpleRuleSet struct { } func TestDumpLarge(t *testing.T) { + t.Parallel() response, err := http.Get("https://raw.githubusercontent.com/MetaCubeX/meta-rules-dat/sing/geo/geosite/cn.json") require.NoError(t, err) defer response.Body.Close() diff --git a/common/domain/set.go b/common/domain/set.go index 2072e1d..c952f6c 100644 --- a/common/domain/set.go +++ b/common/domain/set.go @@ -1,12 +1,10 @@ package domain import ( + "encoding/binary" "math/bits" -) -const ( - prefixLabel = '\r' - rootLabel = '\n' + "github.com/sagernet/sing/common/varbin" ) // mod from https://github.com/openacid/succinct @@ -45,46 +43,6 @@ func newSuccinctSet(keys []string) *succinctSet { return ss } -func (ss *succinctSet) Has(key string) bool { - var nodeId, bmIdx int - for i := 0; i < len(key); i++ { - currentChar := key[i] - for ; ; bmIdx++ { - if getBit(ss.labelBitmap, bmIdx) != 0 { - return false - } - nextLabel := ss.labels[bmIdx-nodeId] - if nextLabel == prefixLabel { - return true - } - if nextLabel == rootLabel { - nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1) - hasNext := getBit(ss.leaves, nextNodeId) != 0 - if currentChar == '.' && hasNext { - return true - } - } - if nextLabel == currentChar { - break - } - } - nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1) - bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1 - } - if getBit(ss.leaves, nodeId) != 0 { - return true - } - for ; ; bmIdx++ { - if getBit(ss.labelBitmap, bmIdx) != 0 { - return false - } - nextLabel := ss.labels[bmIdx-nodeId] - if nextLabel == prefixLabel || nextLabel == rootLabel { - return true - } - } -} - func (ss *succinctSet) keys() []string { var result []string var currentKey []byte @@ -113,6 +71,35 @@ func (ss *succinctSet) keys() []string { return result } +type succinctSetData struct { + Reserved uint8 + Leaves []uint64 + LabelBitmap []uint64 + Labels []byte +} + +func readSuccinctSet(reader varbin.Reader) (*succinctSet, error) { + matcher, err := varbin.ReadValue[succinctSetData](reader, binary.BigEndian) + if err != nil { + return nil, err + } + set := &succinctSet{ + leaves: matcher.Leaves, + labelBitmap: matcher.LabelBitmap, + labels: matcher.Labels, + } + set.init() + return set, nil +} + +func (ss *succinctSet) Write(writer varbin.Writer) error { + return varbin.Write(writer, binary.BigEndian, succinctSetData{ + Leaves: ss.leaves, + LabelBitmap: ss.labelBitmap, + Labels: ss.labels, + }) +} + func setBit(bm *[]uint64, i int, v int) { for i>>6 >= len(*bm) { *bm = append(*bm, 0)