mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 20:47:38 +03:00
refactor the method that support range format ProtoPort
This commit is contained in:
parent
2780dc2766
commit
9349f0a1a3
2 changed files with 92 additions and 126 deletions
|
@ -41,7 +41,8 @@ type compiledRule[O Outbound] struct {
|
||||||
Outbound O
|
Outbound O
|
||||||
HostMatcher hostMatcher
|
HostMatcher hostMatcher
|
||||||
Protocol Protocol
|
Protocol Protocol
|
||||||
Port uint16
|
StartPort uint16
|
||||||
|
EndPoint uint16
|
||||||
HijackAddress net.IP
|
HijackAddress net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,7 +50,7 @@ func (r *compiledRule[O]) Match(host HostInfo, proto Protocol, port uint16) bool
|
||||||
if r.Protocol != ProtocolBoth && r.Protocol != proto {
|
if r.Protocol != ProtocolBoth && r.Protocol != proto {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if r.Port != 0 && r.Port != port {
|
if r.StartPort != 0 && (port < r.StartPort || port > r.EndPoint) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return r.HostMatcher.Match(host)
|
return r.HostMatcher.Match(host)
|
||||||
|
@ -107,11 +108,6 @@ type GeoLoader interface {
|
||||||
func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
||||||
cacheSize int, geoLoader GeoLoader,
|
cacheSize int, geoLoader GeoLoader,
|
||||||
) (CompiledRuleSet[O], error) {
|
) (CompiledRuleSet[O], error) {
|
||||||
for _, rule := range rules {
|
|
||||||
if extra := splitPortRangeRules(&rule); extra != nil {
|
|
||||||
rules = append(rules, extra...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
compiledRules := make([]compiledRule[O], len(rules))
|
compiledRules := make([]compiledRule[O], len(rules))
|
||||||
for i, rule := range rules {
|
for i, rule := range rules {
|
||||||
outbound, ok := outbounds[strings.ToLower(rule.Outbound)]
|
outbound, ok := outbounds[strings.ToLower(rule.Outbound)]
|
||||||
|
@ -122,7 +118,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
||||||
if errStr != "" {
|
if errStr != "" {
|
||||||
return nil, &CompilationError{rule.LineNum, errStr}
|
return nil, &CompilationError{rule.LineNum, errStr}
|
||||||
}
|
}
|
||||||
proto, port, ok := parseProtoPort(rule.ProtoPort)
|
proto, startPort, endPort, ok := parseProtoPort(rule.ProtoPort)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)}
|
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)}
|
||||||
}
|
}
|
||||||
|
@ -133,7 +129,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
||||||
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)}
|
return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
compiledRules[i] = compiledRule[O]{outbound, hm, proto, port, hijackAddress}
|
compiledRules[i] = compiledRule[O]{outbound, hm, proto, startPort, endPort, hijackAddress}
|
||||||
}
|
}
|
||||||
cache, err := lru.New[string, matchResult[O]](cacheSize)
|
cache, err := lru.New[string, matchResult[O]](cacheSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -154,26 +150,26 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
||||||
// [empty] (same as *)
|
// [empty] (same as *)
|
||||||
//
|
//
|
||||||
// proto must be either "tcp" or "udp", case-insensitive.
|
// proto must be either "tcp" or "udp", case-insensitive.
|
||||||
func parseProtoPort(protoPort string) (Protocol, uint16, bool) {
|
func parseProtoPort(protoPort string) (Protocol, uint16, uint16, bool) {
|
||||||
protoPort = strings.ToLower(protoPort)
|
protoPort = strings.ToLower(protoPort)
|
||||||
if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
|
if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
|
||||||
return ProtocolBoth, 0, true
|
return ProtocolBoth, 0, 0, true
|
||||||
}
|
}
|
||||||
parts := strings.SplitN(protoPort, "/", 2)
|
parts := strings.SplitN(protoPort, "/", 2)
|
||||||
if len(parts) == 1 {
|
if len(parts) == 1 {
|
||||||
// No port, only protocol
|
// No port, only protocol
|
||||||
switch parts[0] {
|
switch parts[0] {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
return ProtocolTCP, 0, true
|
return ProtocolTCP, 0, 0, true
|
||||||
case "udp":
|
case "udp":
|
||||||
return ProtocolUDP, 0, true
|
return ProtocolUDP, 0, 0, true
|
||||||
default:
|
default:
|
||||||
return ProtocolBoth, 0, false
|
return ProtocolBoth, 0, 0, false
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Both protocol and port
|
// Both protocol and port
|
||||||
var proto Protocol
|
var proto Protocol
|
||||||
var port uint16
|
var startPort, endPort uint16
|
||||||
switch parts[0] {
|
switch parts[0] {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
proto = ProtocolTCP
|
proto = ProtocolTCP
|
||||||
|
@ -182,16 +178,34 @@ func parseProtoPort(protoPort string) (Protocol, uint16, bool) {
|
||||||
case "*":
|
case "*":
|
||||||
proto = ProtocolBoth
|
proto = ProtocolBoth
|
||||||
default:
|
default:
|
||||||
return ProtocolBoth, 0, false
|
return ProtocolBoth, 0, 0, false
|
||||||
}
|
}
|
||||||
if parts[1] != "*" {
|
if parts[1] != "*" {
|
||||||
p64, err := strconv.ParseUint(parts[1], 10, 16)
|
ports := strings.SplitN(strings.TrimSpace(parts[1]), "-", 2)
|
||||||
if err != nil {
|
if len(ports) == 1 {
|
||||||
return ProtocolBoth, 0, false
|
p64, err := strconv.ParseUint(parts[1], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return ProtocolBoth, 0, 0, false
|
||||||
|
}
|
||||||
|
startPort = uint16(p64)
|
||||||
|
endPort = startPort
|
||||||
|
} else {
|
||||||
|
p64, err := strconv.ParseUint(ports[0], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return ProtocolBoth, 0, 0, false
|
||||||
|
}
|
||||||
|
startPort = uint16(p64)
|
||||||
|
p64, err = strconv.ParseUint(ports[1], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return ProtocolBoth, 0, 0, false
|
||||||
|
}
|
||||||
|
endPort = uint16(p64)
|
||||||
|
if startPort > endPort {
|
||||||
|
return ProtocolBoth, 0, 0, false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
port = uint16(p64)
|
|
||||||
}
|
}
|
||||||
return proto, port, true
|
return proto, startPort, endPort, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -287,48 +301,3 @@ func parseGeoSiteName(s string) (string, []string) {
|
||||||
}
|
}
|
||||||
return base, attrs
|
return base, attrs
|
||||||
}
|
}
|
||||||
|
|
||||||
// splitPortRangeRules splits a rule containing a port range and divides it into multiple rules, each specifying a single port.
|
|
||||||
//
|
|
||||||
// If protoPort has a port range, such as "tcp/80-90",
|
|
||||||
// the function splits this into individual rules for each port in the range,
|
|
||||||
// here resulting in rules for ports 80 through 90.
|
|
||||||
// the original protoPort will be changed to "tcp/80", and the returned rules will have the same Outbound, Address, and HijackAddress.
|
|
||||||
// but the ProtoPort will be changed to "tcp/81", "tcp/82", ..., "tcp/90".
|
|
||||||
func splitPortRangeRules(rule *TextRule) []TextRule {
|
|
||||||
protoPort := strings.ToLower(rule.ProtoPort)
|
|
||||||
if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
parts := strings.SplitN(protoPort, "/", 2)
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
ports := strings.SplitN(strings.TrimSpace(parts[1]), "-", 2)
|
|
||||||
if len(ports) != 2 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
minPorts, err := strconv.Atoi(ports[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
maxPorts, err := strconv.Atoi(ports[1])
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
portLength := maxPorts - minPorts
|
|
||||||
if portLength <= 0 || minPorts == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// port range: minPort < port <= MaxPort
|
|
||||||
extraRules := make([]TextRule, portLength)
|
|
||||||
for i := range extraRules {
|
|
||||||
extraRules[i] = *rule
|
|
||||||
extraRules[i].ProtoPort = fmt.Sprintf("%s/%d", parts[0], minPorts+i+1)
|
|
||||||
}
|
|
||||||
// edit ProtoPort from port range to a single port that value is minPort. For example, 80-90 -> 80
|
|
||||||
rule.ProtoPort = fmt.Sprintf("%s/%d", parts[0], minPorts)
|
|
||||||
return extraRules
|
|
||||||
}
|
|
||||||
|
|
|
@ -304,75 +304,72 @@ func Test_parseGeoSiteName(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_splitPortRangeRules(t *testing.T) {
|
func TestCompileRangePort(t *testing.T) {
|
||||||
|
ob1, ob2, ob3, ob4 := 1, 2, 3, 4
|
||||||
rules := []TextRule{
|
rules := []TextRule{
|
||||||
{
|
|
||||||
Outbound: "ob1",
|
|
||||||
Address: "1.2.3.4",
|
|
||||||
ProtoPort: "tcp/1-1024",
|
|
||||||
HijackAddress: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Outbound: "ob1",
|
|
||||||
Address: "1.2.3.4",
|
|
||||||
ProtoPort: "udp/1-1024",
|
|
||||||
HijackAddress: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Outbound: "ob1",
|
|
||||||
Address: "1.2.3.4",
|
|
||||||
ProtoPort: "*/1-1024",
|
|
||||||
HijackAddress: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Outbound: "ob1",
|
|
||||||
Address: "1.2.3.4",
|
|
||||||
ProtoPort: "tcp/0-222",
|
|
||||||
HijackAddress: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Outbound: "ob1",
|
|
||||||
Address: "1.2.3.4",
|
|
||||||
ProtoPort: "tcp/1024",
|
|
||||||
HijackAddress: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Outbound: "ob1",
|
|
||||||
Address: "1.2.3.4",
|
|
||||||
ProtoPort: "tcp/-1-9",
|
|
||||||
HijackAddress: "",
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Outbound: "ob1",
|
Outbound: "ob1",
|
||||||
Address: "1.2.3.4",
|
Address: "1.2.3.4",
|
||||||
ProtoPort: "tcp/6881-6889",
|
ProtoPort: "tcp/6881-6889",
|
||||||
HijackAddress: "",
|
HijackAddress: "",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Outbound: "ob2",
|
||||||
|
Address: "8.8.8.0/24",
|
||||||
|
ProtoPort: "udp/2525-3333",
|
||||||
|
HijackAddress: "1.1.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Outbound: "ob3",
|
||||||
|
Address: "1.1.1.0/24",
|
||||||
|
ProtoPort: "*/1-65535",
|
||||||
|
HijackAddress: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Outbound: "ob4",
|
||||||
|
Address: "1.1.1.0/24",
|
||||||
|
ProtoPort: "*/22",
|
||||||
|
HijackAddress: "",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
_, rangeLen0 := splitPortRangeRules(&rules[0])
|
_, err := Compile[int](rules, map[string]int{
|
||||||
assert.Equal(t, 1023, rangeLen0)
|
"ob1": ob1,
|
||||||
|
"ob2": ob2,
|
||||||
|
"ob3": ob3,
|
||||||
|
"ob4": ob4,
|
||||||
|
}, 100, &testGeoLoader{})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, rangeLen1 := splitPortRangeRules(&rules[1])
|
ob11 := 1
|
||||||
assert.Equal(t, 1023, rangeLen1)
|
rules2 := []TextRule{
|
||||||
|
|
||||||
_, rangeLen2 := splitPortRangeRules(&rules[2])
|
{
|
||||||
assert.Equal(t, 1023, rangeLen2)
|
Outbound: "ob11",
|
||||||
|
Address: "1.1.2.0/24",
|
||||||
_, rangeLen3 := splitPortRangeRules(&rules[3])
|
ProtoPort: "*/3-1", // invalid range
|
||||||
assert.Equal(t, 0, rangeLen3)
|
HijackAddress: "",
|
||||||
|
},
|
||||||
_, rangeLen4 := splitPortRangeRules(&rules[4])
|
|
||||||
assert.Equal(t, 0, rangeLen4)
|
|
||||||
|
|
||||||
_, rangeLen5 := splitPortRangeRules(&rules[5])
|
|
||||||
assert.Equal(t, 0, rangeLen5)
|
|
||||||
|
|
||||||
rangeRule, _ := splitPortRangeRules(&rules[6])
|
|
||||||
for _, rule := range rangeRule {
|
|
||||||
assert.Equal(t, "ob1", rule.Outbound)
|
|
||||||
assert.Equal(t, "1.2.3.4", rule.Address)
|
|
||||||
t.Log(rule.ProtoPort)
|
|
||||||
assert.Equal(t, "", rule.HijackAddress)
|
|
||||||
}
|
}
|
||||||
assert.Equal(t, "tcp/6881", rules[6].ProtoPort)
|
|
||||||
|
_, err = Compile[int](rules2, map[string]int{
|
||||||
|
"ob11": ob11,
|
||||||
|
}, 100, &testGeoLoader{})
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
ob21 := 1
|
||||||
|
rules3 := []TextRule{
|
||||||
|
|
||||||
|
{
|
||||||
|
Outbound: "ob21",
|
||||||
|
Address: "1.1.2.0/24",
|
||||||
|
ProtoPort: "*/-114-514", // invalid range
|
||||||
|
HijackAddress: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = Compile[int](rules3, map[string]int{
|
||||||
|
"ob21": ob21,
|
||||||
|
}, 100, &testGeoLoader{})
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue