Add endpoint independent nat support

This commit is contained in:
世界 2022-07-26 19:15:04 +08:00
parent 3b0c717db3
commit b4bded886e
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 200 additions and 44 deletions

4
go.mod
View file

@ -3,9 +3,9 @@ module github.com/sagernet/sing-tun
go 1.18
require (
github.com/sagernet/sing v0.0.0-20220714145306-09b55ce4b6d0
github.com/sagernet/sing v0.0.0-20220726034811-bc109486f14e
github.com/vishvananda/netlink v1.1.0
golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f
gvisor.dev/gvisor v0.0.0-20220711011657-cecae2f4234d
)

8
go.sum
View file

@ -1,7 +1,7 @@
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/sagernet/sing v0.0.0-20220714145306-09b55ce4b6d0 h1:8tnMLN6jdqKkjPXwgEekwloPaAmvbxQAMMHdWYOiMj8=
github.com/sagernet/sing v0.0.0-20220714145306-09b55ce4b6d0/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c=
github.com/sagernet/sing v0.0.0-20220726034811-bc109486f14e h1:5lfrAc+vSv0iW6eHGNLyHC+a/k6BDGJvYxYxwB/68Kk=
github.com/sagernet/sing v0.0.0-20220726034811-bc109486f14e/go.mod h1:GbtQfZSpmtD3cXeD1qX2LCMwY8dH+bnnInDTqd92IsM=
github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0=
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
@ -9,8 +9,8 @@ github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695AP
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e h1:NHvCuwuS43lGnYhten69ZWqi2QOj/CiDNcKbVqwVoew=
golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
gvisor.dev/gvisor v0.0.0-20220711011657-cecae2f4234d h1:KjI6i6P1ib9DiNdNIN8pb2TXfBewpKHf3O58cjj9vw4=

View file

@ -27,15 +27,26 @@ type GVisorTun struct {
ctx context.Context
tun Tun
tunMtu uint32
endpointIndependentNat bool
endpointIndependentNatTimeout int64
handler Handler
stack *stack.Stack
}
func NewGVisor(ctx context.Context, tun Tun, tunMtu uint32, handler Handler) *GVisorTun {
func NewGVisor(
ctx context.Context,
tun Tun,
tunMtu uint32,
endpointIndependentNat bool,
endpointIndependentNatTimeout int64,
handler Handler,
) *GVisorTun {
return &GVisorTun{
ctx: ctx,
tun: tun,
tunMtu: tunMtu,
endpointIndependentNat: endpointIndependentNat,
endpointIndependentNatTimeout: endpointIndependentNatTimeout,
handler: handler,
}
}
@ -82,7 +93,8 @@ func (t *GVisorTun) Start() error {
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
mOpt := tcpip.TCPModerateReceiveBufferOption(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
endpoint, err := r.CreateEndpoint(&wq)
if err != nil {
@ -111,11 +123,10 @@ func (t *GVisorTun) Start() error {
endpoint.Abort()
}
}()
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(id stack.TransportEndpointID, buffer *stack.PacketBuffer) bool {
return tcpForwarder.HandlePacket(id, buffer)
})
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
}).HandlePacket)
if !t.endpointIndependentNat {
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
var wq waiter.Queue
endpoint, err := request.CreateEndpoint(&wq)
if err != nil {
@ -132,13 +143,16 @@ func (t *GVisorTun) Start() error {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
hErr := t.handler.NewPacketConnection(t.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(udpConn), Addr: M.SocksaddrFromNet(rAddr)}), metadata)
hErr := t.handler.NewPacketConnection(ContextWithNeedTimeout(t.ctx, true), bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(udpConn), Addr: M.SocksaddrFromNet(rAddr)}), metadata)
if hErr != nil {
endpoint.Abort()
}
}()
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
}).HandlePacket)
} else {
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.endpointIndependentNatTimeout).HandlePacket)
}
t.stack = ipStack
return nil
}

128
gvisor_udp.go Normal file
View file

@ -0,0 +1,128 @@
package tun
import (
"context"
"math"
"net"
"net/netip"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
gBuffer "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
type UDPForwarder struct {
ctx context.Context
stack *stack.Stack
handler Handler
udpNat *udpnat.Service[netip.AddrPort]
}
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
return &UDPForwarder{
ctx: ctx,
stack: stack,
handler: handler,
udpNat: udpnat.New[netip.AddrPort](udpTimeout, nopErrorHandler{handler}),
}
}
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
var upstreamMetadata M.Metadata
upstreamMetadata.Source = M.SocksaddrFrom(M.AddrFromIP(net.IP(id.RemoteAddress)), id.RemotePort)
upstreamMetadata.Destination = M.SocksaddrFrom(M.AddrFromIP(net.IP(id.LocalAddress)), id.LocalPort)
f.udpNat.NewPacket(
f.ctx,
upstreamMetadata.Source.AddrPort(),
buf.As(pkt.Data().AsRange().AsView()),
upstreamMetadata,
func(natConn N.PacketConn) N.PacketWriter {
return &UDPBackWriter{f.stack, id.RemoteAddress, id.RemotePort}
},
)
return true
}
type UDPBackWriter struct {
stack *stack.Stack
source tcpip.Address
sourcePort uint16
}
func (w *UDPBackWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
var netProto tcpip.NetworkProtocolNumber
if destination.IsIPv4() {
netProto = header.IPv4ProtocolNumber
} else {
netProto = header.IPv6ProtocolNumber
}
route, err := w.stack.FindRoute(
defaultNIC,
tcpip.Address(destination.Addr.AsSlice()),
w.source,
netProto,
false,
)
if err != nil {
return E.New(err)
}
defer route.Release()
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
Payload: gBuffer.NewWithData(buffer.Bytes()),
})
defer packet.DecRef()
packet.TransportProtocolNumber = header.UDPProtocolNumber
udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
pLen := uint16(packet.Size())
udpHdr.Encode(&header.UDPFields{
SrcPort: destination.Port,
DstPort: w.sourcePort,
Length: pLen,
})
if route.RequiresTXTransportChecksum() && netProto == header.IPv6ProtocolNumber {
xsum := udpHdr.CalculateChecksum(header.ChecksumCombine(
route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
packet.Data().AsRange().Checksum(),
))
if xsum != math.MaxUint16 {
xsum = ^xsum
}
udpHdr.SetChecksum(xsum)
}
err = route.WritePacket(stack.NetworkHeaderParams{
Protocol: header.UDPProtocolNumber,
TTL: route.DefaultTTL(),
TOS: 0,
}, packet)
if err != nil {
route.Stats().UDP.PacketSendErrors.Increment()
return E.New(err)
}
route.Stats().UDP.PacketsSent.Increment()
return nil
}
type nopErrorHandler struct {
Handler
}
func (h nopErrorHandler) NewError(ctx context.Context, err error) {
}

14
timeout.go Normal file
View file

@ -0,0 +1,14 @@
package tun
import "context"
type needTimeoutKey struct{}
func ContextWithNeedTimeout(ctx context.Context, need bool) context.Context {
return context.WithValue(ctx, (*needTimeoutKey)(nil), need)
}
func NeedTimeoutFromContext(ctx context.Context) bool {
need, _ := ctx.Value((*needTimeoutKey)(nil)).(bool)
return need
}