From 86116b9423d8572196bf05b114b26269ffb665a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 16 Mar 2025 14:50:44 +0800 Subject: [PATCH] refactor: DNS --- .goreleaser.fury.yaml | 4 +- adapter/dns.go | 73 +++ adapter/experimental.go | 3 +- adapter/fakeip.go | 3 +- adapter/inbound.go | 2 - adapter/outbound/manager.go | 17 +- adapter/router.go | 17 - adapter/rule.go | 1 - box.go | 107 +++- cmd/sing-box/cmd.go | 2 +- common/dialer/dialer.go | 9 +- common/dialer/resolve.go | 60 +- common/dialer/router.go | 15 +- common/tls/ech_client.go | 4 +- common/tls/std_client.go | 5 +- constant/dns.go | 29 + dns/client.go | 563 ++++++++++++++++++ dns/client_log.go | 69 +++ dns/client_truncate.go | 29 + dns/extension_edns0_subnet.go | 56 ++ dns/rcode.go | 33 + dns/router.go | 437 ++++++++++++++ .../server.go => dns/transport/dhcp/dhcp.go | 135 ++--- dns/transport/fakeip/fakeip.go | 67 +++ {transport => dns/transport}/fakeip/memory.go | 0 {transport => dns/transport}/fakeip/store.go | 0 dns/transport/hosts/hosts.go | 63 ++ dns/transport/hosts/hosts_file.go | 102 ++++ dns/transport/hosts/hosts_test.go | 16 + dns/transport/hosts/hosts_unix.go | 5 + dns/transport/hosts/hosts_windows.go | 17 + dns/transport/hosts/testdata/hosts | 2 + dns/transport/https.go | 204 +++++++ dns/transport/local/local.go | 197 ++++++ dns/transport/local/resolv.go | 146 +++++ dns/transport/local/resolv_unix.go | 175 ++++++ dns/transport/local/resolv_windows.go | 100 ++++ dns/transport/predefined.go | 83 +++ dns/transport/quic/http3.go | 167 ++++++ dns/transport/quic/quic.go | 174 ++++++ dns/transport/tcp.go | 99 +++ dns/transport/tls.go | 115 ++++ dns/transport/udp.go | 223 +++++++ dns/transport_adapter.go | 70 +++ dns/transport_dialer.go | 93 +++ dns/transport_manager.go | 288 +++++++++ dns/transport_registry.go | 72 +++ experimental/clashapi/dns.go | 6 +- experimental/clashapi/server.go | 16 +- experimental/deprecated/constants.go | 15 + experimental/libbox/config.go | 19 +- experimental/libbox/dns.go | 161 ++--- experimental/libbox/platform.go | 1 + experimental/libbox/service.go | 6 +- go.mod | 1 - go.sum | 2 - include/dhcp.go | 9 +- include/dhcp_stub.go | 12 +- include/quic.go | 8 +- include/quic_stub.go | 14 +- include/registry.go | 27 +- option/dns.go | 315 +++++++++- option/dns_record.go | 161 +++++ option/rule_action.go | 13 +- option/rule_dns.go | 1 + option/types.go | 37 +- protocol/direct/outbound.go | 29 +- protocol/dns/handle.go | 14 +- protocol/dns/outbound.go | 5 +- protocol/socks/outbound.go | 17 +- protocol/wireguard/endpoint.go | 12 +- protocol/wireguard/outbound.go | 14 +- release/config/config.json | 9 +- route/dns.go | 27 +- route/geo_resources.go | 246 -------- route/route.go | 46 +- route/route_dns.go | 348 ----------- route/router.go | 414 ++----------- route/rule/rule_abstract.go | 25 - route/rule/rule_action.go | 14 +- route/rule/rule_default.go | 12 +- route/rule/rule_dns.go | 17 +- route/rule/rule_item_geoip.go | 98 --- route/rule/rule_item_geosite.go | 61 -- route/rule/rule_item_ip_accept_any.go | 21 + route/rule_conds.go | 21 - test/domain_inbound_test.go | 3 +- transport/fakeip/server.go | 95 --- transport/v2rayhttp/server.go | 2 +- 89 files changed, 4792 insertions(+), 1733 deletions(-) create mode 100644 adapter/dns.go create mode 100644 dns/client.go create mode 100644 dns/client_log.go create mode 100644 dns/client_truncate.go create mode 100644 dns/extension_edns0_subnet.go create mode 100644 dns/rcode.go create mode 100644 dns/router.go rename transport/dhcp/server.go => dns/transport/dhcp/dhcp.go (66%) create mode 100644 dns/transport/fakeip/fakeip.go rename {transport => dns/transport}/fakeip/memory.go (100%) rename {transport => dns/transport}/fakeip/store.go (100%) create mode 100644 dns/transport/hosts/hosts.go create mode 100644 dns/transport/hosts/hosts_file.go create mode 100644 dns/transport/hosts/hosts_test.go create mode 100644 dns/transport/hosts/hosts_unix.go create mode 100644 dns/transport/hosts/hosts_windows.go create mode 100644 dns/transport/hosts/testdata/hosts create mode 100644 dns/transport/https.go create mode 100644 dns/transport/local/local.go create mode 100644 dns/transport/local/resolv.go create mode 100644 dns/transport/local/resolv_unix.go create mode 100644 dns/transport/local/resolv_windows.go create mode 100644 dns/transport/predefined.go create mode 100644 dns/transport/quic/http3.go create mode 100644 dns/transport/quic/quic.go create mode 100644 dns/transport/tcp.go create mode 100644 dns/transport/tls.go create mode 100644 dns/transport/udp.go create mode 100644 dns/transport_adapter.go create mode 100644 dns/transport_dialer.go create mode 100644 dns/transport_manager.go create mode 100644 dns/transport_registry.go create mode 100644 option/dns_record.go delete mode 100644 route/geo_resources.go delete mode 100644 route/route_dns.go delete mode 100644 route/rule/rule_item_geoip.go delete mode 100644 route/rule/rule_item_geosite.go create mode 100644 route/rule/rule_item_ip_accept_any.go delete mode 100644 transport/fakeip/server.go diff --git a/.goreleaser.fury.yaml b/.goreleaser.fury.yaml index fbd1ae42..d80dd408 100644 --- a/.goreleaser.fury.yaml +++ b/.goreleaser.fury.yaml @@ -6,7 +6,9 @@ builds: - -v - -trimpath ldflags: - - -X github.com/sagernet/sing-box/constant.Version={{ .Version }} -s -w -buildid= + - -X github.com/sagernet/sing-box/constant.Version={{ .Version }} + - -s + - -buildid= tags: - with_gvisor - with_quic diff --git a/adapter/dns.go b/adapter/dns.go new file mode 100644 index 00000000..e0f381b8 --- /dev/null +++ b/adapter/dns.go @@ -0,0 +1,73 @@ +package adapter + +import ( + "context" + "net/netip" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/logger" + + "github.com/miekg/dns" +) + +type DNSRouter interface { + Lifecycle + Exchange(ctx context.Context, message *dns.Msg, options DNSQueryOptions) (*dns.Msg, error) + Lookup(ctx context.Context, domain string, options DNSQueryOptions) ([]netip.Addr, error) + ClearCache() + LookupReverseMapping(ip netip.Addr) (string, bool) + ResetNetwork() +} + +type DNSClient interface { + Start() + Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) + Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) + LookupCache(domain string, strategy C.DomainStrategy) ([]netip.Addr, bool) + ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool) + ClearCache() +} + +type DNSQueryOptions struct { + Transport DNSTransport + Strategy C.DomainStrategy + DisableCache bool + RewriteTTL *uint32 + ClientSubnet netip.Prefix +} + +type RDRCStore interface { + LoadRDRC(transportName string, qName string, qType uint16) (rejected bool) + SaveRDRC(transportName string, qName string, qType uint16) error + SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger) +} + +type DNSTransport interface { + Type() string + Tag() string + Dependencies() []string + Reset() + Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) +} + +type LegacyDNSTransport interface { + LegacyStrategy() C.DomainStrategy + LegacyClientSubnet() netip.Prefix +} + +type DNSTransportRegistry interface { + option.DNSTransportOptionsRegistry + CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (DNSTransport, error) +} + +type DNSTransportManager interface { + Lifecycle + Transports() []DNSTransport + Transport(tag string) (DNSTransport, bool) + Default() DNSTransport + FakeIP() FakeIPTransport + Remove(tag string) error + Create(ctx context.Context, logger log.ContextLogger, tag string, outboundType string, options any) error +} diff --git a/adapter/experimental.go b/adapter/experimental.go index 648eb418..99d7c9a5 100644 --- a/adapter/experimental.go +++ b/adapter/experimental.go @@ -7,7 +7,6 @@ import ( "time" "github.com/sagernet/sing-box/common/urltest" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common/varbin" ) @@ -31,7 +30,7 @@ type CacheFile interface { FakeIPStorage StoreRDRC() bool - dns.RDRCStore + RDRCStore LoadMode() string StoreMode(mode string) error diff --git a/adapter/fakeip.go b/adapter/fakeip.go index 51247c32..97d1c3c0 100644 --- a/adapter/fakeip.go +++ b/adapter/fakeip.go @@ -3,7 +3,6 @@ package adapter import ( "net/netip" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common/logger" ) @@ -27,6 +26,6 @@ type FakeIPStorage interface { } type FakeIPTransport interface { - dns.Transport + DNSTransport Store() FakeIPStore } diff --git a/adapter/inbound.go b/adapter/inbound.go index 173dd0ee..11085099 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -78,8 +78,6 @@ type InboundContext struct { FallbackNetworkType []C.InterfaceType FallbackDelay time.Duration - DNSServer string - DestinationAddresses []netip.Addr SourceGeoIPCode string GeoIPCode string diff --git a/adapter/outbound/manager.go b/adapter/outbound/manager.go index c3941d02..977fe4ca 100644 --- a/adapter/outbound/manager.go +++ b/adapter/outbound/manager.go @@ -23,7 +23,7 @@ type Manager struct { registry adapter.OutboundRegistry endpoint adapter.EndpointManager defaultTag string - access sync.Mutex + access sync.RWMutex started bool stage adapter.StartStage outbounds []adapter.Outbound @@ -169,15 +169,15 @@ func (m *Manager) Close() error { } func (m *Manager) Outbounds() []adapter.Outbound { - m.access.Lock() - defer m.access.Unlock() + m.access.RLock() + defer m.access.RUnlock() return m.outbounds } func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) { - m.access.Lock() + m.access.RLock() outbound, found := m.outboundByTag[tag] - m.access.Unlock() + m.access.RUnlock() if found { return outbound, true } @@ -185,8 +185,8 @@ func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) { } func (m *Manager) Default() adapter.Outbound { - m.access.Lock() - defer m.access.Unlock() + m.access.RLock() + defer m.access.RUnlock() if m.defaultOutbound != nil { return m.defaultOutbound } else { @@ -196,9 +196,9 @@ func (m *Manager) Default() adapter.Outbound { func (m *Manager) Remove(tag string) error { m.access.Lock() + defer m.access.Unlock() outbound, found := m.outboundByTag[tag] if !found { - m.access.Unlock() return os.ErrInvalid } delete(m.outboundByTag, tag) @@ -232,7 +232,6 @@ func (m *Manager) Remove(tag string) error { }) } } - m.access.Unlock() if started { return common.Close(outbound) } diff --git a/adapter/router.go b/adapter/router.go index a637e506..9cefb49d 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -4,42 +4,25 @@ import ( "context" "net" "net/http" - "net/netip" "sync" - "github.com/sagernet/sing-box/common/geoip" C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-dns" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" - mdns "github.com/miekg/dns" "go4.org/netipx" ) type Router interface { Lifecycle - - FakeIPStore() FakeIPStore - ConnectionRouter PreMatch(metadata InboundContext) error ConnectionRouterEx - - GeoIPReader() *geoip.Reader - LoadGeosite(code string) (Rule, error) RuleSet(tag string) (RuleSet, bool) NeedWIFIState() bool - - Exchange(ctx context.Context, message *mdns.Msg) (*mdns.Msg, error) - Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) - LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) - ClearDNSCache() Rules() []Rule - SetTracker(tracker ConnectionTracker) - ResetNetwork() } diff --git a/adapter/rule.go b/adapter/rule.go index f3737a25..2512a77b 100644 --- a/adapter/rule.go +++ b/adapter/rule.go @@ -13,7 +13,6 @@ type Rule interface { HeadlessRule Service Type() string - UpdateGeosite() error Action() RuleAction } diff --git a/box.go b/box.go index 8eb8f2f3..05431663 100644 --- a/box.go +++ b/box.go @@ -16,6 +16,8 @@ import ( "github.com/sagernet/sing-box/common/taskmonitor" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport/local" "github.com/sagernet/sing-box/experimental" "github.com/sagernet/sing-box/experimental/cachefile" "github.com/sagernet/sing-box/experimental/libbox/platform" @@ -34,17 +36,19 @@ import ( var _ adapter.Service = (*Box)(nil) type Box struct { - createdAt time.Time - logFactory log.Factory - logger log.ContextLogger - network *route.NetworkManager - endpoint *endpoint.Manager - inbound *inbound.Manager - outbound *outbound.Manager - connection *route.ConnectionManager - router *route.Router - services []adapter.LifecycleService - done chan struct{} + createdAt time.Time + logFactory log.Factory + logger log.ContextLogger + network *route.NetworkManager + endpoint *endpoint.Manager + inbound *inbound.Manager + outbound *outbound.Manager + dnsTransport *dns.TransportManager + dnsRouter *dns.Router + connection *route.ConnectionManager + router *route.Router + services []adapter.LifecycleService + done chan struct{} } type Options struct { @@ -58,6 +62,7 @@ func Context( inboundRegistry adapter.InboundRegistry, outboundRegistry adapter.OutboundRegistry, endpointRegistry adapter.EndpointRegistry, + dnsTransportRegistry adapter.DNSTransportRegistry, ) context.Context { if service.FromContext[option.InboundOptionsRegistry](ctx) == nil || service.FromContext[adapter.InboundRegistry](ctx) == nil { @@ -74,6 +79,10 @@ func Context( ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry) ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry) } + if service.FromContext[adapter.DNSTransportRegistry](ctx) == nil { + ctx = service.ContextWith[option.DNSTransportOptionsRegistry](ctx, dnsTransportRegistry) + ctx = service.ContextWith[adapter.DNSTransportRegistry](ctx, dnsTransportRegistry) + } return ctx } @@ -88,6 +97,7 @@ func New(options Options) (*Box, error) { endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx) inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx) + dnsTransportRegistry := service.FromContext[adapter.DNSTransportRegistry](ctx) if endpointRegistry == nil { return nil, E.New("missing endpoint registry in context") @@ -132,13 +142,17 @@ func New(options Options) (*Box, error) { } routeOptions := common.PtrValueOrDefault(options.Route) + dnsOptions := common.PtrValueOrDefault(options.DNS) endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry) inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager) outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final) + dnsTransportManager := dns.NewTransportManager(logFactory.NewLogger("dns/transport"), dnsTransportRegistry, outboundManager, dnsOptions.Final) service.MustRegister[adapter.EndpointManager](ctx, endpointManager) service.MustRegister[adapter.InboundManager](ctx, inboundManager) service.MustRegister[adapter.OutboundManager](ctx, outboundManager) - + service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager) + dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions) + service.MustRegister[adapter.DNSRouter](ctx, dnsRouter) networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions) if err != nil { return nil, E.Cause(err, "initialize network manager") @@ -146,18 +160,40 @@ func New(options Options) (*Box, error) { service.MustRegister[adapter.NetworkManager](ctx, networkManager) connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection")) service.MustRegister[adapter.ConnectionManager](ctx, connectionManager) - router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS)) + router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions) + service.MustRegister[adapter.Router](ctx, router) + err = router.Initialize(routeOptions.Rules, routeOptions.RuleSet) if err != nil { return nil, E.Cause(err, "initialize router") } - ntpOptions := common.PtrValueOrDefault(options.NTP) var timeService *tls.TimeServiceWrapper if ntpOptions.Enabled { timeService = new(tls.TimeServiceWrapper) service.MustRegister[ntp.TimeService](ctx, timeService) } - + for i, transportOptions := range dnsOptions.Servers { + var tag string + if transportOptions.Tag != "" { + tag = transportOptions.Tag + } else { + tag = F.ToString(i) + } + err = dnsTransportManager.Create( + ctx, + logFactory.NewLogger(F.ToString("dns/", transportOptions.Type, "[", tag, "]")), + tag, + transportOptions.Type, + transportOptions.Options, + ) + if err != nil { + return nil, E.Cause(err, "initialize inbound[", i, "]") + } + } + err = dnsRouter.Initialize(dnsOptions.Rules) + if err != nil { + return nil, E.Cause(err, "initialize dns router") + } for i, endpointOptions := range options.Endpoints { var tag string if endpointOptions.Tag != "" { @@ -238,6 +274,13 @@ func New(options Options) (*Box, error) { option.DirectOutboundOptions{}, ), )) + dnsTransportManager.Initialize(common.Must1( + local.NewTransport( + ctx, + logFactory.NewLogger("dns/local"), + "local", + option.LocalDNSServerOptions{}, + ))) if platformInterface != nil { err = platformInterface.Initialize(networkManager) if err != nil { @@ -289,17 +332,19 @@ func New(options Options) (*Box, error) { services = append(services, adapter.NewLifecycleService(ntpService, "ntp service")) } return &Box{ - network: networkManager, - endpoint: endpointManager, - inbound: inboundManager, - outbound: outboundManager, - connection: connectionManager, - router: router, - createdAt: createdAt, - logFactory: logFactory, - logger: logFactory.Logger(), - services: services, - done: make(chan struct{}), + network: networkManager, + endpoint: endpointManager, + inbound: inboundManager, + outbound: outboundManager, + dnsTransport: dnsTransportManager, + dnsRouter: dnsRouter, + connection: connectionManager, + router: router, + createdAt: createdAt, + logFactory: logFactory, + logger: logFactory.Logger(), + services: services, + done: make(chan struct{}), }, nil } @@ -353,11 +398,11 @@ func (s *Box) preStart() error { if err != nil { return err } - err = adapter.Start(adapter.StartStateInitialize, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint) + err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint) if err != nil { return err } - err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.connection, s.router) + err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router) if err != nil { return err } @@ -381,7 +426,7 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.connection, s.router, s.inbound, s.endpoint) + err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint) if err != nil { return err } @@ -389,7 +434,7 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(adapter.StartStateStarted, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint) + err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint) if err != nil { return err } @@ -408,7 +453,7 @@ func (s *Box) Close() error { close(s.done) } err := common.Close( - s.inbound, s.outbound, s.endpoint, s.router, s.connection, s.network, + s.inbound, s.outbound, s.endpoint, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network, ) for _, lifecycleService := range s.services { err = E.Append(err, lifecycleService.Close(), func(err error) error { diff --git a/cmd/sing-box/cmd.go b/cmd/sing-box/cmd.go index d55235b8..55fe1179 100644 --- a/cmd/sing-box/cmd.go +++ b/cmd/sing-box/cmd.go @@ -69,5 +69,5 @@ func preRun(cmd *cobra.Command, args []string) { configPaths = append(configPaths, "config.json") } globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger())) - globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()) + globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry(), include.DNSTransportRegistry()) } diff --git a/common/dialer/dialer.go b/common/dialer/dialer.go index 89d1eeab..f63e3864 100644 --- a/common/dialer/dialer.go +++ b/common/dialer/dialer.go @@ -9,7 +9,6 @@ import ( "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-dns" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -37,13 +36,13 @@ func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) { dialer = NewDetour(outboundManager, options.Detour) } if options.Detour == "" { - router := service.FromContext[adapter.Router](ctx) + router := service.FromContext[adapter.DNSRouter](ctx) if router != nil { dialer = NewResolveDialer( router, dialer, options.Detour == "" && !options.TCPFastOpen, - dns.DomainStrategy(options.DomainStrategy), + C.DomainStrategy(options.DomainStrategy), time.Duration(options.FallbackDelay)) } } @@ -62,10 +61,10 @@ func NewDirect(ctx context.Context, options option.DialerOptions) (ParallelInter return nil, err } return NewResolveParallelInterfaceDialer( - service.FromContext[adapter.Router](ctx), + service.FromContext[adapter.DNSRouter](ctx), dialer, true, - dns.DomainStrategy(options.DomainStrategy), + C.DomainStrategy(options.DomainStrategy), time.Duration(options.FallbackDelay), ), nil } diff --git a/common/dialer/resolve.go b/common/dialer/resolve.go index ede1afd6..3d667a6c 100644 --- a/common/dialer/resolve.go +++ b/common/dialer/resolve.go @@ -3,13 +3,11 @@ package dialer import ( "context" "net" - "net/netip" "time" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -23,12 +21,12 @@ var ( type resolveDialer struct { dialer N.Dialer parallel bool - router adapter.Router - strategy dns.DomainStrategy + router adapter.DNSRouter + strategy C.DomainStrategy fallbackDelay time.Duration } -func NewResolveDialer(router adapter.Router, dialer N.Dialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) N.Dialer { +func NewResolveDialer(router adapter.DNSRouter, dialer N.Dialer, parallel bool, strategy C.DomainStrategy, fallbackDelay time.Duration) N.Dialer { return &resolveDialer{ dialer, parallel, @@ -43,7 +41,7 @@ type resolveParallelNetworkDialer struct { dialer ParallelInterfaceDialer } -func NewResolveParallelInterfaceDialer(router adapter.Router, dialer ParallelInterfaceDialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) ParallelInterfaceDialer { +func NewResolveParallelInterfaceDialer(router adapter.DNSRouter, dialer ParallelInterfaceDialer, parallel bool, strategy C.DomainStrategy, fallbackDelay time.Duration) ParallelInterfaceDialer { return &resolveParallelNetworkDialer{ resolveDialer{ dialer, @@ -60,22 +58,13 @@ func (d *resolveDialer) DialContext(ctx context.Context, network string, destina if !destination.IsFqdn() { return d.dialer.DialContext(ctx, network, destination) } - ctx, metadata := adapter.ExtendContext(ctx) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) - metadata.Destination = destination - metadata.Domain = "" - var addresses []netip.Addr - var err error - if d.strategy == dns.DomainStrategyAsIS { - addresses, err = d.router.LookupDefault(ctx, destination.Fqdn) - } else { - addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy) - } + addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Strategy: d.strategy}) if err != nil { return nil, err } if d.parallel { - return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, d.fallbackDelay) + return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay) } else { return N.DialSerial(ctx, d.dialer, network, destination, addresses) } @@ -85,17 +74,8 @@ func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd if !destination.IsFqdn() { return d.dialer.ListenPacket(ctx, destination) } - ctx, metadata := adapter.ExtendContext(ctx) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) - metadata.Destination = destination - metadata.Domain = "" - var addresses []netip.Addr - var err error - if d.strategy == dns.DomainStrategyAsIS { - addresses, err = d.router.LookupDefault(ctx, destination.Fqdn) - } else { - addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy) - } + addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Strategy: d.strategy}) if err != nil { return nil, err } @@ -110,17 +90,10 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context if !destination.IsFqdn() { return d.dialer.DialContext(ctx, network, destination) } - ctx, metadata := adapter.ExtendContext(ctx) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) - metadata.Destination = destination - metadata.Domain = "" - var addresses []netip.Addr - var err error - if d.strategy == dns.DomainStrategyAsIS { - addresses, err = d.router.LookupDefault(ctx, destination.Fqdn) - } else { - addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy) - } + addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{ + Strategy: d.strategy, + }) if err != nil { return nil, err } @@ -128,7 +101,7 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context fallbackDelay = d.fallbackDelay } if d.parallel { - return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, strategy, interfaceType, fallbackInterfaceType, fallbackDelay) + return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, strategy, interfaceType, fallbackInterfaceType, fallbackDelay) } else { return DialSerialNetwork(ctx, d.dialer, network, destination, addresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay) } @@ -138,17 +111,8 @@ func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.C if !destination.IsFqdn() { return d.dialer.ListenPacket(ctx, destination) } - ctx, metadata := adapter.ExtendContext(ctx) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) - metadata.Destination = destination - metadata.Domain = "" - var addresses []netip.Addr - var err error - if d.strategy == dns.DomainStrategyAsIS { - addresses, err = d.router.LookupDefault(ctx, destination.Fqdn) - } else { - addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy) - } + addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Strategy: d.strategy}) if err != nil { return nil, err } diff --git a/common/dialer/router.go b/common/dialer/router.go index 3edce65b..801a36b1 100644 --- a/common/dialer/router.go +++ b/common/dialer/router.go @@ -7,24 +7,27 @@ import ( "github.com/sagernet/sing-box/adapter" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" ) type DefaultOutboundDialer struct { - outboundManager adapter.OutboundManager + outbound adapter.OutboundManager } -func NewDefaultOutbound(outboundManager adapter.OutboundManager) N.Dialer { - return &DefaultOutboundDialer{outboundManager: outboundManager} +func NewDefaultOutbound(ctx context.Context) N.Dialer { + return &DefaultOutboundDialer{ + outbound: service.FromContext[adapter.OutboundManager](ctx), + } } func (d *DefaultOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - return d.outboundManager.Default().DialContext(ctx, network, destination) + return d.outbound.Default().DialContext(ctx, network, destination) } func (d *DefaultOutboundDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - return d.outboundManager.Default().ListenPacket(ctx, destination) + return d.outbound.Default().ListenPacket(ctx, destination) } func (d *DefaultOutboundDialer) Upstream() any { - return d.outboundManager.Default() + return d.outbound.Default() } diff --git a/common/tls/ech_client.go b/common/tls/ech_client.go index 0ae3997a..fff1873d 100644 --- a/common/tls/ech_client.go +++ b/common/tls/ech_client.go @@ -15,8 +15,8 @@ import ( cftls "github.com/sagernet/cloudflare-tls" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-dns" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/service" @@ -215,7 +215,7 @@ func fetchECHClientConfig(ctx context.Context) func(_ context.Context, serverNam }, }, } - response, err := service.FromContext[adapter.Router](ctx).Exchange(ctx, message) + response, err := service.FromContext[adapter.DNSRouter](ctx).Exchange(ctx, message, adapter.DNSQueryOptions{}) if err != nil { return nil, err } diff --git a/common/tls/std_client.go b/common/tls/std_client.go index 90f51821..7cd130a6 100644 --- a/common/tls/std_client.go +++ b/common/tls/std_client.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "crypto/x509" "net" - "net/netip" "os" "strings" @@ -51,9 +50,7 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb if options.ServerName != "" { serverName = options.ServerName } else if serverAddress != "" { - if _, err := netip.ParseAddr(serverName); err != nil { - serverName = serverAddress - } + serverName = serverAddress } if serverName == "" && !options.Insecure { return nil, E.New("missing server_name or insecure=true") diff --git a/constant/dns.go b/constant/dns.go index 3907b8c1..99461a27 100644 --- a/constant/dns.go +++ b/constant/dns.go @@ -1,5 +1,34 @@ package constant +const ( + DefaultDNSTTL = 600 +) + +type DomainStrategy = uint8 + +const ( + DomainStrategyAsIS DomainStrategy = iota + DomainStrategyPreferIPv4 + DomainStrategyPreferIPv6 + DomainStrategyIPv4Only + DomainStrategyIPv6Only +) + +const ( + DNSTypeLegacy = "legacy" + DNSTypeUDP = "udp" + DNSTypeTCP = "tcp" + DNSTypeTLS = "tls" + DNSTypeHTTPS = "https" + DNSTypeQUIC = "quic" + DNSTypeHTTP3 = "h3" + DNSTypeHosts = "hosts" + DNSTypeLocal = "local" + DNSTypePreDefined = "predefined" + DNSTypeFakeIP = "fakeip" + DNSTypeDHCP = "dhcp" +) + const ( DNSProviderAliDNS = "alidns" DNSProviderCloudflare = "cloudflare" diff --git a/dns/client.go b/dns/client.go new file mode 100644 index 00000000..79b6fce5 --- /dev/null +++ b/dns/client.go @@ -0,0 +1,563 @@ +package dns + +import ( + "context" + "net" + "net/netip" + "strings" + "time" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/task" + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" + + "github.com/miekg/dns" +) + +var ( + ErrNoRawSupport = E.New("no raw query support by current transport") + ErrNotCached = E.New("not cached") + ErrResponseRejected = E.New("response rejected") + ErrResponseRejectedCached = E.Extend(ErrResponseRejected, "cached") +) + +var _ adapter.DNSClient = (*Client)(nil) + +type Client struct { + timeout time.Duration + disableCache bool + disableExpire bool + independentCache bool + rdrc adapter.RDRCStore + initRDRCFunc func() adapter.RDRCStore + logger logger.ContextLogger + cache freelru.Cache[dns.Question, *dns.Msg] + transportCache freelru.Cache[transportCacheKey, *dns.Msg] +} + +type ClientOptions struct { + Timeout time.Duration + DisableCache bool + DisableExpire bool + IndependentCache bool + CacheCapacity uint32 + RDRC func() adapter.RDRCStore + Logger logger.ContextLogger +} + +func NewClient(options ClientOptions) *Client { + client := &Client{ + timeout: options.Timeout, + disableCache: options.DisableCache, + disableExpire: options.DisableExpire, + independentCache: options.IndependentCache, + initRDRCFunc: options.RDRC, + logger: options.Logger, + } + if client.timeout == 0 { + client.timeout = C.DNSTimeout + } + cacheCapacity := options.CacheCapacity + if cacheCapacity < 1024 { + cacheCapacity = 1024 + } + if !client.disableCache { + if !client.independentCache { + client.cache = common.Must1(freelru.NewSharded[dns.Question, *dns.Msg](cacheCapacity, maphash.NewHasher[dns.Question]().Hash32)) + } else { + client.transportCache = common.Must1(freelru.NewSharded[transportCacheKey, *dns.Msg](cacheCapacity, maphash.NewHasher[transportCacheKey]().Hash32)) + } + } + return client +} + +type transportCacheKey struct { + dns.Question + transportTag string +} + +func (c *Client) Start() { + if c.initRDRCFunc != nil { + c.rdrc = c.initRDRCFunc() + } +} + +func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) { + if len(message.Question) == 0 { + if c.logger != nil { + c.logger.WarnContext(ctx, "bad question size: ", len(message.Question)) + } + responseMessage := dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: message.Id, + Response: true, + Rcode: dns.RcodeFormatError, + }, + Question: message.Question, + } + return &responseMessage, nil + } + question := message.Question[0] + if options.ClientSubnet.IsValid() { + message = SetClientSubnet(message, options.ClientSubnet, true) + } + isSimpleRequest := len(message.Question) == 1 && + len(message.Ns) == 0 && + len(message.Extra) == 0 && + !options.ClientSubnet.IsValid() + disableCache := !isSimpleRequest || c.disableCache || options.DisableCache + if !disableCache { + response, ttl := c.loadResponse(question, transport) + if response != nil { + logCachedResponse(c.logger, ctx, response, ttl) + response.Id = message.Id + return response, nil + } + } + if question.Qtype == dns.TypeA && options.Strategy == C.DomainStrategyIPv6Only || question.Qtype == dns.TypeAAAA && options.Strategy == C.DomainStrategyIPv4Only { + responseMessage := dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: message.Id, + Response: true, + Rcode: dns.RcodeSuccess, + }, + Question: []dns.Question{question}, + } + if c.logger != nil { + c.logger.DebugContext(ctx, "strategy rejected") + } + return &responseMessage, nil + } + messageId := message.Id + contextTransport, clientSubnetLoaded := transportTagFromContext(ctx) + if clientSubnetLoaded && transport.Tag() == contextTransport { + return nil, E.New("DNS query loopback in transport[", contextTransport, "]") + } + ctx = contextWithTransportTag(ctx, transport.Tag()) + if responseChecker != nil && c.rdrc != nil { + rejected := c.rdrc.LoadRDRC(transport.Tag(), question.Name, question.Qtype) + if rejected { + return nil, ErrResponseRejectedCached + } + } + ctx, cancel := context.WithTimeout(ctx, c.timeout) + response, err := transport.Exchange(ctx, message) + cancel() + if err != nil { + return nil, err + } + /*if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA { + validResponse := response + loop: + for { + var ( + addresses int + queryCNAME string + ) + for _, rawRR := range validResponse.Answer { + switch rr := rawRR.(type) { + case *dns.A: + break loop + case *dns.AAAA: + break loop + case *dns.CNAME: + queryCNAME = rr.Target + } + } + if queryCNAME == "" { + break + } + exMessage := *message + exMessage.Question = []dns.Question{{ + Name: queryCNAME, + Qtype: question.Qtype, + }} + validResponse, err = c.Exchange(ctx, transport, &exMessage, options, responseChecker) + if err != nil { + return nil, err + } + } + if validResponse != response { + response.Answer = append(response.Answer, validResponse.Answer...) + } + }*/ + if responseChecker != nil { + addr, addrErr := MessageToAddresses(response) + if addrErr != nil || !responseChecker(addr) { + if c.rdrc != nil { + c.rdrc.SaveRDRCAsync(transport.Tag(), question.Name, question.Qtype, c.logger) + } + logRejectedResponse(c.logger, ctx, response) + return response, ErrResponseRejected + } + } + if question.Qtype == dns.TypeHTTPS { + if options.Strategy == C.DomainStrategyIPv4Only || options.Strategy == C.DomainStrategyIPv6Only { + for _, rr := range response.Answer { + https, isHTTPS := rr.(*dns.HTTPS) + if !isHTTPS { + continue + } + content := https.SVCB + content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool { + if options.Strategy == C.DomainStrategyIPv4Only { + return it.Key() != dns.SVCB_IPV6HINT + } else { + return it.Key() != dns.SVCB_IPV4HINT + } + }) + https.SVCB = content + } + } + } + var timeToLive uint32 + for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} { + for _, record := range recordList { + if timeToLive == 0 || record.Header().Ttl > 0 && record.Header().Ttl < timeToLive { + timeToLive = record.Header().Ttl + } + } + } + if options.RewriteTTL != nil { + timeToLive = *options.RewriteTTL + } + for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} { + for _, record := range recordList { + record.Header().Ttl = timeToLive + } + } + response.Id = messageId + if !disableCache { + c.storeCache(transport, question, response, timeToLive) + } + logExchangedResponse(c.logger, ctx, response, timeToLive) + return response, err +} + +func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { + domain = FqdnToDomain(domain) + dnsName := dns.Fqdn(domain) + if options.Strategy == C.DomainStrategyIPv4Only { + return c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker) + } else if options.Strategy == C.DomainStrategyIPv6Only { + return c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker) + } + var response4 []netip.Addr + var response6 []netip.Addr + var group task.Group + group.Append("exchange4", func(ctx context.Context) error { + response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker) + if err != nil { + return err + } + response4 = response + return nil + }) + group.Append("exchange6", func(ctx context.Context) error { + response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker) + if err != nil { + return err + } + response6 = response + return nil + }) + err := group.Run(ctx) + if len(response4) == 0 && len(response6) == 0 { + return nil, err + } + return sortAddresses(response4, response6, options.Strategy), nil +} + +func (c *Client) ClearCache() { + if c.cache != nil { + c.cache.Purge() + } + if c.transportCache != nil { + c.transportCache.Purge() + } +} + +func (c *Client) LookupCache(domain string, strategy C.DomainStrategy) ([]netip.Addr, bool) { + if c.disableCache || c.independentCache { + return nil, false + } + if dns.IsFqdn(domain) { + domain = domain[:len(domain)-1] + } + dnsName := dns.Fqdn(domain) + if strategy == C.DomainStrategyIPv4Only { + response, err := c.questionCache(dns.Question{ + Name: dnsName, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, nil) + if err != ErrNotCached { + return response, true + } + } else if strategy == C.DomainStrategyIPv6Only { + response, err := c.questionCache(dns.Question{ + Name: dnsName, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + }, nil) + if err != ErrNotCached { + return response, true + } + } else { + response4, _ := c.questionCache(dns.Question{ + Name: dnsName, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, nil) + response6, _ := c.questionCache(dns.Question{ + Name: dnsName, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + }, nil) + if len(response4) > 0 || len(response6) > 0 { + return sortAddresses(response4, response6, strategy), true + } + } + return nil, false +} + +func (c *Client) ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool) { + if c.disableCache || c.independentCache || len(message.Question) != 1 { + return nil, false + } + question := message.Question[0] + response, ttl := c.loadResponse(question, nil) + if response == nil { + return nil, false + } + logCachedResponse(c.logger, ctx, response, ttl) + response.Id = message.Id + return response, true +} + +func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.DomainStrategy) []netip.Addr { + if strategy == C.DomainStrategyPreferIPv6 { + return append(response6, response4...) + } else { + return append(response4, response6...) + } +} + +func (c *Client) storeCache(transport adapter.DNSTransport, question dns.Question, message *dns.Msg, timeToLive uint32) { + if timeToLive == 0 { + return + } + if c.disableExpire { + if !c.independentCache { + c.cache.Add(question, message) + } else { + c.transportCache.Add(transportCacheKey{ + Question: question, + transportTag: transport.Tag(), + }, message) + } + return + } + if !c.independentCache { + c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive)) + } else { + c.transportCache.AddWithLifetime(transportCacheKey{ + Question: question, + transportTag: transport.Tag(), + }, message, time.Second*time.Duration(timeToLive)) + } +} + +func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name string, qType uint16, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { + question := dns.Question{ + Name: name, + Qtype: qType, + Qclass: dns.ClassINET, + } + disableCache := c.disableCache || options.DisableCache + if !disableCache { + cachedAddresses, err := c.questionCache(question, transport) + if err != ErrNotCached { + return cachedAddresses, err + } + } + message := dns.Msg{ + MsgHdr: dns.MsgHdr{ + RecursionDesired: true, + }, + Question: []dns.Question{question}, + } + response, err := c.Exchange(ctx, transport, &message, options, responseChecker) + if err != nil { + return nil, err + } + return MessageToAddresses(response) +} + +func (c *Client) questionCache(question dns.Question, transport adapter.DNSTransport) ([]netip.Addr, error) { + response, _ := c.loadResponse(question, transport) + if response == nil { + return nil, ErrNotCached + } + return MessageToAddresses(response) +} + +func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int) { + var ( + response *dns.Msg + loaded bool + ) + if c.disableExpire { + if !c.independentCache { + response, loaded = c.cache.Get(question) + } else { + response, loaded = c.transportCache.Get(transportCacheKey{ + Question: question, + transportTag: transport.Tag(), + }) + } + if !loaded { + return nil, 0 + } + return response.Copy(), 0 + } else { + var expireAt time.Time + if !c.independentCache { + response, expireAt, loaded = c.cache.GetWithLifetime(question) + } else { + response, expireAt, loaded = c.transportCache.GetWithLifetime(transportCacheKey{ + Question: question, + transportTag: transport.Tag(), + }) + } + if !loaded { + return nil, 0 + } + timeNow := time.Now() + if timeNow.After(expireAt) { + if !c.independentCache { + c.cache.Remove(question) + } else { + c.transportCache.Remove(transportCacheKey{ + Question: question, + transportTag: transport.Tag(), + }) + } + return nil, 0 + } + var originTTL int + for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} { + for _, record := range recordList { + if originTTL == 0 || record.Header().Ttl > 0 && int(record.Header().Ttl) < originTTL { + originTTL = int(record.Header().Ttl) + } + } + } + nowTTL := int(expireAt.Sub(timeNow).Seconds()) + if nowTTL < 0 { + nowTTL = 0 + } + response = response.Copy() + if originTTL > 0 { + duration := uint32(originTTL - nowTTL) + for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} { + for _, record := range recordList { + record.Header().Ttl = record.Header().Ttl - duration + } + } + } else { + for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} { + for _, record := range recordList { + record.Header().Ttl = uint32(nowTTL) + } + } + } + return response, nowTTL + } +} + +func MessageToAddresses(response *dns.Msg) ([]netip.Addr, error) { + if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError { + return nil, RCodeError(response.Rcode) + } + addresses := make([]netip.Addr, 0, len(response.Answer)) + for _, rawAnswer := range response.Answer { + switch answer := rawAnswer.(type) { + case *dns.A: + addresses = append(addresses, M.AddrFromIP(answer.A)) + case *dns.AAAA: + addresses = append(addresses, M.AddrFromIP(answer.AAAA)) + case *dns.HTTPS: + for _, value := range answer.SVCB.Value { + if value.Key() == dns.SVCB_IPV4HINT || value.Key() == dns.SVCB_IPV6HINT { + addresses = append(addresses, common.Map(strings.Split(value.String(), ","), M.ParseAddr)...) + } + } + } + } + return addresses, nil +} + +func wrapError(err error) error { + switch dnsErr := err.(type) { + case *net.DNSError: + if dnsErr.IsNotFound { + return RCodeNameError + } + case *net.AddrError: + return RCodeNameError + } + return err +} + +type transportKey struct{} + +func contextWithTransportTag(ctx context.Context, transportTag string) context.Context { + return context.WithValue(ctx, transportKey{}, transportTag) +} + +func transportTagFromContext(ctx context.Context) (string, bool) { + value, loaded := ctx.Value(transportKey{}).(string) + return value, loaded +} + +func FixedResponse(id uint16, question dns.Question, addresses []netip.Addr, timeToLive uint32) *dns.Msg { + response := dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: id, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{question}, + } + for _, address := range addresses { + if address.Is4() { + response.Answer = append(response.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: timeToLive, + }, + A: address.AsSlice(), + }) + } else { + response.Answer = append(response.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: timeToLive, + }, + AAAA: address.AsSlice(), + }) + } + } + return &response +} diff --git a/dns/client_log.go b/dns/client_log.go new file mode 100644 index 00000000..67d00708 --- /dev/null +++ b/dns/client_log.go @@ -0,0 +1,69 @@ +package dns + +import ( + "context" + "strings" + + "github.com/sagernet/sing/common/logger" + + "github.com/miekg/dns" +) + +func logCachedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg, ttl int) { + if logger == nil || len(response.Question) == 0 { + return + } + domain := FqdnToDomain(response.Question[0].Name) + logger.DebugContext(ctx, "cached ", domain, " ", dns.RcodeToString[response.Rcode], " ", ttl) + for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} { + for _, record := range recordList { + logger.InfoContext(ctx, "cached ", dns.Type(record.Header().Rrtype).String(), " ", FormatQuestion(record.String())) + } + } +} + +func logExchangedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg, ttl uint32) { + if logger == nil || len(response.Question) == 0 { + return + } + domain := FqdnToDomain(response.Question[0].Name) + logger.DebugContext(ctx, "exchanged ", domain, " ", dns.RcodeToString[response.Rcode], " ", ttl) + for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} { + for _, record := range recordList { + logger.InfoContext(ctx, "exchanged ", dns.Type(record.Header().Rrtype).String(), " ", FormatQuestion(record.String())) + } + } +} + +func logRejectedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg) { + if logger == nil || len(response.Question) == 0 { + return + } + for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} { + for _, record := range recordList { + logger.InfoContext(ctx, "rejected ", dns.Type(record.Header().Rrtype).String(), " ", FormatQuestion(record.String())) + } + } +} + +func FqdnToDomain(fqdn string) string { + if dns.IsFqdn(fqdn) { + return fqdn[:len(fqdn)-1] + } + return fqdn +} + +func FormatQuestion(string string) string { + for strings.HasPrefix(string, ";") { + string = string[1:] + } + string = strings.ReplaceAll(string, "\t", " ") + string = strings.ReplaceAll(string, "\n", " ") + string = strings.ReplaceAll(string, ";; ", " ") + string = strings.ReplaceAll(string, "; ", " ") + + for strings.Contains(string, " ") { + string = strings.ReplaceAll(string, " ", " ") + } + return strings.TrimSpace(string) +} diff --git a/dns/client_truncate.go b/dns/client_truncate.go new file mode 100644 index 00000000..e64064e6 --- /dev/null +++ b/dns/client_truncate.go @@ -0,0 +1,29 @@ +package dns + +import ( + "github.com/sagernet/sing/common/buf" + + "github.com/miekg/dns" +) + +func TruncateDNSMessage(request *dns.Msg, response *dns.Msg, headroom int) (*buf.Buffer, error) { + maxLen := 512 + if edns0Option := request.IsEdns0(); edns0Option != nil { + if udpSize := int(edns0Option.UDPSize()); udpSize > 512 { + maxLen = udpSize + } + } + responseLen := response.Len() + if responseLen > maxLen { + response.Truncate(maxLen) + } + buffer := buf.NewSize(headroom*2 + 1 + responseLen) + buffer.Resize(headroom, 0) + rawMessage, err := response.PackBuffer(buffer.FreeBytes()) + if err != nil { + buffer.Release() + return nil, err + } + buffer.Truncate(len(rawMessage)) + return buffer, nil +} diff --git a/dns/extension_edns0_subnet.go b/dns/extension_edns0_subnet.go new file mode 100644 index 00000000..1c4033d3 --- /dev/null +++ b/dns/extension_edns0_subnet.go @@ -0,0 +1,56 @@ +package dns + +import ( + "net/netip" + + "github.com/miekg/dns" +) + +func SetClientSubnet(message *dns.Msg, clientSubnet netip.Prefix, override bool) *dns.Msg { + var ( + optRecord *dns.OPT + subnetOption *dns.EDNS0_SUBNET + ) +findExists: + for _, record := range message.Extra { + var isOPTRecord bool + if optRecord, isOPTRecord = record.(*dns.OPT); isOPTRecord { + for _, option := range optRecord.Option { + var isEDNS0Subnet bool + subnetOption, isEDNS0Subnet = option.(*dns.EDNS0_SUBNET) + if isEDNS0Subnet { + if !override { + return message + } + break findExists + } + } + } + } + if optRecord == nil { + exMessage := *message + message = &exMessage + optRecord = &dns.OPT{ + Hdr: dns.RR_Header{ + Name: ".", + Rrtype: dns.TypeOPT, + }, + } + message.Extra = append(message.Extra, optRecord) + } else { + message = message.Copy() + } + if subnetOption == nil { + subnetOption = new(dns.EDNS0_SUBNET) + optRecord.Option = append(optRecord.Option, subnetOption) + } + subnetOption.Code = dns.EDNS0SUBNET + if clientSubnet.Addr().Is4() { + subnetOption.Family = 1 + } else { + subnetOption.Family = 2 + } + subnetOption.SourceNetmask = uint8(clientSubnet.Bits()) + subnetOption.Address = clientSubnet.Addr().AsSlice() + return message +} diff --git a/dns/rcode.go b/dns/rcode.go new file mode 100644 index 00000000..5b7e52cc --- /dev/null +++ b/dns/rcode.go @@ -0,0 +1,33 @@ +package dns + +import F "github.com/sagernet/sing/common/format" + +const ( + RCodeSuccess RCodeError = 0 // NoError + RCodeFormatError RCodeError = 1 // FormErr + RCodeServerFailure RCodeError = 2 // ServFail + RCodeNameError RCodeError = 3 // NXDomain + RCodeNotImplemented RCodeError = 4 // NotImp + RCodeRefused RCodeError = 5 // Refused +) + +type RCodeError uint16 + +func (e RCodeError) Error() string { + switch e { + case RCodeSuccess: + return "success" + case RCodeFormatError: + return "format error" + case RCodeServerFailure: + return "server failure" + case RCodeNameError: + return "name error" + case RCodeNotImplemented: + return "not implemented" + case RCodeRefused: + return "refused" + default: + return F.ToString("unknown error: ", uint16(e)) + } +} diff --git a/dns/router.go b/dns/router.go new file mode 100644 index 00000000..8ecb8891 --- /dev/null +++ b/dns/router.go @@ -0,0 +1,437 @@ +package dns + +import ( + "context" + "errors" + "net/netip" + "strings" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/taskmonitor" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/experimental/libbox/platform" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + R "github.com/sagernet/sing-box/route/rule" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" + "github.com/sagernet/sing/service" + + mDNS "github.com/miekg/dns" +) + +var _ adapter.DNSRouter = (*Router)(nil) + +type Router struct { + ctx context.Context + logger logger.ContextLogger + transport adapter.DNSTransportManager + outbound adapter.OutboundManager + client adapter.DNSClient + rules []adapter.DNSRule + defaultDomainStrategy C.DomainStrategy + dnsReverseMapping freelru.Cache[netip.Addr, string] + platformInterface platform.Interface +} + +func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOptions) *Router { + router := &Router{ + ctx: ctx, + logger: logFactory.NewLogger("dns"), + transport: service.FromContext[adapter.DNSTransportManager](ctx), + outbound: service.FromContext[adapter.OutboundManager](ctx), + rules: make([]adapter.DNSRule, 0, len(options.Rules)), + defaultDomainStrategy: C.DomainStrategy(options.Strategy), + } + router.client = NewClient(ClientOptions{ + DisableCache: options.DNSClientOptions.DisableCache, + DisableExpire: options.DNSClientOptions.DisableExpire, + IndependentCache: options.DNSClientOptions.IndependentCache, + CacheCapacity: options.DNSClientOptions.CacheCapacity, + RDRC: func() adapter.RDRCStore { + cacheFile := service.FromContext[adapter.CacheFile](ctx) + if cacheFile == nil { + return nil + } + if !cacheFile.StoreRDRC() { + return nil + } + return cacheFile + }, + Logger: router.logger, + }) + if options.ReverseMapping { + router.dnsReverseMapping = common.Must1(freelru.NewSharded[netip.Addr, string](1024, maphash.NewHasher[netip.Addr]().Hash32)) + } + return router +} + +func (r *Router) Initialize(rules []option.DNSRule) error { + for i, ruleOptions := range rules { + dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true) + if err != nil { + return E.Cause(err, "parse dns rule[", i, "]") + } + r.rules = append(r.rules, dnsRule) + } + return nil +} + +func (r *Router) Start(stage adapter.StartStage) error { + monitor := taskmonitor.New(r.logger, C.StartTimeout) + switch stage { + case adapter.StartStateStart: + monitor.Start("initialize DNS client") + r.client.Start() + monitor.Finish() + + for i, rule := range r.rules { + monitor.Start("initialize DNS rule[", i, "]") + err := rule.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "initialize DNS rule[", i, "]") + } + } + } + return nil +} + +func (r *Router) Close() error { + monitor := taskmonitor.New(r.logger, C.StopTimeout) + var err error + for i, rule := range r.rules { + monitor.Start("close dns rule[", i, "]") + err = E.Append(err, rule.Close(), func(err error) error { + return E.Cause(err, "close dns rule[", i, "]") + }) + monitor.Finish() + } + return err +} + +func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) { + metadata := adapter.ContextFrom(ctx) + if metadata == nil { + panic("no context") + } + var currentRuleIndex int + if ruleIndex != -1 { + currentRuleIndex = ruleIndex + 1 + } + for ; currentRuleIndex < len(r.rules); currentRuleIndex++ { + currentRule := r.rules[currentRuleIndex] + if currentRule.WithAddressLimit() && !isAddressQuery { + continue + } + metadata.ResetRuleCache() + if currentRule.Match(metadata) { + displayRuleIndex := currentRuleIndex + if displayRuleIndex != -1 { + displayRuleIndex += displayRuleIndex + 1 + } + ruleDescription := currentRule.String() + if ruleDescription != "" { + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] ", currentRule, " => ", currentRule.Action()) + } else { + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action()) + } + switch action := currentRule.Action().(type) { + case *R.RuleActionDNSRoute: + transport, loaded := r.transport.Transport(action.Server) + if !loaded { + r.logger.ErrorContext(ctx, "transport not found: ", action.Server) + continue + } + isFakeIP := transport.Type() == C.DNSTypeFakeIP + if isFakeIP && !allowFakeIP { + continue + } + if action.Strategy != C.DomainStrategyAsIS { + options.Strategy = action.Strategy + } + if isFakeIP || action.DisableCache { + options.DisableCache = true + } + if action.RewriteTTL != nil { + options.RewriteTTL = action.RewriteTTL + } + if action.ClientSubnet.IsValid() { + options.ClientSubnet = action.ClientSubnet + } + if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy { + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = legacyTransport.LegacyStrategy() + } + if !options.ClientSubnet.IsValid() { + options.ClientSubnet = legacyTransport.LegacyClientSubnet() + } + } + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action()) + return transport, currentRule, currentRuleIndex + case *R.RuleActionDNSRouteOptions: + if action.Strategy != C.DomainStrategyAsIS { + options.Strategy = action.Strategy + } + if action.DisableCache { + options.DisableCache = true + } + if action.RewriteTTL != nil { + options.RewriteTTL = action.RewriteTTL + } + if action.ClientSubnet.IsValid() { + options.ClientSubnet = action.ClientSubnet + } + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action()) + case *R.RuleActionReject: + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action()) + return nil, currentRule, currentRuleIndex + } + } + } + return r.transport.Default(), nil, -1 +} + +func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) { + if len(message.Question) != 1 { + r.logger.WarnContext(ctx, "bad question size: ", len(message.Question)) + responseMessage := mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: message.Id, + Response: true, + Rcode: mDNS.RcodeFormatError, + }, + Question: message.Question, + } + return &responseMessage, nil + } + r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String())) + var ( + transport adapter.DNSTransport + err error + ) + response, cached := r.client.ExchangeCache(ctx, message) + if !cached { + var metadata *adapter.InboundContext + ctx, metadata = adapter.ExtendContext(ctx) + metadata.Destination = M.Socksaddr{} + metadata.QueryType = message.Question[0].Qtype + switch metadata.QueryType { + case mDNS.TypeA: + metadata.IPVersion = 4 + case mDNS.TypeAAAA: + metadata.IPVersion = 6 + } + metadata.Domain = FqdnToDomain(message.Question[0].Name) + if options.Transport != nil { + transport = options.Transport + if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy { + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = legacyTransport.LegacyStrategy() + } + if !options.ClientSubnet.IsValid() { + options.ClientSubnet = legacyTransport.LegacyClientSubnet() + } + } + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = r.defaultDomainStrategy + } + response, err = r.client.Exchange(ctx, transport, message, options, nil) + } else { + var ( + rule adapter.DNSRule + ruleIndex int + ) + ruleIndex = -1 + for { + dnsCtx := adapter.OverrideContext(ctx) + dnsOptions := options + transport, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message), &dnsOptions) + if rule != nil { + switch action := rule.Action().(type) { + case *R.RuleActionReject: + switch action.Method { + case C.RuleActionRejectMethodDefault: + return FixedResponse(message.Id, message.Question[0], nil, 0), nil + case C.RuleActionRejectMethodDrop: + return nil, tun.ErrDrop + } + } + } + var responseCheck func(responseAddrs []netip.Addr) bool + if rule != nil && rule.WithAddressLimit() { + responseCheck = func(responseAddrs []netip.Addr) bool { + metadata.DestinationAddresses = responseAddrs + return rule.MatchAddressLimit(metadata) + } + } + if dnsOptions.Strategy == C.DomainStrategyAsIS { + dnsOptions.Strategy = r.defaultDomainStrategy + } + response, err = r.client.Exchange(dnsCtx, transport, message, dnsOptions, responseCheck) + var rejected bool + if err != nil { + if errors.Is(err, ErrResponseRejectedCached) { + rejected = true + r.logger.DebugContext(ctx, E.Cause(err, "response rejected for ", FormatQuestion(message.Question[0].String())), " (cached)") + } else if errors.Is(err, ErrResponseRejected) { + rejected = true + r.logger.DebugContext(ctx, E.Cause(err, "response rejected for ", FormatQuestion(message.Question[0].String()))) + } else if len(message.Question) > 0 { + r.logger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", FormatQuestion(message.Question[0].String()))) + } else { + r.logger.ErrorContext(ctx, E.Cause(err, "exchange failed for ")) + } + } + if responseCheck != nil && rejected { + continue + } + break + } + } + } + if err != nil { + return nil, err + } + if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 { + if transport == nil || transport.Type() != C.DNSTypeFakeIP { + for _, answer := range response.Answer { + switch record := answer.(type) { + case *mDNS.A: + r.dnsReverseMapping.AddWithLifetime(M.AddrFromIP(record.A), FqdnToDomain(record.Hdr.Name), time.Duration(record.Hdr.Ttl)*time.Second) + case *mDNS.AAAA: + r.dnsReverseMapping.AddWithLifetime(M.AddrFromIP(record.AAAA), FqdnToDomain(record.Hdr.Name), time.Duration(record.Hdr.Ttl)*time.Second) + } + } + } + } + return response, nil +} + +func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) { + var ( + responseAddrs []netip.Addr + cached bool + err error + ) + printResult := func() { + if err != nil { + if errors.Is(err, ErrResponseRejectedCached) { + r.logger.DebugContext(ctx, "response rejected for ", domain, " (cached)") + } else if errors.Is(err, ErrResponseRejected) { + r.logger.DebugContext(ctx, "response rejected for ", domain) + } else { + r.logger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain)) + } + } else if len(responseAddrs) == 0 { + r.logger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result") + err = RCodeNameError + } + } + responseAddrs, cached = r.client.LookupCache(domain, options.Strategy) + if cached { + if len(responseAddrs) == 0 { + return nil, RCodeNameError + } + return responseAddrs, nil + } + r.logger.DebugContext(ctx, "lookup domain ", domain) + ctx, metadata := adapter.ExtendContext(ctx) + metadata.Destination = M.Socksaddr{} + metadata.Domain = FqdnToDomain(domain) + if options.Transport != nil { + transport := options.Transport + if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy { + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = r.defaultDomainStrategy + } + if !options.ClientSubnet.IsValid() { + options.ClientSubnet = legacyTransport.LegacyClientSubnet() + } + } + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = r.defaultDomainStrategy + } + responseAddrs, err = r.client.Lookup(ctx, transport, domain, options, nil) + } else { + var ( + transport adapter.DNSTransport + rule adapter.DNSRule + ruleIndex int + ) + ruleIndex = -1 + for { + dnsCtx := adapter.OverrideContext(ctx) + transport, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true, &options) + if rule != nil { + switch action := rule.Action().(type) { + case *R.RuleActionReject: + switch action.Method { + case C.RuleActionRejectMethodDefault: + return nil, nil + case C.RuleActionRejectMethodDrop: + return nil, tun.ErrDrop + } + } + } + var responseCheck func(responseAddrs []netip.Addr) bool + if rule != nil && rule.WithAddressLimit() { + responseCheck = func(responseAddrs []netip.Addr) bool { + metadata.DestinationAddresses = responseAddrs + return rule.MatchAddressLimit(metadata) + } + } + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = r.defaultDomainStrategy + } + responseAddrs, err = r.client.Lookup(dnsCtx, transport, domain, options, responseCheck) + if responseCheck == nil || err == nil { + break + } + printResult() + } + } + printResult() + if len(responseAddrs) > 0 { + r.logger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " ")) + } + return responseAddrs, err +} + +func isAddressQuery(message *mDNS.Msg) bool { + for _, question := range message.Question { + if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA || question.Qtype == mDNS.TypeHTTPS { + return true + } + } + return false +} + +func (r *Router) ClearCache() { + r.client.ClearCache() + if r.platformInterface != nil { + r.platformInterface.ClearDNSCache() + } +} + +func (r *Router) LookupReverseMapping(ip netip.Addr) (string, bool) { + if r.dnsReverseMapping == nil { + return "", false + } + domain, loaded := r.dnsReverseMapping.Get(ip) + return domain, loaded +} + +func (r *Router) ResetNetwork() { + r.ClearCache() + for _, transport := range r.transport.Transports() { + transport.Reset() + } +} diff --git a/transport/dhcp/server.go b/dns/transport/dhcp/dhcp.go similarity index 66% rename from transport/dhcp/server.go rename to dns/transport/dhcp/dhcp.go index 8b9187f0..c75d7369 100644 --- a/transport/dhcp/server.go +++ b/dns/transport/dhcp/dhcp.go @@ -3,9 +3,6 @@ package dhcp import ( "context" "net" - "net/netip" - "net/url" - "os" "runtime" "strings" "sync" @@ -14,13 +11,18 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport" + "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/service" @@ -29,76 +31,70 @@ import ( mDNS "github.com/miekg/dns" ) -func init() { - dns.RegisterTransport([]string{"dhcp"}, func(options dns.TransportOptions) (dns.Transport, error) { - return NewTransport(options) - }) +func RegisterTransport(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.DHCPDNSServerOptions](registry, C.DNSTypeDHCP, NewTransport) } +var _ adapter.DNSTransport = (*Transport)(nil) + type Transport struct { - options dns.TransportOptions - router adapter.Router + dns.TransportAdapter + ctx context.Context + dialer N.Dialer + logger logger.ContextLogger networkManager adapter.NetworkManager interfaceName string - autoInterface bool interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback] - transports []dns.Transport + transports []adapter.DNSTransport updateAccess sync.Mutex updatedAt time.Time } -func NewTransport(options dns.TransportOptions) (*Transport, error) { - linkURL, err := url.Parse(options.Address) +func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.DHCPDNSServerOptions) (adapter.DNSTransport, error) { + transportDialer, err := dns.NewLocalDialer(ctx, options.LocalDNSServerOptions) if err != nil { return nil, err } - if linkURL.Host == "" { - return nil, E.New("missing interface name for DHCP") - } - transport := &Transport{ - options: options, - networkManager: service.FromContext[adapter.NetworkManager](options.Context), - interfaceName: linkURL.Host, - autoInterface: linkURL.Host == "auto", - } - return transport, nil + return &Transport{ + TransportAdapter: dns.NewTransportAdapterWithLocalOptions(C.DNSTypeDHCP, tag, options.LocalDNSServerOptions), + ctx: ctx, + dialer: transportDialer, + logger: logger, + networkManager: service.FromContext[adapter.NetworkManager](ctx), + interfaceName: options.Interface, + }, nil } -func (t *Transport) Name() string { - return t.options.Name -} - -func (t *Transport) Start() error { +func (t *Transport) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } err := t.fetchServers() if err != nil { return err } - if t.autoInterface { + if t.interfaceName == "" { t.interfaceCallback = t.networkManager.InterfaceMonitor().RegisterCallback(t.interfaceUpdated) } return nil } +func (t *Transport) Close() error { + for _, transport := range t.transports { + transport.Reset() + } + if t.interfaceCallback != nil { + t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback) + } + return nil +} + func (t *Transport) Reset() { for _, transport := range t.transports { transport.Reset() } } -func (t *Transport) Close() error { - for _, transport := range t.transports { - transport.Close() - } - if t.interfaceCallback != nil { - t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback) - } - return nil -} - -func (t *Transport) Raw() bool { - return true -} - func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { err := t.fetchServers() if err != nil { @@ -120,7 +116,7 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, } func (t *Transport) fetchInterface() (*control.Interface, error) { - if t.autoInterface { + if t.interfaceName == "" { if t.networkManager.InterfaceMonitor() == nil { return nil, E.New("missing monitor for auto DHCP, set route.auto_detect_interface") } @@ -152,8 +148,8 @@ func (t *Transport) updateServers() error { return E.Cause(err, "dhcp: prepare interface") } - t.options.Logger.Info("dhcp: query DNS servers on ", iface.Name) - fetchCtx, cancel := context.WithTimeout(t.options.Context, C.DHCPTimeout) + t.logger.Info("dhcp: query DNS servers on ", iface.Name) + fetchCtx, cancel := context.WithTimeout(t.ctx, C.DHCPTimeout) err = t.fetchServers0(fetchCtx, iface) cancel() if err != nil { @@ -169,7 +165,7 @@ func (t *Transport) updateServers() error { func (t *Transport) interfaceUpdated(defaultInterface *control.Interface, flags int) { err := t.updateServers() if err != nil { - t.options.Logger.Error("update servers: ", err) + t.logger.Error("update servers: ", err) } } @@ -181,7 +177,7 @@ func (t *Transport) fetchServers0(ctx context.Context, iface *control.Interface) if runtime.GOOS == "linux" || runtime.GOOS == "android" { listenAddr = "255.255.255.255:68" } - packetConn, err := listener.ListenPacket(t.options.Context, "udp4", listenAddr) + packetConn, err := listener.ListenPacket(t.ctx, "udp4", listenAddr) if err != nil { return err } @@ -219,17 +215,17 @@ func (t *Transport) fetchServersResponse(iface *control.Interface, packetConn ne dhcpPacket, err := dhcpv4.FromBytes(buffer.Bytes()) if err != nil { - t.options.Logger.Trace("dhcp: parse DHCP response: ", err) + t.logger.Trace("dhcp: parse DHCP response: ", err) return err } if dhcpPacket.MessageType() != dhcpv4.MessageTypeOffer { - t.options.Logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType()) + t.logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType()) continue } if dhcpPacket.TransactionID != transactionID { - t.options.Logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID) + t.logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID) continue } @@ -237,44 +233,27 @@ func (t *Transport) fetchServersResponse(iface *control.Interface, packetConn ne if len(dns) == 0 { return nil } - - var addrs []netip.Addr - for _, ip := range dns { - addr, _ := netip.AddrFromSlice(ip) - addrs = append(addrs, addr.Unmap()) - } - return t.recreateServers(iface, addrs) + return t.recreateServers(iface, common.Map(dns, func(it net.IP) M.Socksaddr { + return M.SocksaddrFrom(M.AddrFromIP(it), 53) + })) } } -func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []netip.Addr) error { +func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []M.Socksaddr) error { if len(serverAddrs) > 0 { - t.options.Logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, func(it netip.Addr) string { - return it.String() - }), ","), "]") + t.logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, M.Socksaddr.String), ","), "]") } - serverDialer := common.Must1(dialer.NewDefault(t.options.Context, option.DialerOptions{ + serverDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{ BindInterface: iface.Name, UDPFragmentDefault: true, })) - var transports []dns.Transport + var transports []adapter.DNSTransport for _, serverAddr := range serverAddrs { - newOptions := t.options - newOptions.Address = serverAddr.String() - newOptions.Dialer = serverDialer - serverTransport, err := dns.NewUDPTransport(newOptions) - if err != nil { - return E.Cause(err, "create UDP transport from DHCP result: ", serverAddr) - } - transports = append(transports, serverTransport) + transports = append(transports, transport.NewUDPRaw(t.logger, t.TransportAdapter, serverDialer, serverAddr)) } for _, transport := range t.transports { - transport.Close() + transport.Reset() } t.transports = transports return nil } - -func (t *Transport) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) { - return nil, os.ErrInvalid -} diff --git a/dns/transport/fakeip/fakeip.go b/dns/transport/fakeip/fakeip.go new file mode 100644 index 00000000..07f0fd09 --- /dev/null +++ b/dns/transport/fakeip/fakeip.go @@ -0,0 +1,67 @@ +package fakeip + +import ( + "context" + "net/netip" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + + mDNS "github.com/miekg/dns" +) + +func RegisterTransport(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.FakeIPDNSServerOptions](registry, C.DNSTypeFakeIP, NewTransport) +} + +var _ adapter.FakeIPTransport = (*Transport)(nil) + +type Transport struct { + dns.TransportAdapter + logger logger.ContextLogger + store adapter.FakeIPStore +} + +func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.FakeIPDNSServerOptions) (adapter.DNSTransport, error) { + store := NewStore(ctx, logger, options.Inet4Range.Build(netip.Prefix{}), options.Inet6Range.Build(netip.Prefix{})) + return &Transport{ + TransportAdapter: dns.NewTransportAdapter(C.DNSTypeFakeIP, tag, nil), + logger: logger, + store: store, + }, nil +} + +func (t *Transport) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + return t.store.Start() +} + +func (t *Transport) Close() error { + return t.store.Close() +} + +func (t *Transport) Reset() { +} + +func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + question := message.Question[0] + if question.Qtype != mDNS.TypeA && question.Qtype != mDNS.TypeAAAA { + return nil, E.New("only IP queries are supported by fakeip") + } + address, err := t.store.Create(dns.FqdnToDomain(question.Name), question.Qtype == mDNS.TypeAAAA) + if err != nil { + return nil, err + } + return dns.FixedResponse(message.Id, question, []netip.Addr{address}, C.DefaultDNSTTL), nil +} + +func (t *Transport) Store() adapter.FakeIPStore { + return t.store +} diff --git a/transport/fakeip/memory.go b/dns/transport/fakeip/memory.go similarity index 100% rename from transport/fakeip/memory.go rename to dns/transport/fakeip/memory.go diff --git a/transport/fakeip/store.go b/dns/transport/fakeip/store.go similarity index 100% rename from transport/fakeip/store.go rename to dns/transport/fakeip/store.go diff --git a/dns/transport/hosts/hosts.go b/dns/transport/hosts/hosts.go new file mode 100644 index 00000000..29f6778a --- /dev/null +++ b/dns/transport/hosts/hosts.go @@ -0,0 +1,63 @@ +package hosts + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + + mDNS "github.com/miekg/dns" +) + +func RegisterTransport(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.HostsDNSServerOptions](registry, C.DNSTypeHosts, NewTransport) +} + +var _ adapter.DNSTransport = (*Transport)(nil) + +type Transport struct { + dns.TransportAdapter + files []*File +} + +func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.HostsDNSServerOptions) (adapter.DNSTransport, error) { + var files []*File + if len(options.Path) == 0 { + files = append(files, NewFile(DefaultPath)) + } else { + for _, path := range options.Path { + files = append(files, NewFile(path)) + } + } + return &Transport{ + TransportAdapter: dns.NewTransportAdapter(C.DNSTypeHosts, tag, nil), + files: files, + }, nil +} + +func (t *Transport) Reset() { +} + +func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + question := message.Question[0] + domain := dns.FqdnToDomain(question.Name) + if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA { + for _, file := range t.files { + addresses := file.Lookup(domain) + if len(addresses) > 0 { + return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil + } + } + } + return &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: message.Id, + Rcode: mDNS.RcodeNameError, + Response: true, + }, + Question: []mDNS.Question{question}, + }, nil +} diff --git a/dns/transport/hosts/hosts_file.go b/dns/transport/hosts/hosts_file.go new file mode 100644 index 00000000..7ff34f69 --- /dev/null +++ b/dns/transport/hosts/hosts_file.go @@ -0,0 +1,102 @@ +package hosts + +import ( + "bufio" + "errors" + "io" + "net/netip" + "os" + "strings" + "sync" + "time" + + "github.com/miekg/dns" +) + +const cacheMaxAge = 5 * time.Second + +type File struct { + path string + access sync.Mutex + byName map[string][]netip.Addr + expire time.Time + modTime time.Time + size int64 +} + +func NewFile(path string) *File { + return &File{ + path: path, + } +} + +func (f *File) Lookup(name string) []netip.Addr { + f.access.Lock() + defer f.access.Unlock() + f.update() + return f.byName[name] +} + +func (f *File) update() { + now := time.Now() + if now.Before(f.expire) && len(f.byName) > 0 { + return + } + stat, err := os.Stat(f.path) + if err != nil { + return + } + if f.modTime.Equal(stat.ModTime()) && f.size == stat.Size() { + f.expire = now.Add(cacheMaxAge) + return + } + byName := make(map[string][]netip.Addr) + file, err := os.Open(f.path) + if err != nil { + return + } + defer file.Close() + reader := bufio.NewReader(file) + var ( + prefix []byte + line []byte + isPrefix bool + ) + for { + line, isPrefix, err = reader.ReadLine() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return + } + if isPrefix { + prefix = append(prefix, line...) + continue + } else if len(prefix) > 0 { + line = append(prefix, line...) + prefix = nil + } + commentIndex := strings.IndexRune(string(line), '#') + if commentIndex != -1 { + line = line[:commentIndex] + } + fields := strings.Fields(string(line)) + if len(fields) < 2 { + continue + } + var addr netip.Addr + addr, err = netip.ParseAddr(fields[0]) + if err != nil { + continue + } + for index := 1; index < len(fields); index++ { + canonicalName := dns.CanonicalName(fields[index]) + byName[canonicalName] = append(byName[canonicalName], addr) + } + } + f.expire = now.Add(cacheMaxAge) + f.modTime = stat.ModTime() + f.size = stat.Size() + f.byName = byName +} diff --git a/dns/transport/hosts/hosts_test.go b/dns/transport/hosts/hosts_test.go new file mode 100644 index 00000000..944aa437 --- /dev/null +++ b/dns/transport/hosts/hosts_test.go @@ -0,0 +1,16 @@ +package hosts_test + +import ( + "net/netip" + "testing" + + "github.com/sagernet/sing-box/dns/transport/hosts" + + "github.com/stretchr/testify/require" +) + +func TestHosts(t *testing.T) { + t.Parallel() + require.Equal(t, []netip.Addr{netip.AddrFrom4([4]byte{127, 0, 0, 1}), netip.IPv6Loopback()}, hosts.NewFile("testdata/hosts").Lookup("localhost.")) + require.NotEmpty(t, hosts.NewFile(hosts.DefaultPath).Lookup("localhost.")) +} diff --git a/dns/transport/hosts/hosts_unix.go b/dns/transport/hosts/hosts_unix.go new file mode 100644 index 00000000..4caed8b4 --- /dev/null +++ b/dns/transport/hosts/hosts_unix.go @@ -0,0 +1,5 @@ +//go:build !windows + +package hosts + +var DefaultPath = "/etc/hosts" diff --git a/dns/transport/hosts/hosts_windows.go b/dns/transport/hosts/hosts_windows.go new file mode 100644 index 00000000..3144e50d --- /dev/null +++ b/dns/transport/hosts/hosts_windows.go @@ -0,0 +1,17 @@ +package hosts + +import ( + "path/filepath" + + "golang.org/x/sys/windows" +) + +var DefaultPath string + +func init() { + systemDirectory, err := windows.GetSystemDirectory() + if err != nil { + systemDirectory = "C:\\Windows\\System32" + } + DefaultPath = filepath.Join(systemDirectory, "Drivers/etc/hosts") +} diff --git a/dns/transport/hosts/testdata/hosts b/dns/transport/hosts/testdata/hosts new file mode 100644 index 00000000..9ddcc8c1 --- /dev/null +++ b/dns/transport/hosts/testdata/hosts @@ -0,0 +1,2 @@ +127.0.0.1 localhost +::1 localhost diff --git a/dns/transport/https.go b/dns/transport/https.go new file mode 100644 index 00000000..1cfb2574 --- /dev/null +++ b/dns/transport/https.go @@ -0,0 +1,204 @@ +package transport + +import ( + "bytes" + "context" + "io" + "net" + "net/http" + "net/url" + "strconv" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + aTLS "github.com/sagernet/sing/common/tls" + sHTTP "github.com/sagernet/sing/protocol/http" + + mDNS "github.com/miekg/dns" + "golang.org/x/net/http2" +) + +const MimeType = "application/dns-message" + +var _ adapter.DNSTransport = (*HTTPSTransport)(nil) + +func RegisterHTTPS(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.RemoteHTTPSDNSServerOptions](registry, C.DNSTypeHTTPS, NewHTTPS) +} + +type HTTPSTransport struct { + dns.TransportAdapter + logger logger.ContextLogger + dialer N.Dialer + destination *url.URL + headers http.Header + transport *http.Transport +} + +func NewHTTPS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) { + transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions) + if err != nil { + return nil, err + } + tlsOptions := common.PtrValueOrDefault(options.TLS) + tlsOptions.Enabled = true + tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions) + if err != nil { + return nil, err + } + if common.Error(tlsConfig.Config()) == nil && !common.Contains(tlsConfig.NextProtos(), http2.NextProtoTLS) { + tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), http2.NextProtoTLS)) + } + if !common.Contains(tlsConfig.NextProtos(), "http/1.1") { + tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), "http/1.1")) + } + headers := options.Headers.Build() + host := headers.Get("Host") + if host != "" { + headers.Del("Host") + } else { + if tlsConfig.ServerName() != "" { + host = tlsConfig.ServerName() + } else { + host = options.Server + } + } + destinationURL := url.URL{ + Scheme: "https", + Host: host, + } + if destinationURL.Host == "" { + destinationURL.Host = options.Server + } + if options.ServerPort != 0 && options.ServerPort != 443 { + destinationURL.Host = net.JoinHostPort(destinationURL.Host, strconv.Itoa(int(options.ServerPort))) + } + path := options.Path + if path == "" { + path = "/dns-query" + } + err = sHTTP.URLSetPath(&destinationURL, path) + if err != nil { + return nil, err + } + serverAddr := options.ServerOptions.Build() + if serverAddr.Port == 0 { + serverAddr.Port = 443 + } + return NewHTTPSRaw( + dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTPS, tag, options.RemoteDNSServerOptions), + logger, + transportDialer, + &destinationURL, + headers, + serverAddr, + tlsConfig, + ), nil +} + +func NewHTTPSRaw( + adapter dns.TransportAdapter, + logger log.ContextLogger, + dialer N.Dialer, + destination *url.URL, + headers http.Header, + serverAddr M.Socksaddr, + tlsConfig tls.Config, +) *HTTPSTransport { + var transport *http.Transport + if tlsConfig != nil { + transport = &http.Transport{ + ForceAttemptHTTP2: true, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + tcpConn, hErr := dialer.DialContext(ctx, network, serverAddr) + if hErr != nil { + return nil, hErr + } + tlsConn, hErr := aTLS.ClientHandshake(ctx, tcpConn, tlsConfig) + if hErr != nil { + tcpConn.Close() + return nil, hErr + } + return tlsConn, nil + }, + } + } else { + transport = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, network, serverAddr) + }, + } + } + return &HTTPSTransport{ + TransportAdapter: adapter, + logger: logger, + dialer: dialer, + destination: destination, + headers: headers, + transport: transport, + } +} + +func (t *HTTPSTransport) Reset() { + t.transport.CloseIdleConnections() + t.transport = t.transport.Clone() +} + +func (t *HTTPSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + exMessage := *message + exMessage.Id = 0 + exMessage.Compress = true + requestBuffer := buf.NewSize(1 + message.Len()) + rawMessage, err := exMessage.PackBuffer(requestBuffer.FreeBytes()) + if err != nil { + requestBuffer.Release() + return nil, err + } + request, err := http.NewRequestWithContext(ctx, http.MethodPost, t.destination.String(), bytes.NewReader(rawMessage)) + if err != nil { + requestBuffer.Release() + return nil, err + } + request.Header = t.headers.Clone() + request.Header.Set("Content-Type", MimeType) + request.Header.Set("Accept", MimeType) + response, err := t.transport.RoundTrip(request) + requestBuffer.Release() + if err != nil { + return nil, err + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return nil, E.New("unexpected status: ", response.Status) + } + var responseMessage mDNS.Msg + if response.ContentLength > 0 { + responseBuffer := buf.NewSize(int(response.ContentLength)) + _, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength)) + if err != nil { + return nil, err + } + err = responseMessage.Unpack(responseBuffer.Bytes()) + responseBuffer.Release() + } else { + rawMessage, err = io.ReadAll(response.Body) + if err != nil { + return nil, err + } + err = responseMessage.Unpack(rawMessage) + } + if err != nil { + return nil, err + } + return &responseMessage, nil +} diff --git a/dns/transport/local/local.go b/dns/transport/local/local.go new file mode 100644 index 00000000..e5e8aef9 --- /dev/null +++ b/dns/transport/local/local.go @@ -0,0 +1,197 @@ +package local + +import ( + "context" + "math/rand" + "time" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport/hosts" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + mDNS "github.com/miekg/dns" +) + +func RegisterTransport(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.LocalDNSServerOptions](registry, C.DNSTypeLocal, NewTransport) +} + +var _ adapter.DNSTransport = (*Transport)(nil) + +type Transport struct { + dns.TransportAdapter + hosts *hosts.File + dialer N.Dialer +} + +func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) { + transportDialer, err := dns.NewLocalDialer(ctx, options) + if err != nil { + return nil, err + } + return &Transport{ + TransportAdapter: dns.NewTransportAdapterWithLocalOptions(C.DNSTypeLocal, tag, options), + hosts: hosts.NewFile(hosts.DefaultPath), + dialer: transportDialer, + }, nil +} + +func (t *Transport) Reset() { +} + +func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + question := message.Question[0] + domain := dns.FqdnToDomain(question.Name) + if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA { + addresses := t.hosts.Lookup(domain) + if len(addresses) > 0 { + return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil + } + } + systemConfig := getSystemDNSConfig() + if systemConfig.singleRequest || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) { + return t.exchangeSingleRequest(ctx, systemConfig, message, domain) + } else { + return t.exchangeParallel(ctx, systemConfig, message, domain) + } +} + +func (t *Transport) exchangeSingleRequest(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) { + var lastErr error + for _, fqdn := range systemConfig.nameList(domain) { + response, err := t.tryOneName(ctx, systemConfig, fqdn, message) + if err != nil { + lastErr = err + continue + } + return response, nil + } + return nil, lastErr +} + +func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) { + returned := make(chan struct{}) + defer close(returned) + type queryResult struct { + response *mDNS.Msg + err error + } + results := make(chan queryResult) + startRacer := func(ctx context.Context, fqdn string) { + response, err := t.tryOneName(ctx, systemConfig, fqdn, message) + if err == nil { + addresses, _ := dns.MessageToAddresses(response) + if len(addresses) == 0 { + err = E.New(fqdn, ": empty result") + } + } + select { + case results <- queryResult{response, err}: + case <-returned: + } + } + queryCtx, queryCancel := context.WithCancel(ctx) + defer queryCancel() + var nameCount int + for _, fqdn := range systemConfig.nameList(domain) { + nameCount++ + go startRacer(queryCtx, fqdn) + } + var errors []error + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-results: + if result.err == nil { + return result.response, nil + } + errors = append(errors, result.err) + if len(errors) == nameCount { + return nil, E.Errors(errors...) + } + } + } +} + +func (t *Transport) tryOneName(ctx context.Context, config *dnsConfig, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) { + serverOffset := config.serverOffset() + sLen := uint32(len(config.servers)) + var lastErr error + for i := 0; i < config.attempts; i++ { + for j := uint32(0); j < sLen; j++ { + server := config.servers[(serverOffset+j)%sLen] + question := message.Question[0] + question.Name = fqdn + response, err := t.exchangeOne(ctx, M.ParseSocksaddr(server), question, config.timeout, config.useTCP, config.trustAD) + if err != nil { + lastErr = err + continue + } + return response, nil + } + } + return nil, E.Cause(lastErr, fqdn) +} + +func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) { + var networks []string + if useTCP { + networks = []string{N.NetworkTCP} + } else { + networks = []string{N.NetworkUDP, N.NetworkTCP} + } + request := &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: uint16(rand.Uint32()), + RecursionDesired: true, + AuthenticatedData: ad, + }, + Question: []mDNS.Question{question}, + Compress: true, + } + request.SetEdns0(maxDNSPacketSize, false) + buffer := buf.Get(buf.UDPBufferSize) + defer buf.Put(buffer) + for _, network := range networks { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + conn, err := t.dialer.DialContext(ctx, network, server) + if err != nil { + return nil, err + } + defer conn.Close() + if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() { + conn.SetDeadline(deadline) + } + rawMessage, err := request.PackBuffer(buffer) + if err != nil { + return nil, E.Cause(err, "pack request") + } + _, err = conn.Write(rawMessage) + if err != nil { + return nil, E.Cause(err, "write request") + } + n, err := conn.Read(buffer) + if err != nil { + return nil, E.Cause(err, "read response") + } + var response mDNS.Msg + err = response.Unpack(buffer[:n]) + if err != nil { + return nil, E.Cause(err, "unpack response") + } + if response.Truncated && network == N.NetworkUDP { + continue + } + return &response, nil + } + panic("unexpected") +} diff --git a/dns/transport/local/resolv.go b/dns/transport/local/resolv.go new file mode 100644 index 00000000..5484d0ec --- /dev/null +++ b/dns/transport/local/resolv.go @@ -0,0 +1,146 @@ +package local + +import ( + "os" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" +) + +const ( + // net.maxDNSPacketSize + maxDNSPacketSize = 1232 +) + +type resolverConfig struct { + initOnce sync.Once + ch chan struct{} + lastChecked time.Time + dnsConfig atomic.Pointer[dnsConfig] +} + +var resolvConf resolverConfig + +func getSystemDNSConfig() *dnsConfig { + resolvConf.tryUpdate("/etc/resolv.conf") + return resolvConf.dnsConfig.Load() +} + +func (conf *resolverConfig) init() { + conf.dnsConfig.Store(dnsReadConfig("/etc/resolv.conf")) + conf.lastChecked = time.Now() + conf.ch = make(chan struct{}, 1) +} + +func (conf *resolverConfig) tryUpdate(name string) { + conf.initOnce.Do(conf.init) + + if conf.dnsConfig.Load().noReload { + return + } + if !conf.tryAcquireSema() { + return + } + defer conf.releaseSema() + + now := time.Now() + if conf.lastChecked.After(now.Add(-5 * time.Second)) { + return + } + conf.lastChecked = now + if runtime.GOOS != "windows" { + var mtime time.Time + if fi, err := os.Stat(name); err == nil { + mtime = fi.ModTime() + } + if mtime.Equal(conf.dnsConfig.Load().mtime) { + return + } + } + dnsConf := dnsReadConfig(name) + conf.dnsConfig.Store(dnsConf) +} + +func (conf *resolverConfig) tryAcquireSema() bool { + select { + case conf.ch <- struct{}{}: + return true + default: + return false + } +} + +func (conf *resolverConfig) releaseSema() { + <-conf.ch +} + +type dnsConfig struct { + servers []string + search []string + ndots int + timeout time.Duration + attempts int + rotate bool + unknownOpt bool + lookup []string + err error + mtime time.Time + soffset uint32 + singleRequest bool + useTCP bool + trustAD bool + noReload bool +} + +func (c *dnsConfig) serverOffset() uint32 { + if c.rotate { + return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start + } + return 0 +} + +func (conf *dnsConfig) nameList(name string) []string { + l := len(name) + rooted := l > 0 && name[l-1] == '.' + if l > 254 || l == 254 && !rooted { + return nil + } + + if rooted { + if avoidDNS(name) { + return nil + } + return []string{name} + } + + hasNdots := strings.Count(name, ".") >= conf.ndots + name += "." + // l++ + + names := make([]string, 0, 1+len(conf.search)) + if hasNdots && !avoidDNS(name) { + names = append(names, name) + } + for _, suffix := range conf.search { + fqdn := name + suffix + if !avoidDNS(fqdn) && len(fqdn) <= 254 { + names = append(names, fqdn) + } + } + if !hasNdots && !avoidDNS(name) { + names = append(names, name) + } + return names +} + +func avoidDNS(name string) bool { + if name == "" { + return true + } + if name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return strings.HasSuffix(name, ".onion") +} diff --git a/dns/transport/local/resolv_unix.go b/dns/transport/local/resolv_unix.go new file mode 100644 index 00000000..6594ae41 --- /dev/null +++ b/dns/transport/local/resolv_unix.go @@ -0,0 +1,175 @@ +//go:build !windows + +package local + +import ( + "bufio" + "net" + "net/netip" + "os" + "strings" + "time" + _ "unsafe" +) + +func dnsReadConfig(name string) *dnsConfig { + conf := &dnsConfig{ + ndots: 1, + timeout: 5 * time.Second, + attempts: 2, + } + file, err := os.Open(name) + if err != nil { + conf.servers = defaultNS + conf.search = dnsDefaultSearch() + conf.err = err + return conf + } + defer file.Close() + fi, err := file.Stat() + if err == nil { + conf.mtime = fi.ModTime() + } else { + conf.servers = defaultNS + conf.search = dnsDefaultSearch() + conf.err = err + return conf + } + reader := bufio.NewReader(file) + var ( + prefix []byte + line []byte + isPrefix bool + ) + for { + line, isPrefix, err = reader.ReadLine() + if err != nil { + break + } + if isPrefix { + prefix = append(prefix, line...) + continue + } else if len(prefix) > 0 { + line = append(prefix, line...) + prefix = nil + } + if len(line) > 0 && (line[0] == ';' || line[0] == '#') { + continue + } + f := strings.Fields(string(line)) + if len(f) < 1 { + continue + } + switch f[0] { + case "nameserver": + if len(f) > 1 && len(conf.servers) < 3 { + if _, err := netip.ParseAddr(f[1]); err == nil { + conf.servers = append(conf.servers, net.JoinHostPort(f[1], "53")) + } + } + case "domain": + if len(f) > 1 { + conf.search = []string{ensureRooted(f[1])} + } + + case "search": + conf.search = make([]string, 0, len(f)-1) + for i := 1; i < len(f); i++ { + name := ensureRooted(f[i]) + if name == "." { + continue + } + conf.search = append(conf.search, name) + } + + case "options": + for _, s := range f[1:] { + switch { + case strings.HasPrefix(s, "ndots:"): + n, _, _ := dtoi(s[6:]) + if n < 0 { + n = 0 + } else if n > 15 { + n = 15 + } + conf.ndots = n + case strings.HasPrefix(s, "timeout:"): + n, _, _ := dtoi(s[8:]) + if n < 1 { + n = 1 + } + conf.timeout = time.Duration(n) * time.Second + case strings.HasPrefix(s, "attempts:"): + n, _, _ := dtoi(s[9:]) + if n < 1 { + n = 1 + } + conf.attempts = n + case s == "rotate": + conf.rotate = true + case s == "single-request" || s == "single-request-reopen": + conf.singleRequest = true + case s == "use-vc" || s == "usevc" || s == "tcp": + conf.useTCP = true + case s == "trust-ad": + conf.trustAD = true + case s == "edns0": + case s == "no-reload": + conf.noReload = true + default: + conf.unknownOpt = true + } + } + + case "lookup": + conf.lookup = f[1:] + + default: + conf.unknownOpt = true + } + } + if len(conf.servers) == 0 { + conf.servers = defaultNS + } + if len(conf.search) == 0 { + conf.search = dnsDefaultSearch() + } + return conf +} + +//go:linkname defaultNS net.defaultNS +var defaultNS []string + +func dnsDefaultSearch() []string { + hn, err := os.Hostname() + if err != nil { + return nil + } + if i := strings.IndexRune(hn, '.'); i >= 0 && i < len(hn)-1 { + return []string{ensureRooted(hn[i+1:])} + } + return nil +} + +func ensureRooted(s string) string { + if len(s) > 0 && s[len(s)-1] == '.' { + return s + } + return s + "." +} + +const big = 0xFFFFFF + +func dtoi(s string) (n int, i int, ok bool) { + n = 0 + for i = 0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ { + n = n*10 + int(s[i]-'0') + if n >= big { + return big, i, false + } + } + if i == 0 { + return 0, 0, false + } + return n, i, true +} diff --git a/dns/transport/local/resolv_windows.go b/dns/transport/local/resolv_windows.go new file mode 100644 index 00000000..577e7a12 --- /dev/null +++ b/dns/transport/local/resolv_windows.go @@ -0,0 +1,100 @@ +package local + +import ( + "net" + "net/netip" + "os" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +func dnsReadConfig(_ string) *dnsConfig { + conf := &dnsConfig{ + ndots: 1, + timeout: 5 * time.Second, + attempts: 2, + } + defer func() { + if len(conf.servers) == 0 { + conf.servers = defaultNS + } + }() + aas, err := adapterAddresses() + if err != nil { + return nil + } + + for _, aa := range aas { + // Only take interfaces whose OperStatus is IfOperStatusUp(0x01) into DNS configs. + if aa.OperStatus != windows.IfOperStatusUp { + continue + } + + // Only take interfaces which have at least one gateway + if aa.FirstGatewayAddress == nil { + continue + } + + for dns := aa.FirstDnsServerAddress; dns != nil; dns = dns.Next { + sa, err := dns.Address.Sockaddr.Sockaddr() + if err != nil { + continue + } + var ip netip.Addr + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + ip = netip.AddrFrom4([4]byte{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}) + case *syscall.SockaddrInet6: + var addr16 [16]byte + copy(addr16[:], sa.Addr[:]) + if addr16[0] == 0xfe && addr16[1] == 0xc0 { + // fec0/10 IPv6 addresses are site local anycast DNS + // addresses Microsoft sets by default if no other + // IPv6 DNS address is set. Site local anycast is + // deprecated since 2004, see + // https://datatracker.ietf.org/doc/html/rfc3879 + continue + } + ip = netip.AddrFrom16(addr16) + default: + // Unexpected type. + continue + } + conf.servers = append(conf.servers, net.JoinHostPort(ip.String(), "53")) + } + } + return conf +} + +//go:linkname defaultNS net.defaultNS +var defaultNS []string + +func adapterAddresses() ([]*windows.IpAdapterAddresses, error) { + var b []byte + l := uint32(15000) // recommended initial size + for { + b = make([]byte, l) + const flags = windows.GAA_FLAG_INCLUDE_PREFIX | windows.GAA_FLAG_INCLUDE_GATEWAYS + err := windows.GetAdaptersAddresses(syscall.AF_UNSPEC, flags, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l) + if err == nil { + if l == 0 { + return nil, nil + } + break + } + if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW { + return nil, os.NewSyscallError("getadaptersaddresses", err) + } + if l <= uint32(len(b)) { + return nil, os.NewSyscallError("getadaptersaddresses", err) + } + } + var aas []*windows.IpAdapterAddresses + for aa := (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next { + aas = append(aas, aa) + } + return aas, nil +} diff --git a/dns/transport/predefined.go b/dns/transport/predefined.go new file mode 100644 index 00000000..3f112886 --- /dev/null +++ b/dns/transport/predefined.go @@ -0,0 +1,83 @@ +package transport + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + + mDNS "github.com/miekg/dns" +) + +var _ adapter.DNSTransport = (*PredefinedTransport)(nil) + +func RegisterPredefined(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.PredefinedDNSServerOptions](registry, C.DNSTypePreDefined, NewPredefined) +} + +type PredefinedTransport struct { + dns.TransportAdapter + responses []*predefinedResponse +} + +type predefinedResponse struct { + questions []mDNS.Question + answer *mDNS.Msg +} + +func NewPredefined(ctx context.Context, logger log.ContextLogger, tag string, options option.PredefinedDNSServerOptions) (adapter.DNSTransport, error) { + var responses []*predefinedResponse + for _, response := range options.Responses { + questions, msg, err := response.Build() + if err != nil { + return nil, err + } + responses = append(responses, &predefinedResponse{ + questions: questions, + answer: msg, + }) + } + if len(responses) == 0 { + return nil, E.New("empty predefined responses") + } + return &PredefinedTransport{ + TransportAdapter: dns.NewTransportAdapter(C.DNSTypePreDefined, tag, nil), + responses: responses, + }, nil +} + +func (t *PredefinedTransport) Reset() { +} + +func (t *PredefinedTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + for _, response := range t.responses { + for _, question := range response.questions { + if func() bool { + if question.Name == "" && question.Qtype == mDNS.TypeNone { + return true + } else if question.Name == "" { + return common.Any(message.Question, func(it mDNS.Question) bool { + return it.Qtype == question.Qtype + }) + } else if question.Qtype == mDNS.TypeNone { + return common.Any(message.Question, func(it mDNS.Question) bool { + return it.Name == question.Name + }) + } else { + return common.Contains(message.Question, question) + } + }() { + copyAnswer := *response.answer + copyAnswer.Id = message.Id + copyAnswer.Question = message.Question + return ©Answer, nil + } + } + } + return nil, dns.RCodeNameError +} diff --git a/dns/transport/quic/http3.go b/dns/transport/quic/http3.go new file mode 100644 index 00000000..a5181ae0 --- /dev/null +++ b/dns/transport/quic/http3.go @@ -0,0 +1,167 @@ +package quic + +import ( + "bytes" + "context" + "io" + "net" + "net/http" + "net/url" + "strconv" + + "github.com/sagernet/quic-go" + "github.com/sagernet/quic-go/http3" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + sHTTP "github.com/sagernet/sing/protocol/http" + + mDNS "github.com/miekg/dns" +) + +var _ adapter.DNSTransport = (*HTTP3Transport)(nil) + +func RegisterHTTP3Transport(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.RemoteHTTPSDNSServerOptions](registry, C.DNSTypeHTTP3, NewHTTP3) +} + +type HTTP3Transport struct { + dns.TransportAdapter + logger logger.ContextLogger + dialer N.Dialer + destination *url.URL + headers http.Header + transport *http3.Transport +} + +func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) { + transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions) + if err != nil { + return nil, err + } + tlsOptions := common.PtrValueOrDefault(options.TLS) + tlsOptions.Enabled = true + tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions) + if err != nil { + return nil, err + } + stdConfig, err := tlsConfig.Config() + if err != nil { + return nil, err + } + headers := options.Headers.Build() + host := headers.Get("Host") + if host != "" { + headers.Del("Host") + } else { + if tlsConfig.ServerName() != "" { + host = tlsConfig.ServerName() + } else { + host = options.Server + } + } + destinationURL := url.URL{ + Scheme: "https", + Host: host, + } + if destinationURL.Host == "" { + destinationURL.Host = options.Server + } + if options.ServerPort != 0 && options.ServerPort != 443 { + destinationURL.Host = net.JoinHostPort(destinationURL.Host, strconv.Itoa(int(options.ServerPort))) + } + path := options.Path + if path == "" { + path = "/dns-query" + } + err = sHTTP.URLSetPath(&destinationURL, path) + if err != nil { + return nil, err + } + serverAddr := options.ServerOptions.Build() + if serverAddr.Port == 0 { + serverAddr.Port = 443 + } + return &HTTP3Transport{ + TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTP3, tag, options.RemoteDNSServerOptions), + logger: logger, + dialer: transportDialer, + destination: &destinationURL, + headers: headers, + transport: &http3.Transport{ + Dial: func(ctx context.Context, addr string, tlsCfg *tls.STDConfig, cfg *quic.Config) (quic.EarlyConnection, error) { + destinationAddr := M.ParseSocksaddr(addr) + conn, dialErr := transportDialer.DialContext(ctx, N.NetworkUDP, destinationAddr) + if dialErr != nil { + return nil, dialErr + } + return quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsCfg, cfg) + }, + TLSClientConfig: stdConfig, + }, + }, nil +} + +func (t *HTTP3Transport) Reset() { + t.transport.Close() +} + +func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + exMessage := *message + exMessage.Id = 0 + exMessage.Compress = true + requestBuffer := buf.NewSize(1 + message.Len()) + rawMessage, err := exMessage.PackBuffer(requestBuffer.FreeBytes()) + if err != nil { + requestBuffer.Release() + return nil, err + } + request, err := http.NewRequestWithContext(ctx, http.MethodPost, t.destination.String(), bytes.NewReader(rawMessage)) + if err != nil { + requestBuffer.Release() + return nil, err + } + request.Header = t.headers.Clone() + request.Header.Set("Content-Type", transport.MimeType) + request.Header.Set("Accept", transport.MimeType) + response, err := t.transport.RoundTrip(request) + requestBuffer.Release() + if err != nil { + return nil, err + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return nil, E.New("unexpected status: ", response.Status) + } + var responseMessage mDNS.Msg + if response.ContentLength > 0 { + responseBuffer := buf.NewSize(int(response.ContentLength)) + _, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength)) + if err != nil { + return nil, err + } + err = responseMessage.Unpack(responseBuffer.Bytes()) + responseBuffer.Release() + } else { + rawMessage, err = io.ReadAll(response.Body) + if err != nil { + return nil, err + } + err = responseMessage.Unpack(rawMessage) + } + if err != nil { + return nil, err + } + return &responseMessage, nil +} diff --git a/dns/transport/quic/quic.go b/dns/transport/quic/quic.go new file mode 100644 index 00000000..d3844c2b --- /dev/null +++ b/dns/transport/quic/quic.go @@ -0,0 +1,174 @@ +package quic + +import ( + "context" + "errors" + "sync" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + sQUIC "github.com/sagernet/sing-quic" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + mDNS "github.com/miekg/dns" +) + +var _ adapter.DNSTransport = (*Transport)(nil) + +func RegisterTransport(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeQUIC, NewQUIC) +} + +type Transport struct { + dns.TransportAdapter + ctx context.Context + logger logger.ContextLogger + dialer N.Dialer + serverAddr M.Socksaddr + tlsConfig tls.Config + access sync.Mutex + connection quic.EarlyConnection +} + +func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) { + transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions) + if err != nil { + return nil, err + } + tlsOptions := common.PtrValueOrDefault(options.TLS) + tlsOptions.Enabled = true + tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions) + if err != nil { + return nil, err + } + if len(tlsConfig.NextProtos()) == 0 { + tlsConfig.SetNextProtos([]string{"doq"}) + } + serverAddr := options.ServerOptions.Build() + if serverAddr.Port == 0 { + serverAddr.Port = 853 + } + return &Transport{ + TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions), + ctx: ctx, + logger: logger, + dialer: transportDialer, + serverAddr: serverAddr, + tlsConfig: tlsConfig, + }, nil +} + +func (t *Transport) Reset() { + t.access.Lock() + defer t.access.Unlock() + connection := t.connection + if connection != nil { + connection.CloseWithError(0, "") + } +} + +func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + var ( + conn quic.Connection + err error + response *mDNS.Msg + ) + for i := 0; i < 2; i++ { + conn, err = t.openConnection() + if err != nil { + return nil, err + } + response, err = t.exchange(ctx, message, conn) + if err == nil { + return response, nil + } else if !isQUICRetryError(err) { + return nil, err + } else { + conn.CloseWithError(quic.ApplicationErrorCode(0), "") + continue + } + } + return nil, err +} + +func (t *Transport) openConnection() (quic.EarlyConnection, error) { + connection := t.connection + if connection != nil && !common.Done(connection.Context()) { + return connection, nil + } + t.access.Lock() + defer t.access.Unlock() + connection = t.connection + if connection != nil && !common.Done(connection.Context()) { + return connection, nil + } + conn, err := t.dialer.DialContext(t.ctx, N.NetworkUDP, t.serverAddr) + if err != nil { + return nil, err + } + earlyConnection, err := sQUIC.DialEarly( + t.ctx, + bufio.NewUnbindPacketConn(conn), + t.serverAddr.UDPAddr(), + t.tlsConfig, + nil, + ) + if err != nil { + return nil, err + } + t.connection = earlyConnection + return earlyConnection, nil +} + +func (t *Transport) exchange(ctx context.Context, message *mDNS.Msg, conn quic.Connection) (*mDNS.Msg, error) { + stream, err := conn.OpenStreamSync(ctx) + if err != nil { + return nil, err + } + defer stream.Close() + defer stream.CancelRead(0) + err = transport.WriteMessage(stream, 0, message) + if err != nil { + return nil, err + } + return transport.ReadMessage(stream) +} + +// https://github.com/AdguardTeam/dnsproxy/blob/fd1868577652c639cce3da00e12ca548f421baf1/upstream/upstream_quic.go#L394 +func isQUICRetryError(err error) (ok bool) { + var qAppErr *quic.ApplicationError + if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 0 { + return true + } + + var qIdleErr *quic.IdleTimeoutError + if errors.As(err, &qIdleErr) { + return true + } + + var resetErr *quic.StatelessResetError + if errors.As(err, &resetErr) { + return true + } + + var qTransportError *quic.TransportError + if errors.As(err, &qTransportError) && qTransportError.ErrorCode == quic.NoError { + return true + } + + if errors.Is(err, quic.Err0RTTRejected) { + return true + } + + return false +} diff --git a/dns/transport/tcp.go b/dns/transport/tcp.go new file mode 100644 index 00000000..6061585e --- /dev/null +++ b/dns/transport/tcp.go @@ -0,0 +1,99 @@ +package transport + +import ( + "context" + "encoding/binary" + "io" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + mDNS "github.com/miekg/dns" +) + +var _ adapter.DNSTransport = (*TCPTransport)(nil) + +func RegisterTCP(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeTCP, NewTCP) +} + +type TCPTransport struct { + dns.TransportAdapter + dialer N.Dialer + serverAddr M.Socksaddr +} + +func NewTCP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) { + transportDialer, err := dns.NewRemoteDialer(ctx, options) + if err != nil { + return nil, err + } + serverAddr := options.ServerOptions.Build() + if serverAddr.Port == 0 { + serverAddr.Port = 53 + } + return &TCPTransport{ + TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTCP, tag, options), + dialer: transportDialer, + serverAddr: serverAddr, + }, nil +} + +func (t *TCPTransport) Reset() { +} + +func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr) + if err != nil { + return nil, err + } + defer conn.Close() + err = WriteMessage(conn, 0, message) + if err != nil { + return nil, err + } + return ReadMessage(conn) +} + +func ReadMessage(reader io.Reader) (*mDNS.Msg, error) { + var responseLen uint16 + err := binary.Read(reader, binary.BigEndian, &responseLen) + if err != nil { + return nil, err + } + if responseLen < 10 { + return nil, mDNS.ErrShortRead + } + buffer := buf.NewSize(int(responseLen)) + defer buffer.Release() + _, err = buffer.ReadFullFrom(reader, int(responseLen)) + if err != nil { + return nil, err + } + var message mDNS.Msg + err = message.Unpack(buffer.Bytes()) + return &message, err +} + +func WriteMessage(writer io.Writer, messageId uint16, message *mDNS.Msg) error { + requestLen := message.Len() + buffer := buf.NewSize(3 + requestLen) + defer buffer.Release() + common.Must(binary.Write(buffer, binary.BigEndian, uint16(requestLen))) + exMessage := *message + exMessage.Id = messageId + exMessage.Compress = true + rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes()) + if err != nil { + return err + } + buffer.Truncate(2 + len(rawMessage)) + return common.Error(writer.Write(buffer.Bytes())) +} diff --git a/dns/transport/tls.go b/dns/transport/tls.go new file mode 100644 index 00000000..28fa885a --- /dev/null +++ b/dns/transport/tls.go @@ -0,0 +1,115 @@ +package transport + +import ( + "context" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/x/list" + + mDNS "github.com/miekg/dns" +) + +var _ adapter.DNSTransport = (*TLSTransport)(nil) + +func RegisterTLS(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeTLS, NewTLS) +} + +type TLSTransport struct { + dns.TransportAdapter + logger logger.ContextLogger + dialer N.Dialer + serverAddr M.Socksaddr + tlsConfig tls.Config + access sync.Mutex + connections list.List[*tlsDNSConn] +} + +type tlsDNSConn struct { + tls.Conn + queryId uint16 +} + +func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) { + transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions) + if err != nil { + return nil, err + } + tlsOptions := common.PtrValueOrDefault(options.TLS) + tlsOptions.Enabled = true + tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions) + if err != nil { + return nil, err + } + serverAddr := options.ServerOptions.Build() + if serverAddr.Port == 0 { + serverAddr.Port = 853 + } + return &TLSTransport{ + TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTLS, tag, options.RemoteDNSServerOptions), + logger: logger, + dialer: transportDialer, + serverAddr: serverAddr, + tlsConfig: tlsConfig, + }, nil +} + +func (t *TLSTransport) Reset() { + t.access.Lock() + defer t.access.Unlock() + for connection := t.connections.Front(); connection != nil; connection = connection.Next() { + connection.Value.Close() + } + t.connections.Init() +} + +func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + t.access.Lock() + conn := t.connections.PopFront() + t.access.Unlock() + if conn != nil { + response, err := t.exchange(message, conn) + if err == nil { + return response, nil + } + } + tcpConn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr) + if err != nil { + return nil, err + } + tlsConn, err := tls.ClientHandshake(ctx, tcpConn, t.tlsConfig) + if err != nil { + tcpConn.Close() + return nil, err + } + return t.exchange(message, &tlsDNSConn{Conn: tlsConn}) +} + +func (t *TLSTransport) exchange(message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) { + conn.queryId++ + err := WriteMessage(conn, conn.queryId, message) + if err != nil { + conn.Close() + return nil, E.Cause(err, "write request") + } + response, err := ReadMessage(conn) + if err != nil { + conn.Close() + return nil, E.Cause(err, "read response") + } + t.access.Lock() + t.connections.PushBack(conn) + t.access.Unlock() + return response, nil +} diff --git a/dns/transport/udp.go b/dns/transport/udp.go new file mode 100644 index 00000000..5099c6f6 --- /dev/null +++ b/dns/transport/udp.go @@ -0,0 +1,223 @@ +package transport + +import ( + "context" + "net" + "os" + "sync" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + mDNS "github.com/miekg/dns" +) + +var _ adapter.DNSTransport = (*UDPTransport)(nil) + +func RegisterUDP(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeUDP, NewUDP) +} + +type UDPTransport struct { + dns.TransportAdapter + logger logger.ContextLogger + dialer N.Dialer + serverAddr M.Socksaddr + udpSize int + tcpTransport *TCPTransport + access sync.Mutex + conn *dnsConnection + done chan struct{} +} + +func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) { + transportDialer, err := dns.NewRemoteDialer(ctx, options) + if err != nil { + return nil, err + } + serverAddr := options.ServerOptions.Build() + if serverAddr.Port == 0 { + serverAddr.Port = 53 + } + return NewUDPRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeUDP, tag, options), transportDialer, serverAddr), nil +} + +func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr) *UDPTransport { + return &UDPTransport{ + TransportAdapter: adapter, + logger: logger, + dialer: dialer, + serverAddr: serverAddr, + udpSize: 512, + tcpTransport: &TCPTransport{ + dialer: dialer, + serverAddr: serverAddr, + }, + done: make(chan struct{}), + } +} + +func (t *UDPTransport) Reset() { + t.access.Lock() + defer t.access.Unlock() + close(t.done) + t.done = make(chan struct{}) +} + +func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + response, err := t.exchange(ctx, message) + if err != nil { + return nil, err + } + if response.Truncated { + t.logger.InfoContext(ctx, "response truncated, retrying with TCP") + return t.tcpTransport.Exchange(ctx, message) + } + return response, nil +} + +func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + conn, err := t.open(ctx) + if err != nil { + return nil, err + } + if edns0Opt := message.IsEdns0(); edns0Opt != nil { + if udpSize := int(edns0Opt.UDPSize()); udpSize > t.udpSize { + t.udpSize = udpSize + } + } + buffer := buf.NewSize(1 + message.Len()) + defer buffer.Release() + exMessage := *message + exMessage.Compress = true + messageId := message.Id + callback := &dnsCallback{ + done: make(chan struct{}), + } + conn.access.Lock() + conn.queryId++ + exMessage.Id = conn.queryId + conn.callbacks[exMessage.Id] = callback + conn.access.Unlock() + defer func() { + conn.access.Lock() + delete(conn.callbacks, messageId) + conn.access.Unlock() + callback.access.Lock() + select { + case <-callback.done: + default: + close(callback.done) + } + callback.access.Unlock() + }() + rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes()) + if err != nil { + return nil, err + } + _, err = conn.Write(rawMessage) + if err != nil { + conn.Close(err) + return nil, err + } + select { + case <-callback.done: + callback.message.Id = messageId + return callback.message, nil + case <-conn.done: + return nil, conn.err + case <-t.done: + return nil, os.ErrClosed + case <-ctx.Done(): + conn.Close(ctx.Err()) + return nil, ctx.Err() + } +} + +func (t *UDPTransport) open(ctx context.Context) (*dnsConnection, error) { + t.access.Lock() + defer t.access.Unlock() + if t.conn != nil { + select { + case <-t.conn.done: + default: + return t.conn, nil + } + } + conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) + if err != nil { + return nil, err + } + dnsConn := &dnsConnection{ + Conn: conn, + done: make(chan struct{}), + callbacks: make(map[uint16]*dnsCallback), + } + go t.recvLoop(dnsConn) + t.conn = dnsConn + return dnsConn, nil +} + +func (t *UDPTransport) recvLoop(conn *dnsConnection) { + for { + buffer := buf.NewSize(t.udpSize) + _, err := buffer.ReadOnceFrom(conn) + if err != nil { + buffer.Release() + conn.Close(err) + return + } + var message mDNS.Msg + err = message.Unpack(buffer.Bytes()) + buffer.Release() + if err != nil { + conn.Close(err) + return + } + conn.access.RLock() + callback, loaded := conn.callbacks[message.Id] + conn.access.RUnlock() + if !loaded { + continue + } + callback.access.Lock() + select { + case <-callback.done: + default: + callback.message = &message + close(callback.done) + } + callback.access.Unlock() + } +} + +type dnsConnection struct { + net.Conn + access sync.RWMutex + done chan struct{} + closeOnce sync.Once + err error + queryId uint16 + callbacks map[uint16]*dnsCallback +} + +func (c *dnsConnection) Close(err error) { + c.closeOnce.Do(func() { + close(c.done) + c.err = err + }) + c.Conn.Close() +} + +type dnsCallback struct { + access sync.Mutex + message *mDNS.Msg + done chan struct{} +} diff --git a/dns/transport_adapter.go b/dns/transport_adapter.go new file mode 100644 index 00000000..02c84621 --- /dev/null +++ b/dns/transport_adapter.go @@ -0,0 +1,70 @@ +package dns + +import ( + "net/netip" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" +) + +var _ adapter.LegacyDNSTransport = (*TransportAdapter)(nil) + +type TransportAdapter struct { + transportType string + transportTag string + dependencies []string + strategy C.DomainStrategy + clientSubnet netip.Prefix +} + +func NewTransportAdapter(transportType string, transportTag string, dependencies []string) TransportAdapter { + return TransportAdapter{ + transportType: transportType, + transportTag: transportTag, + dependencies: dependencies, + } +} + +func NewTransportAdapterWithLocalOptions(transportType string, transportTag string, localOptions option.LocalDNSServerOptions) TransportAdapter { + return TransportAdapter{ + transportType: transportType, + transportTag: transportTag, + strategy: C.DomainStrategy(localOptions.LegacyStrategy), + clientSubnet: localOptions.LegacyClientSubnet, + } +} + +func NewTransportAdapterWithRemoteOptions(transportType string, transportTag string, remoteOptions option.RemoteDNSServerOptions) TransportAdapter { + var dependencies []string + if remoteOptions.AddressResolver != "" { + dependencies = []string{remoteOptions.AddressResolver} + } + return TransportAdapter{ + transportType: transportType, + transportTag: transportTag, + dependencies: dependencies, + strategy: C.DomainStrategy(remoteOptions.LegacyStrategy), + clientSubnet: remoteOptions.LegacyClientSubnet, + } +} + +func (a *TransportAdapter) Type() string { + return a.transportType +} + +func (a *TransportAdapter) Tag() string { + return a.transportTag +} + +func (a *TransportAdapter) Dependencies() []string { + return a.dependencies +} + +func (a *TransportAdapter) LegacyStrategy() C.DomainStrategy { + return a.strategy +} + +func (a *TransportAdapter) LegacyClientSubnet() netip.Prefix { + return a.clientSubnet +} diff --git a/dns/transport_dialer.go b/dns/transport_dialer.go new file mode 100644 index 00000000..14e1188d --- /dev/null +++ b/dns/transport_dialer.go @@ -0,0 +1,93 @@ +package dns + +import ( + "context" + "net" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" +) + +func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) (N.Dialer, error) { + if options.LegacyDefaultDialer { + return dialer.NewDefaultOutbound(ctx), nil + } else { + return dialer.New(ctx, options.DialerOptions) + } +} + +func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) (N.Dialer, error) { + transportDialer, err := NewLocalDialer(ctx, options.LocalDNSServerOptions) + if err != nil { + return nil, err + } + if options.AddressResolver != "" { + transport := service.FromContext[adapter.DNSTransportManager](ctx) + resolverTransport, loaded := transport.Transport(options.AddressResolver) + if !loaded { + return nil, E.New("address resolver not found: ", options.AddressResolver) + } + transportDialer = NewTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.AddressStrategy), time.Duration(options.AddressFallbackDelay)) + } else if M.IsDomainName(options.Server) { + return nil, E.New("missing address resolver for server: ", options.Server) + } + return transportDialer, nil +} + +type TransportDialer struct { + dialer N.Dialer + dnsRouter adapter.DNSRouter + transport adapter.DNSTransport + strategy C.DomainStrategy + fallbackDelay time.Duration +} + +func NewTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *TransportDialer { + return &TransportDialer{ + dialer, + dnsRouter, + transport, + strategy, + fallbackDelay, + } +} + +func (d *TransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if destination.IsIP() { + return d.dialer.DialContext(ctx, network, destination) + } + addresses, err := d.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{ + Transport: d.transport, + Strategy: d.strategy, + }) + if err != nil { + return nil, err + } + return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay) +} + +func (d *TransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + if destination.IsIP() { + return d.dialer.ListenPacket(ctx, destination) + } + addresses, err := d.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{ + Transport: d.transport, + Strategy: d.strategy, + }) + if err != nil { + return nil, err + } + conn, _, err := N.ListenSerial(ctx, d.dialer, destination, addresses) + return conn, err +} + +func (d *TransportDialer) Upstream() any { + return d.dialer +} diff --git a/dns/transport_manager.go b/dns/transport_manager.go new file mode 100644 index 00000000..4497923b --- /dev/null +++ b/dns/transport_manager.go @@ -0,0 +1,288 @@ +package dns + +import ( + "context" + "io" + "os" + "strings" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/taskmonitor" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" +) + +var _ adapter.DNSTransportManager = (*TransportManager)(nil) + +type TransportManager struct { + logger log.ContextLogger + registry adapter.DNSTransportRegistry + outbound adapter.OutboundManager + defaultTag string + access sync.RWMutex + started bool + stage adapter.StartStage + transports []adapter.DNSTransport + transportByTag map[string]adapter.DNSTransport + dependByTag map[string][]string + defaultTransport adapter.DNSTransport + defaultTransportFallback adapter.DNSTransport + fakeIPTransport adapter.FakeIPTransport +} + +func NewTransportManager(logger logger.ContextLogger, registry adapter.DNSTransportRegistry, outbound adapter.OutboundManager, defaultTag string) *TransportManager { + return &TransportManager{ + logger: logger, + registry: registry, + outbound: outbound, + defaultTag: defaultTag, + transportByTag: make(map[string]adapter.DNSTransport), + dependByTag: make(map[string][]string), + } +} + +func (m *TransportManager) Initialize(defaultTransportFallback adapter.DNSTransport) { + m.defaultTransportFallback = defaultTransportFallback +} + +func (m *TransportManager) Start(stage adapter.StartStage) error { + m.access.Lock() + if m.started && m.stage >= stage { + panic("already started") + } + m.started = true + m.stage = stage + outbounds := m.transports + m.access.Unlock() + if stage == adapter.StartStateStart { + return m.startTransports(m.transports) + } else { + for _, outbound := range outbounds { + err := adapter.LegacyStart(outbound, stage) + if err != nil { + return E.Cause(err, stage, " dns/", outbound.Type(), "[", outbound.Tag(), "]") + } + } + } + return nil +} + +func (m *TransportManager) startTransports(transports []adapter.DNSTransport) error { + monitor := taskmonitor.New(m.logger, C.StartTimeout) + started := make(map[string]bool) + for { + canContinue := false + startOne: + for _, transportToStart := range transports { + transportTag := transportToStart.Tag() + if started[transportTag] { + continue + } + dependencies := transportToStart.Dependencies() + for _, dependency := range dependencies { + if !started[dependency] { + continue startOne + } + } + started[transportTag] = true + canContinue = true + if starter, isStarter := transportToStart.(adapter.Lifecycle); isStarter { + monitor.Start("start dns/", transportToStart.Type(), "[", transportTag, "]") + err := starter.Start(adapter.StartStateStart) + monitor.Finish() + if err != nil { + return E.Cause(err, "start dns/", transportToStart.Type(), "[", transportTag, "]") + } + } + } + if len(started) == len(transports) { + break + } + if canContinue { + continue + } + currentTransport := common.Find(transports, func(it adapter.DNSTransport) bool { + return !started[it.Tag()] + }) + var lintTransport func(oTree []string, oCurrent adapter.DNSTransport) error + lintTransport = func(oTree []string, oCurrent adapter.DNSTransport) error { + problemTransportTag := common.Find(oCurrent.Dependencies(), func(it string) bool { + return !started[it] + }) + if common.Contains(oTree, problemTransportTag) { + return E.New("circular server dependency: ", strings.Join(oTree, " -> "), " -> ", problemTransportTag) + } + m.access.Lock() + problemTransport := m.transportByTag[problemTransportTag] + m.access.Unlock() + if problemTransport == nil { + return E.New("dependency[", problemTransportTag, "] not found for server[", oCurrent.Tag(), "]") + } + return lintTransport(append(oTree, problemTransportTag), problemTransport) + } + return lintTransport([]string{currentTransport.Tag()}, currentTransport) + } + return nil +} + +func (m *TransportManager) Close() error { + monitor := taskmonitor.New(m.logger, C.StopTimeout) + m.access.Lock() + if !m.started { + m.access.Unlock() + return nil + } + m.started = false + transports := m.transports + m.transports = nil + m.access.Unlock() + var err error + for _, transport := range transports { + if closer, isCloser := transport.(io.Closer); isCloser { + monitor.Start("close server/", transport.Type(), "[", transport.Tag(), "]") + err = E.Append(err, closer.Close(), func(err error) error { + return E.Cause(err, "close server/", transport.Type(), "[", transport.Tag(), "]") + }) + monitor.Finish() + } + } + return nil +} + +func (m *TransportManager) Transports() []adapter.DNSTransport { + m.access.RLock() + defer m.access.RUnlock() + return m.transports +} + +func (m *TransportManager) Transport(tag string) (adapter.DNSTransport, bool) { + m.access.RLock() + outbound, found := m.transportByTag[tag] + m.access.RUnlock() + return outbound, found +} + +func (m *TransportManager) Default() adapter.DNSTransport { + m.access.RLock() + defer m.access.RUnlock() + if m.defaultTransport != nil { + return m.defaultTransport + } else { + return m.defaultTransportFallback + } +} + +func (m *TransportManager) FakeIP() adapter.FakeIPTransport { + m.access.RLock() + defer m.access.RUnlock() + return m.fakeIPTransport +} + +func (m *TransportManager) Remove(tag string) error { + m.access.Lock() + defer m.access.Unlock() + transport, found := m.transportByTag[tag] + if !found { + return os.ErrInvalid + } + delete(m.transportByTag, tag) + index := common.Index(m.transports, func(it adapter.DNSTransport) bool { + return it == transport + }) + if index == -1 { + panic("invalid inbound index") + } + m.transports = append(m.transports[:index], m.transports[index+1:]...) + started := m.started + if m.defaultTransport == transport { + if len(m.transports) > 0 { + nextTransport := m.transports[0] + if nextTransport.Type() != C.DNSTypeFakeIP { + return E.New("default server cannot be fakeip") + } + m.defaultTransport = nextTransport + m.logger.Info("updated default server to ", m.defaultTransport.Tag()) + } else { + m.defaultTransport = nil + } + } + dependBy := m.dependByTag[tag] + if len(dependBy) > 0 { + return E.New("server[", tag, "] is depended by ", strings.Join(dependBy, ", ")) + } + dependencies := transport.Dependencies() + for _, dependency := range dependencies { + if len(m.dependByTag[dependency]) == 1 { + delete(m.dependByTag, dependency) + } else { + m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool { + return it != tag + }) + } + } + if started { + transport.Reset() + } + return nil +} + +func (m *TransportManager) Create(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) error { + if tag == "" { + return os.ErrInvalid + } + transport, err := m.registry.CreateDNSTransport(ctx, logger, tag, transportType, options) + if err != nil { + return err + } + m.access.Lock() + defer m.access.Unlock() + if m.started { + for _, stage := range adapter.ListStartStages { + err = adapter.LegacyStart(transport, stage) + if err != nil { + return E.Cause(err, stage, " dns/", transport.Type(), "[", transport.Tag(), "]") + } + } + } + if existsTransport, loaded := m.transportByTag[tag]; loaded { + if m.started { + err = common.Close(existsTransport) + if err != nil { + return E.Cause(err, "close dns/", existsTransport.Type(), "[", existsTransport.Tag(), "]") + } + } + existsIndex := common.Index(m.transports, func(it adapter.DNSTransport) bool { + return it == existsTransport + }) + if existsIndex == -1 { + panic("invalid inbound index") + } + m.transports = append(m.transports[:existsIndex], m.transports[existsIndex+1:]...) + } + m.transports = append(m.transports, transport) + m.transportByTag[tag] = transport + dependencies := transport.Dependencies() + for _, dependency := range dependencies { + m.dependByTag[dependency] = append(m.dependByTag[dependency], tag) + } + if tag == m.defaultTag || (m.defaultTag == "" && m.defaultTransport == nil) { + if transport.Type() == C.DNSTypeFakeIP { + return E.New("default server cannot be fakeip") + } + m.defaultTransport = transport + if m.started { + m.logger.Info("updated default server to ", transport.Tag()) + } + } + if transport.Type() == C.DNSTypeFakeIP { + if m.fakeIPTransport != nil { + return E.New("multiple fakeip server are not supported") + } + m.fakeIPTransport = transport.(adapter.FakeIPTransport) + } + return nil +} diff --git a/dns/transport_registry.go b/dns/transport_registry.go new file mode 100644 index 00000000..d838158b --- /dev/null +++ b/dns/transport_registry.go @@ -0,0 +1,72 @@ +package dns + +import ( + "context" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +type TransportConstructorFunc[T any] func(ctx context.Context, logger log.ContextLogger, tag string, options T) (adapter.DNSTransport, error) + +func RegisterTransport[Options any](registry *TransportRegistry, transportType string, constructor TransportConstructorFunc[Options]) { + registry.register(transportType, func() any { + return new(Options) + }, func(ctx context.Context, logger log.ContextLogger, tag string, rawOptions any) (adapter.DNSTransport, error) { + var options *Options + if rawOptions != nil { + options = rawOptions.(*Options) + } + return constructor(ctx, logger, tag, common.PtrValueOrDefault(options)) + }) +} + +var _ adapter.DNSTransportRegistry = (*TransportRegistry)(nil) + +type ( + optionsConstructorFunc func() any + constructorFunc func(ctx context.Context, logger log.ContextLogger, tag string, options any) (adapter.DNSTransport, error) +) + +type TransportRegistry struct { + access sync.Mutex + optionsType map[string]optionsConstructorFunc + constructors map[string]constructorFunc +} + +func NewTransportRegistry() *TransportRegistry { + return &TransportRegistry{ + optionsType: make(map[string]optionsConstructorFunc), + constructors: make(map[string]constructorFunc), + } +} + +func (r *TransportRegistry) CreateOptions(transportType string) (any, bool) { + r.access.Lock() + defer r.access.Unlock() + optionsConstructor, loaded := r.optionsType[transportType] + if !loaded { + return nil, false + } + return optionsConstructor(), true +} + +func (r *TransportRegistry) CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (adapter.DNSTransport, error) { + r.access.Lock() + defer r.access.Unlock() + constructor, loaded := r.constructors[transportType] + if !loaded { + return nil, E.New("transport type not found: " + transportType) + } + return constructor(ctx, logger, tag, options) +} + +func (r *TransportRegistry) register(transportType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) { + r.access.Lock() + defer r.access.Unlock() + r.optionsType[transportType] = optionsConstructor + r.constructors[transportType] = constructor +} diff --git a/experimental/clashapi/dns.go b/experimental/clashapi/dns.go index 2a21a7c1..4f850f82 100644 --- a/experimental/clashapi/dns.go +++ b/experimental/clashapi/dns.go @@ -13,13 +13,13 @@ import ( "github.com/miekg/dns" ) -func dnsRouter(router adapter.Router) http.Handler { +func dnsRouter(router adapter.DNSRouter) http.Handler { r := chi.NewRouter() r.Get("/query", queryDNS(router)) return r } -func queryDNS(router adapter.Router) func(w http.ResponseWriter, r *http.Request) { +func queryDNS(router adapter.DNSRouter) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { name := r.URL.Query().Get("name") qTypeStr := r.URL.Query().Get("type") @@ -39,7 +39,7 @@ func queryDNS(router adapter.Router) func(w http.ResponseWriter, r *http.Request msg := dns.Msg{} msg.SetQuestion(dns.Fqdn(name), qType) - resp, err := router.Exchange(ctx, &msg) + resp, err := router.Exchange(ctx, &msg, adapter.DNSQueryOptions{}) if err != nil { render.Status(r, http.StatusInternalServerError) render.JSON(w, r, newError(err.Error())) diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index efe2ce36..d01e180a 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -42,6 +42,7 @@ var _ adapter.ClashServer = (*Server)(nil) type Server struct { ctx context.Context router adapter.Router + dnsRouter adapter.DNSRouter outbound adapter.OutboundManager endpoint adapter.EndpointManager logger log.Logger @@ -62,11 +63,12 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op trafficManager := trafficontrol.NewManager() chiRouter := chi.NewRouter() s := &Server{ - ctx: ctx, - router: service.FromContext[adapter.Router](ctx), - outbound: service.FromContext[adapter.OutboundManager](ctx), - endpoint: service.FromContext[adapter.EndpointManager](ctx), - logger: logFactory.NewLogger("clash-api"), + ctx: ctx, + router: service.FromContext[adapter.Router](ctx), + dnsRouter: service.FromContext[adapter.DNSRouter](ctx), + outbound: service.FromContext[adapter.OutboundManager](ctx), + endpoint: service.FromContext[adapter.EndpointManager](ctx), + logger: logFactory.NewLogger("clash-api"), httpServer: &http.Server{ Addr: options.ExternalController, Handler: chiRouter, @@ -121,7 +123,7 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op r.Mount("/script", scriptRouter()) r.Mount("/profile", profileRouter()) r.Mount("/cache", cacheRouter(ctx)) - r.Mount("/dns", dnsRouter(s.router)) + r.Mount("/dns", dnsRouter(s.dnsRouter)) s.setupMetaAPI(r) }) @@ -221,7 +223,7 @@ func (s *Server) SetMode(newMode string) { default: } } - s.router.ClearDNSCache() + s.dnsRouter.ClearCache() cacheFile := service.FromContext[adapter.CacheFile](s.ctx) if cacheFile != nil { err := cacheFile.StoreMode(newMode) diff --git a/experimental/deprecated/constants.go b/experimental/deprecated/constants.go index 68aa9aca..bf648365 100644 --- a/experimental/deprecated/constants.go +++ b/experimental/deprecated/constants.go @@ -146,6 +146,21 @@ var OptionTUNGSO = Note{ EnvName: "TUN_GSO", } +var OptionLegacyDNSTransport = Note{ + Name: "legacy-dns-transport", + Description: "legacy DNS transport", + DeprecatedVersion: "1.12.0", + ScheduledVersion: "1.14.0", + EnvName: "LEGACY_DNS_TRANSPORT", +} + +var OptionLegacyDNSFakeIPOptions = Note{ + Name: "legacy-dns-fakeip-options", + Description: "legacy DNS fakeip options", + DeprecatedVersion: "1.12.0", + ScheduledVersion: "1.14.0", +} + var Options = []Note{ OptionBadMatchSource, OptionGEOIP, diff --git a/experimental/libbox/config.go b/experimental/libbox/config.go index 159fd8f6..603d8b6d 100644 --- a/experimental/libbox/config.go +++ b/experimental/libbox/config.go @@ -9,8 +9,11 @@ import ( "github.com/sagernet/sing-box" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/process" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/include" + "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/control" @@ -21,6 +24,18 @@ import ( "github.com/sagernet/sing/service" ) +func BaseContext(platformInterface PlatformInterface) context.Context { + dnsRegistry := include.DNSTransportRegistry() + if platformInterface != nil { + if localTransport := platformInterface.LocalDNSTransport(); localTransport != nil { + dns.RegisterTransport[option.LocalDNSServerOptions](dnsRegistry, C.DNSTypeLocal, func(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) { + return newPlatformTransport(localTransport, tag, options), nil + }) + } + } + return box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry(), dnsRegistry) +} + func parseConfig(ctx context.Context, configContent string) (option.Options, error) { options, err := json.UnmarshalExtendedContext[option.Options](ctx, []byte(configContent)) if err != nil { @@ -30,7 +45,7 @@ func parseConfig(ctx context.Context, configContent string) (option.Options, err } func CheckConfig(configContent string) error { - ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()) + ctx := BaseContext(nil) options, err := parseConfig(ctx, configContent) if err != nil { return err @@ -131,7 +146,7 @@ func (s *platformInterfaceStub) SendNotification(notification *platform.Notifica } func FormatConfig(configContent string) (*StringBox, error) { - options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()), configContent) + options, err := parseConfig(BaseContext(nil), configContent) if err != nil { return nil, err } diff --git a/experimental/libbox/dns.go b/experimental/libbox/dns.go index a46d9b42..a7ccd2a2 100644 --- a/experimental/libbox/dns.go +++ b/experimental/libbox/dns.go @@ -6,7 +6,10 @@ import ( "strings" "syscall" - "github.com/sagernet/sing-dns" + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -21,118 +24,80 @@ type LocalDNSTransport interface { Exchange(ctx *ExchangeContext, message []byte) error } -func RegisterLocalDNSTransport(transport LocalDNSTransport) { - if transport == nil { - dns.RegisterTransport([]string{"local"}, func(options dns.TransportOptions) (dns.Transport, error) { - return dns.NewLocalTransport(options), nil - }) - } else { - dns.RegisterTransport([]string{"local"}, func(options dns.TransportOptions) (dns.Transport, error) { - return &platformLocalDNSTransport{ - iif: transport, - }, nil - }) - } -} +var _ adapter.DNSTransport = (*platformTransport)(nil) -var _ dns.Transport = (*platformLocalDNSTransport)(nil) - -type platformLocalDNSTransport struct { +type platformTransport struct { + dns.TransportAdapter iif LocalDNSTransport } -func (p *platformLocalDNSTransport) Name() string { - return "local" -} - -func (p *platformLocalDNSTransport) Start() error { - return nil -} - -func (p *platformLocalDNSTransport) Reset() { -} - -func (p *platformLocalDNSTransport) Close() error { - return nil -} - -func (p *platformLocalDNSTransport) Raw() bool { - return p.iif.Raw() -} - -func (p *platformLocalDNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - messageBytes, err := message.Pack() - if err != nil { - return nil, err +func newPlatformTransport(iif LocalDNSTransport, tag string, options option.LocalDNSServerOptions) *platformTransport { + return &platformTransport{ + TransportAdapter: dns.NewTransportAdapterWithLocalOptions(C.DNSTypeLocal, tag, options), + iif: iif, } +} + +func (p *platformTransport) Reset() { +} + +func (p *platformTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { response := &ExchangeContext{ context: ctx, } - var responseMessage *mDNS.Msg - var group task.Group - group.Append0(func(ctx context.Context) error { - err = p.iif.Exchange(response, messageBytes) + if p.iif.Raw() { + messageBytes, err := message.Pack() if err != nil { - return err + return nil, err } - if response.error != nil { - return response.error - } - responseMessage = &response.message - return nil - }) - err = group.Run(ctx) - if err != nil { - return nil, err - } - return responseMessage, nil -} - -func (p *platformLocalDNSTransport) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) { - var network string - switch strategy { - case dns.DomainStrategyUseIPv4: - network = "ip4" - case dns.DomainStrategyPreferIPv6: - network = "ip6" - default: - network = "ip" - } - response := &ExchangeContext{ - context: ctx, - } - var responseAddr []netip.Addr - var group task.Group - group.Append0(func(ctx context.Context) error { - err := p.iif.Lookup(response, network, domain) + var responseMessage *mDNS.Msg + var group task.Group + group.Append0(func(ctx context.Context) error { + err = p.iif.Exchange(response, messageBytes) + if err != nil { + return err + } + if response.error != nil { + return response.error + } + responseMessage = &response.message + return nil + }) + err = group.Run(ctx) if err != nil { - return err + return nil, err } - if response.error != nil { - return response.error - } - switch strategy { - case dns.DomainStrategyUseIPv4: - responseAddr = common.Filter(response.addresses, func(it netip.Addr) bool { - return it.Is4() - }) - case dns.DomainStrategyPreferIPv6: - responseAddr = common.Filter(response.addresses, func(it netip.Addr) bool { - return it.Is6() - }) + return responseMessage, nil + } else { + question := message.Question[0] + var network string + switch question.Qtype { + case mDNS.TypeA: + network = "ip4" + case mDNS.TypeAAAA: + network = "ip6" default: - responseAddr = response.addresses + return nil, E.New("only IP queries are supported by current version of Android") } - /*if len(responseAddr) == 0 { - response.error = dns.RCodeSuccess - }*/ - return nil - }) - err := group.Run(ctx) - if err != nil { - return nil, err + var responseAddrs []netip.Addr + var group task.Group + group.Append0(func(ctx context.Context) error { + err := p.iif.Lookup(response, network, question.Name) + if err != nil { + return err + } + if response.error != nil { + return response.error + } + responseAddrs = response.addresses + return nil + }) + err := group.Run(ctx) + if err != nil { + return nil, err + } + return dns.FixedResponse(message.Id, question, responseAddrs, C.DefaultDNSTTL), nil } - return responseAddr, nil } type Func interface { diff --git a/experimental/libbox/platform.go b/experimental/libbox/platform.go index d5951cd3..f0590367 100644 --- a/experimental/libbox/platform.go +++ b/experimental/libbox/platform.go @@ -6,6 +6,7 @@ import ( ) type PlatformInterface interface { + LocalDNSTransport() LocalDNSTransport UsePlatformAutoDetectInterfaceControl() bool AutoDetectInterfaceControl(fd int32) error OpenTun(options TunOptions) (int32, error) diff --git a/experimental/libbox/service.go b/experimental/libbox/service.go index 8d42d26e..a1e2ec8e 100644 --- a/experimental/libbox/service.go +++ b/experimental/libbox/service.go @@ -18,7 +18,6 @@ import ( "github.com/sagernet/sing-box/experimental/deprecated" "github.com/sagernet/sing-box/experimental/libbox/internal/procfs" "github.com/sagernet/sing-box/experimental/libbox/platform" - "github.com/sagernet/sing-box/include" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-tun" @@ -44,7 +43,7 @@ type BoxService struct { } func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) { - ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()) + ctx := BaseContext(platformInterface) ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID) service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager)) options, err := parseConfig(ctx, configContent) @@ -192,6 +191,9 @@ func (w *platformInterfaceWrapper) Interfaces() ([]adapter.NetworkInterface, err continue } w.defaultInterfaceAccess.Lock() + // (GOOS=windows) SA4006: this value of `isDefault` is never used + // Why not used? + //nolint:staticcheck isDefault := w.defaultInterface != nil && int(netInterface.Index) == w.defaultInterface.Index w.defaultInterfaceAccess.Unlock() interfaces = append(interfaces, adapter.NetworkInterface{ diff --git a/go.mod b/go.mod index 52167f1b..d3aa6559 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,6 @@ require ( github.com/sagernet/quic-go v0.49.0-beta.1 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 github.com/sagernet/sing v0.6.5 - github.com/sagernet/sing-dns v0.4.0 github.com/sagernet/sing-mux v0.3.1 github.com/sagernet/sing-quic v0.4.0 github.com/sagernet/sing-shadowsocks v0.2.7 diff --git a/go.sum b/go.sum index ca9f34dd..f446adfa 100644 --- a/go.sum +++ b/go.sum @@ -121,8 +121,6 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4Wk github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= github.com/sagernet/sing v0.6.5 h1:TBKTK6Ms0/MNTZm+cTC2hhKunE42XrNIdsxcYtWqeUU= github.com/sagernet/sing v0.6.5/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= -github.com/sagernet/sing-dns v0.4.0 h1:+mNoOuR3nljjouCH+qMg4zHI1+R9T2ReblGFkZPEndc= -github.com/sagernet/sing-dns v0.4.0/go.mod h1:dweQs54ng2YGzoJfz+F9dGuDNdP5pJ3PLeggnK5VWc8= github.com/sagernet/sing-mux v0.3.1 h1:kvCc8HyGAskDHDQ0yQvoTi/7J4cZPB/VJMsAM3MmdQI= github.com/sagernet/sing-mux v0.3.1/go.mod h1:Mkdz8LnDstthz0HWuA/5foncnDIdcNN5KZ6AdJX+x78= github.com/sagernet/sing-quic v0.4.0 h1:E4geazHk/UrJTXMlT+CBCKmn8V86RhtNeczWtfeoEFc= diff --git a/include/dhcp.go b/include/dhcp.go index 0e4b4ccf..8cf074be 100644 --- a/include/dhcp.go +++ b/include/dhcp.go @@ -2,4 +2,11 @@ package include -import _ "github.com/sagernet/sing-box/transport/dhcp" +import ( + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport/dhcp" +) + +func registerDHCPTransport(registry *dns.TransportRegistry) { + dhcp.RegisterTransport(registry) +} diff --git a/include/dhcp_stub.go b/include/dhcp_stub.go index 47a19d2e..272f313a 100644 --- a/include/dhcp_stub.go +++ b/include/dhcp_stub.go @@ -3,12 +3,18 @@ package include import ( - "github.com/sagernet/sing-dns" + "context" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" ) -func init() { - dns.RegisterTransport([]string{"dhcp"}, func(options dns.TransportOptions) (dns.Transport, error) { +func registerDHCPTransport(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.DHCPDNSServerOptions](registry, C.DNSTypeDHCP, func(ctx context.Context, logger log.ContextLogger, tag string, options option.DHCPDNSServerOptions) (adapter.DNSTransport, error) { return nil, E.New(`DHCP is not included in this build, rebuild with -tags with_dhcp`) }) } diff --git a/include/quic.go b/include/quic.go index 980b4581..6a3f3017 100644 --- a/include/quic.go +++ b/include/quic.go @@ -5,12 +5,13 @@ package include import ( "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport/quic" "github.com/sagernet/sing-box/protocol/hysteria" "github.com/sagernet/sing-box/protocol/hysteria2" _ "github.com/sagernet/sing-box/protocol/naive/quic" "github.com/sagernet/sing-box/protocol/tuic" _ "github.com/sagernet/sing-box/transport/v2rayquic" - _ "github.com/sagernet/sing-dns/quic" ) func registerQUICInbounds(registry *inbound.Registry) { @@ -24,3 +25,8 @@ func registerQUICOutbounds(registry *outbound.Registry) { tuic.RegisterOutbound(registry) hysteria2.RegisterOutbound(registry) } + +func registerQUICTransports(registry *dns.TransportRegistry) { + quic.RegisterTransport(registry) + quic.RegisterHTTP3Transport(registry) +} diff --git a/include/quic_stub.go b/include/quic_stub.go index 66c08590..c20a5114 100644 --- a/include/quic_stub.go +++ b/include/quic_stub.go @@ -13,20 +13,17 @@ import ( "github.com/sagernet/sing-box/common/listener" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/protocol/naive" "github.com/sagernet/sing-box/transport/v2ray" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) func init() { - dns.RegisterTransport([]string{"quic", "h3"}, func(options dns.TransportOptions) (dns.Transport, error) { - return nil, C.ErrQUICNotIncluded - }) v2ray.RegisterQUICConstructor( func(ctx context.Context, logger logger.ContextLogger, options option.V2RayQUICOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) { return nil, C.ErrQUICNotIncluded @@ -63,3 +60,12 @@ func registerQUICOutbounds(registry *outbound.Registry) { return nil, C.ErrQUICNotIncluded }) } + +func registerQUICTransports(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeQUIC, func(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) { + return nil, C.ErrQUICNotIncluded + }) + dns.RegisterTransport[option.RemoteHTTPSDNSServerOptions](registry, C.DNSTypeHTTP3, func(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) { + return nil, C.ErrQUICNotIncluded + }) +} diff --git a/include/registry.go b/include/registry.go index e71ffb0c..cbf793f4 100644 --- a/include/registry.go +++ b/include/registry.go @@ -8,11 +8,16 @@ import ( "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport" + "github.com/sagernet/sing-box/dns/transport/fakeip" + "github.com/sagernet/sing-box/dns/transport/hosts" + "github.com/sagernet/sing-box/dns/transport/local" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/protocol/block" "github.com/sagernet/sing-box/protocol/direct" - "github.com/sagernet/sing-box/protocol/dns" + protocolDNS "github.com/sagernet/sing-box/protocol/dns" "github.com/sagernet/sing-box/protocol/group" "github.com/sagernet/sing-box/protocol/http" "github.com/sagernet/sing-box/protocol/mixed" @@ -61,7 +66,7 @@ func OutboundRegistry() *outbound.Registry { direct.RegisterOutbound(registry) block.RegisterOutbound(registry) - dns.RegisterOutbound(registry) + protocolDNS.RegisterOutbound(registry) group.RegisterSelector(registry) group.RegisterURLTest(registry) @@ -91,6 +96,24 @@ func EndpointRegistry() *endpoint.Registry { return registry } +func DNSTransportRegistry() *dns.TransportRegistry { + registry := dns.NewTransportRegistry() + + transport.RegisterTCP(registry) + transport.RegisterUDP(registry) + transport.RegisterTLS(registry) + transport.RegisterHTTPS(registry) + transport.RegisterPredefined(registry) + hosts.RegisterTransport(registry) + local.RegisterTransport(registry) + fakeip.RegisterTransport(registry) + + registerQUICTransports(registry) + registerDHCPTransport(registry) + + return registry +} + func registerStubForRemovedInbounds(registry *inbound.Registry) { inbound.Register[option.ShadowsocksInboundOptions](registry, C.TypeShadowsocksR, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) { return nil, E.New("ShadowsocksR is deprecated and removed in sing-box 1.6.0") diff --git a/option/dns.go b/option/dns.go index 272c5180..8c9b8bda 100644 --- a/option/dns.go +++ b/option/dns.go @@ -1,29 +1,53 @@ package option import ( + "context" "net/netip" + "net/url" + "os" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/experimental/deprecated" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" "github.com/sagernet/sing/common/json/badoption" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/service" + + "github.com/miekg/dns" ) -type DNSOptions struct { - Servers []DNSServerOptions `json:"servers,omitempty"` - Rules []DNSRule `json:"rules,omitempty"` - Final string `json:"final,omitempty"` - ReverseMapping bool `json:"reverse_mapping,omitempty"` - FakeIP *DNSFakeIPOptions `json:"fakeip,omitempty"` +type RawDNSOptions struct { + Servers []NewDNSServerOptions `json:"servers,omitempty"` + Rules []DNSRule `json:"rules,omitempty"` + Final string `json:"final,omitempty"` + ReverseMapping bool `json:"reverse_mapping,omitempty"` DNSClientOptions } -type DNSServerOptions struct { - Tag string `json:"tag,omitempty"` - Address string `json:"address"` - AddressResolver string `json:"address_resolver,omitempty"` - AddressStrategy DomainStrategy `json:"address_strategy,omitempty"` - AddressFallbackDelay badoption.Duration `json:"address_fallback_delay,omitempty"` - Strategy DomainStrategy `json:"strategy,omitempty"` - Detour string `json:"detour,omitempty"` - ClientSubnet *badoption.Prefixable `json:"client_subnet,omitempty"` +type LegacyDNSOptions struct { + FakeIP *LegacyDNSFakeIPOptions `json:"fakeip,omitempty"` +} + +type DNSOptions struct { + RawDNSOptions + LegacyDNSOptions +} + +func (o *DNSOptions) UnmarshalJSONContext(ctx context.Context, content []byte) error { + err := json.UnmarshalContext(ctx, content, &o.LegacyDNSOptions) + if err != nil { + return err + } + if o.FakeIP != nil && o.FakeIP.Enabled { + deprecated.Report(ctx, deprecated.OptionLegacyDNSFakeIPOptions) + ctx = context.WithValue(ctx, (*LegacyDNSFakeIPOptions)(nil), o.FakeIP) + } + legacyOptions := o.LegacyDNSOptions + o.LegacyDNSOptions = LegacyDNSOptions{} + return badjson.UnmarshallExcludedContext(ctx, content, legacyOptions, &o.RawDNSOptions) } type DNSClientOptions struct { @@ -35,8 +59,261 @@ type DNSClientOptions struct { ClientSubnet *badoption.Prefixable `json:"client_subnet,omitempty"` } -type DNSFakeIPOptions struct { - Enabled bool `json:"enabled,omitempty"` - Inet4Range *netip.Prefix `json:"inet4_range,omitempty"` - Inet6Range *netip.Prefix `json:"inet6_range,omitempty"` +type LegacyDNSFakeIPOptions struct { + Enabled bool `json:"enabled,omitempty"` + Inet4Range *badoption.Prefix `json:"inet4_range,omitempty"` + Inet6Range *badoption.Prefix `json:"inet6_range,omitempty"` +} + +type DNSTransportOptionsRegistry interface { + CreateOptions(transportType string) (any, bool) +} + +type _NewDNSServerOptions struct { + Type string `json:"type,omitempty"` + Tag string `json:"tag,omitempty"` + Options any `json:"-"` +} + +type NewDNSServerOptions _NewDNSServerOptions + +func (o *NewDNSServerOptions) MarshalJSONContext(ctx context.Context) ([]byte, error) { + return badjson.MarshallObjectsContext(ctx, (*_NewDNSServerOptions)(o), o.Options) +} + +func (o *NewDNSServerOptions) UnmarshalJSONContext(ctx context.Context, content []byte) error { + err := json.UnmarshalContext(ctx, content, (*_NewDNSServerOptions)(o)) + if err != nil { + return err + } + registry := service.FromContext[DNSTransportOptionsRegistry](ctx) + if registry == nil { + return E.New("missing outbound options registry in context") + } + var options any + switch o.Type { + case "", C.DNSTypeLegacy: + o.Type = C.DNSTypeLegacy + options = new(LegacyDNSServerOptions) + deprecated.Report(ctx, deprecated.OptionLegacyDNSTransport) + default: + var loaded bool + options, loaded = registry.CreateOptions(o.Type) + if !loaded { + return E.New("unknown transport type: ", o.Type) + } + } + err = badjson.UnmarshallExcludedContext(ctx, content, (*_Outbound)(o), options) + if err != nil { + return err + } + o.Options = options + if o.Type == C.DNSTypeLegacy { + err = o.Upgrade(ctx) + if err != nil { + return err + } + } + return nil +} + +func (o *NewDNSServerOptions) Upgrade(ctx context.Context) error { + if o.Type != C.DNSTypeLegacy { + return nil + } + defer func() { + encoder := json.NewEncoder(os.Stderr) + encoder.SetIndent("", " ") + encoder.Encode(o) + }() + options := o.Options.(*LegacyDNSServerOptions) + serverURL, _ := url.Parse(options.Address) + var serverType string + if serverURL.Scheme != "" { + serverType = serverURL.Scheme + } else { + switch options.Address { + case "local", "fakeip": + serverType = options.Address + default: + serverType = C.DNSTypeUDP + } + } + remoteOptions := RemoteDNSServerOptions{ + LocalDNSServerOptions: LocalDNSServerOptions{ + DialerOptions: DialerOptions{ + Detour: options.Detour, + }, + LegacyStrategy: options.Strategy, + LegacyDefaultDialer: options.Detour == "", + LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}), + }, + AddressResolver: options.AddressResolver, + AddressStrategy: options.AddressStrategy, + AddressFallbackDelay: options.AddressFallbackDelay, + } + switch serverType { + case C.DNSTypeLocal: + o.Type = C.DNSTypeLocal + o.Options = &remoteOptions.LocalDNSServerOptions + case C.DNSTypeUDP: + o.Type = C.DNSTypeUDP + o.Options = &remoteOptions + var serverAddr M.Socksaddr + if serverURL.Scheme == "" { + serverAddr = M.ParseSocksaddr(options.Address) + } else { + serverAddr = M.ParseSocksaddr(serverURL.Host) + } + if !serverAddr.IsValid() { + return E.New("invalid server address") + } + remoteOptions.Server = serverAddr.Addr.String() + if serverAddr.Port != 0 && serverAddr.Port != 53 { + remoteOptions.ServerPort = serverAddr.Port + } + remoteOptions.Server = serverAddr.AddrString() + remoteOptions.ServerPort = serverAddr.Port + case C.DNSTypeTCP: + o.Type = C.DNSTypeTCP + o.Options = &remoteOptions + serverAddr := M.ParseSocksaddr(serverURL.Host) + if !serverAddr.IsValid() { + return E.New("invalid server address") + } + remoteOptions.Server = serverAddr.Addr.String() + if serverAddr.Port != 0 && serverAddr.Port != 53 { + remoteOptions.ServerPort = serverAddr.Port + } + remoteOptions.Server = serverAddr.AddrString() + remoteOptions.ServerPort = serverAddr.Port + case C.DNSTypeTLS, C.DNSTypeQUIC: + o.Type = serverType + serverAddr := M.ParseSocksaddr(serverURL.Host) + if !serverAddr.IsValid() { + return E.New("invalid server address") + } + remoteOptions.Server = serverAddr.Addr.String() + if serverAddr.Port != 0 && serverAddr.Port != 853 { + remoteOptions.ServerPort = serverAddr.Port + } + o.Options = &RemoteTLSDNSServerOptions{ + RemoteDNSServerOptions: remoteOptions, + } + case C.DNSTypeHTTPS, C.DNSTypeHTTP3: + o.Type = serverType + httpsOptions := RemoteHTTPSDNSServerOptions{ + RemoteTLSDNSServerOptions: RemoteTLSDNSServerOptions{ + RemoteDNSServerOptions: remoteOptions, + }, + } + o.Options = &httpsOptions + serverAddr := M.ParseSocksaddr(serverURL.Host) + if !serverAddr.IsValid() { + return E.New("invalid server address") + } + httpsOptions.Server = serverAddr.Addr.String() + if serverAddr.Port != 0 && serverAddr.Port != 443 { + httpsOptions.ServerPort = serverAddr.Port + } + if serverURL.Path != "/dns-query" { + httpsOptions.Path = serverURL.Path + } + case "rcode": + var rcode int + switch serverURL.Host { + case "success": + rcode = dns.RcodeSuccess + case "format_error": + rcode = dns.RcodeFormatError + case "server_failure": + rcode = dns.RcodeServerFailure + case "name_error": + rcode = dns.RcodeNameError + case "not_implemented": + rcode = dns.RcodeNotImplemented + case "refused": + rcode = dns.RcodeRefused + default: + return E.New("unknown rcode: ", serverURL.Host) + } + o.Type = C.DNSTypePreDefined + o.Options = &PredefinedDNSServerOptions{ + Responses: []DNSResponseOptions{ + { + RCode: common.Ptr(DNSRCode(rcode)), + }, + }, + } + case C.DNSTypeDHCP: + o.Type = C.DNSTypeDHCP + dhcpOptions := DHCPDNSServerOptions{} + if serverURL.Host != "" && serverURL.Host != "auto" { + dhcpOptions.Interface = serverURL.Host + } + o.Options = &dhcpOptions + case C.DNSTypeFakeIP: + o.Type = C.DNSTypeFakeIP + fakeipOptions := FakeIPDNSServerOptions{} + if legacyOptions, loaded := ctx.Value((*LegacyDNSFakeIPOptions)(nil)).(*LegacyDNSFakeIPOptions); loaded { + fakeipOptions.Inet4Range = legacyOptions.Inet4Range + fakeipOptions.Inet6Range = legacyOptions.Inet6Range + } + o.Options = &fakeipOptions + default: + return E.New("unsupported DNS server scheme: ", serverType) + } + return nil +} + +type LegacyDNSServerOptions struct { + Address string `json:"address"` + AddressResolver string `json:"address_resolver,omitempty"` + AddressStrategy DomainStrategy `json:"address_strategy,omitempty"` + AddressFallbackDelay badoption.Duration `json:"address_fallback_delay,omitempty"` + Strategy DomainStrategy `json:"strategy,omitempty"` + Detour string `json:"detour,omitempty"` + ClientSubnet *badoption.Prefixable `json:"client_subnet,omitempty"` +} + +type HostsDNSServerOptions struct { + Path badoption.Listable[string] `json:"path,omitempty"` + Predefined badjson.TypedMap[string, badoption.Listable[netip.Addr]] `json:"predefined,omitempty"` +} + +type LocalDNSServerOptions struct { + DialerOptions + LegacyStrategy DomainStrategy `json:"-"` + LegacyDefaultDialer bool `json:"-"` + LegacyClientSubnet netip.Prefix `json:"-"` +} + +type RemoteDNSServerOptions struct { + LocalDNSServerOptions + ServerOptions + AddressResolver string `json:"address_resolver,omitempty"` + AddressStrategy DomainStrategy `json:"address_strategy,omitempty"` + AddressFallbackDelay badoption.Duration `json:"address_fallback_delay,omitempty"` +} + +type RemoteTLSDNSServerOptions struct { + RemoteDNSServerOptions + OutboundTLSOptionsContainer +} + +type RemoteHTTPSDNSServerOptions struct { + RemoteTLSDNSServerOptions + Path string `json:"path,omitempty"` + Method string `json:"method,omitempty"` + Headers badoption.HTTPHeader `json:"headers,omitempty"` +} + +type FakeIPDNSServerOptions struct { + Inet4Range *badoption.Prefix `json:"inet4_range,omitempty"` + Inet6Range *badoption.Prefix `json:"inet6_range,omitempty"` +} + +type DHCPDNSServerOptions struct { + LocalDNSServerOptions + Interface string `json:"interface,omitempty"` } diff --git a/option/dns_record.go b/option/dns_record.go new file mode 100644 index 00000000..c76a76c6 --- /dev/null +++ b/option/dns_record.go @@ -0,0 +1,161 @@ +package option + +import ( + "encoding/base64" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badoption" + M "github.com/sagernet/sing/common/metadata" + + "github.com/miekg/dns" +) + +type PredefinedDNSServerOptions struct { + Responses []DNSResponseOptions `json:"responses,omitempty"` +} + +type DNSResponseOptions struct { + Query badoption.Listable[string] `json:"query,omitempty"` + QueryType badoption.Listable[DNSQueryType] `json:"query_type,omitempty"` + + RCode *DNSRCode `json:"rcode,omitempty"` + Answer badoption.Listable[DNSRecordOptions] `json:"answer,omitempty"` + Ns badoption.Listable[DNSRecordOptions] `json:"ns,omitempty"` + Extra badoption.Listable[DNSRecordOptions] `json:"extra,omitempty"` +} + +type DNSRCode int + +func (r DNSRCode) MarshalJSON() ([]byte, error) { + rCodeValue, loaded := dns.RcodeToString[int(r)] + if loaded { + return json.Marshal(rCodeValue) + } + return json.Marshal(int(r)) +} + +func (r *DNSRCode) UnmarshalJSON(bytes []byte) error { + var intValue int + err := json.Unmarshal(bytes, &intValue) + if err == nil { + *r = DNSRCode(intValue) + return nil + } + var stringValue string + err = json.Unmarshal(bytes, &stringValue) + if err != nil { + return err + } + rCodeValue, loaded := dns.StringToRcode[stringValue] + if !loaded { + return E.New("unknown rcode: " + stringValue) + } + *r = DNSRCode(rCodeValue) + return nil +} + +func (r *DNSRCode) Build() int { + if r == nil { + return dns.RcodeSuccess + } + return int(*r) +} + +func (o DNSResponseOptions) Build() ([]dns.Question, *dns.Msg, error) { + var questions []dns.Question + if len(o.Query) == 0 && len(o.QueryType) == 0 { + questions = []dns.Question{{Qclass: dns.ClassINET}} + } else if len(o.Query) == 0 { + for _, queryType := range o.QueryType { + questions = append(questions, dns.Question{ + Qtype: uint16(queryType), + Qclass: dns.ClassINET, + }) + } + } else if len(o.QueryType) == 0 { + for _, domain := range o.Query { + questions = append(questions, dns.Question{ + Name: dns.Fqdn(domain), + Qclass: dns.ClassINET, + }) + } + } else { + for _, queryType := range o.QueryType { + for _, domain := range o.Query { + questions = append(questions, dns.Question{ + Name: dns.Fqdn(domain), + Qtype: uint16(queryType), + Qclass: dns.ClassINET, + }) + } + } + } + return questions, &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Response: true, + Rcode: o.RCode.Build(), + Authoritative: true, + RecursionDesired: true, + RecursionAvailable: true, + }, + Answer: common.Map(o.Answer, DNSRecordOptions.build), + Ns: common.Map(o.Ns, DNSRecordOptions.build), + Extra: common.Map(o.Extra, DNSRecordOptions.build), + }, nil +} + +type DNSRecordOptions struct { + dns.RR + fromBase64 bool +} + +func (o DNSRecordOptions) MarshalJSON() ([]byte, error) { + if o.fromBase64 { + buffer := buf.Get(dns.Len(o.RR)) + defer buf.Put(buffer) + offset, err := dns.PackRR(o.RR, buffer, 0, nil, false) + if err != nil { + return nil, err + } + return json.Marshal(base64.StdEncoding.EncodeToString(buffer[:offset])) + } + return json.Marshal(o.RR.String()) +} + +func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error { + var stringValue string + err := json.Unmarshal(data, &stringValue) + if err != nil { + return err + } + binary, err := base64.StdEncoding.DecodeString(stringValue) + if err == nil { + return o.unmarshalBase64(binary) + } + record, err := dns.NewRR(stringValue) + if err != nil { + return err + } + if a, isA := record.(*dns.A); isA { + a.A = M.AddrFromIP(a.A).Unmap().AsSlice() + } + o.RR = record + return nil +} + +func (o *DNSRecordOptions) unmarshalBase64(binary []byte) error { + record, _, err := dns.UnpackRR(binary, 0) + if err != nil { + return E.New("parse binary DNS record") + } + o.RR = record + o.fromBase64 = true + return nil +} + +func (o DNSRecordOptions) build() dns.RR { + return o.RR +} diff --git a/option/rule_action.go b/option/rule_action.go index b7003628..a715d260 100644 --- a/option/rule_action.go +++ b/option/rule_action.go @@ -7,7 +7,6 @@ import ( "time" C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-dns" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" "github.com/sagernet/sing/common/json/badjson" @@ -168,12 +167,14 @@ func (r *RouteOptionsActionOptions) UnmarshalJSON(data []byte) error { type DNSRouteActionOptions struct { Server string `json:"server,omitempty"` + Strategy DomainStrategy `json:"strategy,omitempty"` DisableCache bool `json:"disable_cache,omitempty"` RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` ClientSubnet *badoption.Prefixable `json:"client_subnet,omitempty"` } type _DNSRouteOptionsActionOptions struct { + Strategy DomainStrategy `json:"strategy,omitempty"` DisableCache bool `json:"disable_cache,omitempty"` RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` ClientSubnet *badoption.Prefixable `json:"client_subnet,omitempty"` @@ -225,7 +226,7 @@ func (d DirectActionOptions) Descriptions() []string { if d.UDPFragment != nil { descriptions = append(descriptions, "udp_fragment="+fmt.Sprint(*d.UDPFragment)) } - if d.DomainStrategy != DomainStrategy(dns.DomainStrategyAsIS) { + if d.DomainStrategy != DomainStrategy(C.DomainStrategyAsIS) { descriptions = append(descriptions, "domain_strategy="+d.DomainStrategy.String()) } if d.FallbackDelay != 0 { @@ -252,6 +253,14 @@ type _RejectActionOptions struct { type RejectActionOptions _RejectActionOptions +func (r RejectActionOptions) MarshalJSON() ([]byte, error) { + switch r.Method { + case C.RuleActionRejectMethodDefault: + r.Method = "" + } + return json.Marshal((_RejectActionOptions)(r)) +} + func (r *RejectActionOptions) UnmarshalJSON(bytes []byte) error { err := json.Unmarshal(bytes, (*_RejectActionOptions)(r)) if err != nil { diff --git a/option/rule_dns.go b/option/rule_dns.go index b437eb54..9d6fb138 100644 --- a/option/rule_dns.go +++ b/option/rule_dns.go @@ -83,6 +83,7 @@ type RawDefaultDNSRule struct { GeoIP badoption.Listable[string] `json:"geoip,omitempty"` IPCIDR badoption.Listable[string] `json:"ip_cidr,omitempty"` IPIsPrivate bool `json:"ip_is_private,omitempty"` + IPAcceptAny bool `json:"ip_accept_any,omitempty"` SourceIPCIDR badoption.Listable[string] `json:"source_ip_cidr,omitempty"` SourceIPIsPrivate bool `json:"source_ip_is_private,omitempty"` SourcePort badoption.Listable[uint16] `json:"source_port,omitempty"` diff --git a/option/types.go b/option/types.go index 66f58ef8..fe7d4b3d 100644 --- a/option/types.go +++ b/option/types.go @@ -4,7 +4,6 @@ import ( "strings" C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-dns" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/json" @@ -45,19 +44,19 @@ func (v NetworkList) Build() []string { return strings.Split(string(v), "\n") } -type DomainStrategy dns.DomainStrategy +type DomainStrategy C.DomainStrategy func (s DomainStrategy) String() string { - switch dns.DomainStrategy(s) { - case dns.DomainStrategyAsIS: + switch C.DomainStrategy(s) { + case C.DomainStrategyAsIS: return "" - case dns.DomainStrategyPreferIPv4: + case C.DomainStrategyPreferIPv4: return "prefer_ipv4" - case dns.DomainStrategyPreferIPv6: + case C.DomainStrategyPreferIPv6: return "prefer_ipv6" - case dns.DomainStrategyUseIPv4: + case C.DomainStrategyIPv4Only: return "ipv4_only" - case dns.DomainStrategyUseIPv6: + case C.DomainStrategyIPv6Only: return "ipv6_only" default: panic(E.New("unknown domain strategy: ", s)) @@ -66,17 +65,17 @@ func (s DomainStrategy) String() string { func (s DomainStrategy) MarshalJSON() ([]byte, error) { var value string - switch dns.DomainStrategy(s) { - case dns.DomainStrategyAsIS: + switch C.DomainStrategy(s) { + case C.DomainStrategyAsIS: value = "" // value = "as_is" - case dns.DomainStrategyPreferIPv4: + case C.DomainStrategyPreferIPv4: value = "prefer_ipv4" - case dns.DomainStrategyPreferIPv6: + case C.DomainStrategyPreferIPv6: value = "prefer_ipv6" - case dns.DomainStrategyUseIPv4: + case C.DomainStrategyIPv4Only: value = "ipv4_only" - case dns.DomainStrategyUseIPv6: + case C.DomainStrategyIPv6Only: value = "ipv6_only" default: return nil, E.New("unknown domain strategy: ", s) @@ -92,15 +91,15 @@ func (s *DomainStrategy) UnmarshalJSON(bytes []byte) error { } switch value { case "", "as_is": - *s = DomainStrategy(dns.DomainStrategyAsIS) + *s = DomainStrategy(C.DomainStrategyAsIS) case "prefer_ipv4": - *s = DomainStrategy(dns.DomainStrategyPreferIPv4) + *s = DomainStrategy(C.DomainStrategyPreferIPv4) case "prefer_ipv6": - *s = DomainStrategy(dns.DomainStrategyPreferIPv6) + *s = DomainStrategy(C.DomainStrategyPreferIPv6) case "ipv4_only": - *s = DomainStrategy(dns.DomainStrategyUseIPv4) + *s = DomainStrategy(C.DomainStrategyIPv4Only) case "ipv6_only": - *s = DomainStrategy(dns.DomainStrategyUseIPv6) + *s = DomainStrategy(C.DomainStrategyIPv6Only) default: return E.New("unknown domain strategy: ", value) } diff --git a/protocol/direct/outbound.go b/protocol/direct/outbound.go index aba56336..d173ec53 100644 --- a/protocol/direct/outbound.go +++ b/protocol/direct/outbound.go @@ -12,7 +12,6 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" @@ -34,7 +33,7 @@ type Outbound struct { outbound.Adapter logger logger.ContextLogger dialer dialer.ParallelInterfaceDialer - domainStrategy dns.DomainStrategy + domainStrategy C.DomainStrategy fallbackDelay time.Duration overrideOption int overrideDestination M.Socksaddr @@ -50,7 +49,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL outbound := &Outbound{ Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), logger: logger, - domainStrategy: dns.DomainStrategy(options.DomainStrategy), + domainStrategy: C.DomainStrategy(options.DomainStrategy), fallbackDelay: time.Duration(options.FallbackDelay), dialer: outboundDialer, // loopBack: newLoopBackDetector(router), @@ -151,26 +150,26 @@ func (h *Outbound) DialParallel(ctx context.Context, network string, destination case N.NetworkUDP: h.logger.InfoContext(ctx, "outbound packet connection to ", destination) } - var domainStrategy dns.DomainStrategy - if h.domainStrategy != dns.DomainStrategyAsIS { + var domainStrategy C.DomainStrategy + if h.domainStrategy != C.DomainStrategyAsIS { domainStrategy = h.domainStrategy } else { //nolint:staticcheck - domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) + domainStrategy = C.DomainStrategy(metadata.InboundOptions.DomainStrategy) } switch domainStrategy { - case dns.DomainStrategyUseIPv4: + case C.DomainStrategyIPv4Only: destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is4) if len(destinationAddresses) == 0 { return nil, E.New("no IPv4 address available for ", destination) } - case dns.DomainStrategyUseIPv6: + case C.DomainStrategyIPv6Only: destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is6) if len(destinationAddresses) == 0 { return nil, E.New("no IPv6 address available for ", destination) } } - return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, nil, nil, nil, h.fallbackDelay) + return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == C.DomainStrategyPreferIPv6, nil, nil, nil, h.fallbackDelay) } func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy *C.NetworkStrategy, networkType []C.InterfaceType, fallbackNetworkType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { @@ -191,26 +190,26 @@ func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, dest case N.NetworkUDP: h.logger.InfoContext(ctx, "outbound packet connection to ", destination) } - var domainStrategy dns.DomainStrategy - if h.domainStrategy != dns.DomainStrategyAsIS { + var domainStrategy C.DomainStrategy + if h.domainStrategy != C.DomainStrategyAsIS { domainStrategy = h.domainStrategy } else { //nolint:staticcheck - domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) + domainStrategy = C.DomainStrategy(metadata.InboundOptions.DomainStrategy) } switch domainStrategy { - case dns.DomainStrategyUseIPv4: + case C.DomainStrategyIPv4Only: destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is4) if len(destinationAddresses) == 0 { return nil, E.New("no IPv4 address available for ", destination) } - case dns.DomainStrategyUseIPv6: + case C.DomainStrategyIPv6Only: destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is6) if len(destinationAddresses) == 0 { return nil, E.New("no IPv6 address available for ", destination) } } - return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, networkStrategy, networkType, fallbackNetworkType, fallbackDelay) + return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == C.DomainStrategyPreferIPv6, networkStrategy, networkType, fallbackNetworkType, fallbackDelay) } func (h *Outbound) ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy *C.NetworkStrategy, networkType []C.InterfaceType, fallbackNetworkType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) { diff --git a/protocol/dns/handle.go b/protocol/dns/handle.go index bc58d9e2..c4ad79d9 100644 --- a/protocol/dns/handle.go +++ b/protocol/dns/handle.go @@ -7,7 +7,7 @@ import ( "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-dns" + "github.com/sagernet/sing-box/dns" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" @@ -19,7 +19,7 @@ import ( mDNS "github.com/miekg/dns" ) -func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net.Conn, metadata adapter.InboundContext) error { +func HandleStreamDNSRequest(ctx context.Context, router adapter.DNSRouter, conn net.Conn, metadata adapter.InboundContext) error { var queryLength uint16 err := binary.Read(conn, binary.BigEndian, &queryLength) if err != nil { @@ -41,7 +41,7 @@ func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net } metadataInQuery := metadata go func() error { - response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) + response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message, adapter.DNSQueryOptions{}) if err != nil { conn.Close() return err @@ -61,7 +61,7 @@ func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net return nil } -func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, cachedPackets []*N.PacketBuffer, metadata adapter.InboundContext) error { +func NewDNSPacketConnection(ctx context.Context, router adapter.DNSRouter, conn N.PacketConn, cachedPackets []*N.PacketBuffer, metadata adapter.InboundContext) error { metadata.Destination = M.Socksaddr{} var reader N.PacketReader = conn var counters []N.CountFunc @@ -123,7 +123,7 @@ func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.P } metadataInQuery := metadata go func() error { - response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) + response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message, adapter.DNSQueryOptions{}) if err != nil { cancel(err) return err @@ -148,7 +148,7 @@ func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.P return group.Run(fastClose) } -func newDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error { +func newDNSPacketConnection(ctx context.Context, router adapter.DNSRouter, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error { fastClose, cancel := common.ContextWithCancelCause(ctx) timeout := canceler.New(fastClose, cancel, C.DNSTimeout) var group task.Group @@ -193,7 +193,7 @@ func newDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.P } metadataInQuery := metadata go func() error { - response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) + response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message, adapter.DNSQueryOptions{}) if err != nil { cancel(err) return err diff --git a/protocol/dns/outbound.go b/protocol/dns/outbound.go index 5f06557b..277d7454 100644 --- a/protocol/dns/outbound.go +++ b/protocol/dns/outbound.go @@ -14,6 +14,7 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" ) func RegisterOutbound(registry *outbound.Registry) { @@ -22,14 +23,14 @@ func RegisterOutbound(registry *outbound.Registry) { type Outbound struct { outbound.Adapter - router adapter.Router + router adapter.DNSRouter logger logger.ContextLogger } func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.StubOptions) (adapter.Outbound, error) { return &Outbound{ Adapter: outbound.NewAdapter(C.TypeDNS, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil), - router: router, + router: service.FromContext[adapter.DNSRouter](ctx), logger: logger, }, nil } diff --git a/protocol/socks/outbound.go b/protocol/socks/outbound.go index 0632f082..323149e2 100644 --- a/protocol/socks/outbound.go +++ b/protocol/socks/outbound.go @@ -17,6 +17,7 @@ import ( N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/uot" "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/service" ) func RegisterOutbound(registry *outbound.Registry) { @@ -27,7 +28,7 @@ var _ adapter.Outbound = (*Outbound)(nil) type Outbound struct { outbound.Adapter - router adapter.Router + dnsRouter adapter.DNSRouter logger logger.ContextLogger client *socks.Client resolve bool @@ -50,11 +51,11 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, tag, options.Network.Build(), options.DialerOptions), - router: router, - logger: logger, - client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password), - resolve: version == socks.Version4, + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, tag, options.Network.Build(), options.DialerOptions), + dnsRouter: service.FromContext[adapter.DNSRouter](ctx), + logger: logger, + client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password), + resolve: version == socks.Version4, } uotOptions := common.PtrValueOrDefault(options.UDPOverTCP) if uotOptions.Enabled { @@ -83,7 +84,7 @@ func (h *Outbound) DialContext(ctx context.Context, network string, destination return nil, E.Extend(N.ErrUnknownNetwork, network) } if h.resolve && destination.IsFqdn() { - destinationAddresses, err := h.router.LookupDefault(ctx, destination.Fqdn) + destinationAddresses, err := h.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) if err != nil { return nil, err } @@ -101,7 +102,7 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n return h.uotClient.ListenPacket(ctx, destination) } if h.resolve && destination.IsFqdn() { - destinationAddresses, err := h.router.LookupDefault(ctx, destination.Fqdn) + destinationAddresses, err := h.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) if err != nil { return nil, err } diff --git a/protocol/wireguard/endpoint.go b/protocol/wireguard/endpoint.go index 21d72bd9..300701a9 100644 --- a/protocol/wireguard/endpoint.go +++ b/protocol/wireguard/endpoint.go @@ -13,13 +13,13 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/wireguard" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" ) func RegisterEndpoint(registry *endpoint.Registry) { @@ -35,6 +35,7 @@ type Endpoint struct { endpoint.Adapter ctx context.Context router adapter.Router + dnsRouter adapter.DNSRouter logger logger.ContextLogger localAddresses []netip.Prefix endpoint *wireguard.Endpoint @@ -45,6 +46,7 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), ctx: ctx, router: router, + dnsRouter: service.FromContext[adapter.DNSRouter](ctx), logger: logger, localAddresses: options.Address, } @@ -79,7 +81,9 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL PrivateKey: options.PrivateKey, ListenPort: options.ListenPort, ResolvePeer: func(domain string) (netip.Addr, error) { - endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy)) + endpointAddresses, lookupErr := ep.dnsRouter.Lookup(ctx, domain, adapter.DNSQueryOptions{ + Strategy: C.DomainStrategy(options.DomainStrategy), + }) if lookupErr != nil { return netip.Addr{}, lookupErr } @@ -185,7 +189,7 @@ func (w *Endpoint) DialContext(ctx context.Context, network string, destination w.logger.InfoContext(ctx, "outbound packet connection to ", destination) } if destination.IsFqdn() { - destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) + destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) if err != nil { return nil, err } @@ -199,7 +203,7 @@ func (w *Endpoint) DialContext(ctx context.Context, network string, destination func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { w.logger.InfoContext(ctx, "outbound packet connection to ", destination) if destination.IsFqdn() { - destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) + destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) if err != nil { return nil, err } diff --git a/protocol/wireguard/outbound.go b/protocol/wireguard/outbound.go index 3e299705..4aa49a8d 100644 --- a/protocol/wireguard/outbound.go +++ b/protocol/wireguard/outbound.go @@ -13,12 +13,12 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/wireguard" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" ) func RegisterOutbound(registry *outbound.Registry) { @@ -33,7 +33,7 @@ var ( type Outbound struct { outbound.Adapter ctx context.Context - router adapter.Router + dnsRouter adapter.DNSRouter logger logger.ContextLogger localAddresses []netip.Prefix endpoint *wireguard.Endpoint @@ -47,7 +47,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL outbound := &Outbound{ Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), ctx: ctx, - router: router, + dnsRouter: service.FromContext[adapter.DNSRouter](ctx), logger: logger, localAddresses: options.LocalAddress, } @@ -94,7 +94,9 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL Address: options.LocalAddress, PrivateKey: options.PrivateKey, ResolvePeer: func(domain string) (netip.Addr, error) { - endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy)) + endpointAddresses, lookupErr := outbound.dnsRouter.Lookup(ctx, domain, adapter.DNSQueryOptions{ + Strategy: C.DomainStrategy(options.DomainStrategy), + }) if lookupErr != nil { return netip.Addr{}, lookupErr } @@ -137,7 +139,7 @@ func (o *Outbound) DialContext(ctx context.Context, network string, destination o.logger.InfoContext(ctx, "outbound packet connection to ", destination) } if destination.IsFqdn() { - destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn) + destinationAddresses, err := o.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) if err != nil { return nil, err } @@ -151,7 +153,7 @@ func (o *Outbound) DialContext(ctx context.Context, network string, destination func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { o.logger.InfoContext(ctx, "outbound packet connection to ", destination) if destination.IsFqdn() { - destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn) + destinationAddresses, err := o.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) if err != nil { return nil, err } diff --git a/release/config/config.json b/release/config/config.json index c518d18b..bdc78d40 100644 --- a/release/config/config.json +++ b/release/config/config.json @@ -14,10 +14,13 @@ "type": "shadowsocks", "listen": "::", "listen_port": 8080, - "sniff": true, "network": "tcp", "method": "2022-blake3-aes-128-gcm", - "password": "8JCsPssfgS8tiRwiMlhARg==" + "password": "Gn1JUS14bLUHgv1cWDDp4A==", + "multiplex": { + "enabled": true, + "padding": true + } } ], "outbounds": [ @@ -32,7 +35,7 @@ "route": { "rules": [ { - "protocol": "dns", + "port": 53, "outbound": "dns-out" } ] diff --git a/route/dns.go b/route/dns.go index 2c6efefe..7d2b5778 100644 --- a/route/dns.go +++ b/route/dns.go @@ -8,11 +8,12 @@ import ( "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" dnsOutbound "github.com/sagernet/sing-box/protocol/dns" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/udpnat2" @@ -24,7 +25,7 @@ func (r *Router) hijackDNSStream(ctx context.Context, conn net.Conn, metadata ad metadata.Destination = M.Socksaddr{} for { conn.SetReadDeadline(time.Now().Add(C.DNSTimeout)) - err := dnsOutbound.HandleStreamDNSRequest(ctx, r, conn, metadata) + err := dnsOutbound.HandleStreamDNSRequest(ctx, r.dns, conn, metadata) if err != nil { return err } @@ -38,37 +39,38 @@ func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetB buffer := packet.Buffer destination := packet.Destination N.PutPacketBuffer(packet) - go ExchangeDNSPacket(ctx, r, natConn, buffer, metadata, destination) + go ExchangeDNSPacket(ctx, r.dns, r.logger, natConn, buffer, metadata, destination) } natConn.SetHandler(&dnsHijacker{ - router: r, + router: r.dns, + logger: r.logger, conn: conn, ctx: ctx, metadata: metadata, }) return } - err := dnsOutbound.NewDNSPacketConnection(ctx, r, conn, packetBuffers, metadata) + err := dnsOutbound.NewDNSPacketConnection(ctx, r.dns, conn, packetBuffers, metadata) if err != nil && !E.IsClosedOrCanceled(err) { - r.dnsLogger.ErrorContext(ctx, E.Cause(err, "process packet connection")) + r.logger.ErrorContext(ctx, E.Cause(err, "process DNS packet connection")) } } -func ExchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) { +func ExchangeDNSPacket(ctx context.Context, router adapter.DNSRouter, logger logger.ContextLogger, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) { err := exchangeDNSPacket(ctx, router, conn, buffer, metadata, destination) if err != nil && !errors.Is(err, tun.ErrDrop) && !E.IsClosedOrCanceled(err) { - router.dnsLogger.ErrorContext(ctx, E.Cause(err, "process packet connection")) + logger.ErrorContext(ctx, E.Cause(err, "process DNS packet connection")) } } -func exchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) error { +func exchangeDNSPacket(ctx context.Context, router adapter.DNSRouter, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) error { var message mDNS.Msg err := message.Unpack(buffer.Bytes()) buffer.Release() if err != nil { return E.Cause(err, "unpack request") } - response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message) + response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message, adapter.DNSQueryOptions{}) if err != nil { return err } @@ -81,12 +83,13 @@ func exchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, b } type dnsHijacker struct { - router *Router + router adapter.DNSRouter + logger logger.ContextLogger conn N.PacketConn ctx context.Context metadata adapter.InboundContext } func (h *dnsHijacker) NewPacketEx(buffer *buf.Buffer, destination M.Socksaddr) { - go ExchangeDNSPacket(h.ctx, h.router, h.conn, buffer, h.metadata, destination) + go ExchangeDNSPacket(h.ctx, h.router, h.logger, h.conn, buffer, h.metadata, destination) } diff --git a/route/geo_resources.go b/route/geo_resources.go deleted file mode 100644 index 8a8a3ef5..00000000 --- a/route/geo_resources.go +++ /dev/null @@ -1,246 +0,0 @@ -package route - -import ( - "context" - "io" - "net" - "net/http" - "os" - "path/filepath" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/geoip" - "github.com/sagernet/sing-box/common/geosite" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/experimental/deprecated" - R "github.com/sagernet/sing-box/route/rule" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/service/filemanager" -) - -func (r *Router) GeoIPReader() *geoip.Reader { - return r.geoIPReader -} - -func (r *Router) LoadGeosite(code string) (adapter.Rule, error) { - rule, cached := r.geositeCache[code] - if cached { - return rule, nil - } - items, err := r.geositeReader.Read(code) - if err != nil { - return nil, err - } - rule, err = R.NewDefaultRule(r.ctx, nil, geosite.Compile(items)) - if err != nil { - return nil, err - } - r.geositeCache[code] = rule - return rule, nil -} - -func (r *Router) prepareGeoIPDatabase() error { - deprecated.Report(r.ctx, deprecated.OptionGEOIP) - var geoPath string - if r.geoIPOptions.Path != "" { - geoPath = r.geoIPOptions.Path - } else { - geoPath = "geoip.db" - if foundPath, loaded := C.FindPath(geoPath); loaded { - geoPath = foundPath - } - } - if !rw.IsFile(geoPath) { - geoPath = filemanager.BasePath(r.ctx, geoPath) - } - if stat, err := os.Stat(geoPath); err == nil { - if stat.IsDir() { - return E.New("geoip path is a directory: ", geoPath) - } - if stat.Size() == 0 { - os.Remove(geoPath) - } - } - if !rw.IsFile(geoPath) { - r.logger.Warn("geoip database not exists: ", geoPath) - var err error - for attempts := 0; attempts < 3; attempts++ { - err = r.downloadGeoIPDatabase(geoPath) - if err == nil { - break - } - r.logger.Error("download geoip database: ", err) - os.Remove(geoPath) - // time.Sleep(10 * time.Second) - } - if err != nil { - return err - } - } - geoReader, codes, err := geoip.Open(geoPath) - if err != nil { - return E.Cause(err, "open geoip database") - } - r.logger.Info("loaded geoip database: ", len(codes), " codes") - r.geoIPReader = geoReader - return nil -} - -func (r *Router) prepareGeositeDatabase() error { - deprecated.Report(r.ctx, deprecated.OptionGEOSITE) - var geoPath string - if r.geositeOptions.Path != "" { - geoPath = r.geositeOptions.Path - } else { - geoPath = "geosite.db" - if foundPath, loaded := C.FindPath(geoPath); loaded { - geoPath = foundPath - } - } - if !rw.IsFile(geoPath) { - geoPath = filemanager.BasePath(r.ctx, geoPath) - } - if stat, err := os.Stat(geoPath); err == nil { - if stat.IsDir() { - return E.New("geoip path is a directory: ", geoPath) - } - if stat.Size() == 0 { - os.Remove(geoPath) - } - } - if !rw.IsFile(geoPath) { - r.logger.Warn("geosite database not exists: ", geoPath) - var err error - for attempts := 0; attempts < 3; attempts++ { - err = r.downloadGeositeDatabase(geoPath) - if err == nil { - break - } - r.logger.Error("download geosite database: ", err) - os.Remove(geoPath) - } - if err != nil { - return err - } - } - geoReader, codes, err := geosite.Open(geoPath) - if err == nil { - r.logger.Info("loaded geosite database: ", len(codes), " codes") - r.geositeReader = geoReader - } else { - return E.Cause(err, "open geosite database") - } - return nil -} - -func (r *Router) downloadGeoIPDatabase(savePath string) error { - var downloadURL string - if r.geoIPOptions.DownloadURL != "" { - downloadURL = r.geoIPOptions.DownloadURL - } else { - downloadURL = "https://github.com/SagerNet/sing-geoip/releases/latest/download/geoip.db" - } - r.logger.Info("downloading geoip database") - var detour adapter.Outbound - if r.geoIPOptions.DownloadDetour != "" { - outbound, loaded := r.outbound.Outbound(r.geoIPOptions.DownloadDetour) - if !loaded { - return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour) - } - detour = outbound - } else { - detour = r.outbound.Default() - } - - if parentDir := filepath.Dir(savePath); parentDir != "" { - filemanager.MkdirAll(r.ctx, parentDir, 0o755) - } - - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSHandshakeTimeout: C.TCPTimeout, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, - } - defer httpClient.CloseIdleConnections() - request, err := http.NewRequest("GET", downloadURL, nil) - if err != nil { - return err - } - response, err := httpClient.Do(request.WithContext(r.ctx)) - if err != nil { - return err - } - defer response.Body.Close() - - saveFile, err := filemanager.Create(r.ctx, savePath) - if err != nil { - return E.Cause(err, "open output file: ", downloadURL) - } - _, err = io.Copy(saveFile, response.Body) - saveFile.Close() - if err != nil { - filemanager.Remove(r.ctx, savePath) - } - return err -} - -func (r *Router) downloadGeositeDatabase(savePath string) error { - var downloadURL string - if r.geositeOptions.DownloadURL != "" { - downloadURL = r.geositeOptions.DownloadURL - } else { - downloadURL = "https://github.com/SagerNet/sing-geosite/releases/latest/download/geosite.db" - } - r.logger.Info("downloading geosite database") - var detour adapter.Outbound - if r.geositeOptions.DownloadDetour != "" { - outbound, loaded := r.outbound.Outbound(r.geositeOptions.DownloadDetour) - if !loaded { - return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour) - } - detour = outbound - } else { - detour = r.outbound.Default() - } - - if parentDir := filepath.Dir(savePath); parentDir != "" { - filemanager.MkdirAll(r.ctx, parentDir, 0o755) - } - - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSHandshakeTimeout: C.TCPTimeout, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, - } - defer httpClient.CloseIdleConnections() - request, err := http.NewRequest("GET", downloadURL, nil) - if err != nil { - return err - } - response, err := httpClient.Do(request.WithContext(r.ctx)) - if err != nil { - return err - } - defer response.Body.Close() - - saveFile, err := filemanager.Create(r.ctx, savePath) - if err != nil { - return E.Cause(err, "open output file: ", downloadURL) - } - _, err = io.Copy(saveFile, response.Body) - saveFile.Close() - if err != nil { - filemanager.Remove(r.ctx, savePath) - } - return err -} diff --git a/route/route.go b/route/route.go index 834d3425..ac81420c 100644 --- a/route/route.go +++ b/route/route.go @@ -17,7 +17,6 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/route/rule" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing-mux" "github.com/sagernet/sing-vmess" "github.com/sagernet/sing/common" @@ -325,22 +324,23 @@ func (r *Router) matchRule( metadata.ProcessInfo = processInfo } } - if r.fakeIPStore != nil && r.fakeIPStore.Contains(metadata.Destination.Addr) { - domain, loaded := r.fakeIPStore.Lookup(metadata.Destination.Addr) + if metadata.Destination.Addr.IsValid() && r.dnsTransport.FakeIP() != nil && r.dnsTransport.FakeIP().Store().Contains(metadata.Destination.Addr) { + domain, loaded := r.dnsTransport.FakeIP().Store().Lookup(metadata.Destination.Addr) if !loaded { - fatalErr = E.New("missing fakeip record, try to configure experimental.cache_file") + fatalErr = E.New("missing fakeip record, try enable `experimental.cache_file`") return } - metadata.OriginDestination = metadata.Destination - metadata.Destination = M.Socksaddr{ - Fqdn: domain, - Port: metadata.Destination.Port, + if domain != "" { + metadata.OriginDestination = metadata.Destination + metadata.Destination = M.Socksaddr{ + Fqdn: domain, + Port: metadata.Destination.Port, + } + metadata.FakeIP = true + r.logger.DebugContext(ctx, "found fakeip domain: ", domain) } - metadata.FakeIP = true - r.logger.DebugContext(ctx, "found fakeip domain: ", domain) - } - if r.dnsReverseMapping != nil && metadata.Domain == "" { - domain, loaded := r.dnsReverseMapping.Query(metadata.Destination.Addr) + } else if metadata.Domain == "" { + domain, loaded := r.dns.LookupReverseMapping(metadata.Destination.Addr) if loaded { metadata.Domain = domain r.logger.DebugContext(ctx, "found reserve mapped domain: ", metadata.Domain) @@ -369,9 +369,9 @@ func (r *Router) matchRule( packetBuffers = newPackerBuffers } } - if dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) != dns.DomainStrategyAsIS { + if C.DomainStrategy(metadata.InboundOptions.DomainStrategy) != C.DomainStrategyAsIS { fatalErr = r.actionResolve(ctx, metadata, &rule.RuleActionResolve{ - Strategy: dns.DomainStrategy(metadata.InboundOptions.DomainStrategy), + Strategy: C.DomainStrategy(metadata.InboundOptions.DomainStrategy), }) if fatalErr != nil { return @@ -649,13 +649,23 @@ func (r *Router) actionSniff( func (r *Router) actionResolve(ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionResolve) error { if metadata.Destination.IsFqdn() { - metadata.DNSServer = action.Server - addresses, err := r.Lookup(adapter.WithContext(ctx, metadata), metadata.Destination.Fqdn, action.Strategy) + var transport adapter.DNSTransport + if action.Server != "" { + var loaded bool + transport, loaded = r.dnsTransport.Transport(action.Server) + if !loaded { + return E.New("DNS server not found: ", action.Server) + } + } + addresses, err := r.dns.Lookup(adapter.WithContext(ctx, metadata), metadata.Destination.Fqdn, adapter.DNSQueryOptions{ + Transport: transport, + Strategy: action.Strategy, + }) if err != nil { return err } metadata.DestinationAddresses = addresses - r.dnsLogger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]") + r.logger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]") if metadata.Destination.IsIPv4() { metadata.IPVersion = 4 } else if metadata.Destination.IsIPv6() { diff --git a/route/route_dns.go b/route/route_dns.go deleted file mode 100644 index 3d7dc64f..00000000 --- a/route/route_dns.go +++ /dev/null @@ -1,348 +0,0 @@ -package route - -import ( - "context" - "errors" - "net/netip" - "strings" - "time" - - "github.com/sagernet/sing-box/adapter" - C "github.com/sagernet/sing-box/constant" - R "github.com/sagernet/sing-box/route/rule" - "github.com/sagernet/sing-dns" - "github.com/sagernet/sing-tun" - "github.com/sagernet/sing/common/cache" - E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - M "github.com/sagernet/sing/common/metadata" - - mDNS "github.com/miekg/dns" -) - -type DNSReverseMapping struct { - cache *cache.LruCache[netip.Addr, string] -} - -func NewDNSReverseMapping() *DNSReverseMapping { - return &DNSReverseMapping{ - cache: cache.New[netip.Addr, string](), - } -} - -func (m *DNSReverseMapping) Save(address netip.Addr, domain string, ttl int) { - m.cache.StoreWithExpire(address, domain, time.Now().Add(time.Duration(ttl)*time.Second)) -} - -func (m *DNSReverseMapping) Query(address netip.Addr) (string, bool) { - domain, loaded := m.cache.Load(address) - return domain, loaded -} - -func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool) (dns.Transport, dns.QueryOptions, adapter.DNSRule, int) { - metadata := adapter.ContextFrom(ctx) - if metadata == nil { - panic("no context") - } - var options dns.QueryOptions - var currentRuleIndex int - if ruleIndex != -1 { - currentRuleIndex = ruleIndex + 1 - } - for ; currentRuleIndex < len(r.dnsRules); currentRuleIndex++ { - currentRule := r.dnsRules[currentRuleIndex] - if currentRule.WithAddressLimit() && !isAddressQuery { - continue - } - metadata.ResetRuleCache() - if currentRule.Match(metadata) { - ruleDescription := currentRule.String() - if ruleDescription != "" { - r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action()) - } else { - r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] => ", currentRule.Action()) - } - switch action := currentRule.Action().(type) { - case *R.RuleActionDNSRoute: - transport, loaded := r.transportMap[action.Server] - if !loaded { - r.dnsLogger.ErrorContext(ctx, "transport not found: ", action.Server) - continue - } - _, isFakeIP := transport.(adapter.FakeIPTransport) - if isFakeIP && !allowFakeIP { - continue - } - if isFakeIP || action.DisableCache { - options.DisableCache = true - } - if action.RewriteTTL != nil { - options.RewriteTTL = action.RewriteTTL - } - if action.ClientSubnet.IsValid() { - options.ClientSubnet = action.ClientSubnet - } - if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded { - options.Strategy = domainStrategy - } else { - options.Strategy = r.defaultDomainStrategy - } - r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] => ", currentRule.Action()) - return transport, options, currentRule, currentRuleIndex - case *R.RuleActionDNSRouteOptions: - if action.DisableCache { - options.DisableCache = true - } - if action.RewriteTTL != nil { - options.RewriteTTL = action.RewriteTTL - } - if action.ClientSubnet.IsValid() { - options.ClientSubnet = action.ClientSubnet - } - r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] => ", currentRule.Action()) - case *R.RuleActionReject: - r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] => ", currentRule.Action()) - return nil, options, currentRule, currentRuleIndex - } - } - } - if domainStrategy, dsLoaded := r.transportDomainStrategy[r.defaultTransport]; dsLoaded { - options.Strategy = domainStrategy - } else { - options.Strategy = r.defaultDomainStrategy - } - return r.defaultTransport, options, nil, -1 -} - -func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - if len(message.Question) != 1 { - r.dnsLogger.WarnContext(ctx, "bad question size: ", len(message.Question)) - responseMessage := mDNS.Msg{ - MsgHdr: mDNS.MsgHdr{ - Id: message.Id, - Response: true, - Rcode: mDNS.RcodeFormatError, - }, - Question: message.Question, - } - return &responseMessage, nil - } - var ( - response *mDNS.Msg - cached bool - transport dns.Transport - err error - ) - response, cached = r.dnsClient.ExchangeCache(ctx, message) - if !cached { - var metadata *adapter.InboundContext - ctx, metadata = adapter.ExtendContext(ctx) - metadata.Destination = M.Socksaddr{} - metadata.QueryType = message.Question[0].Qtype - switch metadata.QueryType { - case mDNS.TypeA: - metadata.IPVersion = 4 - case mDNS.TypeAAAA: - metadata.IPVersion = 6 - } - metadata.Domain = fqdnToDomain(message.Question[0].Name) - var ( - options dns.QueryOptions - rule adapter.DNSRule - ruleIndex int - ) - ruleIndex = -1 - for { - dnsCtx := adapter.OverrideContext(ctx) - var addressLimit bool - transport, options, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message)) - if rule != nil { - switch action := rule.Action().(type) { - case *R.RuleActionReject: - switch action.Method { - case C.RuleActionRejectMethodDefault: - return dns.FixedResponse(message.Id, message.Question[0], nil, 0), nil - case C.RuleActionRejectMethodDrop: - return nil, tun.ErrDrop - } - } - } - r.dnsLogger.DebugContext(ctx, "exchange ", formatQuestion(message.Question[0].String()), " via ", transport.Name()) - if rule != nil && rule.WithAddressLimit() { - addressLimit = true - response, err = r.dnsClient.ExchangeWithResponseCheck(dnsCtx, transport, message, options, func(responseAddrs []netip.Addr) bool { - metadata.DestinationAddresses = responseAddrs - return rule.MatchAddressLimit(metadata) - }) - } else { - addressLimit = false - response, err = r.dnsClient.Exchange(dnsCtx, transport, message, options) - } - var rejected bool - if err != nil { - if errors.Is(err, dns.ErrResponseRejectedCached) { - rejected = true - r.dnsLogger.DebugContext(ctx, E.Cause(err, "response rejected for ", formatQuestion(message.Question[0].String())), " (cached)") - } else if errors.Is(err, dns.ErrResponseRejected) { - rejected = true - r.dnsLogger.DebugContext(ctx, E.Cause(err, "response rejected for ", formatQuestion(message.Question[0].String()))) - } else if len(message.Question) > 0 { - r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", formatQuestion(message.Question[0].String()))) - } else { - r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for ")) - } - } - if addressLimit && rejected { - continue - } - break - } - } - if err != nil { - return nil, err - } - if r.dnsReverseMapping != nil && response != nil && len(response.Answer) > 0 { - if _, isFakeIP := transport.(adapter.FakeIPTransport); !isFakeIP { - for _, answer := range response.Answer { - switch record := answer.(type) { - case *mDNS.A: - r.dnsReverseMapping.Save(M.AddrFromIP(record.A), fqdnToDomain(record.Hdr.Name), int(record.Hdr.Ttl)) - case *mDNS.AAAA: - r.dnsReverseMapping.Save(M.AddrFromIP(record.AAAA), fqdnToDomain(record.Hdr.Name), int(record.Hdr.Ttl)) - } - } - } - } - return response, nil -} - -func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) { - var ( - responseAddrs []netip.Addr - cached bool - err error - ) - printResult := func() { - if err != nil { - if errors.Is(err, dns.ErrResponseRejectedCached) { - r.dnsLogger.DebugContext(ctx, "response rejected for ", domain, " (cached)") - } else if errors.Is(err, dns.ErrResponseRejected) { - r.dnsLogger.DebugContext(ctx, "response rejected for ", domain) - } else { - r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain)) - } - } else if len(responseAddrs) == 0 { - r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result") - err = dns.RCodeNameError - } - } - responseAddrs, cached = r.dnsClient.LookupCache(ctx, domain, strategy) - if cached { - if len(responseAddrs) == 0 { - return nil, dns.RCodeNameError - } - return responseAddrs, nil - } - r.dnsLogger.DebugContext(ctx, "lookup domain ", domain) - ctx, metadata := adapter.ExtendContext(ctx) - metadata.Destination = M.Socksaddr{} - metadata.Domain = domain - if metadata.DNSServer != "" { - transport, loaded := r.transportMap[metadata.DNSServer] - if !loaded { - return nil, E.New("transport not found: ", metadata.DNSServer) - } - if strategy == dns.DomainStrategyAsIS { - if transportDomainStrategy, loaded := r.transportDomainStrategy[transport]; loaded { - strategy = transportDomainStrategy - } else { - strategy = r.defaultDomainStrategy - } - } - responseAddrs, err = r.dnsClient.Lookup(ctx, transport, domain, dns.QueryOptions{Strategy: strategy}) - } else { - var ( - transport dns.Transport - options dns.QueryOptions - rule adapter.DNSRule - ruleIndex int - ) - ruleIndex = -1 - for { - dnsCtx := adapter.OverrideContext(ctx) - var addressLimit bool - transport, options, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true) - if strategy != dns.DomainStrategyAsIS { - options.Strategy = strategy - } - if rule != nil { - switch action := rule.Action().(type) { - case *R.RuleActionReject: - switch action.Method { - case C.RuleActionRejectMethodDefault: - return nil, nil - case C.RuleActionRejectMethodDrop: - return nil, tun.ErrDrop - } - } - } - if rule != nil && rule.WithAddressLimit() { - addressLimit = true - responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool { - metadata.DestinationAddresses = responseAddrs - return rule.MatchAddressLimit(metadata) - }) - } else { - addressLimit = false - responseAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, options) - } - if !addressLimit || err == nil { - break - } - printResult() - } - } - printResult() - if len(responseAddrs) > 0 { - r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " ")) - } - return responseAddrs, err -} - -func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) { - return r.Lookup(ctx, domain, dns.DomainStrategyAsIS) -} - -func (r *Router) ClearDNSCache() { - r.dnsClient.ClearCache() - if r.platformInterface != nil { - r.platformInterface.ClearDNSCache() - } -} - -func isAddressQuery(message *mDNS.Msg) bool { - for _, question := range message.Question { - if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA || question.Qtype == mDNS.TypeHTTPS { - return true - } - } - return false -} - -func fqdnToDomain(fqdn string) string { - if mDNS.IsFqdn(fqdn) { - return fqdn[:len(fqdn)-1] - } - return fqdn -} - -func formatQuestion(string string) string { - if strings.HasPrefix(string, ";") { - string = string[1:] - } - string = strings.ReplaceAll(string, "\t", " ") - for strings.Contains(string, " ") { - string = strings.ReplaceAll(string, " ", " ") - } - return string -} diff --git a/route/router.go b/route/router.go index 68f5dc35..63fc7b10 100644 --- a/route/router.go +++ b/route/router.go @@ -2,17 +2,10 @@ package route import ( "context" - "net/netip" - "net/url" "os" "runtime" - "strings" - "time" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/dialer" - "github.com/sagernet/sing-box/common/geoip" - "github.com/sagernet/sing-box/common/geosite" "github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-box/common/taskmonitor" C "github.com/sagernet/sing-box/constant" @@ -20,13 +13,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" R "github.com/sagernet/sing-box/route/rule" - "github.com/sagernet/sing-box/transport/fakeip" - "github.com/sagernet/sing-dns" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/task" "github.com/sagernet/sing/service" "github.com/sagernet/sing/service/pause" @@ -35,334 +22,71 @@ import ( var _ adapter.Router = (*Router)(nil) type Router struct { - ctx context.Context - logger log.ContextLogger - dnsLogger log.ContextLogger - inbound adapter.InboundManager - outbound adapter.OutboundManager - connection adapter.ConnectionManager - network adapter.NetworkManager - rules []adapter.Rule - needGeoIPDatabase bool - needGeositeDatabase bool - geoIPOptions option.GeoIPOptions - geositeOptions option.GeositeOptions - geoIPReader *geoip.Reader - geositeReader *geosite.Reader - geositeCache map[string]adapter.Rule - needFindProcess bool - dnsClient *dns.Client - defaultDomainStrategy dns.DomainStrategy - dnsRules []adapter.DNSRule - ruleSets []adapter.RuleSet - ruleSetMap map[string]adapter.RuleSet - defaultTransport dns.Transport - transports []dns.Transport - transportMap map[string]dns.Transport - transportDomainStrategy map[dns.Transport]dns.DomainStrategy - dnsReverseMapping *DNSReverseMapping - fakeIPStore adapter.FakeIPStore - processSearcher process.Searcher - pauseManager pause.Manager - tracker adapter.ConnectionTracker - platformInterface platform.Interface - needWIFIState bool - started bool + ctx context.Context + logger log.ContextLogger + inbound adapter.InboundManager + outbound adapter.OutboundManager + dns adapter.DNSRouter + dnsTransport adapter.DNSTransportManager + connection adapter.ConnectionManager + network adapter.NetworkManager + rules []adapter.Rule + needFindProcess bool + ruleSets []adapter.RuleSet + ruleSetMap map[string]adapter.RuleSet + processSearcher process.Searcher + pauseManager pause.Manager + tracker adapter.ConnectionTracker + platformInterface platform.Interface + needWIFIState bool + started bool } -func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions) (*Router, error) { - router := &Router{ - ctx: ctx, - logger: logFactory.NewLogger("router"), - dnsLogger: logFactory.NewLogger("dns"), - inbound: service.FromContext[adapter.InboundManager](ctx), - outbound: service.FromContext[adapter.OutboundManager](ctx), - connection: service.FromContext[adapter.ConnectionManager](ctx), - network: service.FromContext[adapter.NetworkManager](ctx), - rules: make([]adapter.Rule, 0, len(options.Rules)), - dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)), - ruleSetMap: make(map[string]adapter.RuleSet), - needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule), - needGeositeDatabase: hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule), - geoIPOptions: common.PtrValueOrDefault(options.GeoIP), - geositeOptions: common.PtrValueOrDefault(options.Geosite), - geositeCache: make(map[string]adapter.Rule), - needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess, - defaultDomainStrategy: dns.DomainStrategy(dnsOptions.Strategy), - pauseManager: service.FromContext[pause.Manager](ctx), - platformInterface: service.FromContext[platform.Interface](ctx), - needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule), - } - service.MustRegister[adapter.Router](ctx, router) - router.dnsClient = dns.NewClient(dns.ClientOptions{ - DisableCache: dnsOptions.DNSClientOptions.DisableCache, - DisableExpire: dnsOptions.DNSClientOptions.DisableExpire, - IndependentCache: dnsOptions.DNSClientOptions.IndependentCache, - CacheCapacity: dnsOptions.DNSClientOptions.CacheCapacity, - RDRC: func() dns.RDRCStore { - cacheFile := service.FromContext[adapter.CacheFile](ctx) - if cacheFile == nil { - return nil - } - if !cacheFile.StoreRDRC() { - return nil - } - return cacheFile - }, - Logger: router.dnsLogger, - }) - for i, ruleOptions := range options.Rules { - routeRule, err := R.NewRule(ctx, router.logger, ruleOptions, true) - if err != nil { - return nil, E.Cause(err, "parse rule[", i, "]") - } - router.rules = append(router.rules, routeRule) - } - for i, dnsRuleOptions := range dnsOptions.Rules { - dnsRule, err := R.NewDNSRule(ctx, router.logger, dnsRuleOptions, true) - if err != nil { - return nil, E.Cause(err, "parse dns rule[", i, "]") - } - router.dnsRules = append(router.dnsRules, dnsRule) - } - for i, ruleSetOptions := range options.RuleSet { - if _, exists := router.ruleSetMap[ruleSetOptions.Tag]; exists { - return nil, E.New("duplicate rule-set tag: ", ruleSetOptions.Tag) - } - ruleSet, err := R.NewRuleSet(ctx, router.logger, ruleSetOptions) - if err != nil { - return nil, E.Cause(err, "parse rule-set[", i, "]") - } - router.ruleSets = append(router.ruleSets, ruleSet) - router.ruleSetMap[ruleSetOptions.Tag] = ruleSet +func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions) *Router { + return &Router{ + ctx: ctx, + logger: logFactory.NewLogger("router"), + inbound: service.FromContext[adapter.InboundManager](ctx), + outbound: service.FromContext[adapter.OutboundManager](ctx), + dns: service.FromContext[adapter.DNSRouter](ctx), + dnsTransport: service.FromContext[adapter.DNSTransportManager](ctx), + connection: service.FromContext[adapter.ConnectionManager](ctx), + network: service.FromContext[adapter.NetworkManager](ctx), + rules: make([]adapter.Rule, 0, len(options.Rules)), + ruleSetMap: make(map[string]adapter.RuleSet), + needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess, + pauseManager: service.FromContext[pause.Manager](ctx), + platformInterface: service.FromContext[platform.Interface](ctx), + needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule), } +} - transports := make([]dns.Transport, len(dnsOptions.Servers)) - dummyTransportMap := make(map[string]dns.Transport) - transportMap := make(map[string]dns.Transport) - transportTags := make([]string, len(dnsOptions.Servers)) - transportTagMap := make(map[string]bool) - transportDomainStrategy := make(map[dns.Transport]dns.DomainStrategy) - for i, server := range dnsOptions.Servers { - var tag string - if server.Tag != "" { - tag = server.Tag - } else { - tag = F.ToString(i) +func (r *Router) Initialize(rules []option.Rule, ruleSets []option.RuleSet) error { + for i, options := range rules { + rule, err := R.NewRule(r.ctx, r.logger, options, false) + if err != nil { + return E.Cause(err, "parse rule[", i, "]") } - if transportTagMap[tag] { - return nil, E.New("duplicate dns server tag: ", tag) - } - transportTags[i] = tag - transportTagMap[tag] = true + r.rules = append(r.rules, rule) } - outboundManager := service.FromContext[adapter.OutboundManager](ctx) - for { - lastLen := len(dummyTransportMap) - for i, server := range dnsOptions.Servers { - tag := transportTags[i] - if _, exists := dummyTransportMap[tag]; exists { - continue - } - var detour N.Dialer - if server.Detour == "" { - detour = dialer.NewDefaultOutbound(outboundManager) - } else { - detour = dialer.NewDetour(outboundManager, server.Detour) - } - var serverProtocol string - switch server.Address { - case "local": - serverProtocol = "local" - default: - serverURL, _ := url.Parse(server.Address) - var serverAddress string - if serverURL != nil { - if serverURL.Scheme == "" { - serverProtocol = "udp" - } else { - serverProtocol = serverURL.Scheme - } - serverAddress = serverURL.Hostname() - } - if serverAddress == "" { - serverAddress = server.Address - } - notIpAddress := !M.ParseSocksaddr(serverAddress).Addr.IsValid() - if server.AddressResolver != "" { - if !transportTagMap[server.AddressResolver] { - return nil, E.New("parse dns server[", tag, "]: address resolver not found: ", server.AddressResolver) - } - if upstream, exists := dummyTransportMap[server.AddressResolver]; exists { - detour = dns.NewDialerWrapper(detour, router.dnsClient, upstream, dns.DomainStrategy(server.AddressStrategy), time.Duration(server.AddressFallbackDelay)) - } else { - continue - } - } else if notIpAddress && strings.Contains(server.Address, ".") { - return nil, E.New("parse dns server[", tag, "]: missing address_resolver") - } - } - var clientSubnet netip.Prefix - if server.ClientSubnet != nil { - clientSubnet = netip.Prefix(common.PtrValueOrDefault(server.ClientSubnet)) - } else if dnsOptions.ClientSubnet != nil { - clientSubnet = netip.Prefix(common.PtrValueOrDefault(dnsOptions.ClientSubnet)) - } - if serverProtocol == "" { - serverProtocol = "transport" - } - transport, err := dns.CreateTransport(dns.TransportOptions{ - Context: ctx, - Logger: logFactory.NewLogger(F.ToString("dns/", serverProtocol, "[", tag, "]")), - Name: tag, - Dialer: detour, - Address: server.Address, - ClientSubnet: clientSubnet, - }) - if err != nil { - return nil, E.Cause(err, "parse dns server[", tag, "]") - } - transports[i] = transport - dummyTransportMap[tag] = transport - if server.Tag != "" { - transportMap[server.Tag] = transport - } - strategy := dns.DomainStrategy(server.Strategy) - if strategy != dns.DomainStrategyAsIS { - transportDomainStrategy[transport] = strategy - } + for i, options := range ruleSets { + if _, exists := r.ruleSetMap[options.Tag]; exists { + return E.New("duplicate rule-set tag: ", options.Tag) } - if len(transports) == len(dummyTransportMap) { - break + ruleSet, err := R.NewRuleSet(r.ctx, r.logger, options) + if err != nil { + return E.Cause(err, "parse rule-set[", i, "]") } - if lastLen != len(dummyTransportMap) { - continue - } - unresolvedTags := common.MapIndexed(common.FilterIndexed(dnsOptions.Servers, func(index int, server option.DNSServerOptions) bool { - _, exists := dummyTransportMap[transportTags[index]] - return !exists - }), func(index int, server option.DNSServerOptions) string { - return transportTags[index] - }) - if len(unresolvedTags) == 0 { - panic(F.ToString("unexpected unresolved dns servers: ", len(transports), " ", len(dummyTransportMap), " ", len(transportMap))) - } - return nil, E.New("found circular reference in dns servers: ", strings.Join(unresolvedTags, " ")) + r.ruleSets = append(r.ruleSets, ruleSet) + r.ruleSetMap[options.Tag] = ruleSet } - var defaultTransport dns.Transport - if dnsOptions.Final != "" { - defaultTransport = dummyTransportMap[dnsOptions.Final] - if defaultTransport == nil { - return nil, E.New("default dns server not found: ", dnsOptions.Final) - } - } - if defaultTransport == nil { - if len(transports) == 0 { - transports = append(transports, common.Must1(dns.CreateTransport(dns.TransportOptions{ - Context: ctx, - Name: "local", - Address: "local", - Dialer: common.Must1(dialer.NewDefault(ctx, option.DialerOptions{})), - }))) - } - defaultTransport = transports[0] - } - if _, isFakeIP := defaultTransport.(adapter.FakeIPTransport); isFakeIP { - return nil, E.New("default DNS server cannot be fakeip") - } - router.defaultTransport = defaultTransport - router.transports = transports - router.transportMap = transportMap - router.transportDomainStrategy = transportDomainStrategy - - if dnsOptions.ReverseMapping { - router.dnsReverseMapping = NewDNSReverseMapping() - } - - if fakeIPOptions := dnsOptions.FakeIP; fakeIPOptions != nil && dnsOptions.FakeIP.Enabled { - var inet4Range netip.Prefix - var inet6Range netip.Prefix - if fakeIPOptions.Inet4Range != nil { - inet4Range = *fakeIPOptions.Inet4Range - } - if fakeIPOptions.Inet6Range != nil { - inet6Range = *fakeIPOptions.Inet6Range - } - router.fakeIPStore = fakeip.NewStore(ctx, router.logger, inet4Range, inet6Range) - } - return router, nil + return nil } func (r *Router) Start(stage adapter.StartStage) error { monitor := taskmonitor.New(r.logger, C.StartTimeout) switch stage { - case adapter.StartStateInitialize: - if r.fakeIPStore != nil { - monitor.Start("initialize fakeip store") - err := r.fakeIPStore.Start() - monitor.Finish() - if err != nil { - return err - } - } case adapter.StartStateStart: - if r.needGeoIPDatabase { - monitor.Start("initialize geoip database") - err := r.prepareGeoIPDatabase() - monitor.Finish() - if err != nil { - return err - } - } - if r.needGeositeDatabase { - monitor.Start("initialize geosite database") - err := r.prepareGeositeDatabase() - monitor.Finish() - if err != nil { - return err - } - } - if r.needGeositeDatabase { - for _, rule := range r.rules { - err := rule.UpdateGeosite() - if err != nil { - r.logger.Error("failed to initialize geosite: ", err) - } - } - for _, rule := range r.dnsRules { - err := rule.UpdateGeosite() - if err != nil { - r.logger.Error("failed to initialize geosite: ", err) - } - } - err := common.Close(r.geositeReader) - if err != nil { - return err - } - r.geositeCache = nil - r.geositeReader = nil - } - - monitor.Start("initialize DNS client") - r.dnsClient.Start() - monitor.Finish() - - for i, rule := range r.dnsRules { - monitor.Start("initialize DNS rule[", i, "]") - err := rule.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "initialize DNS rule[", i, "]") - } - } - for i, transport := range r.transports { - monitor.Start("initialize DNS transport[", i, "]") - err := transport.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "initialize DNS server[", i, "]") - } - } var cacheContext *adapter.HTTPStartContext if len(r.ruleSets) > 0 { monitor.Start("initialize rule-set") @@ -438,7 +162,7 @@ func (r *Router) Start(stage adapter.StartStage) error { r.started = true return nil case adapter.StartStateStarted: - for _, ruleSet := range r.ruleSetMap { + for _, ruleSet := range r.ruleSets { ruleSet.Cleanup() } runtime.GC() @@ -456,34 +180,6 @@ func (r *Router) Close() error { }) monitor.Finish() } - for i, rule := range r.dnsRules { - monitor.Start("close dns rule[", i, "]") - err = E.Append(err, rule.Close(), func(err error) error { - return E.Cause(err, "close dns rule[", i, "]") - }) - monitor.Finish() - } - for i, transport := range r.transports { - monitor.Start("close dns transport[", i, "]") - err = E.Append(err, transport.Close(), func(err error) error { - return E.Cause(err, "close dns transport[", i, "]") - }) - monitor.Finish() - } - if r.geoIPReader != nil { - monitor.Start("close geoip reader") - err = E.Append(err, r.geoIPReader.Close(), func(err error) error { - return E.Cause(err, "close geoip reader") - }) - monitor.Finish() - } - if r.fakeIPStore != nil { - monitor.Start("close fakeip store") - err = E.Append(err, r.fakeIPStore.Close(), func(err error) error { - return E.Cause(err, "close fakeip store") - }) - monitor.Finish() - } for i, ruleSet := range r.ruleSets { monitor.Start("close rule-set[", i, "]") err = E.Append(err, ruleSet.Close(), func(err error) error { @@ -494,10 +190,6 @@ func (r *Router) Close() error { return err } -func (r *Router) FakeIPStore() adapter.FakeIPStore { - return r.fakeIPStore -} - func (r *Router) RuleSet(tag string) (adapter.RuleSet, bool) { ruleSet, loaded := r.ruleSetMap[tag] return ruleSet, loaded @@ -517,7 +209,5 @@ func (r *Router) SetTracker(tracker adapter.ConnectionTracker) { func (r *Router) ResetNetwork() { r.network.ResetNetwork() - for _, transport := range r.transports { - transport.Reset() - } + r.dns.ResetNetwork() } diff --git a/route/rule/rule_abstract.go b/route/rule/rule_abstract.go index 6a569341..5be215e0 100644 --- a/route/rule/rule_abstract.go +++ b/route/rule/rule_abstract.go @@ -51,18 +51,6 @@ func (r *abstractDefaultRule) Close() error { return nil } -func (r *abstractDefaultRule) UpdateGeosite() error { - for _, item := range r.allItems { - if geositeItem, isSite := item.(*GeositeItem); isSite { - err := geositeItem.Update() - if err != nil { - return err - } - } - } - return nil -} - func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool { if len(r.allItems) == 0 { return true @@ -173,19 +161,6 @@ func (r *abstractLogicalRule) Type() string { return C.RuleTypeLogical } -func (r *abstractLogicalRule) UpdateGeosite() error { - for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (adapter.Rule, bool) { - rule, loaded := it.(adapter.Rule) - return rule, loaded - }) { - err := rule.UpdateGeosite() - if err != nil { - return err - } - } - return nil -} - func (r *abstractLogicalRule) Start() error { for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (interface { Start() error diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index 8989ff3c..7287ad86 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -13,7 +13,6 @@ import ( "github.com/sagernet/sing-box/common/sniff" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" @@ -85,7 +84,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti return sniffAction, sniffAction.build() case C.RuleActionTypeResolve: return &RuleActionResolve{ - Strategy: dns.DomainStrategy(action.ResolveOptions.Strategy), + Strategy: C.DomainStrategy(action.ResolveOptions.Strategy), Server: action.ResolveOptions.Server, }, nil default: @@ -101,6 +100,7 @@ func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction) return &RuleActionDNSRoute{ Server: action.RouteOptions.Server, RuleActionDNSRouteOptions: RuleActionDNSRouteOptions{ + Strategy: C.DomainStrategy(action.RouteOptions.Strategy), DisableCache: action.RouteOptions.DisableCache, RewriteTTL: action.RouteOptions.RewriteTTL, ClientSubnet: netip.Prefix(common.PtrValueOrDefault(action.RouteOptions.ClientSubnet)), @@ -108,6 +108,7 @@ func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction) } case C.RuleActionTypeRouteOptions: return &RuleActionDNSRouteOptions{ + Strategy: C.DomainStrategy(action.RouteOptionsOptions.Strategy), DisableCache: action.RouteOptionsOptions.DisableCache, RewriteTTL: action.RouteOptionsOptions.RewriteTTL, ClientSubnet: netip.Prefix(common.PtrValueOrDefault(action.RouteOptionsOptions.ClientSubnet)), @@ -214,6 +215,7 @@ func (r *RuleActionDNSRoute) String() string { } type RuleActionDNSRouteOptions struct { + Strategy C.DomainStrategy DisableCache bool RewriteTTL *uint32 ClientSubnet netip.Prefix @@ -362,7 +364,7 @@ func (r *RuleActionSniff) String() string { } type RuleActionResolve struct { - Strategy dns.DomainStrategy + Strategy C.DomainStrategy Server string } @@ -371,11 +373,11 @@ func (r *RuleActionResolve) Type() string { } func (r *RuleActionResolve) String() string { - if r.Strategy == dns.DomainStrategyAsIS && r.Server == "" { + if r.Strategy == C.DomainStrategyAsIS && r.Server == "" { return F.ToString("resolve") - } else if r.Strategy != dns.DomainStrategyAsIS && r.Server == "" { + } else if r.Strategy != C.DomainStrategyAsIS && r.Server == "" { return F.ToString("resolve(", option.DomainStrategy(r.Strategy).String(), ")") - } else if r.Strategy == dns.DomainStrategyAsIS && r.Server != "" { + } else if r.Strategy == C.DomainStrategyAsIS && r.Server != "" { return F.ToString("resolve(", r.Server, ")") } else { return F.ToString("resolve(", option.DomainStrategy(r.Strategy).String(), ",", r.Server, ")") diff --git a/route/rule/rule_default.go b/route/rule/rule_default.go index 2794c287..aa6059d2 100644 --- a/route/rule/rule_default.go +++ b/route/rule/rule_default.go @@ -120,19 +120,13 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio rule.allItems = append(rule.allItems, item) } if len(options.Geosite) > 0 { - item := NewGeositeItem(router, logger, options.Geosite) - rule.destinationAddressItems = append(rule.destinationAddressItems, item) - rule.allItems = append(rule.allItems, item) + return nil, E.New("geosite database is deprecated in sing-box 1.8.0 and removed in sing-box 1.12.0") } if len(options.SourceGeoIP) > 0 { - item := NewGeoIPItem(router, logger, true, options.SourceGeoIP) - rule.sourceAddressItems = append(rule.sourceAddressItems, item) - rule.allItems = append(rule.allItems, item) + return nil, E.New("geoip database is deprecated in sing-box 1.8.0 and removed in sing-box 1.12.0") } if len(options.GeoIP) > 0 { - item := NewGeoIPItem(router, logger, false, options.GeoIP) - rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item) - rule.allItems = append(rule.allItems, item) + return nil, E.New("geoip database is deprecated in sing-box 1.8.0 and removed in sing-box 1.12.0") } if len(options.SourceIPCIDR) > 0 { item, err := NewIPCIDRItem(true, options.SourceIPCIDR) diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index fb8c6b78..9d1c69b8 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -111,19 +111,13 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op rule.allItems = append(rule.allItems, item) } if len(options.Geosite) > 0 { - item := NewGeositeItem(router, logger, options.Geosite) - rule.destinationAddressItems = append(rule.destinationAddressItems, item) - rule.allItems = append(rule.allItems, item) + return nil, E.New("geosite database is deprecated in sing-box 1.8.0 and removed in sing-box 1.12.0") } if len(options.SourceGeoIP) > 0 { - item := NewGeoIPItem(router, logger, true, options.SourceGeoIP) - rule.sourceAddressItems = append(rule.sourceAddressItems, item) - rule.allItems = append(rule.allItems, item) + return nil, E.New("geoip database is deprecated in sing-box 1.8.0 and removed in sing-box 1.12.0") } if len(options.GeoIP) > 0 { - item := NewGeoIPItem(router, logger, false, options.GeoIP) - rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item) - rule.allItems = append(rule.allItems, item) + return nil, E.New("geoip database is deprecated in sing-box 1.8.0 and removed in sing-box 1.12.0") } if len(options.SourceIPCIDR) > 0 { item, err := NewIPCIDRItem(true, options.SourceIPCIDR) @@ -151,6 +145,11 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item) rule.allItems = append(rule.allItems, item) } + if options.IPAcceptAny { + item := NewIPAcceptAnyItem() + rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item) + rule.allItems = append(rule.allItems, item) + } if len(options.SourcePort) > 0 { item := NewPortItem(true, options.SourcePort) rule.sourcePortItems = append(rule.sourcePortItems, item) diff --git a/route/rule/rule_item_geoip.go b/route/rule/rule_item_geoip.go deleted file mode 100644 index 3c967fec..00000000 --- a/route/rule/rule_item_geoip.go +++ /dev/null @@ -1,98 +0,0 @@ -package rule - -import ( - "net/netip" - "strings" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/log" - N "github.com/sagernet/sing/common/network" -) - -var _ RuleItem = (*GeoIPItem)(nil) - -type GeoIPItem struct { - router adapter.Router - logger log.ContextLogger - isSource bool - codes []string - codeMap map[string]bool -} - -func NewGeoIPItem(router adapter.Router, logger log.ContextLogger, isSource bool, codes []string) *GeoIPItem { - codeMap := make(map[string]bool) - for _, code := range codes { - codeMap[code] = true - } - return &GeoIPItem{ - router: router, - logger: logger, - codes: codes, - isSource: isSource, - codeMap: codeMap, - } -} - -func (r *GeoIPItem) Match(metadata *adapter.InboundContext) bool { - var geoipCode string - if r.isSource && metadata.SourceGeoIPCode != "" { - geoipCode = metadata.SourceGeoIPCode - } else if !r.isSource && metadata.GeoIPCode != "" { - geoipCode = metadata.GeoIPCode - } - if geoipCode != "" { - return r.codeMap[geoipCode] - } - var destination netip.Addr - if r.isSource { - destination = metadata.Source.Addr - } else { - destination = metadata.Destination.Addr - } - if destination.IsValid() { - return r.match(metadata, destination) - } - for _, destinationAddress := range metadata.DestinationAddresses { - if r.match(metadata, destinationAddress) { - return true - } - } - return false -} - -func (r *GeoIPItem) match(metadata *adapter.InboundContext, destination netip.Addr) bool { - var geoipCode string - geoReader := r.router.GeoIPReader() - if !N.IsPublicAddr(destination) { - geoipCode = "private" - } else if geoReader != nil { - geoipCode = geoReader.Lookup(destination) - } - if geoipCode == "" { - return false - } - if r.isSource { - metadata.SourceGeoIPCode = geoipCode - } else { - metadata.GeoIPCode = geoipCode - } - return r.codeMap[geoipCode] -} - -func (r *GeoIPItem) String() string { - var description string - if r.isSource { - description = "source_geoip=" - } else { - description = "geoip=" - } - cLen := len(r.codes) - if cLen == 1 { - description += r.codes[0] - } else if cLen > 3 { - description += "[" + strings.Join(r.codes[:3], " ") + "...]" - } else { - description += "[" + strings.Join(r.codes, " ") + "]" - } - return description -} diff --git a/route/rule/rule_item_geosite.go b/route/rule/rule_item_geosite.go deleted file mode 100644 index 9e5e03c8..00000000 --- a/route/rule/rule_item_geosite.go +++ /dev/null @@ -1,61 +0,0 @@ -package rule - -import ( - "strings" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/log" - E "github.com/sagernet/sing/common/exceptions" -) - -var _ RuleItem = (*GeositeItem)(nil) - -type GeositeItem struct { - router adapter.Router - logger log.ContextLogger - codes []string - matchers []adapter.Rule -} - -func NewGeositeItem(router adapter.Router, logger log.ContextLogger, codes []string) *GeositeItem { - return &GeositeItem{ - router: router, - logger: logger, - codes: codes, - } -} - -func (r *GeositeItem) Update() error { - matchers := make([]adapter.Rule, 0, len(r.codes)) - for _, code := range r.codes { - matcher, err := r.router.LoadGeosite(code) - if err != nil { - return E.Cause(err, "read geosite") - } - matchers = append(matchers, matcher) - } - r.matchers = matchers - return nil -} - -func (r *GeositeItem) Match(metadata *adapter.InboundContext) bool { - for _, matcher := range r.matchers { - if matcher.Match(metadata) { - return true - } - } - return false -} - -func (r *GeositeItem) String() string { - description := "geosite=" - cLen := len(r.codes) - if cLen == 1 { - description += r.codes[0] - } else if cLen > 3 { - description += "[" + strings.Join(r.codes[:3], " ") + "...]" - } else { - description += "[" + strings.Join(r.codes, " ") + "]" - } - return description -} diff --git a/route/rule/rule_item_ip_accept_any.go b/route/rule/rule_item_ip_accept_any.go new file mode 100644 index 00000000..1ca71257 --- /dev/null +++ b/route/rule/rule_item_ip_accept_any.go @@ -0,0 +1,21 @@ +package rule + +import ( + "github.com/sagernet/sing-box/adapter" +) + +var _ RuleItem = (*IPAcceptAnyItem)(nil) + +type IPAcceptAnyItem struct{} + +func NewIPAcceptAnyItem() *IPAcceptAnyItem { + return &IPAcceptAnyItem{} +} + +func (r *IPAcceptAnyItem) Match(metadata *adapter.InboundContext) bool { + return len(metadata.DestinationAddresses) > 0 +} + +func (r *IPAcceptAnyItem) String() string { + return "ip_accept_any=true" +} diff --git a/route/rule_conds.go b/route/rule_conds.go index 76447176..55c4a058 100644 --- a/route/rule_conds.go +++ b/route/rule_conds.go @@ -3,7 +3,6 @@ package route import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common" ) func hasRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool { @@ -38,22 +37,6 @@ func hasDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bo return false } -func isGeoIPRule(rule option.DefaultRule) bool { - return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) || len(rule.GeoIP) > 0 && common.Any(rule.GeoIP, notPrivateNode) -} - -func isGeoIPDNSRule(rule option.DefaultDNSRule) bool { - return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) || len(rule.GeoIP) > 0 && common.Any(rule.GeoIP, notPrivateNode) -} - -func isGeositeRule(rule option.DefaultRule) bool { - return len(rule.Geosite) > 0 -} - -func isGeositeDNSRule(rule option.DefaultDNSRule) bool { - return len(rule.Geosite) > 0 -} - func isProcessRule(rule option.DefaultRule) bool { return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.ProcessPathRegex) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0 } @@ -62,10 +45,6 @@ func isProcessDNSRule(rule option.DefaultDNSRule) bool { return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.ProcessPathRegex) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0 } -func notPrivateNode(code string) bool { - return code != "private" -} - func isWIFIRule(rule option.DefaultRule) bool { return len(rule.WIFISSID) > 0 || len(rule.WIFIBSSID) > 0 } diff --git a/test/domain_inbound_test.go b/test/domain_inbound_test.go index f39cd187..605740d4 100644 --- a/test/domain_inbound_test.go +++ b/test/domain_inbound_test.go @@ -6,7 +6,6 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/json/badoption" @@ -34,7 +33,7 @@ func TestTUICDomainUDP(t *testing.T) { Listen: common.Ptr(badoption.Addr(netip.IPv4Unspecified())), ListenPort: serverPort, InboundOptions: option.InboundOptions{ - DomainStrategy: option.DomainStrategy(dns.DomainStrategyUseIPv6), + DomainStrategy: option.DomainStrategy(C.DomainStrategyIPv6Only), }, }, Users: []option.TUICUser{{ diff --git a/transport/fakeip/server.go b/transport/fakeip/server.go deleted file mode 100644 index d1bbb2aa..00000000 --- a/transport/fakeip/server.go +++ /dev/null @@ -1,95 +0,0 @@ -package fakeip - -import ( - "context" - "net/netip" - "os" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-dns" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - "github.com/sagernet/sing/service" - - mDNS "github.com/miekg/dns" -) - -var ( - _ dns.Transport = (*Transport)(nil) - _ adapter.FakeIPTransport = (*Transport)(nil) -) - -func init() { - dns.RegisterTransport([]string{"fakeip"}, func(options dns.TransportOptions) (dns.Transport, error) { - return NewTransport(options) - }) -} - -type Transport struct { - name string - router adapter.Router - store adapter.FakeIPStore - logger logger.ContextLogger -} - -func NewTransport(options dns.TransportOptions) (*Transport, error) { - router := service.FromContext[adapter.Router](options.Context) - if router == nil { - return nil, E.New("missing router in context") - } - return &Transport{ - name: options.Name, - router: router, - logger: options.Logger, - }, nil -} - -func (s *Transport) Name() string { - return s.name -} - -func (s *Transport) Start() error { - s.store = s.router.FakeIPStore() - if s.store == nil { - return E.New("fakeip not enabled") - } - return nil -} - -func (s *Transport) Reset() { -} - -func (s *Transport) Close() error { - return nil -} - -func (s *Transport) Raw() bool { - return false -} - -func (s *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - return nil, os.ErrInvalid -} - -func (s *Transport) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) { - var addresses []netip.Addr - if strategy != dns.DomainStrategyUseIPv6 { - inet4Address, err := s.store.Create(domain, false) - if err != nil { - return nil, err - } - addresses = append(addresses, inet4Address) - } - if strategy != dns.DomainStrategyUseIPv4 { - inet6Address, err := s.store.Create(domain, true) - if err != nil { - return nil, err - } - addresses = append(addresses, inet6Address) - } - return addresses, nil -} - -func (s *Transport) Store() adapter.FakeIPStore { - return s.store -} diff --git a/transport/v2rayhttp/server.go b/transport/v2rayhttp/server.go index dd2bc9a2..4b830ae2 100644 --- a/transport/v2rayhttp/server.go +++ b/transport/v2rayhttp/server.go @@ -164,7 +164,7 @@ func (s *Server) Serve(listener net.Listener) error { if len(s.tlsConfig.NextProtos()) == 0 { s.tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"}) } else if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) { - s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...)) + s.tlsConfig.SetNextProtos(append([]string{http2.NextProtoTLS}, s.tlsConfig.NextProtos()...)) } listener = aTLS.NewListener(listener, s.tlsConfig) }