refactor: DNS

This commit is contained in:
世界 2025-03-16 14:50:44 +08:00
parent a530e424e9
commit 86116b9423
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
89 changed files with 4792 additions and 1733 deletions

View file

@ -6,7 +6,9 @@ builds:
- -v
- -trimpath
ldflags:
- -X github.com/sagernet/sing-box/constant.Version={{ .Version }} -s -w -buildid=
- -X github.com/sagernet/sing-box/constant.Version={{ .Version }}
- -s
- -buildid=
tags:
- with_gvisor
- with_quic

73
adapter/dns.go Normal file
View file

@ -0,0 +1,73 @@
package adapter
import (
"context"
"net/netip"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/logger"
"github.com/miekg/dns"
)
type DNSRouter interface {
Lifecycle
Exchange(ctx context.Context, message *dns.Msg, options DNSQueryOptions) (*dns.Msg, error)
Lookup(ctx context.Context, domain string, options DNSQueryOptions) ([]netip.Addr, error)
ClearCache()
LookupReverseMapping(ip netip.Addr) (string, bool)
ResetNetwork()
}
type DNSClient interface {
Start()
Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error)
Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error)
LookupCache(domain string, strategy C.DomainStrategy) ([]netip.Addr, bool)
ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool)
ClearCache()
}
type DNSQueryOptions struct {
Transport DNSTransport
Strategy C.DomainStrategy
DisableCache bool
RewriteTTL *uint32
ClientSubnet netip.Prefix
}
type RDRCStore interface {
LoadRDRC(transportName string, qName string, qType uint16) (rejected bool)
SaveRDRC(transportName string, qName string, qType uint16) error
SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger)
}
type DNSTransport interface {
Type() string
Tag() string
Dependencies() []string
Reset()
Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
}
type LegacyDNSTransport interface {
LegacyStrategy() C.DomainStrategy
LegacyClientSubnet() netip.Prefix
}
type DNSTransportRegistry interface {
option.DNSTransportOptionsRegistry
CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (DNSTransport, error)
}
type DNSTransportManager interface {
Lifecycle
Transports() []DNSTransport
Transport(tag string) (DNSTransport, bool)
Default() DNSTransport
FakeIP() FakeIPTransport
Remove(tag string) error
Create(ctx context.Context, logger log.ContextLogger, tag string, outboundType string, options any) error
}

View file

@ -7,7 +7,6 @@ import (
"time"
"github.com/sagernet/sing-box/common/urltest"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/varbin"
)
@ -31,7 +30,7 @@ type CacheFile interface {
FakeIPStorage
StoreRDRC() bool
dns.RDRCStore
RDRCStore
LoadMode() string
StoreMode(mode string) error

View file

@ -3,7 +3,6 @@ package adapter
import (
"net/netip"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/logger"
)
@ -27,6 +26,6 @@ type FakeIPStorage interface {
}
type FakeIPTransport interface {
dns.Transport
DNSTransport
Store() FakeIPStore
}

View file

@ -78,8 +78,6 @@ type InboundContext struct {
FallbackNetworkType []C.InterfaceType
FallbackDelay time.Duration
DNSServer string
DestinationAddresses []netip.Addr
SourceGeoIPCode string
GeoIPCode string

View file

@ -23,7 +23,7 @@ type Manager struct {
registry adapter.OutboundRegistry
endpoint adapter.EndpointManager
defaultTag string
access sync.Mutex
access sync.RWMutex
started bool
stage adapter.StartStage
outbounds []adapter.Outbound
@ -169,15 +169,15 @@ func (m *Manager) Close() error {
}
func (m *Manager) Outbounds() []adapter.Outbound {
m.access.Lock()
defer m.access.Unlock()
m.access.RLock()
defer m.access.RUnlock()
return m.outbounds
}
func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
m.access.Lock()
m.access.RLock()
outbound, found := m.outboundByTag[tag]
m.access.Unlock()
m.access.RUnlock()
if found {
return outbound, true
}
@ -185,8 +185,8 @@ func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
}
func (m *Manager) Default() adapter.Outbound {
m.access.Lock()
defer m.access.Unlock()
m.access.RLock()
defer m.access.RUnlock()
if m.defaultOutbound != nil {
return m.defaultOutbound
} else {
@ -196,9 +196,9 @@ func (m *Manager) Default() adapter.Outbound {
func (m *Manager) Remove(tag string) error {
m.access.Lock()
defer m.access.Unlock()
outbound, found := m.outboundByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.outboundByTag, tag)
@ -232,7 +232,6 @@ func (m *Manager) Remove(tag string) error {
})
}
}
m.access.Unlock()
if started {
return common.Close(outbound)
}

View file

@ -4,42 +4,25 @@ import (
"context"
"net"
"net/http"
"net/netip"
"sync"
"github.com/sagernet/sing-box/common/geoip"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
mdns "github.com/miekg/dns"
"go4.org/netipx"
)
type Router interface {
Lifecycle
FakeIPStore() FakeIPStore
ConnectionRouter
PreMatch(metadata InboundContext) error
ConnectionRouterEx
GeoIPReader() *geoip.Reader
LoadGeosite(code string) (Rule, error)
RuleSet(tag string) (RuleSet, bool)
NeedWIFIState() bool
Exchange(ctx context.Context, message *mdns.Msg) (*mdns.Msg, error)
Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error)
LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error)
ClearDNSCache()
Rules() []Rule
SetTracker(tracker ConnectionTracker)
ResetNetwork()
}

View file

@ -13,7 +13,6 @@ type Rule interface {
HeadlessRule
Service
Type() string
UpdateGeosite() error
Action() RuleAction
}

63
box.go
View file

@ -16,6 +16,8 @@ import (
"github.com/sagernet/sing-box/common/taskmonitor"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport/local"
"github.com/sagernet/sing-box/experimental"
"github.com/sagernet/sing-box/experimental/cachefile"
"github.com/sagernet/sing-box/experimental/libbox/platform"
@ -41,6 +43,8 @@ type Box struct {
endpoint *endpoint.Manager
inbound *inbound.Manager
outbound *outbound.Manager
dnsTransport *dns.TransportManager
dnsRouter *dns.Router
connection *route.ConnectionManager
router *route.Router
services []adapter.LifecycleService
@ -58,6 +62,7 @@ func Context(
inboundRegistry adapter.InboundRegistry,
outboundRegistry adapter.OutboundRegistry,
endpointRegistry adapter.EndpointRegistry,
dnsTransportRegistry adapter.DNSTransportRegistry,
) context.Context {
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
service.FromContext[adapter.InboundRegistry](ctx) == nil {
@ -74,6 +79,10 @@ func Context(
ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry)
ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry)
}
if service.FromContext[adapter.DNSTransportRegistry](ctx) == nil {
ctx = service.ContextWith[option.DNSTransportOptionsRegistry](ctx, dnsTransportRegistry)
ctx = service.ContextWith[adapter.DNSTransportRegistry](ctx, dnsTransportRegistry)
}
return ctx
}
@ -88,6 +97,7 @@ func New(options Options) (*Box, error) {
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
dnsTransportRegistry := service.FromContext[adapter.DNSTransportRegistry](ctx)
if endpointRegistry == nil {
return nil, E.New("missing endpoint registry in context")
@ -132,13 +142,17 @@ func New(options Options) (*Box, error) {
}
routeOptions := common.PtrValueOrDefault(options.Route)
dnsOptions := common.PtrValueOrDefault(options.DNS)
endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final)
dnsTransportManager := dns.NewTransportManager(logFactory.NewLogger("dns/transport"), dnsTransportRegistry, outboundManager, dnsOptions.Final)
service.MustRegister[adapter.EndpointManager](ctx, endpointManager)
service.MustRegister[adapter.InboundManager](ctx, inboundManager)
service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager)
dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions)
service.MustRegister[adapter.DNSRouter](ctx, dnsRouter)
networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions)
if err != nil {
return nil, E.Cause(err, "initialize network manager")
@ -146,18 +160,40 @@ func New(options Options) (*Box, error) {
service.MustRegister[adapter.NetworkManager](ctx, networkManager)
connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS))
router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions)
service.MustRegister[adapter.Router](ctx, router)
err = router.Initialize(routeOptions.Rules, routeOptions.RuleSet)
if err != nil {
return nil, E.Cause(err, "initialize router")
}
ntpOptions := common.PtrValueOrDefault(options.NTP)
var timeService *tls.TimeServiceWrapper
if ntpOptions.Enabled {
timeService = new(tls.TimeServiceWrapper)
service.MustRegister[ntp.TimeService](ctx, timeService)
}
for i, transportOptions := range dnsOptions.Servers {
var tag string
if transportOptions.Tag != "" {
tag = transportOptions.Tag
} else {
tag = F.ToString(i)
}
err = dnsTransportManager.Create(
ctx,
logFactory.NewLogger(F.ToString("dns/", transportOptions.Type, "[", tag, "]")),
tag,
transportOptions.Type,
transportOptions.Options,
)
if err != nil {
return nil, E.Cause(err, "initialize inbound[", i, "]")
}
}
err = dnsRouter.Initialize(dnsOptions.Rules)
if err != nil {
return nil, E.Cause(err, "initialize dns router")
}
for i, endpointOptions := range options.Endpoints {
var tag string
if endpointOptions.Tag != "" {
@ -238,6 +274,13 @@ func New(options Options) (*Box, error) {
option.DirectOutboundOptions{},
),
))
dnsTransportManager.Initialize(common.Must1(
local.NewTransport(
ctx,
logFactory.NewLogger("dns/local"),
"local",
option.LocalDNSServerOptions{},
)))
if platformInterface != nil {
err = platformInterface.Initialize(networkManager)
if err != nil {
@ -293,6 +336,8 @@ func New(options Options) (*Box, error) {
endpoint: endpointManager,
inbound: inboundManager,
outbound: outboundManager,
dnsTransport: dnsTransportManager,
dnsRouter: dnsRouter,
connection: connectionManager,
router: router,
createdAt: createdAt,
@ -353,11 +398,11 @@ func (s *Box) preStart() error {
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateInitialize, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.connection, s.router)
err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router)
if err != nil {
return err
}
@ -381,7 +426,7 @@ func (s *Box) start() error {
if err != nil {
return err
}
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.connection, s.router, s.inbound, s.endpoint)
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint)
if err != nil {
return err
}
@ -389,7 +434,7 @@ func (s *Box) start() error {
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateStarted, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil {
return err
}
@ -408,7 +453,7 @@ func (s *Box) Close() error {
close(s.done)
}
err := common.Close(
s.inbound, s.outbound, s.endpoint, s.router, s.connection, s.network,
s.inbound, s.outbound, s.endpoint, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network,
)
for _, lifecycleService := range s.services {
err = E.Append(err, lifecycleService.Close(), func(err error) error {

View file

@ -69,5 +69,5 @@ func preRun(cmd *cobra.Command, args []string) {
configPaths = append(configPaths, "config.json")
}
globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger()))
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry(), include.DNSTransportRegistry())
}

View file

@ -9,7 +9,6 @@ import (
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
@ -37,13 +36,13 @@ func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) {
dialer = NewDetour(outboundManager, options.Detour)
}
if options.Detour == "" {
router := service.FromContext[adapter.Router](ctx)
router := service.FromContext[adapter.DNSRouter](ctx)
if router != nil {
dialer = NewResolveDialer(
router,
dialer,
options.Detour == "" && !options.TCPFastOpen,
dns.DomainStrategy(options.DomainStrategy),
C.DomainStrategy(options.DomainStrategy),
time.Duration(options.FallbackDelay))
}
}
@ -62,10 +61,10 @@ func NewDirect(ctx context.Context, options option.DialerOptions) (ParallelInter
return nil, err
}
return NewResolveParallelInterfaceDialer(
service.FromContext[adapter.Router](ctx),
service.FromContext[adapter.DNSRouter](ctx),
dialer,
true,
dns.DomainStrategy(options.DomainStrategy),
C.DomainStrategy(options.DomainStrategy),
time.Duration(options.FallbackDelay),
), nil
}

View file

@ -3,13 +3,11 @@ package dialer
import (
"context"
"net"
"net/netip"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
@ -23,12 +21,12 @@ var (
type resolveDialer struct {
dialer N.Dialer
parallel bool
router adapter.Router
strategy dns.DomainStrategy
router adapter.DNSRouter
strategy C.DomainStrategy
fallbackDelay time.Duration
}
func NewResolveDialer(router adapter.Router, dialer N.Dialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) N.Dialer {
func NewResolveDialer(router adapter.DNSRouter, dialer N.Dialer, parallel bool, strategy C.DomainStrategy, fallbackDelay time.Duration) N.Dialer {
return &resolveDialer{
dialer,
parallel,
@ -43,7 +41,7 @@ type resolveParallelNetworkDialer struct {
dialer ParallelInterfaceDialer
}
func NewResolveParallelInterfaceDialer(router adapter.Router, dialer ParallelInterfaceDialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) ParallelInterfaceDialer {
func NewResolveParallelInterfaceDialer(router adapter.DNSRouter, dialer ParallelInterfaceDialer, parallel bool, strategy C.DomainStrategy, fallbackDelay time.Duration) ParallelInterfaceDialer {
return &resolveParallelNetworkDialer{
resolveDialer{
dialer,
@ -60,22 +58,13 @@ func (d *resolveDialer) DialContext(ctx context.Context, network string, destina
if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Strategy: d.strategy})
if err != nil {
return nil, err
}
if d.parallel {
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, d.fallbackDelay)
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay)
} else {
return N.DialSerial(ctx, d.dialer, network, destination, addresses)
}
@ -85,17 +74,8 @@ func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Strategy: d.strategy})
if err != nil {
return nil, err
}
@ -110,17 +90,10 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context
if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{
Strategy: d.strategy,
})
if err != nil {
return nil, err
}
@ -128,7 +101,7 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context
fallbackDelay = d.fallbackDelay
}
if d.parallel {
return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
} else {
return DialSerialNetwork(ctx, d.dialer, network, destination, addresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
}
@ -138,17 +111,8 @@ func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.C
if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Strategy: d.strategy})
if err != nil {
return nil, err
}

View file

@ -7,24 +7,27 @@ import (
"github.com/sagernet/sing-box/adapter"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
type DefaultOutboundDialer struct {
outboundManager adapter.OutboundManager
outbound adapter.OutboundManager
}
func NewDefaultOutbound(outboundManager adapter.OutboundManager) N.Dialer {
return &DefaultOutboundDialer{outboundManager: outboundManager}
func NewDefaultOutbound(ctx context.Context) N.Dialer {
return &DefaultOutboundDialer{
outbound: service.FromContext[adapter.OutboundManager](ctx),
}
}
func (d *DefaultOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return d.outboundManager.Default().DialContext(ctx, network, destination)
return d.outbound.Default().DialContext(ctx, network, destination)
}
func (d *DefaultOutboundDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return d.outboundManager.Default().ListenPacket(ctx, destination)
return d.outbound.Default().ListenPacket(ctx, destination)
}
func (d *DefaultOutboundDialer) Upstream() any {
return d.outboundManager.Default()
return d.outbound.Default()
}

View file

@ -15,8 +15,8 @@ import (
cftls "github.com/sagernet/cloudflare-tls"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/service"
@ -215,7 +215,7 @@ func fetchECHClientConfig(ctx context.Context) func(_ context.Context, serverNam
},
},
}
response, err := service.FromContext[adapter.Router](ctx).Exchange(ctx, message)
response, err := service.FromContext[adapter.DNSRouter](ctx).Exchange(ctx, message, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}

View file

@ -5,7 +5,6 @@ import (
"crypto/tls"
"crypto/x509"
"net"
"net/netip"
"os"
"strings"
@ -51,10 +50,8 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb
if options.ServerName != "" {
serverName = options.ServerName
} else if serverAddress != "" {
if _, err := netip.ParseAddr(serverName); err != nil {
serverName = serverAddress
}
}
if serverName == "" && !options.Insecure {
return nil, E.New("missing server_name or insecure=true")
}

View file

@ -1,5 +1,34 @@
package constant
const (
DefaultDNSTTL = 600
)
type DomainStrategy = uint8
const (
DomainStrategyAsIS DomainStrategy = iota
DomainStrategyPreferIPv4
DomainStrategyPreferIPv6
DomainStrategyIPv4Only
DomainStrategyIPv6Only
)
const (
DNSTypeLegacy = "legacy"
DNSTypeUDP = "udp"
DNSTypeTCP = "tcp"
DNSTypeTLS = "tls"
DNSTypeHTTPS = "https"
DNSTypeQUIC = "quic"
DNSTypeHTTP3 = "h3"
DNSTypeHosts = "hosts"
DNSTypeLocal = "local"
DNSTypePreDefined = "predefined"
DNSTypeFakeIP = "fakeip"
DNSTypeDHCP = "dhcp"
)
const (
DNSProviderAliDNS = "alidns"
DNSProviderCloudflare = "cloudflare"

563
dns/client.go Normal file
View file

@ -0,0 +1,563 @@
package dns
import (
"context"
"net"
"net/netip"
"strings"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
"github.com/miekg/dns"
)
var (
ErrNoRawSupport = E.New("no raw query support by current transport")
ErrNotCached = E.New("not cached")
ErrResponseRejected = E.New("response rejected")
ErrResponseRejectedCached = E.Extend(ErrResponseRejected, "cached")
)
var _ adapter.DNSClient = (*Client)(nil)
type Client struct {
timeout time.Duration
disableCache bool
disableExpire bool
independentCache bool
rdrc adapter.RDRCStore
initRDRCFunc func() adapter.RDRCStore
logger logger.ContextLogger
cache freelru.Cache[dns.Question, *dns.Msg]
transportCache freelru.Cache[transportCacheKey, *dns.Msg]
}
type ClientOptions struct {
Timeout time.Duration
DisableCache bool
DisableExpire bool
IndependentCache bool
CacheCapacity uint32
RDRC func() adapter.RDRCStore
Logger logger.ContextLogger
}
func NewClient(options ClientOptions) *Client {
client := &Client{
timeout: options.Timeout,
disableCache: options.DisableCache,
disableExpire: options.DisableExpire,
independentCache: options.IndependentCache,
initRDRCFunc: options.RDRC,
logger: options.Logger,
}
if client.timeout == 0 {
client.timeout = C.DNSTimeout
}
cacheCapacity := options.CacheCapacity
if cacheCapacity < 1024 {
cacheCapacity = 1024
}
if !client.disableCache {
if !client.independentCache {
client.cache = common.Must1(freelru.NewSharded[dns.Question, *dns.Msg](cacheCapacity, maphash.NewHasher[dns.Question]().Hash32))
} else {
client.transportCache = common.Must1(freelru.NewSharded[transportCacheKey, *dns.Msg](cacheCapacity, maphash.NewHasher[transportCacheKey]().Hash32))
}
}
return client
}
type transportCacheKey struct {
dns.Question
transportTag string
}
func (c *Client) Start() {
if c.initRDRCFunc != nil {
c.rdrc = c.initRDRCFunc()
}
}
func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) {
if len(message.Question) == 0 {
if c.logger != nil {
c.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
}
responseMessage := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: message.Id,
Response: true,
Rcode: dns.RcodeFormatError,
},
Question: message.Question,
}
return &responseMessage, nil
}
question := message.Question[0]
if options.ClientSubnet.IsValid() {
message = SetClientSubnet(message, options.ClientSubnet, true)
}
isSimpleRequest := len(message.Question) == 1 &&
len(message.Ns) == 0 &&
len(message.Extra) == 0 &&
!options.ClientSubnet.IsValid()
disableCache := !isSimpleRequest || c.disableCache || options.DisableCache
if !disableCache {
response, ttl := c.loadResponse(question, transport)
if response != nil {
logCachedResponse(c.logger, ctx, response, ttl)
response.Id = message.Id
return response, nil
}
}
if question.Qtype == dns.TypeA && options.Strategy == C.DomainStrategyIPv6Only || question.Qtype == dns.TypeAAAA && options.Strategy == C.DomainStrategyIPv4Only {
responseMessage := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: message.Id,
Response: true,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{question},
}
if c.logger != nil {
c.logger.DebugContext(ctx, "strategy rejected")
}
return &responseMessage, nil
}
messageId := message.Id
contextTransport, clientSubnetLoaded := transportTagFromContext(ctx)
if clientSubnetLoaded && transport.Tag() == contextTransport {
return nil, E.New("DNS query loopback in transport[", contextTransport, "]")
}
ctx = contextWithTransportTag(ctx, transport.Tag())
if responseChecker != nil && c.rdrc != nil {
rejected := c.rdrc.LoadRDRC(transport.Tag(), question.Name, question.Qtype)
if rejected {
return nil, ErrResponseRejectedCached
}
}
ctx, cancel := context.WithTimeout(ctx, c.timeout)
response, err := transport.Exchange(ctx, message)
cancel()
if err != nil {
return nil, err
}
/*if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
validResponse := response
loop:
for {
var (
addresses int
queryCNAME string
)
for _, rawRR := range validResponse.Answer {
switch rr := rawRR.(type) {
case *dns.A:
break loop
case *dns.AAAA:
break loop
case *dns.CNAME:
queryCNAME = rr.Target
}
}
if queryCNAME == "" {
break
}
exMessage := *message
exMessage.Question = []dns.Question{{
Name: queryCNAME,
Qtype: question.Qtype,
}}
validResponse, err = c.Exchange(ctx, transport, &exMessage, options, responseChecker)
if err != nil {
return nil, err
}
}
if validResponse != response {
response.Answer = append(response.Answer, validResponse.Answer...)
}
}*/
if responseChecker != nil {
addr, addrErr := MessageToAddresses(response)
if addrErr != nil || !responseChecker(addr) {
if c.rdrc != nil {
c.rdrc.SaveRDRCAsync(transport.Tag(), question.Name, question.Qtype, c.logger)
}
logRejectedResponse(c.logger, ctx, response)
return response, ErrResponseRejected
}
}
if question.Qtype == dns.TypeHTTPS {
if options.Strategy == C.DomainStrategyIPv4Only || options.Strategy == C.DomainStrategyIPv6Only {
for _, rr := range response.Answer {
https, isHTTPS := rr.(*dns.HTTPS)
if !isHTTPS {
continue
}
content := https.SVCB
content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool {
if options.Strategy == C.DomainStrategyIPv4Only {
return it.Key() != dns.SVCB_IPV6HINT
} else {
return it.Key() != dns.SVCB_IPV4HINT
}
})
https.SVCB = content
}
}
}
var timeToLive uint32
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
if timeToLive == 0 || record.Header().Ttl > 0 && record.Header().Ttl < timeToLive {
timeToLive = record.Header().Ttl
}
}
}
if options.RewriteTTL != nil {
timeToLive = *options.RewriteTTL
}
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
record.Header().Ttl = timeToLive
}
}
response.Id = messageId
if !disableCache {
c.storeCache(transport, question, response, timeToLive)
}
logExchangedResponse(c.logger, ctx, response, timeToLive)
return response, err
}
func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
domain = FqdnToDomain(domain)
dnsName := dns.Fqdn(domain)
if options.Strategy == C.DomainStrategyIPv4Only {
return c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker)
} else if options.Strategy == C.DomainStrategyIPv6Only {
return c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker)
}
var response4 []netip.Addr
var response6 []netip.Addr
var group task.Group
group.Append("exchange4", func(ctx context.Context) error {
response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker)
if err != nil {
return err
}
response4 = response
return nil
})
group.Append("exchange6", func(ctx context.Context) error {
response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker)
if err != nil {
return err
}
response6 = response
return nil
})
err := group.Run(ctx)
if len(response4) == 0 && len(response6) == 0 {
return nil, err
}
return sortAddresses(response4, response6, options.Strategy), nil
}
func (c *Client) ClearCache() {
if c.cache != nil {
c.cache.Purge()
}
if c.transportCache != nil {
c.transportCache.Purge()
}
}
func (c *Client) LookupCache(domain string, strategy C.DomainStrategy) ([]netip.Addr, bool) {
if c.disableCache || c.independentCache {
return nil, false
}
if dns.IsFqdn(domain) {
domain = domain[:len(domain)-1]
}
dnsName := dns.Fqdn(domain)
if strategy == C.DomainStrategyIPv4Only {
response, err := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}, nil)
if err != ErrNotCached {
return response, true
}
} else if strategy == C.DomainStrategyIPv6Only {
response, err := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}, nil)
if err != ErrNotCached {
return response, true
}
} else {
response4, _ := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}, nil)
response6, _ := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}, nil)
if len(response4) > 0 || len(response6) > 0 {
return sortAddresses(response4, response6, strategy), true
}
}
return nil, false
}
func (c *Client) ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool) {
if c.disableCache || c.independentCache || len(message.Question) != 1 {
return nil, false
}
question := message.Question[0]
response, ttl := c.loadResponse(question, nil)
if response == nil {
return nil, false
}
logCachedResponse(c.logger, ctx, response, ttl)
response.Id = message.Id
return response, true
}
func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.DomainStrategy) []netip.Addr {
if strategy == C.DomainStrategyPreferIPv6 {
return append(response6, response4...)
} else {
return append(response4, response6...)
}
}
func (c *Client) storeCache(transport adapter.DNSTransport, question dns.Question, message *dns.Msg, timeToLive uint32) {
if timeToLive == 0 {
return
}
if c.disableExpire {
if !c.independentCache {
c.cache.Add(question, message)
} else {
c.transportCache.Add(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
}, message)
}
return
}
if !c.independentCache {
c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive))
} else {
c.transportCache.AddWithLifetime(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
}, message, time.Second*time.Duration(timeToLive))
}
}
func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name string, qType uint16, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
question := dns.Question{
Name: name,
Qtype: qType,
Qclass: dns.ClassINET,
}
disableCache := c.disableCache || options.DisableCache
if !disableCache {
cachedAddresses, err := c.questionCache(question, transport)
if err != ErrNotCached {
return cachedAddresses, err
}
}
message := dns.Msg{
MsgHdr: dns.MsgHdr{
RecursionDesired: true,
},
Question: []dns.Question{question},
}
response, err := c.Exchange(ctx, transport, &message, options, responseChecker)
if err != nil {
return nil, err
}
return MessageToAddresses(response)
}
func (c *Client) questionCache(question dns.Question, transport adapter.DNSTransport) ([]netip.Addr, error) {
response, _ := c.loadResponse(question, transport)
if response == nil {
return nil, ErrNotCached
}
return MessageToAddresses(response)
}
func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int) {
var (
response *dns.Msg
loaded bool
)
if c.disableExpire {
if !c.independentCache {
response, loaded = c.cache.Get(question)
} else {
response, loaded = c.transportCache.Get(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
})
}
if !loaded {
return nil, 0
}
return response.Copy(), 0
} else {
var expireAt time.Time
if !c.independentCache {
response, expireAt, loaded = c.cache.GetWithLifetime(question)
} else {
response, expireAt, loaded = c.transportCache.GetWithLifetime(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
})
}
if !loaded {
return nil, 0
}
timeNow := time.Now()
if timeNow.After(expireAt) {
if !c.independentCache {
c.cache.Remove(question)
} else {
c.transportCache.Remove(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
})
}
return nil, 0
}
var originTTL int
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
if originTTL == 0 || record.Header().Ttl > 0 && int(record.Header().Ttl) < originTTL {
originTTL = int(record.Header().Ttl)
}
}
}
nowTTL := int(expireAt.Sub(timeNow).Seconds())
if nowTTL < 0 {
nowTTL = 0
}
response = response.Copy()
if originTTL > 0 {
duration := uint32(originTTL - nowTTL)
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
record.Header().Ttl = record.Header().Ttl - duration
}
}
} else {
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
record.Header().Ttl = uint32(nowTTL)
}
}
}
return response, nowTTL
}
}
func MessageToAddresses(response *dns.Msg) ([]netip.Addr, error) {
if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
return nil, RCodeError(response.Rcode)
}
addresses := make([]netip.Addr, 0, len(response.Answer))
for _, rawAnswer := range response.Answer {
switch answer := rawAnswer.(type) {
case *dns.A:
addresses = append(addresses, M.AddrFromIP(answer.A))
case *dns.AAAA:
addresses = append(addresses, M.AddrFromIP(answer.AAAA))
case *dns.HTTPS:
for _, value := range answer.SVCB.Value {
if value.Key() == dns.SVCB_IPV4HINT || value.Key() == dns.SVCB_IPV6HINT {
addresses = append(addresses, common.Map(strings.Split(value.String(), ","), M.ParseAddr)...)
}
}
}
}
return addresses, nil
}
func wrapError(err error) error {
switch dnsErr := err.(type) {
case *net.DNSError:
if dnsErr.IsNotFound {
return RCodeNameError
}
case *net.AddrError:
return RCodeNameError
}
return err
}
type transportKey struct{}
func contextWithTransportTag(ctx context.Context, transportTag string) context.Context {
return context.WithValue(ctx, transportKey{}, transportTag)
}
func transportTagFromContext(ctx context.Context) (string, bool) {
value, loaded := ctx.Value(transportKey{}).(string)
return value, loaded
}
func FixedResponse(id uint16, question dns.Question, addresses []netip.Addr, timeToLive uint32) *dns.Msg {
response := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: id,
Rcode: dns.RcodeSuccess,
Response: true,
},
Question: []dns.Question{question},
}
for _, address := range addresses {
if address.Is4() {
response.Answer = append(response.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: timeToLive,
},
A: address.AsSlice(),
})
} else {
response.Answer = append(response.Answer, &dns.AAAA{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: timeToLive,
},
AAAA: address.AsSlice(),
})
}
}
return &response
}

69
dns/client_log.go Normal file
View file

@ -0,0 +1,69 @@
package dns
import (
"context"
"strings"
"github.com/sagernet/sing/common/logger"
"github.com/miekg/dns"
)
func logCachedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg, ttl int) {
if logger == nil || len(response.Question) == 0 {
return
}
domain := FqdnToDomain(response.Question[0].Name)
logger.DebugContext(ctx, "cached ", domain, " ", dns.RcodeToString[response.Rcode], " ", ttl)
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
logger.InfoContext(ctx, "cached ", dns.Type(record.Header().Rrtype).String(), " ", FormatQuestion(record.String()))
}
}
}
func logExchangedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg, ttl uint32) {
if logger == nil || len(response.Question) == 0 {
return
}
domain := FqdnToDomain(response.Question[0].Name)
logger.DebugContext(ctx, "exchanged ", domain, " ", dns.RcodeToString[response.Rcode], " ", ttl)
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
logger.InfoContext(ctx, "exchanged ", dns.Type(record.Header().Rrtype).String(), " ", FormatQuestion(record.String()))
}
}
}
func logRejectedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg) {
if logger == nil || len(response.Question) == 0 {
return
}
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
logger.InfoContext(ctx, "rejected ", dns.Type(record.Header().Rrtype).String(), " ", FormatQuestion(record.String()))
}
}
}
func FqdnToDomain(fqdn string) string {
if dns.IsFqdn(fqdn) {
return fqdn[:len(fqdn)-1]
}
return fqdn
}
func FormatQuestion(string string) string {
for strings.HasPrefix(string, ";") {
string = string[1:]
}
string = strings.ReplaceAll(string, "\t", " ")
string = strings.ReplaceAll(string, "\n", " ")
string = strings.ReplaceAll(string, ";; ", " ")
string = strings.ReplaceAll(string, "; ", " ")
for strings.Contains(string, " ") {
string = strings.ReplaceAll(string, " ", " ")
}
return strings.TrimSpace(string)
}

29
dns/client_truncate.go Normal file
View file

@ -0,0 +1,29 @@
package dns
import (
"github.com/sagernet/sing/common/buf"
"github.com/miekg/dns"
)
func TruncateDNSMessage(request *dns.Msg, response *dns.Msg, headroom int) (*buf.Buffer, error) {
maxLen := 512
if edns0Option := request.IsEdns0(); edns0Option != nil {
if udpSize := int(edns0Option.UDPSize()); udpSize > 512 {
maxLen = udpSize
}
}
responseLen := response.Len()
if responseLen > maxLen {
response.Truncate(maxLen)
}
buffer := buf.NewSize(headroom*2 + 1 + responseLen)
buffer.Resize(headroom, 0)
rawMessage, err := response.PackBuffer(buffer.FreeBytes())
if err != nil {
buffer.Release()
return nil, err
}
buffer.Truncate(len(rawMessage))
return buffer, nil
}

View file

@ -0,0 +1,56 @@
package dns
import (
"net/netip"
"github.com/miekg/dns"
)
func SetClientSubnet(message *dns.Msg, clientSubnet netip.Prefix, override bool) *dns.Msg {
var (
optRecord *dns.OPT
subnetOption *dns.EDNS0_SUBNET
)
findExists:
for _, record := range message.Extra {
var isOPTRecord bool
if optRecord, isOPTRecord = record.(*dns.OPT); isOPTRecord {
for _, option := range optRecord.Option {
var isEDNS0Subnet bool
subnetOption, isEDNS0Subnet = option.(*dns.EDNS0_SUBNET)
if isEDNS0Subnet {
if !override {
return message
}
break findExists
}
}
}
}
if optRecord == nil {
exMessage := *message
message = &exMessage
optRecord = &dns.OPT{
Hdr: dns.RR_Header{
Name: ".",
Rrtype: dns.TypeOPT,
},
}
message.Extra = append(message.Extra, optRecord)
} else {
message = message.Copy()
}
if subnetOption == nil {
subnetOption = new(dns.EDNS0_SUBNET)
optRecord.Option = append(optRecord.Option, subnetOption)
}
subnetOption.Code = dns.EDNS0SUBNET
if clientSubnet.Addr().Is4() {
subnetOption.Family = 1
} else {
subnetOption.Family = 2
}
subnetOption.SourceNetmask = uint8(clientSubnet.Bits())
subnetOption.Address = clientSubnet.Addr().AsSlice()
return message
}

33
dns/rcode.go Normal file
View file

@ -0,0 +1,33 @@
package dns
import F "github.com/sagernet/sing/common/format"
const (
RCodeSuccess RCodeError = 0 // NoError
RCodeFormatError RCodeError = 1 // FormErr
RCodeServerFailure RCodeError = 2 // ServFail
RCodeNameError RCodeError = 3 // NXDomain
RCodeNotImplemented RCodeError = 4 // NotImp
RCodeRefused RCodeError = 5 // Refused
)
type RCodeError uint16
func (e RCodeError) Error() string {
switch e {
case RCodeSuccess:
return "success"
case RCodeFormatError:
return "format error"
case RCodeServerFailure:
return "server failure"
case RCodeNameError:
return "name error"
case RCodeNotImplemented:
return "not implemented"
case RCodeRefused:
return "refused"
default:
return F.ToString("unknown error: ", uint16(e))
}
}

437
dns/router.go Normal file
View file

@ -0,0 +1,437 @@
package dns
import (
"context"
"errors"
"net/netip"
"strings"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
R "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
"github.com/sagernet/sing/service"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSRouter = (*Router)(nil)
type Router struct {
ctx context.Context
logger logger.ContextLogger
transport adapter.DNSTransportManager
outbound adapter.OutboundManager
client adapter.DNSClient
rules []adapter.DNSRule
defaultDomainStrategy C.DomainStrategy
dnsReverseMapping freelru.Cache[netip.Addr, string]
platformInterface platform.Interface
}
func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOptions) *Router {
router := &Router{
ctx: ctx,
logger: logFactory.NewLogger("dns"),
transport: service.FromContext[adapter.DNSTransportManager](ctx),
outbound: service.FromContext[adapter.OutboundManager](ctx),
rules: make([]adapter.DNSRule, 0, len(options.Rules)),
defaultDomainStrategy: C.DomainStrategy(options.Strategy),
}
router.client = NewClient(ClientOptions{
DisableCache: options.DNSClientOptions.DisableCache,
DisableExpire: options.DNSClientOptions.DisableExpire,
IndependentCache: options.DNSClientOptions.IndependentCache,
CacheCapacity: options.DNSClientOptions.CacheCapacity,
RDRC: func() adapter.RDRCStore {
cacheFile := service.FromContext[adapter.CacheFile](ctx)
if cacheFile == nil {
return nil
}
if !cacheFile.StoreRDRC() {
return nil
}
return cacheFile
},
Logger: router.logger,
})
if options.ReverseMapping {
router.dnsReverseMapping = common.Must1(freelru.NewSharded[netip.Addr, string](1024, maphash.NewHasher[netip.Addr]().Hash32))
}
return router
}
func (r *Router) Initialize(rules []option.DNSRule) error {
for i, ruleOptions := range rules {
dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true)
if err != nil {
return E.Cause(err, "parse dns rule[", i, "]")
}
r.rules = append(r.rules, dnsRule)
}
return nil
}
func (r *Router) Start(stage adapter.StartStage) error {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
switch stage {
case adapter.StartStateStart:
monitor.Start("initialize DNS client")
r.client.Start()
monitor.Finish()
for i, rule := range r.rules {
monitor.Start("initialize DNS rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize DNS rule[", i, "]")
}
}
}
return nil
}
func (r *Router) Close() error {
monitor := taskmonitor.New(r.logger, C.StopTimeout)
var err error
for i, rule := range r.rules {
monitor.Start("close dns rule[", i, "]")
err = E.Append(err, rule.Close(), func(err error) error {
return E.Cause(err, "close dns rule[", i, "]")
})
monitor.Finish()
}
return err
}
func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
metadata := adapter.ContextFrom(ctx)
if metadata == nil {
panic("no context")
}
var currentRuleIndex int
if ruleIndex != -1 {
currentRuleIndex = ruleIndex + 1
}
for ; currentRuleIndex < len(r.rules); currentRuleIndex++ {
currentRule := r.rules[currentRuleIndex]
if currentRule.WithAddressLimit() && !isAddressQuery {
continue
}
metadata.ResetRuleCache()
if currentRule.Match(metadata) {
displayRuleIndex := currentRuleIndex
if displayRuleIndex != -1 {
displayRuleIndex += displayRuleIndex + 1
}
ruleDescription := currentRule.String()
if ruleDescription != "" {
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] ", currentRule, " => ", currentRule.Action())
} else {
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
}
switch action := currentRule.Action().(type) {
case *R.RuleActionDNSRoute:
transport, loaded := r.transport.Transport(action.Server)
if !loaded {
r.logger.ErrorContext(ctx, "transport not found: ", action.Server)
continue
}
isFakeIP := transport.Type() == C.DNSTypeFakeIP
if isFakeIP && !allowFakeIP {
continue
}
if action.Strategy != C.DomainStrategyAsIS {
options.Strategy = action.Strategy
}
if isFakeIP || action.DisableCache {
options.DisableCache = true
}
if action.RewriteTTL != nil {
options.RewriteTTL = action.RewriteTTL
}
if action.ClientSubnet.IsValid() {
options.ClientSubnet = action.ClientSubnet
}
if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy {
if options.Strategy == C.DomainStrategyAsIS {
options.Strategy = legacyTransport.LegacyStrategy()
}
if !options.ClientSubnet.IsValid() {
options.ClientSubnet = legacyTransport.LegacyClientSubnet()
}
}
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
return transport, currentRule, currentRuleIndex
case *R.RuleActionDNSRouteOptions:
if action.Strategy != C.DomainStrategyAsIS {
options.Strategy = action.Strategy
}
if action.DisableCache {
options.DisableCache = true
}
if action.RewriteTTL != nil {
options.RewriteTTL = action.RewriteTTL
}
if action.ClientSubnet.IsValid() {
options.ClientSubnet = action.ClientSubnet
}
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
case *R.RuleActionReject:
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
return nil, currentRule, currentRuleIndex
}
}
}
return r.transport.Default(), nil, -1
}
func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) {
if len(message.Question) != 1 {
r.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
responseMessage := mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
Id: message.Id,
Response: true,
Rcode: mDNS.RcodeFormatError,
},
Question: message.Question,
}
return &responseMessage, nil
}
r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String()))
var (
transport adapter.DNSTransport
err error
)
response, cached := r.client.ExchangeCache(ctx, message)
if !cached {
var metadata *adapter.InboundContext
ctx, metadata = adapter.ExtendContext(ctx)
metadata.Destination = M.Socksaddr{}
metadata.QueryType = message.Question[0].Qtype
switch metadata.QueryType {
case mDNS.TypeA:
metadata.IPVersion = 4
case mDNS.TypeAAAA:
metadata.IPVersion = 6
}
metadata.Domain = FqdnToDomain(message.Question[0].Name)
if options.Transport != nil {
transport = options.Transport
if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy {
if options.Strategy == C.DomainStrategyAsIS {
options.Strategy = legacyTransport.LegacyStrategy()
}
if !options.ClientSubnet.IsValid() {
options.ClientSubnet = legacyTransport.LegacyClientSubnet()
}
}
if options.Strategy == C.DomainStrategyAsIS {
options.Strategy = r.defaultDomainStrategy
}
response, err = r.client.Exchange(ctx, transport, message, options, nil)
} else {
var (
rule adapter.DNSRule
ruleIndex int
)
ruleIndex = -1
for {
dnsCtx := adapter.OverrideContext(ctx)
dnsOptions := options
transport, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message), &dnsOptions)
if rule != nil {
switch action := rule.Action().(type) {
case *R.RuleActionReject:
switch action.Method {
case C.RuleActionRejectMethodDefault:
return FixedResponse(message.Id, message.Question[0], nil, 0), nil
case C.RuleActionRejectMethodDrop:
return nil, tun.ErrDrop
}
}
}
var responseCheck func(responseAddrs []netip.Addr) bool
if rule != nil && rule.WithAddressLimit() {
responseCheck = func(responseAddrs []netip.Addr) bool {
metadata.DestinationAddresses = responseAddrs
return rule.MatchAddressLimit(metadata)
}
}
if dnsOptions.Strategy == C.DomainStrategyAsIS {
dnsOptions.Strategy = r.defaultDomainStrategy
}
response, err = r.client.Exchange(dnsCtx, transport, message, dnsOptions, responseCheck)
var rejected bool
if err != nil {
if errors.Is(err, ErrResponseRejectedCached) {
rejected = true
r.logger.DebugContext(ctx, E.Cause(err, "response rejected for ", FormatQuestion(message.Question[0].String())), " (cached)")
} else if errors.Is(err, ErrResponseRejected) {
rejected = true
r.logger.DebugContext(ctx, E.Cause(err, "response rejected for ", FormatQuestion(message.Question[0].String())))
} else if len(message.Question) > 0 {
r.logger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", FormatQuestion(message.Question[0].String())))
} else {
r.logger.ErrorContext(ctx, E.Cause(err, "exchange failed for <empty query>"))
}
}
if responseCheck != nil && rejected {
continue
}
break
}
}
}
if err != nil {
return nil, err
}
if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 {
if transport == nil || transport.Type() != C.DNSTypeFakeIP {
for _, answer := range response.Answer {
switch record := answer.(type) {
case *mDNS.A:
r.dnsReverseMapping.AddWithLifetime(M.AddrFromIP(record.A), FqdnToDomain(record.Hdr.Name), time.Duration(record.Hdr.Ttl)*time.Second)
case *mDNS.AAAA:
r.dnsReverseMapping.AddWithLifetime(M.AddrFromIP(record.AAAA), FqdnToDomain(record.Hdr.Name), time.Duration(record.Hdr.Ttl)*time.Second)
}
}
}
}
return response, nil
}
func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
var (
responseAddrs []netip.Addr
cached bool
err error
)
printResult := func() {
if err != nil {
if errors.Is(err, ErrResponseRejectedCached) {
r.logger.DebugContext(ctx, "response rejected for ", domain, " (cached)")
} else if errors.Is(err, ErrResponseRejected) {
r.logger.DebugContext(ctx, "response rejected for ", domain)
} else {
r.logger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
}
} else if len(responseAddrs) == 0 {
r.logger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
err = RCodeNameError
}
}
responseAddrs, cached = r.client.LookupCache(domain, options.Strategy)
if cached {
if len(responseAddrs) == 0 {
return nil, RCodeNameError
}
return responseAddrs, nil
}
r.logger.DebugContext(ctx, "lookup domain ", domain)
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Destination = M.Socksaddr{}
metadata.Domain = FqdnToDomain(domain)
if options.Transport != nil {
transport := options.Transport
if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy {
if options.Strategy == C.DomainStrategyAsIS {
options.Strategy = r.defaultDomainStrategy
}
if !options.ClientSubnet.IsValid() {
options.ClientSubnet = legacyTransport.LegacyClientSubnet()
}
}
if options.Strategy == C.DomainStrategyAsIS {
options.Strategy = r.defaultDomainStrategy
}
responseAddrs, err = r.client.Lookup(ctx, transport, domain, options, nil)
} else {
var (
transport adapter.DNSTransport
rule adapter.DNSRule
ruleIndex int
)
ruleIndex = -1
for {
dnsCtx := adapter.OverrideContext(ctx)
transport, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true, &options)
if rule != nil {
switch action := rule.Action().(type) {
case *R.RuleActionReject:
switch action.Method {
case C.RuleActionRejectMethodDefault:
return nil, nil
case C.RuleActionRejectMethodDrop:
return nil, tun.ErrDrop
}
}
}
var responseCheck func(responseAddrs []netip.Addr) bool
if rule != nil && rule.WithAddressLimit() {
responseCheck = func(responseAddrs []netip.Addr) bool {
metadata.DestinationAddresses = responseAddrs
return rule.MatchAddressLimit(metadata)
}
}
if options.Strategy == C.DomainStrategyAsIS {
options.Strategy = r.defaultDomainStrategy
}
responseAddrs, err = r.client.Lookup(dnsCtx, transport, domain, options, responseCheck)
if responseCheck == nil || err == nil {
break
}
printResult()
}
}
printResult()
if len(responseAddrs) > 0 {
r.logger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " "))
}
return responseAddrs, err
}
func isAddressQuery(message *mDNS.Msg) bool {
for _, question := range message.Question {
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA || question.Qtype == mDNS.TypeHTTPS {
return true
}
}
return false
}
func (r *Router) ClearCache() {
r.client.ClearCache()
if r.platformInterface != nil {
r.platformInterface.ClearDNSCache()
}
}
func (r *Router) LookupReverseMapping(ip netip.Addr) (string, bool) {
if r.dnsReverseMapping == nil {
return "", false
}
domain, loaded := r.dnsReverseMapping.Get(ip)
return domain, loaded
}
func (r *Router) ResetNetwork() {
r.ClearCache()
for _, transport := range r.transport.Transports() {
transport.Reset()
}
}

View file

@ -3,9 +3,6 @@ package dhcp
import (
"context"
"net"
"net/netip"
"net/url"
"os"
"runtime"
"strings"
"sync"
@ -14,13 +11,18 @@ import (
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
@ -29,76 +31,70 @@ import (
mDNS "github.com/miekg/dns"
)
func init() {
dns.RegisterTransport([]string{"dhcp"}, func(options dns.TransportOptions) (dns.Transport, error) {
return NewTransport(options)
})
func RegisterTransport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.DHCPDNSServerOptions](registry, C.DNSTypeDHCP, NewTransport)
}
var _ adapter.DNSTransport = (*Transport)(nil)
type Transport struct {
options dns.TransportOptions
router adapter.Router
dns.TransportAdapter
ctx context.Context
dialer N.Dialer
logger logger.ContextLogger
networkManager adapter.NetworkManager
interfaceName string
autoInterface bool
interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback]
transports []dns.Transport
transports []adapter.DNSTransport
updateAccess sync.Mutex
updatedAt time.Time
}
func NewTransport(options dns.TransportOptions) (*Transport, error) {
linkURL, err := url.Parse(options.Address)
func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.DHCPDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewLocalDialer(ctx, options.LocalDNSServerOptions)
if err != nil {
return nil, err
}
if linkURL.Host == "" {
return nil, E.New("missing interface name for DHCP")
}
transport := &Transport{
options: options,
networkManager: service.FromContext[adapter.NetworkManager](options.Context),
interfaceName: linkURL.Host,
autoInterface: linkURL.Host == "auto",
}
return transport, nil
return &Transport{
TransportAdapter: dns.NewTransportAdapterWithLocalOptions(C.DNSTypeDHCP, tag, options.LocalDNSServerOptions),
ctx: ctx,
dialer: transportDialer,
logger: logger,
networkManager: service.FromContext[adapter.NetworkManager](ctx),
interfaceName: options.Interface,
}, nil
}
func (t *Transport) Name() string {
return t.options.Name
}
func (t *Transport) Start() error {
func (t *Transport) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
err := t.fetchServers()
if err != nil {
return err
}
if t.autoInterface {
if t.interfaceName == "" {
t.interfaceCallback = t.networkManager.InterfaceMonitor().RegisterCallback(t.interfaceUpdated)
}
return nil
}
func (t *Transport) Close() error {
for _, transport := range t.transports {
transport.Reset()
}
if t.interfaceCallback != nil {
t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback)
}
return nil
}
func (t *Transport) Reset() {
for _, transport := range t.transports {
transport.Reset()
}
}
func (t *Transport) Close() error {
for _, transport := range t.transports {
transport.Close()
}
if t.interfaceCallback != nil {
t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback)
}
return nil
}
func (t *Transport) Raw() bool {
return true
}
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
err := t.fetchServers()
if err != nil {
@ -120,7 +116,7 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg,
}
func (t *Transport) fetchInterface() (*control.Interface, error) {
if t.autoInterface {
if t.interfaceName == "" {
if t.networkManager.InterfaceMonitor() == nil {
return nil, E.New("missing monitor for auto DHCP, set route.auto_detect_interface")
}
@ -152,8 +148,8 @@ func (t *Transport) updateServers() error {
return E.Cause(err, "dhcp: prepare interface")
}
t.options.Logger.Info("dhcp: query DNS servers on ", iface.Name)
fetchCtx, cancel := context.WithTimeout(t.options.Context, C.DHCPTimeout)
t.logger.Info("dhcp: query DNS servers on ", iface.Name)
fetchCtx, cancel := context.WithTimeout(t.ctx, C.DHCPTimeout)
err = t.fetchServers0(fetchCtx, iface)
cancel()
if err != nil {
@ -169,7 +165,7 @@ func (t *Transport) updateServers() error {
func (t *Transport) interfaceUpdated(defaultInterface *control.Interface, flags int) {
err := t.updateServers()
if err != nil {
t.options.Logger.Error("update servers: ", err)
t.logger.Error("update servers: ", err)
}
}
@ -181,7 +177,7 @@ func (t *Transport) fetchServers0(ctx context.Context, iface *control.Interface)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
listenAddr = "255.255.255.255:68"
}
packetConn, err := listener.ListenPacket(t.options.Context, "udp4", listenAddr)
packetConn, err := listener.ListenPacket(t.ctx, "udp4", listenAddr)
if err != nil {
return err
}
@ -219,17 +215,17 @@ func (t *Transport) fetchServersResponse(iface *control.Interface, packetConn ne
dhcpPacket, err := dhcpv4.FromBytes(buffer.Bytes())
if err != nil {
t.options.Logger.Trace("dhcp: parse DHCP response: ", err)
t.logger.Trace("dhcp: parse DHCP response: ", err)
return err
}
if dhcpPacket.MessageType() != dhcpv4.MessageTypeOffer {
t.options.Logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType())
t.logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType())
continue
}
if dhcpPacket.TransactionID != transactionID {
t.options.Logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID)
t.logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID)
continue
}
@ -237,44 +233,27 @@ func (t *Transport) fetchServersResponse(iface *control.Interface, packetConn ne
if len(dns) == 0 {
return nil
}
var addrs []netip.Addr
for _, ip := range dns {
addr, _ := netip.AddrFromSlice(ip)
addrs = append(addrs, addr.Unmap())
}
return t.recreateServers(iface, addrs)
return t.recreateServers(iface, common.Map(dns, func(it net.IP) M.Socksaddr {
return M.SocksaddrFrom(M.AddrFromIP(it), 53)
}))
}
}
func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []netip.Addr) error {
func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []M.Socksaddr) error {
if len(serverAddrs) > 0 {
t.options.Logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, func(it netip.Addr) string {
return it.String()
}), ","), "]")
t.logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, M.Socksaddr.String), ","), "]")
}
serverDialer := common.Must1(dialer.NewDefault(t.options.Context, option.DialerOptions{
serverDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{
BindInterface: iface.Name,
UDPFragmentDefault: true,
}))
var transports []dns.Transport
var transports []adapter.DNSTransport
for _, serverAddr := range serverAddrs {
newOptions := t.options
newOptions.Address = serverAddr.String()
newOptions.Dialer = serverDialer
serverTransport, err := dns.NewUDPTransport(newOptions)
if err != nil {
return E.Cause(err, "create UDP transport from DHCP result: ", serverAddr)
}
transports = append(transports, serverTransport)
transports = append(transports, transport.NewUDPRaw(t.logger, t.TransportAdapter, serverDialer, serverAddr))
}
for _, transport := range t.transports {
transport.Close()
transport.Reset()
}
t.transports = transports
return nil
}
func (t *Transport) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
return nil, os.ErrInvalid
}

