Automatically update the certificates without restarting

This commit is contained in:
Frank Denis 2021-02-15 21:04:01 +01:00
parent 0a99d0d212
commit a2f342379e
4 changed files with 59 additions and 23 deletions

View file

@ -19,7 +19,7 @@ tls = ["libdoh/tls"]
libdoh = { path = "src/libdoh", version = "0.3.7", default-features = false } libdoh = { path = "src/libdoh", version = "0.3.7", default-features = false }
clap = "2.33.3" clap = "2.33.3"
jemallocator = "0.3.2" jemallocator = "0.3.2"
tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time"] } tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time", "sync"] }
[package.metadata.deb] [package.metadata.deb]
extended-description = """\ extended-description = """\

View file

@ -20,7 +20,7 @@ byteorder = "1.4.2"
base64 = "0.13.0" base64 = "0.13.0"
futures = "0.3.12" futures = "0.3.12"
hyper = { version = "0.14.4", default-features = false, features = ["server", "http1", "http2", "stream"] } hyper = { version = "0.14.4", default-features = false, features = ["server", "http1", "http2", "stream"] }
tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time"] } tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time", "sync"] }
tokio-rustls = { version = "0.22.0", features = ["early-data"], optional = true } tokio-rustls = { version = "0.22.0", features = ["early-data"], optional = true }
[profile.release] [profile.release]

View file

@ -12,6 +12,7 @@ pub use crate::globals::*;
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
use crate::tls::*; use crate::tls::*;
use futures::join;
use futures::prelude::*; use futures::prelude::*;
use futures::task::{Context, Poll}; use futures::task::{Context, Poll};
use hyper::http; use hyper::http;
@ -23,12 +24,14 @@ use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, UdpSocket}; use tokio::net::{TcpListener, UdpSocket};
use tokio::runtime; use tokio::runtime;
use tokio::sync::mpsc;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct DoH { pub struct DoH {
pub globals: Arc<Globals>, pub globals: Arc<Globals>,
} }
#[allow(clippy::unnecessary_wraps)]
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> { fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
let response = Response::builder() let response = Response::builder()
.status(status_code) .status(status_code)
@ -58,6 +61,7 @@ where
} }
} }
#[allow(clippy::type_complexity)]
impl hyper::service::Service<http::Request<Body>> for DoH { impl hyper::service::Service<http::Request<Body>> for DoH {
type Response = Response<Body>; type Response = Response<Body>;
type Error = http::Error; type Error = http::Error;
@ -265,17 +269,17 @@ impl DoH {
.map_err(DoHError::Io)?; .map_err(DoHError::Io)?;
let path = &self.globals.path; let path = &self.globals.path;
#[cfg(feature = "tls")] let tls_enabled: bool;
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_key_path) {
(Some(tls_cert_path), Some(tls_cert_key_path)) => {
Some(create_tls_acceptor(tls_cert_path, tls_cert_key_path).unwrap())
}
_ => None,
};
#[cfg(not(feature = "tls"))] #[cfg(not(feature = "tls"))]
let tls_acceptor: Option<()> = None; {
tls_enabled = false;
if tls_acceptor.is_some() { }
#[cfg(feature = "tls")]
{
tls_enabled =
self.globals.tls_cert_path.is_some() && self.globals.tls_cert_key_path.is_some();
}
if tls_enabled {
println!("Listening on https://{}{}", listen_address, path); println!("Listening on https://{}{}", listen_address, path);
} else { } else {
println!("Listening on http://{}{}", listen_address, path); println!("Listening on http://{}{}", listen_address, path);
@ -289,9 +293,26 @@ impl DoH {
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
{ {
if let Some(tls_acceptor) = tls_acceptor { if tls_enabled {
self.start_with_tls(tls_acceptor, listener, server).await?; let certs_path = self.globals.tls_cert_path.as_ref().unwrap().clone();
return Ok(()); let certs_keys_path = self.globals.tls_cert_key_path.as_ref().unwrap().clone();
let (tls_acceptor_sender, tls_acceptor_receiver) = mpsc::channel(1);
let http_service = self.start_with_tls(tls_acceptor_receiver, listener, server);
let cert_service = async {
loop {
match create_tls_acceptor(&certs_path, &certs_keys_path) {
Ok(tls_acceptor) => {
if tls_acceptor_sender.send(tls_acceptor).await.is_err() {
break;
}
}
Err(e) => eprintln!("TLS certificates error: {}", e),
}
tokio::time::sleep(Duration::from_secs(5)).await;
}
Ok::<_, DoHError>(())
};
return join!(http_service, cert_service).0;
} }
} }
self.start_without_tls(listener, server).await?; self.start_without_tls(listener, server).await?;

View file

@ -1,12 +1,13 @@
use crate::errors::*; use crate::errors::*;
use crate::{DoH, LocalExecutor}; use crate::{DoH, LocalExecutor};
use futures::{future::FutureExt, select};
use hyper::server::conn::Http; use hyper::server::conn::Http;
use std::fs::File; use std::fs::File;
use std::io::{self, BufReader, Cursor, Read}; use std::io::{self, BufReader, Cursor, Read};
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use tokio::net::TcpListener; use tokio::{net::TcpListener, sync::mpsc::Receiver};
use tokio_rustls::{ use tokio_rustls::{
rustls::{internal::pemfile, NoClientAuth, ServerConfig}, rustls::{internal::pemfile, NoClientAuth, ServerConfig},
TlsAcceptor, TlsAcceptor,
@ -97,17 +98,31 @@ where
impl DoH { impl DoH {
pub async fn start_with_tls( pub async fn start_with_tls(
self, self,
tls_acceptor: TlsAcceptor, mut tls_acceptor_receiver: Receiver<TlsAcceptor>,
listener: TcpListener, listener: TcpListener,
server: Http<LocalExecutor>, server: Http<LocalExecutor>,
) -> Result<(), DoHError> { ) -> Result<(), DoHError> {
let mut tls_acceptor: Option<TlsAcceptor> = None;
let listener_service = async { let listener_service = async {
while let Ok((raw_stream, _client_addr)) = listener.accept().await { loop {
let stream = match tls_acceptor.accept(raw_stream).await { select! {
Ok(stream) => stream, tcp_cnx = listener.accept().fuse() => {
Err(_) => continue, if tls_acceptor.is_none() || tcp_cnx.is_err() {
}; continue;
self.clone().client_serve(stream, server.clone()).await; }
let (raw_stream, _client_addr) = tcp_cnx.unwrap();
if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await {
self.clone().client_serve(stream, server.clone()).await
}
}
new_tls_acceptor = tls_acceptor_receiver.recv().fuse() => {
if new_tls_acceptor.is_none() {
break;
}
tls_acceptor = new_tls_acceptor;
}
complete => break
}
} }
Ok(()) as Result<(), DoHError> Ok(()) as Result<(), DoHError>
}; };