diff --git a/app/cmd/server.go b/app/cmd/server.go index 910de8b..63d6d6d 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -44,6 +44,7 @@ type serverConfig struct { UDPIdleTimeout time.Duration `mapstructure:"udpIdleTimeout"` Auth serverConfigAuth `mapstructure:"auth"` Resolver serverConfigResolver `mapstructure:"resolver"` + ACL serverConfigACL `mapstructure:"acl"` Outbounds []serverConfigOutboundEntry `mapstructure:"outbounds"` Masquerade serverConfigMasquerade `mapstructure:"masquerade"` } @@ -133,6 +134,12 @@ type serverConfigResolver struct { HTTPS serverConfigResolverHTTPS `mapstructure:"https"` } +type serverConfigACL struct { + File string `mapstructure:"file"` + Inline []string `mapstructure:"inline"` + GeoIP string `mapstructure:"geoip"` +} + type serverConfigOutboundDirect struct { Mode string `mapstructure:"mode"` BindIPv4 string `mapstructure:"bindIPv4"` @@ -314,22 +321,60 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error { // Resolver(ACL(Outbounds...)) // Outbounds - var ob outbounds.PluggableOutbound + var obs []outbounds.OutboundEntry if len(c.Outbounds) == 0 { - ob = outbounds.NewDirectOutboundSimple(outbounds.DirectOutboundModeAuto) + // Guarantee we have at least one outbound + obs = []outbounds.OutboundEntry{{ + Name: "default", + Outbound: outbounds.NewDirectOutboundSimple(outbounds.DirectOutboundModeAuto), + }} } else { - // Multiple-outbound is for ACL only, not supported yet. - var err error - entry := c.Outbounds[0] - switch strings.ToLower(entry.Type) { - case "direct": - ob, err = serverConfigOutboundDirectToOutbound(entry.Direct) - default: - err = configError{Field: "outbounds.type", Err: errors.New("unsupported outbound type")} + obs = make([]outbounds.OutboundEntry, len(c.Outbounds)) + for i, entry := range c.Outbounds { + if entry.Name == "" { + return configError{Field: "outbounds.name", Err: errors.New("empty outbound name")} + } + var ob outbounds.PluggableOutbound + var err error + switch strings.ToLower(entry.Type) { + case "direct": + ob, err = serverConfigOutboundDirectToOutbound(entry.Direct) + default: + err = configError{Field: "outbounds.type", Err: errors.New("unsupported outbound type")} + } + if err != nil { + return err + } + obs[i] = outbounds.OutboundEntry{Name: entry.Name, Outbound: ob} } + } + + var uOb outbounds.PluggableOutbound // "unified" outbound + + // ACL + if c.ACL.File != "" && len(c.ACL.Inline) > 0 { + return configError{Field: "acl", Err: errors.New("cannot set both acl.file and acl.inline")} + } + gLoader := &geoipLoader{ + Filename: c.ACL.GeoIP, + DownloadFunc: geoipDownloadFunc, + DownloadErrFunc: geoipDownloadErrFunc, + } + if c.ACL.File != "" { + acl, err := outbounds.NewACLEngineFromFile(c.ACL.File, obs, gLoader.Load) if err != nil { - return err + return configError{Field: "acl.file", Err: err} } + uOb = acl + } else if len(c.ACL.Inline) > 0 { + acl, err := outbounds.NewACLEngineFromString(strings.Join(c.ACL.Inline, "\n"), obs, gLoader.Load) + if err != nil { + return configError{Field: "acl.inline", Err: err} + } + uOb = acl + } else { + // No ACL, use the first outbound + uOb = obs[0].Outbound } // Resolver @@ -340,27 +385,27 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error { if c.Resolver.TCP.Addr == "" { return configError{Field: "resolver.tcp.addr", Err: errors.New("empty resolver address")} } - ob = outbounds.NewStandardResolverTCP(c.Resolver.TCP.Addr, c.Resolver.TCP.Timeout, ob) + uOb = outbounds.NewStandardResolverTCP(c.Resolver.TCP.Addr, c.Resolver.TCP.Timeout, uOb) case "udp": if c.Resolver.UDP.Addr == "" { return configError{Field: "resolver.udp.addr", Err: errors.New("empty resolver address")} } - ob = outbounds.NewStandardResolverUDP(c.Resolver.UDP.Addr, c.Resolver.UDP.Timeout, ob) + uOb = outbounds.NewStandardResolverUDP(c.Resolver.UDP.Addr, c.Resolver.UDP.Timeout, uOb) case "tls", "tcp-tls": if c.Resolver.TLS.Addr == "" { return configError{Field: "resolver.tls.addr", Err: errors.New("empty resolver address")} } - ob = outbounds.NewStandardResolverTLS(c.Resolver.TLS.Addr, c.Resolver.TLS.Timeout, c.Resolver.TLS.SNI, c.Resolver.TLS.Insecure, ob) + uOb = outbounds.NewStandardResolverTLS(c.Resolver.TLS.Addr, c.Resolver.TLS.Timeout, c.Resolver.TLS.SNI, c.Resolver.TLS.Insecure, uOb) case "https", "http": if c.Resolver.HTTPS.Addr == "" { return configError{Field: "resolver.https.addr", Err: errors.New("empty resolver address")} } - ob = outbounds.NewDoHResolver(c.Resolver.HTTPS.Addr, c.Resolver.HTTPS.Timeout, c.Resolver.HTTPS.SNI, c.Resolver.HTTPS.Insecure, ob) + uOb = outbounds.NewDoHResolver(c.Resolver.HTTPS.Addr, c.Resolver.HTTPS.Timeout, c.Resolver.HTTPS.SNI, c.Resolver.HTTPS.Insecure, uOb) default: return configError{Field: "resolver.type", Err: errors.New("unsupported resolver type")} } - hyConfig.Outbound = &outbounds.PluggableOutboundAdapter{PluggableOutbound: ob} + hyConfig.Outbound = &outbounds.PluggableOutboundAdapter{PluggableOutbound: uOb} return nil } @@ -525,6 +570,16 @@ func runServer(cmd *cobra.Command, args []string) { } } +func geoipDownloadFunc(filename, url string) { + logger.Info("downloading GeoIP database", zap.String("filename", filename), zap.String("url", url)) +} + +func geoipDownloadErrFunc(err error) { + if err != nil { + logger.Error("failed to download GeoIP database", zap.Error(err)) + } +} + type serverLogger struct{} func (l *serverLogger) Connect(addr net.Addr, id string, tx uint64) { diff --git a/app/cmd/server_test.go b/app/cmd/server_test.go index 9402dea..961693f 100644 --- a/app/cmd/server_test.go +++ b/app/cmd/server_test.go @@ -95,6 +95,13 @@ func TestServerConfig(t *testing.T) { Insecure: true, }, }, + ACL: serverConfigACL{ + File: "chnroute.txt", + Inline: []string{ + "lmao(ok)", + "kek(cringe,boba,tea)", + }, + }, Outbounds: []serverConfigOutboundEntry{ { Name: "goodstuff", diff --git a/app/cmd/server_test.yaml b/app/cmd/server_test.yaml index 11e3c95..1bcc599 100644 --- a/app/cmd/server_test.yaml +++ b/app/cmd/server_test.yaml @@ -70,6 +70,12 @@ resolver: sni: real.stuff.net insecure: true +acl: + file: chnroute.txt + inline: + - lmao(ok) + - kek(cringe,boba,tea) + outbounds: - name: goodstuff type: direct diff --git a/app/cmd/utils.go b/app/cmd/utils.go index 8ecb5d6..3eaff48 100644 --- a/app/cmd/utils.go +++ b/app/cmd/utils.go @@ -2,10 +2,18 @@ package cmd import ( "fmt" + "io" + "net/http" "os" "github.com/apernet/hysteria/extras/utils" "github.com/mdp/qrterminal/v3" + "github.com/oschwald/geoip2-golang" +) + +const ( + geoipDefaultFilename = "GeoLite2-Country.mmdb" + geoipDownloadURL = "https://git.io/GeoLite2-Country.mmdb" ) // convBandwidth handles both string and int types for bandwidth. @@ -44,3 +52,59 @@ func (e configError) Error() string { func (e configError) Unwrap() error { return e.Err } + +// geoipLoader provides the on-demand GeoIP database loading function required by the ACL engine. +type geoipLoader struct { + Filename string + DownloadFunc func(filename, url string) // Called when downloading the GeoIP database. + DownloadErrFunc func(err error) // Called when downloading the GeoIP database succeeds/fails. + + db *geoip2.Reader +} + +func (l *geoipLoader) download() error { + resp, err := http.Get(geoipDownloadURL) + if err != nil { + return err + } + defer resp.Body.Close() + + f, err := os.Create(geoipDefaultFilename) + if err != nil { + return err + } + defer f.Close() + + _, err = io.Copy(f, resp.Body) + return err +} + +func (l *geoipLoader) Load() *geoip2.Reader { + if l.db == nil { + if l.Filename == "" { + // Filename not specified, try default. + if _, err := os.Stat(geoipDefaultFilename); err == nil { + // Default already exists, just use it. + l.Filename = geoipDefaultFilename + } else if os.IsNotExist(err) { + // Default doesn't exist, download it. + l.DownloadFunc(geoipDefaultFilename, geoipDownloadURL) + err := l.download() + l.DownloadErrFunc(err) + if err != nil { + return nil + } + l.Filename = geoipDefaultFilename + } else { + // Other error + return nil + } + } + db, err := geoip2.Open(l.Filename) + if err != nil { + return nil + } + l.db = db + } + return l.db +} diff --git a/app/server.example.yaml b/app/server.example.yaml index 853ace4..c4206f0 100644 --- a/app/server.example.yaml +++ b/app/server.example.yaml @@ -48,11 +48,18 @@ auth: # sni: server1.yolo.net # insecure: true +# acl: +# inline: +# - haha(8.8.8.8/24, udp/53) +# - reject(v2ex.com) +# - reject(*.v2ex.com) +# - reject(geoip:cn) + # outbounds: # - name: haha # type: direct # direct: -# mode: auto +# mode: 46 # bindIPv4: 2.4.6.8 # bindIPv6: 0:0:0:0:0:ffff:0204:0608 # bindDevice: eth233 diff --git a/extras/outbounds/acl.go b/extras/outbounds/acl.go new file mode 100644 index 0000000..d8df63d --- /dev/null +++ b/extras/outbounds/acl.go @@ -0,0 +1,122 @@ +package outbounds + +import ( + "errors" + "net" + "os" + "strings" + + "github.com/apernet/hysteria/extras/outbounds/acl" + "github.com/oschwald/geoip2-golang" +) + +const ( + aclCacheSize = 1024 +) + +var errRejected = errors.New("rejected") + +// aclEngine is a PluggableOutbound that dispatches connections to different +// outbounds based on ACL rules. +// There are 3 built-in outbounds: +// - direct: directOutbound, auto mode +// - reject: reject the connection +// - default: first outbound in the list, or if the list is empty, equal to direct +// If the user-defined outbounds contain any of the above names, they will +// override the built-in outbounds. +type aclEngine struct { + RuleSet acl.CompiledRuleSet[PluggableOutbound] + Default PluggableOutbound +} + +type OutboundEntry struct { + Name string + Outbound PluggableOutbound +} + +func NewACLEngineFromString(rules string, outbounds []OutboundEntry, geoipFunc func() *geoip2.Reader) (PluggableOutbound, error) { + trs, err := acl.ParseTextRules(rules) + if err != nil { + return nil, err + } + obMap := outboundsToMap(outbounds) + rs, err := acl.Compile[PluggableOutbound](trs, obMap, aclCacheSize, geoipFunc) + if err != nil { + return nil, err + } + return &aclEngine{rs, obMap["default"]}, nil +} + +func NewACLEngineFromFile(filename string, outbounds []OutboundEntry, geoipFunc func() *geoip2.Reader) (PluggableOutbound, error) { + bs, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + return NewACLEngineFromString(string(bs), outbounds, geoipFunc) +} + +func outboundsToMap(outbounds []OutboundEntry) map[string]PluggableOutbound { + obMap := make(map[string]PluggableOutbound) + for _, ob := range outbounds { + obMap[strings.ToLower(ob.Name)] = ob.Outbound + } + // Add built-in outbounds if not overridden + if _, ok := obMap["direct"]; !ok { + obMap["direct"] = NewDirectOutboundSimple(DirectOutboundModeAuto) + } + if _, ok := obMap["reject"]; !ok { + obMap["reject"] = &aclRejectOutbound{} + } + if _, ok := obMap["default"]; !ok { + if len(outbounds) > 0 { + obMap["default"] = outbounds[0].Outbound + } else { + obMap["default"] = obMap["direct"] + } + } + return obMap +} + +func (a *aclEngine) handle(reqAddr *AddrEx, proto acl.Protocol) PluggableOutbound { + hostInfo := acl.HostInfo{Name: reqAddr.Host} + if reqAddr.ResolveInfo != nil { + hostInfo.IPv4 = reqAddr.ResolveInfo.IPv4 + hostInfo.IPv6 = reqAddr.ResolveInfo.IPv6 + } + ob, hijackIP := a.RuleSet.Match(hostInfo, proto, reqAddr.Port) + if ob == nil { + // No match, use default outbound + return a.Default + } + if hijackIP != nil { + // We must rewrite both Host & ResolveInfo, + // as some outbounds only care about Host. + reqAddr.Host = hijackIP.String() + if ip4 := hijackIP.To4(); ip4 != nil { + reqAddr.ResolveInfo = &ResolveInfo{IPv4: ip4} + } else { + reqAddr.ResolveInfo = &ResolveInfo{IPv6: hijackIP} + } + } + return ob +} + +func (a *aclEngine) TCP(reqAddr *AddrEx) (net.Conn, error) { + ob := a.handle(reqAddr, acl.ProtocolTCP) + return ob.TCP(reqAddr) +} + +func (a *aclEngine) UDP(reqAddr *AddrEx) (UDPConn, error) { + ob := a.handle(reqAddr, acl.ProtocolUDP) + return ob.UDP(reqAddr) +} + +type aclRejectOutbound struct{} + +func (a *aclRejectOutbound) TCP(reqAddr *AddrEx) (net.Conn, error) { + return nil, errRejected +} + +func (a *aclRejectOutbound) UDP(reqAddr *AddrEx) (UDPConn, error) { + return nil, errRejected +} diff --git a/extras/outbounds/acl/compile.go b/extras/outbounds/acl/compile.go index 366ffd6..e68f3e0 100644 --- a/extras/outbounds/acl/compile.go +++ b/extras/outbounds/acl/compile.go @@ -10,12 +10,12 @@ import ( "github.com/oschwald/geoip2-golang" ) -type protocol int +type Protocol int const ( - protocolBoth protocol = iota - protocolTCP - protocolUDP + ProtocolBoth Protocol = iota + ProtocolTCP + ProtocolUDP ) type Outbound interface { @@ -33,19 +33,19 @@ func (h HostInfo) String() string { } type CompiledRuleSet[O Outbound] interface { - Match(host HostInfo, proto protocol, port uint16) (O, net.IP) + Match(host HostInfo, proto Protocol, port uint16) (O, net.IP) } type compiledRule[O Outbound] struct { Outbound O HostMatcher hostMatcher - Protocol protocol + Protocol Protocol Port uint16 HijackAddress net.IP } -func (r *compiledRule[O]) Match(host HostInfo, proto protocol, port uint16) bool { - if r.Protocol != protocolBoth && r.Protocol != proto { +func (r *compiledRule[O]) Match(host HostInfo, proto Protocol, port uint16) bool { + if r.Protocol != ProtocolBoth && r.Protocol != proto { return false } if r.Port != 0 && r.Port != port { @@ -64,7 +64,7 @@ type compiledRuleSetImpl[O Outbound] struct { Cache *lru.Cache[string, matchResult[O]] // key: HostInfo.String() } -func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto protocol, port uint16) (O, net.IP) { +func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto Protocol, port uint16) (O, net.IP) { host.Name = strings.ToLower(host.Name) // Normalize host name to lower case key := host.String() if result, ok := s.Cache.Get(key); ok { @@ -92,12 +92,18 @@ func (e *CompilationError) Error() string { return fmt.Sprintf("error at line %d: %s", e.LineNum, e.Message) } +// Compile compiles TextRules into a CompiledRuleSet. +// Names in the outbounds map MUST be in all lower case. +// geoipFunc is a function that returns the GeoIP database needed by the GeoIP matcher. +// It will be called every time a GeoIP matcher is used during compilation, but won't +// be called if there is no GeoIP rule. We use a function here so that database loading +// is on-demand (only required if used by rules). func Compile[O Outbound](rules []TextRule, outbounds map[string]O, cacheSize int, geoipFunc func() *geoip2.Reader, ) (CompiledRuleSet[O], error) { compiledRules := make([]compiledRule[O], len(rules)) for i, rule := range rules { - outbound, ok := outbounds[rule.Outbound] + outbound, ok := outbounds[strings.ToLower(rule.Outbound)] if !ok { return nil, &CompilationError{rule.LineNum, fmt.Sprintf("outbound %s not found", rule.Outbound)} } @@ -137,40 +143,40 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O, // [empty] (same as *) // // proto must be either "tcp" or "udp", case-insensitive. -func parseProtoPort(protoPort string) (protocol, uint16, bool) { +func parseProtoPort(protoPort string) (Protocol, uint16, bool) { protoPort = strings.ToLower(protoPort) if protoPort == "" || protoPort == "*" || protoPort == "*/*" { - return protocolBoth, 0, true + return ProtocolBoth, 0, true } parts := strings.SplitN(protoPort, "/", 2) if len(parts) == 1 { // No port, only protocol switch parts[0] { case "tcp": - return protocolTCP, 0, true + return ProtocolTCP, 0, true case "udp": - return protocolUDP, 0, true + return ProtocolUDP, 0, true default: - return protocolBoth, 0, false + return ProtocolBoth, 0, false } } else { // Both protocol and port - var proto protocol + var proto Protocol var port uint16 switch parts[0] { case "tcp": - proto = protocolTCP + proto = ProtocolTCP case "udp": - proto = protocolUDP + proto = ProtocolUDP case "*": - proto = protocolBoth + proto = ProtocolBoth default: - return protocolBoth, 0, false + return ProtocolBoth, 0, false } if parts[1] != "*" { p64, err := strconv.ParseUint(parts[1], 10, 16) if err != nil { - return protocolBoth, 0, false + return ProtocolBoth, 0, false } port = uint16(p64) } @@ -194,7 +200,7 @@ func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatch if db == nil { return nil, "failed to load GeoIP database" } - return &geoIPMatcher{db, country}, "" + return &geoipMatcher{db, country}, "" } if strings.Contains(addr, "/") { // CIDR matcher diff --git a/extras/outbounds/acl/compile_test.go b/extras/outbounds/acl/compile_test.go index 6f41112..8f61229 100644 --- a/extras/outbounds/acl/compile_test.go +++ b/extras/outbounds/acl/compile_test.go @@ -69,7 +69,7 @@ func TestCompile(t *testing.T) { tests := []struct { host HostInfo - proto protocol + proto Protocol port uint16 wantOutbound int wantIP net.IP @@ -78,7 +78,7 @@ func TestCompile(t *testing.T) { host: HostInfo{ IPv4: net.ParseIP("1.2.3.4"), }, - proto: protocolTCP, + proto: ProtocolTCP, port: 1234, wantOutbound: ob1, wantIP: nil, @@ -87,7 +87,7 @@ func TestCompile(t *testing.T) { host: HostInfo{ IPv4: net.ParseIP("8.8.8.4"), }, - proto: protocolUDP, + proto: ProtocolUDP, port: 5353, wantOutbound: ob2, wantIP: net.ParseIP("1.1.1.1"), @@ -96,7 +96,7 @@ func TestCompile(t *testing.T) { host: HostInfo{ Name: "lean.delicious.com", }, - proto: protocolUDP, + proto: ProtocolUDP, port: 443, wantOutbound: ob3, wantIP: nil, @@ -105,7 +105,7 @@ func TestCompile(t *testing.T) { host: HostInfo{ IPv6: net.ParseIP("2606:4700::6810:85e5"), }, - proto: protocolTCP, + proto: ProtocolTCP, port: 80, wantOutbound: ob1, wantIP: net.ParseIP("2606:4700::6810:85e6"), @@ -114,7 +114,7 @@ func TestCompile(t *testing.T) { host: HostInfo{ IPv6: net.ParseIP("2606:4700:0:0:0:0:0:1"), }, - proto: protocolUDP, + proto: ProtocolUDP, port: 8888, wantOutbound: ob2, wantIP: nil, @@ -123,7 +123,7 @@ func TestCompile(t *testing.T) { host: HostInfo{ Name: "www.v2ex.com", }, - proto: protocolUDP, + proto: ProtocolUDP, port: 1234, wantOutbound: ob3, wantIP: nil, @@ -132,7 +132,7 @@ func TestCompile(t *testing.T) { host: HostInfo{ Name: "crap.v2ex.com", }, - proto: protocolTCP, + proto: ProtocolTCP, port: 80, wantOutbound: ob1, wantIP: net.ParseIP("2.2.2.2"), @@ -141,7 +141,7 @@ func TestCompile(t *testing.T) { host: HostInfo{ IPv4: net.ParseIP("210.140.92.187"), }, - proto: protocolTCP, + proto: ProtocolTCP, port: 25, wantOutbound: ob2, wantIP: nil, diff --git a/extras/outbounds/acl/matchers.go b/extras/outbounds/acl/matchers.go index 33ec53a..5f9b3d2 100644 --- a/extras/outbounds/acl/matchers.go +++ b/extras/outbounds/acl/matchers.go @@ -55,12 +55,12 @@ func deepMatchRune(str, pattern []rune) bool { return len(str) == 0 && len(pattern) == 0 } -type geoIPMatcher struct { +type geoipMatcher struct { DB *geoip2.Reader Country string // must be uppercase ISO 3166-1 alpha-2 code } -func (m *geoIPMatcher) Match(host HostInfo) bool { +func (m *geoipMatcher) Match(host HostInfo) bool { if host.IPv4 != nil { record, err := m.DB.Country(host.IPv4) if err == nil && record.Country.IsoCode == m.Country { diff --git a/extras/outbounds/acl/matchers_test.go b/extras/outbounds/acl/matchers_test.go index 5ec3884..871b265 100644 --- a/extras/outbounds/acl/matchers_test.go +++ b/extras/outbounds/acl/matchers_test.go @@ -234,7 +234,7 @@ func Test_domainMatcher_Match(t *testing.T) { } } -func Test_geoIPMatcher_Match(t *testing.T) { +func Test_geoipMatcher_Match(t *testing.T) { db, err := geoip2.Open("GeoLite2-Country.mmdb") assert.NoError(t, err) defer db.Close() @@ -298,7 +298,7 @@ func Test_geoIPMatcher_Match(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m := &geoIPMatcher{ + m := &geoipMatcher{ DB: tt.fields.DB, Country: tt.fields.Country, } diff --git a/extras/outbounds/acl/parse.go b/extras/outbounds/acl/parse.go index 11cae44..760514d 100644 --- a/extras/outbounds/acl/parse.go +++ b/extras/outbounds/acl/parse.go @@ -2,7 +2,6 @@ package acl import ( "fmt" - "os" "regexp" "strings" ) @@ -71,11 +70,3 @@ func ParseTextRules(text string) ([]TextRule, error) { } return rules, nil } - -func ParseTextRulesFile(filename string) ([]TextRule, error) { - bs, err := os.ReadFile(filename) - if err != nil { - return nil, err - } - return ParseTextRules(string(bs)) -} diff --git a/extras/outbounds/acl_test.go b/extras/outbounds/acl_test.go new file mode 100644 index 0000000..9c68890 --- /dev/null +++ b/extras/outbounds/acl_test.go @@ -0,0 +1,61 @@ +package outbounds + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestACLEngine(t *testing.T) { + ob1, ob2, ob3 := &mockPluggableOutbound{}, &mockPluggableOutbound{}, &mockPluggableOutbound{} + obs := []OutboundEntry{ + {"ob1", ob1}, + {"ob2", ob2}, + {"ob3", ob3}, + {"direct", ob2}, + } + acl, err := NewACLEngineFromString(` +ob2(google.com,tcp) +ob3(youtube.com,udp) +ob1 (1.1.1.1/24,*,8.8.8.8) +Direct(cia.gov) +reJect(nsa.gov) +`, obs, nil) + assert.NoError(t, err) + + // No match, default, should be the first (ob1) + ob1.EXPECT().TCP(&AddrEx{Host: "example.com"}).Return(nil, nil).Once() + conn, err := acl.TCP(&AddrEx{Host: "example.com"}) + assert.NoError(t, err) + assert.Nil(t, conn) + + // Match ob2 + ob2.EXPECT().TCP(&AddrEx{Host: "google.com"}).Return(nil, nil).Once() + conn, err = acl.TCP(&AddrEx{Host: "google.com"}) + assert.NoError(t, err) + assert.Nil(t, conn) + + // Match ob3 + ob3.EXPECT().UDP(&AddrEx{Host: "youtube.com"}).Return(nil, nil).Once() + udpConn, err := acl.UDP(&AddrEx{Host: "youtube.com"}) + assert.NoError(t, err) + assert.Nil(t, udpConn) + + // Match ob1 hijack IP + ob1.EXPECT().TCP(&AddrEx{Host: "8.8.8.8", ResolveInfo: &ResolveInfo{IPv4: net.ParseIP("8.8.8.8").To4()}}).Return(nil, nil).Once() + conn, err = acl.TCP(&AddrEx{ResolveInfo: &ResolveInfo{IPv4: net.ParseIP("1.1.1.22")}}) + assert.NoError(t, err) + assert.Nil(t, conn) + + // direct should be ob2 as we override it + ob2.EXPECT().TCP(&AddrEx{Host: "cia.gov"}).Return(nil, nil).Once() + conn, err = acl.TCP(&AddrEx{Host: "cia.gov"}) + assert.NoError(t, err) + assert.Nil(t, conn) + + // reject + conn, err = acl.TCP(&AddrEx{Host: "nsa.gov"}) + assert.Error(t, err) + assert.Nil(t, conn) +}