Reorganize a little bit

This commit is contained in:
Frank Denis 2019-12-24 09:44:43 +01:00
parent bf42e95368
commit 06b91af009
12 changed files with 312 additions and 275 deletions

2
.gitignore vendored
View file

@ -1,5 +1,7 @@
#*#
**/*.rs.bk
*~
Cargo.lock
/target/
/src/libdoh/target/

View file

@ -13,19 +13,13 @@ readme = "README.md"
[features]
default = []
tls = ["native-tls", "tokio-tls"]
tls = ["libdoh/tls"]
[dependencies]
anyhow = "1.0"
byteorder = "1.3"
base64 = "0.11"
clap = "2.33.0"
futures = { version = "0.3" }
hyper = { version = "0.13", default-features = false, features = ["stream"] }
libdoh = { path = "src/libdoh" }
clap = "2"
jemallocator = "0"
native-tls = { version = "0.2.3", optional = true }
tokio = { version = "0.2", features = ["rt-threaded", "time", "tcp", "udp", "stream"] }
tokio-tls = { version = "0.3", optional = true }
[package.metadata.deb]
extended-description = """\

View file

@ -1,5 +1,6 @@
use libdoh::*;
use crate::constants::*;
use crate::globals::*;
use clap::Arg;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};

View file

@ -1,10 +1,5 @@
pub const BLOCK_SIZE: usize = 128;
pub const DNS_QUERY_PARAM: &str = "dns";
pub const LISTEN_ADDRESS: &str = "127.0.0.1:3000";
pub const MAX_CLIENTS: usize = 512;
pub const MAX_DNS_QUESTION_LEN: usize = 512;
pub const MAX_DNS_RESPONSE_LEN: usize = 4096;
pub const MIN_DNS_PACKET_LEN: usize = 17;
pub const PATH: &str = "/dns-query";
pub const SERVER_ADDRESS: &str = "9.9.9.9:53";
pub const TIMEOUT_SEC: u64 = 10;

28
src/libdoh/Cargo.toml Normal file
View file

@ -0,0 +1,28 @@
[package]
name = "libdoh"
version = "0.2.0"
authors = ["Frank Denis <github@pureftpd.org>"]
license = "MIT"
edition = "2018"
publish = false
[features]
default = []
tls = ["native-tls", "tokio-tls"]
[dependencies]
anyhow = "1.0"
byteorder = "1.3"
base64 = "0.11"
futures = { version = "0.3" }
hyper = { version = "0.13", default-features = false, features = ["stream"] }
native-tls = { version = "0.2.3", optional = true }
tokio = { version = "0.2", features = ["rt-threaded", "time", "tcp", "udp", "stream"] }
tokio-tls = { version = "0.3", optional = true }
[profile.release]
codegen-units = 1
incremental = false
lto = "fat"
opt-level = 3
panic = "abort"

View file

@ -0,0 +1,5 @@
pub const BLOCK_SIZE: usize = 128;
pub const DNS_QUERY_PARAM: &str = "dns";
pub const MAX_DNS_QUESTION_LEN: usize = 512;
pub const MAX_DNS_RESPONSE_LEN: usize = 4096;
pub const MIN_DNS_PACKET_LEN: usize = 17;

View file

@ -243,6 +243,11 @@ pub fn add_edns_padding(packet: &mut Vec<u8>, block_size: usize) -> Result<(), E
let edns_rdlen_offset: usize = edns_offset + 8;
ensure!(packet_len - edns_rdlen_offset >= 2, "Short packet");
let edns_rdlen = BigEndian::read_u16(&packet[edns_rdlen_offset..]);
dbg!(edns_rdlen);
ensure!(
edns_offset + edns_rdlen as usize <= packet_len,
"Out of range EDNS size"
);
ensure!(
0xffff - edns_rdlen as usize >= edns_padding_prr_len,
"EDNS section too large for padding"

264
src/libdoh/src/lib.rs Normal file
View file

@ -0,0 +1,264 @@
mod constants;
mod dns;
mod errors;
mod globals;
#[cfg(feature = "tls")]
mod tls;
use crate::constants::*;
pub use crate::errors::*;
pub use crate::globals::*;
#[cfg(feature = "tls")]
use crate::tls::*;
use futures::prelude::*;
use futures::task::{Context, Poll};
use hyper::http;
use hyper::server::conn::Http;
use hyper::{Body, Method, Request, Response, StatusCode};
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, UdpSocket};
#[derive(Clone, Debug)]
pub struct DoH {
pub globals: Arc<Globals>,
}
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
let response = Response::builder()
.status(status_code)
.body(Body::empty())
.unwrap();
Ok(response)
}
impl hyper::service::Service<http::Request<Body>> for DoH {
type Response = Response<Body>;
type Error = http::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let globals = &self.globals;
if req.uri().path() != globals.path {
return Box::pin(async { http_error(StatusCode::NOT_FOUND) });
}
let self_inner = self.clone();
match *req.method() {
Method::POST => Box::pin(async move { self_inner.serve_post(req).await }),
Method::GET => Box::pin(async move { self_inner.serve_get(req).await }),
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
}
}
}
impl DoH {
async fn serve_post(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
if self.globals.disable_post {
return http_error(StatusCode::METHOD_NOT_ALLOWED);
}
if let Err(response) = Self::check_content_type(&req) {
return Ok(response);
}
match self.read_body_and_proxy(req.into_body()).await {
Err(e) => http_error(StatusCode::from(e)),
Ok(res) => Ok(res),
}
}
async fn serve_get(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
let query = req.uri().query().unwrap_or("");
let mut question_str = None;
for parts in query.split('&') {
let mut kv = parts.split('=');
if let Some(k) = kv.next() {
if k == DNS_QUERY_PARAM {
question_str = kv.next();
}
}
}
let question = match question_str.and_then(|question_str| {
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
}) {
Some(question) => question,
_ => {
return http_error(StatusCode::BAD_REQUEST);
}
};
match self.proxy(question).await {
Err(e) => http_error(StatusCode::from(e)),
Ok(res) => Ok(res),
}
}
fn check_content_type(req: &Request<Body>) -> Result<(), Response<Body>> {
let headers = req.headers();
let content_type = match headers.get(hyper::header::CONTENT_TYPE) {
None => {
let response = Response::builder()
.status(StatusCode::NOT_ACCEPTABLE)
.body(Body::empty())
.unwrap();
return Err(response);
}
Some(content_type) => content_type.to_str(),
};
let content_type = match content_type {
Err(_) => {
let response = Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::empty())
.unwrap();
return Err(response);
}
Ok(content_type) => content_type.to_lowercase(),
};
if content_type != "application/dns-message" {
let response = Response::builder()
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
.body(Body::empty())
.unwrap();
return Err(response);
}
Ok(())
}
async fn read_body_and_proxy(&self, mut body: Body) -> Result<Response<Body>, DoHError> {
let mut sum_size = 0;
let mut query = vec![];
while let Some(chunk) = body.next().await {
let chunk = chunk.map_err(|_| DoHError::TooLarge)?;
sum_size += chunk.len();
if sum_size >= MAX_DNS_QUESTION_LEN {
return Err(DoHError::TooLarge);
}
query.extend(chunk);
}
let response = self.proxy(query).await?;
Ok(response)
}
async fn proxy(&self, mut query: Vec<u8>) -> Result<Response<Body>, DoHError> {
if query.len() < MIN_DNS_PACKET_LEN {
return Err(DoHError::Incomplete);
}
let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as _);
let globals = &self.globals;
let mut socket = UdpSocket::bind(&globals.local_bind_address)
.await
.map_err(DoHError::Io)?;
let expected_server_address = globals.server_address;
let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl);
socket
.send_to(&query, &globals.server_address)
.map_err(DoHError::Io)
.await?;
let mut packet = vec![0; MAX_DNS_RESPONSE_LEN];
let (len, response_server_address) =
socket.recv_from(&mut packet).map_err(DoHError::Io).await?;
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
return Err(DoHError::UpstreamIssue);
}
packet.truncate(len);
let ttl = if dns::is_recoverable_error(&packet) {
err_ttl
} else {
match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) {
Err(_) => return Err(DoHError::UpstreamIssue),
Ok(ttl) => ttl,
}
};
dns::add_edns_padding(&mut packet, BLOCK_SIZE).map_err(|_| DoHError::TooLarge)?;
let packet_len = packet.len();
let response = Response::builder()
.header(hyper::header::CONTENT_LENGTH, packet_len)
.header(hyper::header::CONTENT_TYPE, "application/dns-message")
.header(
hyper::header::CACHE_CONTROL,
format!("max-age={}", ttl).as_str(),
)
.body(Body::from(packet))
.unwrap();
Ok(response)
}
async fn client_serve<I>(self, stream: I, server: Http)
where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let clients_count = self.globals.clients_count.clone();
if clients_count.increment() > self.globals.max_clients {
clients_count.decrement();
return;
}
tokio::spawn(async move {
tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self))
.await
.ok();
clients_count.decrement();
});
}
async fn start_without_tls(
self,
mut listener: TcpListener,
server: Http,
) -> Result<(), DoHError> {
let listener_service = async {
while let Some(stream) = listener.incoming().next().await {
let stream = match stream {
Ok(stream) => stream,
Err(_) => continue,
};
self.clone().client_serve(stream, server.clone()).await;
}
Ok(()) as Result<(), DoHError>
};
listener_service.await?;
Ok(())
}
pub async fn entrypoint(self) -> Result<(), DoHError> {
let listen_address = self.globals.listen_address;
let listener = TcpListener::bind(&listen_address)
.await
.map_err(DoHError::Io)?;
let path = &self.globals.path;
#[cfg(feature = "tls")]
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
(Some(tls_cert_path), Some(tls_cert_password)) => {
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
}
_ => None,
};
#[cfg(not(feature = "tls"))]
let tls_acceptor: Option<()> = None;
if tls_acceptor.is_some() {
println!("Listening on https://{}{}", listen_address, path);
} else {
println!("Listening on http://{}{}", listen_address, path);
}
let mut server = Http::new();
server.keep_alive(self.globals.keepalive);
server.pipeline_flush(true);
#[cfg(feature = "tls")]
{
if let Some(tls_acceptor) = tls_acceptor {
self.start_with_tls(tls_acceptor, listener, server).await?;
return Ok(());
}
}
self.start_without_tls(listener, server).await?;
Ok(())
}
}

View file

@ -6,273 +6,16 @@ extern crate clap;
mod config;
mod constants;
mod dns;
mod errors;
mod globals;
#[cfg(feature = "tls")]
mod tls;
mod utils;
use libdoh::*;
use crate::config::*;
use crate::constants::*;
use crate::errors::*;
use crate::globals::*;
#[cfg(feature = "tls")]
use crate::tls::*;
use futures::prelude::*;
use futures::task::{Context, Poll};
use hyper::http;
use hyper::server::conn::Http;
use hyper::{Body, Method, Request, Response, StatusCode};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, UdpSocket};
#[derive(Clone, Debug)]
pub struct DoH {
pub globals: Arc<Globals>,
}
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
let response = Response::builder()
.status(status_code)
.body(Body::empty())
.unwrap();
Ok(response)
}
impl hyper::service::Service<http::Request<Body>> for DoH {
type Response = Response<Body>;
type Error = http::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let globals = &self.globals;
if req.uri().path() != globals.path {
return Box::pin(async { http_error(StatusCode::NOT_FOUND) });
}
let self_inner = self.clone();
match *req.method() {
Method::POST => Box::pin(async move { self_inner.serve_post(req).await }),
Method::GET => Box::pin(async move { self_inner.serve_get(req).await }),
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
}
}
}
impl DoH {
async fn serve_post(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
if self.globals.disable_post {
return http_error(StatusCode::METHOD_NOT_ALLOWED);
}
if let Err(response) = Self::check_content_type(&req) {
return Ok(response);
}
match self.read_body_and_proxy(req.into_body()).await {
Err(e) => http_error(StatusCode::from(e)),
Ok(res) => Ok(res),
}
}
async fn serve_get(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
let query = req.uri().query().unwrap_or("");
let mut question_str = None;
for parts in query.split('&') {
let mut kv = parts.split('=');
if let Some(k) = kv.next() {
if k == DNS_QUERY_PARAM {
question_str = kv.next();
}
}
}
let question = match question_str.and_then(|question_str| {
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
}) {
Some(question) => question,
_ => {
return http_error(StatusCode::BAD_REQUEST);
}
};
match self.proxy(question).await {
Err(e) => http_error(StatusCode::from(e)),
Ok(res) => Ok(res),
}
}
fn check_content_type(req: &Request<Body>) -> Result<(), Response<Body>> {
let headers = req.headers();
let content_type = match headers.get(hyper::header::CONTENT_TYPE) {
None => {
let response = Response::builder()
.status(StatusCode::NOT_ACCEPTABLE)
.body(Body::empty())
.unwrap();
return Err(response);
}
Some(content_type) => content_type.to_str(),
};
let content_type = match content_type {
Err(_) => {
let response = Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::empty())
.unwrap();
return Err(response);
}
Ok(content_type) => content_type.to_lowercase(),
};
if content_type != "application/dns-message" {
let response = Response::builder()
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
.body(Body::empty())
.unwrap();
return Err(response);
}
Ok(())
}
async fn read_body_and_proxy(&self, mut body: Body) -> Result<Response<Body>, DoHError> {
let mut sum_size = 0;
let mut query = vec![];
while let Some(chunk) = body.next().await {
let chunk = chunk.map_err(|_| DoHError::TooLarge)?;
sum_size += chunk.len();
if sum_size >= MAX_DNS_QUESTION_LEN {
return Err(DoHError::TooLarge);
}
query.extend(chunk);
}
let response = self.proxy(query).await?;
Ok(response)
}
async fn proxy(&self, mut query: Vec<u8>) -> Result<Response<Body>, DoHError> {
if query.len() < MIN_DNS_PACKET_LEN {
return Err(DoHError::Incomplete);
}
let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as _);
let globals = &self.globals;
let mut socket = UdpSocket::bind(&globals.local_bind_address)
.await
.map_err(DoHError::Io)?;
let expected_server_address = globals.server_address;
let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl);
socket
.send_to(&query, &globals.server_address)
.map_err(DoHError::Io)
.await?;
let mut packet = vec![0; MAX_DNS_RESPONSE_LEN];
let (len, response_server_address) =
socket.recv_from(&mut packet).map_err(DoHError::Io).await?;
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
return Err(DoHError::UpstreamIssue);
}
packet.truncate(len);
let ttl = if dns::is_recoverable_error(&packet) {
err_ttl
} else {
match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) {
Err(_) => return Err(DoHError::UpstreamIssue),
Ok(ttl) => ttl,
}
};
dns::add_edns_padding(&mut packet, BLOCK_SIZE).map_err(|_| DoHError::TooLarge)?;
let packet_len = packet.len();
let response = Response::builder()
.header(hyper::header::CONTENT_LENGTH, packet_len)
.header(hyper::header::CONTENT_TYPE, "application/dns-message")
.header(
hyper::header::CACHE_CONTROL,
format!("max-age={}", ttl).as_str(),
)
.body(Body::from(packet))
.unwrap();
Ok(response)
}
async fn client_serve<I>(self, stream: I, server: Http)
where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let clients_count = self.globals.clients_count.clone();
if clients_count.increment() > self.globals.max_clients {
clients_count.decrement();
return;
}
tokio::spawn(async move {
tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self))
.await
.ok();
clients_count.decrement();
});
}
async fn start_without_tls(
self,
mut listener: TcpListener,
server: Http,
) -> Result<(), DoHError> {
let listener_service = async {
while let Some(stream) = listener.incoming().next().await {
let stream = match stream {
Ok(stream) => stream,
Err(_) => continue,
};
self.clone().client_serve(stream, server.clone()).await;
}
Ok(()) as Result<(), DoHError>
};
listener_service.await?;
Ok(())
}
async fn entrypoint(self) -> Result<(), DoHError> {
let listen_address = self.globals.listen_address;
let listener = TcpListener::bind(&listen_address)
.await
.map_err(DoHError::Io)?;
let path = &self.globals.path;
#[cfg(feature = "tls")]
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
(Some(tls_cert_path), Some(tls_cert_password)) => {
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
}
_ => None,
};
#[cfg(not(feature = "tls"))]
let tls_acceptor: Option<()> = None;
if tls_acceptor.is_some() {
println!("Listening on https://{}{}", listen_address, path);
} else {
println!("Listening on http://{}{}", listen_address, path);
}
let mut server = Http::new();
server.keep_alive(self.globals.keepalive);
server.pipeline_flush(true);
#[cfg(feature = "tls")]
{
if let Some(tls_acceptor) = tls_acceptor {
self.start_with_tls(tls_acceptor, listener, server).await?;
return Ok(());
}
}
self.start_without_tls(listener, server).await?;
Ok(())
}
}
fn main() {
let mut globals = Globals {
@ -287,7 +30,7 @@ fn main() {
path: PATH.to_string(),
max_clients: MAX_CLIENTS,
timeout: Duration::from_secs(TIMEOUT_SEC),
clients_count: ClientsCount::default(),
clients_count: Default::default(),
min_ttl: MIN_TTL,
max_ttl: MAX_TTL,
err_ttl: ERR_TTL,