mirror of
https://github.com/DNSCrypt/doh-server.git
synced 2025-04-05 22:17:38 +03:00
Don't trust Hyper's executor
This commit is contained in:
parent
4f2846966e
commit
c75ebff959
3 changed files with 38 additions and 8 deletions
|
@ -2,6 +2,7 @@ use std::net::SocketAddr;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use tokio::runtime;
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
@ -26,6 +27,8 @@ pub struct Globals {
|
||||||
pub err_ttl: u32,
|
pub err_ttl: u32,
|
||||||
pub keepalive: bool,
|
pub keepalive: bool,
|
||||||
pub disable_post: bool,
|
pub disable_post: bool,
|
||||||
|
|
||||||
|
pub runtime_handle: runtime::Handle,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
|
|
|
@ -21,6 +21,7 @@ use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tokio::net::{TcpListener, UdpSocket};
|
use tokio::net::{TcpListener, UdpSocket};
|
||||||
|
use tokio::runtime;
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct DoH {
|
pub struct DoH {
|
||||||
|
@ -35,6 +36,27 @@ fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct LocalExecutor {
|
||||||
|
runtime_handle: runtime::Handle,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LocalExecutor {
|
||||||
|
fn new(runtime_handle: runtime::Handle) -> Self {
|
||||||
|
LocalExecutor { runtime_handle }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> hyper::rt::Executor<F> for LocalExecutor
|
||||||
|
where
|
||||||
|
F: std::future::Future + Send + 'static,
|
||||||
|
F::Output: Send,
|
||||||
|
{
|
||||||
|
fn execute(&self, fut: F) {
|
||||||
|
self.runtime_handle.spawn(fut);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl hyper::service::Service<http::Request<Body>> for DoH {
|
impl hyper::service::Service<http::Request<Body>> for DoH {
|
||||||
type Response = Response<Body>;
|
type Response = Response<Body>;
|
||||||
type Error = http::Error;
|
type Error = http::Error;
|
||||||
|
@ -188,7 +210,7 @@ impl DoH {
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn client_serve<I>(self, stream: I, server: Http)
|
async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>)
|
||||||
where
|
where
|
||||||
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||||
{
|
{
|
||||||
|
@ -197,7 +219,7 @@ impl DoH {
|
||||||
clients_count.decrement();
|
clients_count.decrement();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
tokio::spawn(async move {
|
self.globals.runtime_handle.clone().spawn(async move {
|
||||||
tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self))
|
tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self))
|
||||||
.await
|
.await
|
||||||
.ok();
|
.ok();
|
||||||
|
@ -208,7 +230,7 @@ impl DoH {
|
||||||
async fn start_without_tls(
|
async fn start_without_tls(
|
||||||
self,
|
self,
|
||||||
mut listener: TcpListener,
|
mut listener: TcpListener,
|
||||||
server: Http,
|
server: Http<LocalExecutor>,
|
||||||
) -> Result<(), DoHError> {
|
) -> Result<(), DoHError> {
|
||||||
let listener_service = async {
|
let listener_service = async {
|
||||||
while let Some(stream) = listener.incoming().next().await {
|
while let Some(stream) = listener.incoming().next().await {
|
||||||
|
@ -250,6 +272,8 @@ impl DoH {
|
||||||
let mut server = Http::new();
|
let mut server = Http::new();
|
||||||
server.keep_alive(self.globals.keepalive);
|
server.keep_alive(self.globals.keepalive);
|
||||||
server.pipeline_flush(true);
|
server.pipeline_flush(true);
|
||||||
|
let executor = LocalExecutor::new(self.globals.runtime_handle.clone());
|
||||||
|
let server = server.with_executor(executor);
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
{
|
{
|
||||||
|
|
13
src/main.rs
13
src/main.rs
|
@ -18,6 +18,12 @@ use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
let mut runtime_builder = tokio::runtime::Builder::new();
|
||||||
|
runtime_builder.enable_all();
|
||||||
|
runtime_builder.threaded_scheduler();
|
||||||
|
runtime_builder.thread_name("doh-proxy");
|
||||||
|
let mut runtime = runtime_builder.build().unwrap();
|
||||||
|
|
||||||
let mut globals = Globals {
|
let mut globals = Globals {
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
tls_cert_path: None,
|
tls_cert_path: None,
|
||||||
|
@ -36,15 +42,12 @@ fn main() {
|
||||||
err_ttl: ERR_TTL,
|
err_ttl: ERR_TTL,
|
||||||
keepalive: true,
|
keepalive: true,
|
||||||
disable_post: false,
|
disable_post: false,
|
||||||
|
|
||||||
|
runtime_handle: runtime.handle().clone(),
|
||||||
};
|
};
|
||||||
parse_opts(&mut globals);
|
parse_opts(&mut globals);
|
||||||
let doh = DoH {
|
let doh = DoH {
|
||||||
globals: Arc::new(globals),
|
globals: Arc::new(globals),
|
||||||
};
|
};
|
||||||
let mut runtime_builder = tokio::runtime::Builder::new();
|
|
||||||
runtime_builder.enable_all();
|
|
||||||
runtime_builder.threaded_scheduler();
|
|
||||||
runtime_builder.thread_name("doh-proxy");
|
|
||||||
let mut runtime = runtime_builder.build().unwrap();
|
|
||||||
runtime.block_on(doh.entrypoint()).unwrap();
|
runtime.block_on(doh.entrypoint()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue