mirror of
https://github.com/DNSCrypt/doh-server.git
synced 2025-04-03 04:57:37 +03:00
Get rid of the Mutex
This commit is contained in:
parent
bc925cc2d5
commit
82630f4a31
1 changed files with 34 additions and 24 deletions
58
src/main.rs
58
src/main.rs
|
@ -18,7 +18,7 @@ use hyper::service::Service;
|
|||
use hyper::{Body, Method, Request, Response, StatusCode};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::executor::current_thread;
|
||||
use tokio::net::{TcpListener, UdpSocket};
|
||||
|
@ -38,8 +38,8 @@ const MAX_TTL: u32 = 86400 * 7;
|
|||
const MIN_TTL: u32 = 1;
|
||||
const ERR_TTL: u32 = 1;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct DoH {
|
||||
#[derive(Debug)]
|
||||
struct InnerDoH {
|
||||
listen_address: SocketAddr,
|
||||
local_bind_address: SocketAddr,
|
||||
server_address: SocketAddr,
|
||||
|
@ -50,6 +50,11 @@ struct DoH {
|
|||
clients_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct DoH {
|
||||
inner: Arc<InnerDoH>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Error {
|
||||
Timeout,
|
||||
|
@ -72,10 +77,11 @@ impl Service for DoH {
|
|||
type Future = Box<Future<Item = Response<Body>, Error = Self::Error> + Send>;
|
||||
|
||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||
let inner = &self.inner;
|
||||
{
|
||||
let count = self.clients_count.fetch_add(1, Ordering::Relaxed);
|
||||
if count > self.max_clients {
|
||||
self.clients_count.fetch_sub(1, Ordering::Relaxed);
|
||||
let count = inner.clients_count.fetch_add(1, Ordering::Relaxed);
|
||||
if count > inner.max_clients {
|
||||
inner.clients_count.fetch_sub(1, Ordering::Relaxed);
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::TOO_MANY_REQUESTS)
|
||||
.body(Body::empty())
|
||||
|
@ -83,7 +89,7 @@ impl Service for DoH {
|
|||
return Box::new(future::ok(response));
|
||||
}
|
||||
}
|
||||
let clients_count_inner = self.clients_count.clone();
|
||||
let clients_count_inner = inner.clients_count.clone();
|
||||
let fut = self
|
||||
.handle_client(req)
|
||||
.then(move |fut| {
|
||||
|
@ -94,9 +100,9 @@ impl Service for DoH {
|
|||
eprintln!("server error: {}", err);
|
||||
err
|
||||
});
|
||||
let timed = self
|
||||
let timed = inner
|
||||
.timer_handle
|
||||
.deadline(fut.map_err(|_| {}), Instant::now() + self.timeout)
|
||||
.deadline(fut.map_err(|_| {}), Instant::now() + inner.timeout)
|
||||
.map_err(|_| Error::Timeout);
|
||||
Box::new(timed)
|
||||
}
|
||||
|
@ -107,7 +113,8 @@ impl DoH {
|
|||
&self,
|
||||
req: Request<Body>,
|
||||
) -> Box<Future<Item = Response<Body>, Error = Error> + Send> {
|
||||
if req.uri().path() != self.path {
|
||||
let inner = &self.inner;
|
||||
if req.uri().path() != inner.path {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::empty())
|
||||
|
@ -156,10 +163,11 @@ impl DoH {
|
|||
}
|
||||
|
||||
fn proxy(&self, query: Vec<u8>) -> Box<Future<Item = Response<Body>, Error = ()> + Send> {
|
||||
let socket = UdpSocket::bind(&self.local_bind_address).unwrap();
|
||||
let expected_server_address = self.server_address;
|
||||
let inner = &self.inner;
|
||||
let socket = UdpSocket::bind(&inner.local_bind_address).unwrap();
|
||||
let expected_server_address = inner.server_address;
|
||||
let fut = socket
|
||||
.send_dgram(query, &self.server_address)
|
||||
.send_dgram(query, &inner.server_address)
|
||||
.map_err(|_| ())
|
||||
.and_then(move |(socket, _)| {
|
||||
let packet = vec![0; MAX_DNS_RESPONSE_LEN];
|
||||
|
@ -219,7 +227,7 @@ impl DoH {
|
|||
}
|
||||
|
||||
fn main() {
|
||||
let mut doh = DoH {
|
||||
let mut inner_doh = InnerDoH {
|
||||
listen_address: LISTEN_ADDRESS.parse().unwrap(),
|
||||
local_bind_address: LOCAL_BIND_ADDRESS.parse().unwrap(),
|
||||
server_address: SERVER_ADDRESS.parse().unwrap(),
|
||||
|
@ -229,15 +237,17 @@ fn main() {
|
|||
clients_count: Arc::new(AtomicUsize::new(0)),
|
||||
timer_handle: Timer::default().handle(),
|
||||
};
|
||||
parse_opts(&mut doh);
|
||||
let listen_address = doh.listen_address;
|
||||
parse_opts(&mut inner_doh);
|
||||
let doh = DoH {
|
||||
inner: Arc::new(inner_doh),
|
||||
};
|
||||
let listen_address = doh.inner.listen_address;
|
||||
let listener = TcpListener::bind(&listen_address).unwrap();
|
||||
println!("Listening on http://{}", listen_address);
|
||||
let mut http = Http::new();
|
||||
http.keep_alive(false);
|
||||
let doh = Arc::new(Mutex::new(doh));
|
||||
let server = listener.incoming().for_each(move |io| {
|
||||
let service = doh.lock().unwrap().clone();
|
||||
let service = doh.clone();
|
||||
let conn = http.serve_connection(io, service).map_err(|_| {});
|
||||
current_thread::spawn(conn);
|
||||
Ok(())
|
||||
|
@ -245,7 +255,7 @@ fn main() {
|
|||
current_thread::block_on_all(server).unwrap();
|
||||
}
|
||||
|
||||
fn parse_opts(doh: &mut DoH) {
|
||||
fn parse_opts(inner_doh: &mut InnerDoH) {
|
||||
let max_clients = MAX_CLIENTS.to_string();
|
||||
let timeout_sec = TIMEOUT_SEC.to_string();
|
||||
let matches = App::new("doh-proxy")
|
||||
|
@ -300,18 +310,18 @@ fn parse_opts(doh: &mut DoH) {
|
|||
)
|
||||
.get_matches();
|
||||
if let Some(listen_address) = matches.value_of("listen_address") {
|
||||
doh.listen_address = listen_address.parse().unwrap();
|
||||
inner_doh.listen_address = listen_address.parse().unwrap();
|
||||
}
|
||||
if let Some(server_address) = matches.value_of("server_address") {
|
||||
doh.server_address = server_address.parse().unwrap();
|
||||
inner_doh.server_address = server_address.parse().unwrap();
|
||||
}
|
||||
if let Some(local_bind_address) = matches.value_of("local_bind_address") {
|
||||
doh.local_bind_address = local_bind_address.parse().unwrap();
|
||||
inner_doh.local_bind_address = local_bind_address.parse().unwrap();
|
||||
}
|
||||
if let Some(max_clients) = matches.value_of("max_clients") {
|
||||
doh.max_clients = max_clients.parse().unwrap();
|
||||
inner_doh.max_clients = max_clients.parse().unwrap();
|
||||
}
|
||||
if let Some(timeout) = matches.value_of("timeout") {
|
||||
doh.timeout = Duration::from_secs(timeout.parse().unwrap());
|
||||
inner_doh.timeout = Duration::from_secs(timeout.parse().unwrap());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue