feat: full geoip/geosite support

This commit is contained in:
Toby 2023-10-28 13:55:20 -07:00
parent bcacc46f1d
commit e604c12f7e
18 changed files with 674 additions and 229 deletions

View file

@ -7,7 +7,6 @@ import (
"strings"
"github.com/apernet/hysteria/extras/outbounds/acl"
"github.com/oschwald/geoip2-golang"
)
const (
@ -34,25 +33,25 @@ type OutboundEntry struct {
Outbound PluggableOutbound
}
func NewACLEngineFromString(rules string, outbounds []OutboundEntry, geoipFunc func() *geoip2.Reader) (PluggableOutbound, error) {
func NewACLEngineFromString(rules string, outbounds []OutboundEntry, geoLoader acl.GeoLoader) (PluggableOutbound, error) {
trs, err := acl.ParseTextRules(rules)
if err != nil {
return nil, err
}
obMap := outboundsToMap(outbounds)
rs, err := acl.Compile[PluggableOutbound](trs, obMap, aclCacheSize, geoipFunc)
rs, err := acl.Compile[PluggableOutbound](trs, obMap, aclCacheSize, geoLoader)
if err != nil {
return nil, err
}
return &aclEngine{rs, obMap["default"]}, nil
}
func NewACLEngineFromFile(filename string, outbounds []OutboundEntry, geoipFunc func() *geoip2.Reader) (PluggableOutbound, error) {
func NewACLEngineFromFile(filename string, outbounds []OutboundEntry, geoLoader acl.GeoLoader) (PluggableOutbound, error) {
bs, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
return NewACLEngineFromString(string(bs), outbounds, geoipFunc)
return NewACLEngineFromString(string(bs), outbounds, geoLoader)
}
func outboundsToMap(outbounds []OutboundEntry) map[string]PluggableOutbound {

View file

@ -6,8 +6,9 @@ import (
"strconv"
"strings"
"github.com/apernet/hysteria/extras/outbounds/acl/v2geo"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/oschwald/geoip2-golang"
)
type Protocol int
@ -92,6 +93,11 @@ func (e *CompilationError) Error() string {
return fmt.Sprintf("error at line %d: %s", e.LineNum, e.Message)
}
type GeoLoader interface {
LoadGeoIP() (map[string]*v2geo.GeoIP, error)
LoadGeoSite() (map[string]*v2geo.GeoSite, error)
}
// Compile compiles TextRules into a CompiledRuleSet.
// Names in the outbounds map MUST be in all lower case.
// geoipFunc is a function that returns the GeoIP database needed by the GeoIP matcher.
@ -99,7 +105,7 @@ func (e *CompilationError) Error() string {
// be called if there is no GeoIP rule. We use a function here so that database loading
// is on-demand (only required if used by rules).
func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
cacheSize int, geoipFunc func() *geoip2.Reader,
cacheSize int, geoLoader GeoLoader,
) (CompiledRuleSet[O], error) {
compiledRules := make([]compiledRule[O], len(rules))
for i, rule := range rules {
@ -107,7 +113,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
if !ok {
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("outbound %s not found", rule.Outbound)}
}
hm, errStr := compileHostMatcher(rule.Address, geoipFunc)
hm, errStr := compileHostMatcher(rule.Address, geoLoader)
if errStr != "" {
return nil, &CompilationError{rule.LineNum, errStr}
}
@ -184,7 +190,7 @@ func parseProtoPort(protoPort string) (Protocol, uint16, bool) {
}
}
func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatcher, string) {
func compileHostMatcher(addr string, geoLoader GeoLoader) (hostMatcher, string) {
addr = strings.ToLower(addr) // Normalize to lower case
if addr == "*" || addr == "all" {
// Match all hosts
@ -192,15 +198,43 @@ func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatch
}
if strings.HasPrefix(addr, "geoip:") {
// GeoIP matcher
country := strings.ToUpper(addr[6:])
if len(country) != 2 {
return nil, fmt.Sprintf("invalid country code: %s", country)
country := addr[6:]
if len(country) == 0 {
return nil, "empty GeoIP country code"
}
db := geoipFunc()
if db == nil {
return nil, "failed to load GeoIP database"
gMap, err := geoLoader.LoadGeoIP()
if err != nil {
return nil, err.Error()
}
return &geoipMatcher{db, country}, ""
list, ok := gMap[country]
if !ok || list == nil {
return nil, fmt.Sprintf("GeoIP country code %s not found", country)
}
m, err := newGeoIPMatcher(list)
if err != nil {
return nil, err.Error()
}
return m, ""
}
if strings.HasPrefix(addr, "geosite:") {
// GeoSite matcher
name, attrs := parseGeoSiteName(addr[8:])
if len(name) == 0 {
return nil, "empty GeoSite name"
}
gMap, err := geoLoader.LoadGeoSite()
if err != nil {
return nil, err.Error()
}
list, ok := gMap[name]
if !ok || list == nil {
return nil, fmt.Sprintf("GeoSite name %s not found", name)
}
m, err := newGeositeMatcher(list, attrs)
if err != nil {
return nil, err.Error()
}
return m, ""
}
if strings.Contains(addr, "/") {
// CIDR matcher
@ -227,3 +261,13 @@ func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatch
Wildcard: false,
}, ""
}
func parseGeoSiteName(s string) (string, []string) {
parts := strings.Split(s, "@")
base := strings.TrimSpace(parts[0])
attrs := parts[1:]
for i := range attrs {
attrs[i] = strings.TrimSpace(attrs[i])
}
return base, attrs
}

View file

@ -4,12 +4,25 @@ import (
"net"
"testing"
"github.com/oschwald/geoip2-golang"
"github.com/apernet/hysteria/extras/outbounds/acl/v2geo"
"github.com/stretchr/testify/assert"
)
var _ GeoLoader = (*testGeoLoader)(nil)
type testGeoLoader struct{}
func (l *testGeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
return v2geo.LoadGeoIP("v2geo/geoip.dat")
}
func (l *testGeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
return v2geo.LoadGeoSite("v2geo/geosite.dat")
}
func TestCompile(t *testing.T) {
ob1, ob2, ob3 := 1, 2, 3
ob1, ob2, ob3, ob4 := 1, 2, 3, 4
rules := []TextRule{
{
Outbound: "ob1",
@ -59,12 +72,25 @@ func TestCompile(t *testing.T) {
ProtoPort: "*/*",
HijackAddress: "",
},
{
Outbound: "ob4",
Address: "geosite:4chan",
ProtoPort: "*/*",
HijackAddress: "",
},
{
Outbound: "ob4",
Address: "geosite:google @cn",
ProtoPort: "*/*",
HijackAddress: "",
},
}
reader, err := geoip2.Open("GeoLite2-Country.mmdb")
assert.NoError(t, err)
comp, err := Compile[int](rules, map[string]int{"ob1": ob1, "ob2": ob2, "ob3": ob3}, 100, func() *geoip2.Reader {
return reader
})
comp, err := Compile[int](rules, map[string]int{
"ob1": ob1,
"ob2": ob2,
"ob3": ob3,
"ob4": ob4,
}, 100, &testGeoLoader{})
assert.NoError(t, err)
tests := []struct {
@ -146,6 +172,42 @@ func TestCompile(t *testing.T) {
wantOutbound: ob2,
wantIP: nil,
},
{
host: HostInfo{
IPv4: net.ParseIP("175.45.176.73"),
},
proto: ProtocolTCP,
port: 80,
wantOutbound: 0, // no match default
wantIP: nil,
},
{
host: HostInfo{
Name: "boards.4channel.org",
},
proto: ProtocolTCP,
port: 443,
wantOutbound: ob4,
wantIP: nil,
},
{
host: HostInfo{
Name: "gstatic-cn.com",
},
proto: ProtocolUDP,
port: 9999,
wantOutbound: ob4,
wantIP: nil,
},
{
host: HostInfo{
Name: "hoho.waymo.com",
},
proto: ProtocolUDP,
port: 9999,
wantOutbound: 0, // no match default
wantIP: nil,
},
}
for _, test := range tests {
@ -154,3 +216,56 @@ func TestCompile(t *testing.T) {
assert.Equal(t, test.wantIP, gotIP)
}
}
func Test_parseGeoSiteName(t *testing.T) {
tests := []struct {
name string
s string
want string
want1 []string
}{
{
name: "no attrs",
s: "pornhub",
want: "pornhub",
want1: []string{},
},
{
name: "one attr 1",
s: "xiaomi@cn",
want: "xiaomi",
want1: []string{"cn"},
},
{
name: "one attr 2",
s: " google @jp ",
want: "google",
want1: []string{"jp"},
},
{
name: "two attrs 1",
s: "netflix@jp@kr",
want: "netflix",
want1: []string{"jp", "kr"},
},
{
name: "two attrs 2",
s: "netflix @xixi @haha ",
want: "netflix",
want1: []string{"xixi", "haha"},
},
{
name: "empty",
s: "",
want: "",
want1: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := parseGeoSiteName(tt.s)
assert.Equalf(t, tt.want, got, "parseGeoSiteName(%v)", tt.s)
assert.Equalf(t, tt.want1, got1, "parseGeoSiteName(%v)", tt.s)
})
}
}

View file

@ -2,8 +2,6 @@ package acl
import (
"net"
"github.com/oschwald/geoip2-golang"
)
type hostMatcher interface {
@ -55,27 +53,6 @@ func deepMatchRune(str, pattern []rune) bool {
return len(str) == 0 && len(pattern) == 0
}
type geoipMatcher struct {
DB *geoip2.Reader
Country string // must be uppercase ISO 3166-1 alpha-2 code
}
func (m *geoipMatcher) Match(host HostInfo) bool {
if host.IPv4 != nil {
record, err := m.DB.Country(host.IPv4)
if err == nil && record.Country.IsoCode == m.Country {
return true
}
}
if host.IPv6 != nil {
record, err := m.DB.Country(host.IPv6)
if err == nil && record.Country.IsoCode == m.Country {
return true
}
}
return false
}
type allMatcher struct{}
func (m *allMatcher) Match(host HostInfo) bool {

View file

@ -3,9 +3,6 @@ package acl
import (
"net"
"testing"
"github.com/oschwald/geoip2-golang"
"github.com/stretchr/testify/assert"
)
func Test_ipMatcher_Match(t *testing.T) {
@ -233,78 +230,3 @@ func Test_domainMatcher_Match(t *testing.T) {
})
}
}
func Test_geoipMatcher_Match(t *testing.T) {
db, err := geoip2.Open("GeoLite2-Country.mmdb")
assert.NoError(t, err)
defer db.Close()
type fields struct {
DB *geoip2.Reader
Country string
}
tests := []struct {
name string
fields fields
host HostInfo
want bool
}{
{
name: "ipv4 match",
fields: fields{
DB: db,
Country: "JP",
},
host: HostInfo{
IPv4: net.ParseIP("210.140.92.181"),
},
want: true,
},
{
name: "ipv6 match",
fields: fields{
DB: db,
Country: "US",
},
host: HostInfo{
IPv6: net.ParseIP("2606:4700::6810:85e5"),
},
want: true,
},
{
name: "no match",
fields: fields{
DB: db,
Country: "AU",
},
host: HostInfo{
IPv4: net.ParseIP("210.140.92.181"),
IPv6: net.ParseIP("2606:4700::6810:85e5"),
},
want: false,
},
{
name: "both nil",
fields: fields{
DB: db,
Country: "KR",
},
host: HostInfo{
IPv4: nil,
IPv6: nil,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &geoipMatcher{
DB: tt.fields.DB,
Country: tt.fields.Country,
}
if got := m.Match(tt.host); got != tt.want {
t.Errorf("Match() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,213 @@
package acl
import (
"bytes"
"errors"
"net"
"regexp"
"sort"
"strings"
"github.com/apernet/hysteria/extras/outbounds/acl/v2geo"
)
var _ hostMatcher = (*geoipMatcher)(nil)
type geoipMatcher struct {
N4 []*net.IPNet // sorted
N6 []*net.IPNet // sorted
Inverse bool
}
// matchIP tries to match the given IP address with the corresponding IPNets.
// Note that this function does NOT handle the Inverse flag.
func (m *geoipMatcher) matchIP(ip net.IP) bool {
var n []*net.IPNet
if ip4 := ip.To4(); ip4 != nil {
// N4 stores IPv4 addresses in 4-byte form.
// Make sure we use it here too, otherwise bytes.Compare will fail.
ip = ip4
n = m.N4
} else {
n = m.N6
}
left, right := 0, len(n)-1
for left <= right {
mid := (left + right) / 2
if n[mid].Contains(ip) {
return true
} else if bytes.Compare(n[mid].IP, ip) < 0 {
left = mid + 1
} else {
right = mid - 1
}
}
return false
}
func (m *geoipMatcher) Match(host HostInfo) bool {
if host.IPv4 != nil {
if m.matchIP(host.IPv4) {
return !m.Inverse
}
}
if host.IPv6 != nil {
if m.matchIP(host.IPv6) {
return !m.Inverse
}
}
return m.Inverse
}
func newGeoIPMatcher(list *v2geo.GeoIP) (*geoipMatcher, error) {
n4 := make([]*net.IPNet, 0)
n6 := make([]*net.IPNet, 0)
for _, cidr := range list.Cidr {
if len(cidr.Ip) == 4 {
// IPv4
n4 = append(n4, &net.IPNet{
IP: cidr.Ip,
Mask: net.CIDRMask(int(cidr.Prefix), 32),
})
} else if len(cidr.Ip) == 16 {
// IPv6
n6 = append(n6, &net.IPNet{
IP: cidr.Ip,
Mask: net.CIDRMask(int(cidr.Prefix), 128),
})
} else {
return nil, errors.New("invalid IP length")
}
}
// Sort the IPNets, so we can do binary search later.
sort.Slice(n4, func(i, j int) bool {
return bytes.Compare(n4[i].IP, n4[j].IP) < 0
})
sort.Slice(n6, func(i, j int) bool {
return bytes.Compare(n6[i].IP, n6[j].IP) < 0
})
return &geoipMatcher{
N4: n4,
N6: n6,
Inverse: list.InverseMatch,
}, nil
}
var _ hostMatcher = (*geositeMatcher)(nil)
type geositeDomainType int
const (
geositeDomainPlain geositeDomainType = iota
geositeDomainRegex
geositeDomainRoot
geositeDomainFull
)
type geositeDomain struct {
Type geositeDomainType
Value string
Regex *regexp.Regexp
Attrs map[string]bool
}
type geositeMatcher struct {
Domains []geositeDomain
// Attributes are matched using "and" logic - if you have multiple attributes here,
// a domain must have all of those attributes to be considered a match.
Attrs []string
}
func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
// Match attributes first
if len(m.Attrs) > 0 {
if len(domain.Attrs) == 0 {
return false
}
for _, attr := range m.Attrs {
if !domain.Attrs[attr] {
return false
}
}
}
switch domain.Type {
case geositeDomainPlain:
return strings.Contains(host.Name, domain.Value)
case geositeDomainRegex:
if domain.Regex != nil {
return domain.Regex.MatchString(host.Name)
}
case geositeDomainFull:
return host.Name == domain.Value
case geositeDomainRoot:
if host.Name == domain.Value {
return true
}
return strings.HasSuffix(host.Name, "."+domain.Value)
default:
return false
}
return false
}
func (m *geositeMatcher) Match(host HostInfo) bool {
for _, domain := range m.Domains {
if m.matchDomain(domain, host) {
return true
}
}
return false
}
func newGeositeMatcher(list *v2geo.GeoSite, attrs []string) (*geositeMatcher, error) {
domains := make([]geositeDomain, len(list.Domain))
for i, domain := range list.Domain {
switch domain.Type {
case v2geo.Domain_Plain:
domains[i] = geositeDomain{
Type: geositeDomainPlain,
Value: domain.Value,
Attrs: domainAttributeToMap(domain.Attribute),
}
case v2geo.Domain_Regex:
regex, err := regexp.Compile(domain.Value)
if err != nil {
return nil, err
}
domains[i] = geositeDomain{
Type: geositeDomainRegex,
Regex: regex,
Attrs: domainAttributeToMap(domain.Attribute),
}
case v2geo.Domain_Full:
domains[i] = geositeDomain{
Type: geositeDomainFull,
Value: domain.Value,
Attrs: domainAttributeToMap(domain.Attribute),
}
case v2geo.Domain_RootDomain:
domains[i] = geositeDomain{
Type: geositeDomainRoot,
Value: domain.Value,
Attrs: domainAttributeToMap(domain.Attribute),
}
default:
return nil, errors.New("unsupported domain type")
}
}
return &geositeMatcher{
Domains: domains,
Attrs: attrs,
}, nil
}
func domainAttributeToMap(attrs []*v2geo.Domain_Attribute) map[string]bool {
m := make(map[string]bool)
for _, attr := range attrs {
// Supposedly there are also int attributes,
// but nobody seems to use them, so we treat everything as boolean for now.
m[attr.Key] = true
}
return m
}

View file

@ -0,0 +1,141 @@
package acl
import (
"net"
"testing"
"github.com/apernet/hysteria/extras/outbounds/acl/v2geo"
"github.com/stretchr/testify/assert"
)
func Test_geoipMatcher_Match(t *testing.T) {
geoipMap, err := v2geo.LoadGeoIP("v2geo/geoip.dat")
assert.NoError(t, err)
m, err := newGeoIPMatcher(geoipMap["us"])
assert.NoError(t, err)
tests := []struct {
name string
host HostInfo
want bool
}{
{
name: "IPv4 match",
host: HostInfo{
IPv4: net.ParseIP("73.222.1.100"),
},
want: true,
},
{
name: "IPv4 no match",
host: HostInfo{
IPv4: net.ParseIP("123.123.123.123"),
},
want: false,
},
{
name: "IPv6 match",
host: HostInfo{
IPv6: net.ParseIP("2607:f8b0:4005:80c::2004"),
},
want: true,
},
{
name: "IPv6 no match",
host: HostInfo{
IPv6: net.ParseIP("240e:947:6001::1f8"),
},
want: false,
},
{
name: "both nil",
host: HostInfo{
IPv4: nil,
IPv6: nil,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, m.Match(tt.host), "Match(%v)", tt.host)
})
}
}
func Test_geositeMatcher_Match(t *testing.T) {
geositeMap, err := v2geo.LoadGeoSite("v2geo/geosite.dat")
assert.NoError(t, err)
m, err := newGeositeMatcher(geositeMap["apple"], nil)
assert.NoError(t, err)
tests := []struct {
name string
attrs []string
host HostInfo
want bool
}{
{
name: "subdomain",
attrs: nil,
host: HostInfo{
Name: "poop.i-book.com",
},
want: true,
},
{
name: "subdomain root",
attrs: nil,
host: HostInfo{
Name: "applepaycash.net",
},
want: true,
},
{
name: "full",
attrs: nil,
host: HostInfo{
Name: "courier-push-apple.com.akadns.net",
},
want: true,
},
{
name: "regexp",
attrs: nil,
host: HostInfo{
Name: "cdn4.apple-mapkit.com",
},
want: true,
},
{
name: "attr match",
attrs: []string{"cn"},
host: HostInfo{
Name: "bag.itunes.apple.com",
},
want: true,
},
{
name: "attr multi no match",
attrs: []string{"cn", "haha"},
host: HostInfo{
Name: "bag.itunes.apple.com",
},
want: false,
},
{
name: "attr no match",
attrs: []string{"cn"},
host: HostInfo{
Name: "mr-apple.com.tw",
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m.Attrs = tt.attrs
assert.Equalf(t, tt.want, m.Match(tt.host), "Match(%v)", tt.host)
})
}
}

View file

@ -7,13 +7,9 @@ import (
"google.golang.org/protobuf/proto"
)
type GeoIPMap map[string]*GeoIP
type GeoSiteMap map[string]*GeoSite
// LoadGeoIP loads a GeoIP data file and converts it to a map.
// The keys of the map (country codes) are all normalized to lowercase.
func LoadGeoIP(filename string) (GeoIPMap, error) {
func LoadGeoIP(filename string) (map[string]*GeoIP, error) {
bs, err := os.ReadFile(filename)
if err != nil {
return nil, err
@ -22,7 +18,7 @@ func LoadGeoIP(filename string) (GeoIPMap, error) {
if err := proto.Unmarshal(bs, &list); err != nil {
return nil, err
}
m := make(GeoIPMap)
m := make(map[string]*GeoIP)
for _, entry := range list.Entry {
m[strings.ToLower(entry.CountryCode)] = entry
}
@ -31,7 +27,7 @@ func LoadGeoIP(filename string) (GeoIPMap, error) {
// LoadGeoSite loads a GeoSite data file and converts it to a map.
// The keys of the map (site keys) are all normalized to lowercase.
func LoadGeoSite(filename string) (GeoSiteMap, error) {
func LoadGeoSite(filename string) (map[string]*GeoSite, error) {
bs, err := os.ReadFile(filename)
if err != nil {
return nil, err
@ -40,7 +36,7 @@ func LoadGeoSite(filename string) (GeoSiteMap, error) {
if err := proto.Unmarshal(bs, &list); err != nil {
return nil, err
}
m := make(GeoSiteMap)
m := make(map[string]*GeoSite)
for _, entry := range list.Entry {
m[strings.ToLower(entry.CountryCode)] = entry
}