View file

@ -0,0 +1,67 @@
package fakeip
import (
"context"
"net/netip"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
mDNS "github.com/miekg/dns"
)
func RegisterTransport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.FakeIPDNSServerOptions](registry, C.DNSTypeFakeIP, NewTransport)
}
var _ adapter.FakeIPTransport = (*Transport)(nil)
type Transport struct {
dns.TransportAdapter
logger logger.ContextLogger
store adapter.FakeIPStore
}
func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.FakeIPDNSServerOptions) (adapter.DNSTransport, error) {
store := NewStore(ctx, logger, options.Inet4Range.Build(netip.Prefix{}), options.Inet6Range.Build(netip.Prefix{}))
return &Transport{
TransportAdapter: dns.NewTransportAdapter(C.DNSTypeFakeIP, tag, nil),
logger: logger,
store: store,
}, nil
}
func (t *Transport) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return t.store.Start()
}
func (t *Transport) Close() error {
return t.store.Close()
}
func (t *Transport) Reset() {
}
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
question := message.Question[0]
if question.Qtype != mDNS.TypeA && question.Qtype != mDNS.TypeAAAA {
return nil, E.New("only IP queries are supported by fakeip")
}
address, err := t.store.Create(dns.FqdnToDomain(question.Name), question.Qtype == mDNS.TypeAAAA)
if err != nil {
return nil, err
}
return dns.FixedResponse(message.Id, question, []netip.Addr{address}, C.DefaultDNSTTL), nil
}
func (t *Transport) Store() adapter.FakeIPStore {
return t.store
}

View file

@ -0,0 +1,63 @@
package hosts
import (
"context"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
mDNS "github.com/miekg/dns"
)
func RegisterTransport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.HostsDNSServerOptions](registry, C.DNSTypeHosts, NewTransport)
}
var _ adapter.DNSTransport = (*Transport)(nil)
type Transport struct {
dns.TransportAdapter
files []*File
}
func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.HostsDNSServerOptions) (adapter.DNSTransport, error) {
var files []*File
if len(options.Path) == 0 {
files = append(files, NewFile(DefaultPath))
} else {
for _, path := range options.Path {
files = append(files, NewFile(path))
}
}
return &Transport{
TransportAdapter: dns.NewTransportAdapter(C.DNSTypeHosts, tag, nil),
files: files,
}, nil
}
func (t *Transport) Reset() {
}
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
question := message.Question[0]
domain := dns.FqdnToDomain(question.Name)
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
for _, file := range t.files {
addresses := file.Lookup(domain)
if len(addresses) > 0 {
return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil
}
}
}
return &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
Id: message.Id,
Rcode: mDNS.RcodeNameError,
Response: true,
},
Question: []mDNS.Question{question},
}, nil
}

View file

@ -0,0 +1,102 @@
package hosts
import (
"bufio"
"errors"
"io"
"net/netip"
"os"
"strings"
"sync"
"time"
"github.com/miekg/dns"
)
const cacheMaxAge = 5 * time.Second
type File struct {
path string
access sync.Mutex
byName map[string][]netip.Addr
expire time.Time
modTime time.Time
size int64
}
func NewFile(path string) *File {
return &File{
path: path,
}
}
func (f *File) Lookup(name string) []netip.Addr {
f.access.Lock()
defer f.access.Unlock()
f.update()
return f.byName[name]
}
func (f *File) update() {
now := time.Now()
if now.Before(f.expire) && len(f.byName) > 0 {
return
}
stat, err := os.Stat(f.path)
if err != nil {
return
}
if f.modTime.Equal(stat.ModTime()) && f.size == stat.Size() {
f.expire = now.Add(cacheMaxAge)
return
}
byName := make(map[string][]netip.Addr)
file, err := os.Open(f.path)
if err != nil {
return
}
defer file.Close()
reader := bufio.NewReader(file)
var (
prefix []byte
line []byte
isPrefix bool
)
for {
line, isPrefix, err = reader.ReadLine()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return
}
if isPrefix {
prefix = append(prefix, line...)
continue
} else if len(prefix) > 0 {
line = append(prefix, line...)
prefix = nil
}
commentIndex := strings.IndexRune(string(line), '#')
if commentIndex != -1 {
line = line[:commentIndex]
}
fields := strings.Fields(string(line))
if len(fields) < 2 {
continue
}
var addr netip.Addr
addr, err = netip.ParseAddr(fields[0])
if err != nil {
continue
}
for index := 1; index < len(fields); index++ {
canonicalName := dns.CanonicalName(fields[index])
byName[canonicalName] = append(byName[canonicalName], addr)
}
}
f.expire = now.Add(cacheMaxAge)
f.modTime = stat.ModTime()
f.size = stat.Size()
f.byName = byName
}

View file

@ -0,0 +1,16 @@
package hosts_test
import (
"net/netip"
"testing"
"github.com/sagernet/sing-box/dns/transport/hosts"
"github.com/stretchr/testify/require"
)
func TestHosts(t *testing.T) {
t.Parallel()
require.Equal(t, []netip.Addr{netip.AddrFrom4([4]byte{127, 0, 0, 1}), netip.IPv6Loopback()}, hosts.NewFile("testdata/hosts").Lookup("localhost."))
require.NotEmpty(t, hosts.NewFile(hosts.DefaultPath).Lookup("localhost."))
}

View file

@ -0,0 +1,5 @@
//go:build !windows
package hosts
var DefaultPath = "/etc/hosts"

View file

@ -0,0 +1,17 @@
package hosts
import (
"path/filepath"
"golang.org/x/sys/windows"
)
var DefaultPath string
func init() {
systemDirectory, err := windows.GetSystemDirectory()
if err != nil {
systemDirectory = "C:\\Windows\\System32"
}
DefaultPath = filepath.Join(systemDirectory, "Drivers/etc/hosts")
}

2
dns/transport/hosts/testdata/hosts vendored Normal file
View file

@ -0,0 +1,2 @@
127.0.0.1 localhost
::1 localhost

204
dns/transport/https.go Normal file
View file

@ -0,0 +1,204 @@
package transport
import (
"bytes"
"context"
"io"
"net"
"net/http"
"net/url"
"strconv"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
sHTTP "github.com/sagernet/sing/protocol/http"
mDNS "github.com/miekg/dns"
"golang.org/x/net/http2"
)
const MimeType = "application/dns-message"
var _ adapter.DNSTransport = (*HTTPSTransport)(nil)
func RegisterHTTPS(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteHTTPSDNSServerOptions](registry, C.DNSTypeHTTPS, NewHTTPS)
}
type HTTPSTransport struct {
dns.TransportAdapter
logger logger.ContextLogger
dialer N.Dialer
destination *url.URL
headers http.Header
transport *http.Transport
}
func NewHTTPS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
if err != nil {
return nil, err
}
tlsOptions := common.PtrValueOrDefault(options.TLS)
tlsOptions.Enabled = true
tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions)
if err != nil {
return nil, err
}
if common.Error(tlsConfig.Config()) == nil && !common.Contains(tlsConfig.NextProtos(), http2.NextProtoTLS) {
tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), http2.NextProtoTLS))
}
if !common.Contains(tlsConfig.NextProtos(), "http/1.1") {
tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), "http/1.1"))
}
headers := options.Headers.Build()
host := headers.Get("Host")
if host != "" {
headers.Del("Host")
} else {
if tlsConfig.ServerName() != "" {
host = tlsConfig.ServerName()
} else {
host = options.Server
}
}
destinationURL := url.URL{
Scheme: "https",
Host: host,
}
if destinationURL.Host == "" {
destinationURL.Host = options.Server
}
if options.ServerPort != 0 && options.ServerPort != 443 {
destinationURL.Host = net.JoinHostPort(destinationURL.Host, strconv.Itoa(int(options.ServerPort)))
}
path := options.Path
if path == "" {
path = "/dns-query"
}
err = sHTTP.URLSetPath(&destinationURL, path)
if err != nil {
return nil, err
}
serverAddr := options.ServerOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 443
}
return NewHTTPSRaw(
dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTPS, tag, options.RemoteDNSServerOptions),
logger,
transportDialer,
&destinationURL,
headers,
serverAddr,
tlsConfig,
), nil
}
func NewHTTPSRaw(
adapter dns.TransportAdapter,
logger log.ContextLogger,
dialer N.Dialer,
destination *url.URL,
headers http.Header,
serverAddr M.Socksaddr,
tlsConfig tls.Config,
) *HTTPSTransport {
var transport *http.Transport
if tlsConfig != nil {
transport = &http.Transport{
ForceAttemptHTTP2: true,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
tcpConn, hErr := dialer.DialContext(ctx, network, serverAddr)
if hErr != nil {
return nil, hErr
}
tlsConn, hErr := aTLS.ClientHandshake(ctx, tcpConn, tlsConfig)
if hErr != nil {
tcpConn.Close()
return nil, hErr
}
return tlsConn, nil
},
}
} else {
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, serverAddr)
},
}
}
return &HTTPSTransport{
TransportAdapter: adapter,
logger: logger,
dialer: dialer,
destination: destination,
headers: headers,
transport: transport,
}
}
func (t *HTTPSTransport) Reset() {
t.transport.CloseIdleConnections()
t.transport = t.transport.Clone()
}
func (t *HTTPSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
exMessage := *message
exMessage.Id = 0
exMessage.Compress = true
requestBuffer := buf.NewSize(1 + message.Len())
rawMessage, err := exMessage.PackBuffer(requestBuffer.FreeBytes())
if err != nil {
requestBuffer.Release()
return nil, err
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, t.destination.String(), bytes.NewReader(rawMessage))
if err != nil {
requestBuffer.Release()
return nil, err
}
request.Header = t.headers.Clone()
request.Header.Set("Content-Type", MimeType)
request.Header.Set("Accept", MimeType)
response, err := t.transport.RoundTrip(request)
requestBuffer.Release()
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return nil, E.New("unexpected status: ", response.Status)
}
var responseMessage mDNS.Msg
if response.ContentLength > 0 {
responseBuffer := buf.NewSize(int(response.ContentLength))
_, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength))
if err != nil {
return nil, err
}
err = responseMessage.Unpack(responseBuffer.Bytes())
responseBuffer.Release()
} else {
rawMessage, err = io.ReadAll(response.Body)
if err != nil {
return nil, err
}
err = responseMessage.Unpack(rawMessage)
}
if err != nil {
return nil, err
}
return &responseMessage, nil
}

View file

@ -0,0 +1,197 @@
package local
import (
"context"
"math/rand"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport/hosts"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
mDNS "github.com/miekg/dns"
)
func RegisterTransport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.LocalDNSServerOptions](registry, C.DNSTypeLocal, NewTransport)
}
var _ adapter.DNSTransport = (*Transport)(nil)
type Transport struct {
dns.TransportAdapter
hosts *hosts.File
dialer N.Dialer
}
func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewLocalDialer(ctx, options)
if err != nil {
return nil, err
}
return &Transport{
TransportAdapter: dns.NewTransportAdapterWithLocalOptions(C.DNSTypeLocal, tag, options),
hosts: hosts.NewFile(hosts.DefaultPath),
dialer: transportDialer,
}, nil
}
func (t *Transport) Reset() {
}
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
question := message.Question[0]
domain := dns.FqdnToDomain(question.Name)
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
addresses := t.hosts.Lookup(domain)
if len(addresses) > 0 {
return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil
}
}
systemConfig := getSystemDNSConfig()
if systemConfig.singleRequest || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) {
return t.exchangeSingleRequest(ctx, systemConfig, message, domain)
} else {
return t.exchangeParallel(ctx, systemConfig, message, domain)
}
}
func (t *Transport) exchangeSingleRequest(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
var lastErr error
for _, fqdn := range systemConfig.nameList(domain) {
response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
if err != nil {
lastErr = err
continue
}
return response, nil
}
return nil, lastErr
}
func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
returned := make(chan struct{})
defer close(returned)
type queryResult struct {
response *mDNS.Msg
err error
}
results := make(chan queryResult)
startRacer := func(ctx context.Context, fqdn string) {
response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
if err == nil {
addresses, _ := dns.MessageToAddresses(response)
if len(addresses) == 0 {
err = E.New(fqdn, ": empty result")
}
}
select {
case results <- queryResult{response, err}:
case <-returned:
}
}
queryCtx, queryCancel := context.WithCancel(ctx)
defer queryCancel()
var nameCount int
for _, fqdn := range systemConfig.nameList(domain) {
nameCount++
go startRacer(queryCtx, fqdn)
}
var errors []error
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case result := <-results:
if result.err == nil {
return result.response, nil
}
errors = append(errors, result.err)
if len(errors) == nameCount {
return nil, E.Errors(errors...)
}
}
}
}
func (t *Transport) tryOneName(ctx context.Context, config *dnsConfig, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) {
serverOffset := config.serverOffset()
sLen := uint32(len(config.servers))
var lastErr error
for i := 0; i < config.attempts; i++ {
for j := uint32(0); j < sLen; j++ {
server := config.servers[(serverOffset+j)%sLen]
question := message.Question[0]
question.Name = fqdn
response, err := t.exchangeOne(ctx, M.ParseSocksaddr(server), question, config.timeout, config.useTCP, config.trustAD)
if err != nil {
lastErr = err
continue
}
return response, nil
}
}
return nil, E.Cause(lastErr, fqdn)
}
func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) {
var networks []string
if useTCP {
networks = []string{N.NetworkTCP}
} else {
networks = []string{N.NetworkUDP, N.NetworkTCP}
}
request := &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
Id: uint16(rand.Uint32()),
RecursionDesired: true,
AuthenticatedData: ad,
},
Question: []mDNS.Question{question},
Compress: true,
}
request.SetEdns0(maxDNSPacketSize, false)
buffer := buf.Get(buf.UDPBufferSize)
defer buf.Put(buffer)
for _, network := range networks {
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel()
conn, err := t.dialer.DialContext(ctx, network, server)
if err != nil {
return nil, err
}
defer conn.Close()
if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
conn.SetDeadline(deadline)
}
rawMessage, err := request.PackBuffer(buffer)
if err != nil {
return nil, E.Cause(err, "pack request")
}
_, err = conn.Write(rawMessage)
if err != nil {
return nil, E.Cause(err, "write request")
}
n, err := conn.Read(buffer)
if err != nil {
return nil, E.Cause(err, "read response")
}
var response mDNS.Msg
err = response.Unpack(buffer[:n])
if err != nil {
return nil, E.Cause(err, "unpack response")
}
if response.Truncated && network == N.NetworkUDP {
continue
}
return &response, nil
}
panic("unexpected")
}

View file

@ -0,0 +1,146 @@
package local
import (
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
// net.maxDNSPacketSize
maxDNSPacketSize = 1232
)
type resolverConfig struct {
initOnce sync.Once
ch chan struct{}
lastChecked time.Time
dnsConfig atomic.Pointer[dnsConfig]
}
var resolvConf resolverConfig
func getSystemDNSConfig() *dnsConfig {
resolvConf.tryUpdate("/etc/resolv.conf")
return resolvConf.dnsConfig.Load()
}
func (conf *resolverConfig) init() {
conf.dnsConfig.Store(dnsReadConfig("/etc/resolv.conf"))
conf.lastChecked = time.Now()
conf.ch = make(chan struct{}, 1)
}
func (conf *resolverConfig) tryUpdate(name string) {
conf.initOnce.Do(conf.init)
if conf.dnsConfig.Load().noReload {
return
}
if !conf.tryAcquireSema() {
return
}
defer conf.releaseSema()
now := time.Now()
if conf.lastChecked.After(now.Add(-5 * time.Second)) {
return
}
conf.lastChecked = now
if runtime.GOOS != "windows" {
var mtime time.Time
if fi, err := os.Stat(name); err == nil {
mtime = fi.ModTime()
}
if mtime.Equal(conf.dnsConfig.Load().mtime) {
return
}
}
dnsConf := dnsReadConfig(name)
conf.dnsConfig.Store(dnsConf)
}
func (conf *resolverConfig) tryAcquireSema() bool {
select {
case conf.ch <- struct{}{}:
return true
default:
return false
}
}
func (conf *resolverConfig) releaseSema() {
<-conf.ch
}
type dnsConfig struct {
servers []string
search []string
ndots int
timeout time.Duration
attempts int
rotate bool
unknownOpt bool
lookup []string
err error
mtime time.Time
soffset uint32
singleRequest bool
useTCP bool
trustAD bool
noReload bool
}
func (c *dnsConfig) serverOffset() uint32 {
if c.rotate {
return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start
}
return 0
}
func (conf *dnsConfig) nameList(name string) []string {
l := len(name)
rooted := l > 0 && name[l-1] == '.'
if l > 254 || l == 254 && !rooted {
return nil
}
if rooted {
if avoidDNS(name) {
return nil
}
return []string{name}
}
hasNdots := strings.Count(name, ".") >= conf.ndots
name += "."
// l++
names := make([]string, 0, 1+len(conf.search))
if hasNdots && !avoidDNS(name) {
names = append(names, name)
}
for _, suffix := range conf.search {
fqdn := name + suffix
if !avoidDNS(fqdn) && len(fqdn) <= 254 {
names = append(names, fqdn)
}
}
if !hasNdots && !avoidDNS(name) {
names = append(names, name)
}
return names
}
func avoidDNS(name string) bool {
if name == "" {
return true
}
if name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
return strings.HasSuffix(name, ".onion")
}

View file

@ -0,0 +1,175 @@
//go:build !windows
package local
import (
"bufio"
"net"
"net/netip"
"os"
"strings"
"time"
_ "unsafe"
)
func dnsReadConfig(name string) *dnsConfig {
conf := &dnsConfig{
ndots: 1,
timeout: 5 * time.Second,
attempts: 2,
}
file, err := os.Open(name)
if err != nil {
conf.servers = defaultNS
conf.search = dnsDefaultSearch()
conf.err = err
return conf
}
defer file.Close()
fi, err := file.Stat()
if err == nil {
conf.mtime = fi.ModTime()
} else {
conf.servers = defaultNS
conf.search = dnsDefaultSearch()
conf.err = err
return conf
}
reader := bufio.NewReader(file)
var (
prefix []byte
line []byte
isPrefix bool
)
for {
line, isPrefix, err = reader.ReadLine()
if err != nil {
break
}
if isPrefix {
prefix = append(prefix, line...)
continue
} else if len(prefix) > 0 {
line = append(prefix, line...)
prefix = nil
}
if len(line) > 0 && (line[0] == ';' || line[0] == '#') {
continue
}
f := strings.Fields(string(line))
if len(f) < 1 {
continue
}
switch f[0] {
case "nameserver":
if len(f) > 1 && len(conf.servers) < 3 {
if _, err := netip.ParseAddr(f[1]); err == nil {
conf.servers = append(conf.servers, net.JoinHostPort(f[1], "53"))
}
}
case "domain":
if len(f) > 1 {
conf.search = []string{ensureRooted(f[1])}
}
case "search":
conf.search = make([]string, 0, len(f)-1)
for i := 1; i < len(f); i++ {
name := ensureRooted(f[i])
if name == "." {
continue
}
conf.search = append(conf.search, name)
}
case "options":
for _, s := range f[1:] {
switch {
case strings.HasPrefix(s, "ndots:"):
n, _, _ := dtoi(s[6:])
if n < 0 {
n = 0
} else if n > 15 {
n = 15
}
conf.ndots = n
case strings.HasPrefix(s, "timeout:"):
n, _, _ := dtoi(s[8:])
if n < 1 {
n = 1
}
conf.timeout = time.Duration(n) * time.Second
case strings.HasPrefix(s, "attempts:"):
n, _, _ := dtoi(s[9:])
if n < 1 {
n = 1
}
conf.attempts = n
case s == "rotate":
conf.rotate = true
case s == "single-request" || s == "single-request-reopen":
conf.singleRequest = true
case s == "use-vc" || s == "usevc" || s == "tcp":
conf.useTCP = true
case s == "trust-ad":
conf.trustAD = true
case s == "edns0":
case s == "no-reload":
conf.noReload = true
default:
conf.unknownOpt = true
}
}
case "lookup":
conf.lookup = f[1:]
default:
conf.unknownOpt = true
}
}
if len(conf.servers) == 0 {
conf.servers = defaultNS
}
if len(conf.search) == 0 {
conf.search = dnsDefaultSearch()
}
return conf
}
//go:linkname defaultNS net.defaultNS
var defaultNS []string
func dnsDefaultSearch() []string {
hn, err := os.Hostname()
if err != nil {
return nil
}
if i := strings.IndexRune(hn, '.'); i >= 0 && i < len(hn)-1 {
return []string{ensureRooted(hn[i+1:])}
}
return nil
}
func ensureRooted(s string) string {
if len(s) > 0 && s[len(s)-1] == '.' {
return s
}
return s + "."
}
const big = 0xFFFFFF
func dtoi(s string) (n int, i int, ok bool) {
n = 0
for i = 0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ {
n = n*10 + int(s[i]-'0')
if n >= big {
return big, i, false
}
}
if i == 0 {
return 0, 0, false
}
return n, i, true
}

View file

