diff --git a/src/main.rs b/src/main.rs index 317bab2..e952c5e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,8 @@ use futures::prelude::*; use futures::future; use tokio_core::reactor::Core; use tokio_core::net::UdpSocket; +use std::cell::RefCell; +use std::rc::Rc; const TIMEOUT_SEC: u64 = 10; const LOCAL_ADDRESS: &str = "127.0.0.1:3000"; @@ -26,11 +28,11 @@ const SERVER_ADDRESS: &str = "9.9.9.9:53"; const MIN_DNS_PACKET_LEN: usize = 17; const MAX_DNS_QUESTION_LEN: usize = 512; const MAX_DNS_RESPONSE_LEN: usize = 4096; +const MAX_CLIENTS: u32 = 512; #[derive(Clone, Debug)] struct DoH { handle: Handle, - timers: tokio_timer::Timer, } impl Service for DoH { @@ -52,9 +54,7 @@ impl Service for DoH { "application/dns-udpwireformat".parse().unwrap(), )) }); - let timed = self.timers - .timeout(fut.map_err(|_| ()), Duration::from_secs(TIMEOUT_SEC)); - return Box::new(timed.map_err(|_| hyper::Error::Timeout)); + return Box::new(fut.map_err(|_| hyper::Error::Incomplete)); } (&Method::Post, _) => { response.set_status(StatusCode::NotFound); @@ -96,33 +96,40 @@ impl DoH { fn main() { let mut core = Core::new().unwrap(); let handle = core.handle(); - let addr = LOCAL_ADDRESS.parse().unwrap(); let handle_inner = handle.clone(); - let timers = tokio_timer::wheel().build(); - let server = Http::new() .keep_alive(false) .max_buf_size(MAX_DNS_QUESTION_LEN) .serve_addr_handle(&addr, &handle, move || { Ok(DoH { handle: handle_inner.clone(), - timers: timers.clone(), }) }) .unwrap(); println!("Listening on http://{}", server.incoming_ref().local_addr()); let handle_inner = handle.clone(); - handle.spawn( - server - .for_each(move |conn| { - handle_inner.spawn( - conn.map(|_| ()) - .map_err(|err| eprintln!("server error: {:?}", err)), - ); - Ok(()) + let timers = tokio_timer::wheel().build(); + let client_count = Rc::new(RefCell::new(0u32)); + let fut = server.for_each(move |client_fut| { + let client_count_inner = client_count.clone(); + { + let count = client_count_inner.borrow_mut(); + if *count > MAX_CLIENTS { + return Ok(()); + } + (*count).saturating_add(1); + } + let timers_inner = timers.clone(); + let fut = client_fut + .map(move |_| { + (*client_count_inner.borrow_mut()).saturating_sub(1); }) - .map_err(|_| ()), - ); + .map_err(|err| eprintln!("server error: {:?}", err)); + let timed = timers_inner.timeout(fut, Duration::from_secs(TIMEOUT_SEC)); + handle_inner.spawn(timed); + Ok(()) + }); + handle.spawn(fut.map_err(|_| ())); core.run(futures::future::empty::<(), ()>()).unwrap(); }