diff --git a/adapter/dns.go b/adapter/dns.go index e0f381b8..942f3566 100644 --- a/adapter/dns.go +++ b/adapter/dns.go @@ -45,10 +45,10 @@ type RDRCStore interface { } type DNSTransport interface { + Lifecycle Type() string Tag() string Dependencies() []string - Reset() Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) } diff --git a/common/dialer/detour.go b/common/dialer/detour.go index e4a46049..5c0b552b 100644 --- a/common/dialer/detour.go +++ b/common/dialer/detour.go @@ -6,26 +6,39 @@ import ( "sync" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) +type DirectDialer interface { + IsEmpty() bool +} + type DetourDialer struct { outboundManager adapter.OutboundManager detour string + legacyDNSDialer bool dialer N.Dialer initOnce sync.Once initErr error } -func NewDetour(outboundManager adapter.OutboundManager, detour string) N.Dialer { - return &DetourDialer{outboundManager: outboundManager, detour: detour} +func NewDetour(outboundManager adapter.OutboundManager, detour string, legacyDNSDialer bool) N.Dialer { + return &DetourDialer{ + outboundManager: outboundManager, + detour: detour, + legacyDNSDialer: legacyDNSDialer, + } } -func (d *DetourDialer) Start() error { - _, err := d.Dialer() - return err +func InitializeDetour(dialer N.Dialer) error { + detourDialer, isDetour := common.Cast[*DetourDialer](dialer) + if !isDetour { + return nil + } + return common.Error(detourDialer.Dialer()) } func (d *DetourDialer) Dialer() (N.Dialer, error) { @@ -34,11 +47,20 @@ func (d *DetourDialer) Dialer() (N.Dialer, error) { } func (d *DetourDialer) init() { - var loaded bool - d.dialer, loaded = d.outboundManager.Outbound(d.detour) + dialer, loaded := d.outboundManager.Outbound(d.detour) if !loaded { d.initErr = E.New("outbound detour not found: ", d.detour) + return } + if !d.legacyDNSDialer { + if directDialer, isDirect := dialer.(DirectDialer); isDirect { + if directDialer.IsEmpty() { + d.initErr = E.New("detour to an empty direct outbound makes no sense") + return + } + } + } + d.dialer = dialer } func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { diff --git a/common/dialer/dialer.go b/common/dialer/dialer.go index b93b7096..88e16740 100644 --- a/common/dialer/dialer.go +++ b/common/dialer/dialer.go @@ -23,6 +23,7 @@ type Options struct { DirectResolver bool ResolverOnDetour bool NewDialer bool + LegacyDNSDialer bool } // TODO: merge with NewWithOptions @@ -45,7 +46,7 @@ func NewWithOptions(options Options) (N.Dialer, error) { if outboundManager == nil { return nil, E.New("missing outbound manager") } - dialer = NewDetour(outboundManager, dialOptions.Detour) + dialer = NewDetour(outboundManager, dialOptions.Detour, options.LegacyDNSDialer) } else { dialer, err = NewDefault(options.Context, dialOptions) if err != nil { diff --git a/dns/router.go b/dns/router.go index 42cebd23..44edadbd 100644 --- a/dns/router.go +++ b/dns/router.go @@ -449,6 +449,6 @@ func (r *Router) LookupReverseMapping(ip netip.Addr) (string, bool) { func (r *Router) ResetNetwork() { r.ClearCache() for _, transport := range r.transport.Transports() { - transport.Reset() + transport.Close() } } diff --git a/dns/transport/dhcp/dhcp.go b/dns/transport/dhcp/dhcp.go index c75d7369..92dd1f8b 100644 --- a/dns/transport/dhcp/dhcp.go +++ b/dns/transport/dhcp/dhcp.go @@ -81,7 +81,7 @@ func (t *Transport) Start(stage adapter.StartStage) error { func (t *Transport) Close() error { for _, transport := range t.transports { - transport.Reset() + transport.Close() } if t.interfaceCallback != nil { t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback) @@ -89,12 +89,6 @@ func (t *Transport) Close() error { return nil } -func (t *Transport) Reset() { - for _, transport := range t.transports { - transport.Reset() - } -} - func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { err := t.fetchServers() if err != nil { @@ -252,7 +246,7 @@ func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []M.So transports = append(transports, transport.NewUDPRaw(t.logger, t.TransportAdapter, serverDialer, serverAddr)) } for _, transport := range t.transports { - transport.Reset() + transport.Close() } t.transports = transports return nil diff --git a/dns/transport/hosts/hosts.go b/dns/transport/hosts/hosts.go index 0a1dd395..a5eecb40 100644 --- a/dns/transport/hosts/hosts.go +++ b/dns/transport/hosts/hosts.go @@ -51,7 +51,12 @@ func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, opt }, nil } -func (t *Transport) Reset() { +func (t *Transport) Start(stage adapter.StartStage) error { + return nil +} + +func (t *Transport) Close() error { + return nil } func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { diff --git a/dns/transport/https.go b/dns/transport/https.go index bd150d5f..1750fd26 100644 --- a/dns/transport/https.go +++ b/dns/transport/https.go @@ -10,6 +10,7 @@ import ( "strconv" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" @@ -149,9 +150,17 @@ func NewHTTPSRaw( } } -func (t *HTTPSTransport) Reset() { +func (t *HTTPSTransport) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + return dialer.InitializeDetour(t.dialer) +} + +func (t *HTTPSTransport) Close() error { t.transport.CloseIdleConnections() t.transport = t.transport.Clone() + return nil } func (t *HTTPSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { diff --git a/dns/transport/local/local.go b/dns/transport/local/local.go index 4e05ff02..f50405a5 100644 --- a/dns/transport/local/local.go +++ b/dns/transport/local/local.go @@ -40,7 +40,12 @@ func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, opt }, nil } -func (t *Transport) Reset() { +func (t *Transport) Start(stage adapter.StartStage) error { + return nil +} + +func (t *Transport) Close() error { + return nil } func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { diff --git a/dns/transport/quic/http3.go b/dns/transport/quic/http3.go index e2a75b50..0d871741 100644 --- a/dns/transport/quic/http3.go +++ b/dns/transport/quic/http3.go @@ -111,8 +111,12 @@ func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options }, nil } -func (t *HTTP3Transport) Reset() { - t.transport.Close() +func (t *HTTP3Transport) Start(stage adapter.StartStage) error { + return nil +} + +func (t *HTTP3Transport) Close() error { + return t.transport.Close() } func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { diff --git a/dns/transport/quic/quic.go b/dns/transport/quic/quic.go index 4ae9ac16..fc5101ee 100644 --- a/dns/transport/quic/quic.go +++ b/dns/transport/quic/quic.go @@ -68,13 +68,18 @@ func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options }, nil } -func (t *Transport) Reset() { +func (t *Transport) Start(stage adapter.StartStage) error { + return nil +} + +func (t *Transport) Close() error { t.access.Lock() defer t.access.Unlock() connection := t.connection if connection != nil { connection.CloseWithError(0, "") } + return nil } func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { diff --git a/dns/transport/tcp.go b/dns/transport/tcp.go index 4abeee2f..a814c030 100644 --- a/dns/transport/tcp.go +++ b/dns/transport/tcp.go @@ -6,6 +6,7 @@ import ( "io" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/log" @@ -46,7 +47,15 @@ func NewTCP(ctx context.Context, logger log.ContextLogger, tag string, options o }, nil } -func (t *TCPTransport) Reset() { +func (t *TCPTransport) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + return dialer.InitializeDetour(t.dialer) +} + +func (t *TCPTransport) Close() error { + return nil } func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { diff --git a/dns/transport/tls.go b/dns/transport/tls.go index ce88d425..a99bf2f7 100644 --- a/dns/transport/tls.go +++ b/dns/transport/tls.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" @@ -65,13 +66,21 @@ func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options o }, nil } -func (t *TLSTransport) Reset() { +func (t *TLSTransport) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + return dialer.InitializeDetour(t.dialer) +} + +func (t *TLSTransport) Close() error { t.access.Lock() defer t.access.Unlock() for connection := t.connections.Front(); connection != nil; connection = connection.Next() { connection.Value.Close() } t.connections.Init() + return nil } func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { diff --git a/dns/transport/udp.go b/dns/transport/udp.go index 8c905c4c..8d9f0515 100644 --- a/dns/transport/udp.go +++ b/dns/transport/udp.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/log" @@ -64,11 +65,19 @@ func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer } } -func (t *UDPTransport) Reset() { +func (t *UDPTransport) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + return dialer.InitializeDetour(t.dialer) +} + +func (t *UDPTransport) Close() error { t.access.Lock() defer t.access.Unlock() close(t.done) t.done = make(chan struct{}) + return nil } func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { diff --git a/dns/transport_dialer.go b/dns/transport_dialer.go index 0b15c7ea..b3ee8082 100644 --- a/dns/transport_dialer.go +++ b/dns/transport_dialer.go @@ -20,9 +20,10 @@ func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) ( return dialer.NewDefaultOutbound(ctx), nil } else { return dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: options.DialerOptions, - DirectResolver: true, + Context: ctx, + Options: options.DialerOptions, + DirectResolver: true, + LegacyDNSDialer: options.Legacy, }) } } @@ -43,10 +44,11 @@ func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) return transportDialer, nil } else { return dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: options.DialerOptions, - RemoteIsDomain: options.ServerIsDomain(), - DirectResolver: true, + Context: ctx, + Options: options.DialerOptions, + RemoteIsDomain: options.ServerIsDomain(), + DirectResolver: true, + LegacyDNSDialer: options.Legacy, }) } } diff --git a/dns/transport_manager.go b/dns/transport_manager.go index 8666a1b9..dff886df 100644 --- a/dns/transport_manager.go +++ b/dns/transport_manager.go @@ -225,7 +225,7 @@ func (m *TransportManager) Remove(tag string) error { } } if started { - transport.Reset() + transport.Close() } return nil } diff --git a/experimental/libbox/dns.go b/experimental/libbox/dns.go index 7e143442..d5c97b7e 100644 --- a/experimental/libbox/dns.go +++ b/experimental/libbox/dns.go @@ -38,7 +38,12 @@ func newPlatformTransport(iif LocalDNSTransport, tag string, options option.Loca } } -func (p *platformTransport) Reset() { +func (p *platformTransport) Start(stage adapter.StartStage) error { + return nil +} + +func (p *platformTransport) Close() error { + return nil } func (p *platformTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { diff --git a/option/dns.go b/option/dns.go index 1a42b0f4..f303b894 100644 --- a/option/dns.go +++ b/option/dns.go @@ -191,34 +191,24 @@ func (o *DNSServerOptions) Upgrade(ctx context.Context) error { serverType = C.DNSTypeUDP } } - var remoteOptions RemoteDNSServerOptions - if options.Detour == "" { - remoteOptions = RemoteDNSServerOptions{ - LocalDNSServerOptions: LocalDNSServerOptions{ - LegacyStrategy: options.Strategy, - LegacyDefaultDialer: options.Detour == "", - LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}), - }, - LegacyAddressResolver: options.AddressResolver, - LegacyAddressStrategy: options.AddressStrategy, - LegacyAddressFallbackDelay: options.AddressFallbackDelay, - } - } else { - remoteOptions = RemoteDNSServerOptions{ - LocalDNSServerOptions: LocalDNSServerOptions{ - DialerOptions: DialerOptions{ - Detour: options.Detour, - DomainResolver: &DomainResolveOptions{ - Server: options.AddressResolver, - Strategy: options.AddressStrategy, - }, - FallbackDelay: options.AddressFallbackDelay, + remoteOptions := RemoteDNSServerOptions{ + LocalDNSServerOptions: LocalDNSServerOptions{ + DialerOptions: DialerOptions{ + Detour: options.Detour, + DomainResolver: &DomainResolveOptions{ + Server: options.AddressResolver, + Strategy: options.AddressStrategy, }, - LegacyStrategy: options.Strategy, - LegacyDefaultDialer: options.Detour == "", - LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}), + FallbackDelay: options.AddressFallbackDelay, }, - } + Legacy: true, + LegacyStrategy: options.Strategy, + LegacyDefaultDialer: options.Detour == "", + LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}), + }, + LegacyAddressResolver: options.AddressResolver, + LegacyAddressStrategy: options.AddressStrategy, + LegacyAddressFallbackDelay: options.AddressFallbackDelay, } switch serverType { case C.DNSTypeLocal: @@ -362,6 +352,7 @@ type HostsDNSServerOptions struct { type LocalDNSServerOptions struct { DialerOptions + Legacy bool `json:"-"` LegacyStrategy DomainStrategy `json:"-"` LegacyDefaultDialer bool `json:"-"` LegacyClientSubnet netip.Prefix `json:"-"` diff --git a/option/rule.go b/option/rule.go index b769dab8..41bcc126 100644 --- a/option/rule.go +++ b/option/rule.go @@ -125,10 +125,9 @@ func (r *DefaultRule) UnmarshalJSON(data []byte) error { return badjson.UnmarshallExcluded(data, &r.RawDefaultRule, &r.RuleAction) } -func (r *DefaultRule) IsValid() bool { +func (r DefaultRule) IsValid() bool { var defaultValue DefaultRule defaultValue.Invert = r.Invert - defaultValue.Action = r.Action return !reflect.DeepEqual(r, defaultValue) } diff --git a/option/rule_dns.go b/option/rule_dns.go index 9d6fb138..87b15017 100644 --- a/option/rule_dns.go +++ b/option/rule_dns.go @@ -132,7 +132,6 @@ func (r *DefaultDNSRule) UnmarshalJSONContext(ctx context.Context, data []byte) func (r DefaultDNSRule) IsValid() bool { var defaultValue DefaultDNSRule defaultValue.Invert = r.Invert - defaultValue.DNSRuleAction = r.DNSRuleAction return !reflect.DeepEqual(r, defaultValue) } diff --git a/protocol/direct/outbound.go b/protocol/direct/outbound.go index 7ad756f2..9cd1490b 100644 --- a/protocol/direct/outbound.go +++ b/protocol/direct/outbound.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/netip" + "reflect" "time" "github.com/sagernet/sing-box/adapter" @@ -27,6 +28,7 @@ func RegisterOutbound(registry *outbound.Registry) { var ( _ N.ParallelDialer = (*Outbound)(nil) _ dialer.ParallelNetworkDialer = (*Outbound)(nil) + _ dialer.DirectDialer = (*Outbound)(nil) ) type Outbound struct { @@ -37,6 +39,7 @@ type Outbound struct { fallbackDelay time.Duration overrideOption int overrideDestination M.Socksaddr + isEmpty bool // loopBack *loopBackDetector } @@ -56,6 +59,8 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL domainStrategy: C.DomainStrategy(options.DomainStrategy), fallbackDelay: time.Duration(options.FallbackDelay), dialer: outboundDialer.(dialer.ParallelInterfaceDialer), + //nolint:staticcheck + isEmpty: reflect.DeepEqual(options.DialerOptions, option.DialerOptions{UDPFragmentDefault: true}) && options.OverrideAddress == "" && options.OverridePort == 0, // loopBack: newLoopBackDetector(router), } //nolint:staticcheck @@ -242,6 +247,10 @@ func (h *Outbound) ListenSerialNetworkPacket(ctx context.Context, destination M. return conn, newDestination, nil } +func (h *Outbound) IsEmpty() bool { + return h.isEmpty +} + /*func (h *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { if h.loopBack.CheckConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) { return E.New("reject loopback connection to ", metadata.Destination)