From a5cb9f4f5f6d2bd5185d62528847ea5505b0794d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= <i@sekai.icu>
Date: Wed, 6 Nov 2024 17:10:13 +0800
Subject: [PATCH] Remove unused ICMP replies & Improve tcpbuf options

---
 stack_gvisor.go                |  86 +++++++++----
 stack_gvisor_lazy.go           | 228 ---------------------------------
 stack_gvisor_tcpbuf_default.go |  18 +++
 stack_gvisor_tcpbuf_ios.go     |  21 +++
 stack_gvisor_udp.go            |  12 +-
 stack_system.go                |  36 +-----
 tun_linux.go                   |  10 +-
 7 files changed, 122 insertions(+), 289 deletions(-)
 delete mode 100644 stack_gvisor_lazy.go
 create mode 100644 stack_gvisor_tcpbuf_default.go
 create mode 100644 stack_gvisor_tcpbuf_ios.go

diff --git a/stack_gvisor.go b/stack_gvisor.go
index 60af865..89983f1 100644
--- a/stack_gvisor.go
+++ b/stack_gvisor.go
@@ -5,7 +5,7 @@ package tun
 import (
 	"context"
 	"net/netip"
-	"os"
+	"runtime"
 	"time"
 
 	"github.com/sagernet/gvisor/pkg/tcpip"
@@ -17,6 +17,7 @@ import (
 	"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
 	"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
 	"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
+	"github.com/sagernet/gvisor/pkg/waiter"
 	E "github.com/sagernet/sing/common/exceptions"
 	"github.com/sagernet/sing/common/logger"
 	M "github.com/sagernet/sing/common/metadata"
@@ -77,17 +78,35 @@ func (t *GVisor) Start() error {
 		destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
 		pErr := t.handler.PrepareConnection(N.NetworkTCP, source, destination)
 		if pErr != nil {
-			r.Complete(gWriteUnreachable(t.stack, r.Packet(), err) == os.ErrInvalid)
+			r.Complete(pErr != ErrDrop)
 			return
 		}
-		conn := &gLazyConn{
-			parentCtx:  t.ctx,
-			stack:      t.stack,
-			request:    r,
-			localAddr:  source.TCPAddr(),
-			remoteAddr: destination.TCPAddr(),
+		var (
+			wq       waiter.Queue
+			endpoint tcpip.Endpoint
+			tErr     tcpip.Error
+		)
+		handshakeCtx, cancel := context.WithCancel(context.Background())
+		go func() {
+			select {
+			case <-t.ctx.Done():
+				wq.Notify(wq.Events())
+			case <-handshakeCtx.Done():
+			}
+		}()
+		endpoint, tErr = r.CreateEndpoint(&wq)
+		cancel()
+		if tErr != nil {
+			r.Complete(true)
+			return
 		}
-		go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil)
+		r.Complete(false)
+		endpoint.SocketOptions().SetKeepAlive(true)
+		keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
+		endpoint.SetSockOpt(&keepAliveIdle)
+		keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
+		endpoint.SetSockOpt(&keepAliveInterval)
+		go t.handler.NewConnectionEx(t.ctx, gonet.NewTCPConn(&wq, endpoint), source, destination, nil)
 	})
 	ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
 	ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
@@ -134,30 +153,47 @@ func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
 			icmp.NewProtocol6,
 		},
 	})
-	tErr := ipStack.CreateNIC(defaultNIC, ep)
-	if tErr != nil {
-		return nil, E.New("create nic: ", gonet.TranslateNetstackError(tErr))
+	err := ipStack.CreateNIC(defaultNIC, ep)
+	if err != nil {
+		return nil, gonet.TranslateNetstackError(err)
 	}
 	ipStack.SetRouteTable([]tcpip.Route{
 		{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
 		{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
 	})
-	ipStack.SetSpoofing(defaultNIC, true)
-	ipStack.SetPromiscuousMode(defaultNIC, true)
-	bufSize := 20 * 1024
-	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{
-		Min:     1,
-		Default: bufSize,
-		Max:     bufSize,
-	})
-	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{
-		Min:     1,
-		Default: bufSize,
-		Max:     bufSize,
-	})
+	err = ipStack.SetSpoofing(defaultNIC, true)
+	if err != nil {
+		return nil, gonet.TranslateNetstackError(err)
+	}
+	err = ipStack.SetPromiscuousMode(defaultNIC, true)
+	if err != nil {
+		return nil, gonet.TranslateNetstackError(err)
+	}
 	sOpt := tcpip.TCPSACKEnabled(true)
 	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
 	mOpt := tcpip.TCPModerateReceiveBufferOption(true)
 	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
+	if runtime.GOOS == "windows" {
+		tcpRecoveryOpt := tcpip.TCPRecovery(0)
+		err = ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRecoveryOpt)
+	}
+	tcpRXBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{
+		Min:     tcpRXBufMinSize,
+		Default: tcpRXBufDefSize,
+		Max:     tcpRXBufMaxSize,
+	}
+	err = ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRXBufOpt)
+	if err != nil {
+		return nil, gonet.TranslateNetstackError(err)
+	}
+	tcpTXBufOpt := tcpip.TCPSendBufferSizeRangeOption{
+		Min:     tcpTXBufMinSize,
+		Default: tcpTXBufDefSize,
+		Max:     tcpTXBufMaxSize,
+	}
+	err = ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTXBufOpt)
+	if err != nil {
+		return nil, gonet.TranslateNetstackError(err)
+	}
 	return ipStack, nil
 }
