Prepare for tokio 0.2/hyper 0.13/async-await migration

This commit is contained in:
Frank Denis 2019-12-23 00:10:40 +01:00
parent 1cb4a11a7b
commit 1b850b2f41
6 changed files with 499 additions and 543 deletions

View file

@ -16,14 +16,15 @@ default = []
tls = ["native-tls", "tokio-tls"]
[dependencies]
anyhow = "1.0"
base64 = "0.11"
clap = "2.33.0"
futures = "0.1.29"
hyper = "0.12.35"
futures = "0.3"
hyper = "0.13"
jemallocator = "0"
native-tls = { version = "0.2.3", optional = true }
tokio = "0.1.22"
tokio-tls = { version = "0.2.1", optional = true }
tokio = { version = "0.2", features = ["full"] }
tokio-tls = { version = "0.3", optional = true }
[package.metadata.deb]
extended-description = """\

162
src/config.rs Normal file
View file

@ -0,0 +1,162 @@
#[cfg(feature = "tls")]
use std::path::PathBuf;
use super::*;
pub fn parse_opts(globals: &mut Globals) {
use crate::utils::{verify_remote_server, verify_sock_addr};
let max_clients = MAX_CLIENTS.to_string();
let timeout_sec = TIMEOUT_SEC.to_string();
let min_ttl = MIN_TTL.to_string();
let max_ttl = MAX_TTL.to_string();
let err_ttl = ERR_TTL.to_string();
let _ = include_str!("../Cargo.toml");
let options = app_from_crate!()
.arg(
Arg::with_name("listen_address")
.short("l")
.long("listen-address")
.takes_value(true)
.default_value(LISTEN_ADDRESS)
.validator(verify_sock_addr)
.help("Address to listen to"),
)
.arg(
Arg::with_name("server_address")
.short("u")
.long("server-address")
.takes_value(true)
.default_value(SERVER_ADDRESS)
.validator(verify_remote_server)
.help("Address to connect to"),
)
.arg(
Arg::with_name("local_bind_address")
.short("b")
.long("local-bind-address")
.takes_value(true)
.validator(verify_sock_addr)
.help("Address to connect from"),
)
.arg(
Arg::with_name("path")
.short("p")
.long("path")
.takes_value(true)
.default_value(PATH)
.help("URI path"),
)
.arg(
Arg::with_name("max_clients")
.short("c")
.long("max-clients")
.takes_value(true)
.default_value(&max_clients)
.help("Maximum number of simultaneous clients"),
)
.arg(
Arg::with_name("timeout")
.short("t")
.long("timeout")
.takes_value(true)
.default_value(&timeout_sec)
.help("Timeout, in seconds"),
)
.arg(
Arg::with_name("min_ttl")
.short("T")
.long("min-ttl")
.takes_value(true)
.default_value(&min_ttl)
.help("Minimum TTL, in seconds"),
)
.arg(
Arg::with_name("max_ttl")
.short("X")
.long("max-ttl")
.takes_value(true)
.default_value(&max_ttl)
.help("Maximum TTL, in seconds"),
)
.arg(
Arg::with_name("err_ttl")
.short("E")
.long("err-ttl")
.takes_value(true)
.default_value(&err_ttl)
.help("TTL for errors, in seconds"),
)
.arg(
Arg::with_name("disable_keepalive")
.short("K")
.long("disable-keepalive")
.help("Disable keepalive"),
)
.arg(
Arg::with_name("disable_post")
.short("P")
.long("disable-post")
.help("Disable POST queries"),
);
#[cfg(feature = "tls")]
let options = options
.arg(
Arg::with_name("tls_cert_path")
.short("i")
.long("tls-cert-path")
.takes_value(true)
.help("Path to a PKCS12-encoded identity (only required for built-in TLS)"),
)
.arg(
Arg::with_name("tls_cert_password")
.short("I")
.long("tls-cert-password")
.takes_value(true)
.help("Password for the PKCS12-encoded identity (only required for built-in TLS)"),
);
let matches = options.get_matches();
globals.listen_address = matches.value_of("listen_address").unwrap().parse().unwrap();
globals.server_address = matches
.value_of("server_address")
.unwrap()
.to_socket_addrs()
.unwrap()
.next()
.unwrap();
globals.local_bind_address = match matches.value_of("local_bind_address") {
Some(address) => address.parse().unwrap(),
None => match globals.server_address {
SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
SocketAddr::V6(s) => SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::UNSPECIFIED,
0,
s.flowinfo(),
s.scope_id(),
)),
},
};
globals.path = matches.value_of("path").unwrap().to_string();
if !globals.path.starts_with('/') {
globals.path = format!("/{}", globals.path);
}
globals.max_clients = matches.value_of("max_clients").unwrap().parse().unwrap();
globals.timeout = Duration::from_secs(matches.value_of("timeout").unwrap().parse().unwrap());
globals.min_ttl = matches.value_of("min_ttl").unwrap().parse().unwrap();
globals.max_ttl = matches.value_of("max_ttl").unwrap().parse().unwrap();
globals.err_ttl = matches.value_of("err_ttl").unwrap().parse().unwrap();
globals.keepalive = !matches.is_present("disable_keepalive");
globals.disable_post = matches.is_present("disable_post");
#[cfg(feature = "tls")]
{
globals.tls_cert_path = matches.value_of("tls_cert_path").map(PathBuf::from);
globals.tls_cert_password = matches
.value_of("tls_cert_password")
.map(ToString::to_string);
}
}

13
src/constants.rs Normal file
View file

@ -0,0 +1,13 @@
pub const BLOCK_SIZE: usize = 128;
pub const DNS_QUERY_PARAM: &str = "dns";
pub const LISTEN_ADDRESS: &str = "127.0.0.1:3000";
pub const MAX_CLIENTS: usize = 512;
pub const MAX_DNS_QUESTION_LEN: usize = 512;
pub const MAX_DNS_RESPONSE_LEN: usize = 4096;
pub const MIN_DNS_PACKET_LEN: usize = 17;
pub const PATH: &str = "/dns-query";
pub const SERVER_ADDRESS: &str = "9.9.9.9:53";
pub const TIMEOUT_SEC: u64 = 10;
pub const MAX_TTL: u32 = 86400 * 7;
pub const MIN_TTL: u32 = 10;
pub const ERR_TTL: u32 = 2;

47
src/errors.rs Normal file
View file

@ -0,0 +1,47 @@
pub use anyhow::{anyhow, bail, ensure, Error};
use hyper::StatusCode;
use std::io;
#[allow(dead_code)]
#[derive(Debug)]
pub enum DoHError {
Incomplete,
InvalidData,
TooLarge,
UpstreamIssue,
Hyper(hyper::Error),
Io(io::Error),
}
impl std::fmt::Display for DoHError {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
std::fmt::Debug::fmt(self, fmt)
}
}
impl std::error::Error for DoHError {
fn description(&self) -> &str {
match *self {
DoHError::Incomplete => "Incomplete",
DoHError::InvalidData => "Invalid data",
DoHError::TooLarge => "Too large",
DoHError::UpstreamIssue => "Upstream error",
DoHError::Hyper(_) => self.description(),
DoHError::Io(_) => self.description(),
}
}
}
impl From<DoHError> for StatusCode {
fn from(e: DoHError) -> StatusCode {
match e {
DoHError::Incomplete => StatusCode::UNPROCESSABLE_ENTITY,
DoHError::InvalidData => StatusCode::BAD_REQUEST,
DoHError::TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
DoHError::UpstreamIssue => StatusCode::BAD_GATEWAY,
DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE,
DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}

47
src/globals.rs Normal file
View file

@ -0,0 +1,47 @@
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "tls")]
use std::path::PathBuf;
#[derive(Debug)]
pub struct Globals {
#[cfg(feature = "tls")]
pub tls_cert_path: Option<PathBuf>,
#[cfg(feature = "tls")]
pub tls_cert_password: Option<String>,
pub listen_address: SocketAddr,
pub local_bind_address: SocketAddr,
pub server_address: SocketAddr,
pub path: String,
pub max_clients: usize,
pub timeout: Duration,
pub clients_count: ClientsCount,
pub min_ttl: u32,
pub max_ttl: u32,
pub err_ttl: u32,
pub keepalive: bool,
pub disable_post: bool,
}
#[derive(Debug, Clone, Default)]
pub struct ClientsCount(Arc<AtomicUsize>);
impl ClientsCount {
pub fn increment(&self) -> usize {
self.0.fetch_add(1, Ordering::Relaxed)
}
pub fn decrement(&self) -> usize {
let mut count;
while {
count = self.0.load(Ordering::Relaxed);
count > 0 && self.0.compare_and_swap(count, count - 1, Ordering::Relaxed) != count
} {}
count
}
}

View file

@ -4,168 +4,157 @@ static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
#[macro_use]
extern crate clap;
mod config;
mod constants;
mod dns;
mod errors;
mod globals;
mod utils;
use base64;
use crate::config::*;
use crate::constants::*;
use crate::errors::*;
use crate::globals::*;
use clap::Arg;
use futures::future;
use futures::prelude::*;
use futures::stream::Stream;
use hyper;
use futures::task::{Context, Poll};
use hyper::http;
use hyper::server::conn::Http;
use hyper::service::Service;
use hyper::{Body, Method, Request, Response, StatusCode};
use std::io;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener, UdpSocket};
#[cfg(feature = "tls")]
use native_tls::{self, Identity};
#[cfg(feature = "tls")]
use std::fs::File;
#[cfg(feature = "tls")]
use std::io;
#[cfg(feature = "tls")]
use std::io::Read;
#[cfg(feature = "tls")]
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio;
use tokio::net::{TcpListener, UdpSocket};
use tokio::prelude::{AsyncRead, AsyncWrite, FutureExt};
use std::path::Path;
#[cfg(feature = "tls")]
use tokio_tls::TlsAcceptor;
const BLOCK_SIZE: usize = 128;
const DNS_QUERY_PARAM: &str = "dns";
const LISTEN_ADDRESS: &str = "127.0.0.1:3000";
const MAX_CLIENTS: usize = 512;
const MAX_DNS_QUESTION_LEN: usize = 512;
const MAX_DNS_RESPONSE_LEN: usize = 4096;
const MIN_DNS_PACKET_LEN: usize = 17;
const PATH: &str = "/dns-query";
const SERVER_ADDRESS: &str = "9.9.9.9:53";
const TIMEOUT_SEC: u64 = 10;
const MAX_TTL: u32 = 86400 * 7;
const MIN_TTL: u32 = 10;
const ERR_TTL: u32 = 2;
#[derive(Debug, Clone, Default)]
struct ClientsCount(Arc<AtomicUsize>);
impl ClientsCount {
fn current(&self) -> usize {
self.0.load(Ordering::Relaxed)
}
fn increment(&self) -> usize {
self.0.fetch_add(1, Ordering::Relaxed)
}
fn decrement(&self) -> usize {
let mut count;
while {
count = self.0.load(Ordering::Relaxed);
count > 0 && self.0.compare_and_swap(count, count - 1, Ordering::Relaxed) != count
} {}
count
}
}
#[derive(Debug)]
struct InnerDoH {
#[cfg(feature = "tls")]
tls_cert_path: Option<PathBuf>,
#[cfg(feature = "tls")]
tls_cert_password: Option<String>,
listen_address: SocketAddr,
local_bind_address: SocketAddr,
server_address: SocketAddr,
path: String,
max_clients: usize,
timeout: Duration,
clients_count: ClientsCount,
min_ttl: u32,
max_ttl: u32,
err_ttl: u32,
keepalive: bool,
disable_post: bool,
}
#[derive(Clone, Debug)]
struct DoH {
inner: Arc<InnerDoH>,
globals: Arc<Globals>,
}
#[allow(dead_code)]
#[derive(Debug)]
enum Error {
Incomplete,
InvalidData,
TooLarge,
UpstreamIssue,
Hyper(hyper::Error),
Io(io::Error),
#[cfg(feature = "tls")]
fn create_tls_acceptor<P>(path: P, password: &str) -> io::Result<TlsAcceptor>
where
P: AsRef<Path>,
{
let identity_bin = {
let mut fp = File::open(path)?;
let mut identity_bin = vec![];
fp.read_to_end(&mut identity_bin)?;
identity_bin
};
let identity = Identity::from_pkcs12(&identity_bin, password).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Unusable PKCS12-encoded identity. The encoding and/or the password may be wrong",
)
})?;
let native_acceptor = native_tls::TlsAcceptor::new(identity).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Unable to use the provided PKCS12-encoded identity",
)
})?;
Ok(TlsAcceptor::from(native_acceptor))
}
impl std::fmt::Display for Error {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
std::fmt::Debug::fmt(self, fmt)
impl hyper::service::Service<http::Request<Body>> for DoH {
type Response = Response<Body>;
type Error = http::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
impl std::error::Error for Error {
fn description(&self) -> &str {
match *self {
Error::Incomplete => "Incomplete",
Error::InvalidData => "Invalid data",
Error::TooLarge => "Too large",
Error::UpstreamIssue => "Upstream error",
Error::Hyper(_) => self.description(),
Error::Io(_) => self.description(),
}
}
}
impl From<Error> for StatusCode {
fn from(e: Error) -> StatusCode {
match e {
Error::Incomplete => StatusCode::UNPROCESSABLE_ENTITY,
Error::InvalidData => StatusCode::BAD_REQUEST,
Error::TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
Error::UpstreamIssue => StatusCode::BAD_GATEWAY,
Error::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE,
Error::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
impl Service for DoH {
type ReqBody = Body;
type ResBody = Body;
type Error = Error;
type Future = Box<dyn Future<Item = Response<Body>, Error = Self::Error> + Send>;
fn call(&mut self, req: Request<Body>) -> Self::Future {
let inner = &self.inner;
{
let count = inner.clients_count.current();
if count >= inner.max_clients {
let globals = &self.globals;
if req.uri().path() != globals.path {
let response = Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap();
return Box::pin(async { Ok(response) });
}
let self_inner = self.clone();
match *req.method() {
Method::POST => {
if globals.disable_post {
let response = Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::empty())
.unwrap();
return Box::pin(async { Ok(response) });
}
if let Err(response) = Self::check_content_type(&req) {
return Box::pin(async { Ok(response) });
}
let fut = async move {
match self_inner.read_body_and_proxy(req.into_body()).await {
Err(e) => Response::builder()
.status(StatusCode::from(e))
.body(Body::empty()),
Ok(res) => Ok(res),
}
};
Box::pin(fut)
}
Method::GET => {
let query = req.uri().query().unwrap_or("");
let mut question_str = None;
for parts in query.split('&') {
let mut kv = parts.split('=');
if let Some(k) = kv.next() {
if k == DNS_QUERY_PARAM {
question_str = kv.next();
}
}
}
let question = match question_str.and_then(|question_str| {
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
}) {
Some(question) => question,
_ => {
let response = Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::empty())
.unwrap();
return Box::pin(future::ok(response));
}
};
let fut = async move {
match self_inner.proxy(question).await {
Err(e) => Response::builder()
.status(StatusCode::from(e))
.body(Body::empty()),
Ok(res) => Ok(res),
}
};
Box::pin(fut)
}
_ => {
let response = Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::empty())
.unwrap();
return Box::new(future::ok(response));
Box::pin(async { Ok(response) })
}
}
let fut = self.handle_client(req);
Box::new(fut)
}
}
@ -202,241 +191,120 @@ impl DoH {
Ok(())
}
fn handle_client(
&self,
req: Request<Body>,
) -> Box<dyn Future<Item = Response<Body>, Error = Error> + Send> {
let inner = &self.inner;
if req.uri().path() != inner.path {
let response = Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap();
return Box::new(future::ok(response));
}
match *req.method() {
Method::POST => {
if self.inner.disable_post {
let response = Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::empty())
.unwrap();
return Box::new(future::ok(response));
}
if let Err(response) = Self::check_content_type(&req) {
return Box::new(future::ok(response));
}
let fut = self.read_body_and_proxy(req.into_body()).or_else(|e| {
let response = Response::builder()
.status(StatusCode::from(e))
.body(Body::empty())
.unwrap();
future::ok(response)
});
Box::new(fut)
}
Method::GET => {
let query = req.uri().query().unwrap_or("");
let mut question_str = None;
for parts in query.split('&') {
let mut kv = parts.split('=');
if let Some(k) = kv.next() {
if k == DNS_QUERY_PARAM {
question_str = kv.next();
}
}
}
let question = match question_str.and_then(|question_str| {
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
}) {
Some(question) => question,
_ => {
let response = Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::empty())
.unwrap();
return Box::new(future::ok(response));
}
};
let fut = self.proxy(question).or_else(|e| {
let response = Response::builder()
.status(StatusCode::from(e))
.body(Body::empty())
.unwrap();
future::ok(response)
});
Box::new(fut)
}
_ => {
let response = Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::empty())
.unwrap();
Box::new(future::ok(response))
}
async fn read_body_and_proxy(&self, mut body: Body) -> Result<Response<Body>, DoHError> {
let mut sum_size = 0;
let mut query = vec![];
while let Some(chunk) = body.next().await {
let chunk = chunk.map_err(|_| DoHError::TooLarge)?;
sum_size += chunk.len();
if sum_size >= MAX_DNS_QUESTION_LEN {
return Err(DoHError::TooLarge);
}
query.extend(chunk);
}
let response = self.proxy(query).await?;
Ok(response)
}
fn proxy(
&self,
mut query: Vec<u8>,
) -> Box<dyn Future<Item = Response<Body>, Error = Error> + Send> {
async fn proxy(&self, mut query: Vec<u8>) -> Result<Response<Body>, DoHError> {
if query.len() < MIN_DNS_PACKET_LEN {
return Box::new(future::err(Error::Incomplete));
return Err(DoHError::Incomplete);
}
let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as u16);
let inner = &self.inner;
let socket = UdpSocket::bind(&inner.local_bind_address).unwrap();
let expected_server_address = inner.server_address;
let (min_ttl, max_ttl, err_ttl) = (inner.min_ttl, inner.max_ttl, inner.err_ttl);
let fut = socket
.send_dgram(query, &inner.server_address)
.map_err(Error::Io)
.and_then(move |(socket, _)| {
let packet = vec![0; MAX_DNS_RESPONSE_LEN];
socket.recv_dgram(packet).map_err(Error::Io)
})
.and_then(move |(_socket, mut packet, len, response_server_address)| {
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
return future::err(Error::UpstreamIssue);
}
packet.truncate(len);
let ttl = if dns::is_recoverable_error(&packet) {
err_ttl
} else {
match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) {
Err(_) => return future::err(Error::UpstreamIssue),
Ok(ttl) => ttl,
}
let globals = &self.globals;
let mut socket = UdpSocket::bind(&globals.local_bind_address)
.await
.map_err(DoHError::Io)?;
let expected_server_address = globals.server_address;
let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl);
socket
.send_to(&query, &globals.server_address)
.map_err(DoHError::Io)
.await?;
let mut packet = vec![0; MAX_DNS_RESPONSE_LEN];
let (len, response_server_address) =
socket.recv_from(&mut packet).map_err(DoHError::Io).await?;
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
return Err(DoHError::UpstreamIssue);
}
packet.truncate(len);
let ttl = if dns::is_recoverable_error(&packet) {
err_ttl
} else {
match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) {
Err(_) => return Err(DoHError::UpstreamIssue),
Ok(ttl) => ttl,
}
};
let packet_len = packet.len();
let response = Response::builder()
.header(hyper::header::CONTENT_LENGTH, packet_len)
.header(hyper::header::CONTENT_TYPE, "application/dns-message")
.header("X-Padding", utils::padding_string(packet_len, BLOCK_SIZE))
.header(
hyper::header::CACHE_CONTROL,
format!("max-age={}", ttl).as_str(),
)
.body(Body::from(packet))
.unwrap();
Ok(response)
}
async fn entrypoint(self) -> Result<(), Error> {
let listen_address = self.globals.listen_address;
let mut listener = TcpListener::bind(&listen_address).await?;
let path = &self.globals.path;
#[cfg(feature = "tls")]
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
(Some(tls_cert_path), Some(tls_cert_password)) => {
println!("Listening on https://{}{}", listen_address, path);
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
}
_ => {
println!("Listening on http://{}{}", listen_address, path);
None
}
};
#[cfg(not(feature = "tls"))]
println!("Listening on http://{}{}", listen_address, path);
let mut server = Http::new();
server.keep_alive(self.globals.keepalive);
let listener_service = async {
while let Some(stream) = listener.incoming().next().await {
let stream = match stream {
Ok(stream) => stream,
Err(_) => continue,
};
let packet_len = packet.len();
let response = Response::builder()
.header(hyper::header::CONTENT_LENGTH, packet_len)
.header(hyper::header::CONTENT_TYPE, "application/dns-message")
.header("X-Padding", utils::padding_string(packet_len, BLOCK_SIZE))
.header(
hyper::header::CACHE_CONTROL,
format!("max-age={}", ttl).as_str(),
let clients_count = self.globals.clients_count.clone();
if clients_count.increment() > self.globals.max_clients {
clients_count.decrement();
continue;
}
let self_inner = self.clone();
let server_inner = server.clone();
tokio::spawn(async move {
tokio::time::timeout(
self_inner.globals.timeout,
server_inner.serve_connection(stream, self_inner),
)
.body(Body::from(packet))
.unwrap();
future::ok(response)
});
Box::new(fut)
}
fn read_body_and_proxy(
&self,
body: Body,
) -> Box<dyn Future<Item = Response<Body>, Error = Error> + Send> {
let mut sum_size = 0;
let inner = self.clone();
let fut = body
.map_err(Error::Hyper)
.and_then(move |chunk| {
sum_size += chunk.len();
if sum_size > MAX_DNS_QUESTION_LEN {
Err(Error::TooLarge)
} else {
Ok(chunk)
}
})
.concat2()
.map(move |chunk| chunk.to_vec())
.and_then(move |query| inner.proxy(query));
Box::new(fut)
}
}
#[cfg(feature = "tls")]
fn create_tls_acceptor<P>(path: P, password: &str) -> io::Result<TlsAcceptor>
where
P: AsRef<Path>,
{
let identity_bin = {
let mut fp = File::open(path)?;
let mut identity_bin = vec![];
fp.read_to_end(&mut identity_bin)?;
identity_bin
};
let identity = Identity::from_pkcs12(&identity_bin, password).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Unusable PKCS12-encoded identity. The encoding and/or the password may be wrong",
)
})?;
let native_acceptor = native_tls::TlsAcceptor::new(identity).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Unable to use the provided PKCS12-encoded identity",
)
})?;
Ok(TlsAcceptor::from(native_acceptor))
}
fn client_serve<I>(
clients_count: ClientsCount,
stream: I,
http: Http,
service: DoH,
timeout: Duration,
) where
I: AsyncRead + AsyncWrite + Send + 'static,
{
let clients_count_inner = clients_count.clone();
let conn = http
.serve_connection(stream, service)
.timeout(timeout)
.map_err(|_| {})
.then(move |fut| {
clients_count_inner.decrement();
fut
});
clients_count.increment();
tokio::spawn(conn);
}
#[cfg(feature = "tls")]
fn start_with_tls(
tls_acceptor: TlsAcceptor,
listener: TcpListener,
doh: DoH,
http: Http,
timeout: Duration,
) {
let server = listener.incoming().for_each(move |io| {
let service = doh.clone();
let http = http.clone();
let clients_count = doh.inner.clients_count.clone();
tls_acceptor
.accept(io)
.timeout(timeout)
.then(move |stream| {
if let Ok(stream) = stream {
client_serve(clients_count, stream, http, service, timeout);
}
Ok(())
})
});
tokio::run(server.map_err(|_| {}));
}
fn start_without_tls(listener: TcpListener, doh: DoH, http: Http, timeout: Duration) {
let server = listener.incoming().for_each(move |stream| {
let service = doh.clone();
let http = http.clone();
let clients_count = doh.inner.clients_count.clone();
client_serve(clients_count, stream, http, service, timeout);
.await
.ok();
clients_count.decrement();
});
}
Ok(()) as Result<(), Error>
};
listener_service.await?;
Ok(())
});
tokio::run(server.map_err(|_| {}));
}
}
fn main() {
let mut inner_doh = InnerDoH {
let mut globals = Globals {
#[cfg(feature = "tls")]
tls_cert_path: None,
#[cfg(feature = "tls")]
tls_cert_password: None,
@ -453,196 +321,14 @@ fn main() {
keepalive: true,
disable_post: false,
};
parse_opts(&mut inner_doh);
let timeout = inner_doh.timeout;
#[cfg(feature = "tls")]
let path = inner_doh.path.clone();
parse_opts(&mut globals);
let doh = DoH {
inner: Arc::new(inner_doh),
globals: Arc::new(globals),
};
let listen_address = doh.inner.listen_address;
let listener = TcpListener::bind(&listen_address).unwrap();
#[cfg(feature = "tls")]
let tls_acceptor = match (&doh.inner.tls_cert_path, &doh.inner.tls_cert_password) {
(Some(tls_cert_path), Some(tls_cert_password)) => {
println!("Listening on https://{}{}", listen_address, path);
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
}
_ => {
println!("Listening on http://{}{}", listen_address, path);
None
}
};
let mut http = Http::new();
http.keep_alive(doh.inner.keepalive);
#[cfg(feature = "tls")]
{
if let Some(tls_acceptor) = tls_acceptor {
start_with_tls(tls_acceptor, listener, doh, http, timeout);
return;
}
}
start_without_tls(listener, doh, http, timeout);
}
fn parse_opts(inner_doh: &mut InnerDoH) {
use crate::utils::{verify_remote_server, verify_sock_addr};
let max_clients = MAX_CLIENTS.to_string();
let timeout_sec = TIMEOUT_SEC.to_string();
let min_ttl = MIN_TTL.to_string();
let max_ttl = MAX_TTL.to_string();
let err_ttl = ERR_TTL.to_string();
let _ = include_str!("../Cargo.toml");
let options = app_from_crate!()
.arg(
Arg::with_name("listen_address")
.short("l")
.long("listen-address")
.takes_value(true)
.default_value(LISTEN_ADDRESS)
.validator(verify_sock_addr)
.help("Address to listen to"),
)
.arg(
Arg::with_name("server_address")
.short("u")
.long("server-address")
.takes_value(true)
.default_value(SERVER_ADDRESS)
.validator(verify_remote_server)
.help("Address to connect to"),
)
.arg(
Arg::with_name("local_bind_address")
.short("b")
.long("local-bind-address")
.takes_value(true)
.validator(verify_sock_addr)
.help("Address to connect from"),
)
.arg(
Arg::with_name("path")
.short("p")
.long("path")
.takes_value(true)
.default_value(PATH)
.help("URI path"),
)
.arg(
Arg::with_name("max_clients")
.short("c")
.long("max-clients")
.takes_value(true)
.default_value(&max_clients)
.help("Maximum number of simultaneous clients"),
)
.arg(
Arg::with_name("timeout")
.short("t")
.long("timeout")
.takes_value(true)
.default_value(&timeout_sec)
.help("Timeout, in seconds"),
)
.arg(
Arg::with_name("min_ttl")
.short("T")
.long("min-ttl")
.takes_value(true)
.default_value(&min_ttl)
.help("Minimum TTL, in seconds"),
)
.arg(
Arg::with_name("max_ttl")
.short("X")
.long("max-ttl")
.takes_value(true)
.default_value(&max_ttl)
.help("Maximum TTL, in seconds"),
)
.arg(
Arg::with_name("err_ttl")
.short("E")
.long("err-ttl")
.takes_value(true)
.default_value(&err_ttl)
.help("TTL for errors, in seconds"),
)
.arg(
Arg::with_name("disable_keepalive")
.short("K")
.long("disable-keepalive")
.help("Disable keepalive"),
)
.arg(
Arg::with_name("disable_post")
.short("P")
.long("disable-post")
.help("Disable POST queries"),
);
#[cfg(feature = "tls")]
let options = options
.arg(
Arg::with_name("tls_cert_path")
.short("i")
.long("tls-cert-path")
.takes_value(true)
.help("Path to a PKCS12-encoded identity (only required for built-in TLS)"),
)
.arg(
Arg::with_name("tls_cert_password")
.short("I")
.long("tls-cert-password")
.takes_value(true)
.help("Password for the PKCS12-encoded identity (only required for built-in TLS)"),
);
let matches = options.get_matches();
inner_doh.listen_address = matches.value_of("listen_address").unwrap().parse().unwrap();
inner_doh.server_address = matches
.value_of("server_address")
.unwrap()
.to_socket_addrs()
.unwrap()
.next()
.unwrap();
inner_doh.local_bind_address = match matches.value_of("local_bind_address") {
Some(address) => address.parse().unwrap(),
None => match inner_doh.server_address {
SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
SocketAddr::V6(s) => SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::UNSPECIFIED,
0,
s.flowinfo(),
s.scope_id(),
)),
},
};
inner_doh.path = matches.value_of("path").unwrap().to_string();
if !inner_doh.path.starts_with('/') {
inner_doh.path = format!("/{}", inner_doh.path);
}
inner_doh.max_clients = matches.value_of("max_clients").unwrap().parse().unwrap();
inner_doh.timeout = Duration::from_secs(matches.value_of("timeout").unwrap().parse().unwrap());
inner_doh.min_ttl = matches.value_of("min_ttl").unwrap().parse().unwrap();
inner_doh.max_ttl = matches.value_of("max_ttl").unwrap().parse().unwrap();
inner_doh.err_ttl = matches.value_of("err_ttl").unwrap().parse().unwrap();
inner_doh.keepalive = !matches.is_present("disable_keepalive");
inner_doh.disable_post = matches.is_present("disable_post");
#[cfg(feature = "tls")]
{
inner_doh.tls_cert_path = matches.value_of("tls_cert_path").map(PathBuf::from);
inner_doh.tls_cert_password = matches
.value_of("tls_cert_password")
.map(ToString::to_string);
}
let mut runtime_builder = tokio::runtime::Builder::new();
runtime_builder.enable_all();
runtime_builder.threaded_scheduler();
runtime_builder.thread_name("doh-proxy");
let mut runtime = runtime_builder.build().unwrap();
runtime.block_on(doh.entrypoint()).unwrap();
}