mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-04-03 03:47:37 +03:00
Improve read wait interface &
Refactor Authenticator interface to struct & Update smux & Update gVisor to 20231204.0 & Update quic-go to v0.40.1 & Update wireguard-go & Add GSO support for TUN/WireGuard & Fix router pre-start & Fix bind forwarder to interface for systems stack
This commit is contained in:
parent
35fd9de3ff
commit
89c723e3e4
48 changed files with 902 additions and 658 deletions
|
@ -17,16 +17,16 @@ func (c *NATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaiter, bool) {
|
|||
|
||||
type waitNATPacketConn struct {
|
||||
*NATPacketConn
|
||||
waiter N.PacketReadWaiter
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (c *waitNATPacketConn) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
c.waiter.InitializeReadWaiter(newBuffer)
|
||||
func (c *waitNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return c.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (c *waitNATPacketConn) WaitReadPacket() (destination M.Socksaddr, err error) {
|
||||
destination, err = c.waiter.WaitReadPacket()
|
||||
if socksaddrWithoutPort(destination) == c.origin {
|
||||
func (c *waitNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
buffer, destination, err = c.readWaiter.WaitReadPacket()
|
||||
if err == nil && socksaddrWithoutPort(destination) == c.origin {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.destination.Addr,
|
||||
Fqdn: c.destination.Fqdn,
|
||||
|
|
|
@ -53,7 +53,7 @@ func newMuxConnection0(ctx context.Context, stream net.Conn, metadata M.Metadata
|
|||
case CommandTCP:
|
||||
return handler.NewConnection(ctx, stream, metadata)
|
||||
case CommandUDP:
|
||||
return handler.NewPacketConnection(ctx, &PacketConn{stream}, metadata)
|
||||
return handler.NewPacketConnection(ctx, &PacketConn{Conn: stream}, metadata)
|
||||
default:
|
||||
return E.New("unknown command ", command)
|
||||
}
|
||||
|
|
|
@ -85,9 +85,10 @@ func (c *ClientConn) Upstream() any {
|
|||
|
||||
type ClientPacketConn struct {
|
||||
net.Conn
|
||||
access sync.Mutex
|
||||
key [KeyLength]byte
|
||||
headerWritten bool
|
||||
access sync.Mutex
|
||||
key [KeyLength]byte
|
||||
headerWritten bool
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {
|
||||
|
|
45
transport/trojan/protocol_wait.go
Normal file
45
transport/trojan/protocol_wait.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package trojan
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
)
|
||||
|
||||
var _ N.PacketReadWaiter = (*ClientPacketConn)(nil)
|
||||
|
||||
func (c *ClientPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
c.readWaitOptions = options
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.Conn)
|
||||
if err != nil {
|
||||
return nil, M.Socksaddr{}, E.Cause(err, "read destination")
|
||||
}
|
||||
|
||||
var length uint16
|
||||
err = binary.Read(c.Conn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return nil, M.Socksaddr{}, E.Cause(err, "read chunk length")
|
||||
}
|
||||
|
||||
err = rw.SkipN(c.Conn, 2)
|
||||
if err != nil {
|
||||
return nil, M.Socksaddr{}, E.Cause(err, "skip crlf")
|
||||
}
|
||||
|
||||
buffer = c.readWaitOptions.NewPacketBuffer()
|
||||
_, err = buffer.ReadFullFrom(c.Conn, int(length))
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
c.readWaitOptions.PostReturn(buffer)
|
||||
return
|
||||
}
|
|
@ -105,7 +105,7 @@ func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata
|
|||
case CommandTCP:
|
||||
return s.handler.NewConnection(ctx, conn, metadata)
|
||||
case CommandUDP:
|
||||
return s.handler.NewPacketConnection(ctx, &PacketConn{conn}, metadata)
|
||||
return s.handler.NewPacketConnection(ctx, &PacketConn{Conn: conn}, metadata)
|
||||
// case CommandMux:
|
||||
default:
|
||||
return HandleMuxConnection(ctx, conn, metadata, s.handler)
|
||||
|
@ -122,6 +122,7 @@ func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Met
|
|||
|
||||
type PacketConn struct {
|
||||
net.Conn
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||
|
|
45
transport/trojan/service_wait.go
Normal file
45
transport/trojan/service_wait.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package trojan
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
)
|
||||
|
||||
var _ N.PacketReadWaiter = (*PacketConn)(nil)
|
||||
|
||||
func (c *PacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
c.readWaitOptions = options
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *PacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.Conn)
|
||||
if err != nil {
|
||||
return nil, M.Socksaddr{}, E.Cause(err, "read destination")
|
||||
}
|
||||
|
||||
var length uint16
|
||||
err = binary.Read(c.Conn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return nil, M.Socksaddr{}, E.Cause(err, "read chunk length")
|
||||
}
|
||||
|
||||
err = rw.SkipN(c.Conn, 2)
|
||||
if err != nil {
|
||||
return nil, M.Socksaddr{}, E.Cause(err, "skip crlf")
|
||||
}
|
||||
|
||||
buffer = c.readWaitOptions.NewPacketBuffer()
|
||||
_, err = buffer.ReadFullFrom(c.Conn, int(length))
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
c.readWaitOptions.PostReturn(buffer)
|
||||
return
|
||||
}
|
|
@ -12,7 +12,6 @@ import (
|
|||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/service/pause"
|
||||
"github.com/sagernet/wireguard-go/conn"
|
||||
)
|
||||
|
||||
|
@ -22,33 +21,27 @@ type ClientBind struct {
|
|||
ctx context.Context
|
||||
errorHandler E.Handler
|
||||
dialer N.Dialer
|
||||
reservedForEndpoint map[M.Socksaddr][3]uint8
|
||||
reservedForEndpoint map[netip.AddrPort][3]uint8
|
||||
connAccess sync.Mutex
|
||||
conn *wireConn
|
||||
done chan struct{}
|
||||
isConnect bool
|
||||
connectAddr M.Socksaddr
|
||||
connectAddr netip.AddrPort
|
||||
reserved [3]uint8
|
||||
pauseManager pause.Manager
|
||||
}
|
||||
|
||||
func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
|
||||
func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
|
||||
return &ClientBind{
|
||||
ctx: ctx,
|
||||
errorHandler: errorHandler,
|
||||
dialer: dialer,
|
||||
reservedForEndpoint: make(map[M.Socksaddr][3]uint8),
|
||||
reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
|
||||
isConnect: isConnect,
|
||||
connectAddr: connectAddr,
|
||||
reserved: reserved,
|
||||
pauseManager: pause.ManagerFromContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientBind) SetReservedForEndpoint(destination M.Socksaddr, reserved [3]byte) {
|
||||
c.reservedForEndpoint[destination] = reserved
|
||||
}
|
||||
|
||||
func (c *ClientBind) connect() (*wireConn, error) {
|
||||
serverConn := c.conn
|
||||
if serverConn != nil {
|
||||
|
@ -71,16 +64,13 @@ func (c *ClientBind) connect() (*wireConn, error) {
|
|||
}
|
||||
}
|
||||
if c.isConnect {
|
||||
udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr)
|
||||
udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.conn = &wireConn{
|
||||
PacketConn: &bufio.UnbindPacketConn{
|
||||
ExtendedConn: bufio.NewExtendedConn(udpConn),
|
||||
Addr: c.connectAddr,
|
||||
},
|
||||
done: make(chan struct{}),
|
||||
PacketConn: bufio.NewUnbindPacketConn(udpConn),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
} else {
|
||||
udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
|
||||
|
@ -116,7 +106,6 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
|
|||
c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
|
||||
err = nil
|
||||
time.Sleep(time.Second)
|
||||
c.pauseManager.WaitActive()
|
||||
return
|
||||
}
|
||||
n, addr, err := udpConn.ReadFrom(packets[0])
|
||||
|
@ -133,11 +122,9 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
|
|||
sizes[0] = n
|
||||
if n > 3 {
|
||||
b := packets[0]
|
||||
b[1] = 0
|
||||
b[2] = 0
|
||||
b[3] = 0
|
||||
common.ClearArray(b[1:4])
|
||||
}
|
||||
eps[0] = Endpoint(M.SocksaddrFromNet(addr))
|
||||
eps[0] = Endpoint(M.AddrPortFromNet(addr))
|
||||
count = 1
|
||||
return
|
||||
}
|
||||
|
@ -170,18 +157,16 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destination := M.Socksaddr(ep.(Endpoint))
|
||||
destination := netip.AddrPort(ep.(Endpoint))
|
||||
for _, b := range bufs {
|
||||
if len(b) > 3 {
|
||||
reserved, loaded := c.reservedForEndpoint[destination]
|
||||
if !loaded {
|
||||
reserved = c.reserved
|
||||
}
|
||||
b[1] = reserved[0]
|
||||
b[2] = reserved[1]
|
||||
b[3] = reserved[2]
|
||||
copy(b[1:4], reserved[:])
|
||||
}
|
||||
_, err = udpConn.WriteTo(b, destination)
|
||||
_, err = udpConn.WriteTo(b, M.SocksaddrFromNetIP(destination))
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
return err
|
||||
|
@ -191,13 +176,21 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
|||
}
|
||||
|
||||
func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||
return Endpoint(M.ParseSocksaddr(s)), nil
|
||||
ap, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Endpoint(ap), nil
|
||||
}
|
||||
|
||||
func (c *ClientBind) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) {
|
||||
c.reservedForEndpoint[destination] = reserved
|
||||
}
|
||||
|
||||
type wireConn struct {
|
||||
net.PacketConn
|
||||
access sync.Mutex
|
||||
|
|
|
@ -265,7 +265,7 @@ func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
|
|||
}
|
||||
|
||||
func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
return stack.CapabilityNone
|
||||
return stack.CapabilityRXChecksumOffload
|
||||
}
|
||||
|
||||
func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
|
|
|
@ -2,6 +2,7 @@ package wireguard
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
@ -11,6 +12,7 @@ import (
|
|||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
wgTun "github.com/sagernet/wireguard-go/tun"
|
||||
|
@ -19,16 +21,17 @@ import (
|
|||
var _ Device = (*SystemDevice)(nil)
|
||||
|
||||
type SystemDevice struct {
|
||||
dialer N.Dialer
|
||||
device tun.Tun
|
||||
name string
|
||||
mtu int
|
||||
events chan wgTun.Event
|
||||
addr4 netip.Addr
|
||||
addr6 netip.Addr
|
||||
dialer N.Dialer
|
||||
device tun.Tun
|
||||
batchDevice tun.LinuxTUN
|
||||
name string
|
||||
mtu int
|
||||
events chan wgTun.Event
|
||||
addr4 netip.Addr
|
||||
addr6 netip.Addr
|
||||
}
|
||||
|
||||
func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32) (*SystemDevice, error) {
|
||||
func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) {
|
||||
var inet4Addresses []netip.Prefix
|
||||
var inet6Addresses []netip.Prefix
|
||||
for _, prefixes := range localPrefixes {
|
||||
|
@ -46,6 +49,7 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes
|
|||
Inet4Address: inet4Addresses,
|
||||
Inet6Address: inet6Addresses,
|
||||
MTU: mtu,
|
||||
GSO: gso,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -58,16 +62,25 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes
|
|||
if len(inet6Addresses) > 0 {
|
||||
inet6Address = inet6Addresses[0].Addr()
|
||||
}
|
||||
var batchDevice tun.LinuxTUN
|
||||
if gso {
|
||||
batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN)
|
||||
if !isBatchTUN {
|
||||
return nil, E.New("GSO is not supported on current platform")
|
||||
}
|
||||
batchDevice = batchTUN
|
||||
}
|
||||
return &SystemDevice{
|
||||
dialer: common.Must1(dialer.NewDefault(router, option.DialerOptions{
|
||||
BindInterface: interfaceName,
|
||||
})),
|
||||
device: tunInterface,
|
||||
name: interfaceName,
|
||||
mtu: int(mtu),
|
||||
events: make(chan wgTun.Event),
|
||||
addr4: inet4Address,
|
||||
addr6: inet6Address,
|
||||
device: tunInterface,
|
||||
batchDevice: batchDevice,
|
||||
name: interfaceName,
|
||||
mtu: int(mtu),
|
||||
events: make(chan wgTun.Event),
|
||||
addr4: inet4Address,
|
||||
addr6: inet6Address,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -97,21 +110,31 @@ func (w *SystemDevice) File() *os.File {
|
|||
}
|
||||
|
||||
func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
||||
sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:])
|
||||
if err == nil {
|
||||
count = 1
|
||||
if w.batchDevice != nil {
|
||||
count, err = w.batchDevice.BatchRead(bufs, offset, sizes)
|
||||
} else {
|
||||
sizes[0], err = w.device.Read(bufs[0][offset:])
|
||||
if err == nil {
|
||||
count = 1
|
||||
} else if errors.Is(err, tun.ErrTooManySegments) {
|
||||
err = wgTun.ErrTooManySegments
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||
for _, b := range bufs {
|
||||
_, err = w.device.Write(b[offset:])
|
||||
if err != nil {
|
||||
return
|
||||
if w.batchDevice != nil {
|
||||
return 0, w.batchDevice.BatchWrite(bufs, offset)
|
||||
} else {
|
||||
for _, b := range bufs {
|
||||
_, err = w.device.Write(b[offset:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
count++
|
||||
}
|
||||
// WireGuard will not read count
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -136,5 +159,8 @@ func (w *SystemDevice) Close() error {
|
|||
}
|
||||
|
||||
func (w *SystemDevice) BatchSize() int {
|
||||
if w.batchDevice != nil {
|
||||
return w.batchDevice.BatchSize()
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
|
|
@ -3,13 +3,12 @@ package wireguard
|
|||
import (
|
||||
"net/netip"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/wireguard-go/conn"
|
||||
)
|
||||
|
||||
var _ conn.Endpoint = (*Endpoint)(nil)
|
||||
|
||||
type Endpoint M.Socksaddr
|
||||
type Endpoint netip.AddrPort
|
||||
|
||||
func (e Endpoint) ClearSrc() {
|
||||
}
|
||||
|
@ -19,16 +18,16 @@ func (e Endpoint) SrcToString() string {
|
|||
}
|
||||
|
||||
func (e Endpoint) DstToString() string {
|
||||
return (M.Socksaddr)(e).String()
|
||||
return (netip.AddrPort)(e).String()
|
||||
}
|
||||
|
||||
func (e Endpoint) DstToBytes() []byte {
|
||||
b, _ := (M.Socksaddr)(e).AddrPort().MarshalBinary()
|
||||
b, _ := (netip.AddrPort)(e).MarshalBinary()
|
||||
return b
|
||||
}
|
||||
|
||||
func (e Endpoint) DstIP() netip.Addr {
|
||||
return (M.Socksaddr)(e).Addr
|
||||
return (netip.AddrPort)(e).Addr()
|
||||
}
|
||||
|
||||
func (e Endpoint) SrcIP() netip.Addr {
|
||||
|
|
148
transport/wireguard/resolve.go
Normal file
148
transport/wireguard/resolve.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
dns "github.com/sagernet/sing-dns"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type PeerConfig struct {
|
||||
destination M.Socksaddr
|
||||
domainStrategy dns.DomainStrategy
|
||||
Endpoint netip.AddrPort
|
||||
PublicKey string
|
||||
PreSharedKey string
|
||||
AllowedIPs []string
|
||||
Reserved [3]uint8
|
||||
}
|
||||
|
||||
func (c PeerConfig) GenerateIpcLines() string {
|
||||
ipcLines := "\npublic_key=" + c.PublicKey
|
||||
ipcLines += "\nendpoint=" + c.Endpoint.String()
|
||||
if c.PreSharedKey != "" {
|
||||
ipcLines += "\npreshared_key=" + c.PreSharedKey
|
||||
}
|
||||
for _, allowedIP := range c.AllowedIPs {
|
||||
ipcLines += "\nallowed_ip=" + allowedIP
|
||||
}
|
||||
return ipcLines
|
||||
}
|
||||
|
||||
func ParsePeers(options option.WireGuardOutboundOptions) ([]PeerConfig, error) {
|
||||
var peers []PeerConfig
|
||||
if len(options.Peers) > 0 {
|
||||
for peerIndex, rawPeer := range options.Peers {
|
||||
peer := PeerConfig{
|
||||
AllowedIPs: rawPeer.AllowedIPs,
|
||||
}
|
||||
destination := rawPeer.ServerOptions.Build()
|
||||
if destination.IsFqdn() {
|
||||
peer.destination = destination
|
||||
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
|
||||
} else {
|
||||
peer.Endpoint = destination.AddrPort()
|
||||
}
|
||||
{
|
||||
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
|
||||
}
|
||||
peer.PublicKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
if rawPeer.PreSharedKey != "" {
|
||||
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
|
||||
}
|
||||
peer.PreSharedKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
if len(rawPeer.AllowedIPs) == 0 {
|
||||
return nil, E.New("missing allowed_ips for peer ", peerIndex)
|
||||
}
|
||||
if len(rawPeer.Reserved) > 0 {
|
||||
if len(rawPeer.Reserved) != 3 {
|
||||
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.Reserved))
|
||||
}
|
||||
copy(peer.Reserved[:], options.Reserved)
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
} else {
|
||||
peer := PeerConfig{}
|
||||
var (
|
||||
addressHas4 bool
|
||||
addressHas6 bool
|
||||
)
|
||||
for _, localAddress := range options.LocalAddress {
|
||||
if localAddress.Addr().Is4() {
|
||||
addressHas4 = true
|
||||
} else {
|
||||
addressHas6 = true
|
||||
}
|
||||
}
|
||||
if addressHas4 {
|
||||
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv4Unspecified(), 0).String())
|
||||
}
|
||||
if addressHas6 {
|
||||
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv6Unspecified(), 0).String())
|
||||
}
|
||||
destination := options.ServerOptions.Build()
|
||||
if destination.IsFqdn() {
|
||||
peer.destination = destination
|
||||
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
|
||||
} else {
|
||||
peer.Endpoint = destination.AddrPort()
|
||||
}
|
||||
{
|
||||
bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode peer public key")
|
||||
}
|
||||
peer.PublicKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
if options.PreSharedKey != "" {
|
||||
bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode pre shared key")
|
||||
}
|
||||
peer.PreSharedKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
if len(options.Reserved) > 0 {
|
||||
if len(options.Reserved) != 3 {
|
||||
return nil, E.New("invalid reserved value, required 3 bytes, got ", len(peer.Reserved))
|
||||
}
|
||||
copy(peer.Reserved[:], options.Reserved)
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func ResolvePeers(ctx context.Context, router adapter.Router, peers []PeerConfig) error {
|
||||
for peerIndex, peer := range peers {
|
||||
if peer.Endpoint.IsValid() {
|
||||
continue
|
||||
}
|
||||
destinationAddresses, err := router.Lookup(ctx, peer.destination.Fqdn, peer.domainStrategy)
|
||||
if err != nil {
|
||||
if len(peers) == 1 {
|
||||
return E.Cause(err, "resolve endpoint domain")
|
||||
} else {
|
||||
return E.Cause(err, "resolve endpoint domain for peer ", peerIndex)
|
||||
}
|
||||
}
|
||||
if len(destinationAddresses) == 0 {
|
||||
return E.New("no addresses found for endpoint domain: ", peer.destination.Fqdn)
|
||||
}
|
||||
peers[peerIndex].Endpoint = netip.AddrPortFrom(destinationAddresses[0], peer.destination.Port)
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue