From fa5a74ec01f018e0eac91394cd007b23047c22e2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= <i@sekai.icu>
Date: Thu, 20 Mar 2025 20:48:23 +0800
Subject: [PATCH] Explicitly reject detour to empty direct outbounds

---
 adapter/dns.go               |  2 +-
 common/dialer/detour.go      | 36 ++++++++++++++++++++++++------
 common/dialer/dialer.go      |  3 ++-
 dns/router.go                |  2 +-
 dns/transport/dhcp/dhcp.go   | 10 ++-------
 dns/transport/hosts/hosts.go |  7 +++++-
 dns/transport/https.go       | 11 ++++++++-
 dns/transport/local/local.go |  7 +++++-
 dns/transport/quic/http3.go  |  8 +++++--
 dns/transport/quic/quic.go   |  7 +++++-
 dns/transport/tcp.go         | 11 ++++++++-
 dns/transport/tls.go         | 11 ++++++++-
 dns/transport/udp.go         | 11 ++++++++-
 dns/transport_dialer.go      | 16 ++++++++------
 dns/transport_manager.go     |  2 +-
 experimental/libbox/dns.go   |  7 +++++-
 option/dns.go                | 43 ++++++++++++++----------------------
 option/rule.go               |  3 +--
 option/rule_dns.go           |  1 -
 protocol/direct/outbound.go  |  9 ++++++++
 20 files changed, 142 insertions(+), 65 deletions(-)

diff --git a/adapter/dns.go b/adapter/dns.go
index e0f381b8..942f3566 100644
--- a/adapter/dns.go
+++ b/adapter/dns.go
@@ -45,10 +45,10 @@ type RDRCStore interface {
 }
 
 type DNSTransport interface {
+	Lifecycle
 	Type() string
 	Tag() string
 	Dependencies() []string
-	Reset()
 	Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
 }
 
diff --git a/common/dialer/detour.go b/common/dialer/detour.go
index e4a46049..5c0b552b 100644
--- a/common/dialer/detour.go
+++ b/common/dialer/detour.go
@@ -6,26 +6,39 @@ import (
 	"sync"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing/common"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 )
 
+type DirectDialer interface {
+	IsEmpty() bool
+}
+
 type DetourDialer struct {
 	outboundManager adapter.OutboundManager
 	detour          string
+	legacyDNSDialer bool
 	dialer          N.Dialer
 	initOnce        sync.Once
 	initErr         error
 }
 
-func NewDetour(outboundManager adapter.OutboundManager, detour string) N.Dialer {
-	return &DetourDialer{outboundManager: outboundManager, detour: detour}
+func NewDetour(outboundManager adapter.OutboundManager, detour string, legacyDNSDialer bool) N.Dialer {
+	return &DetourDialer{
+		outboundManager: outboundManager,
+		detour:          detour,
+		legacyDNSDialer: legacyDNSDialer,
+	}
 }
 
-func (d *DetourDialer) Start() error {
-	_, err := d.Dialer()
-	return err
+func InitializeDetour(dialer N.Dialer) error {
+	detourDialer, isDetour := common.Cast[*DetourDialer](dialer)
+	if !isDetour {
+		return nil
+	}
+	return common.Error(detourDialer.Dialer())
 }
 
 func (d *DetourDialer) Dialer() (N.Dialer, error) {
@@ -34,11 +47,20 @@ func (d *DetourDialer) Dialer() (N.Dialer, error) {
 }
 
 func (d *DetourDialer) init() {
-	var loaded bool
-	d.dialer, loaded = d.outboundManager.Outbound(d.detour)
+	dialer, loaded := d.outboundManager.Outbound(d.detour)
 	if !loaded {
 		d.initErr = E.New("outbound detour not found: ", d.detour)
+		return
 	}
+	if !d.legacyDNSDialer {
+		if directDialer, isDirect := dialer.(DirectDialer); isDirect {
+			if directDialer.IsEmpty() {
+				d.initErr = E.New("detour to an empty direct outbound makes no sense")
+				return
+			}
+		}
+	}
+	d.dialer = dialer
 }
 
 func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
diff --git a/common/dialer/dialer.go b/common/dialer/dialer.go
index b93b7096..88e16740 100644
--- a/common/dialer/dialer.go
+++ b/common/dialer/dialer.go
@@ -23,6 +23,7 @@ type Options struct {
 	DirectResolver   bool
 	ResolverOnDetour bool
 	NewDialer        bool
+	LegacyDNSDialer  bool
 }
 
 // TODO: merge with NewWithOptions
@@ -45,7 +46,7 @@ func NewWithOptions(options Options) (N.Dialer, error) {
 		if outboundManager == nil {
 			return nil, E.New("missing outbound manager")
 		}
-		dialer = NewDetour(outboundManager, dialOptions.Detour)
+		dialer = NewDetour(outboundManager, dialOptions.Detour, options.LegacyDNSDialer)
 	} else {
 		dialer, err = NewDefault(options.Context, dialOptions)
 		if err != nil {
diff --git a/dns/router.go b/dns/router.go
index 42cebd23..44edadbd 100644
--- a/dns/router.go
+++ b/dns/router.go
@@ -449,6 +449,6 @@ func (r *Router) LookupReverseMapping(ip netip.Addr) (string, bool) {
 func (r *Router) ResetNetwork() {
 	r.ClearCache()
 	for _, transport := range r.transport.Transports() {
-		transport.Reset()
+		transport.Close()
 	}
 }
diff --git a/dns/transport/dhcp/dhcp.go b/dns/transport/dhcp/dhcp.go
index c75d7369..92dd1f8b 100644
--- a/dns/transport/dhcp/dhcp.go
+++ b/dns/transport/dhcp/dhcp.go
@@ -81,7 +81,7 @@ func (t *Transport) Start(stage adapter.StartStage) error {
 
 func (t *Transport) Close() error {
 	for _, transport := range t.transports {
-		transport.Reset()
+		transport.Close()
 	}
 	if t.interfaceCallback != nil {
 		t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback)
@@ -89,12 +89,6 @@ func (t *Transport) Close() error {
 	return nil
 }
 
-func (t *Transport) Reset() {
-	for _, transport := range t.transports {
-		transport.Reset()
-	}
-}
-
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
 	err := t.fetchServers()
 	if err != nil {
@@ -252,7 +246,7 @@ func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []M.So
 		transports = append(transports, transport.NewUDPRaw(t.logger, t.TransportAdapter, serverDialer, serverAddr))
 	}
 	for _, transport := range t.transports {
-		transport.Reset()
+		transport.Close()
 	}
 	t.transports = transports
 	return nil
diff --git a/dns/transport/hosts/hosts.go b/dns/transport/hosts/hosts.go
index 0a1dd395..a5eecb40 100644
--- a/dns/transport/hosts/hosts.go
+++ b/dns/transport/hosts/hosts.go
@@ -51,7 +51,12 @@ func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, opt
 	}, nil
 }
 
-func (t *Transport) Reset() {
+func (t *Transport) Start(stage adapter.StartStage) error {
+	return nil
+}
+
+func (t *Transport) Close() error {
+	return nil
 }
 
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
diff --git a/dns/transport/https.go b/dns/transport/https.go
index bd150d5f..1750fd26 100644
--- a/dns/transport/https.go
+++ b/dns/transport/https.go
@@ -10,6 +10,7 @@ import (
 	"strconv"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
 	"github.com/sagernet/sing-box/common/tls"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
@@ -149,9 +150,17 @@ func NewHTTPSRaw(
 	}
 }
 
-func (t *HTTPSTransport) Reset() {
+func (t *HTTPSTransport) Start(stage adapter.StartStage) error {
+	if stage != adapter.StartStateStart {
+		return nil
+	}
+	return dialer.InitializeDetour(t.dialer)
+}
+
+func (t *HTTPSTransport) Close() error {
 	t.transport.CloseIdleConnections()
 	t.transport = t.transport.Clone()
+	return nil
 }
 
 func (t *HTTPSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
diff --git a/dns/transport/local/local.go b/dns/transport/local/local.go
index 4e05ff02..f50405a5 100644
--- a/dns/transport/local/local.go
+++ b/dns/transport/local/local.go
@@ -40,7 +40,12 @@ func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, opt
 	}, nil
 }
 
-func (t *Transport) Reset() {
+func (t *Transport) Start(stage adapter.StartStage) error {
+	return nil
+}
+
+func (t *Transport) Close() error {
+	return nil
 }
 
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
diff --git a/dns/transport/quic/http3.go b/dns/transport/quic/http3.go
index e2a75b50..0d871741 100644
--- a/dns/transport/quic/http3.go
+++ b/dns/transport/quic/http3.go
@@ -111,8 +111,12 @@ func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options
 	}, nil
 }
 
-func (t *HTTP3Transport) Reset() {
-	t.transport.Close()
+func (t *HTTP3Transport) Start(stage adapter.StartStage) error {
+	return nil
+}
+
+func (t *HTTP3Transport) Close() error {
+	return t.transport.Close()
 }
 
 func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
diff --git a/dns/transport/quic/quic.go b/dns/transport/quic/quic.go
index 4ae9ac16..fc5101ee 100644
--- a/dns/transport/quic/quic.go
+++ b/dns/transport/quic/quic.go
@@ -68,13 +68,18 @@ func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options
 	}, nil
 }
 
-func (t *Transport) Reset() {
+func (t *Transport) Start(stage adapter.StartStage) error {
+	return nil
+}
+
+func (t *Transport) Close() error {
 	t.access.Lock()
 	defer t.access.Unlock()
 	connection := t.connection
 	if connection != nil {
 		connection.CloseWithError(0, "")
 	}
+	return nil
 }
 
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
diff --git a/dns/transport/tcp.go b/dns/transport/tcp.go
index 4abeee2f..a814c030 100644
--- a/dns/transport/tcp.go
+++ b/dns/transport/tcp.go
@@ -6,6 +6,7 @@ import (
 	"io"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
 	"github.com/sagernet/sing-box/log"
@@ -46,7 +47,15 @@ func NewTCP(ctx context.Context, logger log.ContextLogger, tag string, options o
 	}, nil
 }
 
-func (t *TCPTransport) Reset() {
+func (t *TCPTransport) Start(stage adapter.StartStage) error {
+	if stage != adapter.StartStateStart {
+		return nil
+	}
+	return dialer.InitializeDetour(t.dialer)
+}
+
+func (t *TCPTransport) Close() error {
+	return nil
 }
 
 func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
diff --git a/dns/transport/tls.go b/dns/transport/tls.go
index ce88d425..a99bf2f7 100644
--- a/dns/transport/tls.go
+++ b/dns/transport/tls.go
@@ -5,6 +5,7 @@ import (
 	"sync"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
 	"github.com/sagernet/sing-box/common/tls"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
@@ -65,13 +66,21 @@ func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options o
 	}, nil
 }
 
-func (t *TLSTransport) Reset() {
+func (t *TLSTransport) Start(stage adapter.StartStage) error {
+	if stage != adapter.StartStateStart {
+		return nil
+	}
+	return dialer.InitializeDetour(t.dialer)
+}
+
+func (t *TLSTransport) Close() error {
 	t.access.Lock()
 	defer t.access.Unlock()
 	for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
 		connection.Value.Close()
 	}
 	t.connections.Init()
+	return nil
 }
 
 func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
diff --git a/dns/transport/udp.go b/dns/transport/udp.go
index 8c905c4c..8d9f0515 100644
--- a/dns/transport/udp.go
+++ b/dns/transport/udp.go
@@ -7,6 +7,7 @@ import (
 	"sync"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
 	"github.com/sagernet/sing-box/log"
@@ -64,11 +65,19 @@ func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer
 	}
 }
 
-func (t *UDPTransport) Reset() {
+func (t *UDPTransport) Start(stage adapter.StartStage) error {
+	if stage != adapter.StartStateStart {
+		return nil
+	}
+	return dialer.InitializeDetour(t.dialer)
+}
+
+func (t *UDPTransport) Close() error {
 	t.access.Lock()
 	defer t.access.Unlock()
 	close(t.done)
 	t.done = make(chan struct{})
+	return nil
 }
 
 func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
diff --git a/dns/transport_dialer.go b/dns/transport_dialer.go
index 0b15c7ea..b3ee8082 100644
--- a/dns/transport_dialer.go
+++ b/dns/transport_dialer.go
@@ -20,9 +20,10 @@ func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) (
 		return dialer.NewDefaultOutbound(ctx), nil
 	} else {
 		return dialer.NewWithOptions(dialer.Options{
-			Context:        ctx,
-			Options:        options.DialerOptions,
-			DirectResolver: true,
+			Context:         ctx,
+			Options:         options.DialerOptions,
+			DirectResolver:  true,
+			LegacyDNSDialer: options.Legacy,
 		})
 	}
 }
@@ -43,10 +44,11 @@ func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions)
 		return transportDialer, nil
 	} else {
 		return dialer.NewWithOptions(dialer.Options{
-			Context:        ctx,
-			Options:        options.DialerOptions,
-			RemoteIsDomain: options.ServerIsDomain(),
-			DirectResolver: true,
+			Context:         ctx,
+			Options:         options.DialerOptions,
+			RemoteIsDomain:  options.ServerIsDomain(),
+			DirectResolver:  true,
+			LegacyDNSDialer: options.Legacy,
 		})
 	}
 }
diff --git a/dns/transport_manager.go b/dns/transport_manager.go
index 8666a1b9..dff886df 100644
--- a/dns/transport_manager.go
+++ b/dns/transport_manager.go
@@ -225,7 +225,7 @@ func (m *TransportManager) Remove(tag string) error {
 		}
 	}
 	if started {
-		transport.Reset()
+		transport.Close()
 	}
 	return nil
 }
