diff --git a/go.mod b/go.mod index e4d7596..3130f46 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,9 @@ module github.com/sagernet/sing-tun go 1.18 require ( - github.com/sagernet/sing v0.0.0-20220714145306-09b55ce4b6d0 + github.com/sagernet/sing v0.0.0-20220726034811-bc109486f14e github.com/vishvananda/netlink v1.1.0 - golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e + golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f gvisor.dev/gvisor v0.0.0-20220711011657-cecae2f4234d ) diff --git a/go.sum b/go.sum index 8211807..9d3cc31 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= -github.com/sagernet/sing v0.0.0-20220714145306-09b55ce4b6d0 h1:8tnMLN6jdqKkjPXwgEekwloPaAmvbxQAMMHdWYOiMj8= -github.com/sagernet/sing v0.0.0-20220714145306-09b55ce4b6d0/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c= +github.com/sagernet/sing v0.0.0-20220726034811-bc109486f14e h1:5lfrAc+vSv0iW6eHGNLyHC+a/k6BDGJvYxYxwB/68Kk= +github.com/sagernet/sing v0.0.0-20220726034811-bc109486f14e/go.mod h1:GbtQfZSpmtD3cXeD1qX2LCMwY8dH+bnnInDTqd92IsM= github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= @@ -9,8 +9,8 @@ github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695AP github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e h1:NHvCuwuS43lGnYhten69ZWqi2QOj/CiDNcKbVqwVoew= -golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gvisor.dev/gvisor v0.0.0-20220711011657-cecae2f4234d h1:KjI6i6P1ib9DiNdNIN8pb2TXfBewpKHf3O58cjj9vw4= diff --git a/gvisor.go b/gvisor.go index be9e288..a1e4d39 100644 --- a/gvisor.go +++ b/gvisor.go @@ -24,19 +24,30 @@ import ( const defaultNIC tcpip.NICID = 1 type GVisorTun struct { - ctx context.Context - tun Tun - tunMtu uint32 - handler Handler - stack *stack.Stack + ctx context.Context + tun Tun + tunMtu uint32 + endpointIndependentNat bool + endpointIndependentNatTimeout int64 + handler Handler + stack *stack.Stack } -func NewGVisor(ctx context.Context, tun Tun, tunMtu uint32, handler Handler) *GVisorTun { +func NewGVisor( + ctx context.Context, + tun Tun, + tunMtu uint32, + endpointIndependentNat bool, + endpointIndependentNatTimeout int64, + handler Handler, +) *GVisorTun { return &GVisorTun{ - ctx: ctx, - tun: tun, - tunMtu: tunMtu, - handler: handler, + ctx: ctx, + tun: tun, + tunMtu: tunMtu, + endpointIndependentNat: endpointIndependentNat, + endpointIndependentNatTimeout: endpointIndependentNatTimeout, + handler: handler, } } @@ -82,7 +93,8 @@ func (t *GVisorTun) Start() error { ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt) mOpt := tcpip.TCPModerateReceiveBufferOption(true) ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt) - tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) { + + ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) { var wq waiter.Queue endpoint, err := r.CreateEndpoint(&wq) if err != nil { @@ -111,34 +123,36 @@ func (t *GVisorTun) Start() error { endpoint.Abort() } }() - }) - ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(id stack.TransportEndpointID, buffer *stack.PacketBuffer) bool { - return tcpForwarder.HandlePacket(id, buffer) - }) - udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) { - var wq waiter.Queue - endpoint, err := request.CreateEndpoint(&wq) - if err != nil { - return - } - udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint) - lAddr := udpConn.RemoteAddr() - rAddr := udpConn.LocalAddr() - if lAddr == nil || rAddr == nil { - endpoint.Abort() - return - } - go func() { - var metadata M.Metadata - metadata.Source = M.SocksaddrFromNet(lAddr) - metadata.Destination = M.SocksaddrFromNet(rAddr) - hErr := t.handler.NewPacketConnection(t.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(udpConn), Addr: M.SocksaddrFromNet(rAddr)}), metadata) - if hErr != nil { - endpoint.Abort() + }).HandlePacket) + + if !t.endpointIndependentNat { + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) { + var wq waiter.Queue + endpoint, err := request.CreateEndpoint(&wq) + if err != nil { + return } - }() - }) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint) + lAddr := udpConn.RemoteAddr() + rAddr := udpConn.LocalAddr() + if lAddr == nil || rAddr == nil { + endpoint.Abort() + return + } + go func() { + var metadata M.Metadata + metadata.Source = M.SocksaddrFromNet(lAddr) + metadata.Destination = M.SocksaddrFromNet(rAddr) + hErr := t.handler.NewPacketConnection(ContextWithNeedTimeout(t.ctx, true), bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(udpConn), Addr: M.SocksaddrFromNet(rAddr)}), metadata) + if hErr != nil { + endpoint.Abort() + } + }() + }).HandlePacket) + } else { + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.endpointIndependentNatTimeout).HandlePacket) + } + t.stack = ipStack return nil } diff --git a/gvisor_udp.go b/gvisor_udp.go new file mode 100644 index 0000000..ddfd11a --- /dev/null +++ b/gvisor_udp.go @@ -0,0 +1,128 @@ +package tun + +import ( + "context" + "math" + "net" + "net/netip" + + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/udpnat" + + gBuffer "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type UDPForwarder struct { + ctx context.Context + stack *stack.Stack + handler Handler + udpNat *udpnat.Service[netip.AddrPort] +} + +func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder { + return &UDPForwarder{ + ctx: ctx, + stack: stack, + handler: handler, + udpNat: udpnat.New[netip.AddrPort](udpTimeout, nopErrorHandler{handler}), + } +} + +func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + var upstreamMetadata M.Metadata + upstreamMetadata.Source = M.SocksaddrFrom(M.AddrFromIP(net.IP(id.RemoteAddress)), id.RemotePort) + upstreamMetadata.Destination = M.SocksaddrFrom(M.AddrFromIP(net.IP(id.LocalAddress)), id.LocalPort) + + f.udpNat.NewPacket( + f.ctx, + upstreamMetadata.Source.AddrPort(), + buf.As(pkt.Data().AsRange().AsView()), + upstreamMetadata, + func(natConn N.PacketConn) N.PacketWriter { + return &UDPBackWriter{f.stack, id.RemoteAddress, id.RemotePort} + }, + ) + return true +} + +type UDPBackWriter struct { + stack *stack.Stack + source tcpip.Address + sourcePort uint16 +} + +func (w *UDPBackWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + defer buffer.Release() + + var netProto tcpip.NetworkProtocolNumber + if destination.IsIPv4() { + netProto = header.IPv4ProtocolNumber + } else { + netProto = header.IPv6ProtocolNumber + } + + route, err := w.stack.FindRoute( + defaultNIC, + tcpip.Address(destination.Addr.AsSlice()), + w.source, + netProto, + false, + ) + if err != nil { + return E.New(err) + } + defer route.Release() + + packet := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()), + Payload: gBuffer.NewWithData(buffer.Bytes()), + }) + defer packet.DecRef() + + packet.TransportProtocolNumber = header.UDPProtocolNumber + udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize)) + pLen := uint16(packet.Size()) + udpHdr.Encode(&header.UDPFields{ + SrcPort: destination.Port, + DstPort: w.sourcePort, + Length: pLen, + }) + + if route.RequiresTXTransportChecksum() && netProto == header.IPv6ProtocolNumber { + xsum := udpHdr.CalculateChecksum(header.ChecksumCombine( + route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen), + packet.Data().AsRange().Checksum(), + )) + if xsum != math.MaxUint16 { + xsum = ^xsum + } + udpHdr.SetChecksum(xsum) + } + + err = route.WritePacket(stack.NetworkHeaderParams{ + Protocol: header.UDPProtocolNumber, + TTL: route.DefaultTTL(), + TOS: 0, + }, packet) + + if err != nil { + route.Stats().UDP.PacketSendErrors.Increment() + return E.New(err) + } + + route.Stats().UDP.PacketsSent.Increment() + return nil +} + +type nopErrorHandler struct { + Handler +} + +func (h nopErrorHandler) NewError(ctx context.Context, err error) { +} diff --git a/timeout.go b/timeout.go new file mode 100644 index 0000000..838ca6b --- /dev/null +++ b/timeout.go @@ -0,0 +1,14 @@ +package tun + +import "context" + +type needTimeoutKey struct{} + +func ContextWithNeedTimeout(ctx context.Context, need bool) context.Context { + return context.WithValue(ctx, (*needTimeoutKey)(nil), need) +} + +func NeedTimeoutFromContext(ctx context.Context) bool { + need, _ := ctx.Value((*needTimeoutKey)(nil)).(bool) + return need +}