From deeeafd8d7726760c66390b48e70dff9805b18e9 Mon Sep 17 00:00:00 2001 From: Toby Date: Sun, 30 Jun 2024 12:04:59 -0700 Subject: [PATCH] feat: allow specifying port ranges for sniffing --- app/cmd/server.go | 18 +++- app/cmd/server_test.go | 2 + app/cmd/server_test.yaml | 2 + extras/sniff/sniff.go | 25 ++++- extras/sniff/sniff_test.go | 41 +++++---- extras/transport/udphop/addr.go | 39 ++------ extras/transport/udphop/addr_test.go | 132 --------------------------- extras/utils/portunion.go | 107 ++++++++++++++++++++++ extras/utils/portunion_test.go | 92 +++++++++++++++++++ 9 files changed, 276 insertions(+), 182 deletions(-) delete mode 100644 extras/transport/udphop/addr_test.go create mode 100644 extras/utils/portunion.go create mode 100644 extras/utils/portunion_test.go diff --git a/app/cmd/server.go b/app/cmd/server.go index fdd9a54..b45fb15 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -36,6 +36,7 @@ import ( "github.com/apernet/hysteria/extras/v2/outbounds" "github.com/apernet/hysteria/extras/v2/sniff" "github.com/apernet/hysteria/extras/v2/trafficlogger" + eUtils "github.com/apernet/hysteria/extras/v2/utils" ) const ( @@ -185,6 +186,8 @@ type serverConfigSniff struct { Enable bool `mapstructure:"enable"` Timeout time.Duration `mapstructure:"timeout"` RewriteDomain bool `mapstructure:"rewriteDomain"` + TCPPorts string `mapstructure:"tcpPorts"` + UDPPorts string `mapstructure:"udpPorts"` } type serverConfigACL struct { @@ -551,10 +554,23 @@ func serverConfigOutboundHTTPToOutbound(c serverConfigOutboundHTTP) (outbounds.P func (c *serverConfig) fillRequestHook(hyConfig *server.Config) error { if c.Sniff.Enable { - hyConfig.RequestHook = &sniff.Sniffer{ + s := &sniff.Sniffer{ Timeout: c.Sniff.Timeout, RewriteDomain: c.Sniff.RewriteDomain, } + if c.Sniff.TCPPorts != "" { + s.TCPPorts = eUtils.ParsePortUnion(c.Sniff.TCPPorts) + if s.TCPPorts == nil { + return configError{Field: "sniff.tcpPorts", Err: errors.New("invalid port union")} + } + } + if c.Sniff.UDPPorts != "" { + s.UDPPorts = eUtils.ParsePortUnion(c.Sniff.UDPPorts) + if s.UDPPorts == nil { + return configError{Field: "sniff.udpPorts", Err: errors.New("invalid port union")} + } + } + hyConfig.RequestHook = s } return nil } diff --git a/app/cmd/server_test.go b/app/cmd/server_test.go index bd46681..bb2d12a 100644 --- a/app/cmd/server_test.go +++ b/app/cmd/server_test.go @@ -115,6 +115,8 @@ func TestServerConfig(t *testing.T) { Enable: true, Timeout: 1 * time.Second, RewriteDomain: true, + TCPPorts: "80,443,1000-2000", + UDPPorts: "443", }, ACL: serverConfigACL{ File: "chnroute.txt", diff --git a/app/cmd/server_test.yaml b/app/cmd/server_test.yaml index 343b0a9..ff0bf52 100644 --- a/app/cmd/server_test.yaml +++ b/app/cmd/server_test.yaml @@ -87,6 +87,8 @@ sniff: enable: true timeout: 1s rewriteDomain: true + tcpPorts: 80,443,1000-2000 + udpPorts: 443 acl: file: chnroute.txt diff --git a/extras/sniff/sniff.go b/extras/sniff/sniff.go index 68b3fbc..e0c94d4 100644 --- a/extras/sniff/sniff.go +++ b/extras/sniff/sniff.go @@ -5,6 +5,7 @@ import ( "io" "net" "net/http" + "strconv" "strings" "time" @@ -13,6 +14,7 @@ import ( "github.com/apernet/hysteria/core/v2/server" quicInternal "github.com/apernet/hysteria/extras/v2/sniff/internal/quic" + "github.com/apernet/hysteria/extras/v2/utils" ) const ( @@ -29,6 +31,8 @@ var _ server.RequestHook = (*Sniffer)(nil) type Sniffer struct { Timeout time.Duration RewriteDomain bool // Whether to rewrite the address even when it's already a domain + TCPPorts utils.PortUnion + UDPPorts utils.PortUnion } func (h *Sniffer) isDomain(addr string) bool { @@ -62,7 +66,26 @@ func (h *Sniffer) isTLS(buf []byte) bool { func (h *Sniffer) Check(isUDP bool, reqAddr string) bool { // @ means it's internal (e.g. speed test) - return !strings.HasPrefix(reqAddr, "@") && (h.RewriteDomain || !h.isDomain(reqAddr)) + if strings.HasPrefix(reqAddr, "@") { + return false + } + host, port, err := net.SplitHostPort(reqAddr) + if err != nil { + return false + } + if !h.RewriteDomain && net.ParseIP(host) == nil { + // Is a domain and domain rewriting is disabled + return false + } + portNum, err := strconv.Atoi(port) + if err != nil { + return false + } + if isUDP { + return h.UDPPorts == nil || h.UDPPorts.Contains(uint16(portNum)) + } else { + return h.TCPPorts == nil || h.TCPPorts.Contains(uint16(portNum)) + } } func (h *Sniffer) TCP(stream quic.Stream, reqAddr *string) ([]byte, error) { diff --git a/extras/sniff/sniff_test.go b/extras/sniff/sniff_test.go index fb86c3b..a22784e 100644 --- a/extras/sniff/sniff_test.go +++ b/extras/sniff/sniff_test.go @@ -6,10 +6,35 @@ import ( "testing" "time" + "github.com/apernet/hysteria/extras/v2/utils" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) +func TestSnifferCheck(t *testing.T) { + sniffer := &Sniffer{ + Timeout: 1 * time.Second, + RewriteDomain: false, + TCPPorts: nil, // nil = all + UDPPorts: nil, // nil = all + } + + assert.True(t, sniffer.Check(false, "1.1.1.1:80")) + assert.False(t, sniffer.Check(false, "example.com:443")) + + sniffer.RewriteDomain = true + assert.True(t, sniffer.Check(false, "example.com:443")) + + sniffer.TCPPorts = []utils.PortRange{{80, 80}} + assert.True(t, sniffer.Check(false, "google.com:80")) + assert.False(t, sniffer.Check(false, "google.com:443")) + + sniffer.UDPPorts = []utils.PortRange{{443, 443}} + assert.True(t, sniffer.Check(true, "google.com:443")) + assert.False(t, sniffer.Check(true, "google.com:80")) +} + func TestSnifferTCP(t *testing.T) { sniffer := &Sniffer{ Timeout: 1 * time.Second, @@ -40,27 +65,16 @@ func TestSnifferTCP(t *testing.T) { // Rewrite IP to domain reqAddr := "111.111.111.111:80" - assert.True(t, sniffer.Check(false, reqAddr)) putback, err := sniffer.TCP(stream, &reqAddr) assert.NoError(t, err) assert.Equal(t, *buf, putback) assert.Equal(t, "example.com:80", reqAddr) - // Do not rewrite if it's already a domain - index = 0 - reqAddr = "gulag.cc:443" - assert.False(t, sniffer.Check(false, reqAddr)) - - // Turn on rewrite and now it should rewrite - sniffer.RewriteDomain = true - assert.True(t, sniffer.Check(false, reqAddr)) - // Test TLS *buf, err = base64.StdEncoding.DecodeString("FgMBARcBAAETAwPJL2jlt1OAo+Rslkjv/aqKiTthKMaCKg2Gvd+uALDbDCDdY+UIk8ouadEB9fC3j52Y1i7SJZqGIgBRIS6kKieYrAAoEwITAcAswCvAMMAvwCTAI8AowCfACsAJwBTAEwCdAJwAPQA8ADUALwEAAKIAAAAOAAwAAAlpcGluZm8uaW8ABQAFAQAAAAAAKwAJCAMEAwMDAgMBAA0AGgAYCAQIBQgGBAEFAQIBBAMFAwIDAgIGAQYDACMAAAAKAAgABgAdABcAGAAQAAsACQhodHRwLzEuMQAzACYAJAAdACBguQbqNJNyamYxYcrBFpBP7pWv5TgZsP9gwGtMYNKVBQAxAAAAFwAA/wEAAQAALQACAQE=") assert.NoError(t, err) index = 0 reqAddr = "222.222.222.222:443" - assert.True(t, sniffer.Check(false, reqAddr)) putback, err = sniffer.TCP(stream, &reqAddr) assert.NoError(t, err) assert.Equal(t, *buf, putback) @@ -70,7 +84,6 @@ func TestSnifferTCP(t *testing.T) { *buf = []byte("Wait It's All Ohio? Always Has Been.") index = 0 reqAddr = "123.123.123.123:123" - assert.True(t, sniffer.Check(false, reqAddr)) putback, err = sniffer.TCP(stream, &reqAddr) assert.NoError(t, err) assert.Equal(t, *buf, putback) @@ -80,7 +93,6 @@ func TestSnifferTCP(t *testing.T) { *buf = []byte("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a") index = 0 reqAddr = "45.45.45.45:45" - assert.True(t, sniffer.Check(false, reqAddr)) putback, err = sniffer.TCP(stream, &reqAddr) assert.NoError(t, err) assert.Equal(t, []byte("\x01\x02\x03"), putback) @@ -94,7 +106,6 @@ func TestSnifferTCP(t *testing.T) { return 0, io.EOF }) reqAddr = "66.66.66.66:66" - assert.True(t, sniffer.Check(false, reqAddr)) putback, err = sniffer.TCP(blockStream, &reqAddr) assert.NoError(t, err) assert.Equal(t, []byte{}, putback) @@ -109,7 +120,6 @@ func TestSnifferUDP(t *testing.T) { // Test QUIC reqAddr := "2.3.4.5:443" - assert.True(t, sniffer.Check(true, reqAddr)) pkt, err := base64.StdEncoding.DecodeString("ygAAAAEIwugWgPS7ulYAAES8hY891uwgGE9GG4CPOLd+nsDe28raso24lCSFmlFwYQG1uF39ikbL13/R9ZTghYmTl+jEbr6F9TxxRiOgpTmKRmh6aKZiIiVfy5pVRckovaI8lq0WRoW9xoFNTyYtQP8TVJ3bLCK+zUqpquEQSyWf7CE43ywayyMpE9UlIoPXFWCoopXLM1SvzdQ+17P51N9KR7m4emti4DWWTBLMQOvrwd2HEEkbiZdRO1wf6ZXJlIat5dN0R/6uod60OFPO+u+awvq67MoMReC7+5I/xWI+xx6o4JpnZNn6YPG8Gqi8hS6doNcAAdtD8h5eMLuHCCgkpX3QVjjfWtcOhtw9xKjU43HhUPwzUTv+JDLgwuTQCTmlfYlb3B+pk4b2I9si0tJ0SBuYaZ2VQPtZbj2hpGXw3gn11pbN8xsbKkQL50+Scd4dGJxWQlGaJHeaU5WOCkxLXc635z8m5XO/CBHVYPGp4pfwfwNUgbe5WF+3MaUIlDB8dMfsnrO0BmZPo379jVx0SFLTAiS8wAdHib1WNEY8qKYnTWuiyxYg1GZEhJt0nXmI+8f0eJq42DgHBWC+Rf5rRBr/Sf25o3mFAmTUaul0Woo9/CIrpT73B63N91xd9A77i4ru995YG8l9Hen+eLtpDU9Q9376nwMDYBzeYG9U/Rn0Urbm6q4hmAgV/xlNJ2rAyDS+yLnwqD6I0PRy8bZJEttcidb/SkOyrpgMiAzWeT+SO+c/k+Y8H0UTRa05faZUrhuUaym9wAcaIVRA6nFI+fejfjVp+7afFv+kWn3vCqQEij+CRHuxkltrixZMD2rfYj6NUW7TTYBtPRtuV/V0ZIDjRR26vr4K+0D84+l3c0mA/l6nmpP5kkco3nmpdjtQN6sGXL7+5o0nnsftX5d6/n5mLyEpP+AEDl1zk3iqkS62RsITwql6DMMoGbSDdUpMclCIeM0vlo3CkxGMO7QA9ruVeNddkL3EWMivl+uxO43sXEEqYQHVl4N75y63t05GOf7/gm9Kb/BJ8MpG9ViEkVYaskQCzi3D8bVpzo8FfTj8te8B6c3ikc/cm7r8k0ZcZpr+YiLGDYq+0ilHxpqJfmq8dPkSvxdzLcUSvy7+LMQ/TTobRSF7L4JhtDKck0+00vl9H35Tkh9N+MsVtpKdWyoqZ4XaK2Nx1M6AieczXpdFc0y7lYPoUfF4IeW8WzeVUclol5ElYjkyFz/lDOGAe1bF2g5AYaGWCPiGleVZknNdD5ihB8W8Mfkt1pEwq2S97AHrppqkf/VoIfZzeqH8wUFw8fDDrZIpnoa0rW7HfwIQaqJhPCyB9Z6TVbV4x9UWmaHfVAcinCK/7o10dtaj3rvEqcUC/iPceGq3Tqv/p9GGNJ+Ci2JBjXqNxYr893Llk75VdPD9pM6y1SM0P80oXNy32VMtafkFFST8GpvvqWcxUJ93kzaY8RmU1g3XFOImSU2utU6+FUQ2Pn5uLwcfT2cTYfTpPGh+WXjSbZ6trqdEMEsLHybuPo2UN4WpVLXVQma3kSaHQggcLlEip8GhEUAy/xCb2eKqhI4HkDpDjwDnDVKufWlnRaOHf58cc8Woi+WT8JTOkHC+nBEG6fKRPHDG08U5yayIQIjI") assert.NoError(t, err) err = sniffer.UDP(pkt, &reqAddr) @@ -119,7 +129,6 @@ func TestSnifferUDP(t *testing.T) { // Test unrecognized pkt = []byte("oh my sweet summer child") reqAddr = "90.90.90.90:90" - assert.True(t, sniffer.Check(true, reqAddr)) err = sniffer.UDP(pkt, &reqAddr) assert.NoError(t, err) assert.Equal(t, "90.90.90.90:90", reqAddr) diff --git a/extras/transport/udphop/addr.go b/extras/transport/udphop/addr.go index 3c70472..afde26a 100644 --- a/extras/transport/udphop/addr.go +++ b/extras/transport/udphop/addr.go @@ -3,8 +3,8 @@ package udphop import ( "fmt" "net" - "strconv" - "strings" + + "github.com/apernet/hysteria/extras/v2/utils" ) type InvalidPortError struct { @@ -57,36 +57,11 @@ func ResolveUDPHopAddr(addr string) (*UDPHopAddr, error) { PortStr: portStr, } - portStrs := strings.Split(portStr, ",") - for _, portStr := range portStrs { - if strings.Contains(portStr, "-") { - // Port range - portRange := strings.Split(portStr, "-") - if len(portRange) != 2 { - return nil, InvalidPortError{portStr} - } - start, err := strconv.ParseUint(portRange[0], 10, 16) - if err != nil { - return nil, InvalidPortError{portStr} - } - end, err := strconv.ParseUint(portRange[1], 10, 16) - if err != nil { - return nil, InvalidPortError{portStr} - } - if start > end { - start, end = end, start - } - for i := start; i <= end; i++ { - result.Ports = append(result.Ports, uint16(i)) - } - } else { - // Single port - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return nil, InvalidPortError{portStr} - } - result.Ports = append(result.Ports, uint16(port)) - } + pu := utils.ParsePortUnion(portStr) + if pu == nil { + return nil, InvalidPortError{portStr} } + result.Ports = pu.Ports() + return result, nil } diff --git a/extras/transport/udphop/addr_test.go b/extras/transport/udphop/addr_test.go deleted file mode 100644 index 94a1016..0000000 --- a/extras/transport/udphop/addr_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package udphop - -import ( - "net" - "reflect" - "testing" -) - -func TestResolveUDPHopAddr(t *testing.T) { - type args struct { - addr string - } - tests := []struct { - name string - args args - want *UDPHopAddr - wantErr bool - }{ - { - name: "empty", - args: args{ - addr: "", - }, - want: nil, - wantErr: true, - }, - { - name: "no port", - args: args{ - addr: "8.8.8.8", - }, - want: nil, - wantErr: true, - }, - { - name: "single port", - args: args{ - addr: "8.8.4.4:1234", - }, - want: &UDPHopAddr{ - IP: net.ParseIP("8.8.4.4"), - Ports: []uint16{1234}, - PortStr: "1234", - }, - wantErr: false, - }, - { - name: "multiple ports", - args: args{ - addr: "8.8.3.3:1234,5678,9012", - }, - want: &UDPHopAddr{ - IP: net.ParseIP("8.8.3.3"), - Ports: []uint16{1234, 5678, 9012}, - PortStr: "1234,5678,9012", - }, - wantErr: false, - }, - { - name: "port range", - args: args{ - addr: "1.2.3.4:1234-1240", - }, - want: &UDPHopAddr{ - IP: net.ParseIP("1.2.3.4"), - Ports: []uint16{1234, 1235, 1236, 1237, 1238, 1239, 1240}, - PortStr: "1234-1240", - }, - wantErr: false, - }, - { - name: "port range reversed", - args: args{ - addr: "123.123.123.123:9990-9980", - }, - want: &UDPHopAddr{ - IP: net.ParseIP("123.123.123.123"), - Ports: []uint16{9980, 9981, 9982, 9983, 9984, 9985, 9986, 9987, 9988, 9989, 9990}, - PortStr: "9990-9980", - }, - wantErr: false, - }, - { - name: "port range & port list", - args: args{ - addr: "9.9.9.9:1234-1236,5678,9012", - }, - want: &UDPHopAddr{ - IP: net.ParseIP("9.9.9.9"), - Ports: []uint16{1234, 1235, 1236, 5678, 9012}, - PortStr: "1234-1236,5678,9012", - }, - wantErr: false, - }, - { - name: "invalid port", - args: args{ - addr: "5.5.5.5:1234,bs", - }, - want: nil, - wantErr: true, - }, - { - name: "invalid port range 1", - args: args{ - addr: "6.6.6.6:7788-bbss", - }, - want: nil, - wantErr: true, - }, - { - name: "invalid port range 2", - args: args{ - addr: "1.0.0.1:8899-9002-9005", - }, - want: nil, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := ResolveUDPHopAddr(tt.args.addr) - if (err != nil) != tt.wantErr { - t.Errorf("ParseUDPHopAddr() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("ParseUDPHopAddr() got = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/extras/utils/portunion.go b/extras/utils/portunion.go new file mode 100644 index 0000000..20a31d0 --- /dev/null +++ b/extras/utils/portunion.go @@ -0,0 +1,107 @@ +package utils + +import ( + "sort" + "strconv" + "strings" +) + +// PortUnion is a collection of multiple port ranges. +type PortUnion []PortRange + +// PortRange represents a range of ports. +// Start and End are inclusive. [Start, End] +type PortRange struct { + Start, End uint16 +} + +// ParsePortUnion parses a string of comma-separated port ranges (or single ports) into a PortUnion. +// Returns nil if the input is invalid. +// The returned PortUnion is guaranteed to be normalized. +func ParsePortUnion(s string) PortUnion { + if s == "all" || s == "*" { + // Wildcard special case + return PortUnion{PortRange{0, 65535}} + } + var result PortUnion + portStrs := strings.Split(s, ",") + for _, portStr := range portStrs { + if strings.Contains(portStr, "-") { + // Port range + portRange := strings.Split(portStr, "-") + if len(portRange) != 2 { + return nil + } + start, err := strconv.ParseUint(portRange[0], 10, 16) + if err != nil { + return nil + } + end, err := strconv.ParseUint(portRange[1], 10, 16) + if err != nil { + return nil + } + if start > end { + start, end = end, start + } + result = append(result, PortRange{uint16(start), uint16(end)}) + } else { + // Single port + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil + } + result = append(result, PortRange{uint16(port), uint16(port)}) + } + } + if result == nil { + return nil + } + return result.Normalize() +} + +// Normalize normalizes a PortUnion. +// No overlapping ranges, ranges are sorted from low to high. +func (u PortUnion) Normalize() PortUnion { + if len(u) == 0 { + return u + } + sort.Slice(u, func(i, j int) bool { + if u[i].Start == u[j].Start { + return u[i].End < u[j].End + } + return u[i].Start < u[j].Start + }) + normalized := PortUnion{u[0]} + for _, current := range u[1:] { + last := &normalized[len(normalized)-1] + if current.Start <= last.End+1 { + if current.End > last.End { + last.End = current.End + } + } else { + normalized = append(normalized, current) + } + } + return normalized +} + +// Ports returns all ports in the PortUnion as a slice. +func (u PortUnion) Ports() []uint16 { + var ports []uint16 + for _, r := range u { + for i := r.Start; i <= r.End; i++ { + ports = append(ports, i) + } + } + return ports +} + +// Contains returns true if the PortUnion contains the given port. +func (u PortUnion) Contains(port uint16) bool { + for _, r := range u { + if port >= r.Start && port <= r.End { + return true + } + } + return false +} diff --git a/extras/utils/portunion_test.go b/extras/utils/portunion_test.go new file mode 100644 index 0000000..551bae1 --- /dev/null +++ b/extras/utils/portunion_test.go @@ -0,0 +1,92 @@ +package utils + +import ( + "reflect" + "testing" +) + +func TestParsePortUnion(t *testing.T) { + tests := []struct { + name string + s string + want PortUnion + }{ + { + name: "empty", + s: "", + want: nil, + }, + { + name: "all 1", + s: "all", + want: PortUnion{{0, 65535}}, + }, + { + name: "all 2", + s: "*", + want: PortUnion{{0, 65535}}, + }, + { + name: "single port", + s: "1234", + want: PortUnion{{1234, 1234}}, + }, + { + name: "multiple ports (unsorted)", + s: "5678,1234,9012", + want: PortUnion{{1234, 1234}, {5678, 5678}, {9012, 9012}}, + }, + { + name: "one range", + s: "1234-1240", + want: PortUnion{{1234, 1240}}, + }, + { + name: "one range (reversed)", + s: "1240-1234", + want: PortUnion{{1234, 1240}}, + }, + { + name: "multiple ports and ranges (reversed, unsorted, overlapping)", + s: "5678,1200-1236,9100-9012,1234-1240", + want: PortUnion{{1200, 1240}, {5678, 5678}, {9012, 9100}}, + }, + { + name: "invalid 1", + s: "1234-", + want: nil, + }, + { + name: "invalid 2", + s: "1234-ggez", + want: nil, + }, + { + name: "invalid 3", + s: "233,", + want: nil, + }, + { + name: "invalid 4", + s: "1234-1240-1250", + want: nil, + }, + { + name: "invalid 5", + s: "-,,", + want: nil, + }, + { + name: "invalid 6", + s: "http", + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ParsePortUnion(tt.s); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParsePortUnion() = %v, want %v", got, tt.want) + } + }) + } +}