diff --git a/experimental/libbox/dns.go b/experimental/libbox/dns.go
index 7e143442..d5c97b7e 100644
--- a/experimental/libbox/dns.go
+++ b/experimental/libbox/dns.go
@@ -38,7 +38,12 @@ func newPlatformTransport(iif LocalDNSTransport, tag string, options option.Loca
 	}
 }
 
-func (p *platformTransport) Reset() {
+func (p *platformTransport) Start(stage adapter.StartStage) error {
+	return nil
+}
+
+func (p *platformTransport) Close() error {
+	return nil
 }
 
 func (p *platformTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
diff --git a/option/dns.go b/option/dns.go
index 1a42b0f4..f303b894 100644
--- a/option/dns.go
+++ b/option/dns.go
@@ -191,34 +191,24 @@ func (o *DNSServerOptions) Upgrade(ctx context.Context) error {
 			serverType = C.DNSTypeUDP
 		}
 	}
-	var remoteOptions RemoteDNSServerOptions
-	if options.Detour == "" {
-		remoteOptions = RemoteDNSServerOptions{
-			LocalDNSServerOptions: LocalDNSServerOptions{
-				LegacyStrategy:      options.Strategy,
-				LegacyDefaultDialer: options.Detour == "",
-				LegacyClientSubnet:  options.ClientSubnet.Build(netip.Prefix{}),
-			},
-			LegacyAddressResolver:      options.AddressResolver,
-			LegacyAddressStrategy:      options.AddressStrategy,
-			LegacyAddressFallbackDelay: options.AddressFallbackDelay,
-		}
-	} else {
-		remoteOptions = RemoteDNSServerOptions{
-			LocalDNSServerOptions: LocalDNSServerOptions{
-				DialerOptions: DialerOptions{
-					Detour: options.Detour,
-					DomainResolver: &DomainResolveOptions{
-						Server:   options.AddressResolver,
-						Strategy: options.AddressStrategy,
-					},
-					FallbackDelay: options.AddressFallbackDelay,
+	remoteOptions := RemoteDNSServerOptions{
+		LocalDNSServerOptions: LocalDNSServerOptions{
+			DialerOptions: DialerOptions{
+				Detour: options.Detour,
+				DomainResolver: &DomainResolveOptions{
+					Server:   options.AddressResolver,
+					Strategy: options.AddressStrategy,
 				},
-				LegacyStrategy:      options.Strategy,
-				LegacyDefaultDialer: options.Detour == "",
-				LegacyClientSubnet:  options.ClientSubnet.Build(netip.Prefix{}),
+				FallbackDelay: options.AddressFallbackDelay,
 			},
-		}
+			Legacy:              true,
+			LegacyStrategy:      options.Strategy,
+			LegacyDefaultDialer: options.Detour == "",
+			LegacyClientSubnet:  options.ClientSubnet.Build(netip.Prefix{}),
+		},
+		LegacyAddressResolver:      options.AddressResolver,
+		LegacyAddressStrategy:      options.AddressStrategy,
+		LegacyAddressFallbackDelay: options.AddressFallbackDelay,
 	}
 	switch serverType {
 	case C.DNSTypeLocal:
@@ -362,6 +352,7 @@ type HostsDNSServerOptions struct {
 
 type LocalDNSServerOptions struct {
 	DialerOptions
+	Legacy              bool           `json:"-"`
 	LegacyStrategy      DomainStrategy `json:"-"`
 	LegacyDefaultDialer bool           `json:"-"`
 	LegacyClientSubnet  netip.Prefix   `json:"-"`
diff --git a/option/rule.go b/option/rule.go
index b769dab8..41bcc126 100644
--- a/option/rule.go
+++ b/option/rule.go
@@ -125,10 +125,9 @@ func (r *DefaultRule) UnmarshalJSON(data []byte) error {
 	return badjson.UnmarshallExcluded(data, &r.RawDefaultRule, &r.RuleAction)
 }
 
-func (r *DefaultRule) IsValid() bool {
+func (r DefaultRule) IsValid() bool {
 	var defaultValue DefaultRule
 	defaultValue.Invert = r.Invert
-	defaultValue.Action = r.Action
 	return !reflect.DeepEqual(r, defaultValue)
 }
 
diff --git a/option/rule_dns.go b/option/rule_dns.go
index 9d6fb138..87b15017 100644
--- a/option/rule_dns.go
+++ b/option/rule_dns.go
@@ -132,7 +132,6 @@ func (r *DefaultDNSRule) UnmarshalJSONContext(ctx context.Context, data []byte)
 func (r DefaultDNSRule) IsValid() bool {
 	var defaultValue DefaultDNSRule
 	defaultValue.Invert = r.Invert
-	defaultValue.DNSRuleAction = r.DNSRuleAction
 	return !reflect.DeepEqual(r, defaultValue)
 }
 
diff --git a/protocol/direct/outbound.go b/protocol/direct/outbound.go
index 7ad756f2..9cd1490b 100644
--- a/protocol/direct/outbound.go
+++ b/protocol/direct/outbound.go
@@ -4,6 +4,7 @@ import (
 	"context"
 	"net"
 	"net/netip"
+	"reflect"
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
@@ -27,6 +28,7 @@ func RegisterOutbound(registry *outbound.Registry) {
 var (
 	_ N.ParallelDialer             = (*Outbound)(nil)
 	_ dialer.ParallelNetworkDialer = (*Outbound)(nil)
+	_ dialer.DirectDialer          = (*Outbound)(nil)
 )
 
 type Outbound struct {
@@ -37,6 +39,7 @@ type Outbound struct {
 	fallbackDelay       time.Duration
 	overrideOption      int
 	overrideDestination M.Socksaddr
+	isEmpty             bool
 	// loopBack *loopBackDetector
 }
 
@@ -56,6 +59,8 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
 		domainStrategy: C.DomainStrategy(options.DomainStrategy),
 		fallbackDelay:  time.Duration(options.FallbackDelay),
 		dialer:         outboundDialer.(dialer.ParallelInterfaceDialer),
+		//nolint:staticcheck
+		isEmpty: reflect.DeepEqual(options.DialerOptions, option.DialerOptions{UDPFragmentDefault: true}) && options.OverrideAddress == "" && options.OverridePort == 0,
 		// loopBack:       newLoopBackDetector(router),
 	}
 	//nolint:staticcheck
@@ -242,6 +247,10 @@ func (h *Outbound) ListenSerialNetworkPacket(ctx context.Context, destination M.
 	return conn, newDestination, nil
 }
 
+func (h *Outbound) IsEmpty() bool {
+	return h.isEmpty
+}
+
 /*func (h *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
 	if h.loopBack.CheckConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) {
 		return E.New("reject loopback connection to ", metadata.Destination)