mirror of
https://github.com/DNSCrypt/doh-server.git
synced 2025-04-03 04:57:37 +03:00
Prepare for tokio 0.2/hyper 0.13/async-await migration
This commit is contained in:
parent
1cb4a11a7b
commit
1b850b2f41
6 changed files with 499 additions and 543 deletions
|
@ -16,14 +16,15 @@ default = []
|
|||
tls = ["native-tls", "tokio-tls"]
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
base64 = "0.11"
|
||||
clap = "2.33.0"
|
||||
futures = "0.1.29"
|
||||
hyper = "0.12.35"
|
||||
futures = "0.3"
|
||||
hyper = "0.13"
|
||||
jemallocator = "0"
|
||||
native-tls = { version = "0.2.3", optional = true }
|
||||
tokio = "0.1.22"
|
||||
tokio-tls = { version = "0.2.1", optional = true }
|
||||
tokio = { version = "0.2", features = ["full"] }
|
||||
tokio-tls = { version = "0.3", optional = true }
|
||||
|
||||
[package.metadata.deb]
|
||||
extended-description = """\
|
||||
|
|
162
src/config.rs
Normal file
162
src/config.rs
Normal file
|
@ -0,0 +1,162 @@
|
|||
#[cfg(feature = "tls")]
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub fn parse_opts(globals: &mut Globals) {
|
||||
use crate::utils::{verify_remote_server, verify_sock_addr};
|
||||
|
||||
let max_clients = MAX_CLIENTS.to_string();
|
||||
let timeout_sec = TIMEOUT_SEC.to_string();
|
||||
let min_ttl = MIN_TTL.to_string();
|
||||
let max_ttl = MAX_TTL.to_string();
|
||||
let err_ttl = ERR_TTL.to_string();
|
||||
|
||||
let _ = include_str!("../Cargo.toml");
|
||||
let options = app_from_crate!()
|
||||
.arg(
|
||||
Arg::with_name("listen_address")
|
||||
.short("l")
|
||||
.long("listen-address")
|
||||
.takes_value(true)
|
||||
.default_value(LISTEN_ADDRESS)
|
||||
.validator(verify_sock_addr)
|
||||
.help("Address to listen to"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("server_address")
|
||||
.short("u")
|
||||
.long("server-address")
|
||||
.takes_value(true)
|
||||
.default_value(SERVER_ADDRESS)
|
||||
.validator(verify_remote_server)
|
||||
.help("Address to connect to"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("local_bind_address")
|
||||
.short("b")
|
||||
.long("local-bind-address")
|
||||
.takes_value(true)
|
||||
.validator(verify_sock_addr)
|
||||
.help("Address to connect from"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("path")
|
||||
.short("p")
|
||||
.long("path")
|
||||
.takes_value(true)
|
||||
.default_value(PATH)
|
||||
.help("URI path"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("max_clients")
|
||||
.short("c")
|
||||
.long("max-clients")
|
||||
.takes_value(true)
|
||||
.default_value(&max_clients)
|
||||
.help("Maximum number of simultaneous clients"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("timeout")
|
||||
.short("t")
|
||||
.long("timeout")
|
||||
.takes_value(true)
|
||||
.default_value(&timeout_sec)
|
||||
.help("Timeout, in seconds"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("min_ttl")
|
||||
.short("T")
|
||||
.long("min-ttl")
|
||||
.takes_value(true)
|
||||
.default_value(&min_ttl)
|
||||
.help("Minimum TTL, in seconds"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("max_ttl")
|
||||
.short("X")
|
||||
.long("max-ttl")
|
||||
.takes_value(true)
|
||||
.default_value(&max_ttl)
|
||||
.help("Maximum TTL, in seconds"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("err_ttl")
|
||||
.short("E")
|
||||
.long("err-ttl")
|
||||
.takes_value(true)
|
||||
.default_value(&err_ttl)
|
||||
.help("TTL for errors, in seconds"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("disable_keepalive")
|
||||
.short("K")
|
||||
.long("disable-keepalive")
|
||||
.help("Disable keepalive"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("disable_post")
|
||||
.short("P")
|
||||
.long("disable-post")
|
||||
.help("Disable POST queries"),
|
||||
);
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
let options = options
|
||||
.arg(
|
||||
Arg::with_name("tls_cert_path")
|
||||
.short("i")
|
||||
.long("tls-cert-path")
|
||||
.takes_value(true)
|
||||
.help("Path to a PKCS12-encoded identity (only required for built-in TLS)"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("tls_cert_password")
|
||||
.short("I")
|
||||
.long("tls-cert-password")
|
||||
.takes_value(true)
|
||||
.help("Password for the PKCS12-encoded identity (only required for built-in TLS)"),
|
||||
);
|
||||
|
||||
let matches = options.get_matches();
|
||||
globals.listen_address = matches.value_of("listen_address").unwrap().parse().unwrap();
|
||||
|
||||
globals.server_address = matches
|
||||
.value_of("server_address")
|
||||
.unwrap()
|
||||
.to_socket_addrs()
|
||||
.unwrap()
|
||||
.next()
|
||||
.unwrap();
|
||||
globals.local_bind_address = match matches.value_of("local_bind_address") {
|
||||
Some(address) => address.parse().unwrap(),
|
||||
None => match globals.server_address {
|
||||
SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
||||
SocketAddr::V6(s) => SocketAddr::V6(SocketAddrV6::new(
|
||||
Ipv6Addr::UNSPECIFIED,
|
||||
0,
|
||||
s.flowinfo(),
|
||||
s.scope_id(),
|
||||
)),
|
||||
},
|
||||
};
|
||||
globals.path = matches.value_of("path").unwrap().to_string();
|
||||
if !globals.path.starts_with('/') {
|
||||
globals.path = format!("/{}", globals.path);
|
||||
}
|
||||
globals.max_clients = matches.value_of("max_clients").unwrap().parse().unwrap();
|
||||
globals.timeout = Duration::from_secs(matches.value_of("timeout").unwrap().parse().unwrap());
|
||||
globals.min_ttl = matches.value_of("min_ttl").unwrap().parse().unwrap();
|
||||
globals.max_ttl = matches.value_of("max_ttl").unwrap().parse().unwrap();
|
||||
globals.err_ttl = matches.value_of("err_ttl").unwrap().parse().unwrap();
|
||||
globals.keepalive = !matches.is_present("disable_keepalive");
|
||||
globals.disable_post = matches.is_present("disable_post");
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
{
|
||||
globals.tls_cert_path = matches.value_of("tls_cert_path").map(PathBuf::from);
|
||||
globals.tls_cert_password = matches
|
||||
.value_of("tls_cert_password")
|
||||
.map(ToString::to_string);
|
||||
}
|
||||
}
|
13
src/constants.rs
Normal file
13
src/constants.rs
Normal file
|
@ -0,0 +1,13 @@
|
|||
pub const BLOCK_SIZE: usize = 128;
|
||||
pub const DNS_QUERY_PARAM: &str = "dns";
|
||||
pub const LISTEN_ADDRESS: &str = "127.0.0.1:3000";
|
||||
pub const MAX_CLIENTS: usize = 512;
|
||||
pub const MAX_DNS_QUESTION_LEN: usize = 512;
|
||||
pub const MAX_DNS_RESPONSE_LEN: usize = 4096;
|
||||
pub const MIN_DNS_PACKET_LEN: usize = 17;
|
||||
pub const PATH: &str = "/dns-query";
|
||||
pub const SERVER_ADDRESS: &str = "9.9.9.9:53";
|
||||
pub const TIMEOUT_SEC: u64 = 10;
|
||||
pub const MAX_TTL: u32 = 86400 * 7;
|
||||
pub const MIN_TTL: u32 = 10;
|
||||
pub const ERR_TTL: u32 = 2;
|
47
src/errors.rs
Normal file
47
src/errors.rs
Normal file
|
@ -0,0 +1,47 @@
|
|||
pub use anyhow::{anyhow, bail, ensure, Error};
|
||||
|
||||
use hyper::StatusCode;
|
||||
use std::io;
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug)]
|
||||
pub enum DoHError {
|
||||
Incomplete,
|
||||
InvalidData,
|
||||
TooLarge,
|
||||
UpstreamIssue,
|
||||
Hyper(hyper::Error),
|
||||
Io(io::Error),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DoHError {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
|
||||
std::fmt::Debug::fmt(self, fmt)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for DoHError {
|
||||
fn description(&self) -> &str {
|
||||
match *self {
|
||||
DoHError::Incomplete => "Incomplete",
|
||||
DoHError::InvalidData => "Invalid data",
|
||||
DoHError::TooLarge => "Too large",
|
||||
DoHError::UpstreamIssue => "Upstream error",
|
||||
DoHError::Hyper(_) => self.description(),
|
||||
DoHError::Io(_) => self.description(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DoHError> for StatusCode {
|
||||
fn from(e: DoHError) -> StatusCode {
|
||||
match e {
|
||||
DoHError::Incomplete => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
DoHError::InvalidData => StatusCode::BAD_REQUEST,
|
||||
DoHError::TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
|
||||
DoHError::UpstreamIssue => StatusCode::BAD_GATEWAY,
|
||||
DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE,
|
||||
DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
}
|
47
src/globals.rs
Normal file
47
src/globals.rs
Normal file
|
@ -0,0 +1,47 @@
|
|||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Globals {
|
||||
#[cfg(feature = "tls")]
|
||||
pub tls_cert_path: Option<PathBuf>,
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
pub tls_cert_password: Option<String>,
|
||||
|
||||
pub listen_address: SocketAddr,
|
||||
pub local_bind_address: SocketAddr,
|
||||
pub server_address: SocketAddr,
|
||||
pub path: String,
|
||||
pub max_clients: usize,
|
||||
pub timeout: Duration,
|
||||
pub clients_count: ClientsCount,
|
||||
pub min_ttl: u32,
|
||||
pub max_ttl: u32,
|
||||
pub err_ttl: u32,
|
||||
pub keepalive: bool,
|
||||
pub disable_post: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ClientsCount(Arc<AtomicUsize>);
|
||||
|
||||
impl ClientsCount {
|
||||
pub fn increment(&self) -> usize {
|
||||
self.0.fetch_add(1, Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn decrement(&self) -> usize {
|
||||
let mut count;
|
||||
while {
|
||||
count = self.0.load(Ordering::Relaxed);
|
||||
count > 0 && self.0.compare_and_swap(count, count - 1, Ordering::Relaxed) != count
|
||||
} {}
|
||||
count
|
||||
}
|
||||
}
|
764
src/main.rs
764
src/main.rs
|
@ -4,168 +4,157 @@ static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
|
|||
#[macro_use]
|
||||
extern crate clap;
|
||||
|
||||
mod config;
|
||||
mod constants;
|
||||
mod dns;
|
||||
mod errors;
|
||||
mod globals;
|
||||
mod utils;
|
||||
|
||||
use base64;
|
||||
use crate::config::*;
|
||||
use crate::constants::*;
|
||||
use crate::errors::*;
|
||||
use crate::globals::*;
|
||||
|
||||
use clap::Arg;
|
||||
use futures::future;
|
||||
use futures::prelude::*;
|
||||
use futures::stream::Stream;
|
||||
use hyper;
|
||||
use futures::task::{Context, Poll};
|
||||
use hyper::http;
|
||||
use hyper::server::conn::Http;
|
||||
use hyper::service::Service;
|
||||
use hyper::{Body, Method, Request, Response, StatusCode};
|
||||
use std::io;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::{TcpListener, UdpSocket};
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use native_tls::{self, Identity};
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use std::fs::File;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use std::io;
|
||||
#[cfg(feature = "tls")]
|
||||
use std::io::Read;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio;
|
||||
use tokio::net::{TcpListener, UdpSocket};
|
||||
use tokio::prelude::{AsyncRead, AsyncWrite, FutureExt};
|
||||
|
||||
use std::path::Path;
|
||||
#[cfg(feature = "tls")]
|
||||
use tokio_tls::TlsAcceptor;
|
||||
|
||||
const BLOCK_SIZE: usize = 128;
|
||||
const DNS_QUERY_PARAM: &str = "dns";
|
||||
const LISTEN_ADDRESS: &str = "127.0.0.1:3000";
|
||||
const MAX_CLIENTS: usize = 512;
|
||||
const MAX_DNS_QUESTION_LEN: usize = 512;
|
||||
const MAX_DNS_RESPONSE_LEN: usize = 4096;
|
||||
const MIN_DNS_PACKET_LEN: usize = 17;
|
||||
const PATH: &str = "/dns-query";
|
||||
const SERVER_ADDRESS: &str = "9.9.9.9:53";
|
||||
const TIMEOUT_SEC: u64 = 10;
|
||||
const MAX_TTL: u32 = 86400 * 7;
|
||||
const MIN_TTL: u32 = 10;
|
||||
const ERR_TTL: u32 = 2;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct ClientsCount(Arc<AtomicUsize>);
|
||||
|
||||
impl ClientsCount {
|
||||
fn current(&self) -> usize {
|
||||
self.0.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn increment(&self) -> usize {
|
||||
self.0.fetch_add(1, Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn decrement(&self) -> usize {
|
||||
let mut count;
|
||||
while {
|
||||
count = self.0.load(Ordering::Relaxed);
|
||||
count > 0 && self.0.compare_and_swap(count, count - 1, Ordering::Relaxed) != count
|
||||
} {}
|
||||
count
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InnerDoH {
|
||||
#[cfg(feature = "tls")]
|
||||
tls_cert_path: Option<PathBuf>,
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
tls_cert_password: Option<String>,
|
||||
|
||||
listen_address: SocketAddr,
|
||||
local_bind_address: SocketAddr,
|
||||
server_address: SocketAddr,
|
||||
path: String,
|
||||
max_clients: usize,
|
||||
timeout: Duration,
|
||||
clients_count: ClientsCount,
|
||||
min_ttl: u32,
|
||||
max_ttl: u32,
|
||||
err_ttl: u32,
|
||||
keepalive: bool,
|
||||
disable_post: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct DoH {
|
||||
inner: Arc<InnerDoH>,
|
||||
globals: Arc<Globals>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug)]
|
||||
enum Error {
|
||||
Incomplete,
|
||||
InvalidData,
|
||||
TooLarge,
|
||||
UpstreamIssue,
|
||||
Hyper(hyper::Error),
|
||||
Io(io::Error),
|
||||
#[cfg(feature = "tls")]
|
||||
fn create_tls_acceptor<P>(path: P, password: &str) -> io::Result<TlsAcceptor>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
{
|
||||
let identity_bin = {
|
||||
let mut fp = File::open(path)?;
|
||||
let mut identity_bin = vec![];
|
||||
fp.read_to_end(&mut identity_bin)?;
|
||||
identity_bin
|
||||
};
|
||||
let identity = Identity::from_pkcs12(&identity_bin, password).map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Unusable PKCS12-encoded identity. The encoding and/or the password may be wrong",
|
||||
)
|
||||
})?;
|
||||
let native_acceptor = native_tls::TlsAcceptor::new(identity).map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Unable to use the provided PKCS12-encoded identity",
|
||||
)
|
||||
})?;
|
||||
Ok(TlsAcceptor::from(native_acceptor))
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
|
||||
std::fmt::Debug::fmt(self, fmt)
|
||||
impl hyper::service::Service<http::Request<Body>> for DoH {
|
||||
type Response = Response<Body>;
|
||||
type Error = http::Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn description(&self) -> &str {
|
||||
match *self {
|
||||
Error::Incomplete => "Incomplete",
|
||||
Error::InvalidData => "Invalid data",
|
||||
Error::TooLarge => "Too large",
|
||||
Error::UpstreamIssue => "Upstream error",
|
||||
Error::Hyper(_) => self.description(),
|
||||
Error::Io(_) => self.description(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Error> for StatusCode {
|
||||
fn from(e: Error) -> StatusCode {
|
||||
match e {
|
||||
Error::Incomplete => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Error::InvalidData => StatusCode::BAD_REQUEST,
|
||||
Error::TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
|
||||
Error::UpstreamIssue => StatusCode::BAD_GATEWAY,
|
||||
Error::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE,
|
||||
Error::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Service for DoH {
|
||||
type ReqBody = Body;
|
||||
type ResBody = Body;
|
||||
type Error = Error;
|
||||
type Future = Box<dyn Future<Item = Response<Body>, Error = Self::Error> + Send>;
|
||||
|
||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||
let inner = &self.inner;
|
||||
{
|
||||
let count = inner.clients_count.current();
|
||||
if count >= inner.max_clients {
|
||||
let globals = &self.globals;
|
||||
if req.uri().path() != globals.path {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Box::pin(async { Ok(response) });
|
||||
}
|
||||
let self_inner = self.clone();
|
||||
match *req.method() {
|
||||
Method::POST => {
|
||||
if globals.disable_post {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::METHOD_NOT_ALLOWED)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Box::pin(async { Ok(response) });
|
||||
}
|
||||
if let Err(response) = Self::check_content_type(&req) {
|
||||
return Box::pin(async { Ok(response) });
|
||||
}
|
||||
let fut = async move {
|
||||
match self_inner.read_body_and_proxy(req.into_body()).await {
|
||||
Err(e) => Response::builder()
|
||||
.status(StatusCode::from(e))
|
||||
.body(Body::empty()),
|
||||
Ok(res) => Ok(res),
|
||||
}
|
||||
};
|
||||
Box::pin(fut)
|
||||
}
|
||||
Method::GET => {
|
||||
let query = req.uri().query().unwrap_or("");
|
||||
let mut question_str = None;
|
||||
for parts in query.split('&') {
|
||||
let mut kv = parts.split('=');
|
||||
if let Some(k) = kv.next() {
|
||||
if k == DNS_QUERY_PARAM {
|
||||
question_str = kv.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
let question = match question_str.and_then(|question_str| {
|
||||
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
|
||||
}) {
|
||||
Some(question) => question,
|
||||
_ => {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Box::pin(future::ok(response));
|
||||
}
|
||||
};
|
||||
let fut = async move {
|
||||
match self_inner.proxy(question).await {
|
||||
Err(e) => Response::builder()
|
||||
.status(StatusCode::from(e))
|
||||
.body(Body::empty()),
|
||||
Ok(res) => Ok(res),
|
||||
}
|
||||
};
|
||||
Box::pin(fut)
|
||||
}
|
||||
_ => {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::TOO_MANY_REQUESTS)
|
||||
.status(StatusCode::METHOD_NOT_ALLOWED)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Box::new(future::ok(response));
|
||||
Box::pin(async { Ok(response) })
|
||||
}
|
||||
}
|
||||
let fut = self.handle_client(req);
|
||||
Box::new(fut)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -202,241 +191,120 @@ impl DoH {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_client(
|
||||
&self,
|
||||
req: Request<Body>,
|
||||
) -> Box<dyn Future<Item = Response<Body>, Error = Error> + Send> {
|
||||
let inner = &self.inner;
|
||||
if req.uri().path() != inner.path {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Box::new(future::ok(response));
|
||||
}
|
||||
match *req.method() {
|
||||
Method::POST => {
|
||||
if self.inner.disable_post {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::METHOD_NOT_ALLOWED)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Box::new(future::ok(response));
|
||||
}
|
||||
if let Err(response) = Self::check_content_type(&req) {
|
||||
return Box::new(future::ok(response));
|
||||
}
|
||||
let fut = self.read_body_and_proxy(req.into_body()).or_else(|e| {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::from(e))
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
future::ok(response)
|
||||
});
|
||||
Box::new(fut)
|
||||
}
|
||||
Method::GET => {
|
||||
let query = req.uri().query().unwrap_or("");
|
||||
let mut question_str = None;
|
||||
for parts in query.split('&') {
|
||||
let mut kv = parts.split('=');
|
||||
if let Some(k) = kv.next() {
|
||||
if k == DNS_QUERY_PARAM {
|
||||
question_str = kv.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
let question = match question_str.and_then(|question_str| {
|
||||
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
|
||||
}) {
|
||||
Some(question) => question,
|
||||
_ => {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Box::new(future::ok(response));
|
||||
}
|
||||
};
|
||||
let fut = self.proxy(question).or_else(|e| {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::from(e))
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
future::ok(response)
|
||||
});
|
||||
Box::new(fut)
|
||||
}
|
||||
_ => {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::METHOD_NOT_ALLOWED)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
Box::new(future::ok(response))
|
||||
}
|
||||
async fn read_body_and_proxy(&self, mut body: Body) -> Result<Response<Body>, DoHError> {
|
||||
let mut sum_size = 0;
|
||||
let mut query = vec![];
|
||||
while let Some(chunk) = body.next().await {
|
||||
let chunk = chunk.map_err(|_| DoHError::TooLarge)?;
|
||||
sum_size += chunk.len();
|
||||
if sum_size >= MAX_DNS_QUESTION_LEN {
|
||||
return Err(DoHError::TooLarge);
|
||||
}
|
||||
query.extend(chunk);
|
||||
}
|
||||
let response = self.proxy(query).await?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
fn proxy(
|
||||
&self,
|
||||
mut query: Vec<u8>,
|
||||
) -> Box<dyn Future<Item = Response<Body>, Error = Error> + Send> {
|
||||
async fn proxy(&self, mut query: Vec<u8>) -> Result<Response<Body>, DoHError> {
|
||||
if query.len() < MIN_DNS_PACKET_LEN {
|
||||
return Box::new(future::err(Error::Incomplete));
|
||||
return Err(DoHError::Incomplete);
|
||||
}
|
||||
let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as u16);
|
||||
let inner = &self.inner;
|
||||
let socket = UdpSocket::bind(&inner.local_bind_address).unwrap();
|
||||
let expected_server_address = inner.server_address;
|
||||
let (min_ttl, max_ttl, err_ttl) = (inner.min_ttl, inner.max_ttl, inner.err_ttl);
|
||||
let fut = socket
|
||||
.send_dgram(query, &inner.server_address)
|
||||
.map_err(Error::Io)
|
||||
.and_then(move |(socket, _)| {
|
||||
let packet = vec![0; MAX_DNS_RESPONSE_LEN];
|
||||
socket.recv_dgram(packet).map_err(Error::Io)
|
||||
})
|
||||
.and_then(move |(_socket, mut packet, len, response_server_address)| {
|
||||
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
|
||||
return future::err(Error::UpstreamIssue);
|
||||
}
|
||||
packet.truncate(len);
|
||||
let ttl = if dns::is_recoverable_error(&packet) {
|
||||
err_ttl
|
||||
} else {
|
||||
match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) {
|
||||
Err(_) => return future::err(Error::UpstreamIssue),
|
||||
Ok(ttl) => ttl,
|
||||
}
|
||||
let globals = &self.globals;
|
||||
let mut socket = UdpSocket::bind(&globals.local_bind_address)
|
||||
.await
|
||||
.map_err(DoHError::Io)?;
|
||||
let expected_server_address = globals.server_address;
|
||||
let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl);
|
||||
socket
|
||||
.send_to(&query, &globals.server_address)
|
||||
.map_err(DoHError::Io)
|
||||
.await?;
|
||||
let mut packet = vec![0; MAX_DNS_RESPONSE_LEN];
|
||||
let (len, response_server_address) =
|
||||
socket.recv_from(&mut packet).map_err(DoHError::Io).await?;
|
||||
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
|
||||
return Err(DoHError::UpstreamIssue);
|
||||
}
|
||||
packet.truncate(len);
|
||||
let ttl = if dns::is_recoverable_error(&packet) {
|
||||
err_ttl
|
||||
} else {
|
||||
match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) {
|
||||
Err(_) => return Err(DoHError::UpstreamIssue),
|
||||
Ok(ttl) => ttl,
|
||||
}
|
||||
};
|
||||
let packet_len = packet.len();
|
||||
let response = Response::builder()
|
||||
.header(hyper::header::CONTENT_LENGTH, packet_len)
|
||||
.header(hyper::header::CONTENT_TYPE, "application/dns-message")
|
||||
.header("X-Padding", utils::padding_string(packet_len, BLOCK_SIZE))
|
||||
.header(
|
||||
hyper::header::CACHE_CONTROL,
|
||||
format!("max-age={}", ttl).as_str(),
|
||||
)
|
||||
.body(Body::from(packet))
|
||||
.unwrap();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn entrypoint(self) -> Result<(), Error> {
|
||||
let listen_address = self.globals.listen_address;
|
||||
let mut listener = TcpListener::bind(&listen_address).await?;
|
||||
let path = &self.globals.path;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
|
||||
(Some(tls_cert_path), Some(tls_cert_password)) => {
|
||||
println!("Listening on https://{}{}", listen_address, path);
|
||||
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
|
||||
}
|
||||
_ => {
|
||||
println!("Listening on http://{}{}", listen_address, path);
|
||||
None
|
||||
}
|
||||
};
|
||||
#[cfg(not(feature = "tls"))]
|
||||
println!("Listening on http://{}{}", listen_address, path);
|
||||
|
||||
let mut server = Http::new();
|
||||
server.keep_alive(self.globals.keepalive);
|
||||
let listener_service = async {
|
||||
while let Some(stream) = listener.incoming().next().await {
|
||||
let stream = match stream {
|
||||
Ok(stream) => stream,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let packet_len = packet.len();
|
||||
let response = Response::builder()
|
||||
.header(hyper::header::CONTENT_LENGTH, packet_len)
|
||||
.header(hyper::header::CONTENT_TYPE, "application/dns-message")
|
||||
.header("X-Padding", utils::padding_string(packet_len, BLOCK_SIZE))
|
||||
.header(
|
||||
hyper::header::CACHE_CONTROL,
|
||||
format!("max-age={}", ttl).as_str(),
|
||||
let clients_count = self.globals.clients_count.clone();
|
||||
if clients_count.increment() > self.globals.max_clients {
|
||||
clients_count.decrement();
|
||||
continue;
|
||||
}
|
||||
let self_inner = self.clone();
|
||||
let server_inner = server.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::timeout(
|
||||
self_inner.globals.timeout,
|
||||
server_inner.serve_connection(stream, self_inner),
|
||||
)
|
||||
.body(Body::from(packet))
|
||||
.unwrap();
|
||||
future::ok(response)
|
||||
});
|
||||
Box::new(fut)
|
||||
}
|
||||
|
||||
fn read_body_and_proxy(
|
||||
&self,
|
||||
body: Body,
|
||||
) -> Box<dyn Future<Item = Response<Body>, Error = Error> + Send> {
|
||||
let mut sum_size = 0;
|
||||
let inner = self.clone();
|
||||
let fut = body
|
||||
.map_err(Error::Hyper)
|
||||
.and_then(move |chunk| {
|
||||
sum_size += chunk.len();
|
||||
if sum_size > MAX_DNS_QUESTION_LEN {
|
||||
Err(Error::TooLarge)
|
||||
} else {
|
||||
Ok(chunk)
|
||||
}
|
||||
})
|
||||
.concat2()
|
||||
.map(move |chunk| chunk.to_vec())
|
||||
.and_then(move |query| inner.proxy(query));
|
||||
Box::new(fut)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
fn create_tls_acceptor<P>(path: P, password: &str) -> io::Result<TlsAcceptor>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
{
|
||||
let identity_bin = {
|
||||
let mut fp = File::open(path)?;
|
||||
let mut identity_bin = vec![];
|
||||
fp.read_to_end(&mut identity_bin)?;
|
||||
identity_bin
|
||||
};
|
||||
let identity = Identity::from_pkcs12(&identity_bin, password).map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Unusable PKCS12-encoded identity. The encoding and/or the password may be wrong",
|
||||
)
|
||||
})?;
|
||||
let native_acceptor = native_tls::TlsAcceptor::new(identity).map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Unable to use the provided PKCS12-encoded identity",
|
||||
)
|
||||
})?;
|
||||
Ok(TlsAcceptor::from(native_acceptor))
|
||||
}
|
||||
|
||||
fn client_serve<I>(
|
||||
clients_count: ClientsCount,
|
||||
stream: I,
|
||||
http: Http,
|
||||
service: DoH,
|
||||
timeout: Duration,
|
||||
) where
|
||||
I: AsyncRead + AsyncWrite + Send + 'static,
|
||||
{
|
||||
let clients_count_inner = clients_count.clone();
|
||||
let conn = http
|
||||
.serve_connection(stream, service)
|
||||
.timeout(timeout)
|
||||
.map_err(|_| {})
|
||||
.then(move |fut| {
|
||||
clients_count_inner.decrement();
|
||||
fut
|
||||
});
|
||||
clients_count.increment();
|
||||
tokio::spawn(conn);
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
fn start_with_tls(
|
||||
tls_acceptor: TlsAcceptor,
|
||||
listener: TcpListener,
|
||||
doh: DoH,
|
||||
http: Http,
|
||||
timeout: Duration,
|
||||
) {
|
||||
let server = listener.incoming().for_each(move |io| {
|
||||
let service = doh.clone();
|
||||
let http = http.clone();
|
||||
let clients_count = doh.inner.clients_count.clone();
|
||||
tls_acceptor
|
||||
.accept(io)
|
||||
.timeout(timeout)
|
||||
.then(move |stream| {
|
||||
if let Ok(stream) = stream {
|
||||
client_serve(clients_count, stream, http, service, timeout);
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
tokio::run(server.map_err(|_| {}));
|
||||
}
|
||||
|
||||
fn start_without_tls(listener: TcpListener, doh: DoH, http: Http, timeout: Duration) {
|
||||
let server = listener.incoming().for_each(move |stream| {
|
||||
let service = doh.clone();
|
||||
let http = http.clone();
|
||||
let clients_count = doh.inner.clients_count.clone();
|
||||
client_serve(clients_count, stream, http, service, timeout);
|
||||
.await
|
||||
.ok();
|
||||
clients_count.decrement();
|
||||
});
|
||||
}
|
||||
Ok(()) as Result<(), Error>
|
||||
};
|
||||
listener_service.await?;
|
||||
Ok(())
|
||||
});
|
||||
tokio::run(server.map_err(|_| {}));
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let mut inner_doh = InnerDoH {
|
||||
let mut globals = Globals {
|
||||
#[cfg(feature = "tls")]
|
||||
tls_cert_path: None,
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
tls_cert_password: None,
|
||||
|
||||
|
@ -453,196 +321,14 @@ fn main() {
|
|||
keepalive: true,
|
||||
disable_post: false,
|
||||
};
|
||||
parse_opts(&mut inner_doh);
|
||||
let timeout = inner_doh.timeout;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
let path = inner_doh.path.clone();
|
||||
parse_opts(&mut globals);
|
||||
let doh = DoH {
|
||||
inner: Arc::new(inner_doh),
|
||||
globals: Arc::new(globals),
|
||||
};
|
||||
let listen_address = doh.inner.listen_address;
|
||||
let listener = TcpListener::bind(&listen_address).unwrap();
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
let tls_acceptor = match (&doh.inner.tls_cert_path, &doh.inner.tls_cert_password) {
|
||||
(Some(tls_cert_path), Some(tls_cert_password)) => {
|
||||
println!("Listening on https://{}{}", listen_address, path);
|
||||
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
|
||||
}
|
||||
_ => {
|
||||
println!("Listening on http://{}{}", listen_address, path);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let mut http = Http::new();
|
||||
http.keep_alive(doh.inner.keepalive);
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
{
|
||||
if let Some(tls_acceptor) = tls_acceptor {
|
||||
start_with_tls(tls_acceptor, listener, doh, http, timeout);
|
||||
return;
|
||||
}
|
||||
}
|
||||
start_without_tls(listener, doh, http, timeout);
|
||||
}
|
||||
|
||||
fn parse_opts(inner_doh: &mut InnerDoH) {
|
||||
use crate::utils::{verify_remote_server, verify_sock_addr};
|
||||
|
||||
let max_clients = MAX_CLIENTS.to_string();
|
||||
let timeout_sec = TIMEOUT_SEC.to_string();
|
||||
let min_ttl = MIN_TTL.to_string();
|
||||
let max_ttl = MAX_TTL.to_string();
|
||||
let err_ttl = ERR_TTL.to_string();
|
||||
|
||||
let _ = include_str!("../Cargo.toml");
|
||||
let options = app_from_crate!()
|
||||
.arg(
|
||||
Arg::with_name("listen_address")
|
||||
.short("l")
|
||||
.long("listen-address")
|
||||
.takes_value(true)
|
||||
.default_value(LISTEN_ADDRESS)
|
||||
.validator(verify_sock_addr)
|
||||
.help("Address to listen to"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("server_address")
|
||||
.short("u")
|
||||
.long("server-address")
|
||||
.takes_value(true)
|
||||
.default_value(SERVER_ADDRESS)
|
||||
.validator(verify_remote_server)
|
||||
.help("Address to connect to"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("local_bind_address")
|
||||
.short("b")
|
||||
.long("local-bind-address")
|
||||
.takes_value(true)
|
||||
.validator(verify_sock_addr)
|
||||
.help("Address to connect from"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("path")
|
||||
.short("p")
|
||||
.long("path")
|
||||
.takes_value(true)
|
||||
.default_value(PATH)
|
||||
.help("URI path"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("max_clients")
|
||||
.short("c")
|
||||
.long("max-clients")
|
||||
.takes_value(true)
|
||||
.default_value(&max_clients)
|
||||
.help("Maximum number of simultaneous clients"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("timeout")
|
||||
.short("t")
|
||||
.long("timeout")
|
||||
.takes_value(true)
|
||||
.default_value(&timeout_sec)
|
||||
.help("Timeout, in seconds"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("min_ttl")
|
||||
.short("T")
|
||||
.long("min-ttl")
|
||||
.takes_value(true)
|
||||
.default_value(&min_ttl)
|
||||
.help("Minimum TTL, in seconds"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("max_ttl")
|
||||
.short("X")
|
||||
.long("max-ttl")
|
||||
.takes_value(true)
|
||||
.default_value(&max_ttl)
|
||||
.help("Maximum TTL, in seconds"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("err_ttl")
|
||||
.short("E")
|
||||
.long("err-ttl")
|
||||
.takes_value(true)
|
||||
.default_value(&err_ttl)
|
||||
.help("TTL for errors, in seconds"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("disable_keepalive")
|
||||
.short("K")
|
||||
.long("disable-keepalive")
|
||||
.help("Disable keepalive"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("disable_post")
|
||||
.short("P")
|
||||
.long("disable-post")
|
||||
.help("Disable POST queries"),
|
||||
);
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
let options = options
|
||||
.arg(
|
||||
Arg::with_name("tls_cert_path")
|
||||
.short("i")
|
||||
.long("tls-cert-path")
|
||||
.takes_value(true)
|
||||
.help("Path to a PKCS12-encoded identity (only required for built-in TLS)"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("tls_cert_password")
|
||||
.short("I")
|
||||
.long("tls-cert-password")
|
||||
.takes_value(true)
|
||||
.help("Password for the PKCS12-encoded identity (only required for built-in TLS)"),
|
||||
);
|
||||
|
||||
let matches = options.get_matches();
|
||||
inner_doh.listen_address = matches.value_of("listen_address").unwrap().parse().unwrap();
|
||||
|
||||
inner_doh.server_address = matches
|
||||
.value_of("server_address")
|
||||
.unwrap()
|
||||
.to_socket_addrs()
|
||||
.unwrap()
|
||||
.next()
|
||||
.unwrap();
|
||||
inner_doh.local_bind_address = match matches.value_of("local_bind_address") {
|
||||
Some(address) => address.parse().unwrap(),
|
||||
None => match inner_doh.server_address {
|
||||
SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
||||
SocketAddr::V6(s) => SocketAddr::V6(SocketAddrV6::new(
|
||||
Ipv6Addr::UNSPECIFIED,
|
||||
0,
|
||||
s.flowinfo(),
|
||||
s.scope_id(),
|
||||
)),
|
||||
},
|
||||
};
|
||||
inner_doh.path = matches.value_of("path").unwrap().to_string();
|
||||
if !inner_doh.path.starts_with('/') {
|
||||
inner_doh.path = format!("/{}", inner_doh.path);
|
||||
}
|
||||
inner_doh.max_clients = matches.value_of("max_clients").unwrap().parse().unwrap();
|
||||
inner_doh.timeout = Duration::from_secs(matches.value_of("timeout").unwrap().parse().unwrap());
|
||||
inner_doh.min_ttl = matches.value_of("min_ttl").unwrap().parse().unwrap();
|
||||
inner_doh.max_ttl = matches.value_of("max_ttl").unwrap().parse().unwrap();
|
||||
inner_doh.err_ttl = matches.value_of("err_ttl").unwrap().parse().unwrap();
|
||||
inner_doh.keepalive = !matches.is_present("disable_keepalive");
|
||||
inner_doh.disable_post = matches.is_present("disable_post");
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
{
|
||||
inner_doh.tls_cert_path = matches.value_of("tls_cert_path").map(PathBuf::from);
|
||||
inner_doh.tls_cert_password = matches
|
||||
.value_of("tls_cert_password")
|
||||
.map(ToString::to_string);
|
||||
}
|
||||
let mut runtime_builder = tokio::runtime::Builder::new();
|
||||
runtime_builder.enable_all();
|
||||
runtime_builder.threaded_scheduler();
|
||||
runtime_builder.thread_name("doh-proxy");
|
||||
let mut runtime = runtime_builder.build().unwrap();
|
||||
runtime.block_on(doh.entrypoint()).unwrap();
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue