From 46be8b96622cc580c914cdf2010684519c756653 Mon Sep 17 00:00:00 2001
From: Frank Denis <github@pureftpd.org>
Date: Fri, 29 Oct 2021 20:13:47 +0200
Subject: [PATCH] Painful update of rustls

---
 src/libdoh/Cargo.toml | 11 ++++-----
 src/libdoh/src/tls.rs | 52 +++++++++++++++++++++++++------------------
 2 files changed, 36 insertions(+), 27 deletions(-)

diff --git a/src/libdoh/Cargo.toml b/src/libdoh/Cargo.toml
index d5a63a8..9318afa 100644
--- a/src/libdoh/Cargo.toml
+++ b/src/libdoh/Cargo.toml
@@ -15,18 +15,19 @@ default = ["tls"]
 tls = ["tokio-rustls"]
 
 [dependencies]
-anyhow = "1.0.43"
-arc-swap = "1.3.2"
+anyhow = "1.0.44"
+arc-swap = "1.4.0"
 base64 = "0.13.0"
 byteorder = "1.4.3"
 bytes = "1.1.0"
 futures = "0.3.17"
 hpke = "0.5.1"
-hyper = { version = "0.14.12", default-features = false, features = ["server", "http1", "http2", "stream"] }
+hyper = { version = "0.14.14", default-features = false, features = ["server", "http1", "http2", "stream"] }
 odoh-rs = "1.0.0-alpha.1"
 rand = "0.8.4"
-tokio = { version = "1.11.0", features = ["net", "rt-multi-thread", "parking_lot", "time", "sync"] }
-tokio-rustls = { version = "0.22.0", features = ["early-data"], optional = true }
+tokio = { version = "1.13.0", features = ["net", "rt-multi-thread", "parking_lot", "time", "sync"] }
+tokio-rustls = { version = "0.23.0", features = ["early-data"], optional = true }
+rustls-pemfile = "0.2.1"
 
 [profile.release]
 codegen-units = 1
diff --git a/src/libdoh/src/tls.rs b/src/libdoh/src/tls.rs
index 289b9dd..80dfc2b 100644
--- a/src/libdoh/src/tls.rs
+++ b/src/libdoh/src/tls.rs
@@ -14,7 +14,7 @@ use tokio::{
     sync::mpsc::{self, Receiver},
 };
 use tokio_rustls::{
-    rustls::{internal::pemfile, NoClientAuth, ServerConfig},
+    rustls::{Certificate, PrivateKey, ServerConfig},
     TlsAcceptor,
 };
 
@@ -23,7 +23,7 @@ where
     P: AsRef<Path>,
     P2: AsRef<Path>,
 {
-    let certs = {
+    let certs: Vec<_> = {
         let certs_path_str = certs_path.as_ref().display().to_string();
         let mut reader = BufReader::new(File::open(certs_path).map_err(|e| {
             io::Error::new(
@@ -31,18 +31,21 @@ where
                 format!(
                     "Unable to load the certificates [{}]: {}",
                     certs_path_str,
-                    e.to_string()
+                    e
                 ),
             )
         })?);
-        pemfile::certs(&mut reader).map_err(|_| {
+        rustls_pemfile::certs(&mut reader).map_err(|_| {
             io::Error::new(
                 io::ErrorKind::InvalidInput,
                 "Unable to parse the certificates",
             )
         })?
-    };
-    let certs_keys = {
+    }
+    .drain(..)
+    .map(Certificate)
+    .collect();
+    let certs_keys: Vec<_> = {
         let certs_keys_path_str = certs_keys_path.as_ref().display().to_string();
         let encoded_keys = {
             let mut encoded_keys = vec![];
@@ -53,7 +56,7 @@ where
                         format!(
                             "Unable to load the certificate keys [{}]: {}",
                             certs_keys_path_str,
-                            e.to_string()
+                            e
                         ),
                     )
                 })?
@@ -61,14 +64,14 @@ where
             encoded_keys
         };
         let mut reader = Cursor::new(encoded_keys);
-        let pkcs8_keys = pemfile::pkcs8_private_keys(&mut reader).map_err(|_| {
+        let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| {
             io::Error::new(
                 io::ErrorKind::InvalidInput,
                 "Unable to parse the certificates private keys (PKCS8)",
             )
         })?;
         reader.set_position(0);
-        let mut rsa_keys = pemfile::rsa_private_keys(&mut reader).map_err(|_| {
+        let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| {
             io::Error::new(
                 io::ErrorKind::InvalidInput,
                 "Unable to parse the certificates private keys (RSA)",
@@ -82,21 +85,26 @@ where
                 "No private keys found - Make sure that they are in PKCS#8/PEM format",
             ));
         }
-        keys
+        keys.drain(..).map(PrivateKey).collect()
     };
-    let mut server_config = ServerConfig::new(NoClientAuth::new());
-    server_config.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]);
-    let has_valid_cert_and_key = certs_keys.into_iter().any(|certs_key| {
-        server_config
-            .set_single_cert(certs.clone(), certs_key)
-            .is_ok()
-    });
-    if !has_valid_cert_and_key {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidInput,
-            "Invalid private key for the given certificate",
-        ));
+
+    let mut server_config = None;
+    for certs_key in certs_keys {
+        let server_config_builder = ServerConfig::builder()
+            .with_safe_defaults()
+            .with_no_client_auth();
+        if let Ok(found_config) = server_config_builder.with_single_cert(certs.clone(), certs_key) {
+            server_config = Some(found_config);
+            break;
+        }
     }
+    let mut server_config = server_config.ok_or_else(|| {
+        io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "Unable to find a valid certificate and key",
+        )
+    })?;
+    server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
     Ok(TlsAcceptor::from(Arc::new(server_config)))
 }