mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 04:27:39 +03:00
feat: ACL
This commit is contained in:
parent
6fa958815b
commit
a7d74a9ec1
12 changed files with 380 additions and 61 deletions
|
@ -44,6 +44,7 @@ type serverConfig struct {
|
|||
UDPIdleTimeout time.Duration `mapstructure:"udpIdleTimeout"`
|
||||
Auth serverConfigAuth `mapstructure:"auth"`
|
||||
Resolver serverConfigResolver `mapstructure:"resolver"`
|
||||
ACL serverConfigACL `mapstructure:"acl"`
|
||||
Outbounds []serverConfigOutboundEntry `mapstructure:"outbounds"`
|
||||
Masquerade serverConfigMasquerade `mapstructure:"masquerade"`
|
||||
}
|
||||
|
@ -133,6 +134,12 @@ type serverConfigResolver struct {
|
|||
HTTPS serverConfigResolverHTTPS `mapstructure:"https"`
|
||||
}
|
||||
|
||||
type serverConfigACL struct {
|
||||
File string `mapstructure:"file"`
|
||||
Inline []string `mapstructure:"inline"`
|
||||
GeoIP string `mapstructure:"geoip"`
|
||||
}
|
||||
|
||||
type serverConfigOutboundDirect struct {
|
||||
Mode string `mapstructure:"mode"`
|
||||
BindIPv4 string `mapstructure:"bindIPv4"`
|
||||
|
@ -314,22 +321,60 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error {
|
|||
// Resolver(ACL(Outbounds...))
|
||||
|
||||
// Outbounds
|
||||
var ob outbounds.PluggableOutbound
|
||||
var obs []outbounds.OutboundEntry
|
||||
if len(c.Outbounds) == 0 {
|
||||
ob = outbounds.NewDirectOutboundSimple(outbounds.DirectOutboundModeAuto)
|
||||
// Guarantee we have at least one outbound
|
||||
obs = []outbounds.OutboundEntry{{
|
||||
Name: "default",
|
||||
Outbound: outbounds.NewDirectOutboundSimple(outbounds.DirectOutboundModeAuto),
|
||||
}}
|
||||
} else {
|
||||
// Multiple-outbound is for ACL only, not supported yet.
|
||||
var err error
|
||||
entry := c.Outbounds[0]
|
||||
switch strings.ToLower(entry.Type) {
|
||||
case "direct":
|
||||
ob, err = serverConfigOutboundDirectToOutbound(entry.Direct)
|
||||
default:
|
||||
err = configError{Field: "outbounds.type", Err: errors.New("unsupported outbound type")}
|
||||
obs = make([]outbounds.OutboundEntry, len(c.Outbounds))
|
||||
for i, entry := range c.Outbounds {
|
||||
if entry.Name == "" {
|
||||
return configError{Field: "outbounds.name", Err: errors.New("empty outbound name")}
|
||||
}
|
||||
var ob outbounds.PluggableOutbound
|
||||
var err error
|
||||
switch strings.ToLower(entry.Type) {
|
||||
case "direct":
|
||||
ob, err = serverConfigOutboundDirectToOutbound(entry.Direct)
|
||||
default:
|
||||
err = configError{Field: "outbounds.type", Err: errors.New("unsupported outbound type")}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
obs[i] = outbounds.OutboundEntry{Name: entry.Name, Outbound: ob}
|
||||
}
|
||||
}
|
||||
|
||||
var uOb outbounds.PluggableOutbound // "unified" outbound
|
||||
|
||||
// ACL
|
||||
if c.ACL.File != "" && len(c.ACL.Inline) > 0 {
|
||||
return configError{Field: "acl", Err: errors.New("cannot set both acl.file and acl.inline")}
|
||||
}
|
||||
gLoader := &geoipLoader{
|
||||
Filename: c.ACL.GeoIP,
|
||||
DownloadFunc: geoipDownloadFunc,
|
||||
DownloadErrFunc: geoipDownloadErrFunc,
|
||||
}
|
||||
if c.ACL.File != "" {
|
||||
acl, err := outbounds.NewACLEngineFromFile(c.ACL.File, obs, gLoader.Load)
|
||||
if err != nil {
|
||||
return err
|
||||
return configError{Field: "acl.file", Err: err}
|
||||
}
|
||||
uOb = acl
|
||||
} else if len(c.ACL.Inline) > 0 {
|
||||
acl, err := outbounds.NewACLEngineFromString(strings.Join(c.ACL.Inline, "\n"), obs, gLoader.Load)
|
||||
if err != nil {
|
||||
return configError{Field: "acl.inline", Err: err}
|
||||
}
|
||||
uOb = acl
|
||||
} else {
|
||||
// No ACL, use the first outbound
|
||||
uOb = obs[0].Outbound
|
||||
}
|
||||
|
||||
// Resolver
|
||||
|
@ -340,27 +385,27 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error {
|
|||
if c.Resolver.TCP.Addr == "" {
|
||||
return configError{Field: "resolver.tcp.addr", Err: errors.New("empty resolver address")}
|
||||
}
|
||||
ob = outbounds.NewStandardResolverTCP(c.Resolver.TCP.Addr, c.Resolver.TCP.Timeout, ob)
|
||||
uOb = outbounds.NewStandardResolverTCP(c.Resolver.TCP.Addr, c.Resolver.TCP.Timeout, uOb)
|
||||
case "udp":
|
||||
if c.Resolver.UDP.Addr == "" {
|
||||
return configError{Field: "resolver.udp.addr", Err: errors.New("empty resolver address")}
|
||||
}
|
||||
ob = outbounds.NewStandardResolverUDP(c.Resolver.UDP.Addr, c.Resolver.UDP.Timeout, ob)
|
||||
uOb = outbounds.NewStandardResolverUDP(c.Resolver.UDP.Addr, c.Resolver.UDP.Timeout, uOb)
|
||||
case "tls", "tcp-tls":
|
||||
if c.Resolver.TLS.Addr == "" {
|
||||
return configError{Field: "resolver.tls.addr", Err: errors.New("empty resolver address")}
|
||||
}
|
||||
ob = outbounds.NewStandardResolverTLS(c.Resolver.TLS.Addr, c.Resolver.TLS.Timeout, c.Resolver.TLS.SNI, c.Resolver.TLS.Insecure, ob)
|
||||
uOb = outbounds.NewStandardResolverTLS(c.Resolver.TLS.Addr, c.Resolver.TLS.Timeout, c.Resolver.TLS.SNI, c.Resolver.TLS.Insecure, uOb)
|
||||
case "https", "http":
|
||||
if c.Resolver.HTTPS.Addr == "" {
|
||||
return configError{Field: "resolver.https.addr", Err: errors.New("empty resolver address")}
|
||||
}
|
||||
ob = outbounds.NewDoHResolver(c.Resolver.HTTPS.Addr, c.Resolver.HTTPS.Timeout, c.Resolver.HTTPS.SNI, c.Resolver.HTTPS.Insecure, ob)
|
||||
uOb = outbounds.NewDoHResolver(c.Resolver.HTTPS.Addr, c.Resolver.HTTPS.Timeout, c.Resolver.HTTPS.SNI, c.Resolver.HTTPS.Insecure, uOb)
|
||||
default:
|
||||
return configError{Field: "resolver.type", Err: errors.New("unsupported resolver type")}
|
||||
}
|
||||
|
||||
hyConfig.Outbound = &outbounds.PluggableOutboundAdapter{PluggableOutbound: ob}
|
||||
hyConfig.Outbound = &outbounds.PluggableOutboundAdapter{PluggableOutbound: uOb}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -525,6 +570,16 @@ func runServer(cmd *cobra.Command, args []string) {
|
|||
}
|
||||
}
|
||||
|
||||
func geoipDownloadFunc(filename, url string) {
|
||||
logger.Info("downloading GeoIP database", zap.String("filename", filename), zap.String("url", url))
|
||||
}
|
||||
|
||||
func geoipDownloadErrFunc(err error) {
|
||||
if err != nil {
|
||||
logger.Error("failed to download GeoIP database", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
type serverLogger struct{}
|
||||
|
||||
func (l *serverLogger) Connect(addr net.Addr, id string, tx uint64) {
|
||||
|
|
|
@ -95,6 +95,13 @@ func TestServerConfig(t *testing.T) {
|
|||
Insecure: true,
|
||||
},
|
||||
},
|
||||
ACL: serverConfigACL{
|
||||
File: "chnroute.txt",
|
||||
Inline: []string{
|
||||
"lmao(ok)",
|
||||
"kek(cringe,boba,tea)",
|
||||
},
|
||||
},
|
||||
Outbounds: []serverConfigOutboundEntry{
|
||||
{
|
||||
Name: "goodstuff",
|
||||
|
|
|
@ -70,6 +70,12 @@ resolver:
|
|||
sni: real.stuff.net
|
||||
insecure: true
|
||||
|
||||
acl:
|
||||
file: chnroute.txt
|
||||
inline:
|
||||
- lmao(ok)
|
||||
- kek(cringe,boba,tea)
|
||||
|
||||
outbounds:
|
||||
- name: goodstuff
|
||||
type: direct
|
||||
|
|
|
@ -2,10 +2,18 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/apernet/hysteria/extras/utils"
|
||||
"github.com/mdp/qrterminal/v3"
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
)
|
||||
|
||||
const (
|
||||
geoipDefaultFilename = "GeoLite2-Country.mmdb"
|
||||
geoipDownloadURL = "https://git.io/GeoLite2-Country.mmdb"
|
||||
)
|
||||
|
||||
// convBandwidth handles both string and int types for bandwidth.
|
||||
|
@ -44,3 +52,59 @@ func (e configError) Error() string {
|
|||
func (e configError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// geoipLoader provides the on-demand GeoIP database loading function required by the ACL engine.
|
||||
type geoipLoader struct {
|
||||
Filename string
|
||||
DownloadFunc func(filename, url string) // Called when downloading the GeoIP database.
|
||||
DownloadErrFunc func(err error) // Called when downloading the GeoIP database succeeds/fails.
|
||||
|
||||
db *geoip2.Reader
|
||||
}
|
||||
|
||||
func (l *geoipLoader) download() error {
|
||||
resp, err := http.Get(geoipDownloadURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
f, err := os.Create(geoipDefaultFilename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(f, resp.Body)
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *geoipLoader) Load() *geoip2.Reader {
|
||||
if l.db == nil {
|
||||
if l.Filename == "" {
|
||||
// Filename not specified, try default.
|
||||
if _, err := os.Stat(geoipDefaultFilename); err == nil {
|
||||
// Default already exists, just use it.
|
||||
l.Filename = geoipDefaultFilename
|
||||
} else if os.IsNotExist(err) {
|
||||
// Default doesn't exist, download it.
|
||||
l.DownloadFunc(geoipDefaultFilename, geoipDownloadURL)
|
||||
err := l.download()
|
||||
l.DownloadErrFunc(err)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
l.Filename = geoipDefaultFilename
|
||||
} else {
|
||||
// Other error
|
||||
return nil
|
||||
}
|
||||
}
|
||||
db, err := geoip2.Open(l.Filename)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
l.db = db
|
||||
}
|
||||
return l.db
|
||||
}
|
||||
|
|
|
@ -48,11 +48,18 @@ auth:
|
|||
# sni: server1.yolo.net
|
||||
# insecure: true
|
||||
|
||||
# acl:
|
||||
# inline:
|
||||
# - haha(8.8.8.8/24, udp/53)
|
||||
# - reject(v2ex.com)
|
||||
# - reject(*.v2ex.com)
|
||||
# - reject(geoip:cn)
|
||||
|
||||
# outbounds:
|
||||
# - name: haha
|
||||
# type: direct
|
||||
# direct:
|
||||
# mode: auto
|
||||
# mode: 46
|
||||
# bindIPv4: 2.4.6.8
|
||||
# bindIPv6: 0:0:0:0:0:ffff:0204:0608
|
||||
# bindDevice: eth233
|
||||
|
|
122
extras/outbounds/acl.go
Normal file
122
extras/outbounds/acl.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package outbounds
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/apernet/hysteria/extras/outbounds/acl"
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
)
|
||||
|
||||
const (
|
||||
aclCacheSize = 1024
|
||||
)
|
||||
|
||||
var errRejected = errors.New("rejected")
|
||||
|
||||
// aclEngine is a PluggableOutbound that dispatches connections to different
|
||||
// outbounds based on ACL rules.
|
||||
// There are 3 built-in outbounds:
|
||||
// - direct: directOutbound, auto mode
|
||||
// - reject: reject the connection
|
||||
// - default: first outbound in the list, or if the list is empty, equal to direct
|
||||
// If the user-defined outbounds contain any of the above names, they will
|
||||
// override the built-in outbounds.
|
||||
type aclEngine struct {
|
||||
RuleSet acl.CompiledRuleSet[PluggableOutbound]
|
||||
Default PluggableOutbound
|
||||
}
|
||||
|
||||
type OutboundEntry struct {
|
||||
Name string
|
||||
Outbound PluggableOutbound
|
||||
}
|
||||
|
||||
func NewACLEngineFromString(rules string, outbounds []OutboundEntry, geoipFunc func() *geoip2.Reader) (PluggableOutbound, error) {
|
||||
trs, err := acl.ParseTextRules(rules)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
obMap := outboundsToMap(outbounds)
|
||||
rs, err := acl.Compile[PluggableOutbound](trs, obMap, aclCacheSize, geoipFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &aclEngine{rs, obMap["default"]}, nil
|
||||
}
|
||||
|
||||
func NewACLEngineFromFile(filename string, outbounds []OutboundEntry, geoipFunc func() *geoip2.Reader) (PluggableOutbound, error) {
|
||||
bs, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewACLEngineFromString(string(bs), outbounds, geoipFunc)
|
||||
}
|
||||
|
||||
func outboundsToMap(outbounds []OutboundEntry) map[string]PluggableOutbound {
|
||||
obMap := make(map[string]PluggableOutbound)
|
||||
for _, ob := range outbounds {
|
||||
obMap[strings.ToLower(ob.Name)] = ob.Outbound
|
||||
}
|
||||
// Add built-in outbounds if not overridden
|
||||
if _, ok := obMap["direct"]; !ok {
|
||||
obMap["direct"] = NewDirectOutboundSimple(DirectOutboundModeAuto)
|
||||
}
|
||||
if _, ok := obMap["reject"]; !ok {
|
||||
obMap["reject"] = &aclRejectOutbound{}
|
||||
}
|
||||
if _, ok := obMap["default"]; !ok {
|
||||
if len(outbounds) > 0 {
|
||||
obMap["default"] = outbounds[0].Outbound
|
||||
} else {
|
||||
obMap["default"] = obMap["direct"]
|
||||
}
|
||||
}
|
||||
return obMap
|
||||
}
|
||||
|
||||
func (a *aclEngine) handle(reqAddr *AddrEx, proto acl.Protocol) PluggableOutbound {
|
||||
hostInfo := acl.HostInfo{Name: reqAddr.Host}
|
||||
if reqAddr.ResolveInfo != nil {
|
||||
hostInfo.IPv4 = reqAddr.ResolveInfo.IPv4
|
||||
hostInfo.IPv6 = reqAddr.ResolveInfo.IPv6
|
||||
}
|
||||
ob, hijackIP := a.RuleSet.Match(hostInfo, proto, reqAddr.Port)
|
||||
if ob == nil {
|
||||
// No match, use default outbound
|
||||
return a.Default
|
||||
}
|
||||
if hijackIP != nil {
|
||||
// We must rewrite both Host & ResolveInfo,
|
||||
// as some outbounds only care about Host.
|
||||
reqAddr.Host = hijackIP.String()
|
||||
if ip4 := hijackIP.To4(); ip4 != nil {
|
||||
reqAddr.ResolveInfo = &ResolveInfo{IPv4: ip4}
|
||||
} else {
|
||||
reqAddr.ResolveInfo = &ResolveInfo{IPv6: hijackIP}
|
||||
}
|
||||
}
|
||||
return ob
|
||||
}
|
||||
|
||||
func (a *aclEngine) TCP(reqAddr *AddrEx) (net.Conn, error) {
|
||||
ob := a.handle(reqAddr, acl.ProtocolTCP)
|
||||
return ob.TCP(reqAddr)
|
||||
}
|
||||
|
||||
func (a *aclEngine) UDP(reqAddr *AddrEx) (UDPConn, error) {
|
||||
ob := a.handle(reqAddr, acl.ProtocolUDP)
|
||||
return ob.UDP(reqAddr)
|
||||
}
|
||||
|
||||
type aclRejectOutbound struct{}
|
||||
|
||||
func (a *aclRejectOutbound) TCP(reqAddr *AddrEx) (net.Conn, error) {
|
||||
return nil, errRejected
|
||||
}
|
||||
|
||||
func (a *aclRejectOutbound) UDP(reqAddr *AddrEx) (UDPConn, error) {
|
||||
return nil, errRejected
|
||||
}
|
|
@ -10,12 +10,12 @@ import (
|
|||
"github.com/oschwald/geoip2-golang"
|
||||
)
|
||||
|
||||
type protocol int
|
||||
type Protocol int
|
||||
|
||||
const (
|
||||
protocolBoth protocol = iota
|
||||
protocolTCP
|
||||
protocolUDP
|
||||
ProtocolBoth Protocol = iota
|
||||
ProtocolTCP
|
||||
ProtocolUDP
|
||||
)
|
||||
|
||||
type Outbound interface {
|
||||
|
@ -33,19 +33,19 @@ func (h HostInfo) String() string {
|
|||
}
|
||||
|
||||
type CompiledRuleSet[O Outbound] interface {
|
||||
Match(host HostInfo, proto protocol, port uint16) (O, net.IP)
|
||||
Match(host HostInfo, proto Protocol, port uint16) (O, net.IP)
|
||||
}
|
||||
|
||||
type compiledRule[O Outbound] struct {
|
||||
Outbound O
|
||||
HostMatcher hostMatcher
|
||||
Protocol protocol
|
||||
Protocol Protocol
|
||||
Port uint16
|
||||
HijackAddress net.IP
|
||||
}
|
||||
|
||||
func (r *compiledRule[O]) Match(host HostInfo, proto protocol, port uint16) bool {
|
||||
if r.Protocol != protocolBoth && r.Protocol != proto {
|
||||
func (r *compiledRule[O]) Match(host HostInfo, proto Protocol, port uint16) bool {
|
||||
if r.Protocol != ProtocolBoth && r.Protocol != proto {
|
||||
return false
|
||||
}
|
||||
if r.Port != 0 && r.Port != port {
|
||||
|
@ -64,7 +64,7 @@ type compiledRuleSetImpl[O Outbound] struct {
|
|||
Cache *lru.Cache[string, matchResult[O]] // key: HostInfo.String()
|
||||
}
|
||||
|
||||
func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto protocol, port uint16) (O, net.IP) {
|
||||
func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto Protocol, port uint16) (O, net.IP) {
|
||||
host.Name = strings.ToLower(host.Name) // Normalize host name to lower case
|
||||
key := host.String()
|
||||
if result, ok := s.Cache.Get(key); ok {
|
||||
|
@ -92,12 +92,18 @@ func (e *CompilationError) Error() string {
|
|||
return fmt.Sprintf("error at line %d: %s", e.LineNum, e.Message)
|
||||
}
|
||||
|
||||
// Compile compiles TextRules into a CompiledRuleSet.
|
||||
// Names in the outbounds map MUST be in all lower case.
|
||||
// geoipFunc is a function that returns the GeoIP database needed by the GeoIP matcher.
|
||||
// It will be called every time a GeoIP matcher is used during compilation, but won't
|
||||
// be called if there is no GeoIP rule. We use a function here so that database loading
|
||||
// is on-demand (only required if used by rules).
|
||||
func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
||||
cacheSize int, geoipFunc func() *geoip2.Reader,
|
||||
) (CompiledRuleSet[O], error) {
|
||||
compiledRules := make([]compiledRule[O], len(rules))
|
||||
for i, rule := range rules {
|
||||
outbound, ok := outbounds[rule.Outbound]
|
||||
outbound, ok := outbounds[strings.ToLower(rule.Outbound)]
|
||||
if !ok {
|
||||
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("outbound %s not found", rule.Outbound)}
|
||||
}
|
||||
|
@ -137,40 +143,40 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
|||
// [empty] (same as *)
|
||||
//
|
||||
// proto must be either "tcp" or "udp", case-insensitive.
|
||||
func parseProtoPort(protoPort string) (protocol, uint16, bool) {
|
||||
func parseProtoPort(protoPort string) (Protocol, uint16, bool) {
|
||||
protoPort = strings.ToLower(protoPort)
|
||||
if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
|
||||
return protocolBoth, 0, true
|
||||
return ProtocolBoth, 0, true
|
||||
}
|
||||
parts := strings.SplitN(protoPort, "/", 2)
|
||||
if len(parts) == 1 {
|
||||
// No port, only protocol
|
||||
switch parts[0] {
|
||||
case "tcp":
|
||||
return protocolTCP, 0, true
|
||||
return ProtocolTCP, 0, true
|
||||
case "udp":
|
||||
return protocolUDP, 0, true
|
||||
return ProtocolUDP, 0, true
|
||||
default:
|
||||
return protocolBoth, 0, false
|
||||
return ProtocolBoth, 0, false
|
||||
}
|
||||
} else {
|
||||
// Both protocol and port
|
||||
var proto protocol
|
||||
var proto Protocol
|
||||
var port uint16
|
||||
switch parts[0] {
|
||||
case "tcp":
|
||||
proto = protocolTCP
|
||||
proto = ProtocolTCP
|
||||
case "udp":
|
||||
proto = protocolUDP
|
||||
proto = ProtocolUDP
|
||||
case "*":
|
||||
proto = protocolBoth
|
||||
proto = ProtocolBoth
|
||||
default:
|
||||
return protocolBoth, 0, false
|
||||
return ProtocolBoth, 0, false
|
||||
}
|
||||
if parts[1] != "*" {
|
||||
p64, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return protocolBoth, 0, false
|
||||
return ProtocolBoth, 0, false
|
||||
}
|
||||
port = uint16(p64)
|
||||
}
|
||||
|
@ -194,7 +200,7 @@ func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatch
|
|||
if db == nil {
|
||||
return nil, "failed to load GeoIP database"
|
||||
}
|
||||
return &geoIPMatcher{db, country}, ""
|
||||
return &geoipMatcher{db, country}, ""
|
||||
}
|
||||
if strings.Contains(addr, "/") {
|
||||
// CIDR matcher
|
||||
|
|
|
@ -69,7 +69,7 @@ func TestCompile(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
host HostInfo
|
||||
proto protocol
|
||||
proto Protocol
|
||||
port uint16
|
||||
wantOutbound int
|
||||
wantIP net.IP
|
||||
|
@ -78,7 +78,7 @@ func TestCompile(t *testing.T) {
|
|||
host: HostInfo{
|
||||
IPv4: net.ParseIP("1.2.3.4"),
|
||||
},
|
||||
proto: protocolTCP,
|
||||
proto: ProtocolTCP,
|
||||
port: 1234,
|
||||
wantOutbound: ob1,
|
||||
wantIP: nil,
|
||||
|
@ -87,7 +87,7 @@ func TestCompile(t *testing.T) {
|
|||
host: HostInfo{
|
||||
IPv4: net.ParseIP("8.8.8.4"),
|
||||
},
|
||||
proto: protocolUDP,
|
||||
proto: ProtocolUDP,
|
||||
port: 5353,
|
||||
wantOutbound: ob2,
|
||||
wantIP: net.ParseIP("1.1.1.1"),
|
||||
|
@ -96,7 +96,7 @@ func TestCompile(t *testing.T) {
|
|||
host: HostInfo{
|
||||
Name: "lean.delicious.com",
|
||||
},
|
||||
proto: protocolUDP,
|
||||
proto: ProtocolUDP,
|
||||
port: 443,
|
||||
wantOutbound: ob3,
|
||||
wantIP: nil,
|
||||
|
@ -105,7 +105,7 @@ func TestCompile(t *testing.T) {
|
|||
host: HostInfo{
|
||||
IPv6: net.ParseIP("2606:4700::6810:85e5"),
|
||||
},
|
||||
proto: protocolTCP,
|
||||
proto: ProtocolTCP,
|
||||
port: 80,
|
||||
wantOutbound: ob1,
|
||||
wantIP: net.ParseIP("2606:4700::6810:85e6"),
|
||||
|
@ -114,7 +114,7 @@ func TestCompile(t *testing.T) {
|
|||
host: HostInfo{
|
||||
IPv6: net.ParseIP("2606:4700:0:0:0:0:0:1"),
|
||||
},
|
||||
proto: protocolUDP,
|
||||
proto: ProtocolUDP,
|
||||
port: 8888,
|
||||
wantOutbound: ob2,
|
||||
wantIP: nil,
|
||||
|
@ -123,7 +123,7 @@ func TestCompile(t *testing.T) {
|
|||
host: HostInfo{
|
||||
Name: "www.v2ex.com",
|
||||
},
|
||||
proto: protocolUDP,
|
||||
proto: ProtocolUDP,
|
||||
port: 1234,
|
||||
wantOutbound: ob3,
|
||||
wantIP: nil,
|
||||
|
@ -132,7 +132,7 @@ func TestCompile(t *testing.T) {
|
|||
host: HostInfo{
|
||||
Name: "crap.v2ex.com",
|
||||
},
|
||||
proto: protocolTCP,
|
||||
proto: ProtocolTCP,
|
||||
port: 80,
|
||||
wantOutbound: ob1,
|
||||
wantIP: net.ParseIP("2.2.2.2"),
|
||||
|
@ -141,7 +141,7 @@ func TestCompile(t *testing.T) {
|
|||
host: HostInfo{
|
||||
IPv4: net.ParseIP("210.140.92.187"),
|
||||
},
|
||||
proto: protocolTCP,
|
||||
proto: ProtocolTCP,
|
||||
port: 25,
|
||||
wantOutbound: ob2,
|
||||
wantIP: nil,
|
||||
|
|
|
@ -55,12 +55,12 @@ func deepMatchRune(str, pattern []rune) bool {
|
|||
return len(str) == 0 && len(pattern) == 0
|
||||
}
|
||||
|
||||
type geoIPMatcher struct {
|
||||
type geoipMatcher struct {
|
||||
DB *geoip2.Reader
|
||||
Country string // must be uppercase ISO 3166-1 alpha-2 code
|
||||
}
|
||||
|
||||
func (m *geoIPMatcher) Match(host HostInfo) bool {
|
||||
func (m *geoipMatcher) Match(host HostInfo) bool {
|
||||
if host.IPv4 != nil {
|
||||
record, err := m.DB.Country(host.IPv4)
|
||||
if err == nil && record.Country.IsoCode == m.Country {
|
||||
|
|
|
@ -234,7 +234,7 @@ func Test_domainMatcher_Match(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_geoIPMatcher_Match(t *testing.T) {
|
||||
func Test_geoipMatcher_Match(t *testing.T) {
|
||||
db, err := geoip2.Open("GeoLite2-Country.mmdb")
|
||||
assert.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
@ -298,7 +298,7 @@ func Test_geoIPMatcher_Match(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &geoIPMatcher{
|
||||
m := &geoipMatcher{
|
||||
DB: tt.fields.DB,
|
||||
Country: tt.fields.Country,
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package acl
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
@ -71,11 +70,3 @@ func ParseTextRules(text string) ([]TextRule, error) {
|
|||
}
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func ParseTextRulesFile(filename string) ([]TextRule, error) {
|
||||
bs, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ParseTextRules(string(bs))
|
||||
}
|
||||
|
|
61
extras/outbounds/acl_test.go
Normal file
61
extras/outbounds/acl_test.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package outbounds
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestACLEngine(t *testing.T) {
|
||||
ob1, ob2, ob3 := &mockPluggableOutbound{}, &mockPluggableOutbound{}, &mockPluggableOutbound{}
|
||||
obs := []OutboundEntry{
|
||||
{"ob1", ob1},
|
||||
{"ob2", ob2},
|
||||
{"ob3", ob3},
|
||||
{"direct", ob2},
|
||||
}
|
||||
acl, err := NewACLEngineFromString(`
|
||||
ob2(google.com,tcp)
|
||||
ob3(youtube.com,udp)
|
||||
ob1 (1.1.1.1/24,*,8.8.8.8)
|
||||
Direct(cia.gov)
|
||||
reJect(nsa.gov)
|
||||
`, obs, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// No match, default, should be the first (ob1)
|
||||
ob1.EXPECT().TCP(&AddrEx{Host: "example.com"}).Return(nil, nil).Once()
|
||||
conn, err := acl.TCP(&AddrEx{Host: "example.com"})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, conn)
|
||||
|
||||
// Match ob2
|
||||
ob2.EXPECT().TCP(&AddrEx{Host: "google.com"}).Return(nil, nil).Once()
|
||||
conn, err = acl.TCP(&AddrEx{Host: "google.com"})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, conn)
|
||||
|
||||
// Match ob3
|
||||
ob3.EXPECT().UDP(&AddrEx{Host: "youtube.com"}).Return(nil, nil).Once()
|
||||
udpConn, err := acl.UDP(&AddrEx{Host: "youtube.com"})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, udpConn)
|
||||
|
||||
// Match ob1 hijack IP
|
||||
ob1.EXPECT().TCP(&AddrEx{Host: "8.8.8.8", ResolveInfo: &ResolveInfo{IPv4: net.ParseIP("8.8.8.8").To4()}}).Return(nil, nil).Once()
|
||||
conn, err = acl.TCP(&AddrEx{ResolveInfo: &ResolveInfo{IPv4: net.ParseIP("1.1.1.22")}})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, conn)
|
||||
|
||||
// direct should be ob2 as we override it
|
||||
ob2.EXPECT().TCP(&AddrEx{Host: "cia.gov"}).Return(nil, nil).Once()
|
||||
conn, err = acl.TCP(&AddrEx{Host: "cia.gov"})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, conn)
|
||||
|
||||
// reject
|
||||
conn, err = acl.TCP(&AddrEx{Host: "nsa.gov"})
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, conn)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue