diff --git a/app/cmd/client.go b/app/cmd/client.go index c1d04bd..78071a5 100644 --- a/app/cmd/client.go +++ b/app/cmd/client.go @@ -30,6 +30,7 @@ import ( "github.com/apernet/hysteria/core/client" "github.com/apernet/hysteria/extras/correctnet" "github.com/apernet/hysteria/extras/obfs" + "github.com/apernet/hysteria/extras/protect" "github.com/apernet/hysteria/extras/transport/udphop" ) @@ -55,6 +56,7 @@ func initClientFlags() { type clientConfig struct { Server string `mapstructure:"server"` + ProtectPath string `mapstructure:"protectPath"` Auth string `mapstructure:"auth"` Transport clientConfigTransport `mapstructure:"transport"` Obfs clientConfigObfs `mapstructure:"obfs"` @@ -202,11 +204,11 @@ func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error { if hyConfig.ServerAddr.Network() == "udphop" { hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr) newFunc = func(addr net.Addr) (net.PacketConn, error) { - return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, nil) + return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, protect.ListenUDP(c.ProtectPath)) } } else { newFunc = func(addr net.Addr) (net.PacketConn, error) { - return net.ListenUDP("udp", nil) + return protect.ListenUDP(c.ProtectPath)() } } default: diff --git a/extras/go.mod b/extras/go.mod index bc21044..a470d1e 100644 --- a/extras/go.mod +++ b/extras/go.mod @@ -11,6 +11,7 @@ require ( github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301 golang.org/x/crypto v0.19.0 golang.org/x/net v0.21.0 + golang.org/x/sys v0.17.0 google.golang.org/protobuf v1.33.0 ) @@ -28,7 +29,6 @@ require ( go.uber.org/mock v0.4.0 // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.12.0 // indirect - golang.org/x/sys v0.17.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.11.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/extras/protect/protect.go b/extras/protect/protect.go new file mode 100644 index 0000000..1b5cd08 --- /dev/null +++ b/extras/protect/protect.go @@ -0,0 +1,9 @@ +// Package protect set VPN protect for every conns to bypass route. +package protect + +import ( + "net" +) + +// ListenUDPFunc listen UDP with VPN protect. +type ListenUDPFunc func() (net.PacketConn, error) diff --git a/extras/protect/protect_linux.go b/extras/protect/protect_linux.go new file mode 100644 index 0000000..0fffe13 --- /dev/null +++ b/extras/protect/protect_linux.go @@ -0,0 +1,72 @@ +//go:build linux + +package protect + +import ( + "errors" + "net" + + "golang.org/x/sys/unix" +) + +const ( + timevalSec = 3 +) + +func protect(connFd int, path string) error { + if path == "" { + return nil + } + + socketFd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, unix.PROT_NONE) + if err != nil { + return err + } + defer unix.Close(socketFd) + + _ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{Sec: timevalSec}) + _ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_SNDTIMEO, &unix.Timeval{Sec: timevalSec}) + + err = unix.Connect(socketFd, &unix.SockaddrUnix{Name: path}) + if err != nil { + return err + } + + err = unix.Sendmsg(socketFd, nil, unix.UnixRights(connFd), nil, 0) + if err != nil { + return err + } + + dummy := []byte{1} + n, err := unix.Read(socketFd, dummy) + if err != nil { + return err + } + if n != 1 { + return errors.New("protect failed") + } + + return nil +} + +func ListenUDP(protectPath string) ListenUDPFunc { + return func() (net.PacketConn, error) { + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, err + } + + udpFile, err := udpConn.File() + if err != nil { + return nil, err + } + + err = protect(int(udpFile.Fd()), protectPath) + if err != nil { + _ = udpConn.Close() + return nil, err + } + + return udpConn, nil + } +} diff --git a/extras/protect/protect_stub.go b/extras/protect/protect_stub.go new file mode 100644 index 0000000..a8da279 --- /dev/null +++ b/extras/protect/protect_stub.go @@ -0,0 +1,13 @@ +//go:build !linux + +package protect + +import ( + "net" +) + +func ListenUDP(protectPath string) ListenUDPFunc { + return func() (net.PacketConn, error) { + return net.ListenUDP("udp", nil) + } +} diff --git a/extras/transport/udphop/conn.go b/extras/transport/udphop/conn.go index f20c583..722aa1c 100644 --- a/extras/transport/udphop/conn.go +++ b/extras/transport/udphop/conn.go @@ -7,6 +7,8 @@ import ( "sync" "syscall" "time" + + "github.com/apernet/hysteria/extras/protect" ) const ( @@ -20,7 +22,7 @@ type udpHopPacketConn struct { Addr net.Addr Addrs []net.Addr HopInterval time.Duration - ListenUDPFunc ListenUDPFunc + ListenUDPFunc protect.ListenUDPFunc connMutex sync.RWMutex prevConn net.PacketConn @@ -44,9 +46,7 @@ type udpPacket struct { Err error } -type ListenUDPFunc func() (net.PacketConn, error) - -func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc ListenUDPFunc) (net.PacketConn, error) { +func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc protect.ListenUDPFunc) (net.PacketConn, error) { if hopInterval == 0 { hopInterval = defaultHopInterval } else if hopInterval < 5*time.Second {