mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 04:27:39 +03:00
refactor: protect => quic.sockopts
Android's VpnService.protect() itself is confusing, so we rename the "protect" feature with the name `fdControlUnixSocket` and make it a sub-option under `quic.sockopts`. A unit test is added to make sure the protect feature works. I also added two other common options to `quic.sockopts` that I copied from my other projects but did not fully test here.
This commit is contained in:
parent
a05383c2a1
commit
3e34da1aa8
9 changed files with 321 additions and 115 deletions
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/apernet/hysteria/app/internal/forwarding"
|
||||
"github.com/apernet/hysteria/app/internal/http"
|
||||
"github.com/apernet/hysteria/app/internal/redirect"
|
||||
"github.com/apernet/hysteria/app/internal/sockopts"
|
||||
"github.com/apernet/hysteria/app/internal/socks5"
|
||||
"github.com/apernet/hysteria/app/internal/tproxy"
|
||||
"github.com/apernet/hysteria/app/internal/tun"
|
||||
|
@ -30,7 +31,6 @@ import (
|
|||
"github.com/apernet/hysteria/core/client"
|
||||
"github.com/apernet/hysteria/extras/correctnet"
|
||||
"github.com/apernet/hysteria/extras/obfs"
|
||||
"github.com/apernet/hysteria/extras/protect"
|
||||
"github.com/apernet/hysteria/extras/transport/udphop"
|
||||
)
|
||||
|
||||
|
@ -56,7 +56,6 @@ func initClientFlags() {
|
|||
|
||||
type clientConfig struct {
|
||||
Server string `mapstructure:"server"`
|
||||
ProtectPath string `mapstructure:"protectPath"`
|
||||
Auth string `mapstructure:"auth"`
|
||||
Transport clientConfigTransport `mapstructure:"transport"`
|
||||
Obfs clientConfigObfs `mapstructure:"obfs"`
|
||||
|
@ -101,13 +100,20 @@ type clientConfigTLS struct {
|
|||
}
|
||||
|
||||
type clientConfigQUIC struct {
|
||||
InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"`
|
||||
MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"`
|
||||
InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"`
|
||||
MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"`
|
||||
MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"`
|
||||
KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"`
|
||||
DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"`
|
||||
InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"`
|
||||
MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"`
|
||||
InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"`
|
||||
MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"`
|
||||
MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"`
|
||||
KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"`
|
||||
DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"`
|
||||
Sockopts clientConfigQUICSockopts `mapstructure:"sockopts"`
|
||||
}
|
||||
|
||||
type clientConfigQUICSockopts struct {
|
||||
BindInterface *string `mapstructure:"bindInterface"`
|
||||
FirewallMark *uint32 `mapstructure:"fwmark"`
|
||||
FdControlUnixSocket *string `mapstructure:"fdControlUnixSocket"`
|
||||
}
|
||||
|
||||
type clientConfigBandwidth struct {
|
||||
|
@ -197,6 +203,21 @@ func (c *clientConfig) fillServerAddr(hyConfig *client.Config) error {
|
|||
// fillConnFactory must be called after fillServerAddr, as we have different logic
|
||||
// for ConnFactory depending on whether we have a port hopping address.
|
||||
func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error {
|
||||
so := &sockopts.SocketOptions{
|
||||
BindInterface: c.QUIC.Sockopts.BindInterface,
|
||||
FirewallMark: c.QUIC.Sockopts.FirewallMark,
|
||||
FdControlUnixSocket: c.QUIC.Sockopts.FdControlUnixSocket,
|
||||
}
|
||||
if err := so.CheckSupported(); err != nil {
|
||||
var unsupportedErr *sockopts.UnsupportedError
|
||||
if errors.As(err, &unsupportedErr) {
|
||||
return configError{
|
||||
Field: "quic.sockopts." + unsupportedErr.Field,
|
||||
Err: errors.New("unsupported on this platform"),
|
||||
}
|
||||
}
|
||||
return configError{Field: "quic.sockopts", Err: err}
|
||||
}
|
||||
// Inner PacketConn
|
||||
var newFunc func(addr net.Addr) (net.PacketConn, error)
|
||||
switch strings.ToLower(c.Transport.Type) {
|
||||
|
@ -204,11 +225,11 @@ func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error {
|
|||
if hyConfig.ServerAddr.Network() == "udphop" {
|
||||
hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr)
|
||||
newFunc = func(addr net.Addr) (net.PacketConn, error) {
|
||||
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, protect.ListenUDP(c.ProtectPath))
|
||||
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, so.ListenUDP)
|
||||
}
|
||||
} else {
|
||||
newFunc = func(addr net.Addr) (net.PacketConn, error) {
|
||||
return protect.ListenUDP(c.ProtectPath)()
|
||||
return so.ListenUDP()
|
||||
}
|
||||
}
|
||||
default:
|
||||
|
|
65
app/internal/sockopts/fd_control_unix_socket_test.py
Normal file
65
app/internal/sockopts/fd_control_unix_socket_test.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
import socket
|
||||
import array
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
|
||||
|
||||
def serve(path):
|
||||
try:
|
||||
os.unlink(path)
|
||||
except OSError:
|
||||
if os.path.exists(path):
|
||||
raise
|
||||
|
||||
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server.bind(path)
|
||||
server.listen()
|
||||
print(f"Listening on {path}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
connection, client_address = server.accept()
|
||||
print(f"Client connected")
|
||||
|
||||
try:
|
||||
# Receiving fd from client
|
||||
fds = array.array("i")
|
||||
msg, ancdata, flags, addr = connection.recvmsg(1, socket.CMSG_LEN(struct.calcsize('i')))
|
||||
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
||||
if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
|
||||
fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
|
||||
|
||||
fd = fds[0]
|
||||
|
||||
# We make a call to setsockopt(2) here, so client can verify we have received the fd
|
||||
# In the real scenario, the server would set things like SO_MARK,
|
||||
# we use SO_RCVBUF as it doesn't require any special capabilities.
|
||||
nbytes = struct.pack("i", 2500)
|
||||
fdsocket = fd_to_socket(fd)
|
||||
fdsocket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, nbytes)
|
||||
fdsocket.close()
|
||||
|
||||
# The only protocol-like thing specified in the client implementation.
|
||||
connection.send(b'\x01')
|
||||
finally:
|
||||
connection.close()
|
||||
print("Connection closed")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Exit")
|
||||
|
||||
finally:
|
||||
server.close()
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def fd_to_socket(fd):
|
||||
return socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
raise ValueError("unix socket path is required")
|
||||
|
||||
serve(sys.argv[1])
|
76
app/internal/sockopts/sockopts.go
Normal file
76
app/internal/sockopts/sockopts.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package sockopts
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
type SocketOptions struct {
|
||||
BindInterface *string
|
||||
FirewallMark *uint32
|
||||
FdControlUnixSocket *string
|
||||
}
|
||||
|
||||
// implemented in platform-specific files
|
||||
var (
|
||||
bindInterfaceFunc func(c *net.UDPConn, device string) error
|
||||
firewallMarkFunc func(c *net.UDPConn, fwmark uint32) error
|
||||
fdControlUnixSocketFunc func(c *net.UDPConn, path string) error
|
||||
)
|
||||
|
||||
func (o *SocketOptions) CheckSupported() (err error) {
|
||||
if o.BindInterface != nil && bindInterfaceFunc == nil {
|
||||
return &UnsupportedError{"bindInterface"}
|
||||
}
|
||||
if o.FirewallMark != nil && firewallMarkFunc == nil {
|
||||
return &UnsupportedError{"fwmark"}
|
||||
}
|
||||
if o.FdControlUnixSocket != nil && fdControlUnixSocketFunc == nil {
|
||||
return &UnsupportedError{"fdControlUnixSocket"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type UnsupportedError struct {
|
||||
Field string
|
||||
}
|
||||
|
||||
func (e *UnsupportedError) Error() string {
|
||||
return fmt.Sprintf("%s is not supported on this platform", e.Field)
|
||||
}
|
||||
|
||||
func (o *SocketOptions) ListenUDP() (uconn net.PacketConn, err error) {
|
||||
uconn, err = net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = o.applyToUDPConn(uconn.(*net.UDPConn))
|
||||
if err != nil {
|
||||
uconn.Close()
|
||||
uconn = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (o *SocketOptions) applyToUDPConn(c *net.UDPConn) (err error) {
|
||||
if o.BindInterface != nil && bindInterfaceFunc != nil {
|
||||
err = bindInterfaceFunc(c, *o.BindInterface)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to bind to interface: %w", err)
|
||||
}
|
||||
}
|
||||
if o.FirewallMark != nil && firewallMarkFunc != nil {
|
||||
err = firewallMarkFunc(c, *o.FirewallMark)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to set fwmark: %w", err)
|
||||
}
|
||||
}
|
||||
if o.FdControlUnixSocket != nil && fdControlUnixSocketFunc != nil {
|
||||
err = fdControlUnixSocketFunc(c, *o.FdControlUnixSocket)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to send fd to control unix socket: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
91
app/internal/sockopts/sockopts_linux.go
Normal file
91
app/internal/sockopts/sockopts_linux.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
//go:build linux
|
||||
|
||||
package sockopts
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
fdControlUnixTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
func init() {
|
||||
bindInterfaceFunc = bindInterfaceImpl
|
||||
firewallMarkFunc = firewallMarkImpl
|
||||
fdControlUnixSocketFunc = fdControlUnixSocketImpl
|
||||
}
|
||||
|
||||
func controlUDPConn(c *net.UDPConn, cb func(fd int) error) (err error) {
|
||||
rconn, err := c.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cerr := rconn.Control(func(fd uintptr) {
|
||||
err = cb(int(fd))
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cerr != nil {
|
||||
err = fmt.Errorf("failed to control fd: %w", cerr)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func bindInterfaceImpl(c *net.UDPConn, device string) error {
|
||||
return controlUDPConn(c, func(fd int) error {
|
||||
return unix.BindToDevice(fd, device)
|
||||
})
|
||||
}
|
||||
|
||||
func firewallMarkImpl(c *net.UDPConn, fwmark uint32) error {
|
||||
return controlUDPConn(c, func(fd int) error {
|
||||
return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, int(fwmark))
|
||||
})
|
||||
}
|
||||
|
||||
func fdControlUnixSocketImpl(c *net.UDPConn, path string) error {
|
||||
return controlUDPConn(c, func(fd int) error {
|
||||
socketFd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, unix.PROT_NONE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create unix socket: %w", err)
|
||||
}
|
||||
defer unix.Close(socketFd)
|
||||
|
||||
timeoutUsec := fdControlUnixTimeout.Microseconds()
|
||||
timeout := unix.Timeval{
|
||||
Sec: timeoutUsec / 1e6,
|
||||
Usec: timeoutUsec % 1e6,
|
||||
}
|
||||
|
||||
_ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &timeout)
|
||||
_ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_SNDTIMEO, &timeout)
|
||||
|
||||
err = unix.Connect(socketFd, &unix.SockaddrUnix{Name: path})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
err = unix.Sendmsg(socketFd, nil, unix.UnixRights(fd), nil, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send: %w", err)
|
||||
}
|
||||
|
||||
dummy := []byte{1}
|
||||
n, err := unix.Read(socketFd, dummy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to receive: %w", err)
|
||||
}
|
||||
if n != 1 {
|
||||
return fmt.Errorf("socket closed unexpectedly")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
53
app/internal/sockopts/sockopts_linux_test.go
Normal file
53
app/internal/sockopts/sockopts_linux_test.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
//go:build linux
|
||||
|
||||
package sockopts
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func Test_fdControlUnixSocketImpl(t *testing.T) {
|
||||
sockPath := "./fd_control_unix_socket_test.sock"
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Run test server
|
||||
cmd := exec.Command("python", "fd_control_unix_socket_test.py", sockPath)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err := cmd.Start()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
defer cmd.Process.Kill()
|
||||
|
||||
// Wait for the server to start
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
so := SocketOptions{
|
||||
FdControlUnixSocket: &sockPath,
|
||||
}
|
||||
conn, err := so.ListenUDP()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = controlUDPConn(conn.(*net.UDPConn), func(fd int) (err error) {
|
||||
rcvbuf, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// The test server called setsockopt(fd, SOL_SOCKET, SO_RCVBUF, 2500),
|
||||
// and kernel will double this value for getsockopt().
|
||||
assert.Equal(t, 5000, rcvbuf)
|
||||
return
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
// Package protect set VPN protect for every conns to bypass route.
|
||||
package protect
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// ListenUDPFunc listen UDP with VPN protect.
|
||||
type ListenUDPFunc func() (net.PacketConn, error)
|
|
@ -1,78 +0,0 @@
|
|||
//go:build linux
|
||||
|
||||
package protect
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
timevalSec = 3
|
||||
)
|
||||
|
||||
// protect try to connect with path by unix socket, then send the conn's fd to it.
|
||||
func protect(connFd int, path string) error {
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
socketFd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, unix.PROT_NONE)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer unix.Close(socketFd)
|
||||
|
||||
_ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{Sec: timevalSec})
|
||||
_ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_SNDTIMEO, &unix.Timeval{Sec: timevalSec})
|
||||
|
||||
err = unix.Connect(socketFd, &unix.SockaddrUnix{Name: path})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = unix.Sendmsg(socketFd, nil, unix.UnixRights(connFd), nil, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dummy := []byte{1}
|
||||
n, err := unix.Read(socketFd, dummy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n != 1 {
|
||||
return errors.New("protect failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ListenUDP(protectPath string) ListenUDPFunc {
|
||||
return func() (net.PacketConn, error) {
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = protect(fdFromConn(udpConn), protectPath)
|
||||
if err != nil {
|
||||
_ = udpConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return udpConn, nil
|
||||
}
|
||||
}
|
||||
|
||||
// fdFromConn get net.Conn's file descriptor.
|
||||
func fdFromConn(conn net.Conn) int {
|
||||
v := reflect.ValueOf(conn)
|
||||
netFD := reflect.Indirect(reflect.Indirect(v).FieldByName("fd"))
|
||||
pfd := reflect.Indirect(netFD.FieldByName("pfd"))
|
||||
fd := int(pfd.FieldByName("Sysfd").Int())
|
||||
return fd
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
//go:build !linux
|
||||
|
||||
package protect
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func ListenUDP(protectPath string) ListenUDPFunc {
|
||||
return func() (net.PacketConn, error) {
|
||||
return net.ListenUDP("udp", nil)
|
||||
}
|
||||
}
|
|
@ -7,8 +7,6 @@ import (
|
|||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/apernet/hysteria/extras/protect"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -22,7 +20,7 @@ type udpHopPacketConn struct {
|
|||
Addr net.Addr
|
||||
Addrs []net.Addr
|
||||
HopInterval time.Duration
|
||||
ListenUDPFunc protect.ListenUDPFunc
|
||||
ListenUDPFunc ListenUDPFunc
|
||||
|
||||
connMutex sync.RWMutex
|
||||
prevConn net.PacketConn
|
||||
|
@ -46,7 +44,9 @@ type udpPacket struct {
|
|||
Err error
|
||||
}
|
||||
|
||||
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc protect.ListenUDPFunc) (net.PacketConn, error) {
|
||||
type ListenUDPFunc = func() (net.PacketConn, error)
|
||||
|
||||
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc ListenUDPFunc) (net.PacketConn, error) {
|
||||
if hopInterval == 0 {
|
||||
hopInterval = defaultHopInterval
|
||||
} else if hopInterval < 5*time.Second {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue