diff --git a/.gitignore b/.gitignore index bbbf46e..d9ce6c6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ #*# **/*.rs.bk *~ +Cargo.lock /target/ +/src/libdoh/target/ diff --git a/Cargo.toml b/Cargo.toml index 4b14aa6..2b4ea83 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,19 +13,13 @@ readme = "README.md" [features] default = [] -tls = ["native-tls", "tokio-tls"] +tls = ["libdoh/tls"] [dependencies] -anyhow = "1.0" -byteorder = "1.3" -base64 = "0.11" -clap = "2.33.0" -futures = { version = "0.3" } -hyper = { version = "0.13", default-features = false, features = ["stream"] } +libdoh = { path = "src/libdoh" } +clap = "2" jemallocator = "0" -native-tls = { version = "0.2.3", optional = true } tokio = { version = "0.2", features = ["rt-threaded", "time", "tcp", "udp", "stream"] } -tokio-tls = { version = "0.3", optional = true } [package.metadata.deb] extended-description = """\ diff --git a/src/config.rs b/src/config.rs index 6fbfba5..dbe0c3f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,6 @@ +use libdoh::*; + use crate::constants::*; -use crate::globals::*; use clap::Arg; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; diff --git a/src/constants.rs b/src/constants.rs index 2f23f3a..7a133f5 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,10 +1,5 @@ -pub const BLOCK_SIZE: usize = 128; -pub const DNS_QUERY_PARAM: &str = "dns"; pub const LISTEN_ADDRESS: &str = "127.0.0.1:3000"; pub const MAX_CLIENTS: usize = 512; -pub const MAX_DNS_QUESTION_LEN: usize = 512; -pub const MAX_DNS_RESPONSE_LEN: usize = 4096; -pub const MIN_DNS_PACKET_LEN: usize = 17; pub const PATH: &str = "/dns-query"; pub const SERVER_ADDRESS: &str = "9.9.9.9:53"; pub const TIMEOUT_SEC: u64 = 10; diff --git a/src/libdoh/Cargo.toml b/src/libdoh/Cargo.toml new file mode 100644 index 0000000..07643e0 --- /dev/null +++ b/src/libdoh/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "libdoh" +version = "0.2.0" +authors = ["Frank Denis "] +license = "MIT" +edition = "2018" +publish = false + +[features] +default = [] +tls = ["native-tls", "tokio-tls"] + +[dependencies] +anyhow = "1.0" +byteorder = "1.3" +base64 = "0.11" +futures = { version = "0.3" } +hyper = { version = "0.13", default-features = false, features = ["stream"] } +native-tls = { version = "0.2.3", optional = true } +tokio = { version = "0.2", features = ["rt-threaded", "time", "tcp", "udp", "stream"] } +tokio-tls = { version = "0.3", optional = true } + +[profile.release] +codegen-units = 1 +incremental = false +lto = "fat" +opt-level = 3 +panic = "abort" diff --git a/src/libdoh/src/constants.rs b/src/libdoh/src/constants.rs new file mode 100644 index 0000000..f1a676b --- /dev/null +++ b/src/libdoh/src/constants.rs @@ -0,0 +1,5 @@ +pub const BLOCK_SIZE: usize = 128; +pub const DNS_QUERY_PARAM: &str = "dns"; +pub const MAX_DNS_QUESTION_LEN: usize = 512; +pub const MAX_DNS_RESPONSE_LEN: usize = 4096; +pub const MIN_DNS_PACKET_LEN: usize = 17; diff --git a/src/dns.rs b/src/libdoh/src/dns.rs similarity index 98% rename from src/dns.rs rename to src/libdoh/src/dns.rs index 54b9b1c..93bf9a9 100644 --- a/src/dns.rs +++ b/src/libdoh/src/dns.rs @@ -243,6 +243,11 @@ pub fn add_edns_padding(packet: &mut Vec, block_size: usize) -> Result<(), E let edns_rdlen_offset: usize = edns_offset + 8; ensure!(packet_len - edns_rdlen_offset >= 2, "Short packet"); let edns_rdlen = BigEndian::read_u16(&packet[edns_rdlen_offset..]); + dbg!(edns_rdlen); + ensure!( + edns_offset + edns_rdlen as usize <= packet_len, + "Out of range EDNS size" + ); ensure!( 0xffff - edns_rdlen as usize >= edns_padding_prr_len, "EDNS section too large for padding" diff --git a/src/errors.rs b/src/libdoh/src/errors.rs similarity index 100% rename from src/errors.rs rename to src/libdoh/src/errors.rs diff --git a/src/globals.rs b/src/libdoh/src/globals.rs similarity index 100% rename from src/globals.rs rename to src/libdoh/src/globals.rs diff --git a/src/libdoh/src/lib.rs b/src/libdoh/src/lib.rs new file mode 100644 index 0000000..ead4c12 --- /dev/null +++ b/src/libdoh/src/lib.rs @@ -0,0 +1,264 @@ +mod constants; +mod dns; +mod errors; +mod globals; +#[cfg(feature = "tls")] +mod tls; + +use crate::constants::*; +pub use crate::errors::*; +pub use crate::globals::*; + +#[cfg(feature = "tls")] +use crate::tls::*; + +use futures::prelude::*; +use futures::task::{Context, Poll}; +use hyper::http; +use hyper::server::conn::Http; +use hyper::{Body, Method, Request, Response, StatusCode}; +use std::pin::Pin; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::{TcpListener, UdpSocket}; + +#[derive(Clone, Debug)] +pub struct DoH { + pub globals: Arc, +} + +fn http_error(status_code: StatusCode) -> Result, http::Error> { + let response = Response::builder() + .status(status_code) + .body(Body::empty()) + .unwrap(); + Ok(response) +} + +impl hyper::service::Service> for DoH { + type Response = Response; + type Error = http::Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let globals = &self.globals; + if req.uri().path() != globals.path { + return Box::pin(async { http_error(StatusCode::NOT_FOUND) }); + } + let self_inner = self.clone(); + match *req.method() { + Method::POST => Box::pin(async move { self_inner.serve_post(req).await }), + Method::GET => Box::pin(async move { self_inner.serve_get(req).await }), + _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), + } + } +} + +impl DoH { + async fn serve_post(&self, req: Request) -> Result, http::Error> { + if self.globals.disable_post { + return http_error(StatusCode::METHOD_NOT_ALLOWED); + } + if let Err(response) = Self::check_content_type(&req) { + return Ok(response); + } + match self.read_body_and_proxy(req.into_body()).await { + Err(e) => http_error(StatusCode::from(e)), + Ok(res) => Ok(res), + } + } + + async fn serve_get(&self, req: Request) -> Result, http::Error> { + let query = req.uri().query().unwrap_or(""); + let mut question_str = None; + for parts in query.split('&') { + let mut kv = parts.split('='); + if let Some(k) = kv.next() { + if k == DNS_QUERY_PARAM { + question_str = kv.next(); + } + } + } + let question = match question_str.and_then(|question_str| { + base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok() + }) { + Some(question) => question, + _ => { + return http_error(StatusCode::BAD_REQUEST); + } + }; + match self.proxy(question).await { + Err(e) => http_error(StatusCode::from(e)), + Ok(res) => Ok(res), + } + } + + fn check_content_type(req: &Request) -> Result<(), Response> { + let headers = req.headers(); + let content_type = match headers.get(hyper::header::CONTENT_TYPE) { + None => { + let response = Response::builder() + .status(StatusCode::NOT_ACCEPTABLE) + .body(Body::empty()) + .unwrap(); + return Err(response); + } + Some(content_type) => content_type.to_str(), + }; + let content_type = match content_type { + Err(_) => { + let response = Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::empty()) + .unwrap(); + return Err(response); + } + Ok(content_type) => content_type.to_lowercase(), + }; + if content_type != "application/dns-message" { + let response = Response::builder() + .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) + .body(Body::empty()) + .unwrap(); + return Err(response); + } + Ok(()) + } + + async fn read_body_and_proxy(&self, mut body: Body) -> Result, DoHError> { + let mut sum_size = 0; + let mut query = vec![]; + while let Some(chunk) = body.next().await { + let chunk = chunk.map_err(|_| DoHError::TooLarge)?; + sum_size += chunk.len(); + if sum_size >= MAX_DNS_QUESTION_LEN { + return Err(DoHError::TooLarge); + } + query.extend(chunk); + } + let response = self.proxy(query).await?; + Ok(response) + } + + async fn proxy(&self, mut query: Vec) -> Result, DoHError> { + if query.len() < MIN_DNS_PACKET_LEN { + return Err(DoHError::Incomplete); + } + let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as _); + let globals = &self.globals; + let mut socket = UdpSocket::bind(&globals.local_bind_address) + .await + .map_err(DoHError::Io)?; + let expected_server_address = globals.server_address; + let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl); + socket + .send_to(&query, &globals.server_address) + .map_err(DoHError::Io) + .await?; + let mut packet = vec![0; MAX_DNS_RESPONSE_LEN]; + let (len, response_server_address) = + socket.recv_from(&mut packet).map_err(DoHError::Io).await?; + if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address { + return Err(DoHError::UpstreamIssue); + } + packet.truncate(len); + let ttl = if dns::is_recoverable_error(&packet) { + err_ttl + } else { + match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) { + Err(_) => return Err(DoHError::UpstreamIssue), + Ok(ttl) => ttl, + } + }; + dns::add_edns_padding(&mut packet, BLOCK_SIZE).map_err(|_| DoHError::TooLarge)?; + let packet_len = packet.len(); + let response = Response::builder() + .header(hyper::header::CONTENT_LENGTH, packet_len) + .header(hyper::header::CONTENT_TYPE, "application/dns-message") + .header( + hyper::header::CACHE_CONTROL, + format!("max-age={}", ttl).as_str(), + ) + .body(Body::from(packet)) + .unwrap(); + Ok(response) + } + + async fn client_serve(self, stream: I, server: Http) + where + I: AsyncRead + AsyncWrite + Send + Unpin + 'static, + { + let clients_count = self.globals.clients_count.clone(); + if clients_count.increment() > self.globals.max_clients { + clients_count.decrement(); + return; + } + tokio::spawn(async move { + tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self)) + .await + .ok(); + clients_count.decrement(); + }); + } + + async fn start_without_tls( + self, + mut listener: TcpListener, + server: Http, + ) -> Result<(), DoHError> { + let listener_service = async { + while let Some(stream) = listener.incoming().next().await { + let stream = match stream { + Ok(stream) => stream, + Err(_) => continue, + }; + self.clone().client_serve(stream, server.clone()).await; + } + Ok(()) as Result<(), DoHError> + }; + listener_service.await?; + Ok(()) + } + + pub async fn entrypoint(self) -> Result<(), DoHError> { + let listen_address = self.globals.listen_address; + let listener = TcpListener::bind(&listen_address) + .await + .map_err(DoHError::Io)?; + let path = &self.globals.path; + + #[cfg(feature = "tls")] + let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) { + (Some(tls_cert_path), Some(tls_cert_password)) => { + Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap()) + } + _ => None, + }; + #[cfg(not(feature = "tls"))] + let tls_acceptor: Option<()> = None; + + if tls_acceptor.is_some() { + println!("Listening on https://{}{}", listen_address, path); + } else { + println!("Listening on http://{}{}", listen_address, path); + } + + let mut server = Http::new(); + server.keep_alive(self.globals.keepalive); + server.pipeline_flush(true); + + #[cfg(feature = "tls")] + { + if let Some(tls_acceptor) = tls_acceptor { + self.start_with_tls(tls_acceptor, listener, server).await?; + return Ok(()); + } + } + self.start_without_tls(listener, server).await?; + Ok(()) + } +} diff --git a/src/tls.rs b/src/libdoh/src/tls.rs similarity index 100% rename from src/tls.rs rename to src/libdoh/src/tls.rs diff --git a/src/main.rs b/src/main.rs index 914b412..06086e9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,273 +6,16 @@ extern crate clap; mod config; mod constants; -mod dns; -mod errors; -mod globals; -#[cfg(feature = "tls")] -mod tls; mod utils; +use libdoh::*; + use crate::config::*; use crate::constants::*; -use crate::errors::*; -use crate::globals::*; -#[cfg(feature = "tls")] -use crate::tls::*; - -use futures::prelude::*; -use futures::task::{Context, Poll}; -use hyper::http; -use hyper::server::conn::Http; -use hyper::{Body, Method, Request, Response, StatusCode}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::pin::Pin; use std::sync::Arc; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::{TcpListener, UdpSocket}; - -#[derive(Clone, Debug)] -pub struct DoH { - pub globals: Arc, -} - -fn http_error(status_code: StatusCode) -> Result, http::Error> { - let response = Response::builder() - .status(status_code) - .body(Body::empty()) - .unwrap(); - Ok(response) -} - -impl hyper::service::Service> for DoH { - type Response = Response; - type Error = http::Error; - type Future = Pin> + Send>>; - - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - let globals = &self.globals; - if req.uri().path() != globals.path { - return Box::pin(async { http_error(StatusCode::NOT_FOUND) }); - } - let self_inner = self.clone(); - match *req.method() { - Method::POST => Box::pin(async move { self_inner.serve_post(req).await }), - Method::GET => Box::pin(async move { self_inner.serve_get(req).await }), - _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), - } - } -} - -impl DoH { - async fn serve_post(&self, req: Request) -> Result, http::Error> { - if self.globals.disable_post { - return http_error(StatusCode::METHOD_NOT_ALLOWED); - } - if let Err(response) = Self::check_content_type(&req) { - return Ok(response); - } - match self.read_body_and_proxy(req.into_body()).await { - Err(e) => http_error(StatusCode::from(e)), - Ok(res) => Ok(res), - } - } - - async fn serve_get(&self, req: Request) -> Result, http::Error> { - let query = req.uri().query().unwrap_or(""); - let mut question_str = None; - for parts in query.split('&') { - let mut kv = parts.split('='); - if let Some(k) = kv.next() { - if k == DNS_QUERY_PARAM { - question_str = kv.next(); - } - } - } - let question = match question_str.and_then(|question_str| { - base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok() - }) { - Some(question) => question, - _ => { - return http_error(StatusCode::BAD_REQUEST); - } - }; - match self.proxy(question).await { - Err(e) => http_error(StatusCode::from(e)), - Ok(res) => Ok(res), - } - } - - fn check_content_type(req: &Request) -> Result<(), Response> { - let headers = req.headers(); - let content_type = match headers.get(hyper::header::CONTENT_TYPE) { - None => { - let response = Response::builder() - .status(StatusCode::NOT_ACCEPTABLE) - .body(Body::empty()) - .unwrap(); - return Err(response); - } - Some(content_type) => content_type.to_str(), - }; - let content_type = match content_type { - Err(_) => { - let response = Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::empty()) - .unwrap(); - return Err(response); - } - Ok(content_type) => content_type.to_lowercase(), - }; - if content_type != "application/dns-message" { - let response = Response::builder() - .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) - .body(Body::empty()) - .unwrap(); - return Err(response); - } - Ok(()) - } - - async fn read_body_and_proxy(&self, mut body: Body) -> Result, DoHError> { - let mut sum_size = 0; - let mut query = vec![]; - while let Some(chunk) = body.next().await { - let chunk = chunk.map_err(|_| DoHError::TooLarge)?; - sum_size += chunk.len(); - if sum_size >= MAX_DNS_QUESTION_LEN { - return Err(DoHError::TooLarge); - } - query.extend(chunk); - } - let response = self.proxy(query).await?; - Ok(response) - } - - async fn proxy(&self, mut query: Vec) -> Result, DoHError> { - if query.len() < MIN_DNS_PACKET_LEN { - return Err(DoHError::Incomplete); - } - let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as _); - let globals = &self.globals; - let mut socket = UdpSocket::bind(&globals.local_bind_address) - .await - .map_err(DoHError::Io)?; - let expected_server_address = globals.server_address; - let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl); - socket - .send_to(&query, &globals.server_address) - .map_err(DoHError::Io) - .await?; - let mut packet = vec![0; MAX_DNS_RESPONSE_LEN]; - let (len, response_server_address) = - socket.recv_from(&mut packet).map_err(DoHError::Io).await?; - if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address { - return Err(DoHError::UpstreamIssue); - } - packet.truncate(len); - let ttl = if dns::is_recoverable_error(&packet) { - err_ttl - } else { - match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) { - Err(_) => return Err(DoHError::UpstreamIssue), - Ok(ttl) => ttl, - } - }; - dns::add_edns_padding(&mut packet, BLOCK_SIZE).map_err(|_| DoHError::TooLarge)?; - let packet_len = packet.len(); - let response = Response::builder() - .header(hyper::header::CONTENT_LENGTH, packet_len) - .header(hyper::header::CONTENT_TYPE, "application/dns-message") - .header( - hyper::header::CACHE_CONTROL, - format!("max-age={}", ttl).as_str(), - ) - .body(Body::from(packet)) - .unwrap(); - Ok(response) - } - - async fn client_serve(self, stream: I, server: Http) - where - I: AsyncRead + AsyncWrite + Send + Unpin + 'static, - { - let clients_count = self.globals.clients_count.clone(); - if clients_count.increment() > self.globals.max_clients { - clients_count.decrement(); - return; - } - tokio::spawn(async move { - tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self)) - .await - .ok(); - clients_count.decrement(); - }); - } - - async fn start_without_tls( - self, - mut listener: TcpListener, - server: Http, - ) -> Result<(), DoHError> { - let listener_service = async { - while let Some(stream) = listener.incoming().next().await { - let stream = match stream { - Ok(stream) => stream, - Err(_) => continue, - }; - self.clone().client_serve(stream, server.clone()).await; - } - Ok(()) as Result<(), DoHError> - }; - listener_service.await?; - Ok(()) - } - - async fn entrypoint(self) -> Result<(), DoHError> { - let listen_address = self.globals.listen_address; - let listener = TcpListener::bind(&listen_address) - .await - .map_err(DoHError::Io)?; - let path = &self.globals.path; - - #[cfg(feature = "tls")] - let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) { - (Some(tls_cert_path), Some(tls_cert_password)) => { - Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap()) - } - _ => None, - }; - #[cfg(not(feature = "tls"))] - let tls_acceptor: Option<()> = None; - - if tls_acceptor.is_some() { - println!("Listening on https://{}{}", listen_address, path); - } else { - println!("Listening on http://{}{}", listen_address, path); - } - - let mut server = Http::new(); - server.keep_alive(self.globals.keepalive); - server.pipeline_flush(true); - - #[cfg(feature = "tls")] - { - if let Some(tls_acceptor) = tls_acceptor { - self.start_with_tls(tls_acceptor, listener, server).await?; - return Ok(()); - } - } - self.start_without_tls(listener, server).await?; - Ok(()) - } -} fn main() { let mut globals = Globals { @@ -287,7 +30,7 @@ fn main() { path: PATH.to_string(), max_clients: MAX_CLIENTS, timeout: Duration::from_secs(TIMEOUT_SEC), - clients_count: ClientsCount::default(), + clients_count: Default::default(), min_ttl: MIN_TTL, max_ttl: MAX_TTL, err_ttl: ERR_TTL,