diff --git a/src/builder.rs b/src/builder.rs new file mode 100644 index 0000000..670e2de --- /dev/null +++ b/src/builder.rs @@ -0,0 +1,116 @@ +use std::sync::Arc; + +use crate::{ + certs::verifier::{CustomCertVerifier, SelfsignedCertVerifier}, + Client, +}; + +use tokio_rustls::rustls::{ + self, + client::{danger::ServerCertVerifier, WebPkiServerVerifier}, + pki_types::TrustAnchor, + SupportedProtocolVersion, +}; + +pub struct ClientBuilder { + root_certs: rustls::RootCertStore, + ss_verifier: Option>, + custom_verifier: Option>, + tls_versions: Option<&'static [&'static SupportedProtocolVersion]>, +} + +impl ClientBuilder { + pub fn new() -> Self { + ClientBuilder { + root_certs: rustls::RootCertStore::empty(), + ss_verifier: None, + custom_verifier: None, + tls_versions: None, + } + } + + pub fn build(self) -> Client { + let provider = rustls::crypto::CryptoProvider::get_default() + .map(|c| c.clone()) + .unwrap_or_else(|| Arc::new(rustls::crypto::ring::default_provider())); + + let tls_config = rustls::ClientConfig::builder_with_provider(provider.clone()) + .with_protocol_versions(if let Some(versions) = self.tls_versions { + versions + } else { + rustls::DEFAULT_VERSIONS + }) + .unwrap(); + + let tls_config = if let Some(cv) = self.custom_verifier { + tls_config.dangerous().with_custom_certificate_verifier(cv) + } else if let Some(ssv) = self.ss_verifier { + tls_config + .dangerous() + .with_custom_certificate_verifier(Arc::new(CustomCertVerifier { + provider: provider.clone(), + webpki_verifier: if !self.root_certs.is_empty() { + Some( + WebPkiServerVerifier::builder_with_provider( + Arc::new(self.root_certs), + provider, + ) + .build() + // panics only if roots are empty (that is checked above) + // or CRLs couldn't be parsed (we didn't provide any) + .unwrap(), + ) + } else { + None + }, + ss_allowed: true, + ss_verifier: ssv, + })) + } else { + tls_config.with_root_certificates(self.root_certs) + }; + + // TODO + let tls_config = tls_config.with_no_client_auth(); + + Client::from(tls_config) + } + + pub fn with_webpki_roots(mut self) -> Self { + self.root_certs + .extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + self + } + + pub fn with_custom_roots( + mut self, + iter: impl IntoIterator>, + ) -> Self { + self.root_certs.extend(iter); + self + } + + pub fn with_selfsigned_cert_verifier( + mut self, + ss_verifier: impl SelfsignedCertVerifier + 'static, + ) -> Self { + self.ss_verifier = Some(Box::new(ss_verifier)); + self + } + + pub fn with_custom_verifier( + mut self, + custom_verifier: impl ServerCertVerifier + 'static, + ) -> Self { + self.custom_verifier = Some(Arc::new(custom_verifier)); + self + } + + pub fn with_tls_versions( + mut self, + versions: &'static [&'static SupportedProtocolVersion], + ) -> Self { + self.tls_versions = Some(versions); + self + } +} diff --git a/src/certs/verifier.rs b/src/certs/verifier.rs index 93ef818..cb7d477 100644 --- a/src/certs/verifier.rs +++ b/src/certs/verifier.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + pub use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use tokio_rustls::rustls::{ @@ -21,10 +23,10 @@ pub struct SelfsignedCert { } pub struct CustomCertVerifier { - provider: rustls::crypto::CryptoProvider, - webpki_verifier: Option, - ss_allowed: bool, - ss_verifier: dyn SelfsignedCertVerifier, + pub(crate) provider: Arc, + pub(crate) webpki_verifier: Option>, + pub(crate) ss_allowed: bool, + pub(crate) ss_verifier: Box, } impl ServerCertVerifier for CustomCertVerifier { diff --git a/src/client.rs b/src/client.rs index 4f8f11f..e9c3953 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,4 @@ -use crate::{error::*, response::Response, status::*}; +use crate::{builder::ClientBuilder, error::*, response::Response, status::*}; use std::net::ToSocketAddrs; use std::sync::Arc; @@ -37,6 +37,12 @@ impl From for Client { } } +impl Client { + pub fn builder() -> ClientBuilder { + ClientBuilder::new() + } +} + impl Client { pub async fn request(self: &Self, url_str: &str) -> Result { let url = Url::parse(url_str).map_err(|e| InvalidUrl::ParseError(e))?; diff --git a/src/lib.rs b/src/lib.rs index e6f456e..d079639 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod builder; pub mod certs; pub mod client; pub mod error;