sing-tun/tun_windows.go
2022-09-08 18:11:26 +08:00

316 lines
8.4 KiB
Go

package tun
import (
"crypto/md5"
"errors"
"fmt"
"net/netip"
"os"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/sagernet/sing-tun/internal/winipcfg"
"github.com/sagernet/sing-tun/internal/wintun"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/sys/windows"
)
var TunnelType = "sing-tun"
type NativeTun struct {
adapter *wintun.Adapter
options Options
session wintun.Session
readWait windows.Handle
rate rateJuggler
running sync.WaitGroup
closeOnce sync.Once
close int32
}
func Open(options Options) (WinTun, error) {
adapter, err := wintun.CreateAdapter(options.Name, TunnelType, generateGUIDByDeviceName(options.Name))
if err != nil {
return nil, err
}
nativeTun := &NativeTun{
adapter: adapter,
options: options,
}
session, err := adapter.StartSession(0x800000)
if err != nil {
return nil, err
}
nativeTun.session = session
nativeTun.readWait = session.ReadWaitEvent()
err = nativeTun.configure()
if err != nil {
session.End()
adapter.Close()
return nil, err
}
return nativeTun, nil
}
func (t *NativeTun) configure() error {
luid := winipcfg.LUID(t.adapter.LUID())
if len(t.options.Inet4Address) > 0 {
err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET), t.options.Inet4Address)
if err != nil {
return E.Cause(err, "set ipv4 address")
}
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), []netip.Addr{t.options.Inet4Address[0].Addr().Next()}, nil)
if err != nil {
return E.Cause(err, "set ipv4 dns")
}
}
if len(t.options.Inet6Address) > 0 {
err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET6), t.options.Inet6Address)
if err != nil {
return E.Cause(err, "set ipv6 address")
}
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), []netip.Addr{t.options.Inet6Address[0].Addr().Next()}, nil)
if err != nil {
return E.Cause(err, "set ipv6 dns")
}
}
if t.options.AutoRoute {
if len(t.options.Inet4Address) > 0 {
err := luid.AddRoute(netip.PrefixFrom(netip.IPv4Unspecified(), 0), netip.IPv4Unspecified(), 0)
if err != nil {
return E.Cause(err, "set ipv4 route")
}
}
if len(t.options.Inet6Address) > 0 {
err := luid.AddRoute(netip.PrefixFrom(netip.IPv6Unspecified(), 0), netip.IPv6Unspecified(), 0)
if err != nil {
return E.Cause(err, "set ipv6 route")
}
}
}
if len(t.options.Inet4Address) > 0 {
inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET))
if err != nil {
return err
}
inetIf.ForwardingEnabled = true
inetIf.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
inetIf.DadTransmits = 0
inetIf.ManagedAddressConfigurationSupported = false
inetIf.OtherStatefulConfigurationSupported = false
inetIf.NLMTU = t.options.MTU
if t.options.AutoRoute {
inetIf.UseAutomaticMetric = false
inetIf.Metric = 0
}
err = inetIf.Set()
if err != nil {
return E.Cause(err, "set ipv4 options")
}
}
if len(t.options.Inet6Address) > 0 {
inet6If, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET6))
if err != nil {
return err
}
inet6If.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
inet6If.DadTransmits = 0
inet6If.ManagedAddressConfigurationSupported = false
inet6If.OtherStatefulConfigurationSupported = false
inet6If.NLMTU = t.options.MTU
if t.options.AutoRoute {
inet6If.UseAutomaticMetric = false
inet6If.Metric = 0
}
err = inet6If.Set()
if err != nil {
return E.Cause(err, "set ipv6 options")
}
}
return nil
}
func (t *NativeTun) Read(p []byte) (n int, err error) {
return 0, os.ErrInvalid
}
func (t *NativeTun) ReadPacket() ([]byte, func(), error) {
t.running.Add(1)
defer t.running.Done()
retry:
if atomic.LoadInt32(&t.close) == 1 {
return nil, nil, os.ErrClosed
}
start := nanotime()
shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2
for {
if atomic.LoadInt32(&t.close) == 1 {
return nil, nil, os.ErrClosed
}
packet, err := t.session.ReceivePacket()
switch err {
case nil:
packetSize := len(packet)
t.rate.update(uint64(packetSize))
return packet, func() { t.session.ReleaseReceivePacket(packet) }, nil
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(t.readWait, windows.INFINITE)
goto retry
}
procyield(1)
continue
case windows.ERROR_HANDLE_EOF:
return nil, nil, os.ErrClosed
case windows.ERROR_INVALID_DATA:
return nil, nil, errors.New("send ring corrupt")
}
return nil, nil, fmt.Errorf("read failed: %w", err)
}
}
func (t *NativeTun) ReadFunc(block func(b []byte)) error {
t.running.Add(1)
defer t.running.Done()
retry:
if atomic.LoadInt32(&t.close) == 1 {
return os.ErrClosed
}
start := nanotime()
shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2
for {
if atomic.LoadInt32(&t.close) == 1 {
return os.ErrClosed
}
packet, err := t.session.ReceivePacket()
switch err {
case nil:
packetSize := len(packet)
block(packet)
t.session.ReleaseReceivePacket(packet)
t.rate.update(uint64(packetSize))
return nil
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(t.readWait, windows.INFINITE)
goto retry
}
procyield(1)
continue
case windows.ERROR_HANDLE_EOF:
return os.ErrClosed
case windows.ERROR_INVALID_DATA:
return errors.New("send ring corrupt")
}
return fmt.Errorf("read failed: %w", err)
}
}
func (t *NativeTun) Write(p []byte) (n int, err error) {
t.running.Add(1)
defer t.running.Done()
if atomic.LoadInt32(&t.close) == 1 {
return 0, os.ErrClosed
}
t.rate.update(uint64(len(p)))
packet, err := t.session.AllocateSendPacket(len(p))
copy(packet, p)
if err == nil {
t.session.SendPacket(packet)
return len(p), nil
}
switch err {
case windows.ERROR_HANDLE_EOF:
return 0, os.ErrClosed
case windows.ERROR_BUFFER_OVERFLOW:
return 0, nil // Dropping when ring is full.
}
return 0, fmt.Errorf("write failed: %w", err)
}
func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) {
t.running.Add(1)
defer t.running.Done()
if atomic.LoadInt32(&t.close) == 1 {
return 0, os.ErrClosed
}
var packetSize int
for _, packetElement := range packetElementList {
packetSize += len(packetElement)
}
t.rate.update(uint64(packetSize))
packet, err := t.session.AllocateSendPacket(packetSize)
if err == nil {
var index int
for _, packetElement := range packetElementList {
index += copy(packet[index:], packetElement)
}
t.session.SendPacket(packet)
return
}
switch err {
case windows.ERROR_HANDLE_EOF:
return 0, os.ErrClosed
case windows.ERROR_BUFFER_OVERFLOW:
return 0, nil // Dropping when ring is full.
}
return 0, fmt.Errorf("write failed: %w", err)
}
func (t *NativeTun) Close() error {
var err error
t.closeOnce.Do(func() {
atomic.StoreInt32(&t.close, 1)
windows.SetEvent(t.readWait)
t.running.Wait()
t.session.End()
t.adapter.Close()
})
return err
}
func generateGUIDByDeviceName(name string) *windows.GUID {
hash := md5.New()
hash.Write([]byte("wintun"))
hash.Write([]byte(name))
sum := hash.Sum(nil)
return (*windows.GUID)(unsafe.Pointer(&sum[0]))
}
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime
func nanotime() int64
type rateJuggler struct {
current uint64
nextByteCount uint64
nextStartTime int64
changing int32
}
func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
total := atomic.AddUint64(&rate.nextByteCount, packetLen)
period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
if period >= rateMeasurementGranularity {
if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
return
}
atomic.StoreInt64(&rate.nextStartTime, now)
atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
atomic.StoreUint64(&rate.nextByteCount, 0)
atomic.StoreInt32(&rate.changing, 0)
}
}
const (
rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
spinloopRateThreshold = 800000000 / 8 // 800mbps
spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
)