This commit is contained in:
Frank Denis 2019-12-23 16:32:02 +01:00
parent 1b850b2f41
commit 0d55bf73c6

View file

@ -27,6 +27,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, UdpSocket};
#[cfg(feature = "tls")]
@ -250,55 +251,97 @@ impl DoH {
Ok(response)
}
async fn entrypoint(self) -> Result<(), Error> {
let listen_address = self.globals.listen_address;
let mut listener = TcpListener::bind(&listen_address).await?;
let path = &self.globals.path;
async fn client_serve<I>(self, stream: I, server: Http)
where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let clients_count = self.globals.clients_count.clone();
if clients_count.increment() > self.globals.max_clients {
clients_count.decrement();
return;
}
tokio::spawn(async move {
tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self))
.await
.ok();
clients_count.decrement();
});
}
#[cfg(feature = "tls")]
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
(Some(tls_cert_path), Some(tls_cert_password)) => {
println!("Listening on https://{}{}", listen_address, path);
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
}
_ => {
println!("Listening on http://{}{}", listen_address, path);
None
}
};
#[cfg(not(feature = "tls"))]
println!("Listening on http://{}{}", listen_address, path);
let mut server = Http::new();
server.keep_alive(self.globals.keepalive);
async fn start_without_tls(self, mut listener: TcpListener, server: Http) -> Result<(), Error> {
let listener_service = async {
while let Some(stream) = listener.incoming().next().await {
let stream = match stream {
Ok(stream) => stream,
Err(_) => continue,
};
let clients_count = self.globals.clients_count.clone();
if clients_count.increment() > self.globals.max_clients {
clients_count.decrement();
continue;
}
let self_inner = self.clone();
let server_inner = server.clone();
tokio::spawn(async move {
tokio::time::timeout(
self_inner.globals.timeout,
server_inner.serve_connection(stream, self_inner),
)
.await
.ok();
clients_count.decrement();
});
self.clone().client_serve(stream, server.clone()).await;
}
Ok(()) as Result<(), Error>
};
listener_service.await?;
Ok(())
}
#[cfg(feature = "tls")]
async fn start_with_tls(
self,
tls_acceptor: TlsAcceptor,
mut listener: TcpListener,
server: Http,
) -> Result<(), Error> {
let listener_service = async {
while let Some(raw_stream) = listener.incoming().next().await {
let raw_stream = match raw_stream {
Ok(raw_stream) => raw_stream,
Err(_) => continue,
};
let stream = match tls_acceptor.accept(raw_stream).await {
Ok(stream) => stream,
Err(_) => continue,
};
self.clone().client_serve(stream, server.clone()).await;
}
Ok(()) as Result<(), Error>
};
listener_service.await?;
Ok(())
}
async fn entrypoint(self) -> Result<(), Error> {
let listen_address = self.globals.listen_address;
let listener = TcpListener::bind(&listen_address).await?;
let path = &self.globals.path;
#[cfg(feature = "tls")]
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
(Some(tls_cert_path), Some(tls_cert_password)) => {
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
}
_ => None,
};
#[cfg(not(feature = "tls"))]
let tls_acceptor: Option<()> = None;
if tls_acceptor.is_some() {
println!("Listening on https://{}{}", listen_address, path);
} else {
println!("Listening on http://{}{}", listen_address, path);
}
let mut server = Http::new();
server.keep_alive(self.globals.keepalive);
#[cfg(feature = "tls")]
{
if let Some(tls_acceptor) = tls_acceptor {
self.start_with_tls(tls_acceptor, listener, server).await?;
return Ok(());
}
}
self.start_without_tls(listener, server).await?;
Ok(())
}
}
fn main() {