package domain import ( "encoding/binary" "math/bits" "github.com/sagernet/sing/common/varbin" ) // mod from https://github.com/openacid/succinct type succinctSet struct { leaves, labelBitmap []uint64 labels []byte ranks, selects []int32 } func newSuccinctSet(keys []string) *succinctSet { ss := &succinctSet{} lIdx := 0 type qElt struct{ s, e, col int } queue := []qElt{{0, len(keys), 0}} for i := 0; i < len(queue); i++ { elt := queue[i] if elt.col == len(keys[elt.s]) { // a leaf node elt.s++ setBit(&ss.leaves, i, 1) } for j := elt.s; j < elt.e; { frm := j for ; j < elt.e && keys[j][elt.col] == keys[frm][elt.col]; j++ { } queue = append(queue, qElt{frm, j, elt.col + 1}) ss.labels = append(ss.labels, keys[frm][elt.col]) setBit(&ss.labelBitmap, lIdx, 0) lIdx++ } setBit(&ss.labelBitmap, lIdx, 1) lIdx++ } ss.init() return ss } func (ss *succinctSet) keys() []string { var result []string var currentKey []byte var bmIdx, nodeId int var traverse func(int, int) traverse = func(nodeId, bmIdx int) { if getBit(ss.leaves, nodeId) != 0 { result = append(result, string(currentKey)) } for ; ; bmIdx++ { if getBit(ss.labelBitmap, bmIdx) != 0 { return } nextLabel := ss.labels[bmIdx-nodeId] currentKey = append(currentKey, nextLabel) nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1) nextBmIdx := selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nextNodeId-1) + 1 traverse(nextNodeId, nextBmIdx) currentKey = currentKey[:len(currentKey)-1] } } traverse(nodeId, bmIdx) return result } type succinctSetData struct { Reserved uint8 Leaves []uint64 LabelBitmap []uint64 Labels []byte } func readSuccinctSet(reader varbin.Reader) (*succinctSet, error) { matcher, err := varbin.ReadValue[succinctSetData](reader, binary.BigEndian) if err != nil { return nil, err } set := &succinctSet{ leaves: matcher.Leaves, labelBitmap: matcher.LabelBitmap, labels: matcher.Labels, } set.init() return set, nil } func (ss *succinctSet) Write(writer varbin.Writer) error { return varbin.Write(writer, binary.BigEndian, succinctSetData{ Leaves: ss.leaves, LabelBitmap: ss.labelBitmap, Labels: ss.labels, }) } func setBit(bm *[]uint64, i int, v int) { for i>>6 >= len(*bm) { *bm = append(*bm, 0) } (*bm)[i>>6] |= uint64(v) << uint(i&63) } func getBit(bm []uint64, i int) uint64 { return bm[i>>6] & (1 << uint(i&63)) } func (ss *succinctSet) init() { ss.selects, ss.ranks = indexSelect32R64(ss.labelBitmap) } func countZeros(bm []uint64, ranks []int32, i int) int { a, _ := rank64(bm, ranks, int32(i)) return i - int(a) } func selectIthOne(bm []uint64, ranks, selects []int32, i int) int { a, _ := select32R64(bm, selects, ranks, int32(i)) return int(a) } func rank64(words []uint64, rindex []int32, i int32) (int32, int32) { wordI := i >> 6 j := uint32(i & 63) n := rindex[wordI] w := words[wordI] c1 := n + int32(bits.OnesCount64(w&mask[j])) return c1, int32(w>>uint(j)) & 1 } func indexRank64(words []uint64, opts ...bool) []int32 { trailing := false if len(opts) > 0 { trailing = opts[0] } l := len(words) if trailing { l++ } idx := make([]int32, l) n := int32(0) for i := 0; i < len(words); i++ { idx[i] = n n += int32(bits.OnesCount64(words[i])) } if trailing { idx[len(words)] = n } return idx } func select32R64(words []uint64, selectIndex, rankIndex []int32, i int32) (int32, int32) { a := int32(0) l := int32(len(words)) wordI := selectIndex[i>>5] >> 6 for ; rankIndex[wordI+1] <= i; wordI++ { } w := words[wordI] ww := w base := wordI << 6 findIth := int(i - rankIndex[wordI]) offset := int32(0) ones := bits.OnesCount32(uint32(ww)) if ones <= findIth { findIth -= ones offset |= 32 ww >>= 32 } ones = bits.OnesCount16(uint16(ww)) if ones <= findIth { findIth -= ones offset |= 16 ww >>= 16 } ones = bits.OnesCount8(uint8(ww)) if ones <= findIth { a = int32(select8Lookup[(ww>>5)&(0x7f8)|uint64(findIth-ones)]) + offset + 8 } else { a = int32(select8Lookup[(ww&0xff)<<3|uint64(findIth)]) + offset } a += base w &= rMaskUpto[a&63] if w != 0 { return a, base + int32(bits.TrailingZeros64(w)) } wordI++ for ; wordI < l; wordI++ { w = words[wordI] if w != 0 { return a, wordI<<6 + int32(bits.TrailingZeros64(w)) } } return a, l << 6 } func indexSelect32R64(words []uint64) ([]int32, []int32) { l := len(words) << 6 sidx := make([]int32, 0, len(words)) ith := -1 for i := 0; i < l; i++ { if words[i>>6]&(1<