From c75ebff95987554fc45c5a0dd78e1fe78f56a362 Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Tue, 24 Dec 2019 12:01:47 +0100 Subject: [PATCH] Don't trust Hyper's executor --- src/libdoh/src/globals.rs | 3 +++ src/libdoh/src/lib.rs | 30 +++++++++++++++++++++++++++--- src/main.rs | 13 ++++++++----- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/src/libdoh/src/globals.rs b/src/libdoh/src/globals.rs index 3877bdd..7f998d2 100644 --- a/src/libdoh/src/globals.rs +++ b/src/libdoh/src/globals.rs @@ -2,6 +2,7 @@ use std::net::SocketAddr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; +use tokio::runtime; #[cfg(feature = "tls")] use std::path::PathBuf; @@ -26,6 +27,8 @@ pub struct Globals { pub err_ttl: u32, pub keepalive: bool, pub disable_post: bool, + + pub runtime_handle: runtime::Handle, } #[derive(Debug, Clone, Default)] diff --git a/src/libdoh/src/lib.rs b/src/libdoh/src/lib.rs index 629477f..bd6bc5b 100644 --- a/src/libdoh/src/lib.rs +++ b/src/libdoh/src/lib.rs @@ -21,6 +21,7 @@ use std::pin::Pin; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, UdpSocket}; +use tokio::runtime; #[derive(Clone, Debug)] pub struct DoH { @@ -35,6 +36,27 @@ fn http_error(status_code: StatusCode) -> Result, http::Error> { Ok(response) } +#[derive(Clone, Debug)] +struct LocalExecutor { + runtime_handle: runtime::Handle, +} + +impl LocalExecutor { + fn new(runtime_handle: runtime::Handle) -> Self { + LocalExecutor { runtime_handle } + } +} + +impl hyper::rt::Executor for LocalExecutor +where + F: std::future::Future + Send + 'static, + F::Output: Send, +{ + fn execute(&self, fut: F) { + self.runtime_handle.spawn(fut); + } +} + impl hyper::service::Service> for DoH { type Response = Response; type Error = http::Error; @@ -188,7 +210,7 @@ impl DoH { Ok(response) } - async fn client_serve(self, stream: I, server: Http) + async fn client_serve(self, stream: I, server: Http) where I: AsyncRead + AsyncWrite + Send + Unpin + 'static, { @@ -197,7 +219,7 @@ impl DoH { clients_count.decrement(); return; } - tokio::spawn(async move { + self.globals.runtime_handle.clone().spawn(async move { tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self)) .await .ok(); @@ -208,7 +230,7 @@ impl DoH { async fn start_without_tls( self, mut listener: TcpListener, - server: Http, + server: Http, ) -> Result<(), DoHError> { let listener_service = async { while let Some(stream) = listener.incoming().next().await { @@ -250,6 +272,8 @@ impl DoH { let mut server = Http::new(); server.keep_alive(self.globals.keepalive); server.pipeline_flush(true); + let executor = LocalExecutor::new(self.globals.runtime_handle.clone()); + let server = server.with_executor(executor); #[cfg(feature = "tls")] { diff --git a/src/main.rs b/src/main.rs index 06086e9..92b1b86 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,12 @@ use std::sync::Arc; use std::time::Duration; fn main() { + let mut runtime_builder = tokio::runtime::Builder::new(); + runtime_builder.enable_all(); + runtime_builder.threaded_scheduler(); + runtime_builder.thread_name("doh-proxy"); + let mut runtime = runtime_builder.build().unwrap(); + let mut globals = Globals { #[cfg(feature = "tls")] tls_cert_path: None, @@ -36,15 +42,12 @@ fn main() { err_ttl: ERR_TTL, keepalive: true, disable_post: false, + + runtime_handle: runtime.handle().clone(), }; parse_opts(&mut globals); let doh = DoH { globals: Arc::new(globals), }; - let mut runtime_builder = tokio::runtime::Builder::new(); - runtime_builder.enable_all(); - runtime_builder.threaded_scheduler(); - runtime_builder.thread_name("doh-proxy"); - let mut runtime = runtime_builder.build().unwrap(); runtime.block_on(doh.entrypoint()).unwrap(); }