diff --git a/src/dns.rs b/src/dns.rs new file mode 100644 index 0000000..64360a0 --- /dev/null +++ b/src/dns.rs @@ -0,0 +1,125 @@ +const DNS_CLASS_IN: u16 = 1; +const DNS_HEADER_SIZE: usize = 12; +const DNS_MAX_HOSTNAME_LEN: usize = 256; +const DNS_MAX_PACKET_SIZE: usize = 65_535; +const DNS_OFFSET_QUESTION: usize = DNS_HEADER_SIZE; +const DNS_TYPE_OPT: u16 = 41; + +#[inline] +fn qdcount(packet: &[u8]) -> u16 { + (u16::from(packet[4]) << 8) | u16::from(packet[5]) +} + +#[inline] +fn ancount(packet: &[u8]) -> u16 { + (u16::from(packet[6]) << 8) | u16::from(packet[7]) +} + +#[inline] +fn nscount(packet: &[u8]) -> u16 { + (u16::from(packet[8]) << 8) | u16::from(packet[9]) +} + +#[inline] +fn arcount(packet: &[u8]) -> u16 { + (u16::from(packet[10]) << 8) | u16::from(packet[11]) +} + +fn skip_name(packet: &[u8], offset: usize) -> Result<(usize, u16), &'static str> { + let packet_len = packet.len(); + if offset >= packet_len - 1 { + return Err("Short packet"); + } + let mut name_len: usize = 0; + let mut offset = offset; + let mut labels_count = 0u16; + loop { + let label_len = match packet[offset] { + len if len & 0xc0 == 0xc0 => { + if 2 > packet_len - offset { + return Err("Incomplete offset"); + } + offset += 2; + break; + } + len if len > 0x3f => return Err("Label too long"), + len => len, + } as usize; + if label_len >= packet_len - offset - 1 { + return Err("Malformed packet with an out-of-bounds name"); + } + name_len += label_len + 1; + if name_len > DNS_MAX_HOSTNAME_LEN { + return Err("Name too long"); + } + offset += label_len + 1; + if label_len == 0 { + break; + } + labels_count += 1; + } + Ok((offset, labels_count)) +} + +pub fn min_ttl( + packet: &[u8], + min_ttl: u32, + max_ttl: u32, + failure_ttl: u32, +) -> Result { + if qdcount(packet) != 1 { + return Err("Unsupported number of questions"); + } + let packet_len = packet.len(); + if packet_len <= DNS_OFFSET_QUESTION { + return Err("Short packet"); + } + if packet_len >= DNS_MAX_PACKET_SIZE { + return Err("Large packet"); + } + let mut offset = match skip_name(packet, DNS_OFFSET_QUESTION) { + Ok(offset) => offset.0, + Err(e) => return Err(e), + }; + assert!(offset > DNS_OFFSET_QUESTION); + if 4 > packet_len - offset { + return Err("Short packet"); + } + offset += 4; + let ancount = ancount(packet); + let nscount = nscount(packet); + let arcount = arcount(packet); + let rrcount = ancount + nscount + arcount; + let mut found_min_ttl = if rrcount > 0 { max_ttl } else { failure_ttl }; + for _ in 0..rrcount { + offset = match skip_name(packet, offset) { + Ok(offset) => offset.0, + Err(e) => return Err(e), + }; + if 10 > packet_len - offset { + return Err("Short packet"); + } + let qtype = u16::from(packet[offset]) << 8 | u16::from(packet[offset + 1]); + let qclass = u16::from(packet[offset + 2]) << 8 | u16::from(packet[offset + 3]); + let ttl = u32::from(packet[offset + 4]) << 24 | u32::from(packet[offset + 5]) << 16 + | u32::from(packet[offset + 6]) << 8 | u32::from(packet[offset + 7]); + let rdlen = (u16::from(packet[offset + 8]) << 8 | u16::from(packet[offset + 9])) as usize; + offset += 10; + if !(qtype == DNS_TYPE_OPT && qclass == DNS_CLASS_IN) { + if ttl < found_min_ttl { + found_min_ttl = ttl; + } + } + if rdlen > packet_len - offset { + return Err("Record length would exceed packet length"); + } + offset += rdlen; + } + if found_min_ttl < min_ttl { + found_min_ttl = min_ttl; + } + if offset != packet_len { + return Err("Garbage after packet"); + } + Ok(found_min_ttl) +} diff --git a/src/main.rs b/src/main.rs index e66c927..35d3abb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,11 +10,13 @@ extern crate tokio_core; extern crate tokio_io; extern crate tokio_timer; +mod dns; + use clap::{App, Arg}; use futures::future; use futures::prelude::*; use hyper::{Body, Method, StatusCode}; -use hyper::header::{ContentLength, ContentType}; +use hyper::header::{CacheControl, CacheDirective, ContentLength, ContentType}; use hyper::server::{Http, Request, Response, Service}; use std::cell::RefCell; use std::rc::Rc; @@ -32,6 +34,9 @@ const MAX_DNS_RESPONSE_LEN: usize = 4096; const MIN_DNS_PACKET_LEN: usize = 17; const SERVER_ADDRESS: &str = "9.9.9.9:53"; const TIMEOUT_SEC: u64 = 10; +const MAX_TTL: u32 = 86400 * 7; +const MIN_TTL: u32 = 1; +const ERR_TTL: u32 = 1; #[derive(Clone, Debug)] struct DoH { @@ -53,15 +58,17 @@ impl Service for DoH { let mut response = Response::new(); match (req.method(), req.path()) { (&Method::Post, "/dns-query") => { - let fut = self.body_read(req.body(), self.handle.clone()).map(|body| { - let body_len = body.len(); - response.set_body(body); - response - .with_header(ContentLength(body_len as u64)) - .with_header(ContentType( - "application/dns-udpwireformat".parse().unwrap(), - )) - }); + let fut = self.body_read(req.body(), self.handle.clone()) + .map(|(body, ttl)| { + let body_len = body.len(); + response.set_body(body); + response + .with_header(ContentLength(body_len as u64)) + .with_header(ContentType( + "application/dns-udpwireformat".parse().unwrap(), + )) + .with_header(CacheControl(vec![CacheDirective::MaxAge(ttl)])) + }); return Box::new(fut.map_err(|_| hyper::Error::Incomplete)); } (&Method::Post, _) => { @@ -77,7 +84,7 @@ impl Service for DoH { impl DoH { #[async] - fn body_read(&self, body: Body, handle: Handle) -> Result, ()> { + fn body_read(&self, body: Body, handle: Handle) -> Result<(Vec, u32), ()> { let query = await!( body.concat2() .map_err(|_err| ()) @@ -97,7 +104,8 @@ impl DoH { return Err(()); } packet.truncate(len); - Ok(packet) + let min_ttl = dns::min_ttl(&packet, MIN_TTL, MAX_TTL, ERR_TTL).map_err(|_| {})?; + Ok((packet, min_ttl)) } }