diff --git a/stack.go b/stack.go index 71a183d..bdb2818 100644 --- a/stack.go +++ b/stack.go @@ -15,19 +15,20 @@ type Stack interface { } type StackOptions struct { - Context context.Context - Tun Tun - Name string - MTU uint32 - Inet4Address []netip.Prefix - Inet6Address []netip.Prefix - EndpointIndependentNat bool - UDPTimeout int64 - Router Router - Handler Handler - Logger logger.Logger - ForwarderBindInterface bool - InterfaceFinder control.InterfaceFinder + Context context.Context + Tun Tun + Name string + MTU uint32 + Inet4Address []netip.Prefix + Inet6Address []netip.Prefix + EndpointIndependentNat bool + UDPTimeout int64 + Router Router + Handler Handler + Logger logger.Logger + ForwarderBindInterface bool + InterfaceFinder control.InterfaceFinder + ExperimentalFixWindowsFirewall bool } func NewStack( diff --git a/system.go b/system.go index 356c93b..123e1b4 100644 --- a/system.go +++ b/system.go @@ -42,6 +42,7 @@ type System struct { routeMapping *RouteMapping bindInterface bool interfaceFinder control.InterfaceFinder + fixWindowsFirewall bool } type Session struct { @@ -53,18 +54,19 @@ type Session struct { func NewSystem(options StackOptions) (Stack, error) { stack := &System{ - ctx: options.Context, - tun: options.Tun, - tunName: options.Name, - mtu: options.MTU, - udpTimeout: options.UDPTimeout, - router: options.Router, - handler: options.Handler, - logger: options.Logger, - inet4Prefixes: options.Inet4Address, - inet6Prefixes: options.Inet6Address, - bindInterface: options.ForwarderBindInterface, - interfaceFinder: options.InterfaceFinder, + ctx: options.Context, + tun: options.Tun, + tunName: options.Name, + mtu: options.MTU, + udpTimeout: options.UDPTimeout, + router: options.Router, + handler: options.Handler, + logger: options.Logger, + inet4Prefixes: options.Inet4Address, + inet6Prefixes: options.Inet6Address, + bindInterface: options.ForwarderBindInterface, + interfaceFinder: options.InterfaceFinder, + fixWindowsFirewall: options.ExperimentalFixWindowsFirewall, } if stack.router != nil { stack.routeMapping = NewRouteMapping(options.UDPTimeout) @@ -97,6 +99,12 @@ func (s *System) Close() error { } func (s *System) Start() error { + if s.fixWindowsFirewall { + err := fixWindowsFirewall() + if err != nil { + return E.Cause(err, "fix windows firewall for system stack") + } + } var listener net.ListenConfig if s.bindInterface { listener.Control = control.Append(listener.Control, func(network, address string, conn syscall.RawConn) error { diff --git a/system_nonwindows.go b/system_nonwindows.go new file mode 100644 index 0000000..15b8741 --- /dev/null +++ b/system_nonwindows.go @@ -0,0 +1,7 @@ +//go:build !windows + +package tun + +func fixWindowsFirewall() error { + return nil +} diff --git a/system_windows.go b/system_windows.go new file mode 100644 index 0000000..970f438 --- /dev/null +++ b/system_windows.go @@ -0,0 +1,47 @@ +package tun + +import ( + "os" + "os/exec" + "path/filepath" + + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/shell" +) + +func fixWindowsFirewall() error { + const shellStringSplit = "\"" + isPWSH := true + powershell, err := exec.LookPath("pwsh.exe") + if err != nil { + powershell, err = exec.LookPath("powershell.exe") + isPWSH = false + } + if err != nil { + return nil + } + ruleName := "sing-tun rule for " + os.Args[0] + commandPrefix := []string{"-NoProfile", "-NonInteractive"} + if isPWSH { + commandPrefix = append(commandPrefix, "-Command") + } + err = shell.Exec(powershell, append(commandPrefix, + F.ToString("Get-NetFirewallRule -Name ", shellStringSplit, ruleName, shellStringSplit))...).Run() + if err == nil { + return nil + } + fileName := filepath.Base(os.Args[0]) + output, err := shell.Exec(powershell, append(commandPrefix, + F.ToString("New-NetFirewallRule", + " -Name ", shellStringSplit, ruleName, shellStringSplit, + " -DisplayName ", shellStringSplit, "sing-tun (", fileName, ")", shellStringSplit, + " -Program ", shellStringSplit, os.Args[0], shellStringSplit, + " -Direction Inbound", + " -Protocol TCP", + " -Action Allow"))...).Read() + if err != nil { + return E.Extend(err, output) + } + return nil +}