refactor: protect => quic.sockopts

Android's VpnService.protect() itself is confusing, so we rename the
"protect" feature with the name `fdControlUnixSocket` and make it a
sub-option under `quic.sockopts`.

A unit test is added to make sure the protect feature works.

I also added two other common options to `quic.sockopts` that I copied
from my other projects but did not fully test here.
This commit is contained in:
Haruue 2024-04-05 02:20:45 +08:00
parent a05383c2a1
commit 3e34da1aa8
No known key found for this signature in database
GPG key ID: F6083B28CBCBC148
9 changed files with 321 additions and 115 deletions

View file

@ -0,0 +1,65 @@
import socket
import array
import os
import struct
import sys
def serve(path):
try:
os.unlink(path)
except OSError:
if os.path.exists(path):
raise
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server.bind(path)
server.listen()
print(f"Listening on {path}")
try:
while True:
connection, client_address = server.accept()
print(f"Client connected")
try:
# Receiving fd from client
fds = array.array("i")
msg, ancdata, flags, addr = connection.recvmsg(1, socket.CMSG_LEN(struct.calcsize('i')))
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
fd = fds[0]
# We make a call to setsockopt(2) here, so client can verify we have received the fd
# In the real scenario, the server would set things like SO_MARK,
# we use SO_RCVBUF as it doesn't require any special capabilities.
nbytes = struct.pack("i", 2500)
fdsocket = fd_to_socket(fd)
fdsocket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, nbytes)
fdsocket.close()
# The only protocol-like thing specified in the client implementation.
connection.send(b'\x01')
finally:
connection.close()
print("Connection closed")
except KeyboardInterrupt:
print("Exit")
finally:
server.close()
os.unlink(path)
def fd_to_socket(fd):
return socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM)
if __name__ == "__main__":
if len(sys.argv) < 2:
raise ValueError("unix socket path is required")
serve(sys.argv[1])

View file

@ -0,0 +1,76 @@
package sockopts
import (
"fmt"
"net"
)
type SocketOptions struct {
BindInterface *string
FirewallMark *uint32
FdControlUnixSocket *string
}
// implemented in platform-specific files
var (
bindInterfaceFunc func(c *net.UDPConn, device string) error
firewallMarkFunc func(c *net.UDPConn, fwmark uint32) error
fdControlUnixSocketFunc func(c *net.UDPConn, path string) error
)
func (o *SocketOptions) CheckSupported() (err error) {
if o.BindInterface != nil && bindInterfaceFunc == nil {
return &UnsupportedError{"bindInterface"}
}
if o.FirewallMark != nil && firewallMarkFunc == nil {
return &UnsupportedError{"fwmark"}
}
if o.FdControlUnixSocket != nil && fdControlUnixSocketFunc == nil {
return &UnsupportedError{"fdControlUnixSocket"}
}
return nil
}
type UnsupportedError struct {
Field string
}
func (e *UnsupportedError) Error() string {
return fmt.Sprintf("%s is not supported on this platform", e.Field)
}
func (o *SocketOptions) ListenUDP() (uconn net.PacketConn, err error) {
uconn, err = net.ListenUDP("udp", nil)
if err != nil {
return
}
err = o.applyToUDPConn(uconn.(*net.UDPConn))
if err != nil {
uconn.Close()
uconn = nil
return
}
return
}
func (o *SocketOptions) applyToUDPConn(c *net.UDPConn) (err error) {
if o.BindInterface != nil && bindInterfaceFunc != nil {
err = bindInterfaceFunc(c, *o.BindInterface)
if err != nil {
err = fmt.Errorf("failed to bind to interface: %w", err)
}
}
if o.FirewallMark != nil && firewallMarkFunc != nil {
err = firewallMarkFunc(c, *o.FirewallMark)
if err != nil {
err = fmt.Errorf("failed to set fwmark: %w", err)
}
}
if o.FdControlUnixSocket != nil && fdControlUnixSocketFunc != nil {
err = fdControlUnixSocketFunc(c, *o.FdControlUnixSocket)
if err != nil {
err = fmt.Errorf("failed to send fd to control unix socket: %w", err)
}
}
return
}

View file

@ -0,0 +1,91 @@
//go:build linux
package sockopts
import (
"fmt"
"net"
"time"
"golang.org/x/sys/unix"
)
const (
fdControlUnixTimeout = 3 * time.Second
)
func init() {
bindInterfaceFunc = bindInterfaceImpl
firewallMarkFunc = firewallMarkImpl
fdControlUnixSocketFunc = fdControlUnixSocketImpl
}
func controlUDPConn(c *net.UDPConn, cb func(fd int) error) (err error) {
rconn, err := c.SyscallConn()
if err != nil {
return
}
cerr := rconn.Control(func(fd uintptr) {
err = cb(int(fd))
})
if err != nil {
return
}
if cerr != nil {
err = fmt.Errorf("failed to control fd: %w", cerr)
return
}
return
}
func bindInterfaceImpl(c *net.UDPConn, device string) error {
return controlUDPConn(c, func(fd int) error {
return unix.BindToDevice(fd, device)
})
}
func firewallMarkImpl(c *net.UDPConn, fwmark uint32) error {
return controlUDPConn(c, func(fd int) error {
return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, int(fwmark))
})
}
func fdControlUnixSocketImpl(c *net.UDPConn, path string) error {
return controlUDPConn(c, func(fd int) error {
socketFd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, unix.PROT_NONE)
if err != nil {
return fmt.Errorf("failed to create unix socket: %w", err)
}
defer unix.Close(socketFd)
timeoutUsec := fdControlUnixTimeout.Microseconds()
timeout := unix.Timeval{
Sec: timeoutUsec / 1e6,
Usec: timeoutUsec % 1e6,
}
_ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &timeout)
_ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_SNDTIMEO, &timeout)
err = unix.Connect(socketFd, &unix.SockaddrUnix{Name: path})
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
err = unix.Sendmsg(socketFd, nil, unix.UnixRights(fd), nil, 0)
if err != nil {
return fmt.Errorf("failed to send: %w", err)
}
dummy := []byte{1}
n, err := unix.Read(socketFd, dummy)
if err != nil {
return fmt.Errorf("failed to receive: %w", err)
}
if n != 1 {
return fmt.Errorf("socket closed unexpectedly")
}
return nil
})
}

View file

@ -0,0 +1,53 @@
//go:build linux
package sockopts
import (
"net"
"os"
"os/exec"
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/sys/unix"
)
func Test_fdControlUnixSocketImpl(t *testing.T) {
sockPath := "./fd_control_unix_socket_test.sock"
defer os.Remove(sockPath)
// Run test server
cmd := exec.Command("python", "fd_control_unix_socket_test.py", sockPath)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
if !assert.NoError(t, err) {
return
}
defer cmd.Process.Kill()
// Wait for the server to start
time.Sleep(1 * time.Second)
so := SocketOptions{
FdControlUnixSocket: &sockPath,
}
conn, err := so.ListenUDP()
if !assert.NoError(t, err) {
return
}
defer conn.Close()
err = controlUDPConn(conn.(*net.UDPConn), func(fd int) (err error) {
rcvbuf, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF)
if err != nil {
return
}
// The test server called setsockopt(fd, SOL_SOCKET, SO_RCVBUF, 2500),
// and kernel will double this value for getsockopt().
assert.Equal(t, 5000, rcvbuf)
return
})
assert.NoError(t, err)
}