diff --git a/common/tls/ech_client.go b/common/tls/ech_client.go index 80219350..9c3744b4 100644 --- a/common/tls/ech_client.go +++ b/common/tls/ech_client.go @@ -221,7 +221,7 @@ func fetchECHClientConfig(ctx context.Context) func(_ context.Context, serverNam return nil, err } if response.Rcode != mDNS.RcodeSuccess { - return nil, dns.RCodeError(response.Rcode) + return nil, dns.RcodeError(response.Rcode) } for _, rr := range response.Answer { switch resource := rr.(type) { diff --git a/dns/client.go b/dns/client.go index 79b6fce5..d44883b5 100644 --- a/dns/client.go +++ b/dns/client.go @@ -17,7 +17,7 @@ import ( "github.com/sagernet/sing/contrab/freelru" "github.com/sagernet/sing/contrab/maphash" - "github.com/miekg/dns" + dns "github.com/miekg/dns" ) var ( @@ -484,7 +484,7 @@ func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransp func MessageToAddresses(response *dns.Msg) ([]netip.Addr, error) { if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError { - return nil, RCodeError(response.Rcode) + return nil, RcodeError(response.Rcode) } addresses := make([]netip.Addr, 0, len(response.Answer)) for _, rawAnswer := range response.Answer { @@ -508,10 +508,10 @@ func wrapError(err error) error { switch dnsErr := err.(type) { case *net.DNSError: if dnsErr.IsNotFound { - return RCodeNameError + return RcodeNameError } case *net.AddrError: - return RCodeNameError + return RcodeNameError } return err } @@ -561,3 +561,73 @@ func FixedResponse(id uint16, question dns.Question, addresses []netip.Addr, tim } return &response } + +func FixedResponseCNAME(id uint16, question dns.Question, record string, timeToLive uint32) *dns.Msg { + response := dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: id, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{question}, + Answer: []dns.RR{ + &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: timeToLive, + }, + Target: record, + }, + }, + } + return &response +} + +func FixedResponseTXT(id uint16, question dns.Question, records []string, timeToLive uint32) *dns.Msg { + response := dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: id, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{question}, + Answer: []dns.RR{ + &dns.TXT{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: timeToLive, + }, + Txt: records, + }, + }, + } + return &response +} + +func FixedResponseMX(id uint16, question dns.Question, records []*net.MX, timeToLive uint32) *dns.Msg { + response := dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: id, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{question}, + } + for _, record := range records { + response.Answer = append(response.Answer, &dns.MX{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: timeToLive, + }, + Preference: record.Pref, + Mx: record.Host, + }) + } + return &response +} diff --git a/dns/rcode.go b/dns/rcode.go index 5b7e52cc..08545474 100644 --- a/dns/rcode.go +++ b/dns/rcode.go @@ -1,33 +1,17 @@ package dns -import F "github.com/sagernet/sing/common/format" - -const ( - RCodeSuccess RCodeError = 0 // NoError - RCodeFormatError RCodeError = 1 // FormErr - RCodeServerFailure RCodeError = 2 // ServFail - RCodeNameError RCodeError = 3 // NXDomain - RCodeNotImplemented RCodeError = 4 // NotImp - RCodeRefused RCodeError = 5 // Refused +import ( + mDNS "github.com/miekg/dns" ) -type RCodeError uint16 +const ( + RcodeFormatError RcodeError = mDNS.RcodeFormatError + RcodeNameError RcodeError = mDNS.RcodeNameError + RcodeRefused RcodeError = mDNS.RcodeRefused +) -func (e RCodeError) Error() string { - switch e { - case RCodeSuccess: - return "success" - case RCodeFormatError: - return "format error" - case RCodeServerFailure: - return "server failure" - case RCodeNameError: - return "name error" - case RCodeNotImplemented: - return "not implemented" - case RCodeRefused: - return "refused" - default: - return F.ToString("unknown error: ", uint16(e)) - } +type RcodeError int + +func (e RcodeError) Error() string { + return mDNS.RcodeToString[int(e)] } diff --git a/dns/router.go b/dns/router.go index 4102128e..5db5429a 100644 --- a/dns/router.go +++ b/dns/router.go @@ -329,13 +329,13 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ } } else if len(responseAddrs) == 0 { r.logger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result") - err = RCodeNameError + err = RcodeNameError } } responseAddrs, cached = r.client.LookupCache(domain, options.Strategy) if cached { if len(responseAddrs) == 0 { - return nil, RCodeNameError + return nil, RcodeNameError } return responseAddrs, nil } diff --git a/dns/transport/local/local.go b/dns/transport/local/local.go index 7f02c608..e4236275 100644 --- a/dns/transport/local/local.go +++ b/dns/transport/local/local.go @@ -19,10 +19,6 @@ import ( mDNS "github.com/miekg/dns" ) -func RegisterTransport(registry *dns.TransportRegistry) { - dns.RegisterTransport[option.LocalDNSServerOptions](registry, C.DNSTypeLocal, NewTransport) -} - var _ adapter.DNSTransport = (*Transport)(nil) type Transport struct { diff --git a/dns/transport/local/local_fallback.go b/dns/transport/local/local_fallback.go new file mode 100644 index 00000000..0a7cd9f0 --- /dev/null +++ b/dns/transport/local/local_fallback.go @@ -0,0 +1,201 @@ +package local + +import ( + "context" + "errors" + "net" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/experimental/libbox/platform" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/service" + + mDNS "github.com/miekg/dns" +) + +func RegisterTransport(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.LocalDNSServerOptions](registry, C.DNSTypeLocal, NewFallbackTransport) +} + +type FallbackTransport struct { + adapter.DNSTransport + ctx context.Context + fallback bool + resolver net.Resolver +} + +func NewFallbackTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) { + transport, err := NewTransport(ctx, logger, tag, options) + if err != nil { + return nil, err + } + return &FallbackTransport{ + DNSTransport: transport, + ctx: ctx, + }, nil +} + +func (f *FallbackTransport) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + platformInterface := service.FromContext[platform.Interface](f.ctx) + if platformInterface == nil { + return nil + } + inboundManager := service.FromContext[adapter.InboundManager](f.ctx) + for _, inbound := range inboundManager.Inbounds() { + if inbound.Type() == C.TypeTun { + // platform tun hijacks DNS, so we can only use cgo resolver here + f.fallback = true + break + } + } + return nil +} + +func (f *FallbackTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + if f.fallback { + return f.DNSTransport.Exchange(ctx, message) + } + question := message.Question[0] + domain := dns.FqdnToDomain(question.Name) + if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA { + var network string + if question.Qtype == mDNS.TypeA { + network = "ip4" + } else { + network = "ip6" + } + addresses, err := f.resolver.LookupNetIP(ctx, network, domain) + if err != nil { + var dnsError *net.DNSError + if errors.As(err, &dnsError) && dnsError.IsNotFound { + return nil, dns.RcodeRefused + } + return nil, err + } + return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil + } else if question.Qtype == mDNS.TypeNS { + records, err := f.resolver.LookupNS(ctx, domain) + if err != nil { + var dnsError *net.DNSError + if errors.As(err, &dnsError) && dnsError.IsNotFound { + return nil, dns.RcodeRefused + } + return nil, err + } + response := &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: message.Id, + Rcode: mDNS.RcodeSuccess, + Response: true, + }, + Question: []mDNS.Question{question}, + } + for _, record := range records { + response.Answer = append(response.Answer, &mDNS.NS{ + Hdr: mDNS.RR_Header{ + Name: question.Name, + Rrtype: mDNS.TypeNS, + Class: mDNS.ClassINET, + Ttl: C.DefaultDNSTTL, + }, + Ns: record.Host, + }) + } + return response, nil + } else if question.Qtype == mDNS.TypeCNAME { + cname, err := f.resolver.LookupCNAME(ctx, domain) + if err != nil { + var dnsError *net.DNSError + if errors.As(err, &dnsError) && dnsError.IsNotFound { + return nil, dns.RcodeRefused + } + return nil, err + } + return &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: message.Id, + Rcode: mDNS.RcodeSuccess, + Response: true, + }, + Question: []mDNS.Question{question}, + Answer: []mDNS.RR{ + &mDNS.CNAME{ + Hdr: mDNS.RR_Header{ + Name: question.Name, + Rrtype: mDNS.TypeCNAME, + Class: mDNS.ClassINET, + Ttl: C.DefaultDNSTTL, + }, + Target: cname, + }, + }, + }, nil + } else if question.Qtype == mDNS.TypeTXT { + records, err := f.resolver.LookupTXT(ctx, domain) + if err != nil { + var dnsError *net.DNSError + if errors.As(err, &dnsError) && dnsError.IsNotFound { + return nil, dns.RcodeRefused + } + return nil, err + } + return &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: message.Id, + Rcode: mDNS.RcodeSuccess, + Response: true, + }, + Question: []mDNS.Question{question}, + Answer: []mDNS.RR{ + &mDNS.TXT{ + Hdr: mDNS.RR_Header{ + Name: question.Name, + Rrtype: mDNS.TypeCNAME, + Class: mDNS.ClassINET, + Ttl: C.DefaultDNSTTL, + }, + Txt: records, + }, + }, + }, nil + } else if question.Qtype == mDNS.TypeMX { + records, err := f.resolver.LookupMX(ctx, domain) + if err != nil { + var dnsError *net.DNSError + if errors.As(err, &dnsError) && dnsError.IsNotFound { + return nil, dns.RcodeRefused + } + return nil, err + } + response := &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: message.Id, + Rcode: mDNS.RcodeSuccess, + Response: true, + }, + Question: []mDNS.Question{question}, + } + for _, record := range records { + response.Answer = append(response.Answer, &mDNS.MX{ + Hdr: mDNS.RR_Header{ + Name: question.Name, + Rrtype: mDNS.TypeA, + Class: mDNS.ClassINET, + Ttl: C.DefaultDNSTTL, + }, + Preference: record.Pref, + Mx: record.Host, + }) + } + return response, nil + } else { + return nil, E.New("only A, AAAA, NS, CNAME, TXT, MX queries are supported on current platform when using TUN, please switch to a fixed DNS server.") + } +} diff --git a/dns/transport/local/resolv_windows.go b/dns/transport/local/resolv_windows.go index 577e7a12..f6b81090 100644 --- a/dns/transport/local/resolv_windows.go +++ b/dns/transport/local/resolv_windows.go @@ -69,9 +69,6 @@ func dnsReadConfig(_ string) *dnsConfig { return conf } -//go:linkname defaultNS net.defaultNS -var defaultNS []string - func adapterAddresses() ([]*windows.IpAdapterAddresses, error) { var b []byte l := uint32(15000) // recommended initial size diff --git a/dns/transport/predefined.go b/dns/transport/predefined.go index 3f112886..dbb78e5c 100644 --- a/dns/transport/predefined.go +++ b/dns/transport/predefined.go @@ -79,5 +79,5 @@ func (t *PredefinedTransport) Exchange(ctx context.Context, message *mDNS.Msg) ( } } } - return nil, dns.RCodeNameError + return nil, dns.RcodeNameError } diff --git a/experimental/libbox/dns.go b/experimental/libbox/dns.go index a7ccd2a2..7e143442 100644 --- a/experimental/libbox/dns.go +++ b/experimental/libbox/dns.go @@ -134,7 +134,7 @@ func (c *ExchangeContext) RawSuccess(result []byte) { } func (c *ExchangeContext) ErrorCode(code int32) { - c.error = dns.RCodeError(code) + c.error = dns.RcodeError(code) } func (c *ExchangeContext) ErrnoCode(code int32) { diff --git a/protocol/dns/handle.go b/protocol/dns/handle.go index c4ad79d9..765e5051 100644 --- a/protocol/dns/handle.go +++ b/protocol/dns/handle.go @@ -26,7 +26,7 @@ func HandleStreamDNSRequest(ctx context.Context, router adapter.DNSRouter, conn return err } if queryLength == 0 { - return dns.RCodeFormatError + return dns.RcodeFormatError } buffer := buf.NewSize(int(queryLength)) defer buffer.Release() diff --git a/protocol/tailscale/dns_transport.go b/protocol/tailscale/dns_transport.go index 0c83c698..3447b6b2 100644 --- a/protocol/tailscale/dns_transport.go +++ b/protocol/tailscale/dns_transport.go @@ -287,7 +287,7 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M return nil, E.New("missing default resolvers") } } - return nil, dns.RCodeNameError + return nil, dns.RcodeNameError } type DNSDialer struct {