Merge pull request #975 from HynoR/master

Support range format ProtoPort
This commit is contained in:
Toby 2024-03-12 20:26:56 -07:00 committed by GitHub
commit 55beaff012
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 72 additions and 21 deletions

View file

@ -41,7 +41,8 @@ type compiledRule[O Outbound] struct {
Outbound O Outbound O
HostMatcher hostMatcher HostMatcher hostMatcher
Protocol Protocol Protocol Protocol
Port uint16 StartPort uint16
EndPort uint16
HijackAddress net.IP HijackAddress net.IP
} }
@ -49,7 +50,7 @@ func (r *compiledRule[O]) Match(host HostInfo, proto Protocol, port uint16) bool
if r.Protocol != ProtocolBoth && r.Protocol != proto { if r.Protocol != ProtocolBoth && r.Protocol != proto {
return false return false
} }
if r.Port != 0 && r.Port != port { if r.StartPort != 0 && (port < r.StartPort || port > r.EndPort) {
return false return false
} }
return r.HostMatcher.Match(host) return r.HostMatcher.Match(host)
@ -100,10 +101,9 @@ type GeoLoader interface {
// Compile compiles TextRules into a CompiledRuleSet. // Compile compiles TextRules into a CompiledRuleSet.
// Names in the outbounds map MUST be in all lower case. // Names in the outbounds map MUST be in all lower case.
// geoipFunc is a function that returns the GeoIP database needed by the GeoIP matcher. // We want on-demand loading of GeoIP/GeoSite databases, so instead of passing the
// It will be called every time a GeoIP matcher is used during compilation, but won't // databases directly, we use a GeoLoader interface to load them only when needed
// be called if there is no GeoIP rule. We use a function here so that database loading // by at least one rule.
// is on-demand (only required if used by rules).
func Compile[O Outbound](rules []TextRule, outbounds map[string]O, func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
cacheSize int, geoLoader GeoLoader, cacheSize int, geoLoader GeoLoader,
) (CompiledRuleSet[O], error) { ) (CompiledRuleSet[O], error) {
@ -117,7 +117,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
if errStr != "" { if errStr != "" {
return nil, &CompilationError{rule.LineNum, errStr} return nil, &CompilationError{rule.LineNum, errStr}
} }
proto, port, ok := parseProtoPort(rule.ProtoPort) proto, startPort, endPort, ok := parseProtoPort(rule.ProtoPort)
if !ok { if !ok {
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)} return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)}
} }
@ -128,7 +128,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)} return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)}
} }
} }
compiledRules[i] = compiledRule[O]{outbound, hm, proto, port, hijackAddress} compiledRules[i] = compiledRule[O]{outbound, hm, proto, startPort, endPort, hijackAddress}
} }
cache, err := lru.New[string, matchResult[O]](cacheSize) cache, err := lru.New[string, matchResult[O]](cacheSize)
if err != nil { if err != nil {
@ -149,26 +149,26 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
// [empty] (same as *) // [empty] (same as *)
// //
// proto must be either "tcp" or "udp", case-insensitive. // proto must be either "tcp" or "udp", case-insensitive.
func parseProtoPort(protoPort string) (Protocol, uint16, bool) { func parseProtoPort(protoPort string) (Protocol, uint16, uint16, bool) {
protoPort = strings.ToLower(protoPort) protoPort = strings.ToLower(protoPort)
if protoPort == "" || protoPort == "*" || protoPort == "*/*" { if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
return ProtocolBoth, 0, true return ProtocolBoth, 0, 0, true
} }
parts := strings.SplitN(protoPort, "/", 2) parts := strings.SplitN(protoPort, "/", 2)
if len(parts) == 1 { if len(parts) == 1 {
// No port, only protocol // No port, only protocol
switch parts[0] { switch parts[0] {
case "tcp": case "tcp":
return ProtocolTCP, 0, true return ProtocolTCP, 0, 0, true
case "udp": case "udp":
return ProtocolUDP, 0, true return ProtocolUDP, 0, 0, true
default: default:
return ProtocolBoth, 0, false return ProtocolBoth, 0, 0, false
} }
} else { } else {
// Both protocol and port // Both protocol and port
var proto Protocol var proto Protocol
var port uint16 var startPort, endPort uint16
switch parts[0] { switch parts[0] {
case "tcp": case "tcp":
proto = ProtocolTCP proto = ProtocolTCP
@ -177,16 +177,35 @@ func parseProtoPort(protoPort string) (Protocol, uint16, bool) {
case "*": case "*":
proto = ProtocolBoth proto = ProtocolBoth
default: default:
return ProtocolBoth, 0, false return ProtocolBoth, 0, 0, false
} }
if parts[1] != "*" { if parts[1] != "*" {
p64, err := strconv.ParseUint(parts[1], 10, 16) // We allow either a single port or a range (e.g. "1000-2000")
if err != nil { ports := strings.SplitN(strings.TrimSpace(parts[1]), "-", 2)
return ProtocolBoth, 0, false if len(ports) == 1 {
p64, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return ProtocolBoth, 0, 0, false
}
startPort = uint16(p64)
endPort = startPort
} else {
p64, err := strconv.ParseUint(ports[0], 10, 16)
if err != nil {
return ProtocolBoth, 0, 0, false
}
startPort = uint16(p64)
p64, err = strconv.ParseUint(ports[1], 10, 16)
if err != nil {
return ProtocolBoth, 0, 0, false
}
endPort = uint16(p64)
if startPort > endPort {
return ProtocolBoth, 0, 0, false
}
} }
port = uint16(p64)
} }
return proto, port, true return proto, startPort, endPort, true
} }
} }

View file

@ -22,7 +22,7 @@ func (l *testGeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
} }
func TestCompile(t *testing.T) { func TestCompile(t *testing.T) {
ob1, ob2, ob3, ob4, ob5 := 1, 2, 3, 4, 5 ob1, ob2, ob3, ob4, ob5, ob6 := 1, 2, 3, 4, 5, 6
rules := []TextRule{ rules := []TextRule{
{ {
Outbound: "ob1", Outbound: "ob1",
@ -90,6 +90,12 @@ func TestCompile(t *testing.T) {
ProtoPort: "*/*", ProtoPort: "*/*",
HijackAddress: "", HijackAddress: "",
}, },
{
Outbound: "ob6",
Address: "all",
ProtoPort: "tcp/6881-6889",
HijackAddress: "",
},
} }
comp, err := Compile[int](rules, map[string]int{ comp, err := Compile[int](rules, map[string]int{
"ob1": ob1, "ob1": ob1,
@ -97,6 +103,7 @@ func TestCompile(t *testing.T) {
"ob3": ob3, "ob3": ob3,
"ob4": ob4, "ob4": ob4,
"ob5": ob5, "ob5": ob5,
"ob6": ob6,
}, 100, &testGeoLoader{}) }, 100, &testGeoLoader{})
assert.NoError(t, err) assert.NoError(t, err)
@ -242,6 +249,15 @@ func TestCompile(t *testing.T) {
wantOutbound: 0, // no match default wantOutbound: 0, // no match default
wantIP: nil, wantIP: nil,
}, },
{
host: HostInfo{
IPv4: net.ParseIP("223.1.1.1"),
},
proto: ProtocolTCP,
port: 6883,
wantOutbound: ob6, // match range port rule 6881-6889
wantIP: nil,
},
} }
for _, test := range tests { for _, test := range tests {
@ -249,6 +265,22 @@ func TestCompile(t *testing.T) {
assert.Equal(t, test.wantOutbound, gotOutbound) assert.Equal(t, test.wantOutbound, gotOutbound)
assert.Equal(t, test.wantIP, gotIP) assert.Equal(t, test.wantIP, gotIP)
} }
// Test Invalid Port Range Rule
eb1 := 1
invalidRules := []TextRule{
{
Outbound: "eb1",
Address: "1.1.2.0/24",
ProtoPort: "*/3-1",
HijackAddress: "",
},
}
_, err = Compile[int](invalidRules, map[string]int{
"eb1": eb1,
}, 100, &testGeoLoader{})
assert.Error(t, err)
} }
func Test_parseGeoSiteName(t *testing.T) { func Test_parseGeoSiteName(t *testing.T) {