Replace the clients count Mutex with an Atomic

This commit is contained in:
Frank Denis 2018-07-07 22:14:41 +02:00
parent eddb36b541
commit bc925cc2d5

View file

@ -17,6 +17,7 @@ use hyper::server::conn::Http;
use hyper::service::Service; use hyper::service::Service;
use hyper::{Body, Method, Request, Response, StatusCode}; use hyper::{Body, Method, Request, Response, StatusCode};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::executor::current_thread; use tokio::executor::current_thread;
@ -26,7 +27,7 @@ use tokio_timer::{timer, Timer};
const DNS_QUERY_PARAM: &str = "dns"; const DNS_QUERY_PARAM: &str = "dns";
const LISTEN_ADDRESS: &str = "127.0.0.1:3000"; const LISTEN_ADDRESS: &str = "127.0.0.1:3000";
const LOCAL_BIND_ADDRESS: &str = "0.0.0.0:0"; const LOCAL_BIND_ADDRESS: &str = "0.0.0.0:0";
const MAX_CLIENTS: u32 = 512; const MAX_CLIENTS: usize = 512;
const MAX_DNS_QUESTION_LEN: usize = 512; const MAX_DNS_QUESTION_LEN: usize = 512;
const MAX_DNS_RESPONSE_LEN: usize = 4096; const MAX_DNS_RESPONSE_LEN: usize = 4096;
const MIN_DNS_PACKET_LEN: usize = 17; const MIN_DNS_PACKET_LEN: usize = 17;
@ -43,10 +44,10 @@ struct DoH {
local_bind_address: SocketAddr, local_bind_address: SocketAddr,
server_address: SocketAddr, server_address: SocketAddr,
path: String, path: String,
max_clients: u32, max_clients: usize,
timeout: Duration, timeout: Duration,
timer_handle: timer::Handle, timer_handle: timer::Handle,
clients_count: Arc<Mutex<u32>>, clients_count: Arc<AtomicUsize>,
} }
#[derive(Debug)] #[derive(Debug)]
@ -72,21 +73,21 @@ impl Service for DoH {
fn call(&mut self, req: Request<Body>) -> Self::Future { fn call(&mut self, req: Request<Body>) -> Self::Future {
{ {
let count = self.clients_count.lock().unwrap(); let count = self.clients_count.fetch_add(1, Ordering::Relaxed);
if *count > self.max_clients { if count > self.max_clients {
self.clients_count.fetch_sub(1, Ordering::Relaxed);
let response = Response::builder() let response = Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS) .status(StatusCode::TOO_MANY_REQUESTS)
.body(Body::empty()) .body(Body::empty())
.unwrap(); .unwrap();
return Box::new(future::ok(response)); return Box::new(future::ok(response));
} }
(*count).saturating_add(1);
} }
let clients_count_inner = self.clients_count.clone(); let clients_count_inner = self.clients_count.clone();
let fut = self let fut = self
.handle_client(req) .handle_client(req)
.then(move |fut| { .then(move |fut| {
(*clients_count_inner).lock().unwrap().saturating_sub(1); clients_count_inner.fetch_sub(1, Ordering::Relaxed);
fut fut
}) })
.map_err(|err| { .map_err(|err| {
@ -225,7 +226,7 @@ fn main() {
path: PATH.to_string(), path: PATH.to_string(),
max_clients: MAX_CLIENTS, max_clients: MAX_CLIENTS,
timeout: Duration::from_secs(TIMEOUT_SEC), timeout: Duration::from_secs(TIMEOUT_SEC),
clients_count: Arc::new(Mutex::new(0u32)), clients_count: Arc::new(AtomicUsize::new(0)),
timer_handle: Timer::default().handle(), timer_handle: Timer::default().handle(),
}; };
parse_opts(&mut doh); parse_opts(&mut doh);