diff --git a/src/constants.rs b/src/constants.rs index a784f2a..e7e549c 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -2,6 +2,7 @@ pub const LISTEN_ADDRESS: &str = "127.0.0.1:3000"; pub const MAX_CLIENTS: usize = 512; pub const MAX_CONCURRENT_STREAMS: u32 = 16; pub const PATH: &str = "/dns-query"; +pub const ODOH_CONFIGS_PATH: &str = "/.well-known/odohconfigs"; pub const SERVER_ADDRESS: &str = "9.9.9.9:53"; pub const TIMEOUT_SEC: u64 = 10; pub const MAX_TTL: u32 = 86400 * 7; diff --git a/src/libdoh/Cargo.toml b/src/libdoh/Cargo.toml index be30591..a24df07 100644 --- a/src/libdoh/Cargo.toml +++ b/src/libdoh/Cargo.toml @@ -22,6 +22,10 @@ futures = "0.3.13" hyper = { version = "0.14.4", default-features = false, features = ["server", "http1", "http2", "stream"] } tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time", "sync"] } tokio-rustls = { version = "0.22.0", features = ["early-data"], optional = true } +odoh-rs = "0.1.11" +rand = "0.7" +hpke = "0.5.0" +arc-swap = "1.2.0" [profile.release] codegen-units = 1 diff --git a/src/libdoh/src/constants.rs b/src/libdoh/src/constants.rs index f1dccbb..976175a 100644 --- a/src/libdoh/src/constants.rs +++ b/src/libdoh/src/constants.rs @@ -5,3 +5,4 @@ pub const MIN_DNS_PACKET_LEN: usize = 17; pub const STALE_IF_ERROR_SECS: u32 = 86400; pub const STALE_WHILE_REVALIDATE_SECS: u32 = 60; pub const CERTS_WATCH_DELAY_SECS: u32 = 10; +pub const ODOH_KEY_ROTATION_SECS: u32 = 86400; diff --git a/src/libdoh/src/errors.rs b/src/libdoh/src/errors.rs index 707b6bb..55ec2fb 100644 --- a/src/libdoh/src/errors.rs +++ b/src/libdoh/src/errors.rs @@ -9,6 +9,7 @@ pub enum DoHError { TooLarge, UpstreamIssue, UpstreamTimeout, + StaleKey, Hyper(hyper::Error), Io(io::Error), } @@ -23,6 +24,7 @@ impl std::fmt::Display for DoHError { DoHError::TooLarge => write!(fmt, "Too large"), DoHError::UpstreamIssue => write!(fmt, "Upstream error"), DoHError::UpstreamTimeout => write!(fmt, "Upstream timeout"), + DoHError::StaleKey => write!(fmt, "Stale key material"), DoHError::Hyper(e) => write!(fmt, "HTTP error: {}", e), DoHError::Io(e) => write!(fmt, "IO error: {}", e), } @@ -37,6 +39,7 @@ impl From for StatusCode { DoHError::TooLarge => StatusCode::PAYLOAD_TOO_LARGE, DoHError::UpstreamIssue => StatusCode::BAD_GATEWAY, DoHError::UpstreamTimeout => StatusCode::BAD_GATEWAY, + DoHError::StaleKey => StatusCode::UNAUTHORIZED, DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE, DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR, } diff --git a/src/libdoh/src/globals.rs b/src/libdoh/src/globals.rs index 3fbba84..2e8f002 100644 --- a/src/libdoh/src/globals.rs +++ b/src/libdoh/src/globals.rs @@ -1,3 +1,4 @@ +use crate::odoh::ODoHRotator; use std::net::SocketAddr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -28,6 +29,8 @@ pub struct Globals { pub err_ttl: u32, pub keepalive: bool, pub disable_post: bool, + pub odoh_configs_path: String, + pub odoh_rotator: Arc, pub runtime_handle: runtime::Handle, } diff --git a/src/libdoh/src/lib.rs b/src/libdoh/src/lib.rs index 03eb14b..9f5329f 100644 --- a/src/libdoh/src/lib.rs +++ b/src/libdoh/src/lib.rs @@ -1,5 +1,6 @@ mod constants; pub mod dns; +pub mod odoh; mod errors; mod globals; #[cfg(feature = "tls")] @@ -25,9 +26,30 @@ pub mod reexports { pub use tokio; } +#[derive(Clone, Debug)] +struct DnsResponse { + packet: Vec, + ttl: u32, +} + +#[derive(Clone, Debug)] +enum DoHType { + Standard, + Oblivious, +} + +impl DoHType { + fn as_str(&self) -> String { + match self { + DoHType::Standard => String::from("application/dns-message"), + DoHType::Oblivious => String::from("application/oblivious-dns-message"), + } + } +} + #[derive(Clone, Debug)] pub struct DoH { - pub globals: Arc, + pub globals: Arc } #[allow(clippy::unnecessary_wraps)] @@ -72,14 +94,20 @@ impl hyper::service::Service> for DoH { fn call(&mut self, req: Request) -> Self::Future { let globals = &self.globals; - if req.uri().path() != globals.path { - return Box::pin(async { http_error(StatusCode::NOT_FOUND) }); - } let self_inner = self.clone(); - match *req.method() { - Method::POST => Box::pin(async move { self_inner.serve_post(req).await }), - Method::GET => Box::pin(async move { self_inner.serve_get(req).await }), - _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), + if req.uri().path() == globals.path { + match *req.method() { + Method::POST => Box::pin(async move { self_inner.serve_post(req).await }), + Method::GET => Box::pin(async move { self_inner.serve_get(req).await }), + _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), + } + } else if req.uri().path() == globals.odoh_configs_path { + match *req.method() { + Method::GET => Box::pin(async move { self_inner.serve_odoh_configs().await }), + _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), + } + } else { + Box::pin(async { http_error(StatusCode::NOT_FOUND) }) } } } @@ -89,12 +117,65 @@ impl DoH { if self.globals.disable_post { return http_error(StatusCode::METHOD_NOT_ALLOWED); } - if let Err(response) = Self::check_content_type(&req) { - return Ok(response); + + match Self::parse_content_type(&req) { + Ok(DoHType::Standard) => self.serve_doh_post(req).await, + Ok(DoHType::Oblivious) => self.serve_odoh_post(req).await, + Err(response) => return Ok(response) } - match self.read_body_and_proxy(req.into_body()).await { + } + + async fn serve_doh_post(&self, req:Request) -> Result, http::Error> { + let query = match self.read_body(req.into_body()).await { + Ok(q) => q, + Err(e) => return http_error(StatusCode::from(e)) + }; + + let resp = match self.proxy(query).await { + Ok(resp) => self.build_response(resp.packet, resp.ttl, DoHType::Standard.as_str()), + Err(e) => return http_error(StatusCode::from(e)), + }; + + match resp { + Ok(resp) => Ok(resp), + Err(e) => http_error(StatusCode::from(e)), + } + } + + async fn serve_odoh_post(&self, req:Request) -> Result, http::Error> { + let query_body = match self.read_body(req.into_body()).await { + Ok(q) => q, + Err(e) => return http_error(StatusCode::from(e)) + }; + + let odoh_public_key = (*self.globals.odoh_rotator).clone().current_key(); + let (query, context) = match (*odoh_public_key).clone().decrypt_query(query_body).await { + Ok((q, context)) => (q.to_vec(), context), + Err(e) => return http_error(StatusCode::from(e)) + }; + + let resp_body = match self.proxy(query).await { + Ok(resp) => resp, + Err(e) => return http_error(StatusCode::from(e)) + }; + + let resp = match context.encrypt_response(resp_body.packet).await { + Ok(resp) => self.build_response(resp, 0u32, DoHType::Oblivious.as_str()), + Err(e) => return http_error(StatusCode::from(e)), + }; + + match resp { + Ok(resp) => Ok(resp), + Err(e) => http_error(StatusCode::from(e)), + } + } + + async fn serve_odoh_configs(&self) -> Result, http::Error> { + let odoh_public_key = (*self.globals.odoh_rotator).clone().current_key(); + let configs = (*odoh_public_key).clone().config(); + match self.build_response(configs, 0, "application/octet-stream".to_string()) { + Ok(resp) => Ok(resp), Err(e) => http_error(StatusCode::from(e)), - Ok(res) => Ok(res), } } @@ -117,13 +198,18 @@ impl DoH { return http_error(StatusCode::BAD_REQUEST); } }; - match self.proxy(question).await { + + let resp = match self.proxy(question).await { + Ok(dns_resp) => self.build_response(dns_resp.packet, dns_resp.ttl, DoHType::Standard.as_str()), + Err(e) => Err(e), + }; + match resp { + Ok(resp) => Ok(resp), Err(e) => http_error(StatusCode::from(e)), - Ok(res) => Ok(res), } } - fn check_content_type(req: &Request) -> Result<(), Response> { + fn parse_content_type(req: &Request) -> Result> { let headers = req.headers(); let content_type = match headers.get(hyper::header::CONTENT_TYPE) { None => { @@ -145,17 +231,21 @@ impl DoH { } Ok(content_type) => content_type.to_lowercase(), }; - if content_type != "application/dns-message" { - let response = Response::builder() - .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) - .body(Body::empty()) - .unwrap(); - return Err(response); + + match content_type.as_str() { + "application/dns-message" => Ok(DoHType::Standard), + "application/oblivious-dns-message" => Ok(DoHType::Oblivious), + _ => { + let response = Response::builder() + .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) + .body(Body::empty()) + .unwrap(); + return Err(response); + } } - Ok(()) } - async fn read_body_and_proxy(&self, mut body: Body) -> Result, DoHError> { + async fn read_body(&self, mut body: Body) -> Result, DoHError> { let mut sum_size = 0; let mut query = vec![]; while let Some(chunk) = body.next().await { @@ -166,17 +256,16 @@ impl DoH { } query.extend(chunk); } - let response = self.proxy(query).await?; - Ok(response) + Ok(query) } - async fn proxy(&self, query: Vec) -> Result, DoHError> { + async fn proxy(&self, query: Vec) -> Result { let proxy_timeout = self.globals.timeout; let timeout_res = tokio::time::timeout(proxy_timeout, self._proxy(query)).await; timeout_res.map_err(|_| DoHError::UpstreamTimeout)? } - async fn _proxy(&self, mut query: Vec) -> Result, DoHError> { + async fn _proxy(&self, mut query: Vec) -> Result { if query.len() < MIN_DNS_PACKET_LEN { return Err(DoHError::Incomplete); } @@ -209,10 +298,17 @@ impl DoH { dns::add_edns_padding(&mut packet) .map_err(|_| DoHError::TooLarge) .ok(); + Ok(DnsResponse{ + packet, + ttl, + }) + } + + fn build_response(&self, packet: Vec, ttl: u32, content_type: String) -> Result, DoHError> { let packet_len = packet.len(); let response = Response::builder() .header(hyper::header::CONTENT_LENGTH, packet_len) - .header(hyper::header::CONTENT_TYPE, "application/dns-message") + .header(hyper::header::CONTENT_TYPE, content_type.as_str()) .header( hyper::header::CACHE_CONTROL, format!( diff --git a/src/libdoh/src/odoh.rs b/src/libdoh/src/odoh.rs new file mode 100644 index 0000000..5169cc6 --- /dev/null +++ b/src/libdoh/src/odoh.rs @@ -0,0 +1,158 @@ +use crate::constants::ODOH_KEY_ROTATION_SECS; +use crate::errors::DoHError; +use std::sync::Arc; +use arc_swap::ArcSwap; +use std::time::Duration; +use tokio::runtime; +use hpke::kex::Serializable; +use odoh_rs::key_utils::{derive_keypair_from_seed}; +use odoh_rs::protocol::{create_response_msg, + parse_received_query, RESPONSE_NONCE_SIZE, + ObliviousDoHQueryBody, Serialize, Deserialize, + ObliviousDoHKeyPair, ObliviousDoHConfigContents, + ObliviousDoHConfig, ObliviousDoHConfigs, + ObliviousDoHMessage, ObliviousDoHMessageType, +}; +use rand::Rng; +use std::fmt; + +// https://cfrg.github.io/draft-irtf-cfrg-hpke/draft-irtf-cfrg-hpke.html#name-algorithm-identifiers +const DEFAULT_HPKE_SEED_SIZE: usize = 32; +const DEFAULT_HPKE_KEM: u16 = 0x0020; // DHKEM(X25519, HKDF-SHA256) +const DEFAULT_HPKE_KDF: u16 = 0x0001; // KDF(SHA-256) +const DEFAULT_HPKE_AEAD: u16 = 0x0001; // AEAD(AES-GCM-128) + +#[derive(Clone)] +pub struct ODoHPublicKey { + key: ObliviousDoHKeyPair, + serialized_configs: Vec, +} + +impl fmt::Debug for ODoHPublicKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ODoHPublicKey").finish() + } +} + +#[derive(Clone, Debug)] +pub struct ODoHQueryContext { + query: ObliviousDoHQueryBody, + secret: Vec, +} + +fn generate_key_pair() -> ObliviousDoHKeyPair { + let ikm = rand::thread_rng().gen::<[u8; DEFAULT_HPKE_SEED_SIZE]>(); + let (secret_key, public_key) = derive_keypair_from_seed(&ikm); + let public_key_bytes = public_key.to_bytes().to_vec(); + let odoh_public_key = ObliviousDoHConfigContents { + kem_id: DEFAULT_HPKE_KEM, + kdf_id: DEFAULT_HPKE_KDF, + aead_id: DEFAULT_HPKE_AEAD, + public_key: public_key_bytes, + }; + ObliviousDoHKeyPair { + private_key: secret_key, + public_key: odoh_public_key, + } +} + +impl ODoHPublicKey { + pub fn new() -> Result { + let key_pair = generate_key_pair(); + let config = ObliviousDoHConfig::new(&key_pair.public_key.clone().to_bytes().unwrap()).unwrap(); + let serialized_configs = ObliviousDoHConfigs { + configs: vec![config.clone()], + } + .to_bytes() + .unwrap() + .to_vec(); + + Ok(ODoHPublicKey{ + key: key_pair, + serialized_configs: serialized_configs + }) + } + + pub fn config(self) -> Vec { + self.serialized_configs + } + + pub async fn decrypt_query(self, encrypted_query: Vec) -> Result<(Vec, ODoHQueryContext), DoHError> { + let odoh_query = match ObliviousDoHMessage::from_bytes(&encrypted_query) { + Ok(q) => { + if q.msg_type != ObliviousDoHMessageType::Query { + return Err(DoHError::InvalidData); + } + q + }, + Err(_) => return Err(DoHError::InvalidData) + }; + + match self.key.public_key.identifier() { + Ok(key_id) => { + if !key_id.eq(&odoh_query.key_id) { + return Err(DoHError::StaleKey); + } + }, + Err(_) => return Err(DoHError::InvalidData) + }; + + let (query, server_secret) = match parse_received_query(&self.key, &encrypted_query).await { + Ok((pq, ss)) => (pq, ss), + Err(_) => return Err(DoHError::InvalidData) + }; + let context = ODoHQueryContext{ + query: query.clone(), + secret: server_secret, + }; + Ok((query.dns_msg.clone(), context)) + } +} + +impl ODoHQueryContext { + pub async fn encrypt_response(self, response_body: Vec) -> Result, DoHError> { + let response_nonce = rand::thread_rng().gen::<[u8; RESPONSE_NONCE_SIZE]>(); + create_response_msg(&self.secret, &response_body, None, Some(response_nonce.to_vec()), &self.query) + .await + .map_err(|_| { + DoHError::InvalidData + }) + } +} + +#[derive(Clone, Debug)] +pub struct ODoHRotator { + key: Arc>, +} + +impl ODoHRotator { + pub fn new(runtime_handle: runtime::Handle) -> Result { + let odoh_key = match ODoHPublicKey::new() { + Ok(key) => Arc::new(ArcSwap::from_pointee(key)), + Err(e) => panic!("ODoH key rotation error: {}", e), + }; + + let current_key = Arc::clone(&odoh_key); + + runtime_handle.clone().spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(ODOH_KEY_ROTATION_SECS.into())).await; + match ODoHPublicKey::new() { + Ok(key) => { + current_key.store(Arc::new(key)); + }, + Err(e) => eprintln!("ODoH key rotation error: {}", e), + }; + } + }); + + Ok(ODoHRotator{ + key: Arc::clone(&odoh_key) + }) + } + + pub fn current_key(&self) -> Arc { + let key = Arc::clone(&self.key); + Arc::clone(&key.load()) + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index f7f69a4..86ecfe0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ use libdoh::*; use crate::config::*; use crate::constants::*; +use libdoh::odoh::ODoHRotator; use libdoh::reexports::tokio; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; @@ -24,6 +25,11 @@ fn main() { runtime_builder.thread_name("doh-proxy"); let runtime = runtime_builder.build().unwrap(); + let rotator = match ODoHRotator::new(runtime.handle().clone()) { + Ok(r) => r, + Err(_) => panic!("Failed to create ODoHRotator"), + }; + let mut globals = Globals { #[cfg(feature = "tls")] tls_cert_path: None, @@ -43,6 +49,8 @@ fn main() { err_ttl: ERR_TTL, keepalive: true, disable_post: false, + odoh_configs_path: ODOH_CONFIGS_PATH.to_string(), + odoh_rotator: Arc::new(rotator), runtime_handle: runtime.handle().clone(), };