diff --git a/.github/workflows/debug.yml b/.github/workflows/debug.yml index 4b3e5d3..fd2846a 100644 --- a/.github/workflows/debug.yml +++ b/.github/workflows/debug.yml @@ -37,4 +37,4 @@ jobs: go mod init build go get -v github.com/sagernet/sing-tun@$version popd - go build -v ./... \ No newline at end of file + go build -v . \ No newline at end of file diff --git a/README.md b/README.md index 366eff6..1899924 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Simple transparent proxy library. -Currently only for linux. +Currently only for linux and windows. ## License diff --git a/format.go b/format.go index e9ee0d9..cb9e463 100644 --- a/format.go +++ b/format.go @@ -1,7 +1,7 @@ package tun //go:generate go install -v mvdan.cc/gofumpt@latest -//go:generate go install -v github.com/daixiang0/gci@latest +//go:generate go install -v github.com/daixiang0/gci@v0.4.0 //go:generate gofumpt -l -w . -//go:generate gofmt -s -w .k +//go:generate gofmt -s -w . //go:generate gci write -s "standard,prefix(github.com/sagernet/),default" . diff --git a/go.mod b/go.mod index 3591863..eab189a 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,12 @@ go 1.18 require ( github.com/sagernet/sing v0.0.0-20220711062652-4394f7cbbae1 github.com/vishvananda/netlink v1.1.0 + golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e gvisor.dev/gvisor v0.0.0-20220711011657-cecae2f4234d ) require ( github.com/google/btree v1.0.1 // indirect github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect - golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect ) diff --git a/go.sum b/go.sum index 9ba509b..04c97ed 100644 --- a/go.sum +++ b/go.sum @@ -9,8 +9,8 @@ github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695AP github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 h1:id054HUawV2/6IGm2IV8KZQjqtwAOo2CYlOToYqa0d0= -golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e h1:NHvCuwuS43lGnYhten69ZWqi2QOj/CiDNcKbVqwVoew= +golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gvisor.dev/gvisor v0.0.0-20220711011657-cecae2f4234d h1:KjI6i6P1ib9DiNdNIN8pb2TXfBewpKHf3O58cjj9vw4= diff --git a/gvisor.go b/gvisor.go index 5d49ac7..5bae991 100644 --- a/gvisor.go +++ b/gvisor.go @@ -3,6 +3,11 @@ package tun import ( "context" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -13,34 +18,29 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" ) const defaultNIC tcpip.NICID = 1 type GVisorTun struct { ctx context.Context - tunFd uintptr + tun Tun tunMtu uint32 handler Handler stack *stack.Stack } -func NewGVisor(ctx context.Context, tunFd uintptr, tunMtu uint32, handler Handler) *GVisorTun { +func NewGVisor(ctx context.Context, tun Tun, tunMtu uint32, handler Handler) *GVisorTun { return &GVisorTun{ ctx: ctx, - tunFd: tunFd, + tun: tun, tunMtu: tunMtu, handler: handler, } } func (t *GVisorTun) Start() error { - linkEndpoint, err := NewEndpoint(t.tunFd, t.tunMtu) + linkEndpoint, err := t.tun.NewEndpoint() if err != nil { return err } diff --git a/gvisor_linux.go b/gvisor_linux.go deleted file mode 100644 index 0cc35a6..0000000 --- a/gvisor_linux.go +++ /dev/null @@ -1,22 +0,0 @@ -package tun - -import ( - "runtime" - - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -func NewEndpoint(tunFd uintptr, tunMtu uint32) (stack.LinkEndpoint, error) { - var packetDispatchMode fdbased.PacketDispatchMode - if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" { - packetDispatchMode = fdbased.PacketMMap - } else { - packetDispatchMode = fdbased.RecvMMsg - } - return fdbased.New(&fdbased.Options{ - FDs: []int{int(tunFd)}, - MTU: tunMtu, - PacketDispatchMode: packetDispatchMode, - }) -} diff --git a/gvisor_nonlinux.go b/gvisor_nonlinux.go deleted file mode 100644 index 6fb5fd5..0000000 --- a/gvisor_nonlinux.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build !linux - -package tun - -import "gvisor.dev/gvisor/pkg/tcpip/stack" - -func NewEndpoint(tunFd uintptr, tunMtu uint32) (stack.LinkEndpoint, error) { - return NewPosixEndpoint(tunFd, tunMtu) -} diff --git a/gvisor_posix.go b/gvisor_posix.go deleted file mode 100644 index 1fda37b..0000000 --- a/gvisor_posix.go +++ /dev/null @@ -1,118 +0,0 @@ -package tun - -import ( - "os" - - gBuffer "gvisor.dev/gvisor/pkg/buffer" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/rw" -) - -var _ stack.LinkEndpoint = (*PosixEndpoint)(nil) - -type PosixEndpoint struct { - fd uintptr - mtu uint32 - file *os.File - dispatcher stack.NetworkDispatcher -} - -func NewPosixEndpoint(tunFd uintptr, tunMtu uint32) (stack.LinkEndpoint, error) { - return &PosixEndpoint{ - fd: tunFd, - mtu: tunMtu, - file: os.NewFile(tunFd, "tun"), - }, nil -} - -func (e *PosixEndpoint) MTU() uint32 { - return e.mtu -} - -func (e *PosixEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -func (e *PosixEndpoint) LinkAddress() tcpip.LinkAddress { - return "" -} - -func (e *PosixEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return stack.CapabilityNone -} - -func (e *PosixEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - if dispatcher == nil && e.dispatcher != nil { - e.dispatcher = nil - return - } - if dispatcher != nil && e.dispatcher == nil { - e.dispatcher = dispatcher - go e.dispatchLoop() - } -} - -func (e *PosixEndpoint) dispatchLoop() { - _buffer := buf.StackNewPacket() - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - for { - n, err := e.file.Read(buffer.FreeBytes()) - if err != nil { - break - } - var view gBuffer.View - view.Append(buffer.To(n)) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: view, - IsForwardedPacket: true, - }) - defer pkt.DecRef() - var p tcpip.NetworkProtocolNumber - ipHeader, ok := pkt.Data().PullUp(1) - if !ok { - continue - } - switch header.IPVersion(ipHeader) { - case header.IPv4Version: - p = header.IPv4ProtocolNumber - case header.IPv6Version: - p = header.IPv6ProtocolNumber - default: - continue - } - e.dispatcher.DeliverNetworkPacket(p, pkt) - } -} - -func (e *PosixEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -func (e *PosixEndpoint) Wait() { -} - -func (e *PosixEndpoint) ARPHardwareType() header.ARPHardwareType { - return header.ARPHardwareNone -} - -func (e *PosixEndpoint) AddHeader(buffer *stack.PacketBuffer) { -} - -func (e *PosixEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { - var n int - for _, packet := range pkts.AsSlice() { - _, err := rw.WriteV(e.fd, packet.Slices()) - if err != nil { - return n, &tcpip.ErrAborted{} - } - n++ - } - return n, nil -} diff --git a/internal/winipcfg/interface_change_handler.go b/internal/winipcfg/interface_change_handler.go new file mode 100644 index 0000000..af29801 --- /dev/null +++ b/internal/winipcfg/interface_change_handler.go @@ -0,0 +1,88 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +import ( + "sync" + + "golang.org/x/sys/windows" +) + +// InterfaceChangeCallback structure allows interface change callback handling. +type InterfaceChangeCallback struct { + cb func(notificationType MibNotificationType, iface *MibIPInterfaceRow) + wait sync.WaitGroup +} + +var ( + interfaceChangeAddRemoveMutex = sync.Mutex{} + interfaceChangeMutex = sync.Mutex{} + interfaceChangeCallbacks = make(map[*InterfaceChangeCallback]bool) + interfaceChangeHandle = windows.Handle(0) +) + +// RegisterInterfaceChangeCallback registers a new InterfaceChangeCallback. If this particular callback is already +// registered, the function will silently return. Returned InterfaceChangeCallback.Unregister method should be used +// to unregister. +func RegisterInterfaceChangeCallback(callback func(notificationType MibNotificationType, iface *MibIPInterfaceRow)) (*InterfaceChangeCallback, error) { + s := &InterfaceChangeCallback{cb: callback} + + interfaceChangeAddRemoveMutex.Lock() + defer interfaceChangeAddRemoveMutex.Unlock() + + interfaceChangeMutex.Lock() + defer interfaceChangeMutex.Unlock() + + interfaceChangeCallbacks[s] = true + + if interfaceChangeHandle == 0 { + err := notifyIPInterfaceChange(windows.AF_UNSPEC, windows.NewCallback(interfaceChanged), 0, false, &interfaceChangeHandle) + if err != nil { + delete(interfaceChangeCallbacks, s) + interfaceChangeHandle = 0 + return nil, err + } + } + + return s, nil +} + +// Unregister unregisters the callback. +func (callback *InterfaceChangeCallback) Unregister() error { + interfaceChangeAddRemoveMutex.Lock() + defer interfaceChangeAddRemoveMutex.Unlock() + + interfaceChangeMutex.Lock() + delete(interfaceChangeCallbacks, callback) + removeIt := len(interfaceChangeCallbacks) == 0 && interfaceChangeHandle != 0 + interfaceChangeMutex.Unlock() + + callback.wait.Wait() + + if removeIt { + err := cancelMibChangeNotify2(interfaceChangeHandle) + if err != nil { + return err + } + interfaceChangeHandle = 0 + } + + return nil +} + +func interfaceChanged(callerContext uintptr, row *MibIPInterfaceRow, notificationType MibNotificationType) uintptr { + rowCopy := *row + interfaceChangeMutex.Lock() + for cb := range interfaceChangeCallbacks { + cb.wait.Add(1) + go func(cb *InterfaceChangeCallback) { + cb.cb(notificationType, &rowCopy) + cb.wait.Done() + }(cb) + } + interfaceChangeMutex.Unlock() + return 0 +} diff --git a/internal/winipcfg/luid.go b/internal/winipcfg/luid.go new file mode 100644 index 0000000..0c898b8 --- /dev/null +++ b/internal/winipcfg/luid.go @@ -0,0 +1,387 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +import ( + "errors" + "net/netip" + "strings" + + "golang.org/x/sys/windows" +) + +// LUID represents a network interface. +type LUID uint64 + +// IPInterface method retrieves IP information for the specified interface on the local computer. +func (luid LUID) IPInterface(family AddressFamily) (*MibIPInterfaceRow, error) { + row := &MibIPInterfaceRow{} + row.Init() + row.InterfaceLUID = luid + row.Family = family + err := row.get() + if err != nil { + return nil, err + } + return row, nil +} + +// Interface method retrieves information for the specified adapter on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getifentry2 +func (luid LUID) Interface() (*MibIfRow2, error) { + row := &MibIfRow2{} + row.InterfaceLUID = luid + err := row.get() + if err != nil { + return nil, err + } + return row, nil +} + +// GUID method converts a locally unique identifier (LUID) for a network interface to a globally unique identifier (GUID) for the interface. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-convertinterfaceluidtoguid +func (luid LUID) GUID() (*windows.GUID, error) { + guid := &windows.GUID{} + err := convertInterfaceLUIDToGUID(&luid, guid) + if err != nil { + return nil, err + } + return guid, nil +} + +// LUIDFromGUID function converts a globally unique identifier (GUID) for a network interface to the locally unique identifier (LUID) for the interface. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-convertinterfaceguidtoluid +func LUIDFromGUID(guid *windows.GUID) (LUID, error) { + var luid LUID + err := convertInterfaceGUIDToLUID(guid, &luid) + if err != nil { + return 0, err + } + return luid, nil +} + +// LUIDFromIndex function converts a local index for a network interface to the locally unique identifier (LUID) for the interface. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-convertinterfaceindextoluid +func LUIDFromIndex(index uint32) (LUID, error) { + var luid LUID + err := convertInterfaceIndexToLUID(index, &luid) + if err != nil { + return 0, err + } + return luid, nil +} + +// IPAddress method returns MibUnicastIPAddressRow struct that matches to provided 'ip' argument. Corresponds to GetUnicastIpAddressEntry +// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddressentry) +func (luid LUID) IPAddress(addr netip.Addr) (*MibUnicastIPAddressRow, error) { + row := &MibUnicastIPAddressRow{InterfaceLUID: luid} + + err := row.Address.SetAddr(addr) + if err != nil { + return nil, err + } + + err = row.get() + if err != nil { + return nil, err + } + + return row, nil +} + +// AddIPAddress method adds new unicast IP address to the interface. Corresponds to CreateUnicastIpAddressEntry function +// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry). +func (luid LUID) AddIPAddress(address netip.Prefix) error { + row := &MibUnicastIPAddressRow{} + row.Init() + row.InterfaceLUID = luid + row.DadState = DadStatePreferred + row.ValidLifetime = 0xffffffff + row.PreferredLifetime = 0xffffffff + err := row.Address.SetAddr(address.Addr()) + if err != nil { + return err + } + row.OnLinkPrefixLength = uint8(address.Bits()) + return row.Create() +} + +// AddIPAddresses method adds multiple new unicast IP addresses to the interface. Corresponds to CreateUnicastIpAddressEntry function +// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry). +func (luid LUID) AddIPAddresses(addresses []netip.Prefix) error { + for i := range addresses { + err := luid.AddIPAddress(addresses[i]) + if err != nil { + return err + } + } + return nil +} + +// SetIPAddresses method sets new unicast IP addresses to the interface. +func (luid LUID) SetIPAddresses(addresses []netip.Prefix) error { + err := luid.FlushIPAddresses(windows.AF_UNSPEC) + if err != nil { + return err + } + return luid.AddIPAddresses(addresses) +} + +// SetIPAddressesForFamily method sets new unicast IP addresses for a specific family to the interface. +func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []netip.Prefix) error { + err := luid.FlushIPAddresses(family) + if err != nil { + return err + } + for i := range addresses { + if !addresses[i].Addr().Is4() && family == windows.AF_INET { + continue + } else if !addresses[i].Addr().Is6() && family == windows.AF_INET6 { + continue + } + err := luid.AddIPAddress(addresses[i]) + if err != nil { + return err + } + } + return nil +} + +// DeleteIPAddress method deletes interface's unicast IP address. Corresponds to DeleteUnicastIpAddressEntry function +// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry). +func (luid LUID) DeleteIPAddress(address netip.Prefix) error { + row := &MibUnicastIPAddressRow{} + row.Init() + row.InterfaceLUID = luid + err := row.Address.SetAddr(address.Addr()) + if err != nil { + return err + } + // Note: OnLinkPrefixLength member is ignored by DeleteUnicastIpAddressEntry(). + row.OnLinkPrefixLength = uint8(address.Bits()) + return row.Delete() +} + +// FlushIPAddresses method deletes all interface's unicast IP addresses. +func (luid LUID) FlushIPAddresses(family AddressFamily) error { + var tab *mibUnicastIPAddressTable + err := getUnicastIPAddressTable(family, &tab) + if err != nil { + return err + } + t := tab.get() + for i := range t { + if t[i].InterfaceLUID == luid { + t[i].Delete() + } + } + tab.free() + return nil +} + +// Route method returns route determined with the input arguments. Corresponds to GetIpForwardEntry2 function +// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardentry2). +// NOTE: If the corresponding route isn't found, the method will return error. +func (luid LUID) Route(destination netip.Prefix, nextHop netip.Addr) (*MibIPforwardRow2, error) { + row := &MibIPforwardRow2{} + row.Init() + row.InterfaceLUID = luid + row.ValidLifetime = 0xffffffff + row.PreferredLifetime = 0xffffffff + err := row.DestinationPrefix.SetPrefix(destination) + if err != nil { + return nil, err + } + err = row.NextHop.SetAddr(nextHop) + if err != nil { + return nil, err + } + + err = row.get() + if err != nil { + return nil, err + } + return row, nil +} + +// AddRoute method adds a route to the interface. Corresponds to CreateIpForwardEntry2 function, with added splitDefault feature. +// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createipforwardentry2) +func (luid LUID) AddRoute(destination netip.Prefix, nextHop netip.Addr, metric uint32) error { + row := &MibIPforwardRow2{} + row.Init() + row.InterfaceLUID = luid + err := row.DestinationPrefix.SetPrefix(destination) + if err != nil { + return err + } + err = row.NextHop.SetAddr(nextHop) + if err != nil { + return err + } + row.Metric = metric + return row.Create() +} + +// AddRoutes method adds multiple routes to the interface. +func (luid LUID) AddRoutes(routesData []*RouteData) error { + for _, rd := range routesData { + err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric) + if err != nil { + return err + } + } + return nil +} + +// SetRoutes method sets (flush than add) multiple routes to the interface. +func (luid LUID) SetRoutes(routesData []*RouteData) error { + err := luid.FlushRoutes(windows.AF_UNSPEC) + if err != nil { + return err + } + return luid.AddRoutes(routesData) +} + +// SetRoutesForFamily method sets (flush than add) multiple routes for a specific family to the interface. +func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteData) error { + err := luid.FlushRoutes(family) + if err != nil { + return err + } + for _, rd := range routesData { + if !rd.Destination.Addr().Is4() && family == windows.AF_INET { + continue + } else if !rd.Destination.Addr().Is6() && family == windows.AF_INET6 { + continue + } + err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric) + if err != nil { + return err + } + } + return nil +} + +// DeleteRoute method deletes a route that matches the criteria. Corresponds to DeleteIpForwardEntry2 function +// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2). +func (luid LUID) DeleteRoute(destination netip.Prefix, nextHop netip.Addr) error { + row := &MibIPforwardRow2{} + row.Init() + row.InterfaceLUID = luid + err := row.DestinationPrefix.SetPrefix(destination) + if err != nil { + return err + } + err = row.NextHop.SetAddr(nextHop) + if err != nil { + return err + } + err = row.get() + if err != nil { + return err + } + return row.Delete() +} + +// FlushRoutes method deletes all interface's routes. +// It continues on failures, and returns the last error afterwards. +func (luid LUID) FlushRoutes(family AddressFamily) error { + var tab *mibIPforwardTable2 + err := getIPForwardTable2(family, &tab) + if err != nil { + return err + } + t := tab.get() + for i := range t { + if t[i].InterfaceLUID == luid { + err2 := t[i].Delete() + if err2 != nil { + err = err2 + } + } + } + tab.free() + return err +} + +// DNS method returns all DNS server addresses associated with the adapter. +func (luid LUID) DNS() ([]netip.Addr, error) { + addresses, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault) + if err != nil { + return nil, err + } + r := make([]netip.Addr, 0, len(addresses)) + for _, addr := range addresses { + if addr.LUID == luid { + for dns := addr.FirstDNSServerAddress; dns != nil; dns = dns.Next { + if ip := dns.Address.IP(); ip != nil { + if a, ok := netip.AddrFromSlice(ip); ok { + r = append(r, a) + } + } else { + return nil, windows.ERROR_INVALID_PARAMETER + } + } + } + } + return r, nil +} + +// SetDNS method clears previous and associates new DNS servers and search domains with the adapter for a specific family. +func (luid LUID) SetDNS(family AddressFamily, servers []netip.Addr, domains []string) error { + if family != windows.AF_INET && family != windows.AF_INET6 { + return windows.ERROR_PROTOCOL_UNREACHABLE + } + + var filteredServers []string + for _, server := range servers { + if (server.Is4() && family == windows.AF_INET) || (server.Is6() && family == windows.AF_INET6) { + filteredServers = append(filteredServers, server.String()) + } + } + servers16, err := windows.UTF16PtrFromString(strings.Join(filteredServers, ",")) + if err != nil { + return err + } + domains16, err := windows.UTF16PtrFromString(strings.Join(domains, ",")) + if err != nil { + return err + } + guid, err := luid.GUID() + if err != nil { + return err + } + dnsInterfaceSettings := &DnsInterfaceSettings{ + Version: DnsInterfaceSettingsVersion1, + Flags: DnsInterfaceSettingsFlagNameserver | DnsInterfaceSettingsFlagSearchList, + NameServer: servers16, + SearchList: domains16, + } + if family == windows.AF_INET6 { + dnsInterfaceSettings.Flags |= DnsInterfaceSettingsFlagIPv6 + } + // For >= Windows 10 1809 + err = SetInterfaceDnsSettings(*guid, dnsInterfaceSettings) + if err == nil || !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return err + } + + // For < Windows 10 1809 + err = luid.fallbackSetDNSForFamily(family, servers) + if err != nil { + return err + } + if len(domains) > 0 { + return luid.fallbackSetDNSDomain(domains[0]) + } else { + return luid.fallbackSetDNSDomain("") + } +} + +// FlushDNS method clears all DNS servers associated with the adapter. +func (luid LUID) FlushDNS(family AddressFamily) error { + return luid.SetDNS(family, nil, nil) +} diff --git a/internal/winipcfg/mksyscall.go b/internal/winipcfg/mksyscall.go new file mode 100644 index 0000000..d62d38d --- /dev/null +++ b/internal/winipcfg/mksyscall.go @@ -0,0 +1,8 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zwinipcfg_windows.go winipcfg.go diff --git a/internal/winipcfg/netsh.go b/internal/winipcfg/netsh.go new file mode 100644 index 0000000..4f8e5b1 --- /dev/null +++ b/internal/winipcfg/netsh.go @@ -0,0 +1,108 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/netip" + "os/exec" + "path/filepath" + "strings" + "syscall" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" +) + +func runNetsh(cmds []string) error { + system32, err := windows.GetSystemDirectory() + if err != nil { + return err + } + cmd := exec.Command(filepath.Join(system32, "netsh.exe")) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + + stdin, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("runNetsh stdin pipe - %w", err) + } + go func() { + defer stdin.Close() + io.WriteString(stdin, strings.Join(append(cmds, "exit\r\n"), "\r\n")) + }() + output, err := cmd.CombinedOutput() + // Horrible kludges, sorry. + cleaned := bytes.ReplaceAll(output, []byte{'\r', '\n'}, []byte{'\n'}) + cleaned = bytes.ReplaceAll(cleaned, []byte("netsh>"), []byte{}) + cleaned = bytes.ReplaceAll(cleaned, []byte("There are no Domain Name Servers (DNS) configured on this computer."), []byte{}) + cleaned = bytes.TrimSpace(cleaned) + if len(cleaned) != 0 && err == nil { + return fmt.Errorf("netsh: %#q", string(cleaned)) + } else if err != nil { + return fmt.Errorf("netsh: %v: %#q", err, string(cleaned)) + } + return nil +} + +const ( + netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both" + netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both" + netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no" + netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no" +) + +func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []netip.Addr) error { + var templateFlush string + if family == windows.AF_INET { + templateFlush = netshCmdTemplateFlush4 + } else if family == windows.AF_INET6 { + templateFlush = netshCmdTemplateFlush6 + } + + cmds := make([]string, 0, 1+len(dnses)) + ipif, err := luid.IPInterface(family) + if err != nil { + return err + } + cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex)) + for i := 0; i < len(dnses); i++ { + if dnses[i].Is4() && family == windows.AF_INET { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, dnses[i].String())) + } else if dnses[i].Is6() && family == windows.AF_INET6 { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, dnses[i].String())) + } + } + return runNetsh(cmds) +} + +func (luid LUID) fallbackSetDNSDomain(domain string) error { + guid, err := luid.GUID() + if err != nil { + return fmt.Errorf("Error converting luid to guid: %w", err) + } + key, err := registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", guid), registry.QUERY_VALUE) + if err != nil { + return fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %w", err) + } + paths, _, err := key.GetStringsValue("IpConfig") + key.Close() + if err != nil { + return fmt.Errorf("Error reading IpConfig registry key: %w", err) + } + if len(paths) == 0 { + return errors.New("No TCP/IP interfaces found on adapter") + } + key, err = registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), registry.SET_VALUE) + if err != nil { + return fmt.Errorf("Unable to open TCP/IP network registry key: %w", err) + } + err = key.SetStringValue("Domain", domain) + key.Close() + return err +} diff --git a/internal/winipcfg/route_change_handler.go b/internal/winipcfg/route_change_handler.go new file mode 100644 index 0000000..4b78331 --- /dev/null +++ b/internal/winipcfg/route_change_handler.go @@ -0,0 +1,88 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +import ( + "sync" + + "golang.org/x/sys/windows" +) + +// RouteChangeCallback structure allows route change callback handling. +type RouteChangeCallback struct { + cb func(notificationType MibNotificationType, route *MibIPforwardRow2) + wait sync.WaitGroup +} + +var ( + routeChangeAddRemoveMutex = sync.Mutex{} + routeChangeMutex = sync.Mutex{} + routeChangeCallbacks = make(map[*RouteChangeCallback]bool) + routeChangeHandle = windows.Handle(0) +) + +// RegisterRouteChangeCallback registers a new RouteChangeCallback. If this particular callback is already +// registered, the function will silently return. Returned RouteChangeCallback.Unregister method should be used +// to unregister. +func RegisterRouteChangeCallback(callback func(notificationType MibNotificationType, route *MibIPforwardRow2)) (*RouteChangeCallback, error) { + s := &RouteChangeCallback{cb: callback} + + routeChangeAddRemoveMutex.Lock() + defer routeChangeAddRemoveMutex.Unlock() + + routeChangeMutex.Lock() + defer routeChangeMutex.Unlock() + + routeChangeCallbacks[s] = true + + if routeChangeHandle == 0 { + err := notifyRouteChange2(windows.AF_UNSPEC, windows.NewCallback(routeChanged), 0, false, &routeChangeHandle) + if err != nil { + delete(routeChangeCallbacks, s) + routeChangeHandle = 0 + return nil, err + } + } + + return s, nil +} + +// Unregister unregisters the callback. +func (callback *RouteChangeCallback) Unregister() error { + routeChangeAddRemoveMutex.Lock() + defer routeChangeAddRemoveMutex.Unlock() + + routeChangeMutex.Lock() + delete(routeChangeCallbacks, callback) + removeIt := len(routeChangeCallbacks) == 0 && routeChangeHandle != 0 + routeChangeMutex.Unlock() + + callback.wait.Wait() + + if removeIt { + err := cancelMibChangeNotify2(routeChangeHandle) + if err != nil { + return err + } + routeChangeHandle = 0 + } + + return nil +} + +func routeChanged(callerContext uintptr, row *MibIPforwardRow2, notificationType MibNotificationType) uintptr { + rowCopy := *row + routeChangeMutex.Lock() + for cb := range routeChangeCallbacks { + cb.wait.Add(1) + go func(cb *RouteChangeCallback) { + cb.cb(notificationType, &rowCopy) + cb.wait.Done() + }(cb) + } + routeChangeMutex.Unlock() + return 0 +} diff --git a/internal/winipcfg/types.go b/internal/winipcfg/types.go new file mode 100644 index 0000000..8e8f4a5 --- /dev/null +++ b/internal/winipcfg/types.go @@ -0,0 +1,1018 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +import ( + "encoding/binary" + "fmt" + "net/netip" + "strconv" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + anySize = 1 + maxDNSSuffixStringLength = 256 + maxDHCPv6DUIDLength = 130 + ifMaxStringSize = 256 + ifMaxPhysAddressLength = 32 +) + +// AddressFamily enumeration specifies protocol family and is one of the windows.AF_* constants. +type AddressFamily uint16 + +// IPAAFlags enumeration describes adapter addresses flags +// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_addresses_lh +type IPAAFlags uint32 + +const ( + IPAAFlagDdnsEnabled IPAAFlags = 1 << iota + IPAAFlagRegisterAdapterSuffix + IPAAFlagDhcpv4Enabled + IPAAFlagReceiveOnly + IPAAFlagNoMulticast + IPAAFlagIpv6OtherStatefulConfig + IPAAFlagNetbiosOverTcpipEnabled + IPAAFlagIpv4Enabled + IPAAFlagIpv6Enabled + IPAAFlagIpv6ManagedAddressConfigurationSupported +) + +// IfOperStatus enumeration specifies the operational status of an interface. +// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-if_oper_status +type IfOperStatus uint32 + +const ( + IfOperStatusUp IfOperStatus = iota + 1 + IfOperStatusDown + IfOperStatusTesting + IfOperStatusUnknown + IfOperStatusDormant + IfOperStatusNotPresent + IfOperStatusLowerLayerDown +) + +// IfType enumeration specifies interface type. +type IfType uint32 + +const ( + IfTypeOther IfType = 1 // None of the below + IfTypeRegular1822 = 2 + IfTypeHdh1822 = 3 + IfTypeDdnX25 = 4 + IfTypeRfc877X25 = 5 + IfTypeEthernetCSMACD = 6 + IfTypeISO88023CSMACD = 7 + IfTypeISO88024Tokenbus = 8 + IfTypeISO88025Tokenring = 9 + IfTypeISO88026Man = 10 + IfTypeStarlan = 11 + IfTypeProteon10Mbit = 12 + IfTypeProteon80Mbit = 13 + IfTypeHyperchannel = 14 + IfTypeFddi = 15 + IfTypeLapB = 16 + IfTypeSdlc = 17 + IfTypeDs1 = 18 // DS1-MIB + IfTypeE1 = 19 // Obsolete; see DS1-MIB + IfTypeBasicISDN = 20 + IfTypePrimaryISDN = 21 + IfTypePropPoint2PointSerial = 22 // proprietary serial + IfTypePPP = 23 + IfTypeSoftwareLoopback = 24 + IfTypeEon = 25 // CLNP over IP + IfTypeEthernet3Mbit = 26 + IfTypeNsip = 27 // XNS over IP + IfTypeSlip = 28 // Generic Slip + IfTypeUltra = 29 // ULTRA Technologies + IfTypeDs3 = 30 // DS3-MIB + IfTypeSip = 31 // SMDS, coffee + IfTypeFramerelay = 32 // DTE only + IfTypeRs232 = 33 + IfTypePara = 34 // Parallel port + IfTypeArcnet = 35 + IfTypeArcnetPlus = 36 + IfTypeAtm = 37 // ATM cells + IfTypeMioX25 = 38 + IfTypeSonet = 39 // SONET or SDH + IfTypeX25Ple = 40 + IfTypeIso88022LLC = 41 + IfTypeLocaltalk = 42 + IfTypeSmdsDxi = 43 + IfTypeFramerelayService = 44 // FRNETSERV-MIB + IfTypeV35 = 45 + IfTypeHssi = 46 + IfTypeHippi = 47 + IfTypeModem = 48 // Generic Modem + IfTypeAal5 = 49 // AAL5 over ATM + IfTypeSonetPath = 50 + IfTypeSonetVt = 51 + IfTypeSmdsIcip = 52 // SMDS InterCarrier Interface + IfTypePropVirtual = 53 // Proprietary virtual/internal + IfTypePropMultiplexor = 54 // Proprietary multiplexing + IfTypeIEEE80212 = 55 // 100BaseVG + IfTypeFibrechannel = 56 + IfTypeHippiinterface = 57 + IfTypeFramerelayInterconnect = 58 // Obsolete, use 32 or 44 + IfTypeAflane8023 = 59 // ATM Emulated LAN for 802.3 + IfTypeAflane8025 = 60 // ATM Emulated LAN for 802.5 + IfTypeCctemul = 61 // ATM Emulated circuit + IfTypeFastether = 62 // Fast Ethernet (100BaseT) + IfTypeISDN = 63 // ISDN and X.25 + IfTypeV11 = 64 // CCITT V.11/X.21 + IfTypeV36 = 65 // CCITT V.36 + IfTypeG703_64k = 66 // CCITT G703 at 64Kbps + IfTypeG703_2mb = 67 // Obsolete; see DS1-MIB + IfTypeQllc = 68 // SNA QLLC + IfTypeFastetherFX = 69 // Fast Ethernet (100BaseFX) + IfTypeChannel = 70 + IfTypeIEEE80211 = 71 // Radio spread spectrum + IfTypeIBM370parchan = 72 // IBM System 360/370 OEMI Channel + IfTypeEscon = 73 // IBM Enterprise Systems Connection + IfTypeDlsw = 74 // Data Link Switching + IfTypeISDNS = 75 // ISDN S/T interface + IfTypeISDNU = 76 // ISDN U interface + IfTypeLapD = 77 // Link Access Protocol D + IfTypeIpswitch = 78 // IP Switching Objects + IfTypeRsrb = 79 // Remote Source Route Bridging + IfTypeAtmLogical = 80 // ATM Logical Port + IfTypeDs0 = 81 // Digital Signal Level 0 + IfTypeDs0Bundle = 82 // Group of ds0s on the same ds1 + IfTypeBsc = 83 // Bisynchronous Protocol + IfTypeAsync = 84 // Asynchronous Protocol + IfTypeCnr = 85 // Combat Net Radio + IfTypeIso88025rDtr = 86 // ISO 802.5r DTR + IfTypeEplrs = 87 // Ext Pos Loc Report Sys + IfTypeArap = 88 // Appletalk Remote Access Protocol + IfTypePropCnls = 89 // Proprietary Connectionless Proto + IfTypeHostpad = 90 // CCITT-ITU X.29 PAD Protocol + IfTypeTermpad = 91 // CCITT-ITU X.3 PAD Facility + IfTypeFramerelayMpi = 92 // Multiproto Interconnect over FR + IfTypeX213 = 93 // CCITT-ITU X213 + IfTypeAdsl = 94 // Asymmetric Digital Subscrbr Loop + IfTypeRadsl = 95 // Rate-Adapt Digital Subscrbr Loop + IfTypeSdsl = 96 // Symmetric Digital Subscriber Loop + IfTypeVdsl = 97 // Very H-Speed Digital Subscrb Loop + IfTypeIso88025Crfprint = 98 // ISO 802.5 CRFP + IfTypeMyrinet = 99 // Myricom Myrinet + IfTypeVoiceEm = 100 // Voice recEive and transMit + IfTypeVoiceFxo = 101 // Voice Foreign Exchange Office + IfTypeVoiceFxs = 102 // Voice Foreign Exchange Station + IfTypeVoiceEncap = 103 // Voice encapsulation + IfTypeVoiceOverip = 104 // Voice over IP encapsulation + IfTypeAtmDxi = 105 // ATM DXI + IfTypeAtmFuni = 106 // ATM FUNI + IfTypeAtmIma = 107 // ATM IMA + IfTypePPPmultilinkbundle = 108 // PPP Multilink Bundle + IfTypeIpoverCdlc = 109 // IBM ipOverCdlc + IfTypeIpoverClaw = 110 // IBM Common Link Access to Workstn + IfTypeStacktostack = 111 // IBM stackToStack + IfTypeVirtualipaddress = 112 // IBM VIPA + IfTypeMpc = 113 // IBM multi-proto channel support + IfTypeIpoverAtm = 114 // IBM ipOverAtm + IfTypeIso88025Fiber = 115 // ISO 802.5j Fiber Token Ring + IfTypeTdlc = 116 // IBM twinaxial data link control + IfTypeGigabitethernet = 117 + IfTypeHdlc = 118 + IfTypeLapF = 119 + IfTypeV37 = 120 + IfTypeX25Mlp = 121 // Multi-Link Protocol + IfTypeX25Huntgroup = 122 // X.25 Hunt Group + IfTypeTransphdlc = 123 + IfTypeInterleave = 124 // Interleave channel + IfTypeFast = 125 // Fast channel + IfTypeIP = 126 // IP (for APPN HPR in IP networks) + IfTypeDocscableMaclayer = 127 // CATV Mac Layer + IfTypeDocscableDownstream = 128 // CATV Downstream interface + IfTypeDocscableUpstream = 129 // CATV Upstream interface + IfTypeA12mppswitch = 130 // Avalon Parallel Processor + IfTypeTunnel = 131 // Encapsulation interface + IfTypeCoffee = 132 // Coffee pot + IfTypeCes = 133 // Circuit Emulation Service + IfTypeAtmSubinterface = 134 // ATM Sub Interface + IfTypeL2Vlan = 135 // Layer 2 Virtual LAN using 802.1Q + IfTypeL3Ipvlan = 136 // Layer 3 Virtual LAN using IP + IfTypeL3Ipxvlan = 137 // Layer 3 Virtual LAN using IPX + IfTypeDigitalpowerline = 138 // IP over Power Lines + IfTypeMediamailoverip = 139 // Multimedia Mail over IP + IfTypeDtm = 140 // Dynamic syncronous Transfer Mode + IfTypeDcn = 141 // Data Communications Network + IfTypeIpforward = 142 // IP Forwarding Interface + IfTypeMsdsl = 143 // Multi-rate Symmetric DSL + IfTypeIEEE1394 = 144 // IEEE1394 High Perf Serial Bus + IfTypeIfGsn = 145 + IfTypeDvbrccMaclayer = 146 + IfTypeDvbrccDownstream = 147 + IfTypeDvbrccUpstream = 148 + IfTypeAtmVirtual = 149 + IfTypeMplsTunnel = 150 + IfTypeSrp = 151 + IfTypeVoiceoveratm = 152 + IfTypeVoiceoverframerelay = 153 + IfTypeIdsl = 154 + IfTypeCompositelink = 155 + IfTypeSs7Siglink = 156 + IfTypePropWirelessP2P = 157 + IfTypeFrForward = 158 + IfTypeRfc1483 = 159 + IfTypeUsb = 160 + IfTypeIEEE8023adLag = 161 + IfTypeBgpPolicyAccounting = 162 + IfTypeFrf16MfrBundle = 163 + IfTypeH323Gatekeeper = 164 + IfTypeH323Proxy = 165 + IfTypeMpls = 166 + IfTypeMfSiglink = 167 + IfTypeHdsl2 = 168 + IfTypeShdsl = 169 + IfTypeDs1Fdl = 170 + IfTypePos = 171 + IfTypeDvbAsiIn = 172 + IfTypeDvbAsiOut = 173 + IfTypePlc = 174 + IfTypeNfas = 175 + IfTypeTr008 = 176 + IfTypeGr303Rdt = 177 + IfTypeGr303Idt = 178 + IfTypeIsup = 179 + IfTypePropDocsWirelessMaclayer = 180 + IfTypePropDocsWirelessDownstream = 181 + IfTypePropDocsWirelessUpstream = 182 + IfTypeHiperlan2 = 183 + IfTypePropBwaP2MP = 184 + IfTypeSonetOverheadChannel = 185 + IfTypeDigitalWrapperOverheadChannel = 186 + IfTypeAal2 = 187 + IfTypeRadioMac = 188 + IfTypeAtmRadio = 189 + IfTypeImt = 190 + IfTypeMvl = 191 + IfTypeReachDsl = 192 + IfTypeFrDlciEndpt = 193 + IfTypeAtmVciEndpt = 194 + IfTypeOpticalChannel = 195 + IfTypeOpticalTransport = 196 + IfTypeIEEE80216Wman = 237 + IfTypeWwanpp = 243 // WWAN devices based on GSM technology + IfTypeWwanpp2 = 244 // WWAN devices based on CDMA technology + IfTypeIEEE802154 = 259 // IEEE 802.15.4 WPAN interface + IfTypeXboxWireless = 281 +) + +// MibIfEntryLevel enumeration specifies level of interface information to retrieve in GetIfTable2Ex function call. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getifentry2ex +type MibIfEntryLevel uint32 + +const ( + MibIfEntryNormal MibIfEntryLevel = 0 + MibIfEntryNormalWithoutStatistics = 2 +) + +// NdisMedium enumeration type identifies the medium types that NDIS drivers support. +// https://docs.microsoft.com/en-us/windows-hardware/drivers/ddi/content/ntddndis/ne-ntddndis-_ndis_medium +type NdisMedium uint32 + +const ( + NdisMedium802_3 NdisMedium = iota + NdisMedium802_5 + NdisMediumFddi + NdisMediumWan + NdisMediumLocalTalk + NdisMediumDix // defined for convenience, not a real medium + NdisMediumArcnetRaw + NdisMediumArcnet878_2 + NdisMediumAtm + NdisMediumWirelessWan + NdisMediumIrda + NdisMediumBpc + NdisMediumCoWan + NdisMedium1394 + NdisMediumInfiniBand + NdisMediumTunnel + NdisMediumNative802_11 + NdisMediumLoopback + NdisMediumWiMAX + NdisMediumIP + NdisMediumMax +) + +// NdisPhysicalMedium describes NDIS physical medium type. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_row2 +type NdisPhysicalMedium uint32 + +const ( + NdisPhysicalMediumUnspecified NdisPhysicalMedium = iota + NdisPhysicalMediumWirelessLan + NdisPhysicalMediumCableModem + NdisPhysicalMediumPhoneLine + NdisPhysicalMediumPowerLine + NdisPhysicalMediumDSL // includes ADSL and UADSL (G.Lite) + NdisPhysicalMediumFibreChannel + NdisPhysicalMedium1394 + NdisPhysicalMediumWirelessWan + NdisPhysicalMediumNative802_11 + NdisPhysicalMediumBluetooth + NdisPhysicalMediumInfiniband + NdisPhysicalMediumWiMax + NdisPhysicalMediumUWB + NdisPhysicalMedium802_3 + NdisPhysicalMedium802_5 + NdisPhysicalMediumIrda + NdisPhysicalMediumWiredWAN + NdisPhysicalMediumWiredCoWan + NdisPhysicalMediumOther + NdisPhysicalMediumNative802_15_4 + NdisPhysicalMediumMax +) + +// NetIfAccessType enumeration type specifies the NDIS network interface access type. +// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-_net_if_access_type +type NetIfAccessType uint32 + +const ( + NetIfAccessLoopback NetIfAccessType = iota + 1 + NetIfAccessBroadcast + NetIfAccessPointToPoint + NetIfAccessPointToMultiPoint + NetIfAccessMax +) + +// NetIfAdminStatus enumeration type specifies the NDIS network interface administrative status, as described in RFC 2863. +// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-net_if_admin_status +type NetIfAdminStatus uint32 + +const ( + NetIfAdminStatusUp NetIfAdminStatus = iota + 1 + NetIfAdminStatusDown + NetIfAdminStatusTesting +) + +// NetIfConnectionType enumeration type specifies the NDIS network interface connection type. +// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-_net_if_connection_type +type NetIfConnectionType uint32 + +const ( + NetIfConnectionDedicated NetIfConnectionType = iota + 1 + NetIfConnectionPassive + NetIfConnectionDemand + NetIfConnectionMaximum +) + +// NetIfDirectionType enumeration type specifies the NDIS network interface direction type. +// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-net_if_direction_type +type NetIfDirectionType uint32 + +const ( + NetIfDirectionSendReceive NetIfDirectionType = iota + NetIfDirectionSendOnly + NetIfDirectionReceiveOnly + NetIfDirectionMaximum +) + +// NetIfMediaConnectState enumeration type specifies the NDIS network interface connection state. +// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-_net_if_media_connect_state +type NetIfMediaConnectState uint32 + +const ( + MediaConnectStateUnknown NetIfMediaConnectState = iota + MediaConnectStateConnected + MediaConnectStateDisconnected +) + +// DadState enumeration specifies information about the duplicate address detection (DAD) state for an IPv4 or IPv6 address. +// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-nl_dad_state +type DadState uint32 + +const ( + DadStateInvalid DadState = iota + DadStateTentative + DadStateDuplicate + DadStateDeprecated + DadStatePreferred +) + +// PrefixOrigin enumeration specifies the origin of an IPv4 or IPv6 address prefix, and is used with the IP_ADAPTER_UNICAST_ADDRESS structure. +// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-nl_prefix_origin +type PrefixOrigin uint32 + +const ( + PrefixOriginOther PrefixOrigin = iota + PrefixOriginManual + PrefixOriginWellKnown + PrefixOriginDHCP + PrefixOriginRouterAdvertisement + PrefixOriginUnchanged = 1 << 4 +) + +// LinkLocalAddressBehavior enumeration type defines the link local address behavior. +// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-_nl_link_local_address_behavior +type LinkLocalAddressBehavior int32 + +const ( + LinkLocalAddressAlwaysOff LinkLocalAddressBehavior = iota // Never use link locals. + LinkLocalAddressDelayed // Use link locals only if no other addresses. (default for IPv4). Legacy mapping: IPAutoconfigurationEnabled. + LinkLocalAddressAlwaysOn // Always use link locals (default for IPv6). + LinkLocalAddressUnchanged = -1 +) + +// OffloadRod enumeration specifies a set of flags that indicate the offload capabilities for an IP interface. +// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ns-nldef-_nl_interface_offload_rod +type OffloadRod uint8 + +const ( + ChecksumSupported OffloadRod = 1 << iota + OptionsSupported + DatagramChecksumSupported + StreamChecksumSupported + StreamOptionsSupported + FastPathCompatible + LargeSendOffloadSupported + GiantSendOffloadSupported +) + +// RouteOrigin enumeration type defines the origin of the IP route. +// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-nl_route_origin +type RouteOrigin uint32 + +const ( + RouteOriginManual RouteOrigin = iota + RouteOriginWellKnown + RouteOriginDHCP + RouteOriginRouterAdvertisement + RouteOrigin6to4 +) + +// RouteProtocol enumeration type defines the routing mechanism that an IP route was added with, as described in RFC 4292. +// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-nl_route_protocol +type RouteProtocol uint32 + +const ( + RouteProtocolOther RouteProtocol = iota + 1 + RouteProtocolLocal + RouteProtocolNetMgmt + RouteProtocolIcmp + RouteProtocolEgp + RouteProtocolGgp + RouteProtocolHello + RouteProtocolRip + RouteProtocolIsIs + RouteProtocolEsIs + RouteProtocolCisco + RouteProtocolBbn + RouteProtocolOspf + RouteProtocolBgp + RouteProtocolIdpr + RouteProtocolEigrp + RouteProtocolDvmrp + RouteProtocolRpl + RouteProtocolDHCP + RouteProtocolNTAutostatic = 10002 + RouteProtocolNTStatic = 10006 + RouteProtocolNTStaticNonDOD = 10007 +) + +// RouterDiscoveryBehavior enumeration type defines the router discovery behavior, as described in RFC 2461. +// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-_nl_router_discovery_behavior +type RouterDiscoveryBehavior int32 + +const ( + RouterDiscoveryDisabled RouterDiscoveryBehavior = iota + RouterDiscoveryEnabled + RouterDiscoveryDHCP + RouterDiscoveryUnchanged = -1 +) + +// SuffixOrigin enumeration specifies the origin of an IPv4 or IPv6 address suffix, and is used with the IP_ADAPTER_UNICAST_ADDRESS structure. +// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-nl_suffix_origin +type SuffixOrigin uint32 + +const ( + SuffixOriginOther SuffixOrigin = iota + SuffixOriginManual + SuffixOriginWellKnown + SuffixOriginDHCP + SuffixOriginLinkLayerAddress + SuffixOriginRandom + SuffixOriginUnchanged = 1 << 4 +) + +// MibNotificationType enumeration defines the notification type passed to a callback function when a notification occurs. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ne-netioapi-_mib_notification_type +type MibNotificationType uint32 + +const ( + MibParameterNotification MibNotificationType = iota // Parameter change + MibAddInstance // Addition + MibDeleteInstance // Deletion + MibInitialNotification // Initial notification +) + +type ChangeCallback interface { + Unregister() error +} + +// TunnelType enumeration type defines the encapsulation method used by a tunnel, as described by the Internet Assigned Names Authority (IANA). +// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-tunnel_type +type TunnelType uint32 + +const ( + TunnelTypeNone TunnelType = 0 + TunnelTypeOther = 1 + TunnelTypeDirect = 2 + TunnelType6to4 = 11 + TunnelTypeIsatap = 13 + TunnelTypeTeredo = 14 + TunnelTypeIPHTTPS = 15 +) + +// InterfaceAndOperStatusFlags enumeration type defines interface and operation flags +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_row2 +type InterfaceAndOperStatusFlags uint8 + +const ( + IAOSFHardwareInterface InterfaceAndOperStatusFlags = 1 << iota + IAOSFFilterInterface + IAOSFConnectorPresent + IAOSFNotAuthenticated + IAOSFNotMediaConnected + IAOSFPaused + IAOSFLowPower + IAOSFEndPointInterface +) + +// GAAFlags enumeration defines flags used in GetAdaptersAddresses calls +// https://docs.microsoft.com/en-us/windows/desktop/api/iphlpapi/nf-iphlpapi-getadaptersaddresses +type GAAFlags uint32 + +const ( + GAAFlagSkipUnicast GAAFlags = 1 << iota + GAAFlagSkipAnycast + GAAFlagSkipMulticast + GAAFlagSkipDNSServer + GAAFlagIncludePrefix + GAAFlagSkipFriendlyName + GAAFlagIncludeWinsInfo + GAAFlagIncludeGateways + GAAFlagIncludeAllInterfaces + GAAFlagIncludeAllCompartments + GAAFlagIncludeTunnelBindingOrder + GAAFlagSkipDNSInfo + + GAAFlagDefault GAAFlags = 0 + GAAFlagSkipAll = GAAFlagSkipUnicast | GAAFlagSkipAnycast | GAAFlagSkipMulticast | GAAFlagSkipDNSServer | GAAFlagSkipFriendlyName | GAAFlagSkipDNSInfo + GAAFlagIncludeAll = GAAFlagIncludePrefix | GAAFlagIncludeWinsInfo | GAAFlagIncludeGateways | GAAFlagIncludeAllInterfaces | GAAFlagIncludeAllCompartments | GAAFlagIncludeTunnelBindingOrder +) + +// ScopeLevel enumeration is used with the IP_ADAPTER_ADDRESSES structure to identify scope levels for IPv6 addresses. +// https://docs.microsoft.com/en-us/windows/desktop/api/ws2def/ne-ws2def-scope_level +type ScopeLevel uint32 + +const ( + ScopeLevelInterface ScopeLevel = 1 + ScopeLevelLink = 2 + ScopeLevelSubnet = 3 + ScopeLevelAdmin = 4 + ScopeLevelSite = 5 + ScopeLevelOrganization = 8 + ScopeLevelGlobal = 14 + ScopeLevelCount = 16 +) + +// RouteData structure describes a route to add +type RouteData struct { + Destination netip.Prefix + NextHop netip.Addr + Metric uint32 +} + +func (routeData *RouteData) String() string { + return fmt.Sprintf("%+v", *routeData) +} + +// IPAdapterDNSSuffix structure stores a DNS suffix in a linked list of DNS suffixes for a particular adapter. +// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_dns_suffix +type IPAdapterDNSSuffix struct { + Next *IPAdapterDNSSuffix + str [maxDNSSuffixStringLength]uint16 +} + +// String method returns the DNS suffix for this DNS suffix entry. +func (obj *IPAdapterDNSSuffix) String() string { + return windows.UTF16ToString(obj.str[:]) +} + +// AdapterName method returns the name of the adapter with which these addresses are associated. +// Unlike an adapter's friendly name, the adapter name returned by AdapterName is permanent and cannot be modified by the user. +func (addr *IPAdapterAddresses) AdapterName() string { + return windows.BytePtrToString(addr.adapterName) +} + +// DNSSuffix method returns adapter DNS suffix associated with this adapter. +func (addr *IPAdapterAddresses) DNSSuffix() string { + if addr.dnsSuffix == nil { + return "" + } + return windows.UTF16PtrToString(addr.dnsSuffix) +} + +// Description method returns description for the adapter. +func (addr *IPAdapterAddresses) Description() string { + if addr.description == nil { + return "" + } + return windows.UTF16PtrToString(addr.description) +} + +// FriendlyName method returns a user-friendly name for the adapter. For example: "Local Area Connection 1." +// This name appears in contexts such as the ipconfig command line program and the Connection folder. +func (addr *IPAdapterAddresses) FriendlyName() string { + if addr.friendlyName == nil { + return "" + } + return windows.UTF16PtrToString(addr.friendlyName) +} + +// PhysicalAddress method returns the Media Access Control (MAC) address for the adapter. +// For example, on an Ethernet network this member would specify the Ethernet hardware address. +func (addr *IPAdapterAddresses) PhysicalAddress() []byte { + return addr.physicalAddress[:addr.physicalAddressLength] +} + +// DHCPv6ClientDUID method returns the DHCP unique identifier (DUID) for the DHCPv6 client. +// This information is only applicable to an IPv6 adapter address configured using DHCPv6. +func (addr *IPAdapterAddresses) DHCPv6ClientDUID() []byte { + return addr.dhcpv6ClientDUID[:addr.dhcpv6ClientDUIDLength] +} + +// Init method initializes the members of an MIB_IPINTERFACE_ROW entry with default values. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-initializeipinterfaceentry +func (row *MibIPInterfaceRow) Init() { + initializeIPInterfaceEntry(row) +} + +// get method retrieves IP information for the specified interface on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipinterfaceentry +func (row *MibIPInterfaceRow) get() error { + if err := getIPInterfaceEntry(row); err != nil { + return err + } + + // Patch that fixes SitePrefixLength issue + // https://stackoverflow.com/questions/54857292/setipinterfaceentry-returns-error-invalid-parameter?noredirect=1 + switch row.Family { + case windows.AF_INET: + if row.SitePrefixLength > 32 { + row.SitePrefixLength = 0 + } + case windows.AF_INET6: + if row.SitePrefixLength > 128 { + row.SitePrefixLength = 128 + } + } + + return nil +} + +// Set method sets the properties of an IP interface on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-setipinterfaceentry +func (row *MibIPInterfaceRow) Set() error { + return setIPInterfaceEntry(row) +} + +// get method returns all table rows as a Go slice. +func (tab *mibIPInterfaceTable) get() (s []MibIPInterfaceRow) { + return unsafe.Slice(&tab.table[0], tab.numEntries) +} + +// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-freemibtable +func (tab *mibIPInterfaceTable) free() { + freeMibTable(unsafe.Pointer(tab)) +} + +// Alias method returns a string that contains the alias name of the network interface. +func (row *MibIfRow2) Alias() string { + return windows.UTF16ToString(row.alias[:]) +} + +// Description method returns a string that contains a description of the network interface. +func (row *MibIfRow2) Description() string { + return windows.UTF16ToString(row.description[:]) +} + +// PhysicalAddress method returns the physical hardware address of the adapter for this network interface. +func (row *MibIfRow2) PhysicalAddress() []byte { + return row.physicalAddress[:row.physicalAddressLength] +} + +// PermanentPhysicalAddress method returns the permanent physical hardware address of the adapter for this network interface. +func (row *MibIfRow2) PermanentPhysicalAddress() []byte { + return row.permanentPhysicalAddress[:row.physicalAddressLength] +} + +// get method retrieves information for the specified interface on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getifentry2 +func (row *MibIfRow2) get() (ret error) { + return getIfEntry2(row) +} + +// get method returns all table rows as a Go slice. +func (tab *mibIfTable2) get() (s []MibIfRow2) { + return unsafe.Slice(&tab.table[0], tab.numEntries) +} + +// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-freemibtable +func (tab *mibIfTable2) free() { + freeMibTable(unsafe.Pointer(tab)) +} + +// RawSockaddrInet union contains an IPv4, an IPv6 address, or an address family. +// https://docs.microsoft.com/en-us/windows/desktop/api/ws2ipdef/ns-ws2ipdef-_sockaddr_inet +type RawSockaddrInet struct { + Family AddressFamily + data [26]byte +} + +func ntohs(i uint16) uint16 { + return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&i))[:]) +} + +func htons(i uint16) uint16 { + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, i) + return *(*uint16)(unsafe.Pointer(&b[0])) +} + +// SetAddrPort method sets family, address, and port to the given IPv4 or IPv6 address and port. +// All other members of the structure are set to zero. +func (addr *RawSockaddrInet) SetAddrPort(addrPort netip.AddrPort) error { + if addrPort.Addr().Is4() { + addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)) + addr4.Family = windows.AF_INET + addr4.Addr = addrPort.Addr().As4() + addr4.Port = htons(addrPort.Port()) + for i := 0; i < 8; i++ { + addr4.Zero[i] = 0 + } + return nil + } else if addrPort.Addr().Is6() { + addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)) + addr6.Family = windows.AF_INET6 + addr6.Addr = addrPort.Addr().As16() + addr6.Port = htons(addrPort.Port()) + addr6.Flowinfo = 0 + scopeId := uint32(0) + if z := addrPort.Addr().Zone(); z != "" { + if s, err := strconv.ParseUint(z, 10, 32); err == nil { + scopeId = uint32(s) + } + } + addr6.Scope_id = scopeId + return nil + } + return windows.ERROR_INVALID_PARAMETER +} + +// SetAddr method sets family and address to the given IPv4 or IPv6 address. +// All other members of the structure are set to zero. +func (addr *RawSockaddrInet) SetAddr(netAddr netip.Addr) error { + return addr.SetAddrPort(netip.AddrPortFrom(netAddr, 0)) +} + +// AddrPort returns the IP address and port. +func (addr *RawSockaddrInet) AddrPort() netip.AddrPort { + return netip.AddrPortFrom(addr.Addr(), addr.Port()) +} + +// Addr returns IPv4 or IPv6 address, or an invalid address if the address is neither. +func (addr *RawSockaddrInet) Addr() netip.Addr { + switch addr.Family { + case windows.AF_INET: + return netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr) + case windows.AF_INET6: + raw := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)) + a := netip.AddrFrom16(raw.Addr) + if raw.Scope_id != 0 { + a = a.WithZone(strconv.FormatUint(uint64(raw.Scope_id), 10)) + } + return a + } + return netip.Addr{} +} + +// Port returns the port if the address if IPv4 or IPv6, or 0 if neither. +func (addr *RawSockaddrInet) Port() uint16 { + switch addr.Family { + case windows.AF_INET: + return ntohs((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Port) + case windows.AF_INET6: + return ntohs((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Port) + } + return 0 +} + +// Init method initializes a MibUnicastIPAddressRow structure with default values for a unicast IP address entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-initializeunicastipaddressentry +func (row *MibUnicastIPAddressRow) Init() { + initializeUnicastIPAddressEntry(row) +} + +// get method retrieves information for an existing unicast IP address entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddressentry +func (row *MibUnicastIPAddressRow) get() error { + return getUnicastIPAddressEntry(row) +} + +// Set method sets the properties of an existing unicast IP address entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-setunicastipaddressentry +func (row *MibUnicastIPAddressRow) Set() error { + return setUnicastIPAddressEntry(row) +} + +// Create method adds a new unicast IP address entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry +func (row *MibUnicastIPAddressRow) Create() error { + return createUnicastIPAddressEntry(row) +} + +// Delete method deletes an existing unicast IP address entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry +func (row *MibUnicastIPAddressRow) Delete() error { + return deleteUnicastIPAddressEntry(row) +} + +// get method returns all table rows as a Go slice. +func (tab *mibUnicastIPAddressTable) get() (s []MibUnicastIPAddressRow) { + return unsafe.Slice(&tab.table[0], tab.numEntries) +} + +// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-freemibtable +func (tab *mibUnicastIPAddressTable) free() { + freeMibTable(unsafe.Pointer(tab)) +} + +// get method retrieves information for an existing anycast IP address entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getanycastipaddressentry +func (row *MibAnycastIPAddressRow) get() error { + return getAnycastIPAddressEntry(row) +} + +// Create method adds a new anycast IP address entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createanycastipaddressentry +func (row *MibAnycastIPAddressRow) Create() error { + return createAnycastIPAddressEntry(row) +} + +// Delete method deletes an existing anycast IP address entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteanycastipaddressentry +func (row *MibAnycastIPAddressRow) Delete() error { + return deleteAnycastIPAddressEntry(row) +} + +// get method returns all table rows as a Go slice. +func (tab *mibAnycastIPAddressTable) get() (s []MibAnycastIPAddressRow) { + return unsafe.Slice(&tab.table[0], tab.numEntries) +} + +// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-freemibtable +func (tab *mibAnycastIPAddressTable) free() { + freeMibTable(unsafe.Pointer(tab)) +} + +// IPAddressPrefix structure stores an IP address prefix. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_ip_address_prefix +type IPAddressPrefix struct { + RawPrefix RawSockaddrInet + PrefixLength uint8 + _ [2]byte +} + +// SetPrefix method sets IP address prefix using netip.Prefix. +func (prefix *IPAddressPrefix) SetPrefix(netPrefix netip.Prefix) error { + err := prefix.RawPrefix.SetAddr(netPrefix.Addr()) + if err != nil { + return err + } + prefix.PrefixLength = uint8(netPrefix.Bits()) + return nil +} + +// Prefix returns IP address prefix as netip.Prefix. +func (prefix *IPAddressPrefix) Prefix() netip.Prefix { + switch prefix.RawPrefix.Family { + case windows.AF_INET: + return netip.PrefixFrom(netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength)) + case windows.AF_INET6: + return netip.PrefixFrom(netip.AddrFrom16((*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength)) + } + return netip.Prefix{} +} + +// MibIPforwardRow2 structure stores information about an IP route entry. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipforward_row2 +type MibIPforwardRow2 struct { + InterfaceLUID LUID + InterfaceIndex uint32 + DestinationPrefix IPAddressPrefix + NextHop RawSockaddrInet + SitePrefixLength uint8 + ValidLifetime uint32 + PreferredLifetime uint32 + Metric uint32 + Protocol RouteProtocol + Loopback bool + AutoconfigureAddress bool + Publish bool + Immortal bool + Age uint32 + Origin RouteOrigin +} + +// Init method initializes a MIB_IPFORWARD_ROW2 structure with default values for an IP route entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-initializeipforwardentry +func (row *MibIPforwardRow2) Init() { + initializeIPForwardEntry(row) +} + +// get method retrieves information for an IP route entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardentry2 +func (row *MibIPforwardRow2) get() error { + return getIPForwardEntry2(row) +} + +// Set method sets the properties of an IP route entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-setipforwardentry2 +func (row *MibIPforwardRow2) Set() error { + return setIPForwardEntry2(row) +} + +// Create method creates a new IP route entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createipforwardentry2 +func (row *MibIPforwardRow2) Create() error { + return createIPForwardEntry2(row) +} + +// Delete method deletes an IP route entry on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2 +func (row *MibIPforwardRow2) Delete() error { + return deleteIPForwardEntry2(row) +} + +// get method returns all table rows as a Go slice. +func (tab *mibIPforwardTable2) get() (s []MibIPforwardRow2) { + return unsafe.Slice(&tab.table[0], tab.numEntries) +} + +// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-freemibtable +func (tab *mibIPforwardTable2) free() { + freeMibTable(unsafe.Pointer(tab)) +} + +// +// DNS API +// + +// DnsInterfaceSettings is meant to be used with SetInterfaceDnsSettings +type DnsInterfaceSettings struct { + Version uint32 + _ [4]byte + Flags uint64 + Domain *uint16 + NameServer *uint16 + SearchList *uint16 + RegistrationEnabled uint32 + RegisterAdapterName uint32 + EnableLLMNR uint32 + QueryAdapterName uint32 + ProfileNameServer *uint16 +} + +const ( + DnsInterfaceSettingsVersion1 = 1 // for DnsInterfaceSettings + DnsInterfaceSettingsVersion2 = 2 // for DnsInterfaceSettingsEx + DnsInterfaceSettingsVersion3 = 3 // for DnsInterfaceSettings3 + + DnsInterfaceSettingsFlagIPv6 = 0x0001 + DnsInterfaceSettingsFlagNameserver = 0x0002 + DnsInterfaceSettingsFlagSearchList = 0x0004 + DnsInterfaceSettingsFlagRegistrationEnabled = 0x0008 + DnsInterfaceSettingsFlagRegisterAdapterName = 0x0010 + DnsInterfaceSettingsFlagDomain = 0x0020 + DnsInterfaceSettingsFlagHostname = 0x0040 + DnsInterfaceSettingsFlagEnableLLMNR = 0x0080 + DnsInterfaceSettingsFlagQueryAdapterName = 0x0100 + DnsInterfaceSettingsFlagProfileNameserver = 0x0200 + DnsInterfaceSettingsFlagDisableUnconstrainedQueries = 0x0400 // v2 only + DnsInterfaceSettingsFlagSupplementalSearchList = 0x0800 // v2 only + DnsInterfaceSettingsFlagDOH = 0x1000 // v3 only + DnsInterfaceSettingsFlagDOHProfile = 0x2000 // v3 only +) diff --git a/internal/winipcfg/types_32.go b/internal/winipcfg/types_32.go new file mode 100644 index 0000000..1a8d444 --- /dev/null +++ b/internal/winipcfg/types_32.go @@ -0,0 +1,232 @@ +//go:build 386 || arm + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +import ( + "golang.org/x/sys/windows" +) + +// IPAdapterWINSServerAddress structure stores a single Windows Internet Name Service (WINS) server address in a linked list of WINS server addresses for a particular adapter. +// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_wins_server_address_lh +type IPAdapterWINSServerAddress struct { + Length uint32 + _ uint32 + Next *IPAdapterWINSServerAddress + Address windows.SocketAddress + _ [4]byte +} + +// IPAdapterGatewayAddress structure stores a single gateway address in a linked list of gateway addresses for a particular adapter. +// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_gateway_address_lh +type IPAdapterGatewayAddress struct { + Length uint32 + _ uint32 + Next *IPAdapterGatewayAddress + Address windows.SocketAddress + _ [4]byte +} + +// IPAdapterAddresses structure is the header node for a linked list of addresses for a particular adapter. This structure can simultaneously be used as part of a linked list of IP_ADAPTER_ADDRESSES structures. +// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_addresses_lh +// This is a modified and extended version of windows.IpAdapterAddresses. +type IPAdapterAddresses struct { + Length uint32 + IfIndex uint32 + Next *IPAdapterAddresses + adapterName *byte + FirstUnicastAddress *windows.IpAdapterUnicastAddress + FirstAnycastAddress *windows.IpAdapterAnycastAddress + FirstMulticastAddress *windows.IpAdapterMulticastAddress + FirstDNSServerAddress *windows.IpAdapterDnsServerAdapter + dnsSuffix *uint16 + description *uint16 + friendlyName *uint16 + physicalAddress [windows.MAX_ADAPTER_ADDRESS_LENGTH]byte + physicalAddressLength uint32 + Flags IPAAFlags + MTU uint32 + IfType IfType + OperStatus IfOperStatus + IPv6IfIndex uint32 + ZoneIndices [16]uint32 + FirstPrefix *windows.IpAdapterPrefix + TransmitLinkSpeed uint64 + ReceiveLinkSpeed uint64 + FirstWINSServerAddress *IPAdapterWINSServerAddress + FirstGatewayAddress *IPAdapterGatewayAddress + Ipv4Metric uint32 + Ipv6Metric uint32 + LUID LUID + DHCPv4Server windows.SocketAddress + CompartmentID uint32 + NetworkGUID windows.GUID + ConnectionType NetIfConnectionType + TunnelType TunnelType + DHCPv6Server windows.SocketAddress + dhcpv6ClientDUID [maxDHCPv6DUIDLength]byte + dhcpv6ClientDUIDLength uint32 + DHCPv6IAID uint32 + FirstDNSSuffix *IPAdapterDNSSuffix + _ [4]byte +} + +// MibIPInterfaceRow structure stores interface management information for a particular IP address family on a network interface. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipinterface_row +type MibIPInterfaceRow struct { + Family AddressFamily + _ [4]byte + InterfaceLUID LUID + InterfaceIndex uint32 + MaxReassemblySize uint32 + InterfaceIdentifier uint64 + MinRouterAdvertisementInterval uint32 + MaxRouterAdvertisementInterval uint32 + AdvertisingEnabled bool + ForwardingEnabled bool + WeakHostSend bool + WeakHostReceive bool + UseAutomaticMetric bool + UseNeighborUnreachabilityDetection bool + ManagedAddressConfigurationSupported bool + OtherStatefulConfigurationSupported bool + AdvertiseDefaultRoute bool + RouterDiscoveryBehavior RouterDiscoveryBehavior + DadTransmits uint32 + BaseReachableTime uint32 + RetransmitTime uint32 + PathMTUDiscoveryTimeout uint32 + LinkLocalAddressBehavior LinkLocalAddressBehavior + LinkLocalAddressTimeout uint32 + ZoneIndices [ScopeLevelCount]uint32 + SitePrefixLength uint32 + Metric uint32 + NLMTU uint32 + Connected bool + SupportsWakeUpPatterns bool + SupportsNeighborDiscovery bool + SupportsRouterDiscovery bool + ReachableTime uint32 + TransmitOffload OffloadRod + ReceiveOffload OffloadRod + DisableDefaultRoutes bool +} + +// mibIPInterfaceTable structure contains a table of IP interface entries. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipinterface_table +type mibIPInterfaceTable struct { + numEntries uint32 + _ [4]byte + table [anySize]MibIPInterfaceRow +} + +// MibIfRow2 structure stores information about a particular interface. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_row2 +type MibIfRow2 struct { + InterfaceLUID LUID + InterfaceIndex uint32 + InterfaceGUID windows.GUID + alias [ifMaxStringSize + 1]uint16 + description [ifMaxStringSize + 1]uint16 + physicalAddressLength uint32 + physicalAddress [ifMaxPhysAddressLength]byte + permanentPhysicalAddress [ifMaxPhysAddressLength]byte + MTU uint32 + Type IfType + TunnelType TunnelType + MediaType NdisMedium + PhysicalMediumType NdisPhysicalMedium + AccessType NetIfAccessType + DirectionType NetIfDirectionType + InterfaceAndOperStatusFlags InterfaceAndOperStatusFlags + OperStatus IfOperStatus + AdminStatus NetIfAdminStatus + MediaConnectState NetIfMediaConnectState + NetworkGUID windows.GUID + ConnectionType NetIfConnectionType + _ [4]byte + TransmitLinkSpeed uint64 + ReceiveLinkSpeed uint64 + InOctets uint64 + InUcastPkts uint64 + InNUcastPkts uint64 + InDiscards uint64 + InErrors uint64 + InUnknownProtos uint64 + InUcastOctets uint64 + InMulticastOctets uint64 + InBroadcastOctets uint64 + OutOctets uint64 + OutUcastPkts uint64 + OutNUcastPkts uint64 + OutDiscards uint64 + OutErrors uint64 + OutUcastOctets uint64 + OutMulticastOctets uint64 + OutBroadcastOctets uint64 + OutQLen uint64 +} + +// mibIfTable2 structure contains a table of logical and physical interface entries. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_table2 +type mibIfTable2 struct { + numEntries uint32 + _ [4]byte + table [anySize]MibIfRow2 +} + +// MibUnicastIPAddressRow structure stores information about a unicast IP address. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_unicastipaddress_row +type MibUnicastIPAddressRow struct { + Address RawSockaddrInet + _ [4]byte + InterfaceLUID LUID + InterfaceIndex uint32 + PrefixOrigin PrefixOrigin + SuffixOrigin SuffixOrigin + ValidLifetime uint32 + PreferredLifetime uint32 + OnLinkPrefixLength uint8 + SkipAsSource bool + DadState DadState + ScopeID uint32 + CreationTimeStamp int64 +} + +// mibUnicastIPAddressTable structure contains a table of unicast IP address entries. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_unicastipaddress_table +type mibUnicastIPAddressTable struct { + numEntries uint32 + _ [4]byte + table [anySize]MibUnicastIPAddressRow +} + +// MibAnycastIPAddressRow structure stores information about an anycast IP address. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_anycastipaddress_row +type MibAnycastIPAddressRow struct { + Address RawSockaddrInet + _ [4]byte + InterfaceLUID LUID + InterfaceIndex uint32 + ScopeID uint32 +} + +// mibAnycastIPAddressTable structure contains a table of anycast IP address entries. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-mib_anycastipaddress_table +type mibAnycastIPAddressTable struct { + numEntries uint32 + _ [4]byte + table [anySize]MibAnycastIPAddressRow +} + +// mibIPforwardTable2 structure contains a table of IP route entries. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipforward_table2 +type mibIPforwardTable2 struct { + numEntries uint32 + _ [4]byte + table [anySize]MibIPforwardRow2 +} diff --git a/internal/winipcfg/types_64.go b/internal/winipcfg/types_64.go new file mode 100644 index 0000000..3a1fe07 --- /dev/null +++ b/internal/winipcfg/types_64.go @@ -0,0 +1,220 @@ +//go:build amd64 || arm64 + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +import ( + "golang.org/x/sys/windows" +) + +// IPAdapterWINSServerAddress structure stores a single Windows Internet Name Service (WINS) server address in a linked list of WINS server addresses for a particular adapter. +// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_wins_server_address_lh +type IPAdapterWINSServerAddress struct { + Length uint32 + _ uint32 + Next *IPAdapterWINSServerAddress + Address windows.SocketAddress +} + +// IPAdapterGatewayAddress structure stores a single gateway address in a linked list of gateway addresses for a particular adapter. +// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_gateway_address_lh +type IPAdapterGatewayAddress struct { + Length uint32 + _ uint32 + Next *IPAdapterGatewayAddress + Address windows.SocketAddress +} + +// IPAdapterAddresses structure is the header node for a linked list of addresses for a particular adapter. This structure can simultaneously be used as part of a linked list of IP_ADAPTER_ADDRESSES structures. +// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_addresses_lh +// This is a modified and extended version of windows.IpAdapterAddresses. +type IPAdapterAddresses struct { + Length uint32 + IfIndex uint32 + Next *IPAdapterAddresses + adapterName *byte + FirstUnicastAddress *windows.IpAdapterUnicastAddress + FirstAnycastAddress *windows.IpAdapterAnycastAddress + FirstMulticastAddress *windows.IpAdapterMulticastAddress + FirstDNSServerAddress *windows.IpAdapterDnsServerAdapter + dnsSuffix *uint16 + description *uint16 + friendlyName *uint16 + physicalAddress [windows.MAX_ADAPTER_ADDRESS_LENGTH]byte + physicalAddressLength uint32 + Flags IPAAFlags + MTU uint32 + IfType IfType + OperStatus IfOperStatus + IPv6IfIndex uint32 + ZoneIndices [16]uint32 + FirstPrefix *windows.IpAdapterPrefix + TransmitLinkSpeed uint64 + ReceiveLinkSpeed uint64 + FirstWINSServerAddress *IPAdapterWINSServerAddress + FirstGatewayAddress *IPAdapterGatewayAddress + Ipv4Metric uint32 + Ipv6Metric uint32 + LUID LUID + DHCPv4Server windows.SocketAddress + CompartmentID uint32 + NetworkGUID windows.GUID + ConnectionType NetIfConnectionType + TunnelType TunnelType + DHCPv6Server windows.SocketAddress + dhcpv6ClientDUID [maxDHCPv6DUIDLength]byte + dhcpv6ClientDUIDLength uint32 + DHCPv6IAID uint32 + FirstDNSSuffix *IPAdapterDNSSuffix +} + +// MibIPInterfaceRow structure stores interface management information for a particular IP address family on a network interface. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipinterface_row +type MibIPInterfaceRow struct { + Family AddressFamily + InterfaceLUID LUID + InterfaceIndex uint32 + MaxReassemblySize uint32 + InterfaceIdentifier uint64 + MinRouterAdvertisementInterval uint32 + MaxRouterAdvertisementInterval uint32 + AdvertisingEnabled bool + ForwardingEnabled bool + WeakHostSend bool + WeakHostReceive bool + UseAutomaticMetric bool + UseNeighborUnreachabilityDetection bool + ManagedAddressConfigurationSupported bool + OtherStatefulConfigurationSupported bool + AdvertiseDefaultRoute bool + RouterDiscoveryBehavior RouterDiscoveryBehavior + DadTransmits uint32 + BaseReachableTime uint32 + RetransmitTime uint32 + PathMTUDiscoveryTimeout uint32 + LinkLocalAddressBehavior LinkLocalAddressBehavior + LinkLocalAddressTimeout uint32 + ZoneIndices [ScopeLevelCount]uint32 + SitePrefixLength uint32 + Metric uint32 + NLMTU uint32 + Connected bool + SupportsWakeUpPatterns bool + SupportsNeighborDiscovery bool + SupportsRouterDiscovery bool + ReachableTime uint32 + TransmitOffload OffloadRod + ReceiveOffload OffloadRod + DisableDefaultRoutes bool +} + +// mibIPInterfaceTable structure contains a table of IP interface entries. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipinterface_table +type mibIPInterfaceTable struct { + numEntries uint32 + table [anySize]MibIPInterfaceRow +} + +// MibIfRow2 structure stores information about a particular interface. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_row2 +type MibIfRow2 struct { + InterfaceLUID LUID + InterfaceIndex uint32 + InterfaceGUID windows.GUID + alias [ifMaxStringSize + 1]uint16 + description [ifMaxStringSize + 1]uint16 + physicalAddressLength uint32 + physicalAddress [ifMaxPhysAddressLength]byte + permanentPhysicalAddress [ifMaxPhysAddressLength]byte + MTU uint32 + Type IfType + TunnelType TunnelType + MediaType NdisMedium + PhysicalMediumType NdisPhysicalMedium + AccessType NetIfAccessType + DirectionType NetIfDirectionType + InterfaceAndOperStatusFlags InterfaceAndOperStatusFlags + OperStatus IfOperStatus + AdminStatus NetIfAdminStatus + MediaConnectState NetIfMediaConnectState + NetworkGUID windows.GUID + ConnectionType NetIfConnectionType + TransmitLinkSpeed uint64 + ReceiveLinkSpeed uint64 + InOctets uint64 + InUcastPkts uint64 + InNUcastPkts uint64 + InDiscards uint64 + InErrors uint64 + InUnknownProtos uint64 + InUcastOctets uint64 + InMulticastOctets uint64 + InBroadcastOctets uint64 + OutOctets uint64 + OutUcastPkts uint64 + OutNUcastPkts uint64 + OutDiscards uint64 + OutErrors uint64 + OutUcastOctets uint64 + OutMulticastOctets uint64 + OutBroadcastOctets uint64 + OutQLen uint64 +} + +// mibIfTable2 structure contains a table of logical and physical interface entries. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_table2 +type mibIfTable2 struct { + numEntries uint32 + table [anySize]MibIfRow2 +} + +// MibUnicastIPAddressRow structure stores information about a unicast IP address. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_unicastipaddress_row +type MibUnicastIPAddressRow struct { + Address RawSockaddrInet + InterfaceLUID LUID + InterfaceIndex uint32 + PrefixOrigin PrefixOrigin + SuffixOrigin SuffixOrigin + ValidLifetime uint32 + PreferredLifetime uint32 + OnLinkPrefixLength uint8 + SkipAsSource bool + DadState DadState + ScopeID uint32 + CreationTimeStamp int64 +} + +// mibUnicastIPAddressTable structure contains a table of unicast IP address entries. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_unicastipaddress_table +type mibUnicastIPAddressTable struct { + numEntries uint32 + table [anySize]MibUnicastIPAddressRow +} + +// MibAnycastIPAddressRow structure stores information about an anycast IP address. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_anycastipaddress_row +type MibAnycastIPAddressRow struct { + Address RawSockaddrInet + InterfaceLUID LUID + InterfaceIndex uint32 + ScopeID uint32 +} + +// mibAnycastIPAddressTable structure contains a table of anycast IP address entries. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-mib_anycastipaddress_table +type mibAnycastIPAddressTable struct { + numEntries uint32 + table [anySize]MibAnycastIPAddressRow +} + +// mibIPforwardTable2 structure contains a table of IP route entries. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipforward_table2 +type mibIPforwardTable2 struct { + numEntries uint32 + table [anySize]MibIPforwardRow2 +} diff --git a/internal/winipcfg/types_test.go b/internal/winipcfg/types_test.go new file mode 100644 index 0000000..b72d73f --- /dev/null +++ b/internal/winipcfg/types_test.go @@ -0,0 +1,1056 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +import ( + "testing" + "unsafe" +) + +const ( + mibIPInterfaceRowSize = 168 + mibIPInterfaceRowInterfaceLUIDOffset = 8 + mibIPInterfaceRowInterfaceIndexOffset = 16 + mibIPInterfaceRowMaxReassemblySizeOffset = 20 + mibIPInterfaceRowInterfaceIdentifierOffset = 24 + mibIPInterfaceRowMinRouterAdvertisementIntervalOffset = 32 + mibIPInterfaceRowMaxRouterAdvertisementIntervalOffset = 36 + mibIPInterfaceRowAdvertisingEnabledOffset = 40 + mibIPInterfaceRowForwardingEnabledOffset = 41 + mibIPInterfaceRowWeakHostSendOffset = 42 + mibIPInterfaceRowWeakHostReceiveOffset = 43 + mibIPInterfaceRowUseAutomaticMetricOffset = 44 + mibIPInterfaceRowUseNeighborUnreachabilityDetectionOffset = 45 + mibIPInterfaceRowManagedAddressConfigurationSupportedOffset = 46 + mibIPInterfaceRowOtherStatefulConfigurationSupportedOffset = 47 + mibIPInterfaceRowAdvertiseDefaultRouteOffset = 48 + mibIPInterfaceRowRouterDiscoveryBehaviorOffset = 52 + mibIPInterfaceRowDadTransmitsOffset = 56 + mibIPInterfaceRowBaseReachableTimeOffset = 60 + mibIPInterfaceRowRetransmitTimeOffset = 64 + mibIPInterfaceRowPathMTUDiscoveryTimeoutOffset = 68 + mibIPInterfaceRowLinkLocalAddressBehaviorOffset = 72 + mibIPInterfaceRowLinkLocalAddressTimeoutOffset = 76 + mibIPInterfaceRowZoneIndicesOffset = 80 + mibIPInterfaceRowSitePrefixLengthOffset = 144 + mibIPInterfaceRowMetricOffset = 148 + mibIPInterfaceRowNLMTUOffset = 152 + mibIPInterfaceRowConnectedOffset = 156 + mibIPInterfaceRowSupportsWakeUpPatternsOffset = 157 + mibIPInterfaceRowSupportsNeighborDiscoveryOffset = 158 + mibIPInterfaceRowSupportsRouterDiscoveryOffset = 159 + mibIPInterfaceRowReachableTimeOffset = 160 + mibIPInterfaceRowTransmitOffloadOffset = 164 + mibIPInterfaceRowReceiveOffloadOffset = 165 + mibIPInterfaceRowDisableDefaultRoutesOffset = 166 + + mibIPInterfaceTableSize = 176 + mibIPInterfaceTableTableOffset = 8 + + mibIfRow2Size = 1352 + mibIfRow2InterfaceIndexOffset = 8 + mibIfRow2InterfaceGUIDOffset = 12 + mibIfRow2AliasOffset = 28 + mibIfRow2DescriptionOffset = 542 + mibIfRow2PhysicalAddressLengthOffset = 1056 + mibIfRow2PhysicalAddressOffset = 1060 + mibIfRow2PermanentPhysicalAddressOffset = 1092 + mibIfRow2MTUOffset = 1124 + mibIfRow2TypeOffset = 1128 + mibIfRow2TunnelTypeOffset = 1132 + mibIfRow2MediaTypeOffset = 1136 + mibIfRow2PhysicalMediumTypeOffset = 1140 + mibIfRow2AccessTypeOffset = 1144 + mibIfRow2DirectionTypeOffset = 1148 + mibIfRow2InterfaceAndOperStatusFlagsOffset = 1152 + mibIfRow2OperStatusOffset = 1156 + mibIfRow2AdminStatusOffset = 1160 + mibIfRow2MediaConnectStateOffset = 1164 + mibIfRow2NetworkGUIDOffset = 1168 + mibIfRow2ConnectionTypeOffset = 1184 + mibIfRow2TransmitLinkSpeedOffset = 1192 + mibIfRow2ReceiveLinkSpeedOffset = 1200 + mibIfRow2InOctetsOffset = 1208 + mibIfRow2InUcastPktsOffset = 1216 + mibIfRow2InNUcastPktsOffset = 1224 + mibIfRow2InDiscardsOffset = 1232 + mibIfRow2InErrorsOffset = 1240 + mibIfRow2InUnknownProtosOffset = 1248 + mibIfRow2InUcastOctetsOffset = 1256 + mibIfRow2InMulticastOctetsOffset = 1264 + mibIfRow2InBroadcastOctetsOffset = 1272 + mibIfRow2OutOctetsOffset = 1280 + mibIfRow2OutUcastPktsOffset = 1288 + mibIfRow2OutNUcastPktsOffset = 1296 + mibIfRow2OutDiscardsOffset = 1304 + mibIfRow2OutErrorsOffset = 1312 + mibIfRow2OutUcastOctetsOffset = 1320 + mibIfRow2OutMulticastOctetsOffset = 1328 + mibIfRow2OutBroadcastOctetsOffset = 1336 + mibIfRow2OutQLenOffset = 1344 + + mibIfTable2Size = 1360 + mibIfTable2TableOffset = 8 + + rawSockaddrInetSize = 28 + rawSockaddrInetDataOffset = 2 + + mibUnicastIPAddressRowSize = 80 + mibUnicastIPAddressRowInterfaceLUIDOffset = 32 + mibUnicastIPAddressRowInterfaceIndexOffset = 40 + mibUnicastIPAddressRowPrefixOriginOffset = 44 + mibUnicastIPAddressRowSuffixOriginOffset = 48 + mibUnicastIPAddressRowValidLifetimeOffset = 52 + mibUnicastIPAddressRowPreferredLifetimeOffset = 56 + mibUnicastIPAddressRowOnLinkPrefixLengthOffset = 60 + mibUnicastIPAddressRowSkipAsSourceOffset = 61 + mibUnicastIPAddressRowDadStateOffset = 64 + mibUnicastIPAddressRowScopeIDOffset = 68 + mibUnicastIPAddressRowCreationTimeStampOffset = 72 + + mibUnicastIPAddressTableSize = 88 + mibUnicastIPAddressTableTableOffset = 8 + + mibAnycastIPAddressRowSize = 48 + mibAnycastIPAddressRowInterfaceLUIDOffset = 32 + mibAnycastIPAddressRowInterfaceIndexOffset = 40 + mibAnycastIPAddressRowScopeIDOffset = 44 + + mibAnycastIPAddressTableSize = 56 + mibAnycastIPAddressTableTableOffset = 8 + + ipAddressPrefixSize = 32 + ipAddressPrefixPrefixLengthOffset = 28 + + mibIPforwardRow2Size = 104 + mibIPforwardRow2InterfaceIndexOffset = 8 + mibIPforwardRow2DestinationPrefixOffset = 12 + mibIPforwardRow2NextHopOffset = 44 + mibIPforwardRow2SitePrefixLengthOffset = 72 + mibIPforwardRow2ValidLifetimeOffset = 76 + mibIPforwardRow2PreferredLifetimeOffset = 80 + mibIPforwardRow2MetricOffset = 84 + mibIPforwardRow2ProtocolOffset = 88 + mibIPforwardRow2LoopbackOffset = 92 + mibIPforwardRow2AutoconfigureAddressOffset = 93 + mibIPforwardRow2PublishOffset = 94 + mibIPforwardRow2ImmortalOffset = 95 + mibIPforwardRow2AgeOffset = 96 + mibIPforwardRow2OriginOffset = 100 + + mibIPforwardTable2Size = 112 + mibIPforwardTable2TableOffset = 8 +) + +func TestIPAdapterWINSServerAddress(t *testing.T) { + s := IPAdapterWINSServerAddress{} + sp := uintptr(unsafe.Pointer(&s)) + const actualIPAdapterWINSServerAddressSize = unsafe.Sizeof(s) + + if actualIPAdapterWINSServerAddressSize != ipAdapterWINSServerAddressSize { + t.Errorf("Size of IPAdapterWINSServerAddress is %d, although %d is expected.", actualIPAdapterWINSServerAddressSize, ipAdapterWINSServerAddressSize) + } + + offset := uintptr(unsafe.Pointer(&s.Next)) - sp + if offset != ipAdapterWINSServerAddressNextOffset { + t.Errorf("IPAdapterWINSServerAddress.Next offset is %d although %d is expected", offset, ipAdapterWINSServerAddressNextOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Address)) - sp + if offset != ipAdapterWINSServerAddressAddressOffset { + t.Errorf("IPAdapterWINSServerAddress.Address offset is %d although %d is expected", offset, ipAdapterWINSServerAddressAddressOffset) + } +} + +func TestIPAdapterGatewayAddress(t *testing.T) { + s := IPAdapterGatewayAddress{} + sp := uintptr(unsafe.Pointer(&s)) + const actualIPAdapterGatewayAddressSize = unsafe.Sizeof(s) + + if actualIPAdapterGatewayAddressSize != ipAdapterGatewayAddressSize { + t.Errorf("Size of IPAdapterGatewayAddress is %d, although %d is expected.", actualIPAdapterGatewayAddressSize, ipAdapterGatewayAddressSize) + } + + offset := uintptr(unsafe.Pointer(&s.Next)) - sp + if offset != ipAdapterGatewayAddressNextOffset { + t.Errorf("IPAdapterGatewayAddress.Next offset is %d although %d is expected", offset, ipAdapterGatewayAddressNextOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Address)) - sp + if offset != ipAdapterGatewayAddressAddressOffset { + t.Errorf("IPAdapterGatewayAddress.Address offset is %d although %d is expected", offset, ipAdapterGatewayAddressAddressOffset) + } +} + +func TestIPAdapterDNSSuffix(t *testing.T) { + s := IPAdapterDNSSuffix{} + sp := uintptr(unsafe.Pointer(&s)) + const actualIPAdapterDNSSuffixSize = unsafe.Sizeof(s) + + if actualIPAdapterDNSSuffixSize != ipAdapterDNSSuffixSize { + t.Errorf("Size of IPAdapterDNSSuffix is %d, although %d is expected.", actualIPAdapterDNSSuffixSize, ipAdapterDNSSuffixSize) + } + + offset := uintptr(unsafe.Pointer(&s.str)) - sp + if offset != ipAdapterDNSSuffixStringOffset { + t.Errorf("IPAdapterDNSSuffix.str offset is %d although %d is expected", offset, ipAdapterDNSSuffixStringOffset) + } +} + +func TestInAdapterAddresses(t *testing.T) { + s := IPAdapterAddresses{} + sp := uintptr(unsafe.Pointer(&s)) + const actualIn6AddrSize = unsafe.Sizeof(s) + + if actualIn6AddrSize != ipAdapterAddressesSize { + t.Errorf("Size of IPAdapterAddresses is %d, although %d is expected.", actualIn6AddrSize, ipAdapterAddressesSize) + } + + offset := uintptr(unsafe.Pointer(&s.IfIndex)) - sp + if offset != ipAdapterAddressesIfIndexOffset { + t.Errorf("IPAdapterAddresses.IfIndex offset is %d although %d is expected", offset, ipAdapterAddressesIfIndexOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Next)) - sp + if offset != ipAdapterAddressesNextOffset { + t.Errorf("IPAdapterAddresses.Next offset is %d although %d is expected", offset, ipAdapterAddressesNextOffset) + } + + offset = uintptr(unsafe.Pointer(&s.adapterName)) - sp + if offset != ipAdapterAddressesAdapterNameOffset { + t.Errorf("IPAdapterAddresses.adapterName offset is %d although %d is expected", offset, ipAdapterAddressesAdapterNameOffset) + } + + offset = uintptr(unsafe.Pointer(&s.FirstUnicastAddress)) - sp + if offset != ipAdapterAddressesFirstUnicastAddressOffset { + t.Errorf("IPAdapterAddresses.FirstUnicastAddress offset is %d although %d is expected", offset, ipAdapterAddressesFirstUnicastAddressOffset) + } + + offset = uintptr(unsafe.Pointer(&s.FirstAnycastAddress)) - sp + if offset != ipAdapterAddressesFirstAnycastAddressOffset { + t.Errorf("IPAdapterAddresses.FirstAnycastAddress offset is %d although %d is expected", offset, ipAdapterAddressesFirstAnycastAddressOffset) + } + + offset = uintptr(unsafe.Pointer(&s.FirstMulticastAddress)) - sp + if offset != ipAdapterAddressesFirstMulticastAddressOffset { + t.Errorf("IPAdapterAddresses.FirstMulticastAddress offset is %d although %d is expected", offset, ipAdapterAddressesFirstMulticastAddressOffset) + } + + offset = uintptr(unsafe.Pointer(&s.FirstDNSServerAddress)) - sp + if offset != ipAdapterAddressesFirstDNSServerAddressOffset { + t.Errorf("IPAdapterAddresses.FirstDNSServerAddress offset is %d although %d is expected", offset, ipAdapterAddressesFirstDNSServerAddressOffset) + } + + offset = uintptr(unsafe.Pointer(&s.dnsSuffix)) - sp + if offset != ipAdapterAddressesDNSSuffixOffset { + t.Errorf("IPAdapterAddresses.DNSSuffix offset is %d although %d is expected", offset, ipAdapterAddressesDNSSuffixOffset) + } + + offset = uintptr(unsafe.Pointer(&s.description)) - sp + if offset != ipAdapterAddressesDescriptionOffset { + t.Errorf("IPAdapterAddresses.Description offset is %d although %d is expected", offset, ipAdapterAddressesDescriptionOffset) + } + + offset = uintptr(unsafe.Pointer(&s.friendlyName)) - sp + if offset != ipAdapterAddressesFriendlyNameOffset { + t.Errorf("IPAdapterAddresses.FriendlyName offset is %d although %d is expected", offset, ipAdapterAddressesFriendlyNameOffset) + } + + offset = uintptr(unsafe.Pointer(&s.physicalAddress)) - sp + if offset != ipAdapterAddressesPhysicalAddressOffset { + t.Errorf("IPAdapterAddresses.PhysicalAddress offset is %d although %d is expected", offset, ipAdapterAddressesPhysicalAddressOffset) + } + + offset = uintptr(unsafe.Pointer(&s.physicalAddressLength)) - sp + if offset != ipAdapterAddressesPhysicalAddressLengthOffset { + t.Errorf("IPAdapterAddresses.PhysicalAddressLength offset is %d although %d is expected", offset, ipAdapterAddressesPhysicalAddressLengthOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Flags)) - sp + if offset != ipAdapterAddressesFlagsOffset { + t.Errorf("IPAdapterAddresses.Flags offset is %d although %d is expected", offset, ipAdapterAddressesFlagsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.MTU)) - sp + if offset != ipAdapterAddressesMTUOffset { + t.Errorf("IPAdapterAddresses.MTU offset is %d although %d is expected", offset, ipAdapterAddressesMTUOffset) + } + + offset = uintptr(unsafe.Pointer(&s.IfType)) - sp + if offset != ipAdapterAddressesIfTypeOffset { + t.Errorf("IPAdapterAddresses.IfType offset is %d although %d is expected", offset, ipAdapterAddressesIfTypeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OperStatus)) - sp + if offset != ipAdapterAddressesOperStatusOffset { + t.Errorf("IPAdapterAddresses.OperStatus offset is %d although %d is expected", offset, ipAdapterAddressesOperStatusOffset) + } + + offset = uintptr(unsafe.Pointer(&s.IPv6IfIndex)) - sp + if offset != ipAdapterAddressesIPv6IfIndexOffset { + t.Errorf("IPAdapterAddresses.IPv6IfIndex offset is %d although %d is expected", offset, ipAdapterAddressesIPv6IfIndexOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ZoneIndices)) - sp + if offset != ipAdapterAddressesZoneIndicesOffset { + t.Errorf("IPAdapterAddresses.ZoneIndices offset is %d although %d is expected", offset, ipAdapterAddressesZoneIndicesOffset) + } + + offset = uintptr(unsafe.Pointer(&s.FirstPrefix)) - sp + if offset != ipAdapterAddressesFirstPrefixOffset { + t.Errorf("IPAdapterAddresses.FirstPrefix offset is %d although %d is expected", offset, ipAdapterAddressesFirstPrefixOffset) + } + + offset = uintptr(unsafe.Pointer(&s.TransmitLinkSpeed)) - sp + if offset != ipAdapterAddressesTransmitLinkSpeedOffset { + t.Errorf("IPAdapterAddresses.TransmitLinkSpeed offset is %d although %d is expected", offset, ipAdapterAddressesTransmitLinkSpeedOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ReceiveLinkSpeed)) - sp + if offset != ipAdapterAddressesReceiveLinkSpeedOffset { + t.Errorf("IPAdapterAddresses.ReceiveLinkSpeed offset is %d although %d is expected", offset, ipAdapterAddressesReceiveLinkSpeedOffset) + } + + offset = uintptr(unsafe.Pointer(&s.FirstWINSServerAddress)) - sp + if offset != ipAdapterAddressesFirstWINSServerAddressOffset { + t.Errorf("IPAdapterAddresses.FirstWINSServerAddress offset is %d although %d is expected", offset, ipAdapterAddressesFirstWINSServerAddressOffset) + } + + offset = uintptr(unsafe.Pointer(&s.FirstGatewayAddress)) - sp + if offset != ipAdapterAddressesFirstGatewayAddressOffset { + t.Errorf("IPAdapterAddresses.FirstGatewayAddress offset is %d although %d is expected", offset, ipAdapterAddressesFirstGatewayAddressOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Ipv4Metric)) - sp + if offset != ipAdapterAddressesIPv4MetricOffset { + t.Errorf("IPAdapterAddresses.IPv4Metric offset is %d although %d is expected", offset, ipAdapterAddressesIPv4MetricOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Ipv6Metric)) - sp + if offset != ipAdapterAddressesIPv6MetricOffset { + t.Errorf("IPAdapterAddresses.IPv6Metric offset is %d although %d is expected", offset, ipAdapterAddressesIPv6MetricOffset) + } + + offset = uintptr(unsafe.Pointer(&s.LUID)) - sp + if offset != ipAdapterAddressesLUIDOffset { + t.Errorf("IPAdapterAddresses.LUID offset is %d although %d is expected", offset, ipAdapterAddressesLUIDOffset) + } + + offset = uintptr(unsafe.Pointer(&s.DHCPv4Server)) - sp + if offset != ipAdapterAddressesDHCPv4ServerOffset { + t.Errorf("IPAdapterAddresses.DHCPv4Server offset is %d although %d is expected", offset, ipAdapterAddressesDHCPv4ServerOffset) + } + + offset = uintptr(unsafe.Pointer(&s.CompartmentID)) - sp + if offset != ipAdapterAddressesCompartmentIDOffset { + t.Errorf("IPAdapterAddresses.CompartmentID offset is %d although %d is expected", offset, ipAdapterAddressesCompartmentIDOffset) + } + + offset = uintptr(unsafe.Pointer(&s.NetworkGUID)) - sp + if offset != ipAdapterAddressesNetworkGUIDOffset { + t.Errorf("IPAdapterAddresses.NetworkGUID offset is %d although %d is expected", offset, ipAdapterAddressesNetworkGUIDOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ConnectionType)) - sp + if offset != ipAdapterAddressesConnectionTypeOffset { + t.Errorf("IPAdapterAddresses.ConnectionType offset is %d although %d is expected", offset, ipAdapterAddressesConnectionTypeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.TunnelType)) - sp + if offset != ipAdapterAddressesTunnelTypeOffset { + t.Errorf("IPAdapterAddresses.TunnelType offset is %d although %d is expected", offset, ipAdapterAddressesTunnelTypeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.DHCPv6Server)) - sp + if offset != ipAdapterAddressesDHCPv6ServerOffset { + t.Errorf("IPAdapterAddresses.DHCPv6Server offset is %d although %d is expected", offset, ipAdapterAddressesDHCPv6ServerOffset) + } + + offset = uintptr(unsafe.Pointer(&s.dhcpv6ClientDUID)) - sp + if offset != ipAdapterAddressesDHCPv6ClientDUIDOffset { + t.Errorf("IPAdapterAddresses.DHCPv6ClientDUID offset is %d although %d is expected", offset, ipAdapterAddressesDHCPv6ClientDUIDOffset) + } + + offset = uintptr(unsafe.Pointer(&s.dhcpv6ClientDUIDLength)) - sp + if offset != ipAdapterAddressesDHCPv6ClientDUIDLengthOffset { + t.Errorf("IPAdapterAddresses.DHCPv6ClientDUIDLength offset is %d although %d is expected", offset, ipAdapterAddressesDHCPv6ClientDUIDLengthOffset) + } + + offset = uintptr(unsafe.Pointer(&s.DHCPv6IAID)) - sp + if offset != ipAdapterAddressesDHCPv6IAIDOffset { + t.Errorf("IPAdapterAddresses.DHCPv6IAID offset is %d although %d is expected", offset, ipAdapterAddressesDHCPv6IAIDOffset) + } + + offset = uintptr(unsafe.Pointer(&s.FirstDNSSuffix)) - sp + if offset != ipAdapterAddressesFirstDNSSuffixOffset { + t.Errorf("IPAdapterAddresses.FirstDNSSuffix offset is %d although %d is expected", offset, ipAdapterAddressesFirstDNSSuffixOffset) + } +} + +func TestMibIPInterfaceRow(t *testing.T) { + s := MibIPInterfaceRow{} + sp := uintptr(unsafe.Pointer(&s)) + const actualTestMibIPInterfaceRowSize = unsafe.Sizeof(s) + + if actualTestMibIPInterfaceRowSize != mibIPInterfaceRowSize { + t.Errorf("Size of MibIPInterfaceRow is %d, although %d is expected.", actualTestMibIPInterfaceRowSize, mibIPInterfaceRowSize) + } + + offset := uintptr(unsafe.Pointer(&s.InterfaceLUID)) - sp + if offset != mibIPInterfaceRowInterfaceLUIDOffset { + t.Errorf("MibIPInterfaceRow.InterfaceLUID offset is %d although %d is expected", offset, mibIPInterfaceRowInterfaceLUIDOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InterfaceIndex)) - sp + if offset != mibIPInterfaceRowInterfaceIndexOffset { + t.Errorf("MibIPInterfaceRow.InterfaceIndex offset is %d although %d is expected", offset, mibIPInterfaceRowInterfaceIndexOffset) + } + + offset = uintptr(unsafe.Pointer(&s.MaxReassemblySize)) - sp + if offset != mibIPInterfaceRowMaxReassemblySizeOffset { + t.Errorf("mibIPInterfaceRow.MaxReassemblySize offset is %d although %d is expected", offset, mibIPInterfaceRowMaxReassemblySizeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InterfaceIdentifier)) - sp + if offset != mibIPInterfaceRowInterfaceIdentifierOffset { + t.Errorf("MibIPInterfaceRow.InterfaceIdentifier offset is %d although %d is expected", offset, mibIPInterfaceRowInterfaceIdentifierOffset) + } + + offset = uintptr(unsafe.Pointer(&s.MinRouterAdvertisementInterval)) - sp + if offset != mibIPInterfaceRowMinRouterAdvertisementIntervalOffset { + t.Errorf("MibIPInterfaceRow.MinRouterAdvertisementInterval offset is %d although %d is expected", offset, mibIPInterfaceRowMinRouterAdvertisementIntervalOffset) + } + + offset = uintptr(unsafe.Pointer(&s.MaxRouterAdvertisementInterval)) - sp + if offset != mibIPInterfaceRowMaxRouterAdvertisementIntervalOffset { + t.Errorf("MibIPInterfaceRow.MaxRouterAdvertisementInterval offset is %d although %d is expected", offset, mibIPInterfaceRowMaxRouterAdvertisementIntervalOffset) + } + + offset = uintptr(unsafe.Pointer(&s.AdvertisingEnabled)) - sp + if offset != mibIPInterfaceRowAdvertisingEnabledOffset { + t.Errorf("MibIPInterfaceRow.AdvertisingEnabled offset is %d although %d is expected", offset, mibIPInterfaceRowAdvertisingEnabledOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ForwardingEnabled)) - sp + if offset != mibIPInterfaceRowForwardingEnabledOffset { + t.Errorf("MibIPInterfaceRow.ForwardingEnabled offset is %d although %d is expected", offset, mibIPInterfaceRowForwardingEnabledOffset) + } + + offset = uintptr(unsafe.Pointer(&s.WeakHostSend)) - sp + if offset != mibIPInterfaceRowWeakHostSendOffset { + t.Errorf("MibIPInterfaceRow.WeakHostSend offset is %d although %d is expected", offset, mibIPInterfaceRowWeakHostSendOffset) + } + + offset = uintptr(unsafe.Pointer(&s.WeakHostReceive)) - sp + if offset != mibIPInterfaceRowWeakHostReceiveOffset { + t.Errorf("MibIPInterfaceRow.WeakHostReceive offset is %d although %d is expected", offset, mibIPInterfaceRowWeakHostReceiveOffset) + } + + offset = uintptr(unsafe.Pointer(&s.UseAutomaticMetric)) - sp + if offset != mibIPInterfaceRowUseAutomaticMetricOffset { + t.Errorf("MibIPInterfaceRow.UseAutomaticMetric offset is %d although %d is expected", offset, mibIPInterfaceRowUseAutomaticMetricOffset) + } + + offset = uintptr(unsafe.Pointer(&s.UseNeighborUnreachabilityDetection)) - sp + if offset != mibIPInterfaceRowUseNeighborUnreachabilityDetectionOffset { + t.Errorf("MibIPInterfaceRow.UseNeighborUnreachabilityDetection offset is %d although %d is expected", offset, mibIPInterfaceRowUseNeighborUnreachabilityDetectionOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ManagedAddressConfigurationSupported)) - sp + if offset != mibIPInterfaceRowManagedAddressConfigurationSupportedOffset { + t.Errorf("MibIPInterfaceRow.ManagedAddressConfigurationSupported offset is %d although %d is expected", offset, mibIPInterfaceRowManagedAddressConfigurationSupportedOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OtherStatefulConfigurationSupported)) - sp + if offset != mibIPInterfaceRowOtherStatefulConfigurationSupportedOffset { + t.Errorf("MibIPInterfaceRow.OtherStatefulConfigurationSupported offset is %d although %d is expected", offset, mibIPInterfaceRowOtherStatefulConfigurationSupportedOffset) + } + + offset = uintptr(unsafe.Pointer(&s.AdvertiseDefaultRoute)) - sp + if offset != mibIPInterfaceRowAdvertiseDefaultRouteOffset { + t.Errorf("MibIPInterfaceRow.AdvertiseDefaultRoute offset is %d although %d is expected", offset, mibIPInterfaceRowAdvertiseDefaultRouteOffset) + } + + offset = uintptr(unsafe.Pointer(&s.RouterDiscoveryBehavior)) - sp + if offset != mibIPInterfaceRowRouterDiscoveryBehaviorOffset { + t.Errorf("MibIPInterfaceRow.RouterDiscoveryBehavior offset is %d although %d is expected", offset, mibIPInterfaceRowRouterDiscoveryBehaviorOffset) + } + + offset = uintptr(unsafe.Pointer(&s.DadTransmits)) - sp + if offset != mibIPInterfaceRowDadTransmitsOffset { + t.Errorf("MibIPInterfaceRow.DadTransmits offset is %d although %d is expected", offset, mibIPInterfaceRowDadTransmitsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.BaseReachableTime)) - sp + if offset != mibIPInterfaceRowBaseReachableTimeOffset { + t.Errorf("MibIPInterfaceRow.BaseReachableTime offset is %d although %d is expected", offset, mibIPInterfaceRowBaseReachableTimeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.RetransmitTime)) - sp + if offset != mibIPInterfaceRowRetransmitTimeOffset { + t.Errorf("MibIPInterfaceRow.RetransmitTime offset is %d although %d is expected", offset, mibIPInterfaceRowRetransmitTimeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.PathMTUDiscoveryTimeout)) - sp + if offset != mibIPInterfaceRowPathMTUDiscoveryTimeoutOffset { + t.Errorf("MibIPInterfaceRow.PathMTUDiscoveryTimeout offset is %d although %d is expected", offset, mibIPInterfaceRowPathMTUDiscoveryTimeoutOffset) + } + + offset = uintptr(unsafe.Pointer(&s.LinkLocalAddressBehavior)) - sp + if offset != mibIPInterfaceRowLinkLocalAddressBehaviorOffset { + t.Errorf("MibIPInterfaceRow.LinkLocalAddressBehavior offset is %d although %d is expected", offset, mibIPInterfaceRowLinkLocalAddressBehaviorOffset) + } + + offset = uintptr(unsafe.Pointer(&s.LinkLocalAddressTimeout)) - sp + if offset != mibIPInterfaceRowLinkLocalAddressTimeoutOffset { + t.Errorf("MibIPInterfaceRow.LinkLocalAddressTimeout offset is %d although %d is expected", offset, mibIPInterfaceRowLinkLocalAddressTimeoutOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ZoneIndices)) - sp + if offset != mibIPInterfaceRowZoneIndicesOffset { + t.Errorf("MibIPInterfaceRow.ZoneIndices offset is %d although %d is expected", offset, mibIPInterfaceRowZoneIndicesOffset) + } + + offset = uintptr(unsafe.Pointer(&s.SitePrefixLength)) - sp + if offset != mibIPInterfaceRowSitePrefixLengthOffset { + t.Errorf("MibIPInterfaceRow.SitePrefixLength offset is %d although %d is expected", offset, mibIPInterfaceRowSitePrefixLengthOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Metric)) - sp + if offset != mibIPInterfaceRowMetricOffset { + t.Errorf("MibIPInterfaceRow.Metric offset is %d although %d is expected", offset, mibIPInterfaceRowMetricOffset) + } + + offset = uintptr(unsafe.Pointer(&s.NLMTU)) - sp + if offset != mibIPInterfaceRowNLMTUOffset { + t.Errorf("MibIPInterfaceRow.NLMTU offset is %d although %d is expected", offset, mibIPInterfaceRowNLMTUOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Connected)) - sp + if offset != mibIPInterfaceRowConnectedOffset { + t.Errorf("MibIPInterfaceRow.Connected offset is %d although %d is expected", offset, mibIPInterfaceRowConnectedOffset) + } + + offset = uintptr(unsafe.Pointer(&s.SupportsWakeUpPatterns)) - sp + if offset != mibIPInterfaceRowSupportsWakeUpPatternsOffset { + t.Errorf("MibIPInterfaceRow.SupportsWakeUpPatterns offset is %d although %d is expected", offset, mibIPInterfaceRowSupportsWakeUpPatternsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.SupportsNeighborDiscovery)) - sp + if offset != mibIPInterfaceRowSupportsNeighborDiscoveryOffset { + t.Errorf("MibIPInterfaceRow.SupportsNeighborDiscovery offset is %d although %d is expected", offset, mibIPInterfaceRowSupportsNeighborDiscoveryOffset) + } + + offset = uintptr(unsafe.Pointer(&s.SupportsRouterDiscovery)) - sp + if offset != mibIPInterfaceRowSupportsRouterDiscoveryOffset { + t.Errorf("MibIPInterfaceRow.SupportsRouterDiscovery offset is %d although %d is expected", offset, mibIPInterfaceRowSupportsRouterDiscoveryOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ReachableTime)) - sp + if offset != mibIPInterfaceRowReachableTimeOffset { + t.Errorf("MibIPInterfaceRow.ReachableTime offset is %d although %d is expected", offset, mibIPInterfaceRowReachableTimeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.TransmitOffload)) - sp + if offset != mibIPInterfaceRowTransmitOffloadOffset { + t.Errorf("MibIPInterfaceRow.TransmitOffload offset is %d although %d is expected", offset, mibIPInterfaceRowTransmitOffloadOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ReceiveOffload)) - sp + if offset != mibIPInterfaceRowReceiveOffloadOffset { + t.Errorf("MibIPInterfaceRow.ReceiveOffload offset is %d although %d is expected", offset, mibIPInterfaceRowReceiveOffloadOffset) + } + + offset = uintptr(unsafe.Pointer(&s.DisableDefaultRoutes)) - sp + if offset != mibIPInterfaceRowDisableDefaultRoutesOffset { + t.Errorf("MibIPInterfaceRow.DisableDefaultRoutes offset is %d although %d is expected", offset, mibIPInterfaceRowDisableDefaultRoutesOffset) + } +} + +func TestMibIPInterfaceTable(t *testing.T) { + s := mibIPInterfaceTable{} + sp := uintptr(unsafe.Pointer(&s)) + const actualmibIPInterfaceTableSize = unsafe.Sizeof(s) + + if actualmibIPInterfaceTableSize != mibIPInterfaceTableSize { + t.Errorf("Size of mibIPInterfaceTable is %d, although %d is expected.", actualmibIPInterfaceTableSize, mibIPInterfaceTableSize) + } + + offset := uintptr(unsafe.Pointer(&s.table)) - sp + if offset != mibIPInterfaceTableTableOffset { + t.Errorf("mibIPInterfaceTable.table offset is %d although %d is expected", offset, mibIPInterfaceTableTableOffset) + } +} + +func TestMibIfRow2(t *testing.T) { + s := MibIfRow2{} + sp := uintptr(unsafe.Pointer(&s)) + const actualMibIfRow2Size = unsafe.Sizeof(s) + + if actualMibIfRow2Size != mibIfRow2Size { + t.Errorf("Size of MibIfRow2 is %d, although %d is expected.", actualMibIfRow2Size, mibIfRow2Size) + } + + offset := uintptr(unsafe.Pointer(&s.InterfaceIndex)) - sp + if offset != mibIfRow2InterfaceIndexOffset { + t.Errorf("MibIfRow2.InterfaceIndex offset is %d although %d is expected", offset, mibIfRow2InterfaceIndexOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InterfaceGUID)) - sp + if offset != mibIfRow2InterfaceGUIDOffset { + t.Errorf("MibIfRow2.InterfaceGUID offset is %d although %d is expected", offset, mibIfRow2InterfaceGUIDOffset) + } + + offset = uintptr(unsafe.Pointer(&s.alias)) - sp + if offset != mibIfRow2AliasOffset { + t.Errorf("MibIfRow2.alias offset is %d although %d is expected", offset, mibIfRow2AliasOffset) + } + + offset = uintptr(unsafe.Pointer(&s.description)) - sp + if offset != mibIfRow2DescriptionOffset { + t.Errorf("MibIfRow2.description offset is %d although %d is expected", offset, mibIfRow2DescriptionOffset) + } + + offset = uintptr(unsafe.Pointer(&s.physicalAddressLength)) - sp + if offset != mibIfRow2PhysicalAddressLengthOffset { + t.Errorf("MibIfRow2.physicalAddressLength offset is %d although %d is expected", offset, mibIfRow2PhysicalAddressLengthOffset) + } + + offset = uintptr(unsafe.Pointer(&s.physicalAddress)) - sp + if offset != mibIfRow2PhysicalAddressOffset { + t.Errorf("MibIfRow2.physicalAddress offset is %d although %d is expected", offset, mibIfRow2PhysicalAddressOffset) + } + + offset = uintptr(unsafe.Pointer(&s.permanentPhysicalAddress)) - sp + if offset != mibIfRow2PermanentPhysicalAddressOffset { + t.Errorf("MibIfRow2.permanentPhysicalAddress offset is %d although %d is expected", offset, mibIfRow2PermanentPhysicalAddressOffset) + } + + offset = uintptr(unsafe.Pointer(&s.MTU)) - sp + if offset != mibIfRow2MTUOffset { + t.Errorf("MibIfRow2.MTU offset is %d although %d is expected", offset, mibIfRow2MTUOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Type)) - sp + if offset != mibIfRow2TypeOffset { + t.Errorf("MibIfRow2.Type offset is %d although %d is expected", offset, mibIfRow2TypeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.TunnelType)) - sp + if offset != mibIfRow2TunnelTypeOffset { + t.Errorf("MibIfRow2.TunnelType offset is %d although %d is expected", offset, mibIfRow2TunnelTypeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.MediaType)) - sp + if offset != mibIfRow2MediaTypeOffset { + t.Errorf("MibIfRow2.MediaType offset is %d although %d is expected", offset, mibIfRow2MediaTypeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.PhysicalMediumType)) - sp + if offset != mibIfRow2PhysicalMediumTypeOffset { + t.Errorf("MibIfRow2.PhysicalMediumType offset is %d although %d is expected", offset, mibIfRow2PhysicalMediumTypeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.AccessType)) - sp + if offset != mibIfRow2AccessTypeOffset { + t.Errorf("MibIfRow2.AccessType offset is %d although %d is expected", offset, mibIfRow2AccessTypeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.DirectionType)) - sp + if offset != mibIfRow2DirectionTypeOffset { + t.Errorf("MibIfRow2.DirectionType offset is %d although %d is expected", offset, mibIfRow2DirectionTypeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InterfaceAndOperStatusFlags)) - sp + if offset != mibIfRow2InterfaceAndOperStatusFlagsOffset { + t.Errorf("MibIfRow2.InterfaceAndOperStatusFlags offset is %d although %d is expected", offset, mibIfRow2InterfaceAndOperStatusFlagsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OperStatus)) - sp + if offset != mibIfRow2OperStatusOffset { + t.Errorf("MibIfRow2.OperStatus offset is %d although %d is expected", offset, mibIfRow2OperStatusOffset) + } + + offset = uintptr(unsafe.Pointer(&s.AdminStatus)) - sp + if offset != mibIfRow2AdminStatusOffset { + t.Errorf("MibIfRow2.AdminStatus offset is %d although %d is expected", offset, mibIfRow2AdminStatusOffset) + } + + offset = uintptr(unsafe.Pointer(&s.MediaConnectState)) - sp + if offset != mibIfRow2MediaConnectStateOffset { + t.Errorf("MibIfRow2.MediaConnectState offset is %d although %d is expected", offset, mibIfRow2MediaConnectStateOffset) + } + + offset = uintptr(unsafe.Pointer(&s.NetworkGUID)) - sp + if offset != mibIfRow2NetworkGUIDOffset { + t.Errorf("MibIfRow2.NetworkGUID offset is %d although %d is expected", offset, mibIfRow2NetworkGUIDOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ConnectionType)) - sp + if offset != mibIfRow2ConnectionTypeOffset { + t.Errorf("MibIfRow2.ConnectionType offset is %d although %d is expected", offset, mibIfRow2ConnectionTypeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.TransmitLinkSpeed)) - sp + if offset != mibIfRow2TransmitLinkSpeedOffset { + t.Errorf("MibIfRow2.TransmitLinkSpeed offset is %d although %d is expected", offset, mibIfRow2TransmitLinkSpeedOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ReceiveLinkSpeed)) - sp + if offset != mibIfRow2ReceiveLinkSpeedOffset { + t.Errorf("MibIfRow2.ReceiveLinkSpeed offset is %d although %d is expected", offset, mibIfRow2ReceiveLinkSpeedOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InOctets)) - sp + if offset != mibIfRow2InOctetsOffset { + t.Errorf("MibIfRow2.InOctets offset is %d although %d is expected", offset, mibIfRow2InOctetsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InUcastPkts)) - sp + if offset != mibIfRow2InUcastPktsOffset { + t.Errorf("MibIfRow2.InUcastPkts offset is %d although %d is expected", offset, mibIfRow2InUcastPktsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InNUcastPkts)) - sp + if offset != mibIfRow2InNUcastPktsOffset { + t.Errorf("MibIfRow2.InNUcastPkts offset is %d although %d is expected", offset, mibIfRow2InNUcastPktsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InDiscards)) - sp + if offset != mibIfRow2InDiscardsOffset { + t.Errorf("MibIfRow2.InDiscards offset is %d although %d is expected", offset, mibIfRow2InDiscardsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InErrors)) - sp + if offset != mibIfRow2InErrorsOffset { + t.Errorf("MibIfRow2.InErrors offset is %d although %d is expected", offset, mibIfRow2InErrorsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InUnknownProtos)) - sp + if offset != mibIfRow2InUnknownProtosOffset { + t.Errorf("MibIfRow2.InUnknownProtos offset is %d although %d is expected", offset, mibIfRow2InUnknownProtosOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InUcastOctets)) - sp + if offset != mibIfRow2InUcastOctetsOffset { + t.Errorf("MibIfRow2.InUcastOctets offset is %d although %d is expected", offset, mibIfRow2InUcastOctetsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InMulticastOctets)) - sp + if offset != mibIfRow2InMulticastOctetsOffset { + t.Errorf("MibIfRow2.InMulticastOctets offset is %d although %d is expected", offset, mibIfRow2InMulticastOctetsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InBroadcastOctets)) - sp + if offset != mibIfRow2InBroadcastOctetsOffset { + t.Errorf("MibIfRow2.InBroadcastOctets offset is %d although %d is expected", offset, mibIfRow2InBroadcastOctetsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OutOctets)) - sp + if offset != mibIfRow2OutOctetsOffset { + t.Errorf("MibIfRow2.OutOctets offset is %d although %d is expected", offset, mibIfRow2OutOctetsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OutUcastPkts)) - sp + if offset != mibIfRow2OutUcastPktsOffset { + t.Errorf("MibIfRow2.OutUcastPkts offset is %d although %d is expected", offset, mibIfRow2OutUcastPktsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OutNUcastPkts)) - sp + if offset != mibIfRow2OutNUcastPktsOffset { + t.Errorf("MibIfRow2.OutNUcastPkts offset is %d although %d is expected", offset, mibIfRow2OutNUcastPktsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OutDiscards)) - sp + if offset != mibIfRow2OutDiscardsOffset { + t.Errorf("MibIfRow2.OutDiscards offset is %d although %d is expected", offset, mibIfRow2OutDiscardsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OutErrors)) - sp + if offset != mibIfRow2OutErrorsOffset { + t.Errorf("MibIfRow2.OutErrors offset is %d although %d is expected", offset, mibIfRow2OutErrorsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OutUcastOctets)) - sp + if offset != mibIfRow2OutUcastOctetsOffset { + t.Errorf("MibIfRow2.OutUcastOctets offset is %d although %d is expected", offset, mibIfRow2OutUcastOctetsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OutMulticastOctets)) - sp + if offset != mibIfRow2OutMulticastOctetsOffset { + t.Errorf("MibIfRow2.OutMulticastOctets offset is %d although %d is expected", offset, mibIfRow2OutMulticastOctetsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OutBroadcastOctets)) - sp + if offset != mibIfRow2OutBroadcastOctetsOffset { + t.Errorf("MibIfRow2.OutBroadcastOctets offset is %d although %d is expected", offset, mibIfRow2OutBroadcastOctetsOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OutQLen)) - sp + if offset != mibIfRow2OutQLenOffset { + t.Errorf("MibIfRow2.OutQLen offset is %d although %d is expected", offset, mibIfRow2OutQLenOffset) + } +} + +func TestMibIfTable2(t *testing.T) { + s := mibIfTable2{} + sp := uintptr(unsafe.Pointer(&s)) + const actualmibIfTable2Size = unsafe.Sizeof(s) + + if actualmibIfTable2Size != mibIfTable2Size { + t.Errorf("Size of mibIfTable2 is %d, although %d is expected.", actualmibIfTable2Size, mibIfTable2Size) + } + + offset := uintptr(unsafe.Pointer(&s.table)) - sp + if offset != mibIfTable2TableOffset { + t.Errorf("mibIfTable2.table offset is %d although %d is expected", offset, mibIfTable2TableOffset) + } +} + +func TestRawSockaddrInet(t *testing.T) { + s := RawSockaddrInet{} + sp := uintptr(unsafe.Pointer(&s)) + const actualRawSockaddrInetSize = unsafe.Sizeof(s) + + if actualRawSockaddrInetSize != rawSockaddrInetSize { + t.Errorf("Size of RawSockaddrInet is %d, although %d is expected.", actualRawSockaddrInetSize, rawSockaddrInetSize) + } + + offset := uintptr(unsafe.Pointer(&s.data)) - sp + if offset != rawSockaddrInetDataOffset { + t.Errorf("RawSockaddrInet.data offset is %d although %d is expected", offset, rawSockaddrInetDataOffset) + } +} + +func TestMibUnicastIPAddressRow(t *testing.T) { + s := MibUnicastIPAddressRow{} + sp := uintptr(unsafe.Pointer(&s)) + const actualMibUnicastIPAddressRowSize = unsafe.Sizeof(s) + + if actualMibUnicastIPAddressRowSize != mibUnicastIPAddressRowSize { + t.Errorf("Size of MibUnicastIPAddressRow is %d, although %d is expected.", actualMibUnicastIPAddressRowSize, mibUnicastIPAddressRowSize) + } + + offset := uintptr(unsafe.Pointer(&s.InterfaceLUID)) - sp + if offset != mibUnicastIPAddressRowInterfaceLUIDOffset { + t.Errorf("MibUnicastIPAddressRow.InterfaceLUID offset is %d although %d is expected", offset, mibUnicastIPAddressRowInterfaceLUIDOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InterfaceIndex)) - sp + if offset != mibUnicastIPAddressRowInterfaceIndexOffset { + t.Errorf("MibUnicastIPAddressRow.InterfaceIndex offset is %d although %d is expected", offset, mibUnicastIPAddressRowInterfaceIndexOffset) + } + + offset = uintptr(unsafe.Pointer(&s.PrefixOrigin)) - sp + if offset != mibUnicastIPAddressRowPrefixOriginOffset { + t.Errorf("MibUnicastIPAddressRow.PrefixOrigin offset is %d although %d is expected", offset, mibUnicastIPAddressRowPrefixOriginOffset) + } + + offset = uintptr(unsafe.Pointer(&s.SuffixOrigin)) - sp + if offset != mibUnicastIPAddressRowSuffixOriginOffset { + t.Errorf("MibUnicastIPAddressRow.SuffixOrigin offset is %d although %d is expected", offset, mibUnicastIPAddressRowSuffixOriginOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ValidLifetime)) - sp + if offset != mibUnicastIPAddressRowValidLifetimeOffset { + t.Errorf("MibUnicastIPAddressRow.ValidLifetime offset is %d although %d is expected", offset, mibUnicastIPAddressRowValidLifetimeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.PreferredLifetime)) - sp + if offset != mibUnicastIPAddressRowPreferredLifetimeOffset { + t.Errorf("MibUnicastIPAddressRow.PreferredLifetime offset is %d although %d is expected", offset, mibUnicastIPAddressRowPreferredLifetimeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.OnLinkPrefixLength)) - sp + if offset != mibUnicastIPAddressRowOnLinkPrefixLengthOffset { + t.Errorf("MibUnicastIPAddressRow.OnLinkPrefixLength offset is %d although %d is expected", offset, mibUnicastIPAddressRowOnLinkPrefixLengthOffset) + } + + offset = uintptr(unsafe.Pointer(&s.SkipAsSource)) - sp + if offset != mibUnicastIPAddressRowSkipAsSourceOffset { + t.Errorf("MibUnicastIPAddressRow.SkipAsSource offset is %d although %d is expected", offset, mibUnicastIPAddressRowSkipAsSourceOffset) + } + + offset = uintptr(unsafe.Pointer(&s.DadState)) - sp + if offset != mibUnicastIPAddressRowDadStateOffset { + t.Errorf("MibUnicastIPAddressRow.DadState offset is %d although %d is expected", offset, mibUnicastIPAddressRowDadStateOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ScopeID)) - sp + if offset != mibUnicastIPAddressRowScopeIDOffset { + t.Errorf("MibUnicastIPAddressRow.ScopeID offset is %d although %d is expected", offset, mibUnicastIPAddressRowScopeIDOffset) + } + + offset = uintptr(unsafe.Pointer(&s.CreationTimeStamp)) - sp + if offset != mibUnicastIPAddressRowCreationTimeStampOffset { + t.Errorf("MibUnicastIPAddressRow.CreationTimeStamp offset is %d although %d is expected", offset, mibUnicastIPAddressRowCreationTimeStampOffset) + } +} + +func TestMibUnicastIPAddressTable(t *testing.T) { + s := mibUnicastIPAddressTable{} + sp := uintptr(unsafe.Pointer(&s)) + const actualmibUnicastIPAddressTableSize = unsafe.Sizeof(s) + + if actualmibUnicastIPAddressTableSize != mibUnicastIPAddressTableSize { + t.Errorf("Size of mibUnicastIPAddressTable is %d, although %d is expected.", actualmibUnicastIPAddressTableSize, mibUnicastIPAddressTableSize) + } + + offset := uintptr(unsafe.Pointer(&s.table)) - sp + if offset != mibUnicastIPAddressTableTableOffset { + t.Errorf("mibUnicastIPAddressTable.table offset is %d although %d is expected", offset, mibUnicastIPAddressTableTableOffset) + } +} + +func TestMibAnycastIPAddressRow(t *testing.T) { + s := MibAnycastIPAddressRow{} + sp := uintptr(unsafe.Pointer(&s)) + const actualMibAnycastIPAddressRowSize = unsafe.Sizeof(s) + + if actualMibAnycastIPAddressRowSize != mibAnycastIPAddressRowSize { + t.Errorf("Size of MibAnycastIPAddressRow is %d, although %d is expected.", actualMibAnycastIPAddressRowSize, mibAnycastIPAddressRowSize) + } + + offset := uintptr(unsafe.Pointer(&s.InterfaceLUID)) - sp + if offset != mibAnycastIPAddressRowInterfaceLUIDOffset { + t.Errorf("MibAnycastIPAddressRow.InterfaceLUID offset is %d although %d is expected", offset, mibAnycastIPAddressRowInterfaceLUIDOffset) + } + + offset = uintptr(unsafe.Pointer(&s.InterfaceIndex)) - sp + if offset != mibAnycastIPAddressRowInterfaceIndexOffset { + t.Errorf("MibAnycastIPAddressRow.InterfaceIndex offset is %d although %d is expected", offset, mibAnycastIPAddressRowInterfaceIndexOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ScopeID)) - sp + if offset != mibAnycastIPAddressRowScopeIDOffset { + t.Errorf("MibAnycastIPAddressRow.ScopeID offset is %d although %d is expected", offset, mibAnycastIPAddressRowScopeIDOffset) + } +} + +func TestMibAnycastIPAddressTable(t *testing.T) { + s := mibAnycastIPAddressTable{} + sp := uintptr(unsafe.Pointer(&s)) + const actualmibAnycastIPAddressTableSize = unsafe.Sizeof(s) + + if actualmibAnycastIPAddressTableSize != mibAnycastIPAddressTableSize { + t.Errorf("Size of mibAnycastIPAddressTable is %d, although %d is expected.", actualmibAnycastIPAddressTableSize, mibAnycastIPAddressTableSize) + } + + offset := uintptr(unsafe.Pointer(&s.table)) - sp + if offset != mibAnycastIPAddressTableTableOffset { + t.Errorf("mibAnycastIPAddressTable.table offset is %d although %d is expected", offset, mibAnycastIPAddressTableTableOffset) + } +} + +func TestIPAddressPrefix(t *testing.T) { + s := IPAddressPrefix{} + sp := uintptr(unsafe.Pointer(&s)) + const actualIPAddressPrefixSize = unsafe.Sizeof(s) + + if actualIPAddressPrefixSize != ipAddressPrefixSize { + t.Errorf("Size of IPAddressPrefix is %d, although %d is expected.", actualIPAddressPrefixSize, ipAddressPrefixSize) + } + + offset := uintptr(unsafe.Pointer(&s.PrefixLength)) - sp + if offset != ipAddressPrefixPrefixLengthOffset { + t.Errorf("IPAddressPrefix.PrefixLength offset is %d although %d is expected", offset, ipAddressPrefixPrefixLengthOffset) + } +} + +func TestMibIPforwardRow2(t *testing.T) { + s := MibIPforwardRow2{} + sp := uintptr(unsafe.Pointer(&s)) + const actualMibIPforwardRow2Size = unsafe.Sizeof(s) + + if actualMibIPforwardRow2Size != mibIPforwardRow2Size { + t.Errorf("Size of MibIPforwardRow2 is %d, although %d is expected.", actualMibIPforwardRow2Size, mibIPforwardRow2Size) + } + + offset := uintptr(unsafe.Pointer(&s.InterfaceIndex)) - sp + if offset != mibIPforwardRow2InterfaceIndexOffset { + t.Errorf("MibIPforwardRow2.InterfaceIndex offset is %d although %d is expected", offset, mibIPforwardRow2InterfaceIndexOffset) + } + + offset = uintptr(unsafe.Pointer(&s.DestinationPrefix)) - sp + if offset != mibIPforwardRow2DestinationPrefixOffset { + t.Errorf("MibIPforwardRow2.DestinationPrefix offset is %d although %d is expected", offset, mibIPforwardRow2DestinationPrefixOffset) + } + + offset = uintptr(unsafe.Pointer(&s.NextHop)) - sp + if offset != mibIPforwardRow2NextHopOffset { + t.Errorf("MibIPforwardRow2.NextHop offset is %d although %d is expected", offset, mibIPforwardRow2NextHopOffset) + } + + offset = uintptr(unsafe.Pointer(&s.SitePrefixLength)) - sp + if offset != mibIPforwardRow2SitePrefixLengthOffset { + t.Errorf("MibIPforwardRow2.SitePrefixLength offset is %d although %d is expected", offset, mibIPforwardRow2SitePrefixLengthOffset) + } + + offset = uintptr(unsafe.Pointer(&s.ValidLifetime)) - sp + if offset != mibIPforwardRow2ValidLifetimeOffset { + t.Errorf("MibIPforwardRow2.ValidLifetime offset is %d although %d is expected", offset, mibIPforwardRow2ValidLifetimeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.PreferredLifetime)) - sp + if offset != mibIPforwardRow2PreferredLifetimeOffset { + t.Errorf("MibIPforwardRow2.PreferredLifetime offset is %d although %d is expected", offset, mibIPforwardRow2PreferredLifetimeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Metric)) - sp + if offset != mibIPforwardRow2MetricOffset { + t.Errorf("MibIPforwardRow2.Metric offset is %d although %d is expected", offset, mibIPforwardRow2MetricOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Protocol)) - sp + if offset != mibIPforwardRow2ProtocolOffset { + t.Errorf("MibIPforwardRow2.Protocol offset is %d although %d is expected", offset, mibIPforwardRow2ProtocolOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Loopback)) - sp + if offset != mibIPforwardRow2LoopbackOffset { + t.Errorf("MibIPforwardRow2.Loopback offset is %d although %d is expected", offset, mibIPforwardRow2LoopbackOffset) + } + + offset = uintptr(unsafe.Pointer(&s.AutoconfigureAddress)) - sp + if offset != mibIPforwardRow2AutoconfigureAddressOffset { + t.Errorf("MibIPforwardRow2.AutoconfigureAddress offset is %d although %d is expected", offset, mibIPforwardRow2AutoconfigureAddressOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Publish)) - sp + if offset != mibIPforwardRow2PublishOffset { + t.Errorf("MibIPforwardRow2.Publish offset is %d although %d is expected", offset, mibIPforwardRow2PublishOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Immortal)) - sp + if offset != mibIPforwardRow2ImmortalOffset { + t.Errorf("MibIPforwardRow2.Immortal offset is %d although %d is expected", offset, mibIPforwardRow2ImmortalOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Age)) - sp + if offset != mibIPforwardRow2AgeOffset { + t.Errorf("MibIPforwardRow2.Age offset is %d although %d is expected", offset, mibIPforwardRow2AgeOffset) + } + + offset = uintptr(unsafe.Pointer(&s.Origin)) - sp + if offset != mibIPforwardRow2OriginOffset { + t.Errorf("MibIPforwardRow2.Origin offset is %d although %d is expected", offset, mibIPforwardRow2OriginOffset) + } +} + +func TestMibIPforwardTable2(t *testing.T) { + s := mibIPforwardTable2{} + sp := uintptr(unsafe.Pointer(&s)) + const actualmibIPforwardTable2Size = unsafe.Sizeof(s) + + if actualmibIPforwardTable2Size != mibIPforwardTable2Size { + t.Errorf("Size of mibIPforwardTable2 is %d, although %d is expected.", actualmibIPforwardTable2Size, mibIPforwardTable2Size) + } + + offset := uintptr(unsafe.Pointer(&s.table)) - sp + if offset != mibIPforwardTable2TableOffset { + t.Errorf("mibIPforwardTable2.table offset is %d although %d is expected", offset, mibIPforwardTable2TableOffset) + } +} diff --git a/internal/winipcfg/types_test_32.go b/internal/winipcfg/types_test_32.go new file mode 100644 index 0000000..9e62bfe --- /dev/null +++ b/internal/winipcfg/types_test_32.go @@ -0,0 +1,59 @@ +//go:build 386 || arm + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +const ( + ipAdapterWINSServerAddressSize = 24 + ipAdapterWINSServerAddressNextOffset = 8 + ipAdapterWINSServerAddressAddressOffset = 12 + + ipAdapterGatewayAddressSize = 24 + ipAdapterGatewayAddressNextOffset = 8 + ipAdapterGatewayAddressAddressOffset = 12 + + ipAdapterDNSSuffixSize = 516 + ipAdapterDNSSuffixStringOffset = 4 + + ipAdapterAddressesSize = 376 + ipAdapterAddressesIfIndexOffset = 4 + ipAdapterAddressesNextOffset = 8 + ipAdapterAddressesAdapterNameOffset = 12 + ipAdapterAddressesFirstUnicastAddressOffset = 16 + ipAdapterAddressesFirstAnycastAddressOffset = 20 + ipAdapterAddressesFirstMulticastAddressOffset = 24 + ipAdapterAddressesFirstDNSServerAddressOffset = 28 + ipAdapterAddressesDNSSuffixOffset = 32 + ipAdapterAddressesDescriptionOffset = 36 + ipAdapterAddressesFriendlyNameOffset = 40 + ipAdapterAddressesPhysicalAddressOffset = 44 + ipAdapterAddressesPhysicalAddressLengthOffset = 52 + ipAdapterAddressesFlagsOffset = 56 + ipAdapterAddressesMTUOffset = 60 + ipAdapterAddressesIfTypeOffset = 64 + ipAdapterAddressesOperStatusOffset = 68 + ipAdapterAddressesIPv6IfIndexOffset = 72 + ipAdapterAddressesZoneIndicesOffset = 76 + ipAdapterAddressesFirstPrefixOffset = 140 + ipAdapterAddressesTransmitLinkSpeedOffset = 144 + ipAdapterAddressesReceiveLinkSpeedOffset = 152 + ipAdapterAddressesFirstWINSServerAddressOffset = 160 + ipAdapterAddressesFirstGatewayAddressOffset = 164 + ipAdapterAddressesIPv4MetricOffset = 168 + ipAdapterAddressesIPv6MetricOffset = 172 + ipAdapterAddressesLUIDOffset = 176 + ipAdapterAddressesDHCPv4ServerOffset = 184 + ipAdapterAddressesCompartmentIDOffset = 192 + ipAdapterAddressesNetworkGUIDOffset = 196 + ipAdapterAddressesConnectionTypeOffset = 212 + ipAdapterAddressesTunnelTypeOffset = 216 + ipAdapterAddressesDHCPv6ServerOffset = 220 + ipAdapterAddressesDHCPv6ClientDUIDOffset = 228 + ipAdapterAddressesDHCPv6ClientDUIDLengthOffset = 360 + ipAdapterAddressesDHCPv6IAIDOffset = 364 + ipAdapterAddressesFirstDNSSuffixOffset = 368 +) diff --git a/internal/winipcfg/types_test_64.go b/internal/winipcfg/types_test_64.go new file mode 100644 index 0000000..8a18157 --- /dev/null +++ b/internal/winipcfg/types_test_64.go @@ -0,0 +1,59 @@ +//go:build amd64 || arm64 + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +const ( + ipAdapterWINSServerAddressSize = 32 + ipAdapterWINSServerAddressNextOffset = 8 + ipAdapterWINSServerAddressAddressOffset = 16 + + ipAdapterGatewayAddressSize = 32 + ipAdapterGatewayAddressNextOffset = 8 + ipAdapterGatewayAddressAddressOffset = 16 + + ipAdapterDNSSuffixSize = 520 + ipAdapterDNSSuffixStringOffset = 8 + + ipAdapterAddressesSize = 448 + ipAdapterAddressesIfIndexOffset = 4 + ipAdapterAddressesNextOffset = 8 + ipAdapterAddressesAdapterNameOffset = 16 + ipAdapterAddressesFirstUnicastAddressOffset = 24 + ipAdapterAddressesFirstAnycastAddressOffset = 32 + ipAdapterAddressesFirstMulticastAddressOffset = 40 + ipAdapterAddressesFirstDNSServerAddressOffset = 48 + ipAdapterAddressesDNSSuffixOffset = 56 + ipAdapterAddressesDescriptionOffset = 64 + ipAdapterAddressesFriendlyNameOffset = 72 + ipAdapterAddressesPhysicalAddressOffset = 80 + ipAdapterAddressesPhysicalAddressLengthOffset = 88 + ipAdapterAddressesFlagsOffset = 92 + ipAdapterAddressesMTUOffset = 96 + ipAdapterAddressesIfTypeOffset = 100 + ipAdapterAddressesOperStatusOffset = 104 + ipAdapterAddressesIPv6IfIndexOffset = 108 + ipAdapterAddressesZoneIndicesOffset = 112 + ipAdapterAddressesFirstPrefixOffset = 176 + ipAdapterAddressesTransmitLinkSpeedOffset = 184 + ipAdapterAddressesReceiveLinkSpeedOffset = 192 + ipAdapterAddressesFirstWINSServerAddressOffset = 200 + ipAdapterAddressesFirstGatewayAddressOffset = 208 + ipAdapterAddressesIPv4MetricOffset = 216 + ipAdapterAddressesIPv6MetricOffset = 220 + ipAdapterAddressesLUIDOffset = 224 + ipAdapterAddressesDHCPv4ServerOffset = 232 + ipAdapterAddressesCompartmentIDOffset = 248 + ipAdapterAddressesNetworkGUIDOffset = 252 + ipAdapterAddressesConnectionTypeOffset = 268 + ipAdapterAddressesTunnelTypeOffset = 272 + ipAdapterAddressesDHCPv6ServerOffset = 280 + ipAdapterAddressesDHCPv6ClientDUIDOffset = 296 + ipAdapterAddressesDHCPv6ClientDUIDLengthOffset = 428 + ipAdapterAddressesDHCPv6IAIDOffset = 432 + ipAdapterAddressesFirstDNSSuffixOffset = 440 +) diff --git a/internal/winipcfg/unicast_address_change_handler.go b/internal/winipcfg/unicast_address_change_handler.go new file mode 100644 index 0000000..cf4fcb3 --- /dev/null +++ b/internal/winipcfg/unicast_address_change_handler.go @@ -0,0 +1,88 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +import ( + "sync" + + "golang.org/x/sys/windows" +) + +// UnicastAddressChangeCallback structure allows unicast address change callback handling. +type UnicastAddressChangeCallback struct { + cb func(notificationType MibNotificationType, unicastAddress *MibUnicastIPAddressRow) + wait sync.WaitGroup +} + +var ( + unicastAddressChangeAddRemoveMutex = sync.Mutex{} + unicastAddressChangeMutex = sync.Mutex{} + unicastAddressChangeCallbacks = make(map[*UnicastAddressChangeCallback]bool) + unicastAddressChangeHandle = windows.Handle(0) +) + +// RegisterUnicastAddressChangeCallback registers a new UnicastAddressChangeCallback. If this particular callback is already +// registered, the function will silently return. Returned UnicastAddressChangeCallback.Unregister method should be used +// to unregister. +func RegisterUnicastAddressChangeCallback(callback func(notificationType MibNotificationType, unicastAddress *MibUnicastIPAddressRow)) (*UnicastAddressChangeCallback, error) { + s := &UnicastAddressChangeCallback{cb: callback} + + unicastAddressChangeAddRemoveMutex.Lock() + defer unicastAddressChangeAddRemoveMutex.Unlock() + + unicastAddressChangeMutex.Lock() + defer unicastAddressChangeMutex.Unlock() + + unicastAddressChangeCallbacks[s] = true + + if unicastAddressChangeHandle == 0 { + err := notifyUnicastIPAddressChange(windows.AF_UNSPEC, windows.NewCallback(unicastAddressChanged), 0, false, &unicastAddressChangeHandle) + if err != nil { + delete(unicastAddressChangeCallbacks, s) + unicastAddressChangeHandle = 0 + return nil, err + } + } + + return s, nil +} + +// Unregister unregisters the callback. +func (callback *UnicastAddressChangeCallback) Unregister() error { + unicastAddressChangeAddRemoveMutex.Lock() + defer unicastAddressChangeAddRemoveMutex.Unlock() + + unicastAddressChangeMutex.Lock() + delete(unicastAddressChangeCallbacks, callback) + removeIt := len(unicastAddressChangeCallbacks) == 0 && unicastAddressChangeHandle != 0 + unicastAddressChangeMutex.Unlock() + + callback.wait.Wait() + + if removeIt { + err := cancelMibChangeNotify2(unicastAddressChangeHandle) + if err != nil { + return err + } + unicastAddressChangeHandle = 0 + } + + return nil +} + +func unicastAddressChanged(callerContext uintptr, row *MibUnicastIPAddressRow, notificationType MibNotificationType) uintptr { + rowCopy := *row + unicastAddressChangeMutex.Lock() + for cb := range unicastAddressChangeCallbacks { + cb.wait.Add(1) + go func(cb *UnicastAddressChangeCallback) { + cb.cb(notificationType, &rowCopy) + cb.wait.Done() + }(cb) + } + unicastAddressChangeMutex.Unlock() + return 0 +} diff --git a/internal/winipcfg/winipcfg.go b/internal/winipcfg/winipcfg.go new file mode 100644 index 0000000..e24157b --- /dev/null +++ b/internal/winipcfg/winipcfg.go @@ -0,0 +1,196 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package winipcfg + +import ( + "runtime" + "unsafe" + + "golang.org/x/sys/windows" +) + +// +// Common functions +// + +//sys freeMibTable(memory unsafe.Pointer) = iphlpapi.FreeMibTable + +// +// Interface-related functions +// + +//sys initializeIPInterfaceEntry(row *MibIPInterfaceRow) = iphlpapi.InitializeIpInterfaceEntry +//sys getIPInterfaceTable(family AddressFamily, table **mibIPInterfaceTable) (ret error) = iphlpapi.GetIpInterfaceTable +//sys getIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) = iphlpapi.GetIpInterfaceEntry +//sys setIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) = iphlpapi.SetIpInterfaceEntry +//sys getIfEntry2(row *MibIfRow2) (ret error) = iphlpapi.GetIfEntry2 +//sys getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) = iphlpapi.GetIfTable2Ex +//sys convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid +//sys convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) = iphlpapi.ConvertInterfaceGuidToLuid +//sys convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *LUID) (ret error) = iphlpapi.ConvertInterfaceIndexToLuid + +// GetAdaptersAddresses function retrieves the addresses associated with the adapters on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/iphlpapi/nf-iphlpapi-getadaptersaddresses +func GetAdaptersAddresses(family AddressFamily, flags GAAFlags) ([]*IPAdapterAddresses, error) { + var b []byte + size := uint32(15000) + + for { + b = make([]byte, size) + err := windows.GetAdaptersAddresses(uint32(family), uint32(flags), 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &size) + if err == nil { + break + } + if err != windows.ERROR_BUFFER_OVERFLOW || size <= uint32(len(b)) { + return nil, err + } + } + + result := make([]*IPAdapterAddresses, 0, uintptr(size)/unsafe.Sizeof(IPAdapterAddresses{})) + for wtiaa := (*IPAdapterAddresses)(unsafe.Pointer(&b[0])); wtiaa != nil; wtiaa = wtiaa.Next { + result = append(result, wtiaa) + } + + return result, nil +} + +// GetIPInterfaceTable function retrieves the IP interface entries on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipinterfacetable +func GetIPInterfaceTable(family AddressFamily) ([]MibIPInterfaceRow, error) { + var tab *mibIPInterfaceTable + err := getIPInterfaceTable(family, &tab) + if err != nil { + return nil, err + } + t := append(make([]MibIPInterfaceRow, 0, tab.numEntries), tab.get()...) + tab.free() + return t, nil +} + +// GetIfTable2Ex function retrieves the MIB-II interface table. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getiftable2ex +func GetIfTable2Ex(level MibIfEntryLevel) ([]MibIfRow2, error) { + var tab *mibIfTable2 + err := getIfTable2Ex(level, &tab) + if err != nil { + return nil, err + } + t := append(make([]MibIfRow2, 0, tab.numEntries), tab.get()...) + tab.free() + return t, nil +} + +// +// Unicast IP address-related functions +// + +//sys getUnicastIPAddressTable(family AddressFamily, table **mibUnicastIPAddressTable) (ret error) = iphlpapi.GetUnicastIpAddressTable +//sys initializeUnicastIPAddressEntry(row *MibUnicastIPAddressRow) = iphlpapi.InitializeUnicastIpAddressEntry +//sys getUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) = iphlpapi.GetUnicastIpAddressEntry +//sys setUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) = iphlpapi.SetUnicastIpAddressEntry +//sys createUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) = iphlpapi.CreateUnicastIpAddressEntry +//sys deleteUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) = iphlpapi.DeleteUnicastIpAddressEntry + +// GetUnicastIPAddressTable function retrieves the unicast IP address table on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddresstable +func GetUnicastIPAddressTable(family AddressFamily) ([]MibUnicastIPAddressRow, error) { + var tab *mibUnicastIPAddressTable + err := getUnicastIPAddressTable(family, &tab) + if err != nil { + return nil, err + } + t := append(make([]MibUnicastIPAddressRow, 0, tab.numEntries), tab.get()...) + tab.free() + return t, nil +} + +// +// Anycast IP address-related functions +// + +//sys getAnycastIPAddressTable(family AddressFamily, table **mibAnycastIPAddressTable) (ret error) = iphlpapi.GetAnycastIpAddressTable +//sys getAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) = iphlpapi.GetAnycastIpAddressEntry +//sys createAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) = iphlpapi.CreateAnycastIpAddressEntry +//sys deleteAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) = iphlpapi.DeleteAnycastIpAddressEntry + +// GetAnycastIPAddressTable function retrieves the anycast IP address table on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getanycastipaddresstable +func GetAnycastIPAddressTable(family AddressFamily) ([]MibAnycastIPAddressRow, error) { + var tab *mibAnycastIPAddressTable + err := getAnycastIPAddressTable(family, &tab) + if err != nil { + return nil, err + } + t := append(make([]MibAnycastIPAddressRow, 0, tab.numEntries), tab.get()...) + tab.free() + return t, nil +} + +// +// Routing-related functions +// + +//sys getIPForwardTable2(family AddressFamily, table **mibIPforwardTable2) (ret error) = iphlpapi.GetIpForwardTable2 +//sys initializeIPForwardEntry(route *MibIPforwardRow2) = iphlpapi.InitializeIpForwardEntry +//sys getIPForwardEntry2(route *MibIPforwardRow2) (ret error) = iphlpapi.GetIpForwardEntry2 +//sys setIPForwardEntry2(route *MibIPforwardRow2) (ret error) = iphlpapi.SetIpForwardEntry2 +//sys createIPForwardEntry2(route *MibIPforwardRow2) (ret error) = iphlpapi.CreateIpForwardEntry2 +//sys deleteIPForwardEntry2(route *MibIPforwardRow2) (ret error) = iphlpapi.DeleteIpForwardEntry2 + +// GetIPForwardTable2 function retrieves the IP route entries on the local computer. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardtable2 +func GetIPForwardTable2(family AddressFamily) ([]MibIPforwardRow2, error) { + var tab *mibIPforwardTable2 + err := getIPForwardTable2(family, &tab) + if err != nil { + return nil, err + } + t := append(make([]MibIPforwardRow2, 0, tab.numEntries), tab.get()...) + tab.free() + return t, nil +} + +// +// Notifications-related functions +// + +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-notifyipinterfacechange +//sys notifyIPInterfaceChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) = iphlpapi.NotifyIpInterfaceChange + +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-notifyunicastipaddresschange +//sys notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) = iphlpapi.NotifyUnicastIpAddressChange + +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-notifyroutechange2 +//sys notifyRouteChange2(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) = iphlpapi.NotifyRouteChange2 + +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-cancelmibchangenotify2 +//sys cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) = iphlpapi.CancelMibChangeNotify2 + +// +// DNS-related functions +// + +//sys setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? +//sys setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? +//sys setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? + +// The GUID is passed by value, not by reference, which means different +// things on different calling conventions. On amd64, this means it's +// passed by reference anyway, while on arm, arm64, and 386, it's split +// into words. +func SetInterfaceDnsSettings(guid windows.GUID, settings *DnsInterfaceSettings) error { + words := (*[4]uintptr)(unsafe.Pointer(&guid)) + switch runtime.GOARCH { + case "amd64": + return setInterfaceDnsSettingsByPtr(&guid, settings) + case "arm64": + return setInterfaceDnsSettingsByQwords(words[0], words[1], settings) + case "arm", "386": + return setInterfaceDnsSettingsByDwords(words[0], words[1], words[2], words[3], settings) + default: + panic("unknown calling convention") + } +} diff --git a/internal/winipcfg/winipcfg_test.go b/internal/winipcfg/winipcfg_test.go new file mode 100644 index 0000000..b49daf3 --- /dev/null +++ b/internal/winipcfg/winipcfg_test.go @@ -0,0 +1,660 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +/* + +Some tests in this file require: + +- A dedicated network adapter + Any network adapter will do. It may be virtual (WireGuardNT, Wintun, + etc.). The adapter name must contain string "winipcfg_test". + Tests will add, remove, flush DNS servers, change adapter IP address, manipulate + routes etc. + The adapter will not be returned to previous state, so use an expendable one. + +- Elevation + Run go test as Administrator + +*/ + +package winipcfg + +import ( + "net/netip" + "strings" + "syscall" + "testing" + "time" + + "golang.org/x/sys/windows" +) + +const ( + testInterfaceMarker = "winipcfg_test" // The interface we will use for testing must contain this string in its name +) + +// TODO: Add IPv6 tests. +var ( + nonexistantIPv4ToAdd = netip.MustParsePrefix("172.16.1.114/24") + nonexistentRouteIPv4ToAdd = RouteData{ + Destination: netip.MustParsePrefix("172.16.200.0/24"), + NextHop: netip.MustParseAddr("172.16.1.2"), + Metric: 0, + } + dnsesToSet = []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")} +) + +func runningElevated() bool { + var process windows.Token + err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &process) + if err != nil { + return false + } + defer process.Close() + return process.IsElevated() +} + +func getTestInterface() (*IPAdapterAddresses, error) { + ifcs, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagIncludeAll) + if err != nil { + return nil, err + } + + marker := strings.ToLower(testInterfaceMarker) + for _, ifc := range ifcs { + if strings.Contains(strings.ToLower(ifc.FriendlyName()), marker) { + return ifc, nil + } + } + + return nil, windows.ERROR_NOT_FOUND +} + +func getTestIPInterface(family AddressFamily) (*MibIPInterfaceRow, error) { + ifc, err := getTestInterface() + if err != nil { + return nil, err + } + + return ifc.LUID.IPInterface(family) +} + +func TestAdaptersAddresses(t *testing.T) { + ifcs, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagIncludeAll) + if err != nil { + t.Errorf("GetAdaptersAddresses() returned error: %w", err) + } else if ifcs == nil { + t.Errorf("GetAdaptersAddresses() returned nil.") + } else if len(ifcs) == 0 { + t.Errorf("GetAdaptersAddresses() returned empty.") + } else { + for _, i := range ifcs { + i.AdapterName() + i.DNSSuffix() + i.Description() + i.FriendlyName() + i.PhysicalAddress() + i.DHCPv6ClientDUID() + for dnsSuffix := i.FirstDNSSuffix; dnsSuffix != nil; dnsSuffix = dnsSuffix.Next { + _ = dnsSuffix.String() + } + } + } + + ifcs, err = GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault) + + for _, i := range ifcs { + ifc, err := i.LUID.Interface() + if err != nil { + t.Errorf("LUID.Interface() returned an error: %w", err) + continue + } else if ifc == nil { + t.Errorf("LUID.Interface() returned nil.") + continue + } + } + + for _, i := range ifcs { + guid, err := i.LUID.GUID() + if err != nil { + t.Errorf("LUID.GUID() returned an error: %w", err) + continue + } + if guid == nil { + t.Error("LUID.GUID() returned nil.") + continue + } + + luid, err := LUIDFromGUID(guid) + if err != nil { + t.Errorf("LUIDFromGUID() returned an error: %w", err) + continue + } + if luid != i.LUID { + t.Errorf("LUIDFromGUID() returned LUID %d, although expected was %d.", luid, i.LUID) + continue + } + } +} + +func TestIPInterface(t *testing.T) { + ifcs, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault) + if err != nil { + t.Errorf("GetAdaptersAddresses() returned error: %w", err) + } + + for _, i := range ifcs { + _, err := i.LUID.IPInterface(windows.AF_INET) + if err == windows.ERROR_NOT_FOUND { + // Ignore isatap and similar adapters without IPv4. + continue + } + if err != nil { + t.Errorf("LUID.IPInterface(%s) returned an error: %w", i.FriendlyName(), err) + } + + _, err = i.LUID.IPInterface(windows.AF_INET6) + if err != nil { + t.Errorf("LUID.IPInterface(%s) returned an error: %w", i.FriendlyName(), err) + } + } +} + +func TestIPInterfaces(t *testing.T) { + tab, err := GetIPInterfaceTable(windows.AF_UNSPEC) + if err != nil { + t.Errorf("GetIPInterfaceTable() returned an error: %w", err) + return + } else if tab == nil { + t.Error("GetIPInterfaceTable() returned nil.") + } + + if len(tab) == 0 { + t.Error("GetIPInterfaceTable() returned an empty slice.") + return + } +} + +func TestIPChangeMetric(t *testing.T) { + ipifc, err := getTestIPInterface(windows.AF_INET) + if err != nil { + t.Errorf("getTestIPInterface() returned an error: %w", err) + return + } + if !runningElevated() { + t.Errorf("%s requires elevation", t.Name()) + return + } + + var changed bool + cb, err := RegisterInterfaceChangeCallback(func(notificationType MibNotificationType, iface *MibIPInterfaceRow) { + if iface == nil || iface.InterfaceLUID != ipifc.InterfaceLUID { + return + } + switch notificationType { + case MibParameterNotification: + changed = true + } + }) + if err != nil { + t.Errorf("RegisterInterfaceChangeCallback() returned error: %w", err) + return + } + defer func() { + err = cb.Unregister() + if err != nil { + t.Errorf("UnregisterInterfaceChangeCallback() returned error: %w", err) + } + }() + + useAutomaticMetric := ipifc.UseAutomaticMetric + metric := ipifc.Metric + + newMetric := uint32(100) + if newMetric == metric { + newMetric = 200 + } + + ipifc.UseAutomaticMetric = false + ipifc.Metric = newMetric + err = ipifc.Set() + if err != nil { + t.Errorf("MibIPInterfaceRow.Set() returned an error: %w", err) + } + + time.Sleep(500 * time.Millisecond) + + ipifc, err = getTestIPInterface(windows.AF_INET) + if err != nil { + t.Errorf("getTestIPInterface() returned an error: %w", err) + return + } + if ipifc.Metric != newMetric { + t.Errorf("Expected metric: %d; actual metric: %d", newMetric, ipifc.Metric) + } + if ipifc.UseAutomaticMetric { + t.Error("UseAutomaticMetric is true although it's set to false.") + } + if !changed { + t.Errorf("Notification handler has not been called on metric change.") + } + changed = false + + ipifc.UseAutomaticMetric = useAutomaticMetric + ipifc.Metric = metric + err = ipifc.Set() + if err != nil { + t.Errorf("MibIPInterfaceRow.Set() returned an error: %w", err) + } + + time.Sleep(500 * time.Millisecond) + + ipifc, err = getTestIPInterface(windows.AF_INET) + if err != nil { + t.Errorf("getTestIPInterface() returned an error: %w", err) + return + } + if ipifc.Metric != metric { + t.Errorf("Expected metric: %d; actual metric: %d", metric, ipifc.Metric) + } + if ipifc.UseAutomaticMetric != useAutomaticMetric { + t.Errorf("UseAutomaticMetric is %v although %v is expected.", ipifc.UseAutomaticMetric, useAutomaticMetric) + } + if !changed { + t.Errorf("Notification handler has not been called on metric change.") + } +} + +func TestIPChangeMTU(t *testing.T) { + ipifc, err := getTestIPInterface(windows.AF_INET) + if err != nil { + t.Errorf("getTestIPInterface() returned an error: %w", err) + return + } + if !runningElevated() { + t.Errorf("%s requires elevation", t.Name()) + return + } + + prevMTU := ipifc.NLMTU + mtuToSet := prevMTU - 1 + ipifc.NLMTU = mtuToSet + err = ipifc.Set() + if err != nil { + t.Errorf("Interface.Set() returned error: %w", err) + } + + time.Sleep(500 * time.Millisecond) + + ipifc, err = getTestIPInterface(windows.AF_INET) + if err != nil { + t.Errorf("getTestIPInterface() returned an error: %w", err) + return + } + if ipifc.NLMTU != mtuToSet { + t.Errorf("Interface.NLMTU is %d although %d is expected.", ipifc.NLMTU, mtuToSet) + } + + ipifc.NLMTU = prevMTU + err = ipifc.Set() + if err != nil { + t.Errorf("Interface.Set() returned error: %w", err) + } + + time.Sleep(500 * time.Millisecond) + + ipifc, err = getTestIPInterface(windows.AF_INET) + if err != nil { + t.Errorf("getTestIPInterface() returned an error: %w", err) + } + if ipifc.NLMTU != prevMTU { + t.Errorf("Interface.NLMTU is %d although %d is expected.", ipifc.NLMTU, prevMTU) + } +} + +func TestGetIfRow(t *testing.T) { + ifc, err := getTestInterface() + if err != nil { + t.Errorf("getTestInterface() returned an error: %w", err) + return + } + + row, err := ifc.LUID.Interface() + if err != nil { + t.Errorf("LUID.Interface() returned an error: %w", err) + return + } + + row.Alias() + row.Description() + row.PhysicalAddress() + row.PermanentPhysicalAddress() +} + +func TestGetIfRows(t *testing.T) { + tab, err := GetIfTable2Ex(MibIfEntryNormal) + if err != nil { + t.Errorf("GetIfTable2Ex() returned an error: %w", err) + return + } else if tab == nil { + t.Errorf("GetIfTable2Ex() returned nil") + return + } + + for i := range tab { + tab[i].Alias() + tab[i].Description() + tab[i].PhysicalAddress() + tab[i].PermanentPhysicalAddress() + } +} + +func TestUnicastIPAddress(t *testing.T) { + _, err := GetUnicastIPAddressTable(windows.AF_UNSPEC) + if err != nil { + t.Errorf("GetUnicastAddresses() returned an error: %w", err) + return + } +} + +func TestAddDeleteIPAddress(t *testing.T) { + ifc, err := getTestInterface() + if err != nil { + t.Errorf("getTestInterface() returned an error: %w", err) + return + } + if !runningElevated() { + t.Errorf("%s requires elevation", t.Name()) + return + } + + addr, err := ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr()) + if err == nil { + t.Errorf("Unicast address %s already exists. Please set nonexistantIPv4ToAdd appropriately.", nonexistantIPv4ToAdd.Addr().String()) + return + } else if err != windows.ERROR_NOT_FOUND { + t.Errorf("LUID.IPAddress() returned an error: %w", err) + return + } + + var created, deleted bool + cb, err := RegisterUnicastAddressChangeCallback(func(notificationType MibNotificationType, addr *MibUnicastIPAddressRow) { + if addr == nil || addr.InterfaceLUID != ifc.LUID { + return + } + switch notificationType { + case MibAddInstance: + created = true + case MibDeleteInstance: + deleted = true + } + }) + if err != nil { + t.Errorf("RegisterUnicastAddressChangeCallback() returned an error: %w", err) + } else { + defer cb.Unregister() + } + var count int + for addr := ifc.FirstUnicastAddress; addr != nil; addr = addr.Next { + count-- + } + err = ifc.LUID.AddIPAddresses([]netip.Prefix{nonexistantIPv4ToAdd}) + if err != nil { + t.Errorf("LUID.AddIPAddresses() returned an error: %w", err) + } + + time.Sleep(500 * time.Millisecond) + + ifc, _ = getTestInterface() + for addr := ifc.FirstUnicastAddress; addr != nil; addr = addr.Next { + count++ + } + if count != 1 { + t.Errorf("After adding there are %d new interface(s).", count) + } + addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr()) + if err != nil { + t.Errorf("LUID.IPAddress() returned an error: %w", err) + } else if addr == nil { + t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", nonexistantIPv4ToAdd.Addr().String()) + } + if !created { + t.Errorf("Notification handler has not been called on add.") + } + + err = ifc.LUID.DeleteIPAddress(nonexistantIPv4ToAdd) + if err != nil { + t.Errorf("LUID.DeleteIPAddress() returned an error: %w", err) + } + + time.Sleep(500 * time.Millisecond) + + addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr()) + if err == nil { + t.Errorf("Unicast address %s still exists, although it's deleted successfully.", nonexistantIPv4ToAdd.Addr().String()) + } else if err != windows.ERROR_NOT_FOUND { + t.Errorf("LUID.IPAddress() returned an error: %w", err) + } + if !deleted { + t.Errorf("Notification handler has not been called on delete.") + } +} + +func TestGetRoutes(t *testing.T) { + _, err := GetIPForwardTable2(windows.AF_UNSPEC) + if err != nil { + t.Errorf("GetIPForwardTable2() returned error: %w", err) + } +} + +func TestAddDeleteRoute(t *testing.T) { + findRoute := func(luid LUID, dest netip.Prefix) ([]MibIPforwardRow2, error) { + var family AddressFamily + if dest.Addr().Is4() { + family = windows.AF_INET + } else if dest.Addr().Is6() { + family = windows.AF_INET6 + } else { + return nil, windows.ERROR_INVALID_PARAMETER + } + r, err := GetIPForwardTable2(family) + if err != nil { + return nil, err + } + matches := make([]MibIPforwardRow2, 0, len(r)) + for _, route := range r { + if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(dest.Bits()) && route.DestinationPrefix.RawPrefix.Family == family && route.DestinationPrefix.RawPrefix.Addr() == dest.Addr() { + matches = append(matches, route) + } + } + return matches, nil + } + + ifc, err := getTestInterface() + if err != nil { + t.Errorf("getTestInterface() returned an error: %w", err) + return + } + if !runningElevated() { + t.Errorf("%s requires elevation", t.Name()) + return + } + + _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) + if err == nil { + t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?") + return + } else if err != windows.ERROR_NOT_FOUND { + t.Errorf("LUID.Route() returned an error: %w", err) + return + } + + routes, err := findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination) + if err != nil { + t.Errorf("findRoute() returned an error: %w", err) + } else if len(routes) != 0 { + t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?", len(routes)) + } + + var created, deleted bool + cb, err := RegisterRouteChangeCallback(func(notificationType MibNotificationType, route *MibIPforwardRow2) { + switch notificationType { + case MibAddInstance: + created = true + case MibDeleteInstance: + deleted = true + } + }) + if err != nil { + t.Errorf("RegisterRouteChangeCallback() returned an error: %w", err) + } else { + defer cb.Unregister() + } + err = ifc.LUID.AddRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop, nonexistentRouteIPv4ToAdd.Metric) + if err != nil { + t.Errorf("LUID.AddRoute() returned an error: %w", err) + } + + time.Sleep(500 * time.Millisecond) + + route, err := ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) + if err == windows.ERROR_NOT_FOUND { + t.Error("LUID.Route() returned nil although the route is added successfully.") + } else if err != nil { + t.Errorf("LUID.Route() returned an error: %w", err) + } else if route.DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() || route.NextHop.Addr() != nonexistentRouteIPv4ToAdd.NextHop { + t.Error("LUID.Route() returned a wrong route!") + } + if !created { + t.Errorf("Route handler has not been called on add.") + } + + routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination) + if err != nil { + t.Errorf("findRoute() returned an error: %w", err) + } else if len(routes) != 1 { + t.Errorf("findRoute() returned %d items although %d is expected.", len(routes), 1) + } else if routes[0].DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() { + t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.RawPrefix.Addr().String(), nonexistentRouteIPv4ToAdd.Destination.Addr().String()) + } + + err = ifc.LUID.DeleteRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) + if err != nil { + t.Errorf("LUID.DeleteRoute() returned an error: %w", err) + } + + time.Sleep(500 * time.Millisecond) + + _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) + if err == nil { + t.Error("LUID.Route() returned a route although it is removed successfully.") + } else if err != windows.ERROR_NOT_FOUND { + t.Errorf("LUID.Route() returned an error: %w", err) + } + if !deleted { + t.Errorf("Route handler has not been called on delete.") + } + + routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination) + if err != nil { + t.Errorf("findRoute() returned an error: %w", err) + } else if len(routes) != 0 { + t.Errorf("findRoute() returned %d items although the route is deleted successfully.", len(routes)) + } +} + +func TestFlushDNS(t *testing.T) { + ifc, err := getTestInterface() + if err != nil { + t.Errorf("getTestInterface() returned an error: %w", err) + return + } + if !runningElevated() { + t.Errorf("%s requires elevation", t.Name()) + return + } + + prevDNSes, err := ifc.LUID.DNS() + if err != nil { + t.Errorf("LUID.DNS() returned an error: %w", err) + } + + err = ifc.LUID.FlushDNS(syscall.AF_INET) + if err != nil { + t.Errorf("LUID.FlushDNS() returned an error: %w", err) + } + + ifc, _ = getTestInterface() + + n := 0 + dns, err := ifc.LUID.DNS() + if err != nil { + t.Errorf("LUID.DNS() returned an error: %w", err) + } + for _, a := range dns { + if a.Is4() { + n++ + } + } + if n != 0 { + t.Errorf("DNSServerAddresses contains %d items, although FlushDNS is executed successfully.", n) + } + + err = ifc.LUID.SetDNS(windows.AF_INET, prevDNSes, nil) + if err != nil { + t.Errorf("LUID.SetDNS() returned an error: %v.", err) + } +} + +func TestSetDNS(t *testing.T) { + ifc, err := getTestInterface() + if err != nil { + t.Errorf("getTestInterface() returned an error: %w", err) + return + } + if !runningElevated() { + t.Errorf("%s requires elevation", t.Name()) + return + } + + prevDNSes, err := ifc.LUID.DNS() + if err != nil { + t.Errorf("LUID.DNS() returned an error: %w", err) + } + + err = ifc.LUID.SetDNS(windows.AF_INET, dnsesToSet, nil) + if err != nil { + t.Errorf("LUID.SetDNS() returned an error: %w", err) + return + } + + ifc, _ = getTestInterface() + + newDNSes, err := ifc.LUID.DNS() + if err != nil { + t.Errorf("LUID.DNS() returned an error: %w", err) + } else if len(newDNSes) != len(dnsesToSet) { + t.Errorf("dnsesToSet contains %d items, while DNSServerAddresses contains %d.", len(dnsesToSet), len(newDNSes)) + } else { + for i := range dnsesToSet { + if dnsesToSet[i] != newDNSes[i] { + t.Errorf("dnsesToSet[%d] = %s while DNSServerAddresses[%d] = %s.", i, dnsesToSet[i].String(), i, newDNSes[i].String()) + } + } + } + + err = ifc.LUID.SetDNS(windows.AF_INET, prevDNSes, nil) + if err != nil { + t.Errorf("LUID.SetDNS() returned an error: %v.", err) + } +} + +func TestAnycastIPAddress(t *testing.T) { + _, err := GetAnycastIPAddressTable(windows.AF_UNSPEC) + if err != nil { + t.Errorf("GetAnycastIPAddressTable() returned an error: %w", err) + return + } +} diff --git a/internal/winipcfg/zwinipcfg_windows.go b/internal/winipcfg/zwinipcfg_windows.go new file mode 100644 index 0000000..3a0d868 --- /dev/null +++ b/internal/winipcfg/zwinipcfg_windows.go @@ -0,0 +1,350 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package winipcfg + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") + + procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2") + procConvertInterfaceGuidToLuid = modiphlpapi.NewProc("ConvertInterfaceGuidToLuid") + procConvertInterfaceIndexToLuid = modiphlpapi.NewProc("ConvertInterfaceIndexToLuid") + procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid") + procCreateAnycastIpAddressEntry = modiphlpapi.NewProc("CreateAnycastIpAddressEntry") + procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2") + procCreateUnicastIpAddressEntry = modiphlpapi.NewProc("CreateUnicastIpAddressEntry") + procDeleteAnycastIpAddressEntry = modiphlpapi.NewProc("DeleteAnycastIpAddressEntry") + procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2") + procDeleteUnicastIpAddressEntry = modiphlpapi.NewProc("DeleteUnicastIpAddressEntry") + procFreeMibTable = modiphlpapi.NewProc("FreeMibTable") + procGetAnycastIpAddressEntry = modiphlpapi.NewProc("GetAnycastIpAddressEntry") + procGetAnycastIpAddressTable = modiphlpapi.NewProc("GetAnycastIpAddressTable") + procGetIfEntry2 = modiphlpapi.NewProc("GetIfEntry2") + procGetIfTable2Ex = modiphlpapi.NewProc("GetIfTable2Ex") + procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2") + procGetIpForwardTable2 = modiphlpapi.NewProc("GetIpForwardTable2") + procGetIpInterfaceEntry = modiphlpapi.NewProc("GetIpInterfaceEntry") + procGetIpInterfaceTable = modiphlpapi.NewProc("GetIpInterfaceTable") + procGetUnicastIpAddressEntry = modiphlpapi.NewProc("GetUnicastIpAddressEntry") + procGetUnicastIpAddressTable = modiphlpapi.NewProc("GetUnicastIpAddressTable") + procInitializeIpForwardEntry = modiphlpapi.NewProc("InitializeIpForwardEntry") + procInitializeIpInterfaceEntry = modiphlpapi.NewProc("InitializeIpInterfaceEntry") + procInitializeUnicastIpAddressEntry = modiphlpapi.NewProc("InitializeUnicastIpAddressEntry") + procNotifyIpInterfaceChange = modiphlpapi.NewProc("NotifyIpInterfaceChange") + procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2") + procNotifyUnicastIpAddressChange = modiphlpapi.NewProc("NotifyUnicastIpAddressChange") + procSetInterfaceDnsSettings = modiphlpapi.NewProc("SetInterfaceDnsSettings") + procSetIpForwardEntry2 = modiphlpapi.NewProc("SetIpForwardEntry2") + procSetIpInterfaceEntry = modiphlpapi.NewProc("SetIpInterfaceEntry") + procSetUnicastIpAddressEntry = modiphlpapi.NewProc("SetUnicastIpAddressEntry") +) + +func cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) { + r0, _, _ := syscall.Syscall(procCancelMibChangeNotify2.Addr(), 1, uintptr(notificationHandle), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) { + r0, _, _ := syscall.Syscall(procConvertInterfaceGuidToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceGUID)), uintptr(unsafe.Pointer(interfaceLUID)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *LUID) (ret error) { + r0, _, _ := syscall.Syscall(procConvertInterfaceIndexToLuid.Addr(), 2, uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) { + r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func createAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procCreateAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func createIPForwardEntry2(route *MibIPforwardRow2) (ret error) { + r0, _, _ := syscall.Syscall(procCreateIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func createUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procCreateUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func deleteAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procDeleteAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func deleteIPForwardEntry2(route *MibIPforwardRow2) (ret error) { + r0, _, _ := syscall.Syscall(procDeleteIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func deleteUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procDeleteUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func freeMibTable(memory unsafe.Pointer) { + syscall.Syscall(procFreeMibTable.Addr(), 1, uintptr(memory), 0, 0) + return +} + +func getAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procGetAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func getAnycastIPAddressTable(family AddressFamily, table **mibAnycastIPAddressTable) (ret error) { + r0, _, _ := syscall.Syscall(procGetAnycastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func getIfEntry2(row *MibIfRow2) (ret error) { + r0, _, _ := syscall.Syscall(procGetIfEntry2.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) { + r0, _, _ := syscall.Syscall(procGetIfTable2Ex.Addr(), 2, uintptr(level), uintptr(unsafe.Pointer(table)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func getIPForwardEntry2(route *MibIPforwardRow2) (ret error) { + r0, _, _ := syscall.Syscall(procGetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func getIPForwardTable2(family AddressFamily, table **mibIPforwardTable2) (ret error) { + r0, _, _ := syscall.Syscall(procGetIpForwardTable2.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func getIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) { + r0, _, _ := syscall.Syscall(procGetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func getIPInterfaceTable(family AddressFamily, table **mibIPInterfaceTable) (ret error) { + r0, _, _ := syscall.Syscall(procGetIpInterfaceTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func getUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procGetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func getUnicastIPAddressTable(family AddressFamily, table **mibUnicastIPAddressTable) (ret error) { + r0, _, _ := syscall.Syscall(procGetUnicastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func initializeIPForwardEntry(route *MibIPforwardRow2) { + syscall.Syscall(procInitializeIpForwardEntry.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) + return +} + +func initializeIPInterfaceEntry(row *MibIPInterfaceRow) { + syscall.Syscall(procInitializeIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + return +} + +func initializeUnicastIPAddressEntry(row *MibUnicastIPAddressRow) { + syscall.Syscall(procInitializeUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + return +} + +func notifyIPInterfaceChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) { + var _p0 uint32 + if initialNotification { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procNotifyIpInterfaceChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func notifyRouteChange2(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) { + var _p0 uint32 + if initialNotification { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procNotifyRouteChange2.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) { + var _p0 uint32 + if initialNotification { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procNotifyUnicastIpAddressChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) { + ret = procSetInterfaceDnsSettings.Find() + if ret != nil { + return + } + r0, _, _ := syscall.Syscall6(procSetInterfaceDnsSettings.Addr(), 5, uintptr(guid1), uintptr(guid2), uintptr(guid3), uintptr(guid4), uintptr(unsafe.Pointer(settings)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *DnsInterfaceSettings) (ret error) { + ret = procSetInterfaceDnsSettings.Find() + if ret != nil { + return + } + r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 3, uintptr(guid1), uintptr(guid2), uintptr(unsafe.Pointer(settings))) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) { + ret = procSetInterfaceDnsSettings.Find() + if ret != nil { + return + } + r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(settings)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func setIPForwardEntry2(route *MibIPforwardRow2) (ret error) { + r0, _, _ := syscall.Syscall(procSetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func setIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) { + r0, _, _ := syscall.Syscall(procSetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func setUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procSetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} diff --git a/internal/wintun/README.md b/internal/wintun/README.md new file mode 100644 index 0000000..9c7cb7a --- /dev/null +++ b/internal/wintun/README.md @@ -0,0 +1,3 @@ +# wintun + +DLL version: 0.14.1 \ No newline at end of file diff --git a/internal/wintun/amd64/wintun.dll b/internal/wintun/amd64/wintun.dll new file mode 100755 index 0000000..aee04e7 Binary files /dev/null and b/internal/wintun/amd64/wintun.dll differ diff --git a/internal/wintun/arm/wintun.dll b/internal/wintun/arm/wintun.dll new file mode 100755 index 0000000..0017794 Binary files /dev/null and b/internal/wintun/arm/wintun.dll differ diff --git a/internal/wintun/arm64/wintun.dll b/internal/wintun/arm64/wintun.dll new file mode 100755 index 0000000..dc4e4ae Binary files /dev/null and b/internal/wintun/arm64/wintun.dll differ diff --git a/internal/wintun/dll_windows.go b/internal/wintun/dll_windows.go new file mode 100644 index 0000000..42285ba --- /dev/null +++ b/internal/wintun/dll_windows.go @@ -0,0 +1,98 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package wintun + +import ( + "fmt" + "sync" + "sync/atomic" + "unsafe" + + "github.com/sagernet/sing-tun/internal/wintun/memmod" + + "golang.org/x/sys/windows" +) + +func (d *lazyDLL) NewProc(name string) *lazyProc { + return &lazyProc{dll: d, Name: name} +} + +type lazyProc struct { + Name string + mu sync.Mutex + dll *lazyDLL + addr uintptr +} + +func (p *lazyProc) Find() error { + if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil { + return nil + } + p.mu.Lock() + defer p.mu.Unlock() + if p.addr != 0 { + return nil + } + + err := p.dll.Load() + if err != nil { + return fmt.Errorf("error loading DLL: %s, MODULE: %s, error: %w", p.dll.Name, p.Name, err) + } + addr, err := p.nameToAddr() + if err != nil { + return fmt.Errorf("error getting %s address: %w", p.Name, err) + } + + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr)) + return nil +} + +func (p *lazyProc) Addr() uintptr { + err := p.Find() + if err != nil { + panic(err) + } + return p.addr +} + +func (p *lazyProc) Load() error { + return p.dll.Load() +} + +type lazyDLL struct { + Name string + Base windows.Handle + mu sync.Mutex + module *memmod.Module +} + +func newLazyDLL(name string) *lazyDLL { + return &lazyDLL{Name: name} +} + +func (d *lazyDLL) Load() error { + if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil { + return nil + } + d.mu.Lock() + defer d.mu.Unlock() + if d.module != nil { + return nil + } + + module, err := memmod.LoadLibrary(dllContent) + if err != nil { + return fmt.Errorf("unable to load library: %w", err) + } + d.Base = windows.Handle(module.BaseAddr()) + + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module)) + return nil +} + +func (p *lazyProc) nameToAddr() (uintptr, error) { + return p.dll.module.ProcAddressByName(p.Name) +} diff --git a/internal/wintun/dll_windows_386.go b/internal/wintun/dll_windows_386.go new file mode 100644 index 0000000..5085e90 --- /dev/null +++ b/internal/wintun/dll_windows_386.go @@ -0,0 +1,13 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package wintun + +import ( + _ "embed" +) + +//go:embed x86/wintun.dll +var dllContent []byte diff --git a/internal/wintun/dll_windows_amd64.go b/internal/wintun/dll_windows_amd64.go new file mode 100644 index 0000000..a5fc1c1 --- /dev/null +++ b/internal/wintun/dll_windows_amd64.go @@ -0,0 +1,13 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package wintun + +import ( + _ "embed" +) + +//go:embed amd64/wintun.dll +var dllContent []byte diff --git a/internal/wintun/dll_windows_arm.go b/internal/wintun/dll_windows_arm.go new file mode 100644 index 0000000..95ee623 --- /dev/null +++ b/internal/wintun/dll_windows_arm.go @@ -0,0 +1,13 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package wintun + +import ( + _ "embed" +) + +//go:embed arm/wintun.dll +var dllContent []byte diff --git a/internal/wintun/dll_windows_arm64.go b/internal/wintun/dll_windows_arm64.go new file mode 100644 index 0000000..e89de0b --- /dev/null +++ b/internal/wintun/dll_windows_arm64.go @@ -0,0 +1,13 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package wintun + +import ( + _ "embed" +) + +//go:embed arm64/wintun.dll +var dllContent []byte diff --git a/internal/wintun/memmod/memmod_windows.go b/internal/wintun/memmod/memmod_windows.go new file mode 100644 index 0000000..05d71f1 --- /dev/null +++ b/internal/wintun/memmod/memmod_windows.go @@ -0,0 +1,698 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package memmod + +import ( + "errors" + "fmt" + "strings" + "sync" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +type addressList struct { + next *addressList + address uintptr +} + +func (head *addressList) free() { + for node := head; node != nil; node = node.next { + windows.VirtualFree(node.address, 0, windows.MEM_RELEASE) + } +} + +type Module struct { + headers *IMAGE_NT_HEADERS + codeBase uintptr + modules []windows.Handle + initialized bool + isDLL bool + isRelocated bool + nameExports map[string]uint16 + entry uintptr + blockedMemory *addressList +} + +func (module *Module) BaseAddr() uintptr { + return module.codeBase +} + +func (module *Module) headerDirectory(idx int) *IMAGE_DATA_DIRECTORY { + return &module.headers.OptionalHeader.DataDirectory[idx] +} + +func (module *Module) copySections(address, size uintptr, oldHeaders *IMAGE_NT_HEADERS) error { + sections := module.headers.Sections() + for i := range sections { + if sections[i].SizeOfRawData == 0 { + // Section doesn't contain data in the dll itself, but may define uninitialized data. + sectionSize := oldHeaders.OptionalHeader.SectionAlignment + if sectionSize == 0 { + continue + } + dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress), + uintptr(sectionSize), + windows.MEM_COMMIT, + windows.PAGE_READWRITE) + if err != nil { + return fmt.Errorf("Error allocating section: %w", err) + } + + // Always use position from file to support alignments smaller than page size (allocation above will align to page size). + dest = module.codeBase + uintptr(sections[i].VirtualAddress) + // NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used. + sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff)) + dst := unsafe.Slice((*byte)(a2p(dest)), sectionSize) + for j := range dst { + dst[j] = 0 + } + continue + } + + if size < uintptr(sections[i].PointerToRawData+sections[i].SizeOfRawData) { + return errors.New("Incomplete section") + } + + // Commit memory block and copy data from dll. + dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress), + uintptr(sections[i].SizeOfRawData), + windows.MEM_COMMIT, + windows.PAGE_READWRITE) + if err != nil { + return fmt.Errorf("Error allocating memory block: %w", err) + } + + // Always use position from file to support alignments smaller than page size (allocation above will align to page size). + memcpy( + module.codeBase+uintptr(sections[i].VirtualAddress), + address+uintptr(sections[i].PointerToRawData), + uintptr(sections[i].SizeOfRawData)) + // NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used. + sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff)) + } + + return nil +} + +func (module *Module) realSectionSize(section *IMAGE_SECTION_HEADER) uintptr { + size := section.SizeOfRawData + if size != 0 { + return uintptr(size) + } + if (section.Characteristics & IMAGE_SCN_CNT_INITIALIZED_DATA) != 0 { + return uintptr(module.headers.OptionalHeader.SizeOfInitializedData) + } + if (section.Characteristics & IMAGE_SCN_CNT_UNINITIALIZED_DATA) != 0 { + return uintptr(module.headers.OptionalHeader.SizeOfUninitializedData) + } + return 0 +} + +type sectionFinalizeData struct { + address uintptr + alignedAddress uintptr + size uintptr + characteristics uint32 + last bool +} + +func (module *Module) finalizeSection(sectionData *sectionFinalizeData) error { + if sectionData.size == 0 { + return nil + } + + if (sectionData.characteristics & IMAGE_SCN_MEM_DISCARDABLE) != 0 { + // Section is not needed any more and can safely be freed. + if sectionData.address == sectionData.alignedAddress && + (sectionData.last || + (sectionData.size%uintptr(module.headers.OptionalHeader.SectionAlignment)) == 0) { + // Only allowed to decommit whole pages. + windows.VirtualFree(sectionData.address, sectionData.size, windows.MEM_DECOMMIT) + } + return nil + } + + // determine protection flags based on characteristics + ProtectionFlags := [8]uint32{ + windows.PAGE_NOACCESS, // not writeable, not readable, not executable + windows.PAGE_EXECUTE, // not writeable, not readable, executable + windows.PAGE_READONLY, // not writeable, readable, not executable + windows.PAGE_EXECUTE_READ, // not writeable, readable, executable + windows.PAGE_WRITECOPY, // writeable, not readable, not executable + windows.PAGE_EXECUTE_WRITECOPY, // writeable, not readable, executable + windows.PAGE_READWRITE, // writeable, readable, not executable + windows.PAGE_EXECUTE_READWRITE, // writeable, readable, executable + } + protect := ProtectionFlags[sectionData.characteristics>>29] + if (sectionData.characteristics & IMAGE_SCN_MEM_NOT_CACHED) != 0 { + protect |= windows.PAGE_NOCACHE + } + + // Change memory access flags. + var oldProtect uint32 + err := windows.VirtualProtect(sectionData.address, sectionData.size, protect, &oldProtect) + if err != nil { + return fmt.Errorf("Error protecting memory page: %w", err) + } + + return nil +} + +func (module *Module) registerExceptionHandlers() { + directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXCEPTION) + if directory.Size == 0 || directory.VirtualAddress == 0 { + return + } + runtimeFuncs := (*windows.RUNTIME_FUNCTION)(unsafe.Pointer(module.codeBase + uintptr(directory.VirtualAddress))) + windows.RtlAddFunctionTable(runtimeFuncs, uint32(uintptr(directory.Size)/unsafe.Sizeof(*runtimeFuncs)), module.codeBase) +} + +func (module *Module) finalizeSections() error { + sections := module.headers.Sections() + imageOffset := module.headers.OptionalHeader.imageOffset() + sectionData := sectionFinalizeData{} + sectionData.address = uintptr(sections[0].PhysicalAddress()) | imageOffset + sectionData.alignedAddress = alignDown(sectionData.address, uintptr(module.headers.OptionalHeader.SectionAlignment)) + sectionData.size = module.realSectionSize(§ions[0]) + sections[0].SetVirtualSize(uint32(sectionData.size)) + sectionData.characteristics = sections[0].Characteristics + + // Loop through all sections and change access flags. + for i := uint16(1); i < module.headers.FileHeader.NumberOfSections; i++ { + sectionAddress := uintptr(sections[i].PhysicalAddress()) | imageOffset + alignedAddress := alignDown(sectionAddress, uintptr(module.headers.OptionalHeader.SectionAlignment)) + sectionSize := module.realSectionSize(§ions[i]) + sections[i].SetVirtualSize(uint32(sectionSize)) + // Combine access flags of all sections that share a page. + // TODO: We currently share flags of a trailing large section with the page of a first small section. This should be optimized. + if sectionData.alignedAddress == alignedAddress || sectionData.address+sectionData.size > alignedAddress { + // Section shares page with previous. + if (sections[i].Characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 || (sectionData.characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 { + sectionData.characteristics = (sectionData.characteristics | sections[i].Characteristics) &^ IMAGE_SCN_MEM_DISCARDABLE + } else { + sectionData.characteristics |= sections[i].Characteristics + } + sectionData.size = sectionAddress + sectionSize - sectionData.address + continue + } + + err := module.finalizeSection(§ionData) + if err != nil { + return fmt.Errorf("Error finalizing section: %w", err) + } + sectionData.address = sectionAddress + sectionData.alignedAddress = alignedAddress + sectionData.size = sectionSize + sectionData.characteristics = sections[i].Characteristics + } + sectionData.last = true + err := module.finalizeSection(§ionData) + if err != nil { + return fmt.Errorf("Error finalizing section: %w", err) + } + return nil +} + +func (module *Module) executeTLS() { + directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_TLS) + if directory.VirtualAddress == 0 { + return + } + + tls := (*IMAGE_TLS_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress))) + callback := tls.AddressOfCallbacks + if callback != 0 { + for { + f := *(*uintptr)(a2p(callback)) + if f == 0 { + break + } + syscall.SyscallN(f, module.codeBase, DLL_PROCESS_ATTACH, 0) + callback += unsafe.Sizeof(f) + } + } +} + +func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err error) { + directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_BASERELOC) + if directory.Size == 0 { + return delta == 0, nil + } + + relocationHdr := (*IMAGE_BASE_RELOCATION)(a2p(module.codeBase + uintptr(directory.VirtualAddress))) + for relocationHdr.VirtualAddress > 0 { + dest := module.codeBase + uintptr(relocationHdr.VirtualAddress) + + relInfos := unsafe.Slice( + (*uint16)(a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr))), + (uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(uint16(0))) + for _, relInfo := range relInfos { + // The upper 4 bits define the type of relocation. + relType := relInfo >> 12 + // The lower 12 bits define the offset. + relOffset := uintptr(relInfo & 0xfff) + + switch relType { + case IMAGE_REL_BASED_ABSOLUTE: + // Skip relocation. + + case IMAGE_REL_BASED_LOW: + *(*uint16)(a2p(dest + relOffset)) += uint16(delta & 0xffff) + break + + case IMAGE_REL_BASED_HIGH: + *(*uint16)(a2p(dest + relOffset)) += uint16(uint32(delta) >> 16) + break + + case IMAGE_REL_BASED_HIGHLOW: + *(*uint32)(a2p(dest + relOffset)) += uint32(delta) + + case IMAGE_REL_BASED_DIR64: + *(*uint64)(a2p(dest + relOffset)) += uint64(delta) + + case IMAGE_REL_BASED_THUMB_MOV32: + inst := *(*uint32)(a2p(dest + relOffset)) + imm16 := ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) + + ((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff) + if (inst & 0x8000fbf0) != 0x0000f240 { + return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVW", inst) + } + imm16 += uint32(delta) & 0xffff + hiDelta := (uint32(delta&0xffff0000) >> 16) + ((imm16 & 0xffff0000) >> 16) + *(*uint32)(a2p(dest + relOffset)) = (inst & 0x8f00fbf0) + ((imm16 >> 1) & 0x0400) + + ((imm16 >> 12) & 0x000f) + + ((imm16 << 20) & 0x70000000) + + ((imm16 << 16) & 0xff0000) + if hiDelta != 0 { + inst = *(*uint32)(a2p(dest + relOffset + 4)) + imm16 = ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) + + ((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff) + if (inst & 0x8000fbf0) != 0x0000f2c0 { + return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVT", inst) + } + imm16 += hiDelta + if imm16 > 0xffff { + return false, fmt.Errorf("Resulting immediate value won't fit: %08x", imm16) + } + *(*uint32)(a2p(dest + relOffset + 4)) = (inst & 0x8f00fbf0) + + ((imm16 >> 1) & 0x0400) + + ((imm16 >> 12) & 0x000f) + + ((imm16 << 20) & 0x70000000) + + ((imm16 << 16) & 0xff0000) + } + + default: + return false, fmt.Errorf("Unsupported relocation: %v", relType) + } + } + + // Advance to next relocation block. + relocationHdr = (*IMAGE_BASE_RELOCATION)(a2p(uintptr(unsafe.Pointer(relocationHdr)) + uintptr(relocationHdr.SizeOfBlock))) + } + return true, nil +} + +func (module *Module) buildImportTable() error { + directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_IMPORT) + if directory.Size == 0 { + return nil + } + + module.modules = make([]windows.Handle, 0, 16) + importDesc := (*IMAGE_IMPORT_DESCRIPTOR)(a2p(module.codeBase + uintptr(directory.VirtualAddress))) + for importDesc.Name != 0 { + handle, err := windows.LoadLibraryEx(windows.BytePtrToString((*byte)(a2p(module.codeBase+uintptr(importDesc.Name)))), 0, windows.LOAD_LIBRARY_SEARCH_SYSTEM32) + if err != nil { + return fmt.Errorf("Error loading module: %w", err) + } + var thunkRef, funcRef *uintptr + if importDesc.OriginalFirstThunk() != 0 { + thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.OriginalFirstThunk()))) + funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk))) + } else { + // No hint table. + thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk))) + funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk))) + } + for *thunkRef != 0 { + if IMAGE_SNAP_BY_ORDINAL(*thunkRef) { + *funcRef, err = windows.GetProcAddressByOrdinal(handle, IMAGE_ORDINAL(*thunkRef)) + } else { + thunkData := (*IMAGE_IMPORT_BY_NAME)(a2p(module.codeBase + *thunkRef)) + *funcRef, err = windows.GetProcAddress(handle, windows.BytePtrToString(&thunkData.Name[0])) + } + if err != nil { + windows.FreeLibrary(handle) + return fmt.Errorf("Error getting function address: %w", err) + } + thunkRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(thunkRef)) + unsafe.Sizeof(*thunkRef))) + funcRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(funcRef)) + unsafe.Sizeof(*funcRef))) + } + module.modules = append(module.modules, handle) + importDesc = (*IMAGE_IMPORT_DESCRIPTOR)(a2p(uintptr(unsafe.Pointer(importDesc)) + unsafe.Sizeof(*importDesc))) + } + return nil +} + +func (module *Module) buildNameExports() error { + directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT) + if directory.Size == 0 { + return errors.New("No export table found") + } + exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress))) + if exports.NumberOfNames == 0 || exports.NumberOfFunctions == 0 { + return errors.New("No functions exported") + } + if exports.NumberOfNames == 0 { + return errors.New("No functions exported by name") + } + nameRefs := unsafe.Slice((*uint32)(a2p(module.codeBase+uintptr(exports.AddressOfNames))), exports.NumberOfNames) + ordinals := unsafe.Slice((*uint16)(a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals))), exports.NumberOfNames) + module.nameExports = make(map[string]uint16) + for i := range nameRefs { + nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i])))) + module.nameExports[nameArray] = ordinals[i] + } + return nil +} + +type addressRange struct { + start uintptr + end uintptr +} + +var ( + loadedAddressRanges []addressRange + loadedAddressRangesMu sync.RWMutex + haveHookedRtlPcToFileHeader sync.Once + hookRtlPcToFileHeaderResult error +) + +func hookRtlPcToFileHeader() error { + var kernelBase windows.Handle + err := windows.GetModuleHandleEx(windows.GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, windows.StringToUTF16Ptr("kernelbase.dll"), &kernelBase) + if err != nil { + return err + } + imageBase := unsafe.Pointer(kernelBase) + dosHeader := (*IMAGE_DOS_HEADER)(imageBase) + ntHeaders := (*IMAGE_NT_HEADERS)(unsafe.Add(imageBase, dosHeader.E_lfanew)) + importsDirectory := ntHeaders.OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT] + importDescriptor := (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(imageBase, importsDirectory.VirtualAddress)) + for ; importDescriptor.Name != 0; importDescriptor = (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(unsafe.Pointer(importDescriptor), unsafe.Sizeof(*importDescriptor))) { + libraryName := windows.BytePtrToString((*byte)(unsafe.Add(imageBase, importDescriptor.Name))) + if strings.EqualFold(libraryName, "ntdll.dll") { + break + } + } + if importDescriptor.Name == 0 { + return errors.New("ntdll.dll not found") + } + originalThunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.OriginalFirstThunk())) + thunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.FirstThunk)) + for ; *originalThunk != 0; originalThunk = (*uintptr)(unsafe.Add(unsafe.Pointer(originalThunk), unsafe.Sizeof(*originalThunk))) { + if *originalThunk&IMAGE_ORDINAL_FLAG == 0 { + function := (*IMAGE_IMPORT_BY_NAME)(unsafe.Add(imageBase, *originalThunk)) + name := windows.BytePtrToString(&function.Name[0]) + if name == "RtlPcToFileHeader" { + break + } + } + thunk = (*uintptr)(unsafe.Add(unsafe.Pointer(thunk), unsafe.Sizeof(*thunk))) + } + if *originalThunk == 0 { + return errors.New("RtlPcToFileHeader not found") + } + var oldProtect uint32 + err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), windows.PAGE_READWRITE, &oldProtect) + if err != nil { + return err + } + originalRtlPcToFileHeader := *thunk + *thunk = windows.NewCallback(func(pcValue uintptr, baseOfImage *uintptr) uintptr { + loadedAddressRangesMu.RLock() + for i := range loadedAddressRanges { + if pcValue >= loadedAddressRanges[i].start && pcValue < loadedAddressRanges[i].end { + pcValue = *thunk + break + } + } + loadedAddressRangesMu.RUnlock() + ret, _, _ := syscall.SyscallN(originalRtlPcToFileHeader, pcValue, uintptr(unsafe.Pointer(baseOfImage))) + return ret + }) + err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), oldProtect, &oldProtect) + if err != nil { + return err + } + return nil +} + +// LoadLibrary loads module image to memory. +func LoadLibrary(data []byte) (module *Module, err error) { + addr := uintptr(unsafe.Pointer(&data[0])) + size := uintptr(len(data)) + if size < unsafe.Sizeof(IMAGE_DOS_HEADER{}) { + return nil, errors.New("Incomplete IMAGE_DOS_HEADER") + } + dosHeader := (*IMAGE_DOS_HEADER)(a2p(addr)) + if dosHeader.E_magic != IMAGE_DOS_SIGNATURE { + return nil, fmt.Errorf("Not an MS-DOS binary (provided: %x, expected: %x)", dosHeader.E_magic, IMAGE_DOS_SIGNATURE) + } + if (size < uintptr(dosHeader.E_lfanew)+unsafe.Sizeof(IMAGE_NT_HEADERS{})) { + return nil, errors.New("Incomplete IMAGE_NT_HEADERS") + } + oldHeader := (*IMAGE_NT_HEADERS)(a2p(addr + uintptr(dosHeader.E_lfanew))) + if oldHeader.Signature != IMAGE_NT_SIGNATURE { + return nil, fmt.Errorf("Not an NT binary (provided: %x, expected: %x)", oldHeader.Signature, IMAGE_NT_SIGNATURE) + } + if oldHeader.FileHeader.Machine != imageFileProcess { + return nil, fmt.Errorf("Foreign platform (provided: %x, expected: %x)", oldHeader.FileHeader.Machine, imageFileProcess) + } + if (oldHeader.OptionalHeader.SectionAlignment & 1) != 0 { + return nil, errors.New("Unaligned section") + } + lastSectionEnd := uintptr(0) + sections := oldHeader.Sections() + optionalSectionSize := oldHeader.OptionalHeader.SectionAlignment + for i := range sections { + var endOfSection uintptr + if sections[i].SizeOfRawData == 0 { + // Section without data in the DLL + endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(optionalSectionSize) + } else { + endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(sections[i].SizeOfRawData) + } + if endOfSection > lastSectionEnd { + lastSectionEnd = endOfSection + } + } + alignedImageSize := alignUp(uintptr(oldHeader.OptionalHeader.SizeOfImage), uintptr(oldHeader.OptionalHeader.SectionAlignment)) + if alignedImageSize != alignUp(lastSectionEnd, uintptr(oldHeader.OptionalHeader.SectionAlignment)) { + return nil, errors.New("Section is not page-aligned") + } + + module = &Module{isDLL: (oldHeader.FileHeader.Characteristics & IMAGE_FILE_DLL) != 0} + defer func() { + if err != nil { + module.Free() + module = nil + } + }() + + // Reserve memory for image of library. + // TODO: Is it correct to commit the complete memory region at once? Calling DllEntry raises an exception if we don't. + module.codeBase, err = windows.VirtualAlloc(oldHeader.OptionalHeader.ImageBase, + alignedImageSize, + windows.MEM_RESERVE|windows.MEM_COMMIT, + windows.PAGE_READWRITE) + if err != nil { + // Try to allocate memory at arbitrary position. + module.codeBase, err = windows.VirtualAlloc(0, + alignedImageSize, + windows.MEM_RESERVE|windows.MEM_COMMIT, + windows.PAGE_READWRITE) + if err != nil { + err = fmt.Errorf("Error allocating code: %w", err) + return + } + } + err = module.check4GBBoundaries(alignedImageSize) + if err != nil { + err = fmt.Errorf("Error reallocating code: %w", err) + return + } + + if size < uintptr(oldHeader.OptionalHeader.SizeOfHeaders) { + err = errors.New("Incomplete headers") + return + } + // Commit memory for headers. + headers, err := windows.VirtualAlloc(module.codeBase, + uintptr(oldHeader.OptionalHeader.SizeOfHeaders), + windows.MEM_COMMIT, + windows.PAGE_READWRITE) + if err != nil { + err = fmt.Errorf("Error allocating headers: %w", err) + return + } + // Copy PE header to code. + memcpy(headers, addr, uintptr(oldHeader.OptionalHeader.SizeOfHeaders)) + module.headers = (*IMAGE_NT_HEADERS)(a2p(headers + uintptr(dosHeader.E_lfanew))) + + // Update position. + module.headers.OptionalHeader.ImageBase = module.codeBase + + // Copy sections from DLL file block to new memory location. + err = module.copySections(addr, size, oldHeader) + if err != nil { + err = fmt.Errorf("Error copying sections: %w", err) + return + } + + // Adjust base address of imported data. + locationDelta := module.headers.OptionalHeader.ImageBase - oldHeader.OptionalHeader.ImageBase + if locationDelta != 0 { + module.isRelocated, err = module.performBaseRelocation(locationDelta) + if err != nil { + err = fmt.Errorf("Error relocating module: %w", err) + return + } + } else { + module.isRelocated = true + } + + // Load required dlls and adjust function table of imports. + err = module.buildImportTable() + if err != nil { + err = fmt.Errorf("Error building import table: %w", err) + return + } + + // Mark memory pages depending on section headers and release sections that are marked as "discardable". + err = module.finalizeSections() + if err != nil { + err = fmt.Errorf("Error finalizing sections: %w", err) + return + } + + // Register exception tables, if they exist. + module.registerExceptionHandlers() + + // Register function PCs. + loadedAddressRangesMu.Lock() + loadedAddressRanges = append(loadedAddressRanges, addressRange{module.codeBase, module.codeBase + alignedImageSize}) + loadedAddressRangesMu.Unlock() + haveHookedRtlPcToFileHeader.Do(func() { + hookRtlPcToFileHeaderResult = hookRtlPcToFileHeader() + }) + err = hookRtlPcToFileHeaderResult + if err != nil { + return + } + + // TLS callbacks are executed BEFORE the main loading. + module.executeTLS() + + // Get entry point of loaded module. + if module.headers.OptionalHeader.AddressOfEntryPoint != 0 { + module.entry = module.codeBase + uintptr(module.headers.OptionalHeader.AddressOfEntryPoint) + if module.isDLL { + // Notify library about attaching to process. + r0, _, _ := syscall.SyscallN(module.entry, module.codeBase, DLL_PROCESS_ATTACH, 0) + successful := r0 != 0 + if !successful { + err = windows.ERROR_DLL_INIT_FAILED + return + } + module.initialized = true + } + } + + module.buildNameExports() + return +} + +// Free releases module resources and unloads it. +func (module *Module) Free() { + if module.initialized { + // Notify library about detaching from process. + syscall.SyscallN(module.entry, module.codeBase, DLL_PROCESS_DETACH, 0) + module.initialized = false + } + if module.modules != nil { + // Free previously opened libraries. + for _, handle := range module.modules { + windows.FreeLibrary(handle) + } + module.modules = nil + } + if module.codeBase != 0 { + windows.VirtualFree(module.codeBase, 0, windows.MEM_RELEASE) + module.codeBase = 0 + } + if module.blockedMemory != nil { + module.blockedMemory.free() + module.blockedMemory = nil + } +} + +// ProcAddressByName returns function address by exported name. +func (module *Module) ProcAddressByName(name string) (uintptr, error) { + directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT) + if directory.Size == 0 { + return 0, errors.New("No export table found") + } + exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress))) + if module.nameExports == nil { + return 0, errors.New("No functions exported by name") + } + if idx, ok := module.nameExports[name]; ok { + if uint32(idx) > exports.NumberOfFunctions { + return 0, errors.New("Ordinal number too high") + } + // AddressOfFunctions contains the RVAs to the "real" functions. + return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil + } + return 0, errors.New("Function not found by name") +} + +// ProcAddressByOrdinal returns function address by exported ordinal. +func (module *Module) ProcAddressByOrdinal(ordinal uint16) (uintptr, error) { + directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT) + if directory.Size == 0 { + return 0, errors.New("No export table found") + } + exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress))) + if uint32(ordinal) < exports.Base { + return 0, errors.New("Ordinal number too low") + } + idx := ordinal - uint16(exports.Base) + if uint32(idx) > exports.NumberOfFunctions { + return 0, errors.New("Ordinal number too high") + } + // AddressOfFunctions contains the RVAs to the "real" functions. + return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil +} + +func alignDown(value, alignment uintptr) uintptr { + return value & ^(alignment - 1) +} + +func alignUp(value, alignment uintptr) uintptr { + return (value + alignment - 1) & ^(alignment - 1) +} + +func a2p(addr uintptr) unsafe.Pointer { + return unsafe.Pointer(addr) +} + +func memcpy(dst, src, size uintptr) { + copy(unsafe.Slice((*byte)(a2p(dst)), size), unsafe.Slice((*byte)(a2p(src)), size)) +} diff --git a/internal/wintun/memmod/memmod_windows_32.go b/internal/wintun/memmod/memmod_windows_32.go new file mode 100644 index 0000000..50e6feb --- /dev/null +++ b/internal/wintun/memmod/memmod_windows_32.go @@ -0,0 +1,16 @@ +//go:build (windows && 386) || (windows && arm) + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package memmod + +func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr { + return 0 +} + +func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) { + return +} diff --git a/internal/wintun/memmod/memmod_windows_386.go b/internal/wintun/memmod/memmod_windows_386.go new file mode 100644 index 0000000..475c5c5 --- /dev/null +++ b/internal/wintun/memmod/memmod_windows_386.go @@ -0,0 +1,8 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package memmod + +const imageFileProcess = IMAGE_FILE_MACHINE_I386 diff --git a/internal/wintun/memmod/memmod_windows_64.go b/internal/wintun/memmod/memmod_windows_64.go new file mode 100644 index 0000000..a53851c --- /dev/null +++ b/internal/wintun/memmod/memmod_windows_64.go @@ -0,0 +1,36 @@ +//go:build (windows && amd64) || (windows && arm64) + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package memmod + +import ( + "fmt" + + "golang.org/x/sys/windows" +) + +func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr { + return uintptr(opthdr.ImageBase & 0xffffffff00000000) +} + +func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) { + for (module.codeBase >> 32) < ((module.codeBase + alignedImageSize) >> 32) { + node := &addressList{ + next: module.blockedMemory, + address: module.codeBase, + } + module.blockedMemory = node + module.codeBase, err = windows.VirtualAlloc(0, + alignedImageSize, + windows.MEM_RESERVE|windows.MEM_COMMIT, + windows.PAGE_READWRITE) + if err != nil { + return fmt.Errorf("Error allocating memory block: %w", err) + } + } + return +} diff --git a/internal/wintun/memmod/memmod_windows_amd64.go b/internal/wintun/memmod/memmod_windows_amd64.go new file mode 100644 index 0000000..a021a63 --- /dev/null +++ b/internal/wintun/memmod/memmod_windows_amd64.go @@ -0,0 +1,8 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package memmod + +const imageFileProcess = IMAGE_FILE_MACHINE_AMD64 diff --git a/internal/wintun/memmod/memmod_windows_arm.go b/internal/wintun/memmod/memmod_windows_arm.go new file mode 100644 index 0000000..4637a01 --- /dev/null +++ b/internal/wintun/memmod/memmod_windows_arm.go @@ -0,0 +1,8 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package memmod + +const imageFileProcess = IMAGE_FILE_MACHINE_ARMNT diff --git a/internal/wintun/memmod/memmod_windows_arm64.go b/internal/wintun/memmod/memmod_windows_arm64.go new file mode 100644 index 0000000..b8f1259 --- /dev/null +++ b/internal/wintun/memmod/memmod_windows_arm64.go @@ -0,0 +1,8 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package memmod + +const imageFileProcess = IMAGE_FILE_MACHINE_ARM64 diff --git a/internal/wintun/memmod/syscall_windows.go b/internal/wintun/memmod/syscall_windows.go new file mode 100644 index 0000000..b79be69 --- /dev/null +++ b/internal/wintun/memmod/syscall_windows.go @@ -0,0 +1,392 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package memmod + +import "unsafe" + +const ( + IMAGE_DOS_SIGNATURE = 0x5A4D // MZ + IMAGE_OS2_SIGNATURE = 0x454E // NE + IMAGE_OS2_SIGNATURE_LE = 0x454C // LE + IMAGE_VXD_SIGNATURE = 0x454C // LE + IMAGE_NT_SIGNATURE = 0x00004550 // PE00 +) + +// DOS .EXE header +type IMAGE_DOS_HEADER struct { + E_magic uint16 // Magic number + E_cblp uint16 // Bytes on last page of file + E_cp uint16 // Pages in file + E_crlc uint16 // Relocations + E_cparhdr uint16 // Size of header in paragraphs + E_minalloc uint16 // Minimum extra paragraphs needed + E_maxalloc uint16 // Maximum extra paragraphs needed + E_ss uint16 // Initial (relative) SS value + E_sp uint16 // Initial SP value + E_csum uint16 // Checksum + E_ip uint16 // Initial IP value + E_cs uint16 // Initial (relative) CS value + E_lfarlc uint16 // File address of relocation table + E_ovno uint16 // Overlay number + E_res [4]uint16 // Reserved words + E_oemid uint16 // OEM identifier (for e_oeminfo) + E_oeminfo uint16 // OEM information; e_oemid specific + E_res2 [10]uint16 // Reserved words + E_lfanew int32 // File address of new exe header +} + +// File header format +type IMAGE_FILE_HEADER struct { + Machine uint16 + NumberOfSections uint16 + TimeDateStamp uint32 + PointerToSymbolTable uint32 + NumberOfSymbols uint32 + SizeOfOptionalHeader uint16 + Characteristics uint16 +} + +const ( + IMAGE_SIZEOF_FILE_HEADER = 20 + + IMAGE_FILE_RELOCS_STRIPPED = 0x0001 // Relocation info stripped from file. + IMAGE_FILE_EXECUTABLE_IMAGE = 0x0002 // File is executable (i.e. no unresolved external references). + IMAGE_FILE_LINE_NUMS_STRIPPED = 0x0004 // Line nunbers stripped from file. + IMAGE_FILE_LOCAL_SYMS_STRIPPED = 0x0008 // Local symbols stripped from file. + IMAGE_FILE_AGGRESIVE_WS_TRIM = 0x0010 // Aggressively trim working set + IMAGE_FILE_LARGE_ADDRESS_AWARE = 0x0020 // App can handle >2gb addresses + IMAGE_FILE_BYTES_REVERSED_LO = 0x0080 // Bytes of machine word are reversed. + IMAGE_FILE_32BIT_MACHINE = 0x0100 // 32 bit word machine. + IMAGE_FILE_DEBUG_STRIPPED = 0x0200 // Debugging info stripped from file in .DBG file + IMAGE_FILE_REMOVABLE_RUN_FROM_SWAP = 0x0400 // If Image is on removable media, copy and run from the swap file. + IMAGE_FILE_NET_RUN_FROM_SWAP = 0x0800 // If Image is on Net, copy and run from the swap file. + IMAGE_FILE_SYSTEM = 0x1000 // System File. + IMAGE_FILE_DLL = 0x2000 // File is a DLL. + IMAGE_FILE_UP_SYSTEM_ONLY = 0x4000 // File should only be run on a UP machine + IMAGE_FILE_BYTES_REVERSED_HI = 0x8000 // Bytes of machine word are reversed. + + IMAGE_FILE_MACHINE_UNKNOWN = 0 + IMAGE_FILE_MACHINE_TARGET_HOST = 0x0001 // Useful for indicating we want to interact with the host and not a WoW guest. + IMAGE_FILE_MACHINE_I386 = 0x014c // Intel 386. + IMAGE_FILE_MACHINE_R3000 = 0x0162 // MIPS little-endian, 0x160 big-endian + IMAGE_FILE_MACHINE_R4000 = 0x0166 // MIPS little-endian + IMAGE_FILE_MACHINE_R10000 = 0x0168 // MIPS little-endian + IMAGE_FILE_MACHINE_WCEMIPSV2 = 0x0169 // MIPS little-endian WCE v2 + IMAGE_FILE_MACHINE_ALPHA = 0x0184 // Alpha_AXP + IMAGE_FILE_MACHINE_SH3 = 0x01a2 // SH3 little-endian + IMAGE_FILE_MACHINE_SH3DSP = 0x01a3 + IMAGE_FILE_MACHINE_SH3E = 0x01a4 // SH3E little-endian + IMAGE_FILE_MACHINE_SH4 = 0x01a6 // SH4 little-endian + IMAGE_FILE_MACHINE_SH5 = 0x01a8 // SH5 + IMAGE_FILE_MACHINE_ARM = 0x01c0 // ARM Little-Endian + IMAGE_FILE_MACHINE_THUMB = 0x01c2 // ARM Thumb/Thumb-2 Little-Endian + IMAGE_FILE_MACHINE_ARMNT = 0x01c4 // ARM Thumb-2 Little-Endian + IMAGE_FILE_MACHINE_AM33 = 0x01d3 + IMAGE_FILE_MACHINE_POWERPC = 0x01F0 // IBM PowerPC Little-Endian + IMAGE_FILE_MACHINE_POWERPCFP = 0x01f1 + IMAGE_FILE_MACHINE_IA64 = 0x0200 // Intel 64 + IMAGE_FILE_MACHINE_MIPS16 = 0x0266 // MIPS + IMAGE_FILE_MACHINE_ALPHA64 = 0x0284 // ALPHA64 + IMAGE_FILE_MACHINE_MIPSFPU = 0x0366 // MIPS + IMAGE_FILE_MACHINE_MIPSFPU16 = 0x0466 // MIPS + IMAGE_FILE_MACHINE_AXP64 = IMAGE_FILE_MACHINE_ALPHA64 + IMAGE_FILE_MACHINE_TRICORE = 0x0520 // Infineon + IMAGE_FILE_MACHINE_CEF = 0x0CEF + IMAGE_FILE_MACHINE_EBC = 0x0EBC // EFI Byte Code + IMAGE_FILE_MACHINE_AMD64 = 0x8664 // AMD64 (K8) + IMAGE_FILE_MACHINE_M32R = 0x9041 // M32R little-endian + IMAGE_FILE_MACHINE_ARM64 = 0xAA64 // ARM64 Little-Endian + IMAGE_FILE_MACHINE_CEE = 0xC0EE +) + +// Directory format +type IMAGE_DATA_DIRECTORY struct { + VirtualAddress uint32 + Size uint32 +} + +const IMAGE_NUMBEROF_DIRECTORY_ENTRIES = 16 + +type IMAGE_NT_HEADERS struct { + Signature uint32 + FileHeader IMAGE_FILE_HEADER + OptionalHeader IMAGE_OPTIONAL_HEADER +} + +func (ntheader *IMAGE_NT_HEADERS) Sections() []IMAGE_SECTION_HEADER { + return (*[0xffff]IMAGE_SECTION_HEADER)(unsafe.Pointer( + (uintptr)(unsafe.Pointer(ntheader)) + + unsafe.Offsetof(ntheader.OptionalHeader) + + uintptr(ntheader.FileHeader.SizeOfOptionalHeader)))[:ntheader.FileHeader.NumberOfSections] +} + +const ( + IMAGE_DIRECTORY_ENTRY_EXPORT = 0 // Export Directory + IMAGE_DIRECTORY_ENTRY_IMPORT = 1 // Import Directory + IMAGE_DIRECTORY_ENTRY_RESOURCE = 2 // Resource Directory + IMAGE_DIRECTORY_ENTRY_EXCEPTION = 3 // Exception Directory + IMAGE_DIRECTORY_ENTRY_SECURITY = 4 // Security Directory + IMAGE_DIRECTORY_ENTRY_BASERELOC = 5 // Base Relocation Table + IMAGE_DIRECTORY_ENTRY_DEBUG = 6 // Debug Directory + IMAGE_DIRECTORY_ENTRY_COPYRIGHT = 7 // (X86 usage) + IMAGE_DIRECTORY_ENTRY_ARCHITECTURE = 7 // Architecture Specific Data + IMAGE_DIRECTORY_ENTRY_GLOBALPTR = 8 // RVA of GP + IMAGE_DIRECTORY_ENTRY_TLS = 9 // TLS Directory + IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG = 10 // Load Configuration Directory + IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT = 11 // Bound Import Directory in headers + IMAGE_DIRECTORY_ENTRY_IAT = 12 // Import Address Table + IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT = 13 // Delay Load Import Descriptors + IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR = 14 // COM Runtime descriptor +) + +const IMAGE_SIZEOF_SHORT_NAME = 8 + +// Section header format +type IMAGE_SECTION_HEADER struct { + Name [IMAGE_SIZEOF_SHORT_NAME]byte + physicalAddressOrVirtualSize uint32 + VirtualAddress uint32 + SizeOfRawData uint32 + PointerToRawData uint32 + PointerToRelocations uint32 + PointerToLinenumbers uint32 + NumberOfRelocations uint16 + NumberOfLinenumbers uint16 + Characteristics uint32 +} + +func (ishdr *IMAGE_SECTION_HEADER) PhysicalAddress() uint32 { + return ishdr.physicalAddressOrVirtualSize +} + +func (ishdr *IMAGE_SECTION_HEADER) SetPhysicalAddress(addr uint32) { + ishdr.physicalAddressOrVirtualSize = addr +} + +func (ishdr *IMAGE_SECTION_HEADER) VirtualSize() uint32 { + return ishdr.physicalAddressOrVirtualSize +} + +func (ishdr *IMAGE_SECTION_HEADER) SetVirtualSize(addr uint32) { + ishdr.physicalAddressOrVirtualSize = addr +} + +const ( + // Dll characteristics. + IMAGE_DLL_CHARACTERISTICS_HIGH_ENTROPY_VA = 0x0020 + IMAGE_DLL_CHARACTERISTICS_DYNAMIC_BASE = 0x0040 + IMAGE_DLL_CHARACTERISTICS_FORCE_INTEGRITY = 0x0080 + IMAGE_DLL_CHARACTERISTICS_NX_COMPAT = 0x0100 + IMAGE_DLL_CHARACTERISTICS_NO_ISOLATION = 0x0200 + IMAGE_DLL_CHARACTERISTICS_NO_SEH = 0x0400 + IMAGE_DLL_CHARACTERISTICS_NO_BIND = 0x0800 + IMAGE_DLL_CHARACTERISTICS_APPCONTAINER = 0x1000 + IMAGE_DLL_CHARACTERISTICS_WDM_DRIVER = 0x2000 + IMAGE_DLL_CHARACTERISTICS_GUARD_CF = 0x4000 + IMAGE_DLL_CHARACTERISTICS_TERMINAL_SERVER_AWARE = 0x8000 +) + +const ( + // Section characteristics. + IMAGE_SCN_TYPE_REG = 0x00000000 // Reserved. + IMAGE_SCN_TYPE_DSECT = 0x00000001 // Reserved. + IMAGE_SCN_TYPE_NOLOAD = 0x00000002 // Reserved. + IMAGE_SCN_TYPE_GROUP = 0x00000004 // Reserved. + IMAGE_SCN_TYPE_NO_PAD = 0x00000008 // Reserved. + IMAGE_SCN_TYPE_COPY = 0x00000010 // Reserved. + + IMAGE_SCN_CNT_CODE = 0x00000020 // Section contains code. + IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040 // Section contains initialized data. + IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080 // Section contains uninitialized data. + + IMAGE_SCN_LNK_OTHER = 0x00000100 // Reserved. + IMAGE_SCN_LNK_INFO = 0x00000200 // Section contains comments or some other type of information. + IMAGE_SCN_TYPE_OVER = 0x00000400 // Reserved. + IMAGE_SCN_LNK_REMOVE = 0x00000800 // Section contents will not become part of image. + IMAGE_SCN_LNK_COMDAT = 0x00001000 // Section contents comdat. + IMAGE_SCN_MEM_PROTECTED = 0x00004000 // Obsolete. + IMAGE_SCN_NO_DEFER_SPEC_EXC = 0x00004000 // Reset speculative exceptions handling bits in the TLB entries for this section. + IMAGE_SCN_GPREL = 0x00008000 // Section content can be accessed relative to GP + IMAGE_SCN_MEM_FARDATA = 0x00008000 + IMAGE_SCN_MEM_SYSHEAP = 0x00010000 // Obsolete. + IMAGE_SCN_MEM_PURGEABLE = 0x00020000 + IMAGE_SCN_MEM_16BIT = 0x00020000 + IMAGE_SCN_MEM_LOCKED = 0x00040000 + IMAGE_SCN_MEM_PRELOAD = 0x00080000 + + IMAGE_SCN_ALIGN_1BYTES = 0x00100000 // + IMAGE_SCN_ALIGN_2BYTES = 0x00200000 // + IMAGE_SCN_ALIGN_4BYTES = 0x00300000 // + IMAGE_SCN_ALIGN_8BYTES = 0x00400000 // + IMAGE_SCN_ALIGN_16BYTES = 0x00500000 // Default alignment if no others are specified. + IMAGE_SCN_ALIGN_32BYTES = 0x00600000 // + IMAGE_SCN_ALIGN_64BYTES = 0x00700000 // + IMAGE_SCN_ALIGN_128BYTES = 0x00800000 // + IMAGE_SCN_ALIGN_256BYTES = 0x00900000 // + IMAGE_SCN_ALIGN_512BYTES = 0x00A00000 // + IMAGE_SCN_ALIGN_1024BYTES = 0x00B00000 // + IMAGE_SCN_ALIGN_2048BYTES = 0x00C00000 // + IMAGE_SCN_ALIGN_4096BYTES = 0x00D00000 // + IMAGE_SCN_ALIGN_8192BYTES = 0x00E00000 // + IMAGE_SCN_ALIGN_MASK = 0x00F00000 + + IMAGE_SCN_LNK_NRELOC_OVFL = 0x01000000 // Section contains extended relocations. + IMAGE_SCN_MEM_DISCARDABLE = 0x02000000 // Section can be discarded. + IMAGE_SCN_MEM_NOT_CACHED = 0x04000000 // Section is not cachable. + IMAGE_SCN_MEM_NOT_PAGED = 0x08000000 // Section is not pageable. + IMAGE_SCN_MEM_SHARED = 0x10000000 // Section is shareable. + IMAGE_SCN_MEM_EXECUTE = 0x20000000 // Section is executable. + IMAGE_SCN_MEM_READ = 0x40000000 // Section is readable. + IMAGE_SCN_MEM_WRITE = 0x80000000 // Section is writeable. + + // TLS Characteristic Flags + IMAGE_SCN_SCALE_INDEX = 0x00000001 // Tls index is scaled. +) + +// Based relocation format +type IMAGE_BASE_RELOCATION struct { + VirtualAddress uint32 + SizeOfBlock uint32 +} + +const ( + IMAGE_REL_BASED_ABSOLUTE = 0 + IMAGE_REL_BASED_HIGH = 1 + IMAGE_REL_BASED_LOW = 2 + IMAGE_REL_BASED_HIGHLOW = 3 + IMAGE_REL_BASED_HIGHADJ = 4 + IMAGE_REL_BASED_MACHINE_SPECIFIC_5 = 5 + IMAGE_REL_BASED_RESERVED = 6 + IMAGE_REL_BASED_MACHINE_SPECIFIC_7 = 7 + IMAGE_REL_BASED_MACHINE_SPECIFIC_8 = 8 + IMAGE_REL_BASED_MACHINE_SPECIFIC_9 = 9 + IMAGE_REL_BASED_DIR64 = 10 + + IMAGE_REL_BASED_IA64_IMM64 = 9 + + IMAGE_REL_BASED_MIPS_JMPADDR = 5 + IMAGE_REL_BASED_MIPS_JMPADDR16 = 9 + + IMAGE_REL_BASED_ARM_MOV32 = 5 + IMAGE_REL_BASED_THUMB_MOV32 = 7 +) + +// Export Format +type IMAGE_EXPORT_DIRECTORY struct { + Characteristics uint32 + TimeDateStamp uint32 + MajorVersion uint16 + MinorVersion uint16 + Name uint32 + Base uint32 + NumberOfFunctions uint32 + NumberOfNames uint32 + AddressOfFunctions uint32 // RVA from base of image + AddressOfNames uint32 // RVA from base of image + AddressOfNameOrdinals uint32 // RVA from base of image +} + +type IMAGE_IMPORT_BY_NAME struct { + Hint uint16 + Name [1]byte +} + +func IMAGE_ORDINAL(ordinal uintptr) uintptr { + return ordinal & 0xffff +} + +func IMAGE_SNAP_BY_ORDINAL(ordinal uintptr) bool { + return (ordinal & IMAGE_ORDINAL_FLAG) != 0 +} + +// Thread Local Storage +type IMAGE_TLS_DIRECTORY struct { + StartAddressOfRawData uintptr + EndAddressOfRawData uintptr + AddressOfIndex uintptr // PDWORD + AddressOfCallbacks uintptr // PIMAGE_TLS_CALLBACK *; + SizeOfZeroFill uint32 + Characteristics uint32 +} + +type IMAGE_IMPORT_DESCRIPTOR struct { + characteristicsOrOriginalFirstThunk uint32 // 0 for terminating null import descriptor + // RVA to original unbound IAT (PIMAGE_THUNK_DATA) + TimeDateStamp uint32 // 0 if not bound, + // -1 if bound, and real date\time stamp + // in IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT (new BIND) + // O.W. date/time stamp of DLL bound to (Old BIND) + ForwarderChain uint32 // -1 if no forwarders + Name uint32 + FirstThunk uint32 // RVA to IAT (if bound this IAT has actual addresses) +} + +func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) Characteristics() uint32 { + return imgimpdesc.characteristicsOrOriginalFirstThunk +} + +func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) OriginalFirstThunk() uint32 { + return imgimpdesc.characteristicsOrOriginalFirstThunk +} + +type IMAGE_DELAYLOAD_DESCRIPTOR struct { + Attributes uint32 + DllNameRVA uint32 + ModuleHandleRVA uint32 + ImportAddressTableRVA uint32 + ImportNameTableRVA uint32 + BoundImportAddressTableRVA uint32 + UnloadInformationTableRVA uint32 + TimeDateStamp uint32 +} + +type IMAGE_LOAD_CONFIG_CODE_INTEGRITY struct { + Flags uint16 + Catalog uint16 + CatalogOffset uint32 + Reserved uint32 +} + +const ( + IMAGE_GUARD_CF_INSTRUMENTED = 0x00000100 + IMAGE_GUARD_CFW_INSTRUMENTED = 0x00000200 + IMAGE_GUARD_CF_FUNCTION_TABLE_PRESENT = 0x00000400 + IMAGE_GUARD_SECURITY_COOKIE_UNUSED = 0x00000800 + IMAGE_GUARD_PROTECT_DELAYLOAD_IAT = 0x00001000 + IMAGE_GUARD_DELAYLOAD_IAT_IN_ITS_OWN_SECTION = 0x00002000 + IMAGE_GUARD_CF_EXPORT_SUPPRESSION_INFO_PRESENT = 0x00004000 + IMAGE_GUARD_CF_ENABLE_EXPORT_SUPPRESSION = 0x00008000 + IMAGE_GUARD_CF_LONGJUMP_TABLE_PRESENT = 0x00010000 + IMAGE_GUARD_RF_INSTRUMENTED = 0x00020000 + IMAGE_GUARD_RF_ENABLE = 0x00040000 + IMAGE_GUARD_RF_STRICT = 0x00080000 + IMAGE_GUARD_RETPOLINE_PRESENT = 0x00100000 + IMAGE_GUARD_EH_CONTINUATION_TABLE_PRESENT = 0x00400000 + IMAGE_GUARD_XFG_ENABLED = 0x00800000 + IMAGE_GUARD_CF_FUNCTION_TABLE_SIZE_MASK = 0xF0000000 + IMAGE_GUARD_CF_FUNCTION_TABLE_SIZE_SHIFT = 28 +) + +const ( + DLL_PROCESS_ATTACH = 1 + DLL_THREAD_ATTACH = 2 + DLL_THREAD_DETACH = 3 + DLL_PROCESS_DETACH = 0 +) + +type SYSTEM_INFO struct { + ProcessorArchitecture uint16 + Reserved uint16 + PageSize uint32 + MinimumApplicationAddress uintptr + MaximumApplicationAddress uintptr + ActiveProcessorMask uintptr + NumberOfProcessors uint32 + ProcessorType uint32 + AllocationGranularity uint32 + ProcessorLevel uint16 + ProcessorRevision uint16 +} diff --git a/internal/wintun/memmod/syscall_windows_32.go b/internal/wintun/memmod/syscall_windows_32.go new file mode 100644 index 0000000..f036ecb --- /dev/null +++ b/internal/wintun/memmod/syscall_windows_32.go @@ -0,0 +1,96 @@ +//go:build (windows && 386) || (windows && arm) + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package memmod + +// Optional header format +type IMAGE_OPTIONAL_HEADER struct { + Magic uint16 + MajorLinkerVersion uint8 + MinorLinkerVersion uint8 + SizeOfCode uint32 + SizeOfInitializedData uint32 + SizeOfUninitializedData uint32 + AddressOfEntryPoint uint32 + BaseOfCode uint32 + BaseOfData uint32 + ImageBase uintptr + SectionAlignment uint32 + FileAlignment uint32 + MajorOperatingSystemVersion uint16 + MinorOperatingSystemVersion uint16 + MajorImageVersion uint16 + MinorImageVersion uint16 + MajorSubsystemVersion uint16 + MinorSubsystemVersion uint16 + Win32VersionValue uint32 + SizeOfImage uint32 + SizeOfHeaders uint32 + CheckSum uint32 + Subsystem uint16 + DllCharacteristics uint16 + SizeOfStackReserve uintptr + SizeOfStackCommit uintptr + SizeOfHeapReserve uintptr + SizeOfHeapCommit uintptr + LoaderFlags uint32 + NumberOfRvaAndSizes uint32 + DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY +} + +const IMAGE_ORDINAL_FLAG uintptr = 0x80000000 + +type IMAGE_LOAD_CONFIG_DIRECTORY struct { + Size uint32 + TimeDateStamp uint32 + MajorVersion uint16 + MinorVersion uint16 + GlobalFlagsClear uint32 + GlobalFlagsSet uint32 + CriticalSectionDefaultTimeout uint32 + DeCommitFreeBlockThreshold uint32 + DeCommitTotalFreeThreshold uint32 + LockPrefixTable uint32 + MaximumAllocationSize uint32 + VirtualMemoryThreshold uint32 + ProcessHeapFlags uint32 + ProcessAffinityMask uint32 + CSDVersion uint16 + DependentLoadFlags uint16 + EditList uint32 + SecurityCookie uint32 + SEHandlerTable uint32 + SEHandlerCount uint32 + GuardCFCheckFunctionPointer uint32 + GuardCFDispatchFunctionPointer uint32 + GuardCFFunctionTable uint32 + GuardCFFunctionCount uint32 + GuardFlags uint32 + CodeIntegrity IMAGE_LOAD_CONFIG_CODE_INTEGRITY + GuardAddressTakenIatEntryTable uint32 + GuardAddressTakenIatEntryCount uint32 + GuardLongJumpTargetTable uint32 + GuardLongJumpTargetCount uint32 + DynamicValueRelocTable uint32 + CHPEMetadataPointer uint32 + GuardRFFailureRoutine uint32 + GuardRFFailureRoutineFunctionPointer uint32 + DynamicValueRelocTableOffset uint32 + DynamicValueRelocTableSection uint16 + Reserved2 uint16 + GuardRFVerifyStackPointerFunctionPointer uint32 + HotPatchTableOffset uint32 + Reserved3 uint32 + EnclaveConfigurationPointer uint32 + VolatileMetadataPointer uint32 + GuardEHContinuationTable uint32 + GuardEHContinuationCount uint32 + GuardXFGCheckFunctionPointer uint32 + GuardXFGDispatchFunctionPointer uint32 + GuardXFGTableDispatchFunctionPointer uint32 + CastGuardOsDeterminedFailureMode uint32 +} diff --git a/internal/wintun/memmod/syscall_windows_64.go b/internal/wintun/memmod/syscall_windows_64.go new file mode 100644 index 0000000..6f2c039 --- /dev/null +++ b/internal/wintun/memmod/syscall_windows_64.go @@ -0,0 +1,95 @@ +//go:build (windows && amd64) || (windows && arm64) + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package memmod + +// Optional header format +type IMAGE_OPTIONAL_HEADER struct { + Magic uint16 + MajorLinkerVersion uint8 + MinorLinkerVersion uint8 + SizeOfCode uint32 + SizeOfInitializedData uint32 + SizeOfUninitializedData uint32 + AddressOfEntryPoint uint32 + BaseOfCode uint32 + ImageBase uintptr + SectionAlignment uint32 + FileAlignment uint32 + MajorOperatingSystemVersion uint16 + MinorOperatingSystemVersion uint16 + MajorImageVersion uint16 + MinorImageVersion uint16 + MajorSubsystemVersion uint16 + MinorSubsystemVersion uint16 + Win32VersionValue uint32 + SizeOfImage uint32 + SizeOfHeaders uint32 + CheckSum uint32 + Subsystem uint16 + DllCharacteristics uint16 + SizeOfStackReserve uintptr + SizeOfStackCommit uintptr + SizeOfHeapReserve uintptr + SizeOfHeapCommit uintptr + LoaderFlags uint32 + NumberOfRvaAndSizes uint32 + DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY +} + +const IMAGE_ORDINAL_FLAG uintptr = 0x8000000000000000 + +type IMAGE_LOAD_CONFIG_DIRECTORY struct { + Size uint32 + TimeDateStamp uint32 + MajorVersion uint16 + MinorVersion uint16 + GlobalFlagsClear uint32 + GlobalFlagsSet uint32 + CriticalSectionDefaultTimeout uint32 + DeCommitFreeBlockThreshold uint64 + DeCommitTotalFreeThreshold uint64 + LockPrefixTable uint64 + MaximumAllocationSize uint64 + VirtualMemoryThreshold uint64 + ProcessAffinityMask uint64 + ProcessHeapFlags uint32 + CSDVersion uint16 + DependentLoadFlags uint16 + EditList uint64 + SecurityCookie uint64 + SEHandlerTable uint64 + SEHandlerCount uint64 + GuardCFCheckFunctionPointer uint64 + GuardCFDispatchFunctionPointer uint64 + GuardCFFunctionTable uint64 + GuardCFFunctionCount uint64 + GuardFlags uint32 + CodeIntegrity IMAGE_LOAD_CONFIG_CODE_INTEGRITY + GuardAddressTakenIatEntryTable uint64 + GuardAddressTakenIatEntryCount uint64 + GuardLongJumpTargetTable uint64 + GuardLongJumpTargetCount uint64 + DynamicValueRelocTable uint64 + CHPEMetadataPointer uint64 + GuardRFFailureRoutine uint64 + GuardRFFailureRoutineFunctionPointer uint64 + DynamicValueRelocTableOffset uint32 + DynamicValueRelocTableSection uint16 + Reserved2 uint16 + GuardRFVerifyStackPointerFunctionPointer uint64 + HotPatchTableOffset uint32 + Reserved3 uint32 + EnclaveConfigurationPointer uint64 + VolatileMetadataPointer uint64 + GuardEHContinuationTable uint64 + GuardEHContinuationCount uint64 + GuardXFGCheckFunctionPointer uint64 + GuardXFGDispatchFunctionPointer uint64 + GuardXFGTableDispatchFunctionPointer uint64 + CastGuardOsDeterminedFailureMode uint64 +} diff --git a/internal/wintun/session_windows.go b/internal/wintun/session_windows.go new file mode 100644 index 0000000..f023baf --- /dev/null +++ b/internal/wintun/session_windows.go @@ -0,0 +1,90 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package wintun + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +type Session struct { + handle uintptr +} + +const ( + PacketSizeMax = 0xffff // Maximum packet size + RingCapacityMin = 0x20000 // Minimum ring capacity (128 kiB) + RingCapacityMax = 0x4000000 // Maximum ring capacity (64 MiB) +) + +// Packet with data +type Packet struct { + Next *Packet // Pointer to next packet in queue + Size uint32 // Size of packet (max WINTUN_MAX_IP_PACKET_SIZE) + Data *[PacketSizeMax]byte // Pointer to layer 3 IPv4 or IPv6 packet +} + +var ( + procWintunAllocateSendPacket = modwintun.NewProc("WintunAllocateSendPacket") + procWintunEndSession = modwintun.NewProc("WintunEndSession") + procWintunGetReadWaitEvent = modwintun.NewProc("WintunGetReadWaitEvent") + procWintunReceivePacket = modwintun.NewProc("WintunReceivePacket") + procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket") + procWintunSendPacket = modwintun.NewProc("WintunSendPacket") + procWintunStartSession = modwintun.NewProc("WintunStartSession") +) + +func (wintun *Adapter) StartSession(capacity uint32) (session Session, err error) { + r0, _, e1 := syscall.Syscall(procWintunStartSession.Addr(), 2, uintptr(wintun.handle), uintptr(capacity), 0) + if r0 == 0 { + err = e1 + } else { + session = Session{r0} + } + return +} + +func (session Session) End() { + syscall.Syscall(procWintunEndSession.Addr(), 1, session.handle, 0, 0) + session.handle = 0 +} + +func (session Session) ReadWaitEvent() (handle windows.Handle) { + r0, _, _ := syscall.Syscall(procWintunGetReadWaitEvent.Addr(), 1, session.handle, 0, 0) + handle = windows.Handle(r0) + return +} + +func (session Session) ReceivePacket() (packet []byte, err error) { + var packetSize uint32 + r0, _, e1 := syscall.Syscall(procWintunReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0) + if r0 == 0 { + err = e1 + return + } + packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize) + return +} + +func (session Session) ReleaseReceivePacket(packet []byte) { + syscall.Syscall(procWintunReleaseReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) +} + +func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) { + r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket.Addr(), 2, session.handle, uintptr(packetSize), 0) + if r0 == 0 { + err = e1 + return + } + packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize) + return +} + +func (session Session) SendPacket(packet []byte) { + syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) +} diff --git a/internal/wintun/wintun_windows.go b/internal/wintun/wintun_windows.go new file mode 100644 index 0000000..a817e6c --- /dev/null +++ b/internal/wintun/wintun_windows.go @@ -0,0 +1,112 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package wintun + +import ( + "runtime" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +type ( + Adapter struct { + handle uintptr + } +) + +var ( + modwintun = newLazyDLL("wintun.dll") + procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter") + procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter") + procWintunCloseAdapter = modwintun.NewProc("WintunCloseAdapter") + procWintunDeleteDriver = modwintun.NewProc("WintunDeleteDriver") + procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID") + procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion") +) + +func closeAdapter(wintun *Adapter) { + syscall.SyscallN(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0) +} + +// CreateAdapter creates a Wintun adapter. name is the cosmetic name of the adapter. +// tunnelType represents the type of adapter and should be "Wintun". requestedGUID is +// the GUID of the created network adapter, which then influences NLA generation +// deterministically. If it is set to nil, the GUID is chosen by the system at random, +// and hence a new NLA entry is created for each new adapter. +func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) { + var name16 *uint16 + name16, err = windows.UTF16PtrFromString(name) + if err != nil { + return + } + var tunnelType16 *uint16 + tunnelType16, err = windows.UTF16PtrFromString(tunnelType) + if err != nil { + return + } + r0, _, e1 := syscall.Syscall(procWintunCreateAdapter.Addr(), 3, uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID))) + if r0 == 0 { + err = e1 + return + } + wintun = &Adapter{handle: r0} + runtime.SetFinalizer(wintun, closeAdapter) + return +} + +// OpenAdapter opens an existing Wintun adapter by name. +func OpenAdapter(name string) (wintun *Adapter, err error) { + var name16 *uint16 + name16, err = windows.UTF16PtrFromString(name) + if err != nil { + return + } + r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 1, uintptr(unsafe.Pointer(name16)), 0, 0) + if r0 == 0 { + err = e1 + return + } + wintun = &Adapter{handle: r0} + runtime.SetFinalizer(wintun, closeAdapter) + return +} + +// Close closes a Wintun adapter. +func (wintun *Adapter) Close() (err error) { + runtime.SetFinalizer(wintun, nil) + r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0) + if r1 == 0 { + err = e1 + } + return +} + +// Uninstall removes the driver from the system if no drivers are currently in use. +func Uninstall() (err error) { + r1, _, e1 := syscall.Syscall(procWintunDeleteDriver.Addr(), 0, 0, 0, 0) + if r1 == 0 { + err = e1 + } + return +} + +// RunningVersion returns the version of the running Wintun driver. +func RunningVersion() (version uint32, err error) { + r0, _, e1 := syscall.Syscall(procWintunGetRunningDriverVersion.Addr(), 0, 0, 0, 0) + version = uint32(r0) + if version == 0 { + err = e1 + } + return +} + +// LUID returns the LUID of the adapter. +func (wintun *Adapter) LUID() (luid uint64) { + syscall.Syscall(procWintunGetAdapterLUID.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&luid)), 0) + return +} diff --git a/internal/wintun/x86/wintun.dll b/internal/wintun/x86/wintun.dll new file mode 100755 index 0000000..2ab97db Binary files /dev/null and b/internal/wintun/x86/wintun.dll differ diff --git a/monitor.go b/monitor.go new file mode 100644 index 0000000..9711ef7 --- /dev/null +++ b/monitor.go @@ -0,0 +1,14 @@ +package tun + +import E "github.com/sagernet/sing/common/exceptions" + +var ErrNoRoute = E.New("no route to internet") + +type InterfaceMonitor interface { + Start() error + Close() error + DefaultInterfaceName() string + DefaultInterfaceIndex() int +} + +type InterfaceMonitorCallback func() diff --git a/monitor_linux.go b/monitor_linux.go new file mode 100644 index 0000000..f36d9cc --- /dev/null +++ b/monitor_linux.go @@ -0,0 +1,101 @@ +package tun + +import ( + "os" + + E "github.com/sagernet/sing/common/exceptions" + + "github.com/vishvananda/netlink" +) + +type NativeMonitor struct { + defaultInterfaceName string + defaultInterfaceIndex int + update chan netlink.RouteUpdate + close chan struct{} + callback InterfaceMonitorCallback +} + +func NewMonitor(callback InterfaceMonitorCallback) (InterfaceMonitor, error) { + return &NativeMonitor{ + callback: callback, + update: make(chan netlink.RouteUpdate, 2), + close: make(chan struct{}), + }, nil +} + +func (m *NativeMonitor) Start() error { + err := netlink.RouteSubscribe(m.update, m.close) + if err != nil { + return err + } + err = m.checkUpdate() + if err != nil { + return err + } + go m.loopUpdate() + return nil +} + +func (m *NativeMonitor) loopUpdate() { + for { + select { + case <-m.close: + return + case <-m.update: + m.checkUpdate() + } + } +} + +func (m *NativeMonitor) checkUpdate() error { + routes, err := netlink.RouteList(nil, netlink.FAMILY_V4) + if err != nil { + return err + } + for _, route := range routes { + if route.Dst != nil { + continue + } + var link netlink.Link + link, err = netlink.LinkByIndex(route.LinkIndex) + if err != nil { + return err + } + + if link.Type() == "tuntap" { + continue + } + + oldInterface := m.defaultInterfaceName + oldIndex := m.defaultInterfaceIndex + + m.defaultInterfaceName = link.Attrs().Name + m.defaultInterfaceIndex = link.Attrs().Index + + if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex { + return nil + } + m.callback() + return nil + } + return E.New("no route to internet") +} + +func (m *NativeMonitor) Close() error { + select { + case <-m.close: + return os.ErrClosed + default: + } + close(m.close) + return nil +} + +func (m *NativeMonitor) DefaultInterfaceName() string { + return m.defaultInterfaceName +} + +func (m *NativeMonitor) DefaultInterfaceIndex() int { + return m.defaultInterfaceIndex +} diff --git a/monitor_other.go b/monitor_other.go new file mode 100644 index 0000000..b10d336 --- /dev/null +++ b/monitor_other.go @@ -0,0 +1,9 @@ +//go:build !linux && !windows + +package tun + +import "os" + +func NewMonitor() (InterfaceMonitor, error) { + return nil, os.ErrInvalid +} diff --git a/monitor_windows.go b/monitor_windows.go new file mode 100644 index 0000000..afca49b --- /dev/null +++ b/monitor_windows.go @@ -0,0 +1,98 @@ +package tun + +import ( + "github.com/sagernet/sing-tun/internal/winipcfg" + + "golang.org/x/sys/windows" +) + +var _ InterfaceMonitor = (*NativeMonitor)(nil) + +type NativeMonitor struct { + listener *winipcfg.RouteChangeCallback + callback InterfaceMonitorCallback + defaultInterfaceName string + defaultInterfaceIndex int +} + +func NewMonitor(callback InterfaceMonitorCallback) (InterfaceMonitor, error) { + return &NativeMonitor{callback: callback}, nil +} + +func (m *NativeMonitor) Start() error { + err := m.checkUpdate() + if err != nil { + return err + } + listener, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { + m.checkUpdate() + }) + if err != nil { + return err + } + m.listener = listener + return nil +} + +func (m *NativeMonitor) checkUpdate() error { + rows, err := winipcfg.GetIPForwardTable2(windows.AF_INET) + if err != nil { + return err + } + + lowestMetric := ^uint32(0) + alias := "" + var index int + + for _, row := range rows { + ifrow, err := row.InterfaceLUID.Interface() + if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp { + continue + } + + iface, err := row.InterfaceLUID.IPInterface(windows.AF_INET) + if err != nil { + continue + } + + if ifrow.Type == winipcfg.IfTypePropVirtual || ifrow.Type == winipcfg.IfTypeSoftwareLoopback { + continue + } + + metric := row.Metric + iface.Metric + if metric < lowestMetric { + lowestMetric = metric + alias = ifrow.Alias() + index = int(ifrow.InterfaceIndex) + } + } + + if alias == "" { + return ErrNoRoute + } + + oldInterface := m.defaultInterfaceName + oldIndex := m.defaultInterfaceIndex + + m.defaultInterfaceName = alias + m.defaultInterfaceIndex = index + + if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex { + return nil + } + + m.callback() + return nil +} + +func (m *NativeMonitor) Close() error { + return m.listener.Unregister() +} + +func (m *NativeMonitor) DefaultInterfaceName() string { + return m.defaultInterfaceName +} + +func (m *NativeMonitor) DefaultInterfaceIndex() int { + return m.defaultInterfaceIndex +} diff --git a/tun.go b/tun.go index 4f0e006..a24240f 100644 --- a/tun.go +++ b/tun.go @@ -2,9 +2,16 @@ package tun import ( N "github.com/sagernet/sing/common/network" + + "gvisor.dev/gvisor/pkg/tcpip/stack" ) type Handler interface { N.TCPConnectionHandler N.UDPConnectionHandler } + +type Tun interface { + NewEndpoint() (stack.LinkEndpoint, error) + Close() error +} diff --git a/tun_linux.go b/tun_linux.go index 5c5c037..11890e0 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -1,44 +1,72 @@ package tun import ( + "math" "net" "net/netip" + "runtime" + "syscall" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" "github.com/vishvananda/netlink" + "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/tun" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -func Open(name string) (uintptr, error) { - tunFd, err := tun.Open(name) - if err != nil { - return 0, err - } - return uintptr(tunFd), nil +type NativeTun struct { + name string + inet4Address netip.Prefix + inet6Address netip.Prefix + mtu uint32 + autoRoute bool + fdList []int } -func Configure(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, mtu uint32, autoRoute bool) error { - tunLink, err := netlink.LinkByName(name) +func Open(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, mtu uint32, autoRoute bool) (Tun, error) { + tunFd, err := tun.Open(name) + if err != nil { + return nil, err + } + nativeTun := &NativeTun{ + name: name, + fdList: []int{tunFd}, + mtu: mtu, + inet4Address: inet4Address, + inet6Address: inet6Address, + autoRoute: autoRoute, + } + err = nativeTun.configure() + if err != nil { + return nil, E.Errors(err, syscall.Close(tunFd)) + } + return nativeTun, nil +} + +func (t *NativeTun) configure() error { + tunLink, err := netlink.LinkByName(t.name) if err != nil { return err } - - if inet4Address.IsValid() { - addr4, _ := netlink.ParseAddr(inet4Address.String()) + if t.inet4Address.IsValid() { + addr4, _ := netlink.ParseAddr(t.inet4Address.String()) err = netlink.AddrAdd(tunLink, addr4) if err != nil { return err } } - if inet6Address.IsValid() { - addr6, _ := netlink.ParseAddr(inet6Address.String()) + if t.inet6Address.IsValid() { + addr6, _ := netlink.ParseAddr(t.inet6Address.String()) err = netlink.AddrAdd(tunLink, addr6) if err != nil { return err } } - err = netlink.LinkSetMTU(tunLink, int(mtu)) + err = netlink.LinkSetMTU(tunLink, int(t.mtu)) if err != nil { return err } @@ -48,8 +76,8 @@ func Configure(name string, inet4Address netip.Prefix, inet6Address netip.Prefix return err } - if autoRoute { - if inet4Address.IsValid() { + if t.autoRoute { + if t.inet4Address.IsValid() { err = netlink.RouteAdd(&netlink.Route{ Dst: &net.IPNet{ IP: net.IPv4zero, @@ -61,7 +89,7 @@ func Configure(name string, inet4Address netip.Prefix, inet6Address netip.Prefix return err } } - if inet6Address.IsValid() { + if t.inet6Address.IsValid() { err = netlink.RouteAdd(&netlink.Route{ Dst: &net.IPNet{ IP: net.IPv6zero, @@ -74,17 +102,38 @@ func Configure(name string, inet4Address netip.Prefix, inet6Address netip.Prefix } } } - return nil } -func UnConfigure(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, autoRoute bool) error { - if autoRoute { - tunLink, err := netlink.LinkByName(name) +func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) { + var packetDispatchMode fdbased.PacketDispatchMode + if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" { + packetDispatchMode = fdbased.PacketMMap + } else { + packetDispatchMode = fdbased.RecvMMsg + } + dupFdSize := int(math.Max(float64(runtime.NumCPU()/2), 1)) - 1 + for i := 0; i < dupFdSize; i++ { + dupFd, err := syscall.Dup(t.fdList[0]) if err != nil { - return err + return nil, err } - if inet4Address.IsValid() { + t.fdList = append(t.fdList, dupFd) + } + return fdbased.New(&fdbased.Options{ + FDs: t.fdList, + MTU: t.mtu, + PacketDispatchMode: packetDispatchMode, + }) +} + +func (t *NativeTun) Close() error { + tunLink, err := netlink.LinkByName(t.name) + if err != nil { + return err + } + if t.autoRoute { + if t.inet4Address.IsValid() { err = netlink.RouteDel(&netlink.Route{ Dst: &net.IPNet{ IP: net.IPv4zero, @@ -96,7 +145,7 @@ func UnConfigure(name string, inet4Address netip.Prefix, inet6Address netip.Pref return err } } - if inet6Address.IsValid() { + if t.inet6Address.IsValid() { err = netlink.RouteDel(&netlink.Route{ Dst: &net.IPNet{ IP: net.IPv6zero, @@ -109,5 +158,5 @@ func UnConfigure(name string, inet4Address netip.Prefix, inet6Address netip.Pref } } } - return nil + return E.Errors(common.Map(t.fdList, syscall.Close)...) } diff --git a/tun_other.go b/tun_other.go index 787c3c1..06ad2d2 100644 --- a/tun_other.go +++ b/tun_other.go @@ -1,4 +1,4 @@ -//go:build !linux +//go:build !linux && !windows package tun diff --git a/tun_windows.go b/tun_windows.go new file mode 100644 index 0000000..2643546 --- /dev/null +++ b/tun_windows.go @@ -0,0 +1,366 @@ +package tun + +import ( + "crypto/md5" + "errors" + "fmt" + "net/netip" + "os" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/sagernet/sing-tun/internal/winipcfg" + "github.com/sagernet/sing-tun/internal/wintun" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + + "golang.org/x/sys/windows" + gBuffer "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var TunnelType = "sing-tun" + +type NativeTun struct { + adapter *wintun.Adapter + inet4Address netip.Prefix + inet6Address netip.Prefix + mtu uint32 + autoRoute bool + session wintun.Session + readWait windows.Handle + rate rateJuggler + running sync.WaitGroup + closeOnce sync.Once + close int32 +} + +func Open(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, mtu uint32, autoRoute bool) (Tun, error) { + adapter, err := wintun.CreateAdapter(name, TunnelType, generateGUIDByDeviceName(name)) + if err != nil { + return nil, err + } + nativeTun := &NativeTun{ + adapter: adapter, + inet4Address: inet4Address, + inet6Address: inet6Address, + mtu: mtu, + autoRoute: autoRoute, + } + err = nativeTun.configure() + if err != nil { + adapter.Close() + return nil, err + } + return nativeTun, nil +} + +func (t *NativeTun) configure() error { + luid := winipcfg.LUID(t.adapter.LUID()) + if t.inet4Address.IsValid() { + err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET), []netip.Prefix{t.inet4Address}) + if err != nil { + return E.Cause(err, "set ipv4 address") + } + } + if t.inet6Address.IsValid() { + err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET6), []netip.Prefix{t.inet6Address}) + if err != nil { + return E.Cause(err, "set ipv6 address") + } + } + err := luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), []netip.Addr{t.inet4Address.Addr().Next()}, nil) + if err != nil { + return E.Cause(err, "set ipv4 dns") + } + err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), []netip.Addr{t.inet6Address.Addr().Next()}, nil) + if err != nil { + return E.Cause(err, "set ipv6 dns") + } + if t.autoRoute { + if t.inet4Address.IsValid() { + err = luid.AddRoute(netip.PrefixFrom(netip.IPv4Unspecified(), 0), netip.IPv4Unspecified(), 0) + if err != nil { + return E.Cause(err, "set ipv4 route") + } + } + if t.inet6Address.IsValid() { + err = luid.AddRoute(netip.PrefixFrom(netip.IPv6Unspecified(), 0), netip.IPv6Unspecified(), 0) + if err != nil { + return E.Cause(err, "set ipv6 route") + } + } + } + if t.inet4Address.IsValid() { + var inetIf *winipcfg.MibIPInterfaceRow + inetIf, err = luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET)) + if err != nil { + return err + } + inetIf.ForwardingEnabled = true + inetIf.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + inetIf.DadTransmits = 0 + inetIf.ManagedAddressConfigurationSupported = false + inetIf.OtherStatefulConfigurationSupported = false + inetIf.NLMTU = t.mtu + if t.autoRoute { + inetIf.UseAutomaticMetric = false + inetIf.Metric = 0 + } + err = inetIf.Set() + if err != nil { + return E.Cause(err, "set ipv4 options") + } + } + if t.inet6Address.IsValid() { + var inet6If *winipcfg.MibIPInterfaceRow + inet6If, err = luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET6)) + if err != nil { + return err + } + inet6If.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + inet6If.DadTransmits = 0 + inet6If.ManagedAddressConfigurationSupported = false + inet6If.OtherStatefulConfigurationSupported = false + inet6If.NLMTU = t.mtu + if t.autoRoute { + inet6If.UseAutomaticMetric = false + inet6If.Metric = 0 + } + err = inet6If.Set() + if err != nil { + return E.Cause(err, "set ipv6 options") + } + } + + return nil +} + +func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) { + session, err := t.adapter.StartSession(0x800000) + if err != nil { + return nil, err + } + t.session = session + t.readWait = session.ReadWaitEvent() + return &WintunEndpoint{tun: t}, nil +} + +func (t *NativeTun) Read(p []byte) (n int, err error) { + t.running.Add(1) + defer t.running.Done() +retry: + if atomic.LoadInt32(&t.close) == 1 { + return 0, os.ErrClosed + } + start := nanotime() + shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2 + for { + if atomic.LoadInt32(&t.close) == 1 { + return 0, os.ErrClosed + } + packet, err := t.session.ReceivePacket() + switch err { + case nil: + packetSize := len(packet) + n = copy(p, packet) + t.session.ReleaseReceivePacket(packet) + t.rate.update(uint64(packetSize)) + return n, nil + case windows.ERROR_NO_MORE_ITEMS: + if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { + windows.WaitForSingleObject(t.readWait, windows.INFINITE) + goto retry + } + procyield(1) + continue + case windows.ERROR_HANDLE_EOF: + return 0, os.ErrClosed + case windows.ERROR_INVALID_DATA: + return 0, errors.New("send ring corrupt") + } + return 0, fmt.Errorf("read failed: %w", err) + } +} + +func (t *NativeTun) Write(packetElementList [][]byte) (n int, err error) { + t.running.Add(1) + defer t.running.Done() + if atomic.LoadInt32(&t.close) == 1 { + return 0, os.ErrClosed + } + var packetSize int + for _, packetElement := range packetElementList { + packetSize += len(packetElement) + } + t.rate.update(uint64(packetSize)) + packet, err := t.session.AllocateSendPacket(packetSize) + if err == nil { + var index int + for _, packetElement := range packetElementList { + index += copy(packet[index:], packetElement) + } + t.session.SendPacket(packet) + return + } + switch err { + case windows.ERROR_HANDLE_EOF: + return 0, os.ErrClosed + case windows.ERROR_BUFFER_OVERFLOW: + return 0, nil // Dropping when ring is full. + } + return 0, fmt.Errorf("write failed: %w", err) +} + +func (t *NativeTun) Close() error { + var err error + t.closeOnce.Do(func() { + atomic.StoreInt32(&t.close, 1) + windows.SetEvent(t.readWait) + t.running.Wait() + t.session.End() + t.adapter.Close() + }) + return err +} + +func generateGUIDByDeviceName(name string) *windows.GUID { + hash := md5.New() + hash.Write([]byte("wintun")) + hash.Write([]byte(name)) + sum := hash.Sum(nil) + return (*windows.GUID)(unsafe.Pointer(&sum[0])) +} + +//go:linkname procyield runtime.procyield +func procyield(cycles uint32) + +//go:linkname nanotime runtime.nanotime +func nanotime() int64 + +type rateJuggler struct { + current uint64 + nextByteCount uint64 + nextStartTime int64 + changing int32 +} + +func (rate *rateJuggler) update(packetLen uint64) { + now := nanotime() + total := atomic.AddUint64(&rate.nextByteCount, packetLen) + period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) + if period >= rateMeasurementGranularity { + if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { + return + } + atomic.StoreInt64(&rate.nextStartTime, now) + atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) + atomic.StoreUint64(&rate.nextByteCount, 0) + atomic.StoreInt32(&rate.changing, 0) + } +} + +const ( + rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) + spinloopRateThreshold = 800000000 / 8 // 800mbps + spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s +) + +var _ stack.LinkEndpoint = (*WintunEndpoint)(nil) + +type WintunEndpoint struct { + tun *NativeTun + dispatcher stack.NetworkDispatcher +} + +func (e *WintunEndpoint) MTU() uint32 { + return e.tun.mtu +} + +func (e *WintunEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +func (e *WintunEndpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +func (e *WintunEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityNone +} + +func (e *WintunEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + if dispatcher == nil && e.dispatcher != nil { + e.dispatcher = nil + return + } + if dispatcher != nil && e.dispatcher == nil { + e.dispatcher = dispatcher + go e.dispatchLoop() + } +} + +func (e *WintunEndpoint) dispatchLoop() { + _buffer := buf.StackNewPacket() + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + for { + n, err := e.tun.Read(buffer.FreeBytes()) + if err != nil { + break + } + var view gBuffer.View + view.Append(buffer.To(n)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: view, + IsForwardedPacket: true, + }) + defer pkt.DecRef() + var p tcpip.NetworkProtocolNumber + ipHeader, ok := pkt.Data().PullUp(1) + if !ok { + continue + } + switch header.IPVersion(ipHeader) { + case header.IPv4Version: + p = header.IPv4ProtocolNumber + case header.IPv6Version: + p = header.IPv6ProtocolNumber + default: + continue + } + e.dispatcher.DeliverNetworkPacket(p, pkt) + } +} + +func (e *WintunEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +func (e *WintunEndpoint) Wait() { +} + +func (e *WintunEndpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +func (e *WintunEndpoint) AddHeader(buffer *stack.PacketBuffer) { +} + +func (e *WintunEndpoint) WritePackets(packetBufferList stack.PacketBufferList) (int, tcpip.Error) { + var n int + for _, packet := range packetBufferList.AsSlice() { + _, err := e.tun.Write(packet.Slices()) + if err != nil { + return n, &tcpip.ErrAborted{} + } + n++ + } + return n, nil +}