mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-06 05:57:38 +03:00
feat: ACL
This commit is contained in:
parent
6fa958815b
commit
a7d74a9ec1
12 changed files with 380 additions and 61 deletions
|
@ -44,6 +44,7 @@ type serverConfig struct {
|
||||||
UDPIdleTimeout time.Duration `mapstructure:"udpIdleTimeout"`
|
UDPIdleTimeout time.Duration `mapstructure:"udpIdleTimeout"`
|
||||||
Auth serverConfigAuth `mapstructure:"auth"`
|
Auth serverConfigAuth `mapstructure:"auth"`
|
||||||
Resolver serverConfigResolver `mapstructure:"resolver"`
|
Resolver serverConfigResolver `mapstructure:"resolver"`
|
||||||
|
ACL serverConfigACL `mapstructure:"acl"`
|
||||||
Outbounds []serverConfigOutboundEntry `mapstructure:"outbounds"`
|
Outbounds []serverConfigOutboundEntry `mapstructure:"outbounds"`
|
||||||
Masquerade serverConfigMasquerade `mapstructure:"masquerade"`
|
Masquerade serverConfigMasquerade `mapstructure:"masquerade"`
|
||||||
}
|
}
|
||||||
|
@ -133,6 +134,12 @@ type serverConfigResolver struct {
|
||||||
HTTPS serverConfigResolverHTTPS `mapstructure:"https"`
|
HTTPS serverConfigResolverHTTPS `mapstructure:"https"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type serverConfigACL struct {
|
||||||
|
File string `mapstructure:"file"`
|
||||||
|
Inline []string `mapstructure:"inline"`
|
||||||
|
GeoIP string `mapstructure:"geoip"`
|
||||||
|
}
|
||||||
|
|
||||||
type serverConfigOutboundDirect struct {
|
type serverConfigOutboundDirect struct {
|
||||||
Mode string `mapstructure:"mode"`
|
Mode string `mapstructure:"mode"`
|
||||||
BindIPv4 string `mapstructure:"bindIPv4"`
|
BindIPv4 string `mapstructure:"bindIPv4"`
|
||||||
|
@ -314,13 +321,21 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error {
|
||||||
// Resolver(ACL(Outbounds...))
|
// Resolver(ACL(Outbounds...))
|
||||||
|
|
||||||
// Outbounds
|
// Outbounds
|
||||||
var ob outbounds.PluggableOutbound
|
var obs []outbounds.OutboundEntry
|
||||||
if len(c.Outbounds) == 0 {
|
if len(c.Outbounds) == 0 {
|
||||||
ob = outbounds.NewDirectOutboundSimple(outbounds.DirectOutboundModeAuto)
|
// Guarantee we have at least one outbound
|
||||||
|
obs = []outbounds.OutboundEntry{{
|
||||||
|
Name: "default",
|
||||||
|
Outbound: outbounds.NewDirectOutboundSimple(outbounds.DirectOutboundModeAuto),
|
||||||
|
}}
|
||||||
} else {
|
} else {
|
||||||
// Multiple-outbound is for ACL only, not supported yet.
|
obs = make([]outbounds.OutboundEntry, len(c.Outbounds))
|
||||||
|
for i, entry := range c.Outbounds {
|
||||||
|
if entry.Name == "" {
|
||||||
|
return configError{Field: "outbounds.name", Err: errors.New("empty outbound name")}
|
||||||
|
}
|
||||||
|
var ob outbounds.PluggableOutbound
|
||||||
var err error
|
var err error
|
||||||
entry := c.Outbounds[0]
|
|
||||||
switch strings.ToLower(entry.Type) {
|
switch strings.ToLower(entry.Type) {
|
||||||
case "direct":
|
case "direct":
|
||||||
ob, err = serverConfigOutboundDirectToOutbound(entry.Direct)
|
ob, err = serverConfigOutboundDirectToOutbound(entry.Direct)
|
||||||
|
@ -330,6 +345,36 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
obs[i] = outbounds.OutboundEntry{Name: entry.Name, Outbound: ob}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var uOb outbounds.PluggableOutbound // "unified" outbound
|
||||||
|
|
||||||
|
// ACL
|
||||||
|
if c.ACL.File != "" && len(c.ACL.Inline) > 0 {
|
||||||
|
return configError{Field: "acl", Err: errors.New("cannot set both acl.file and acl.inline")}
|
||||||
|
}
|
||||||
|
gLoader := &geoipLoader{
|
||||||
|
Filename: c.ACL.GeoIP,
|
||||||
|
DownloadFunc: geoipDownloadFunc,
|
||||||
|
DownloadErrFunc: geoipDownloadErrFunc,
|
||||||
|
}
|
||||||
|
if c.ACL.File != "" {
|
||||||
|
acl, err := outbounds.NewACLEngineFromFile(c.ACL.File, obs, gLoader.Load)
|
||||||
|
if err != nil {
|
||||||
|
return configError{Field: "acl.file", Err: err}
|
||||||
|
}
|
||||||
|
uOb = acl
|
||||||
|
} else if len(c.ACL.Inline) > 0 {
|
||||||
|
acl, err := outbounds.NewACLEngineFromString(strings.Join(c.ACL.Inline, "\n"), obs, gLoader.Load)
|
||||||
|
if err != nil {
|
||||||
|
return configError{Field: "acl.inline", Err: err}
|
||||||
|
}
|
||||||
|
uOb = acl
|
||||||
|
} else {
|
||||||
|
// No ACL, use the first outbound
|
||||||
|
uOb = obs[0].Outbound
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolver
|
// Resolver
|
||||||
|
@ -340,27 +385,27 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error {
|
||||||
if c.Resolver.TCP.Addr == "" {
|
if c.Resolver.TCP.Addr == "" {
|
||||||
return configError{Field: "resolver.tcp.addr", Err: errors.New("empty resolver address")}
|
return configError{Field: "resolver.tcp.addr", Err: errors.New("empty resolver address")}
|
||||||
}
|
}
|
||||||
ob = outbounds.NewStandardResolverTCP(c.Resolver.TCP.Addr, c.Resolver.TCP.Timeout, ob)
|
uOb = outbounds.NewStandardResolverTCP(c.Resolver.TCP.Addr, c.Resolver.TCP.Timeout, uOb)
|
||||||
case "udp":
|
case "udp":
|
||||||
if c.Resolver.UDP.Addr == "" {
|
if c.Resolver.UDP.Addr == "" {
|
||||||
return configError{Field: "resolver.udp.addr", Err: errors.New("empty resolver address")}
|
return configError{Field: "resolver.udp.addr", Err: errors.New("empty resolver address")}
|
||||||
}
|
}
|
||||||
ob = outbounds.NewStandardResolverUDP(c.Resolver.UDP.Addr, c.Resolver.UDP.Timeout, ob)
|
uOb = outbounds.NewStandardResolverUDP(c.Resolver.UDP.Addr, c.Resolver.UDP.Timeout, uOb)
|
||||||
case "tls", "tcp-tls":
|
case "tls", "tcp-tls":
|
||||||
if c.Resolver.TLS.Addr == "" {
|
if c.Resolver.TLS.Addr == "" {
|
||||||
return configError{Field: "resolver.tls.addr", Err: errors.New("empty resolver address")}
|
return configError{Field: "resolver.tls.addr", Err: errors.New("empty resolver address")}
|
||||||
}
|
}
|
||||||
ob = outbounds.NewStandardResolverTLS(c.Resolver.TLS.Addr, c.Resolver.TLS.Timeout, c.Resolver.TLS.SNI, c.Resolver.TLS.Insecure, ob)
|
uOb = outbounds.NewStandardResolverTLS(c.Resolver.TLS.Addr, c.Resolver.TLS.Timeout, c.Resolver.TLS.SNI, c.Resolver.TLS.Insecure, uOb)
|
||||||
case "https", "http":
|
case "https", "http":
|
||||||
if c.Resolver.HTTPS.Addr == "" {
|
if c.Resolver.HTTPS.Addr == "" {
|
||||||
return configError{Field: "resolver.https.addr", Err: errors.New("empty resolver address")}
|
return configError{Field: "resolver.https.addr", Err: errors.New("empty resolver address")}
|
||||||
}
|
}
|
||||||
ob = outbounds.NewDoHResolver(c.Resolver.HTTPS.Addr, c.Resolver.HTTPS.Timeout, c.Resolver.HTTPS.SNI, c.Resolver.HTTPS.Insecure, ob)
|
uOb = outbounds.NewDoHResolver(c.Resolver.HTTPS.Addr, c.Resolver.HTTPS.Timeout, c.Resolver.HTTPS.SNI, c.Resolver.HTTPS.Insecure, uOb)
|
||||||
default:
|
default:
|
||||||
return configError{Field: "resolver.type", Err: errors.New("unsupported resolver type")}
|
return configError{Field: "resolver.type", Err: errors.New("unsupported resolver type")}
|
||||||
}
|
}
|
||||||
|
|
||||||
hyConfig.Outbound = &outbounds.PluggableOutboundAdapter{PluggableOutbound: ob}
|
hyConfig.Outbound = &outbounds.PluggableOutboundAdapter{PluggableOutbound: uOb}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -525,6 +570,16 @@ func runServer(cmd *cobra.Command, args []string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func geoipDownloadFunc(filename, url string) {
|
||||||
|
logger.Info("downloading GeoIP database", zap.String("filename", filename), zap.String("url", url))
|
||||||
|
}
|
||||||
|
|
||||||
|
func geoipDownloadErrFunc(err error) {
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to download GeoIP database", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type serverLogger struct{}
|
type serverLogger struct{}
|
||||||
|
|
||||||
func (l *serverLogger) Connect(addr net.Addr, id string, tx uint64) {
|
func (l *serverLogger) Connect(addr net.Addr, id string, tx uint64) {
|
||||||
|
|
|
@ -95,6 +95,13 @@ func TestServerConfig(t *testing.T) {
|
||||||
Insecure: true,
|
Insecure: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
ACL: serverConfigACL{
|
||||||
|
File: "chnroute.txt",
|
||||||
|
Inline: []string{
|
||||||
|
"lmao(ok)",
|
||||||
|
"kek(cringe,boba,tea)",
|
||||||
|
},
|
||||||
|
},
|
||||||
Outbounds: []serverConfigOutboundEntry{
|
Outbounds: []serverConfigOutboundEntry{
|
||||||
{
|
{
|
||||||
Name: "goodstuff",
|
Name: "goodstuff",
|
||||||
|
|
|
@ -70,6 +70,12 @@ resolver:
|
||||||
sni: real.stuff.net
|
sni: real.stuff.net
|
||||||
insecure: true
|
insecure: true
|
||||||
|
|
||||||
|
acl:
|
||||||
|
file: chnroute.txt
|
||||||
|
inline:
|
||||||
|
- lmao(ok)
|
||||||
|
- kek(cringe,boba,tea)
|
||||||
|
|
||||||
outbounds:
|
outbounds:
|
||||||
- name: goodstuff
|
- name: goodstuff
|
||||||
type: direct
|
type: direct
|
||||||
|
|
|
@ -2,10 +2,18 @@ package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/apernet/hysteria/extras/utils"
|
"github.com/apernet/hysteria/extras/utils"
|
||||||
"github.com/mdp/qrterminal/v3"
|
"github.com/mdp/qrterminal/v3"
|
||||||
|
"github.com/oschwald/geoip2-golang"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
geoipDefaultFilename = "GeoLite2-Country.mmdb"
|
||||||
|
geoipDownloadURL = "https://git.io/GeoLite2-Country.mmdb"
|
||||||
)
|
)
|
||||||
|
|
||||||
// convBandwidth handles both string and int types for bandwidth.
|
// convBandwidth handles both string and int types for bandwidth.
|
||||||
|
@ -44,3 +52,59 @@ func (e configError) Error() string {
|
||||||
func (e configError) Unwrap() error {
|
func (e configError) Unwrap() error {
|
||||||
return e.Err
|
return e.Err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// geoipLoader provides the on-demand GeoIP database loading function required by the ACL engine.
|
||||||
|
type geoipLoader struct {
|
||||||
|
Filename string
|
||||||
|
DownloadFunc func(filename, url string) // Called when downloading the GeoIP database.
|
||||||
|
DownloadErrFunc func(err error) // Called when downloading the GeoIP database succeeds/fails.
|
||||||
|
|
||||||
|
db *geoip2.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *geoipLoader) download() error {
|
||||||
|
resp, err := http.Get(geoipDownloadURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
f, err := os.Create(geoipDefaultFilename)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(f, resp.Body)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *geoipLoader) Load() *geoip2.Reader {
|
||||||
|
if l.db == nil {
|
||||||
|
if l.Filename == "" {
|
||||||
|
// Filename not specified, try default.
|
||||||
|
if _, err := os.Stat(geoipDefaultFilename); err == nil {
|
||||||
|
// Default already exists, just use it.
|
||||||
|
l.Filename = geoipDefaultFilename
|
||||||
|
} else if os.IsNotExist(err) {
|
||||||
|
// Default doesn't exist, download it.
|
||||||
|
l.DownloadFunc(geoipDefaultFilename, geoipDownloadURL)
|
||||||
|
err := l.download()
|
||||||
|
l.DownloadErrFunc(err)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
l.Filename = geoipDefaultFilename
|
||||||
|
} else {
|
||||||
|
// Other error
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
db, err := geoip2.Open(l.Filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
l.db = db
|
||||||
|
}
|
||||||
|
return l.db
|
||||||
|
}
|
||||||
|
|
|
@ -48,11 +48,18 @@ auth:
|
||||||
# sni: server1.yolo.net
|
# sni: server1.yolo.net
|
||||||
# insecure: true
|
# insecure: true
|
||||||
|
|
||||||
|
# acl:
|
||||||
|
# inline:
|
||||||
|
# - haha(8.8.8.8/24, udp/53)
|
||||||
|
# - reject(v2ex.com)
|
||||||
|
# - reject(*.v2ex.com)
|
||||||
|
# - reject(geoip:cn)
|
||||||
|
|
||||||
# outbounds:
|
# outbounds:
|
||||||
# - name: haha
|
# - name: haha
|
||||||
# type: direct
|
# type: direct
|
||||||
# direct:
|
# direct:
|
||||||
# mode: auto
|
# mode: 46
|
||||||
# bindIPv4: 2.4.6.8
|
# bindIPv4: 2.4.6.8
|
||||||
# bindIPv6: 0:0:0:0:0:ffff:0204:0608
|
# bindIPv6: 0:0:0:0:0:ffff:0204:0608
|
||||||
# bindDevice: eth233
|
# bindDevice: eth233
|
||||||
|
|
122
extras/outbounds/acl.go
Normal file
122
extras/outbounds/acl.go
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
package outbounds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/apernet/hysteria/extras/outbounds/acl"
|
||||||
|
"github.com/oschwald/geoip2-golang"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
aclCacheSize = 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
var errRejected = errors.New("rejected")
|
||||||
|
|
||||||
|
// aclEngine is a PluggableOutbound that dispatches connections to different
|
||||||
|
// outbounds based on ACL rules.
|
||||||
|
// There are 3 built-in outbounds:
|
||||||
|
// - direct: directOutbound, auto mode
|
||||||
|
// - reject: reject the connection
|
||||||
|
// - default: first outbound in the list, or if the list is empty, equal to direct
|
||||||
|
// If the user-defined outbounds contain any of the above names, they will
|
||||||
|
// override the built-in outbounds.
|
||||||
|
type aclEngine struct {
|
||||||
|
RuleSet acl.CompiledRuleSet[PluggableOutbound]
|
||||||
|
Default PluggableOutbound
|
||||||
|
}
|
||||||
|
|
||||||
|
type OutboundEntry struct {
|
||||||
|
Name string
|
||||||
|
Outbound PluggableOutbound
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewACLEngineFromString(rules string, outbounds []OutboundEntry, geoipFunc func() *geoip2.Reader) (PluggableOutbound, error) {
|
||||||
|
trs, err := acl.ParseTextRules(rules)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
obMap := outboundsToMap(outbounds)
|
||||||
|
rs, err := acl.Compile[PluggableOutbound](trs, obMap, aclCacheSize, geoipFunc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &aclEngine{rs, obMap["default"]}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewACLEngineFromFile(filename string, outbounds []OutboundEntry, geoipFunc func() *geoip2.Reader) (PluggableOutbound, error) {
|
||||||
|
bs, err := os.ReadFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewACLEngineFromString(string(bs), outbounds, geoipFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func outboundsToMap(outbounds []OutboundEntry) map[string]PluggableOutbound {
|
||||||
|
obMap := make(map[string]PluggableOutbound)
|
||||||
|
for _, ob := range outbounds {
|
||||||
|
obMap[strings.ToLower(ob.Name)] = ob.Outbound
|
||||||
|
}
|
||||||
|
// Add built-in outbounds if not overridden
|
||||||
|
if _, ok := obMap["direct"]; !ok {
|
||||||
|
obMap["direct"] = NewDirectOutboundSimple(DirectOutboundModeAuto)
|
||||||
|
}
|
||||||
|
if _, ok := obMap["reject"]; !ok {
|
||||||
|
obMap["reject"] = &aclRejectOutbound{}
|
||||||
|
}
|
||||||
|
if _, ok := obMap["default"]; !ok {
|
||||||
|
if len(outbounds) > 0 {
|
||||||
|
obMap["default"] = outbounds[0].Outbound
|
||||||
|
} else {
|
||||||
|
obMap["default"] = obMap["direct"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return obMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *aclEngine) handle(reqAddr *AddrEx, proto acl.Protocol) PluggableOutbound {
|
||||||
|
hostInfo := acl.HostInfo{Name: reqAddr.Host}
|
||||||
|
if reqAddr.ResolveInfo != nil {
|
||||||
|
hostInfo.IPv4 = reqAddr.ResolveInfo.IPv4
|
||||||
|
hostInfo.IPv6 = reqAddr.ResolveInfo.IPv6
|
||||||
|
}
|
||||||
|
ob, hijackIP := a.RuleSet.Match(hostInfo, proto, reqAddr.Port)
|
||||||
|
if ob == nil {
|
||||||
|
// No match, use default outbound
|
||||||
|
return a.Default
|
||||||
|
}
|
||||||
|
if hijackIP != nil {
|
||||||
|
// We must rewrite both Host & ResolveInfo,
|
||||||
|
// as some outbounds only care about Host.
|
||||||
|
reqAddr.Host = hijackIP.String()
|
||||||
|
if ip4 := hijackIP.To4(); ip4 != nil {
|
||||||
|
reqAddr.ResolveInfo = &ResolveInfo{IPv4: ip4}
|
||||||
|
} else {
|
||||||
|
reqAddr.ResolveInfo = &ResolveInfo{IPv6: hijackIP}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ob
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *aclEngine) TCP(reqAddr *AddrEx) (net.Conn, error) {
|
||||||
|
ob := a.handle(reqAddr, acl.ProtocolTCP)
|
||||||
|
return ob.TCP(reqAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *aclEngine) UDP(reqAddr *AddrEx) (UDPConn, error) {
|
||||||
|
ob := a.handle(reqAddr, acl.ProtocolUDP)
|
||||||
|
return ob.UDP(reqAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type aclRejectOutbound struct{}
|
||||||
|
|
||||||
|
func (a *aclRejectOutbound) TCP(reqAddr *AddrEx) (net.Conn, error) {
|
||||||
|
return nil, errRejected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *aclRejectOutbound) UDP(reqAddr *AddrEx) (UDPConn, error) {
|
||||||
|
return nil, errRejected
|
||||||
|
}
|
|
@ -10,12 +10,12 @@ import (
|
||||||
"github.com/oschwald/geoip2-golang"
|
"github.com/oschwald/geoip2-golang"
|
||||||
)
|
)
|
||||||
|
|
||||||
type protocol int
|
type Protocol int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
protocolBoth protocol = iota
|
ProtocolBoth Protocol = iota
|
||||||
protocolTCP
|
ProtocolTCP
|
||||||
protocolUDP
|
ProtocolUDP
|
||||||
)
|
)
|
||||||
|
|
||||||
type Outbound interface {
|
type Outbound interface {
|
||||||
|
@ -33,19 +33,19 @@ func (h HostInfo) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompiledRuleSet[O Outbound] interface {
|
type CompiledRuleSet[O Outbound] interface {
|
||||||
Match(host HostInfo, proto protocol, port uint16) (O, net.IP)
|
Match(host HostInfo, proto Protocol, port uint16) (O, net.IP)
|
||||||
}
|
}
|
||||||
|
|
||||||
type compiledRule[O Outbound] struct {
|
type compiledRule[O Outbound] struct {
|
||||||
Outbound O
|
Outbound O
|
||||||
HostMatcher hostMatcher
|
HostMatcher hostMatcher
|
||||||
Protocol protocol
|
Protocol Protocol
|
||||||
Port uint16
|
Port uint16
|
||||||
HijackAddress net.IP
|
HijackAddress net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *compiledRule[O]) Match(host HostInfo, proto protocol, port uint16) bool {
|
func (r *compiledRule[O]) Match(host HostInfo, proto Protocol, port uint16) bool {
|
||||||
if r.Protocol != protocolBoth && r.Protocol != proto {
|
if r.Protocol != ProtocolBoth && r.Protocol != proto {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if r.Port != 0 && r.Port != port {
|
if r.Port != 0 && r.Port != port {
|
||||||
|
@ -64,7 +64,7 @@ type compiledRuleSetImpl[O Outbound] struct {
|
||||||
Cache *lru.Cache[string, matchResult[O]] // key: HostInfo.String()
|
Cache *lru.Cache[string, matchResult[O]] // key: HostInfo.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto protocol, port uint16) (O, net.IP) {
|
func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto Protocol, port uint16) (O, net.IP) {
|
||||||
host.Name = strings.ToLower(host.Name) // Normalize host name to lower case
|
host.Name = strings.ToLower(host.Name) // Normalize host name to lower case
|
||||||
key := host.String()
|
key := host.String()
|
||||||
if result, ok := s.Cache.Get(key); ok {
|
if result, ok := s.Cache.Get(key); ok {
|
||||||
|
@ -92,12 +92,18 @@ func (e *CompilationError) Error() string {
|
||||||
return fmt.Sprintf("error at line %d: %s", e.LineNum, e.Message)
|
return fmt.Sprintf("error at line %d: %s", e.LineNum, e.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Compile compiles TextRules into a CompiledRuleSet.
|
||||||
|
// Names in the outbounds map MUST be in all lower case.
|
||||||
|
// geoipFunc is a function that returns the GeoIP database needed by the GeoIP matcher.
|
||||||
|
// It will be called every time a GeoIP matcher is used during compilation, but won't
|
||||||
|
// be called if there is no GeoIP rule. We use a function here so that database loading
|
||||||
|
// is on-demand (only required if used by rules).
|
||||||
func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
||||||
cacheSize int, geoipFunc func() *geoip2.Reader,
|
cacheSize int, geoipFunc func() *geoip2.Reader,
|
||||||
) (CompiledRuleSet[O], error) {
|
) (CompiledRuleSet[O], error) {
|
||||||
compiledRules := make([]compiledRule[O], len(rules))
|
compiledRules := make([]compiledRule[O], len(rules))
|
||||||
for i, rule := range rules {
|
for i, rule := range rules {
|
||||||
outbound, ok := outbounds[rule.Outbound]
|
outbound, ok := outbounds[strings.ToLower(rule.Outbound)]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("outbound %s not found", rule.Outbound)}
|
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("outbound %s not found", rule.Outbound)}
|
||||||
}
|
}
|
||||||
|
@ -137,40 +143,40 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
||||||
// [empty] (same as *)
|
// [empty] (same as *)
|
||||||
//
|
//
|
||||||
// proto must be either "tcp" or "udp", case-insensitive.
|
// proto must be either "tcp" or "udp", case-insensitive.
|
||||||
func parseProtoPort(protoPort string) (protocol, uint16, bool) {
|
func parseProtoPort(protoPort string) (Protocol, uint16, bool) {
|
||||||
protoPort = strings.ToLower(protoPort)
|
protoPort = strings.ToLower(protoPort)
|
||||||
if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
|
if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
|
||||||
return protocolBoth, 0, true
|
return ProtocolBoth, 0, true
|
||||||
}
|
}
|
||||||
parts := strings.SplitN(protoPort, "/", 2)
|
parts := strings.SplitN(protoPort, "/", 2)
|
||||||
if len(parts) == 1 {
|
if len(parts) == 1 {
|
||||||
// No port, only protocol
|
// No port, only protocol
|
||||||
switch parts[0] {
|
switch parts[0] {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
return protocolTCP, 0, true
|
return ProtocolTCP, 0, true
|
||||||
case "udp":
|
case "udp":
|
||||||
return protocolUDP, 0, true
|
return ProtocolUDP, 0, true
|
||||||
default:
|
default:
|
||||||
return protocolBoth, 0, false
|
return ProtocolBoth, 0, false
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Both protocol and port
|
// Both protocol and port
|
||||||
var proto protocol
|
var proto Protocol
|
||||||
var port uint16
|
var port uint16
|
||||||
switch parts[0] {
|
switch parts[0] {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
proto = protocolTCP
|
proto = ProtocolTCP
|
||||||
case "udp":
|
case "udp":
|
||||||
proto = protocolUDP
|
proto = ProtocolUDP
|
||||||
case "*":
|
case "*":
|
||||||
proto = protocolBoth
|
proto = ProtocolBoth
|
||||||
default:
|
default:
|
||||||
return protocolBoth, 0, false
|
return ProtocolBoth, 0, false
|
||||||
}
|
}
|
||||||
if parts[1] != "*" {
|
if parts[1] != "*" {
|
||||||
p64, err := strconv.ParseUint(parts[1], 10, 16)
|
p64, err := strconv.ParseUint(parts[1], 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return protocolBoth, 0, false
|
return ProtocolBoth, 0, false
|
||||||
}
|
}
|
||||||
port = uint16(p64)
|
port = uint16(p64)
|
||||||
}
|
}
|
||||||
|
@ -194,7 +200,7 @@ func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatch
|
||||||
if db == nil {
|
if db == nil {
|
||||||
return nil, "failed to load GeoIP database"
|
return nil, "failed to load GeoIP database"
|
||||||
}
|
}
|
||||||
return &geoIPMatcher{db, country}, ""
|
return &geoipMatcher{db, country}, ""
|
||||||
}
|
}
|
||||||
if strings.Contains(addr, "/") {
|
if strings.Contains(addr, "/") {
|
||||||
// CIDR matcher
|
// CIDR matcher
|
||||||
|
|
|
@ -69,7 +69,7 @@ func TestCompile(t *testing.T) {
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
host HostInfo
|
host HostInfo
|
||||||
proto protocol
|
proto Protocol
|
||||||
port uint16
|
port uint16
|
||||||
wantOutbound int
|
wantOutbound int
|
||||||
wantIP net.IP
|
wantIP net.IP
|
||||||
|
@ -78,7 +78,7 @@ func TestCompile(t *testing.T) {
|
||||||
host: HostInfo{
|
host: HostInfo{
|
||||||
IPv4: net.ParseIP("1.2.3.4"),
|
IPv4: net.ParseIP("1.2.3.4"),
|
||||||
},
|
},
|
||||||
proto: protocolTCP,
|
proto: ProtocolTCP,
|
||||||
port: 1234,
|
port: 1234,
|
||||||
wantOutbound: ob1,
|
wantOutbound: ob1,
|
||||||
wantIP: nil,
|
wantIP: nil,
|
||||||
|
@ -87,7 +87,7 @@ func TestCompile(t *testing.T) {
|
||||||
host: HostInfo{
|
host: HostInfo{
|
||||||
IPv4: net.ParseIP("8.8.8.4"),
|
IPv4: net.ParseIP("8.8.8.4"),
|
||||||
},
|
},
|
||||||
proto: protocolUDP,
|
proto: ProtocolUDP,
|
||||||
port: 5353,
|
port: 5353,
|
||||||
wantOutbound: ob2,
|
wantOutbound: ob2,
|
||||||
wantIP: net.ParseIP("1.1.1.1"),
|
wantIP: net.ParseIP("1.1.1.1"),
|
||||||
|
@ -96,7 +96,7 @@ func TestCompile(t *testing.T) {
|
||||||
host: HostInfo{
|
host: HostInfo{
|
||||||
Name: "lean.delicious.com",
|
Name: "lean.delicious.com",
|
||||||
},
|
},
|
||||||
proto: protocolUDP,
|
proto: ProtocolUDP,
|
||||||
port: 443,
|
port: 443,
|
||||||
wantOutbound: ob3,
|
wantOutbound: ob3,
|
||||||
wantIP: nil,
|
wantIP: nil,
|
||||||
|
@ -105,7 +105,7 @@ func TestCompile(t *testing.T) {
|
||||||
host: HostInfo{
|
host: HostInfo{
|
||||||
IPv6: net.ParseIP("2606:4700::6810:85e5"),
|
IPv6: net.ParseIP("2606:4700::6810:85e5"),
|
||||||
},
|
},
|
||||||
proto: protocolTCP,
|
proto: ProtocolTCP,
|
||||||
port: 80,
|
port: 80,
|
||||||
wantOutbound: ob1,
|
wantOutbound: ob1,
|
||||||
wantIP: net.ParseIP("2606:4700::6810:85e6"),
|
wantIP: net.ParseIP("2606:4700::6810:85e6"),
|
||||||
|
@ -114,7 +114,7 @@ func TestCompile(t *testing.T) {
|
||||||
host: HostInfo{
|
host: HostInfo{
|
||||||
IPv6: net.ParseIP("2606:4700:0:0:0:0:0:1"),
|
IPv6: net.ParseIP("2606:4700:0:0:0:0:0:1"),
|
||||||
},
|
},
|
||||||
proto: protocolUDP,
|
proto: ProtocolUDP,
|
||||||
port: 8888,
|
port: 8888,
|
||||||
wantOutbound: ob2,
|
wantOutbound: ob2,
|
||||||
wantIP: nil,
|
wantIP: nil,
|
||||||
|
@ -123,7 +123,7 @@ func TestCompile(t *testing.T) {
|
||||||
host: HostInfo{
|
host: HostInfo{
|
||||||
Name: "www.v2ex.com",
|
Name: "www.v2ex.com",
|
||||||
},
|
},
|
||||||
proto: protocolUDP,
|
proto: ProtocolUDP,
|
||||||
port: 1234,
|
port: 1234,
|
||||||
wantOutbound: ob3,
|
wantOutbound: ob3,
|
||||||
wantIP: nil,
|
wantIP: nil,
|
||||||
|
@ -132,7 +132,7 @@ func TestCompile(t *testing.T) {
|
||||||
host: HostInfo{
|
host: HostInfo{
|
||||||
Name: "crap.v2ex.com",
|
Name: "crap.v2ex.com",
|
||||||
},
|
},
|
||||||
proto: protocolTCP,
|
proto: ProtocolTCP,
|
||||||
port: 80,
|
port: 80,
|
||||||
wantOutbound: ob1,
|
wantOutbound: ob1,
|
||||||
wantIP: net.ParseIP("2.2.2.2"),
|
wantIP: net.ParseIP("2.2.2.2"),
|
||||||
|
@ -141,7 +141,7 @@ func TestCompile(t *testing.T) {
|
||||||
host: HostInfo{
|
host: HostInfo{
|
||||||
IPv4: net.ParseIP("210.140.92.187"),
|
IPv4: net.ParseIP("210.140.92.187"),
|
||||||
},
|
},
|
||||||
proto: protocolTCP,
|
proto: ProtocolTCP,
|
||||||
port: 25,
|
port: 25,
|
||||||
wantOutbound: ob2,
|
wantOutbound: ob2,
|
||||||
wantIP: nil,
|
wantIP: nil,
|
||||||
|
|
|
@ -55,12 +55,12 @@ func deepMatchRune(str, pattern []rune) bool {
|
||||||
return len(str) == 0 && len(pattern) == 0
|
return len(str) == 0 && len(pattern) == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
type geoIPMatcher struct {
|
type geoipMatcher struct {
|
||||||
DB *geoip2.Reader
|
DB *geoip2.Reader
|
||||||
Country string // must be uppercase ISO 3166-1 alpha-2 code
|
Country string // must be uppercase ISO 3166-1 alpha-2 code
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *geoIPMatcher) Match(host HostInfo) bool {
|
func (m *geoipMatcher) Match(host HostInfo) bool {
|
||||||
if host.IPv4 != nil {
|
if host.IPv4 != nil {
|
||||||
record, err := m.DB.Country(host.IPv4)
|
record, err := m.DB.Country(host.IPv4)
|
||||||
if err == nil && record.Country.IsoCode == m.Country {
|
if err == nil && record.Country.IsoCode == m.Country {
|
||||||
|
|
|
@ -234,7 +234,7 @@ func Test_domainMatcher_Match(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_geoIPMatcher_Match(t *testing.T) {
|
func Test_geoipMatcher_Match(t *testing.T) {
|
||||||
db, err := geoip2.Open("GeoLite2-Country.mmdb")
|
db, err := geoip2.Open("GeoLite2-Country.mmdb")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
@ -298,7 +298,7 @@ func Test_geoIPMatcher_Match(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
m := &geoIPMatcher{
|
m := &geoipMatcher{
|
||||||
DB: tt.fields.DB,
|
DB: tt.fields.DB,
|
||||||
Country: tt.fields.Country,
|
Country: tt.fields.Country,
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
@ -71,11 +70,3 @@ func ParseTextRules(text string) ([]TextRule, error) {
|
||||||
}
|
}
|
||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseTextRulesFile(filename string) ([]TextRule, error) {
|
|
||||||
bs, err := os.ReadFile(filename)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ParseTextRules(string(bs))
|
|
||||||
}
|
|
||||||
|
|
61
extras/outbounds/acl_test.go
Normal file
61
extras/outbounds/acl_test.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package outbounds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestACLEngine(t *testing.T) {
|
||||||
|
ob1, ob2, ob3 := &mockPluggableOutbound{}, &mockPluggableOutbound{}, &mockPluggableOutbound{}
|
||||||
|
obs := []OutboundEntry{
|
||||||
|
{"ob1", ob1},
|
||||||
|
{"ob2", ob2},
|
||||||
|
{"ob3", ob3},
|
||||||
|
{"direct", ob2},
|
||||||
|
}
|
||||||
|
acl, err := NewACLEngineFromString(`
|
||||||
|
ob2(google.com,tcp)
|
||||||
|
ob3(youtube.com,udp)
|
||||||
|
ob1 (1.1.1.1/24,*,8.8.8.8)
|
||||||
|
Direct(cia.gov)
|
||||||
|
reJect(nsa.gov)
|
||||||
|
`, obs, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// No match, default, should be the first (ob1)
|
||||||
|
ob1.EXPECT().TCP(&AddrEx{Host: "example.com"}).Return(nil, nil).Once()
|
||||||
|
conn, err := acl.TCP(&AddrEx{Host: "example.com"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, conn)
|
||||||
|
|
||||||
|
// Match ob2
|
||||||
|
ob2.EXPECT().TCP(&AddrEx{Host: "google.com"}).Return(nil, nil).Once()
|
||||||
|
conn, err = acl.TCP(&AddrEx{Host: "google.com"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, conn)
|
||||||
|
|
||||||
|
// Match ob3
|
||||||
|
ob3.EXPECT().UDP(&AddrEx{Host: "youtube.com"}).Return(nil, nil).Once()
|
||||||
|
udpConn, err := acl.UDP(&AddrEx{Host: "youtube.com"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, udpConn)
|
||||||
|
|
||||||
|
// Match ob1 hijack IP
|
||||||
|
ob1.EXPECT().TCP(&AddrEx{Host: "8.8.8.8", ResolveInfo: &ResolveInfo{IPv4: net.ParseIP("8.8.8.8").To4()}}).Return(nil, nil).Once()
|
||||||
|
conn, err = acl.TCP(&AddrEx{ResolveInfo: &ResolveInfo{IPv4: net.ParseIP("1.1.1.22")}})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, conn)
|
||||||
|
|
||||||
|
// direct should be ob2 as we override it
|
||||||
|
ob2.EXPECT().TCP(&AddrEx{Host: "cia.gov"}).Return(nil, nil).Once()
|
||||||
|
conn, err = acl.TCP(&AddrEx{Host: "cia.gov"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, conn)
|
||||||
|
|
||||||
|
// reject
|
||||||
|
conn, err = acl.TCP(&AddrEx{Host: "nsa.gov"})
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, conn)
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue