From e042a139bfdef379b39b85a242f702cf971a2f06 Mon Sep 17 00:00:00 2001 From: DarkCat09 Date: Wed, 21 Aug 2024 15:46:52 +0400 Subject: [PATCH] refactor: rewrite self-signed cert verifier --- Cargo.lock | 1 + Cargo.toml | 1 + examples/main.rs | 109 +++++++++--- examples/simple.rs | 13 +- src/certs/{insecure.rs => allow_all.rs} | 29 ++-- src/certs/dane.rs | 54 ++++++ src/certs/file_sscv.rs | 214 +++++++----------------- src/certs/fingerprint.rs | 11 +- src/certs/mod.rs | 24 +-- src/client/builder.rs | 63 +++---- src/client/mod.rs | 47 ++---- 11 files changed, 294 insertions(+), 272 deletions(-) rename src/certs/{insecure.rs => allow_all.rs} (78%) create mode 100644 src/certs/dane.rs diff --git a/Cargo.lock b/Cargo.lock index 5b239bc..a3bbbe5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -879,6 +879,7 @@ dependencies = [ name = "tokio-gemini" version = "0.4.0" dependencies = [ + "async-trait", "base16ct", "base64ct", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 39ab814..2fefd57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ tokio-rustls = { version = "0.26.0", default-features = false, features = ["ring dashmap = { version = "6.0.1", optional = true } hickory-client = { version = "0.24.1", optional = true } +async-trait = "0.1.81" [dev-dependencies] tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] } diff --git a/examples/main.rs b/examples/main.rs index 35ff328..b1172ca 100644 --- a/examples/main.rs +++ b/examples/main.rs @@ -1,11 +1,13 @@ use tokio_gemini::{ - certs::{file_sscv::FileBasedCertVerifier, insecure::AllowAllCertVerifier}, + certs::{ + dane, file_sscv::KnownHostsFile, fingerprint::CertFingerprint, SelfsignedCertVerifier, + }, dns::DnsClient, Client, LibError, }; // -// cargo add tokio_gemini -F file-sscv,hickory +// cargo add tokio_gemini -F file-sscv,dane // cargo add tokio -F macros,rt-multi-thread,io-util,fs // @@ -13,12 +15,6 @@ const USAGE: &str = "-k\t\tinsecure mode (trust all certs) -d \tuse custom DNS for resolving & DANE -h\t\tshow help"; -struct Config { - insecure: bool, - dns: Option, - url: String, -} - #[tokio::main] async fn main() -> Result<(), LibError> { let config = parse_args(); @@ -68,11 +64,11 @@ fn parse_args() -> Config { "-k" => config.insecure = true, "-d" => expected_dns = true, "-h" => { - println!("{}", USAGE); + eprintln!("{}", USAGE); std::process::exit(0); } url => { - println!("URL: {}", url); + eprintln!("URL: {}", url); config.url = url.to_owned(); break; } @@ -80,7 +76,7 @@ fn parse_args() -> Config { } if expected_dns { - println!("{}", USAGE); + eprintln!("{}", USAGE); std::process::exit(0); } @@ -94,21 +90,94 @@ async fn build_client(config: &Config) -> Result { None }; + let known_hosts = KnownHostsFile::parse_file("known_hosts").await?; + let verifier = CertVerifier { + known_hosts, + dns: dns.clone(), + }; + let client = tokio_gemini::Client::builder(); let client = if config.insecure { - client.with_custom_verifier(AllowAllCertVerifier::yes_i_know_what_i_am_doing()) + client.dangerous_with_no_verifier() } else { - client.with_selfsigned_cert_verifier( - FileBasedCertVerifier::init("known_hosts", dns.clone()).await?, - ) + client.with_selfsigned_cert_verifier(verifier) }; - let client = if let Some(dns) = dns { - client.with_dns_client(dns) - } else { - client - }; + let client = client.maybe_with_dns_client(dns); Ok(client.build()) } + +struct Config { + insecure: bool, + dns: Option, + url: String, +} + +struct CertVerifier { + known_hosts: KnownHostsFile, + dns: Option, +} + +impl SelfsignedCertVerifier for CertVerifier { + async fn verify<'c>( + &self, + cert: &'c tokio_gemini::certs::CertificateDer<'c>, + host: &str, + port: u16, + ) -> Result { + if let Some(known) = self.known_hosts.get_known_cert(host) { + // if found in known_hosts, just compare certs + Ok(known.fingerprint.hash_and_compare(cert)) + } else { + // otherwise, generate a hash and add to known_hosts + let hash = if let Some(dns) = &self.dns { + // if DNS client is configured, try verifying the cert + // via DANE instead of blindly trusting + match dane::dane(&dns, cert, host, port).await { + Ok(hash) => hash, // use the fingerprint matched with TLSA record + Err(LibError::HostLookupError) => { + // no TLSA record found -- server admin haven't set it up + eprintln!( + "TLSA not configured for tcp:{}:{}, trusting on first use", + host, port, + ); + // just generate a hash for this cert + CertFingerprint::new_sha256(cert) + } + Err(e) => { + // some other problem (e.g. DNS server rejected the request), + // we shouldn't continue + eprintln!("DANE verification failed: {:?}", e); + return Err(e); + } + } + } else { + eprintln!("DANE disabled"); + // just generate a hash for this cert + CertFingerprint::new_sha256(cert) + }; + + let fingerprint = hash.base64(); + let fptype = hash.fingerprint_type_str(); + eprintln!( + "Warning: adding trusted cert for {} with FP {}", + host, &fingerprint, + ); + + // adding the cert hash to trusted + // can be done simplier: + // self.known_hosts.add_trusted_cert(host, hash).await.unwrap... + self.known_hosts.add_cert_to_hashmap(host, hash); + self.known_hosts + .add_cert_to_file(host, &fingerprint, fptype) + .await + .unwrap_or_else(|e| { + eprintln!("Cert saved in-memory, unable to write to file: {:?}", e); + }); + + Ok(true) + } + } +} diff --git a/examples/simple.rs b/examples/simple.rs index b652577..dab66cd 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -33,17 +33,18 @@ async fn main() -> Result<(), LibError> { struct CertVerifier; impl SelfsignedCertVerifier for CertVerifier { - fn verify( + async fn verify<'c>( &self, - cert: &tokio_gemini::certs::CertificateDer, + cert: &'c tokio_gemini::certs::CertificateDer<'c>, host: &str, - _now: tokio_gemini::certs::UnixTime, - ) -> Result { - // For real verification example with known_hosts file + port: u16, + ) -> Result { + // For real verification example with known_hosts and DANE // see examples/main.rs eprintln!( - "Host = {}\nFingerprint = {}", + "Host = {}:{}\nFingerprint = {}", host, + port, CertFingerprint::new_sha256(cert).base64(), ); Ok(true) diff --git a/src/certs/insecure.rs b/src/certs/allow_all.rs similarity index 78% rename from src/certs/insecure.rs rename to src/certs/allow_all.rs index 8322c65..d204116 100644 --- a/src/certs/insecure.rs +++ b/src/certs/allow_all.rs @@ -1,6 +1,8 @@ //! Custom verifier for Rustls accepting any TLS cert //! (usually called "insecure mode") +use std::sync::Arc; + use tokio_rustls::rustls::{ self, client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, @@ -9,28 +11,28 @@ use tokio_rustls::rustls::{ /// Custom verifier for Rustls accepting any TLS certificate #[derive(Debug)] -pub struct AllowAllCertVerifier(std::sync::Arc); +pub struct AllowAllCertVerifier(Arc); impl AllowAllCertVerifier { - /// Constructor for this verifier. - /// Use only if you know what you are doing. - /// - /// # Examples - /// ``` - /// let client = tokio_gemini::Client::builder() - /// .with_custom_verifier(AllowAllCertVerifier::yes_i_know_what_i_am_doing()) - /// .build() - /// ``` - pub fn yes_i_know_what_i_am_doing() -> Self { + /// Constructor for this verifier + pub fn new() -> Self { AllowAllCertVerifier( CryptoProvider::get_default() .cloned() - .unwrap_or_else(|| std::sync::Arc::new(rustls::crypto::ring::default_provider())), + .unwrap_or_else(|| Arc::new(rustls::crypto::ring::default_provider())), ) } } +impl From> for AllowAllCertVerifier { + #[inline] + fn from(value: Arc) -> Self { + AllowAllCertVerifier(value) + } +} + impl ServerCertVerifier for AllowAllCertVerifier { + #[inline] fn verify_server_cert( &self, _end_entity: &rustls::pki_types::CertificateDer<'_>, @@ -42,6 +44,7 @@ impl ServerCertVerifier for AllowAllCertVerifier { Ok(ServerCertVerified::assertion()) } + #[inline] fn verify_tls12_signature( &self, message: &[u8], @@ -56,6 +59,7 @@ impl ServerCertVerifier for AllowAllCertVerifier { ) } + #[inline] fn verify_tls13_signature( &self, message: &[u8], @@ -70,6 +74,7 @@ impl ServerCertVerifier for AllowAllCertVerifier { ) } + #[inline] fn supported_verify_schemes(&self) -> Vec { self.0.signature_verification_algorithms.supported_schemes() } diff --git a/src/certs/dane.rs b/src/certs/dane.rs new file mode 100644 index 0000000..9eccbb8 --- /dev/null +++ b/src/certs/dane.rs @@ -0,0 +1,54 @@ +use crate::{ + certs::{fingerprint::CertFingerprint, CertificateDer}, + dns::DnsClient, + LibError, +}; + +pub async fn dane<'d>( + dns: &DnsClient, + cert: &CertificateDer<'d>, + host: &str, + port: u16, +) -> Result { + let mut dns = dns.clone(); + + let mut sha256: Option = None; + let mut sha512: Option = None; + + for tlsa_fp in dns.query_tlsa(host, port).await? { + match tlsa_fp { + CertFingerprint::Sha256(_) => { + if sha256.is_none() { + sha256 = Some(CertFingerprint::new_sha256(cert)); + } + let this_fp = sha256.as_ref().unwrap(); + if this_fp == &tlsa_fp { + return Ok(sha256.unwrap()); + } + } + CertFingerprint::Sha512(_) => { + if sha512.is_none() { + sha512 = Some(CertFingerprint::new_sha512(cert)); + } + let this_fp = sha512.as_ref().unwrap(); + if this_fp == &tlsa_fp { + return Ok(sha512.unwrap()); + } + } + CertFingerprint::Raw(_) => { + let this_fp = CertFingerprint::new_raw(cert); + if this_fp == tlsa_fp { + return Ok(CertFingerprint::new_sha256(cert)); + } + } + } + } + + if let Some(sha256) = sha256 { + Ok(sha256) + } else if let Some(sha512) = sha512 { + Ok(sha512) + } else { + Ok(CertFingerprint::new_sha256(cert)) + } +} diff --git a/src/certs/file_sscv.rs b/src/certs/file_sscv.rs index 041e7ea..31944e9 100644 --- a/src/certs/file_sscv.rs +++ b/src/certs/file_sscv.rs @@ -1,35 +1,25 @@ -use std::{ - borrow::Cow, - io::{BufWriter, Write}, - os::fd::AsFd, - path::Path, - sync::Mutex, -}; +use std::{borrow::Cow, os::fd::AsFd, path::Path, sync::Mutex}; use dashmap::DashMap; -use tokio::io::AsyncBufReadExt; -use tokio_rustls::rustls::pki_types::{CertificateDer, UnixTime}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufWriter}; -use crate::{ - certs::{fingerprint::CertFingerprint, SelfsignedCert, SelfsignedCertVerifier}, - LibError, -}; +use crate::certs::{fingerprint::CertFingerprint, SelfsignedCert}; -#[cfg(feature = "hickory")] -use crate::dns::DnsClient; - -pub struct FileBasedCertVerifier { +pub struct KnownHostsFile { fd: Mutex, map: DashMap, - #[cfg(feature = "hickory")] - dns: Option, } -impl FileBasedCertVerifier { - pub async fn init( - path: impl AsRef, - #[cfg(feature = "hickory")] dns: Option, - ) -> Result { +impl KnownHostsFile { + /// Read or create a known_hosts file at the given path. + /// Format of known_hosts is: _(fields are separated by any space sequence)_ + /// ``` + /// #host expires algo base64 fingerprint + /// dc09.ru 1722930541 sha512 dGVzdHRlc3R0ZXN0Cg + /// ``` + pub async fn parse_file(path: impl AsRef) -> std::io::Result { + // TODO: remove eprintln!()s + let map = DashMap::new(); if tokio::fs::try_exists(&path).await? { @@ -45,10 +35,6 @@ impl FileBasedCertVerifier { break; } - // Format: - // host expires hash-algo fingerprint - // Example: - // dc09.ru 1722930541 sha512 dGVzdHRlc3R0ZXN0Cg if let [host, expires, algo, fp] = *buf.split_whitespace().take(4).collect::>() { @@ -91,34 +77,42 @@ impl FileBasedCertVerifier { } let fd = Mutex::new( - std::fs::OpenOptions::new() + tokio::fs::OpenOptions::new() .append(true) .create(true) - .open(path)? + .open(path) + .await? .as_fd() .try_clone_to_owned()?, ); - Ok(FileBasedCertVerifier { - fd, - map, - #[cfg(feature = "hickory")] - dns, - }) + Ok(KnownHostsFile { fd, map }) } -} -impl FileBasedCertVerifier { - pub fn add_trusted_cert( + pub fn get_known_cert( &self, host: &str, - hash: CertFingerprint, - ) -> Result<(), std::io::Error> { - let fp = hash.base64(); - let ft = hash.fingerprint_type_str(); - // TODO: remove eprintln!() - eprintln!("Warning: adding {} cert with FP {}", &host, &fp); + ) -> Option> { + self.map.get(host) + } + /// Add a new entry to known_hosts (both to the in-memory hashmap and to the file). + /// Should be called only after the TLS cert is verified for validity (including DANE). + /// Exactly the same as `add_cert_to_hashmap` + `add_cert_to_file`. + pub async fn add_trusted_cert(&self, host: &str, hash: CertFingerprint) -> std::io::Result<()> { + let fp = hash.base64(); + let fptype = hash.fingerprint_type_str(); + + self.add_cert_to_hashmap(host, hash); + + self.add_cert_to_file(host, &fp, fptype).await?; + + Ok(()) + } + + /// Add a new trusted cert only to the in-memory hashmap, + /// do not write to the known_hosts file. + pub fn add_cert_to_hashmap(&self, host: &str, hash: CertFingerprint) { self.map.insert( host.to_owned(), SelfsignedCert { @@ -126,117 +120,35 @@ impl FileBasedCertVerifier { expires: 0, // TODO after implementing cert parsing in tokio-gemini }, ); + } + /// Write a new trusted cert's fingerprint to the known_hosts file. + /// - `fp` is a TLS cert hash in base64 (see [`CertFingerprint::base64`]), + /// - `fptype` is a name of hashing algorithm (see [`CertFingerprint::fingerprint_type_str`]). + /// The certificate will not be trusted in the current session unless you call `add_cert_to_hashmap`, + /// so use this function only if you need modularity, otherwise just use `add_trusted_cert`. + pub async fn add_cert_to_file( + &self, + host: &str, + fp: &str, + fptype: &str, + ) -> std::io::Result<()> { // trick with cloning file descriptor // because we are not allowed to mutate &self - let f = std::fs::File::from(self.fd.lock().unwrap().try_clone()?); + let f = tokio::fs::File::from(std::fs::File::from(self.fd.lock().unwrap().try_clone()?)); let mut bw = BufWriter::new(f); - bw.write_all(host.as_bytes())?; - bw.write_all(b" 0 ")?; // TODO after implementing `expires` - bw.write_all(ft.as_bytes())?; - bw.write_all(b" ")?; - bw.write_all(fp.as_bytes())?; - bw.write_all(b"\n")?; - bw.flush()?; + + bw.write_all(host.as_bytes()).await?; + bw.write_all(b" 0 ").await?; // TODO after implementing `expires` + + bw.write_all(fptype.as_bytes()).await?; + bw.write_all(b" ").await?; + + bw.write_all(fp.as_bytes()).await?; + bw.write_all(b"\n").await?; + + bw.flush().await?; Ok(()) } - - #[cfg(feature = "hickory")] - pub fn dane( - &self, - cert: &CertificateDer, - host: &str, - port: u16, - ) -> Result { - let mut dns = if let Some(dns) = &self.dns { - dns.clone() - } else { - return Err(LibError::HostLookupError); - }; - - let rt = tokio::runtime::Handle::try_current()?; - - let mut sha256: Option = None; - let mut sha512: Option = None; - - for tlsa_fp in rt.block_on(dns.query_tlsa(host, port))? { - match tlsa_fp { - CertFingerprint::Sha256(_) => { - if sha256.is_none() { - sha256 = Some(CertFingerprint::new_sha256(cert)); - } - let this_fp = sha256.as_ref().unwrap(); - if this_fp == &tlsa_fp { - return Ok(sha256.unwrap()); - } - } - CertFingerprint::Sha512(_) => { - if sha512.is_none() { - sha512 = Some(CertFingerprint::new_sha512(cert)); - } - let this_fp = sha512.as_ref().unwrap(); - if this_fp == &tlsa_fp { - return Ok(sha512.unwrap()); - } - } - CertFingerprint::Raw(_) => { - let this_fp = CertFingerprint::new_raw(cert); - if this_fp == tlsa_fp { - return Ok(CertFingerprint::new_sha256(cert)); - } - } - } - } - - if let Some(sha256) = sha256 { - Ok(sha256) - } else if let Some(sha512) = sha512 { - Ok(sha512) - } else { - Ok(CertFingerprint::new_sha256(cert)) - } - } -} - -impl SelfsignedCertVerifier for FileBasedCertVerifier { - fn verify( - &self, - cert: &CertificateDer, - host: &str, - _now: UnixTime, - ) -> Result { - // - // TODO: remove eprintln!()s - // - - if let Some(known_cert) = self.map.get(host) { - // if host is found in known_hosts, compare certs - let this_hash = match known_cert.fingerprint { - CertFingerprint::Sha256(_) => CertFingerprint::new_sha256(cert), - CertFingerprint::Sha512(_) => CertFingerprint::new_sha512(cert), - CertFingerprint::Raw(_) => CertFingerprint::new_raw(cert), - }; - Ok(this_hash == known_cert.fingerprint) - } else { - // host is unknown, generate hash and add to known_hosts - #[cfg(feature = "hickory")] - let this_hash = match self.dane(cert, host, 1965) { - Ok(hash) => hash, - Err(e) => { - eprintln!("DANE verification failed: {:?}", e); - CertFingerprint::new_sha256(cert) - } - }; - - #[cfg(not(feature = "hickory"))] - let this_hash = CertFingerprint::new_sha256(cert); - - self.add_trusted_cert(host, this_hash).unwrap_or_else(|e| { - eprintln!("Unable to add new cert: {:?}", e); - }); - - Ok(true) - } - } } diff --git a/src/certs/fingerprint.rs b/src/certs/fingerprint.rs index 77197d7..ef72936 100644 --- a/src/certs/fingerprint.rs +++ b/src/certs/fingerprint.rs @@ -8,7 +8,7 @@ pub use sha2::{Digest, Sha256, Sha512}; use base16ct::upper as b16; use base64ct::{Base64Unpadded as b64, Encoding}; -use super::verifier::CertificateDer; +use super::CertificateDer; pub const SHA256_LEN: usize = 32; // 256 / 8 pub const SHA512_LEN: usize = 64; // 512 / 8 @@ -127,6 +127,15 @@ impl CertFingerprint { } impl CertFingerprint { + pub fn hash_and_compare(&self, cert: &CertificateDer) -> bool { + let hash = match self { + Self::Sha256(_) => Self::new_sha256(cert), + Self::Sha512(_) => Self::new_sha512(cert), + Self::Raw(_) => Self::new_raw(cert), + }; + *self == hash + } + pub fn fingerprint_type_str(&self) -> &'static str { match self { Self::Sha256(_) => "sha256", diff --git a/src/certs/mod.rs b/src/certs/mod.rs index fafe98f..bb728ed 100644 --- a/src/certs/mod.rs +++ b/src/certs/mod.rs @@ -1,27 +1,29 @@ //! Everything related to TLS certs verification +pub mod allow_all; pub mod fingerprint; -pub mod insecure; #[cfg(feature = "file-sscv")] pub mod file_sscv; -pub(crate) mod verifier; +#[cfg(feature = "hickory")] +pub mod dane; +use async_trait::async_trait; pub use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, UnixTime}; -use tokio_rustls::rustls; - -/// Trait for implementing self-signed cert verifiers -/// like [`file_sscv::FileBasedCertVerifier`] -/// (probably via known_hosts with TOFU policy or DANE verification) +/// Trait for implementing self-signed cert verifiers, +/// probably via known_hosts with TOFU policy or DANE verification. +/// It is recommended to use helpers from file_sscv. +#[async_trait] pub trait SelfsignedCertVerifier: Send + Sync { - fn verify( + async fn verify<'c>( &self, - cert: &CertificateDer, + cert: &'c CertificateDer<'c>, host: &str, - now: UnixTime, - ) -> Result; + port: u16, + // now: UnixTime, + ) -> Result; } /// Structure holding a cert fingerprint and expiry date, diff --git a/src/client/builder.rs b/src/client/builder.rs index bc4eebb..7178cce 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -3,20 +3,21 @@ use std::sync::Arc; use crate::{ - certs::{verifier::CustomCertVerifier, SelfsignedCertVerifier}, + certs::{allow_all::AllowAllCertVerifier, SelfsignedCertVerifier}, Client, }; #[cfg(feature = "hickory")] use crate::dns::DnsClient; -use tokio_rustls::rustls::{self, client::danger::ServerCertVerifier, SupportedProtocolVersion}; +use tokio_rustls::{ + rustls::{self, SupportedProtocolVersion}, + TlsConnector, +}; /// Builder for creating configured [`Client`] instance pub struct ClientBuilder { - root_certs: rustls::RootCertStore, - ss_verifier: Option>, - custom_verifier: Option>, + ss_verifier: Option>, tls_versions: Option<&'static [&'static SupportedProtocolVersion]>, #[cfg(feature = "hickory")] dns: Option, @@ -34,9 +35,7 @@ impl ClientBuilder { /// no cert verifiers and default TLS versions. pub fn new() -> Self { ClientBuilder { - root_certs: rustls::RootCertStore::empty(), ss_verifier: None, - custom_verifier: None, tls_versions: None, #[cfg(feature = "hickory")] dns: None, @@ -55,30 +54,19 @@ impl ClientBuilder { } else { rustls::DEFAULT_VERSIONS }) - .unwrap(); - - let tls_config = if let Some(cv) = self.custom_verifier { - tls_config.dangerous().with_custom_certificate_verifier(cv) - } else if let Some(ssv) = self.ss_verifier { - tls_config - .dangerous() - .with_custom_certificate_verifier(Arc::new(CustomCertVerifier { - provider: provider.clone(), - ss_verifier: ssv, - })) - } else { - tls_config.with_root_certificates(self.root_certs) - }; + .unwrap() + .dangerous() + .with_custom_certificate_verifier(Arc::new(AllowAllCertVerifier::from(provider))); // TODO let tls_config = tls_config.with_no_client_auth(); - #[cfg(feature = "hickory")] - if let Some(dns) = self.dns { - return Client::from((tls_config, dns)); + Client { + connector: TlsConnector::from(Arc::new(tls_config)), + ss_verifier: self.ss_verifier, + #[cfg(feature = "hickory")] + dns: self.dns, } - - Client::from(tls_config) } /// Limit the supported TLS versions list to the specified ones. @@ -99,25 +87,28 @@ impl ClientBuilder { mut self, ss_verifier: impl SelfsignedCertVerifier + 'static, ) -> Self { - self.ss_verifier = Some(Box::new(ss_verifier)); + self.ss_verifier = Some(Arc::new(ss_verifier)); self } - /// Include a custom TLS cert verifier implementing rustls' [`ServerCertVerifier`]. - /// Normally need to be used only for [`crate::certs::insecure::AllowAllCertVerifier`]. - /// Note: the webpki verifier and a self-signed cert verifier are not called - /// when a custom verifier is set. - pub fn with_custom_verifier( - mut self, - custom_verifier: impl ServerCertVerifier + 'static, - ) -> Self { - self.custom_verifier = Some(Arc::new(custom_verifier)); + /// Disable TLS cert verification. + /// Use only if you definitely know what you are doing. + // TODO: will make sense after implementing a typestate pattern + pub fn dangerous_with_no_verifier(self) -> Self { self } + /// Use a custom DNS client for resolving IPs. #[cfg(feature = "hickory")] pub fn with_dns_client(mut self, dns: DnsClient) -> Self { self.dns = Some(dns); self } + + /// Same as [`Builder::with_dns_client`], but accepts Option. + #[cfg(feature = "hickory")] + pub fn maybe_with_dns_client(mut self, dns: Option) -> Self { + self.dns = dns; + self + } } diff --git a/src/client/mod.rs b/src/client/mod.rs index f431188..b49336c 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -12,7 +12,11 @@ use hickory_client::rr::IntoName; #[cfg(feature = "hickory")] use std::net::SocketAddr; -use crate::{error::*, status::*}; +use crate::{ + certs::{SelfsignedCertVerifier, ServerName}, + error::*, + status::*, +}; use builder::ClientBuilder; use std::sync::Arc; @@ -21,45 +25,18 @@ use tokio::{ io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}, net::TcpStream, }; -use tokio_rustls::{ - rustls::{self, pki_types}, - TlsConnector, -}; +use tokio_rustls::TlsConnector; use url::Url; pub struct Client { - connector: TlsConnector, + pub(crate) connector: TlsConnector, + pub(crate) ss_verifier: Option>, #[cfg(feature = "hickory")] - dns: Option, -} - -impl From for Client { - /// Create a Client from a Rustls config. - #[inline] - fn from(config: rustls::ClientConfig) -> Self { - Client { - connector: TlsConnector::from(Arc::new(config)), - #[cfg(feature = "hickory")] - dns: None, - } - } -} - -#[cfg(feature = "hickory")] -impl From<(rustls::ClientConfig, DnsClient)> for Client { - /// Create a Client from a Rustls config and - /// a DnsClient instance as a custom resolver. - #[inline] - fn from(value: (rustls::ClientConfig, DnsClient)) -> Self { - Client { - connector: TlsConnector::from(Arc::new(value.0)), - dns: Some(value.1), - } - } + pub(crate) dns: Option, } impl Client { - /// Create a Client with a customized configuration, + /// Construct a Client with a customized configuration, /// see [`ClientBuilder`] methods. pub fn builder() -> ClientBuilder { ClientBuilder::new() @@ -117,7 +94,7 @@ impl Client { host: &str, port: u16, ) -> Result { - let domain = pki_types::ServerName::try_from(host) + let domain = ServerName::try_from(host) .map_err(|_| InvalidUrl::ConvertError)? .to_owned(); @@ -170,7 +147,7 @@ impl Client { Ok(Response::new(status, message, stream)) } - pub async fn try_connect(&self, host: &str, port: u16) -> Result { + async fn try_connect(&self, host: &str, port: u16) -> Result { let mut last_err: Option = None; #[cfg(feature = "hickory")]