Refactor bind control

This commit is contained in:
世界 2022-09-13 22:14:36 +08:00
parent a606585cf7
commit 93cc53b60c
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
8 changed files with 93 additions and 173 deletions

View file

@ -1,43 +1,44 @@
package control
import (
"net"
E "github.com/sagernet/sing/common/exceptions"
"os"
"runtime"
"syscall"
)
type BindManager interface {
IndexByName(name string) (int, error)
Update() error
func BindToInterface(finder InterfaceFinder, interfaceName string, interfaceIndex int) Func {
return func(network, address string, conn syscall.RawConn) error {
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
}
}
type myBindManager struct {
interfaceIndexByName map[string]int
func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int)) Func {
return func(network, address string, conn syscall.RawConn) error {
interfaceName, interfaceIndex := block(network, address)
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
}
}
func (m *myBindManager) IndexByName(name string) (int, error) {
if index, loaded := m.interfaceIndexByName[name]; loaded {
return index, nil
}
err := m.Update()
if err != nil {
return 0, err
}
if index, loaded := m.interfaceIndexByName[name]; loaded {
return index, nil
}
return 0, E.New("interface ", name, " not found")
}
const useInterfaceName = runtime.GOOS == "linux" || runtime.GOOS == "android"
func (m *myBindManager) Update() error {
interfaces, err := net.Interfaces()
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
if interfaceName == "" && interfaceIndex == -1 {
return nil
}
if interfaceName != "" && useInterfaceName || interfaceIndex != -1 && !useInterfaceName {
return bindToInterface(conn, network, address, interfaceName, interfaceIndex)
}
if finder == nil {
return os.ErrInvalid
}
var err error
if useInterfaceName {
interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex)
} else {
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
}
if err != nil {
return err
}
interfaceIndexByName := make(map[string]int)
for _, iface := range interfaces {
interfaceIndexByName[iface.Name] = iface.Index
}
m.interfaceIndexByName = interfaceIndexByName
return nil
return bindToInterface(conn, network, address, interfaceName, interfaceIndex)
}

View file

@ -6,50 +6,16 @@ import (
"golang.org/x/sys/unix"
)
func NewBindManager() BindManager {
return &myBindManager{
interfaceIndexByName: make(map[string]int),
}
}
func BindToInterface(manager BindManager, interfaceName string) Func {
return func(network, address string, conn syscall.RawConn) error {
index, err := manager.IndexByName(interfaceName)
if err != nil {
return err
}
return bindToInterface(conn, network, index)
}
}
func BindToInterfaceFunc(manager BindManager, interfaceNameFunc func(network, address string) string) Func {
return func(network, address string, conn syscall.RawConn) error {
interfaceName := interfaceNameFunc(network, address)
if interfaceName == "" {
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
if interfaceIndex == -1 {
return nil
}
index, err := manager.IndexByName(interfaceName)
if err != nil {
return err
}
return bindToInterface(conn, network, index)
}
}
func BindToInterfaceIndexFunc(interfaceIndexFunc func(network, address string) int) Func {
return func(network, address string, conn syscall.RawConn) error {
index := interfaceIndexFunc(network, address)
return bindToInterface(conn, network, index)
}
}
func bindToInterface(conn syscall.RawConn, network string, index int) error {
return Raw(conn, func(fd uintptr) error {
switch network {
case "tcp6", "udp6":
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, index)
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, interfaceIndex)
default:
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, interfaceIndex)
}
})
}

View file

@ -0,0 +1,30 @@
package control
import "net"
type InterfaceFinder interface {
InterfaceIndexByName(name string) (int, error)
InterfaceNameByIndex(index int) (string, error)
}
func DefaultInterfaceFinder() InterfaceFinder {
return (*netInterfaceFinder)(nil)
}
type netInterfaceFinder struct{}
func (w *netInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
netInterface, err := net.InterfaceByName(name)
if err != nil {
return 0, err
}
return netInterface.Index, nil
}
func (w *netInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
netInterface, err := net.InterfaceByIndex(index)
if err != nil {
return "", err
}
return netInterface.Name, nil
}

View file

@ -2,32 +2,12 @@ package control
import (
"syscall"
"golang.org/x/sys/unix"
)
func NewBindManager() BindManager {
return nil
}
func BindToInterface(manager BindManager, interfaceName string) Func {
return func(network, address string, conn syscall.RawConn) error {
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
return Raw(conn, func(fd uintptr) error {
return syscall.BindToDevice(int(fd), interfaceName)
return unix.BindToDevice(int(fd), interfaceName)
})
}
}
func BindToInterfaceFunc(manager BindManager, interfaceNameFunc func(network, address string) string) Func {
return func(network, address string, conn syscall.RawConn) error {
interfaceName := interfaceNameFunc(network, address)
if interfaceName == "" {
return nil
}
return Raw(conn, func(fd uintptr) error {
return syscall.BindToDevice(int(fd), interfaceName)
})
}
}
func BindToInterfaceIndexFunc(interfaceIndexFunc func(network, address string) int) Func {
return nil
}

