diff --git a/Cargo.toml b/Cargo.toml index f7c9dab..74e73c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,14 +16,15 @@ default = [] tls = ["native-tls", "tokio-tls"] [dependencies] +anyhow = "1.0" base64 = "0.11" clap = "2.33.0" -futures = "0.1.29" -hyper = "0.12.35" +futures = "0.3" +hyper = "0.13" jemallocator = "0" native-tls = { version = "0.2.3", optional = true } -tokio = "0.1.22" -tokio-tls = { version = "0.2.1", optional = true } +tokio = { version = "0.2", features = ["full"] } +tokio-tls = { version = "0.3", optional = true } [package.metadata.deb] extended-description = """\ diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..b02211a --- /dev/null +++ b/src/config.rs @@ -0,0 +1,162 @@ +#[cfg(feature = "tls")] +use std::path::PathBuf; + +use super::*; + +pub fn parse_opts(globals: &mut Globals) { + use crate::utils::{verify_remote_server, verify_sock_addr}; + + let max_clients = MAX_CLIENTS.to_string(); + let timeout_sec = TIMEOUT_SEC.to_string(); + let min_ttl = MIN_TTL.to_string(); + let max_ttl = MAX_TTL.to_string(); + let err_ttl = ERR_TTL.to_string(); + + let _ = include_str!("../Cargo.toml"); + let options = app_from_crate!() + .arg( + Arg::with_name("listen_address") + .short("l") + .long("listen-address") + .takes_value(true) + .default_value(LISTEN_ADDRESS) + .validator(verify_sock_addr) + .help("Address to listen to"), + ) + .arg( + Arg::with_name("server_address") + .short("u") + .long("server-address") + .takes_value(true) + .default_value(SERVER_ADDRESS) + .validator(verify_remote_server) + .help("Address to connect to"), + ) + .arg( + Arg::with_name("local_bind_address") + .short("b") + .long("local-bind-address") + .takes_value(true) + .validator(verify_sock_addr) + .help("Address to connect from"), + ) + .arg( + Arg::with_name("path") + .short("p") + .long("path") + .takes_value(true) + .default_value(PATH) + .help("URI path"), + ) + .arg( + Arg::with_name("max_clients") + .short("c") + .long("max-clients") + .takes_value(true) + .default_value(&max_clients) + .help("Maximum number of simultaneous clients"), + ) + .arg( + Arg::with_name("timeout") + .short("t") + .long("timeout") + .takes_value(true) + .default_value(&timeout_sec) + .help("Timeout, in seconds"), + ) + .arg( + Arg::with_name("min_ttl") + .short("T") + .long("min-ttl") + .takes_value(true) + .default_value(&min_ttl) + .help("Minimum TTL, in seconds"), + ) + .arg( + Arg::with_name("max_ttl") + .short("X") + .long("max-ttl") + .takes_value(true) + .default_value(&max_ttl) + .help("Maximum TTL, in seconds"), + ) + .arg( + Arg::with_name("err_ttl") + .short("E") + .long("err-ttl") + .takes_value(true) + .default_value(&err_ttl) + .help("TTL for errors, in seconds"), + ) + .arg( + Arg::with_name("disable_keepalive") + .short("K") + .long("disable-keepalive") + .help("Disable keepalive"), + ) + .arg( + Arg::with_name("disable_post") + .short("P") + .long("disable-post") + .help("Disable POST queries"), + ); + + #[cfg(feature = "tls")] + let options = options + .arg( + Arg::with_name("tls_cert_path") + .short("i") + .long("tls-cert-path") + .takes_value(true) + .help("Path to a PKCS12-encoded identity (only required for built-in TLS)"), + ) + .arg( + Arg::with_name("tls_cert_password") + .short("I") + .long("tls-cert-password") + .takes_value(true) + .help("Password for the PKCS12-encoded identity (only required for built-in TLS)"), + ); + + let matches = options.get_matches(); + globals.listen_address = matches.value_of("listen_address").unwrap().parse().unwrap(); + + globals.server_address = matches + .value_of("server_address") + .unwrap() + .to_socket_addrs() + .unwrap() + .next() + .unwrap(); + globals.local_bind_address = match matches.value_of("local_bind_address") { + Some(address) => address.parse().unwrap(), + None => match globals.server_address { + SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), + SocketAddr::V6(s) => SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::UNSPECIFIED, + 0, + s.flowinfo(), + s.scope_id(), + )), + }, + }; + globals.path = matches.value_of("path").unwrap().to_string(); + if !globals.path.starts_with('/') { + globals.path = format!("/{}", globals.path); + } + globals.max_clients = matches.value_of("max_clients").unwrap().parse().unwrap(); + globals.timeout = Duration::from_secs(matches.value_of("timeout").unwrap().parse().unwrap()); + globals.min_ttl = matches.value_of("min_ttl").unwrap().parse().unwrap(); + globals.max_ttl = matches.value_of("max_ttl").unwrap().parse().unwrap(); + globals.err_ttl = matches.value_of("err_ttl").unwrap().parse().unwrap(); + globals.keepalive = !matches.is_present("disable_keepalive"); + globals.disable_post = matches.is_present("disable_post"); + + #[cfg(feature = "tls")] + { + globals.tls_cert_path = matches.value_of("tls_cert_path").map(PathBuf::from); + globals.tls_cert_password = matches + .value_of("tls_cert_password") + .map(ToString::to_string); + } +} diff --git a/src/constants.rs b/src/constants.rs new file mode 100644 index 0000000..2f23f3a --- /dev/null +++ b/src/constants.rs @@ -0,0 +1,13 @@ +pub const BLOCK_SIZE: usize = 128; +pub const DNS_QUERY_PARAM: &str = "dns"; +pub const LISTEN_ADDRESS: &str = "127.0.0.1:3000"; +pub const MAX_CLIENTS: usize = 512; +pub const MAX_DNS_QUESTION_LEN: usize = 512; +pub const MAX_DNS_RESPONSE_LEN: usize = 4096; +pub const MIN_DNS_PACKET_LEN: usize = 17; +pub const PATH: &str = "/dns-query"; +pub const SERVER_ADDRESS: &str = "9.9.9.9:53"; +pub const TIMEOUT_SEC: u64 = 10; +pub const MAX_TTL: u32 = 86400 * 7; +pub const MIN_TTL: u32 = 10; +pub const ERR_TTL: u32 = 2; diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..1eccc52 --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,47 @@ +pub use anyhow::{anyhow, bail, ensure, Error}; + +use hyper::StatusCode; +use std::io; + +#[allow(dead_code)] +#[derive(Debug)] +pub enum DoHError { + Incomplete, + InvalidData, + TooLarge, + UpstreamIssue, + Hyper(hyper::Error), + Io(io::Error), +} + +impl std::fmt::Display for DoHError { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + std::fmt::Debug::fmt(self, fmt) + } +} + +impl std::error::Error for DoHError { + fn description(&self) -> &str { + match *self { + DoHError::Incomplete => "Incomplete", + DoHError::InvalidData => "Invalid data", + DoHError::TooLarge => "Too large", + DoHError::UpstreamIssue => "Upstream error", + DoHError::Hyper(_) => self.description(), + DoHError::Io(_) => self.description(), + } + } +} + +impl From for StatusCode { + fn from(e: DoHError) -> StatusCode { + match e { + DoHError::Incomplete => StatusCode::UNPROCESSABLE_ENTITY, + DoHError::InvalidData => StatusCode::BAD_REQUEST, + DoHError::TooLarge => StatusCode::PAYLOAD_TOO_LARGE, + DoHError::UpstreamIssue => StatusCode::BAD_GATEWAY, + DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE, + DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} diff --git a/src/globals.rs b/src/globals.rs new file mode 100644 index 0000000..3877bdd --- /dev/null +++ b/src/globals.rs @@ -0,0 +1,47 @@ +use std::net::SocketAddr; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +#[cfg(feature = "tls")] +use std::path::PathBuf; + +#[derive(Debug)] +pub struct Globals { + #[cfg(feature = "tls")] + pub tls_cert_path: Option, + + #[cfg(feature = "tls")] + pub tls_cert_password: Option, + + pub listen_address: SocketAddr, + pub local_bind_address: SocketAddr, + pub server_address: SocketAddr, + pub path: String, + pub max_clients: usize, + pub timeout: Duration, + pub clients_count: ClientsCount, + pub min_ttl: u32, + pub max_ttl: u32, + pub err_ttl: u32, + pub keepalive: bool, + pub disable_post: bool, +} + +#[derive(Debug, Clone, Default)] +pub struct ClientsCount(Arc); + +impl ClientsCount { + pub fn increment(&self) -> usize { + self.0.fetch_add(1, Ordering::Relaxed) + } + + pub fn decrement(&self) -> usize { + let mut count; + while { + count = self.0.load(Ordering::Relaxed); + count > 0 && self.0.compare_and_swap(count, count - 1, Ordering::Relaxed) != count + } {} + count + } +} diff --git a/src/main.rs b/src/main.rs index c14759b..06f6bfb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,168 +4,157 @@ static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; #[macro_use] extern crate clap; +mod config; +mod constants; mod dns; +mod errors; +mod globals; mod utils; -use base64; +use crate::config::*; +use crate::constants::*; +use crate::errors::*; +use crate::globals::*; + use clap::Arg; use futures::future; use futures::prelude::*; -use futures::stream::Stream; -use hyper; +use futures::task::{Context, Poll}; +use hyper::http; use hyper::server::conn::Http; -use hyper::service::Service; use hyper::{Body, Method, Request, Response, StatusCode}; -use std::io; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; +use tokio::net::{TcpListener, UdpSocket}; #[cfg(feature = "tls")] use native_tls::{self, Identity}; - #[cfg(feature = "tls")] use std::fs::File; - +#[cfg(feature = "tls")] +use std::io; #[cfg(feature = "tls")] use std::io::Read; - #[cfg(feature = "tls")] -use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::time::Duration; -use tokio; -use tokio::net::{TcpListener, UdpSocket}; -use tokio::prelude::{AsyncRead, AsyncWrite, FutureExt}; - +use std::path::Path; #[cfg(feature = "tls")] use tokio_tls::TlsAcceptor; -const BLOCK_SIZE: usize = 128; -const DNS_QUERY_PARAM: &str = "dns"; -const LISTEN_ADDRESS: &str = "127.0.0.1:3000"; -const MAX_CLIENTS: usize = 512; -const MAX_DNS_QUESTION_LEN: usize = 512; -const MAX_DNS_RESPONSE_LEN: usize = 4096; -const MIN_DNS_PACKET_LEN: usize = 17; -const PATH: &str = "/dns-query"; -const SERVER_ADDRESS: &str = "9.9.9.9:53"; -const TIMEOUT_SEC: u64 = 10; -const MAX_TTL: u32 = 86400 * 7; -const MIN_TTL: u32 = 10; -const ERR_TTL: u32 = 2; - -#[derive(Debug, Clone, Default)] -struct ClientsCount(Arc); - -impl ClientsCount { - fn current(&self) -> usize { - self.0.load(Ordering::Relaxed) - } - - fn increment(&self) -> usize { - self.0.fetch_add(1, Ordering::Relaxed) - } - - fn decrement(&self) -> usize { - let mut count; - while { - count = self.0.load(Ordering::Relaxed); - count > 0 && self.0.compare_and_swap(count, count - 1, Ordering::Relaxed) != count - } {} - count - } -} - -#[derive(Debug)] -struct InnerDoH { - #[cfg(feature = "tls")] - tls_cert_path: Option, - - #[cfg(feature = "tls")] - tls_cert_password: Option, - - listen_address: SocketAddr, - local_bind_address: SocketAddr, - server_address: SocketAddr, - path: String, - max_clients: usize, - timeout: Duration, - clients_count: ClientsCount, - min_ttl: u32, - max_ttl: u32, - err_ttl: u32, - keepalive: bool, - disable_post: bool, -} - #[derive(Clone, Debug)] struct DoH { - inner: Arc, + globals: Arc, } -#[allow(dead_code)] -#[derive(Debug)] -enum Error { - Incomplete, - InvalidData, - TooLarge, - UpstreamIssue, - Hyper(hyper::Error), - Io(io::Error), +#[cfg(feature = "tls")] +fn create_tls_acceptor

(path: P, password: &str) -> io::Result +where + P: AsRef, +{ + let identity_bin = { + let mut fp = File::open(path)?; + let mut identity_bin = vec![]; + fp.read_to_end(&mut identity_bin)?; + identity_bin + }; + let identity = Identity::from_pkcs12(&identity_bin, password).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unusable PKCS12-encoded identity. The encoding and/or the password may be wrong", + ) + })?; + let native_acceptor = native_tls::TlsAcceptor::new(identity).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to use the provided PKCS12-encoded identity", + ) + })?; + Ok(TlsAcceptor::from(native_acceptor)) } -impl std::fmt::Display for Error { - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - std::fmt::Debug::fmt(self, fmt) +impl hyper::service::Service> for DoH { + type Response = Response; + type Error = http::Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } -} - -impl std::error::Error for Error { - fn description(&self) -> &str { - match *self { - Error::Incomplete => "Incomplete", - Error::InvalidData => "Invalid data", - Error::TooLarge => "Too large", - Error::UpstreamIssue => "Upstream error", - Error::Hyper(_) => self.description(), - Error::Io(_) => self.description(), - } - } -} - -impl From for StatusCode { - fn from(e: Error) -> StatusCode { - match e { - Error::Incomplete => StatusCode::UNPROCESSABLE_ENTITY, - Error::InvalidData => StatusCode::BAD_REQUEST, - Error::TooLarge => StatusCode::PAYLOAD_TOO_LARGE, - Error::UpstreamIssue => StatusCode::BAD_GATEWAY, - Error::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE, - Error::Io(_) => StatusCode::INTERNAL_SERVER_ERROR, - } - } -} - -impl Service for DoH { - type ReqBody = Body; - type ResBody = Body; - type Error = Error; - type Future = Box, Error = Self::Error> + Send>; fn call(&mut self, req: Request) -> Self::Future { - let inner = &self.inner; - { - let count = inner.clients_count.current(); - if count >= inner.max_clients { + let globals = &self.globals; + if req.uri().path() != globals.path { + let response = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap(); + return Box::pin(async { Ok(response) }); + } + let self_inner = self.clone(); + match *req.method() { + Method::POST => { + if globals.disable_post { + let response = Response::builder() + .status(StatusCode::METHOD_NOT_ALLOWED) + .body(Body::empty()) + .unwrap(); + return Box::pin(async { Ok(response) }); + } + if let Err(response) = Self::check_content_type(&req) { + return Box::pin(async { Ok(response) }); + } + let fut = async move { + match self_inner.read_body_and_proxy(req.into_body()).await { + Err(e) => Response::builder() + .status(StatusCode::from(e)) + .body(Body::empty()), + Ok(res) => Ok(res), + } + }; + Box::pin(fut) + } + Method::GET => { + let query = req.uri().query().unwrap_or(""); + let mut question_str = None; + for parts in query.split('&') { + let mut kv = parts.split('='); + if let Some(k) = kv.next() { + if k == DNS_QUERY_PARAM { + question_str = kv.next(); + } + } + } + let question = match question_str.and_then(|question_str| { + base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok() + }) { + Some(question) => question, + _ => { + let response = Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::empty()) + .unwrap(); + return Box::pin(future::ok(response)); + } + }; + let fut = async move { + match self_inner.proxy(question).await { + Err(e) => Response::builder() + .status(StatusCode::from(e)) + .body(Body::empty()), + Ok(res) => Ok(res), + } + }; + Box::pin(fut) + } + _ => { let response = Response::builder() - .status(StatusCode::TOO_MANY_REQUESTS) + .status(StatusCode::METHOD_NOT_ALLOWED) .body(Body::empty()) .unwrap(); - return Box::new(future::ok(response)); + Box::pin(async { Ok(response) }) } } - let fut = self.handle_client(req); - Box::new(fut) } } @@ -202,241 +191,120 @@ impl DoH { Ok(()) } - fn handle_client( - &self, - req: Request, - ) -> Box, Error = Error> + Send> { - let inner = &self.inner; - if req.uri().path() != inner.path { - let response = Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::empty()) - .unwrap(); - return Box::new(future::ok(response)); - } - match *req.method() { - Method::POST => { - if self.inner.disable_post { - let response = Response::builder() - .status(StatusCode::METHOD_NOT_ALLOWED) - .body(Body::empty()) - .unwrap(); - return Box::new(future::ok(response)); - } - if let Err(response) = Self::check_content_type(&req) { - return Box::new(future::ok(response)); - } - let fut = self.read_body_and_proxy(req.into_body()).or_else(|e| { - let response = Response::builder() - .status(StatusCode::from(e)) - .body(Body::empty()) - .unwrap(); - future::ok(response) - }); - Box::new(fut) - } - Method::GET => { - let query = req.uri().query().unwrap_or(""); - let mut question_str = None; - for parts in query.split('&') { - let mut kv = parts.split('='); - if let Some(k) = kv.next() { - if k == DNS_QUERY_PARAM { - question_str = kv.next(); - } - } - } - let question = match question_str.and_then(|question_str| { - base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok() - }) { - Some(question) => question, - _ => { - let response = Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::empty()) - .unwrap(); - return Box::new(future::ok(response)); - } - }; - let fut = self.proxy(question).or_else(|e| { - let response = Response::builder() - .status(StatusCode::from(e)) - .body(Body::empty()) - .unwrap(); - future::ok(response) - }); - Box::new(fut) - } - _ => { - let response = Response::builder() - .status(StatusCode::METHOD_NOT_ALLOWED) - .body(Body::empty()) - .unwrap(); - Box::new(future::ok(response)) - } + async fn read_body_and_proxy(&self, mut body: Body) -> Result, DoHError> { + let mut sum_size = 0; + let mut query = vec![]; + while let Some(chunk) = body.next().await { + let chunk = chunk.map_err(|_| DoHError::TooLarge)?; + sum_size += chunk.len(); + if sum_size >= MAX_DNS_QUESTION_LEN { + return Err(DoHError::TooLarge); + } + query.extend(chunk); } + let response = self.proxy(query).await?; + Ok(response) } - fn proxy( - &self, - mut query: Vec, - ) -> Box, Error = Error> + Send> { + async fn proxy(&self, mut query: Vec) -> Result, DoHError> { if query.len() < MIN_DNS_PACKET_LEN { - return Box::new(future::err(Error::Incomplete)); + return Err(DoHError::Incomplete); } let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as u16); - let inner = &self.inner; - let socket = UdpSocket::bind(&inner.local_bind_address).unwrap(); - let expected_server_address = inner.server_address; - let (min_ttl, max_ttl, err_ttl) = (inner.min_ttl, inner.max_ttl, inner.err_ttl); - let fut = socket - .send_dgram(query, &inner.server_address) - .map_err(Error::Io) - .and_then(move |(socket, _)| { - let packet = vec![0; MAX_DNS_RESPONSE_LEN]; - socket.recv_dgram(packet).map_err(Error::Io) - }) - .and_then(move |(_socket, mut packet, len, response_server_address)| { - if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address { - return future::err(Error::UpstreamIssue); - } - packet.truncate(len); - let ttl = if dns::is_recoverable_error(&packet) { - err_ttl - } else { - match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) { - Err(_) => return future::err(Error::UpstreamIssue), - Ok(ttl) => ttl, - } + let globals = &self.globals; + let mut socket = UdpSocket::bind(&globals.local_bind_address) + .await + .map_err(DoHError::Io)?; + let expected_server_address = globals.server_address; + let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl); + socket + .send_to(&query, &globals.server_address) + .map_err(DoHError::Io) + .await?; + let mut packet = vec![0; MAX_DNS_RESPONSE_LEN]; + let (len, response_server_address) = + socket.recv_from(&mut packet).map_err(DoHError::Io).await?; + if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address { + return Err(DoHError::UpstreamIssue); + } + packet.truncate(len); + let ttl = if dns::is_recoverable_error(&packet) { + err_ttl + } else { + match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) { + Err(_) => return Err(DoHError::UpstreamIssue), + Ok(ttl) => ttl, + } + }; + let packet_len = packet.len(); + let response = Response::builder() + .header(hyper::header::CONTENT_LENGTH, packet_len) + .header(hyper::header::CONTENT_TYPE, "application/dns-message") + .header("X-Padding", utils::padding_string(packet_len, BLOCK_SIZE)) + .header( + hyper::header::CACHE_CONTROL, + format!("max-age={}", ttl).as_str(), + ) + .body(Body::from(packet)) + .unwrap(); + Ok(response) + } + + async fn entrypoint(self) -> Result<(), Error> { + let listen_address = self.globals.listen_address; + let mut listener = TcpListener::bind(&listen_address).await?; + let path = &self.globals.path; + + #[cfg(feature = "tls")] + let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) { + (Some(tls_cert_path), Some(tls_cert_password)) => { + println!("Listening on https://{}{}", listen_address, path); + Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap()) + } + _ => { + println!("Listening on http://{}{}", listen_address, path); + None + } + }; + #[cfg(not(feature = "tls"))] + println!("Listening on http://{}{}", listen_address, path); + + let mut server = Http::new(); + server.keep_alive(self.globals.keepalive); + let listener_service = async { + while let Some(stream) = listener.incoming().next().await { + let stream = match stream { + Ok(stream) => stream, + Err(_) => continue, }; - let packet_len = packet.len(); - let response = Response::builder() - .header(hyper::header::CONTENT_LENGTH, packet_len) - .header(hyper::header::CONTENT_TYPE, "application/dns-message") - .header("X-Padding", utils::padding_string(packet_len, BLOCK_SIZE)) - .header( - hyper::header::CACHE_CONTROL, - format!("max-age={}", ttl).as_str(), + let clients_count = self.globals.clients_count.clone(); + if clients_count.increment() > self.globals.max_clients { + clients_count.decrement(); + continue; + } + let self_inner = self.clone(); + let server_inner = server.clone(); + tokio::spawn(async move { + tokio::time::timeout( + self_inner.globals.timeout, + server_inner.serve_connection(stream, self_inner), ) - .body(Body::from(packet)) - .unwrap(); - future::ok(response) - }); - Box::new(fut) - } - - fn read_body_and_proxy( - &self, - body: Body, - ) -> Box, Error = Error> + Send> { - let mut sum_size = 0; - let inner = self.clone(); - let fut = body - .map_err(Error::Hyper) - .and_then(move |chunk| { - sum_size += chunk.len(); - if sum_size > MAX_DNS_QUESTION_LEN { - Err(Error::TooLarge) - } else { - Ok(chunk) - } - }) - .concat2() - .map(move |chunk| chunk.to_vec()) - .and_then(move |query| inner.proxy(query)); - Box::new(fut) - } -} - -#[cfg(feature = "tls")] -fn create_tls_acceptor

(path: P, password: &str) -> io::Result -where - P: AsRef, -{ - let identity_bin = { - let mut fp = File::open(path)?; - let mut identity_bin = vec![]; - fp.read_to_end(&mut identity_bin)?; - identity_bin - }; - let identity = Identity::from_pkcs12(&identity_bin, password).map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Unusable PKCS12-encoded identity. The encoding and/or the password may be wrong", - ) - })?; - let native_acceptor = native_tls::TlsAcceptor::new(identity).map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Unable to use the provided PKCS12-encoded identity", - ) - })?; - Ok(TlsAcceptor::from(native_acceptor)) -} - -fn client_serve( - clients_count: ClientsCount, - stream: I, - http: Http, - service: DoH, - timeout: Duration, -) where - I: AsyncRead + AsyncWrite + Send + 'static, -{ - let clients_count_inner = clients_count.clone(); - let conn = http - .serve_connection(stream, service) - .timeout(timeout) - .map_err(|_| {}) - .then(move |fut| { - clients_count_inner.decrement(); - fut - }); - clients_count.increment(); - tokio::spawn(conn); -} - -#[cfg(feature = "tls")] -fn start_with_tls( - tls_acceptor: TlsAcceptor, - listener: TcpListener, - doh: DoH, - http: Http, - timeout: Duration, -) { - let server = listener.incoming().for_each(move |io| { - let service = doh.clone(); - let http = http.clone(); - let clients_count = doh.inner.clients_count.clone(); - tls_acceptor - .accept(io) - .timeout(timeout) - .then(move |stream| { - if let Ok(stream) = stream { - client_serve(clients_count, stream, http, service, timeout); - } - Ok(()) - }) - }); - tokio::run(server.map_err(|_| {})); -} - -fn start_without_tls(listener: TcpListener, doh: DoH, http: Http, timeout: Duration) { - let server = listener.incoming().for_each(move |stream| { - let service = doh.clone(); - let http = http.clone(); - let clients_count = doh.inner.clients_count.clone(); - client_serve(clients_count, stream, http, service, timeout); + .await + .ok(); + clients_count.decrement(); + }); + } + Ok(()) as Result<(), Error> + }; + listener_service.await?; Ok(()) - }); - tokio::run(server.map_err(|_| {})); + } } fn main() { - let mut inner_doh = InnerDoH { + let mut globals = Globals { #[cfg(feature = "tls")] tls_cert_path: None, - #[cfg(feature = "tls")] tls_cert_password: None, @@ -453,196 +321,14 @@ fn main() { keepalive: true, disable_post: false, }; - parse_opts(&mut inner_doh); - let timeout = inner_doh.timeout; - - #[cfg(feature = "tls")] - let path = inner_doh.path.clone(); + parse_opts(&mut globals); let doh = DoH { - inner: Arc::new(inner_doh), + globals: Arc::new(globals), }; - let listen_address = doh.inner.listen_address; - let listener = TcpListener::bind(&listen_address).unwrap(); - - #[cfg(feature = "tls")] - let tls_acceptor = match (&doh.inner.tls_cert_path, &doh.inner.tls_cert_password) { - (Some(tls_cert_path), Some(tls_cert_password)) => { - println!("Listening on https://{}{}", listen_address, path); - Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap()) - } - _ => { - println!("Listening on http://{}{}", listen_address, path); - None - } - }; - - let mut http = Http::new(); - http.keep_alive(doh.inner.keepalive); - - #[cfg(feature = "tls")] - { - if let Some(tls_acceptor) = tls_acceptor { - start_with_tls(tls_acceptor, listener, doh, http, timeout); - return; - } - } - start_without_tls(listener, doh, http, timeout); -} - -fn parse_opts(inner_doh: &mut InnerDoH) { - use crate::utils::{verify_remote_server, verify_sock_addr}; - - let max_clients = MAX_CLIENTS.to_string(); - let timeout_sec = TIMEOUT_SEC.to_string(); - let min_ttl = MIN_TTL.to_string(); - let max_ttl = MAX_TTL.to_string(); - let err_ttl = ERR_TTL.to_string(); - - let _ = include_str!("../Cargo.toml"); - let options = app_from_crate!() - .arg( - Arg::with_name("listen_address") - .short("l") - .long("listen-address") - .takes_value(true) - .default_value(LISTEN_ADDRESS) - .validator(verify_sock_addr) - .help("Address to listen to"), - ) - .arg( - Arg::with_name("server_address") - .short("u") - .long("server-address") - .takes_value(true) - .default_value(SERVER_ADDRESS) - .validator(verify_remote_server) - .help("Address to connect to"), - ) - .arg( - Arg::with_name("local_bind_address") - .short("b") - .long("local-bind-address") - .takes_value(true) - .validator(verify_sock_addr) - .help("Address to connect from"), - ) - .arg( - Arg::with_name("path") - .short("p") - .long("path") - .takes_value(true) - .default_value(PATH) - .help("URI path"), - ) - .arg( - Arg::with_name("max_clients") - .short("c") - .long("max-clients") - .takes_value(true) - .default_value(&max_clients) - .help("Maximum number of simultaneous clients"), - ) - .arg( - Arg::with_name("timeout") - .short("t") - .long("timeout") - .takes_value(true) - .default_value(&timeout_sec) - .help("Timeout, in seconds"), - ) - .arg( - Arg::with_name("min_ttl") - .short("T") - .long("min-ttl") - .takes_value(true) - .default_value(&min_ttl) - .help("Minimum TTL, in seconds"), - ) - .arg( - Arg::with_name("max_ttl") - .short("X") - .long("max-ttl") - .takes_value(true) - .default_value(&max_ttl) - .help("Maximum TTL, in seconds"), - ) - .arg( - Arg::with_name("err_ttl") - .short("E") - .long("err-ttl") - .takes_value(true) - .default_value(&err_ttl) - .help("TTL for errors, in seconds"), - ) - .arg( - Arg::with_name("disable_keepalive") - .short("K") - .long("disable-keepalive") - .help("Disable keepalive"), - ) - .arg( - Arg::with_name("disable_post") - .short("P") - .long("disable-post") - .help("Disable POST queries"), - ); - - #[cfg(feature = "tls")] - let options = options - .arg( - Arg::with_name("tls_cert_path") - .short("i") - .long("tls-cert-path") - .takes_value(true) - .help("Path to a PKCS12-encoded identity (only required for built-in TLS)"), - ) - .arg( - Arg::with_name("tls_cert_password") - .short("I") - .long("tls-cert-password") - .takes_value(true) - .help("Password for the PKCS12-encoded identity (only required for built-in TLS)"), - ); - - let matches = options.get_matches(); - inner_doh.listen_address = matches.value_of("listen_address").unwrap().parse().unwrap(); - - inner_doh.server_address = matches - .value_of("server_address") - .unwrap() - .to_socket_addrs() - .unwrap() - .next() - .unwrap(); - inner_doh.local_bind_address = match matches.value_of("local_bind_address") { - Some(address) => address.parse().unwrap(), - None => match inner_doh.server_address { - SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), - SocketAddr::V6(s) => SocketAddr::V6(SocketAddrV6::new( - Ipv6Addr::UNSPECIFIED, - 0, - s.flowinfo(), - s.scope_id(), - )), - }, - }; - inner_doh.path = matches.value_of("path").unwrap().to_string(); - if !inner_doh.path.starts_with('/') { - inner_doh.path = format!("/{}", inner_doh.path); - } - inner_doh.max_clients = matches.value_of("max_clients").unwrap().parse().unwrap(); - inner_doh.timeout = Duration::from_secs(matches.value_of("timeout").unwrap().parse().unwrap()); - inner_doh.min_ttl = matches.value_of("min_ttl").unwrap().parse().unwrap(); - inner_doh.max_ttl = matches.value_of("max_ttl").unwrap().parse().unwrap(); - inner_doh.err_ttl = matches.value_of("err_ttl").unwrap().parse().unwrap(); - inner_doh.keepalive = !matches.is_present("disable_keepalive"); - inner_doh.disable_post = matches.is_present("disable_post"); - - #[cfg(feature = "tls")] - { - inner_doh.tls_cert_path = matches.value_of("tls_cert_path").map(PathBuf::from); - inner_doh.tls_cert_password = matches - .value_of("tls_cert_password") - .map(ToString::to_string); - } + let mut runtime_builder = tokio::runtime::Builder::new(); + runtime_builder.enable_all(); + runtime_builder.threaded_scheduler(); + runtime_builder.thread_name("doh-proxy"); + let mut runtime = runtime_builder.build().unwrap(); + runtime.block_on(doh.entrypoint()).unwrap(); }