refactor: WireGuard endpoint

This commit is contained in:
世界 2024-11-21 18:10:41 +08:00
parent fd299a0961
commit 68781387fe
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
89 changed files with 2187 additions and 679 deletions

View file

@ -128,7 +128,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
select {
case <-c.done:
default:
c.logger.Error(context.Background(), E.Cause(err, "read packet"))
c.logger.Error(E.Cause(err, "read packet"))
err = nil
}
return
@ -138,7 +138,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
b := packets[0]
common.ClearArray(b[1:4])
}
eps[0] = Endpoint(M.AddrPortFromNet(addr))
eps[0] = remoteEndpoint(M.AddrPortFromNet(addr))
count = 1
return
}
@ -169,7 +169,7 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
time.Sleep(time.Second)
return err
}
destination := netip.AddrPort(ep.(Endpoint))
destination := netip.AddrPort(ep.(remoteEndpoint))
for _, b := range bufs {
if len(b) > 3 {
reserved, loaded := c.reservedForEndpoint[destination]
@ -192,7 +192,7 @@ func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
if err != nil {
return nil, err
}
return Endpoint(ap), nil
return remoteEndpoint(ap), nil
}
func (c *ClientBind) BatchSize() int {
@ -229,3 +229,31 @@ func (w *wireConn) Close() error {
close(w.done)
return nil
}
var _ conn.Endpoint = (*remoteEndpoint)(nil)
type remoteEndpoint netip.AddrPort
func (e remoteEndpoint) ClearSrc() {
}
func (e remoteEndpoint) SrcToString() string {
return ""
}
func (e remoteEndpoint) DstToString() string {
return (netip.AddrPort)(e).String()
}
func (e remoteEndpoint) DstToBytes() []byte {
b, _ := (netip.AddrPort)(e).MarshalBinary()
return b
}
func (e remoteEndpoint) DstIP() netip.Addr {
return (netip.AddrPort)(e).Addr()
}
func (e remoteEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}

View file

@ -1,13 +1,44 @@
package wireguard
import (
"context"
"net/netip"
"time"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/tun"
"github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun"
)
type Device interface {
tun.Device
wgTun.Device
N.Dialer
Start() error
// NewEndpoint() (stack.LinkEndpoint, error)
SetDevice(device *device.Device)
}
type DeviceOptions struct {
Context context.Context
Logger logger.ContextLogger
System bool
Handler tun.Handler
UDPTimeout time.Duration
CreateDialer func(interfaceName string) N.Dialer
Name string
MTU uint32
GSO bool
Address []netip.Prefix
AllowedAddress []netip.Prefix
}
func NewDevice(options DeviceOptions) (Device, error) {
if !options.System {
return newStackDevice(options)
} else if options.Handler == nil {
return newSystemDevice(options)
} else {
return newSystemStackDevice(options)
}
}

View file

@ -5,7 +5,6 @@ package wireguard
import (
"context"
"net"
"net/netip"
"os"
"github.com/sagernet/gvisor/pkg/buffer"
@ -15,52 +14,41 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun"
)
var _ Device = (*StackDevice)(nil)
var _ Device = (*stackDevice)(nil)
const defaultNIC tcpip.NICID = 1
type StackDevice struct {
stack *stack.Stack
mtu uint32
events chan wgTun.Event
outbound chan *stack.PacketBuffer
packetOutbound chan *buf.Buffer
done chan struct{}
dispatcher stack.NetworkDispatcher
addr4 tcpip.Address
addr6 tcpip.Address
type stackDevice struct {
stack *stack.Stack
mtu uint32
events chan wgTun.Event
outbound chan *stack.PacketBuffer
done chan struct{}
dispatcher stack.NetworkDispatcher
addr4 tcpip.Address
addr6 tcpip.Address
}
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
HandleLocal: true,
})
tunDevice := &StackDevice{
stack: ipStack,
mtu: mtu,
events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
packetOutbound: make(chan *buf.Buffer, 256),
done: make(chan struct{}),
func newStackDevice(options DeviceOptions) (*stackDevice, error) {
tunDevice := &stackDevice{
mtu: options.MTU,
events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
done: make(chan struct{}),
}
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
if err != nil {
return nil, E.New(err.String())
return nil, err
}
for _, prefix := range localAddresses {
for _, prefix := range options.Address {
addr := tun.AddressFromAddr(prefix.Addr())
protoAddr := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddressWithPrefix{
@ -75,32 +63,27 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
tunDevice.addr6 = addr
protoAddr.Protocol = ipv6.ProtocolNumber
}
err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
if err != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
if gErr != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
}
}
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
cOpt := tcpip.CongestionControlOption("cubic")
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
tunDevice.stack = ipStack
if options.Handler != nil {
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
}
return tunDevice, nil
}
func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
return (*wireEndpoint)(w), nil
}
func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
addr := tcpip.FullAddress{
NIC: defaultNIC,
NIC: tun.DefaultNIC,
Port: destination.Port,
Addr: tun.AddressFromAddr(destination.Addr),
}
bind := tcpip.FullAddress{
NIC: defaultNIC,
NIC: tun.DefaultNIC,
}
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() {
@ -128,9 +111,9 @@ func (w *StackDevice) DialContext(ctx context.Context, network string, destinati
}
}
func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
bind := tcpip.FullAddress{
NIC: defaultNIC,
NIC: tun.DefaultNIC,
}
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() {
@ -147,24 +130,19 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
return udpConn, nil
}
func (w *StackDevice) Inet4Address() netip.Addr {
return tun.AddrFromAddress(w.addr4)
func (w *stackDevice) SetDevice(device *device.Device) {
}
func (w *StackDevice) Inet6Address() netip.Addr {
return tun.AddrFromAddress(w.addr6)
}
func (w *StackDevice) Start() error {
func (w *stackDevice) Start() error {
w.events <- wgTun.EventUp
return nil
}
func (w *StackDevice) File() *os.File {
func (w *stackDevice) File() *os.File {
return nil
}
func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
select {
case packetBuffer, ok := <-w.outbound:
if !ok {
@ -180,17 +158,12 @@ func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, e
sizes[0] = n
count = 1
return
case packet := <-w.packetOutbound:
defer packet.Release()
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
count = 1
return
case <-w.done:
return 0, os.ErrClosed
}
}
func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
for _, b := range bufs {
b = b[offset:]
if len(b) == 0 {
@ -213,23 +186,23 @@ func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
return
}
func (w *StackDevice) Flush() error {
func (w *stackDevice) Flush() error {
return nil
}
func (w *StackDevice) MTU() (int, error) {
func (w *stackDevice) MTU() (int, error) {
return int(w.mtu), nil
}
func (w *StackDevice) Name() (string, error) {
func (w *stackDevice) Name() (string, error) {
return "sing-box", nil
}
func (w *StackDevice) Events() <-chan wgTun.Event {
func (w *stackDevice) Events() <-chan wgTun.Event {
return w.events
}
func (w *StackDevice) Close() error {
func (w *stackDevice) Close() error {
close(w.done)
close(w.events)
w.stack.Close()
@ -240,13 +213,13 @@ func (w *StackDevice) Close() error {
return nil
}
func (w *StackDevice) BatchSize() int {
func (w *stackDevice) BatchSize() int {
return 1
}
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
type wireEndpoint StackDevice
type wireEndpoint stackDevice
func (ep *wireEndpoint) MTU() uint32 {
return ep.mtu

View file

@ -2,12 +2,12 @@
package wireguard
import (
"net/netip"
import "github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun"
)
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) {
func newStackDevice(options DeviceOptions) (Device, error) {
return nil, tun.ErrGVisorNotIncluded
}
func newSystemStackDevice(options DeviceOptions) (Device, error) {
return nil, tun.ErrGVisorNotIncluded
}

View file

@ -6,96 +6,89 @@ import (
"net"
"net/netip"
"os"
"runtime"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
"github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun"
)
var _ Device = (*SystemDevice)(nil)
var _ Device = (*systemDevice)(nil)
type SystemDevice struct {
dialer N.Dialer
device tun.Tun
batchDevice tun.LinuxTUN
name string
mtu uint32
inet4Addresses []netip.Prefix
inet6Addresses []netip.Prefix
gso bool
events chan wgTun.Event
closeOnce sync.Once
type systemDevice struct {
options DeviceOptions
dialer N.Dialer
device tun.Tun
batchDevice tun.LinuxTUN
events chan wgTun.Event
closeOnce sync.Once
}
func NewSystemDevice(networkManager adapter.NetworkManager, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) {
var inet4Addresses []netip.Prefix
var inet6Addresses []netip.Prefix
for _, prefixes := range localPrefixes {
if prefixes.Addr().Is4() {
inet4Addresses = append(inet4Addresses, prefixes)
} else {
inet6Addresses = append(inet6Addresses, prefixes)
}
func newSystemDevice(options DeviceOptions) (*systemDevice, error) {
if options.Name == "" {
options.Name = tun.CalculateInterfaceName("wg")
}
if interfaceName == "" {
interfaceName = tun.CalculateInterfaceName("wg")
}
return &SystemDevice{
dialer: common.Must1(dialer.NewDefault(networkManager, option.DialerOptions{
BindInterface: interfaceName,
})),
name: interfaceName,
mtu: mtu,
inet4Addresses: inet4Addresses,
inet6Addresses: inet6Addresses,
gso: gso,
events: make(chan wgTun.Event, 1),
return &systemDevice{
options: options,
dialer: options.CreateDialer(options.Name),
events: make(chan wgTun.Event, 1),
}, nil
}
func (w *SystemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
func (w *systemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return w.dialer.DialContext(ctx, network, destination)
}
func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return w.dialer.ListenPacket(ctx, destination)
}
func (w *SystemDevice) Inet4Address() netip.Addr {
if len(w.inet4Addresses) == 0 {
return netip.Addr{}
}
return w.inet4Addresses[0].Addr()
func (w *systemDevice) SetDevice(device *device.Device) {
}
func (w *SystemDevice) Inet6Address() netip.Addr {
if len(w.inet6Addresses) == 0 {
return netip.Addr{}
func (w *systemDevice) Start() error {
networkManager := service.FromContext[adapter.NetworkManager](w.options.Context)
tunOptions := tun.Options{
Name: w.options.Name,
Inet4Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
return it.Addr().Is4()
}),
Inet6Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
return it.Addr().Is6()
}),
MTU: w.options.MTU,
GSO: w.options.GSO,
InterfaceScope: true,
Inet4RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool {
return it.Addr().Is4()
}),
Inet6RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool { return it.Addr().Is6() }),
InterfaceMonitor: networkManager.InterfaceMonitor(),
InterfaceFinder: networkManager.InterfaceFinder(),
Logger: w.options.Logger,
}
return w.inet6Addresses[0].Addr()
}
func (w *SystemDevice) Start() error {
tunInterface, err := tun.New(tun.Options{
Name: w.name,
Inet4Address: w.inet4Addresses,
Inet6Address: w.inet6Addresses,
MTU: w.mtu,
GSO: w.gso,
})
// works with Linux, macOS with IFSCOPE routes, not tested on Windows
if runtime.GOOS == "darwin" {
tunOptions.AutoRoute = true
}
tunInterface, err := tun.New(tunOptions)
if err != nil {
return err
}
err = tunInterface.Start()
if err != nil {
return err
}
w.options.Logger.Info("started at ", w.options.Name)
w.device = tunInterface
if w.gso {
if w.options.GSO {
batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN)
if !isBatchTUN {
tunInterface.Close()
@ -107,15 +100,15 @@ func (w *SystemDevice) Start() error {
return nil
}
func (w *SystemDevice) File() *os.File {
func (w *systemDevice) File() *os.File {
return nil
}
func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
func (w *systemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
if w.batchDevice != nil {
count, err = w.batchDevice.BatchRead(bufs, offset, sizes)
count, err = w.batchDevice.BatchRead(bufs, offset-tun.PacketOffset, sizes)
} else {
sizes[0], err = w.device.Read(bufs[0][offset:])
sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:])
if err == nil {
count = 1
} else if errors.Is(err, tun.ErrTooManySegments) {
@ -125,12 +118,16 @@ func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int,
return
}
func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
func (w *systemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
if w.batchDevice != nil {
return 0, w.batchDevice.BatchWrite(bufs, offset)
return w.batchDevice.BatchWrite(bufs, offset)
} else {
for _, b := range bufs {
_, err = w.device.Write(b[offset:])
for _, packet := range bufs {
if tun.PacketOffset > 0 {
common.ClearArray(packet[offset-tun.PacketOffset : offset])
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
}
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
if err != nil {
return
}
@ -140,28 +137,28 @@ func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
return
}
func (w *SystemDevice) Flush() error {
func (w *systemDevice) Flush() error {
return nil
}
func (w *SystemDevice) MTU() (int, error) {
return int(w.mtu), nil
func (w *systemDevice) MTU() (int, error) {
return int(w.options.MTU), nil
}
func (w *SystemDevice) Name() (string, error) {
return w.name, nil
func (w *systemDevice) Name() (string, error) {
return w.options.Name, nil
}
func (w *SystemDevice) Events() <-chan wgTun.Event {
func (w *systemDevice) Events() <-chan wgTun.Event {
return w.events
}
func (w *SystemDevice) Close() error {
func (w *systemDevice) Close() error {
close(w.events)
return w.device.Close()
}
func (w *SystemDevice) BatchSize() int {
func (w *systemDevice) BatchSize() int {
if w.batchDevice != nil {
return w.batchDevice.BatchSize()
}

View file

@ -0,0 +1,182 @@
//go:build with_gvisor
package wireguard
import (
"net/netip"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
"github.com/sagernet/wireguard-go/device"
)
var _ Device = (*systemStackDevice)(nil)
type systemStackDevice struct {
*systemDevice
stack *stack.Stack
endpoint *deviceEndpoint
writeBufs [][]byte
}
func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) {
system, err := newSystemDevice(options)
if err != nil {
return nil, err
}
endpoint := &deviceEndpoint{
mtu: options.MTU,
done: make(chan struct{}),
}
ipStack, err := tun.NewGVisorStack(endpoint)
if err != nil {
return nil, err
}
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
return &systemStackDevice{
systemDevice: system,
stack: ipStack,
endpoint: endpoint,
}, nil
}
func (w *systemStackDevice) SetDevice(device *device.Device) {
w.endpoint.device = device
}
func (w *systemStackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
if w.batchDevice != nil {
w.writeBufs = w.writeBufs[:0]
for _, packet := range bufs {
if !w.writeStack(packet[offset:]) {
w.writeBufs = append(w.writeBufs, packet)
}
}
if len(w.writeBufs) > 0 {
return w.batchDevice.BatchWrite(bufs, offset)
}
} else {
for _, packet := range bufs {
if !w.writeStack(packet[offset:]) {
if tun.PacketOffset > 0 {
common.ClearArray(packet[offset-tun.PacketOffset : offset])
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
}
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
}
if err != nil {
return
}
}
}
// WireGuard will not read count
return
}
func (w *systemStackDevice) Close() error {
close(w.endpoint.done)
w.stack.Close()
for _, endpoint := range w.stack.CleanupEndpoints() {
endpoint.Abort()
}
w.stack.Wait()
return w.systemDevice.Close()
}
func (w *systemStackDevice) writeStack(packet []byte) bool {
var (
networkProtocol tcpip.NetworkProtocolNumber
destination netip.Addr
)
switch header.IPVersion(packet) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
destination = netip.AddrFrom4(header.IPv4(packet).DestinationAddress().As4())
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
destination = netip.AddrFrom16(header.IPv6(packet).DestinationAddress().As16())
}
for _, prefix := range w.options.Address {
if prefix.Contains(destination) {
return false
}
}
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
w.endpoint.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
packetBuffer.DecRef()
return true
}
type deviceEndpoint struct {
mtu uint32
done chan struct{}
device *device.Device
dispatcher stack.NetworkDispatcher
}
func (ep *deviceEndpoint) MTU() uint32 {
return ep.mtu
}
func (ep *deviceEndpoint) SetMTU(mtu uint32) {
}
func (ep *deviceEndpoint) MaxHeaderLength() uint16 {
return 0
}
func (ep *deviceEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (ep *deviceEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
}
func (ep *deviceEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityRXChecksumOffload
}
func (ep *deviceEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
ep.dispatcher = dispatcher
}
func (ep *deviceEndpoint) IsAttached() bool {
return ep.dispatcher != nil
}
func (ep *deviceEndpoint) Wait() {
}
func (ep *deviceEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (ep *deviceEndpoint) AddHeader(buffer *stack.PacketBuffer) {
}
func (ep *deviceEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
return true
}
func (ep *deviceEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
for _, packetBuffer := range list.AsSlice() {
destination := packetBuffer.Network().DestinationAddress()
ep.device.InputPacket(destination.AsSlice(), packetBuffer.AsSlices())
}
return list.Len(), nil
}
func (ep *deviceEndpoint) Close() {
}
func (ep *deviceEndpoint) SetOnCloseAction(f func()) {
}

View file

@ -1,35 +1,260 @@
package wireguard
import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/netip"
"os"
"strings"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn"
"github.com/sagernet/wireguard-go/device"
"go4.org/netipx"
)
var _ conn.Endpoint = (*Endpoint)(nil)
type Endpoint netip.AddrPort
func (e Endpoint) ClearSrc() {
type Endpoint struct {
options EndpointOptions
peers []peerConfig
ipcConf string
allowedAddress []netip.Prefix
tunDevice Device
device *device.Device
pauseManager pause.Manager
pauseCallback *list.Element[pause.Callback]
}
func (e Endpoint) SrcToString() string {
return ""
func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
if options.PrivateKey == "" {
return nil, E.New("missing private key")
}
privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
if err != nil {
return nil, E.Cause(err, "decode private key")
}
privateKey := hex.EncodeToString(privateKeyBytes)
ipcConf := "private_key=" + privateKey
if options.ListenPort != 0 {
ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
}
var peers []peerConfig
for peerIndex, rawPeer := range options.Peers {
peer := peerConfig{
allowedIPs: rawPeer.AllowedIPs,
keepalive: rawPeer.PersistentKeepaliveInterval,
}
if rawPeer.Endpoint.Addr.IsValid() {
peer.endpoint = rawPeer.Endpoint.AddrPort()
} else if rawPeer.Endpoint.IsFqdn() {
peer.destination = rawPeer.Endpoint
}
publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
}
peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
if rawPeer.PreSharedKey != "" {
preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
}
peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
}
if len(rawPeer.AllowedIPs) == 0 {
return nil, E.New("missing allowed ips for peer ", peerIndex)
}
if len(rawPeer.Reserved) > 0 {
if len(rawPeer.Reserved) != 3 {
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
}
copy(peer.reserved[:], rawPeer.Reserved[:])
}
peers = append(peers, peer)
}
var allowedPrefixBuilder netipx.IPSetBuilder
for _, peer := range options.Peers {
for _, prefix := range peer.AllowedIPs {
allowedPrefixBuilder.AddPrefix(prefix)
}
}
allowedIPSet, err := allowedPrefixBuilder.IPSet()
if err != nil {
return nil, err
}
allowedAddresses := allowedIPSet.Prefixes()
if options.MTU == 0 {
options.MTU = 1408
}
deviceOptions := DeviceOptions{
Context: options.Context,
Logger: options.Logger,
System: options.System,
Handler: options.Handler,
UDPTimeout: options.UDPTimeout,
CreateDialer: options.CreateDialer,
Name: options.Name,
MTU: options.MTU,
GSO: options.GSO,
Address: options.Address,
AllowedAddress: allowedAddresses,
}
tunDevice, err := NewDevice(deviceOptions)
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
return &Endpoint{
options: options,
peers: peers,
ipcConf: ipcConf,
allowedAddress: allowedAddresses,
tunDevice: tunDevice,
}, nil
}
func (e Endpoint) DstToString() string {
return (netip.AddrPort)(e).String()
func (e *Endpoint) Start(resolve bool) error {
if common.Any(e.peers, func(peer peerConfig) bool {
return !peer.endpoint.IsValid() && peer.destination.IsFqdn()
}) {
if !resolve {
return nil
}
for peerIndex, peer := range e.peers {
if peer.endpoint.IsValid() || !peer.destination.IsFqdn() {
continue
}
destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
if err != nil {
return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
}
e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
}
} else if resolve {
return nil
}
var bind conn.Bind
wgListener, isWgListener := e.options.Dialer.(conn.Listener)
if isWgListener {
bind = conn.NewStdNetBind(wgListener)
} else {
var (
isConnect bool
connectAddr netip.AddrPort
reserved [3]uint8
)
if len(e.peers) == 1 {
isConnect = true
connectAddr = e.peers[0].endpoint
reserved = e.peers[0].reserved
}
bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
}
if isWgListener || len(e.peers) > 1 {
for _, peer := range e.peers {
if peer.reserved != [3]uint8{} {
bind.SetReservedForEndpoint(peer.endpoint, peer.reserved)
}
}
}
err := e.tunDevice.Start()
if err != nil {
return err
}
logger := &device.Logger{
Verbosef: func(format string, args ...interface{}) {
e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
},
Errorf: func(format string, args ...interface{}) {
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
},
}
wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers)
e.tunDevice.SetDevice(wgDevice)
ipcConf := e.ipcConf
for _, peer := range e.peers {
ipcConf += peer.GenerateIpcLines()
}
err = wgDevice.IpcSet(ipcConf)
if err != nil {
return E.Cause(err, "setup wireguard: \n", ipcConf)
}
e.device = wgDevice
e.pauseManager = service.FromContext[pause.Manager](e.options.Context)
if e.pauseManager != nil {
e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated)
}
return nil
}
func (e Endpoint) DstToBytes() []byte {
b, _ := (netip.AddrPort)(e).MarshalBinary()
return b
func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
return e.tunDevice.DialContext(ctx, network, destination)
}
func (e Endpoint) DstIP() netip.Addr {
return (netip.AddrPort)(e).Addr()
func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
return e.tunDevice.ListenPacket(ctx, destination)
}
func (e Endpoint) SrcIP() netip.Addr {
return netip.Addr{}
func (e *Endpoint) BindUpdate() error {
return e.device.BindUpdate()
}
func (e *Endpoint) Close() error {
if e.device != nil {
e.device.Close()
}
if e.pauseCallback != nil {
e.pauseManager.UnregisterCallback(e.pauseCallback)
}
return nil
}
func (e *Endpoint) onPauseUpdated(event int) {
switch event {
case pause.EventDevicePaused:
e.device.Down()
case pause.EventDeviceWake:
e.device.Up()
}
}
type peerConfig struct {
destination M.Socksaddr
endpoint netip.AddrPort
publicKeyHex string
preSharedKeyHex string
allowedIPs []netip.Prefix
keepalive uint16
reserved [3]uint8
}
func (c peerConfig) GenerateIpcLines() string {
ipcLines := "\npublic_key=" + c.publicKeyHex
if c.endpoint.IsValid() {
ipcLines += "\nendpoint=" + c.endpoint.String()
}
if c.preSharedKeyHex != "" {
ipcLines += "\npreshared_key=" + c.preSharedKeyHex
}
for _, allowedIP := range c.allowedIPs {
ipcLines += "\nallowed_ip=" + allowedIP.String()
}
if c.keepalive > 0 {
ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive)
}
return ipcLines
}

View file

@ -0,0 +1,40 @@
package wireguard
import (
"context"
"net/netip"
"time"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type EndpointOptions struct {
Context context.Context
Logger logger.ContextLogger
System bool
Handler tun.Handler
UDPTimeout time.Duration
Dialer N.Dialer
CreateDialer func(interfaceName string) N.Dialer
Name string
MTU uint32
GSO bool
Address []netip.Prefix
PrivateKey string
ListenPort uint16
ResolvePeer func(domain string) (netip.Addr, error)
Peers []PeerOptions
Workers int
}
type PeerOptions struct {
Endpoint M.Socksaddr
PublicKey string
PreSharedKey string
AllowedIPs []netip.Prefix
PersistentKeepaliveInterval uint16
Reserved []uint8
}

View file

@ -1,148 +0,0 @@
package wireguard
import (
"context"
"encoding/base64"
"encoding/hex"
"net/netip"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
type PeerConfig struct {
destination M.Socksaddr
domainStrategy dns.DomainStrategy
Endpoint netip.AddrPort
PublicKey string
PreSharedKey string
AllowedIPs []string
Reserved [3]uint8
}
func (c PeerConfig) GenerateIpcLines() string {
ipcLines := "\npublic_key=" + c.PublicKey
ipcLines += "\nendpoint=" + c.Endpoint.String()
if c.PreSharedKey != "" {
ipcLines += "\npreshared_key=" + c.PreSharedKey
}
for _, allowedIP := range c.AllowedIPs {
ipcLines += "\nallowed_ip=" + allowedIP
}
return ipcLines
}
func ParsePeers(options option.WireGuardOutboundOptions) ([]PeerConfig, error) {
var peers []PeerConfig
if len(options.Peers) > 0 {
for peerIndex, rawPeer := range options.Peers {
peer := PeerConfig{
AllowedIPs: rawPeer.AllowedIPs,
}
destination := rawPeer.ServerOptions.Build()
if destination.IsFqdn() {
peer.destination = destination
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
} else {
peer.Endpoint = destination.AddrPort()
}
{
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
}
peer.PublicKey = hex.EncodeToString(bytes)
}
if rawPeer.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
}
peer.PreSharedKey = hex.EncodeToString(bytes)
}
if len(rawPeer.AllowedIPs) == 0 {
return nil, E.New("missing allowed_ips for peer ", peerIndex)
}
if len(rawPeer.Reserved) > 0 {
if len(rawPeer.Reserved) != 3 {
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.Reserved))
}
copy(peer.Reserved[:], options.Reserved)
}
peers = append(peers, peer)
}
} else {
peer := PeerConfig{}
var (
addressHas4 bool
addressHas6 bool
)
for _, localAddress := range options.LocalAddress {
if localAddress.Addr().Is4() {
addressHas4 = true
} else {
addressHas6 = true
}
}
if addressHas4 {
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv4Unspecified(), 0).String())
}
if addressHas6 {
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv6Unspecified(), 0).String())
}
destination := options.ServerOptions.Build()
if destination.IsFqdn() {
peer.destination = destination
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
} else {
peer.Endpoint = destination.AddrPort()
}
{
bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
if err != nil {
return nil, E.Cause(err, "decode peer public key")
}
peer.PublicKey = hex.EncodeToString(bytes)
}
if options.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key")
}
peer.PreSharedKey = hex.EncodeToString(bytes)
}
if len(options.Reserved) > 0 {
if len(options.Reserved) != 3 {
return nil, E.New("invalid reserved value, required 3 bytes, got ", len(peer.Reserved))
}
copy(peer.Reserved[:], options.Reserved)
}
peers = append(peers, peer)
}
return peers, nil
}
func ResolvePeers(ctx context.Context, router adapter.Router, peers []PeerConfig) error {
for peerIndex, peer := range peers {
if peer.Endpoint.IsValid() {
continue
}
destinationAddresses, err := router.Lookup(ctx, peer.destination.Fqdn, peer.domainStrategy)
if err != nil {
if len(peers) == 1 {
return E.Cause(err, "resolve endpoint domain")
} else {
return E.Cause(err, "resolve endpoint domain for peer ", peerIndex)
}
}
if len(destinationAddresses) == 0 {
return E.New("no addresses found for endpoint domain: ", peer.destination.Fqdn)
}
peers[peerIndex].Endpoint = netip.AddrPortFrom(destinationAddresses[0], peer.destination.Port)
}
return nil
}