@ -0,0 +1,100 @@
package local
import (
"net"
"net/netip"
"os"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
func dnsReadConfig(_ string) *dnsConfig {
conf := &dnsConfig{
ndots: 1,
timeout: 5 * time.Second,
attempts: 2,
}
defer func() {
if len(conf.servers) == 0 {
conf.servers = defaultNS
}
}()
aas, err := adapterAddresses()
if err != nil {
return nil
}
for _, aa := range aas {
// Only take interfaces whose OperStatus is IfOperStatusUp(0x01) into DNS configs.
if aa.OperStatus != windows.IfOperStatusUp {
continue
}
// Only take interfaces which have at least one gateway
if aa.FirstGatewayAddress == nil {
continue
}
for dns := aa.FirstDnsServerAddress; dns != nil; dns = dns.Next {
sa, err := dns.Address.Sockaddr.Sockaddr()
if err != nil {
continue
}
var ip netip.Addr
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
ip = netip.AddrFrom4([4]byte{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]})
case *syscall.SockaddrInet6:
var addr16 [16]byte
copy(addr16[:], sa.Addr[:])
if addr16[0] == 0xfe && addr16[1] == 0xc0 {
// fec0/10 IPv6 addresses are site local anycast DNS
// addresses Microsoft sets by default if no other
// IPv6 DNS address is set. Site local anycast is
// deprecated since 2004, see
// https://datatracker.ietf.org/doc/html/rfc3879
continue
}
ip = netip.AddrFrom16(addr16)
default:
// Unexpected type.
continue
}
conf.servers = append(conf.servers, net.JoinHostPort(ip.String(), "53"))
}
}
return conf
}
//go:linkname defaultNS net.defaultNS
var defaultNS []string
func adapterAddresses() ([]*windows.IpAdapterAddresses, error) {
var b []byte
l := uint32(15000) // recommended initial size
for {
b = make([]byte, l)
const flags = windows.GAA_FLAG_INCLUDE_PREFIX | windows.GAA_FLAG_INCLUDE_GATEWAYS
err := windows.GetAdaptersAddresses(syscall.AF_UNSPEC, flags, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l)
if err == nil {
if l == 0 {
return nil, nil
}
break
}
if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW {
return nil, os.NewSyscallError("getadaptersaddresses", err)
}
if l <= uint32(len(b)) {
return nil, os.NewSyscallError("getadaptersaddresses", err)
}
}
var aas []*windows.IpAdapterAddresses
for aa := (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next {
aas = append(aas, aa)
}
return aas, nil
}

View file

@ -0,0 +1,83 @@
package transport
import (
"context"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*PredefinedTransport)(nil)
func RegisterPredefined(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.PredefinedDNSServerOptions](registry, C.DNSTypePreDefined, NewPredefined)
}
type PredefinedTransport struct {
dns.TransportAdapter
responses []*predefinedResponse
}
type predefinedResponse struct {
questions []mDNS.Question
answer *mDNS.Msg
}
func NewPredefined(ctx context.Context, logger log.ContextLogger, tag string, options option.PredefinedDNSServerOptions) (adapter.DNSTransport, error) {
var responses []*predefinedResponse
for _, response := range options.Responses {
questions, msg, err := response.Build()
if err != nil {
return nil, err
}
responses = append(responses, &predefinedResponse{
questions: questions,
answer: msg,
})
}
if len(responses) == 0 {
return nil, E.New("empty predefined responses")
}
return &PredefinedTransport{
TransportAdapter: dns.NewTransportAdapter(C.DNSTypePreDefined, tag, nil),
responses: responses,
}, nil
}
func (t *PredefinedTransport) Reset() {
}
func (t *PredefinedTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
for _, response := range t.responses {
for _, question := range response.questions {
if func() bool {
if question.Name == "" && question.Qtype == mDNS.TypeNone {
return true
} else if question.Name == "" {
return common.Any(message.Question, func(it mDNS.Question) bool {
return it.Qtype == question.Qtype
})
} else if question.Qtype == mDNS.TypeNone {
return common.Any(message.Question, func(it mDNS.Question) bool {
return it.Name == question.Name
})
} else {
return common.Contains(message.Question, question)
}
}() {
copyAnswer := *response.answer
copyAnswer.Id = message.Id
copyAnswer.Question = message.Question
return &copyAnswer, nil
}
}
}
return nil, dns.RCodeNameError
}

167
dns/transport/quic/http3.go Normal file
View file

@ -0,0 +1,167 @@
package quic
import (
"bytes"
"context"
"io"
"net"
"net/http"
"net/url"
"strconv"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/http3"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
sHTTP "github.com/sagernet/sing/protocol/http"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*HTTP3Transport)(nil)
func RegisterHTTP3Transport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteHTTPSDNSServerOptions](registry, C.DNSTypeHTTP3, NewHTTP3)
}
type HTTP3Transport struct {
dns.TransportAdapter
logger logger.ContextLogger
dialer N.Dialer
destination *url.URL
headers http.Header
transport *http3.Transport
}
func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
if err != nil {
return nil, err
}
tlsOptions := common.PtrValueOrDefault(options.TLS)
tlsOptions.Enabled = true
tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions)
if err != nil {
return nil, err
}
stdConfig, err := tlsConfig.Config()
if err != nil {
return nil, err
}
headers := options.Headers.Build()
host := headers.Get("Host")
if host != "" {
headers.Del("Host")
} else {
if tlsConfig.ServerName() != "" {
host = tlsConfig.ServerName()
} else {
host = options.Server
}
}
destinationURL := url.URL{
Scheme: "https",
Host: host,
}
if destinationURL.Host == "" {
destinationURL.Host = options.Server
}
if options.ServerPort != 0 && options.ServerPort != 443 {
destinationURL.Host = net.JoinHostPort(destinationURL.Host, strconv.Itoa(int(options.ServerPort)))
}
path := options.Path
if path == "" {
path = "/dns-query"
}
err = sHTTP.URLSetPath(&destinationURL, path)
if err != nil {
return nil, err
}
serverAddr := options.ServerOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 443
}
return &HTTP3Transport{
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTP3, tag, options.RemoteDNSServerOptions),
logger: logger,
dialer: transportDialer,
destination: &destinationURL,
headers: headers,
transport: &http3.Transport{
Dial: func(ctx context.Context, addr string, tlsCfg *tls.STDConfig, cfg *quic.Config) (quic.EarlyConnection, error) {
destinationAddr := M.ParseSocksaddr(addr)
conn, dialErr := transportDialer.DialContext(ctx, N.NetworkUDP, destinationAddr)
if dialErr != nil {
return nil, dialErr
}
return quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsCfg, cfg)
},
TLSClientConfig: stdConfig,
},
}, nil
}
func (t *HTTP3Transport) Reset() {
t.transport.Close()
}
func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
exMessage := *message
exMessage.Id = 0
exMessage.Compress = true
requestBuffer := buf.NewSize(1 + message.Len())
rawMessage, err := exMessage.PackBuffer(requestBuffer.FreeBytes())
if err != nil {
requestBuffer.Release()
return nil, err
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, t.destination.String(), bytes.NewReader(rawMessage))
if err != nil {
requestBuffer.Release()
return nil, err
}
request.Header = t.headers.Clone()
request.Header.Set("Content-Type", transport.MimeType)
request.Header.Set("Accept", transport.MimeType)
response, err := t.transport.RoundTrip(request)
requestBuffer.Release()
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return nil, E.New("unexpected status: ", response.Status)
}
var responseMessage mDNS.Msg
if response.ContentLength > 0 {
responseBuffer := buf.NewSize(int(response.ContentLength))
_, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength))
if err != nil {
return nil, err
}
err = responseMessage.Unpack(responseBuffer.Bytes())
responseBuffer.Release()
} else {
rawMessage, err = io.ReadAll(response.Body)
if err != nil {
return nil, err
}
err = responseMessage.Unpack(rawMessage)
}
if err != nil {
return nil, err
}
return &responseMessage, nil
}

174
dns/transport/quic/quic.go Normal file
View file

@ -0,0 +1,174 @@
package quic
import (
"context"
"errors"
"sync"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
sQUIC "github.com/sagernet/sing-quic"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*Transport)(nil)
func RegisterTransport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeQUIC, NewQUIC)
}
type Transport struct {
dns.TransportAdapter
ctx context.Context
logger logger.ContextLogger
dialer N.Dialer
serverAddr M.Socksaddr
tlsConfig tls.Config
access sync.Mutex
connection quic.EarlyConnection
}
func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
if err != nil {
return nil, err
}
tlsOptions := common.PtrValueOrDefault(options.TLS)
tlsOptions.Enabled = true
tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions)
if err != nil {
return nil, err
}
if len(tlsConfig.NextProtos()) == 0 {
tlsConfig.SetNextProtos([]string{"doq"})
}
serverAddr := options.ServerOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 853
}
return &Transport{
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions),
ctx: ctx,
logger: logger,
dialer: transportDialer,
serverAddr: serverAddr,
tlsConfig: tlsConfig,
}, nil
}
func (t *Transport) Reset() {
t.access.Lock()
defer t.access.Unlock()
connection := t.connection
if connection != nil {
connection.CloseWithError(0, "")
}
}
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
var (
conn quic.Connection
err error
response *mDNS.Msg
)
for i := 0; i < 2; i++ {
conn, err = t.openConnection()
if err != nil {
return nil, err
}
response, err = t.exchange(ctx, message, conn)
if err == nil {
return response, nil
} else if !isQUICRetryError(err) {
return nil, err
} else {
conn.CloseWithError(quic.ApplicationErrorCode(0), "")
continue
}
}
return nil, err
}
func (t *Transport) openConnection() (quic.EarlyConnection, error) {
connection := t.connection
if connection != nil && !common.Done(connection.Context()) {
return connection, nil
}
t.access.Lock()
defer t.access.Unlock()
connection = t.connection
if connection != nil && !common.Done(connection.Context()) {
return connection, nil
}
conn, err := t.dialer.DialContext(t.ctx, N.NetworkUDP, t.serverAddr)
if err != nil {
return nil, err
}
earlyConnection, err := sQUIC.DialEarly(
t.ctx,
bufio.NewUnbindPacketConn(conn),
t.serverAddr.UDPAddr(),
t.tlsConfig,
nil,
)
if err != nil {
return nil, err
}
t.connection = earlyConnection
return earlyConnection, nil
}
func (t *Transport) exchange(ctx context.Context, message *mDNS.Msg, conn quic.Connection) (*mDNS.Msg, error) {
stream, err := conn.OpenStreamSync(ctx)
if err != nil {
return nil, err
}
defer stream.Close()
defer stream.CancelRead(0)
err = transport.WriteMessage(stream, 0, message)
if err != nil {
return nil, err
}
return transport.ReadMessage(stream)
}
// https://github.com/AdguardTeam/dnsproxy/blob/fd1868577652c639cce3da00e12ca548f421baf1/upstream/upstream_quic.go#L394
func isQUICRetryError(err error) (ok bool) {
var qAppErr *quic.ApplicationError
if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 0 {
return true
}
var qIdleErr *quic.IdleTimeoutError
if errors.As(err, &qIdleErr) {
return true
}
var resetErr *quic.StatelessResetError
if errors.As(err, &resetErr) {
return true
}
var qTransportError *quic.TransportError
if errors.As(err, &qTransportError) && qTransportError.ErrorCode == quic.NoError {
return true
}
if errors.Is(err, quic.Err0RTTRejected) {
return true
}
return false
}

99
dns/transport/tcp.go Normal file
View file

@ -0,0 +1,99 @@
package transport
import (
"context"
"encoding/binary"
"io"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*TCPTransport)(nil)
func RegisterTCP(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeTCP, NewTCP)
}
type TCPTransport struct {
dns.TransportAdapter
dialer N.Dialer
serverAddr M.Socksaddr
}
func NewTCP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options)
if err != nil {
return nil, err
}
serverAddr := options.ServerOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 53
}
return &TCPTransport{
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTCP, tag, options),
dialer: transportDialer,
serverAddr: serverAddr,
}, nil
}
func (t *TCPTransport) Reset() {
}
func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
if err != nil {
return nil, err
}
defer conn.Close()
err = WriteMessage(conn, 0, message)
if err != nil {
return nil, err
}
return ReadMessage(conn)
}
func ReadMessage(reader io.Reader) (*mDNS.Msg, error) {
var responseLen uint16
err := binary.Read(reader, binary.BigEndian, &responseLen)
if err != nil {
return nil, err
}
if responseLen < 10 {
return nil, mDNS.ErrShortRead
}
buffer := buf.NewSize(int(responseLen))
defer buffer.Release()
_, err = buffer.ReadFullFrom(reader, int(responseLen))
if err != nil {
return nil, err
}
var message mDNS.Msg
err = message.Unpack(buffer.Bytes())
return &message, err
}
func WriteMessage(writer io.Writer, messageId uint16, message *mDNS.Msg) error {
requestLen := message.Len()
buffer := buf.NewSize(3 + requestLen)
defer buffer.Release()
common.Must(binary.Write(buffer, binary.BigEndian, uint16(requestLen)))
exMessage := *message
exMessage.Id = messageId
exMessage.Compress = true
rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes())
if err != nil {
return err
}
buffer.Truncate(2 + len(rawMessage))
return common.Error(writer.Write(buffer.Bytes()))
}

115
dns/transport/tls.go Normal file
View file

@ -0,0 +1,115 @@
package transport
import (
"context"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*TLSTransport)(nil)
func RegisterTLS(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeTLS, NewTLS)
}
type TLSTransport struct {
dns.TransportAdapter
logger logger.ContextLogger
dialer N.Dialer
serverAddr M.Socksaddr
tlsConfig tls.Config
access sync.Mutex
connections list.List[*tlsDNSConn]
}
type tlsDNSConn struct {
tls.Conn
queryId uint16
}
func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
if err != nil {
return nil, err
}
tlsOptions := common.PtrValueOrDefault(options.TLS)
tlsOptions.Enabled = true
tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions)
if err != nil {
return nil, err
}
serverAddr := options.ServerOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 853
}
return &TLSTransport{
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTLS, tag, options.RemoteDNSServerOptions),
logger: logger,
dialer: transportDialer,
serverAddr: serverAddr,
tlsConfig: tlsConfig,
}, nil
}
func (t *TLSTransport) Reset() {
t.access.Lock()
defer t.access.Unlock()
for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
connection.Value.Close()
}
t.connections.Init()
}
func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
t.access.Lock()
conn := t.connections.PopFront()
t.access.Unlock()
if conn != nil {
response, err := t.exchange(message, conn)
if err == nil {
return response, nil
}
}
tcpConn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
if err != nil {
return nil, err
}
tlsConn, err := tls.ClientHandshake(ctx, tcpConn, t.tlsConfig)
if err != nil {
tcpConn.Close()
return nil, err
}
return t.exchange(message, &tlsDNSConn{Conn: tlsConn})
}
func (t *TLSTransport) exchange(message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
conn.queryId++
err := WriteMessage(conn, conn.queryId, message)
if err != nil {
conn.Close()
return nil, E.Cause(err, "write request")
}
response, err := ReadMessage(conn)
if err != nil {
conn.Close()
return nil, E.Cause(err, "read response")
}
t.access.Lock()
t.connections.PushBack(conn)
t.access.Unlock()
return response, nil
}

223
dns/transport/udp.go Normal file
View file

@ -0,0 +1,223 @@
package transport
import (
"context"
"net"
"os"
"sync"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*UDPTransport)(nil)
func RegisterUDP(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeUDP, NewUDP)
}
type UDPTransport struct {
dns.TransportAdapter
logger logger.ContextLogger
dialer N.Dialer
serverAddr M.Socksaddr
udpSize int
tcpTransport *TCPTransport
access sync.Mutex
conn *dnsConnection
done chan struct{}
}
func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options)
if err != nil {
return nil, err
}
serverAddr := options.ServerOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 53
}
return NewUDPRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeUDP, tag, options), transportDialer, serverAddr), nil
}
func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr) *UDPTransport {
return &UDPTransport{
TransportAdapter: adapter,
logger: logger,
dialer: dialer,
serverAddr: serverAddr,
udpSize: 512,
tcpTransport: &TCPTransport{
dialer: dialer,
serverAddr: serverAddr,
},
done: make(chan struct{}),
}
}
func (t *UDPTransport) Reset() {
t.access.Lock()
defer t.access.Unlock()
close(t.done)
t.done = make(chan struct{})
}
func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
response, err := t.exchange(ctx, message)
if err != nil {
return nil, err
}
if response.Truncated {
t.logger.InfoContext(ctx, "response truncated, retrying with TCP")
return t.tcpTransport.Exchange(ctx, message)
}
return response, nil
}
func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
conn, err := t.open(ctx)
if err != nil {
return nil, err
}
if edns0Opt := message.IsEdns0(); edns0Opt != nil {
if udpSize := int(edns0Opt.UDPSize()); udpSize > t.udpSize {
t.udpSize = udpSize
}
}
buffer := buf.NewSize(1 + message.Len())
defer buffer.Release()
exMessage := *message
exMessage.Compress = true
messageId := message.Id
callback := &dnsCallback{
done: make(chan struct{}),
}
conn.access.Lock()
conn.queryId++
exMessage.Id = conn.queryId
conn.callbacks[exMessage.Id] = callback
conn.access.Unlock()
defer func() {
conn.access.Lock()
delete(conn.callbacks, messageId)
conn.access.Unlock()
callback.access.Lock()
select {
case <-callback.done:
default:
close(callback.done)
}
callback.access.Unlock()
}()
rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes())
if err != nil {
return nil, err
}
_, err = conn.Write(rawMessage)
if err != nil {
conn.Close(err)
return nil, err
}
select {
case <-callback.done:
callback.message.Id = messageId
return callback.message, nil
case <-conn.done:
return nil, conn.err
case <-t.done:
return nil, os.ErrClosed
case <-ctx.Done():
conn.Close(ctx.Err())
return nil, ctx.Err()
}
}
func (t *UDPTransport) open(ctx context.Context) (*dnsConnection, error) {
t.access.Lock()
defer t.access.Unlock()
if t.conn != nil {
select {
case <-t.conn.done:
default:
return t.conn, nil
}
}
conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
if err != nil {
return nil, err
}
dnsConn := &dnsConnection{
Conn: conn,
done: make(chan struct{}),
callbacks: make(map[uint16]*dnsCallback),
}
go t.recvLoop(dnsConn)
t.conn = dnsConn
return dnsConn, nil
}
func (t *UDPTransport) recvLoop(conn *dnsConnection) {
for {
buffer := buf.NewSize(t.udpSize)
_, err := buffer.ReadOnceFrom(conn)
if err != nil {
buffer.Release()
conn.Close(err)
return
}
var message mDNS.Msg
err = message.Unpack(buffer.Bytes())
buffer.Release()
if err != nil {
conn.Close(err)
return
}
conn.access.RLock()
callback, loaded := conn.callbacks[message.Id]
conn.access.RUnlock()
if !loaded {
continue
}
callback.access.Lock()
select {
case <-callback.done:
default:
callback.message = &message
close(callback.done)
}
callback.access.Unlock()
}
}
type dnsConnection struct {
net.Conn
access sync.RWMutex
done chan struct{}
closeOnce sync.Once
err error
queryId uint16
callbacks map[uint16]*dnsCallback
}
func (c *dnsConnection) Close(err error) {
c.closeOnce.Do(func() {
close(c.done)
c.err = err
})
c.Conn.Close()
}
type dnsCallback struct {
access sync.Mutex
message *mDNS.Msg
done chan struct{}
}

70
dns/transport_adapter.go Normal file
View file

@ -0,0 +1,70 @@
package dns
import (
"net/netip"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
)
var _ adapter.LegacyDNSTransport = (*TransportAdapter)(nil)
type TransportAdapter struct {
transportType string
transportTag string
dependencies []string
strategy C.DomainStrategy
clientSubnet netip.Prefix
}
func NewTransportAdapter(transportType string, transportTag string, dependencies []string) TransportAdapter {
return TransportAdapter{
transportType: transportType,
transportTag: transportTag,
dependencies: dependencies,
}
}
func NewTransportAdapterWithLocalOptions(transportType string, transportTag string, localOptions option.LocalDNSServerOptions) TransportAdapter {
return TransportAdapter{
transportType: transportType,
transportTag: transportTag,
strategy: C.DomainStrategy(localOptions.LegacyStrategy),
clientSubnet: localOptions.LegacyClientSubnet,
}
}
func NewTransportAdapterWithRemoteOptions(transportType string, transportTag string, remoteOptions option.RemoteDNSServerOptions) TransportAdapter {
var dependencies []string
if remoteOptions.AddressResolver != "" {
dependencies = []string{remoteOptions.AddressResolver}
}
return TransportAdapter{
transportType: transportType,
transportTag: transportTag,
dependencies: dependencies,
strategy: C.DomainStrategy(remoteOptions.LegacyStrategy),
clientSubnet: remoteOptions.LegacyClientSubnet,
}
}
func (a *TransportAdapter) Type() string {
return a.transportType
}
func (a *TransportAdapter) Tag() string {
return a.transportTag
}
func (a *TransportAdapter) Dependencies() []string {
return a.dependencies
}
func (a *TransportAdapter) LegacyStrategy() C.DomainStrategy {
return a.strategy
}
func (a *TransportAdapter) LegacyClientSubnet() netip.Prefix {
return a.clientSubnet
}

93
dns/transport_dialer.go Normal file
View file

@ -0,0 +1,93 @@
package dns
import (
"context"
"net"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) (N.Dialer, error) {
if options.LegacyDefaultDialer {
return dialer.NewDefaultOutbound(ctx), nil
} else {
return dialer.New(ctx, options.DialerOptions)
}
}
func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) (N.Dialer, error) {
transportDialer, err := NewLocalDialer(ctx, options.LocalDNSServerOptions)
if err != nil {
return nil, err
}
if options.AddressResolver != "" {
transport := service.FromContext[adapter.DNSTransportManager](ctx)
resolverTransport, loaded := transport.Transport(options.AddressResolver)
if !loaded {
return nil, E.New("address resolver not found: ", options.AddressResolver)
}
transportDialer = NewTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.AddressStrategy), time.Duration(options.AddressFallbackDelay))
} else if M.IsDomainName(options.Server) {
return nil, E.New("missing address resolver for server: ", options.Server)
}
return transportDialer, nil
}
type TransportDialer struct {
dialer N.Dialer
dnsRouter adapter.DNSRouter
transport adapter.DNSTransport
strategy C.DomainStrategy
fallbackDelay time.Duration
}
func NewTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *TransportDialer {
return &TransportDialer{
dialer,
dnsRouter,
transport,
strategy,
fallbackDelay,
}
}
func (d *TransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if destination.IsIP() {
return d.dialer.DialContext(ctx, network, destination)
}
addresses, err := d.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{
Transport: d.transport,
Strategy: d.strategy,
})
if err != nil {
return nil, err
}
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay)
}
func (d *TransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if destination.IsIP() {
return d.dialer.ListenPacket(ctx, destination)
}
addresses, err := d.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{
Transport: d.transport,
Strategy: d.strategy,
})
if err != nil {
return nil, err
}
conn, _, err := N.ListenSerial(ctx, d.dialer, destination, addresses)
return conn, err
}
func (d *TransportDialer) Upstream() any {
return d.dialer
}

288
dns/transport_manager.go Normal file
View file

@ -0,0 +1,288 @@
package dns
import (
"context"
"io"
"os"
"strings"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
var _ adapter.DNSTransportManager = (*TransportManager)(nil)
type TransportManager struct {
logger log.ContextLogger
registry adapter.DNSTransportRegistry
outbound adapter.OutboundManager
defaultTag string
access sync.RWMutex
started bool
stage adapter.StartStage
transports []adapter.DNSTransport
transportByTag map[string]adapter.DNSTransport
dependByTag map[string][]string
defaultTransport adapter.DNSTransport
defaultTransportFallback adapter.DNSTransport
fakeIPTransport adapter.FakeIPTransport
}
func NewTransportManager(logger logger.ContextLogger, registry adapter.DNSTransportRegistry, outbound adapter.OutboundManager, defaultTag string) *TransportManager {
return &TransportManager{
logger: logger,
registry: registry,
outbound: outbound,
defaultTag: defaultTag,
transportByTag: make(map[string]adapter.DNSTransport),
dependByTag: make(map[string][]string),
}
}
func (m *TransportManager) Initialize(defaultTransportFallback adapter.DNSTransport) {
m.defaultTransportFallback = defaultTransportFallback
}
func (m *TransportManager) Start(stage adapter.StartStage) error {
m.access.Lock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
outbounds := m.transports
m.access.Unlock()
if stage == adapter.StartStateStart {
return m.startTransports(m.transports)
} else {
for _, outbound := range outbounds {
err := adapter.LegacyStart(outbound, stage)
if err != nil {
return E.Cause(err, stage, " dns/", outbound.Type(), "[", outbound.Tag(), "]")
}
}
}
return nil
}
func (m *TransportManager) startTransports(transports []adapter.DNSTransport) error {
monitor := taskmonitor.New(m.logger, C.StartTimeout)
started := make(map[string]bool)
for {
canContinue := false
startOne:
for _, transportToStart := range transports {
transportTag := transportToStart.Tag()
if started[transportTag] {
continue
}
dependencies := transportToStart.Dependencies()
for _, dependency := range dependencies {
if !started[dependency] {
continue startOne
}
}
started[transportTag] = true
canContinue = true
if starter, isStarter := transportToStart.(adapter.Lifecycle); isStarter {
monitor.Start("start dns/", transportToStart.Type(), "[", transportTag, "]")
err := starter.Start(adapter.StartStateStart)
monitor.Finish()
if err != nil {
return E.Cause(err, "start dns/", transportToStart.Type(), "[", transportTag, "]")
}
}
}
if len(started) == len(transports) {
break
}
if canContinue {
continue
}
currentTransport := common.Find(transports, func(it adapter.DNSTransport) bool {
return !started[it.Tag()]
})
var lintTransport func(oTree []string, oCurrent adapter.DNSTransport) error
lintTransport = func(oTree []string, oCurrent adapter.DNSTransport) error {
problemTransportTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
return !started[it]
})
if common.Contains(oTree, problemTransportTag) {
return E.New("circular server dependency: ", strings.Join(oTree, " -> "), " -> ", problemTransportTag)
}
m.access.Lock()
problemTransport := m.transportByTag[problemTransportTag]
m.access.Unlock()
if problemTransport == nil {
return E.New("dependency[", problemTransportTag, "] not found for server[", oCurrent.Tag(), "]")
}
return lintTransport(append(oTree, problemTransportTag), problemTransport)
}
return lintTransport([]string{currentTransport.Tag()}, currentTransport)
}
return nil
}
func (m *TransportManager) Close() error {
monitor := taskmonitor.New(m.logger, C.StopTimeout)
m.access.Lock()
if !m.started {
m.access.Unlock()
return nil
}
m.started = false
transports := m.transports
m.transports = nil
m.access.Unlock()
var err error
for _, transport := range transports {
if closer, isCloser := transport.(io.Closer); isCloser {
monitor.Start("close server/", transport.Type(), "[", transport.Tag(), "]")
err = E.Append(err, closer.Close(), func(err error) error {
return E.Cause(err, "close server/", transport.Type(), "[", transport.Tag(), "]")
})
monitor.Finish()
}
}
return nil
}
func (m *TransportManager) Transports() []adapter.DNSTransport {
m.access.RLock()
defer m.access.RUnlock()
return m.transports
}
func (m *TransportManager) Transport(tag string) (adapter.DNSTransport, bool) {
m.access.RLock()
outbound, found := m.transportByTag[tag]
m.access.RUnlock()
return outbound, found
}
func (m *TransportManager) Default() adapter.DNSTransport {
m.access.RLock()
defer m.access.RUnlock()
if m.defaultTransport != nil {
return m.defaultTransport
} else {
return m.defaultTransportFallback
}
}
func (m *TransportManager) FakeIP() adapter.FakeIPTransport {
m.access.RLock()
defer m.access.RUnlock()
return m.fakeIPTransport
}
func (m *TransportManager) Remove(tag string) error {
m.access.Lock()
defer m.access.Unlock()
transport, found := m.transportByTag[tag]
if !found {
return os.ErrInvalid
}
delete(m.transportByTag, tag)
index := common.Index(m.transports, func(it adapter.DNSTransport) bool {
return it == transport
})
if index == -1 {
panic("invalid inbound index")
}
m.transports = append(m.transports[:index], m.transports[index+1:]...)
started := m.started
if m.defaultTransport == transport {
if len(m.transports) > 0 {
nextTransport := m.transports[0]
if nextTransport.Type() != C.DNSTypeFakeIP {
return E.New("default server cannot be fakeip")
}
m.defaultTransport = nextTransport
m.logger.Info("updated default server to ", m.defaultTransport.Tag())
} else {
m.defaultTransport = nil
}
}
dependBy := m.dependByTag[tag]
if len(dependBy) > 0 {
return E.New("server[", tag, "] is depended by ", strings.Join(dependBy, ", "))
}
dependencies := transport.Dependencies()
for _, dependency := range dependencies {
if len(m.dependByTag[dependency]) == 1 {
delete(m.dependByTag, dependency)
} else {
m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool {
return it != tag
})
}
}
if started {
transport.Reset()
}
return nil
}
func (m *TransportManager) Create(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) error {
if tag == "" {
return os.ErrInvalid
}
transport, err := m.registry.CreateDNSTransport(ctx, logger, tag, transportType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(transport, stage)
if err != nil {
return E.Cause(err, stage, " dns/", transport.Type(), "[", transport.Tag(), "]")
}
}
}
if existsTransport, loaded := m.transportByTag[tag]; loaded {
if m.started {
err = common.Close(existsTransport)
if err != nil {
return E.Cause(err, "close dns/", existsTransport.Type(), "[", existsTransport.Tag(), "]")
}
}
existsIndex := common.Index(m.transports, func(it adapter.DNSTransport) bool {
return it == existsTransport
})
if existsIndex == -1 {
panic("invalid inbound index")
}
m.transports = append(m.transports[:existsIndex], m.transports[existsIndex+1:]...)
}
m.transports = append(m.transports, transport)
m.transportByTag[tag] = transport
dependencies := transport.Dependencies()
for _, dependency := range dependencies {
m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
}
if tag == m.defaultTag || (m.defaultTag == "" && m.defaultTransport == nil) {
if transport.Type() == C.DNSTypeFakeIP {
return E.New("default server cannot be fakeip")
}
m.defaultTransport = transport
if m.started {
m.logger.Info("updated default server to ", transport.Tag())
}
}
if transport.Type() == C.DNSTypeFakeIP {
if m.fakeIPTransport != nil {
return E.New("multiple fakeip server are not supported")
}
m.fakeIPTransport = transport.(adapter.FakeIPTransport)
}
return nil
}

