diff --git a/src/libdoh/src/lib.rs b/src/libdoh/src/lib.rs index 52d25f5..2410331 100644 --- a/src/libdoh/src/lib.rs +++ b/src/libdoh/src/lib.rs @@ -167,7 +167,13 @@ impl DoH { Ok(response) } - async fn proxy(&self, mut query: Vec) -> Result, DoHError> { + async fn proxy(&self, query: Vec) -> Result, DoHError> { + let proxy_timeout = self.globals.timeout; + let timeout_res = tokio::time::timeout(proxy_timeout, self._proxy(query)).await; + timeout_res.map_err(|_| DoHError::UpstreamTimeout)? + } + + async fn _proxy(&self, mut query: Vec) -> Result, DoHError> { if query.len() < MIN_DNS_PACKET_LEN { return Err(DoHError::Incomplete); } @@ -183,20 +189,8 @@ impl DoH { .map_err(DoHError::Io) .await?; let mut packet = vec![0; MAX_DNS_RESPONSE_LEN]; - let socket_timeout = self - .globals - .timeout - .checked_sub(Duration::from_secs(1)) - .unwrap_or(self.globals.timeout); - let timeout_res = tokio::time::timeout( - socket_timeout, - socket.recv_from(&mut packet).map_err(DoHError::Io), - ) - .await; - let (len, response_server_address) = match timeout_res { - Err(_) => return Err(DoHError::UpstreamTimeout), - Ok(recv_res) => recv_res?, - }; + let (len, response_server_address) = + socket.recv_from(&mut packet).map_err(DoHError::Io).await?; if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address { return Err(DoHError::UpstreamIssue); } @@ -233,9 +227,12 @@ impl DoH { return; } self.globals.runtime_handle.clone().spawn(async move { - tokio::time::timeout(self.globals.timeout, server.serve_connection(stream, self)) - .await - .ok(); + tokio::time::timeout( + self.globals.timeout + Duration::from_secs(1), + server.serve_connection(stream, self), + ) + .await + .ok(); clients_count.decrement(); }); }