mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-04-03 03:47:37 +03:00
refactor: WireGuard endpoint
This commit is contained in:
parent
fd299a0961
commit
68781387fe
89 changed files with 2187 additions and 679 deletions
|
@ -128,7 +128,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
|
|||
select {
|
||||
case <-c.done:
|
||||
default:
|
||||
c.logger.Error(context.Background(), E.Cause(err, "read packet"))
|
||||
c.logger.Error(E.Cause(err, "read packet"))
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
|
@ -138,7 +138,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
|
|||
b := packets[0]
|
||||
common.ClearArray(b[1:4])
|
||||
}
|
||||
eps[0] = Endpoint(M.AddrPortFromNet(addr))
|
||||
eps[0] = remoteEndpoint(M.AddrPortFromNet(addr))
|
||||
count = 1
|
||||
return
|
||||
}
|
||||
|
@ -169,7 +169,7 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
|||
time.Sleep(time.Second)
|
||||
return err
|
||||
}
|
||||
destination := netip.AddrPort(ep.(Endpoint))
|
||||
destination := netip.AddrPort(ep.(remoteEndpoint))
|
||||
for _, b := range bufs {
|
||||
if len(b) > 3 {
|
||||
reserved, loaded := c.reservedForEndpoint[destination]
|
||||
|
@ -192,7 +192,7 @@ func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Endpoint(ap), nil
|
||||
return remoteEndpoint(ap), nil
|
||||
}
|
||||
|
||||
func (c *ClientBind) BatchSize() int {
|
||||
|
@ -229,3 +229,31 @@ func (w *wireConn) Close() error {
|
|||
close(w.done)
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ conn.Endpoint = (*remoteEndpoint)(nil)
|
||||
|
||||
type remoteEndpoint netip.AddrPort
|
||||
|
||||
func (e remoteEndpoint) ClearSrc() {
|
||||
}
|
||||
|
||||
func (e remoteEndpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e remoteEndpoint) DstToString() string {
|
||||
return (netip.AddrPort)(e).String()
|
||||
}
|
||||
|
||||
func (e remoteEndpoint) DstToBytes() []byte {
|
||||
b, _ := (netip.AddrPort)(e).MarshalBinary()
|
||||
return b
|
||||
}
|
||||
|
||||
func (e remoteEndpoint) DstIP() netip.Addr {
|
||||
return (netip.AddrPort)(e).Addr()
|
||||
}
|
||||
|
||||
func (e remoteEndpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,44 @@
|
|||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/wireguard-go/tun"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
wgTun "github.com/sagernet/wireguard-go/tun"
|
||||
)
|
||||
|
||||
type Device interface {
|
||||
tun.Device
|
||||
wgTun.Device
|
||||
N.Dialer
|
||||
Start() error
|
||||
// NewEndpoint() (stack.LinkEndpoint, error)
|
||||
SetDevice(device *device.Device)
|
||||
}
|
||||
|
||||
type DeviceOptions struct {
|
||||
Context context.Context
|
||||
Logger logger.ContextLogger
|
||||
System bool
|
||||
Handler tun.Handler
|
||||
UDPTimeout time.Duration
|
||||
CreateDialer func(interfaceName string) N.Dialer
|
||||
Name string
|
||||
MTU uint32
|
||||
GSO bool
|
||||
Address []netip.Prefix
|
||||
AllowedAddress []netip.Prefix
|
||||
}
|
||||
|
||||
func NewDevice(options DeviceOptions) (Device, error) {
|
||||
if !options.System {
|
||||
return newStackDevice(options)
|
||||
} else if options.Handler == nil {
|
||||
return newSystemDevice(options)
|
||||
} else {
|
||||
return newSystemStackDevice(options)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ package wireguard
|
|||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
|
@ -15,52 +14,41 @@ import (
|
|||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
wgTun "github.com/sagernet/wireguard-go/tun"
|
||||
)
|
||||
|
||||
var _ Device = (*StackDevice)(nil)
|
||||
var _ Device = (*stackDevice)(nil)
|
||||
|
||||
const defaultNIC tcpip.NICID = 1
|
||||
|
||||
type StackDevice struct {
|
||||
stack *stack.Stack
|
||||
mtu uint32
|
||||
events chan wgTun.Event
|
||||
outbound chan *stack.PacketBuffer
|
||||
packetOutbound chan *buf.Buffer
|
||||
done chan struct{}
|
||||
dispatcher stack.NetworkDispatcher
|
||||
addr4 tcpip.Address
|
||||
addr6 tcpip.Address
|
||||
type stackDevice struct {
|
||||
stack *stack.Stack
|
||||
mtu uint32
|
||||
events chan wgTun.Event
|
||||
outbound chan *stack.PacketBuffer
|
||||
done chan struct{}
|
||||
dispatcher stack.NetworkDispatcher
|
||||
addr4 tcpip.Address
|
||||
addr6 tcpip.Address
|
||||
}
|
||||
|
||||
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
|
||||
ipStack := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
|
||||
HandleLocal: true,
|
||||
})
|
||||
tunDevice := &StackDevice{
|
||||
stack: ipStack,
|
||||
mtu: mtu,
|
||||
events: make(chan wgTun.Event, 1),
|
||||
outbound: make(chan *stack.PacketBuffer, 256),
|
||||
packetOutbound: make(chan *buf.Buffer, 256),
|
||||
done: make(chan struct{}),
|
||||
func newStackDevice(options DeviceOptions) (*stackDevice, error) {
|
||||
tunDevice := &stackDevice{
|
||||
mtu: options.MTU,
|
||||
events: make(chan wgTun.Event, 1),
|
||||
outbound: make(chan *stack.PacketBuffer, 256),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
|
||||
ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
|
||||
if err != nil {
|
||||
return nil, E.New(err.String())
|
||||
return nil, err
|
||||
}
|
||||
for _, prefix := range localAddresses {
|
||||
for _, prefix := range options.Address {
|
||||
addr := tun.AddressFromAddr(prefix.Addr())
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
|
@ -75,32 +63,27 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
|
|||
tunDevice.addr6 = addr
|
||||
protoAddr.Protocol = ipv6.ProtocolNumber
|
||||
}
|
||||
err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
|
||||
if err != nil {
|
||||
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
|
||||
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
|
||||
if gErr != nil {
|
||||
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
|
||||
}
|
||||
}
|
||||
sOpt := tcpip.TCPSACKEnabled(true)
|
||||
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
|
||||
cOpt := tcpip.CongestionControlOption("cubic")
|
||||
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
|
||||
ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
|
||||
ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
|
||||
tunDevice.stack = ipStack
|
||||
if options.Handler != nil {
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
|
||||
}
|
||||
return tunDevice, nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
|
||||
return (*wireEndpoint)(w), nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
addr := tcpip.FullAddress{
|
||||
NIC: defaultNIC,
|
||||
NIC: tun.DefaultNIC,
|
||||
Port: destination.Port,
|
||||
Addr: tun.AddressFromAddr(destination.Addr),
|
||||
}
|
||||
bind := tcpip.FullAddress{
|
||||
NIC: defaultNIC,
|
||||
NIC: tun.DefaultNIC,
|
||||
}
|
||||
var networkProtocol tcpip.NetworkProtocolNumber
|
||||
if destination.IsIPv4() {
|
||||
|
@ -128,9 +111,9 @@ func (w *StackDevice) DialContext(ctx context.Context, network string, destinati
|
|||
}
|
||||
}
|
||||
|
||||
func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
bind := tcpip.FullAddress{
|
||||
NIC: defaultNIC,
|
||||
NIC: tun.DefaultNIC,
|
||||
}
|
||||
var networkProtocol tcpip.NetworkProtocolNumber
|
||||
if destination.IsIPv4() {
|
||||
|
@ -147,24 +130,19 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
|
|||
return udpConn, nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) Inet4Address() netip.Addr {
|
||||
return tun.AddrFromAddress(w.addr4)
|
||||
func (w *stackDevice) SetDevice(device *device.Device) {
|
||||
}
|
||||
|
||||
func (w *StackDevice) Inet6Address() netip.Addr {
|
||||
return tun.AddrFromAddress(w.addr6)
|
||||
}
|
||||
|
||||
func (w *StackDevice) Start() error {
|
||||
func (w *stackDevice) Start() error {
|
||||
w.events <- wgTun.EventUp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) File() *os.File {
|
||||
func (w *stackDevice) File() *os.File {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
||||
func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
||||
select {
|
||||
case packetBuffer, ok := <-w.outbound:
|
||||
if !ok {
|
||||
|
@ -180,17 +158,12 @@ func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, e
|
|||
sizes[0] = n
|
||||
count = 1
|
||||
return
|
||||
case packet := <-w.packetOutbound:
|
||||
defer packet.Release()
|
||||
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
|
||||
count = 1
|
||||
return
|
||||
case <-w.done:
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||
func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||
for _, b := range bufs {
|
||||
b = b[offset:]
|
||||
if len(b) == 0 {
|
||||
|
@ -213,23 +186,23 @@ func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (w *StackDevice) Flush() error {
|
||||
func (w *stackDevice) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) MTU() (int, error) {
|
||||
func (w *stackDevice) MTU() (int, error) {
|
||||
return int(w.mtu), nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) Name() (string, error) {
|
||||
func (w *stackDevice) Name() (string, error) {
|
||||
return "sing-box", nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) Events() <-chan wgTun.Event {
|
||||
func (w *stackDevice) Events() <-chan wgTun.Event {
|
||||
return w.events
|
||||
}
|
||||
|
||||
func (w *StackDevice) Close() error {
|
||||
func (w *stackDevice) Close() error {
|
||||
close(w.done)
|
||||
close(w.events)
|
||||
w.stack.Close()
|
||||
|
@ -240,13 +213,13 @@ func (w *StackDevice) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) BatchSize() int {
|
||||
func (w *stackDevice) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
|
||||
|
||||
type wireEndpoint StackDevice
|
||||
type wireEndpoint stackDevice
|
||||
|
||||
func (ep *wireEndpoint) MTU() uint32 {
|
||||
return ep.mtu
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
import "github.com/sagernet/sing-tun"
|
||||
|
||||
"github.com/sagernet/sing-tun"
|
||||
)
|
||||
|
||||
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) {
|
||||
func newStackDevice(options DeviceOptions) (Device, error) {
|
||||
return nil, tun.ErrGVisorNotIncluded
|
||||
}
|
||||
|
||||
func newSystemStackDevice(options DeviceOptions) (Device, error) {
|
||||
return nil, tun.ErrGVisorNotIncluded
|
||||
}
|
||||
|
|
|
@ -6,96 +6,89 @@ import (
|
|||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/service"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
wgTun "github.com/sagernet/wireguard-go/tun"
|
||||
)
|
||||
|
||||
var _ Device = (*SystemDevice)(nil)
|
||||
var _ Device = (*systemDevice)(nil)
|
||||
|
||||
type SystemDevice struct {
|
||||
dialer N.Dialer
|
||||
device tun.Tun
|
||||
batchDevice tun.LinuxTUN
|
||||
name string
|
||||
mtu uint32
|
||||
inet4Addresses []netip.Prefix
|
||||
inet6Addresses []netip.Prefix
|
||||
gso bool
|
||||
events chan wgTun.Event
|
||||
closeOnce sync.Once
|
||||
type systemDevice struct {
|
||||
options DeviceOptions
|
||||
dialer N.Dialer
|
||||
device tun.Tun
|
||||
batchDevice tun.LinuxTUN
|
||||
events chan wgTun.Event
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewSystemDevice(networkManager adapter.NetworkManager, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) {
|
||||
var inet4Addresses []netip.Prefix
|
||||
var inet6Addresses []netip.Prefix
|
||||
for _, prefixes := range localPrefixes {
|
||||
if prefixes.Addr().Is4() {
|
||||
inet4Addresses = append(inet4Addresses, prefixes)
|
||||
} else {
|
||||
inet6Addresses = append(inet6Addresses, prefixes)
|
||||
}
|
||||
func newSystemDevice(options DeviceOptions) (*systemDevice, error) {
|
||||
if options.Name == "" {
|
||||
options.Name = tun.CalculateInterfaceName("wg")
|
||||
}
|
||||
if interfaceName == "" {
|
||||
interfaceName = tun.CalculateInterfaceName("wg")
|
||||
}
|
||||
|
||||
return &SystemDevice{
|
||||
dialer: common.Must1(dialer.NewDefault(networkManager, option.DialerOptions{
|
||||
BindInterface: interfaceName,
|
||||
})),
|
||||
name: interfaceName,
|
||||
mtu: mtu,
|
||||
inet4Addresses: inet4Addresses,
|
||||
inet6Addresses: inet6Addresses,
|
||||
gso: gso,
|
||||
events: make(chan wgTun.Event, 1),
|
||||
return &systemDevice{
|
||||
options: options,
|
||||
dialer: options.CreateDialer(options.Name),
|
||||
events: make(chan wgTun.Event, 1),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *SystemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
func (w *systemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
return w.dialer.DialContext(ctx, network, destination)
|
||||
}
|
||||
|
||||
func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
return w.dialer.ListenPacket(ctx, destination)
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Inet4Address() netip.Addr {
|
||||
if len(w.inet4Addresses) == 0 {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return w.inet4Addresses[0].Addr()
|
||||
func (w *systemDevice) SetDevice(device *device.Device) {
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Inet6Address() netip.Addr {
|
||||
if len(w.inet6Addresses) == 0 {
|
||||
return netip.Addr{}
|
||||
func (w *systemDevice) Start() error {
|
||||
networkManager := service.FromContext[adapter.NetworkManager](w.options.Context)
|
||||
tunOptions := tun.Options{
|
||||
Name: w.options.Name,
|
||||
Inet4Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
|
||||
return it.Addr().Is4()
|
||||
}),
|
||||
Inet6Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
|
||||
return it.Addr().Is6()
|
||||
}),
|
||||
MTU: w.options.MTU,
|
||||
GSO: w.options.GSO,
|
||||
InterfaceScope: true,
|
||||
Inet4RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool {
|
||||
return it.Addr().Is4()
|
||||
}),
|
||||
Inet6RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool { return it.Addr().Is6() }),
|
||||
InterfaceMonitor: networkManager.InterfaceMonitor(),
|
||||
InterfaceFinder: networkManager.InterfaceFinder(),
|
||||
Logger: w.options.Logger,
|
||||
}
|
||||
return w.inet6Addresses[0].Addr()
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Start() error {
|
||||
tunInterface, err := tun.New(tun.Options{
|
||||
Name: w.name,
|
||||
Inet4Address: w.inet4Addresses,
|
||||
Inet6Address: w.inet6Addresses,
|
||||
MTU: w.mtu,
|
||||
GSO: w.gso,
|
||||
})
|
||||
// works with Linux, macOS with IFSCOPE routes, not tested on Windows
|
||||
if runtime.GOOS == "darwin" {
|
||||
tunOptions.AutoRoute = true
|
||||
}
|
||||
tunInterface, err := tun.New(tunOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tunInterface.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.options.Logger.Info("started at ", w.options.Name)
|
||||
w.device = tunInterface
|
||||
if w.gso {
|
||||
if w.options.GSO {
|
||||
batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN)
|
||||
if !isBatchTUN {
|
||||
tunInterface.Close()
|
||||
|
@ -107,15 +100,15 @@ func (w *SystemDevice) Start() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (w *SystemDevice) File() *os.File {
|
||||
func (w *systemDevice) File() *os.File {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
||||
func (w *systemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
||||
if w.batchDevice != nil {
|
||||
count, err = w.batchDevice.BatchRead(bufs, offset, sizes)
|
||||
count, err = w.batchDevice.BatchRead(bufs, offset-tun.PacketOffset, sizes)
|
||||
} else {
|
||||
sizes[0], err = w.device.Read(bufs[0][offset:])
|
||||
sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:])
|
||||
if err == nil {
|
||||
count = 1
|
||||
} else if errors.Is(err, tun.ErrTooManySegments) {
|
||||
|
@ -125,12 +118,16 @@ func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int,
|
|||
return
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||
func (w *systemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||
if w.batchDevice != nil {
|
||||
return 0, w.batchDevice.BatchWrite(bufs, offset)
|
||||
return w.batchDevice.BatchWrite(bufs, offset)
|
||||
} else {
|
||||
for _, b := range bufs {
|
||||
_, err = w.device.Write(b[offset:])
|
||||
for _, packet := range bufs {
|
||||
if tun.PacketOffset > 0 {
|
||||
common.ClearArray(packet[offset-tun.PacketOffset : offset])
|
||||
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
|
||||
}
|
||||
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -140,28 +137,28 @@ func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Flush() error {
|
||||
func (w *systemDevice) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *SystemDevice) MTU() (int, error) {
|
||||
return int(w.mtu), nil
|
||||
func (w *systemDevice) MTU() (int, error) {
|
||||
return int(w.options.MTU), nil
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Name() (string, error) {
|
||||
return w.name, nil
|
||||
func (w *systemDevice) Name() (string, error) {
|
||||
return w.options.Name, nil
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Events() <-chan wgTun.Event {
|
||||
func (w *systemDevice) Events() <-chan wgTun.Event {
|
||||
return w.events
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Close() error {
|
||||
func (w *systemDevice) Close() error {
|
||||
close(w.events)
|
||||
return w.device.Close()
|
||||
}
|
||||
|
||||
func (w *SystemDevice) BatchSize() int {
|
||||
func (w *systemDevice) BatchSize() int {
|
||||
if w.batchDevice != nil {
|
||||
return w.batchDevice.BatchSize()
|
||||
}
|
||||
|
|
182
transport/wireguard/device_system_stack.go
Normal file
182
transport/wireguard/device_system_stack.go
Normal file
|
@ -0,0 +1,182 @@
|
|||
//go:build with_gvisor
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
)
|
||||
|
||||
var _ Device = (*systemStackDevice)(nil)
|
||||
|
||||
type systemStackDevice struct {
|
||||
*systemDevice
|
||||
stack *stack.Stack
|
||||
endpoint *deviceEndpoint
|
||||
writeBufs [][]byte
|
||||
}
|
||||
|
||||
func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) {
|
||||
system, err := newSystemDevice(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
endpoint := &deviceEndpoint{
|
||||
mtu: options.MTU,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
ipStack, err := tun.NewGVisorStack(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
|
||||
return &systemStackDevice{
|
||||
systemDevice: system,
|
||||
stack: ipStack,
|
||||
endpoint: endpoint,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *systemStackDevice) SetDevice(device *device.Device) {
|
||||
w.endpoint.device = device
|
||||
}
|
||||
|
||||
func (w *systemStackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||
if w.batchDevice != nil {
|
||||
w.writeBufs = w.writeBufs[:0]
|
||||
for _, packet := range bufs {
|
||||
if !w.writeStack(packet[offset:]) {
|
||||
w.writeBufs = append(w.writeBufs, packet)
|
||||
}
|
||||
}
|
||||
if len(w.writeBufs) > 0 {
|
||||
return w.batchDevice.BatchWrite(bufs, offset)
|
||||
}
|
||||
} else {
|
||||
for _, packet := range bufs {
|
||||
if !w.writeStack(packet[offset:]) {
|
||||
if tun.PacketOffset > 0 {
|
||||
common.ClearArray(packet[offset-tun.PacketOffset : offset])
|
||||
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
|
||||
}
|
||||
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// WireGuard will not read count
|
||||
return
|
||||
}
|
||||
|
||||
func (w *systemStackDevice) Close() error {
|
||||
close(w.endpoint.done)
|
||||
w.stack.Close()
|
||||
for _, endpoint := range w.stack.CleanupEndpoints() {
|
||||
endpoint.Abort()
|
||||
}
|
||||
w.stack.Wait()
|
||||
return w.systemDevice.Close()
|
||||
}
|
||||
|
||||
func (w *systemStackDevice) writeStack(packet []byte) bool {
|
||||
var (
|
||||
networkProtocol tcpip.NetworkProtocolNumber
|
||||
destination netip.Addr
|
||||
)
|
||||
switch header.IPVersion(packet) {
|
||||
case header.IPv4Version:
|
||||
networkProtocol = header.IPv4ProtocolNumber
|
||||
destination = netip.AddrFrom4(header.IPv4(packet).DestinationAddress().As4())
|
||||
case header.IPv6Version:
|
||||
networkProtocol = header.IPv6ProtocolNumber
|
||||
destination = netip.AddrFrom16(header.IPv6(packet).DestinationAddress().As16())
|
||||
}
|
||||
for _, prefix := range w.options.Address {
|
||||
if prefix.Contains(destination) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
w.endpoint.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
|
||||
packetBuffer.DecRef()
|
||||
return true
|
||||
}
|
||||
|
||||
type deviceEndpoint struct {
|
||||
mtu uint32
|
||||
done chan struct{}
|
||||
device *device.Device
|
||||
dispatcher stack.NetworkDispatcher
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) MTU() uint32 {
|
||||
return ep.mtu
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) SetMTU(mtu uint32) {
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) MaxHeaderLength() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) LinkAddress() tcpip.LinkAddress {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
return stack.CapabilityRXChecksumOffload
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
ep.dispatcher = dispatcher
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) IsAttached() bool {
|
||||
return ep.dispatcher != nil
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) Wait() {
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) ARPHardwareType() header.ARPHardwareType {
|
||||
return header.ARPHardwareNone
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) AddHeader(buffer *stack.PacketBuffer) {
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
|
||||
for _, packetBuffer := range list.AsSlice() {
|
||||
destination := packetBuffer.Network().DestinationAddress()
|
||||
ep.device.InputPacket(destination.AsSlice(), packetBuffer.AsSlices())
|
||||
}
|
||||
return list.Len(), nil
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) Close() {
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) SetOnCloseAction(f func()) {
|
||||
}
|
|
@ -1,35 +1,260 @@
|
|||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
"github.com/sagernet/sing/service"
|
||||
"github.com/sagernet/sing/service/pause"
|
||||
"github.com/sagernet/wireguard-go/conn"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
var _ conn.Endpoint = (*Endpoint)(nil)
|
||||
|
||||
type Endpoint netip.AddrPort
|
||||
|
||||
func (e Endpoint) ClearSrc() {
|
||||
type Endpoint struct {
|
||||
options EndpointOptions
|
||||
peers []peerConfig
|
||||
ipcConf string
|
||||
allowedAddress []netip.Prefix
|
||||
tunDevice Device
|
||||
device *device.Device
|
||||
pauseManager pause.Manager
|
||||
pauseCallback *list.Element[pause.Callback]
|
||||
}
|
||||
|
||||
func (e Endpoint) SrcToString() string {
|
||||
return ""
|
||||
func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
|
||||
if options.PrivateKey == "" {
|
||||
return nil, E.New("missing private key")
|
||||
}
|
||||
privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode private key")
|
||||
}
|
||||
privateKey := hex.EncodeToString(privateKeyBytes)
|
||||
ipcConf := "private_key=" + privateKey
|
||||
if options.ListenPort != 0 {
|
||||
ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
|
||||
}
|
||||
var peers []peerConfig
|
||||
for peerIndex, rawPeer := range options.Peers {
|
||||
peer := peerConfig{
|
||||
allowedIPs: rawPeer.AllowedIPs,
|
||||
keepalive: rawPeer.PersistentKeepaliveInterval,
|
||||
}
|
||||
if rawPeer.Endpoint.Addr.IsValid() {
|
||||
peer.endpoint = rawPeer.Endpoint.AddrPort()
|
||||
} else if rawPeer.Endpoint.IsFqdn() {
|
||||
peer.destination = rawPeer.Endpoint
|
||||
}
|
||||
publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
|
||||
}
|
||||
peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
|
||||
if rawPeer.PreSharedKey != "" {
|
||||
preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
|
||||
}
|
||||
peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
|
||||
}
|
||||
if len(rawPeer.AllowedIPs) == 0 {
|
||||
return nil, E.New("missing allowed ips for peer ", peerIndex)
|
||||
}
|
||||
if len(rawPeer.Reserved) > 0 {
|
||||
if len(rawPeer.Reserved) != 3 {
|
||||
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
|
||||
}
|
||||
copy(peer.reserved[:], rawPeer.Reserved[:])
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
var allowedPrefixBuilder netipx.IPSetBuilder
|
||||
for _, peer := range options.Peers {
|
||||
for _, prefix := range peer.AllowedIPs {
|
||||
allowedPrefixBuilder.AddPrefix(prefix)
|
||||
}
|
||||
}
|
||||
allowedIPSet, err := allowedPrefixBuilder.IPSet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allowedAddresses := allowedIPSet.Prefixes()
|
||||
if options.MTU == 0 {
|
||||
options.MTU = 1408
|
||||
}
|
||||
deviceOptions := DeviceOptions{
|
||||
Context: options.Context,
|
||||
Logger: options.Logger,
|
||||
System: options.System,
|
||||
Handler: options.Handler,
|
||||
UDPTimeout: options.UDPTimeout,
|
||||
CreateDialer: options.CreateDialer,
|
||||
Name: options.Name,
|
||||
MTU: options.MTU,
|
||||
GSO: options.GSO,
|
||||
Address: options.Address,
|
||||
AllowedAddress: allowedAddresses,
|
||||
}
|
||||
tunDevice, err := NewDevice(deviceOptions)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create WireGuard device")
|
||||
}
|
||||
return &Endpoint{
|
||||
options: options,
|
||||
peers: peers,
|
||||
ipcConf: ipcConf,
|
||||
allowedAddress: allowedAddresses,
|
||||
tunDevice: tunDevice,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e Endpoint) DstToString() string {
|
||||
return (netip.AddrPort)(e).String()
|
||||
func (e *Endpoint) Start(resolve bool) error {
|
||||
if common.Any(e.peers, func(peer peerConfig) bool {
|
||||
return !peer.endpoint.IsValid() && peer.destination.IsFqdn()
|
||||
}) {
|
||||
if !resolve {
|
||||
return nil
|
||||
}
|
||||
for peerIndex, peer := range e.peers {
|
||||
if peer.endpoint.IsValid() || !peer.destination.IsFqdn() {
|
||||
continue
|
||||
}
|
||||
destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
|
||||
}
|
||||
e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
|
||||
}
|
||||
} else if resolve {
|
||||
return nil
|
||||
}
|
||||
var bind conn.Bind
|
||||
wgListener, isWgListener := e.options.Dialer.(conn.Listener)
|
||||
if isWgListener {
|
||||
bind = conn.NewStdNetBind(wgListener)
|
||||
} else {
|
||||
var (
|
||||
isConnect bool
|
||||
connectAddr netip.AddrPort
|
||||
reserved [3]uint8
|
||||
)
|
||||
if len(e.peers) == 1 {
|
||||
isConnect = true
|
||||
connectAddr = e.peers[0].endpoint
|
||||
reserved = e.peers[0].reserved
|
||||
}
|
||||
bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
|
||||
}
|
||||
if isWgListener || len(e.peers) > 1 {
|
||||
for _, peer := range e.peers {
|
||||
if peer.reserved != [3]uint8{} {
|
||||
bind.SetReservedForEndpoint(peer.endpoint, peer.reserved)
|
||||
}
|
||||
}
|
||||
}
|
||||
err := e.tunDevice.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger := &device.Logger{
|
||||
Verbosef: func(format string, args ...interface{}) {
|
||||
e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
|
||||
},
|
||||
Errorf: func(format string, args ...interface{}) {
|
||||
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
|
||||
},
|
||||
}
|
||||
wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers)
|
||||
e.tunDevice.SetDevice(wgDevice)
|
||||
ipcConf := e.ipcConf
|
||||
for _, peer := range e.peers {
|
||||
ipcConf += peer.GenerateIpcLines()
|
||||
}
|
||||
err = wgDevice.IpcSet(ipcConf)
|
||||
if err != nil {
|
||||
return E.Cause(err, "setup wireguard: \n", ipcConf)
|
||||
}
|
||||
e.device = wgDevice
|
||||
e.pauseManager = service.FromContext[pause.Manager](e.options.Context)
|
||||
if e.pauseManager != nil {
|
||||
e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e Endpoint) DstToBytes() []byte {
|
||||
b, _ := (netip.AddrPort)(e).MarshalBinary()
|
||||
return b
|
||||
func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||
}
|
||||
return e.tunDevice.DialContext(ctx, network, destination)
|
||||
}
|
||||
|
||||
func (e Endpoint) DstIP() netip.Addr {
|
||||
return (netip.AddrPort)(e).Addr()
|
||||
func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||
}
|
||||
return e.tunDevice.ListenPacket(ctx, destination)
|
||||
}
|
||||
|
||||
func (e Endpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
func (e *Endpoint) BindUpdate() error {
|
||||
return e.device.BindUpdate()
|
||||
}
|
||||
|
||||
func (e *Endpoint) Close() error {
|
||||
if e.device != nil {
|
||||
e.device.Close()
|
||||
}
|
||||
if e.pauseCallback != nil {
|
||||
e.pauseManager.UnregisterCallback(e.pauseCallback)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Endpoint) onPauseUpdated(event int) {
|
||||
switch event {
|
||||
case pause.EventDevicePaused:
|
||||
e.device.Down()
|
||||
case pause.EventDeviceWake:
|
||||
e.device.Up()
|
||||
}
|
||||
}
|
||||
|
||||
type peerConfig struct {
|
||||
destination M.Socksaddr
|
||||
endpoint netip.AddrPort
|
||||
publicKeyHex string
|
||||
preSharedKeyHex string
|
||||
allowedIPs []netip.Prefix
|
||||
keepalive uint16
|
||||
reserved [3]uint8
|
||||
}
|
||||
|
||||
func (c peerConfig) GenerateIpcLines() string {
|
||||
ipcLines := "\npublic_key=" + c.publicKeyHex
|
||||
if c.endpoint.IsValid() {
|
||||
ipcLines += "\nendpoint=" + c.endpoint.String()
|
||||
}
|
||||
if c.preSharedKeyHex != "" {
|
||||
ipcLines += "\npreshared_key=" + c.preSharedKeyHex
|
||||
}
|
||||
for _, allowedIP := range c.allowedIPs {
|
||||
ipcLines += "\nallowed_ip=" + allowedIP.String()
|
||||
}
|
||||
if c.keepalive > 0 {
|
||||
ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive)
|
||||
}
|
||||
return ipcLines
|
||||
}
|
||||
|
|
40
transport/wireguard/endpoint_options.go
Normal file
40
transport/wireguard/endpoint_options.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type EndpointOptions struct {
|
||||
Context context.Context
|
||||
Logger logger.ContextLogger
|
||||
System bool
|
||||
Handler tun.Handler
|
||||
UDPTimeout time.Duration
|
||||
Dialer N.Dialer
|
||||
CreateDialer func(interfaceName string) N.Dialer
|
||||
Name string
|
||||
MTU uint32
|
||||
GSO bool
|
||||
Address []netip.Prefix
|
||||
PrivateKey string
|
||||
ListenPort uint16
|
||||
ResolvePeer func(domain string) (netip.Addr, error)
|
||||
Peers []PeerOptions
|
||||
Workers int
|
||||
}
|
||||
|
||||
type PeerOptions struct {
|
||||
Endpoint M.Socksaddr
|
||||
PublicKey string
|
||||
PreSharedKey string
|
||||
AllowedIPs []netip.Prefix
|
||||
PersistentKeepaliveInterval uint16
|
||||
Reserved []uint8
|
||||
}
|
|
@ -1,148 +0,0 @@
|
|||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-dns"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type PeerConfig struct {
|
||||
destination M.Socksaddr
|
||||
domainStrategy dns.DomainStrategy
|
||||
Endpoint netip.AddrPort
|
||||
PublicKey string
|
||||
PreSharedKey string
|
||||
AllowedIPs []string
|
||||
Reserved [3]uint8
|
||||
}
|
||||
|
||||
func (c PeerConfig) GenerateIpcLines() string {
|
||||
ipcLines := "\npublic_key=" + c.PublicKey
|
||||
ipcLines += "\nendpoint=" + c.Endpoint.String()
|
||||
if c.PreSharedKey != "" {
|
||||
ipcLines += "\npreshared_key=" + c.PreSharedKey
|
||||
}
|
||||
for _, allowedIP := range c.AllowedIPs {
|
||||
ipcLines += "\nallowed_ip=" + allowedIP
|
||||
}
|
||||
return ipcLines
|
||||
}
|
||||
|
||||
func ParsePeers(options option.WireGuardOutboundOptions) ([]PeerConfig, error) {
|
||||
var peers []PeerConfig
|
||||
if len(options.Peers) > 0 {
|
||||
for peerIndex, rawPeer := range options.Peers {
|
||||
peer := PeerConfig{
|
||||
AllowedIPs: rawPeer.AllowedIPs,
|
||||
}
|
||||
destination := rawPeer.ServerOptions.Build()
|
||||
if destination.IsFqdn() {
|
||||
peer.destination = destination
|
||||
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
|
||||
} else {
|
||||
peer.Endpoint = destination.AddrPort()
|
||||
}
|
||||
{
|
||||
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
|
||||
}
|
||||
peer.PublicKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
if rawPeer.PreSharedKey != "" {
|
||||
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
|
||||
}
|
||||
peer.PreSharedKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
if len(rawPeer.AllowedIPs) == 0 {
|
||||
return nil, E.New("missing allowed_ips for peer ", peerIndex)
|
||||
}
|
||||
if len(rawPeer.Reserved) > 0 {
|
||||
if len(rawPeer.Reserved) != 3 {
|
||||
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.Reserved))
|
||||
}
|
||||
copy(peer.Reserved[:], options.Reserved)
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
} else {
|
||||
peer := PeerConfig{}
|
||||
var (
|
||||
addressHas4 bool
|
||||
addressHas6 bool
|
||||
)
|
||||
for _, localAddress := range options.LocalAddress {
|
||||
if localAddress.Addr().Is4() {
|
||||
addressHas4 = true
|
||||
} else {
|
||||
addressHas6 = true
|
||||
}
|
||||
}
|
||||
if addressHas4 {
|
||||
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv4Unspecified(), 0).String())
|
||||
}
|
||||
if addressHas6 {
|
||||
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv6Unspecified(), 0).String())
|
||||
}
|
||||
destination := options.ServerOptions.Build()
|
||||
if destination.IsFqdn() {
|
||||
peer.destination = destination
|
||||
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
|
||||
} else {
|
||||
peer.Endpoint = destination.AddrPort()
|
||||
}
|
||||
{
|
||||
bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode peer public key")
|
||||
}
|
||||
peer.PublicKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
if options.PreSharedKey != "" {
|
||||
bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode pre shared key")
|
||||
}
|
||||
peer.PreSharedKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
if len(options.Reserved) > 0 {
|
||||
if len(options.Reserved) != 3 {
|
||||
return nil, E.New("invalid reserved value, required 3 bytes, got ", len(peer.Reserved))
|
||||
}
|
||||
copy(peer.Reserved[:], options.Reserved)
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func ResolvePeers(ctx context.Context, router adapter.Router, peers []PeerConfig) error {
|
||||
for peerIndex, peer := range peers {
|
||||
if peer.Endpoint.IsValid() {
|
||||
continue
|
||||
}
|
||||
destinationAddresses, err := router.Lookup(ctx, peer.destination.Fqdn, peer.domainStrategy)
|
||||
if err != nil {
|
||||
if len(peers) == 1 {
|
||||
return E.Cause(err, "resolve endpoint domain")
|
||||
} else {
|
||||
return E.Cause(err, "resolve endpoint domain for peer ", peerIndex)
|
||||
}
|
||||
}
|
||||
if len(destinationAddresses) == 0 {
|
||||
return E.New("no addresses found for endpoint domain: ", peer.destination.Fqdn)
|
||||
}
|
||||
peers[peerIndex].Endpoint = netip.AddrPortFrom(destinationAddresses[0], peer.destination.Port)
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue