Migrate to udpnat2 / Add PrepareConnection

This commit is contained in:
世界 2024-10-21 21:55:07 +08:00
parent 99eea00432
commit 7f8e556bb0
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
6 changed files with 121 additions and 84 deletions

View file

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"net" "net"
"net/netip" "net/netip"
"time"
"github.com/sagernet/sing/common/control" "github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@ -23,7 +24,7 @@ type StackOptions struct {
Tun Tun Tun Tun
TunOptions Options TunOptions Options
EndpointIndependentNat bool EndpointIndependentNat bool
UDPTimeout int64 UDPTimeout time.Duration
Handler Handler Handler Handler
Logger logger.Logger Logger logger.Logger
ForwarderBindInterface bool ForwarderBindInterface bool

View file

@ -5,6 +5,7 @@ package tun
import ( import (
"context" "context"
"net/netip" "net/netip"
"os"
"time" "time"
"github.com/sagernet/gvisor/pkg/tcpip" "github.com/sagernet/gvisor/pkg/tcpip"
@ -32,7 +33,7 @@ type GVisor struct {
ctx context.Context ctx context.Context
tun GVisorTun tun GVisorTun
endpointIndependentNat bool endpointIndependentNat bool
udpTimeout int64 udpTimeout time.Duration
broadcastAddr netip.Addr broadcastAddr netip.Addr
handler Handler handler Handler
logger logger.Logger logger logger.Logger
@ -85,13 +86,18 @@ func (t *GVisor) Start() error {
localAddr: source.TCPAddr(), localAddr: source.TCPAddr(),
remoteAddr: destination.TCPAddr(), remoteAddr: destination.TCPAddr(),
} }
pErr := t.handler.PrepareConnection(source, destination)
if pErr != nil {
r.Complete(gWriteUnreachable(t.stack, r.Packet(), err) == os.ErrInvalid)
return
}
go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil) go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil)
}) })
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
if !t.endpointIndependentNat { if !t.endpointIndependentNat {
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) { udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) {
var wq waiter.Queue var wq waiter.Queue
endpoint, err := request.CreateEndpoint(&wq) endpoint, err := r.CreateEndpoint(&wq)
if err != nil { if err != nil {
return return
} }
@ -102,9 +108,15 @@ func (t *GVisor) Start() error {
endpoint.Abort() endpoint.Abort()
return return
} }
source := M.SocksaddrFromNet(lAddr)
destination := M.SocksaddrFromNet(rAddr)
pErr := t.handler.PrepareConnection(source, destination)
if pErr != nil {
gWriteUnreachable(t.stack, r.Packet(), pErr)
r.Packet().DecRef()
return
}
go func() { go func() {
source := M.SocksaddrFromNet(lAddr)
destination := M.SocksaddrFromNet(rAddr)
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, destination), time.Duration(t.udpTimeout)*time.Second) ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, destination), time.Duration(t.udpTimeout)*time.Second)
t.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil) t.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil)
}() }()

View file

