Use api to create windows firewall rules

This commit is contained in:
dyhkwong 2023-05-20 12:11:00 +08:00 committed by GitHub
parent 91df97aee2
commit b02f252916
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 321 additions and 67 deletions

2
go.mod
View file

@ -4,9 +4,11 @@ go 1.18
require (
github.com/fsnotify/fsnotify v1.6.0
github.com/go-ole/go-ole v1.2.6
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97
github.com/sagernet/sing v0.2.4
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9
golang.org/x/net v0.9.0
golang.org/x/sys v0.7.0
gvisor.dev/gvisor v0.0.0-20230415003630-3981d5d5e523

5
go.sum
View file

@ -1,5 +1,7 @@
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 h1:5+m7c6AkmAylhauulqN/c5dnh8/KssrE9c93TQrXldA=
@ -9,10 +11,13 @@ github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJ
github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
github.com/sagernet/sing v0.2.4 h1:gC8BR5sglbJZX23RtMyFa8EETP9YEUADhfbEzU1yVbo=
github.com/sagernet/sing v0.2.4/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w=
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 h1:rc/CcqLH3lh8n+csdOuDfP+NuykE0U6AeYSJJHKDgSg=
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9/go.mod h1:a/83NAfUXvEuLpmxDssAXxgUgrEy12MId3Wd7OTs76s=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

274
internal/winfw/winfw.go Normal file
View file

@ -0,0 +1,274 @@
// Copyright (c) 2018 Samuel Melrose
// SPDX-License-Identifier: MIT
// https://github.com/iamacarpet/go-win64api/blob/ef6dbdd6db97301ae08a55eedea773476985a602/firewall.go
//go:build windows
package winfw
import (
"fmt"
"runtime"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
"github.com/scjalliance/comshim"
)
// Firewall related API constants.
const (
NET_FW_IP_PROTOCOL_TCP = 6
NET_FW_IP_PROTOCOL_UDP = 17
NET_FW_IP_PROTOCOL_ICMPv4 = 1
NET_FW_IP_PROTOCOL_ICMPv6 = 58
NET_FW_IP_PROTOCOL_ANY = 256
NET_FW_RULE_DIR_IN = 1
NET_FW_RULE_DIR_OUT = 2
NET_FW_ACTION_BLOCK = 0
NET_FW_ACTION_ALLOW = 1
// NET_FW_PROFILE2_CURRENT is not real API constant, just helper used in FW functions.
// It can mean one profile or multiple (even all) profiles. It depends on which profiles
// are currently in use. Every active interface can have it's own profile. F.e.: Public for Wifi,
// Domain for VPN, and Private for LAN. All at the same time.
NET_FW_PROFILE2_CURRENT = 0
NET_FW_PROFILE2_DOMAIN = 1
NET_FW_PROFILE2_PRIVATE = 2
NET_FW_PROFILE2_PUBLIC = 4
NET_FW_PROFILE2_ALL = 2147483647
)
// Firewall Rule Groups
// Use this magical strings instead of group names. It will work on all language Windows versions.
// You can find more string locations here:
// https://windows10dll.nirsoft.net/firewallapi_dll.html
const (
NET_FW_FILE_AND_PRINTER_SHARING = "@FirewallAPI.dll,-28502"
NET_FW_REMOTE_DESKTOP = "@FirewallAPI.dll,-28752"
)
// FWRule represents Firewall Rule.
type FWRule struct {
Name, Description, ApplicationName, ServiceName string
LocalPorts, RemotePorts string
// LocalAddresses, RemoteAddresses are always returned with netmask, f.e.:
// `10.10.1.1/255.255.255.0`
LocalAddresses, RemoteAddresses string
// ICMPTypesAndCodes is string. You can find define multiple codes separated by ":" (colon).
// Types are listed here:
// https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml
// So to allow ping set it to:
// "0"
ICMPTypesAndCodes string
Grouping string
// InterfaceTypes can be:
// "LAN", "Wireless", "RemoteAccess", "All"
// You can add multiple deviding with comma:
// "LAN, Wireless"
InterfaceTypes string
Protocol, Direction, Action, Profiles int32
Enabled, EdgeTraversal bool
}
// FirewallRuleAddAdvanced allows to modify almost all available FW Rule parameters.
// You probably do not want to use this, as function allows to create any rule, even opening all ports
// in given profile. So use with caution.
func FirewallRuleAddAdvanced(rule FWRule) (bool, error) {
return firewallRuleAdd(rule.Name, rule.Description, rule.Grouping, rule.ApplicationName, rule.ServiceName,
rule.LocalPorts, rule.RemotePorts, rule.LocalAddresses, rule.RemoteAddresses, rule.ICMPTypesAndCodes,
rule.Protocol, rule.Direction, rule.Action, rule.Profiles, rule.Enabled, rule.EdgeTraversal)
}
// firewallRuleAdd is universal function to add all kinds of rules.
func firewallRuleAdd(name, description, group, appPath, serviceName, ports, remotePorts, localAddresses, remoteAddresses, icmpTypes string, protocol, direction, action, profile int32, enabled, edgeTraversal bool) (bool, error) {
if name == "" {
return false, fmt.Errorf("empty FW Rule name, name is mandatory")
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
u, fwPolicy, err := firewallAPIInit()
if err != nil {
return false, err
}
defer firewallAPIRelease(u, fwPolicy)
if profile == NET_FW_PROFILE2_CURRENT {
currentProfiles, err := oleutil.GetProperty(fwPolicy, "CurrentProfileTypes")
if err != nil {
return false, fmt.Errorf("Failed to get CurrentProfiles: %s", err)
}
profile = currentProfiles.Value().(int32)
}
unknownRules, err := oleutil.GetProperty(fwPolicy, "Rules")
if err != nil {
return false, fmt.Errorf("Failed to get Rules: %s", err)
}
rules := unknownRules.ToIDispatch()
if ok, err := FirewallRuleExistsByName(rules, name); err != nil {
return false, fmt.Errorf("Error while checking rules for duplicate: %s", err)
} else if ok {
return false, nil
}
unknown2, err := oleutil.CreateObject("HNetCfg.FWRule")
if err != nil {
return false, fmt.Errorf("Error creating Rule object: %s", err)
}
defer unknown2.Release()
fwRule, err := unknown2.QueryInterface(ole.IID_IDispatch)
if err != nil {
return false, fmt.Errorf("Error creating Rule object (2): %s", err)
}
defer fwRule.Release()
if _, err := oleutil.PutProperty(fwRule, "Name", name); err != nil {
return false, fmt.Errorf("Error setting property (Name) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Description", description); err != nil {
return false, fmt.Errorf("Error setting property (Description) of Rule: %s", err)
}
if appPath != "" {
if _, err := oleutil.PutProperty(fwRule, "Applicationname", appPath); err != nil {
return false, fmt.Errorf("Error setting property (Applicationname) of Rule: %s", err)
}
}
if serviceName != "" {
if _, err := oleutil.PutProperty(fwRule, "ServiceName", serviceName); err != nil {
return false, fmt.Errorf("Error setting property (ServiceName) of Rule: %s", err)
}
}
if protocol != 0 {
if _, err := oleutil.PutProperty(fwRule, "Protocol", protocol); err != nil {
return false, fmt.Errorf("Error setting property (Protocol) of Rule: %s", err)
}
}
if icmpTypes != "" {
if _, err := oleutil.PutProperty(fwRule, "IcmpTypesAndCodes", icmpTypes); err != nil {
return false, fmt.Errorf("Error setting property (IcmpTypesAndCodes) of Rule: %s", err)
}
}
if ports != "" {
if _, err := oleutil.PutProperty(fwRule, "LocalPorts", ports); err != nil {
return false, fmt.Errorf("Error setting property (LocalPorts) of Rule: %s", err)
}
}
if remotePorts != "" {
if _, err := oleutil.PutProperty(fwRule, "RemotePorts", remotePorts); err != nil {
return false, fmt.Errorf("Error setting property (RemotePorts) of Rule: %s", err)
}
}
if localAddresses != "" {
if _, err := oleutil.PutProperty(fwRule, "LocalAddresses", localAddresses); err != nil {
return false, fmt.Errorf("Error setting property (LocalAddresses) of Rule: %s", err)
}
}
if remoteAddresses != "" {
if _, err := oleutil.PutProperty(fwRule, "RemoteAddresses", remoteAddresses); err != nil {
return false, fmt.Errorf("Error setting property (RemoteAddresses) of Rule: %s", err)
}
}
if direction != 0 {
if _, err := oleutil.PutProperty(fwRule, "Direction", direction); err != nil {
return false, fmt.Errorf("Error setting property (Direction) of Rule: %s", err)
}
}
if _, err := oleutil.PutProperty(fwRule, "Enabled", enabled); err != nil {
return false, fmt.Errorf("Error setting property (Enabled) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Grouping", group); err != nil {
return false, fmt.Errorf("Error setting property (Grouping) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Profiles", profile); err != nil {
return false, fmt.Errorf("Error setting property (Profiles) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Action", action); err != nil {
return false, fmt.Errorf("Error setting property (Action) of Rule: %s", err)
}
if edgeTraversal {
if _, err := oleutil.PutProperty(fwRule, "EdgeTraversal", edgeTraversal); err != nil {
return false, fmt.Errorf("Error setting property (EdgeTraversal) of Rule: %s", err)
}
}
if _, err := oleutil.CallMethod(rules, "Add", fwRule); err != nil {
return false, fmt.Errorf("Error adding Rule: %s", err)
}
return true, nil
}
func FirewallRuleExistsByName(rules *ole.IDispatch, name string) (bool, error) {
enumProperty, err := rules.GetProperty("_NewEnum")
if err != nil {
return false, fmt.Errorf("Failed to get enumeration property on Rules: %s", err)
}
defer enumProperty.Clear()
enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
if err != nil {
return false, fmt.Errorf("Failed to cast enum to correct type: %s", err)
}
if enum == nil {
return false, fmt.Errorf("can't get IEnumVARIANT, enum is nil")
}
for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
if err != nil {
return false, fmt.Errorf("Failed to seek next Rule item: %s", err)
}
t, err := func() (bool, error) {
item := itemRaw.ToIDispatch()
defer item.Release()
if item, err := oleutil.GetProperty(item, "Name"); err != nil {
return false, fmt.Errorf("Failed to get Property (Name) of Rule")
} else if item.ToString() == name {
return true, nil
}
return false, nil
}()
if err != nil {
return false, err
} else if t {
return true, nil
}
}
return false, nil
}
// firewallAPIInit initialize common fw api.
// then:
// dispatch firewallAPIRelease(u, fwp)
func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) {
comshim.Add(1)
unknown, err := oleutil.CreateObject("HNetCfg.FwPolicy2")
if err != nil {
return nil, nil, fmt.Errorf("Failed to create FwPolicy Object: %s", err)
}
fwPolicy, err := unknown.QueryInterface(ole.IID_IDispatch)
if err != nil {
unknown.Release()
return nil, nil, fmt.Errorf("Failed to create FwPolicy Object (2): %s", err)
}
return unknown, fwPolicy, nil
}
// firewallAPIRelease cleans memory.
func firewallAPIRelease(u *ole.IUnknown, fwp *ole.IDispatch) {
fwp.Release()
u.Release()
comshim.Done()
}

View file

@ -15,20 +15,19 @@ type Stack interface {
}
type StackOptions struct {
Context context.Context
Tun Tun
Name string
MTU uint32
Inet4Address []netip.Prefix
Inet6Address []netip.Prefix
EndpointIndependentNat bool
UDPTimeout int64
Router Router
Handler Handler
Logger logger.Logger
ForwarderBindInterface bool
InterfaceFinder control.InterfaceFinder
ExperimentalFixWindowsFirewall bool
Context context.Context
Tun Tun
Name string
MTU uint32
Inet4Address []netip.Prefix
Inet6Address []netip.Prefix
EndpointIndependentNat bool
UDPTimeout int64
Router Router
Handler Handler
Logger logger.Logger
ForwarderBindInterface bool
InterfaceFinder control.InterfaceFinder
}
func NewStack(

View file

@ -42,7 +42,6 @@ type System struct {
routeMapping *RouteMapping
bindInterface bool
interfaceFinder control.InterfaceFinder
fixWindowsFirewall bool
}
type Session struct {
@ -54,19 +53,18 @@ type Session struct {
func NewSystem(options StackOptions) (Stack, error) {
stack := &System{
ctx: options.Context,
tun: options.Tun,
tunName: options.Name,
mtu: options.MTU,
udpTimeout: options.UDPTimeout,
router: options.Router,
handler: options.Handler,
logger: options.Logger,
inet4Prefixes: options.Inet4Address,
inet6Prefixes: options.Inet6Address,
bindInterface: options.ForwarderBindInterface,
interfaceFinder: options.InterfaceFinder,
fixWindowsFirewall: options.ExperimentalFixWindowsFirewall,
ctx: options.Context,
tun: options.Tun,
tunName: options.Name,
mtu: options.MTU,
udpTimeout: options.UDPTimeout,
router: options.Router,
handler: options.Handler,
logger: options.Logger,
inet4Prefixes: options.Inet4Address,
inet6Prefixes: options.Inet6Address,
bindInterface: options.ForwarderBindInterface,
interfaceFinder: options.InterfaceFinder,
}
if stack.router != nil {
stack.routeMapping = NewRouteMapping(options.UDPTimeout)
@ -99,11 +97,9 @@ func (s *System) Close() error {
}
func (s *System) Start() error {
if s.fixWindowsFirewall {
err := fixWindowsFirewall()
if err != nil {
return E.Cause(err, "fix windows firewall for system stack")
}
err := fixWindowsFirewall()
if err != nil {
return E.Cause(err, "fix windows firewall for system stack")
}
var listener net.ListenConfig
if s.bindInterface {

View file

@ -2,46 +2,24 @@ package tun
import (
"os"
"os/exec"
"path/filepath"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/shell"
"github.com/sagernet/sing-tun/internal/winfw"
)
func fixWindowsFirewall() error {
const shellStringSplit = "\""
isPWSH := true
powershell, err := exec.LookPath("pwsh.exe")
absPath, err := filepath.Abs(os.Args[0])
if err != nil {
powershell, err = exec.LookPath("powershell.exe")
isPWSH = false
return err
}
if err != nil {
return nil
rule := winfw.FWRule{
Name: "sing-tun (" + absPath + ")",
ApplicationName: absPath,
Enabled: true,
Protocol: winfw.NET_FW_IP_PROTOCOL_TCP,
Direction: winfw.NET_FW_RULE_DIR_IN,
Action: winfw.NET_FW_ACTION_ALLOW,
}
ruleName := "sing-tun rule for " + os.Args[0]
commandPrefix := []string{"-NoProfile", "-NonInteractive"}
if isPWSH {
commandPrefix = append(commandPrefix, "-Command")
}
err = shell.Exec(powershell, append(commandPrefix,
F.ToString("Get-NetFirewallRule -Name ", shellStringSplit, ruleName, shellStringSplit))...).Run()
if err == nil {
return nil
}
fileName := filepath.Base(os.Args[0])
output, err := shell.Exec(powershell, append(commandPrefix,
F.ToString("New-NetFirewallRule",
" -Name ", shellStringSplit, ruleName, shellStringSplit,
" -DisplayName ", shellStringSplit, "sing-tun (", fileName, ")", shellStringSplit,
" -Program ", shellStringSplit, os.Args[0], shellStringSplit,
" -Direction Inbound",
" -Protocol TCP",
" -Action Allow"))...).Read()
if err != nil {
return E.Extend(err, output)
}
return nil
_, err = winfw.FirewallRuleAddAdvanced(rule)
return err
}