hysteria/extras/outbounds/dns_standard.go
2023-08-13 13:32:11 -07:00

221 lines
5.4 KiB
Go

package outbounds
import (
"crypto/tls"
"net"
"time"
"github.com/miekg/dns"
)
const (
resolverDefaultTimeout = 2 * time.Second
standardResolverRetryTimes = 2
)
// standardResolver is a PluggableOutbound DNS resolver that resolves hostnames
// using the user-provided DNS server.
// Based on "github.com/miekg/dns", it supports UDP, TCP & DNS-over-TLS (TCP).
type standardResolver struct {
Addr string
Client *dns.Client
Next PluggableOutbound
}
func NewStandardResolverUDP(addr string, timeout time.Duration, next PluggableOutbound) PluggableOutbound {
return &standardResolver{
Addr: addDefaultPort(addr),
Client: &dns.Client{
Timeout: timeoutOrDefault(timeout),
},
Next: next,
}
}
func NewStandardResolverTCP(addr string, timeout time.Duration, next PluggableOutbound) PluggableOutbound {
return &standardResolver{
Addr: addDefaultPort(addr),
Client: &dns.Client{
Net: "tcp",
Timeout: timeoutOrDefault(timeout),
},
Next: next,
}
}
func NewStandardResolverTLS(addr string, timeout time.Duration, sni string, insecure bool, next PluggableOutbound) PluggableOutbound {
return &standardResolver{
Addr: addDefaultPortTLS(addr),
Client: &dns.Client{
Net: "tcp-tls",
Timeout: timeoutOrDefault(timeout),
TLSConfig: &tls.Config{
ServerName: sni,
InsecureSkipVerify: insecure,
},
},
Next: next,
}
}
// addDefaultPort adds the default DNS port (53) to the address if not present.
func addDefaultPort(addr string) string {
if _, _, err := net.SplitHostPort(addr); err != nil {
return net.JoinHostPort(addr, "53")
}
return addr
}
// addDefaultPortTLS adds the default DNS-over-TLS port (853) to the address if not present.
func addDefaultPortTLS(addr string) string {
if _, _, err := net.SplitHostPort(addr); err != nil {
return net.JoinHostPort(addr, "853")
}
return addr
}
func timeoutOrDefault(timeout time.Duration) time.Duration {
if timeout == 0 {
return resolverDefaultTimeout
}
return timeout
}
// skipCNAMEChain skips the CNAME chain and returns the last CNAME target.
// Sometimes the DNS server returns a CNAME chain like this, in one packet:
// domain1.com. CNAME domain2.com.
// domain2.com. CNAME domain3.com.
// In this case, we should avoid sending a query for domain2.com and go
// straight to domain3.com.
func (r *standardResolver) skipCNAMEChain(answers []dns.RR) string {
var lastCNAME string
for _, a := range answers {
if cname, ok := a.(*dns.CNAME); ok {
if lastCNAME == "" {
// First CNAME
lastCNAME = cname.Target
} else if cname.Hdr.Name == lastCNAME {
// CNAME chain
lastCNAME = cname.Target
} else {
// CNAME chain ends
return lastCNAME
}
}
}
return lastCNAME
}
// lookup4 resolves a hostname to an IPv4 address.
// If there's no IPv4 address, it returns (nil, nil), no error.
func (r *standardResolver) lookup4(host string) (net.IP, error) {
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(host), dns.TypeA)
m.RecursionDesired = true
resp, _, err := r.Client.Exchange(m, r.Addr)
if err != nil {
return nil, err
}
if len(resp.Answer) == 0 {
return nil, nil
}
// Sometimes the DNS server returns both CNAME and A records in one packet.
hasCNAME := false
for _, a := range resp.Answer {
if aa, ok := a.(*dns.A); ok {
return aa.A.To4(), nil
} else if _, ok := a.(*dns.CNAME); ok {
hasCNAME = true
}
}
if hasCNAME {
return r.lookup4(r.skipCNAMEChain(resp.Answer))
} else {
// Should not happen
return nil, nil
}
}
// lookup6 resolves a hostname to an IPv6 address.
// If there's no IPv6 address, it returns (nil, nil), no error.
func (r *standardResolver) lookup6(host string) (net.IP, error) {
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
m.RecursionDesired = true
resp, _, err := r.Client.Exchange(m, r.Addr)
if err != nil {
return nil, err
}
if len(resp.Answer) == 0 {
return nil, nil
}
// Sometimes the DNS server returns both CNAME and AAAA records in one packet.
hasCNAME := false
for _, a := range resp.Answer {
if aa, ok := a.(*dns.AAAA); ok {
return aa.AAAA.To16(), nil
} else if _, ok := a.(*dns.CNAME); ok {
hasCNAME = true
}
}
if hasCNAME {
return r.lookup6(r.skipCNAMEChain(resp.Answer))
} else {
// Should not happen
return nil, nil
}
}
func (r *standardResolver) resolve(reqAddr *AddrEx) {
if tryParseIP(reqAddr) {
// The host is already an IP address, we don't need to resolve it.
return
}
type lookupResult struct {
ip net.IP
err error
}
ch4, ch6 := make(chan lookupResult, 1), make(chan lookupResult, 1)
go func() {
var ip net.IP
var err error
for i := 0; i < standardResolverRetryTimes; i++ {
ip, err = r.lookup4(reqAddr.Host)
if err == nil {
break
}
}
ch4 <- lookupResult{ip, err}
}()
go func() {
var ip net.IP
var err error
for i := 0; i < standardResolverRetryTimes; i++ {
ip, err = r.lookup6(reqAddr.Host)
if err == nil {
break
}
}
ch6 <- lookupResult{ip, err}
}()
result4, result6 := <-ch4, <-ch6
reqAddr.ResolveInfo = &ResolveInfo{
IPv4: result4.ip,
IPv6: result6.ip,
}
if result4.err != nil {
reqAddr.ResolveInfo.Err = result4.err
} else if result6.err != nil {
reqAddr.ResolveInfo.Err = result6.err
}
}
func (r *standardResolver) TCP(reqAddr *AddrEx) (net.Conn, error) {
r.resolve(reqAddr)
return r.Next.TCP(reqAddr)
}
func (r *standardResolver) UDP(reqAddr *AddrEx) (UDPConn, error) {
r.resolve(reqAddr)
return r.Next.UDP(reqAddr)
}