@ -8,6 +8,8 @@ import (
"net/netip" "net/netip"
"os" "os"
"sync" "sync"
"time"
_ "unsafe"
"github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip" "github.com/sagernet/gvisor/pkg/tcpip"
@ -19,59 +21,60 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat" "github.com/sagernet/sing/common/udpnat2"
) )
type UDPForwarder struct { type UDPForwarder struct {
ctx context.Context ctx context.Context
stack *stack.Stack stack *stack.Stack
udpNat *udpnat.Service[netip.AddrPort] handler Handler
udpNat *udpnat.Service
// cache
cacheProto tcpip.NetworkProtocolNumber
cacheID stack.TransportEndpointID
} }
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder { func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, timeout time.Duration) *UDPForwarder {
return &UDPForwarder{ forwarder := &UDPForwarder{
ctx: ctx, ctx: ctx,
stack: stack, stack: stack,
udpNat: udpnat.NewEx[netip.AddrPort](udpTimeout, handler), handler: handler,
} }
forwarder.udpNat = udpnat.New(handler, forwarder.PreparePacketConnection, timeout)
return forwarder
} }
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort) source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort) destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
if source.IsIPv4() { bufferRange := pkt.Data().AsRange()
f.cacheProto = header.IPv4ProtocolNumber bufferSlices := make([][]byte, bufferRange.Size())
} else { rangeIterate(bufferRange, func(view *buffer.View) {
f.cacheProto = header.IPv6ProtocolNumber bufferSlices = append(bufferSlices, view.AsSlice())
}
gBuffer := pkt.Data().ToBuffer()
sBuffer := buf.NewSize(int(gBuffer.Size()))
gBuffer.Apply(func(view *buffer.View) {
sBuffer.Write(view.AsSlice())
}) })
f.cacheID = id f.udpNat.NewPacket(bufferSlices, source, destination, pkt)
f.udpNat.NewPacketEx(
f.ctx,
source.AddrPort(),
sBuffer,
source,
destination,
f.newUDPConn,
)
return true return true
} }
func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter { //go:linkname rangeIterate github.com/sagernet/gvisor/pkg/tcpip/stack.Range.iterate
return &UDPBackWriter{ func rangeIterate(r stack.Range, fn func(*buffer.View))
stack: f.stack,
source: f.cacheID.RemoteAddress, func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
sourcePort: f.cacheID.RemotePort, pErr := f.handler.PrepareConnection(source, destination)
sourceNetwork: f.cacheProto, if pErr != nil {
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer), pErr)
return false, nil, nil, nil
} }
var sourceNetwork tcpip.NetworkProtocolNumber
if source.Addr.Is4() {
sourceNetwork = header.IPv4ProtocolNumber
} else {
sourceNetwork = header.IPv6ProtocolNumber
}
writer := &UDPBackWriter{
stack: f.stack,
source: AddressFromAddr(source.Addr),
sourcePort: source.Port,
sourceNetwork: sourceNetwork,
}
return true, f.ctx, writer, nil
} }
type UDPBackWriter struct { type UDPBackWriter struct {

View file

@ -15,7 +15,7 @@ import (
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat" "github.com/sagernet/sing/common/udpnat2"
) )
var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25") var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25")
@ -34,13 +34,13 @@ type System struct {
inet6ServerAddress netip.Addr inet6ServerAddress netip.Addr
inet6Address netip.Addr inet6Address netip.Addr
broadcastAddr netip.Addr broadcastAddr netip.Addr
udpTimeout int64 udpTimeout time.Duration
tcpListener net.Listener tcpListener net.Listener
tcpListener6 net.Listener tcpListener6 net.Listener
tcpPort uint16 tcpPort uint16
tcpPort6 uint16 tcpPort6 uint16
tcpNat *TCPNat tcpNat *TCPNat
udpNat *udpnat.Service[netip.AddrPort] udpNat *udpnat.Service
bindInterface bool bindInterface bool
interfaceFinder control.InterfaceFinder interfaceFinder control.InterfaceFinder
frontHeadroom int frontHeadroom int
@ -151,8 +151,8 @@ func (s *System) start() error {
s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port
go s.acceptLoop(tcpListener) go s.acceptLoop(tcpListener)
} }
s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout)) s.tcpNat = NewNat(s.ctx, s.udpTimeout)
s.udpNat = udpnat.NewEx[netip.AddrPort](s.udpTimeout, s.handler) s.udpNat = udpnat.New(s.handler, s.preparePacketConnection, s.udpTimeout)
return nil return nil
} }
@ -354,7 +354,11 @@ func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.
packet.SetDestinationIP(session.Source.Addr()) packet.SetDestinationIP(session.Source.Addr())
header.SetDestinationPort(session.Source.Port()) header.SetDestinationPort(session.Source.Port())
} else { } else {
natPort := s.tcpNat.Lookup(source, destination) natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
if err != nil {
// TODO: implement rejects
return nil
}
packet.SetSourceIP(s.inet4Address) packet.SetSourceIP(s.inet4Address)
header.SetSourcePort(natPort) header.SetSourcePort(natPort)
packet.SetDestinationIP(s.inet4ServerAddress) packet.SetDestinationIP(s.inet4ServerAddress)
@ -385,7 +389,11 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.
packet.SetDestinationIP(session.Source.Addr()) packet.SetDestinationIP(session.Source.Addr())
header.SetDestinationPort(session.Source.Port()) header.SetDestinationPort(session.Source.Port())
} else { } else {
natPort := s.tcpNat.Lookup(source, destination) natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
if err != nil {
// TODO: implement rejects
return nil
}
packet.SetSourceIP(s.inet6Address) packet.SetSourceIP(s.inet6Address)
header.SetSourcePort(natPort) header.SetSourcePort(natPort)
packet.SetDestinationIP(s.inet6ServerAddress) packet.SetDestinationIP(s.inet6ServerAddress)
@ -409,27 +417,12 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.
if !header.Valid() { if !header.Valid() {
return E.New("ipv4: udp: invalid packet") return E.New("ipv4: udp: invalid packet")
} }
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) source := M.SocksaddrFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) destination := M.SocksaddrFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() { if !destination.Addr.IsGlobalUnicast() {
return nil return nil
} }
data := buf.As(header.Payload()) s.udpNat.NewPacket([][]byte{header.Payload()}, source, destination, packet)
if data.Len() == 0 {
return nil
}
s.udpNat.NewPacketEx(s.ctx, source, data.ToOwned(), M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), func(natConn N.PacketConn) N.PacketWriter {
headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
return &systemUDPPacketWriter4{
s.tun,
s.frontHeadroom + PacketOffset,
headerCopy,
source,
s.txChecksumOffload,
}
})
return nil return nil
} }
@ -437,28 +430,48 @@ func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.
if !header.Valid() { if !header.Valid() {
return E.New("ipv6: udp: invalid packet") return E.New("ipv6: udp: invalid packet")
} }
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) source := M.SocksaddrFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) destination := M.SocksaddrFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() { if !destination.Addr.IsGlobalUnicast() {
return nil return nil
} }
data := buf.As(header.Payload()) s.udpNat.NewPacket([][]byte{header.Payload()}, source, destination, packet)
if data.Len() == 0 { return nil
return nil }
func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
pErr := s.handler.PrepareConnection(source, destination)
if pErr != nil {
// TODO: implement ICMP port unreachable
return false, nil, nil, nil
} }
s.udpNat.NewPacketEx(s.ctx, source, data.ToOwned(), M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), func(natConn N.PacketConn) N.PacketWriter { var writer N.PacketWriter
headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize if source.IsIPv4() {
packet := userData.(clashtcpip.IPv4Packet)
headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen) headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen]) copy(headerCopy, packet[:headerLen])
return &systemUDPPacketWriter6{ writer = &systemUDPPacketWriter4{
s.tun, s.tun,
s.frontHeadroom + PacketOffset, s.frontHeadroom + PacketOffset,
headerCopy, headerCopy,
source, source.AddrPort(),
s.txChecksumOffload, s.txChecksumOffload,
} }
}) } else {
return nil packet := userData.(clashtcpip.IPv6Packet)
headerLen := len(packet) - int(packet.PayloadLength()) + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
writer = &systemUDPPacketWriter6{
s.tun,
s.frontHeadroom + PacketOffset,
headerCopy,
source.AddrPort(),
s.txChecksumOffload,
}
}
return true, s.ctx, writer, nil
} }
func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error {

View file

@ -5,6 +5,8 @@ import (
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
M "github.com/sagernet/sing/common/metadata"
) )
type TCPNat struct { type TCPNat struct {
@ -68,12 +70,16 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession {
return session return session
} }
func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint16 { func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort, handler Handler) (uint16, error) {
n.addrAccess.RLock() n.addrAccess.RLock()
port, loaded := n.addrMap[source] port, loaded := n.addrMap[source]
n.addrAccess.RUnlock() n.addrAccess.RUnlock()
if loaded { if loaded {
return port return port, nil
}
pErr := handler.PrepareConnection(M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination))
if pErr != nil {
return 0, pErr
} }
n.addrAccess.Lock() n.addrAccess.Lock()
nextPort := n.portIndex nextPort := n.portIndex
@ -92,5 +98,5 @@ func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint1
LastActive: time.Now(), LastActive: time.Now(),
} }
n.portAccess.Unlock() n.portAccess.Unlock()
return nextPort return nextPort, nil
} }

2
tun.go
View file

@ -10,11 +10,13 @@ import (
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ranges" "github.com/sagernet/sing/common/ranges"
) )
type Handler interface { type Handler interface {
PrepareConnection(source M.Socksaddr, destination M.Socksaddr) error
N.TCPConnectionHandlerEx N.TCPConnectionHandlerEx
N.UDPConnectionHandlerEx N.UDPConnectionHandlerEx
} }