diff --git a/common/control/protect_linux.go b/common/control/protect_linux.go new file mode 100644 index 0000000..62e234c --- /dev/null +++ b/common/control/protect_linux.go @@ -0,0 +1,45 @@ +//go:build linux + +package control + +import ( + "syscall" + + E "github.com/sagernet/sing/common/exceptions" +) + +func sendAncillaryFileDescriptors(protectPath string, fileDescriptors []int) error { + socket, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + if err != nil { + return E.Cause(err, "open protect socket") + } + defer syscall.Close(socket) + err = syscall.Connect(socket, &syscall.SockaddrUnix{Name: protectPath}) + if err != nil { + return E.Cause(err, "connect protect path") + } + oob := syscall.UnixRights(fileDescriptors...) + dummy := []byte{1} + err = syscall.Sendmsg(socket, dummy, oob, nil, 0) + if err != nil { + return err + } + n, err := syscall.Read(socket, dummy) + if err != nil { + return err + } + if n != 1 { + return E.New("failed to protect fd") + } + return nil +} + +func ProtectPath(protectPath string) Func { + return func(network, address string, conn syscall.RawConn) error { + var innerErr error + err := conn.Control(func(fd uintptr) { + innerErr = sendAncillaryFileDescriptors(protectPath, []int{int(fd)}) + }) + return E.Errors(innerErr, err) + } +} diff --git a/common/control/protect_other.go b/common/control/protect_other.go new file mode 100644 index 0000000..3a5ca0f --- /dev/null +++ b/common/control/protect_other.go @@ -0,0 +1,7 @@ +//go:build !linux + +package control + +func ProtectPath(protectPath string) Func { + return nil +} diff --git a/common/domain/matcher.go b/common/domain/matcher.go new file mode 100644 index 0000000..95dc0dc --- /dev/null +++ b/common/domain/matcher.go @@ -0,0 +1,60 @@ +package domain + +import ( + "sort" + "unicode/utf8" +) + +type Matcher struct { + set *succinctSet +} + +func NewMatcher(domains []string, domainSuffix []string) *Matcher { + domainList := make([]string, 0, len(domains)+len(domainSuffix)) + seen := make(map[string]bool, len(domainList)) + for _, domain := range domainSuffix { + if seen[domain] { + continue + } + seen[domain] = true + domainList = append(domainList, reverseDomainSuffix(domain)) + } + for _, domain := range domains { + if seen[domain] { + continue + } + seen[domain] = true + domainList = append(domainList, reverseDomain(domain)) + } + sort.Strings(domainList) + return &Matcher{ + newSuccinctSet(domainList), + } +} + +func (m *Matcher) Match(domain string) bool { + return m.set.Has(reverseDomain(domain)) +} + +func reverseDomain(domain string) string { + l := len(domain) + b := make([]byte, l) + for i := 0; i < l; { + r, n := utf8.DecodeRuneInString(domain[i:]) + i += n + utf8.EncodeRune(b[l-i:], r) + } + return string(b) +} + +func reverseDomainSuffix(domain string) string { + l := len(domain) + b := make([]byte, l+1) + for i := 0; i < l; { + r, n := utf8.DecodeRuneInString(domain[i:]) + i += n + utf8.EncodeRune(b[l-i:], r) + } + b[l] = prefixLabel + return string(b) +} diff --git a/common/domain/set.go b/common/domain/set.go new file mode 100644 index 0000000..adf661a --- /dev/null +++ b/common/domain/set.go @@ -0,0 +1,231 @@ +package domain + +import ( + "math/bits" +) + +const prefixLabel = '\r' + +// mod from https://github.com/openacid/succinct + +type succinctSet struct { + leaves, labelBitmap []uint64 + labels []byte + ranks, selects []int32 +} + +func newSuccinctSet(keys []string) *succinctSet { + ss := &succinctSet{} + lIdx := 0 + type qElt struct{ s, e, col int } + queue := []qElt{{0, len(keys), 0}} + for i := 0; i < len(queue); i++ { + elt := queue[i] + if elt.col == len(keys[elt.s]) { + // a leaf node + elt.s++ + setBit(&ss.leaves, i, 1) + } + for j := elt.s; j < elt.e; { + frm := j + for ; j < elt.e && keys[j][elt.col] == keys[frm][elt.col]; j++ { + } + queue = append(queue, qElt{frm, j, elt.col + 1}) + ss.labels = append(ss.labels, keys[frm][elt.col]) + setBit(&ss.labelBitmap, lIdx, 0) + lIdx++ + } + setBit(&ss.labelBitmap, lIdx, 1) + lIdx++ + } + ss.init() + return ss +} + +func (ss *succinctSet) Has(key string) bool { + var nodeId, bmIdx int + for i := 0; i < len(key); i++ { + currentChar := key[i] + for ; ; bmIdx++ { + if getBit(ss.labelBitmap, bmIdx) != 0 { + return false + } + nextLabel := ss.labels[bmIdx-nodeId] + if nextLabel == prefixLabel { + return true + } + if nextLabel == currentChar { + break + } + } + nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1) + bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1 + } + if getBit(ss.leaves, nodeId) != 0 { + return true + } + for ; ; bmIdx++ { + if getBit(ss.labelBitmap, bmIdx) != 0 { + return false + } + if ss.labels[bmIdx-nodeId] == prefixLabel { + return true + } + } +} + +func setBit(bm *[]uint64, i int, v int) { + for i>>6 >= len(*bm) { + *bm = append(*bm, 0) + } + (*bm)[i>>6] |= uint64(v) << uint(i&63) +} + +func getBit(bm []uint64, i int) uint64 { + return bm[i>>6] & (1 << uint(i&63)) +} + +func (ss *succinctSet) init() { + ss.selects, ss.ranks = indexSelect32R64(ss.labelBitmap) +} + +func countZeros(bm []uint64, ranks []int32, i int) int { + a, _ := rank64(bm, ranks, int32(i)) + return i - int(a) +} + +func selectIthOne(bm []uint64, ranks, selects []int32, i int) int { + a, _ := select32R64(bm, selects, ranks, int32(i)) + return int(a) +} + +func rank64(words []uint64, rindex []int32, i int32) (int32, int32) { + wordI := i >> 6 + j := uint32(i & 63) + n := rindex[wordI] + w := words[wordI] + c1 := n + int32(bits.OnesCount64(w&mask[j])) + return c1, int32(w>>uint(j)) & 1 +} + +func indexRank64(words []uint64, opts ...bool) []int32 { + trailing := false + if len(opts) > 0 { + trailing = opts[0] + } + l := len(words) + if trailing { + l++ + } + idx := make([]int32, l) + n := int32(0) + for i := 0; i < len(words); i++ { + idx[i] = n + n += int32(bits.OnesCount64(words[i])) + } + if trailing { + idx[len(words)] = n + } + return idx +} + +func select32R64(words []uint64, selectIndex, rankIndex []int32, i int32) (int32, int32) { + a := int32(0) + l := int32(len(words)) + wordI := selectIndex[i>>5] >> 6 + for ; rankIndex[wordI+1] <= i; wordI++ { + } + w := words[wordI] + ww := w + base := wordI << 6 + findIth := int(i - rankIndex[wordI]) + offset := int32(0) + ones := bits.OnesCount32(uint32(ww)) + if ones <= findIth { + findIth -= ones + offset |= 32 + ww >>= 32 + } + ones = bits.OnesCount16(uint16(ww)) + if ones <= findIth { + findIth -= ones + offset |= 16 + ww >>= 16 + } + ones = bits.OnesCount8(uint8(ww)) + if ones <= findIth { + a = int32(select8Lookup[(ww>>5)&(0x7f8)|uint64(findIth-ones)]) + offset + 8 + } else { + a = int32(select8Lookup[(ww&0xff)<<3|uint64(findIth)]) + offset + } + a += base + w &= rMaskUpto[a&63] + if w != 0 { + return a, base + int32(bits.TrailingZeros64(w)) + } + wordI++ + for ; wordI < l; wordI++ { + w = words[wordI] + if w != 0 { + return a, wordI<<6 + int32(bits.TrailingZeros64(w)) + } + } + return a, l << 6 +} + +func indexSelect32R64(words []uint64) ([]int32, []int32) { + l := len(words) << 6 + sidx := make([]int32, 0, len(words)) + + ith := -1 + for i := 0; i < l; i++ { + if words[i>>6]&(1<