diff --git a/src/libdoh/src/errors.rs b/src/libdoh/src/errors.rs index 73cecff..e4e1e59 100644 --- a/src/libdoh/src/errors.rs +++ b/src/libdoh/src/errors.rs @@ -8,6 +8,7 @@ pub enum DoHError { InvalidData, TooLarge, UpstreamIssue, + UpstreamTimeout, Hyper(hyper::Error), Io(io::Error), } @@ -25,6 +26,7 @@ impl std::error::Error for DoHError { DoHError::InvalidData => "Invalid data", DoHError::TooLarge => "Too large", DoHError::UpstreamIssue => "Upstream error", + DoHError::UpstreamTimeout => "Upstream timeout", DoHError::Hyper(_) => self.description(), DoHError::Io(_) => self.description(), } @@ -38,6 +40,7 @@ impl From for StatusCode { DoHError::InvalidData => StatusCode::BAD_REQUEST, DoHError::TooLarge => StatusCode::PAYLOAD_TOO_LARGE, DoHError::UpstreamIssue => StatusCode::BAD_GATEWAY, + DoHError::UpstreamTimeout => StatusCode::BAD_GATEWAY, DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE, DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR, } diff --git a/src/libdoh/src/lib.rs b/src/libdoh/src/lib.rs index 6c52c2d..52d25f5 100644 --- a/src/libdoh/src/lib.rs +++ b/src/libdoh/src/lib.rs @@ -19,6 +19,7 @@ use hyper::server::conn::Http; use hyper::{Body, Method, Request, Response, StatusCode}; use std::pin::Pin; use std::sync::Arc; +use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, UdpSocket}; use tokio::runtime; @@ -182,8 +183,20 @@ impl DoH { .map_err(DoHError::Io) .await?; let mut packet = vec![0; MAX_DNS_RESPONSE_LEN]; - let (len, response_server_address) = - socket.recv_from(&mut packet).map_err(DoHError::Io).await?; + let socket_timeout = self + .globals + .timeout + .checked_sub(Duration::from_secs(1)) + .unwrap_or(self.globals.timeout); + let timeout_res = tokio::time::timeout( + socket_timeout, + socket.recv_from(&mut packet).map_err(DoHError::Io), + ) + .await; + let (len, response_server_address) = match timeout_res { + Err(_) => return Err(DoHError::UpstreamTimeout), + Ok(recv_res) => recv_res?, + }; if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address { return Err(DoHError::UpstreamIssue); }