refactor the method that support range format ProtoPort

This commit is contained in:
HynoR 2024-03-12 11:00:47 +08:00
parent 2780dc2766
commit 9349f0a1a3
2 changed files with 92 additions and 126 deletions

View file

@ -41,7 +41,8 @@ type compiledRule[O Outbound] struct {
Outbound O Outbound O
HostMatcher hostMatcher HostMatcher hostMatcher
Protocol Protocol Protocol Protocol
Port uint16 StartPort uint16
EndPoint uint16
HijackAddress net.IP HijackAddress net.IP
} }
@ -49,7 +50,7 @@ func (r *compiledRule[O]) Match(host HostInfo, proto Protocol, port uint16) bool
if r.Protocol != ProtocolBoth && r.Protocol != proto { if r.Protocol != ProtocolBoth && r.Protocol != proto {
return false return false
} }
if r.Port != 0 && r.Port != port { if r.StartPort != 0 && (port < r.StartPort || port > r.EndPoint) {
return false return false
} }
return r.HostMatcher.Match(host) return r.HostMatcher.Match(host)
@ -107,11 +108,6 @@ type GeoLoader interface {
func Compile[O Outbound](rules []TextRule, outbounds map[string]O, func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
cacheSize int, geoLoader GeoLoader, cacheSize int, geoLoader GeoLoader,
) (CompiledRuleSet[O], error) { ) (CompiledRuleSet[O], error) {
for _, rule := range rules {
if extra := splitPortRangeRules(&rule); extra != nil {
rules = append(rules, extra...)
}
}
compiledRules := make([]compiledRule[O], len(rules)) compiledRules := make([]compiledRule[O], len(rules))
for i, rule := range rules { for i, rule := range rules {
outbound, ok := outbounds[strings.ToLower(rule.Outbound)] outbound, ok := outbounds[strings.ToLower(rule.Outbound)]
@ -122,7 +118,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
if errStr != "" { if errStr != "" {
return nil, &CompilationError{rule.LineNum, errStr} return nil, &CompilationError{rule.LineNum, errStr}
} }
proto, port, ok := parseProtoPort(rule.ProtoPort) proto, startPort, endPort, ok := parseProtoPort(rule.ProtoPort)
if !ok { if !ok {
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)} return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)}
} }
@ -133,7 +129,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)} return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)}
} }
} }
compiledRules[i] = compiledRule[O]{outbound, hm, proto, port, hijackAddress} compiledRules[i] = compiledRule[O]{outbound, hm, proto, startPort, endPort, hijackAddress}
} }
cache, err := lru.New[string, matchResult[O]](cacheSize) cache, err := lru.New[string, matchResult[O]](cacheSize)
if err != nil { if err != nil {
@ -154,26 +150,26 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
// [empty] (same as *) // [empty] (same as *)
// //
// proto must be either "tcp" or "udp", case-insensitive. // proto must be either "tcp" or "udp", case-insensitive.
func parseProtoPort(protoPort string) (Protocol, uint16, bool) { func parseProtoPort(protoPort string) (Protocol, uint16, uint16, bool) {
protoPort = strings.ToLower(protoPort) protoPort = strings.ToLower(protoPort)
if protoPort == "" || protoPort == "*" || protoPort == "*/*" { if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
return ProtocolBoth, 0, true return ProtocolBoth, 0, 0, true
} }
parts := strings.SplitN(protoPort, "/", 2) parts := strings.SplitN(protoPort, "/", 2)
if len(parts) == 1 { if len(parts) == 1 {
// No port, only protocol // No port, only protocol
switch parts[0] { switch parts[0] {
case "tcp": case "tcp":
return ProtocolTCP, 0, true return ProtocolTCP, 0, 0, true
case "udp": case "udp":
return ProtocolUDP, 0, true return ProtocolUDP, 0, 0, true
default: default:
return ProtocolBoth, 0, false return ProtocolBoth, 0, 0, false
} }
} else { } else {
// Both protocol and port // Both protocol and port
var proto Protocol var proto Protocol
var port uint16 var startPort, endPort uint16
switch parts[0] { switch parts[0] {
case "tcp": case "tcp":
proto = ProtocolTCP proto = ProtocolTCP
@ -182,16 +178,34 @@ func parseProtoPort(protoPort string) (Protocol, uint16, bool) {
case "*": case "*":
proto = ProtocolBoth proto = ProtocolBoth
default: default:
return ProtocolBoth, 0, false return ProtocolBoth, 0, 0, false
} }
if parts[1] != "*" { if parts[1] != "*" {
p64, err := strconv.ParseUint(parts[1], 10, 16) ports := strings.SplitN(strings.TrimSpace(parts[1]), "-", 2)
if err != nil { if len(ports) == 1 {
return ProtocolBoth, 0, false p64, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return ProtocolBoth, 0, 0, false
}
startPort = uint16(p64)
endPort = startPort
} else {
p64, err := strconv.ParseUint(ports[0], 10, 16)
if err != nil {
return ProtocolBoth, 0, 0, false
}
startPort = uint16(p64)
p64, err = strconv.ParseUint(ports[1], 10, 16)
if err != nil {
return ProtocolBoth, 0, 0, false
}
endPort = uint16(p64)
if startPort > endPort {
return ProtocolBoth, 0, 0, false
}
} }
port = uint16(p64)
} }
return proto, port, true return proto, startPort, endPort, true
} }
} }
@ -287,48 +301,3 @@ func parseGeoSiteName(s string) (string, []string) {
} }
return base, attrs return base, attrs
} }
// splitPortRangeRules splits a rule containing a port range and divides it into multiple rules, each specifying a single port.
//
// If protoPort has a port range, such as "tcp/80-90",
// the function splits this into individual rules for each port in the range,
// here resulting in rules for ports 80 through 90.
// the original protoPort will be changed to "tcp/80", and the returned rules will have the same Outbound, Address, and HijackAddress.
// but the ProtoPort will be changed to "tcp/81", "tcp/82", ..., "tcp/90".
func splitPortRangeRules(rule *TextRule) []TextRule {
protoPort := strings.ToLower(rule.ProtoPort)
if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
return nil
}
parts := strings.SplitN(protoPort, "/", 2)
if len(parts) != 2 {
return nil
}
ports := strings.SplitN(strings.TrimSpace(parts[1]), "-", 2)
if len(ports) != 2 {
return nil
}
minPorts, err := strconv.Atoi(ports[0])
if err != nil {
return nil
}
maxPorts, err := strconv.Atoi(ports[1])
if err != nil {
return nil
}
portLength := maxPorts - minPorts
if portLength <= 0 || minPorts == 0 {
return nil
}
// port range: minPort < port <= MaxPort
extraRules := make([]TextRule, portLength)
for i := range extraRules {
extraRules[i] = *rule
extraRules[i].ProtoPort = fmt.Sprintf("%s/%d", parts[0], minPorts+i+1)
}
// edit ProtoPort from port range to a single port that value is minPort. For example, 80-90 -> 80
rule.ProtoPort = fmt.Sprintf("%s/%d", parts[0], minPorts)
return extraRules
}

View file

@ -304,75 +304,72 @@ func Test_parseGeoSiteName(t *testing.T) {
} }
} }
func Test_splitPortRangeRules(t *testing.T) { func TestCompileRangePort(t *testing.T) {
ob1, ob2, ob3, ob4 := 1, 2, 3, 4
rules := []TextRule{ rules := []TextRule{
{
Outbound: "ob1",
Address: "1.2.3.4",
ProtoPort: "tcp/1-1024",
HijackAddress: "",
},
{
Outbound: "ob1",
Address: "1.2.3.4",
ProtoPort: "udp/1-1024",
HijackAddress: "",
},
{
Outbound: "ob1",
Address: "1.2.3.4",
ProtoPort: "*/1-1024",
HijackAddress: "",
},
{
Outbound: "ob1",
Address: "1.2.3.4",
ProtoPort: "tcp/0-222",
HijackAddress: "",
},
{
Outbound: "ob1",
Address: "1.2.3.4",
ProtoPort: "tcp/1024",
HijackAddress: "",
},
{
Outbound: "ob1",
Address: "1.2.3.4",
ProtoPort: "tcp/-1-9",
HijackAddress: "",
},
{ {
Outbound: "ob1", Outbound: "ob1",
Address: "1.2.3.4", Address: "1.2.3.4",
ProtoPort: "tcp/6881-6889", ProtoPort: "tcp/6881-6889",
HijackAddress: "", HijackAddress: "",
}, },
{
Outbound: "ob2",
Address: "8.8.8.0/24",
ProtoPort: "udp/2525-3333",
HijackAddress: "1.1.1.1",
},
{
Outbound: "ob3",
Address: "1.1.1.0/24",
ProtoPort: "*/1-65535",
HijackAddress: "",
},
{
Outbound: "ob4",
Address: "1.1.1.0/24",
ProtoPort: "*/22",
HijackAddress: "",
},
} }
_, rangeLen0 := splitPortRangeRules(&rules[0]) _, err := Compile[int](rules, map[string]int{
assert.Equal(t, 1023, rangeLen0) "ob1": ob1,
"ob2": ob2,
"ob3": ob3,
"ob4": ob4,
}, 100, &testGeoLoader{})
assert.NoError(t, err)
_, rangeLen1 := splitPortRangeRules(&rules[1]) ob11 := 1
assert.Equal(t, 1023, rangeLen1) rules2 := []TextRule{
_, rangeLen2 := splitPortRangeRules(&rules[2]) {
assert.Equal(t, 1023, rangeLen2) Outbound: "ob11",
Address: "1.1.2.0/24",
_, rangeLen3 := splitPortRangeRules(&rules[3]) ProtoPort: "*/3-1", // invalid range
assert.Equal(t, 0, rangeLen3) HijackAddress: "",
},
_, rangeLen4 := splitPortRangeRules(&rules[4])
assert.Equal(t, 0, rangeLen4)
_, rangeLen5 := splitPortRangeRules(&rules[5])
assert.Equal(t, 0, rangeLen5)
rangeRule, _ := splitPortRangeRules(&rules[6])
for _, rule := range rangeRule {
assert.Equal(t, "ob1", rule.Outbound)
assert.Equal(t, "1.2.3.4", rule.Address)
t.Log(rule.ProtoPort)
assert.Equal(t, "", rule.HijackAddress)
} }
assert.Equal(t, "tcp/6881", rules[6].ProtoPort)
_, err = Compile[int](rules2, map[string]int{
"ob11": ob11,
}, 100, &testGeoLoader{})
assert.Error(t, err)
ob21 := 1
rules3 := []TextRule{
{
Outbound: "ob21",
Address: "1.1.2.0/24",
ProtoPort: "*/-114-514", // invalid range
HijackAddress: "",
},
}
_, err = Compile[int](rules3, map[string]int{
"ob21": ob21,
}, 100, &testGeoLoader{})
assert.Error(t, err)
} }