mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 04:27:39 +03:00
refactor: protect => quic.sockopts
Android's VpnService.protect() itself is confusing, so we rename the "protect" feature with the name `fdControlUnixSocket` and make it a sub-option under `quic.sockopts`. A unit test is added to make sure the protect feature works. I also added two other common options to `quic.sockopts` that I copied from my other projects but did not fully test here.
This commit is contained in:
parent
a05383c2a1
commit
3e34da1aa8
9 changed files with 321 additions and 115 deletions
|
@ -1,9 +0,0 @@
|
|||
// Package protect set VPN protect for every conns to bypass route.
|
||||
package protect
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// ListenUDPFunc listen UDP with VPN protect.
|
||||
type ListenUDPFunc func() (net.PacketConn, error)
|
|
@ -1,78 +0,0 @@
|
|||
//go:build linux
|
||||
|
||||
package protect
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
timevalSec = 3
|
||||
)
|
||||
|
||||
// protect try to connect with path by unix socket, then send the conn's fd to it.
|
||||
func protect(connFd int, path string) error {
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
socketFd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, unix.PROT_NONE)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer unix.Close(socketFd)
|
||||
|
||||
_ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{Sec: timevalSec})
|
||||
_ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_SNDTIMEO, &unix.Timeval{Sec: timevalSec})
|
||||
|
||||
err = unix.Connect(socketFd, &unix.SockaddrUnix{Name: path})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = unix.Sendmsg(socketFd, nil, unix.UnixRights(connFd), nil, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dummy := []byte{1}
|
||||
n, err := unix.Read(socketFd, dummy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n != 1 {
|
||||
return errors.New("protect failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ListenUDP(protectPath string) ListenUDPFunc {
|
||||
return func() (net.PacketConn, error) {
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = protect(fdFromConn(udpConn), protectPath)
|
||||
if err != nil {
|
||||
_ = udpConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return udpConn, nil
|
||||
}
|
||||
}
|
||||
|
||||
// fdFromConn get net.Conn's file descriptor.
|
||||
func fdFromConn(conn net.Conn) int {
|
||||
v := reflect.ValueOf(conn)
|
||||
netFD := reflect.Indirect(reflect.Indirect(v).FieldByName("fd"))
|
||||
pfd := reflect.Indirect(netFD.FieldByName("pfd"))
|
||||
fd := int(pfd.FieldByName("Sysfd").Int())
|
||||
return fd
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
//go:build !linux
|
||||
|
||||
package protect
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func ListenUDP(protectPath string) ListenUDPFunc {
|
||||
return func() (net.PacketConn, error) {
|
||||
return net.ListenUDP("udp", nil)
|
||||
}
|
||||
}
|
|
@ -7,8 +7,6 @@ import (
|
|||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/apernet/hysteria/extras/protect"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -22,7 +20,7 @@ type udpHopPacketConn struct {
|
|||
Addr net.Addr
|
||||
Addrs []net.Addr
|
||||
HopInterval time.Duration
|
||||
ListenUDPFunc protect.ListenUDPFunc
|
||||
ListenUDPFunc ListenUDPFunc
|
||||
|
||||
connMutex sync.RWMutex
|
||||
prevConn net.PacketConn
|
||||
|
@ -46,7 +44,9 @@ type udpPacket struct {
|
|||
Err error
|
||||
}
|
||||
|
||||
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc protect.ListenUDPFunc) (net.PacketConn, error) {
|
||||
type ListenUDPFunc = func() (net.PacketConn, error)
|
||||
|
||||
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc ListenUDPFunc) (net.PacketConn, error) {
|
||||
if hopInterval == 0 {
|
||||
hopInterval = defaultHopInterval
|
||||
} else if hopInterval < 5*time.Second {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue