From 93cc53b60ceee229938b96edaa792a0edaf0d683 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 13 Sep 2022 22:14:36 +0800 Subject: [PATCH] Refactor bind control --- common/control/bind.go | 59 +++++++++++++------------ common/control/bind_darwin.go | 44 +++--------------- common/control/bind_finder.go | 30 +++++++++++++ common/control/bind_linux.go | 32 +++----------- common/control/bind_other.go | 14 +----- common/control/bind_windows.go | 81 +++++++--------------------------- go.mod | 2 +- go.sum | 4 +- 8 files changed, 93 insertions(+), 173 deletions(-) create mode 100644 common/control/bind_finder.go diff --git a/common/control/bind.go b/common/control/bind.go index 310e78c..e41bd39 100644 --- a/common/control/bind.go +++ b/common/control/bind.go @@ -1,43 +1,44 @@ package control import ( - "net" - - E "github.com/sagernet/sing/common/exceptions" + "os" + "runtime" + "syscall" ) -type BindManager interface { - IndexByName(name string) (int, error) - Update() error +func BindToInterface(finder InterfaceFinder, interfaceName string, interfaceIndex int) Func { + return func(network, address string, conn syscall.RawConn) error { + return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex) + } } -type myBindManager struct { - interfaceIndexByName map[string]int +func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int)) Func { + return func(network, address string, conn syscall.RawConn) error { + interfaceName, interfaceIndex := block(network, address) + return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex) + } } -func (m *myBindManager) IndexByName(name string) (int, error) { - if index, loaded := m.interfaceIndexByName[name]; loaded { - return index, nil - } - err := m.Update() - if err != nil { - return 0, err - } - if index, loaded := m.interfaceIndexByName[name]; loaded { - return index, nil - } - return 0, E.New("interface ", name, " not found") -} +const useInterfaceName = runtime.GOOS == "linux" || runtime.GOOS == "android" -func (m *myBindManager) Update() error { - interfaces, err := net.Interfaces() +func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { + if interfaceName == "" && interfaceIndex == -1 { + return nil + } + if interfaceName != "" && useInterfaceName || interfaceIndex != -1 && !useInterfaceName { + return bindToInterface(conn, network, address, interfaceName, interfaceIndex) + } + if finder == nil { + return os.ErrInvalid + } + var err error + if useInterfaceName { + interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex) + } else { + interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + } if err != nil { return err } - interfaceIndexByName := make(map[string]int) - for _, iface := range interfaces { - interfaceIndexByName[iface.Name] = iface.Index - } - m.interfaceIndexByName = interfaceIndexByName - return nil + return bindToInterface(conn, network, address, interfaceName, interfaceIndex) } diff --git a/common/control/bind_darwin.go b/common/control/bind_darwin.go index 5836563..8262ac7 100644 --- a/common/control/bind_darwin.go +++ b/common/control/bind_darwin.go @@ -6,50 +6,16 @@ import ( "golang.org/x/sys/unix" ) -func NewBindManager() BindManager { - return &myBindManager{ - interfaceIndexByName: make(map[string]int), +func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { + if interfaceIndex == -1 { + return nil } -} - -func BindToInterface(manager BindManager, interfaceName string) Func { - return func(network, address string, conn syscall.RawConn) error { - index, err := manager.IndexByName(interfaceName) - if err != nil { - return err - } - return bindToInterface(conn, network, index) - } -} - -func BindToInterfaceFunc(manager BindManager, interfaceNameFunc func(network, address string) string) Func { - return func(network, address string, conn syscall.RawConn) error { - interfaceName := interfaceNameFunc(network, address) - if interfaceName == "" { - return nil - } - index, err := manager.IndexByName(interfaceName) - if err != nil { - return err - } - return bindToInterface(conn, network, index) - } -} - -func BindToInterfaceIndexFunc(interfaceIndexFunc func(network, address string) int) Func { - return func(network, address string, conn syscall.RawConn) error { - index := interfaceIndexFunc(network, address) - return bindToInterface(conn, network, index) - } -} - -func bindToInterface(conn syscall.RawConn, network string, index int) error { return Raw(conn, func(fd uintptr) error { switch network { case "tcp6", "udp6": - return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, index) + return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, interfaceIndex) default: - return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, index) + return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, interfaceIndex) } }) } diff --git a/common/control/bind_finder.go b/common/control/bind_finder.go new file mode 100644 index 0000000..21820fb --- /dev/null +++ b/common/control/bind_finder.go @@ -0,0 +1,30 @@ +package control + +import "net" + +type InterfaceFinder interface { + InterfaceIndexByName(name string) (int, error) + InterfaceNameByIndex(index int) (string, error) +} + +func DefaultInterfaceFinder() InterfaceFinder { + return (*netInterfaceFinder)(nil) +} + +type netInterfaceFinder struct{} + +func (w *netInterfaceFinder) InterfaceIndexByName(name string) (int, error) { + netInterface, err := net.InterfaceByName(name) + if err != nil { + return 0, err + } + return netInterface.Index, nil +} + +func (w *netInterfaceFinder) InterfaceNameByIndex(index int) (string, error) { + netInterface, err := net.InterfaceByIndex(index) + if err != nil { + return "", err + } + return netInterface.Name, nil +} diff --git a/common/control/bind_linux.go b/common/control/bind_linux.go index 91be667..6ebca49 100644 --- a/common/control/bind_linux.go +++ b/common/control/bind_linux.go @@ -2,32 +2,12 @@ package control import ( "syscall" + + "golang.org/x/sys/unix" ) -func NewBindManager() BindManager { - return nil -} - -func BindToInterface(manager BindManager, interfaceName string) Func { - return func(network, address string, conn syscall.RawConn) error { - return Raw(conn, func(fd uintptr) error { - return syscall.BindToDevice(int(fd), interfaceName) - }) - } -} - -func BindToInterfaceFunc(manager BindManager, interfaceNameFunc func(network, address string) string) Func { - return func(network, address string, conn syscall.RawConn) error { - interfaceName := interfaceNameFunc(network, address) - if interfaceName == "" { - return nil - } - return Raw(conn, func(fd uintptr) error { - return syscall.BindToDevice(int(fd), interfaceName) - }) - } -} - -func BindToInterfaceIndexFunc(interfaceIndexFunc func(network, address string) int) Func { - return nil +func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { + return Raw(conn, func(fd uintptr) error { + return unix.BindToDevice(int(fd), interfaceName) + }) } diff --git a/common/control/bind_other.go b/common/control/bind_other.go index ad13c0c..27d0497 100644 --- a/common/control/bind_other.go +++ b/common/control/bind_other.go @@ -2,18 +2,8 @@ package control -func NewBindManager() BindManager { - return nil -} +import "syscall" -func BindToInterface(manager BindManager, interfaceName string) Func { - return nil -} - -func BindToInterfaceFunc(manager BindManager, interfaceNameFunc func(network, address string) string) Func { - return nil -} - -func BindToInterfaceIndexFunc(interfaceIndexFunc func(network, address string) int) Func { +func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { return nil } diff --git a/common/control/bind_windows.go b/common/control/bind_windows.go index 7a3c9e0..5e23bf1 100644 --- a/common/control/bind_windows.go +++ b/common/control/bind_windows.go @@ -2,46 +2,17 @@ package control import ( "encoding/binary" - "net" - "net/netip" "syscall" "unsafe" + + M "github.com/sagernet/sing/common/metadata" ) -const ( - IP_UNICAST_IF = 31 - IPV6_UNICAST_IF = 31 -) - -func NewBindManager() BindManager { - return &myBindManager{ - interfaceIndexByName: make(map[string]int), - } -} - -func bind4(handle syscall.Handle, ifaceIdx int) error { - var bytes [4]byte - binary.BigEndian.PutUint32(bytes[:], uint32(ifaceIdx)) - idx := *(*uint32)(unsafe.Pointer(&bytes[0])) - return syscall.SetsockoptInt(handle, syscall.IPPROTO_IP, IP_UNICAST_IF, int(idx)) -} - -func bind6(handle syscall.Handle, ifaceIdx int) error { - return syscall.SetsockoptInt(handle, syscall.IPPROTO_IPV6, IPV6_UNICAST_IF, ifaceIdx) -} - -func bindInterfaceIndex(network string, address string, conn syscall.RawConn, interfaceIndex int) error { - ipStr, _, err := net.SplitHostPort(address) - if err == nil { - if ip, err := netip.ParseAddr(ipStr); err == nil && !ip.IsGlobalUnicast() { - return err - } - } +func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { return Raw(conn, func(fd uintptr) error { handle := syscall.Handle(fd) - // handle ip empty, e.g. net.Listen("udp", ":0") - if ipStr == "" { - err = bind4(handle, interfaceIndex) + if M.ParseSocksaddr(address).AddrString() == "" { + err := bind4(handle, interfaceIndex) if err != nil { return err } @@ -58,36 +29,18 @@ func bindInterfaceIndex(network string, address string, conn syscall.RawConn, in }) } -func BindToInterface(manager BindManager, interfaceName string) Func { - return func(network, address string, conn syscall.RawConn) error { - index, err := manager.IndexByName(interfaceName) - if err != nil { - return err - } - return bindInterfaceIndex(network, address, conn, index) - } +const ( + IP_UNICAST_IF = 31 + IPV6_UNICAST_IF = 31 +) + +func bind4(handle syscall.Handle, ifaceIdx int) error { + var bytes [4]byte + binary.BigEndian.PutUint32(bytes[:], uint32(ifaceIdx)) + idx := *(*uint32)(unsafe.Pointer(&bytes[0])) + return syscall.SetsockoptInt(handle, syscall.IPPROTO_IP, IP_UNICAST_IF, int(idx)) } -func BindToInterfaceFunc(manager BindManager, interfaceNameFunc func(network, address string) string) Func { - return func(network, address string, conn syscall.RawConn) error { - interfaceName := interfaceNameFunc(network, address) - if interfaceName == "" { - return nil - } - index, err := manager.IndexByName(interfaceName) - if err != nil { - return err - } - return bindInterfaceIndex(network, address, conn, index) - } -} - -func BindToInterfaceIndexFunc(interfaceIndexFunc func(network, address string) int) Func { - return func(network, address string, conn syscall.RawConn) error { - index := interfaceIndexFunc(network, address) - if index == -1 { - return nil - } - return bindInterfaceIndex(network, address, conn, index) - } +func bind6(handle syscall.Handle, ifaceIdx int) error { + return syscall.SetsockoptInt(handle, syscall.IPPROTO_IPV6, IPV6_UNICAST_IF, ifaceIdx) } diff --git a/go.mod b/go.mod index d83e689..fabefb4 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/sagernet/sing go 1.18 -require golang.org/x/sys v0.0.0-20220818161305-2296e01440c6 +require golang.org/x/sys v0.0.0-20220913120320-3275c407cedc diff --git a/go.sum b/go.sum index 65066cc..6b04132 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -golang.org/x/sys v0.0.0-20220818161305-2296e01440c6 h1:Sx/u41w+OwrInGdEckYmEuU5gHoGSL4QbDz3S9s6j4U= -golang.org/x/sys v0.0.0-20220818161305-2296e01440c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220913120320-3275c407cedc h1:dpclq5m2YrqPGStKmtw7IcNbKLfbIqKXvNxDJKdIKYc= +golang.org/x/sys v0.0.0-20220913120320-3275c407cedc/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=