refactor/feat: DnsClient as IP resolver in Client

This commit is contained in:
DarkCat09 2024-08-15 19:13:17 +04:00
parent de8333d164
commit 05f22493be
Signed by: DarkCat09
GPG key ID: 0A26CD5B3345D6E3
3 changed files with 102 additions and 9 deletions

View file

@ -7,6 +7,9 @@ use crate::{
Client,
};
#[cfg(feature = "hickory")]
use crate::dns::DnsClient;
use tokio_rustls::rustls::{self, client::danger::ServerCertVerifier, SupportedProtocolVersion};
#[cfg(feature = "webpki")]
@ -18,6 +21,8 @@ pub struct ClientBuilder {
ss_verifier: Option<Box<dyn SelfsignedCertVerifier>>,
custom_verifier: Option<Arc<dyn ServerCertVerifier + 'static>>,
tls_versions: Option<&'static [&'static SupportedProtocolVersion]>,
#[cfg(feature = "hickory")]
dns: Option<DnsClient>,
}
impl Default for ClientBuilder {
@ -36,6 +41,8 @@ impl ClientBuilder {
ss_verifier: None,
custom_verifier: None,
tls_versions: None,
#[cfg(feature = "hickory")]
dns: None,
}
}
@ -92,6 +99,11 @@ impl ClientBuilder {
// TODO
let tls_config = tls_config.with_no_client_auth();
#[cfg(feature = "hickory")]
if let Some(dns) = self.dns {
return Client::from((tls_config, dns));
}
Client::from(tls_config)
}
@ -149,4 +161,10 @@ impl ClientBuilder {
self.custom_verifier = Some(Arc::new(custom_verifier));
self
}
#[cfg(feature = "hickory")]
pub fn with_dns_client(mut self, dns: DnsClient) -> Self {
self.dns = Some(dns);
self
}
}

View file

@ -5,9 +5,18 @@ pub mod response;
pub use response::Response;
#[cfg(feature = "hickory")]
use crate::dns::DnsClient;
#[cfg(feature = "hickory")]
use hickory_client::rr::IntoName;
use crate::{error::*, status::*};
use builder::ClientBuilder;
#[cfg(feature = "hickory")]
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::{
@ -22,6 +31,8 @@ use url::Url;
pub struct Client {
connector: TlsConnector,
#[cfg(feature = "hickory")]
dns: Option<DnsClient>,
}
impl From<rustls::ClientConfig> for Client {
@ -30,6 +41,19 @@ impl From<rustls::ClientConfig> for Client {
fn from(config: rustls::ClientConfig) -> Self {
Client {
connector: TlsConnector::from(Arc::new(config)),
#[cfg(feature = "hickory")]
dns: None,
}
}
}
#[cfg(feature = "hickory")]
impl From<(rustls::ClientConfig, DnsClient)> for Client {
#[inline]
fn from(value: (rustls::ClientConfig, DnsClient)) -> Self {
Client {
connector: TlsConnector::from(Arc::new(value.0)),
dns: Some(value.1),
}
}
}
@ -93,16 +117,11 @@ impl Client {
host: &str,
port: u16,
) -> Result<Response, LibError> {
let addr = tokio::net::lookup_host((host, port))
.await?
.next()
.ok_or(InvalidUrl::ConvertError)?;
let domain = pki_types::ServerName::try_from(host)
.map_err(|_| InvalidUrl::ConvertError)?
.to_owned();
let stream = TcpStream::connect(&addr).await?;
let stream = self.try_connect(host, port).await?;
let mut stream = self.connector.connect(domain, stream).await?;
// Write URL, then CRLF
@ -150,4 +169,59 @@ impl Client {
Ok(Response::new(status, message, stream))
}
pub async fn try_connect(&self, host: &str, port: u16) -> Result<TcpStream, LibError> {
let mut last_err: Option<std::io::Error> = None;
#[cfg(feature = "hickory")]
if let Some(dns) = &self.dns {
let mut dns = dns.clone();
let name = host.into_name()?;
for ip_addr in dns.query_ipv4(name.clone()).await? {
match TcpStream::connect(SocketAddr::new(ip_addr, port)).await {
Ok(stream) => {
return Ok(stream);
}
Err(err) => {
last_err = Some(err);
}
}
}
for ip_addr in dns.query_ipv6(name).await? {
match TcpStream::connect(SocketAddr::new(ip_addr, port)).await {
Ok(stream) => {
return Ok(stream);
}
Err(err) => {
last_err = Some(err);
}
}
}
if let Some(err) = last_err {
return Err(err.into());
}
return Err(LibError::HostLookupError);
}
for addr in tokio::net::lookup_host((host, port)).await? {
match TcpStream::connect(addr).await {
Ok(stream) => {
return Ok(stream);
}
Err(err) => {
last_err = Some(err);
}
}
}
if let Some(err) = last_err {
return Err(err.into());
}
Err(LibError::HostLookupError)
}
}

View file

@ -13,6 +13,7 @@ use tokio::net::ToSocketAddrs;
use crate::{certs::fingerprint::CertFingerprint, LibError};
#[derive(Clone)]
pub struct DnsClient(AsyncClient);
impl DnsClient {
@ -34,14 +35,14 @@ impl DnsClient {
pub async fn query_ipv4(
&mut self,
name: &str,
name: impl IntoName,
) -> Result<impl Iterator<Item = IpAddr>, LibError> {
self.query_ip(name, RecordType::A).await
}
pub async fn query_ipv6(
&mut self,
name: &str,
name: impl IntoName,
) -> Result<impl Iterator<Item = IpAddr>, LibError> {
self.query_ip(name, RecordType::AAAA).await
}
@ -49,7 +50,7 @@ impl DnsClient {
#[inline]
async fn query_ip(
&mut self,
name: &str,
name: impl IntoName,
rtype: RecordType,
) -> Result<impl Iterator<Item = IpAddr>, LibError> {
let answers = self