Make TLS an optional feature

This commit is contained in:
Frank Denis 2019-05-19 11:38:55 +02:00
parent badcb6104d
commit 1706ec0dcb
4 changed files with 99 additions and 1164 deletions

1118
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -10,14 +10,18 @@ repository = "https://github.com/jedisct1/rust-doh"
categories = ["asynchronous", "network-programming","command-line-utilities"] categories = ["asynchronous", "network-programming","command-line-utilities"]
edition = "2018" edition = "2018"
[features]
default = []
tls = ["native-tls", "tokio-tls"]
[dependencies] [dependencies]
base64 = "0.10" base64 = "0.10"
clap = "2" clap = "2"
futures = "0.1.27" futures = "0.1.27"
hyper = "0.12.28" hyper = "0.12.29"
native-tls = "0.2.3" jemallocator = "0"
jemallocator = "0.3" native-tls = { version = "0.2.3", optional = true }
tokio = "0.1.20" tokio = "0.1.20"
tokio-current-thread = "0.1.6" tokio-current-thread = "0.1.6"
tokio-timer = "0.2.11" tokio-timer = "0.2.11"
tokio-tls = "0.2.1" tokio-tls = { version = "0.2.1", optional = true }

View file

@ -37,6 +37,8 @@ OPTIONS:
-i, --tls-cert-path <tls_cert_path> Path to a PKCS12-encoded identity (only required for built-in TLS) -i, --tls-cert-path <tls_cert_path> Path to a PKCS12-encoded identity (only required for built-in TLS)
``` ```
## HTTP/2 termination
## Clients ## Clients
`doh-proxy` can be used with [dnscrypt-proxy](https://github.com/jedisct1/dnscrypt-proxy) `doh-proxy` can be used with [dnscrypt-proxy](https://github.com/jedisct1/dnscrypt-proxy)

View file

@ -13,10 +13,18 @@ use hyper;
use hyper::server::conn::Http; use hyper::server::conn::Http;
use hyper::service::Service; use hyper::service::Service;
use hyper::{Body, Method, Request, Response, StatusCode}; use hyper::{Body, Method, Request, Response, StatusCode};
#[cfg(feature = "tls")]
use native_tls::{self, Identity}; use native_tls::{self, Identity};
#[cfg(feature = "tls")]
use std::fs::File; use std::fs::File;
#[cfg(feature = "tls")]
use std::io::{self, Read}; use std::io::{self, Read};
use std::net::SocketAddr; use std::net::SocketAddr;
#[cfg(feature = "tls")]
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
@ -24,6 +32,8 @@ use std::time::Duration;
use tokio; use tokio;
use tokio::net::{TcpListener, UdpSocket}; use tokio::net::{TcpListener, UdpSocket};
use tokio::prelude::{AsyncRead, AsyncWrite, FutureExt}; use tokio::prelude::{AsyncRead, AsyncWrite, FutureExt};
#[cfg(feature = "tls")]
use tokio_tls::TlsAcceptor; use tokio_tls::TlsAcceptor;
const BLOCK_SIZE: usize = 128; const BLOCK_SIZE: usize = 128;
@ -65,8 +75,12 @@ impl ClientsCount {
#[derive(Debug)] #[derive(Debug)]
struct InnerDoH { struct InnerDoH {
#[cfg(feature = "tls")]
tls_cert_path: Option<PathBuf>, tls_cert_path: Option<PathBuf>,
#[cfg(feature = "tls")]
tls_cert_password: Option<String>, tls_cert_password: Option<String>,
listen_address: SocketAddr, listen_address: SocketAddr,
local_bind_address: SocketAddr, local_bind_address: SocketAddr,
server_address: SocketAddr, server_address: SocketAddr,
@ -254,6 +268,7 @@ impl DoH {
} }
} }
#[cfg(feature = "tls")]
fn create_tls_acceptor<P>(path: P, password: &str) -> io::Result<TlsAcceptor> fn create_tls_acceptor<P>(path: P, password: &str) -> io::Result<TlsAcceptor>
where where
P: AsRef<Path>, P: AsRef<Path>,
@ -301,10 +316,50 @@ fn client_serve<I>(
tokio::spawn(conn); tokio::spawn(conn);
} }
#[cfg(feature = "tls")]
fn start_with_tls(
tls_acceptor: TlsAcceptor,
listener: TcpListener,
doh: DoH,
http: Http,
timeout: Duration,
) {
let server = listener.incoming().for_each(move |io| {
let service = doh.clone();
let http = http.clone();
let clients_count = doh.inner.clients_count.clone();
tls_acceptor
.accept(io)
.timeout(timeout)
.then(move |stream| {
if let Ok(stream) = stream {
client_serve(clients_count, stream, http, service, timeout);
}
Ok(())
})
});
tokio::run(server.map_err(|_| {}));
}
fn start_without_tls(listener: TcpListener, doh: DoH, http: Http, timeout: Duration) {
let server = listener.incoming().for_each(move |stream| {
let service = doh.clone();
let http = http.clone();
let clients_count = doh.inner.clients_count.clone();
client_serve(clients_count, stream, http, service, timeout);
Ok(())
});
tokio::run(server.map_err(|_| {}));
}
fn main() { fn main() {
let mut inner_doh = InnerDoH { let mut inner_doh = InnerDoH {
#[cfg(feature = "tls")]
tls_cert_path: None, tls_cert_path: None,
#[cfg(feature = "tls")]
tls_cert_password: None, tls_cert_password: None,
listen_address: LISTEN_ADDRESS.parse().unwrap(), listen_address: LISTEN_ADDRESS.parse().unwrap(),
local_bind_address: LOCAL_BIND_ADDRESS.parse().unwrap(), local_bind_address: LOCAL_BIND_ADDRESS.parse().unwrap(),
server_address: SERVER_ADDRESS.parse().unwrap(), server_address: SERVER_ADDRESS.parse().unwrap(),
@ -319,57 +374,40 @@ fn main() {
}; };
parse_opts(&mut inner_doh); parse_opts(&mut inner_doh);
let timeout = inner_doh.timeout; let timeout = inner_doh.timeout;
#[cfg(feature = "tls")]
let path = inner_doh.path.clone(); let path = inner_doh.path.clone();
let doh = DoH { let doh = DoH {
inner: Arc::new(inner_doh), inner: Arc::new(inner_doh),
}; };
let listen_address = doh.inner.listen_address; let listen_address = doh.inner.listen_address;
let listener = TcpListener::bind(&listen_address).unwrap();
// openssl pkcs12 -export -out Cert.p12 -in cert.pem -inkey key.pem -passin pass:root -passout pass:root // openssl pkcs12 -export -out Cert.p12 -in cert.pem -inkey key.pem -passin pass:root -passout pass:root
#[cfg(feature = "tls")]
let tls_acceptor = match (&doh.inner.tls_cert_path, &doh.inner.tls_cert_password) { let tls_acceptor = match (&doh.inner.tls_cert_path, &doh.inner.tls_cert_password) {
(Some(tls_cert_path), Some(tls_cert_password)) => { (Some(tls_cert_path), Some(tls_cert_password)) => {
println!("Listening on https://{}{}", listen_address, path);
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap()) Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
} }
_ => None, _ => {
}; println!("Listening on http://{}{}", listen_address, path);
None
let listener = TcpListener::bind(&listen_address).unwrap(); }
match tls_acceptor {
Some(_) => println!("Listening on https://{}{}", listen_address, path),
None => println!("Listening on http://{}{}", listen_address, path),
}; };
let mut http = Http::new(); let mut http = Http::new();
http.keep_alive(doh.inner.keepalive); http.keep_alive(doh.inner.keepalive);
if let Some(tls_acceptor) = tls_acceptor { #[cfg(feature = "tls")]
let server = listener.incoming().for_each(move |io| { {
let service = doh.clone(); if let Some(tls_acceptor) = tls_acceptor {
let http = http.clone(); start_with_tls(tls_acceptor, listener, doh, http, timeout);
let clients_count = doh.inner.clients_count.clone(); return;
tls_acceptor }
.accept(io) }
.timeout(timeout) start_without_tls(listener, doh, http, timeout);
.then(move |stream| {
if let Ok(stream) = stream {
client_serve(clients_count, stream, http, service, timeout);
}
Ok(())
})
});
tokio::run(server.map_err(|_| {}));
} else {
let server = listener.incoming().for_each(move |stream| {
let service = doh.clone();
let http = http.clone();
let clients_count = doh.inner.clients_count.clone();
client_serve(clients_count, stream, http, service, timeout);
Ok(())
});
tokio::run(server.map_err(|_| {}));
};
} }
fn parse_opts(inner_doh: &mut InnerDoH) { fn parse_opts(inner_doh: &mut InnerDoH) {
@ -378,7 +416,8 @@ fn parse_opts(inner_doh: &mut InnerDoH) {
let min_ttl = MIN_TTL.to_string(); let min_ttl = MIN_TTL.to_string();
let max_ttl = MAX_TTL.to_string(); let max_ttl = MAX_TTL.to_string();
let err_ttl = ERR_TTL.to_string(); let err_ttl = ERR_TTL.to_string();
let matches = App::new("doh-proxy")
let options = App::new("doh-proxy")
.about("A DNS-over-HTTP server proxy") .about("A DNS-over-HTTP server proxy")
.arg( .arg(
Arg::with_name("listen_address") Arg::with_name("listen_address")
@ -457,7 +496,10 @@ fn parse_opts(inner_doh: &mut InnerDoH) {
.short("K") .short("K")
.long("disable-keepalive") .long("disable-keepalive")
.help("Disable keepalive"), .help("Disable keepalive"),
) );
#[cfg(feature = "tls")]
let options = options
.arg( .arg(
Arg::with_name("tls_cert_path") Arg::with_name("tls_cert_path")
.short("i") .short("i")
@ -471,8 +513,9 @@ fn parse_opts(inner_doh: &mut InnerDoH) {
.long("tls-cert-password") .long("tls-cert-password")
.takes_value(true) .takes_value(true)
.help("Password for the PKCS12-encoded identity (only required for built-in TLS)"), .help("Password for the PKCS12-encoded identity (only required for built-in TLS)"),
) );
.get_matches();
let matches = options.get_matches();
inner_doh.listen_address = matches.value_of("listen_address").unwrap().parse().unwrap(); inner_doh.listen_address = matches.value_of("listen_address").unwrap().parse().unwrap();
inner_doh.server_address = matches.value_of("server_address").unwrap().parse().unwrap(); inner_doh.server_address = matches.value_of("server_address").unwrap().parse().unwrap();
inner_doh.local_bind_address = matches inner_doh.local_bind_address = matches
@ -490,8 +533,12 @@ fn parse_opts(inner_doh: &mut InnerDoH) {
inner_doh.max_ttl = matches.value_of("max_ttl").unwrap().parse().unwrap(); inner_doh.max_ttl = matches.value_of("max_ttl").unwrap().parse().unwrap();
inner_doh.err_ttl = matches.value_of("err_ttl").unwrap().parse().unwrap(); inner_doh.err_ttl = matches.value_of("err_ttl").unwrap().parse().unwrap();
inner_doh.keepalive = !matches.is_present("disable_keepalive"); inner_doh.keepalive = !matches.is_present("disable_keepalive");
inner_doh.tls_cert_path = matches.value_of("tls_cert_path").map(PathBuf::from);
inner_doh.tls_cert_password = matches #[cfg(feature = "tls")]
.value_of("tls_cert_password") {
.map(ToString::to_string); inner_doh.tls_cert_path = matches.value_of("tls_cert_path").map(PathBuf::from);
inner_doh.tls_cert_password = matches
.value_of("tls_cert_password")
.map(ToString::to_string);
}
} }