Don't trust Hyper's executor

This commit is contained in:
Frank Denis 2019-12-24 12:01:47 +01:00
parent 4f2846966e
commit c75ebff959
3 changed files with 38 additions and 8 deletions

View file

@ -2,6 +2,7 @@ use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime;
#[cfg(feature = "tls")]
use std::path::PathBuf;
@ -26,6 +27,8 @@ pub struct Globals {
pub err_ttl: u32,
pub keepalive: bool,
pub disable_post: bool,
pub runtime_handle: runtime::Handle,
}
#[derive(Debug, Clone, Default)]

View file

@ -21,6 +21,7 @@ use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, UdpSocket};
use tokio::runtime;
#[derive(Clone, Debug)]
pub struct DoH {
@ -35,6 +36,27 @@ fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
Ok(response)
}
#[derive(Clone, Debug)]
struct LocalExecutor {
runtime_handle: runtime::Handle,
}
impl LocalExecutor {
fn new(runtime_handle: runtime::Handle) -> Self {
LocalExecutor { runtime_handle }
}
}
impl<F> hyper::rt::Executor<F> for LocalExecutor
where
F: std::future::Future + Send + 'static,
F::Output: Send,
{
fn execute(&self, fut: F) {
self.runtime_handle.spawn(fut);
}
}
impl hyper::service::Service<http::Request<Body>> for DoH {
type Response = Response<Body>;
type Error = http::Error;
@ -188,7 +210,7 @@ impl DoH {
Ok(response)
}
async fn client_serve<I>(self, stream: I, server: Http)
async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>)
where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
@ -197,7 +219,7 @@ impl DoH {
clients_count.decrement();
return;
}
tokio::spawn(async move {
self.globals.runtime_handle.clone().spawn(async move {
tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self))
.await
.ok();
@ -208,7 +230,7 @@ impl DoH {
async fn start_without_tls(
self,
mut listener: TcpListener,
server: Http,
server: Http<LocalExecutor>,
) -> Result<(), DoHError> {
let listener_service = async {
while let Some(stream) = listener.incoming().next().await {
@ -250,6 +272,8 @@ impl DoH {
let mut server = Http::new();
server.keep_alive(self.globals.keepalive);
server.pipeline_flush(true);
let executor = LocalExecutor::new(self.globals.runtime_handle.clone());
let server = server.with_executor(executor);
#[cfg(feature = "tls")]
{

View file

@ -18,6 +18,12 @@ use std::sync::Arc;
use std::time::Duration;
fn main() {
let mut runtime_builder = tokio::runtime::Builder::new();
runtime_builder.enable_all();
runtime_builder.threaded_scheduler();
runtime_builder.thread_name("doh-proxy");
let mut runtime = runtime_builder.build().unwrap();
let mut globals = Globals {
#[cfg(feature = "tls")]
tls_cert_path: None,
@ -36,15 +42,12 @@ fn main() {
err_ttl: ERR_TTL,
keepalive: true,
disable_post: false,
runtime_handle: runtime.handle().clone(),
};
parse_opts(&mut globals);
let doh = DoH {
globals: Arc::new(globals),
};
let mut runtime_builder = tokio::runtime::Builder::new();
runtime_builder.enable_all();
runtime_builder.threaded_scheduler();
runtime_builder.thread_name("doh-proxy");
let mut runtime = runtime_builder.build().unwrap();
runtime.block_on(doh.entrypoint()).unwrap();
}