mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-02 03:57:38 +03:00
feat: udp port hopping
This commit is contained in:
parent
1ea7c515ae
commit
3e5eccd6e3
9 changed files with 1947 additions and 43 deletions
|
@ -6,13 +6,14 @@ import (
|
|||
"encoding/hex"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/apernet/hysteria/extras/transport/udphop"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
|
@ -21,6 +22,7 @@ import (
|
|||
"github.com/apernet/hysteria/app/internal/http"
|
||||
"github.com/apernet/hysteria/app/internal/socks5"
|
||||
"github.com/apernet/hysteria/app/internal/tproxy"
|
||||
"github.com/apernet/hysteria/app/internal/url"
|
||||
"github.com/apernet/hysteria/app/internal/utils"
|
||||
"github.com/apernet/hysteria/core/client"
|
||||
"github.com/apernet/hysteria/extras/obfs"
|
||||
|
@ -49,6 +51,7 @@ func initClientFlags() {
|
|||
type clientConfig struct {
|
||||
Server string `mapstructure:"server"`
|
||||
Auth string `mapstructure:"auth"`
|
||||
Transport clientConfigTransport `mapstructure:"transport"`
|
||||
Obfs clientConfigObfs `mapstructure:"obfs"`
|
||||
TLS clientConfigTLS `mapstructure:"tls"`
|
||||
QUIC clientConfigQUIC `mapstructure:"quic"`
|
||||
|
@ -63,6 +66,15 @@ type clientConfig struct {
|
|||
UDPTProxy *udpTProxyConfig `mapstructure:"udpTProxy"`
|
||||
}
|
||||
|
||||
type clientConfigTransportUDP struct {
|
||||
HopInterval time.Duration `mapstructure:"hopInterval"`
|
||||
}
|
||||
|
||||
type clientConfigTransport struct {
|
||||
Type string `mapstructure:"type"`
|
||||
UDP clientConfigTransportUDP `mapstructure:"udp"`
|
||||
}
|
||||
|
||||
type clientConfigObfsSalamander struct {
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
|
@ -128,34 +140,18 @@ type udpTProxyConfig struct {
|
|||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
}
|
||||
|
||||
func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error {
|
||||
switch strings.ToLower(c.Obfs.Type) {
|
||||
case "", "plain":
|
||||
// Default, do nothing
|
||||
return nil
|
||||
case "salamander":
|
||||
ob, err := obfs.NewSalamanderObfuscator([]byte(c.Obfs.Salamander.Password))
|
||||
if err != nil {
|
||||
return configError{Field: "obfs.salamander.password", Err: err}
|
||||
}
|
||||
hyConfig.ConnFactory = &obfsConnFactory{
|
||||
NewFunc: func(addr net.Addr) (net.PacketConn, error) {
|
||||
return net.ListenUDP("udp", nil)
|
||||
},
|
||||
Obfuscator: ob,
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return configError{Field: "obfs.type", Err: errors.New("unsupported obfuscation type")}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientConfig) fillServerAddr(hyConfig *client.Config) error {
|
||||
if c.Server == "" {
|
||||
return configError{Field: "server", Err: errors.New("server address is empty")}
|
||||
}
|
||||
host, hostPort := parseServerAddrString(c.Server)
|
||||
addr, err := net.ResolveUDPAddr("udp", hostPort)
|
||||
var addr net.Addr
|
||||
var err error
|
||||
host, port, hostPort := parseServerAddrString(c.Server)
|
||||
if !isPortHoppingPort(port) {
|
||||
addr, err = net.ResolveUDPAddr("udp", hostPort)
|
||||
} else {
|
||||
addr, err = udphop.ResolveUDPHopAddr(hostPort)
|
||||
}
|
||||
if err != nil {
|
||||
return configError{Field: "server", Err: err}
|
||||
}
|
||||
|
@ -168,6 +164,47 @@ func (c *clientConfig) fillServerAddr(hyConfig *client.Config) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// fillConnFactory must be called after fillServerAddr, as we have different logic
|
||||
// for ConnFactory depending on whether we have a port hopping address.
|
||||
func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error {
|
||||
// Inner PacketConn
|
||||
var newFunc func(addr net.Addr) (net.PacketConn, error)
|
||||
switch strings.ToLower(c.Transport.Type) {
|
||||
case "", "udp":
|
||||
if hyConfig.ServerAddr.Network() == "udphop" {
|
||||
hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr)
|
||||
newFunc = func(addr net.Addr) (net.PacketConn, error) {
|
||||
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval)
|
||||
}
|
||||
} else {
|
||||
newFunc = func(addr net.Addr) (net.PacketConn, error) {
|
||||
return net.ListenUDP("udp", nil)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return configError{Field: "transport.type", Err: errors.New("unsupported transport type")}
|
||||
}
|
||||
// Obfuscation
|
||||
var ob obfs.Obfuscator
|
||||
var err error
|
||||
switch strings.ToLower(c.Obfs.Type) {
|
||||
case "", "plain":
|
||||
// Keep it nil
|
||||
case "salamander":
|
||||
ob, err = obfs.NewSalamanderObfuscator([]byte(c.Obfs.Salamander.Password))
|
||||
if err != nil {
|
||||
return configError{Field: "obfs.salamander.password", Err: err}
|
||||
}
|
||||
default:
|
||||
return configError{Field: "obfs.type", Err: errors.New("unsupported obfuscation type")}
|
||||
}
|
||||
hyConfig.ConnFactory = &adaptiveConnFactory{
|
||||
NewFunc: newFunc,
|
||||
Obfuscator: ob,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientConfig) fillAuth(hyConfig *client.Config) error {
|
||||
hyConfig.Auth = c.Auth
|
||||
return nil
|
||||
|
@ -330,8 +367,8 @@ func (c *clientConfig) Config() (*client.Config, error) {
|
|||
c.parseURI()
|
||||
hyConfig := &client.Config{}
|
||||
fillers := []func(*client.Config) error{
|
||||
c.fillConnFactory,
|
||||
c.fillServerAddr,
|
||||
c.fillConnFactory,
|
||||
c.fillAuth,
|
||||
c.fillTLSConfig,
|
||||
c.fillQUICConfig,
|
||||
|
@ -596,12 +633,18 @@ func clientUDPTProxy(config udpTProxyConfig, c client.Client) error {
|
|||
|
||||
// parseServerAddrString parses server address string.
|
||||
// Server address can be in either "host:port" or "host" format (in which case we assume port 443).
|
||||
func parseServerAddrString(addrStr string) (host, hostPort string) {
|
||||
h, _, err := net.SplitHostPort(addrStr)
|
||||
func parseServerAddrString(addrStr string) (host, port, hostPort string) {
|
||||
h, p, err := net.SplitHostPort(addrStr)
|
||||
if err != nil {
|
||||
return addrStr, net.JoinHostPort(addrStr, "443")
|
||||
return addrStr, "443", net.JoinHostPort(addrStr, "443")
|
||||
}
|
||||
return h, addrStr
|
||||
return h, p, addrStr
|
||||
}
|
||||
|
||||
// isPortHoppingPort returns whether the port string is a port hopping port.
|
||||
// We consider a port string to be a port hopping port if it contains "-" or ",".
|
||||
func isPortHoppingPort(port string) bool {
|
||||
return strings.Contains(port, "-") || strings.Contains(port, ",")
|
||||
}
|
||||
|
||||
// normalizeCertHash normalizes a certificate hash string.
|
||||
|
@ -613,18 +656,21 @@ func normalizeCertHash(hash string) string {
|
|||
return r
|
||||
}
|
||||
|
||||
// obfsConnFactory adds obfuscation to a function that creates net.PacketConn.
|
||||
type obfsConnFactory struct {
|
||||
type adaptiveConnFactory struct {
|
||||
NewFunc func(addr net.Addr) (net.PacketConn, error)
|
||||
Obfuscator obfs.Obfuscator
|
||||
Obfuscator obfs.Obfuscator // nil if no obfuscation
|
||||
}
|
||||
|
||||
func (f *obfsConnFactory) New(addr net.Addr) (net.PacketConn, error) {
|
||||
conn, err := f.NewFunc(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (f *adaptiveConnFactory) New(addr net.Addr) (net.PacketConn, error) {
|
||||
if f.Obfuscator == nil {
|
||||
return f.NewFunc(addr)
|
||||
} else {
|
||||
conn, err := f.NewFunc(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return obfs.WrapPacketConn(conn, f.Obfuscator), nil
|
||||
}
|
||||
return obfs.WrapPacketConn(conn, f.Obfuscator), nil
|
||||
}
|
||||
|
||||
func connectLog(count int) {
|
||||
|
|
|
@ -20,6 +20,12 @@ func TestClientConfig(t *testing.T) {
|
|||
assert.Equal(t, config, clientConfig{
|
||||
Server: "example.com",
|
||||
Auth: "weak_ahh_password",
|
||||
Transport: clientConfigTransport{
|
||||
Type: "udp",
|
||||
UDP: clientConfigTransportUDP{
|
||||
HopInterval: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
Obfs: clientConfigObfs{
|
||||
Type: "salamander",
|
||||
Salamander: clientConfigObfsSalamander{
|
||||
|
@ -98,13 +104,21 @@ func TestClientConfigURI(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
uri: "hysteria2://john:wick@continental.org/",
|
||||
uri: "hysteria2://john:wick@continental.org:4443/",
|
||||
uriOK: true,
|
||||
config: &clientConfig{
|
||||
Server: "continental.org",
|
||||
Server: "continental.org:4443",
|
||||
Auth: "john:wick",
|
||||
},
|
||||
},
|
||||
{
|
||||
uri: "hysteria2://saul@better.call:7000-10000,20000/",
|
||||
uriOK: true,
|
||||
config: &clientConfig{
|
||||
Server: "better.call:7000-10000,20000",
|
||||
Auth: "saul",
|
||||
},
|
||||
},
|
||||
{
|
||||
uri: "hysteria2://noauth.com/?insecure=1&obfs=salamander&obfs-password=66ccff&pinSHA256=deadbeef&sni=crap.cc",
|
||||
uriOK: true,
|
||||
|
|
|
@ -2,6 +2,11 @@ server: example.com
|
|||
|
||||
auth: weak_ahh_password
|
||||
|
||||
transport:
|
||||
type: udp
|
||||
udp:
|
||||
hopInterval: 30s
|
||||
|
||||
obfs:
|
||||
type: salamander
|
||||
salamander:
|
||||
|
|
1270
app/internal/url/url.go
Normal file
1270
app/internal/url/url.go
Normal file
File diff suppressed because it is too large
Load diff
91
app/internal/url/url_test.go
Normal file
91
app/internal/url/url_test.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package url
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
type args struct {
|
||||
rawURL string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *URL
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no port",
|
||||
args: args{
|
||||
rawURL: "hysteria2://ganggang@icecreamsogood/",
|
||||
},
|
||||
want: &URL{
|
||||
Scheme: "hysteria2",
|
||||
User: User("ganggang"),
|
||||
Host: "icecreamsogood",
|
||||
Path: "/",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single port",
|
||||
args: args{
|
||||
rawURL: "hysteria2://yesyes@icecreamsogood:8888/",
|
||||
},
|
||||
want: &URL{
|
||||
Scheme: "hysteria2",
|
||||
User: User("yesyes"),
|
||||
Host: "icecreamsogood:8888",
|
||||
Path: "/",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi port",
|
||||
args: args{
|
||||
rawURL: "hysteria2://darkness@laplus.org:8888,9999,11111/",
|
||||
},
|
||||
want: &URL{
|
||||
Scheme: "hysteria2",
|
||||
User: User("darkness"),
|
||||
Host: "laplus.org:8888,9999,11111",
|
||||
Path: "/",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "range port",
|
||||
args: args{
|
||||
rawURL: "hysteria2://darkness@laplus.org:8888-9999/",
|
||||
},
|
||||
want: &URL{
|
||||
Scheme: "hysteria2",
|
||||
User: User("darkness"),
|
||||
Host: "laplus.org:8888-9999",
|
||||
Path: "/",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "both",
|
||||
args: args{
|
||||
rawURL: "hysteria2://gawr:gura@atlantis.moe:443,7788-8899,10010/",
|
||||
},
|
||||
want: &URL{
|
||||
Scheme: "hysteria2",
|
||||
User: UserPassword("gawr", "gura"),
|
||||
Host: "atlantis.moe:443,7788-8899,10010",
|
||||
Path: "/",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Parse(tt.args.rawURL)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Parse() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -7,7 +7,7 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
const udpBufSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough
|
||||
const udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough
|
||||
|
||||
// Obfuscator is the interface that wraps the Obfuscate and Deobfuscate methods.
|
||||
// Both methods return the number of bytes written to out.
|
||||
|
@ -45,8 +45,8 @@ func WrapPacketConn(conn net.PacketConn, obfs Obfuscator) net.PacketConn {
|
|||
opc := &obfsPacketConn{
|
||||
Conn: conn,
|
||||
Obfs: obfs,
|
||||
readBuf: make([]byte, udpBufSize),
|
||||
writeBuf: make([]byte, udpBufSize),
|
||||
readBuf: make([]byte, udpBufferSize),
|
||||
writeBuf: make([]byte, udpBufferSize),
|
||||
}
|
||||
if udpConn, ok := conn.(*net.UDPConn); ok {
|
||||
return &obfsPacketConnUDP{
|
||||
|
|
86
extras/transport/udphop/addr.go
Normal file
86
extras/transport/udphop/addr.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package udphop
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrInvalidPort = errors.New("invalid port")
|
||||
|
||||
// UDPHopAddr contains an IP address and a list of ports.
|
||||
type UDPHopAddr struct {
|
||||
IP net.IP
|
||||
Ports []uint16
|
||||
PortStr string
|
||||
}
|
||||
|
||||
func (a *UDPHopAddr) Network() string {
|
||||
return "udphop"
|
||||
}
|
||||
|
||||
func (a *UDPHopAddr) String() string {
|
||||
return net.JoinHostPort(a.IP.String(), a.PortStr)
|
||||
}
|
||||
|
||||
// addrs returns a list of net.Addr's, one for each port.
|
||||
func (a *UDPHopAddr) addrs() ([]net.Addr, error) {
|
||||
var addrs []net.Addr
|
||||
for _, port := range a.Ports {
|
||||
addr := &net.UDPAddr{
|
||||
IP: a.IP,
|
||||
Port: int(port),
|
||||
}
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func ResolveUDPHopAddr(addr string) (*UDPHopAddr, error) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ip, err := net.ResolveIPAddr("ip", host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := &UDPHopAddr{
|
||||
IP: ip.IP,
|
||||
PortStr: portStr,
|
||||
}
|
||||
|
||||
portStrs := strings.Split(portStr, ",")
|
||||
for _, portStr := range portStrs {
|
||||
if strings.Contains(portStr, "-") {
|
||||
// Port range
|
||||
portRange := strings.Split(portStr, "-")
|
||||
if len(portRange) != 2 {
|
||||
return nil, ErrInvalidPort
|
||||
}
|
||||
start, err := strconv.ParseUint(portRange[0], 10, 16)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidPort
|
||||
}
|
||||
end, err := strconv.ParseUint(portRange[1], 10, 16)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidPort
|
||||
}
|
||||
if start > end {
|
||||
start, end = end, start
|
||||
}
|
||||
for i := start; i <= end; i++ {
|
||||
result.Ports = append(result.Ports, uint16(i))
|
||||
}
|
||||
} else {
|
||||
// Single port
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidPort
|
||||
}
|
||||
result.Ports = append(result.Ports, uint16(port))
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
132
extras/transport/udphop/addr_test.go
Normal file
132
extras/transport/udphop/addr_test.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
package udphop
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveUDPHopAddr(t *testing.T) {
|
||||
type args struct {
|
||||
addr string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *UDPHopAddr
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
addr: "",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no port",
|
||||
args: args{
|
||||
addr: "8.8.8.8",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "single port",
|
||||
args: args{
|
||||
addr: "8.8.4.4:1234",
|
||||
},
|
||||
want: &UDPHopAddr{
|
||||
IP: net.ParseIP("8.8.4.4"),
|
||||
Ports: []uint16{1234},
|
||||
PortStr: "1234",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple ports",
|
||||
args: args{
|
||||
addr: "8.8.3.3:1234,5678,9012",
|
||||
},
|
||||
want: &UDPHopAddr{
|
||||
IP: net.ParseIP("8.8.3.3"),
|
||||
Ports: []uint16{1234, 5678, 9012},
|
||||
PortStr: "1234,5678,9012",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port range",
|
||||
args: args{
|
||||
addr: "1.2.3.4:1234-1240",
|
||||
},
|
||||
want: &UDPHopAddr{
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Ports: []uint16{1234, 1235, 1236, 1237, 1238, 1239, 1240},
|
||||
PortStr: "1234-1240",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port range reversed",
|
||||
args: args{
|
||||
addr: "123.123.123.123:9990-9980",
|
||||
},
|
||||
want: &UDPHopAddr{
|
||||
IP: net.ParseIP("123.123.123.123"),
|
||||
Ports: []uint16{9980, 9981, 9982, 9983, 9984, 9985, 9986, 9987, 9988, 9989, 9990},
|
||||
PortStr: "9990-9980",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port range & port list",
|
||||
args: args{
|
||||
addr: "9.9.9.9:1234-1236,5678,9012",
|
||||
},
|
||||
want: &UDPHopAddr{
|
||||
IP: net.ParseIP("9.9.9.9"),
|
||||
Ports: []uint16{1234, 1235, 1236, 5678, 9012},
|
||||
PortStr: "1234-1236,5678,9012",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid port",
|
||||
args: args{
|
||||
addr: "5.5.5.5:1234,bs",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid port range 1",
|
||||
args: args{
|
||||
addr: "6.6.6.6:7788-bbss",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid port range 2",
|
||||
args: args{
|
||||
addr: "1.0.0.1:8899-9002-9005",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ResolveUDPHopAddr(tt.args.addr)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseUDPHopAddr() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ParseUDPHopAddr() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
260
extras/transport/udphop/conn.go
Normal file
260
extras/transport/udphop/conn.go
Normal file
|
@ -0,0 +1,260 @@
|
|||
package udphop
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
packetQueueSize = 1024
|
||||
udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough
|
||||
|
||||
defaultHopInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
type udpHopPacketConn struct {
|
||||
Addr net.Addr
|
||||
Addrs []net.Addr
|
||||
HopInterval time.Duration
|
||||
|
||||
connMutex sync.RWMutex
|
||||
prevConn net.PacketConn
|
||||
currentConn net.PacketConn
|
||||
addrIndex int
|
||||
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
|
||||
recvQueue chan *udpPacket
|
||||
closeChan chan struct{}
|
||||
closed bool
|
||||
|
||||
bufPool sync.Pool
|
||||
}
|
||||
|
||||
type udpPacket struct {
|
||||
Buf []byte
|
||||
N int
|
||||
Addr net.Addr
|
||||
}
|
||||
|
||||
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration) (net.PacketConn, error) {
|
||||
if hopInterval == 0 {
|
||||
hopInterval = defaultHopInterval
|
||||
} else if hopInterval < 5*time.Second {
|
||||
return nil, errors.New("hop interval must be at least 5 seconds")
|
||||
}
|
||||
addrs, err := addr.addrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
curConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hConn := &udpHopPacketConn{
|
||||
Addr: addr,
|
||||
Addrs: addrs,
|
||||
HopInterval: hopInterval,
|
||||
prevConn: nil,
|
||||
currentConn: curConn,
|
||||
addrIndex: rand.Intn(len(addrs)),
|
||||
recvQueue: make(chan *udpPacket, packetQueueSize),
|
||||
closeChan: make(chan struct{}),
|
||||
bufPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, udpBufferSize)
|
||||
},
|
||||
},
|
||||
}
|
||||
go hConn.recvLoop(curConn)
|
||||
go hConn.hopLoop()
|
||||
return hConn, nil
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) recvLoop(conn net.PacketConn) {
|
||||
for {
|
||||
buf := u.bufPool.Get().([]byte)
|
||||
n, addr, err := conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case u.recvQueue <- &udpPacket{buf, n, addr}:
|
||||
default:
|
||||
// Queue is full, drop the packet
|
||||
u.bufPool.Put(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) hopLoop() {
|
||||
ticker := time.NewTicker(u.HopInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
u.hop()
|
||||
case <-u.closeChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) hop() {
|
||||
u.connMutex.Lock()
|
||||
defer u.connMutex.Unlock()
|
||||
if u.closed {
|
||||
return
|
||||
}
|
||||
newConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
// Could be temporary, just skip this hop
|
||||
return
|
||||
}
|
||||
// We need to keep receiving packets from the previous connection,
|
||||
// because otherwise there will be packet loss due to the time gap
|
||||
// between we hop to a new port and the server acknowledges this change.
|
||||
// So we do the following:
|
||||
// Close prevConn,
|
||||
// move currentConn to prevConn,
|
||||
// set newConn as currentConn,
|
||||
// start recvLoop on newConn.
|
||||
if u.prevConn != nil {
|
||||
_ = u.prevConn.Close() // recvLoop for this conn will exit
|
||||
}
|
||||
u.prevConn = u.currentConn
|
||||
u.currentConn = newConn
|
||||
// Set buffer sizes if previously set
|
||||
if u.readBufferSize > 0 {
|
||||
_ = trySetReadBuffer(u.currentConn, u.readBufferSize)
|
||||
}
|
||||
if u.writeBufferSize > 0 {
|
||||
_ = trySetWriteBuffer(u.currentConn, u.writeBufferSize)
|
||||
}
|
||||
go u.recvLoop(newConn)
|
||||
// Update addrIndex to a new random value
|
||||
u.addrIndex = rand.Intn(len(u.Addrs))
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
for {
|
||||
select {
|
||||
case p := <-u.recvQueue:
|
||||
// Currently we do not check whether the packet is from
|
||||
// the server or not due to performance reasons.
|
||||
n := copy(b, p.Buf[:p.N])
|
||||
u.bufPool.Put(p.Buf)
|
||||
return n, u.Addr, nil
|
||||
case <-u.closeChan:
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
u.connMutex.RLock()
|
||||
defer u.connMutex.RUnlock()
|
||||
if u.closed {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
// Skip the check for now, always write to the server,
|
||||
// for the same reason as in ReadFrom.
|
||||
return u.currentConn.WriteTo(b, u.Addrs[u.addrIndex])
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) Close() error {
|
||||
u.connMutex.Lock()
|
||||
defer u.connMutex.Unlock()
|
||||
if u.closed {
|
||||
return nil
|
||||
}
|
||||
// Close prevConn and currentConn
|
||||
// Close closeChan to unblock ReadFrom & hopLoop
|
||||
// Set closed flag to true to prevent double close
|
||||
if u.prevConn != nil {
|
||||
_ = u.prevConn.Close()
|
||||
}
|
||||
err := u.currentConn.Close()
|
||||
close(u.closeChan)
|
||||
u.closed = true
|
||||
u.Addrs = nil // For GC
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) LocalAddr() net.Addr {
|
||||
u.connMutex.RLock()
|
||||
defer u.connMutex.RUnlock()
|
||||
return u.currentConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) SetDeadline(t time.Time) error {
|
||||
// Not implemented
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) SetReadDeadline(t time.Time) error {
|
||||
// Not implemented
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) SetWriteDeadline(t time.Time) error {
|
||||
// Not implemented
|
||||
return nil
|
||||
}
|
||||
|
||||
// UDP-specific methods below
|
||||
|
||||
func (u *udpHopPacketConn) SetReadBuffer(bytes int) error {
|
||||
u.connMutex.Lock()
|
||||
defer u.connMutex.Unlock()
|
||||
u.readBufferSize = bytes
|
||||
if u.prevConn != nil {
|
||||
_ = trySetReadBuffer(u.prevConn, bytes)
|
||||
}
|
||||
return trySetReadBuffer(u.currentConn, bytes)
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) SetWriteBuffer(bytes int) error {
|
||||
u.connMutex.Lock()
|
||||
defer u.connMutex.Unlock()
|
||||
u.writeBufferSize = bytes
|
||||
if u.prevConn != nil {
|
||||
_ = trySetWriteBuffer(u.prevConn, bytes)
|
||||
}
|
||||
return trySetWriteBuffer(u.currentConn, bytes)
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) SyscallConn() (syscall.RawConn, error) {
|
||||
u.connMutex.RLock()
|
||||
defer u.connMutex.RUnlock()
|
||||
sc, ok := u.currentConn.(syscall.Conn)
|
||||
if !ok {
|
||||
return nil, errors.New("not supported")
|
||||
}
|
||||
return sc.SyscallConn()
|
||||
}
|
||||
|
||||
func trySetReadBuffer(pc net.PacketConn, bytes int) error {
|
||||
sc, ok := pc.(interface {
|
||||
SetReadBuffer(bytes int) error
|
||||
})
|
||||
if ok {
|
||||
return sc.SetReadBuffer(bytes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func trySetWriteBuffer(pc net.PacketConn, bytes int) error {
|
||||
sc, ok := pc.(interface {
|
||||
SetWriteBuffer(bytes int) error
|
||||
})
|
||||
if ok {
|
||||
return sc.SetWriteBuffer(bytes)
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue