refactor: rewrite self-signed cert verifier

This commit is contained in:
DarkCat09 2024-08-21 15:46:52 +04:00
parent 4314df372b
commit e042a139bf
Signed by: DarkCat09
GPG key ID: BD3CE9B65916CD82
11 changed files with 294 additions and 272 deletions

1
Cargo.lock generated
View file

@ -879,6 +879,7 @@ dependencies = [
name = "tokio-gemini" name = "tokio-gemini"
version = "0.4.0" version = "0.4.0"
dependencies = [ dependencies = [
"async-trait",
"base16ct", "base16ct",
"base64ct", "base64ct",
"bytes", "bytes",

View file

@ -24,6 +24,7 @@ tokio-rustls = { version = "0.26.0", default-features = false, features = ["ring
dashmap = { version = "6.0.1", optional = true } dashmap = { version = "6.0.1", optional = true }
hickory-client = { version = "0.24.1", optional = true } hickory-client = { version = "0.24.1", optional = true }
async-trait = "0.1.81"
[dev-dependencies] [dev-dependencies]
tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] }

View file

@ -1,11 +1,13 @@
use tokio_gemini::{ use tokio_gemini::{
certs::{file_sscv::FileBasedCertVerifier, insecure::AllowAllCertVerifier}, certs::{
dane, file_sscv::KnownHostsFile, fingerprint::CertFingerprint, SelfsignedCertVerifier,
},
dns::DnsClient, dns::DnsClient,
Client, LibError, Client, LibError,
}; };
// //
// cargo add tokio_gemini -F file-sscv,hickory // cargo add tokio_gemini -F file-sscv,dane
// cargo add tokio -F macros,rt-multi-thread,io-util,fs // cargo add tokio -F macros,rt-multi-thread,io-util,fs
// //
@ -13,12 +15,6 @@ const USAGE: &str = "-k\t\tinsecure mode (trust all certs)
-d <DNS server addr>\tuse custom DNS for resolving & DANE -d <DNS server addr>\tuse custom DNS for resolving & DANE
-h\t\tshow help"; -h\t\tshow help";
struct Config {
insecure: bool,
dns: Option<String>,
url: String,
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), LibError> { async fn main() -> Result<(), LibError> {
let config = parse_args(); let config = parse_args();
@ -68,11 +64,11 @@ fn parse_args() -> Config {
"-k" => config.insecure = true, "-k" => config.insecure = true,
"-d" => expected_dns = true, "-d" => expected_dns = true,
"-h" => { "-h" => {
println!("{}", USAGE); eprintln!("{}", USAGE);
std::process::exit(0); std::process::exit(0);
} }
url => { url => {
println!("URL: {}", url); eprintln!("URL: {}", url);
config.url = url.to_owned(); config.url = url.to_owned();
break; break;
} }
@ -80,7 +76,7 @@ fn parse_args() -> Config {
} }
if expected_dns { if expected_dns {
println!("{}", USAGE); eprintln!("{}", USAGE);
std::process::exit(0); std::process::exit(0);
} }
@ -94,21 +90,94 @@ async fn build_client(config: &Config) -> Result<Client, LibError> {
None None
}; };
let known_hosts = KnownHostsFile::parse_file("known_hosts").await?;
let verifier = CertVerifier {
known_hosts,
dns: dns.clone(),
};
let client = tokio_gemini::Client::builder(); let client = tokio_gemini::Client::builder();
let client = if config.insecure { let client = if config.insecure {
client.with_custom_verifier(AllowAllCertVerifier::yes_i_know_what_i_am_doing()) client.dangerous_with_no_verifier()
} else { } else {
client.with_selfsigned_cert_verifier( client.with_selfsigned_cert_verifier(verifier)
FileBasedCertVerifier::init("known_hosts", dns.clone()).await?,
)
}; };
let client = if let Some(dns) = dns { let client = client.maybe_with_dns_client(dns);
client.with_dns_client(dns)
} else {
client
};
Ok(client.build()) Ok(client.build())
} }
struct Config {
insecure: bool,
dns: Option<String>,
url: String,
}
struct CertVerifier {
known_hosts: KnownHostsFile,
dns: Option<DnsClient>,
}
impl SelfsignedCertVerifier for CertVerifier {
async fn verify<'c>(
&self,
cert: &'c tokio_gemini::certs::CertificateDer<'c>,
host: &str,
port: u16,
) -> Result<bool, tokio_gemini::LibError> {
if let Some(known) = self.known_hosts.get_known_cert(host) {
// if found in known_hosts, just compare certs
Ok(known.fingerprint.hash_and_compare(cert))
} else {
// otherwise, generate a hash and add to known_hosts
let hash = if let Some(dns) = &self.dns {
// if DNS client is configured, try verifying the cert
// via DANE instead of blindly trusting
match dane::dane(&dns, cert, host, port).await {
Ok(hash) => hash, // use the fingerprint matched with TLSA record
Err(LibError::HostLookupError) => {
// no TLSA record found -- server admin haven't set it up
eprintln!(
"TLSA not configured for tcp:{}:{}, trusting on first use",
host, port,
);
// just generate a hash for this cert
CertFingerprint::new_sha256(cert)
}
Err(e) => {
// some other problem (e.g. DNS server rejected the request),
// we shouldn't continue
eprintln!("DANE verification failed: {:?}", e);
return Err(e);
}
}
} else {
eprintln!("DANE disabled");
// just generate a hash for this cert
CertFingerprint::new_sha256(cert)
};
let fingerprint = hash.base64();
let fptype = hash.fingerprint_type_str();
eprintln!(
"Warning: adding trusted cert for {} with FP {}",
host, &fingerprint,
);
// adding the cert hash to trusted
// can be done simplier:
// self.known_hosts.add_trusted_cert(host, hash).await.unwrap...
self.known_hosts.add_cert_to_hashmap(host, hash);
self.known_hosts
.add_cert_to_file(host, &fingerprint, fptype)
.await
.unwrap_or_else(|e| {
eprintln!("Cert saved in-memory, unable to write to file: {:?}", e);
});
Ok(true)
}
}
}

View file

@ -33,17 +33,18 @@ async fn main() -> Result<(), LibError> {
struct CertVerifier; struct CertVerifier;
impl SelfsignedCertVerifier for CertVerifier { impl SelfsignedCertVerifier for CertVerifier {
fn verify( async fn verify<'c>(
&self, &self,
cert: &tokio_gemini::certs::CertificateDer, cert: &'c tokio_gemini::certs::CertificateDer<'c>,
host: &str, host: &str,
_now: tokio_gemini::certs::UnixTime, port: u16,
) -> Result<bool, tokio_rustls::rustls::Error> { ) -> Result<bool, tokio_gemini::LibError> {
// For real verification example with known_hosts file // For real verification example with known_hosts and DANE
// see examples/main.rs // see examples/main.rs
eprintln!( eprintln!(
"Host = {}\nFingerprint = {}", "Host = {}:{}\nFingerprint = {}",
host, host,
port,
CertFingerprint::new_sha256(cert).base64(), CertFingerprint::new_sha256(cert).base64(),
); );
Ok(true) Ok(true)

View file

@ -1,6 +1,8 @@
//! Custom verifier for Rustls accepting any TLS cert //! Custom verifier for Rustls accepting any TLS cert
//! (usually called "insecure mode") //! (usually called "insecure mode")
use std::sync::Arc;
use tokio_rustls::rustls::{ use tokio_rustls::rustls::{
self, self,
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
@ -9,28 +11,28 @@ use tokio_rustls::rustls::{
/// Custom verifier for Rustls accepting any TLS certificate /// Custom verifier for Rustls accepting any TLS certificate
#[derive(Debug)] #[derive(Debug)]
pub struct AllowAllCertVerifier(std::sync::Arc<CryptoProvider>); pub struct AllowAllCertVerifier(Arc<CryptoProvider>);
impl AllowAllCertVerifier { impl AllowAllCertVerifier {
/// Constructor for this verifier. /// Constructor for this verifier
/// Use only if you know what you are doing. pub fn new() -> Self {
///
/// # Examples
/// ```
/// let client = tokio_gemini::Client::builder()
/// .with_custom_verifier(AllowAllCertVerifier::yes_i_know_what_i_am_doing())
/// .build()
/// ```
pub fn yes_i_know_what_i_am_doing() -> Self {
AllowAllCertVerifier( AllowAllCertVerifier(
CryptoProvider::get_default() CryptoProvider::get_default()
.cloned() .cloned()
.unwrap_or_else(|| std::sync::Arc::new(rustls::crypto::ring::default_provider())), .unwrap_or_else(|| Arc::new(rustls::crypto::ring::default_provider())),
) )
} }
} }
impl From<Arc<CryptoProvider>> for AllowAllCertVerifier {
#[inline]
fn from(value: Arc<CryptoProvider>) -> Self {
AllowAllCertVerifier(value)
}
}
impl ServerCertVerifier for AllowAllCertVerifier { impl ServerCertVerifier for AllowAllCertVerifier {
#[inline]
fn verify_server_cert( fn verify_server_cert(
&self, &self,
_end_entity: &rustls::pki_types::CertificateDer<'_>, _end_entity: &rustls::pki_types::CertificateDer<'_>,
@ -42,6 +44,7 @@ impl ServerCertVerifier for AllowAllCertVerifier {
Ok(ServerCertVerified::assertion()) Ok(ServerCertVerified::assertion())
} }
#[inline]
fn verify_tls12_signature( fn verify_tls12_signature(
&self, &self,
message: &[u8], message: &[u8],
@ -56,6 +59,7 @@ impl ServerCertVerifier for AllowAllCertVerifier {
) )
} }
#[inline]
fn verify_tls13_signature( fn verify_tls13_signature(
&self, &self,
message: &[u8], message: &[u8],
@ -70,6 +74,7 @@ impl ServerCertVerifier for AllowAllCertVerifier {
) )
} }
#[inline]
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> { fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes() self.0.signature_verification_algorithms.supported_schemes()
} }

54
src/certs/dane.rs Normal file
View file

@ -0,0 +1,54 @@
use crate::{
certs::{fingerprint::CertFingerprint, CertificateDer},
dns::DnsClient,
LibError,
};
pub async fn dane<'d>(
dns: &DnsClient,
cert: &CertificateDer<'d>,
host: &str,
port: u16,
) -> Result<CertFingerprint, LibError> {
let mut dns = dns.clone();
let mut sha256: Option<CertFingerprint> = None;
let mut sha512: Option<CertFingerprint> = None;
for tlsa_fp in dns.query_tlsa(host, port).await? {
match tlsa_fp {
CertFingerprint::Sha256(_) => {
if sha256.is_none() {
sha256 = Some(CertFingerprint::new_sha256(cert));
}
let this_fp = sha256.as_ref().unwrap();
if this_fp == &tlsa_fp {
return Ok(sha256.unwrap());
}
}
CertFingerprint::Sha512(_) => {
if sha512.is_none() {
sha512 = Some(CertFingerprint::new_sha512(cert));
}
let this_fp = sha512.as_ref().unwrap();
if this_fp == &tlsa_fp {
return Ok(sha512.unwrap());
}
}
CertFingerprint::Raw(_) => {
let this_fp = CertFingerprint::new_raw(cert);
if this_fp == tlsa_fp {
return Ok(CertFingerprint::new_sha256(cert));
}
}
}
}
if let Some(sha256) = sha256 {
Ok(sha256)
} else if let Some(sha512) = sha512 {
Ok(sha512)
} else {
Ok(CertFingerprint::new_sha256(cert))
}
}

View file

@ -1,35 +1,25 @@
use std::{ use std::{borrow::Cow, os::fd::AsFd, path::Path, sync::Mutex};
borrow::Cow,
io::{BufWriter, Write},
os::fd::AsFd,
path::Path,
sync::Mutex,
};
use dashmap::DashMap; use dashmap::DashMap;
use tokio::io::AsyncBufReadExt; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufWriter};
use tokio_rustls::rustls::pki_types::{CertificateDer, UnixTime};
use crate::{ use crate::certs::{fingerprint::CertFingerprint, SelfsignedCert};
certs::{fingerprint::CertFingerprint, SelfsignedCert, SelfsignedCertVerifier},
LibError,
};
#[cfg(feature = "hickory")] pub struct KnownHostsFile {
use crate::dns::DnsClient;
pub struct FileBasedCertVerifier {
fd: Mutex<std::os::fd::OwnedFd>, fd: Mutex<std::os::fd::OwnedFd>,
map: DashMap<String, SelfsignedCert>, map: DashMap<String, SelfsignedCert>,
#[cfg(feature = "hickory")]
dns: Option<DnsClient>,
} }
impl FileBasedCertVerifier { impl KnownHostsFile {
pub async fn init( /// Read or create a known_hosts file at the given path.
path: impl AsRef<Path>, /// Format of known_hosts is: _(fields are separated by any space sequence)_
#[cfg(feature = "hickory")] dns: Option<DnsClient>, /// ```
) -> Result<Self, LibError> { /// #host expires algo base64 fingerprint
/// dc09.ru 1722930541 sha512 dGVzdHRlc3R0ZXN0Cg
/// ```
pub async fn parse_file(path: impl AsRef<Path>) -> std::io::Result<KnownHostsFile> {
// TODO: remove eprintln!()s
let map = DashMap::new(); let map = DashMap::new();
if tokio::fs::try_exists(&path).await? { if tokio::fs::try_exists(&path).await? {
@ -45,10 +35,6 @@ impl FileBasedCertVerifier {
break; break;
} }
// Format:
// host <space> expires <space> hash-algo <space> fingerprint
// Example:
// dc09.ru 1722930541 sha512 dGVzdHRlc3R0ZXN0Cg
if let [host, expires, algo, fp] = if let [host, expires, algo, fp] =
*buf.split_whitespace().take(4).collect::<Cow<[&str]>>() *buf.split_whitespace().take(4).collect::<Cow<[&str]>>()
{ {
@ -91,34 +77,42 @@ impl FileBasedCertVerifier {
} }
let fd = Mutex::new( let fd = Mutex::new(
std::fs::OpenOptions::new() tokio::fs::OpenOptions::new()
.append(true) .append(true)
.create(true) .create(true)
.open(path)? .open(path)
.await?
.as_fd() .as_fd()
.try_clone_to_owned()?, .try_clone_to_owned()?,
); );
Ok(FileBasedCertVerifier { Ok(KnownHostsFile { fd, map })
fd,
map,
#[cfg(feature = "hickory")]
dns,
})
} }
}
impl FileBasedCertVerifier { pub fn get_known_cert(
pub fn add_trusted_cert(
&self, &self,
host: &str, host: &str,
hash: CertFingerprint, ) -> Option<dashmap::mapref::one::Ref<String, SelfsignedCert>> {
) -> Result<(), std::io::Error> { self.map.get(host)
let fp = hash.base64(); }
let ft = hash.fingerprint_type_str();
// TODO: remove eprintln!()
eprintln!("Warning: adding {} cert with FP {}", &host, &fp);
/// Add a new entry to known_hosts (both to the in-memory hashmap and to the file).
/// Should be called only after the TLS cert is verified for validity (including DANE).
/// Exactly the same as `add_cert_to_hashmap` + `add_cert_to_file`.
pub async fn add_trusted_cert(&self, host: &str, hash: CertFingerprint) -> std::io::Result<()> {
let fp = hash.base64();
let fptype = hash.fingerprint_type_str();
self.add_cert_to_hashmap(host, hash);
self.add_cert_to_file(host, &fp, fptype).await?;
Ok(())
}
/// Add a new trusted cert only to the in-memory hashmap,
/// do not write to the known_hosts file.
pub fn add_cert_to_hashmap(&self, host: &str, hash: CertFingerprint) {
self.map.insert( self.map.insert(
host.to_owned(), host.to_owned(),
SelfsignedCert { SelfsignedCert {
@ -126,117 +120,35 @@ impl FileBasedCertVerifier {
expires: 0, // TODO after implementing cert parsing in tokio-gemini expires: 0, // TODO after implementing cert parsing in tokio-gemini
}, },
); );
}
/// Write a new trusted cert's fingerprint to the known_hosts file.
/// - `fp` is a TLS cert hash in base64 (see [`CertFingerprint::base64`]),
/// - `fptype` is a name of hashing algorithm (see [`CertFingerprint::fingerprint_type_str`]).
/// The certificate will not be trusted in the current session unless you call `add_cert_to_hashmap`,
/// so use this function only if you need modularity, otherwise just use `add_trusted_cert`.
pub async fn add_cert_to_file(
&self,
host: &str,
fp: &str,
fptype: &str,
) -> std::io::Result<()> {
// trick with cloning file descriptor // trick with cloning file descriptor
// because we are not allowed to mutate &self // because we are not allowed to mutate &self
let f = std::fs::File::from(self.fd.lock().unwrap().try_clone()?); let f = tokio::fs::File::from(std::fs::File::from(self.fd.lock().unwrap().try_clone()?));
let mut bw = BufWriter::new(f); let mut bw = BufWriter::new(f);
bw.write_all(host.as_bytes())?;
bw.write_all(b" 0 ")?; // TODO after implementing `expires` bw.write_all(host.as_bytes()).await?;
bw.write_all(ft.as_bytes())?; bw.write_all(b" 0 ").await?; // TODO after implementing `expires`
bw.write_all(b" ")?;
bw.write_all(fp.as_bytes())?; bw.write_all(fptype.as_bytes()).await?;
bw.write_all(b"\n")?; bw.write_all(b" ").await?;
bw.flush()?;
bw.write_all(fp.as_bytes()).await?;
bw.write_all(b"\n").await?;
bw.flush().await?;
Ok(()) Ok(())
} }
#[cfg(feature = "hickory")]
pub fn dane(
&self,
cert: &CertificateDer,
host: &str,
port: u16,
) -> Result<CertFingerprint, LibError> {
let mut dns = if let Some(dns) = &self.dns {
dns.clone()
} else {
return Err(LibError::HostLookupError);
};
let rt = tokio::runtime::Handle::try_current()?;
let mut sha256: Option<CertFingerprint> = None;
let mut sha512: Option<CertFingerprint> = None;
for tlsa_fp in rt.block_on(dns.query_tlsa(host, port))? {
match tlsa_fp {
CertFingerprint::Sha256(_) => {
if sha256.is_none() {
sha256 = Some(CertFingerprint::new_sha256(cert));
}
let this_fp = sha256.as_ref().unwrap();
if this_fp == &tlsa_fp {
return Ok(sha256.unwrap());
}
}
CertFingerprint::Sha512(_) => {
if sha512.is_none() {
sha512 = Some(CertFingerprint::new_sha512(cert));
}
let this_fp = sha512.as_ref().unwrap();
if this_fp == &tlsa_fp {
return Ok(sha512.unwrap());
}
}
CertFingerprint::Raw(_) => {
let this_fp = CertFingerprint::new_raw(cert);
if this_fp == tlsa_fp {
return Ok(CertFingerprint::new_sha256(cert));
}
}
}
}
if let Some(sha256) = sha256 {
Ok(sha256)
} else if let Some(sha512) = sha512 {
Ok(sha512)
} else {
Ok(CertFingerprint::new_sha256(cert))
}
}
}
impl SelfsignedCertVerifier for FileBasedCertVerifier {
fn verify(
&self,
cert: &CertificateDer,
host: &str,
_now: UnixTime,
) -> Result<bool, tokio_rustls::rustls::Error> {
//
// TODO: remove eprintln!()s
//
if let Some(known_cert) = self.map.get(host) {
// if host is found in known_hosts, compare certs
let this_hash = match known_cert.fingerprint {
CertFingerprint::Sha256(_) => CertFingerprint::new_sha256(cert),
CertFingerprint::Sha512(_) => CertFingerprint::new_sha512(cert),
CertFingerprint::Raw(_) => CertFingerprint::new_raw(cert),
};
Ok(this_hash == known_cert.fingerprint)
} else {
// host is unknown, generate hash and add to known_hosts
#[cfg(feature = "hickory")]
let this_hash = match self.dane(cert, host, 1965) {
Ok(hash) => hash,
Err(e) => {
eprintln!("DANE verification failed: {:?}", e);
CertFingerprint::new_sha256(cert)
}
};
#[cfg(not(feature = "hickory"))]
let this_hash = CertFingerprint::new_sha256(cert);
self.add_trusted_cert(host, this_hash).unwrap_or_else(|e| {
eprintln!("Unable to add new cert: {:?}", e);
});
Ok(true)
}
}
} }

View file

@ -8,7 +8,7 @@ pub use sha2::{Digest, Sha256, Sha512};
use base16ct::upper as b16; use base16ct::upper as b16;
use base64ct::{Base64Unpadded as b64, Encoding}; use base64ct::{Base64Unpadded as b64, Encoding};
use super::verifier::CertificateDer; use super::CertificateDer;
pub const SHA256_LEN: usize = 32; // 256 / 8 pub const SHA256_LEN: usize = 32; // 256 / 8
pub const SHA512_LEN: usize = 64; // 512 / 8 pub const SHA512_LEN: usize = 64; // 512 / 8
@ -127,6 +127,15 @@ impl CertFingerprint {
} }
impl CertFingerprint { impl CertFingerprint {
pub fn hash_and_compare(&self, cert: &CertificateDer) -> bool {
let hash = match self {
Self::Sha256(_) => Self::new_sha256(cert),
Self::Sha512(_) => Self::new_sha512(cert),
Self::Raw(_) => Self::new_raw(cert),
};
*self == hash
}
pub fn fingerprint_type_str(&self) -> &'static str { pub fn fingerprint_type_str(&self) -> &'static str {
match self { match self {
Self::Sha256(_) => "sha256", Self::Sha256(_) => "sha256",

View file

@ -1,27 +1,29 @@
//! Everything related to TLS certs verification //! Everything related to TLS certs verification
pub mod allow_all;
pub mod fingerprint; pub mod fingerprint;
pub mod insecure;
#[cfg(feature = "file-sscv")] #[cfg(feature = "file-sscv")]
pub mod file_sscv; pub mod file_sscv;
pub(crate) mod verifier; #[cfg(feature = "hickory")]
pub mod dane;
use async_trait::async_trait;
pub use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, UnixTime}; pub use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use tokio_rustls::rustls; /// Trait for implementing self-signed cert verifiers,
/// probably via known_hosts with TOFU policy or DANE verification.
/// Trait for implementing self-signed cert verifiers /// It is recommended to use helpers from file_sscv.
/// like [`file_sscv::FileBasedCertVerifier`] #[async_trait]
/// (probably via known_hosts with TOFU policy or DANE verification)
pub trait SelfsignedCertVerifier: Send + Sync { pub trait SelfsignedCertVerifier: Send + Sync {
fn verify( async fn verify<'c>(
&self, &self,
cert: &CertificateDer, cert: &'c CertificateDer<'c>,
host: &str, host: &str,
now: UnixTime, port: u16,
) -> Result<bool, rustls::Error>; // now: UnixTime,
) -> Result<bool, crate::LibError>;
} }
/// Structure holding a cert fingerprint and expiry date, /// Structure holding a cert fingerprint and expiry date,

View file

@ -3,20 +3,21 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::{
certs::{verifier::CustomCertVerifier, SelfsignedCertVerifier}, certs::{allow_all::AllowAllCertVerifier, SelfsignedCertVerifier},
Client, Client,
}; };
#[cfg(feature = "hickory")] #[cfg(feature = "hickory")]
use crate::dns::DnsClient; use crate::dns::DnsClient;
use tokio_rustls::rustls::{self, client::danger::ServerCertVerifier, SupportedProtocolVersion}; use tokio_rustls::{
rustls::{self, SupportedProtocolVersion},
TlsConnector,
};
/// Builder for creating configured [`Client`] instance /// Builder for creating configured [`Client`] instance
pub struct ClientBuilder { pub struct ClientBuilder {
root_certs: rustls::RootCertStore, ss_verifier: Option<Arc<dyn SelfsignedCertVerifier>>,
ss_verifier: Option<Box<dyn SelfsignedCertVerifier>>,
custom_verifier: Option<Arc<dyn ServerCertVerifier + 'static>>,
tls_versions: Option<&'static [&'static SupportedProtocolVersion]>, tls_versions: Option<&'static [&'static SupportedProtocolVersion]>,
#[cfg(feature = "hickory")] #[cfg(feature = "hickory")]
dns: Option<DnsClient>, dns: Option<DnsClient>,
@ -34,9 +35,7 @@ impl ClientBuilder {
/// no cert verifiers and default TLS versions. /// no cert verifiers and default TLS versions.
pub fn new() -> Self { pub fn new() -> Self {
ClientBuilder { ClientBuilder {
root_certs: rustls::RootCertStore::empty(),
ss_verifier: None, ss_verifier: None,
custom_verifier: None,
tls_versions: None, tls_versions: None,
#[cfg(feature = "hickory")] #[cfg(feature = "hickory")]
dns: None, dns: None,
@ -55,30 +54,19 @@ impl ClientBuilder {
} else { } else {
rustls::DEFAULT_VERSIONS rustls::DEFAULT_VERSIONS
}) })
.unwrap(); .unwrap()
let tls_config = if let Some(cv) = self.custom_verifier {
tls_config.dangerous().with_custom_certificate_verifier(cv)
} else if let Some(ssv) = self.ss_verifier {
tls_config
.dangerous() .dangerous()
.with_custom_certificate_verifier(Arc::new(CustomCertVerifier { .with_custom_certificate_verifier(Arc::new(AllowAllCertVerifier::from(provider)));
provider: provider.clone(),
ss_verifier: ssv,
}))
} else {
tls_config.with_root_certificates(self.root_certs)
};
// TODO // TODO
let tls_config = tls_config.with_no_client_auth(); let tls_config = tls_config.with_no_client_auth();
Client {
connector: TlsConnector::from(Arc::new(tls_config)),
ss_verifier: self.ss_verifier,
#[cfg(feature = "hickory")] #[cfg(feature = "hickory")]
if let Some(dns) = self.dns { dns: self.dns,
return Client::from((tls_config, dns));
} }
Client::from(tls_config)
} }
/// Limit the supported TLS versions list to the specified ones. /// Limit the supported TLS versions list to the specified ones.
@ -99,25 +87,28 @@ impl ClientBuilder {
mut self, mut self,
ss_verifier: impl SelfsignedCertVerifier + 'static, ss_verifier: impl SelfsignedCertVerifier + 'static,
) -> Self { ) -> Self {
self.ss_verifier = Some(Box::new(ss_verifier)); self.ss_verifier = Some(Arc::new(ss_verifier));
self self
} }
/// Include a custom TLS cert verifier implementing rustls' [`ServerCertVerifier`]. /// Disable TLS cert verification.
/// Normally need to be used only for [`crate::certs::insecure::AllowAllCertVerifier`]. /// Use only if you definitely know what you are doing.
/// Note: the webpki verifier and a self-signed cert verifier are not called // TODO: will make sense after implementing a typestate pattern
/// when a custom verifier is set. pub fn dangerous_with_no_verifier(self) -> Self {
pub fn with_custom_verifier(
mut self,
custom_verifier: impl ServerCertVerifier + 'static,
) -> Self {
self.custom_verifier = Some(Arc::new(custom_verifier));
self self
} }
/// Use a custom DNS client for resolving IPs.
#[cfg(feature = "hickory")] #[cfg(feature = "hickory")]
pub fn with_dns_client(mut self, dns: DnsClient) -> Self { pub fn with_dns_client(mut self, dns: DnsClient) -> Self {
self.dns = Some(dns); self.dns = Some(dns);
self self
} }
/// Same as [`Builder::with_dns_client`], but accepts Option.
#[cfg(feature = "hickory")]
pub fn maybe_with_dns_client(mut self, dns: Option<DnsClient>) -> Self {
self.dns = dns;
self
}
} }

View file

@ -12,7 +12,11 @@ use hickory_client::rr::IntoName;
#[cfg(feature = "hickory")] #[cfg(feature = "hickory")]
use std::net::SocketAddr; use std::net::SocketAddr;
use crate::{error::*, status::*}; use crate::{
certs::{SelfsignedCertVerifier, ServerName},
error::*,
status::*,
};
use builder::ClientBuilder; use builder::ClientBuilder;
use std::sync::Arc; use std::sync::Arc;
@ -21,45 +25,18 @@ use tokio::{
io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}, io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt},
net::TcpStream, net::TcpStream,
}; };
use tokio_rustls::{ use tokio_rustls::TlsConnector;
rustls::{self, pki_types},
TlsConnector,
};
use url::Url; use url::Url;
pub struct Client { pub struct Client {
connector: TlsConnector, pub(crate) connector: TlsConnector,
pub(crate) ss_verifier: Option<Arc<dyn SelfsignedCertVerifier>>,
#[cfg(feature = "hickory")] #[cfg(feature = "hickory")]
dns: Option<DnsClient>, pub(crate) dns: Option<DnsClient>,
}
impl From<rustls::ClientConfig> for Client {
/// Create a Client from a Rustls config.
#[inline]
fn from(config: rustls::ClientConfig) -> Self {
Client {
connector: TlsConnector::from(Arc::new(config)),
#[cfg(feature = "hickory")]
dns: None,
}
}
}
#[cfg(feature = "hickory")]
impl From<(rustls::ClientConfig, DnsClient)> for Client {
/// Create a Client from a Rustls config and
/// a DnsClient instance as a custom resolver.
#[inline]
fn from(value: (rustls::ClientConfig, DnsClient)) -> Self {
Client {
connector: TlsConnector::from(Arc::new(value.0)),
dns: Some(value.1),
}
}
} }
impl Client { impl Client {
/// Create a Client with a customized configuration, /// Construct a Client with a customized configuration,
/// see [`ClientBuilder`] methods. /// see [`ClientBuilder`] methods.
pub fn builder() -> ClientBuilder { pub fn builder() -> ClientBuilder {
ClientBuilder::new() ClientBuilder::new()
@ -117,7 +94,7 @@ impl Client {
host: &str, host: &str,
port: u16, port: u16,
) -> Result<Response, LibError> { ) -> Result<Response, LibError> {
let domain = pki_types::ServerName::try_from(host) let domain = ServerName::try_from(host)
.map_err(|_| InvalidUrl::ConvertError)? .map_err(|_| InvalidUrl::ConvertError)?
.to_owned(); .to_owned();
@ -170,7 +147,7 @@ impl Client {
Ok(Response::new(status, message, stream)) Ok(Response::new(status, message, stream))
} }
pub async fn try_connect(&self, host: &str, port: u16) -> Result<TcpStream, LibError> { async fn try_connect(&self, host: &str, port: u16) -> Result<TcpStream, LibError> {
let mut last_err: Option<std::io::Error> = None; let mut last_err: Option<std::io::Error> = None;
#[cfg(feature = "hickory")] #[cfg(feature = "hickory")]