mirror of
https://github.com/DNSCrypt/doh-server.git
synced 2025-04-01 20:27:34 +03:00
Reorganize a little bit
This commit is contained in:
parent
bf42e95368
commit
06b91af009
12 changed files with 312 additions and 275 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,5 +1,7 @@
|
||||||
#*#
|
#*#
|
||||||
**/*.rs.bk
|
**/*.rs.bk
|
||||||
*~
|
*~
|
||||||
|
Cargo.lock
|
||||||
/target/
|
/target/
|
||||||
|
/src/libdoh/target/
|
||||||
|
|
||||||
|
|
12
Cargo.toml
12
Cargo.toml
|
@ -13,19 +13,13 @@ readme = "README.md"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
tls = ["native-tls", "tokio-tls"]
|
tls = ["libdoh/tls"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0"
|
libdoh = { path = "src/libdoh" }
|
||||||
byteorder = "1.3"
|
clap = "2"
|
||||||
base64 = "0.11"
|
|
||||||
clap = "2.33.0"
|
|
||||||
futures = { version = "0.3" }
|
|
||||||
hyper = { version = "0.13", default-features = false, features = ["stream"] }
|
|
||||||
jemallocator = "0"
|
jemallocator = "0"
|
||||||
native-tls = { version = "0.2.3", optional = true }
|
|
||||||
tokio = { version = "0.2", features = ["rt-threaded", "time", "tcp", "udp", "stream"] }
|
tokio = { version = "0.2", features = ["rt-threaded", "time", "tcp", "udp", "stream"] }
|
||||||
tokio-tls = { version = "0.3", optional = true }
|
|
||||||
|
|
||||||
[package.metadata.deb]
|
[package.metadata.deb]
|
||||||
extended-description = """\
|
extended-description = """\
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
|
use libdoh::*;
|
||||||
|
|
||||||
use crate::constants::*;
|
use crate::constants::*;
|
||||||
use crate::globals::*;
|
|
||||||
|
|
||||||
use clap::Arg;
|
use clap::Arg;
|
||||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
|
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
|
||||||
|
|
|
@ -1,10 +1,5 @@
|
||||||
pub const BLOCK_SIZE: usize = 128;
|
|
||||||
pub const DNS_QUERY_PARAM: &str = "dns";
|
|
||||||
pub const LISTEN_ADDRESS: &str = "127.0.0.1:3000";
|
pub const LISTEN_ADDRESS: &str = "127.0.0.1:3000";
|
||||||
pub const MAX_CLIENTS: usize = 512;
|
pub const MAX_CLIENTS: usize = 512;
|
||||||
pub const MAX_DNS_QUESTION_LEN: usize = 512;
|
|
||||||
pub const MAX_DNS_RESPONSE_LEN: usize = 4096;
|
|
||||||
pub const MIN_DNS_PACKET_LEN: usize = 17;
|
|
||||||
pub const PATH: &str = "/dns-query";
|
pub const PATH: &str = "/dns-query";
|
||||||
pub const SERVER_ADDRESS: &str = "9.9.9.9:53";
|
pub const SERVER_ADDRESS: &str = "9.9.9.9:53";
|
||||||
pub const TIMEOUT_SEC: u64 = 10;
|
pub const TIMEOUT_SEC: u64 = 10;
|
||||||
|
|
28
src/libdoh/Cargo.toml
Normal file
28
src/libdoh/Cargo.toml
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
[package]
|
||||||
|
name = "libdoh"
|
||||||
|
version = "0.2.0"
|
||||||
|
authors = ["Frank Denis <github@pureftpd.org>"]
|
||||||
|
license = "MIT"
|
||||||
|
edition = "2018"
|
||||||
|
publish = false
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
tls = ["native-tls", "tokio-tls"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0"
|
||||||
|
byteorder = "1.3"
|
||||||
|
base64 = "0.11"
|
||||||
|
futures = { version = "0.3" }
|
||||||
|
hyper = { version = "0.13", default-features = false, features = ["stream"] }
|
||||||
|
native-tls = { version = "0.2.3", optional = true }
|
||||||
|
tokio = { version = "0.2", features = ["rt-threaded", "time", "tcp", "udp", "stream"] }
|
||||||
|
tokio-tls = { version = "0.3", optional = true }
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
codegen-units = 1
|
||||||
|
incremental = false
|
||||||
|
lto = "fat"
|
||||||
|
opt-level = 3
|
||||||
|
panic = "abort"
|
5
src/libdoh/src/constants.rs
Normal file
5
src/libdoh/src/constants.rs
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
pub const BLOCK_SIZE: usize = 128;
|
||||||
|
pub const DNS_QUERY_PARAM: &str = "dns";
|
||||||
|
pub const MAX_DNS_QUESTION_LEN: usize = 512;
|
||||||
|
pub const MAX_DNS_RESPONSE_LEN: usize = 4096;
|
||||||
|
pub const MIN_DNS_PACKET_LEN: usize = 17;
|
|
@ -243,6 +243,11 @@ pub fn add_edns_padding(packet: &mut Vec<u8>, block_size: usize) -> Result<(), E
|
||||||
let edns_rdlen_offset: usize = edns_offset + 8;
|
let edns_rdlen_offset: usize = edns_offset + 8;
|
||||||
ensure!(packet_len - edns_rdlen_offset >= 2, "Short packet");
|
ensure!(packet_len - edns_rdlen_offset >= 2, "Short packet");
|
||||||
let edns_rdlen = BigEndian::read_u16(&packet[edns_rdlen_offset..]);
|
let edns_rdlen = BigEndian::read_u16(&packet[edns_rdlen_offset..]);
|
||||||
|
dbg!(edns_rdlen);
|
||||||
|
ensure!(
|
||||||
|
edns_offset + edns_rdlen as usize <= packet_len,
|
||||||
|
"Out of range EDNS size"
|
||||||
|
);
|
||||||
ensure!(
|
ensure!(
|
||||||
0xffff - edns_rdlen as usize >= edns_padding_prr_len,
|
0xffff - edns_rdlen as usize >= edns_padding_prr_len,
|
||||||
"EDNS section too large for padding"
|
"EDNS section too large for padding"
|
264
src/libdoh/src/lib.rs
Normal file
264
src/libdoh/src/lib.rs
Normal file
|
@ -0,0 +1,264 @@
|
||||||
|
mod constants;
|
||||||
|
mod dns;
|
||||||
|
mod errors;
|
||||||
|
mod globals;
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
mod tls;
|
||||||
|
|
||||||
|
use crate::constants::*;
|
||||||
|
pub use crate::errors::*;
|
||||||
|
pub use crate::globals::*;
|
||||||
|
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
use crate::tls::*;
|
||||||
|
|
||||||
|
use futures::prelude::*;
|
||||||
|
use futures::task::{Context, Poll};
|
||||||
|
use hyper::http;
|
||||||
|
use hyper::server::conn::Http;
|
||||||
|
use hyper::{Body, Method, Request, Response, StatusCode};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use tokio::net::{TcpListener, UdpSocket};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct DoH {
|
||||||
|
pub globals: Arc<Globals>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
|
||||||
|
let response = Response::builder()
|
||||||
|
.status(status_code)
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl hyper::service::Service<http::Request<Body>> for DoH {
|
||||||
|
type Response = Response<Body>;
|
||||||
|
type Error = http::Error;
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||||
|
let globals = &self.globals;
|
||||||
|
if req.uri().path() != globals.path {
|
||||||
|
return Box::pin(async { http_error(StatusCode::NOT_FOUND) });
|
||||||
|
}
|
||||||
|
let self_inner = self.clone();
|
||||||
|
match *req.method() {
|
||||||
|
Method::POST => Box::pin(async move { self_inner.serve_post(req).await }),
|
||||||
|
Method::GET => Box::pin(async move { self_inner.serve_get(req).await }),
|
||||||
|
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DoH {
|
||||||
|
async fn serve_post(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
|
||||||
|
if self.globals.disable_post {
|
||||||
|
return http_error(StatusCode::METHOD_NOT_ALLOWED);
|
||||||
|
}
|
||||||
|
if let Err(response) = Self::check_content_type(&req) {
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
match self.read_body_and_proxy(req.into_body()).await {
|
||||||
|
Err(e) => http_error(StatusCode::from(e)),
|
||||||
|
Ok(res) => Ok(res),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn serve_get(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
|
||||||
|
let query = req.uri().query().unwrap_or("");
|
||||||
|
let mut question_str = None;
|
||||||
|
for parts in query.split('&') {
|
||||||
|
let mut kv = parts.split('=');
|
||||||
|
if let Some(k) = kv.next() {
|
||||||
|
if k == DNS_QUERY_PARAM {
|
||||||
|
question_str = kv.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let question = match question_str.and_then(|question_str| {
|
||||||
|
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
|
||||||
|
}) {
|
||||||
|
Some(question) => question,
|
||||||
|
_ => {
|
||||||
|
return http_error(StatusCode::BAD_REQUEST);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match self.proxy(question).await {
|
||||||
|
Err(e) => http_error(StatusCode::from(e)),
|
||||||
|
Ok(res) => Ok(res),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_content_type(req: &Request<Body>) -> Result<(), Response<Body>> {
|
||||||
|
let headers = req.headers();
|
||||||
|
let content_type = match headers.get(hyper::header::CONTENT_TYPE) {
|
||||||
|
None => {
|
||||||
|
let response = Response::builder()
|
||||||
|
.status(StatusCode::NOT_ACCEPTABLE)
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
return Err(response);
|
||||||
|
}
|
||||||
|
Some(content_type) => content_type.to_str(),
|
||||||
|
};
|
||||||
|
let content_type = match content_type {
|
||||||
|
Err(_) => {
|
||||||
|
let response = Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
return Err(response);
|
||||||
|
}
|
||||||
|
Ok(content_type) => content_type.to_lowercase(),
|
||||||
|
};
|
||||||
|
if content_type != "application/dns-message" {
|
||||||
|
let response = Response::builder()
|
||||||
|
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
return Err(response);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn read_body_and_proxy(&self, mut body: Body) -> Result<Response<Body>, DoHError> {
|
||||||
|
let mut sum_size = 0;
|
||||||
|
let mut query = vec![];
|
||||||
|
while let Some(chunk) = body.next().await {
|
||||||
|
let chunk = chunk.map_err(|_| DoHError::TooLarge)?;
|
||||||
|
sum_size += chunk.len();
|
||||||
|
if sum_size >= MAX_DNS_QUESTION_LEN {
|
||||||
|
return Err(DoHError::TooLarge);
|
||||||
|
}
|
||||||
|
query.extend(chunk);
|
||||||
|
}
|
||||||
|
let response = self.proxy(query).await?;
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn proxy(&self, mut query: Vec<u8>) -> Result<Response<Body>, DoHError> {
|
||||||
|
if query.len() < MIN_DNS_PACKET_LEN {
|
||||||
|
return Err(DoHError::Incomplete);
|
||||||
|
}
|
||||||
|
let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as _);
|
||||||
|
let globals = &self.globals;
|
||||||
|
let mut socket = UdpSocket::bind(&globals.local_bind_address)
|
||||||
|
.await
|
||||||
|
.map_err(DoHError::Io)?;
|
||||||
|
let expected_server_address = globals.server_address;
|
||||||
|
let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl);
|
||||||
|
socket
|
||||||
|
.send_to(&query, &globals.server_address)
|
||||||
|
.map_err(DoHError::Io)
|
||||||
|
.await?;
|
||||||
|
let mut packet = vec![0; MAX_DNS_RESPONSE_LEN];
|
||||||
|
let (len, response_server_address) =
|
||||||
|
socket.recv_from(&mut packet).map_err(DoHError::Io).await?;
|
||||||
|
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
|
||||||
|
return Err(DoHError::UpstreamIssue);
|
||||||
|
}
|
||||||
|
packet.truncate(len);
|
||||||
|
let ttl = if dns::is_recoverable_error(&packet) {
|
||||||
|
err_ttl
|
||||||
|
} else {
|
||||||
|
match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) {
|
||||||
|
Err(_) => return Err(DoHError::UpstreamIssue),
|
||||||
|
Ok(ttl) => ttl,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
dns::add_edns_padding(&mut packet, BLOCK_SIZE).map_err(|_| DoHError::TooLarge)?;
|
||||||
|
let packet_len = packet.len();
|
||||||
|
let response = Response::builder()
|
||||||
|
.header(hyper::header::CONTENT_LENGTH, packet_len)
|
||||||
|
.header(hyper::header::CONTENT_TYPE, "application/dns-message")
|
||||||
|
.header(
|
||||||
|
hyper::header::CACHE_CONTROL,
|
||||||
|
format!("max-age={}", ttl).as_str(),
|
||||||
|
)
|
||||||
|
.body(Body::from(packet))
|
||||||
|
.unwrap();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn client_serve<I>(self, stream: I, server: Http)
|
||||||
|
where
|
||||||
|
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||||
|
{
|
||||||
|
let clients_count = self.globals.clients_count.clone();
|
||||||
|
if clients_count.increment() > self.globals.max_clients {
|
||||||
|
clients_count.decrement();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
tokio::spawn(async move {
|
||||||
|
tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self))
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
clients_count.decrement();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start_without_tls(
|
||||||
|
self,
|
||||||
|
mut listener: TcpListener,
|
||||||
|
server: Http,
|
||||||
|
) -> Result<(), DoHError> {
|
||||||
|
let listener_service = async {
|
||||||
|
while let Some(stream) = listener.incoming().next().await {
|
||||||
|
let stream = match stream {
|
||||||
|
Ok(stream) => stream,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
self.clone().client_serve(stream, server.clone()).await;
|
||||||
|
}
|
||||||
|
Ok(()) as Result<(), DoHError>
|
||||||
|
};
|
||||||
|
listener_service.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn entrypoint(self) -> Result<(), DoHError> {
|
||||||
|
let listen_address = self.globals.listen_address;
|
||||||
|
let listener = TcpListener::bind(&listen_address)
|
||||||
|
.await
|
||||||
|
.map_err(DoHError::Io)?;
|
||||||
|
let path = &self.globals.path;
|
||||||
|
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
|
||||||
|
(Some(tls_cert_path), Some(tls_cert_password)) => {
|
||||||
|
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
#[cfg(not(feature = "tls"))]
|
||||||
|
let tls_acceptor: Option<()> = None;
|
||||||
|
|
||||||
|
if tls_acceptor.is_some() {
|
||||||
|
println!("Listening on https://{}{}", listen_address, path);
|
||||||
|
} else {
|
||||||
|
println!("Listening on http://{}{}", listen_address, path);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut server = Http::new();
|
||||||
|
server.keep_alive(self.globals.keepalive);
|
||||||
|
server.pipeline_flush(true);
|
||||||
|
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
{
|
||||||
|
if let Some(tls_acceptor) = tls_acceptor {
|
||||||
|
self.start_with_tls(tls_acceptor, listener, server).await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.start_without_tls(listener, server).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
263
src/main.rs
263
src/main.rs
|
@ -6,273 +6,16 @@ extern crate clap;
|
||||||
|
|
||||||
mod config;
|
mod config;
|
||||||
mod constants;
|
mod constants;
|
||||||
mod dns;
|
|
||||||
mod errors;
|
|
||||||
mod globals;
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
mod tls;
|
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
|
use libdoh::*;
|
||||||
|
|
||||||
use crate::config::*;
|
use crate::config::*;
|
||||||
use crate::constants::*;
|
use crate::constants::*;
|
||||||
use crate::errors::*;
|
|
||||||
use crate::globals::*;
|
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
use crate::tls::*;
|
|
||||||
|
|
||||||
use futures::prelude::*;
|
|
||||||
use futures::task::{Context, Poll};
|
|
||||||
use hyper::http;
|
|
||||||
use hyper::server::conn::Http;
|
|
||||||
use hyper::{Body, Method, Request, Response, StatusCode};
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::pin::Pin;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
|
||||||
use tokio::net::{TcpListener, UdpSocket};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct DoH {
|
|
||||||
pub globals: Arc<Globals>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
|
|
||||||
let response = Response::builder()
|
|
||||||
.status(status_code)
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl hyper::service::Service<http::Request<Body>> for DoH {
|
|
||||||
type Response = Response<Body>;
|
|
||||||
type Error = http::Error;
|
|
||||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
|
||||||
|
|
||||||
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
||||||
Poll::Ready(Ok(()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
|
||||||
let globals = &self.globals;
|
|
||||||
if req.uri().path() != globals.path {
|
|
||||||
return Box::pin(async { http_error(StatusCode::NOT_FOUND) });
|
|
||||||
}
|
|
||||||
let self_inner = self.clone();
|
|
||||||
match *req.method() {
|
|
||||||
Method::POST => Box::pin(async move { self_inner.serve_post(req).await }),
|
|
||||||
Method::GET => Box::pin(async move { self_inner.serve_get(req).await }),
|
|
||||||
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DoH {
|
|
||||||
async fn serve_post(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
|
|
||||||
if self.globals.disable_post {
|
|
||||||
return http_error(StatusCode::METHOD_NOT_ALLOWED);
|
|
||||||
}
|
|
||||||
if let Err(response) = Self::check_content_type(&req) {
|
|
||||||
return Ok(response);
|
|
||||||
}
|
|
||||||
match self.read_body_and_proxy(req.into_body()).await {
|
|
||||||
Err(e) => http_error(StatusCode::from(e)),
|
|
||||||
Ok(res) => Ok(res),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn serve_get(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
|
|
||||||
let query = req.uri().query().unwrap_or("");
|
|
||||||
let mut question_str = None;
|
|
||||||
for parts in query.split('&') {
|
|
||||||
let mut kv = parts.split('=');
|
|
||||||
if let Some(k) = kv.next() {
|
|
||||||
if k == DNS_QUERY_PARAM {
|
|
||||||
question_str = kv.next();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let question = match question_str.and_then(|question_str| {
|
|
||||||
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
|
|
||||||
}) {
|
|
||||||
Some(question) => question,
|
|
||||||
_ => {
|
|
||||||
return http_error(StatusCode::BAD_REQUEST);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
match self.proxy(question).await {
|
|
||||||
Err(e) => http_error(StatusCode::from(e)),
|
|
||||||
Ok(res) => Ok(res),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn check_content_type(req: &Request<Body>) -> Result<(), Response<Body>> {
|
|
||||||
let headers = req.headers();
|
|
||||||
let content_type = match headers.get(hyper::header::CONTENT_TYPE) {
|
|
||||||
None => {
|
|
||||||
let response = Response::builder()
|
|
||||||
.status(StatusCode::NOT_ACCEPTABLE)
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
return Err(response);
|
|
||||||
}
|
|
||||||
Some(content_type) => content_type.to_str(),
|
|
||||||
};
|
|
||||||
let content_type = match content_type {
|
|
||||||
Err(_) => {
|
|
||||||
let response = Response::builder()
|
|
||||||
.status(StatusCode::BAD_REQUEST)
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
return Err(response);
|
|
||||||
}
|
|
||||||
Ok(content_type) => content_type.to_lowercase(),
|
|
||||||
};
|
|
||||||
if content_type != "application/dns-message" {
|
|
||||||
let response = Response::builder()
|
|
||||||
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
return Err(response);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn read_body_and_proxy(&self, mut body: Body) -> Result<Response<Body>, DoHError> {
|
|
||||||
let mut sum_size = 0;
|
|
||||||
let mut query = vec![];
|
|
||||||
while let Some(chunk) = body.next().await {
|
|
||||||
let chunk = chunk.map_err(|_| DoHError::TooLarge)?;
|
|
||||||
sum_size += chunk.len();
|
|
||||||
if sum_size >= MAX_DNS_QUESTION_LEN {
|
|
||||||
return Err(DoHError::TooLarge);
|
|
||||||
}
|
|
||||||
query.extend(chunk);
|
|
||||||
}
|
|
||||||
let response = self.proxy(query).await?;
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn proxy(&self, mut query: Vec<u8>) -> Result<Response<Body>, DoHError> {
|
|
||||||
if query.len() < MIN_DNS_PACKET_LEN {
|
|
||||||
return Err(DoHError::Incomplete);
|
|
||||||
}
|
|
||||||
let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as _);
|
|
||||||
let globals = &self.globals;
|
|
||||||
let mut socket = UdpSocket::bind(&globals.local_bind_address)
|
|
||||||
.await
|
|
||||||
.map_err(DoHError::Io)?;
|
|
||||||
let expected_server_address = globals.server_address;
|
|
||||||
let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl);
|
|
||||||
socket
|
|
||||||
.send_to(&query, &globals.server_address)
|
|
||||||
.map_err(DoHError::Io)
|
|
||||||
.await?;
|
|
||||||
let mut packet = vec![0; MAX_DNS_RESPONSE_LEN];
|
|
||||||
let (len, response_server_address) =
|
|
||||||
socket.recv_from(&mut packet).map_err(DoHError::Io).await?;
|
|
||||||
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
|
|
||||||
return Err(DoHError::UpstreamIssue);
|
|
||||||
}
|
|
||||||
packet.truncate(len);
|
|
||||||
let ttl = if dns::is_recoverable_error(&packet) {
|
|
||||||
err_ttl
|
|
||||||
} else {
|
|
||||||
match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) {
|
|
||||||
Err(_) => return Err(DoHError::UpstreamIssue),
|
|
||||||
Ok(ttl) => ttl,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
dns::add_edns_padding(&mut packet, BLOCK_SIZE).map_err(|_| DoHError::TooLarge)?;
|
|
||||||
let packet_len = packet.len();
|
|
||||||
let response = Response::builder()
|
|
||||||
.header(hyper::header::CONTENT_LENGTH, packet_len)
|
|
||||||
.header(hyper::header::CONTENT_TYPE, "application/dns-message")
|
|
||||||
.header(
|
|
||||||
hyper::header::CACHE_CONTROL,
|
|
||||||
format!("max-age={}", ttl).as_str(),
|
|
||||||
)
|
|
||||||
.body(Body::from(packet))
|
|
||||||
.unwrap();
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn client_serve<I>(self, stream: I, server: Http)
|
|
||||||
where
|
|
||||||
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
||||||
{
|
|
||||||
let clients_count = self.globals.clients_count.clone();
|
|
||||||
if clients_count.increment() > self.globals.max_clients {
|
|
||||||
clients_count.decrement();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
tokio::spawn(async move {
|
|
||||||
tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self))
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
clients_count.decrement();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn start_without_tls(
|
|
||||||
self,
|
|
||||||
mut listener: TcpListener,
|
|
||||||
server: Http,
|
|
||||||
) -> Result<(), DoHError> {
|
|
||||||
let listener_service = async {
|
|
||||||
while let Some(stream) = listener.incoming().next().await {
|
|
||||||
let stream = match stream {
|
|
||||||
Ok(stream) => stream,
|
|
||||||
Err(_) => continue,
|
|
||||||
};
|
|
||||||
self.clone().client_serve(stream, server.clone()).await;
|
|
||||||
}
|
|
||||||
Ok(()) as Result<(), DoHError>
|
|
||||||
};
|
|
||||||
listener_service.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn entrypoint(self) -> Result<(), DoHError> {
|
|
||||||
let listen_address = self.globals.listen_address;
|
|
||||||
let listener = TcpListener::bind(&listen_address)
|
|
||||||
.await
|
|
||||||
.map_err(DoHError::Io)?;
|
|
||||||
let path = &self.globals.path;
|
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
|
|
||||||
(Some(tls_cert_path), Some(tls_cert_password)) => {
|
|
||||||
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
#[cfg(not(feature = "tls"))]
|
|
||||||
let tls_acceptor: Option<()> = None;
|
|
||||||
|
|
||||||
if tls_acceptor.is_some() {
|
|
||||||
println!("Listening on https://{}{}", listen_address, path);
|
|
||||||
} else {
|
|
||||||
println!("Listening on http://{}{}", listen_address, path);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut server = Http::new();
|
|
||||||
server.keep_alive(self.globals.keepalive);
|
|
||||||
server.pipeline_flush(true);
|
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
{
|
|
||||||
if let Some(tls_acceptor) = tls_acceptor {
|
|
||||||
self.start_with_tls(tls_acceptor, listener, server).await?;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.start_without_tls(listener, server).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let mut globals = Globals {
|
let mut globals = Globals {
|
||||||
|
@ -287,7 +30,7 @@ fn main() {
|
||||||
path: PATH.to_string(),
|
path: PATH.to_string(),
|
||||||
max_clients: MAX_CLIENTS,
|
max_clients: MAX_CLIENTS,
|
||||||
timeout: Duration::from_secs(TIMEOUT_SEC),
|
timeout: Duration::from_secs(TIMEOUT_SEC),
|
||||||
clients_count: ClientsCount::default(),
|
clients_count: Default::default(),
|
||||||
min_ttl: MIN_TTL,
|
min_ttl: MIN_TTL,
|
||||||
max_ttl: MAX_TTL,
|
max_ttl: MAX_TTL,
|
||||||
err_ttl: ERR_TTL,
|
err_ttl: ERR_TTL,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue