mirror of
https://github.com/SagerNet/sing-tun.git
synced 2025-04-02 11:27:39 +03:00
Migrate to udpnat2 / Add PrepareConnection
This commit is contained in:
parent
99eea00432
commit
7f8e556bb0
6 changed files with 121 additions and 84 deletions
3
stack.go
3
stack.go
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
@ -23,7 +24,7 @@ type StackOptions struct {
|
|||
Tun Tun
|
||||
TunOptions Options
|
||||
EndpointIndependentNat bool
|
||||
UDPTimeout int64
|
||||
UDPTimeout time.Duration
|
||||
Handler Handler
|
||||
Logger logger.Logger
|
||||
ForwarderBindInterface bool
|
||||
|
|
|
@ -5,6 +5,7 @@ package tun
|
|||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
|
@ -32,7 +33,7 @@ type GVisor struct {
|
|||
ctx context.Context
|
||||
tun GVisorTun
|
||||
endpointIndependentNat bool
|
||||
udpTimeout int64
|
||||
udpTimeout time.Duration
|
||||
broadcastAddr netip.Addr
|
||||
handler Handler
|
||||
logger logger.Logger
|
||||
|
@ -85,13 +86,18 @@ func (t *GVisor) Start() error {
|
|||
localAddr: source.TCPAddr(),
|
||||
remoteAddr: destination.TCPAddr(),
|
||||
}
|
||||
pErr := t.handler.PrepareConnection(source, destination)
|
||||
if pErr != nil {
|
||||
r.Complete(gWriteUnreachable(t.stack, r.Packet(), err) == os.ErrInvalid)
|
||||
return
|
||||
}
|
||||
go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil)
|
||||
})
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
if !t.endpointIndependentNat {
|
||||
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
|
||||
udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) {
|
||||
var wq waiter.Queue
|
||||
endpoint, err := request.CreateEndpoint(&wq)
|
||||
endpoint, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -102,9 +108,15 @@ func (t *GVisor) Start() error {
|
|||
endpoint.Abort()
|
||||
return
|
||||
}
|
||||
source := M.SocksaddrFromNet(lAddr)
|
||||
destination := M.SocksaddrFromNet(rAddr)
|
||||
pErr := t.handler.PrepareConnection(source, destination)
|
||||
if pErr != nil {
|
||||
gWriteUnreachable(t.stack, r.Packet(), pErr)
|
||||
r.Packet().DecRef()
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
source := M.SocksaddrFromNet(lAddr)
|
||||
destination := M.SocksaddrFromNet(rAddr)
|
||||
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, destination), time.Duration(t.udpTimeout)*time.Second)
|
||||
t.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil)
|
||||
}()
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
|
@ -19,59 +21,60 @@ import (
|
|||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
"github.com/sagernet/sing/common/udpnat2"
|
||||
)
|
||||
|
||||
type UDPForwarder struct {
|
||||
ctx context.Context
|
||||
stack *stack.Stack
|
||||
udpNat *udpnat.Service[netip.AddrPort]
|
||||
|
||||
// cache
|
||||
cacheProto tcpip.NetworkProtocolNumber
|
||||
cacheID stack.TransportEndpointID
|
||||
ctx context.Context
|
||||
stack *stack.Stack
|
||||
handler Handler
|
||||
udpNat *udpnat.Service
|
||||
}
|
||||
|
||||
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
|
||||
return &UDPForwarder{
|
||||
ctx: ctx,
|
||||
stack: stack,
|
||||
udpNat: udpnat.NewEx[netip.AddrPort](udpTimeout, handler),
|
||||
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, timeout time.Duration) *UDPForwarder {
|
||||
forwarder := &UDPForwarder{
|
||||
ctx: ctx,
|
||||
stack: stack,
|
||||
handler: handler,
|
||||
}
|
||||
forwarder.udpNat = udpnat.New(handler, forwarder.PreparePacketConnection, timeout)
|
||||
return forwarder
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
|
||||
destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
|
||||
if source.IsIPv4() {
|
||||
f.cacheProto = header.IPv4ProtocolNumber
|
||||
} else {
|
||||
f.cacheProto = header.IPv6ProtocolNumber
|
||||
}
|
||||
gBuffer := pkt.Data().ToBuffer()
|
||||
sBuffer := buf.NewSize(int(gBuffer.Size()))
|
||||
gBuffer.Apply(func(view *buffer.View) {
|
||||
sBuffer.Write(view.AsSlice())
|
||||
bufferRange := pkt.Data().AsRange()
|
||||
bufferSlices := make([][]byte, bufferRange.Size())
|
||||
rangeIterate(bufferRange, func(view *buffer.View) {
|
||||
bufferSlices = append(bufferSlices, view.AsSlice())
|
||||
})
|
||||
f.cacheID = id
|
||||
f.udpNat.NewPacketEx(
|
||||
f.ctx,
|
||||
source.AddrPort(),
|
||||
sBuffer,
|
||||
source,
|
||||
destination,
|
||||
f.newUDPConn,
|
||||
)
|
||||
f.udpNat.NewPacket(bufferSlices, source, destination, pkt)
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter {
|
||||
return &UDPBackWriter{
|
||||
stack: f.stack,
|
||||
source: f.cacheID.RemoteAddress,
|
||||
sourcePort: f.cacheID.RemotePort,
|
||||
sourceNetwork: f.cacheProto,
|
||||
//go:linkname rangeIterate github.com/sagernet/gvisor/pkg/tcpip/stack.Range.iterate
|
||||
func rangeIterate(r stack.Range, fn func(*buffer.View))
|
||||
|
||||
func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
|
||||
pErr := f.handler.PrepareConnection(source, destination)
|
||||
if pErr != nil {
|
||||
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer), pErr)
|
||||
return false, nil, nil, nil
|
||||
}
|
||||
var sourceNetwork tcpip.NetworkProtocolNumber
|
||||
if source.Addr.Is4() {
|
||||
sourceNetwork = header.IPv4ProtocolNumber
|
||||
} else {
|
||||
sourceNetwork = header.IPv6ProtocolNumber
|
||||
}
|
||||
writer := &UDPBackWriter{
|
||||
stack: f.stack,
|
||||
source: AddressFromAddr(source.Addr),
|
||||
sourcePort: source.Port,
|
||||
sourceNetwork: sourceNetwork,
|
||||
}
|
||||
return true, f.ctx, writer, nil
|
||||
}
|
||||
|
||||
type UDPBackWriter struct {
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
"github.com/sagernet/sing/common/udpnat2"
|
||||
)
|
||||
|
||||
var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25")
|
||||
|
@ -34,13 +34,13 @@ type System struct {
|
|||
inet6ServerAddress netip.Addr
|
||||
inet6Address netip.Addr
|
||||
broadcastAddr netip.Addr
|
||||
udpTimeout int64
|
||||
udpTimeout time.Duration
|
||||
tcpListener net.Listener
|
||||
tcpListener6 net.Listener
|
||||
tcpPort uint16
|
||||
tcpPort6 uint16
|
||||
tcpNat *TCPNat
|
||||
udpNat *udpnat.Service[netip.AddrPort]
|
||||
udpNat *udpnat.Service
|
||||
bindInterface bool
|
||||
interfaceFinder control.InterfaceFinder
|
||||
frontHeadroom int
|
||||
|
@ -151,8 +151,8 @@ func (s *System) start() error {
|
|||
s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port
|
||||
go s.acceptLoop(tcpListener)
|
||||
}
|
||||
s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout))
|
||||
s.udpNat = udpnat.NewEx[netip.AddrPort](s.udpTimeout, s.handler)
|
||||
s.tcpNat = NewNat(s.ctx, s.udpTimeout)
|
||||
s.udpNat = udpnat.New(s.handler, s.preparePacketConnection, s.udpTimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -354,7 +354,11 @@ func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.
|
|||
packet.SetDestinationIP(session.Source.Addr())
|
||||
header.SetDestinationPort(session.Source.Port())
|
||||
} else {
|
||||
natPort := s.tcpNat.Lookup(source, destination)
|
||||
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
|
||||
if err != nil {
|
||||
// TODO: implement rejects
|
||||
return nil
|
||||
}
|
||||
packet.SetSourceIP(s.inet4Address)
|
||||
header.SetSourcePort(natPort)
|
||||
packet.SetDestinationIP(s.inet4ServerAddress)
|
||||
|
@ -385,7 +389,11 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.
|
|||
packet.SetDestinationIP(session.Source.Addr())
|
||||
header.SetDestinationPort(session.Source.Port())
|
||||
} else {
|
||||
natPort := s.tcpNat.Lookup(source, destination)
|
||||
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
|
||||
if err != nil {
|
||||
// TODO: implement rejects
|
||||
return nil
|
||||
}
|
||||
packet.SetSourceIP(s.inet6Address)
|
||||
header.SetSourcePort(natPort)
|
||||
packet.SetDestinationIP(s.inet6ServerAddress)
|
||||
|
@ -409,27 +417,12 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.
|
|||
if !header.Valid() {
|
||||
return E.New("ipv4: udp: invalid packet")
|
||||
}
|
||||
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
|
||||
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
|
||||
if !destination.Addr().IsGlobalUnicast() {
|
||||
source := M.SocksaddrFrom(packet.SourceIP(), header.SourcePort())
|
||||
destination := M.SocksaddrFrom(packet.DestinationIP(), header.DestinationPort())
|
||||
if !destination.Addr.IsGlobalUnicast() {
|
||||
return nil
|
||||
}
|
||||
data := buf.As(header.Payload())
|
||||
if data.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
s.udpNat.NewPacketEx(s.ctx, source, data.ToOwned(), M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), func(natConn N.PacketConn) N.PacketWriter {
|
||||
headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize
|
||||
headerCopy := make([]byte, headerLen)
|
||||
copy(headerCopy, packet[:headerLen])
|
||||
return &systemUDPPacketWriter4{
|
||||
s.tun,
|
||||
s.frontHeadroom + PacketOffset,
|
||||
headerCopy,
|
||||
source,
|
||||
s.txChecksumOffload,
|
||||
}
|
||||
})
|
||||
s.udpNat.NewPacket([][]byte{header.Payload()}, source, destination, packet)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -437,28 +430,48 @@ func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.
|
|||
if !header.Valid() {
|
||||
return E.New("ipv6: udp: invalid packet")
|
||||
}
|
||||
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
|
||||
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
|
||||
if !destination.Addr().IsGlobalUnicast() {
|
||||
source := M.SocksaddrFrom(packet.SourceIP(), header.SourcePort())
|
||||
destination := M.SocksaddrFrom(packet.DestinationIP(), header.DestinationPort())
|
||||
if !destination.Addr.IsGlobalUnicast() {
|
||||
return nil
|
||||
}
|
||||
data := buf.As(header.Payload())
|
||||
if data.Len() == 0 {
|
||||
return nil
|
||||
s.udpNat.NewPacket([][]byte{header.Payload()}, source, destination, packet)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
|
||||
pErr := s.handler.PrepareConnection(source, destination)
|
||||
if pErr != nil {
|
||||
// TODO: implement ICMP port unreachable
|
||||
return false, nil, nil, nil
|
||||
}
|
||||
s.udpNat.NewPacketEx(s.ctx, source, data.ToOwned(), M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), func(natConn N.PacketConn) N.PacketWriter {
|
||||
headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize
|
||||
var writer N.PacketWriter
|
||||
if source.IsIPv4() {
|
||||
packet := userData.(clashtcpip.IPv4Packet)
|
||||
headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize
|
||||
headerCopy := make([]byte, headerLen)
|
||||
copy(headerCopy, packet[:headerLen])
|
||||
return &systemUDPPacketWriter6{
|
||||
writer = &systemUDPPacketWriter4{
|
||||
s.tun,
|
||||
s.frontHeadroom + PacketOffset,
|
||||
headerCopy,
|
||||
source,
|
||||
source.AddrPort(),
|
||||
s.txChecksumOffload,
|
||||
}
|
||||
})
|
||||
return nil
|
||||
} else {
|
||||
packet := userData.(clashtcpip.IPv6Packet)
|
||||
headerLen := len(packet) - int(packet.PayloadLength()) + clashtcpip.UDPHeaderSize
|
||||
headerCopy := make([]byte, headerLen)
|
||||
copy(headerCopy, packet[:headerLen])
|
||||
writer = &systemUDPPacketWriter6{
|
||||
s.tun,
|
||||
s.frontHeadroom + PacketOffset,
|
||||
headerCopy,
|
||||
source.AddrPort(),
|
||||
s.txChecksumOffload,
|
||||
}
|
||||
}
|
||||
return true, s.ctx, writer, nil
|
||||
}
|
||||
|
||||
func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error {
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type TCPNat struct {
|
||||
|
@ -68,12 +70,16 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession {
|
|||
return session
|
||||
}
|
||||
|
||||
func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint16 {
|
||||
func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort, handler Handler) (uint16, error) {
|
||||
n.addrAccess.RLock()
|
||||
port, loaded := n.addrMap[source]
|
||||
n.addrAccess.RUnlock()
|
||||
if loaded {
|
||||
return port
|
||||
return port, nil
|
||||
}
|
||||
pErr := handler.PrepareConnection(M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination))
|
||||
if pErr != nil {
|
||||
return 0, pErr
|
||||
}
|
||||
n.addrAccess.Lock()
|
||||
nextPort := n.portIndex
|
||||
|
@ -92,5 +98,5 @@ func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint1
|
|||
LastActive: time.Now(),
|
||||
}
|
||||
n.portAccess.Unlock()
|
||||
return nextPort
|
||||
return nextPort, nil
|
||||
}
|
||||
|
|
2
tun.go
2
tun.go
|
@ -10,11 +10,13 @@ import (
|
|||
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ranges"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
PrepareConnection(source M.Socksaddr, destination M.Socksaddr) error
|
||||
N.TCPConnectionHandlerEx
|
||||
N.UDPConnectionHandlerEx
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue