From 0d55bf73c6f6c7334a607d1664e2cc70c2ffbb95 Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Mon, 23 Dec 2019 16:32:02 +0100 Subject: [PATCH] Refactor --- src/main.rs | 115 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 79 insertions(+), 36 deletions(-) diff --git a/src/main.rs b/src/main.rs index 06f6bfb..6bac0db 100644 --- a/src/main.rs +++ b/src/main.rs @@ -27,6 +27,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV use std::pin::Pin; use std::sync::Arc; use std::time::Duration; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, UdpSocket}; #[cfg(feature = "tls")] @@ -250,55 +251,97 @@ impl DoH { Ok(response) } - async fn entrypoint(self) -> Result<(), Error> { - let listen_address = self.globals.listen_address; - let mut listener = TcpListener::bind(&listen_address).await?; - let path = &self.globals.path; + async fn client_serve(self, stream: I, server: Http) + where + I: AsyncRead + AsyncWrite + Send + Unpin + 'static, + { + let clients_count = self.globals.clients_count.clone(); + if clients_count.increment() > self.globals.max_clients { + clients_count.decrement(); + return; + } + tokio::spawn(async move { + tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self)) + .await + .ok(); + clients_count.decrement(); + }); + } - #[cfg(feature = "tls")] - let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) { - (Some(tls_cert_path), Some(tls_cert_password)) => { - println!("Listening on https://{}{}", listen_address, path); - Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap()) - } - _ => { - println!("Listening on http://{}{}", listen_address, path); - None - } - }; - #[cfg(not(feature = "tls"))] - println!("Listening on http://{}{}", listen_address, path); - - let mut server = Http::new(); - server.keep_alive(self.globals.keepalive); + async fn start_without_tls(self, mut listener: TcpListener, server: Http) -> Result<(), Error> { let listener_service = async { while let Some(stream) = listener.incoming().next().await { let stream = match stream { Ok(stream) => stream, Err(_) => continue, }; - let clients_count = self.globals.clients_count.clone(); - if clients_count.increment() > self.globals.max_clients { - clients_count.decrement(); - continue; - } - let self_inner = self.clone(); - let server_inner = server.clone(); - tokio::spawn(async move { - tokio::time::timeout( - self_inner.globals.timeout, - server_inner.serve_connection(stream, self_inner), - ) - .await - .ok(); - clients_count.decrement(); - }); + self.clone().client_serve(stream, server.clone()).await; } Ok(()) as Result<(), Error> }; listener_service.await?; Ok(()) } + + #[cfg(feature = "tls")] + async fn start_with_tls( + self, + tls_acceptor: TlsAcceptor, + mut listener: TcpListener, + server: Http, + ) -> Result<(), Error> { + let listener_service = async { + while let Some(raw_stream) = listener.incoming().next().await { + let raw_stream = match raw_stream { + Ok(raw_stream) => raw_stream, + Err(_) => continue, + }; + let stream = match tls_acceptor.accept(raw_stream).await { + Ok(stream) => stream, + Err(_) => continue, + }; + self.clone().client_serve(stream, server.clone()).await; + } + Ok(()) as Result<(), Error> + }; + listener_service.await?; + Ok(()) + } + + async fn entrypoint(self) -> Result<(), Error> { + let listen_address = self.globals.listen_address; + let listener = TcpListener::bind(&listen_address).await?; + let path = &self.globals.path; + + #[cfg(feature = "tls")] + let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) { + (Some(tls_cert_path), Some(tls_cert_password)) => { + Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap()) + } + _ => None, + }; + #[cfg(not(feature = "tls"))] + let tls_acceptor: Option<()> = None; + + if tls_acceptor.is_some() { + println!("Listening on https://{}{}", listen_address, path); + } else { + println!("Listening on http://{}{}", listen_address, path); + } + + let mut server = Http::new(); + server.keep_alive(self.globals.keepalive); + + #[cfg(feature = "tls")] + { + if let Some(tls_acceptor) = tls_acceptor { + self.start_with_tls(tls_acceptor, listener, server).await?; + return Ok(()); + } + } + self.start_without_tls(listener, server).await?; + Ok(()) + } } fn main() {