diff --git a/extras/outbounds/acl/compile.go b/extras/outbounds/acl/compile.go index ebf5207..d27c498 100644 --- a/extras/outbounds/acl/compile.go +++ b/extras/outbounds/acl/compile.go @@ -41,7 +41,8 @@ type compiledRule[O Outbound] struct { Outbound O HostMatcher hostMatcher Protocol Protocol - Port uint16 + StartPort uint16 + EndPoint uint16 HijackAddress net.IP } @@ -49,7 +50,7 @@ func (r *compiledRule[O]) Match(host HostInfo, proto Protocol, port uint16) bool if r.Protocol != ProtocolBoth && r.Protocol != proto { return false } - if r.Port != 0 && r.Port != port { + if r.StartPort != 0 && (port < r.StartPort || port > r.EndPoint) { return false } return r.HostMatcher.Match(host) @@ -107,11 +108,6 @@ type GeoLoader interface { func Compile[O Outbound](rules []TextRule, outbounds map[string]O, cacheSize int, geoLoader GeoLoader, ) (CompiledRuleSet[O], error) { - for _, rule := range rules { - if extra := splitPortRangeRules(&rule); extra != nil { - rules = append(rules, extra...) - } - } compiledRules := make([]compiledRule[O], len(rules)) for i, rule := range rules { outbound, ok := outbounds[strings.ToLower(rule.Outbound)] @@ -122,7 +118,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O, if errStr != "" { return nil, &CompilationError{rule.LineNum, errStr} } - proto, port, ok := parseProtoPort(rule.ProtoPort) + proto, startPort, endPort, ok := parseProtoPort(rule.ProtoPort) if !ok { return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)} } @@ -133,7 +129,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O, return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)} } } - compiledRules[i] = compiledRule[O]{outbound, hm, proto, port, hijackAddress} + compiledRules[i] = compiledRule[O]{outbound, hm, proto, startPort, endPort, hijackAddress} } cache, err := lru.New[string, matchResult[O]](cacheSize) if err != nil { @@ -154,26 +150,26 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O, // [empty] (same as *) // // proto must be either "tcp" or "udp", case-insensitive. -func parseProtoPort(protoPort string) (Protocol, uint16, bool) { +func parseProtoPort(protoPort string) (Protocol, uint16, uint16, bool) { protoPort = strings.ToLower(protoPort) if protoPort == "" || protoPort == "*" || protoPort == "*/*" { - return ProtocolBoth, 0, true + return ProtocolBoth, 0, 0, true } parts := strings.SplitN(protoPort, "/", 2) if len(parts) == 1 { // No port, only protocol switch parts[0] { case "tcp": - return ProtocolTCP, 0, true + return ProtocolTCP, 0, 0, true case "udp": - return ProtocolUDP, 0, true + return ProtocolUDP, 0, 0, true default: - return ProtocolBoth, 0, false + return ProtocolBoth, 0, 0, false } } else { // Both protocol and port var proto Protocol - var port uint16 + var startPort, endPort uint16 switch parts[0] { case "tcp": proto = ProtocolTCP @@ -182,16 +178,34 @@ func parseProtoPort(protoPort string) (Protocol, uint16, bool) { case "*": proto = ProtocolBoth default: - return ProtocolBoth, 0, false + return ProtocolBoth, 0, 0, false } if parts[1] != "*" { - p64, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return ProtocolBoth, 0, false + ports := strings.SplitN(strings.TrimSpace(parts[1]), "-", 2) + if len(ports) == 1 { + p64, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return ProtocolBoth, 0, 0, false + } + startPort = uint16(p64) + endPort = startPort + } else { + p64, err := strconv.ParseUint(ports[0], 10, 16) + if err != nil { + return ProtocolBoth, 0, 0, false + } + startPort = uint16(p64) + p64, err = strconv.ParseUint(ports[1], 10, 16) + if err != nil { + return ProtocolBoth, 0, 0, false + } + endPort = uint16(p64) + if startPort > endPort { + return ProtocolBoth, 0, 0, false + } } - port = uint16(p64) } - return proto, port, true + return proto, startPort, endPort, true } } @@ -287,48 +301,3 @@ func parseGeoSiteName(s string) (string, []string) { } return base, attrs } - -// splitPortRangeRules splits a rule containing a port range and divides it into multiple rules, each specifying a single port. -// -// If protoPort has a port range, such as "tcp/80-90", -// the function splits this into individual rules for each port in the range, -// here resulting in rules for ports 80 through 90. -// the original protoPort will be changed to "tcp/80", and the returned rules will have the same Outbound, Address, and HijackAddress. -// but the ProtoPort will be changed to "tcp/81", "tcp/82", ..., "tcp/90". -func splitPortRangeRules(rule *TextRule) []TextRule { - protoPort := strings.ToLower(rule.ProtoPort) - if protoPort == "" || protoPort == "*" || protoPort == "*/*" { - return nil - } - parts := strings.SplitN(protoPort, "/", 2) - if len(parts) != 2 { - return nil - } - ports := strings.SplitN(strings.TrimSpace(parts[1]), "-", 2) - if len(ports) != 2 { - return nil - } - minPorts, err := strconv.Atoi(ports[0]) - if err != nil { - return nil - } - maxPorts, err := strconv.Atoi(ports[1]) - if err != nil { - return nil - } - - portLength := maxPorts - minPorts - if portLength <= 0 || minPorts == 0 { - return nil - } - - // port range: minPort < port <= MaxPort - extraRules := make([]TextRule, portLength) - for i := range extraRules { - extraRules[i] = *rule - extraRules[i].ProtoPort = fmt.Sprintf("%s/%d", parts[0], minPorts+i+1) - } - // edit ProtoPort from port range to a single port that value is minPort. For example, 80-90 -> 80 - rule.ProtoPort = fmt.Sprintf("%s/%d", parts[0], minPorts) - return extraRules -} diff --git a/extras/outbounds/acl/compile_test.go b/extras/outbounds/acl/compile_test.go index 87245b7..aaae8c2 100644 --- a/extras/outbounds/acl/compile_test.go +++ b/extras/outbounds/acl/compile_test.go @@ -304,75 +304,72 @@ func Test_parseGeoSiteName(t *testing.T) { } } -func Test_splitPortRangeRules(t *testing.T) { +func TestCompileRangePort(t *testing.T) { + ob1, ob2, ob3, ob4 := 1, 2, 3, 4 rules := []TextRule{ - { - Outbound: "ob1", - Address: "1.2.3.4", - ProtoPort: "tcp/1-1024", - HijackAddress: "", - }, - { - Outbound: "ob1", - Address: "1.2.3.4", - ProtoPort: "udp/1-1024", - HijackAddress: "", - }, - { - Outbound: "ob1", - Address: "1.2.3.4", - ProtoPort: "*/1-1024", - HijackAddress: "", - }, - { - Outbound: "ob1", - Address: "1.2.3.4", - ProtoPort: "tcp/0-222", - HijackAddress: "", - }, - { - Outbound: "ob1", - Address: "1.2.3.4", - ProtoPort: "tcp/1024", - HijackAddress: "", - }, - { - Outbound: "ob1", - Address: "1.2.3.4", - ProtoPort: "tcp/-1-9", - HijackAddress: "", - }, { Outbound: "ob1", Address: "1.2.3.4", ProtoPort: "tcp/6881-6889", HijackAddress: "", }, + { + Outbound: "ob2", + Address: "8.8.8.0/24", + ProtoPort: "udp/2525-3333", + HijackAddress: "1.1.1.1", + }, + { + Outbound: "ob3", + Address: "1.1.1.0/24", + ProtoPort: "*/1-65535", + HijackAddress: "", + }, + { + Outbound: "ob4", + Address: "1.1.1.0/24", + ProtoPort: "*/22", + HijackAddress: "", + }, } - _, rangeLen0 := splitPortRangeRules(&rules[0]) - assert.Equal(t, 1023, rangeLen0) + _, err := Compile[int](rules, map[string]int{ + "ob1": ob1, + "ob2": ob2, + "ob3": ob3, + "ob4": ob4, + }, 100, &testGeoLoader{}) + assert.NoError(t, err) - _, rangeLen1 := splitPortRangeRules(&rules[1]) - assert.Equal(t, 1023, rangeLen1) + ob11 := 1 + rules2 := []TextRule{ - _, rangeLen2 := splitPortRangeRules(&rules[2]) - assert.Equal(t, 1023, rangeLen2) - - _, rangeLen3 := splitPortRangeRules(&rules[3]) - assert.Equal(t, 0, rangeLen3) - - _, rangeLen4 := splitPortRangeRules(&rules[4]) - assert.Equal(t, 0, rangeLen4) - - _, rangeLen5 := splitPortRangeRules(&rules[5]) - assert.Equal(t, 0, rangeLen5) - - rangeRule, _ := splitPortRangeRules(&rules[6]) - for _, rule := range rangeRule { - assert.Equal(t, "ob1", rule.Outbound) - assert.Equal(t, "1.2.3.4", rule.Address) - t.Log(rule.ProtoPort) - assert.Equal(t, "", rule.HijackAddress) + { + Outbound: "ob11", + Address: "1.1.2.0/24", + ProtoPort: "*/3-1", // invalid range + HijackAddress: "", + }, } - assert.Equal(t, "tcp/6881", rules[6].ProtoPort) + + _, err = Compile[int](rules2, map[string]int{ + "ob11": ob11, + }, 100, &testGeoLoader{}) + assert.Error(t, err) + + ob21 := 1 + rules3 := []TextRule{ + + { + Outbound: "ob21", + Address: "1.1.2.0/24", + ProtoPort: "*/-114-514", // invalid range + HijackAddress: "", + }, + } + + _, err = Compile[int](rules3, map[string]int{ + "ob21": ob21, + }, 100, &testGeoLoader{}) + assert.Error(t, err) + }