Merge branch 'master' of github.com:jedisct1/rust-doh

* 'master' of github.com:jedisct1/rust-doh:
  Add retries over TCP
This commit is contained in:
Frank Denis 2021-06-10 22:28:33 +02:00
commit f4cc9bb0f9
5 changed files with 78 additions and 17 deletions

View file

@ -6,3 +6,4 @@ pub const STALE_IF_ERROR_SECS: u32 = 86400;
pub const STALE_WHILE_REVALIDATE_SECS: u32 = 60;
pub const CERTS_WATCH_DELAY_SECS: u32 = 10;
pub const ODOH_KEY_ROTATION_SECS: u32 = 86400;
pub const UDP_TCP_RATIO: usize = 8;

View file

@ -2,9 +2,13 @@ use anyhow::{ensure, Error};
use byteorder::{BigEndian, ByteOrder};
const DNS_HEADER_SIZE: usize = 12;
pub const DNS_OFFSET_FLAGS: usize = 2;
const DNS_MAX_HOSTNAME_SIZE: usize = 256;
const DNS_MAX_PACKET_SIZE: usize = 4096;
const DNS_OFFSET_QUESTION: usize = DNS_HEADER_SIZE;
const DNS_FLAGS_TC: u16 = 1u16 << 9;
const DNS_TYPE_OPT: u16 = 41;
const DNS_PTYPE_PADDING: u16 = 12;
@ -51,6 +55,11 @@ pub fn is_recoverable_error(packet: &[u8]) -> bool {
rcode == DNS_RCODE_SERVFAIL || rcode == DNS_RCODE_REFUSED
}
#[inline]
pub fn is_truncated(packet: &[u8]) -> bool {
BigEndian::read_u16(&packet[DNS_OFFSET_FLAGS..]) & DNS_FLAGS_TC == DNS_FLAGS_TC
}
fn skip_name(packet: &[u8], offset: usize) -> Result<usize, Error> {
let packet_len = packet.len();
ensure!(offset < packet_len - 1, "Short packet");

View file

@ -1,7 +1,6 @@
use hyper::StatusCode;
use std::io;
#[allow(dead_code)]
#[derive(Debug)]
pub enum DoHError {
Incomplete,
@ -13,6 +12,7 @@ pub enum DoHError {
Hyper(hyper::Error),
Io(io::Error),
ODoHConfigError(anyhow::Error),
TooManyTcpSessions,
}
impl std::error::Error for DoHError {}
@ -29,6 +29,7 @@ impl std::fmt::Display for DoHError {
DoHError::Hyper(e) => write!(fmt, "HTTP error: {}", e),
DoHError::Io(e) => write!(fmt, "IO error: {}", e),
DoHError::ODoHConfigError(e) => write!(fmt, "ODoH config error: {}", e),
DoHError::TooManyTcpSessions => write!(fmt, "Too many TCP sessions"),
}
}
}
@ -45,6 +46,7 @@ impl From<DoHError> for StatusCode {
DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE,
DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
DoHError::ODoHConfigError(_) => StatusCode::INTERNAL_SERVER_ERROR,
DoHError::TooManyTcpSessions => StatusCode::SERVICE_UNAVAILABLE,
}
}
}

View file

@ -40,6 +40,10 @@ pub struct Globals {
pub struct ClientsCount(Arc<AtomicUsize>);
impl ClientsCount {
pub fn current(&self) -> usize {
self.0.load(Ordering::Relaxed)
}
pub fn increment(&self) -> usize {
self.0.fetch_add(1, Ordering::Relaxed)
}

View file

@ -10,16 +10,18 @@ use crate::constants::*;
pub use crate::errors::*;
pub use crate::globals::*;
use byteorder::{BigEndian, ByteOrder};
use futures::prelude::*;
use futures::task::{Context, Poll};
use hyper::http;
use hyper::server::conn::Http;
use hyper::{Body, HeaderMap, Method, Request, Response, StatusCode};
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, UdpSocket};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpSocket, UdpSocket};
use tokio::runtime;
pub mod reexports {
@ -340,22 +342,65 @@ impl DoH {
}
let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as _);
let globals = &self.globals;
let socket = UdpSocket::bind(&globals.local_bind_address)
.await
.map_err(DoHError::Io)?;
let expected_server_address = globals.server_address;
let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl);
socket
.send_to(&query, &globals.server_address)
.map_err(DoHError::Io)
.await?;
let mut packet = vec![0; MAX_DNS_RESPONSE_LEN];
let (len, response_server_address) =
socket.recv_from(&mut packet).map_err(DoHError::Io).await?;
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
return Err(DoHError::UpstreamIssue);
let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl);
// UDP
{
let socket = UdpSocket::bind(&globals.local_bind_address)
.await
.map_err(DoHError::Io)?;
let expected_server_address = globals.server_address;
socket
.send_to(&query, &globals.server_address)
.map_err(DoHError::Io)
.await?;
let (len, response_server_address) =
socket.recv_from(&mut packet).map_err(DoHError::Io).await?;
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
return Err(DoHError::UpstreamIssue);
}
packet.truncate(len);
}
packet.truncate(len);
// TCP
if dns::is_truncated(&packet) {
let clients_count = self.globals.clients_count.current();
if self.globals.max_clients >= UDP_TCP_RATIO
&& clients_count >= self.globals.max_clients / UDP_TCP_RATIO
{
return Err(DoHError::TooManyTcpSessions);
}
let socket = match globals.server_address {
SocketAddr::V4(_) => TcpSocket::new_v4(),
SocketAddr::V6(_) => TcpSocket::new_v6(),
}
.map_err(DoHError::Io)?;
let mut ext_socket = socket
.connect(globals.server_address)
.await
.map_err(DoHError::Io)?;
ext_socket.set_nodelay(true).map_err(DoHError::Io)?;
let mut binlen = [0u8, 0];
BigEndian::write_u16(&mut binlen, query.len() as u16);
ext_socket.write_all(&binlen).await.map_err(DoHError::Io)?;
ext_socket.write_all(&query).await.map_err(DoHError::Io)?;
ext_socket.flush().await.map_err(DoHError::Io)?;
ext_socket
.read_exact(&mut binlen)
.await
.map_err(DoHError::Io)?;
let packet_len = BigEndian::read_u16(&binlen) as usize;
if packet_len < MIN_DNS_PACKET_LEN || packet_len > MAX_DNS_RESPONSE_LEN {
return Err(DoHError::UpstreamIssue);
}
packet = vec![0u8; packet_len];
ext_socket
.read_exact(&mut packet)
.await
.map_err(DoHError::Io)?;
}
let ttl = if dns::is_recoverable_error(&packet) {
err_ttl
} else {