Automatically update the certificates without restarting

This commit is contained in:
Frank Denis 2021-02-15 21:04:01 +01:00
parent 0a99d0d212
commit a2f342379e
4 changed files with 59 additions and 23 deletions

View file

@ -19,7 +19,7 @@ tls = ["libdoh/tls"]
libdoh = { path = "src/libdoh", version = "0.3.7", default-features = false }
clap = "2.33.3"
jemallocator = "0.3.2"
tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time"] }
tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time", "sync"] }
[package.metadata.deb]
extended-description = """\

View file

@ -20,7 +20,7 @@ byteorder = "1.4.2"
base64 = "0.13.0"
futures = "0.3.12"
hyper = { version = "0.14.4", default-features = false, features = ["server", "http1", "http2", "stream"] }
tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time"] }
tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time", "sync"] }
tokio-rustls = { version = "0.22.0", features = ["early-data"], optional = true }
[profile.release]

View file

@ -12,6 +12,7 @@ pub use crate::globals::*;
#[cfg(feature = "tls")]
use crate::tls::*;
use futures::join;
use futures::prelude::*;
use futures::task::{Context, Poll};
use hyper::http;
@ -23,12 +24,14 @@ use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, UdpSocket};
use tokio::runtime;
use tokio::sync::mpsc;
#[derive(Clone, Debug)]
pub struct DoH {
pub globals: Arc<Globals>,
}
#[allow(clippy::unnecessary_wraps)]
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
let response = Response::builder()
.status(status_code)
@ -58,6 +61,7 @@ where
}
}
#[allow(clippy::type_complexity)]
impl hyper::service::Service<http::Request<Body>> for DoH {
type Response = Response<Body>;
type Error = http::Error;
@ -265,17 +269,17 @@ impl DoH {
.map_err(DoHError::Io)?;
let path = &self.globals.path;
#[cfg(feature = "tls")]
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_key_path) {
(Some(tls_cert_path), Some(tls_cert_key_path)) => {
Some(create_tls_acceptor(tls_cert_path, tls_cert_key_path).unwrap())
}
_ => None,
};
let tls_enabled: bool;
#[cfg(not(feature = "tls"))]
let tls_acceptor: Option<()> = None;
if tls_acceptor.is_some() {
{
tls_enabled = false;
}
#[cfg(feature = "tls")]
{
tls_enabled =
self.globals.tls_cert_path.is_some() && self.globals.tls_cert_key_path.is_some();
}
if tls_enabled {
println!("Listening on https://{}{}", listen_address, path);
} else {
println!("Listening on http://{}{}", listen_address, path);
@ -289,9 +293,26 @@ impl DoH {
#[cfg(feature = "tls")]
{
if let Some(tls_acceptor) = tls_acceptor {
self.start_with_tls(tls_acceptor, listener, server).await?;
return Ok(());
if tls_enabled {
let certs_path = self.globals.tls_cert_path.as_ref().unwrap().clone();
let certs_keys_path = self.globals.tls_cert_key_path.as_ref().unwrap().clone();
let (tls_acceptor_sender, tls_acceptor_receiver) = mpsc::channel(1);
let http_service = self.start_with_tls(tls_acceptor_receiver, listener, server);
let cert_service = async {
loop {
match create_tls_acceptor(&certs_path, &certs_keys_path) {
Ok(tls_acceptor) => {
if tls_acceptor_sender.send(tls_acceptor).await.is_err() {
break;
}
}
Err(e) => eprintln!("TLS certificates error: {}", e),
}
tokio::time::sleep(Duration::from_secs(5)).await;
}
Ok::<_, DoHError>(())
};
return join!(http_service, cert_service).0;
}
}
self.start_without_tls(listener, server).await?;

View file

@ -1,12 +1,13 @@
use crate::errors::*;
use crate::{DoH, LocalExecutor};
use futures::{future::FutureExt, select};
use hyper::server::conn::Http;
use std::fs::File;
use std::io::{self, BufReader, Cursor, Read};
use std::path::Path;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::{net::TcpListener, sync::mpsc::Receiver};
use tokio_rustls::{
rustls::{internal::pemfile, NoClientAuth, ServerConfig},
TlsAcceptor,
@ -97,17 +98,31 @@ where
impl DoH {
pub async fn start_with_tls(
self,
tls_acceptor: TlsAcceptor,
mut tls_acceptor_receiver: Receiver<TlsAcceptor>,
listener: TcpListener,
server: Http<LocalExecutor>,
) -> Result<(), DoHError> {
let mut tls_acceptor: Option<TlsAcceptor> = None;
let listener_service = async {
while let Ok((raw_stream, _client_addr)) = listener.accept().await {
let stream = match tls_acceptor.accept(raw_stream).await {
Ok(stream) => stream,
Err(_) => continue,
};
self.clone().client_serve(stream, server.clone()).await;
loop {
select! {
tcp_cnx = listener.accept().fuse() => {
if tls_acceptor.is_none() || tcp_cnx.is_err() {
continue;
}
let (raw_stream, _client_addr) = tcp_cnx.unwrap();
if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await {
self.clone().client_serve(stream, server.clone()).await
}
}
new_tls_acceptor = tls_acceptor_receiver.recv().fuse() => {
if new_tls_acceptor.is_none() {
break;
}
tls_acceptor = new_tls_acceptor;
}
complete => break
}
}
Ok(()) as Result<(), DoHError>
};