refactor: WireGuard endpoint

This commit is contained in:
世界 2024-11-21 18:10:41 +08:00
parent fd299a0961
commit 68781387fe
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
89 changed files with 2187 additions and 679 deletions

View file

@ -2,238 +2,163 @@ package wireguard
import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/netip"
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn"
"github.com/sagernet/wireguard-go/device"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
}
var _ adapter.InterfaceUpdateListener = (*Outbound)(nil)
var (
_ adapter.Endpoint = (*Endpoint)(nil)
_ adapter.InterfaceUpdateListener = (*Endpoint)(nil)
)
type Outbound struct {
outbound.Adapter
ctx context.Context
router adapter.Router
logger logger.ContextLogger
workers int
peers []wireguard.PeerConfig
useStdNetBind bool
listener N.Dialer
ipcConf string
pauseManager pause.Manager
pauseCallback *list.Element[pause.Callback]
bind conn.Bind
device *device.Device
tunDevice wireguard.Device
ctx context.Context
router adapter.Router
logger logger.ContextLogger
localAddresses []netip.Prefix
endpoint *wireguard.Endpoint
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) {
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) {
deprecated.Report(ctx, deprecated.OptionWireGuardOutbound)
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, options.Network.Build(), tag, options.DialerOptions),
ctx: ctx,
router: router,
logger: logger,
workers: options.Workers,
pauseManager: service.FromContext[pause.Manager](ctx),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
ctx: ctx,
router: router,
logger: logger,
localAddresses: options.LocalAddress,
}
peers, err := wireguard.ParsePeers(options)
if err != nil {
return nil, err
}
outbound.peers = peers
if len(options.LocalAddress) == 0 {
return nil, E.New("missing local address")
}
if options.GSO {
if options.GSO && options.Detour != "" {
return nil, E.New("gso is conflict with detour")
}
if options.Detour == "" {
options.IsWireGuardListener = true
outbound.useStdNetBind = true
} else if options.GSO {
return nil, E.New("gso is conflict with detour")
}
listener, err := dialer.New(ctx, options.DialerOptions)
outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil {
return nil, err
}
outbound.listener = listener
var privateKey string
{
bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
if err != nil {
return nil, E.Cause(err, "decode private key")
peers := common.Map(options.Peers, func(it option.LegacyWireGuardPeer) wireguard.PeerOptions {
return wireguard.PeerOptions{
Endpoint: it.ServerOptions.Build(),
PublicKey: it.PublicKey,
PreSharedKey: it.PreSharedKey,
AllowedIPs: it.AllowedIPs,
// PersistentKeepaliveInterval: time.Duration(it.PersistentKeepaliveInterval),
Reserved: it.Reserved,
}
privateKey = hex.EncodeToString(bytes)
}
outbound.ipcConf = "private_key=" + privateKey
mtu := options.MTU
if mtu == 0 {
mtu = 1408
}
var wireTunDevice wireguard.Device
if !options.SystemInterface && tun.WithGVisor {
wireTunDevice, err = wireguard.NewStackDevice(options.LocalAddress, mtu)
} else {
wireTunDevice, err = wireguard.NewSystemDevice(service.FromContext[adapter.NetworkManager](ctx), options.InterfaceName, options.LocalAddress, mtu, options.GSO)
})
if len(peers) == 0 {
peers = []wireguard.PeerOptions{{
Endpoint: options.ServerOptions.Build(),
PublicKey: options.PeerPublicKey,
PreSharedKey: options.PreSharedKey,
AllowedIPs: []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0), netip.PrefixFrom(netip.IPv6Unspecified(), 0)},
Reserved: options.Reserved,
}}
}
wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
Context: ctx,
Logger: logger,
System: options.SystemInterface,
Dialer: outboundDialer,
CreateDialer: func(interfaceName string) N.Dialer {
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{
BindInterface: interfaceName,
}))
},
Name: options.InterfaceName,
MTU: options.MTU,
GSO: options.GSO,
Address: options.LocalAddress,
PrivateKey: options.PrivateKey,
ResolvePeer: func(domain string) (netip.Addr, error) {
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
if lookupErr != nil {
return netip.Addr{}, lookupErr
}
return endpointAddresses[0], nil
},
Peers: peers,
Workers: options.Workers,
})
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
return nil, err
}
outbound.tunDevice = wireTunDevice
outbound.endpoint = wgEndpoint
return outbound, nil
}
func (w *Outbound) Start() error {
if common.Any(w.peers, func(peer wireguard.PeerConfig) bool {
return !peer.Endpoint.IsValid()
}) {
// wait for all outbounds to be started and continue in PortStart
return nil
}
return w.start()
}
func (w *Outbound) PostStart() error {
if common.All(w.peers, func(peer wireguard.PeerConfig) bool {
return peer.Endpoint.IsValid()
}) {
return nil
}
return w.start()
}
func (w *Outbound) start() error {
err := wireguard.ResolvePeers(w.ctx, w.router, w.peers)
if err != nil {
return err
}
var bind conn.Bind
if w.useStdNetBind {
bind = conn.NewStdNetBind(w.listener.(dialer.WireGuardListener))
} else {
var (
isConnect bool
connectAddr netip.AddrPort
reserved [3]uint8
)
peerLen := len(w.peers)
if peerLen == 1 {
isConnect = true
connectAddr = w.peers[0].Endpoint
reserved = w.peers[0].Reserved
}
bind = wireguard.NewClientBind(w.ctx, w.logger, w.listener, isConnect, connectAddr, reserved)
}
if w.useStdNetBind || len(w.peers) > 1 {
for _, peer := range w.peers {
if peer.Reserved != [3]uint8{} {
bind.SetReservedForEndpoint(peer.Endpoint, peer.Reserved)
}
}
}
err = w.tunDevice.Start()
if err != nil {
return err
}
wgDevice := device.NewDevice(w.ctx, w.tunDevice, bind, &device.Logger{
Verbosef: func(format string, args ...interface{}) {
w.logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
},
Errorf: func(format string, args ...interface{}) {
w.logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
},
}, w.workers)
ipcConf := w.ipcConf
for _, peer := range w.peers {
ipcConf += peer.GenerateIpcLines()
}
err = wgDevice.IpcSet(ipcConf)
if err != nil {
return E.Cause(err, "setup wireguard: \n", ipcConf)
}
w.device = wgDevice
w.pauseCallback = w.pauseManager.RegisterCallback(w.onPauseUpdated)
return nil
}
func (w *Outbound) Close() error {
if w.device != nil {
w.device.Close()
}
if w.pauseCallback != nil {
w.pauseManager.UnregisterCallback(w.pauseCallback)
func (o *Outbound) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateStart:
return o.endpoint.Start(false)
case adapter.StartStatePostStart:
return o.endpoint.Start(true)
}
return nil
}
func (w *Outbound) InterfaceUpdated() {
w.device.BindUpdate()
func (o *Outbound) Close() error {
return o.endpoint.Close()
}
func (o *Outbound) InterfaceUpdated() {
o.endpoint.BindUpdate()
return
}
func (w *Outbound) onPauseUpdated(event int) {
switch event {
case pause.EventDevicePaused:
w.device.Down()
case pause.EventDeviceWake:
w.device.Up()
}
}
func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
func (o *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch network {
case N.NetworkTCP:
w.logger.InfoContext(ctx, "outbound connection to ", destination)
o.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
return N.DialSerial(ctx, w.tunDevice, network, destination, destinationAddresses)
return N.DialSerial(ctx, o.endpoint, network, destination, destinationAddresses)
} else if !destination.Addr.IsValid() {
return nil, E.New("invalid destination: ", destination)
}
return w.tunDevice.DialContext(ctx, network, destination)
return o.endpoint.DialContext(ctx, network, destination)
}
func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
packetConn, _, err := N.ListenSerial(ctx, w.tunDevice, destination, destinationAddresses)
packetConn, _, err := N.ListenSerial(ctx, o.endpoint, destination, destinationAddresses)
if err != nil {
return nil, err
}
return packetConn, err
}
return w.tunDevice.ListenPacket(ctx, destination)
return o.endpoint.ListenPacket(ctx, destination)
}