72
dns/transport_registry.go Normal file
View file

@ -0,0 +1,72 @@
package dns
import (
"context"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
type TransportConstructorFunc[T any] func(ctx context.Context, logger log.ContextLogger, tag string, options T) (adapter.DNSTransport, error)
func RegisterTransport[Options any](registry *TransportRegistry, transportType string, constructor TransportConstructorFunc[Options]) {
registry.register(transportType, func() any {
return new(Options)
}, func(ctx context.Context, logger log.ContextLogger, tag string, rawOptions any) (adapter.DNSTransport, error) {
var options *Options
if rawOptions != nil {
options = rawOptions.(*Options)
}
return constructor(ctx, logger, tag, common.PtrValueOrDefault(options))
})
}
var _ adapter.DNSTransportRegistry = (*TransportRegistry)(nil)
type (
optionsConstructorFunc func() any
constructorFunc func(ctx context.Context, logger log.ContextLogger, tag string, options any) (adapter.DNSTransport, error)
)
type TransportRegistry struct {
access sync.Mutex
optionsType map[string]optionsConstructorFunc
constructors map[string]constructorFunc
}
func NewTransportRegistry() *TransportRegistry {
return &TransportRegistry{
optionsType: make(map[string]optionsConstructorFunc),
constructors: make(map[string]constructorFunc),
}
}
func (r *TransportRegistry) CreateOptions(transportType string) (any, bool) {
r.access.Lock()
defer r.access.Unlock()
optionsConstructor, loaded := r.optionsType[transportType]
if !loaded {
return nil, false
}
return optionsConstructor(), true
}
func (r *TransportRegistry) CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (adapter.DNSTransport, error) {
r.access.Lock()
defer r.access.Unlock()
constructor, loaded := r.constructors[transportType]
if !loaded {
return nil, E.New("transport type not found: " + transportType)
}
return constructor(ctx, logger, tag, options)
}
func (r *TransportRegistry) register(transportType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
r.access.Lock()
defer r.access.Unlock()
r.optionsType[transportType] = optionsConstructor
r.constructors[transportType] = constructor
}

View file

@ -13,13 +13,13 @@ import (
"github.com/miekg/dns"
)
func dnsRouter(router adapter.Router) http.Handler {
func dnsRouter(router adapter.DNSRouter) http.Handler {
r := chi.NewRouter()
r.Get("/query", queryDNS(router))
return r
}
func queryDNS(router adapter.Router) func(w http.ResponseWriter, r *http.Request) {
func queryDNS(router adapter.DNSRouter) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
name := r.URL.Query().Get("name")
qTypeStr := r.URL.Query().Get("type")
@ -39,7 +39,7 @@ func queryDNS(router adapter.Router) func(w http.ResponseWriter, r *http.Request
msg := dns.Msg{}
msg.SetQuestion(dns.Fqdn(name), qType)
resp, err := router.Exchange(ctx, &msg)
resp, err := router.Exchange(ctx, &msg, adapter.DNSQueryOptions{})
if err != nil {
render.Status(r, http.StatusInternalServerError)
render.JSON(w, r, newError(err.Error()))

View file

@ -42,6 +42,7 @@ var _ adapter.ClashServer = (*Server)(nil)
type Server struct {
ctx context.Context
router adapter.Router
dnsRouter adapter.DNSRouter
outbound adapter.OutboundManager
endpoint adapter.EndpointManager
logger log.Logger
@ -64,6 +65,7 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op
s := &Server{
ctx: ctx,
router: service.FromContext[adapter.Router](ctx),
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
outbound: service.FromContext[adapter.OutboundManager](ctx),
endpoint: service.FromContext[adapter.EndpointManager](ctx),
logger: logFactory.NewLogger("clash-api"),
@ -121,7 +123,7 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op
r.Mount("/script", scriptRouter())
r.Mount("/profile", profileRouter())
r.Mount("/cache", cacheRouter(ctx))
r.Mount("/dns", dnsRouter(s.router))
r.Mount("/dns", dnsRouter(s.dnsRouter))
s.setupMetaAPI(r)
})
@ -221,7 +223,7 @@ func (s *Server) SetMode(newMode string) {
default:
}
}
s.router.ClearDNSCache()
s.dnsRouter.ClearCache()
cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
if cacheFile != nil {
err := cacheFile.StoreMode(newMode)

View file

@ -146,6 +146,21 @@ var OptionTUNGSO = Note{
EnvName: "TUN_GSO",
}
var OptionLegacyDNSTransport = Note{
Name: "legacy-dns-transport",
Description: "legacy DNS transport",
DeprecatedVersion: "1.12.0",
ScheduledVersion: "1.14.0",
EnvName: "LEGACY_DNS_TRANSPORT",
}
var OptionLegacyDNSFakeIPOptions = Note{
Name: "legacy-dns-fakeip-options",
Description: "legacy DNS fakeip options",
DeprecatedVersion: "1.12.0",
ScheduledVersion: "1.14.0",
}
var Options = []Note{
OptionBadMatchSource,
OptionGEOIP,

View file

@ -9,8 +9,11 @@ import (
"github.com/sagernet/sing-box"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/process"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/include"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/control"
@ -21,6 +24,18 @@ import (
"github.com/sagernet/sing/service"
)
func BaseContext(platformInterface PlatformInterface) context.Context {
dnsRegistry := include.DNSTransportRegistry()
if platformInterface != nil {
if localTransport := platformInterface.LocalDNSTransport(); localTransport != nil {
dns.RegisterTransport[option.LocalDNSServerOptions](dnsRegistry, C.DNSTypeLocal, func(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) {
return newPlatformTransport(localTransport, tag, options), nil
})
}
}
return box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry(), dnsRegistry)
}
func parseConfig(ctx context.Context, configContent string) (option.Options, error) {
options, err := json.UnmarshalExtendedContext[option.Options](ctx, []byte(configContent))
if err != nil {
@ -30,7 +45,7 @@ func parseConfig(ctx context.Context, configContent string) (option.Options, err
}
func CheckConfig(configContent string) error {
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
ctx := BaseContext(nil)
options, err := parseConfig(ctx, configContent)
if err != nil {
return err
@ -131,7 +146,7 @@ func (s *platformInterfaceStub) SendNotification(notification *platform.Notifica
}
func FormatConfig(configContent string) (*StringBox, error) {
options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()), configContent)
options, err := parseConfig(BaseContext(nil), configContent)
if err != nil {
return nil, err
}

View file

@ -6,7 +6,10 @@ import (
"strings"
"syscall"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
@ -21,53 +24,32 @@ type LocalDNSTransport interface {
Exchange(ctx *ExchangeContext, message []byte) error
}
func RegisterLocalDNSTransport(transport LocalDNSTransport) {
if transport == nil {
dns.RegisterTransport([]string{"local"}, func(options dns.TransportOptions) (dns.Transport, error) {
return dns.NewLocalTransport(options), nil
})
} else {
dns.RegisterTransport([]string{"local"}, func(options dns.TransportOptions) (dns.Transport, error) {
return &platformLocalDNSTransport{
iif: transport,
}, nil
})
}
}
var _ adapter.DNSTransport = (*platformTransport)(nil)
var _ dns.Transport = (*platformLocalDNSTransport)(nil)
type platformLocalDNSTransport struct {
type platformTransport struct {
dns.TransportAdapter
iif LocalDNSTransport
}
func (p *platformLocalDNSTransport) Name() string {
return "local"
func newPlatformTransport(iif LocalDNSTransport, tag string, options option.LocalDNSServerOptions) *platformTransport {
return &platformTransport{
TransportAdapter: dns.NewTransportAdapterWithLocalOptions(C.DNSTypeLocal, tag, options),
iif: iif,
}
}
func (p *platformLocalDNSTransport) Start() error {
return nil
func (p *platformTransport) Reset() {
}
func (p *platformLocalDNSTransport) Reset() {
}
func (p *platformLocalDNSTransport) Close() error {
return nil
}
func (p *platformLocalDNSTransport) Raw() bool {
return p.iif.Raw()
}
func (p *platformLocalDNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
func (p *platformTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
response := &ExchangeContext{
context: ctx,
}
if p.iif.Raw() {
messageBytes, err := message.Pack()
if err != nil {
return nil, err
}
response := &ExchangeContext{
context: ctx,
}
var responseMessage *mDNS.Msg
var group task.Group
group.Append0(func(ctx context.Context) error {
@ -86,53 +68,36 @@ func (p *platformLocalDNSTransport) Exchange(ctx context.Context, message *mDNS.
return nil, err
}
return responseMessage, nil
}
func (p *platformLocalDNSTransport) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
} else {
question := message.Question[0]
var network string
switch strategy {
case dns.DomainStrategyUseIPv4:
switch question.Qtype {
case mDNS.TypeA:
network = "ip4"
case dns.DomainStrategyPreferIPv6:
case mDNS.TypeAAAA:
network = "ip6"
default:
network = "ip"
return nil, E.New("only IP queries are supported by current version of Android")
}
response := &ExchangeContext{
context: ctx,
}
var responseAddr []netip.Addr
var responseAddrs []netip.Addr
var group task.Group
group.Append0(func(ctx context.Context) error {
err := p.iif.Lookup(response, network, domain)
err := p.iif.Lookup(response, network, question.Name)
if err != nil {
return err
}
if response.error != nil {
return response.error
}
switch strategy {
case dns.DomainStrategyUseIPv4:
responseAddr = common.Filter(response.addresses, func(it netip.Addr) bool {
return it.Is4()
})
case dns.DomainStrategyPreferIPv6:
responseAddr = common.Filter(response.addresses, func(it netip.Addr) bool {
return it.Is6()
})
default:
responseAddr = response.addresses
}
/*if len(responseAddr) == 0 {
response.error = dns.RCodeSuccess
}*/
responseAddrs = response.addresses
return nil
})
err := group.Run(ctx)
if err != nil {
return nil, err
}
return responseAddr, nil
return dns.FixedResponse(message.Id, question, responseAddrs, C.DefaultDNSTTL), nil
}
}
type Func interface {

View file

@ -6,6 +6,7 @@ import (
)
type PlatformInterface interface {
LocalDNSTransport() LocalDNSTransport
UsePlatformAutoDetectInterfaceControl() bool
AutoDetectInterfaceControl(fd int32) error
OpenTun(options TunOptions) (int32, error)

View file

@ -18,7 +18,6 @@ import (
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/experimental/libbox/internal/procfs"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/include"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun"
@ -44,7 +43,7 @@ type BoxService struct {
}
func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) {
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
ctx := BaseContext(platformInterface)
ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID)
service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager))
options, err := parseConfig(ctx, configContent)
@ -192,6 +191,9 @@ func (w *platformInterfaceWrapper) Interfaces() ([]adapter.NetworkInterface, err
continue
}
w.defaultInterfaceAccess.Lock()
// (GOOS=windows) SA4006: this value of `isDefault` is never used
// Why not used?
//nolint:staticcheck
isDefault := w.defaultInterface != nil && int(netInterface.Index) == w.defaultInterface.Index
w.defaultInterfaceAccess.Unlock()
interfaces = append(interfaces, adapter.NetworkInterface{

1
go.mod
View file

@ -27,7 +27,6 @@ require (
github.com/sagernet/quic-go v0.49.0-beta.1
github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691
github.com/sagernet/sing v0.6.5
github.com/sagernet/sing-dns v0.4.0
github.com/sagernet/sing-mux v0.3.1
github.com/sagernet/sing-quic v0.4.0
github.com/sagernet/sing-shadowsocks v0.2.7

2
go.sum
View file

@ -121,8 +121,6 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4Wk
github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
github.com/sagernet/sing v0.6.5 h1:TBKTK6Ms0/MNTZm+cTC2hhKunE42XrNIdsxcYtWqeUU=
github.com/sagernet/sing v0.6.5/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing-dns v0.4.0 h1:+mNoOuR3nljjouCH+qMg4zHI1+R9T2ReblGFkZPEndc=
github.com/sagernet/sing-dns v0.4.0/go.mod h1:dweQs54ng2YGzoJfz+F9dGuDNdP5pJ3PLeggnK5VWc8=
github.com/sagernet/sing-mux v0.3.1 h1:kvCc8HyGAskDHDQ0yQvoTi/7J4cZPB/VJMsAM3MmdQI=
github.com/sagernet/sing-mux v0.3.1/go.mod h1:Mkdz8LnDstthz0HWuA/5foncnDIdcNN5KZ6AdJX+x78=
github.com/sagernet/sing-quic v0.4.0 h1:E4geazHk/UrJTXMlT+CBCKmn8V86RhtNeczWtfeoEFc=

View file

@ -2,4 +2,11 @@
package include
import _ "github.com/sagernet/sing-box/transport/dhcp"
import (
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport/dhcp"
)
func registerDHCPTransport(registry *dns.TransportRegistry) {
dhcp.RegisterTransport(registry)
}

View file

@ -3,12 +3,18 @@
package include
import (
"github.com/sagernet/sing-dns"
"context"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func init() {
dns.RegisterTransport([]string{"dhcp"}, func(options dns.TransportOptions) (dns.Transport, error) {
func registerDHCPTransport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.DHCPDNSServerOptions](registry, C.DNSTypeDHCP, func(ctx context.Context, logger log.ContextLogger, tag string, options option.DHCPDNSServerOptions) (adapter.DNSTransport, error) {
return nil, E.New(`DHCP is not included in this build, rebuild with -tags with_dhcp`)
})
}

View file

@ -5,12 +5,13 @@ package include
import (
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport/quic"
"github.com/sagernet/sing-box/protocol/hysteria"
"github.com/sagernet/sing-box/protocol/hysteria2"
_ "github.com/sagernet/sing-box/protocol/naive/quic"
"github.com/sagernet/sing-box/protocol/tuic"
_ "github.com/sagernet/sing-box/transport/v2rayquic"
_ "github.com/sagernet/sing-dns/quic"
)
func registerQUICInbounds(registry *inbound.Registry) {
@ -24,3 +25,8 @@ func registerQUICOutbounds(registry *outbound.Registry) {
tuic.RegisterOutbound(registry)
hysteria2.RegisterOutbound(registry)
}
func registerQUICTransports(registry *dns.TransportRegistry) {
quic.RegisterTransport(registry)
quic.RegisterHTTP3Transport(registry)
}

View file

@ -13,20 +13,17 @@ import (
"github.com/sagernet/sing-box/common/listener"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/protocol/naive"
"github.com/sagernet/sing-box/transport/v2ray"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func init() {
dns.RegisterTransport([]string{"quic", "h3"}, func(options dns.TransportOptions) (dns.Transport, error) {
return nil, C.ErrQUICNotIncluded
})
v2ray.RegisterQUICConstructor(
func(ctx context.Context, logger logger.ContextLogger, options option.V2RayQUICOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) {
return nil, C.ErrQUICNotIncluded
@ -63,3 +60,12 @@ func registerQUICOutbounds(registry *outbound.Registry) {
return nil, C.ErrQUICNotIncluded
})
}
func registerQUICTransports(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeQUIC, func(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
return nil, C.ErrQUICNotIncluded
})
dns.RegisterTransport[option.RemoteHTTPSDNSServerOptions](registry, C.DNSTypeHTTP3, func(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) {
return nil, C.ErrQUICNotIncluded
})
}

View file

@ -8,11 +8,16 @@ import (
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport"
"github.com/sagernet/sing-box/dns/transport/fakeip"
"github.com/sagernet/sing-box/dns/transport/hosts"
"github.com/sagernet/sing-box/dns/transport/local"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/protocol/block"
"github.com/sagernet/sing-box/protocol/direct"
"github.com/sagernet/sing-box/protocol/dns"
protocolDNS "github.com/sagernet/sing-box/protocol/dns"
"github.com/sagernet/sing-box/protocol/group"
"github.com/sagernet/sing-box/protocol/http"
"github.com/sagernet/sing-box/protocol/mixed"
@ -61,7 +66,7 @@ func OutboundRegistry() *outbound.Registry {
direct.RegisterOutbound(registry)
block.RegisterOutbound(registry)
dns.RegisterOutbound(registry)
protocolDNS.RegisterOutbound(registry)
group.RegisterSelector(registry)
group.RegisterURLTest(registry)
@ -91,6 +96,24 @@ func EndpointRegistry() *endpoint.Registry {
return registry
}
func DNSTransportRegistry() *dns.TransportRegistry {
registry := dns.NewTransportRegistry()
transport.RegisterTCP(registry)
transport.RegisterUDP(registry)
transport.RegisterTLS(registry)
transport.RegisterHTTPS(registry)
transport.RegisterPredefined(registry)
hosts.RegisterTransport(registry)
local.RegisterTransport(registry)
fakeip.RegisterTransport(registry)
registerQUICTransports(registry)
registerDHCPTransport(registry)
return registry
}
func registerStubForRemovedInbounds(registry *inbound.Registry) {
inbound.Register[option.ShadowsocksInboundOptions](registry, C.TypeShadowsocksR, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) {
return nil, E.New("ShadowsocksR is deprecated and removed in sing-box 1.6.0")

View file

@ -1,29 +1,53 @@
package option
import (
"context"
"net/netip"
"net/url"
"os"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badjson"
"github.com/sagernet/sing/common/json/badoption"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/service"
"github.com/miekg/dns"
)
type DNSOptions struct {
Servers []DNSServerOptions `json:"servers,omitempty"`
type RawDNSOptions struct {
Servers []NewDNSServerOptions `json:"servers,omitempty"`
Rules []DNSRule `json:"rules,omitempty"`
Final string `json:"final,omitempty"`
ReverseMapping bool `json:"reverse_mapping,omitempty"`
FakeIP *DNSFakeIPOptions `json:"fakeip,omitempty"`
DNSClientOptions
}
type DNSServerOptions struct {
Tag string `json:"tag,omitempty"`
Address string `json:"address"`
AddressResolver string `json:"address_resolver,omitempty"`
AddressStrategy DomainStrategy `json:"address_strategy,omitempty"`
AddressFallbackDelay badoption.Duration `json:"address_fallback_delay,omitempty"`
Strategy DomainStrategy `json:"strategy,omitempty"`
Detour string `json:"detour,omitempty"`
ClientSubnet *badoption.Prefixable `json:"client_subnet,omitempty"`
type LegacyDNSOptions struct {
FakeIP *LegacyDNSFakeIPOptions `json:"fakeip,omitempty"`
}
type DNSOptions struct {
RawDNSOptions
LegacyDNSOptions
}
func (o *DNSOptions) UnmarshalJSONContext(ctx context.Context, content []byte) error {
err := json.UnmarshalContext(ctx, content, &o.LegacyDNSOptions)
if err != nil {
return err
}
if o.FakeIP != nil && o.FakeIP.Enabled {
deprecated.Report(ctx, deprecated.OptionLegacyDNSFakeIPOptions)
ctx = context.WithValue(ctx, (*LegacyDNSFakeIPOptions)(nil), o.FakeIP)
}
legacyOptions := o.LegacyDNSOptions
o.LegacyDNSOptions = LegacyDNSOptions{}
return badjson.UnmarshallExcludedContext(ctx, content, legacyOptions, &o.RawDNSOptions)
}
type DNSClientOptions struct {
@ -35,8 +59,261 @@ type DNSClientOptions struct {
ClientSubnet *badoption.Prefixable `json:"client_subnet,omitempty"`
}
type DNSFakeIPOptions struct {
type LegacyDNSFakeIPOptions struct {
Enabled bool `json:"enabled,omitempty"`
Inet4Range *netip.Prefix `json:"inet4_range,omitempty"`
Inet6Range *netip.Prefix `json:"inet6_range,omitempty"`
Inet4Range *badoption.Prefix `json:"inet4_range,omitempty"`
Inet6Range *badoption.Prefix `json:"inet6_range,omitempty"`
}
type DNSTransportOptionsRegistry interface {
CreateOptions(transportType string) (any, bool)
}
type _NewDNSServerOptions struct {
Type string `json:"type,omitempty"`
Tag string `json:"tag,omitempty"`
Options any `json:"-"`
}
type NewDNSServerOptions _NewDNSServerOptions
func (o *NewDNSServerOptions) MarshalJSONContext(ctx context.Context) ([]byte, error) {
return badjson.MarshallObjectsContext(ctx, (*_NewDNSServerOptions)(o), o.Options)
}
func (o *NewDNSServerOptions) UnmarshalJSONContext(ctx context.Context, content []byte) error {
err := json.UnmarshalContext(ctx, content, (*_NewDNSServerOptions)(o))
if err != nil {
return err
}
registry := service.FromContext[DNSTransportOptionsRegistry](ctx)
if registry == nil {
return E.New("missing outbound options registry in context")
}
var options any
switch o.Type {
case "", C.DNSTypeLegacy:
o.Type = C.DNSTypeLegacy
options = new(LegacyDNSServerOptions)
deprecated.Report(ctx, deprecated.OptionLegacyDNSTransport)
default:
var loaded bool
options, loaded = registry.CreateOptions(o.Type)
if !loaded {
return E.New("unknown transport type: ", o.Type)
}
}
err = badjson.UnmarshallExcludedContext(ctx, content, (*_Outbound)(o), options)
if err != nil {
return err
}
o.Options = options
if o.Type == C.DNSTypeLegacy {
err = o.Upgrade(ctx)
if err != nil {
return err
}
}
return nil
}
func (o *NewDNSServerOptions) Upgrade(ctx context.Context) error {
if o.Type != C.DNSTypeLegacy {
return nil
}
defer func() {
encoder := json.NewEncoder(os.Stderr)
encoder.SetIndent("", " ")
encoder.Encode(o)
}()
options := o.Options.(*LegacyDNSServerOptions)
serverURL, _ := url.Parse(options.Address)
var serverType string
if serverURL.Scheme != "" {
serverType = serverURL.Scheme
} else {
switch options.Address {
case "local", "fakeip":
serverType = options.Address
default:
serverType = C.DNSTypeUDP
}
}
remoteOptions := RemoteDNSServerOptions{
LocalDNSServerOptions: LocalDNSServerOptions{
DialerOptions: DialerOptions{
Detour: options.Detour,
},
LegacyStrategy: options.Strategy,
LegacyDefaultDialer: options.Detour == "",
LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
},
AddressResolver: options.AddressResolver,
AddressStrategy: options.AddressStrategy,
AddressFallbackDelay: options.AddressFallbackDelay,
}
switch serverType {
case C.DNSTypeLocal:
o.Type = C.DNSTypeLocal
o.Options = &remoteOptions.LocalDNSServerOptions
case C.DNSTypeUDP:
o.Type = C.DNSTypeUDP
o.Options = &remoteOptions
var serverAddr M.Socksaddr
if serverURL.Scheme == "" {
serverAddr = M.ParseSocksaddr(options.Address)
} else {
serverAddr = M.ParseSocksaddr(serverURL.Host)
}
if !serverAddr.IsValid() {
return E.New("invalid server address")
}
remoteOptions.Server = serverAddr.Addr.String()
if serverAddr.Port != 0 && serverAddr.Port != 53 {
remoteOptions.ServerPort = serverAddr.Port
}
remoteOptions.Server = serverAddr.AddrString()
remoteOptions.ServerPort = serverAddr.Port
case C.DNSTypeTCP:
o.Type = C.DNSTypeTCP
o.Options = &remoteOptions
serverAddr := M.ParseSocksaddr(serverURL.Host)
if !serverAddr.IsValid() {
return E.New("invalid server address")
}
remoteOptions.Server = serverAddr.Addr.String()
if serverAddr.Port != 0 && serverAddr.Port != 53 {
remoteOptions.ServerPort = serverAddr.Port
}
remoteOptions.Server = serverAddr.AddrString()
remoteOptions.ServerPort = serverAddr.Port
case C.DNSTypeTLS, C.DNSTypeQUIC:
o.Type = serverType
serverAddr := M.ParseSocksaddr(serverURL.Host)
if !serverAddr.IsValid() {
return E.New("invalid server address")
}
remoteOptions.Server = serverAddr.Addr.String()
if serverAddr.Port != 0 && serverAddr.Port != 853 {
remoteOptions.ServerPort = serverAddr.Port
}
o.Options = &RemoteTLSDNSServerOptions{
RemoteDNSServerOptions: remoteOptions,
}
case C.DNSTypeHTTPS, C.DNSTypeHTTP3:
o.Type = serverType
httpsOptions := RemoteHTTPSDNSServerOptions{
RemoteTLSDNSServerOptions: RemoteTLSDNSServerOptions{
RemoteDNSServerOptions: remoteOptions,
},
}
o.Options = &httpsOptions
serverAddr := M.ParseSocksaddr(serverURL.Host)
if !serverAddr.IsValid() {
return E.New("invalid server address")
}
httpsOptions.Server = serverAddr.Addr.String()
if serverAddr.Port != 0 && serverAddr.Port != 443 {
httpsOptions.ServerPort = serverAddr.Port
}
if serverURL.Path != "/dns-query" {
httpsOptions.Path = serverURL.Path
}
case "rcode":
var rcode int
switch serverURL.Host {
case "success":
rcode = dns.RcodeSuccess
case "format_error":
rcode = dns.RcodeFormatError
case "server_failure":
rcode = dns.RcodeServerFailure
case "name_error":
rcode = dns.RcodeNameError
case "not_implemented":
rcode = dns.RcodeNotImplemented
case "refused":
rcode = dns.RcodeRefused
default:
return E.New("unknown rcode: ", serverURL.Host)
}
o.Type = C.DNSTypePreDefined
o.Options = &PredefinedDNSServerOptions{
Responses: []DNSResponseOptions{
{
RCode: common.Ptr(DNSRCode(rcode)),
},
},
}
case C.DNSTypeDHCP:
o.Type = C.DNSTypeDHCP
dhcpOptions := DHCPDNSServerOptions{}
if serverURL.Host != "" && serverURL.Host != "auto" {
dhcpOptions.Interface = serverURL.Host
}
o.Options = &dhcpOptions
case C.DNSTypeFakeIP:
o.Type = C.DNSTypeFakeIP
fakeipOptions := FakeIPDNSServerOptions{}
if legacyOptions, loaded := ctx.Value((*LegacyDNSFakeIPOptions)(nil)).(*LegacyDNSFakeIPOptions); loaded {
fakeipOptions.Inet4Range = legacyOptions.Inet4Range
fakeipOptions.Inet6Range = legacyOptions.Inet6Range
}
o.Options = &fakeipOptions
default:
return E.New("unsupported DNS server scheme: ", serverType)
}
return nil
}
type LegacyDNSServerOptions struct {
Address string `json:"address"`
AddressResolver string `json:"address_resolver,omitempty"`
AddressStrategy DomainStrategy `json:"address_strategy,omitempty"`
AddressFallbackDelay badoption.Duration `json:"address_fallback_delay,omitempty"`
Strategy DomainStrategy `json:"strategy,omitempty"`
Detour string `json:"detour,omitempty"`
ClientSubnet *badoption.Prefixable `json:"client_subnet,omitempty"`
}
type HostsDNSServerOptions struct {
Path badoption.Listable[string] `json:"path,omitempty"`
Predefined badjson.TypedMap[string, badoption.Listable[netip.Addr]] `json:"predefined,omitempty"`
}
type LocalDNSServerOptions struct {
DialerOptions
LegacyStrategy DomainStrategy `json:"-"`
LegacyDefaultDialer bool `json:"-"`
LegacyClientSubnet netip.Prefix `json:"-"`
}
type RemoteDNSServerOptions struct {
LocalDNSServerOptions
ServerOptions
AddressResolver string `json:"address_resolver,omitempty"`
AddressStrategy DomainStrategy `json:"address_strategy,omitempty"`
AddressFallbackDelay badoption.Duration `json:"address_fallback_delay,omitempty"`
}
type RemoteTLSDNSServerOptions struct {
RemoteDNSServerOptions
OutboundTLSOptionsContainer
}
type RemoteHTTPSDNSServerOptions struct {
RemoteTLSDNSServerOptions
Path string `json:"path,omitempty"`
Method string `json:"method,omitempty"`
Headers badoption.HTTPHeader `json:"headers,omitempty"`
}
type FakeIPDNSServerOptions struct {
Inet4Range *badoption.Prefix `json:"inet4_range,omitempty"`
Inet6Range *badoption.Prefix `json:"inet6_range,omitempty"`
}
type DHCPDNSServerOptions struct {
LocalDNSServerOptions
Interface string `json:"interface,omitempty"`
}

161
option/dns_record.go Normal file
View file

@ -0,0 +1,161 @@
package option
import (
"encoding/base64"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badoption"
M "github.com/sagernet/sing/common/metadata"
"github.com/miekg/dns"
)
type PredefinedDNSServerOptions struct {
Responses []DNSResponseOptions `json:"responses,omitempty"`
}
type DNSResponseOptions struct {
Query badoption.Listable[string] `json:"query,omitempty"`
QueryType badoption.Listable[DNSQueryType] `json:"query_type,omitempty"`
RCode *DNSRCode `json:"rcode,omitempty"`
Answer badoption.Listable[DNSRecordOptions] `json:"answer,omitempty"`
Ns badoption.Listable[DNSRecordOptions] `json:"ns,omitempty"`
Extra badoption.Listable[DNSRecordOptions] `json:"extra,omitempty"`
}
type DNSRCode int
func (r DNSRCode) MarshalJSON() ([]byte, error) {
rCodeValue, loaded := dns.RcodeToString[int(r)]
if loaded {
return json.Marshal(rCodeValue)
}
return json.Marshal(int(r))
}
func (r *DNSRCode) UnmarshalJSON(bytes []byte) error {
var intValue int
err := json.Unmarshal(bytes, &intValue)
if err == nil {
*r = DNSRCode(intValue)
return nil
}
var stringValue string
err = json.Unmarshal(bytes, &stringValue)
if err != nil {
return err
}
rCodeValue, loaded := dns.StringToRcode[stringValue]
if !loaded {
return E.New("unknown rcode: " + stringValue)
}
*r = DNSRCode(rCodeValue)
return nil
}
func (r *DNSRCode) Build() int {
if r == nil {
return dns.RcodeSuccess
}
return int(*r)
}
func (o DNSResponseOptions) Build() ([]dns.Question, *dns.Msg, error) {
var questions []dns.Question
if len(o.Query) == 0 && len(o.QueryType) == 0 {
questions = []dns.Question{{Qclass: dns.ClassINET}}
} else if len(o.Query) == 0 {
for _, queryType := range o.QueryType {
questions = append(questions, dns.Question{
Qtype: uint16(queryType),
Qclass: dns.ClassINET,
})
}
} else if len(o.QueryType) == 0 {
for _, domain := range o.Query {
questions = append(questions, dns.Question{
Name: dns.Fqdn(domain),
Qclass: dns.ClassINET,
})
}
} else {
for _, queryType := range o.QueryType {
for _, domain := range o.Query {
questions = append(questions, dns.Question{
Name: dns.Fqdn(domain),
Qtype: uint16(queryType),
Qclass: dns.ClassINET,
})
}
}
}
return questions, &dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
Rcode: o.RCode.Build(),
Authoritative: true,
RecursionDesired: true,
RecursionAvailable: true,
},
Answer: common.Map(o.Answer, DNSRecordOptions.build),
Ns: common.Map(o.Ns, DNSRecordOptions.build),
Extra: common.Map(o.Extra, DNSRecordOptions.build),
}, nil
}
type DNSRecordOptions struct {
dns.RR
fromBase64 bool
}
func (o DNSRecordOptions) MarshalJSON() ([]byte, error) {
if o.fromBase64 {
buffer := buf.Get(dns.Len(o.RR))
defer buf.Put(buffer)
offset, err := dns.PackRR(o.RR, buffer, 0, nil, false)
if err != nil {
return nil, err
}
return json.Marshal(base64.StdEncoding.EncodeToString(buffer[:offset]))
}
return json.Marshal(o.RR.String())
}
func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error {
var stringValue string
err := json.Unmarshal(data, &stringValue)
if err != nil {
return err
}
binary, err := base64.StdEncoding.DecodeString(stringValue)
if err == nil {
return o.unmarshalBase64(binary)
}
record, err := dns.NewRR(stringValue)
if err != nil {
return err
}
if a, isA := record.(*dns.A); isA {
a.A = M.AddrFromIP(a.A).Unmap().AsSlice()
}
o.RR = record
return nil
}
func (o *DNSRecordOptions) unmarshalBase64(binary []byte) error {
record, _, err := dns.UnpackRR(binary, 0)
if err != nil {
return E.New("parse binary DNS record")
}
o.RR = record
o.fromBase64 = true
return nil
}
func (o DNSRecordOptions) build() dns.RR {
return o.RR
}

View file

@ -7,7 +7,6 @@ import (
"time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badjson"
@ -168,12 +167,14 @@ func (r *RouteOptionsActionOptions) UnmarshalJSON(data []byte) error {
type DNSRouteActionOptions struct {
Server string `json:"server,omitempty"`
Strategy DomainStrategy `json:"strategy,omitempty"`
DisableCache bool `json:"disable_cache,omitempty"`
RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"`
ClientSubnet *badoption.Prefixable `json:"client_subnet,omitempty"`
}
type _DNSRouteOptionsActionOptions struct {
Strategy DomainStrategy `json:"strategy,omitempty"`
DisableCache bool `json:"disable_cache,omitempty"`
RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"`
ClientSubnet *badoption.Prefixable `json:"client_subnet,omitempty"`
@ -225,7 +226,7 @@ func (d DirectActionOptions) Descriptions() []string {
if d.UDPFragment != nil {
descriptions = append(descriptions, "udp_fragment="+fmt.Sprint(*d.UDPFragment))
}
if d.DomainStrategy != DomainStrategy(dns.DomainStrategyAsIS) {
if d.DomainStrategy != DomainStrategy(C.DomainStrategyAsIS) {
descriptions = append(descriptions, "domain_strategy="+d.DomainStrategy.String())
}
if d.FallbackDelay != 0 {
@ -252,6 +253,14 @@ type _RejectActionOptions struct {
type RejectActionOptions _RejectActionOptions
func (r RejectActionOptions) MarshalJSON() ([]byte, error) {
switch r.Method {
case C.RuleActionRejectMethodDefault:
r.Method = ""
}
return json.Marshal((_RejectActionOptions)(r))
}
func (r *RejectActionOptions) UnmarshalJSON(bytes []byte) error {
err := json.Unmarshal(bytes, (*_RejectActionOptions)(r))
if err != nil {

View file

@ -83,6 +83,7 @@ type RawDefaultDNSRule struct {
GeoIP badoption.Listable[string] `json:"geoip,omitempty"`
IPCIDR badoption.Listable[string] `json:"ip_cidr,omitempty"`
IPIsPrivate bool `json:"ip_is_private,omitempty"`
IPAcceptAny bool `json:"ip_accept_any,omitempty"`
SourceIPCIDR badoption.Listable[string] `json:"source_ip_cidr,omitempty"`
SourceIPIsPrivate bool `json:"source_ip_is_private,omitempty"`
SourcePort badoption.Listable[uint16] `json:"source_port,omitempty"`

View file

@ -4,7 +4,6 @@ import (
"strings"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/json"
@ -45,19 +44,19 @@ func (v NetworkList) Build() []string {
return strings.Split(string(v), "\n")
}
type DomainStrategy dns.DomainStrategy
type DomainStrategy C.DomainStrategy
func (s DomainStrategy) String() string {
switch dns.DomainStrategy(s) {
case dns.DomainStrategyAsIS:
switch C.DomainStrategy(s) {
case C.DomainStrategyAsIS:
return ""
case dns.DomainStrategyPreferIPv4:
case C.DomainStrategyPreferIPv4:
return "prefer_ipv4"
case dns.DomainStrategyPreferIPv6:
case C.DomainStrategyPreferIPv6:
return "prefer_ipv6"
case dns.DomainStrategyUseIPv4:
case C.DomainStrategyIPv4Only:
return "ipv4_only"
case dns.DomainStrategyUseIPv6:
case C.DomainStrategyIPv6Only:
return "ipv6_only"
default:
panic(E.New("unknown domain strategy: ", s))
@ -66,17 +65,17 @@ func (s DomainStrategy) String() string {
func (s DomainStrategy) MarshalJSON() ([]byte, error) {
var value string
switch dns.DomainStrategy(s) {
case dns.DomainStrategyAsIS:
switch C.DomainStrategy(s) {
case C.DomainStrategyAsIS:
value = ""
// value = "as_is"
case dns.DomainStrategyPreferIPv4:
case C.DomainStrategyPreferIPv4:
value = "prefer_ipv4"
case dns.DomainStrategyPreferIPv6:
case C.DomainStrategyPreferIPv6:
value = "prefer_ipv6"
case dns.DomainStrategyUseIPv4:
case C.DomainStrategyIPv4Only:
value = "ipv4_only"
case dns.DomainStrategyUseIPv6:
case C.DomainStrategyIPv6Only:
value = "ipv6_only"
default:
return nil, E.New("unknown domain strategy: ", s)
@ -92,15 +91,15 @@ func (s *DomainStrategy) UnmarshalJSON(bytes []byte) error {
}
switch value {
case "", "as_is":
*s = DomainStrategy(dns.DomainStrategyAsIS)
*s = DomainStrategy(C.DomainStrategyAsIS)
case "prefer_ipv4":
*s = DomainStrategy(dns.DomainStrategyPreferIPv4)
*s = DomainStrategy(C.DomainStrategyPreferIPv4)
case "prefer_ipv6":
*s = DomainStrategy(dns.DomainStrategyPreferIPv6)
*s = DomainStrategy(C.DomainStrategyPreferIPv6)
case "ipv4_only":
*s = DomainStrategy(dns.DomainStrategyUseIPv4)
*s = DomainStrategy(C.DomainStrategyIPv4Only)
case "ipv6_only":
*s = DomainStrategy(dns.DomainStrategyUseIPv6)
*s = DomainStrategy(C.DomainStrategyIPv6Only)
default:
return E.New("unknown domain strategy: ", value)
}

View file

@ -12,7 +12,6 @@ import (
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
@ -34,7 +33,7 @@ type Outbound struct {
outbound.Adapter
logger logger.ContextLogger
dialer dialer.ParallelInterfaceDialer
domainStrategy dns.DomainStrategy
domainStrategy C.DomainStrategy
fallbackDelay time.Duration
overrideOption int
overrideDestination M.Socksaddr
@ -50,7 +49,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
logger: logger,
domainStrategy: dns.DomainStrategy(options.DomainStrategy),
domainStrategy: C.DomainStrategy(options.DomainStrategy),
fallbackDelay: time.Duration(options.FallbackDelay),
dialer: outboundDialer,
// loopBack: newLoopBackDetector(router),
@ -151,26 +150,26 @@ func (h *Outbound) DialParallel(ctx context.Context, network string, destination
case N.NetworkUDP:
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
var domainStrategy dns.DomainStrategy
if h.domainStrategy != dns.DomainStrategyAsIS {
var domainStrategy C.DomainStrategy
if h.domainStrategy != C.DomainStrategyAsIS {
domainStrategy = h.domainStrategy
} else {
//nolint:staticcheck
domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy)
domainStrategy = C.DomainStrategy(metadata.InboundOptions.DomainStrategy)
}
switch domainStrategy {
case dns.DomainStrategyUseIPv4:
case C.DomainStrategyIPv4Only:
destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is4)
if len(destinationAddresses) == 0 {
return nil, E.New("no IPv4 address available for ", destination)
}
case dns.DomainStrategyUseIPv6:
case C.DomainStrategyIPv6Only:
destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is6)
if len(destinationAddresses) == 0 {
return nil, E.New("no IPv6 address available for ", destination)
}
}
return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, nil, nil, nil, h.fallbackDelay)
return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == C.DomainStrategyPreferIPv6, nil, nil, nil, h.fallbackDelay)
}
func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy *C.NetworkStrategy, networkType []C.InterfaceType, fallbackNetworkType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
@ -191,26 +190,26 @@ func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, dest
case N.NetworkUDP:
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
var domainStrategy dns.DomainStrategy
if h.domainStrategy != dns.DomainStrategyAsIS {
var domainStrategy C.DomainStrategy
if h.domainStrategy != C.DomainStrategyAsIS {
domainStrategy = h.domainStrategy
} else {
//nolint:staticcheck
domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy)
domainStrategy = C.DomainStrategy(metadata.InboundOptions.DomainStrategy)
}
switch domainStrategy {
case dns.DomainStrategyUseIPv4:
case C.DomainStrategyIPv4Only:
destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is4)
if len(destinationAddresses) == 0 {
return nil, E.New("no IPv4 address available for ", destination)
}
case dns.DomainStrategyUseIPv6:
case C.DomainStrategyIPv6Only:
destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is6)
if len(destinationAddresses) == 0 {
return nil, E.New("no IPv6 address available for ", destination)
}
}
return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, networkStrategy, networkType, fallbackNetworkType, fallbackDelay)
return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == C.DomainStrategyPreferIPv6, networkStrategy, networkType, fallbackNetworkType, fallbackDelay)
}
func (h *Outbound) ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy *C.NetworkStrategy, networkType []C.InterfaceType, fallbackNetworkType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) {

View file

@ -7,7 +7,7 @@ import (
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
@ -19,7 +19,7 @@ import (
mDNS "github.com/miekg/dns"
)
func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net.Conn, metadata adapter.InboundContext) error {
func HandleStreamDNSRequest(ctx context.Context, router adapter.DNSRouter, conn net.Conn, metadata adapter.InboundContext) error {
var queryLength uint16
err := binary.Read(conn, binary.BigEndian, &queryLength)
if err != nil {
@ -41,7 +41,7 @@ func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net
}
metadataInQuery := metadata
go func() error {
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message, adapter.DNSQueryOptions{})
if err != nil {
conn.Close()
return err
@ -61,7 +61,7 @@ func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net
return nil
}
func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, cachedPackets []*N.PacketBuffer, metadata adapter.InboundContext) error {
func NewDNSPacketConnection(ctx context.Context, router adapter.DNSRouter, conn N.PacketConn, cachedPackets []*N.PacketBuffer, metadata adapter.InboundContext) error {
metadata.Destination = M.Socksaddr{}
var reader N.PacketReader = conn
var counters []N.CountFunc
@ -123,7 +123,7 @@ func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.P
}
metadataInQuery := metadata
go func() error {
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message, adapter.DNSQueryOptions{})
if err != nil {
cancel(err)
return err
@ -148,7 +148,7 @@ func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.P
return group.Run(fastClose)
}
func newDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
func newDNSPacketConnection(ctx context.Context, router adapter.DNSRouter, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
fastClose, cancel := common.ContextWithCancelCause(ctx)
timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
var group task.Group
@ -193,7 +193,7 @@ func newDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.P
}
metadataInQuery := metadata
go func() error {
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message, adapter.DNSQueryOptions{})
if err != nil {
cancel(err)
return err

View file

@ -14,6 +14,7 @@ import (
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
@ -22,14 +23,14 @@ func RegisterOutbound(registry *outbound.Registry) {
type Outbound struct {
outbound.Adapter
router adapter.Router
router adapter.DNSRouter
logger logger.ContextLogger
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.StubOptions) (adapter.Outbound, error) {
return &Outbound{
Adapter: outbound.NewAdapter(C.TypeDNS, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil),
router: router,
router: service.FromContext[adapter.DNSRouter](ctx),
logger: logger,
}, nil
}

View file

@ -17,6 +17,7 @@ import (
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
@ -27,7 +28,7 @@ var _ adapter.Outbound = (*Outbound)(nil)
type Outbound struct {
outbound.Adapter
router adapter.Router
dnsRouter adapter.DNSRouter
logger logger.ContextLogger
client *socks.Client
resolve bool
@ -51,7 +52,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
}
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, tag, options.Network.Build(), options.DialerOptions),
router: router,
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
logger: logger,
client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password),
resolve: version == socks.Version4,
@ -83,7 +84,7 @@ func (h *Outbound) DialContext(ctx context.Context, network string, destination
return nil, E.Extend(N.ErrUnknownNetwork, network)
}
if h.resolve && destination.IsFqdn() {
destinationAddresses, err := h.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := h.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}
@ -101,7 +102,7 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
return h.uotClient.ListenPacket(ctx, destination)
}
if h.resolve && destination.IsFqdn() {
destinationAddresses, err := h.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := h.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}

View file

@ -13,13 +13,13 @@ import (
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterEndpoint(registry *endpoint.Registry) {
@ -35,6 +35,7 @@ type Endpoint struct {
endpoint.Adapter
ctx context.Context
router adapter.Router
dnsRouter adapter.DNSRouter
logger logger.ContextLogger
localAddresses []netip.Prefix
endpoint *wireguard.Endpoint
@ -45,6 +46,7 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL
Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
ctx: ctx,
router: router,
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
logger: logger,
localAddresses: options.Address,
}
@ -79,7 +81,9 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL
PrivateKey: options.PrivateKey,
ListenPort: options.ListenPort,
ResolvePeer: func(domain string) (netip.Addr, error) {
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
endpointAddresses, lookupErr := ep.dnsRouter.Lookup(ctx, domain, adapter.DNSQueryOptions{
Strategy: C.DomainStrategy(options.DomainStrategy),
})
if lookupErr != nil {
return netip.Addr{}, lookupErr
}
@ -185,7 +189,7 @@ func (w *Endpoint) DialContext(ctx context.Context, network string, destination
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}
@ -199,7 +203,7 @@ func (w *Endpoint) DialContext(ctx context.Context, network string, destination
func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}

View file

@ -13,12 +13,12 @@ import (
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
@ -33,7 +33,7 @@ var (
type Outbound struct {
outbound.Adapter
ctx context.Context
router adapter.Router
dnsRouter adapter.DNSRouter
logger logger.ContextLogger
localAddresses []netip.Prefix
endpoint *wireguard.Endpoint
@ -47,7 +47,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
ctx: ctx,
router: router,
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
logger: logger,
localAddresses: options.LocalAddress,
}
@ -94,7 +94,9 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
Address: options.LocalAddress,
PrivateKey: options.PrivateKey,
ResolvePeer: func(domain string) (netip.Addr, error) {
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
endpointAddresses, lookupErr := outbound.dnsRouter.Lookup(ctx, domain, adapter.DNSQueryOptions{
Strategy: C.DomainStrategy(options.DomainStrategy),
})
if lookupErr != nil {
return netip.Addr{}, lookupErr
}
@ -137,7 +139,7 @@ func (o *Outbound) DialContext(ctx context.Context, network string, destination
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsFqdn() {
destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := o.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}
@ -151,7 +153,7 @@ func (o *Outbound) DialContext(ctx context.Context, network string, destination
func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsFqdn() {
destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := o.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}

View file

@ -14,10 +14,13 @@
"type": "shadowsocks",
"listen": "::",
"listen_port": 8080,
"sniff": true,
"network": "tcp",
"method": "2022-blake3-aes-128-gcm",
"password": "8JCsPssfgS8tiRwiMlhARg=="
"password": "Gn1JUS14bLUHgv1cWDDp4A==",
"multiplex": {
"enabled": true,
"padding": true
}
}
],
"outbounds": [
@ -32,7 +35,7 @@
"route": {
"rules": [
{
"protocol": "dns",
"port": 53,
"outbound": "dns-out"
}
]

View file

@ -8,11 +8,12 @@ import (
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
dnsOutbound "github.com/sagernet/sing-box/protocol/dns"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat2"
@ -24,7 +25,7 @@ func (r *Router) hijackDNSStream(ctx context.Context, conn net.Conn, metadata ad
metadata.Destination = M.Socksaddr{}
for {
conn.SetReadDeadline(time.Now().Add(C.DNSTimeout))
err := dnsOutbound.HandleStreamDNSRequest(ctx, r, conn, metadata)
err := dnsOutbound.HandleStreamDNSRequest(ctx, r.dns, conn, metadata)
if err != nil {
return err
}
@ -38,37 +39,38 @@ func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetB
buffer := packet.Buffer
destination := packet.Destination
N.PutPacketBuffer(packet)
go ExchangeDNSPacket(ctx, r, natConn, buffer, metadata, destination)
go ExchangeDNSPacket(ctx, r.dns, r.logger, natConn, buffer, metadata, destination)
}
natConn.SetHandler(&dnsHijacker{
router: r,
router: r.dns,
logger: r.logger,
conn: conn,
ctx: ctx,
metadata: metadata,
})
return
}
err := dnsOutbound.NewDNSPacketConnection(ctx, r, conn, packetBuffers, metadata)
err := dnsOutbound.NewDNSPacketConnection(ctx, r.dns, conn, packetBuffers, metadata)
if err != nil && !E.IsClosedOrCanceled(err) {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "process packet connection"))
r.logger.ErrorContext(ctx, E.Cause(err, "process DNS packet connection"))
}
}
func ExchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) {
func ExchangeDNSPacket(ctx context.Context, router adapter.DNSRouter, logger logger.ContextLogger, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) {
err := exchangeDNSPacket(ctx, router, conn, buffer, metadata, destination)
if err != nil && !errors.Is(err, tun.ErrDrop) && !E.IsClosedOrCanceled(err) {
router.dnsLogger.ErrorContext(ctx, E.Cause(err, "process packet connection"))
logger.ErrorContext(ctx, E.Cause(err, "process DNS packet connection"))
}
}
func exchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) error {
func exchangeDNSPacket(ctx context.Context, router adapter.DNSRouter, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) error {
var message mDNS.Msg
err := message.Unpack(buffer.Bytes())
buffer.Release()
if err != nil {
return E.Cause(err, "unpack request")
}
response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message, adapter.DNSQueryOptions{})
if err != nil {
return err
}
@ -81,12 +83,13 @@ func exchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, b
}
type dnsHijacker struct {
router *Router
router adapter.DNSRouter
logger logger.ContextLogger
conn N.PacketConn
ctx context.Context
metadata adapter.InboundContext
}
func (h *dnsHijacker) NewPacketEx(buffer *buf.Buffer, destination M.Socksaddr) {
go ExchangeDNSPacket(h.ctx, h.router, h.conn, buffer, h.metadata, destination)
go ExchangeDNSPacket(h.ctx, h.router, h.logger, h.conn, buffer, h.metadata, destination)
}

View file

@ -1,246 +0,0 @@
package route
import (
"context"
"io"
"net"
"net/http"
"os"
"path/filepath"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/geoip"
"github.com/sagernet/sing-box/common/geosite"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/deprecated"
R "github.com/sagernet/sing-box/route/rule"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/service/filemanager"
)
func (r *Router) GeoIPReader() *geoip.Reader {
return r.geoIPReader
}
func (r *Router) LoadGeosite(code string) (adapter.Rule, error) {
rule, cached := r.geositeCache[code]
if cached {
return rule, nil
}
items, err := r.geositeReader.Read(code)
if err != nil {
return nil, err
}
rule, err = R.NewDefaultRule(r.ctx, nil, geosite.Compile(items))
if err != nil {
return nil, err
}
r.geositeCache[code] = rule
return rule, nil
}
func (r *Router) prepareGeoIPDatabase() error {
deprecated.Report(r.ctx, deprecated.OptionGEOIP)
var geoPath string
if r.geoIPOptions.Path != "" {
geoPath = r.geoIPOptions.Path
} else {
geoPath = "geoip.db"
if foundPath, loaded := C.FindPath(geoPath); loaded {
geoPath = foundPath
}
}
if !rw.IsFile(geoPath) {
geoPath = filemanager.BasePath(r.ctx, geoPath)
}
if stat, err := os.Stat(geoPath); err == nil {
if stat.IsDir() {
return E.New("geoip path is a directory: ", geoPath)
}
if stat.Size() == 0 {
os.Remove(geoPath)
}
}
if !rw.IsFile(geoPath) {
r.logger.Warn("geoip database not exists: ", geoPath)
var err error
for attempts := 0; attempts < 3; attempts++ {
err = r.downloadGeoIPDatabase(geoPath)
if err == nil {
break
}
r.logger.Error("download geoip database: ", err)
os.Remove(geoPath)
// time.Sleep(10 * time.Second)
}
if err != nil {
return err
}
}
geoReader, codes, err := geoip.Open(geoPath)
if err != nil {
return E.Cause(err, "open geoip database")
}
r.logger.Info("loaded geoip database: ", len(codes), " codes")
r.geoIPReader = geoReader
return nil
}
func (r *Router) prepareGeositeDatabase() error {
deprecated.Report(r.ctx, deprecated.OptionGEOSITE)
var geoPath string
if r.geositeOptions.Path != "" {
geoPath = r.geositeOptions.Path
} else {
geoPath = "geosite.db"
if foundPath, loaded := C.FindPath(geoPath); loaded {
geoPath = foundPath
}
}
if !rw.IsFile(geoPath) {
geoPath = filemanager.BasePath(r.ctx, geoPath)
}
if stat, err := os.Stat(geoPath); err == nil {
if stat.IsDir() {
return E.New("geoip path is a directory: ", geoPath)
}
if stat.Size() == 0 {
os.Remove(geoPath)
}
}
if !rw.IsFile(geoPath) {
r.logger.Warn("geosite database not exists: ", geoPath)
var err error
for attempts := 0; attempts < 3; attempts++ {
err = r.downloadGeositeDatabase(geoPath)
if err == nil {
break
}
r.logger.Error("download geosite database: ", err)
os.Remove(geoPath)
}
if err != nil {
return err
}
}
geoReader, codes, err := geosite.Open(geoPath)
if err == nil {
r.logger.Info("loaded geosite database: ", len(codes), " codes")
r.geositeReader = geoReader
} else {
return E.Cause(err, "open geosite database")
}
return nil
}
func (r *Router) downloadGeoIPDatabase(savePath string) error {
var downloadURL string
if r.geoIPOptions.DownloadURL != "" {
downloadURL = r.geoIPOptions.DownloadURL
} else {
downloadURL = "https://github.com/SagerNet/sing-geoip/releases/latest/download/geoip.db"
}
r.logger.Info("downloading geoip database")
var detour adapter.Outbound
if r.geoIPOptions.DownloadDetour != "" {
outbound, loaded := r.outbound.Outbound(r.geoIPOptions.DownloadDetour)
if !loaded {
return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour)
}
detour = outbound
} else {
detour = r.outbound.Default()
}
if parentDir := filepath.Dir(savePath); parentDir != "" {
filemanager.MkdirAll(r.ctx, parentDir, 0o755)
}
httpClient := &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: C.TCPTimeout,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return detour.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
},
}
defer httpClient.CloseIdleConnections()
request, err := http.NewRequest("GET", downloadURL, nil)
if err != nil {
return err
}
response, err := httpClient.Do(request.WithContext(r.ctx))
if err != nil {
return err
}
defer response.Body.Close()
saveFile, err := filemanager.Create(r.ctx, savePath)
if err != nil {
return E.Cause(err, "open output file: ", downloadURL)
}
_, err = io.Copy(saveFile, response.Body)
saveFile.Close()
if err != nil {
filemanager.Remove(r.ctx, savePath)
}
return err
}
func (r *Router) downloadGeositeDatabase(savePath string) error {
var downloadURL string
if r.geositeOptions.DownloadURL != "" {
downloadURL = r.geositeOptions.DownloadURL
} else {
downloadURL = "https://github.com/SagerNet/sing-geosite/releases/latest/download/geosite.db"
}
r.logger.Info("downloading geosite database")
var detour adapter.Outbound
if r.geositeOptions.DownloadDetour != "" {
outbound, loaded := r.outbound.Outbound(r.geositeOptions.DownloadDetour)
if !loaded {
return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour)
}
detour = outbound
} else {
detour = r.outbound.Default()
}
if parentDir := filepath.Dir(savePath); parentDir != "" {
filemanager.MkdirAll(r.ctx, parentDir, 0o755)
}
httpClient := &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: C.TCPTimeout,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return detour.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
},
}
defer httpClient.CloseIdleConnections()
request, err := http.NewRequest("GET", downloadURL, nil)
if err != nil {
return err
}
response, err := httpClient.Do(request.WithContext(r.ctx))
if err != nil {
return err
}
defer response.Body.Close()
saveFile, err := filemanager.Create(r.ctx, savePath)
if err != nil {
return E.Cause(err, "open output file: ", downloadURL)
}
_, err = io.Copy(saveFile, response.Body)
saveFile.Close()
if err != nil {
filemanager.Remove(r.ctx, savePath)
}
return err
}

View file

@ -17,7 +17,6 @@ import (
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-mux"
"github.com/sagernet/sing-vmess"
"github.com/sagernet/sing/common"
@ -325,12 +324,13 @@ func (r *Router) matchRule(
metadata.ProcessInfo = processInfo
}
}
if r.fakeIPStore != nil && r.fakeIPStore.Contains(metadata.Destination.Addr) {
domain, loaded := r.fakeIPStore.Lookup(metadata.Destination.Addr)
if metadata.Destination.Addr.IsValid() && r.dnsTransport.FakeIP() != nil && r.dnsTransport.FakeIP().Store().Contains(metadata.Destination.Addr) {
domain, loaded := r.dnsTransport.FakeIP().Store().Lookup(metadata.Destination.Addr)
if !loaded {
fatalErr = E.New("missing fakeip record, try to configure experimental.cache_file")
fatalErr = E.New("missing fakeip record, try enable `experimental.cache_file`")
return
}
if domain != "" {
metadata.OriginDestination = metadata.Destination
metadata.Destination = M.Socksaddr{
Fqdn: domain,
@ -339,8 +339,8 @@ func (r *Router) matchRule(
metadata.FakeIP = true
r.logger.DebugContext(ctx, "found fakeip domain: ", domain)
}
if r.dnsReverseMapping != nil && metadata.Domain == "" {
domain, loaded := r.dnsReverseMapping.Query(metadata.Destination.Addr)
} else if metadata.Domain == "" {
domain, loaded := r.dns.LookupReverseMapping(metadata.Destination.Addr)
if loaded {
metadata.Domain = domain
r.logger.DebugContext(ctx, "found reserve mapped domain: ", metadata.Domain)
@ -369,9 +369,9 @@ func (r *Router) matchRule(
packetBuffers = newPackerBuffers
}
}
if dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) != dns.DomainStrategyAsIS {
if C.DomainStrategy(metadata.InboundOptions.DomainStrategy) != C.DomainStrategyAsIS {
fatalErr = r.actionResolve(ctx, metadata, &rule.RuleActionResolve{
Strategy: dns.DomainStrategy(metadata.InboundOptions.DomainStrategy),
Strategy: C.DomainStrategy(metadata.InboundOptions.DomainStrategy),
})
if fatalErr != nil {
return
@ -649,13 +649,23 @@ func (r *Router) actionSniff(
func (r *Router) actionResolve(ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionResolve) error {
if metadata.Destination.IsFqdn() {
metadata.DNSServer = action.Server
addresses, err := r.Lookup(adapter.WithContext(ctx, metadata), metadata.Destination.Fqdn, action.Strategy)
var transport adapter.DNSTransport
if action.Server != "" {
var loaded bool
transport, loaded = r.dnsTransport.Transport(action.Server)
if !loaded {
return E.New("DNS server not found: ", action.Server)
}
}
addresses, err := r.dns.Lookup(adapter.WithContext(ctx, metadata), metadata.Destination.Fqdn, adapter.DNSQueryOptions{
Transport: transport,
Strategy: action.Strategy,
})
if err != nil {
return err
}
metadata.DestinationAddresses = addresses
r.dnsLogger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]")
r.logger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]")
if metadata.Destination.IsIPv4() {
metadata.IPVersion = 4
} else if metadata.Destination.IsIPv6() {

View file

@ -1,348 +0,0 @@
package route
import (
"context"
"errors"
"net/netip"
"strings"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
R "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
mDNS "github.com/miekg/dns"
)
type DNSReverseMapping struct {
cache *cache.LruCache[netip.Addr, string]
}
func NewDNSReverseMapping() *DNSReverseMapping {
return &DNSReverseMapping{
cache: cache.New[netip.Addr, string](),
}
}
func (m *DNSReverseMapping) Save(address netip.Addr, domain string, ttl int) {
m.cache.StoreWithExpire(address, domain, time.Now().Add(time.Duration(ttl)*time.Second))
}
func (m *DNSReverseMapping) Query(address netip.Addr) (string, bool) {
domain, loaded := m.cache.Load(address)
return domain, loaded
}
func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool) (dns.Transport, dns.QueryOptions, adapter.DNSRule, int) {
metadata := adapter.ContextFrom(ctx)
if metadata == nil {
panic("no context")
}
var options dns.QueryOptions
var currentRuleIndex int
if ruleIndex != -1 {
currentRuleIndex = ruleIndex + 1
}
for ; currentRuleIndex < len(r.dnsRules); currentRuleIndex++ {
currentRule := r.dnsRules[currentRuleIndex]
if currentRule.WithAddressLimit() && !isAddressQuery {
continue
}
metadata.ResetRuleCache()
if currentRule.Match(metadata) {
ruleDescription := currentRule.String()
if ruleDescription != "" {
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
} else {
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] => ", currentRule.Action())
}
switch action := currentRule.Action().(type) {
case *R.RuleActionDNSRoute:
transport, loaded := r.transportMap[action.Server]
if !loaded {
r.dnsLogger.ErrorContext(ctx, "transport not found: ", action.Server)
continue
}
_, isFakeIP := transport.(adapter.FakeIPTransport)
if isFakeIP && !allowFakeIP {
continue
}
if isFakeIP || action.DisableCache {
options.DisableCache = true
}
if action.RewriteTTL != nil {
options.RewriteTTL = action.RewriteTTL
}
if action.ClientSubnet.IsValid() {
options.ClientSubnet = action.ClientSubnet
}
if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded {
options.Strategy = domainStrategy
} else {
options.Strategy = r.defaultDomainStrategy
}
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] => ", currentRule.Action())
return transport, options, currentRule, currentRuleIndex
case *R.RuleActionDNSRouteOptions:
if action.DisableCache {
options.DisableCache = true
}
if action.RewriteTTL != nil {
options.RewriteTTL = action.RewriteTTL
}
if action.ClientSubnet.IsValid() {
options.ClientSubnet = action.ClientSubnet
}
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] => ", currentRule.Action())
case *R.RuleActionReject:
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] => ", currentRule.Action())
return nil, options, currentRule, currentRuleIndex
}
}
}
if domainStrategy, dsLoaded := r.transportDomainStrategy[r.defaultTransport]; dsLoaded {
options.Strategy = domainStrategy
} else {
options.Strategy = r.defaultDomainStrategy
}
return r.defaultTransport, options, nil, -1
}
func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
if len(message.Question) != 1 {
r.dnsLogger.WarnContext(ctx, "bad question size: ", len(message.Question))
responseMessage := mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
Id: message.Id,
Response: true,
Rcode: mDNS.RcodeFormatError,
},
Question: message.Question,
}
return &responseMessage, nil
}
var (
response *mDNS.Msg
cached bool
transport dns.Transport
err error
)
response, cached = r.dnsClient.ExchangeCache(ctx, message)
if !cached {
var metadata *adapter.InboundContext
ctx, metadata = adapter.ExtendContext(ctx)
metadata.Destination = M.Socksaddr{}
metadata.QueryType = message.Question[0].Qtype
switch metadata.QueryType {
case mDNS.TypeA:
metadata.IPVersion = 4
case mDNS.TypeAAAA:
metadata.IPVersion = 6
}
metadata.Domain = fqdnToDomain(message.Question[0].Name)
var (
options dns.QueryOptions
rule adapter.DNSRule
ruleIndex int
)
ruleIndex = -1
for {
dnsCtx := adapter.OverrideContext(ctx)
var addressLimit bool
transport, options, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message))
if rule != nil {
switch action := rule.Action().(type) {
case *R.RuleActionReject:
switch action.Method {
case C.RuleActionRejectMethodDefault:
return dns.FixedResponse(message.Id, message.Question[0], nil, 0), nil
case C.RuleActionRejectMethodDrop:
return nil, tun.ErrDrop
}
}
}
r.dnsLogger.DebugContext(ctx, "exchange ", formatQuestion(message.Question[0].String()), " via ", transport.Name())
if rule != nil && rule.WithAddressLimit() {
addressLimit = true
response, err = r.dnsClient.ExchangeWithResponseCheck(dnsCtx, transport, message, options, func(responseAddrs []netip.Addr) bool {
metadata.DestinationAddresses = responseAddrs
return rule.MatchAddressLimit(metadata)
})
} else {
addressLimit = false
response, err = r.dnsClient.Exchange(dnsCtx, transport, message, options)
}
var rejected bool
if err != nil {
if errors.Is(err, dns.ErrResponseRejectedCached) {
rejected = true
r.dnsLogger.DebugContext(ctx, E.Cause(err, "response rejected for ", formatQuestion(message.Question[0].String())), " (cached)")
} else if errors.Is(err, dns.ErrResponseRejected) {
rejected = true
r.dnsLogger.DebugContext(ctx, E.Cause(err, "response rejected for ", formatQuestion(message.Question[0].String())))
} else if len(message.Question) > 0 {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", formatQuestion(message.Question[0].String())))
} else {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for <empty query>"))
}
}
if addressLimit && rejected {
continue
}
break
}
}
if err != nil {
return nil, err
}
if r.dnsReverseMapping != nil && response != nil && len(response.Answer) > 0 {
if _, isFakeIP := transport.(adapter.FakeIPTransport); !isFakeIP {
for _, answer := range response.Answer {
switch record := answer.(type) {
case *mDNS.A:
r.dnsReverseMapping.Save(M.AddrFromIP(record.A), fqdnToDomain(record.Hdr.Name), int(record.Hdr.Ttl))
case *mDNS.AAAA:
r.dnsReverseMapping.Save(M.AddrFromIP(record.AAAA), fqdnToDomain(record.Hdr.Name), int(record.Hdr.Ttl))
}
}
}
}
return response, nil
}
func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
var (
responseAddrs []netip.Addr
cached bool
err error
)
printResult := func() {
if err != nil {
if errors.Is(err, dns.ErrResponseRejectedCached) {
r.dnsLogger.DebugContext(ctx, "response rejected for ", domain, " (cached)")
} else if errors.Is(err, dns.ErrResponseRejected) {
r.dnsLogger.DebugContext(ctx, "response rejected for ", domain)
} else {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
}
} else if len(responseAddrs) == 0 {
r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
err = dns.RCodeNameError
}
}
responseAddrs, cached = r.dnsClient.LookupCache(ctx, domain, strategy)
if cached {
if len(responseAddrs) == 0 {
return nil, dns.RCodeNameError
}
return responseAddrs, nil
}
r.dnsLogger.DebugContext(ctx, "lookup domain ", domain)
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Destination = M.Socksaddr{}
metadata.Domain = domain
if metadata.DNSServer != "" {
transport, loaded := r.transportMap[metadata.DNSServer]
if !loaded {
return nil, E.New("transport not found: ", metadata.DNSServer)
}
if strategy == dns.DomainStrategyAsIS {
if transportDomainStrategy, loaded := r.transportDomainStrategy[transport]; loaded {
strategy = transportDomainStrategy
} else {
strategy = r.defaultDomainStrategy
}
}
responseAddrs, err = r.dnsClient.Lookup(ctx, transport, domain, dns.QueryOptions{Strategy: strategy})
} else {
var (
transport dns.Transport
options dns.QueryOptions
rule adapter.DNSRule
ruleIndex int
)
ruleIndex = -1
for {
dnsCtx := adapter.OverrideContext(ctx)
var addressLimit bool
transport, options, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true)
if strategy != dns.DomainStrategyAsIS {
options.Strategy = strategy
}
if rule != nil {
switch action := rule.Action().(type) {
case *R.RuleActionReject:
switch action.Method {
case C.RuleActionRejectMethodDefault:
return nil, nil
case C.RuleActionRejectMethodDrop:
return nil, tun.ErrDrop
}
}
}
if rule != nil && rule.WithAddressLimit() {
addressLimit = true
responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool {
metadata.DestinationAddresses = responseAddrs
return rule.MatchAddressLimit(metadata)
})
} else {
addressLimit = false
responseAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, options)
}
if !addressLimit || err == nil {
break
}
printResult()
}
}
printResult()
if len(responseAddrs) > 0 {
r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " "))
}
return responseAddrs, err
}
func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) {
return r.Lookup(ctx, domain, dns.DomainStrategyAsIS)
}
func (r *Router) ClearDNSCache() {
r.dnsClient.ClearCache()
if r.platformInterface != nil {
r.platformInterface.ClearDNSCache()
}
}
func isAddressQuery(message *mDNS.Msg) bool {
for _, question := range message.Question {
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA || question.Qtype == mDNS.TypeHTTPS {
return true
}
}
return false
}
func fqdnToDomain(fqdn string) string {
if mDNS.IsFqdn(fqdn) {
return fqdn[:len(fqdn)-1]
}
return fqdn
}
func formatQuestion(string string) string {
if strings.HasPrefix(string, ";") {
string = string[1:]
}
string = strings.ReplaceAll(string, "\t", " ")
for strings.Contains(string, " ") {
string = strings.ReplaceAll(string, " ", " ")
}
return string
}

View file

@ -2,17 +2,10 @@ package route
import (
"context"
"net/netip"
"net/url"
"os"
"runtime"
"strings"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/geoip"
"github.com/sagernet/sing-box/common/geosite"
"github.com/sagernet/sing-box/common/process"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
@ -20,13 +13,7 @@ import (
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
R "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing-box/transport/fakeip"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
@ -37,31 +24,16 @@ var _ adapter.Router = (*Router)(nil)
type Router struct {
ctx context.Context
logger log.ContextLogger
dnsLogger log.ContextLogger
inbound adapter.InboundManager
outbound adapter.OutboundManager
dns adapter.DNSRouter
dnsTransport adapter.DNSTransportManager
connection adapter.ConnectionManager
network adapter.NetworkManager
rules []adapter.Rule
needGeoIPDatabase bool
needGeositeDatabase bool
geoIPOptions option.GeoIPOptions
geositeOptions option.GeositeOptions
geoIPReader *geoip.Reader
geositeReader *geosite.Reader
geositeCache map[string]adapter.Rule
needFindProcess bool
dnsClient *dns.Client
defaultDomainStrategy dns.DomainStrategy
dnsRules []adapter.DNSRule
ruleSets []adapter.RuleSet
ruleSetMap map[string]adapter.RuleSet
defaultTransport dns.Transport
transports []dns.Transport
transportMap map[string]dns.Transport
transportDomainStrategy map[dns.Transport]dns.DomainStrategy
dnsReverseMapping *DNSReverseMapping
fakeIPStore adapter.FakeIPStore
processSearcher process.Searcher
pauseManager pause.Manager
tracker adapter.ConnectionTracker
@ -70,299 +42,51 @@ type Router struct {
started bool
}
func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions) (*Router, error) {
router := &Router{
func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions) *Router {
return &Router{
ctx: ctx,
logger: logFactory.NewLogger("router"),
dnsLogger: logFactory.NewLogger("dns"),
inbound: service.FromContext[adapter.InboundManager](ctx),
outbound: service.FromContext[adapter.OutboundManager](ctx),
dns: service.FromContext[adapter.DNSRouter](ctx),
dnsTransport: service.FromContext[adapter.DNSTransportManager](ctx),
connection: service.FromContext[adapter.ConnectionManager](ctx),
network: service.FromContext[adapter.NetworkManager](ctx),
rules: make([]adapter.Rule, 0, len(options.Rules)),
dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)),
ruleSetMap: make(map[string]adapter.RuleSet),
needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule),
needGeositeDatabase: hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule),
geoIPOptions: common.PtrValueOrDefault(options.GeoIP),
geositeOptions: common.PtrValueOrDefault(options.Geosite),
geositeCache: make(map[string]adapter.Rule),
needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess,
defaultDomainStrategy: dns.DomainStrategy(dnsOptions.Strategy),
pauseManager: service.FromContext[pause.Manager](ctx),
platformInterface: service.FromContext[platform.Interface](ctx),
needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule),
}
service.MustRegister[adapter.Router](ctx, router)
router.dnsClient = dns.NewClient(dns.ClientOptions{
DisableCache: dnsOptions.DNSClientOptions.DisableCache,
DisableExpire: dnsOptions.DNSClientOptions.DisableExpire,
IndependentCache: dnsOptions.DNSClientOptions.IndependentCache,
CacheCapacity: dnsOptions.DNSClientOptions.CacheCapacity,
RDRC: func() dns.RDRCStore {
cacheFile := service.FromContext[adapter.CacheFile](ctx)
if cacheFile == nil {
}
func (r *Router) Initialize(rules []option.Rule, ruleSets []option.RuleSet) error {
for i, options := range rules {
rule, err := R.NewRule(r.ctx, r.logger, options, false)
if err != nil {
return E.Cause(err, "parse rule[", i, "]")
}
r.rules = append(r.rules, rule)
}
for i, options := range ruleSets {
if _, exists := r.ruleSetMap[options.Tag]; exists {
return E.New("duplicate rule-set tag: ", options.Tag)
}
ruleSet, err := R.NewRuleSet(r.ctx, r.logger, options)
if err != nil {
return E.Cause(err, "parse rule-set[", i, "]")
}
r.ruleSets = append(r.ruleSets, ruleSet)
r.ruleSetMap[options.Tag] = ruleSet
}
return nil
}
if !cacheFile.StoreRDRC() {
return nil
}
return cacheFile
},
Logger: router.dnsLogger,
})
for i, ruleOptions := range options.Rules {
routeRule, err := R.NewRule(ctx, router.logger, ruleOptions, true)
if err != nil {
return nil, E.Cause(err, "parse rule[", i, "]")
}
router.rules = append(router.rules, routeRule)
}
for i, dnsRuleOptions := range dnsOptions.Rules {
dnsRule, err := R.NewDNSRule(ctx, router.logger, dnsRuleOptions, true)
if err != nil {
return nil, E.Cause(err, "parse dns rule[", i, "]")
}
router.dnsRules = append(router.dnsRules, dnsRule)
}
for i, ruleSetOptions := range options.RuleSet {
if _, exists := router.ruleSetMap[ruleSetOptions.Tag]; exists {
return nil, E.New("duplicate rule-set tag: ", ruleSetOptions.Tag)
}
ruleSet, err := R.NewRuleSet(ctx, router.logger, ruleSetOptions)
if err != nil {
return nil, E.Cause(err, "parse rule-set[", i, "]")
}
router.ruleSets = append(router.ruleSets, ruleSet)
router.ruleSetMap[ruleSetOptions.Tag] = ruleSet
}
transports := make([]dns.Transport, len(dnsOptions.Servers))
dummyTransportMap := make(map[string]dns.Transport)
transportMap := make(map[string]dns.Transport)
transportTags := make([]string, len(dnsOptions.Servers))
transportTagMap := make(map[string]bool)
transportDomainStrategy := make(map[dns.Transport]dns.DomainStrategy)
for i, server := range dnsOptions.Servers {
var tag string
if server.Tag != "" {
tag = server.Tag
} else {
tag = F.ToString(i)
}
if transportTagMap[tag] {
return nil, E.New("duplicate dns server tag: ", tag)
}
transportTags[i] = tag
transportTagMap[tag] = true
}
outboundManager := service.FromContext[adapter.OutboundManager](ctx)
for {
lastLen := len(dummyTransportMap)
for i, server := range dnsOptions.Servers {
tag := transportTags[i]
if _, exists := dummyTransportMap[tag]; exists {
continue
}
var detour N.Dialer
if server.Detour == "" {
detour = dialer.NewDefaultOutbound(outboundManager)
} else {
detour = dialer.NewDetour(outboundManager, server.Detour)
}
var serverProtocol string
switch server.Address {
case "local":
serverProtocol = "local"
default:
serverURL, _ := url.Parse(server.Address)
var serverAddress string
if serverURL != nil {
if serverURL.Scheme == "" {
serverProtocol = "udp"
} else {
serverProtocol = serverURL.Scheme
}
serverAddress = serverURL.Hostname()
}
if serverAddress == "" {
serverAddress = server.Address
}
notIpAddress := !M.ParseSocksaddr(serverAddress).Addr.IsValid()
if server.AddressResolver != "" {
if !transportTagMap[server.AddressResolver] {
return nil, E.New("parse dns server[", tag, "]: address resolver not found: ", server.AddressResolver)
}
if upstream, exists := dummyTransportMap[server.AddressResolver]; exists {
detour = dns.NewDialerWrapper(detour, router.dnsClient, upstream, dns.DomainStrategy(server.AddressStrategy), time.Duration(server.AddressFallbackDelay))
} else {
continue
}
} else if notIpAddress && strings.Contains(server.Address, ".") {
return nil, E.New("parse dns server[", tag, "]: missing address_resolver")
}
}
var clientSubnet netip.Prefix
if server.ClientSubnet != nil {
clientSubnet = netip.Prefix(common.PtrValueOrDefault(server.ClientSubnet))
} else if dnsOptions.ClientSubnet != nil {
clientSubnet = netip.Prefix(common.PtrValueOrDefault(dnsOptions.ClientSubnet))
}
if serverProtocol == "" {
serverProtocol = "transport"
}
transport, err := dns.CreateTransport(dns.TransportOptions{
Context: ctx,
Logger: logFactory.NewLogger(F.ToString("dns/", serverProtocol, "[", tag, "]")),
Name: tag,
Dialer: detour,
Address: server.Address,
ClientSubnet: clientSubnet,
})
if err != nil {
return nil, E.Cause(err, "parse dns server[", tag, "]")
}
transports[i] = transport
dummyTransportMap[tag] = transport
if server.Tag != "" {
transportMap[server.Tag] = transport
}
strategy := dns.DomainStrategy(server.Strategy)
if strategy != dns.DomainStrategyAsIS {
transportDomainStrategy[transport] = strategy
}
}
if len(transports) == len(dummyTransportMap) {
break
}
if lastLen != len(dummyTransportMap) {
continue
}
unresolvedTags := common.MapIndexed(common.FilterIndexed(dnsOptions.Servers, func(index int, server option.DNSServerOptions) bool {
_, exists := dummyTransportMap[transportTags[index]]
return !exists
}), func(index int, server option.DNSServerOptions) string {
return transportTags[index]
})
if len(unresolvedTags) == 0 {
panic(F.ToString("unexpected unresolved dns servers: ", len(transports), " ", len(dummyTransportMap), " ", len(transportMap)))
}
return nil, E.New("found circular reference in dns servers: ", strings.Join(unresolvedTags, " "))
}
var defaultTransport dns.Transport
if dnsOptions.Final != "" {
defaultTransport = dummyTransportMap[dnsOptions.Final]
if defaultTransport == nil {
return nil, E.New("default dns server not found: ", dnsOptions.Final)
}
}
if defaultTransport == nil {
if len(transports) == 0 {
transports = append(transports, common.Must1(dns.CreateTransport(dns.TransportOptions{
Context: ctx,
Name: "local",
Address: "local",
Dialer: common.Must1(dialer.NewDefault(ctx, option.DialerOptions{})),
})))
}
defaultTransport = transports[0]
}
if _, isFakeIP := defaultTransport.(adapter.FakeIPTransport); isFakeIP {
return nil, E.New("default DNS server cannot be fakeip")
}
router.defaultTransport = defaultTransport
router.transports = transports
router.transportMap = transportMap
router.transportDomainStrategy = transportDomainStrategy
if dnsOptions.ReverseMapping {
router.dnsReverseMapping = NewDNSReverseMapping()
}
if fakeIPOptions := dnsOptions.FakeIP; fakeIPOptions != nil && dnsOptions.FakeIP.Enabled {
var inet4Range netip.Prefix
var inet6Range netip.Prefix
if fakeIPOptions.Inet4Range != nil {
inet4Range = *fakeIPOptions.Inet4Range
}
if fakeIPOptions.Inet6Range != nil {
inet6Range = *fakeIPOptions.Inet6Range
}
router.fakeIPStore = fakeip.NewStore(ctx, router.logger, inet4Range, inet6Range)
}
return router, nil
}
func (r *Router) Start(stage adapter.StartStage) error {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
switch stage {
case adapter.StartStateInitialize:
if r.fakeIPStore != nil {
monitor.Start("initialize fakeip store")
err := r.fakeIPStore.Start()
monitor.Finish()
if err != nil {
return err
}
}
case adapter.StartStateStart:
if r.needGeoIPDatabase {
monitor.Start("initialize geoip database")
err := r.prepareGeoIPDatabase()
monitor.Finish()
if err != nil {
return err
}
}
if r.needGeositeDatabase {
monitor.Start("initialize geosite database")
err := r.prepareGeositeDatabase()
monitor.Finish()
if err != nil {
return err
}
}
if r.needGeositeDatabase {
for _, rule := range r.rules {
err := rule.UpdateGeosite()
if err != nil {
r.logger.Error("failed to initialize geosite: ", err)
}
}
for _, rule := range r.dnsRules {
err := rule.UpdateGeosite()
if err != nil {
r.logger.Error("failed to initialize geosite: ", err)
}
}
err := common.Close(r.geositeReader)
if err != nil {
return err
}
r.geositeCache = nil
r.geositeReader = nil
}
monitor.Start("initialize DNS client")
r.dnsClient.Start()
monitor.Finish()
for i, rule := range r.dnsRules {
monitor.Start("initialize DNS rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize DNS rule[", i, "]")
}
}
for i, transport := range r.transports {
monitor.Start("initialize DNS transport[", i, "]")
err := transport.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize DNS server[", i, "]")
}
}
var cacheContext *adapter.HTTPStartContext
if len(r.ruleSets) > 0 {
monitor.Start("initialize rule-set")
@ -438,7 +162,7 @@ func (r *Router) Start(stage adapter.StartStage) error {
r.started = true
return nil
case adapter.StartStateStarted:
for _, ruleSet := range r.ruleSetMap {
for _, ruleSet := range r.ruleSets {
ruleSet.Cleanup()
}
runtime.GC()
@ -456,34 +180,6 @@ func (r *Router) Close() error {
})
monitor.Finish()
}
for i, rule := range r.dnsRules {
monitor.Start("close dns rule[", i, "]")
err = E.Append(err, rule.Close(), func(err error) error {
return E.Cause(err, "close dns rule[", i, "]")
})
monitor.Finish()
}
for i, transport := range r.transports {
monitor.Start("close dns transport[", i, "]")
err = E.Append(err, transport.Close(), func(err error) error {
return E.Cause(err, "close dns transport[", i, "]")
})
monitor.Finish()
}
if r.geoIPReader != nil {
monitor.Start("close geoip reader")
err = E.Append(err, r.geoIPReader.Close(), func(err error) error {
return E.Cause(err, "close geoip reader")
})
monitor.Finish()
}
if r.fakeIPStore != nil {
monitor.Start("close fakeip store")
err = E.Append(err, r.fakeIPStore.Close(), func(err error) error {
return E.Cause(err, "close fakeip store")
})
monitor.Finish()
}
for i, ruleSet := range r.ruleSets {
monitor.Start("close rule-set[", i, "]")
err = E.Append(err, ruleSet.Close(), func(err error) error {
@ -494,10 +190,6 @@ func (r *Router) Close() error {
return err
}
func (r *Router) FakeIPStore() adapter.FakeIPStore {
return r.fakeIPStore
}
func (r *Router) RuleSet(tag string) (adapter.RuleSet, bool) {
ruleSet, loaded := r.ruleSetMap[tag]
return ruleSet, loaded
@ -517,7 +209,5 @@ func (r *Router) SetTracker(tracker adapter.ConnectionTracker) {
func (r *Router) ResetNetwork() {
r.network.ResetNetwork()
for _, transport := range r.transports {
transport.Reset()
}
r.dns.ResetNetwork()
}

View file

@ -51,18 +51,6 @@ func (r *abstractDefaultRule) Close() error {
return nil
}
func (r *abstractDefaultRule) UpdateGeosite() error {
for _, item := range r.allItems {
if geositeItem, isSite := item.(*GeositeItem); isSite {
err := geositeItem.Update()
if err != nil {
return err
}
}
}
return nil
}
func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
if len(r.allItems) == 0 {
return true
@ -173,19 +161,6 @@ func (r *abstractLogicalRule) Type() string {
return C.RuleTypeLogical
}
func (r *abstractLogicalRule) UpdateGeosite() error {
for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (adapter.Rule, bool) {
rule, loaded := it.(adapter.Rule)
return rule, loaded
}) {
err := rule.UpdateGeosite()
if err != nil {
return err
}
}
return nil
}
func (r *abstractLogicalRule) Start() error {
for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (interface {
Start() error

View file

@ -13,7 +13,6 @@ import (
"github.com/sagernet/sing-box/common/sniff"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
@ -85,7 +84,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
return sniffAction, sniffAction.build()
case C.RuleActionTypeResolve:
return &RuleActionResolve{
Strategy: dns.DomainStrategy(action.ResolveOptions.Strategy),
Strategy: C.DomainStrategy(action.ResolveOptions.Strategy),
Server: action.ResolveOptions.Server,
}, nil
default:
@ -101,6 +100,7 @@ func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction)
return &RuleActionDNSRoute{
Server: action.RouteOptions.Server,
RuleActionDNSRouteOptions: RuleActionDNSRouteOptions{
Strategy: C.DomainStrategy(action.RouteOptions.Strategy),
DisableCache: action.RouteOptions.DisableCache,
RewriteTTL: action.RouteOptions.RewriteTTL,
ClientSubnet: netip.Prefix(common.PtrValueOrDefault(action.RouteOptions.ClientSubnet)),
@ -108,6 +108,7 @@ func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction)
}
case C.RuleActionTypeRouteOptions:
return &RuleActionDNSRouteOptions{
Strategy: C.DomainStrategy(action.RouteOptionsOptions.Strategy),
DisableCache: action.RouteOptionsOptions.DisableCache,
RewriteTTL: action.RouteOptionsOptions.RewriteTTL,
ClientSubnet: netip.Prefix(common.PtrValueOrDefault(action.RouteOptionsOptions.ClientSubnet)),
@ -214,6 +215,7 @@ func (r *RuleActionDNSRoute) String() string {
}
type RuleActionDNSRouteOptions struct {
Strategy C.DomainStrategy
DisableCache bool
RewriteTTL *uint32
ClientSubnet netip.Prefix
@ -362,7 +364,7 @@ func (r *RuleActionSniff) String() string {
}
type RuleActionResolve struct {
Strategy dns.DomainStrategy
Strategy C.DomainStrategy
Server string
}
@ -371,11 +373,11 @@ func (r *RuleActionResolve) Type() string {
}
func (r *RuleActionResolve) String() string {
if r.Strategy == dns.DomainStrategyAsIS && r.Server == "" {
if r.Strategy == C.DomainStrategyAsIS && r.Server == "" {
return F.ToString("resolve")
} else if r.Strategy != dns.DomainStrategyAsIS && r.Server == "" {
} else if r.Strategy != C.DomainStrategyAsIS && r.Server == "" {
return F.ToString("resolve(", option.DomainStrategy(r.Strategy).String(), ")")
} else if r.Strategy == dns.DomainStrategyAsIS && r.Server != "" {
} else if r.Strategy == C.DomainStrategyAsIS && r.Server != "" {
return F.ToString("resolve(", r.Server, ")")
} else {
return F.ToString("resolve(", option.DomainStrategy(r.Strategy).String(), ",", r.Server, ")")

View file

@ -120,19 +120,13 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio
rule.allItems = append(rule.allItems, item)
}
if len(options.Geosite) > 0 {
item := NewGeositeItem(router, logger, options.Geosite)
rule.destinationAddressItems = append(rule.destinationAddressItems, item)
rule.allItems = append(rule.allItems, item)
return nil, E.New("geosite database is deprecated in sing-box 1.8.0 and removed in sing-box 1.12.0")
}
if len(options.SourceGeoIP) > 0 {
item := NewGeoIPItem(router, logger, true, options.SourceGeoIP)
rule.sourceAddressItems = append(rule.sourceAddressItems, item)
rule.allItems = append(rule.allItems, item)
return nil, E.New("geoip database is deprecated in sing-box 1.8.0 and removed in sing-box 1.12.0")
}
if len(options.GeoIP) > 0 {
item := NewGeoIPItem(router, logger, false, options.GeoIP)
rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
rule.allItems = append(rule.allItems, item)
return nil, E.New("geoip database is deprecated in sing-box 1.8.0 and removed in sing-box 1.12.0")
}
if len(options.SourceIPCIDR) > 0 {
item, err := NewIPCIDRItem(true, options.SourceIPCIDR)

View file

@ -111,19 +111,13 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op
rule.allItems = append(rule.allItems, item)
}
if len(options.Geosite) > 0 {
item := NewGeositeItem(router, logger, options.Geosite)
rule.destinationAddressItems = append(rule.destinationAddressItems, item)
rule.allItems = append(rule.allItems, item)
return nil, E.New("geosite database is deprecated in sing-box 1.8.0 and removed in sing-box 1.12.0")
}
if len(options.SourceGeoIP) > 0 {
item := NewGeoIPItem(router, logger, true, options.SourceGeoIP)
rule.sourceAddressItems = append(rule.sourceAddressItems, item)
rule.allItems = append(rule.allItems, item)
return nil, E.New("geoip database is deprecated in sing-box 1.8.0 and removed in sing-box 1.12.0")
}
if len(options.GeoIP) > 0 {
item := NewGeoIPItem(router, logger, false, options.GeoIP)
rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
rule.allItems = append(rule.allItems, item)
return nil, E.New("geoip database is deprecated in sing-box 1.8.0 and removed in sing-box 1.12.0")
}
if len(options.SourceIPCIDR) > 0 {
item, err := NewIPCIDRItem(true, options.SourceIPCIDR)
@ -151,6 +145,11 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op
rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
rule.allItems = append(rule.allItems, item)
}
if options.IPAcceptAny {
item := NewIPAcceptAnyItem()
rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.SourcePort) > 0 {
item := NewPortItem(true, options.SourcePort)
rule.sourcePortItems = append(rule.sourcePortItems, item)

View file

@ -1,98 +0,0 @@
package rule
import (
"net/netip"
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
N "github.com/sagernet/sing/common/network"
)
var _ RuleItem = (*GeoIPItem)(nil)
type GeoIPItem struct {
router adapter.Router
logger log.ContextLogger
isSource bool
codes []string
codeMap map[string]bool
}
func NewGeoIPItem(router adapter.Router, logger log.ContextLogger, isSource bool, codes []string) *GeoIPItem {
codeMap := make(map[string]bool)
for _, code := range codes {
codeMap[code] = true
}
return &GeoIPItem{
router: router,
logger: logger,
codes: codes,
isSource: isSource,
codeMap: codeMap,
}
}
func (r *GeoIPItem) Match(metadata *adapter.InboundContext) bool {
var geoipCode string
if r.isSource && metadata.SourceGeoIPCode != "" {
geoipCode = metadata.SourceGeoIPCode
} else if !r.isSource && metadata.GeoIPCode != "" {
geoipCode = metadata.GeoIPCode
}
if geoipCode != "" {
return r.codeMap[geoipCode]
}
var destination netip.Addr
if r.isSource {
destination = metadata.Source.Addr
} else {
destination = metadata.Destination.Addr
}
if destination.IsValid() {
return r.match(metadata, destination)
}
for _, destinationAddress := range metadata.DestinationAddresses {
if r.match(metadata, destinationAddress) {
return true
}
}
return false
}
func (r *GeoIPItem) match(metadata *adapter.InboundContext, destination netip.Addr) bool {
var geoipCode string
geoReader := r.router.GeoIPReader()
if !N.IsPublicAddr(destination) {
geoipCode = "private"
} else if geoReader != nil {
geoipCode = geoReader.Lookup(destination)
}
if geoipCode == "" {
return false
}
if r.isSource {
metadata.SourceGeoIPCode = geoipCode
} else {
metadata.GeoIPCode = geoipCode
}
return r.codeMap[geoipCode]
}
func (r *GeoIPItem) String() string {
var description string
if r.isSource {
description = "source_geoip="
} else {
description = "geoip="
}
cLen := len(r.codes)
if cLen == 1 {
description += r.codes[0]
} else if cLen > 3 {
description += "[" + strings.Join(r.codes[:3], " ") + "...]"
} else {
description += "[" + strings.Join(r.codes, " ") + "]"
}
return description
}

View file

@ -1,61 +0,0 @@
package rule
import (
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
var _ RuleItem = (*GeositeItem)(nil)
type GeositeItem struct {
router adapter.Router
logger log.ContextLogger
codes []string
matchers []adapter.Rule
}
func NewGeositeItem(router adapter.Router, logger log.ContextLogger, codes []string) *GeositeItem {
return &GeositeItem{
router: router,
logger: logger,
codes: codes,
}
}
func (r *GeositeItem) Update() error {
matchers := make([]adapter.Rule, 0, len(r.codes))
for _, code := range r.codes {
matcher, err := r.router.LoadGeosite(code)
if err != nil {
return E.Cause(err, "read geosite")
}
matchers = append(matchers, matcher)
}
r.matchers = matchers
return nil
}
func (r *GeositeItem) Match(metadata *adapter.InboundContext) bool {
for _, matcher := range r.matchers {
if matcher.Match(metadata) {
return true
}
}
return false
}
func (r *GeositeItem) String() string {
description := "geosite="
cLen := len(r.codes)
if cLen == 1 {
description += r.codes[0]
} else if cLen > 3 {
description += "[" + strings.Join(r.codes[:3], " ") + "...]"
} else {
description += "[" + strings.Join(r.codes, " ") + "]"
}
return description
}

View file

@ -0,0 +1,21 @@
package rule
import (
"github.com/sagernet/sing-box/adapter"
)
var _ RuleItem = (*IPAcceptAnyItem)(nil)
type IPAcceptAnyItem struct{}
func NewIPAcceptAnyItem() *IPAcceptAnyItem {
return &IPAcceptAnyItem{}
}
func (r *IPAcceptAnyItem) Match(metadata *adapter.InboundContext) bool {
return len(metadata.DestinationAddresses) > 0
}
func (r *IPAcceptAnyItem) String() string {
return "ip_accept_any=true"
}

View file

@ -3,7 +3,6 @@ package route
import (
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
)
func hasRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool {
@ -38,22 +37,6 @@ func hasDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bo
return false
}
func isGeoIPRule(rule option.DefaultRule) bool {
return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) || len(rule.GeoIP) > 0 && common.Any(rule.GeoIP, notPrivateNode)
}
func isGeoIPDNSRule(rule option.DefaultDNSRule) bool {
return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) || len(rule.GeoIP) > 0 && common.Any(rule.GeoIP, notPrivateNode)
}
func isGeositeRule(rule option.DefaultRule) bool {
return len(rule.Geosite) > 0
}
func isGeositeDNSRule(rule option.DefaultDNSRule) bool {
return len(rule.Geosite) > 0
}
func isProcessRule(rule option.DefaultRule) bool {
return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.ProcessPathRegex) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0
}
@ -62,10 +45,6 @@ func isProcessDNSRule(rule option.DefaultDNSRule) bool {
return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.ProcessPathRegex) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0
}
func notPrivateNode(code string) bool {
return code != "private"
}
func isWIFIRule(rule option.DefaultRule) bool {
return len(rule.WIFISSID) > 0 || len(rule.WIFIBSSID) > 0
}

View file

@ -6,7 +6,6 @@ import (
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/json/badoption"
@ -34,7 +33,7 @@ func TestTUICDomainUDP(t *testing.T) {
Listen: common.Ptr(badoption.Addr(netip.IPv4Unspecified())),
ListenPort: serverPort,
InboundOptions: option.InboundOptions{
DomainStrategy: option.DomainStrategy(dns.DomainStrategyUseIPv6),
DomainStrategy: option.DomainStrategy(C.DomainStrategyIPv6Only),
},
},
Users: []option.TUICUser{{

View file

@ -1,95 +0,0 @@
package fakeip
import (
"context"
"net/netip"
"os"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
mDNS "github.com/miekg/dns"
)
var (
_ dns.Transport = (*Transport)(nil)
_ adapter.FakeIPTransport = (*Transport)(nil)
)
func init() {
dns.RegisterTransport([]string{"fakeip"}, func(options dns.TransportOptions) (dns.Transport, error) {
return NewTransport(options)
})
}
type Transport struct {
name string
router adapter.Router
store adapter.FakeIPStore
logger logger.ContextLogger
}
func NewTransport(options dns.TransportOptions) (*Transport, error) {
router := service.FromContext[adapter.Router](options.Context)
if router == nil {
return nil, E.New("missing router in context")
}
return &Transport{
name: options.Name,
router: router,
logger: options.Logger,
}, nil
}
func (s *Transport) Name() string {
return s.name
}
func (s *Transport) Start() error {
s.store = s.router.FakeIPStore()
if s.store == nil {
return E.New("fakeip not enabled")
}
return nil
}
func (s *Transport) Reset() {
}
func (s *Transport) Close() error {
return nil
}
func (s *Transport) Raw() bool {
return false
}
func (s *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
return nil, os.ErrInvalid
}
func (s *Transport) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
var addresses []netip.Addr
if strategy != dns.DomainStrategyUseIPv6 {
inet4Address, err := s.store.Create(domain, false)
if err != nil {
return nil, err
}
addresses = append(addresses, inet4Address)
}
if strategy != dns.DomainStrategyUseIPv4 {
inet6Address, err := s.store.Create(domain, true)
if err != nil {
return nil, err
}
addresses = append(addresses, inet6Address)
}
return addresses, nil
}
func (s *Transport) Store() adapter.FakeIPStore {
return s.store
}

View file

@ -164,7 +164,7 @@ func (s *Server) Serve(listener net.Listener) error {
if len(s.tlsConfig.NextProtos()) == 0 {
s.tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"})
} else if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
s.tlsConfig.SetNextProtos(append([]string{http2.NextProtoTLS}, s.tlsConfig.NextProtos()...))
}
listener = aTLS.NewListener(listener, s.tlsConfig)
}