diff --git a/Cargo.toml b/Cargo.toml index 53e3548..4b14aa6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,8 @@ default = [] tls = ["native-tls", "tokio-tls"] [dependencies] +anyhow = "1.0" +byteorder = "1.3" base64 = "0.11" clap = "2.33.0" futures = { version = "0.3" } diff --git a/src/dns.rs b/src/dns.rs index b06256a..4d86a65 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -1,5 +1,8 @@ +use anyhow::{ensure, Error}; +use byteorder::{BigEndian, ByteOrder}; + const DNS_HEADER_SIZE: usize = 12; -const DNS_MAX_HOSTNAME_LEN: usize = 256; +const DNS_MAX_HOSTNAME_SIZE: usize = 256; const DNS_MAX_PACKET_SIZE: usize = 65_535; const DNS_OFFSET_QUESTION: usize = DNS_HEADER_SIZE; const DNS_TYPE_OPT: u16 = 41; @@ -7,172 +10,133 @@ const DNS_TYPE_OPT: u16 = 41; const DNS_RCODE_SERVFAIL: u8 = 2; const DNS_RCODE_REFUSED: u8 = 5; -#[inline] -fn qdcount(packet: &[u8]) -> u16 { - (u16::from(packet[4]) << 8) | u16::from(packet[5]) -} - -#[inline] -fn ancount(packet: &[u8]) -> u16 { - (u16::from(packet[6]) << 8) | u16::from(packet[7]) -} - -#[inline] -fn nscount(packet: &[u8]) -> u16 { - (u16::from(packet[8]) << 8) | u16::from(packet[9]) -} - -#[inline] -fn arcount(packet: &[u8]) -> u16 { - (u16::from(packet[10]) << 8) | u16::from(packet[11]) -} - #[inline] pub fn rcode(packet: &[u8]) -> u8 { packet[3] & 0x0f } +#[inline] +pub fn qdcount(packet: &[u8]) -> u16 { + BigEndian::read_u16(&packet[4..]) +} + +#[inline] +pub fn ancount(packet: &[u8]) -> u16 { + BigEndian::read_u16(&packet[6..]) +} + +#[inline] +pub fn arcount(packet: &[u8]) -> u16 { + BigEndian::read_u16(&packet[10..]) +} + +fn arcount_inc(packet: &mut [u8]) -> Result<(), Error> { + let mut arcount = arcount(packet); + ensure!(arcount < 0xffff, "Too many additional records"); + arcount += 1; + BigEndian::write_u16(&mut packet[10..], arcount); + Ok(()) +} + +#[inline] +fn nscount(packet: &[u8]) -> u16 { + BigEndian::read_u16(&packet[8..]) +} + +#[inline] pub fn is_recoverable_error(packet: &[u8]) -> bool { let rcode = rcode(packet); rcode == DNS_RCODE_SERVFAIL || rcode == DNS_RCODE_REFUSED } -fn arcount_inc(packet: &mut [u8]) -> Result<(), &'static str> { - let mut arcount = arcount(packet); - if arcount == 0xffff { - return Err("Too many additional records"); - } - arcount += 1; - packet[10] = (arcount >> 8) as u8; - packet[11] = arcount as u8; - Ok(()) -} - -fn skip_name(packet: &[u8], offset: usize) -> Result<(usize, u16), &'static str> { +fn skip_name(packet: &[u8], offset: usize) -> Result { let packet_len = packet.len(); - if offset >= packet_len - 1 { - return Err("Short packet"); - } - let mut name_len: usize = 0; + ensure!(offset < packet_len - 1, "Short packet"); + let mut qname_len: usize = 0; let mut offset = offset; - let mut labels_count = 0u16; loop { - let label_len = match packet[offset] { - len if len & 0xc0 == 0xc0 => { - if 2 > packet_len - offset { - return Err("Incomplete offset"); - } + let label_len = match packet[offset] as usize { + label_len if label_len & 0xc0 == 0xc0 => { + ensure!(packet_len - offset >= 2, "Incomplete offset"); offset += 2; break; } - len if len > 0x3f => return Err("Label too long"), - len => len, + label_len => label_len, } as usize; - if label_len >= packet_len - offset - 1 { - return Err("Malformed packet with an out-of-bounds name"); - } - name_len += label_len + 1; - if name_len > DNS_MAX_HOSTNAME_LEN { - return Err("Name too long"); - } + ensure!( + packet_len - offset - 1 > label_len, + "Malformed packet with an out-of-bounds name" + ); + qname_len += label_len + 1; + ensure!(qname_len <= DNS_MAX_HOSTNAME_SIZE, "Name too long"); offset += label_len + 1; if label_len == 0 { break; } - labels_count += 1; } - Ok((offset, labels_count)) + Ok(offset) } -fn traverse_rrs Result<(), &'static str>>( +fn traverse_rrs Result<(), Error>>( packet: &[u8], mut offset: usize, rrcount: u16, mut cb: F, -) -> Result { +) -> Result { let packet_len = packet.len(); for _ in 0..rrcount { - offset = match skip_name(packet, offset) { - Ok(offset) => offset.0, - Err(e) => return Err(e), - }; - if 10 > packet_len - offset { - return Err("Short packet"); - } + offset = skip_name(packet, offset)?; + ensure!(packet_len - offset >= 10, "Short packet"); cb(offset)?; - let rdlen = (u16::from(packet[offset + 8]) << 8 | u16::from(packet[offset + 9])) as usize; + let rdlen = BigEndian::read_u16(&packet[offset + 8..]) as usize; offset += 10; - if rdlen > packet_len - offset { - return Err("Record length would exceed packet length"); - } + ensure!( + packet_len - offset >= rdlen, + "Record length would exceed packet length" + ); offset += rdlen; } Ok(offset) } -fn traverse_rrs_mut Result<(), &'static str>>( +fn traverse_rrs_mut Result<(), Error>>( packet: &mut [u8], mut offset: usize, rrcount: u16, mut cb: F, -) -> Result { +) -> Result { let packet_len = packet.len(); for _ in 0..rrcount { - offset = match skip_name(packet, offset) { - Ok(offset) => offset.0, - Err(e) => return Err(e), - }; - if 10 > packet_len - offset { - return Err("Short packet"); - } + offset = skip_name(packet, offset)?; + ensure!(packet_len - offset >= 10, "Short packet"); cb(packet, offset)?; - let rdlen = (u16::from(packet[offset + 8]) << 8 | u16::from(packet[offset + 9])) as usize; + let rdlen = BigEndian::read_u16(&packet[offset + 8..]) as usize; offset += 10; - if rdlen > packet_len - offset { - return Err("Record length would exceed packet length"); - } + ensure!( + packet_len - offset >= rdlen, + "Record length would exceed packet length" + ); offset += rdlen; } Ok(offset) } -pub(crate) fn min_ttl( - packet: &[u8], - min_ttl: u32, - max_ttl: u32, - failure_ttl: u32, -) -> Result { - if qdcount(packet) != 1 { - return Err("Unsupported number of questions"); - } +pub fn min_ttl(packet: &[u8], min_ttl: u32, max_ttl: u32, failure_ttl: u32) -> Result { + ensure!(qdcount(packet) == 1, "Unsupported number of questions"); let packet_len = packet.len(); - if packet_len <= DNS_OFFSET_QUESTION { - return Err("Short packet"); - } - if packet_len >= DNS_MAX_PACKET_SIZE { - return Err("Large packet"); - } - let mut offset = match skip_name(packet, DNS_OFFSET_QUESTION) { - Ok(offset) => offset.0, - Err(e) => return Err(e), - }; + ensure!(packet_len > DNS_OFFSET_QUESTION, "Short packet"); + ensure!(packet_len <= DNS_MAX_PACKET_SIZE, "Large packet"); + let mut offset = skip_name(packet, DNS_OFFSET_QUESTION)?; assert!(offset > DNS_OFFSET_QUESTION); - if 4 > packet_len - offset { - return Err("Short packet"); - } + ensure!(packet_len - offset > 4, "Short packet"); offset += 4; - let ancount = ancount(packet); - let nscount = nscount(packet); - let arcount = arcount(packet); + let (ancount, nscount, arcount) = (ancount(packet), nscount(packet), arcount(packet)); let rrcount = ancount + nscount + arcount; - let mut found_min_ttl = if rrcount > 0 { max_ttl } else { failure_ttl }; + offset = traverse_rrs(packet, offset, rrcount, |offset| { - let qtype = u16::from(packet[offset]) << 8 | u16::from(packet[offset + 1]); - let ttl = u32::from(packet[offset + 4]) << 24 - | u32::from(packet[offset + 5]) << 16 - | u32::from(packet[offset + 6]) << 8 - | u32::from(packet[offset + 7]); + let qtype = BigEndian::read_u16(&packet[offset..]); + let ttl = BigEndian::read_u32(&packet[offset + 4..]); if qtype != DNS_TYPE_OPT && ttl < found_min_ttl { found_min_ttl = ttl; } @@ -181,13 +145,11 @@ pub(crate) fn min_ttl( if found_min_ttl < min_ttl { found_min_ttl = min_ttl; } - if offset != packet_len { - return Err("Garbage after packet"); - } + ensure!(packet_len == offset, "Garbage after packet"); Ok(found_min_ttl) } -fn add_edns_section(packet: &mut Vec, max_payload_size: u16) -> Result<(), &'static str> { +fn add_edns_section(packet: &mut Vec, max_payload_size: u16) -> Result<(), Error> { let opt_rr: [u8; 11] = [ 0, (DNS_TYPE_OPT >> 8) as u8, @@ -201,61 +163,40 @@ fn add_edns_section(packet: &mut Vec, max_payload_size: u16) -> Result<(), & 0, 0, ]; - if DNS_MAX_PACKET_SIZE - packet.len() < opt_rr.len() { - return Err("Packet would be too large to add a new record"); - } + ensure!( + DNS_MAX_PACKET_SIZE - packet.len() >= opt_rr.len(), + "Packet would be too large to add a new record" + ); arcount_inc(packet)?; packet.extend(&opt_rr); Ok(()) } -pub(crate) fn set_edns_max_payload_size( - packet: &mut Vec, - max_payload_size: u16, -) -> Result<(), &'static str> { - if qdcount(packet) != 1 { - return Err("Unsupported number of questions"); - } +pub fn set_edns_max_payload_size(packet: &mut Vec, max_payload_size: u16) -> Result<(), Error> { + ensure!(qdcount(packet) == 1, "Unsupported number of questions"); let packet_len = packet.len(); - if packet_len <= DNS_OFFSET_QUESTION { - return Err("Short packet"); - } - if packet_len >= DNS_MAX_PACKET_SIZE { - return Err("Large packet"); - } - let mut offset = match skip_name(packet, DNS_OFFSET_QUESTION) { - Ok(offset) => offset.0, - Err(e) => return Err(e), - }; + ensure!(packet_len > DNS_OFFSET_QUESTION, "Short packet"); + ensure!(packet_len <= DNS_MAX_PACKET_SIZE, "Large packet"); + + let mut offset = skip_name(packet, DNS_OFFSET_QUESTION)?; assert!(offset > DNS_OFFSET_QUESTION); - if 4 > packet_len - offset { - return Err("Short packet"); - } + ensure!(packet_len - offset >= 4, "Short packet"); offset += 4; - let ancount = ancount(packet); - let nscount = nscount(packet); - let arcount = arcount(packet); - + let (ancount, nscount, arcount) = (ancount(packet), nscount(packet), arcount(packet)); offset = traverse_rrs(packet, offset, ancount + nscount, |_offset| Ok(()))?; - let mut edns_payload_set = false; traverse_rrs_mut(packet, offset, arcount, |packet, offset| { - let qtype = u16::from(packet[offset]) << 8 | u16::from(packet[offset + 1]); + let qtype = BigEndian::read_u16(&packet[offset..]); if qtype == DNS_TYPE_OPT { - if edns_payload_set { - return Err("Duplicate OPT RR found"); - } - packet[offset + 2] = (max_payload_size >> 8) as u8; - packet[offset + 3] = max_payload_size as u8; + ensure!(!edns_payload_set, "Duplicate OPT RR found"); + BigEndian::write_u16(&mut packet[offset + 2..], max_payload_size); edns_payload_set = true; } Ok(()) })?; - if edns_payload_set { return Ok(()); } add_edns_section(packet, max_payload_size)?; - Ok(()) }