diff --git a/stack_gvisor_lazy.go b/stack_gvisor_lazy.go
deleted file mode 100644
index 16abdac..0000000
--- a/stack_gvisor_lazy.go
+++ /dev/null
@@ -1,228 +0,0 @@
-//go:build with_gvisor
-
-package tun
-
-import (
-	"context"
-	"errors"
-	"net"
-	"os"
-	"syscall"
-	"time"
-
-	"github.com/sagernet/gvisor/pkg/tcpip"
-	"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
-	"github.com/sagernet/gvisor/pkg/tcpip/header"
-	"github.com/sagernet/gvisor/pkg/tcpip/stack"
-	"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
-	"github.com/sagernet/gvisor/pkg/waiter"
-)
-
-type gLazyConn struct {
-	tcpConn       *gonet.TCPConn
-	parentCtx     context.Context
-	stack         *stack.Stack
-	request       *tcp.ForwarderRequest
-	localAddr     net.Addr
-	remoteAddr    net.Addr
-	handshakeDone bool
-	handshakeErr  error
-}
-
-func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
-	if c.handshakeDone {
-		return nil
-	}
-	defer func() {
-		c.handshakeDone = true
-	}()
-	var (
-		wq       waiter.Queue
-		endpoint tcpip.Endpoint
-	)
-	handshakeCtx, cancel := context.WithCancel(ctx)
-	go func() {
-		select {
-		case <-c.parentCtx.Done():
-			wq.Notify(wq.Events())
-		case <-handshakeCtx.Done():
-		}
-	}()
-	endpoint, err := c.request.CreateEndpoint(&wq)
-	cancel()
-	if err != nil {
-		gErr := gonet.TranslateNetstackError(err)
-		c.handshakeErr = gErr
-		c.request.Complete(true)
-		return gErr
-	}
-	c.request.Complete(false)
-	endpoint.SocketOptions().SetKeepAlive(true)
-	keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
-	endpoint.SetSockOpt(&keepAliveIdle)
-	keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
-	endpoint.SetSockOpt(&keepAliveInterval)
-	tcpConn := gonet.NewTCPConn(&wq, endpoint)
-	c.tcpConn = tcpConn
-	return nil
-}
-
-func (c *gLazyConn) HandshakeFailure(err error) error {
-	if c.handshakeDone {
-		return nil
-	}
-	c.request.Complete(gWriteUnreachable(c.stack, c.request.Packet(), err) == os.ErrInvalid)
-	c.handshakeDone = true
-	c.handshakeErr = err
-	return nil
-}
-
-func (c *gLazyConn) HandshakeSuccess() error {
-	return c.HandshakeContext(context.Background())
-}
-
-func (c *gLazyConn) Read(b []byte) (n int, err error) {
-	if !c.handshakeDone {
-		err = c.HandshakeContext(context.Background())
-		if err != nil {
-			return
-		}
-	} else if c.handshakeErr != nil {
-		return 0, c.handshakeErr
-	}
-	return c.tcpConn.Read(b)
-}
-
-func (c *gLazyConn) Write(b []byte) (n int, err error) {
-	if !c.handshakeDone {
-		err = c.HandshakeContext(context.Background())
-		if err != nil {
-			return
-		}
-	} else if c.handshakeErr != nil {
-		return 0, c.handshakeErr
-	}
-	return c.tcpConn.Write(b)
-}
-
-func (c *gLazyConn) LocalAddr() net.Addr {
-	return c.localAddr
-}
-
-func (c *gLazyConn) RemoteAddr() net.Addr {
-	return c.remoteAddr
-}
-
-func (c *gLazyConn) SetDeadline(t time.Time) error {
-	if !c.handshakeDone {
-		err := c.HandshakeContext(context.Background())
-		if err != nil {
-			return err
-		}
-	} else if c.handshakeErr != nil {
-		return c.handshakeErr
-	}
-	return c.tcpConn.SetDeadline(t)
-}
-
-func (c *gLazyConn) SetReadDeadline(t time.Time) error {
-	if !c.handshakeDone {
-		err := c.HandshakeContext(context.Background())
-		if err != nil {
-			return err
-		}
-	} else if c.handshakeErr != nil {
-		return c.handshakeErr
-	}
-	return c.tcpConn.SetReadDeadline(t)
-}
-
-func (c *gLazyConn) SetWriteDeadline(t time.Time) error {
-	if !c.handshakeDone {
-		err := c.HandshakeContext(context.Background())
-		if err != nil {
-			return err
-		}
-	} else if c.handshakeErr != nil {
-		return c.handshakeErr
-	}
-	return c.tcpConn.SetWriteDeadline(t)
-}
-
-func (c *gLazyConn) Close() error {
-	if !c.handshakeDone {
-		c.request.Complete(true)
-		c.handshakeErr = net.ErrClosed
-		return nil
-	} else if c.handshakeErr != nil {
-		return nil
-	}
-	return c.tcpConn.Close()
-}
-
-func (c *gLazyConn) CloseRead() error {
-	if !c.handshakeDone {
-		c.request.Complete(true)
-		c.handshakeErr = net.ErrClosed
-		return nil
-	} else if c.handshakeErr != nil {
-		return nil
-	}
-	return c.tcpConn.CloseRead()
-}
-
-func (c *gLazyConn) CloseWrite() error {
-	if !c.handshakeDone {
-		c.request.Complete(true)
-		c.handshakeErr = net.ErrClosed
-		return nil
-	} else if c.handshakeErr != nil {
-		return nil
-	}
-	return c.tcpConn.CloseRead()
-}
-
-func (c *gLazyConn) ReaderReplaceable() bool {
-	return c.handshakeDone && c.handshakeErr == nil
-}
-
-func (c *gLazyConn) WriterReplaceable() bool {
-	return c.handshakeDone && c.handshakeErr == nil
-}
-
-func (c *gLazyConn) Upstream() any {
-	return c.tcpConn
-}
-
-func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) error {
-	if errors.Is(err, ErrDrop) {
-		return nil
-	} else if errors.Is(err, syscall.ENETUNREACH) {
-		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
-			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable)
-		} else {
-			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
-		}
-	} else if errors.Is(err, syscall.EHOSTUNREACH) {
-		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
-			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable)
-		} else {
-			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPAddrUnreachable)
-		}
-	} else if errors.Is(err, syscall.ECONNREFUSED) {
-		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
-			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
-		} else {
-			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable)
-		}
-	}
-	return os.ErrInvalid
-}
-
-func gWriteUnreachable4(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv4WithICMPType) error {
-	return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true))
-}
-
-func gWriteUnreachable6(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv6WithICMPType) error {
-	return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true))
-}
diff --git a/stack_gvisor_tcpbuf_default.go b/stack_gvisor_tcpbuf_default.go
new file mode 100644
index 0000000..f636d1a
--- /dev/null
+++ b/stack_gvisor_tcpbuf_default.go
@@ -0,0 +1,18 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build with_gvisor && !ios
+
+package tun
+
+import "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
+
+const (
+	tcpRXBufMinSize = tcp.MinBufferSize
+	tcpRXBufDefSize = tcp.DefaultSendBufferSize
+	tcpRXBufMaxSize = 8 << 20 // 8MiB
+
+	tcpTXBufMinSize = tcp.MinBufferSize
+	tcpTXBufDefSize = tcp.DefaultReceiveBufferSize
+	tcpTXBufMaxSize = 6 << 20 // 6MiB
+)
diff --git a/stack_gvisor_tcpbuf_ios.go b/stack_gvisor_tcpbuf_ios.go
new file mode 100644
index 0000000..495e59b
--- /dev/null
+++ b/stack_gvisor_tcpbuf_ios.go
@@ -0,0 +1,21 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build with_gvisor
+
+package tun
+
+import "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
+
+const (
+	// tcp{RX,TX}Buf{Min,Def,Max}Size mirror gVisor defaults. We leave these
+	// unchanged on iOS for now as to not increase pressure towards the
+	// NetworkExtension memory limit.
+	tcpRXBufMinSize = tcp.MinBufferSize
+	tcpRXBufDefSize = tcp.DefaultSendBufferSize
+	tcpRXBufMaxSize = tcp.MaxBufferSize
+
+	tcpTXBufMinSize = tcp.MinBufferSize
+	tcpTXBufDefSize = tcp.DefaultReceiveBufferSize
+	tcpTXBufMaxSize = tcp.MaxBufferSize
+)
diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go
index dd0c8a0..22e7e09 100644
--- a/stack_gvisor_udp.go
+++ b/stack_gvisor_udp.go
@@ -59,7 +59,9 @@ func rangeIterate(r stack.Range, fn func(*buffer.View))
 func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
 	pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination)
 	if pErr != nil {
-		gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer), pErr)
+		if pErr != ErrDrop {
+			gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer), pErr)
+		}
 		return false, nil, nil, nil
 	}
 	var sourceNetwork tcpip.NetworkProtocolNumber
