feat: allow specifying port ranges for sniffing

This commit is contained in:
Toby 2024-06-30 12:04:59 -07:00
parent b481b49a28
commit deeeafd8d7
9 changed files with 276 additions and 182 deletions

View file

@ -36,6 +36,7 @@ import (
"github.com/apernet/hysteria/extras/v2/outbounds"
"github.com/apernet/hysteria/extras/v2/sniff"
"github.com/apernet/hysteria/extras/v2/trafficlogger"
eUtils "github.com/apernet/hysteria/extras/v2/utils"
)
const (
@ -185,6 +186,8 @@ type serverConfigSniff struct {
Enable bool `mapstructure:"enable"`
Timeout time.Duration `mapstructure:"timeout"`
RewriteDomain bool `mapstructure:"rewriteDomain"`
TCPPorts string `mapstructure:"tcpPorts"`
UDPPorts string `mapstructure:"udpPorts"`
}
type serverConfigACL struct {
@ -551,10 +554,23 @@ func serverConfigOutboundHTTPToOutbound(c serverConfigOutboundHTTP) (outbounds.P
func (c *serverConfig) fillRequestHook(hyConfig *server.Config) error {
if c.Sniff.Enable {
hyConfig.RequestHook = &sniff.Sniffer{
s := &sniff.Sniffer{
Timeout: c.Sniff.Timeout,
RewriteDomain: c.Sniff.RewriteDomain,
}
if c.Sniff.TCPPorts != "" {
s.TCPPorts = eUtils.ParsePortUnion(c.Sniff.TCPPorts)
if s.TCPPorts == nil {
return configError{Field: "sniff.tcpPorts", Err: errors.New("invalid port union")}
}
}
if c.Sniff.UDPPorts != "" {
s.UDPPorts = eUtils.ParsePortUnion(c.Sniff.UDPPorts)
if s.UDPPorts == nil {
return configError{Field: "sniff.udpPorts", Err: errors.New("invalid port union")}
}
}
hyConfig.RequestHook = s
}
return nil
}

View file

@ -115,6 +115,8 @@ func TestServerConfig(t *testing.T) {
Enable: true,
Timeout: 1 * time.Second,
RewriteDomain: true,
TCPPorts: "80,443,1000-2000",
UDPPorts: "443",
},
ACL: serverConfigACL{
File: "chnroute.txt",

View file

@ -87,6 +87,8 @@ sniff:
enable: true
timeout: 1s
rewriteDomain: true
tcpPorts: 80,443,1000-2000
udpPorts: 443
acl:
file: chnroute.txt

View file

@ -5,6 +5,7 @@ import (
"io"
"net"
"net/http"
"strconv"
"strings"
"time"
@ -13,6 +14,7 @@ import (
"github.com/apernet/hysteria/core/v2/server"
quicInternal "github.com/apernet/hysteria/extras/v2/sniff/internal/quic"
"github.com/apernet/hysteria/extras/v2/utils"
)
const (
@ -29,6 +31,8 @@ var _ server.RequestHook = (*Sniffer)(nil)
type Sniffer struct {
Timeout time.Duration
RewriteDomain bool // Whether to rewrite the address even when it's already a domain
TCPPorts utils.PortUnion
UDPPorts utils.PortUnion
}
func (h *Sniffer) isDomain(addr string) bool {
@ -62,7 +66,26 @@ func (h *Sniffer) isTLS(buf []byte) bool {
func (h *Sniffer) Check(isUDP bool, reqAddr string) bool {
// @ means it's internal (e.g. speed test)
return !strings.HasPrefix(reqAddr, "@") && (h.RewriteDomain || !h.isDomain(reqAddr))
if strings.HasPrefix(reqAddr, "@") {
return false
}
host, port, err := net.SplitHostPort(reqAddr)
if err != nil {
return false
}
if !h.RewriteDomain && net.ParseIP(host) == nil {
// Is a domain and domain rewriting is disabled
return false
}
portNum, err := strconv.Atoi(port)
if err != nil {
return false
}
if isUDP {
return h.UDPPorts == nil || h.UDPPorts.Contains(uint16(portNum))
} else {
return h.TCPPorts == nil || h.TCPPorts.Contains(uint16(portNum))
}
}
func (h *Sniffer) TCP(stream quic.Stream, reqAddr *string) ([]byte, error) {

View file

@ -6,10 +6,35 @@ import (
"testing"
"time"
"github.com/apernet/hysteria/extras/v2/utils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestSnifferCheck(t *testing.T) {
sniffer := &Sniffer{
Timeout: 1 * time.Second,
RewriteDomain: false,
TCPPorts: nil, // nil = all
UDPPorts: nil, // nil = all
}
assert.True(t, sniffer.Check(false, "1.1.1.1:80"))
assert.False(t, sniffer.Check(false, "example.com:443"))
sniffer.RewriteDomain = true
assert.True(t, sniffer.Check(false, "example.com:443"))
sniffer.TCPPorts = []utils.PortRange{{80, 80}}
assert.True(t, sniffer.Check(false, "google.com:80"))
assert.False(t, sniffer.Check(false, "google.com:443"))
sniffer.UDPPorts = []utils.PortRange{{443, 443}}
assert.True(t, sniffer.Check(true, "google.com:443"))
assert.False(t, sniffer.Check(true, "google.com:80"))
}
func TestSnifferTCP(t *testing.T) {
sniffer := &Sniffer{
Timeout: 1 * time.Second,
@ -40,27 +65,16 @@ func TestSnifferTCP(t *testing.T) {
// Rewrite IP to domain
reqAddr := "111.111.111.111:80"
assert.True(t, sniffer.Check(false, reqAddr))
putback, err := sniffer.TCP(stream, &reqAddr)
assert.NoError(t, err)
assert.Equal(t, *buf, putback)
assert.Equal(t, "example.com:80", reqAddr)
// Do not rewrite if it's already a domain
index = 0
reqAddr = "gulag.cc:443"
assert.False(t, sniffer.Check(false, reqAddr))
// Turn on rewrite and now it should rewrite
sniffer.RewriteDomain = true
assert.True(t, sniffer.Check(false, reqAddr))
// Test TLS
*buf, err = base64.StdEncoding.DecodeString("FgMBARcBAAETAwPJL2jlt1OAo+Rslkjv/aqKiTthKMaCKg2Gvd+uALDbDCDdY+UIk8ouadEB9fC3j52Y1i7SJZqGIgBRIS6kKieYrAAoEwITAcAswCvAMMAvwCTAI8AowCfACsAJwBTAEwCdAJwAPQA8ADUALwEAAKIAAAAOAAwAAAlpcGluZm8uaW8ABQAFAQAAAAAAKwAJCAMEAwMDAgMBAA0AGgAYCAQIBQgGBAEFAQIBBAMFAwIDAgIGAQYDACMAAAAKAAgABgAdABcAGAAQAAsACQhodHRwLzEuMQAzACYAJAAdACBguQbqNJNyamYxYcrBFpBP7pWv5TgZsP9gwGtMYNKVBQAxAAAAFwAA/wEAAQAALQACAQE=")
assert.NoError(t, err)
index = 0
reqAddr = "222.222.222.222:443"
assert.True(t, sniffer.Check(false, reqAddr))
putback, err = sniffer.TCP(stream, &reqAddr)
assert.NoError(t, err)
assert.Equal(t, *buf, putback)
@ -70,7 +84,6 @@ func TestSnifferTCP(t *testing.T) {
*buf = []byte("Wait It's All Ohio? Always Has Been.")
index = 0
reqAddr = "123.123.123.123:123"
assert.True(t, sniffer.Check(false, reqAddr))
putback, err = sniffer.TCP(stream, &reqAddr)
assert.NoError(t, err)
assert.Equal(t, *buf, putback)
@ -80,7 +93,6 @@ func TestSnifferTCP(t *testing.T) {
*buf = []byte("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a")
index = 0
reqAddr = "45.45.45.45:45"
assert.True(t, sniffer.Check(false, reqAddr))
putback, err = sniffer.TCP(stream, &reqAddr)
assert.NoError(t, err)
assert.Equal(t, []byte("\x01\x02\x03"), putback)
@ -94,7 +106,6 @@ func TestSnifferTCP(t *testing.T) {
return 0, io.EOF
})
reqAddr = "66.66.66.66:66"
assert.True(t, sniffer.Check(false, reqAddr))
putback, err = sniffer.TCP(blockStream, &reqAddr)
assert.NoError(t, err)
assert.Equal(t, []byte{}, putback)
@ -109,7 +120,6 @@ func TestSnifferUDP(t *testing.T) {
// Test QUIC
reqAddr := "2.3.4.5:443"
assert.True(t, sniffer.Check(true, reqAddr))
pkt, err := base64.StdEncoding.DecodeString("ygAAAAEIwugWgPS7ulYAAES8hY891uwgGE9GG4CPOLd+nsDe28raso24lCSFmlFwYQG1uF39ikbL13/R9ZTghYmTl+jEbr6F9TxxRiOgpTmKRmh6aKZiIiVfy5pVRckovaI8lq0WRoW9xoFNTyYtQP8TVJ3bLCK+zUqpquEQSyWf7CE43ywayyMpE9UlIoPXFWCoopXLM1SvzdQ+17P51N9KR7m4emti4DWWTBLMQOvrwd2HEEkbiZdRO1wf6ZXJlIat5dN0R/6uod60OFPO+u+awvq67MoMReC7+5I/xWI+xx6o4JpnZNn6YPG8Gqi8hS6doNcAAdtD8h5eMLuHCCgkpX3QVjjfWtcOhtw9xKjU43HhUPwzUTv+JDLgwuTQCTmlfYlb3B+pk4b2I9si0tJ0SBuYaZ2VQPtZbj2hpGXw3gn11pbN8xsbKkQL50+Scd4dGJxWQlGaJHeaU5WOCkxLXc635z8m5XO/CBHVYPGp4pfwfwNUgbe5WF+3MaUIlDB8dMfsnrO0BmZPo379jVx0SFLTAiS8wAdHib1WNEY8qKYnTWuiyxYg1GZEhJt0nXmI+8f0eJq42DgHBWC+Rf5rRBr/Sf25o3mFAmTUaul0Woo9/CIrpT73B63N91xd9A77i4ru995YG8l9Hen+eLtpDU9Q9376nwMDYBzeYG9U/Rn0Urbm6q4hmAgV/xlNJ2rAyDS+yLnwqD6I0PRy8bZJEttcidb/SkOyrpgMiAzWeT+SO+c/k+Y8H0UTRa05faZUrhuUaym9wAcaIVRA6nFI+fejfjVp+7afFv+kWn3vCqQEij+CRHuxkltrixZMD2rfYj6NUW7TTYBtPRtuV/V0ZIDjRR26vr4K+0D84+l3c0mA/l6nmpP5kkco3nmpdjtQN6sGXL7+5o0nnsftX5d6/n5mLyEpP+AEDl1zk3iqkS62RsITwql6DMMoGbSDdUpMclCIeM0vlo3CkxGMO7QA9ruVeNddkL3EWMivl+uxO43sXEEqYQHVl4N75y63t05GOf7/gm9Kb/BJ8MpG9ViEkVYaskQCzi3D8bVpzo8FfTj8te8B6c3ikc/cm7r8k0ZcZpr+YiLGDYq+0ilHxpqJfmq8dPkSvxdzLcUSvy7+LMQ/TTobRSF7L4JhtDKck0+00vl9H35Tkh9N+MsVtpKdWyoqZ4XaK2Nx1M6AieczXpdFc0y7lYPoUfF4IeW8WzeVUclol5ElYjkyFz/lDOGAe1bF2g5AYaGWCPiGleVZknNdD5ihB8W8Mfkt1pEwq2S97AHrppqkf/VoIfZzeqH8wUFw8fDDrZIpnoa0rW7HfwIQaqJhPCyB9Z6TVbV4x9UWmaHfVAcinCK/7o10dtaj3rvEqcUC/iPceGq3Tqv/p9GGNJ+Ci2JBjXqNxYr893Llk75VdPD9pM6y1SM0P80oXNy32VMtafkFFST8GpvvqWcxUJ93kzaY8RmU1g3XFOImSU2utU6+FUQ2Pn5uLwcfT2cTYfTpPGh+WXjSbZ6trqdEMEsLHybuPo2UN4WpVLXVQma3kSaHQggcLlEip8GhEUAy/xCb2eKqhI4HkDpDjwDnDVKufWlnRaOHf58cc8Woi+WT8JTOkHC+nBEG6fKRPHDG08U5yayIQIjI")
assert.NoError(t, err)
err = sniffer.UDP(pkt, &reqAddr)
@ -119,7 +129,6 @@ func TestSnifferUDP(t *testing.T) {
// Test unrecognized
pkt = []byte("oh my sweet summer child")
reqAddr = "90.90.90.90:90"
assert.True(t, sniffer.Check(true, reqAddr))
err = sniffer.UDP(pkt, &reqAddr)
assert.NoError(t, err)
assert.Equal(t, "90.90.90.90:90", reqAddr)

View file

@ -3,8 +3,8 @@ package udphop
import (
"fmt"
"net"
"strconv"
"strings"
"github.com/apernet/hysteria/extras/v2/utils"
)
type InvalidPortError struct {
@ -57,36 +57,11 @@ func ResolveUDPHopAddr(addr string) (*UDPHopAddr, error) {
PortStr: portStr,
}
portStrs := strings.Split(portStr, ",")
for _, portStr := range portStrs {
if strings.Contains(portStr, "-") {
// Port range
portRange := strings.Split(portStr, "-")
if len(portRange) != 2 {
return nil, InvalidPortError{portStr}
}
start, err := strconv.ParseUint(portRange[0], 10, 16)
if err != nil {
return nil, InvalidPortError{portStr}
}
end, err := strconv.ParseUint(portRange[1], 10, 16)
if err != nil {
return nil, InvalidPortError{portStr}
}
if start > end {
start, end = end, start
}
for i := start; i <= end; i++ {
result.Ports = append(result.Ports, uint16(i))
}
} else {
// Single port
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, InvalidPortError{portStr}
}
result.Ports = append(result.Ports, uint16(port))
}
pu := utils.ParsePortUnion(portStr)
if pu == nil {
return nil, InvalidPortError{portStr}
}
result.Ports = pu.Ports()
return result, nil
}

View file

@ -1,132 +0,0 @@
package udphop
import (
"net"
"reflect"
"testing"
)
func TestResolveUDPHopAddr(t *testing.T) {
type args struct {
addr string
}
tests := []struct {
name string
args args
want *UDPHopAddr
wantErr bool
}{
{
name: "empty",
args: args{
addr: "",
},
want: nil,
wantErr: true,
},
{
name: "no port",
args: args{
addr: "8.8.8.8",
},
want: nil,
wantErr: true,
},
{
name: "single port",
args: args{
addr: "8.8.4.4:1234",
},
want: &UDPHopAddr{
IP: net.ParseIP("8.8.4.4"),
Ports: []uint16{1234},
PortStr: "1234",
},
wantErr: false,
},
{
name: "multiple ports",
args: args{
addr: "8.8.3.3:1234,5678,9012",
},
want: &UDPHopAddr{
IP: net.ParseIP("8.8.3.3"),
Ports: []uint16{1234, 5678, 9012},
PortStr: "1234,5678,9012",
},
wantErr: false,
},
{
name: "port range",
args: args{
addr: "1.2.3.4:1234-1240",
},
want: &UDPHopAddr{
IP: net.ParseIP("1.2.3.4"),
Ports: []uint16{1234, 1235, 1236, 1237, 1238, 1239, 1240},
PortStr: "1234-1240",
},
wantErr: false,
},
{
name: "port range reversed",
args: args{
addr: "123.123.123.123:9990-9980",
},
want: &UDPHopAddr{
IP: net.ParseIP("123.123.123.123"),
Ports: []uint16{9980, 9981, 9982, 9983, 9984, 9985, 9986, 9987, 9988, 9989, 9990},
PortStr: "9990-9980",
},
wantErr: false,
},
{
name: "port range & port list",
args: args{
addr: "9.9.9.9:1234-1236,5678,9012",
},
want: &UDPHopAddr{
IP: net.ParseIP("9.9.9.9"),
Ports: []uint16{1234, 1235, 1236, 5678, 9012},
PortStr: "1234-1236,5678,9012",
},
wantErr: false,
},
{
name: "invalid port",
args: args{
addr: "5.5.5.5:1234,bs",
},
want: nil,
wantErr: true,
},
{
name: "invalid port range 1",
args: args{
addr: "6.6.6.6:7788-bbss",
},
want: nil,
wantErr: true,
},
{
name: "invalid port range 2",
args: args{
addr: "1.0.0.1:8899-9002-9005",
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ResolveUDPHopAddr(tt.args.addr)
if (err != nil) != tt.wantErr {
t.Errorf("ParseUDPHopAddr() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ParseUDPHopAddr() got = %v, want %v", got, tt.want)
}
})
}
}

107
extras/utils/portunion.go Normal file
View file

@ -0,0 +1,107 @@
package utils
import (
"sort"
"strconv"
"strings"
)
// PortUnion is a collection of multiple port ranges.
type PortUnion []PortRange
// PortRange represents a range of ports.
// Start and End are inclusive. [Start, End]
type PortRange struct {
Start, End uint16
}
// ParsePortUnion parses a string of comma-separated port ranges (or single ports) into a PortUnion.
// Returns nil if the input is invalid.
// The returned PortUnion is guaranteed to be normalized.
func ParsePortUnion(s string) PortUnion {
if s == "all" || s == "*" {
// Wildcard special case
return PortUnion{PortRange{0, 65535}}
}
var result PortUnion
portStrs := strings.Split(s, ",")
for _, portStr := range portStrs {
if strings.Contains(portStr, "-") {
// Port range
portRange := strings.Split(portStr, "-")
if len(portRange) != 2 {
return nil
}
start, err := strconv.ParseUint(portRange[0], 10, 16)
if err != nil {
return nil
}
end, err := strconv.ParseUint(portRange[1], 10, 16)
if err != nil {
return nil
}
if start > end {
start, end = end, start
}
result = append(result, PortRange{uint16(start), uint16(end)})
} else {
// Single port
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil
}
result = append(result, PortRange{uint16(port), uint16(port)})
}
}
if result == nil {
return nil
}
return result.Normalize()
}
// Normalize normalizes a PortUnion.
// No overlapping ranges, ranges are sorted from low to high.
func (u PortUnion) Normalize() PortUnion {
if len(u) == 0 {
return u
}
sort.Slice(u, func(i, j int) bool {
if u[i].Start == u[j].Start {
return u[i].End < u[j].End
}
return u[i].Start < u[j].Start
})
normalized := PortUnion{u[0]}
for _, current := range u[1:] {
last := &normalized[len(normalized)-1]
if current.Start <= last.End+1 {
if current.End > last.End {
last.End = current.End
}
} else {
normalized = append(normalized, current)
}
}
return normalized
}
// Ports returns all ports in the PortUnion as a slice.
func (u PortUnion) Ports() []uint16 {
var ports []uint16
for _, r := range u {
for i := r.Start; i <= r.End; i++ {
ports = append(ports, i)
}
}
return ports
}
// Contains returns true if the PortUnion contains the given port.
func (u PortUnion) Contains(port uint16) bool {
for _, r := range u {
if port >= r.Start && port <= r.End {
return true
}
}
return false
}

View file

@ -0,0 +1,92 @@
package utils
import (
"reflect"
"testing"
)
func TestParsePortUnion(t *testing.T) {
tests := []struct {
name string
s string
want PortUnion
}{
{
name: "empty",
s: "",
want: nil,
},
{
name: "all 1",
s: "all",
want: PortUnion{{0, 65535}},
},
{
name: "all 2",
s: "*",
want: PortUnion{{0, 65535}},
},
{
name: "single port",
s: "1234",
want: PortUnion{{1234, 1234}},
},
{
name: "multiple ports (unsorted)",
s: "5678,1234,9012",
want: PortUnion{{1234, 1234}, {5678, 5678}, {9012, 9012}},
},
{
name: "one range",
s: "1234-1240",
want: PortUnion{{1234, 1240}},
},
{
name: "one range (reversed)",
s: "1240-1234",
want: PortUnion{{1234, 1240}},
},
{
name: "multiple ports and ranges (reversed, unsorted, overlapping)",
s: "5678,1200-1236,9100-9012,1234-1240",
want: PortUnion{{1200, 1240}, {5678, 5678}, {9012, 9100}},
},
{
name: "invalid 1",
s: "1234-",
want: nil,
},
{
name: "invalid 2",
s: "1234-ggez",
want: nil,
},
{
name: "invalid 3",
s: "233,",
want: nil,
},
{
name: "invalid 4",
s: "1234-1240-1250",
want: nil,
},
{
name: "invalid 5",
s: "-,,",
want: nil,
},
{
name: "invalid 6",
s: "http",
want: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ParsePortUnion(tt.s); !reflect.DeepEqual(got, tt.want) {
t.Errorf("ParsePortUnion() = %v, want %v", got, tt.want)
}
})
}
}