mirror of
https://github.com/DNSCrypt/doh-server.git
synced 2025-04-03 04:57:37 +03:00
Refactor
This commit is contained in:
parent
1b850b2f41
commit
0d55bf73c6
1 changed files with 79 additions and 36 deletions
115
src/main.rs
115
src/main.rs
|
@ -27,6 +27,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV
|
|||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::{TcpListener, UdpSocket};
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
|
@ -250,55 +251,97 @@ impl DoH {
|
|||
Ok(response)
|
||||
}
|
||||
|
||||
async fn entrypoint(self) -> Result<(), Error> {
|
||||
let listen_address = self.globals.listen_address;
|
||||
let mut listener = TcpListener::bind(&listen_address).await?;
|
||||
let path = &self.globals.path;
|
||||
async fn client_serve<I>(self, stream: I, server: Http)
|
||||
where
|
||||
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
let clients_count = self.globals.clients_count.clone();
|
||||
if clients_count.increment() > self.globals.max_clients {
|
||||
clients_count.decrement();
|
||||
return;
|
||||
}
|
||||
tokio::spawn(async move {
|
||||
tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self))
|
||||
.await
|
||||
.ok();
|
||||
clients_count.decrement();
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
|
||||
(Some(tls_cert_path), Some(tls_cert_password)) => {
|
||||
println!("Listening on https://{}{}", listen_address, path);
|
||||
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
|
||||
}
|
||||
_ => {
|
||||
println!("Listening on http://{}{}", listen_address, path);
|
||||
None
|
||||
}
|
||||
};
|
||||
#[cfg(not(feature = "tls"))]
|
||||
println!("Listening on http://{}{}", listen_address, path);
|
||||
|
||||
let mut server = Http::new();
|
||||
server.keep_alive(self.globals.keepalive);
|
||||
async fn start_without_tls(self, mut listener: TcpListener, server: Http) -> Result<(), Error> {
|
||||
let listener_service = async {
|
||||
while let Some(stream) = listener.incoming().next().await {
|
||||
let stream = match stream {
|
||||
Ok(stream) => stream,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let clients_count = self.globals.clients_count.clone();
|
||||
if clients_count.increment() > self.globals.max_clients {
|
||||
clients_count.decrement();
|
||||
continue;
|
||||
}
|
||||
let self_inner = self.clone();
|
||||
let server_inner = server.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::timeout(
|
||||
self_inner.globals.timeout,
|
||||
server_inner.serve_connection(stream, self_inner),
|
||||
)
|
||||
.await
|
||||
.ok();
|
||||
clients_count.decrement();
|
||||
});
|
||||
self.clone().client_serve(stream, server.clone()).await;
|
||||
}
|
||||
Ok(()) as Result<(), Error>
|
||||
};
|
||||
listener_service.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
async fn start_with_tls(
|
||||
self,
|
||||
tls_acceptor: TlsAcceptor,
|
||||
mut listener: TcpListener,
|
||||
server: Http,
|
||||
) -> Result<(), Error> {
|
||||
let listener_service = async {
|
||||
while let Some(raw_stream) = listener.incoming().next().await {
|
||||
let raw_stream = match raw_stream {
|
||||
Ok(raw_stream) => raw_stream,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let stream = match tls_acceptor.accept(raw_stream).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_) => continue,
|
||||
};
|
||||
self.clone().client_serve(stream, server.clone()).await;
|
||||
}
|
||||
Ok(()) as Result<(), Error>
|
||||
};
|
||||
listener_service.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn entrypoint(self) -> Result<(), Error> {
|
||||
let listen_address = self.globals.listen_address;
|
||||
let listener = TcpListener::bind(&listen_address).await?;
|
||||
let path = &self.globals.path;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
|
||||
(Some(tls_cert_path), Some(tls_cert_password)) => {
|
||||
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
#[cfg(not(feature = "tls"))]
|
||||
let tls_acceptor: Option<()> = None;
|
||||
|
||||
if tls_acceptor.is_some() {
|
||||
println!("Listening on https://{}{}", listen_address, path);
|
||||
} else {
|
||||
println!("Listening on http://{}{}", listen_address, path);
|
||||
}
|
||||
|
||||
let mut server = Http::new();
|
||||
server.keep_alive(self.globals.keepalive);
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
{
|
||||
if let Some(tls_acceptor) = tls_acceptor {
|
||||
self.start_with_tls(tls_acceptor, listener, server).await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
self.start_without_tls(listener, server).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue