package acl import ( "fmt" "net" "strconv" "strings" "github.com/apernet/hysteria/extras/v2/outbounds/acl/v2geo" lru "github.com/hashicorp/golang-lru/v2" ) type Protocol int const ( ProtocolBoth Protocol = iota ProtocolTCP ProtocolUDP ) type Outbound interface { any } type HostInfo struct { Name string IPv4 net.IP IPv6 net.IP } func (h HostInfo) String() string { return fmt.Sprintf("%s|%s|%s", h.Name, h.IPv4, h.IPv6) } type CompiledRuleSet[O Outbound] interface { Match(host HostInfo, proto Protocol, port uint16) (O, net.IP) } type compiledRule[O Outbound] struct { Outbound O HostMatcher hostMatcher Protocol Protocol StartPort uint16 EndPort uint16 HijackAddress net.IP } func (r *compiledRule[O]) Match(host HostInfo, proto Protocol, port uint16) bool { if r.Protocol != ProtocolBoth && r.Protocol != proto { return false } if r.StartPort != 0 && (port < r.StartPort || port > r.EndPort) { return false } return r.HostMatcher.Match(host) } type matchResult[O Outbound] struct { Outbound O HijackAddress net.IP } type compiledRuleSetImpl[O Outbound] struct { Rules []compiledRule[O] Cache *lru.Cache[string, matchResult[O]] // key: HostInfo.String() } func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto Protocol, port uint16) (O, net.IP) { host.Name = strings.ToLower(host.Name) // Normalize host name to lower case key := host.String() if result, ok := s.Cache.Get(key); ok { return result.Outbound, result.HijackAddress } for _, rule := range s.Rules { if rule.Match(host, proto, port) { result := matchResult[O]{rule.Outbound, rule.HijackAddress} s.Cache.Add(key, result) return result.Outbound, result.HijackAddress } } // No match should also be cached var zero O s.Cache.Add(key, matchResult[O]{zero, nil}) return zero, nil } type CompilationError struct { LineNum int Message string } func (e *CompilationError) Error() string { return fmt.Sprintf("error at line %d: %s", e.LineNum, e.Message) } type GeoLoader interface { LoadGeoIP() (map[string]*v2geo.GeoIP, error) LoadGeoSite() (map[string]*v2geo.GeoSite, error) } // Compile compiles TextRules into a CompiledRuleSet. // Names in the outbounds map MUST be in all lower case. // We want on-demand loading of GeoIP/GeoSite databases, so instead of passing the // databases directly, we use a GeoLoader interface to load them only when needed // by at least one rule. func Compile[O Outbound](rules []TextRule, outbounds map[string]O, cacheSize int, geoLoader GeoLoader, ) (CompiledRuleSet[O], error) { compiledRules := make([]compiledRule[O], len(rules)) for i, rule := range rules { outbound, ok := outbounds[strings.ToLower(rule.Outbound)] if !ok { return nil, &CompilationError{rule.LineNum, fmt.Sprintf("outbound %s not found", rule.Outbound)} } hm, errStr := compileHostMatcher(rule.Address, geoLoader) if errStr != "" { return nil, &CompilationError{rule.LineNum, errStr} } proto, startPort, endPort, ok := parseProtoPort(rule.ProtoPort) if !ok { return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)} } var hijackAddress net.IP if rule.HijackAddress != "" { hijackAddress = net.ParseIP(rule.HijackAddress) if hijackAddress == nil { return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)} } } compiledRules[i] = compiledRule[O]{outbound, hm, proto, startPort, endPort, hijackAddress} } cache, err := lru.New[string, matchResult[O]](cacheSize) if err != nil { return nil, err } return &compiledRuleSetImpl[O]{compiledRules, cache}, nil } // parseProtoPort parses the protocol and port from a protoPort string. // protoPort must be in one of the following formats: // // proto/port // proto/* // proto // */port // */* // * // [empty] (same as *) // // proto must be either "tcp" or "udp", case-insensitive. func parseProtoPort(protoPort string) (Protocol, uint16, uint16, bool) { protoPort = strings.ToLower(protoPort) if protoPort == "" || protoPort == "*" || protoPort == "*/*" { return ProtocolBoth, 0, 0, true } parts := strings.SplitN(protoPort, "/", 2) if len(parts) == 1 { // No port, only protocol switch parts[0] { case "tcp": return ProtocolTCP, 0, 0, true case "udp": return ProtocolUDP, 0, 0, true default: return ProtocolBoth, 0, 0, false } } else { // Both protocol and port var proto Protocol var startPort, endPort uint16 switch parts[0] { case "tcp": proto = ProtocolTCP case "udp": proto = ProtocolUDP case "*": proto = ProtocolBoth default: return ProtocolBoth, 0, 0, false } if parts[1] != "*" { // We allow either a single port or a range (e.g. "1000-2000") ports := strings.SplitN(strings.TrimSpace(parts[1]), "-", 2) if len(ports) == 1 { p64, err := strconv.ParseUint(parts[1], 10, 16) if err != nil { return ProtocolBoth, 0, 0, false } startPort = uint16(p64) endPort = startPort } else { p64, err := strconv.ParseUint(ports[0], 10, 16) if err != nil { return ProtocolBoth, 0, 0, false } startPort = uint16(p64) p64, err = strconv.ParseUint(ports[1], 10, 16) if err != nil { return ProtocolBoth, 0, 0, false } endPort = uint16(p64) if startPort > endPort { return ProtocolBoth, 0, 0, false } } } return proto, startPort, endPort, true } } func compileHostMatcher(addr string, geoLoader GeoLoader) (hostMatcher, string) { addr = strings.ToLower(addr) // Normalize to lower case if addr == "*" || addr == "all" { // Match all hosts return &allMatcher{}, "" } if strings.HasPrefix(addr, "geoip:") { // GeoIP matcher country := addr[6:] if len(country) == 0 { return nil, "empty GeoIP country code" } gMap, err := geoLoader.LoadGeoIP() if err != nil { return nil, err.Error() } list, ok := gMap[country] if !ok || list == nil { return nil, fmt.Sprintf("GeoIP country code %s not found", country) } m, err := newGeoIPMatcher(list) if err != nil { return nil, err.Error() } return m, "" } if strings.HasPrefix(addr, "geosite:") { // GeoSite matcher name, attrs := parseGeoSiteName(addr[8:]) if len(name) == 0 { return nil, "empty GeoSite name" } gMap, err := geoLoader.LoadGeoSite() if err != nil { return nil, err.Error() } list, ok := gMap[name] if !ok || list == nil { return nil, fmt.Sprintf("GeoSite name %s not found", name) } m, err := newGeositeMatcher(list, attrs) if err != nil { return nil, err.Error() } return m, "" } if strings.HasPrefix(addr, "suffix:") { // Domain suffix matcher suffix := addr[7:] if len(suffix) == 0 { return nil, "empty domain suffix" } return &domainMatcher{ Pattern: suffix, Mode: domainMatchSuffix, }, "" } if strings.Contains(addr, "/") { // CIDR matcher _, ipnet, err := net.ParseCIDR(addr) if err != nil { return nil, fmt.Sprintf("invalid CIDR address: %s", addr) } return &cidrMatcher{ipnet}, "" } if ip := net.ParseIP(addr); ip != nil { // Single IP matcher return &ipMatcher{ip}, "" } if strings.Contains(addr, "*") { // Wildcard domain matcher return &domainMatcher{ Pattern: addr, Mode: domainMatchWildcard, }, "" } // Nothing else matched, treat it as a non-wildcard domain return &domainMatcher{ Pattern: addr, Mode: domainMatchExact, }, "" } func parseGeoSiteName(s string) (string, []string) { parts := strings.Split(s, "@") base := strings.TrimSpace(parts[0]) attrs := parts[1:] for i := range attrs { attrs[i] = strings.TrimSpace(attrs[i]) } return base, attrs }