From 3e34da1aa87f8409307e2ade59f85826f4236849 Mon Sep 17 00:00:00 2001 From: Haruue Date: Fri, 5 Apr 2024 02:20:45 +0800 Subject: [PATCH] refactor: protect => quic.sockopts Android's VpnService.protect() itself is confusing, so we rename the "protect" feature with the name `fdControlUnixSocket` and make it a sub-option under `quic.sockopts`. A unit test is added to make sure the protect feature works. I also added two other common options to `quic.sockopts` that I copied from my other projects but did not fully test here. --- app/cmd/client.go | 43 ++++++--- .../sockopts/fd_control_unix_socket_test.py | 65 +++++++++++++ app/internal/sockopts/sockopts.go | 76 ++++++++++++++++ app/internal/sockopts/sockopts_linux.go | 91 +++++++++++++++++++ app/internal/sockopts/sockopts_linux_test.go | 53 +++++++++++ extras/protect/protect.go | 9 -- extras/protect/protect_linux.go | 78 ---------------- extras/protect/protect_stub.go | 13 --- extras/transport/udphop/conn.go | 8 +- 9 files changed, 321 insertions(+), 115 deletions(-) create mode 100644 app/internal/sockopts/fd_control_unix_socket_test.py create mode 100644 app/internal/sockopts/sockopts.go create mode 100644 app/internal/sockopts/sockopts_linux.go create mode 100644 app/internal/sockopts/sockopts_linux_test.go delete mode 100644 extras/protect/protect.go delete mode 100644 extras/protect/protect_linux.go delete mode 100644 extras/protect/protect_stub.go diff --git a/app/cmd/client.go b/app/cmd/client.go index 78071a5..c230140 100644 --- a/app/cmd/client.go +++ b/app/cmd/client.go @@ -22,6 +22,7 @@ import ( "github.com/apernet/hysteria/app/internal/forwarding" "github.com/apernet/hysteria/app/internal/http" "github.com/apernet/hysteria/app/internal/redirect" + "github.com/apernet/hysteria/app/internal/sockopts" "github.com/apernet/hysteria/app/internal/socks5" "github.com/apernet/hysteria/app/internal/tproxy" "github.com/apernet/hysteria/app/internal/tun" @@ -30,7 +31,6 @@ import ( "github.com/apernet/hysteria/core/client" "github.com/apernet/hysteria/extras/correctnet" "github.com/apernet/hysteria/extras/obfs" - "github.com/apernet/hysteria/extras/protect" "github.com/apernet/hysteria/extras/transport/udphop" ) @@ -56,7 +56,6 @@ func initClientFlags() { type clientConfig struct { Server string `mapstructure:"server"` - ProtectPath string `mapstructure:"protectPath"` Auth string `mapstructure:"auth"` Transport clientConfigTransport `mapstructure:"transport"` Obfs clientConfigObfs `mapstructure:"obfs"` @@ -101,13 +100,20 @@ type clientConfigTLS struct { } type clientConfigQUIC struct { - InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` - MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` - InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` - MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` - MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` - KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"` - DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` + InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` + MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` + InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` + MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` + MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` + KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"` + DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` + Sockopts clientConfigQUICSockopts `mapstructure:"sockopts"` +} + +type clientConfigQUICSockopts struct { + BindInterface *string `mapstructure:"bindInterface"` + FirewallMark *uint32 `mapstructure:"fwmark"` + FdControlUnixSocket *string `mapstructure:"fdControlUnixSocket"` } type clientConfigBandwidth struct { @@ -197,6 +203,21 @@ func (c *clientConfig) fillServerAddr(hyConfig *client.Config) error { // fillConnFactory must be called after fillServerAddr, as we have different logic // for ConnFactory depending on whether we have a port hopping address. func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error { + so := &sockopts.SocketOptions{ + BindInterface: c.QUIC.Sockopts.BindInterface, + FirewallMark: c.QUIC.Sockopts.FirewallMark, + FdControlUnixSocket: c.QUIC.Sockopts.FdControlUnixSocket, + } + if err := so.CheckSupported(); err != nil { + var unsupportedErr *sockopts.UnsupportedError + if errors.As(err, &unsupportedErr) { + return configError{ + Field: "quic.sockopts." + unsupportedErr.Field, + Err: errors.New("unsupported on this platform"), + } + } + return configError{Field: "quic.sockopts", Err: err} + } // Inner PacketConn var newFunc func(addr net.Addr) (net.PacketConn, error) switch strings.ToLower(c.Transport.Type) { @@ -204,11 +225,11 @@ func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error { if hyConfig.ServerAddr.Network() == "udphop" { hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr) newFunc = func(addr net.Addr) (net.PacketConn, error) { - return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, protect.ListenUDP(c.ProtectPath)) + return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, so.ListenUDP) } } else { newFunc = func(addr net.Addr) (net.PacketConn, error) { - return protect.ListenUDP(c.ProtectPath)() + return so.ListenUDP() } } default: diff --git a/app/internal/sockopts/fd_control_unix_socket_test.py b/app/internal/sockopts/fd_control_unix_socket_test.py new file mode 100644 index 0000000..e47a6f6 --- /dev/null +++ b/app/internal/sockopts/fd_control_unix_socket_test.py @@ -0,0 +1,65 @@ +import socket +import array +import os +import struct +import sys + + +def serve(path): + try: + os.unlink(path) + except OSError: + if os.path.exists(path): + raise + + server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server.bind(path) + server.listen() + print(f"Listening on {path}") + + try: + while True: + connection, client_address = server.accept() + print(f"Client connected") + + try: + # Receiving fd from client + fds = array.array("i") + msg, ancdata, flags, addr = connection.recvmsg(1, socket.CMSG_LEN(struct.calcsize('i'))) + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: + fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + + fd = fds[0] + + # We make a call to setsockopt(2) here, so client can verify we have received the fd + # In the real scenario, the server would set things like SO_MARK, + # we use SO_RCVBUF as it doesn't require any special capabilities. + nbytes = struct.pack("i", 2500) + fdsocket = fd_to_socket(fd) + fdsocket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, nbytes) + fdsocket.close() + + # The only protocol-like thing specified in the client implementation. + connection.send(b'\x01') + finally: + connection.close() + print("Connection closed") + + except KeyboardInterrupt: + print("Exit") + + finally: + server.close() + os.unlink(path) + + +def fd_to_socket(fd): + return socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + raise ValueError("unix socket path is required") + + serve(sys.argv[1]) diff --git a/app/internal/sockopts/sockopts.go b/app/internal/sockopts/sockopts.go new file mode 100644 index 0000000..9c08922 --- /dev/null +++ b/app/internal/sockopts/sockopts.go @@ -0,0 +1,76 @@ +package sockopts + +import ( + "fmt" + "net" +) + +type SocketOptions struct { + BindInterface *string + FirewallMark *uint32 + FdControlUnixSocket *string +} + +// implemented in platform-specific files +var ( + bindInterfaceFunc func(c *net.UDPConn, device string) error + firewallMarkFunc func(c *net.UDPConn, fwmark uint32) error + fdControlUnixSocketFunc func(c *net.UDPConn, path string) error +) + +func (o *SocketOptions) CheckSupported() (err error) { + if o.BindInterface != nil && bindInterfaceFunc == nil { + return &UnsupportedError{"bindInterface"} + } + if o.FirewallMark != nil && firewallMarkFunc == nil { + return &UnsupportedError{"fwmark"} + } + if o.FdControlUnixSocket != nil && fdControlUnixSocketFunc == nil { + return &UnsupportedError{"fdControlUnixSocket"} + } + return nil +} + +type UnsupportedError struct { + Field string +} + +func (e *UnsupportedError) Error() string { + return fmt.Sprintf("%s is not supported on this platform", e.Field) +} + +func (o *SocketOptions) ListenUDP() (uconn net.PacketConn, err error) { + uconn, err = net.ListenUDP("udp", nil) + if err != nil { + return + } + err = o.applyToUDPConn(uconn.(*net.UDPConn)) + if err != nil { + uconn.Close() + uconn = nil + return + } + return +} + +func (o *SocketOptions) applyToUDPConn(c *net.UDPConn) (err error) { + if o.BindInterface != nil && bindInterfaceFunc != nil { + err = bindInterfaceFunc(c, *o.BindInterface) + if err != nil { + err = fmt.Errorf("failed to bind to interface: %w", err) + } + } + if o.FirewallMark != nil && firewallMarkFunc != nil { + err = firewallMarkFunc(c, *o.FirewallMark) + if err != nil { + err = fmt.Errorf("failed to set fwmark: %w", err) + } + } + if o.FdControlUnixSocket != nil && fdControlUnixSocketFunc != nil { + err = fdControlUnixSocketFunc(c, *o.FdControlUnixSocket) + if err != nil { + err = fmt.Errorf("failed to send fd to control unix socket: %w", err) + } + } + return +} diff --git a/app/internal/sockopts/sockopts_linux.go b/app/internal/sockopts/sockopts_linux.go new file mode 100644 index 0000000..f4d8621 --- /dev/null +++ b/app/internal/sockopts/sockopts_linux.go @@ -0,0 +1,91 @@ +//go:build linux + +package sockopts + +import ( + "fmt" + "net" + "time" + + "golang.org/x/sys/unix" +) + +const ( + fdControlUnixTimeout = 3 * time.Second +) + +func init() { + bindInterfaceFunc = bindInterfaceImpl + firewallMarkFunc = firewallMarkImpl + fdControlUnixSocketFunc = fdControlUnixSocketImpl +} + +func controlUDPConn(c *net.UDPConn, cb func(fd int) error) (err error) { + rconn, err := c.SyscallConn() + if err != nil { + return + } + cerr := rconn.Control(func(fd uintptr) { + err = cb(int(fd)) + }) + if err != nil { + return + } + if cerr != nil { + err = fmt.Errorf("failed to control fd: %w", cerr) + return + } + return +} + +func bindInterfaceImpl(c *net.UDPConn, device string) error { + return controlUDPConn(c, func(fd int) error { + return unix.BindToDevice(fd, device) + }) +} + +func firewallMarkImpl(c *net.UDPConn, fwmark uint32) error { + return controlUDPConn(c, func(fd int) error { + return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, int(fwmark)) + }) +} + +func fdControlUnixSocketImpl(c *net.UDPConn, path string) error { + return controlUDPConn(c, func(fd int) error { + socketFd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, unix.PROT_NONE) + if err != nil { + return fmt.Errorf("failed to create unix socket: %w", err) + } + defer unix.Close(socketFd) + + timeoutUsec := fdControlUnixTimeout.Microseconds() + timeout := unix.Timeval{ + Sec: timeoutUsec / 1e6, + Usec: timeoutUsec % 1e6, + } + + _ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &timeout) + _ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_SNDTIMEO, &timeout) + + err = unix.Connect(socketFd, &unix.SockaddrUnix{Name: path}) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + + err = unix.Sendmsg(socketFd, nil, unix.UnixRights(fd), nil, 0) + if err != nil { + return fmt.Errorf("failed to send: %w", err) + } + + dummy := []byte{1} + n, err := unix.Read(socketFd, dummy) + if err != nil { + return fmt.Errorf("failed to receive: %w", err) + } + if n != 1 { + return fmt.Errorf("socket closed unexpectedly") + } + + return nil + }) +} diff --git a/app/internal/sockopts/sockopts_linux_test.go b/app/internal/sockopts/sockopts_linux_test.go new file mode 100644 index 0000000..66614a4 --- /dev/null +++ b/app/internal/sockopts/sockopts_linux_test.go @@ -0,0 +1,53 @@ +//go:build linux + +package sockopts + +import ( + "net" + "os" + "os/exec" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/sys/unix" +) + +func Test_fdControlUnixSocketImpl(t *testing.T) { + sockPath := "./fd_control_unix_socket_test.sock" + defer os.Remove(sockPath) + + // Run test server + cmd := exec.Command("python", "fd_control_unix_socket_test.py", sockPath) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if !assert.NoError(t, err) { + return + } + defer cmd.Process.Kill() + + // Wait for the server to start + time.Sleep(1 * time.Second) + + so := SocketOptions{ + FdControlUnixSocket: &sockPath, + } + conn, err := so.ListenUDP() + if !assert.NoError(t, err) { + return + } + defer conn.Close() + + err = controlUDPConn(conn.(*net.UDPConn), func(fd int) (err error) { + rcvbuf, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF) + if err != nil { + return + } + // The test server called setsockopt(fd, SOL_SOCKET, SO_RCVBUF, 2500), + // and kernel will double this value for getsockopt(). + assert.Equal(t, 5000, rcvbuf) + return + }) + assert.NoError(t, err) +} diff --git a/extras/protect/protect.go b/extras/protect/protect.go deleted file mode 100644 index 1b5cd08..0000000 --- a/extras/protect/protect.go +++ /dev/null @@ -1,9 +0,0 @@ -// Package protect set VPN protect for every conns to bypass route. -package protect - -import ( - "net" -) - -// ListenUDPFunc listen UDP with VPN protect. -type ListenUDPFunc func() (net.PacketConn, error) diff --git a/extras/protect/protect_linux.go b/extras/protect/protect_linux.go deleted file mode 100644 index 003eb13..0000000 --- a/extras/protect/protect_linux.go +++ /dev/null @@ -1,78 +0,0 @@ -//go:build linux - -package protect - -import ( - "errors" - "net" - "reflect" - - "golang.org/x/sys/unix" -) - -const ( - timevalSec = 3 -) - -// protect try to connect with path by unix socket, then send the conn's fd to it. -func protect(connFd int, path string) error { - if path == "" { - return nil - } - - socketFd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, unix.PROT_NONE) - if err != nil { - return err - } - defer unix.Close(socketFd) - - _ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{Sec: timevalSec}) - _ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_SNDTIMEO, &unix.Timeval{Sec: timevalSec}) - - err = unix.Connect(socketFd, &unix.SockaddrUnix{Name: path}) - if err != nil { - return err - } - - err = unix.Sendmsg(socketFd, nil, unix.UnixRights(connFd), nil, 0) - if err != nil { - return err - } - - dummy := []byte{1} - n, err := unix.Read(socketFd, dummy) - if err != nil { - return err - } - if n != 1 { - return errors.New("protect failed") - } - - return nil -} - -func ListenUDP(protectPath string) ListenUDPFunc { - return func() (net.PacketConn, error) { - udpConn, err := net.ListenUDP("udp", nil) - if err != nil { - return nil, err - } - - err = protect(fdFromConn(udpConn), protectPath) - if err != nil { - _ = udpConn.Close() - return nil, err - } - - return udpConn, nil - } -} - -// fdFromConn get net.Conn's file descriptor. -func fdFromConn(conn net.Conn) int { - v := reflect.ValueOf(conn) - netFD := reflect.Indirect(reflect.Indirect(v).FieldByName("fd")) - pfd := reflect.Indirect(netFD.FieldByName("pfd")) - fd := int(pfd.FieldByName("Sysfd").Int()) - return fd -} diff --git a/extras/protect/protect_stub.go b/extras/protect/protect_stub.go deleted file mode 100644 index a8da279..0000000 --- a/extras/protect/protect_stub.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build !linux - -package protect - -import ( - "net" -) - -func ListenUDP(protectPath string) ListenUDPFunc { - return func() (net.PacketConn, error) { - return net.ListenUDP("udp", nil) - } -} diff --git a/extras/transport/udphop/conn.go b/extras/transport/udphop/conn.go index 722aa1c..32cc31c 100644 --- a/extras/transport/udphop/conn.go +++ b/extras/transport/udphop/conn.go @@ -7,8 +7,6 @@ import ( "sync" "syscall" "time" - - "github.com/apernet/hysteria/extras/protect" ) const ( @@ -22,7 +20,7 @@ type udpHopPacketConn struct { Addr net.Addr Addrs []net.Addr HopInterval time.Duration - ListenUDPFunc protect.ListenUDPFunc + ListenUDPFunc ListenUDPFunc connMutex sync.RWMutex prevConn net.PacketConn @@ -46,7 +44,9 @@ type udpPacket struct { Err error } -func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc protect.ListenUDPFunc) (net.PacketConn, error) { +type ListenUDPFunc = func() (net.PacketConn, error) + +func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc ListenUDPFunc) (net.PacketConn, error) { if hopInterval == 0 { hopInterval = defaultHopInterval } else if hopInterval < 5*time.Second {