mirror of
https://github.com/DNSCrypt/doh-server.git
synced 2025-04-05 14:07:37 +03:00
parent
474701ec1e
commit
485afd5976
5 changed files with 78 additions and 17 deletions
|
@ -6,3 +6,4 @@ pub const STALE_IF_ERROR_SECS: u32 = 86400;
|
||||||
pub const STALE_WHILE_REVALIDATE_SECS: u32 = 60;
|
pub const STALE_WHILE_REVALIDATE_SECS: u32 = 60;
|
||||||
pub const CERTS_WATCH_DELAY_SECS: u32 = 10;
|
pub const CERTS_WATCH_DELAY_SECS: u32 = 10;
|
||||||
pub const ODOH_KEY_ROTATION_SECS: u32 = 86400;
|
pub const ODOH_KEY_ROTATION_SECS: u32 = 86400;
|
||||||
|
pub const UDP_TCP_RATIO: usize = 8;
|
||||||
|
|
|
@ -2,9 +2,13 @@ use anyhow::{ensure, Error};
|
||||||
use byteorder::{BigEndian, ByteOrder};
|
use byteorder::{BigEndian, ByteOrder};
|
||||||
|
|
||||||
const DNS_HEADER_SIZE: usize = 12;
|
const DNS_HEADER_SIZE: usize = 12;
|
||||||
|
pub const DNS_OFFSET_FLAGS: usize = 2;
|
||||||
const DNS_MAX_HOSTNAME_SIZE: usize = 256;
|
const DNS_MAX_HOSTNAME_SIZE: usize = 256;
|
||||||
const DNS_MAX_PACKET_SIZE: usize = 4096;
|
const DNS_MAX_PACKET_SIZE: usize = 4096;
|
||||||
const DNS_OFFSET_QUESTION: usize = DNS_HEADER_SIZE;
|
const DNS_OFFSET_QUESTION: usize = DNS_HEADER_SIZE;
|
||||||
|
|
||||||
|
const DNS_FLAGS_TC: u16 = 1u16 << 9;
|
||||||
|
|
||||||
const DNS_TYPE_OPT: u16 = 41;
|
const DNS_TYPE_OPT: u16 = 41;
|
||||||
|
|
||||||
const DNS_PTYPE_PADDING: u16 = 12;
|
const DNS_PTYPE_PADDING: u16 = 12;
|
||||||
|
@ -51,6 +55,11 @@ pub fn is_recoverable_error(packet: &[u8]) -> bool {
|
||||||
rcode == DNS_RCODE_SERVFAIL || rcode == DNS_RCODE_REFUSED
|
rcode == DNS_RCODE_SERVFAIL || rcode == DNS_RCODE_REFUSED
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn is_truncated(packet: &[u8]) -> bool {
|
||||||
|
BigEndian::read_u16(&packet[DNS_OFFSET_FLAGS..]) & DNS_FLAGS_TC == DNS_FLAGS_TC
|
||||||
|
}
|
||||||
|
|
||||||
fn skip_name(packet: &[u8], offset: usize) -> Result<usize, Error> {
|
fn skip_name(packet: &[u8], offset: usize) -> Result<usize, Error> {
|
||||||
let packet_len = packet.len();
|
let packet_len = packet.len();
|
||||||
ensure!(offset < packet_len - 1, "Short packet");
|
ensure!(offset < packet_len - 1, "Short packet");
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use std::io;
|
use std::io;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum DoHError {
|
pub enum DoHError {
|
||||||
Incomplete,
|
Incomplete,
|
||||||
|
@ -13,6 +12,7 @@ pub enum DoHError {
|
||||||
Hyper(hyper::Error),
|
Hyper(hyper::Error),
|
||||||
Io(io::Error),
|
Io(io::Error),
|
||||||
ODoHConfigError(anyhow::Error),
|
ODoHConfigError(anyhow::Error),
|
||||||
|
TooManyTcpSessions,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::error::Error for DoHError {}
|
impl std::error::Error for DoHError {}
|
||||||
|
@ -29,6 +29,7 @@ impl std::fmt::Display for DoHError {
|
||||||
DoHError::Hyper(e) => write!(fmt, "HTTP error: {}", e),
|
DoHError::Hyper(e) => write!(fmt, "HTTP error: {}", e),
|
||||||
DoHError::Io(e) => write!(fmt, "IO error: {}", e),
|
DoHError::Io(e) => write!(fmt, "IO error: {}", e),
|
||||||
DoHError::ODoHConfigError(e) => write!(fmt, "ODoH config error: {}", e),
|
DoHError::ODoHConfigError(e) => write!(fmt, "ODoH config error: {}", e),
|
||||||
|
DoHError::TooManyTcpSessions => write!(fmt, "Too many TCP sessions"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -45,6 +46,7 @@ impl From<DoHError> for StatusCode {
|
||||||
DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE,
|
DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE,
|
||||||
DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
DoHError::ODoHConfigError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
DoHError::ODoHConfigError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
DoHError::TooManyTcpSessions => StatusCode::SERVICE_UNAVAILABLE,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,6 +40,10 @@ pub struct Globals {
|
||||||
pub struct ClientsCount(Arc<AtomicUsize>);
|
pub struct ClientsCount(Arc<AtomicUsize>);
|
||||||
|
|
||||||
impl ClientsCount {
|
impl ClientsCount {
|
||||||
|
pub fn current(&self) -> usize {
|
||||||
|
self.0.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn increment(&self) -> usize {
|
pub fn increment(&self) -> usize {
|
||||||
self.0.fetch_add(1, Ordering::Relaxed)
|
self.0.fetch_add(1, Ordering::Relaxed)
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,16 +10,18 @@ use crate::constants::*;
|
||||||
pub use crate::errors::*;
|
pub use crate::errors::*;
|
||||||
pub use crate::globals::*;
|
pub use crate::globals::*;
|
||||||
|
|
||||||
|
use byteorder::{BigEndian, ByteOrder};
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
use futures::task::{Context, Poll};
|
use futures::task::{Context, Poll};
|
||||||
use hyper::http;
|
use hyper::http;
|
||||||
use hyper::server::conn::Http;
|
use hyper::server::conn::Http;
|
||||||
use hyper::{Body, HeaderMap, Method, Request, Response, StatusCode};
|
use hyper::{Body, HeaderMap, Method, Request, Response, StatusCode};
|
||||||
|
use std::net::SocketAddr;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, UdpSocket};
|
use tokio::net::{TcpListener, TcpSocket, UdpSocket};
|
||||||
use tokio::runtime;
|
use tokio::runtime;
|
||||||
|
|
||||||
pub mod reexports {
|
pub mod reexports {
|
||||||
|
@ -340,22 +342,65 @@ impl DoH {
|
||||||
}
|
}
|
||||||
let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as _);
|
let _ = dns::set_edns_max_payload_size(&mut query, MAX_DNS_RESPONSE_LEN as _);
|
||||||
let globals = &self.globals;
|
let globals = &self.globals;
|
||||||
let socket = UdpSocket::bind(&globals.local_bind_address)
|
|
||||||
.await
|
|
||||||
.map_err(DoHError::Io)?;
|
|
||||||
let expected_server_address = globals.server_address;
|
|
||||||
let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl);
|
|
||||||
socket
|
|
||||||
.send_to(&query, &globals.server_address)
|
|
||||||
.map_err(DoHError::Io)
|
|
||||||
.await?;
|
|
||||||
let mut packet = vec![0; MAX_DNS_RESPONSE_LEN];
|
let mut packet = vec![0; MAX_DNS_RESPONSE_LEN];
|
||||||
let (len, response_server_address) =
|
let (min_ttl, max_ttl, err_ttl) = (globals.min_ttl, globals.max_ttl, globals.err_ttl);
|
||||||
socket.recv_from(&mut packet).map_err(DoHError::Io).await?;
|
|
||||||
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
|
// UDP
|
||||||
return Err(DoHError::UpstreamIssue);
|
{
|
||||||
|
let socket = UdpSocket::bind(&globals.local_bind_address)
|
||||||
|
.await
|
||||||
|
.map_err(DoHError::Io)?;
|
||||||
|
let expected_server_address = globals.server_address;
|
||||||
|
socket
|
||||||
|
.send_to(&query, &globals.server_address)
|
||||||
|
.map_err(DoHError::Io)
|
||||||
|
.await?;
|
||||||
|
let (len, response_server_address) =
|
||||||
|
socket.recv_from(&mut packet).map_err(DoHError::Io).await?;
|
||||||
|
if len < MIN_DNS_PACKET_LEN || expected_server_address != response_server_address {
|
||||||
|
return Err(DoHError::UpstreamIssue);
|
||||||
|
}
|
||||||
|
packet.truncate(len);
|
||||||
}
|
}
|
||||||
packet.truncate(len);
|
|
||||||
|
// TCP
|
||||||
|
if dns::is_truncated(&packet) {
|
||||||
|
let clients_count = self.globals.clients_count.current();
|
||||||
|
if self.globals.max_clients >= UDP_TCP_RATIO
|
||||||
|
&& clients_count >= self.globals.max_clients / UDP_TCP_RATIO
|
||||||
|
{
|
||||||
|
return Err(DoHError::TooManyTcpSessions);
|
||||||
|
}
|
||||||
|
let socket = match globals.server_address {
|
||||||
|
SocketAddr::V4(_) => TcpSocket::new_v4(),
|
||||||
|
SocketAddr::V6(_) => TcpSocket::new_v6(),
|
||||||
|
}
|
||||||
|
.map_err(DoHError::Io)?;
|
||||||
|
let mut ext_socket = socket
|
||||||
|
.connect(globals.server_address)
|
||||||
|
.await
|
||||||
|
.map_err(DoHError::Io)?;
|
||||||
|
ext_socket.set_nodelay(true).map_err(DoHError::Io)?;
|
||||||
|
let mut binlen = [0u8, 0];
|
||||||
|
BigEndian::write_u16(&mut binlen, query.len() as u16);
|
||||||
|
ext_socket.write_all(&binlen).await.map_err(DoHError::Io)?;
|
||||||
|
ext_socket.write_all(&query).await.map_err(DoHError::Io)?;
|
||||||
|
ext_socket.flush().await.map_err(DoHError::Io)?;
|
||||||
|
ext_socket
|
||||||
|
.read_exact(&mut binlen)
|
||||||
|
.await
|
||||||
|
.map_err(DoHError::Io)?;
|
||||||
|
let packet_len = BigEndian::read_u16(&binlen) as usize;
|
||||||
|
if packet_len < MIN_DNS_PACKET_LEN || packet_len > MAX_DNS_RESPONSE_LEN {
|
||||||
|
return Err(DoHError::UpstreamIssue);
|
||||||
|
}
|
||||||
|
packet = vec![0u8; packet_len];
|
||||||
|
ext_socket
|
||||||
|
.read_exact(&mut packet)
|
||||||
|
.await
|
||||||
|
.map_err(DoHError::Io)?;
|
||||||
|
}
|
||||||
|
|
||||||
let ttl = if dns::is_recoverable_error(&packet) {
|
let ttl = if dns::is_recoverable_error(&packet) {
|
||||||
err_ttl
|
err_ttl
|
||||||
} else {
|
} else {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue