mirror of
https://github.com/DNSCrypt/doh-server.git
synced 2025-03-31 11:47:36 +03:00
Reorganize a little bit
This commit is contained in:
parent
bf42e95368
commit
06b91af009
12 changed files with 312 additions and 275 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,5 +1,7 @@
|
|||
#*#
|
||||
**/*.rs.bk
|
||||
*~
|
||||
Cargo.lock
|
||||
/target/
|
||||
/src/libdoh/target/
|
||||
|
||||
|
|
12
Cargo.toml
12
Cargo.toml
|
@ -13,19 +13,13 @@ readme = "README.md"
|
|||
|
||||
[features]
|
||||
default = []
|
||||
tls = ["native-tls", "tokio-tls"]
|
||||
tls = ["libdoh/tls"]
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
byteorder = "1.3"
|
||||
base64 = "0.11"
|
||||
clap = "2.33.0"
|
||||
futures = { version = "0.3" }
|
||||
hyper = { version = "0.13", default-features = false, features = ["stream"] }
|
||||
libdoh = { path = "src/libdoh" }
|
||||
clap = "2"
|
||||
jemallocator = "0"
|
||||
native-tls = { version = "0.2.3", optional = true }
|
||||
tokio = { version = "0.2", features = ["rt-threaded", "time", "tcp", "udp", "stream"] }
|
||||
tokio-tls = { version = "0.3", optional = true }
|
||||
|
||||
[package.metadata.deb]
|
||||
extended-description = """\
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use libdoh::*;
|
||||
|
||||
use crate::constants::*;
|
||||
use crate::globals::*;
|
||||
|
||||
use clap::Arg;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
|
||||
|
|
|
@ -1,10 +1,5 @@
|
|||
pub const BLOCK_SIZE: usize = 128;
|
||||
pub const DNS_QUERY_PARAM: &str = "dns";
|
||||
pub const LISTEN_ADDRESS: &str = "127.0.0.1:3000";
|
||||
pub const MAX_CLIENTS: usize = 512;
|
||||
pub const MAX_DNS_QUESTION_LEN: usize = 512;
|
||||
pub const MAX_DNS_RESPONSE_LEN: usize = 4096;
|
||||
pub const MIN_DNS_PACKET_LEN: usize = 17;
|
||||
pub const PATH: &str = "/dns-query";
|
||||
pub const SERVER_ADDRESS: &str = "9.9.9.9:53";
|
||||
pub const TIMEOUT_SEC: u64 = 10;
|
||||
|
|
28
src/libdoh/Cargo.toml
Normal file
28
src/libdoh/Cargo.toml
Normal file
|
@ -0,0 +1,28 @@
|
|||
[package]
|
||||
name = "libdoh"
|
||||
version = "0.2.0"
|
||||
authors = ["Frank Denis <github@pureftpd.org>"]
|
||||
license = "MIT"
|
||||
edition = "2018"
|
||||
publish = false
|
||||
|
||||
[features]
|
||||
default = []
|
||||
tls = ["native-tls", "tokio-tls"]
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
byteorder = "1.3"
|
||||
base64 = "0.11"
|
||||
futures = { version = "0.3" }
|
||||
hyper = { version = "0.13", default-features = false, features = ["stream"] }
|
||||
native-tls = { version = "0.2.3", optional = true }
|
||||
tokio = { version = "0.2", features = ["rt-threaded", "time", "tcp", "udp", "stream"] }
|
||||
tokio-tls = { version = "0.3", optional = true }
|
||||
|
||||
[profile.release]
|
||||
codegen-units = 1
|
||||
incremental = false
|
||||
lto = "fat"
|
||||
opt-level = 3
|
||||
panic = "abort"
|
5
src/libdoh/src/constants.rs
Normal file
5
src/libdoh/src/constants.rs
Normal file
|
@ -0,0 +1,5 @@
|
|||
pub const BLOCK_SIZE: usize = 128;
|
||||
pub const DNS_QUERY_PARAM: &str = "dns";
|
||||
pub const MAX_DNS_QUESTION_LEN: usize = 512;
|
||||
pub const MAX_DNS_RESPONSE_LEN: usize = 4096;
|
||||
pub const MIN_DNS_PACKET_LEN: usize = 17;
|
|
@ -243,6 +243,11 @@ pub fn add_edns_padding(packet: &mut Vec<u8>, block_size: usize) -> Result<(), E
|
|||
let edns_rdlen_offset: usize = edns_offset + 8;
|
||||
ensure!(packet_len - edns_rdlen_offset >= 2, "Short packet");
|
||||
let edns_rdlen = BigEndian::read_u16(&packet[edns_rdlen_offset..]);
|
||||
dbg!(edns_rdlen);
|
||||
ensure!(
|
||||
edns_offset + edns_rdlen as usize <= packet_len,
|
||||
"Out of range EDNS size"
|
||||
);
|
||||
ensure!(
|
||||
0xffff - edns_rdlen as usize >= edns_padding_prr_len,
|
||||
"EDNS section too large for padding"
|
264
src/libdoh/src/lib.rs
Normal file
264
src/libdoh/src/lib.rs
Normal file
|
@ -0,0 +1,264 @@
|
|||
mod constants;
|
||||
mod dns;
|
||||
mod errors;
|
||||
mod globals;
|
||||
#[cfg(feature = "tls")]
|
||||
mod tls;
|
||||
|
||||
use crate::constants::*;
|
||||
pub use crate::errors::*;
|
||||
pub use crate::globals::*;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use crate::tls::*;
|
||||
|
||||
use futures::prelude::*;
|
||||
use futures::task::{Context, Poll};
|
||||
use hyper::http;
|
||||
use hyper::server::conn::Http;
|
||||
use hyper::{Body, Method, Request, Response, StatusCode};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::{TcpListener, UdpSocket};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DoH {
|
||||
pub globals: Arc<Globals>,
|
||||
}
|
||||
|
||||
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
|
||||
let response = Response::builder()
|
||||
.status(status_code)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
impl hyper::service::Service<http::Request<Body>> for DoH {
|
||||
type Response = Response<Body>;
|
||||
type Error = http::Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||
let globals = &self.globals;
|
||||
if req.uri().path() != globals.path {
|
||||
return Box::pin(async { http_error(StatusCode::NOT_FOUND) });
|
||||
}
|
||||
let self_inner = self.clone();
|
||||
match *req.method() {
|
||||
Method::POST => Box::pin(async move { self_inner.serve_post(req).await }),
|
||||
Method::GET => Box::pin(async move { self_inner.serve_get(req).await }),
|
||||
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DoH {
|
||||
async fn serve_post(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
|
||||
if self.globals.disable_post {
|
||||
return http_error(StatusCode::METHOD_NOT_ALLOWED);
|
||||
}
|
||||
if let Err(response) = Self::check_content_type(&req) {
|
||||
return Ok(response);
|
||||
}
|
||||
match self.read_body_and_proxy(req.into_body()).await {
|
||||
Err(e) => http_error(StatusCode::from(e)),
|
||||
Ok(res) => Ok(res),
|
||||
}
|
||||
}
|
||||
|
||||
async fn serve_get(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
|
||||
let query = req.uri().query().unwrap_or("");
|
||||
let mut question_str = None;
|
||||
for parts in query.split('&') {
|
||||
let mut kv = parts.split('=');
|
||||
if let Some(k) = kv.next() {
|
||||
if k == DNS_QUERY_PARAM {
|
||||
question_str = kv.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
let question = match question_str.and_then(|question_str| {
|
||||
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
|
||||
}) {
|
||||
Some(question) => question,
|
||||
_ => {
|
||||
return http_error(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
};
|
||||
match self.proxy(question).await {
|
||||
Err(e) => http_error(StatusCode::from(e)),
|
||||
Ok(res) => Ok(res),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_content_type(req: &Request<Body>) -> Result<(), Response<Body>> {
|
||||
let headers = req.headers();
|
||||
let content_type = match headers.get(hyper::header::CONTENT_TYPE) {
|
||||
None => {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::NOT_ACCEPTABLE)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Err(response);
|
||||
}
|
||||
Some(content_type) => content_type.to_str(),
|
||||
};
|
||||
let content_type = match content_type {
|
||||
Err(_) => {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Err(response);
|
||||
}
|
||||
Ok(content_type) => content_type.to_lowercase(),
|
||||
};
|
||||
if content_type != "application/dns-message" {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Err(response);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_body_and_proxy(&self, mut body: Body) -> Result<Response<Body>, DoHError> {
|
||||
let mut sum_size = 0;
|
||||
let mut query = vec![];
|
||||
while let Some(chunk) = body.next().await {
|
||||
let chunk = chunk.map_err(|_| DoHError::TooLarge)?;
|
||||
sum_size += chunk.len();
|
||||
if sum_size >= MAX_DNS_QUESTION_LEN {
|
||||
return Err(DoHError::TooLarge);
|
||||
}
|
||||
query.extend(chunk);
|
||||
}
|
||||
let response = self.proxy(query).await?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn proxy(&self, mut query: Vec<u8>) -> Result<Response<Body>, DoHError> {
|
||||
if query.len() < MIN_DNS_PACKET_LEN {
|
||||
return Err(DoHError::Incomplete);
|
||||
}
|
||||
let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as _);
|
||||
let globals = &self.globals;
|
||||
let mut socket = UdpSocket::bind(&globals.local_bind_address)
|
||||
.await
|
||||
.map_err(DoHError::Io)?;
|
||||
let expected_server_address = globals.server_address;
|
||||
let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl);
|
||||
socket
|
||||
.send_to(&query, &globals.server_address)
|
||||
.map_err(DoHError::Io)
|
||||
.await?;
|
||||
let mut packet = vec![0; MAX_DNS_RESPONSE_LEN];
|
||||
let (len, response_server_address) =
|
||||
socket.recv_from(&mut packet).map_err(DoHError::Io).await?;
|
||||
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
|
||||
return Err(DoHError::UpstreamIssue);
|
||||
}
|
||||
packet.truncate(len);
|
||||
let ttl = if dns::is_recoverable_error(&packet) {
|
||||
err_ttl
|
||||
} else {
|
||||
match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) {
|
||||
Err(_) => return Err(DoHError::UpstreamIssue),
|
||||
Ok(ttl) => ttl,
|
||||
}
|
||||
};
|
||||
dns::add_edns_padding(&mut packet, BLOCK_SIZE).map_err(|_| DoHError::TooLarge)?;
|
||||
let packet_len = packet.len();
|
||||
let response = Response::builder()
|
||||
.header(hyper::header::CONTENT_LENGTH, packet_len)
|
||||
.header(hyper::header::CONTENT_TYPE, "application/dns-message")
|
||||
.header(
|
||||
hyper::header::CACHE_CONTROL,
|
||||
format!("max-age={}", ttl).as_str(),
|
||||
)
|
||||
.body(Body::from(packet))
|
||||
.unwrap();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn client_serve<I>(self, stream: I, server: Http)
|
||||
where
|
||||
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
let clients_count = self.globals.clients_count.clone();
|
||||
if clients_count.increment() > self.globals.max_clients {
|
||||
clients_count.decrement();
|
||||
return;
|
||||
}
|
||||
tokio::spawn(async move {
|
||||
tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self))
|
||||
.await
|
||||
.ok();
|
||||
clients_count.decrement();
|
||||
});
|
||||
}
|
||||
|
||||
async fn start_without_tls(
|
||||
self,
|
||||
mut listener: TcpListener,
|
||||
server: Http,
|
||||
) -> Result<(), DoHError> {
|
||||
let listener_service = async {
|
||||
while let Some(stream) = listener.incoming().next().await {
|
||||
let stream = match stream {
|
||||
Ok(stream) => stream,
|
||||
Err(_) => continue,
|
||||
};
|
||||
self.clone().client_serve(stream, server.clone()).await;
|
||||
}
|
||||
Ok(()) as Result<(), DoHError>
|
||||
};
|
||||
listener_service.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn entrypoint(self) -> Result<(), DoHError> {
|
||||
let listen_address = self.globals.listen_address;
|
||||
let listener = TcpListener::bind(&listen_address)
|
||||
.await
|
||||
.map_err(DoHError::Io)?;
|
||||
let path = &self.globals.path;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
|
||||
(Some(tls_cert_path), Some(tls_cert_password)) => {
|
||||
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
#[cfg(not(feature = "tls"))]
|
||||
let tls_acceptor: Option<()> = None;
|
||||
|
||||
if tls_acceptor.is_some() {
|
||||
println!("Listening on https://{}{}", listen_address, path);
|
||||
} else {
|
||||
println!("Listening on http://{}{}", listen_address, path);
|
||||
}
|
||||
|
||||
let mut server = Http::new();
|
||||
server.keep_alive(self.globals.keepalive);
|
||||
server.pipeline_flush(true);
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
{
|
||||
if let Some(tls_acceptor) = tls_acceptor {
|
||||
self.start_with_tls(tls_acceptor, listener, server).await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
self.start_without_tls(listener, server).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
263
src/main.rs
263
src/main.rs
|
@ -6,273 +6,16 @@ extern crate clap;
|
|||
|
||||
mod config;
|
||||
mod constants;
|
||||
mod dns;
|
||||
mod errors;
|
||||
mod globals;
|
||||
#[cfg(feature = "tls")]
|
||||
mod tls;
|
||||
mod utils;
|
||||
|
||||
use libdoh::*;
|
||||
|
||||
use crate::config::*;
|
||||
use crate::constants::*;
|
||||
use crate::errors::*;
|
||||
use crate::globals::*;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use crate::tls::*;
|
||||
|
||||
use futures::prelude::*;
|
||||
use futures::task::{Context, Poll};
|
||||
use hyper::http;
|
||||
use hyper::server::conn::Http;
|
||||
use hyper::{Body, Method, Request, Response, StatusCode};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::{TcpListener, UdpSocket};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DoH {
|
||||
pub globals: Arc<Globals>,
|
||||
}
|
||||
|
||||
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
|
||||
let response = Response::builder()
|
||||
.status(status_code)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
impl hyper::service::Service<http::Request<Body>> for DoH {
|
||||
type Response = Response<Body>;
|
||||
type Error = http::Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||
let globals = &self.globals;
|
||||
if req.uri().path() != globals.path {
|
||||
return Box::pin(async { http_error(StatusCode::NOT_FOUND) });
|
||||
}
|
||||
let self_inner = self.clone();
|
||||
match *req.method() {
|
||||
Method::POST => Box::pin(async move { self_inner.serve_post(req).await }),
|
||||
Method::GET => Box::pin(async move { self_inner.serve_get(req).await }),
|
||||
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DoH {
|
||||
async fn serve_post(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
|
||||
if self.globals.disable_post {
|
||||
return http_error(StatusCode::METHOD_NOT_ALLOWED);
|
||||
}
|
||||
if let Err(response) = Self::check_content_type(&req) {
|
||||
return Ok(response);
|
||||
}
|
||||
match self.read_body_and_proxy(req.into_body()).await {
|
||||
Err(e) => http_error(StatusCode::from(e)),
|
||||
Ok(res) => Ok(res),
|
||||
}
|
||||
}
|
||||
|
||||
async fn serve_get(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
|
||||
let query = req.uri().query().unwrap_or("");
|
||||
let mut question_str = None;
|
||||
for parts in query.split('&') {
|
||||
let mut kv = parts.split('=');
|
||||
if let Some(k) = kv.next() {
|
||||
if k == DNS_QUERY_PARAM {
|
||||
question_str = kv.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
let question = match question_str.and_then(|question_str| {
|
||||
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
|
||||
}) {
|
||||
Some(question) => question,
|
||||
_ => {
|
||||
return http_error(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
};
|
||||
match self.proxy(question).await {
|
||||
Err(e) => http_error(StatusCode::from(e)),
|
||||
Ok(res) => Ok(res),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_content_type(req: &Request<Body>) -> Result<(), Response<Body>> {
|
||||
let headers = req.headers();
|
||||
let content_type = match headers.get(hyper::header::CONTENT_TYPE) {
|
||||
None => {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::NOT_ACCEPTABLE)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Err(response);
|
||||
}
|
||||
Some(content_type) => content_type.to_str(),
|
||||
};
|
||||
let content_type = match content_type {
|
||||
Err(_) => {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Err(response);
|
||||
}
|
||||
Ok(content_type) => content_type.to_lowercase(),
|
||||
};
|
||||
if content_type != "application/dns-message" {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Err(response);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_body_and_proxy(&self, mut body: Body) -> Result<Response<Body>, DoHError> {
|
||||
let mut sum_size = 0;
|
||||
let mut query = vec![];
|
||||
while let Some(chunk) = body.next().await {
|
||||
let chunk = chunk.map_err(|_| DoHError::TooLarge)?;
|
||||
sum_size += chunk.len();
|
||||
if sum_size >= MAX_DNS_QUESTION_LEN {
|
||||
return Err(DoHError::TooLarge);
|
||||
}
|
||||
query.extend(chunk);
|
||||
}
|
||||
let response = self.proxy(query).await?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn proxy(&self, mut query: Vec<u8>) -> Result<Response<Body>, DoHError> {
|
||||
if query.len() < MIN_DNS_PACKET_LEN {
|
||||
return Err(DoHError::Incomplete);
|
||||
}
|
||||
let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as _);
|
||||
let globals = &self.globals;
|
||||
let mut socket = UdpSocket::bind(&globals.local_bind_address)
|
||||
.await
|
||||
.map_err(DoHError::Io)?;
|
||||
let expected_server_address = globals.server_address;
|
||||
let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl);
|
||||
socket
|
||||
.send_to(&query, &globals.server_address)
|
||||
.map_err(DoHError::Io)
|
||||
.await?;
|
||||
let mut packet = vec![0; MAX_DNS_RESPONSE_LEN];
|
||||
let (len, response_server_address) =
|
||||
socket.recv_from(&mut packet).map_err(DoHError::Io).await?;
|
||||
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
|
||||
return Err(DoHError::UpstreamIssue);
|
||||
}
|
||||
packet.truncate(len);
|
||||
let ttl = if dns::is_recoverable_error(&packet) {
|
||||
err_ttl
|
||||
} else {
|
||||
match dns::min_ttl(&packet, min_ttl, max_ttl, err_ttl) {
|
||||
Err(_) => return Err(DoHError::UpstreamIssue),
|
||||
Ok(ttl) => ttl,
|
||||
}
|
||||
};
|
||||
dns::add_edns_padding(&mut packet, BLOCK_SIZE).map_err(|_| DoHError::TooLarge)?;
|
||||
let packet_len = packet.len();
|
||||
let response = Response::builder()
|
||||
.header(hyper::header::CONTENT_LENGTH, packet_len)
|
||||
.header(hyper::header::CONTENT_TYPE, "application/dns-message")
|
||||
.header(
|
||||
hyper::header::CACHE_CONTROL,
|
||||
format!("max-age={}", ttl).as_str(),
|
||||
)
|
||||
.body(Body::from(packet))
|
||||
.unwrap();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn client_serve<I>(self, stream: I, server: Http)
|
||||
where
|
||||
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
let clients_count = self.globals.clients_count.clone();
|
||||
if clients_count.increment() > self.globals.max_clients {
|
||||
clients_count.decrement();
|
||||
return;
|
||||
}
|
||||
tokio::spawn(async move {
|
||||
tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self))
|
||||
.await
|
||||
.ok();
|
||||
clients_count.decrement();
|
||||
});
|
||||
}
|
||||
|
||||
async fn start_without_tls(
|
||||
self,
|
||||
mut listener: TcpListener,
|
||||
server: Http,
|
||||
) -> Result<(), DoHError> {
|
||||
let listener_service = async {
|
||||
while let Some(stream) = listener.incoming().next().await {
|
||||
let stream = match stream {
|
||||
Ok(stream) => stream,
|
||||
Err(_) => continue,
|
||||
};
|
||||
self.clone().client_serve(stream, server.clone()).await;
|
||||
}
|
||||
Ok(()) as Result<(), DoHError>
|
||||
};
|
||||
listener_service.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn entrypoint(self) -> Result<(), DoHError> {
|
||||
let listen_address = self.globals.listen_address;
|
||||
let listener = TcpListener::bind(&listen_address)
|
||||
.await
|
||||
.map_err(DoHError::Io)?;
|
||||
let path = &self.globals.path;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
|
||||
(Some(tls_cert_path), Some(tls_cert_password)) => {
|
||||
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
#[cfg(not(feature = "tls"))]
|
||||
let tls_acceptor: Option<()> = None;
|
||||
|
||||
if tls_acceptor.is_some() {
|
||||
println!("Listening on https://{}{}", listen_address, path);
|
||||
} else {
|
||||
println!("Listening on http://{}{}", listen_address, path);
|
||||
}
|
||||
|
||||
let mut server = Http::new();
|
||||
server.keep_alive(self.globals.keepalive);
|
||||
server.pipeline_flush(true);
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
{
|
||||
if let Some(tls_acceptor) = tls_acceptor {
|
||||
self.start_with_tls(tls_acceptor, listener, server).await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
self.start_without_tls(listener, server).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let mut globals = Globals {
|
||||
|
@ -287,7 +30,7 @@ fn main() {
|
|||
path: PATH.to_string(),
|
||||
max_clients: MAX_CLIENTS,
|
||||
timeout: Duration::from_secs(TIMEOUT_SEC),
|
||||
clients_count: ClientsCount::default(),
|
||||
clients_count: Default::default(),
|
||||
min_ttl: MIN_TTL,
|
||||
max_ttl: MAX_TTL,
|
||||
err_ttl: ERR_TTL,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue