feat: ACL

This commit is contained in:
Toby 2023-08-14 19:00:56 -07:00
parent 6fa958815b
commit a7d74a9ec1
12 changed files with 380 additions and 61 deletions

122
extras/outbounds/acl.go Normal file
View file

@ -0,0 +1,122 @@
package outbounds
import (
"errors"
"net"
"os"
"strings"
"github.com/apernet/hysteria/extras/outbounds/acl"
"github.com/oschwald/geoip2-golang"
)
const (
aclCacheSize = 1024
)
var errRejected = errors.New("rejected")
// aclEngine is a PluggableOutbound that dispatches connections to different
// outbounds based on ACL rules.
// There are 3 built-in outbounds:
// - direct: directOutbound, auto mode
// - reject: reject the connection
// - default: first outbound in the list, or if the list is empty, equal to direct
// If the user-defined outbounds contain any of the above names, they will
// override the built-in outbounds.
type aclEngine struct {
RuleSet acl.CompiledRuleSet[PluggableOutbound]
Default PluggableOutbound
}
type OutboundEntry struct {
Name string
Outbound PluggableOutbound
}
func NewACLEngineFromString(rules string, outbounds []OutboundEntry, geoipFunc func() *geoip2.Reader) (PluggableOutbound, error) {
trs, err := acl.ParseTextRules(rules)
if err != nil {
return nil, err
}
obMap := outboundsToMap(outbounds)
rs, err := acl.Compile[PluggableOutbound](trs, obMap, aclCacheSize, geoipFunc)
if err != nil {
return nil, err
}
return &aclEngine{rs, obMap["default"]}, nil
}
func NewACLEngineFromFile(filename string, outbounds []OutboundEntry, geoipFunc func() *geoip2.Reader) (PluggableOutbound, error) {
bs, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
return NewACLEngineFromString(string(bs), outbounds, geoipFunc)
}
func outboundsToMap(outbounds []OutboundEntry) map[string]PluggableOutbound {
obMap := make(map[string]PluggableOutbound)
for _, ob := range outbounds {
obMap[strings.ToLower(ob.Name)] = ob.Outbound
}
// Add built-in outbounds if not overridden
if _, ok := obMap["direct"]; !ok {
obMap["direct"] = NewDirectOutboundSimple(DirectOutboundModeAuto)
}
if _, ok := obMap["reject"]; !ok {
obMap["reject"] = &aclRejectOutbound{}
}
if _, ok := obMap["default"]; !ok {
if len(outbounds) > 0 {
obMap["default"] = outbounds[0].Outbound
} else {
obMap["default"] = obMap["direct"]
}
}
return obMap
}
func (a *aclEngine) handle(reqAddr *AddrEx, proto acl.Protocol) PluggableOutbound {
hostInfo := acl.HostInfo{Name: reqAddr.Host}
if reqAddr.ResolveInfo != nil {
hostInfo.IPv4 = reqAddr.ResolveInfo.IPv4
hostInfo.IPv6 = reqAddr.ResolveInfo.IPv6
}
ob, hijackIP := a.RuleSet.Match(hostInfo, proto, reqAddr.Port)
if ob == nil {
// No match, use default outbound
return a.Default
}
if hijackIP != nil {
// We must rewrite both Host & ResolveInfo,
// as some outbounds only care about Host.
reqAddr.Host = hijackIP.String()
if ip4 := hijackIP.To4(); ip4 != nil {
reqAddr.ResolveInfo = &ResolveInfo{IPv4: ip4}
} else {
reqAddr.ResolveInfo = &ResolveInfo{IPv6: hijackIP}
}
}
return ob
}
func (a *aclEngine) TCP(reqAddr *AddrEx) (net.Conn, error) {
ob := a.handle(reqAddr, acl.ProtocolTCP)
return ob.TCP(reqAddr)
}
func (a *aclEngine) UDP(reqAddr *AddrEx) (UDPConn, error) {
ob := a.handle(reqAddr, acl.ProtocolUDP)
return ob.UDP(reqAddr)
}
type aclRejectOutbound struct{}
func (a *aclRejectOutbound) TCP(reqAddr *AddrEx) (net.Conn, error) {
return nil, errRejected
}
func (a *aclRejectOutbound) UDP(reqAddr *AddrEx) (UDPConn, error) {
return nil, errRejected
}

View file

@ -10,12 +10,12 @@ import (
"github.com/oschwald/geoip2-golang"
)
type protocol int
type Protocol int
const (
protocolBoth protocol = iota
protocolTCP
protocolUDP
ProtocolBoth Protocol = iota
ProtocolTCP
ProtocolUDP
)
type Outbound interface {
@ -33,19 +33,19 @@ func (h HostInfo) String() string {
}
type CompiledRuleSet[O Outbound] interface {
Match(host HostInfo, proto protocol, port uint16) (O, net.IP)
Match(host HostInfo, proto Protocol, port uint16) (O, net.IP)
}
type compiledRule[O Outbound] struct {
Outbound O
HostMatcher hostMatcher
Protocol protocol
Protocol Protocol
Port uint16
HijackAddress net.IP
}
func (r *compiledRule[O]) Match(host HostInfo, proto protocol, port uint16) bool {
if r.Protocol != protocolBoth && r.Protocol != proto {
func (r *compiledRule[O]) Match(host HostInfo, proto Protocol, port uint16) bool {
if r.Protocol != ProtocolBoth && r.Protocol != proto {
return false
}
if r.Port != 0 && r.Port != port {
@ -64,7 +64,7 @@ type compiledRuleSetImpl[O Outbound] struct {
Cache *lru.Cache[string, matchResult[O]] // key: HostInfo.String()
}
func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto protocol, port uint16) (O, net.IP) {
func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto Protocol, port uint16) (O, net.IP) {
host.Name = strings.ToLower(host.Name) // Normalize host name to lower case
key := host.String()
if result, ok := s.Cache.Get(key); ok {
@ -92,12 +92,18 @@ func (e *CompilationError) Error() string {
return fmt.Sprintf("error at line %d: %s", e.LineNum, e.Message)
}
// Compile compiles TextRules into a CompiledRuleSet.
// Names in the outbounds map MUST be in all lower case.
// geoipFunc is a function that returns the GeoIP database needed by the GeoIP matcher.
// It will be called every time a GeoIP matcher is used during compilation, but won't
// be called if there is no GeoIP rule. We use a function here so that database loading
// is on-demand (only required if used by rules).
func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
cacheSize int, geoipFunc func() *geoip2.Reader,
) (CompiledRuleSet[O], error) {
compiledRules := make([]compiledRule[O], len(rules))
for i, rule := range rules {
outbound, ok := outbounds[rule.Outbound]
outbound, ok := outbounds[strings.ToLower(rule.Outbound)]
if !ok {
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("outbound %s not found", rule.Outbound)}
}
@ -137,40 +143,40 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
// [empty] (same as *)
//
// proto must be either "tcp" or "udp", case-insensitive.
func parseProtoPort(protoPort string) (protocol, uint16, bool) {
func parseProtoPort(protoPort string) (Protocol, uint16, bool) {
protoPort = strings.ToLower(protoPort)
if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
return protocolBoth, 0, true
return ProtocolBoth, 0, true
}
parts := strings.SplitN(protoPort, "/", 2)
if len(parts) == 1 {
// No port, only protocol
switch parts[0] {
case "tcp":
return protocolTCP, 0, true
return ProtocolTCP, 0, true
case "udp":
return protocolUDP, 0, true
return ProtocolUDP, 0, true
default:
return protocolBoth, 0, false
return ProtocolBoth, 0, false
}
} else {
// Both protocol and port
var proto protocol
var proto Protocol
var port uint16
switch parts[0] {
case "tcp":
proto = protocolTCP
proto = ProtocolTCP
case "udp":
proto = protocolUDP
proto = ProtocolUDP
case "*":
proto = protocolBoth
proto = ProtocolBoth
default:
return protocolBoth, 0, false
return ProtocolBoth, 0, false
}
if parts[1] != "*" {
p64, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return protocolBoth, 0, false
return ProtocolBoth, 0, false
}
port = uint16(p64)
}
@ -194,7 +200,7 @@ func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatch
if db == nil {
return nil, "failed to load GeoIP database"
}
return &geoIPMatcher{db, country}, ""
return &geoipMatcher{db, country}, ""
}
if strings.Contains(addr, "/") {
// CIDR matcher

View file

@ -69,7 +69,7 @@ func TestCompile(t *testing.T) {
tests := []struct {
host HostInfo
proto protocol
proto Protocol
port uint16
wantOutbound int
wantIP net.IP
@ -78,7 +78,7 @@ func TestCompile(t *testing.T) {
host: HostInfo{
IPv4: net.ParseIP("1.2.3.4"),
},
proto: protocolTCP,
proto: ProtocolTCP,
port: 1234,
wantOutbound: ob1,
wantIP: nil,
@ -87,7 +87,7 @@ func TestCompile(t *testing.T) {
host: HostInfo{
IPv4: net.ParseIP("8.8.8.4"),
},
proto: protocolUDP,
proto: ProtocolUDP,
port: 5353,
wantOutbound: ob2,
wantIP: net.ParseIP("1.1.1.1"),
@ -96,7 +96,7 @@ func TestCompile(t *testing.T) {
host: HostInfo{
Name: "lean.delicious.com",
},
proto: protocolUDP,
proto: ProtocolUDP,
port: 443,
wantOutbound: ob3,
wantIP: nil,
@ -105,7 +105,7 @@ func TestCompile(t *testing.T) {
host: HostInfo{
IPv6: net.ParseIP("2606:4700::6810:85e5"),
},
proto: protocolTCP,
proto: ProtocolTCP,
port: 80,
wantOutbound: ob1,
wantIP: net.ParseIP("2606:4700::6810:85e6"),
@ -114,7 +114,7 @@ func TestCompile(t *testing.T) {
host: HostInfo{
IPv6: net.ParseIP("2606:4700:0:0:0:0:0:1"),
},
proto: protocolUDP,
proto: ProtocolUDP,
port: 8888,
wantOutbound: ob2,
wantIP: nil,
@ -123,7 +123,7 @@ func TestCompile(t *testing.T) {
host: HostInfo{
Name: "www.v2ex.com",
},
proto: protocolUDP,
proto: ProtocolUDP,
port: 1234,
wantOutbound: ob3,
wantIP: nil,
@ -132,7 +132,7 @@ func TestCompile(t *testing.T) {
host: HostInfo{
Name: "crap.v2ex.com",
},
proto: protocolTCP,
proto: ProtocolTCP,
port: 80,
wantOutbound: ob1,
wantIP: net.ParseIP("2.2.2.2"),
@ -141,7 +141,7 @@ func TestCompile(t *testing.T) {
host: HostInfo{
IPv4: net.ParseIP("210.140.92.187"),
},
proto: protocolTCP,
proto: ProtocolTCP,
port: 25,
wantOutbound: ob2,
wantIP: nil,

View file

@ -55,12 +55,12 @@ func deepMatchRune(str, pattern []rune) bool {
return len(str) == 0 && len(pattern) == 0
}
type geoIPMatcher struct {
type geoipMatcher struct {
DB *geoip2.Reader
Country string // must be uppercase ISO 3166-1 alpha-2 code
}
func (m *geoIPMatcher) Match(host HostInfo) bool {
func (m *geoipMatcher) Match(host HostInfo) bool {
if host.IPv4 != nil {
record, err := m.DB.Country(host.IPv4)
if err == nil && record.Country.IsoCode == m.Country {

View file

@ -234,7 +234,7 @@ func Test_domainMatcher_Match(t *testing.T) {
}
}
func Test_geoIPMatcher_Match(t *testing.T) {
func Test_geoipMatcher_Match(t *testing.T) {
db, err := geoip2.Open("GeoLite2-Country.mmdb")
assert.NoError(t, err)
defer db.Close()
@ -298,7 +298,7 @@ func Test_geoIPMatcher_Match(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &geoIPMatcher{
m := &geoipMatcher{
DB: tt.fields.DB,
Country: tt.fields.Country,
}

View file

@ -2,7 +2,6 @@ package acl
import (
"fmt"
"os"
"regexp"
"strings"
)
@ -71,11 +70,3 @@ func ParseTextRules(text string) ([]TextRule, error) {
}
return rules, nil
}
func ParseTextRulesFile(filename string) ([]TextRule, error) {
bs, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
return ParseTextRules(string(bs))
}

View file

@ -0,0 +1,61 @@
package outbounds
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestACLEngine(t *testing.T) {
ob1, ob2, ob3 := &mockPluggableOutbound{}, &mockPluggableOutbound{}, &mockPluggableOutbound{}
obs := []OutboundEntry{
{"ob1", ob1},
{"ob2", ob2},
{"ob3", ob3},
{"direct", ob2},
}
acl, err := NewACLEngineFromString(`
ob2(google.com,tcp)
ob3(youtube.com,udp)
ob1 (1.1.1.1/24,*,8.8.8.8)
Direct(cia.gov)
reJect(nsa.gov)
`, obs, nil)
assert.NoError(t, err)
// No match, default, should be the first (ob1)
ob1.EXPECT().TCP(&AddrEx{Host: "example.com"}).Return(nil, nil).Once()
conn, err := acl.TCP(&AddrEx{Host: "example.com"})
assert.NoError(t, err)
assert.Nil(t, conn)
// Match ob2
ob2.EXPECT().TCP(&AddrEx{Host: "google.com"}).Return(nil, nil).Once()
conn, err = acl.TCP(&AddrEx{Host: "google.com"})
assert.NoError(t, err)
assert.Nil(t, conn)
// Match ob3
ob3.EXPECT().UDP(&AddrEx{Host: "youtube.com"}).Return(nil, nil).Once()
udpConn, err := acl.UDP(&AddrEx{Host: "youtube.com"})
assert.NoError(t, err)
assert.Nil(t, udpConn)
// Match ob1 hijack IP
ob1.EXPECT().TCP(&AddrEx{Host: "8.8.8.8", ResolveInfo: &ResolveInfo{IPv4: net.ParseIP("8.8.8.8").To4()}}).Return(nil, nil).Once()
conn, err = acl.TCP(&AddrEx{ResolveInfo: &ResolveInfo{IPv4: net.ParseIP("1.1.1.22")}})
assert.NoError(t, err)
assert.Nil(t, conn)
// direct should be ob2 as we override it
ob2.EXPECT().TCP(&AddrEx{Host: "cia.gov"}).Return(nil, nil).Once()
conn, err = acl.TCP(&AddrEx{Host: "cia.gov"})
assert.NoError(t, err)
assert.Nil(t, conn)
// reject
conn, err = acl.TCP(&AddrEx{Host: "nsa.gov"})
assert.Error(t, err)
assert.Nil(t, conn)
}