diff --git a/src/client/builder.rs b/src/client/builder.rs index 4725ad9..d49e863 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -7,6 +7,9 @@ use crate::{ Client, }; +#[cfg(feature = "hickory")] +use crate::dns::DnsClient; + use tokio_rustls::rustls::{self, client::danger::ServerCertVerifier, SupportedProtocolVersion}; #[cfg(feature = "webpki")] @@ -18,6 +21,8 @@ pub struct ClientBuilder { ss_verifier: Option>, custom_verifier: Option>, tls_versions: Option<&'static [&'static SupportedProtocolVersion]>, + #[cfg(feature = "hickory")] + dns: Option, } impl Default for ClientBuilder { @@ -36,6 +41,8 @@ impl ClientBuilder { ss_verifier: None, custom_verifier: None, tls_versions: None, + #[cfg(feature = "hickory")] + dns: None, } } @@ -92,6 +99,11 @@ impl ClientBuilder { // TODO let tls_config = tls_config.with_no_client_auth(); + #[cfg(feature = "hickory")] + if let Some(dns) = self.dns { + return Client::from((tls_config, dns)); + } + Client::from(tls_config) } @@ -149,4 +161,10 @@ impl ClientBuilder { self.custom_verifier = Some(Arc::new(custom_verifier)); self } + + #[cfg(feature = "hickory")] + pub fn with_dns_client(mut self, dns: DnsClient) -> Self { + self.dns = Some(dns); + self + } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 3f87b88..d6fe5be 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -5,9 +5,18 @@ pub mod response; pub use response::Response; +#[cfg(feature = "hickory")] +use crate::dns::DnsClient; + +#[cfg(feature = "hickory")] +use hickory_client::rr::IntoName; + use crate::{error::*, status::*}; use builder::ClientBuilder; +#[cfg(feature = "hickory")] +use std::net::SocketAddr; + use std::sync::Arc; use tokio::{ @@ -22,6 +31,8 @@ use url::Url; pub struct Client { connector: TlsConnector, + #[cfg(feature = "hickory")] + dns: Option, } impl From for Client { @@ -30,6 +41,19 @@ impl From for Client { fn from(config: rustls::ClientConfig) -> Self { Client { connector: TlsConnector::from(Arc::new(config)), + #[cfg(feature = "hickory")] + dns: None, + } + } +} + +#[cfg(feature = "hickory")] +impl From<(rustls::ClientConfig, DnsClient)> for Client { + #[inline] + fn from(value: (rustls::ClientConfig, DnsClient)) -> Self { + Client { + connector: TlsConnector::from(Arc::new(value.0)), + dns: Some(value.1), } } } @@ -93,16 +117,11 @@ impl Client { host: &str, port: u16, ) -> Result { - let addr = tokio::net::lookup_host((host, port)) - .await? - .next() - .ok_or(InvalidUrl::ConvertError)?; - let domain = pki_types::ServerName::try_from(host) .map_err(|_| InvalidUrl::ConvertError)? .to_owned(); - let stream = TcpStream::connect(&addr).await?; + let stream = self.try_connect(host, port).await?; let mut stream = self.connector.connect(domain, stream).await?; // Write URL, then CRLF @@ -150,4 +169,59 @@ impl Client { Ok(Response::new(status, message, stream)) } + + pub async fn try_connect(&self, host: &str, port: u16) -> Result { + let mut last_err: Option = None; + + #[cfg(feature = "hickory")] + if let Some(dns) = &self.dns { + let mut dns = dns.clone(); + let name = host.into_name()?; + + for ip_addr in dns.query_ipv4(name.clone()).await? { + match TcpStream::connect(SocketAddr::new(ip_addr, port)).await { + Ok(stream) => { + return Ok(stream); + } + Err(err) => { + last_err = Some(err); + } + } + } + + for ip_addr in dns.query_ipv6(name).await? { + match TcpStream::connect(SocketAddr::new(ip_addr, port)).await { + Ok(stream) => { + return Ok(stream); + } + Err(err) => { + last_err = Some(err); + } + } + } + + if let Some(err) = last_err { + return Err(err.into()); + } + + return Err(LibError::HostLookupError); + } + + for addr in tokio::net::lookup_host((host, port)).await? { + match TcpStream::connect(addr).await { + Ok(stream) => { + return Ok(stream); + } + Err(err) => { + last_err = Some(err); + } + } + } + + if let Some(err) = last_err { + return Err(err.into()); + } + + Err(LibError::HostLookupError) + } } diff --git a/src/dns.rs b/src/dns.rs index ce74922..95fa74f 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -13,6 +13,7 @@ use tokio::net::ToSocketAddrs; use crate::{certs::fingerprint::CertFingerprint, LibError}; +#[derive(Clone)] pub struct DnsClient(AsyncClient); impl DnsClient { @@ -34,14 +35,14 @@ impl DnsClient { pub async fn query_ipv4( &mut self, - name: &str, + name: impl IntoName, ) -> Result, LibError> { self.query_ip(name, RecordType::A).await } pub async fn query_ipv6( &mut self, - name: &str, + name: impl IntoName, ) -> Result, LibError> { self.query_ip(name, RecordType::AAAA).await } @@ -49,7 +50,7 @@ impl DnsClient { #[inline] async fn query_ip( &mut self, - name: &str, + name: impl IntoName, rtype: RecordType, ) -> Result, LibError> { let answers = self