mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 20:47:38 +03:00
feat: WIP ACL
This commit is contained in:
parent
cab753718d
commit
cd2524c767
9 changed files with 930 additions and 0 deletions
|
@ -5,7 +5,9 @@ go 1.20
|
|||
require (
|
||||
github.com/apernet/hysteria/core v0.0.0-00010101000000-000000000000
|
||||
github.com/babolivier/go-doh-client v0.0.0-20201028162107-a76cff4cb8b6
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.5
|
||||
github.com/miekg/dns v1.1.55
|
||||
github.com/oschwald/geoip2-golang v1.9.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
golang.org/x/crypto v0.11.0
|
||||
)
|
||||
|
@ -17,6 +19,7 @@ require (
|
|||
github.com/golang/mock v1.6.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/oschwald/maxminddb-golang v1.11.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/quic-go/qpack v0.4.0 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.3.1 // indirect
|
||||
|
|
|
@ -17,6 +17,8 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg
|
|||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.5 h1:wW7h1TG88eUIJ2i69gaE3uNVtEPIagzhGvHgwfx2Vm4=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.5/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
|
||||
|
@ -25,6 +27,10 @@ github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWb
|
|||
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
|
||||
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
|
||||
github.com/oschwald/geoip2-golang v1.9.0 h1:uvD3O6fXAXs+usU+UGExshpdP13GAqp4GBrzN7IgKZc=
|
||||
github.com/oschwald/geoip2-golang v1.9.0/go.mod h1:BHK6TvDyATVQhKNbQBdrj9eAvuwOMi2zSFXizL3K81Y=
|
||||
github.com/oschwald/maxminddb-golang v1.11.0 h1:aSXMqYR/EPNjGE8epgqwDay+P30hCBZIveY0WZbAWh0=
|
||||
github.com/oschwald/maxminddb-golang v1.11.0/go.mod h1:YmVI+H0zh3ySFR3w+oz8PCfglAFj3PuCmui13+P9zDg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
|
|
BIN
extras/outbounds/acl/GeoLite2-Country.mmdb
Normal file
BIN
extras/outbounds/acl/GeoLite2-Country.mmdb
Normal file
Binary file not shown.
223
extras/outbounds/acl/compile.go
Normal file
223
extras/outbounds/acl/compile.go
Normal file
|
@ -0,0 +1,223 @@
|
|||
package acl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
)
|
||||
|
||||
type protocol int
|
||||
|
||||
const (
|
||||
protocolBoth protocol = iota
|
||||
protocolTCP
|
||||
protocolUDP
|
||||
)
|
||||
|
||||
type Outbound interface {
|
||||
any
|
||||
}
|
||||
|
||||
type HostInfo struct {
|
||||
Name string
|
||||
IPv4 net.IP
|
||||
IPv6 net.IP
|
||||
}
|
||||
|
||||
func (h HostInfo) String() string {
|
||||
return fmt.Sprintf("%s|%s|%s", h.Name, h.IPv4, h.IPv6)
|
||||
}
|
||||
|
||||
type CompiledRuleSet[O Outbound] interface {
|
||||
Match(host HostInfo, proto protocol, port uint16) (O, net.IP)
|
||||
}
|
||||
|
||||
type compiledRule[O Outbound] struct {
|
||||
Outbound O
|
||||
HostMatcher hostMatcher
|
||||
Protocol protocol
|
||||
Port uint16
|
||||
HijackAddress net.IP
|
||||
}
|
||||
|
||||
func (r *compiledRule[O]) Match(host HostInfo, proto protocol, port uint16) bool {
|
||||
if r.Protocol != protocolBoth && r.Protocol != proto {
|
||||
return false
|
||||
}
|
||||
if r.Port != 0 && r.Port != port {
|
||||
return false
|
||||
}
|
||||
return r.HostMatcher.Match(host)
|
||||
}
|
||||
|
||||
type matchResult[O Outbound] struct {
|
||||
Outbound O
|
||||
HijackAddress net.IP
|
||||
}
|
||||
|
||||
type compiledRuleSetImpl[O Outbound] struct {
|
||||
Rules []compiledRule[O]
|
||||
Cache *lru.Cache[string, matchResult[O]] // key: HostInfo.String()
|
||||
}
|
||||
|
||||
func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto protocol, port uint16) (O, net.IP) {
|
||||
host.Name = strings.ToLower(host.Name) // Normalize host name to lower case
|
||||
key := host.String()
|
||||
if result, ok := s.Cache.Get(key); ok {
|
||||
return result.Outbound, result.HijackAddress
|
||||
}
|
||||
for _, rule := range s.Rules {
|
||||
if rule.Match(host, proto, port) {
|
||||
result := matchResult[O]{rule.Outbound, rule.HijackAddress}
|
||||
s.Cache.Add(key, result)
|
||||
return result.Outbound, result.HijackAddress
|
||||
}
|
||||
}
|
||||
// No match should also be cached
|
||||
var zero O
|
||||
s.Cache.Add(key, matchResult[O]{zero, nil})
|
||||
return zero, nil
|
||||
}
|
||||
|
||||
type CompilationError struct {
|
||||
Index int
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *CompilationError) Error() string {
|
||||
return fmt.Sprintf("error at index %d: %s", e.Index, e.Message)
|
||||
}
|
||||
|
||||
func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
||||
cacheSize int, geoipFunc func() *geoip2.Reader,
|
||||
) (CompiledRuleSet[O], error) {
|
||||
compiledRules := make([]compiledRule[O], len(rules))
|
||||
for i, rule := range rules {
|
||||
outbound, ok := outbounds[rule.Outbound]
|
||||
if !ok {
|
||||
return nil, &CompilationError{i, fmt.Sprintf("outbound %s not found", rule.Outbound)}
|
||||
}
|
||||
hm, errStr := compileHostMatcher(rule.Address, geoipFunc)
|
||||
if errStr != "" {
|
||||
return nil, &CompilationError{i, errStr}
|
||||
}
|
||||
proto, port, ok := parseProtoPort(rule.ProtoPort)
|
||||
if !ok {
|
||||
return nil, &CompilationError{i, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)}
|
||||
}
|
||||
var hijackAddress net.IP
|
||||
if rule.HijackAddress != "" {
|
||||
hijackAddress = net.ParseIP(rule.HijackAddress)
|
||||
if hijackAddress == nil {
|
||||
return nil, &CompilationError{i, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)}
|
||||
}
|
||||
}
|
||||
compiledRules[i] = compiledRule[O]{outbound, hm, proto, port, hijackAddress}
|
||||
}
|
||||
cache, err := lru.New[string, matchResult[O]](cacheSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &compiledRuleSetImpl[O]{compiledRules, cache}, nil
|
||||
}
|
||||
|
||||
// parseProtoPort parses the protocol and port from a protoPort string.
|
||||
// protoPort must be in one of the following formats:
|
||||
//
|
||||
// proto/port
|
||||
// proto/*
|
||||
// proto
|
||||
// */port
|
||||
// */*
|
||||
// *
|
||||
// [empty] (same as *)
|
||||
//
|
||||
// proto must be either "tcp" or "udp", case-insensitive.
|
||||
func parseProtoPort(protoPort string) (protocol, uint16, bool) {
|
||||
protoPort = strings.ToLower(protoPort)
|
||||
if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
|
||||
return protocolBoth, 0, true
|
||||
}
|
||||
parts := strings.SplitN(protoPort, "/", 2)
|
||||
if len(parts) == 1 {
|
||||
// No port, only protocol
|
||||
switch parts[0] {
|
||||
case "tcp":
|
||||
return protocolTCP, 0, true
|
||||
case "udp":
|
||||
return protocolUDP, 0, true
|
||||
default:
|
||||
return protocolBoth, 0, false
|
||||
}
|
||||
} else {
|
||||
// Both protocol and port
|
||||
var proto protocol
|
||||
var port uint16
|
||||
switch parts[0] {
|
||||
case "tcp":
|
||||
proto = protocolTCP
|
||||
case "udp":
|
||||
proto = protocolUDP
|
||||
case "*":
|
||||
proto = protocolBoth
|
||||
default:
|
||||
return protocolBoth, 0, false
|
||||
}
|
||||
if parts[1] != "*" {
|
||||
p64, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return protocolBoth, 0, false
|
||||
}
|
||||
port = uint16(p64)
|
||||
}
|
||||
return proto, port, true
|
||||
}
|
||||
}
|
||||
|
||||
func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatcher, string) {
|
||||
addr = strings.ToLower(addr) // Normalize to lower case
|
||||
if addr == "*" || addr == "all" {
|
||||
// Match all hosts
|
||||
return &allMatcher{}, ""
|
||||
}
|
||||
if strings.HasPrefix(addr, "geoip:") {
|
||||
// GeoIP matcher
|
||||
country := strings.ToUpper(addr[6:])
|
||||
if len(country) != 2 {
|
||||
return nil, fmt.Sprintf("invalid country code: %s", country)
|
||||
}
|
||||
db := geoipFunc()
|
||||
if db == nil {
|
||||
return nil, "failed to load GeoIP database"
|
||||
}
|
||||
return &geoIPMatcher{db, country}, ""
|
||||
}
|
||||
if strings.Contains(addr, "/") {
|
||||
// CIDR matcher
|
||||
_, ipnet, err := net.ParseCIDR(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Sprintf("invalid CIDR address: %s", addr)
|
||||
}
|
||||
return &cidrMatcher{ipnet}, ""
|
||||
}
|
||||
if ip := net.ParseIP(addr); ip != nil {
|
||||
// Single IP matcher
|
||||
return &ipMatcher{ip}, ""
|
||||
}
|
||||
if strings.Contains(addr, "*") {
|
||||
// Wildcard domain matcher
|
||||
return &domainMatcher{
|
||||
Pattern: addr,
|
||||
Wildcard: true,
|
||||
}, ""
|
||||
}
|
||||
// Nothing else matched, treat it as a non-wildcard domain
|
||||
return &domainMatcher{
|
||||
Pattern: addr,
|
||||
Wildcard: false,
|
||||
}, ""
|
||||
}
|
156
extras/outbounds/acl/compile_test.go
Normal file
156
extras/outbounds/acl/compile_test.go
Normal file
|
@ -0,0 +1,156 @@
|
|||
package acl
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCompile(t *testing.T) {
|
||||
ob1, ob2, ob3 := 1, 2, 3
|
||||
rules := []TextRule{
|
||||
{
|
||||
Outbound: "ob1",
|
||||
Address: "1.2.3.4",
|
||||
ProtoPort: "",
|
||||
HijackAddress: "",
|
||||
},
|
||||
{
|
||||
Outbound: "ob2",
|
||||
Address: "8.8.8.0/24",
|
||||
ProtoPort: "*",
|
||||
HijackAddress: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
Outbound: "ob3",
|
||||
Address: "all",
|
||||
ProtoPort: "udp/443",
|
||||
HijackAddress: "",
|
||||
},
|
||||
{
|
||||
Outbound: "ob1",
|
||||
Address: "2606:4700::6810:85e5",
|
||||
ProtoPort: "tcp",
|
||||
HijackAddress: "2606:4700::6810:85e6",
|
||||
},
|
||||
{
|
||||
Outbound: "ob2",
|
||||
Address: "2606:4700::/44",
|
||||
ProtoPort: "*/8888",
|
||||
HijackAddress: "",
|
||||
},
|
||||
{
|
||||
Outbound: "ob3",
|
||||
Address: "*.v2ex.com",
|
||||
ProtoPort: "udp",
|
||||
HijackAddress: "",
|
||||
},
|
||||
{
|
||||
Outbound: "ob1",
|
||||
Address: "crap.v2ex.com",
|
||||
ProtoPort: "tcp/80",
|
||||
HijackAddress: "2.2.2.2",
|
||||
},
|
||||
{
|
||||
Outbound: "ob2",
|
||||
Address: "geoip:JP",
|
||||
ProtoPort: "*/*",
|
||||
HijackAddress: "",
|
||||
},
|
||||
}
|
||||
reader, err := geoip2.Open("GeoLite2-Country.mmdb")
|
||||
assert.NoError(t, err)
|
||||
comp, err := Compile[int](rules, map[string]int{"ob1": ob1, "ob2": ob2, "ob3": ob3}, 100, func() *geoip2.Reader {
|
||||
return reader
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
host HostInfo
|
||||
proto protocol
|
||||
port uint16
|
||||
wantOutbound int
|
||||
wantIP net.IP
|
||||
}{
|
||||
{
|
||||
host: HostInfo{
|
||||
IPv4: net.ParseIP("1.2.3.4"),
|
||||
},
|
||||
proto: protocolTCP,
|
||||
port: 1234,
|
||||
wantOutbound: ob1,
|
||||
wantIP: nil,
|
||||
},
|
||||
{
|
||||
host: HostInfo{
|
||||
IPv4: net.ParseIP("8.8.8.4"),
|
||||
},
|
||||
proto: protocolUDP,
|
||||
port: 5353,
|
||||
wantOutbound: ob2,
|
||||
wantIP: net.ParseIP("1.1.1.1"),
|
||||
},
|
||||
{
|
||||
host: HostInfo{
|
||||
Name: "lean.delicious.com",
|
||||
},
|
||||
proto: protocolUDP,
|
||||
port: 443,
|
||||
wantOutbound: ob3,
|
||||
wantIP: nil,
|
||||
},
|
||||
{
|
||||
host: HostInfo{
|
||||
IPv6: net.ParseIP("2606:4700::6810:85e5"),
|
||||
},
|
||||
proto: protocolTCP,
|
||||
port: 80,
|
||||
wantOutbound: ob1,
|
||||
wantIP: net.ParseIP("2606:4700::6810:85e6"),
|
||||
},
|
||||
{
|
||||
host: HostInfo{
|
||||
IPv6: net.ParseIP("2606:4700:0:0:0:0:0:1"),
|
||||
},
|
||||
proto: protocolUDP,
|
||||
port: 8888,
|
||||
wantOutbound: ob2,
|
||||
wantIP: nil,
|
||||
},
|
||||
{
|
||||
host: HostInfo{
|
||||
Name: "www.v2ex.com",
|
||||
},
|
||||
proto: protocolUDP,
|
||||
port: 1234,
|
||||
wantOutbound: ob3,
|
||||
wantIP: nil,
|
||||
},
|
||||
{
|
||||
host: HostInfo{
|
||||
Name: "crap.v2ex.com",
|
||||
},
|
||||
proto: protocolTCP,
|
||||
port: 80,
|
||||
wantOutbound: ob1,
|
||||
wantIP: net.ParseIP("2.2.2.2"),
|
||||
},
|
||||
{
|
||||
host: HostInfo{
|
||||
IPv4: net.ParseIP("210.140.92.187"),
|
||||
},
|
||||
proto: protocolTCP,
|
||||
port: 25,
|
||||
wantOutbound: ob2,
|
||||
wantIP: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
gotOutbound, gotIP := comp.Match(test.host, test.proto, test.port)
|
||||
assert.Equal(t, test.wantOutbound, gotOutbound)
|
||||
assert.Equal(t, test.wantIP, gotIP)
|
||||
}
|
||||
}
|
83
extras/outbounds/acl/matchers.go
Normal file
83
extras/outbounds/acl/matchers.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package acl
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
)
|
||||
|
||||
type hostMatcher interface {
|
||||
Match(HostInfo) bool
|
||||
}
|
||||
|
||||
type ipMatcher struct {
|
||||
IP net.IP
|
||||
}
|
||||
|
||||
func (m *ipMatcher) Match(host HostInfo) bool {
|
||||
return m.IP.Equal(host.IPv4) || m.IP.Equal(host.IPv6)
|
||||
}
|
||||
|
||||
type cidrMatcher struct {
|
||||
IPNet *net.IPNet
|
||||
}
|
||||
|
||||
func (m *cidrMatcher) Match(host HostInfo) bool {
|
||||
return m.IPNet.Contains(host.IPv4) || m.IPNet.Contains(host.IPv6)
|
||||
}
|
||||
|
||||
type domainMatcher struct {
|
||||
Pattern string
|
||||
Wildcard bool
|
||||
}
|
||||
|
||||
func (m *domainMatcher) Match(host HostInfo) bool {
|
||||
if m.Wildcard {
|
||||
return deepMatchRune([]rune(host.Name), []rune(m.Pattern))
|
||||
}
|
||||
return m.Pattern == host.Name
|
||||
}
|
||||
|
||||
func deepMatchRune(str, pattern []rune) bool {
|
||||
for len(pattern) > 0 {
|
||||
switch pattern[0] {
|
||||
default:
|
||||
if len(str) == 0 || str[0] != pattern[0] {
|
||||
return false
|
||||
}
|
||||
case '*':
|
||||
return deepMatchRune(str, pattern[1:]) ||
|
||||
(len(str) > 0 && deepMatchRune(str[1:], pattern))
|
||||
}
|
||||
str = str[1:]
|
||||
pattern = pattern[1:]
|
||||
}
|
||||
return len(str) == 0 && len(pattern) == 0
|
||||
}
|
||||
|
||||
type geoIPMatcher struct {
|
||||
DB *geoip2.Reader
|
||||
Country string // must be uppercase ISO 3166-1 alpha-2 code
|
||||
}
|
||||
|
||||
func (m *geoIPMatcher) Match(host HostInfo) bool {
|
||||
if host.IPv4 != nil {
|
||||
record, err := m.DB.Country(host.IPv4)
|
||||
if err == nil && record.Country.IsoCode == m.Country {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if host.IPv6 != nil {
|
||||
record, err := m.DB.Country(host.IPv6)
|
||||
if err == nil && record.Country.IsoCode == m.Country {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type allMatcher struct{}
|
||||
|
||||
func (m *allMatcher) Match(host HostInfo) bool {
|
||||
return true
|
||||
}
|
310
extras/outbounds/acl/matchers_test.go
Normal file
310
extras/outbounds/acl/matchers_test.go
Normal file
|
@ -0,0 +1,310 @@
|
|||
package acl
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_ipMatcher_Match(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
IP net.IP
|
||||
host HostInfo
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "ipv4 match",
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
host: HostInfo{
|
||||
IPv4: net.IPv4(127, 0, 0, 1),
|
||||
IPv6: nil,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ipv6 match",
|
||||
IP: net.IPv6loopback,
|
||||
host: HostInfo{
|
||||
IPv4: nil,
|
||||
IPv6: net.IPv6loopback,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
host: HostInfo{
|
||||
IPv4: net.IPv4(127, 0, 0, 2),
|
||||
IPv6: net.IPv6loopback,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "both nil",
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
host: HostInfo{
|
||||
IPv4: nil,
|
||||
IPv6: nil,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &ipMatcher{
|
||||
IP: tt.IP,
|
||||
}
|
||||
if got := m.Match(tt.host); got != tt.want {
|
||||
t.Errorf("Match() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_cidrMatcher_Match(t *testing.T) {
|
||||
_, cidr1, _ := net.ParseCIDR("192.168.1.0/24")
|
||||
_, cidr2, _ := net.ParseCIDR("::1/128")
|
||||
_, cidr3, _ := net.ParseCIDR("0.0.0.0/0")
|
||||
_, cidr4, _ := net.ParseCIDR("::/0")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
IPNet *net.IPNet
|
||||
host HostInfo
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "ipv4 match",
|
||||
IPNet: cidr1,
|
||||
host: HostInfo{
|
||||
IPv4: net.ParseIP("192.168.1.100"),
|
||||
IPv6: net.ParseIP("::1"),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ipv6 match",
|
||||
IPNet: cidr2,
|
||||
host: HostInfo{
|
||||
IPv4: net.ParseIP("10.0.0.1"),
|
||||
IPv6: net.ParseIP("::1"),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
IPNet: cidr1,
|
||||
host: HostInfo{
|
||||
IPv4: net.ParseIP("10.0.0.1"),
|
||||
IPv6: net.ParseIP("2001:db8::2:1"),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "ipv4 broad",
|
||||
IPNet: cidr3,
|
||||
host: HostInfo{
|
||||
IPv4: net.ParseIP("10.0.0.1"),
|
||||
IPv6: net.ParseIP("::1"),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ipv6 broad",
|
||||
IPNet: cidr4,
|
||||
host: HostInfo{
|
||||
IPv4: net.ParseIP("10.0.0.1"),
|
||||
IPv6: net.ParseIP("2001:db8::2:1"),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "both nil",
|
||||
IPNet: cidr1,
|
||||
host: HostInfo{
|
||||
IPv4: nil,
|
||||
IPv6: nil,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &cidrMatcher{
|
||||
IPNet: tt.IPNet,
|
||||
}
|
||||
if got := m.Match(tt.host); got != tt.want {
|
||||
t.Errorf("Match() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_domainMatcher_Match(t *testing.T) {
|
||||
type fields struct {
|
||||
Pattern string
|
||||
Wildcard bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
host HostInfo
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "non-wildcard match",
|
||||
fields: fields{
|
||||
Pattern: "example.com",
|
||||
Wildcard: false,
|
||||
},
|
||||
host: HostInfo{
|
||||
Name: "example.com",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non-wildcard no match",
|
||||
fields: fields{
|
||||
Pattern: "example.com",
|
||||
Wildcard: false,
|
||||
},
|
||||
host: HostInfo{
|
||||
Name: "example.org",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard match 1",
|
||||
fields: fields{
|
||||
Pattern: "*.example.com",
|
||||
Wildcard: true,
|
||||
},
|
||||
host: HostInfo{
|
||||
Name: "www.example.com",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard match 2",
|
||||
fields: fields{
|
||||
Pattern: "example*.com",
|
||||
Wildcard: true,
|
||||
},
|
||||
host: HostInfo{
|
||||
Name: "example2.com",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard no match",
|
||||
fields: fields{
|
||||
Pattern: "*.example.com",
|
||||
Wildcard: true,
|
||||
},
|
||||
host: HostInfo{
|
||||
Name: "example.com",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
fields: fields{
|
||||
Pattern: "*.example.com",
|
||||
Wildcard: true,
|
||||
},
|
||||
host: HostInfo{
|
||||
Name: "",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &domainMatcher{
|
||||
Pattern: tt.fields.Pattern,
|
||||
Wildcard: tt.fields.Wildcard,
|
||||
}
|
||||
if got := m.Match(tt.host); got != tt.want {
|
||||
t.Errorf("Match() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_geoIPMatcher_Match(t *testing.T) {
|
||||
db, err := geoip2.Open("GeoLite2-Country.mmdb")
|
||||
assert.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
type fields struct {
|
||||
DB *geoip2.Reader
|
||||
Country string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
host HostInfo
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "ipv4 match",
|
||||
fields: fields{
|
||||
DB: db,
|
||||
Country: "JP",
|
||||
},
|
||||
host: HostInfo{
|
||||
IPv4: net.ParseIP("210.140.92.181"),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ipv6 match",
|
||||
fields: fields{
|
||||
DB: db,
|
||||
Country: "US",
|
||||
},
|
||||
host: HostInfo{
|
||||
IPv6: net.ParseIP("2606:4700::6810:85e5"),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
fields: fields{
|
||||
DB: db,
|
||||
Country: "AU",
|
||||
},
|
||||
host: HostInfo{
|
||||
IPv4: net.ParseIP("210.140.92.181"),
|
||||
IPv6: net.ParseIP("2606:4700::6810:85e5"),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "both nil",
|
||||
fields: fields{
|
||||
DB: db,
|
||||
Country: "KR",
|
||||
},
|
||||
host: HostInfo{
|
||||
IPv4: nil,
|
||||
IPv6: nil,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &geoIPMatcher{
|
||||
DB: tt.fields.DB,
|
||||
Country: tt.fields.Country,
|
||||
}
|
||||
if got := m.Match(tt.host); got != tt.want {
|
||||
t.Errorf("Match() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
79
extras/outbounds/acl/parse.go
Normal file
79
extras/outbounds/acl/parse.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package acl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var linePattern = regexp.MustCompile(`^(\w+)\s*\(([^,]+)(?:,([^,]+))?(?:,([^,]+))?\)$`)
|
||||
|
||||
type InvalidSyntaxError struct {
|
||||
Line string
|
||||
LineNum int
|
||||
}
|
||||
|
||||
func (e *InvalidSyntaxError) Error() string {
|
||||
return fmt.Sprintf("invalid syntax at line %d: %s", e.LineNum, e.Line)
|
||||
}
|
||||
|
||||
// TextRule is the struct representation of a (non-comment) line parsed from an ACL file.
|
||||
// A line can be parsed into a TextRule as long as it matches one of the following patterns:
|
||||
//
|
||||
// outbound(address)
|
||||
// outbound(address,protoPort)
|
||||
// outbound(address,protoPort,hijackAddress)
|
||||
//
|
||||
// It does not check whether any of the fields is valid - it's up to the compiler to do so.
|
||||
type TextRule struct {
|
||||
Outbound string
|
||||
Address string
|
||||
ProtoPort string
|
||||
HijackAddress string
|
||||
}
|
||||
|
||||
func parseLine(line string) *TextRule {
|
||||
matches := linePattern.FindStringSubmatch(line)
|
||||
if matches == nil {
|
||||
return nil
|
||||
}
|
||||
return &TextRule{
|
||||
Outbound: matches[1],
|
||||
Address: strings.TrimSpace(matches[2]),
|
||||
ProtoPort: strings.TrimSpace(matches[3]),
|
||||
HijackAddress: strings.TrimSpace(matches[4]),
|
||||
}
|
||||
}
|
||||
|
||||
func ParseTextRules(text string) ([]TextRule, error) {
|
||||
rules := make([]TextRule, 0)
|
||||
lineNum := 0
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
lineNum++
|
||||
// Remove comments
|
||||
if i := strings.Index(line, "#"); i >= 0 {
|
||||
line = line[:i]
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
// Skip empty lines
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
// Parse line
|
||||
rule := parseLine(line)
|
||||
if rule == nil {
|
||||
return nil, &InvalidSyntaxError{line, lineNum}
|
||||
}
|
||||
rules = append(rules, *rule)
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func ParseTextRulesFile(filename string) ([]TextRule, error) {
|
||||
bs, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ParseTextRules(string(bs))
|
||||
}
|
70
extras/outbounds/acl/parse_test.go
Normal file
70
extras/outbounds/acl/parse_test.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package acl
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseTextRules(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
want []TextRule
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
text: "",
|
||||
want: []TextRule{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
text: `
|
||||
# just a comment
|
||||
# another comment
|
||||
direct(1.1.1.1)
|
||||
direct(8.8.8.0/24)
|
||||
reject(all, udp/443) # inline comment
|
||||
reject(geoip:cn)
|
||||
reject(*.v2ex.com)
|
||||
my_custom_outbound1(9.9.9.9,*, 8.8.8.8) # bebop
|
||||
my_custom_outbound2(all)
|
||||
`,
|
||||
want: []TextRule{
|
||||
{Outbound: "direct", Address: "1.1.1.1"},
|
||||
{Outbound: "direct", Address: "8.8.8.0/24"},
|
||||
{Outbound: "reject", Address: "all", ProtoPort: "udp/443"},
|
||||
{Outbound: "reject", Address: "geoip:cn"},
|
||||
{Outbound: "reject", Address: "*.v2ex.com"},
|
||||
{Outbound: "my_custom_outbound1", Address: "9.9.9.9", ProtoPort: "*", HijackAddress: "8.8.8.8"},
|
||||
{Outbound: "my_custom_outbound2", Address: "all"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "fail 1",
|
||||
text: `boom()`,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "fail 2",
|
||||
text: `lol(1,1,1,1)`,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseTextRules(tt.text)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseTextRules() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ParseTextRules() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue