mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 04:27:39 +03:00
Merge pull request #975 from HynoR/master
Support range format ProtoPort
This commit is contained in:
commit
55beaff012
2 changed files with 72 additions and 21 deletions
|
@ -41,7 +41,8 @@ type compiledRule[O Outbound] struct {
|
|||
Outbound O
|
||||
HostMatcher hostMatcher
|
||||
Protocol Protocol
|
||||
Port uint16
|
||||
StartPort uint16
|
||||
EndPort uint16
|
||||
HijackAddress net.IP
|
||||
}
|
||||
|
||||
|
@ -49,7 +50,7 @@ func (r *compiledRule[O]) Match(host HostInfo, proto Protocol, port uint16) bool
|
|||
if r.Protocol != ProtocolBoth && r.Protocol != proto {
|
||||
return false
|
||||
}
|
||||
if r.Port != 0 && r.Port != port {
|
||||
if r.StartPort != 0 && (port < r.StartPort || port > r.EndPort) {
|
||||
return false
|
||||
}
|
||||
return r.HostMatcher.Match(host)
|
||||
|
@ -100,10 +101,9 @@ type GeoLoader interface {
|
|||
|
||||
// Compile compiles TextRules into a CompiledRuleSet.
|
||||
// Names in the outbounds map MUST be in all lower case.
|
||||
// geoipFunc is a function that returns the GeoIP database needed by the GeoIP matcher.
|
||||
// It will be called every time a GeoIP matcher is used during compilation, but won't
|
||||
// be called if there is no GeoIP rule. We use a function here so that database loading
|
||||
// is on-demand (only required if used by rules).
|
||||
// We want on-demand loading of GeoIP/GeoSite databases, so instead of passing the
|
||||
// databases directly, we use a GeoLoader interface to load them only when needed
|
||||
// by at least one rule.
|
||||
func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
||||
cacheSize int, geoLoader GeoLoader,
|
||||
) (CompiledRuleSet[O], error) {
|
||||
|
@ -117,7 +117,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
|||
if errStr != "" {
|
||||
return nil, &CompilationError{rule.LineNum, errStr}
|
||||
}
|
||||
proto, port, ok := parseProtoPort(rule.ProtoPort)
|
||||
proto, startPort, endPort, ok := parseProtoPort(rule.ProtoPort)
|
||||
if !ok {
|
||||
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)}
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
|||
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)}
|
||||
}
|
||||
}
|
||||
compiledRules[i] = compiledRule[O]{outbound, hm, proto, port, hijackAddress}
|
||||
compiledRules[i] = compiledRule[O]{outbound, hm, proto, startPort, endPort, hijackAddress}
|
||||
}
|
||||
cache, err := lru.New[string, matchResult[O]](cacheSize)
|
||||
if err != nil {
|
||||
|
@ -149,26 +149,26 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
|||
// [empty] (same as *)
|
||||
//
|
||||
// proto must be either "tcp" or "udp", case-insensitive.
|
||||
func parseProtoPort(protoPort string) (Protocol, uint16, bool) {
|
||||
func parseProtoPort(protoPort string) (Protocol, uint16, uint16, bool) {
|
||||
protoPort = strings.ToLower(protoPort)
|
||||
if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
|
||||
return ProtocolBoth, 0, true
|
||||
return ProtocolBoth, 0, 0, true
|
||||
}
|
||||
parts := strings.SplitN(protoPort, "/", 2)
|
||||
if len(parts) == 1 {
|
||||
// No port, only protocol
|
||||
switch parts[0] {
|
||||
case "tcp":
|
||||
return ProtocolTCP, 0, true
|
||||
return ProtocolTCP, 0, 0, true
|
||||
case "udp":
|
||||
return ProtocolUDP, 0, true
|
||||
return ProtocolUDP, 0, 0, true
|
||||
default:
|
||||
return ProtocolBoth, 0, false
|
||||
return ProtocolBoth, 0, 0, false
|
||||
}
|
||||
} else {
|
||||
// Both protocol and port
|
||||
var proto Protocol
|
||||
var port uint16
|
||||
var startPort, endPort uint16
|
||||
switch parts[0] {
|
||||
case "tcp":
|
||||
proto = ProtocolTCP
|
||||
|
@ -177,16 +177,35 @@ func parseProtoPort(protoPort string) (Protocol, uint16, bool) {
|
|||
case "*":
|
||||
proto = ProtocolBoth
|
||||
default:
|
||||
return ProtocolBoth, 0, false
|
||||
return ProtocolBoth, 0, 0, false
|
||||
}
|
||||
if parts[1] != "*" {
|
||||
p64, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return ProtocolBoth, 0, false
|
||||
// We allow either a single port or a range (e.g. "1000-2000")
|
||||
ports := strings.SplitN(strings.TrimSpace(parts[1]), "-", 2)
|
||||
if len(ports) == 1 {
|
||||
p64, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return ProtocolBoth, 0, 0, false
|
||||
}
|
||||
startPort = uint16(p64)
|
||||
endPort = startPort
|
||||
} else {
|
||||
p64, err := strconv.ParseUint(ports[0], 10, 16)
|
||||
if err != nil {
|
||||
return ProtocolBoth, 0, 0, false
|
||||
}
|
||||
startPort = uint16(p64)
|
||||
p64, err = strconv.ParseUint(ports[1], 10, 16)
|
||||
if err != nil {
|
||||
return ProtocolBoth, 0, 0, false
|
||||
}
|
||||
endPort = uint16(p64)
|
||||
if startPort > endPort {
|
||||
return ProtocolBoth, 0, 0, false
|
||||
}
|
||||
}
|
||||
port = uint16(p64)
|
||||
}
|
||||
return proto, port, true
|
||||
return proto, startPort, endPort, true
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ func (l *testGeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
|
|||
}
|
||||
|
||||
func TestCompile(t *testing.T) {
|
||||
ob1, ob2, ob3, ob4, ob5 := 1, 2, 3, 4, 5
|
||||
ob1, ob2, ob3, ob4, ob5, ob6 := 1, 2, 3, 4, 5, 6
|
||||
rules := []TextRule{
|
||||
{
|
||||
Outbound: "ob1",
|
||||
|
@ -90,6 +90,12 @@ func TestCompile(t *testing.T) {
|
|||
ProtoPort: "*/*",
|
||||
HijackAddress: "",
|
||||
},
|
||||
{
|
||||
Outbound: "ob6",
|
||||
Address: "all",
|
||||
ProtoPort: "tcp/6881-6889",
|
||||
HijackAddress: "",
|
||||
},
|
||||
}
|
||||
comp, err := Compile[int](rules, map[string]int{
|
||||
"ob1": ob1,
|
||||
|
@ -97,6 +103,7 @@ func TestCompile(t *testing.T) {
|
|||
"ob3": ob3,
|
||||
"ob4": ob4,
|
||||
"ob5": ob5,
|
||||
"ob6": ob6,
|
||||
}, 100, &testGeoLoader{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -242,6 +249,15 @@ func TestCompile(t *testing.T) {
|
|||
wantOutbound: 0, // no match default
|
||||
wantIP: nil,
|
||||
},
|
||||
{
|
||||
host: HostInfo{
|
||||
IPv4: net.ParseIP("223.1.1.1"),
|
||||
},
|
||||
proto: ProtocolTCP,
|
||||
port: 6883,
|
||||
wantOutbound: ob6, // match range port rule 6881-6889
|
||||
wantIP: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
@ -249,6 +265,22 @@ func TestCompile(t *testing.T) {
|
|||
assert.Equal(t, test.wantOutbound, gotOutbound)
|
||||
assert.Equal(t, test.wantIP, gotIP)
|
||||
}
|
||||
|
||||
// Test Invalid Port Range Rule
|
||||
eb1 := 1
|
||||
invalidRules := []TextRule{
|
||||
{
|
||||
Outbound: "eb1",
|
||||
Address: "1.1.2.0/24",
|
||||
ProtoPort: "*/3-1",
|
||||
HijackAddress: "",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = Compile[int](invalidRules, map[string]int{
|
||||
"eb1": eb1,
|
||||
}, 100, &testGeoLoader{})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_parseGeoSiteName(t *testing.T) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue