mirror of
https://github.com/DNSCrypt/doh-server.git
synced 2025-04-05 14:07:37 +03:00
Refactor
This commit is contained in:
parent
1b850b2f41
commit
0d55bf73c6
1 changed files with 79 additions and 36 deletions
111
src/main.rs
111
src/main.rs
|
@ -27,6 +27,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tokio::net::{TcpListener, UdpSocket};
|
use tokio::net::{TcpListener, UdpSocket};
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
|
@ -250,55 +251,97 @@ impl DoH {
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn entrypoint(self) -> Result<(), Error> {
|
async fn client_serve<I>(self, stream: I, server: Http)
|
||||||
let listen_address = self.globals.listen_address;
|
where
|
||||||
let mut listener = TcpListener::bind(&listen_address).await?;
|
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||||
let path = &self.globals.path;
|
{
|
||||||
|
let clients_count = self.globals.clients_count.clone();
|
||||||
#[cfg(feature = "tls")]
|
if clients_count.increment() > self.globals.max_clients {
|
||||||
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
|
clients_count.decrement();
|
||||||
(Some(tls_cert_path), Some(tls_cert_password)) => {
|
return;
|
||||||
println!("Listening on https://{}{}", listen_address, path);
|
|
||||||
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
|
|
||||||
}
|
}
|
||||||
_ => {
|
tokio::spawn(async move {
|
||||||
println!("Listening on http://{}{}", listen_address, path);
|
tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self))
|
||||||
None
|
.await
|
||||||
|
.ok();
|
||||||
|
clients_count.decrement();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
};
|
|
||||||
#[cfg(not(feature = "tls"))]
|
|
||||||
println!("Listening on http://{}{}", listen_address, path);
|
|
||||||
|
|
||||||
let mut server = Http::new();
|
async fn start_without_tls(self, mut listener: TcpListener, server: Http) -> Result<(), Error> {
|
||||||
server.keep_alive(self.globals.keepalive);
|
|
||||||
let listener_service = async {
|
let listener_service = async {
|
||||||
while let Some(stream) = listener.incoming().next().await {
|
while let Some(stream) = listener.incoming().next().await {
|
||||||
let stream = match stream {
|
let stream = match stream {
|
||||||
Ok(stream) => stream,
|
Ok(stream) => stream,
|
||||||
Err(_) => continue,
|
Err(_) => continue,
|
||||||
};
|
};
|
||||||
let clients_count = self.globals.clients_count.clone();
|
self.clone().client_serve(stream, server.clone()).await;
|
||||||
if clients_count.increment() > self.globals.max_clients {
|
|
||||||
clients_count.decrement();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let self_inner = self.clone();
|
|
||||||
let server_inner = server.clone();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
tokio::time::timeout(
|
|
||||||
self_inner.globals.timeout,
|
|
||||||
server_inner.serve_connection(stream, self_inner),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
clients_count.decrement();
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
Ok(()) as Result<(), Error>
|
Ok(()) as Result<(), Error>
|
||||||
};
|
};
|
||||||
listener_service.await?;
|
listener_service.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
async fn start_with_tls(
|
||||||
|
self,
|
||||||
|
tls_acceptor: TlsAcceptor,
|
||||||
|
mut listener: TcpListener,
|
||||||
|
server: Http,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let listener_service = async {
|
||||||
|
while let Some(raw_stream) = listener.incoming().next().await {
|
||||||
|
let raw_stream = match raw_stream {
|
||||||
|
Ok(raw_stream) => raw_stream,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
let stream = match tls_acceptor.accept(raw_stream).await {
|
||||||
|
Ok(stream) => stream,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
self.clone().client_serve(stream, server.clone()).await;
|
||||||
|
}
|
||||||
|
Ok(()) as Result<(), Error>
|
||||||
|
};
|
||||||
|
listener_service.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn entrypoint(self) -> Result<(), Error> {
|
||||||
|
let listen_address = self.globals.listen_address;
|
||||||
|
let listener = TcpListener::bind(&listen_address).await?;
|
||||||
|
let path = &self.globals.path;
|
||||||
|
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
let tls_acceptor = match (&self.globals.tls_cert_path, &self.globals.tls_cert_password) {
|
||||||
|
(Some(tls_cert_path), Some(tls_cert_password)) => {
|
||||||
|
Some(create_tls_acceptor(tls_cert_path, tls_cert_password).unwrap())
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
#[cfg(not(feature = "tls"))]
|
||||||
|
let tls_acceptor: Option<()> = None;
|
||||||
|
|
||||||
|
if tls_acceptor.is_some() {
|
||||||
|
println!("Listening on https://{}{}", listen_address, path);
|
||||||
|
} else {
|
||||||
|
println!("Listening on http://{}{}", listen_address, path);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut server = Http::new();
|
||||||
|
server.keep_alive(self.globals.keepalive);
|
||||||
|
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
{
|
||||||
|
if let Some(tls_acceptor) = tls_acceptor {
|
||||||
|
self.start_with_tls(tls_acceptor, listener, server).await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.start_without_tls(listener, server).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue