diff --git a/src/client/mod.rs b/src/client/mod.rs index b49336c..a6c358b 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -25,7 +25,7 @@ use tokio::{ io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}, net::TcpStream, }; -use tokio_rustls::TlsConnector; +use tokio_rustls::{rustls, TlsConnector}; use url::Url; pub struct Client { @@ -101,6 +101,20 @@ impl Client { let stream = self.try_connect(host, port).await?; let mut stream = self.connector.connect(domain, stream).await?; + if let Some(ssv) = &self.ss_verifier { + let cert = stream + .get_ref() + .1 // rustls::ClientConnection + .peer_certificates() + .unwrap() // i think handshake already completed if we awaited on connector.connect? + .first() + .ok_or(rustls::Error::NoCertificatesPresented)?; + + if !ssv.verify(cert, host, port).await? { + return Err(rustls::CertificateError::ApplicationVerificationFailure.into()); + } + } + // Write URL, then CRLF stream.write_all(url_str.as_bytes()).await?; stream.write_all(b"\r\n").await?; diff --git a/src/error.rs b/src/error.rs index 1a0c435..9d96d04 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,11 +1,11 @@ //! Library error structures and enums +use tokio_rustls::rustls; + #[cfg(feature = "hickory")] use hickory_client::{ error::ClientError as HickoryClientError, proto::error::ProtoError as HickoryProtoError, }; -#[cfg(feature = "hickory")] -use tokio::runtime::TryCurrentError; /// Main error structure, also a wrapper for everything else #[derive(Debug)] @@ -15,9 +15,12 @@ pub enum LibError { IoError(std::io::Error), /// URL parse or check error InvalidUrlError(InvalidUrl), - /// DNS provided no suitable records + /// DNS server has provided no suitable records /// (e. g. domain does not exist) HostLookupError, + /// TLS library error related to certificate/signature + /// verification failure or connection failure + RustlsError(rustls::Error), /// Response status code is out of [10; 69] range StatusOutOfRange(u8), /// Response metadata or content cannot be parsed @@ -25,16 +28,9 @@ pub enum LibError { DataNotUtf8(std::string::FromUtf8Error), /// Provided string is not a valid MIME type InvalidMime(mime::FromStrError), - /// Hickory Client error + /// Hickory DNS client error #[cfg(feature = "hickory")] DnsClientError(HickoryClientError), - /// Hickory Proto error - #[cfg(feature = "hickory")] - DnsProtoError(HickoryProtoError), - /// Could not get Tokio runtime handle - /// inside Rustls cert verifier - #[cfg(feature = "hickory")] - NoTokioRuntime(TryCurrentError), } impl From for LibError { @@ -58,6 +54,20 @@ impl From for LibError { } } +impl From for LibError { + #[inline] + fn from(err: rustls::Error) -> Self { + Self::RustlsError(err) + } +} + +impl From for LibError { + #[inline] + fn from(err: rustls::CertificateError) -> Self { + Self::RustlsError(err.into()) + } +} + impl LibError { #[inline] pub fn status_out_of_range(num: u8) -> Self { @@ -91,15 +101,7 @@ impl From for LibError { impl From for LibError { #[inline] fn from(err: HickoryProtoError) -> Self { - Self::DnsProtoError(err) - } -} - -#[cfg(feature = "hickory")] -impl From for LibError { - #[inline] - fn from(err: TryCurrentError) -> Self { - Self::NoTokioRuntime(err) + Self::DnsClientError(err.into()) } }