mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-02 03:57:38 +03:00
feat: allow specifying port ranges for sniffing
This commit is contained in:
parent
b481b49a28
commit
deeeafd8d7
9 changed files with 276 additions and 182 deletions
|
@ -36,6 +36,7 @@ import (
|
|||
"github.com/apernet/hysteria/extras/v2/outbounds"
|
||||
"github.com/apernet/hysteria/extras/v2/sniff"
|
||||
"github.com/apernet/hysteria/extras/v2/trafficlogger"
|
||||
eUtils "github.com/apernet/hysteria/extras/v2/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -185,6 +186,8 @@ type serverConfigSniff struct {
|
|||
Enable bool `mapstructure:"enable"`
|
||||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
RewriteDomain bool `mapstructure:"rewriteDomain"`
|
||||
TCPPorts string `mapstructure:"tcpPorts"`
|
||||
UDPPorts string `mapstructure:"udpPorts"`
|
||||
}
|
||||
|
||||
type serverConfigACL struct {
|
||||
|
@ -551,10 +554,23 @@ func serverConfigOutboundHTTPToOutbound(c serverConfigOutboundHTTP) (outbounds.P
|
|||
|
||||
func (c *serverConfig) fillRequestHook(hyConfig *server.Config) error {
|
||||
if c.Sniff.Enable {
|
||||
hyConfig.RequestHook = &sniff.Sniffer{
|
||||
s := &sniff.Sniffer{
|
||||
Timeout: c.Sniff.Timeout,
|
||||
RewriteDomain: c.Sniff.RewriteDomain,
|
||||
}
|
||||
if c.Sniff.TCPPorts != "" {
|
||||
s.TCPPorts = eUtils.ParsePortUnion(c.Sniff.TCPPorts)
|
||||
if s.TCPPorts == nil {
|
||||
return configError{Field: "sniff.tcpPorts", Err: errors.New("invalid port union")}
|
||||
}
|
||||
}
|
||||
if c.Sniff.UDPPorts != "" {
|
||||
s.UDPPorts = eUtils.ParsePortUnion(c.Sniff.UDPPorts)
|
||||
if s.UDPPorts == nil {
|
||||
return configError{Field: "sniff.udpPorts", Err: errors.New("invalid port union")}
|
||||
}
|
||||
}
|
||||
hyConfig.RequestHook = s
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -115,6 +115,8 @@ func TestServerConfig(t *testing.T) {
|
|||
Enable: true,
|
||||
Timeout: 1 * time.Second,
|
||||
RewriteDomain: true,
|
||||
TCPPorts: "80,443,1000-2000",
|
||||
UDPPorts: "443",
|
||||
},
|
||||
ACL: serverConfigACL{
|
||||
File: "chnroute.txt",
|
||||
|
|
|
@ -87,6 +87,8 @@ sniff:
|
|||
enable: true
|
||||
timeout: 1s
|
||||
rewriteDomain: true
|
||||
tcpPorts: 80,443,1000-2000
|
||||
udpPorts: 443
|
||||
|
||||
acl:
|
||||
file: chnroute.txt
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -13,6 +14,7 @@ import (
|
|||
|
||||
"github.com/apernet/hysteria/core/v2/server"
|
||||
quicInternal "github.com/apernet/hysteria/extras/v2/sniff/internal/quic"
|
||||
"github.com/apernet/hysteria/extras/v2/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -29,6 +31,8 @@ var _ server.RequestHook = (*Sniffer)(nil)
|
|||
type Sniffer struct {
|
||||
Timeout time.Duration
|
||||
RewriteDomain bool // Whether to rewrite the address even when it's already a domain
|
||||
TCPPorts utils.PortUnion
|
||||
UDPPorts utils.PortUnion
|
||||
}
|
||||
|
||||
func (h *Sniffer) isDomain(addr string) bool {
|
||||
|
@ -62,7 +66,26 @@ func (h *Sniffer) isTLS(buf []byte) bool {
|
|||
|
||||
func (h *Sniffer) Check(isUDP bool, reqAddr string) bool {
|
||||
// @ means it's internal (e.g. speed test)
|
||||
return !strings.HasPrefix(reqAddr, "@") && (h.RewriteDomain || !h.isDomain(reqAddr))
|
||||
if strings.HasPrefix(reqAddr, "@") {
|
||||
return false
|
||||
}
|
||||
host, port, err := net.SplitHostPort(reqAddr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if !h.RewriteDomain && net.ParseIP(host) == nil {
|
||||
// Is a domain and domain rewriting is disabled
|
||||
return false
|
||||
}
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if isUDP {
|
||||
return h.UDPPorts == nil || h.UDPPorts.Contains(uint16(portNum))
|
||||
} else {
|
||||
return h.TCPPorts == nil || h.TCPPorts.Contains(uint16(portNum))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Sniffer) TCP(stream quic.Stream, reqAddr *string) ([]byte, error) {
|
||||
|
|
|
@ -6,10 +6,35 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/apernet/hysteria/extras/v2/utils"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestSnifferCheck(t *testing.T) {
|
||||
sniffer := &Sniffer{
|
||||
Timeout: 1 * time.Second,
|
||||
RewriteDomain: false,
|
||||
TCPPorts: nil, // nil = all
|
||||
UDPPorts: nil, // nil = all
|
||||
}
|
||||
|
||||
assert.True(t, sniffer.Check(false, "1.1.1.1:80"))
|
||||
assert.False(t, sniffer.Check(false, "example.com:443"))
|
||||
|
||||
sniffer.RewriteDomain = true
|
||||
assert.True(t, sniffer.Check(false, "example.com:443"))
|
||||
|
||||
sniffer.TCPPorts = []utils.PortRange{{80, 80}}
|
||||
assert.True(t, sniffer.Check(false, "google.com:80"))
|
||||
assert.False(t, sniffer.Check(false, "google.com:443"))
|
||||
|
||||
sniffer.UDPPorts = []utils.PortRange{{443, 443}}
|
||||
assert.True(t, sniffer.Check(true, "google.com:443"))
|
||||
assert.False(t, sniffer.Check(true, "google.com:80"))
|
||||
}
|
||||
|
||||
func TestSnifferTCP(t *testing.T) {
|
||||
sniffer := &Sniffer{
|
||||
Timeout: 1 * time.Second,
|
||||
|
@ -40,27 +65,16 @@ func TestSnifferTCP(t *testing.T) {
|
|||
|
||||
// Rewrite IP to domain
|
||||
reqAddr := "111.111.111.111:80"
|
||||
assert.True(t, sniffer.Check(false, reqAddr))
|
||||
putback, err := sniffer.TCP(stream, &reqAddr)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, *buf, putback)
|
||||
assert.Equal(t, "example.com:80", reqAddr)
|
||||
|
||||
// Do not rewrite if it's already a domain
|
||||
index = 0
|
||||
reqAddr = "gulag.cc:443"
|
||||
assert.False(t, sniffer.Check(false, reqAddr))
|
||||
|
||||
// Turn on rewrite and now it should rewrite
|
||||
sniffer.RewriteDomain = true
|
||||
assert.True(t, sniffer.Check(false, reqAddr))
|
||||
|
||||
// Test TLS
|
||||
*buf, err = base64.StdEncoding.DecodeString("FgMBARcBAAETAwPJL2jlt1OAo+Rslkjv/aqKiTthKMaCKg2Gvd+uALDbDCDdY+UIk8ouadEB9fC3j52Y1i7SJZqGIgBRIS6kKieYrAAoEwITAcAswCvAMMAvwCTAI8AowCfACsAJwBTAEwCdAJwAPQA8ADUALwEAAKIAAAAOAAwAAAlpcGluZm8uaW8ABQAFAQAAAAAAKwAJCAMEAwMDAgMBAA0AGgAYCAQIBQgGBAEFAQIBBAMFAwIDAgIGAQYDACMAAAAKAAgABgAdABcAGAAQAAsACQhodHRwLzEuMQAzACYAJAAdACBguQbqNJNyamYxYcrBFpBP7pWv5TgZsP9gwGtMYNKVBQAxAAAAFwAA/wEAAQAALQACAQE=")
|
||||
assert.NoError(t, err)
|
||||
index = 0
|
||||
reqAddr = "222.222.222.222:443"
|
||||
assert.True(t, sniffer.Check(false, reqAddr))
|
||||
putback, err = sniffer.TCP(stream, &reqAddr)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, *buf, putback)
|
||||
|
@ -70,7 +84,6 @@ func TestSnifferTCP(t *testing.T) {
|
|||
*buf = []byte("Wait It's All Ohio? Always Has Been.")
|
||||
index = 0
|
||||
reqAddr = "123.123.123.123:123"
|
||||
assert.True(t, sniffer.Check(false, reqAddr))
|
||||
putback, err = sniffer.TCP(stream, &reqAddr)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, *buf, putback)
|
||||
|
@ -80,7 +93,6 @@ func TestSnifferTCP(t *testing.T) {
|
|||
*buf = []byte("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a")
|
||||
index = 0
|
||||
reqAddr = "45.45.45.45:45"
|
||||
assert.True(t, sniffer.Check(false, reqAddr))
|
||||
putback, err = sniffer.TCP(stream, &reqAddr)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("\x01\x02\x03"), putback)
|
||||
|
@ -94,7 +106,6 @@ func TestSnifferTCP(t *testing.T) {
|
|||
return 0, io.EOF
|
||||
})
|
||||
reqAddr = "66.66.66.66:66"
|
||||
assert.True(t, sniffer.Check(false, reqAddr))
|
||||
putback, err = sniffer.TCP(blockStream, &reqAddr)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte{}, putback)
|
||||
|
@ -109,7 +120,6 @@ func TestSnifferUDP(t *testing.T) {
|
|||
|
||||
// Test QUIC
|
||||
reqAddr := "2.3.4.5:443"
|
||||
assert.True(t, sniffer.Check(true, reqAddr))
|
||||
pkt, err := base64.StdEncoding.DecodeString("ygAAAAEIwugWgPS7ulYAAES8hY891uwgGE9GG4CPOLd+nsDe28raso24lCSFmlFwYQG1uF39ikbL13/R9ZTghYmTl+jEbr6F9TxxRiOgpTmKRmh6aKZiIiVfy5pVRckovaI8lq0WRoW9xoFNTyYtQP8TVJ3bLCK+zUqpquEQSyWf7CE43ywayyMpE9UlIoPXFWCoopXLM1SvzdQ+17P51N9KR7m4emti4DWWTBLMQOvrwd2HEEkbiZdRO1wf6ZXJlIat5dN0R/6uod60OFPO+u+awvq67MoMReC7+5I/xWI+xx6o4JpnZNn6YPG8Gqi8hS6doNcAAdtD8h5eMLuHCCgkpX3QVjjfWtcOhtw9xKjU43HhUPwzUTv+JDLgwuTQCTmlfYlb3B+pk4b2I9si0tJ0SBuYaZ2VQPtZbj2hpGXw3gn11pbN8xsbKkQL50+Scd4dGJxWQlGaJHeaU5WOCkxLXc635z8m5XO/CBHVYPGp4pfwfwNUgbe5WF+3MaUIlDB8dMfsnrO0BmZPo379jVx0SFLTAiS8wAdHib1WNEY8qKYnTWuiyxYg1GZEhJt0nXmI+8f0eJq42DgHBWC+Rf5rRBr/Sf25o3mFAmTUaul0Woo9/CIrpT73B63N91xd9A77i4ru995YG8l9Hen+eLtpDU9Q9376nwMDYBzeYG9U/Rn0Urbm6q4hmAgV/xlNJ2rAyDS+yLnwqD6I0PRy8bZJEttcidb/SkOyrpgMiAzWeT+SO+c/k+Y8H0UTRa05faZUrhuUaym9wAcaIVRA6nFI+fejfjVp+7afFv+kWn3vCqQEij+CRHuxkltrixZMD2rfYj6NUW7TTYBtPRtuV/V0ZIDjRR26vr4K+0D84+l3c0mA/l6nmpP5kkco3nmpdjtQN6sGXL7+5o0nnsftX5d6/n5mLyEpP+AEDl1zk3iqkS62RsITwql6DMMoGbSDdUpMclCIeM0vlo3CkxGMO7QA9ruVeNddkL3EWMivl+uxO43sXEEqYQHVl4N75y63t05GOf7/gm9Kb/BJ8MpG9ViEkVYaskQCzi3D8bVpzo8FfTj8te8B6c3ikc/cm7r8k0ZcZpr+YiLGDYq+0ilHxpqJfmq8dPkSvxdzLcUSvy7+LMQ/TTobRSF7L4JhtDKck0+00vl9H35Tkh9N+MsVtpKdWyoqZ4XaK2Nx1M6AieczXpdFc0y7lYPoUfF4IeW8WzeVUclol5ElYjkyFz/lDOGAe1bF2g5AYaGWCPiGleVZknNdD5ihB8W8Mfkt1pEwq2S97AHrppqkf/VoIfZzeqH8wUFw8fDDrZIpnoa0rW7HfwIQaqJhPCyB9Z6TVbV4x9UWmaHfVAcinCK/7o10dtaj3rvEqcUC/iPceGq3Tqv/p9GGNJ+Ci2JBjXqNxYr893Llk75VdPD9pM6y1SM0P80oXNy32VMtafkFFST8GpvvqWcxUJ93kzaY8RmU1g3XFOImSU2utU6+FUQ2Pn5uLwcfT2cTYfTpPGh+WXjSbZ6trqdEMEsLHybuPo2UN4WpVLXVQma3kSaHQggcLlEip8GhEUAy/xCb2eKqhI4HkDpDjwDnDVKufWlnRaOHf58cc8Woi+WT8JTOkHC+nBEG6fKRPHDG08U5yayIQIjI")
|
||||
assert.NoError(t, err)
|
||||
err = sniffer.UDP(pkt, &reqAddr)
|
||||
|
@ -119,7 +129,6 @@ func TestSnifferUDP(t *testing.T) {
|
|||
// Test unrecognized
|
||||
pkt = []byte("oh my sweet summer child")
|
||||
reqAddr = "90.90.90.90:90"
|
||||
assert.True(t, sniffer.Check(true, reqAddr))
|
||||
err = sniffer.UDP(pkt, &reqAddr)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "90.90.90.90:90", reqAddr)
|
||||
|
|
|
@ -3,8 +3,8 @@ package udphop
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/apernet/hysteria/extras/v2/utils"
|
||||
)
|
||||
|
||||
type InvalidPortError struct {
|
||||
|
@ -57,36 +57,11 @@ func ResolveUDPHopAddr(addr string) (*UDPHopAddr, error) {
|
|||
PortStr: portStr,
|
||||
}
|
||||
|
||||
portStrs := strings.Split(portStr, ",")
|
||||
for _, portStr := range portStrs {
|
||||
if strings.Contains(portStr, "-") {
|
||||
// Port range
|
||||
portRange := strings.Split(portStr, "-")
|
||||
if len(portRange) != 2 {
|
||||
return nil, InvalidPortError{portStr}
|
||||
}
|
||||
start, err := strconv.ParseUint(portRange[0], 10, 16)
|
||||
if err != nil {
|
||||
return nil, InvalidPortError{portStr}
|
||||
}
|
||||
end, err := strconv.ParseUint(portRange[1], 10, 16)
|
||||
if err != nil {
|
||||
return nil, InvalidPortError{portStr}
|
||||
}
|
||||
if start > end {
|
||||
start, end = end, start
|
||||
}
|
||||
for i := start; i <= end; i++ {
|
||||
result.Ports = append(result.Ports, uint16(i))
|
||||
}
|
||||
} else {
|
||||
// Single port
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return nil, InvalidPortError{portStr}
|
||||
}
|
||||
result.Ports = append(result.Ports, uint16(port))
|
||||
}
|
||||
pu := utils.ParsePortUnion(portStr)
|
||||
if pu == nil {
|
||||
return nil, InvalidPortError{portStr}
|
||||
}
|
||||
result.Ports = pu.Ports()
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
@ -1,132 +0,0 @@
|
|||
package udphop
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveUDPHopAddr(t *testing.T) {
|
||||
type args struct {
|
||||
addr string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *UDPHopAddr
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
addr: "",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no port",
|
||||
args: args{
|
||||
addr: "8.8.8.8",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "single port",
|
||||
args: args{
|
||||
addr: "8.8.4.4:1234",
|
||||
},
|
||||
want: &UDPHopAddr{
|
||||
IP: net.ParseIP("8.8.4.4"),
|
||||
Ports: []uint16{1234},
|
||||
PortStr: "1234",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple ports",
|
||||
args: args{
|
||||
addr: "8.8.3.3:1234,5678,9012",
|
||||
},
|
||||
want: &UDPHopAddr{
|
||||
IP: net.ParseIP("8.8.3.3"),
|
||||
Ports: []uint16{1234, 5678, 9012},
|
||||
PortStr: "1234,5678,9012",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port range",
|
||||
args: args{
|
||||
addr: "1.2.3.4:1234-1240",
|
||||
},
|
||||
want: &UDPHopAddr{
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Ports: []uint16{1234, 1235, 1236, 1237, 1238, 1239, 1240},
|
||||
PortStr: "1234-1240",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port range reversed",
|
||||
args: args{
|
||||
addr: "123.123.123.123:9990-9980",
|
||||
},
|
||||
want: &UDPHopAddr{
|
||||
IP: net.ParseIP("123.123.123.123"),
|
||||
Ports: []uint16{9980, 9981, 9982, 9983, 9984, 9985, 9986, 9987, 9988, 9989, 9990},
|
||||
PortStr: "9990-9980",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port range & port list",
|
||||
args: args{
|
||||
addr: "9.9.9.9:1234-1236,5678,9012",
|
||||
},
|
||||
want: &UDPHopAddr{
|
||||
IP: net.ParseIP("9.9.9.9"),
|
||||
Ports: []uint16{1234, 1235, 1236, 5678, 9012},
|
||||
PortStr: "1234-1236,5678,9012",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid port",
|
||||
args: args{
|
||||
addr: "5.5.5.5:1234,bs",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid port range 1",
|
||||
args: args{
|
||||
addr: "6.6.6.6:7788-bbss",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid port range 2",
|
||||
args: args{
|
||||
addr: "1.0.0.1:8899-9002-9005",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ResolveUDPHopAddr(tt.args.addr)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseUDPHopAddr() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ParseUDPHopAddr() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
107
extras/utils/portunion.go
Normal file
107
extras/utils/portunion.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PortUnion is a collection of multiple port ranges.
|
||||
type PortUnion []PortRange
|
||||
|
||||
// PortRange represents a range of ports.
|
||||
// Start and End are inclusive. [Start, End]
|
||||
type PortRange struct {
|
||||
Start, End uint16
|
||||
}
|
||||
|
||||
// ParsePortUnion parses a string of comma-separated port ranges (or single ports) into a PortUnion.
|
||||
// Returns nil if the input is invalid.
|
||||
// The returned PortUnion is guaranteed to be normalized.
|
||||
func ParsePortUnion(s string) PortUnion {
|
||||
if s == "all" || s == "*" {
|
||||
// Wildcard special case
|
||||
return PortUnion{PortRange{0, 65535}}
|
||||
}
|
||||
var result PortUnion
|
||||
portStrs := strings.Split(s, ",")
|
||||
for _, portStr := range portStrs {
|
||||
if strings.Contains(portStr, "-") {
|
||||
// Port range
|
||||
portRange := strings.Split(portStr, "-")
|
||||
if len(portRange) != 2 {
|
||||
return nil
|
||||
}
|
||||
start, err := strconv.ParseUint(portRange[0], 10, 16)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
end, err := strconv.ParseUint(portRange[1], 10, 16)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if start > end {
|
||||
start, end = end, start
|
||||
}
|
||||
result = append(result, PortRange{uint16(start), uint16(end)})
|
||||
} else {
|
||||
// Single port
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
result = append(result, PortRange{uint16(port), uint16(port)})
|
||||
}
|
||||
}
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
return result.Normalize()
|
||||
}
|
||||
|
||||
// Normalize normalizes a PortUnion.
|
||||
// No overlapping ranges, ranges are sorted from low to high.
|
||||
func (u PortUnion) Normalize() PortUnion {
|
||||
if len(u) == 0 {
|
||||
return u
|
||||
}
|
||||
sort.Slice(u, func(i, j int) bool {
|
||||
if u[i].Start == u[j].Start {
|
||||
return u[i].End < u[j].End
|
||||
}
|
||||
return u[i].Start < u[j].Start
|
||||
})
|
||||
normalized := PortUnion{u[0]}
|
||||
for _, current := range u[1:] {
|
||||
last := &normalized[len(normalized)-1]
|
||||
if current.Start <= last.End+1 {
|
||||
if current.End > last.End {
|
||||
last.End = current.End
|
||||
}
|
||||
} else {
|
||||
normalized = append(normalized, current)
|
||||
}
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
// Ports returns all ports in the PortUnion as a slice.
|
||||
func (u PortUnion) Ports() []uint16 {
|
||||
var ports []uint16
|
||||
for _, r := range u {
|
||||
for i := r.Start; i <= r.End; i++ {
|
||||
ports = append(ports, i)
|
||||
}
|
||||
}
|
||||
return ports
|
||||
}
|
||||
|
||||
// Contains returns true if the PortUnion contains the given port.
|
||||
func (u PortUnion) Contains(port uint16) bool {
|
||||
for _, r := range u {
|
||||
if port >= r.Start && port <= r.End {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
92
extras/utils/portunion_test.go
Normal file
92
extras/utils/portunion_test.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParsePortUnion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
want PortUnion
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
s: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "all 1",
|
||||
s: "all",
|
||||
want: PortUnion{{0, 65535}},
|
||||
},
|
||||
{
|
||||
name: "all 2",
|
||||
s: "*",
|
||||
want: PortUnion{{0, 65535}},
|
||||
},
|
||||
{
|
||||
name: "single port",
|
||||
s: "1234",
|
||||
want: PortUnion{{1234, 1234}},
|
||||
},
|
||||
{
|
||||
name: "multiple ports (unsorted)",
|
||||
s: "5678,1234,9012",
|
||||
want: PortUnion{{1234, 1234}, {5678, 5678}, {9012, 9012}},
|
||||
},
|
||||
{
|
||||
name: "one range",
|
||||
s: "1234-1240",
|
||||
want: PortUnion{{1234, 1240}},
|
||||
},
|
||||
{
|
||||
name: "one range (reversed)",
|
||||
s: "1240-1234",
|
||||
want: PortUnion{{1234, 1240}},
|
||||
},
|
||||
{
|
||||
name: "multiple ports and ranges (reversed, unsorted, overlapping)",
|
||||
s: "5678,1200-1236,9100-9012,1234-1240",
|
||||
want: PortUnion{{1200, 1240}, {5678, 5678}, {9012, 9100}},
|
||||
},
|
||||
{
|
||||
name: "invalid 1",
|
||||
s: "1234-",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid 2",
|
||||
s: "1234-ggez",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid 3",
|
||||
s: "233,",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid 4",
|
||||
s: "1234-1240-1250",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid 5",
|
||||
s: "-,,",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid 6",
|
||||
s: "http",
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ParsePortUnion(tt.s); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ParsePortUnion() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue