mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 20:47:38 +03:00
303 lines
7.6 KiB
Go
303 lines
7.6 KiB
Go
package acl
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/apernet/hysteria/extras/outbounds/acl/v2geo"
|
|
|
|
lru "github.com/hashicorp/golang-lru/v2"
|
|
)
|
|
|
|
type Protocol int
|
|
|
|
const (
|
|
ProtocolBoth Protocol = iota
|
|
ProtocolTCP
|
|
ProtocolUDP
|
|
)
|
|
|
|
type Outbound interface {
|
|
any
|
|
}
|
|
|
|
type HostInfo struct {
|
|
Name string
|
|
IPv4 net.IP
|
|
IPv6 net.IP
|
|
}
|
|
|
|
func (h HostInfo) String() string {
|
|
return fmt.Sprintf("%s|%s|%s", h.Name, h.IPv4, h.IPv6)
|
|
}
|
|
|
|
type CompiledRuleSet[O Outbound] interface {
|
|
Match(host HostInfo, proto Protocol, port uint16) (O, net.IP)
|
|
}
|
|
|
|
type compiledRule[O Outbound] struct {
|
|
Outbound O
|
|
HostMatcher hostMatcher
|
|
Protocol Protocol
|
|
StartPort uint16
|
|
EndPort uint16
|
|
HijackAddress net.IP
|
|
}
|
|
|
|
func (r *compiledRule[O]) Match(host HostInfo, proto Protocol, port uint16) bool {
|
|
if r.Protocol != ProtocolBoth && r.Protocol != proto {
|
|
return false
|
|
}
|
|
if r.StartPort != 0 && (port < r.StartPort || port > r.EndPort) {
|
|
return false
|
|
}
|
|
return r.HostMatcher.Match(host)
|
|
}
|
|
|
|
type matchResult[O Outbound] struct {
|
|
Outbound O
|
|
HijackAddress net.IP
|
|
}
|
|
|
|
type compiledRuleSetImpl[O Outbound] struct {
|
|
Rules []compiledRule[O]
|
|
Cache *lru.Cache[string, matchResult[O]] // key: HostInfo.String()
|
|
}
|
|
|
|
func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto Protocol, port uint16) (O, net.IP) {
|
|
host.Name = strings.ToLower(host.Name) // Normalize host name to lower case
|
|
key := host.String()
|
|
if result, ok := s.Cache.Get(key); ok {
|
|
return result.Outbound, result.HijackAddress
|
|
}
|
|
for _, rule := range s.Rules {
|
|
if rule.Match(host, proto, port) {
|
|
result := matchResult[O]{rule.Outbound, rule.HijackAddress}
|
|
s.Cache.Add(key, result)
|
|
return result.Outbound, result.HijackAddress
|
|
}
|
|
}
|
|
// No match should also be cached
|
|
var zero O
|
|
s.Cache.Add(key, matchResult[O]{zero, nil})
|
|
return zero, nil
|
|
}
|
|
|
|
type CompilationError struct {
|
|
LineNum int
|
|
Message string
|
|
}
|
|
|
|
func (e *CompilationError) Error() string {
|
|
return fmt.Sprintf("error at line %d: %s", e.LineNum, e.Message)
|
|
}
|
|
|
|
type GeoLoader interface {
|
|
LoadGeoIP() (map[string]*v2geo.GeoIP, error)
|
|
LoadGeoSite() (map[string]*v2geo.GeoSite, error)
|
|
}
|
|
|
|
// Compile compiles TextRules into a CompiledRuleSet.
|
|
// Names in the outbounds map MUST be in all lower case.
|
|
// We want on-demand loading of GeoIP/GeoSite databases, so instead of passing the
|
|
// databases directly, we use a GeoLoader interface to load them only when needed
|
|
// by at least one rule.
|
|
func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
|
cacheSize int, geoLoader GeoLoader,
|
|
) (CompiledRuleSet[O], error) {
|
|
compiledRules := make([]compiledRule[O], len(rules))
|
|
for i, rule := range rules {
|
|
outbound, ok := outbounds[strings.ToLower(rule.Outbound)]
|
|
if !ok {
|
|
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("outbound %s not found", rule.Outbound)}
|
|
}
|
|
hm, errStr := compileHostMatcher(rule.Address, geoLoader)
|
|
if errStr != "" {
|
|
return nil, &CompilationError{rule.LineNum, errStr}
|
|
}
|
|
proto, startPort, endPort, ok := parseProtoPort(rule.ProtoPort)
|
|
if !ok {
|
|
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)}
|
|
}
|
|
var hijackAddress net.IP
|
|
if rule.HijackAddress != "" {
|
|
hijackAddress = net.ParseIP(rule.HijackAddress)
|
|
if hijackAddress == nil {
|
|
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)}
|
|
}
|
|
}
|
|
compiledRules[i] = compiledRule[O]{outbound, hm, proto, startPort, endPort, hijackAddress}
|
|
}
|
|
cache, err := lru.New[string, matchResult[O]](cacheSize)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &compiledRuleSetImpl[O]{compiledRules, cache}, nil
|
|
}
|
|
|
|
// parseProtoPort parses the protocol and port from a protoPort string.
|
|
// protoPort must be in one of the following formats:
|
|
//
|
|
// proto/port
|
|
// proto/*
|
|
// proto
|
|
// */port
|
|
// */*
|
|
// *
|
|
// [empty] (same as *)
|
|
//
|
|
// proto must be either "tcp" or "udp", case-insensitive.
|
|
func parseProtoPort(protoPort string) (Protocol, uint16, uint16, bool) {
|
|
protoPort = strings.ToLower(protoPort)
|
|
if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
|
|
return ProtocolBoth, 0, 0, true
|
|
}
|
|
parts := strings.SplitN(protoPort, "/", 2)
|
|
if len(parts) == 1 {
|
|
// No port, only protocol
|
|
switch parts[0] {
|
|
case "tcp":
|
|
return ProtocolTCP, 0, 0, true
|
|
case "udp":
|
|
return ProtocolUDP, 0, 0, true
|
|
default:
|
|
return ProtocolBoth, 0, 0, false
|
|
}
|
|
} else {
|
|
// Both protocol and port
|
|
var proto Protocol
|
|
var startPort, endPort uint16
|
|
switch parts[0] {
|
|
case "tcp":
|
|
proto = ProtocolTCP
|
|
case "udp":
|
|
proto = ProtocolUDP
|
|
case "*":
|
|
proto = ProtocolBoth
|
|
default:
|
|
return ProtocolBoth, 0, 0, false
|
|
}
|
|
if parts[1] != "*" {
|
|
// We allow either a single port or a range (e.g. "1000-2000")
|
|
ports := strings.SplitN(strings.TrimSpace(parts[1]), "-", 2)
|
|
if len(ports) == 1 {
|
|
p64, err := strconv.ParseUint(parts[1], 10, 16)
|
|
if err != nil {
|
|
return ProtocolBoth, 0, 0, false
|
|
}
|
|
startPort = uint16(p64)
|
|
endPort = startPort
|
|
} else {
|
|
p64, err := strconv.ParseUint(ports[0], 10, 16)
|
|
if err != nil {
|
|
return ProtocolBoth, 0, 0, false
|
|
}
|
|
startPort = uint16(p64)
|
|
p64, err = strconv.ParseUint(ports[1], 10, 16)
|
|
if err != nil {
|
|
return ProtocolBoth, 0, 0, false
|
|
}
|
|
endPort = uint16(p64)
|
|
if startPort > endPort {
|
|
return ProtocolBoth, 0, 0, false
|
|
}
|
|
}
|
|
}
|
|
return proto, startPort, endPort, true
|
|
}
|
|
}
|
|
|
|
func compileHostMatcher(addr string, geoLoader GeoLoader) (hostMatcher, string) {
|
|
addr = strings.ToLower(addr) // Normalize to lower case
|
|
if addr == "*" || addr == "all" {
|
|
// Match all hosts
|
|
return &allMatcher{}, ""
|
|
}
|
|
if strings.HasPrefix(addr, "geoip:") {
|
|
// GeoIP matcher
|
|
country := addr[6:]
|
|
if len(country) == 0 {
|
|
return nil, "empty GeoIP country code"
|
|
}
|
|
gMap, err := geoLoader.LoadGeoIP()
|
|
if err != nil {
|
|
return nil, err.Error()
|
|
}
|
|
list, ok := gMap[country]
|
|
if !ok || list == nil {
|
|
return nil, fmt.Sprintf("GeoIP country code %s not found", country)
|
|
}
|
|
m, err := newGeoIPMatcher(list)
|
|
if err != nil {
|
|
return nil, err.Error()
|
|
}
|
|
return m, ""
|
|
}
|
|
if strings.HasPrefix(addr, "geosite:") {
|
|
// GeoSite matcher
|
|
name, attrs := parseGeoSiteName(addr[8:])
|
|
if len(name) == 0 {
|
|
return nil, "empty GeoSite name"
|
|
}
|
|
gMap, err := geoLoader.LoadGeoSite()
|
|
if err != nil {
|
|
return nil, err.Error()
|
|
}
|
|
list, ok := gMap[name]
|
|
if !ok || list == nil {
|
|
return nil, fmt.Sprintf("GeoSite name %s not found", name)
|
|
}
|
|
m, err := newGeositeMatcher(list, attrs)
|
|
if err != nil {
|
|
return nil, err.Error()
|
|
}
|
|
return m, ""
|
|
}
|
|
if strings.HasPrefix(addr, "suffix:") {
|
|
// Domain suffix matcher
|
|
suffix := addr[7:]
|
|
if len(suffix) == 0 {
|
|
return nil, "empty domain suffix"
|
|
}
|
|
return &domainMatcher{
|
|
Pattern: suffix,
|
|
Mode: domainMatchSuffix,
|
|
}, ""
|
|
}
|
|
if strings.Contains(addr, "/") {
|
|
// CIDR matcher
|
|
_, ipnet, err := net.ParseCIDR(addr)
|
|
if err != nil {
|
|
return nil, fmt.Sprintf("invalid CIDR address: %s", addr)
|
|
}
|
|
return &cidrMatcher{ipnet}, ""
|
|
}
|
|
if ip := net.ParseIP(addr); ip != nil {
|
|
// Single IP matcher
|
|
return &ipMatcher{ip}, ""
|
|
}
|
|
if strings.Contains(addr, "*") {
|
|
// Wildcard domain matcher
|
|
return &domainMatcher{
|
|
Pattern: addr,
|
|
Mode: domainMatchWildcard,
|
|
}, ""
|
|
}
|
|
// Nothing else matched, treat it as a non-wildcard domain
|
|
return &domainMatcher{
|
|
Pattern: addr,
|
|
Mode: domainMatchExact,
|
|
}, ""
|
|
}
|
|
|
|
func parseGeoSiteName(s string) (string, []string) {
|
|
parts := strings.Split(s, "@")
|
|
base := strings.TrimSpace(parts[0])
|
|
attrs := parts[1:]
|
|
for i := range attrs {
|
|
attrs[i] = strings.TrimSpace(attrs[i])
|
|
}
|
|
return base, attrs
|
|
}
|