diff --git a/src/certs/file_sscv.rs b/src/certs/file_sscv.rs index 460c37b..b85241d 100644 --- a/src/certs/file_sscv.rs +++ b/src/certs/file_sscv.rs @@ -1,4 +1,10 @@ -use std::{borrow::Cow, io::Write, os::fd::AsFd, path::Path, sync::Mutex}; +use std::{ + borrow::Cow, + io::{BufWriter, Write}, + os::fd::AsFd, + path::Path, + sync::Mutex, +}; use dashmap::DashMap; use tokio::io::AsyncBufReadExt; @@ -12,6 +18,8 @@ use crate::{ pub struct FileBasedCertVerifier { fd: Mutex, map: DashMap, + #[cfg(feature = "hickory")] + dns: Option, } impl FileBasedCertVerifier { @@ -85,7 +93,103 @@ impl FileBasedCertVerifier { .try_clone_to_owned()?, ); - Ok(FileBasedCertVerifier { fd, map }) + Ok(FileBasedCertVerifier { + fd, + map, + #[cfg(feature = "hickory")] + dns: None, + }) + } +} + +impl FileBasedCertVerifier { + pub fn add_trusted_cert( + &self, + host: &str, + hash: CertFingerprint, + ) -> Result<(), std::io::Error> { + let fp = hash.base64(); + let ft = hash.fingerprint_type_str(); + // TODO: remove eprintln!() + eprintln!("Warning: adding {} cert with FP {}", &host, &fp); + + self.map.insert( + host.to_owned(), + SelfsignedCert { + fingerprint: hash, + expires: 0, // TODO after implementing cert parsing in tokio-gemini + }, + ); + + // trick with cloning file descriptor + // because we are not allowed to mutate &self + let f = std::fs::File::from(self.fd.lock().unwrap().try_clone()?); + let mut bw = BufWriter::new(f); + bw.write_all(host.as_bytes())?; + bw.write_all(b" 0 ")?; // TODO after implementing `expires` + bw.write_all(ft.as_bytes())?; + bw.write_all(b" ")?; + bw.write_all(fp.as_bytes())?; + bw.write_all(b"\n")?; + bw.flush()?; + + Ok(()) + } + + #[cfg(feature = "hickory")] + pub fn dane( + &self, + cert: &CertificateDer, + host: &str, + port: u16, + ) -> Result { + let mut dns = if let Some(dns) = &self.dns { + dns.clone() + } else { + return Err(LibError::HostLookupError); + }; + + let rt = tokio::runtime::Handle::try_current()?; + + let mut sha256: Option = None; + let mut sha512: Option = None; + + for tlsa_fp in rt.block_on(dns.query_tlsa(host, port))? { + match tlsa_fp { + CertFingerprint::Sha256(_) => { + if sha256.is_none() { + sha256 = Some(CertFingerprint::new_sha256(cert)); + } + let this_fp = sha256.as_ref().unwrap(); + if this_fp == &tlsa_fp { + return Ok(sha256.unwrap()); + } + } + CertFingerprint::Sha512(_) => { + if sha512.is_none() { + sha512 = Some(CertFingerprint::new_sha512(cert)); + } + let this_fp = sha512.as_ref().unwrap(); + if this_fp == &tlsa_fp { + return Ok(sha512.unwrap()); + } + } + CertFingerprint::Raw(_) => { + let this_fp = CertFingerprint::new_raw(cert); + if this_fp == tlsa_fp { + return Ok(CertFingerprint::new_sha256(cert)); + } + } + } + } + + if let Some(sha256) = sha256 { + Ok(sha256) + } else if let Some(sha512) = sha512 { + Ok(sha512) + } else { + Ok(CertFingerprint::new_sha256(cert)) + } } } @@ -97,49 +201,35 @@ impl SelfsignedCertVerifier for FileBasedCertVerifier { _now: UnixTime, ) -> Result { // - // TODO: remove eprintln!()s and do overall code cleanup + // TODO: remove eprintln!()s // if let Some(known_cert) = self.map.get(host) { // if host is found in known_hosts, compare certs - let this_fp = match known_cert.fingerprint { + let this_hash = match known_cert.fingerprint { CertFingerprint::Sha256(_) => CertFingerprint::new_sha256(cert), CertFingerprint::Sha512(_) => CertFingerprint::new_sha512(cert), - _ => unreachable!(), + CertFingerprint::Raw(_) => CertFingerprint::new_raw(cert), }; - Ok(this_fp == known_cert.fingerprint) + Ok(this_hash == known_cert.fingerprint) } else { // host is unknown, generate hash and add to known_hosts + #[cfg(feature = "hickory")] + let this_hash = match self.dane(cert, host, 1965) { + Ok(hash) => hash, + Err(e) => { + eprintln!("DANE verification failed: {:?}", e); + CertFingerprint::new_sha256(cert) + } + }; + + #[cfg(not(feature = "hickory"))] let this_hash = CertFingerprint::new_sha256(cert); - let this_fp = this_hash.base64(); - // TODO: DANE cert check, use this_hash.hex() for this - eprintln!( - "Warning: updating known_hosts with cert {} for {}", - &this_fp, &host, - ); - (|| { - // trick with cloning file descriptor - // because we are not allowed to mutate &self - let mut f = std::fs::File::from(self.fd.lock().unwrap().try_clone()?); - f.write_all(host.as_bytes())?; - f.write_all(b" 0 sha256 ")?; // TODO after implementing `expires` - f.write_all(this_fp.as_bytes())?; - f.write_all(b"\n")?; - Ok::<(), std::io::Error>(()) - })() - .unwrap_or_else(|e| { - eprintln!("Could not add cert to file: {:?}", e); + self.add_trusted_cert(host, this_hash).unwrap_or_else(|e| { + eprintln!("Unable to add new cert: {:?}", e); }); - self.map.insert( - host.to_owned(), - SelfsignedCert { - fingerprint: this_hash, - expires: 0, // TODO after implementing cert parsing in tokio-gemini - }, - ); - Ok(true) } } diff --git a/src/certs/fingerprint.rs b/src/certs/fingerprint.rs index 57598f6..77197d7 100644 --- a/src/certs/fingerprint.rs +++ b/src/certs/fingerprint.rs @@ -125,3 +125,13 @@ impl CertFingerprint { } } } + +impl CertFingerprint { + pub fn fingerprint_type_str(&self) -> &'static str { + match self { + Self::Sha256(_) => "sha256", + Self::Sha512(_) => "sha512", + Self::Raw(_) => "raw", + } + } +} diff --git a/src/error.rs b/src/error.rs index 429fdce..1a0c435 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,6 +4,8 @@ use hickory_client::{ error::ClientError as HickoryClientError, proto::error::ProtoError as HickoryProtoError, }; +#[cfg(feature = "hickory")] +use tokio::runtime::TryCurrentError; /// Main error structure, also a wrapper for everything else #[derive(Debug)] @@ -29,6 +31,10 @@ pub enum LibError { /// Hickory Proto error #[cfg(feature = "hickory")] DnsProtoError(HickoryProtoError), + /// Could not get Tokio runtime handle + /// inside Rustls cert verifier + #[cfg(feature = "hickory")] + NoTokioRuntime(TryCurrentError), } impl From for LibError { @@ -89,6 +95,14 @@ impl From for LibError { } } +#[cfg(feature = "hickory")] +impl From for LibError { + #[inline] + fn from(err: TryCurrentError) -> Self { + Self::NoTokioRuntime(err) + } +} + /// URL parse or check error #[derive(Debug)] pub enum InvalidUrl {