Get rid of the Mutex

This commit is contained in:
Frank Denis 2018-07-07 22:29:41 +02:00
parent bc925cc2d5
commit 82630f4a31

View file

@ -18,7 +18,7 @@ use hyper::service::Service;
use hyper::{Body, Method, Request, Response, StatusCode};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::executor::current_thread;
use tokio::net::{TcpListener, UdpSocket};
@ -38,8 +38,8 @@ const MAX_TTL: u32 = 86400 * 7;
const MIN_TTL: u32 = 1;
const ERR_TTL: u32 = 1;
#[derive(Clone, Debug)]
struct DoH {
#[derive(Debug)]
struct InnerDoH {
listen_address: SocketAddr,
local_bind_address: SocketAddr,
server_address: SocketAddr,
@ -50,6 +50,11 @@ struct DoH {
clients_count: Arc<AtomicUsize>,
}
#[derive(Clone, Debug)]
struct DoH {
inner: Arc<InnerDoH>,
}
#[derive(Debug)]
enum Error {
Timeout,
@ -72,10 +77,11 @@ impl Service for DoH {
type Future = Box<Future<Item = Response<Body>, Error = Self::Error> + Send>;
fn call(&mut self, req: Request<Body>) -> Self::Future {
let inner = &self.inner;
{
let count = self.clients_count.fetch_add(1, Ordering::Relaxed);
if count > self.max_clients {
self.clients_count.fetch_sub(1, Ordering::Relaxed);
let count = inner.clients_count.fetch_add(1, Ordering::Relaxed);
if count > inner.max_clients {
inner.clients_count.fetch_sub(1, Ordering::Relaxed);
let response = Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.body(Body::empty())
@ -83,7 +89,7 @@ impl Service for DoH {
return Box::new(future::ok(response));
}
}
let clients_count_inner = self.clients_count.clone();
let clients_count_inner = inner.clients_count.clone();
let fut = self
.handle_client(req)
.then(move |fut| {
@ -94,9 +100,9 @@ impl Service for DoH {
eprintln!("server error: {}", err);
err
});
let timed = self
let timed = inner
.timer_handle
.deadline(fut.map_err(|_| {}), Instant::now() + self.timeout)
.deadline(fut.map_err(|_| {}), Instant::now() + inner.timeout)
.map_err(|_| Error::Timeout);
Box::new(timed)
}
@ -107,7 +113,8 @@ impl DoH {
&self,
req: Request<Body>,
) -> Box<Future<Item = Response<Body>, Error = Error> + Send> {
if req.uri().path() != self.path {
let inner = &self.inner;
if req.uri().path() != inner.path {
let response = Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
@ -156,10 +163,11 @@ impl DoH {
}
fn proxy(&self, query: Vec<u8>) -> Box<Future<Item = Response<Body>, Error = ()> + Send> {
let socket = UdpSocket::bind(&self.local_bind_address).unwrap();
let expected_server_address = self.server_address;
let inner = &self.inner;
let socket = UdpSocket::bind(&inner.local_bind_address).unwrap();
let expected_server_address = inner.server_address;
let fut = socket
.send_dgram(query, &self.server_address)
.send_dgram(query, &inner.server_address)
.map_err(|_| ())
.and_then(move |(socket, _)| {
let packet = vec![0; MAX_DNS_RESPONSE_LEN];
@ -219,7 +227,7 @@ impl DoH {
}
fn main() {
let mut doh = DoH {
let mut inner_doh = InnerDoH {
listen_address: LISTEN_ADDRESS.parse().unwrap(),
local_bind_address: LOCAL_BIND_ADDRESS.parse().unwrap(),
server_address: SERVER_ADDRESS.parse().unwrap(),
@ -229,15 +237,17 @@ fn main() {
clients_count: Arc::new(AtomicUsize::new(0)),
timer_handle: Timer::default().handle(),
};
parse_opts(&mut doh);
let listen_address = doh.listen_address;
parse_opts(&mut inner_doh);
let doh = DoH {
inner: Arc::new(inner_doh),
};
let listen_address = doh.inner.listen_address;
let listener = TcpListener::bind(&listen_address).unwrap();
println!("Listening on http://{}", listen_address);
let mut http = Http::new();
http.keep_alive(false);
let doh = Arc::new(Mutex::new(doh));
let server = listener.incoming().for_each(move |io| {
let service = doh.lock().unwrap().clone();
let service = doh.clone();
let conn = http.serve_connection(io, service).map_err(|_| {});
current_thread::spawn(conn);
Ok(())
@ -245,7 +255,7 @@ fn main() {
current_thread::block_on_all(server).unwrap();
}
fn parse_opts(doh: &mut DoH) {
fn parse_opts(inner_doh: &mut InnerDoH) {
let max_clients = MAX_CLIENTS.to_string();
let timeout_sec = TIMEOUT_SEC.to_string();
let matches = App::new("doh-proxy")
@ -300,18 +310,18 @@ fn parse_opts(doh: &mut DoH) {
)
.get_matches();
if let Some(listen_address) = matches.value_of("listen_address") {
doh.listen_address = listen_address.parse().unwrap();
inner_doh.listen_address = listen_address.parse().unwrap();
}
if let Some(server_address) = matches.value_of("server_address") {
doh.server_address = server_address.parse().unwrap();
inner_doh.server_address = server_address.parse().unwrap();
}
if let Some(local_bind_address) = matches.value_of("local_bind_address") {
doh.local_bind_address = local_bind_address.parse().unwrap();
inner_doh.local_bind_address = local_bind_address.parse().unwrap();
}
if let Some(max_clients) = matches.value_of("max_clients") {
doh.max_clients = max_clients.parse().unwrap();
inner_doh.max_clients = max_clients.parse().unwrap();
}
if let Some(timeout) = matches.value_of("timeout") {
doh.timeout = Duration::from_secs(timeout.parse().unwrap());
inner_doh.timeout = Duration::from_secs(timeout.parse().unwrap());
}
}