From 05a60818ce8d632a74d3306dd09c99d4b3ee5a5a Mon Sep 17 00:00:00 2001 From: Christopher Wood Date: Mon, 26 Apr 2021 13:05:52 -0700 Subject: [PATCH] Add Oblivious DoH target support as a default feature. This change adds Oblivious DoH (ODoH) target support to doh-server. This change does include support for ODoH key rotation or algorithm agility. ODoH is a default feature and not conditionally compiled out. --- src/constants.rs | 1 + src/libdoh/Cargo.toml | 3 + src/libdoh/src/globals.rs | 3 + src/libdoh/src/lib.rs | 151 +++++++++++++++++++++++++++++++------- src/libdoh/src/odoh.rs | 96 ++++++++++++++++++++++++ src/main.rs | 8 ++ 6 files changed, 234 insertions(+), 28 deletions(-) create mode 100644 src/libdoh/src/odoh.rs diff --git a/src/constants.rs b/src/constants.rs index a784f2a..e7e549c 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -2,6 +2,7 @@ pub const LISTEN_ADDRESS: &str = "127.0.0.1:3000"; pub const MAX_CLIENTS: usize = 512; pub const MAX_CONCURRENT_STREAMS: u32 = 16; pub const PATH: &str = "/dns-query"; +pub const ODOH_CONFIGS_PATH: &str = "/.well-known/odohconfigs"; pub const SERVER_ADDRESS: &str = "9.9.9.9:53"; pub const TIMEOUT_SEC: u64 = 10; pub const MAX_TTL: u32 = 86400 * 7; diff --git a/src/libdoh/Cargo.toml b/src/libdoh/Cargo.toml index be30591..3df9bbe 100644 --- a/src/libdoh/Cargo.toml +++ b/src/libdoh/Cargo.toml @@ -22,6 +22,9 @@ futures = "0.3.13" hyper = { version = "0.14.4", default-features = false, features = ["server", "http1", "http2", "stream"] } tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time", "sync"] } tokio-rustls = { version = "0.22.0", features = ["early-data"], optional = true } +odoh-rs = "0.1.11" +rand = "0.7" +hpke = "0.5.0" [profile.release] codegen-units = 1 diff --git a/src/libdoh/src/globals.rs b/src/libdoh/src/globals.rs index 3fbba84..f5327e8 100644 --- a/src/libdoh/src/globals.rs +++ b/src/libdoh/src/globals.rs @@ -1,3 +1,4 @@ +use crate::odoh::ODoHPublicKey; use std::net::SocketAddr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -28,6 +29,8 @@ pub struct Globals { pub err_ttl: u32, pub keepalive: bool, pub disable_post: bool, + pub odoh_configs_path: String, + pub odoh_public_key: Arc, pub runtime_handle: runtime::Handle, } diff --git a/src/libdoh/src/lib.rs b/src/libdoh/src/lib.rs index 03eb14b..19313b0 100644 --- a/src/libdoh/src/lib.rs +++ b/src/libdoh/src/lib.rs @@ -1,5 +1,6 @@ mod constants; pub mod dns; +pub mod odoh; mod errors; mod globals; #[cfg(feature = "tls")] @@ -25,9 +26,30 @@ pub mod reexports { pub use tokio; } +#[derive(Clone, Debug)] +struct DnsResponse { + packet: Vec, + ttl: u32, +} + +#[derive(Clone, Debug)] +enum DoHType { + Standard, + Oblivious, +} + +impl DoHType { + fn as_str(&self) -> String { + match self { + DoHType::Standard => String::from("application/dns-message"), + DoHType::Oblivious => String::from("application/oblivious-dns-message"), + } + } +} + #[derive(Clone, Debug)] pub struct DoH { - pub globals: Arc, + pub globals: Arc } #[allow(clippy::unnecessary_wraps)] @@ -72,14 +94,20 @@ impl hyper::service::Service> for DoH { fn call(&mut self, req: Request) -> Self::Future { let globals = &self.globals; - if req.uri().path() != globals.path { - return Box::pin(async { http_error(StatusCode::NOT_FOUND) }); - } let self_inner = self.clone(); - match *req.method() { - Method::POST => Box::pin(async move { self_inner.serve_post(req).await }), - Method::GET => Box::pin(async move { self_inner.serve_get(req).await }), - _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), + if req.uri().path() == globals.path { + match *req.method() { + Method::POST => Box::pin(async move { self_inner.serve_post(req).await }), + Method::GET => Box::pin(async move { self_inner.serve_get(req).await }), + _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), + } + } else if req.uri().path() == globals.odoh_configs_path { + match *req.method() { + Method::GET => Box::pin(async move { self_inner.serve_odoh_configs().await }), + _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), + } + } else { + Box::pin(async { http_error(StatusCode::NOT_FOUND) }) } } } @@ -89,12 +117,64 @@ impl DoH { if self.globals.disable_post { return http_error(StatusCode::METHOD_NOT_ALLOWED); } - if let Err(response) = Self::check_content_type(&req) { - return Ok(response); + + match Self::parse_content_type(&req) { + Ok(DoHType::Standard) => self.serve_doh_post(req).await, + Ok(DoHType::Oblivious) => self.serve_odoh_post(req).await, + Err(response) => return Ok(response) } - match self.read_body_and_proxy(req.into_body()).await { + } + + async fn serve_doh_post(&self, req:Request) -> Result, http::Error> { + let query = match self.read_body(req.into_body()).await { + Ok(q) => q, + Err(e) => return http_error(StatusCode::from(e)) + }; + + let resp = match self.proxy(query).await { + Ok(resp) => self.build_response(resp.packet, resp.ttl, DoHType::Standard.as_str()), + Err(e) => return http_error(StatusCode::from(e)), + }; + + match resp { + Ok(resp) => Ok(resp), + Err(e) => http_error(StatusCode::from(e)), + } + } + + // #[cfg(feature = "odoh")] + async fn serve_odoh_post(&self, req:Request) -> Result, http::Error> { + let query_body = match self.read_body(req.into_body()).await { + Ok(q) => q, + Err(e) => return http_error(StatusCode::from(e)) + }; + + let (query, context) = match (*self.globals.odoh_public_key).clone().decrypt_query(query_body).await { + Ok((q, context)) => (q.to_vec(), context), + Err(_) => return http_error(StatusCode::from(DoHError::InvalidData)) + }; + + let resp_body = match self.proxy(query).await { + Ok(resp) => resp, + Err(e) => return http_error(StatusCode::from(e)) + }; + + let resp = match context.encrypt_response(resp_body.packet).await { + Ok(resp) => self.build_response(resp, 0u32, DoHType::Oblivious.as_str()), + Err(e) => return http_error(StatusCode::from(e)), + }; + + match resp { + Ok(resp) => Ok(resp), + Err(e) => http_error(StatusCode::from(e)), + } + } + + async fn serve_odoh_configs(&self) -> Result, http::Error> { + let configs = (*self.globals.odoh_public_key).clone().config(); + match self.build_response(configs, 0, "application/octet-stream".to_string()) { + Ok(resp) => Ok(resp), Err(e) => http_error(StatusCode::from(e)), - Ok(res) => Ok(res), } } @@ -117,13 +197,18 @@ impl DoH { return http_error(StatusCode::BAD_REQUEST); } }; - match self.proxy(question).await { + + let resp = match self.proxy(question).await { + Ok(dns_resp) => self.build_response(dns_resp.packet, dns_resp.ttl, DoHType::Standard.as_str()), + Err(e) => Err(e), + }; + match resp { + Ok(resp) => Ok(resp), Err(e) => http_error(StatusCode::from(e)), - Ok(res) => Ok(res), } } - fn check_content_type(req: &Request) -> Result<(), Response> { + fn parse_content_type(req: &Request) -> Result> { let headers = req.headers(); let content_type = match headers.get(hyper::header::CONTENT_TYPE) { None => { @@ -145,17 +230,21 @@ impl DoH { } Ok(content_type) => content_type.to_lowercase(), }; - if content_type != "application/dns-message" { - let response = Response::builder() - .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) - .body(Body::empty()) - .unwrap(); - return Err(response); + + match content_type.as_str() { + "application/dns-message" => Ok(DoHType::Standard), + "application/oblivious-dns-message" => Ok(DoHType::Oblivious), + _ => { + let response = Response::builder() + .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) + .body(Body::empty()) + .unwrap(); + return Err(response); + } } - Ok(()) } - async fn read_body_and_proxy(&self, mut body: Body) -> Result, DoHError> { + async fn read_body(&self, mut body: Body) -> Result, DoHError> { let mut sum_size = 0; let mut query = vec![]; while let Some(chunk) = body.next().await { @@ -166,17 +255,16 @@ impl DoH { } query.extend(chunk); } - let response = self.proxy(query).await?; - Ok(response) + Ok(query) } - async fn proxy(&self, query: Vec) -> Result, DoHError> { + async fn proxy(&self, query: Vec) -> Result { let proxy_timeout = self.globals.timeout; let timeout_res = tokio::time::timeout(proxy_timeout, self._proxy(query)).await; timeout_res.map_err(|_| DoHError::UpstreamTimeout)? } - async fn _proxy(&self, mut query: Vec) -> Result, DoHError> { + async fn _proxy(&self, mut query: Vec) -> Result { if query.len() < MIN_DNS_PACKET_LEN { return Err(DoHError::Incomplete); } @@ -209,10 +297,17 @@ impl DoH { dns::add_edns_padding(&mut packet) .map_err(|_| DoHError::TooLarge) .ok(); + Ok(DnsResponse{ + packet, + ttl, + }) + } + + fn build_response(&self, packet: Vec, ttl: u32, content_type: String) -> Result, DoHError> { let packet_len = packet.len(); let response = Response::builder() .header(hyper::header::CONTENT_LENGTH, packet_len) - .header(hyper::header::CONTENT_TYPE, "application/dns-message") + .header(hyper::header::CONTENT_TYPE, content_type.as_str()) .header( hyper::header::CACHE_CONTROL, format!( diff --git a/src/libdoh/src/odoh.rs b/src/libdoh/src/odoh.rs new file mode 100644 index 0000000..b7e499c --- /dev/null +++ b/src/libdoh/src/odoh.rs @@ -0,0 +1,96 @@ +use crate::errors::DoHError; +use hpke::kex::Serializable; +use odoh_rs::key_utils::{derive_keypair_from_seed}; +use odoh_rs::protocol::{create_response_msg, + parse_received_query, RESPONSE_NONCE_SIZE, + ObliviousDoHQueryBody, Serialize, + ObliviousDoHKeyPair, ObliviousDoHConfigContents, + ObliviousDoHConfig, ObliviousDoHConfigs +}; +use rand::Rng; +use std::fmt; + +// https://cfrg.github.io/draft-irtf-cfrg-hpke/draft-irtf-cfrg-hpke.html#name-algorithm-identifiers +const DEFAULT_HPKE_SEED_SIZE: usize = 32; +const DEFAULT_HPKE_KEM: u16 = 0x0020; // DHKEM(X25519, HKDF-SHA256) +const DEFAULT_HPKE_KDF: u16 = 0x0001; // KDF(SHA-256) +const DEFAULT_HPKE_AEAD: u16 = 0x0001; // AEAD(AES-GCM-128) + +#[derive(Clone)] +pub struct ODoHPublicKey { + key: ObliviousDoHKeyPair, + serialized_configs: Vec, +} + +impl fmt::Debug for ODoHPublicKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ODoHPublicKey").finish() + } +} + +#[derive(Clone, Debug)] +pub struct ODoHQueryContext { + query: ObliviousDoHQueryBody, + secret: Vec, +} + +fn generate_key_pair() -> ObliviousDoHKeyPair { + let ikm = rand::thread_rng().gen::<[u8; DEFAULT_HPKE_SEED_SIZE]>(); + let (secret_key, public_key) = derive_keypair_from_seed(&ikm); + let public_key_bytes = public_key.to_bytes().to_vec(); + let odoh_public_key = ObliviousDoHConfigContents { + kem_id: DEFAULT_HPKE_KEM, + kdf_id: DEFAULT_HPKE_KDF, + aead_id: DEFAULT_HPKE_AEAD, + public_key: public_key_bytes, + }; + ObliviousDoHKeyPair { + private_key: secret_key, + public_key: odoh_public_key, + } +} + +impl ODoHPublicKey { + pub fn new() -> Result { + let key_pair = generate_key_pair(); + let config = ObliviousDoHConfig::new(&key_pair.public_key.clone().to_bytes().unwrap()).unwrap(); + let serialized_configs = ObliviousDoHConfigs { + configs: vec![config.clone()], + } + .to_bytes() + .unwrap() + .to_vec(); + + Ok(ODoHPublicKey{ + key: key_pair, + serialized_configs: serialized_configs + }) + } + + pub fn config(self) -> Vec { + self.serialized_configs + } + + pub async fn decrypt_query(self, encrypted_query: Vec) -> Result<(Vec, ODoHQueryContext), DoHError> { + let (query, server_secret) = match parse_received_query(&self.key, &encrypted_query).await { + Ok((pq, ss)) => (pq, ss), + Err(_) => return Err(DoHError::InvalidData) + }; + let context = ODoHQueryContext{ + query: query.clone(), + secret: server_secret, + }; + Ok((query.dns_msg.clone(), context)) + } +} + +impl ODoHQueryContext { + pub async fn encrypt_response(self, response_body: Vec) -> Result, DoHError> { + let response_nonce = rand::thread_rng().gen::<[u8; RESPONSE_NONCE_SIZE]>(); + create_response_msg(&self.secret, &response_body, None, Some(response_nonce.to_vec()), &self.query) + .await + .map_err(|_| { + DoHError::InvalidData + }) + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index f7f69a4..fbf249c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ use libdoh::*; use crate::config::*; use crate::constants::*; +use libdoh::odoh::ODoHPublicKey; use libdoh::reexports::tokio; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; @@ -24,6 +25,11 @@ fn main() { runtime_builder.thread_name("doh-proxy"); let runtime = runtime_builder.build().unwrap(); + let odoh_key = match ODoHPublicKey::new() { + Ok(key) => key, + Err(_) => panic!("Failed to generate ODoH public key configuration"), + }; + let mut globals = Globals { #[cfg(feature = "tls")] tls_cert_path: None, @@ -43,6 +49,8 @@ fn main() { err_ttl: ERR_TTL, keepalive: true, disable_post: false, + odoh_configs_path: ODOH_CONFIGS_PATH.to_string(), + odoh_public_key: Arc::new(odoh_key), runtime_handle: runtime.handle().clone(), };