View file

@ -2,18 +2,8 @@
package control
func NewBindManager() BindManager {
return nil
}
import "syscall"
func BindToInterface(manager BindManager, interfaceName string) Func {
return nil
}
func BindToInterfaceFunc(manager BindManager, interfaceNameFunc func(network, address string) string) Func {
return nil
}
func BindToInterfaceIndexFunc(interfaceIndexFunc func(network, address string) int) Func {
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
return nil
}

View file

@ -2,46 +2,17 @@ package control
import (
"encoding/binary"
"net"
"net/netip"
"syscall"
"unsafe"
M "github.com/sagernet/sing/common/metadata"
)
const (
IP_UNICAST_IF = 31
IPV6_UNICAST_IF = 31
)
func NewBindManager() BindManager {
return &myBindManager{
interfaceIndexByName: make(map[string]int),
}
}
func bind4(handle syscall.Handle, ifaceIdx int) error {
var bytes [4]byte
binary.BigEndian.PutUint32(bytes[:], uint32(ifaceIdx))
idx := *(*uint32)(unsafe.Pointer(&bytes[0]))
return syscall.SetsockoptInt(handle, syscall.IPPROTO_IP, IP_UNICAST_IF, int(idx))
}
func bind6(handle syscall.Handle, ifaceIdx int) error {
return syscall.SetsockoptInt(handle, syscall.IPPROTO_IPV6, IPV6_UNICAST_IF, ifaceIdx)
}
func bindInterfaceIndex(network string, address string, conn syscall.RawConn, interfaceIndex int) error {
ipStr, _, err := net.SplitHostPort(address)
if err == nil {
if ip, err := netip.ParseAddr(ipStr); err == nil && !ip.IsGlobalUnicast() {
return err
}
}
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
return Raw(conn, func(fd uintptr) error {
handle := syscall.Handle(fd)
// handle ip empty, e.g. net.Listen("udp", ":0")
if ipStr == "" {
err = bind4(handle, interfaceIndex)
if M.ParseSocksaddr(address).AddrString() == "" {
err := bind4(handle, interfaceIndex)
if err != nil {
return err
}
@ -58,36 +29,18 @@ func bindInterfaceIndex(network string, address string, conn syscall.RawConn, in
})
}
func BindToInterface(manager BindManager, interfaceName string) Func {
return func(network, address string, conn syscall.RawConn) error {
index, err := manager.IndexByName(interfaceName)
if err != nil {
return err
}
return bindInterfaceIndex(network, address, conn, index)
}
const (
IP_UNICAST_IF = 31
IPV6_UNICAST_IF = 31
)
func bind4(handle syscall.Handle, ifaceIdx int) error {
var bytes [4]byte
binary.BigEndian.PutUint32(bytes[:], uint32(ifaceIdx))
idx := *(*uint32)(unsafe.Pointer(&bytes[0]))
return syscall.SetsockoptInt(handle, syscall.IPPROTO_IP, IP_UNICAST_IF, int(idx))
}
func BindToInterfaceFunc(manager BindManager, interfaceNameFunc func(network, address string) string) Func {
return func(network, address string, conn syscall.RawConn) error {
interfaceName := interfaceNameFunc(network, address)
if interfaceName == "" {
return nil
}
index, err := manager.IndexByName(interfaceName)
if err != nil {
return err
}
return bindInterfaceIndex(network, address, conn, index)
}
}
func BindToInterfaceIndexFunc(interfaceIndexFunc func(network, address string) int) Func {
return func(network, address string, conn syscall.RawConn) error {
index := interfaceIndexFunc(network, address)
if index == -1 {
return nil
}
return bindInterfaceIndex(network, address, conn, index)
}
func bind6(handle syscall.Handle, ifaceIdx int) error {
return syscall.SetsockoptInt(handle, syscall.IPPROTO_IPV6, IPV6_UNICAST_IF, ifaceIdx)
}

2
go.mod
View file

@ -2,4 +2,4 @@ module github.com/sagernet/sing
go 1.18
require golang.org/x/sys v0.0.0-20220818161305-2296e01440c6
require golang.org/x/sys v0.0.0-20220913120320-3275c407cedc

4
go.sum
View file

@ -1,2 +1,2 @@
golang.org/x/sys v0.0.0-20220818161305-2296e01440c6 h1:Sx/u41w+OwrInGdEckYmEuU5gHoGSL4QbDz3S9s6j4U=
golang.org/x/sys v0.0.0-20220818161305-2296e01440c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220913120320-3275c407cedc h1:dpclq5m2YrqPGStKmtw7IcNbKLfbIqKXvNxDJKdIKYc=
golang.org/x/sys v0.0.0-20220913120320-3275c407cedc/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=