From 8098d3938ddfa77df280085d482f14fc19238327 Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Tue, 6 Feb 2018 11:32:21 +0100 Subject: [PATCH] Add command-line parser --- Cargo.toml | 2 +- src/main.rs | 115 +++++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 97 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5800a0d..10f63be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "rust-doh" +name = "doh-proxy" version = "0.1.0" authors = ["Frank Denis "] diff --git a/src/main.rs b/src/main.rs index 6a991fc..e66c927 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,36 +3,44 @@ #![cfg_attr(feature = "clippy", feature(plugin))] #![cfg_attr(feature = "clippy", plugin(clippy))] +extern crate clap; extern crate futures_await as futures; extern crate hyper; extern crate tokio_core; extern crate tokio_io; extern crate tokio_timer; -use tokio_core::reactor::Handle; -use std::time::Duration; +use clap::{App, Arg}; +use futures::future; +use futures::prelude::*; use hyper::{Body, Method, StatusCode}; use hyper::header::{ContentLength, ContentType}; use hyper::server::{Http, Request, Response, Service}; -use futures::prelude::*; -use futures::future; -use tokio_core::reactor::Core; -use tokio_core::net::UdpSocket; use std::cell::RefCell; use std::rc::Rc; +use std::net::SocketAddr; +use std::time::Duration; +use tokio_core::net::UdpSocket; +use tokio_core::reactor::Core; +use tokio_core::reactor::Handle; -const TIMEOUT_SEC: u64 = 10; -const LOCAL_ADDRESS: &str = "127.0.0.1:3000"; +const LISTEN_ADDRESS: &str = "127.0.0.1:3000"; const LOCAL_BIND_ADDRESS: &str = "0.0.0.0:0"; -const SERVER_ADDRESS: &str = "9.9.9.9:53"; -const MIN_DNS_PACKET_LEN: usize = 17; +const MAX_CLIENTS: u32 = 512; const MAX_DNS_QUESTION_LEN: usize = 512; const MAX_DNS_RESPONSE_LEN: usize = 4096; -const MAX_CLIENTS: u32 = 512; +const MIN_DNS_PACKET_LEN: usize = 17; +const SERVER_ADDRESS: &str = "9.9.9.9:53"; +const TIMEOUT_SEC: u64 = 10; #[derive(Clone, Debug)] struct DoH { handle: Handle, + listen_address: SocketAddr, + local_bind_address: SocketAddr, + server_address: SocketAddr, + max_clients: u32, + timeout: Duration, } impl Service for DoH { @@ -96,16 +104,22 @@ impl DoH { fn main() { let mut core = Core::new().unwrap(); let handle = core.handle(); - let addr = LOCAL_ADDRESS.parse().unwrap(); let handle_inner = handle.clone(); + let mut doh = DoH { + handle: handle_inner.clone(), + listen_address: LISTEN_ADDRESS.parse().unwrap(), + local_bind_address: LOCAL_BIND_ADDRESS.parse().unwrap(), + server_address: SERVER_ADDRESS.parse().unwrap(), + max_clients: MAX_CLIENTS, + timeout: Duration::from_secs(TIMEOUT_SEC), + }; + parse_opts(&mut doh); + let listen_address = doh.listen_address; + let doh_inner = doh.clone(); let server = Http::new() .keep_alive(false) .max_buf_size(MAX_DNS_QUESTION_LEN) - .serve_addr_handle(&addr, &handle, move || { - Ok(DoH { - handle: handle_inner.clone(), - }) - }) + .serve_addr_handle(&listen_address, &handle, move || Ok(doh_inner.clone())) .unwrap(); println!("Listening on http://{}", server.incoming_ref().local_addr()); let handle_inner = handle.clone(); @@ -114,7 +128,7 @@ fn main() { let fut = server.for_each(move |client_fut| { { let count = client_count.borrow_mut(); - if *count > MAX_CLIENTS { + if *count > doh.max_clients { return Ok(()); } (*count).saturating_add(1); @@ -126,10 +140,73 @@ fn main() { (*client_count_inner.borrow_mut()).saturating_sub(1); }) .map_err(|err| eprintln!("server error: {:?}", err)); - let timed = timers_inner.timeout(fut, Duration::from_secs(TIMEOUT_SEC)); + let timed = timers_inner.timeout(fut, doh.timeout); handle_inner.spawn(timed); Ok(()) }); handle.spawn(fut.map_err(|_| ())); core.run(futures::future::empty::<(), ()>()).unwrap(); } + +fn parse_opts(doh: &mut DoH) { + let max_clients = MAX_CLIENTS.to_string(); + let timeout_sec = TIMEOUT_SEC.to_string(); + let matches = App::new("doh-proxy") + .about("A DNS-over-HTTP server proxy") + .arg( + Arg::with_name("listen_address") + .short("l") + .long("listen_address") + .takes_value(true) + .default_value(LISTEN_ADDRESS) + .help("Address to listen to"), + ) + .arg( + Arg::with_name("server_address") + .short("u") + .long("server_address") + .takes_value(true) + .default_value(SERVER_ADDRESS) + .help("Address to connect to"), + ) + .arg( + Arg::with_name("local_bind_address") + .short("b") + .long("local_bind_address") + .takes_value(true) + .default_value(LOCAL_BIND_ADDRESS) + .help("Address to connect from"), + ) + .arg( + Arg::with_name("max_clients") + .short("c") + .long("max_clients") + .takes_value(true) + .default_value(&max_clients) + .help("Maximum number of simultaneous clients"), + ) + .arg( + Arg::with_name("timeout") + .short("t") + .long("timeout") + .takes_value(true) + .default_value(&timeout_sec) + .help("Timeout, in seconds"), + ) + .get_matches(); + if let Some(listen_address) = matches.value_of("listen_address") { + doh.listen_address = listen_address.parse().unwrap(); + } + if let Some(server_address) = matches.value_of("server_address") { + doh.server_address = server_address.parse().unwrap(); + } + if let Some(local_bind_address) = matches.value_of("local_bind_address") { + doh.local_bind_address = local_bind_address.parse().unwrap(); + } + if let Some(max_clients) = matches.value_of("max_clients") { + doh.max_clients = max_clients.parse().unwrap(); + } + if let Some(timeout) = matches.value_of("timeout") { + doh.timeout = Duration::from_secs(timeout.parse().unwrap()); + } +}