@@ -147,3 +149,11 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock
 	route.Stats().UDP.PacketsSent.Increment()
 	return nil
 }
+
+func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) error {
+	if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+		return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, stack.RejectIPv4WithICMPPortUnreachable, true))
+	} else {
+		return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, stack.RejectIPv6WithICMPPortUnreachable, true))
+	}
+}
diff --git a/stack_system.go b/stack_system.go
index 2baa0c6..a06329e 100644
--- a/stack_system.go
+++ b/stack_system.go
@@ -2,7 +2,6 @@ package tun
 
 import (
 	"context"
-	"errors"
 	"net"
 	"net/netip"
 	"syscall"
@@ -357,14 +356,8 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
 	} else {
 		natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
 		if err != nil {
-			if errors.Is(err, ErrDrop) {
+			if err == ErrDrop {
 				return false, nil
-			} else if errors.Is(err, syscall.ENETUNREACH) {
-				return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4NetUnreachable)
-			} else if errors.Is(err, syscall.EHOSTUNREACH) {
-				return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4HostUnreachable)
-			} else if errors.Is(err, syscall.ECONNREFUSED) {
-				return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable)
 			} else {
 				return false, s.resetIPv4TCP(ipHdr, tcpHdr)
 			}
@@ -450,14 +443,8 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err
 	} else {
 		natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
 		if err != nil {
-			if errors.Is(err, ErrDrop) {
+			if err == ErrDrop {
 				return false, nil
-			} else if errors.Is(err, syscall.ENETUNREACH) {
-				return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6NetworkUnreachable)
-			} else if errors.Is(err, syscall.EHOSTUNREACH) {
-				return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6AddressUnreachable)
-			} else if errors.Is(err, syscall.ECONNREFUSED) {
-				return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6PortUnreachable)
 			} else {
 				return false, s.resetIPv6TCP(ipHdr, tcpHdr)
 			}
@@ -551,23 +538,12 @@ func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error {
 func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
 	pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination)
 	if pErr != nil {
-		if errors.Is(pErr, ErrDrop) {
-		} else if source.IsIPv4() {
-			ipHdr := userData.(header.IPv4)
-			if errors.Is(pErr, syscall.ENETUNREACH) {
-				s.rejectIPv4WithICMP(ipHdr, header.ICMPv4NetUnreachable)
-			} else if errors.Is(pErr, syscall.EHOSTUNREACH) {
-				s.rejectIPv4WithICMP(ipHdr, header.ICMPv4HostUnreachable)
-			} else {
+		if pErr != ErrDrop {
+			if source.IsIPv4() {
+				ipHdr := userData.(header.IPv4)
 				s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable)
-			}
-		} else {
-			ipHdr := userData.(header.IPv6)
-			if errors.Is(pErr, syscall.ENETUNREACH) {
-				s.rejectIPv6WithICMP(ipHdr, header.ICMPv6NetworkUnreachable)
-			} else if errors.Is(pErr, syscall.EHOSTUNREACH) {
-				s.rejectIPv6WithICMP(ipHdr, header.ICMPv6AddressUnreachable)
 			} else {
+				ipHdr := userData.(header.IPv6)
 				s.rejectIPv6WithICMP(ipHdr, header.ICMPv6PortUnreachable)
 			}
 		}
diff --git a/tun_linux.go b/tun_linux.go
index 390d5b4..a799b8d 100644
--- a/tun_linux.go
+++ b/tun_linux.go
@@ -224,7 +224,6 @@ func open(name string, vnetHdr bool) (int, error) {
 func (t *NativeTun) configure(tunLink netlink.Link) error {
 	err := netlink.LinkSetMTU(tunLink, int(t.options.MTU))
 	if errors.Is(err, unix.EPERM) {
-		// unprivileged
 		return nil
 	} else if err != nil {
 		return err
@@ -293,16 +292,17 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
 }
 
 func (t *NativeTun) Start() error {
+	if t.options.FileDescriptor != 0 {
+		return nil
+	}
+
 	tunLink, err := netlink.LinkByName(t.options.Name)
 	if err != nil {
 		return err
 	}
 
 	err = netlink.LinkSetUp(tunLink)
-	if errors.Is(err, unix.EPERM) {
-		// unprivileged
-		return nil
-	} else if err != nil {
+	if err != nil {
 		return err
 	}