diff --git a/src/libdoh/Cargo.toml b/src/libdoh/Cargo.toml index 3df9bbe..a24df07 100644 --- a/src/libdoh/Cargo.toml +++ b/src/libdoh/Cargo.toml @@ -25,6 +25,7 @@ tokio-rustls = { version = "0.22.0", features = ["early-data"], optional = true odoh-rs = "0.1.11" rand = "0.7" hpke = "0.5.0" +arc-swap = "1.2.0" [profile.release] codegen-units = 1 diff --git a/src/libdoh/src/constants.rs b/src/libdoh/src/constants.rs index f1dccbb..976175a 100644 --- a/src/libdoh/src/constants.rs +++ b/src/libdoh/src/constants.rs @@ -5,3 +5,4 @@ pub const MIN_DNS_PACKET_LEN: usize = 17; pub const STALE_IF_ERROR_SECS: u32 = 86400; pub const STALE_WHILE_REVALIDATE_SECS: u32 = 60; pub const CERTS_WATCH_DELAY_SECS: u32 = 10; +pub const ODOH_KEY_ROTATION_SECS: u32 = 86400; diff --git a/src/libdoh/src/errors.rs b/src/libdoh/src/errors.rs index 707b6bb..55ec2fb 100644 --- a/src/libdoh/src/errors.rs +++ b/src/libdoh/src/errors.rs @@ -9,6 +9,7 @@ pub enum DoHError { TooLarge, UpstreamIssue, UpstreamTimeout, + StaleKey, Hyper(hyper::Error), Io(io::Error), } @@ -23,6 +24,7 @@ impl std::fmt::Display for DoHError { DoHError::TooLarge => write!(fmt, "Too large"), DoHError::UpstreamIssue => write!(fmt, "Upstream error"), DoHError::UpstreamTimeout => write!(fmt, "Upstream timeout"), + DoHError::StaleKey => write!(fmt, "Stale key material"), DoHError::Hyper(e) => write!(fmt, "HTTP error: {}", e), DoHError::Io(e) => write!(fmt, "IO error: {}", e), } @@ -37,6 +39,7 @@ impl From for StatusCode { DoHError::TooLarge => StatusCode::PAYLOAD_TOO_LARGE, DoHError::UpstreamIssue => StatusCode::BAD_GATEWAY, DoHError::UpstreamTimeout => StatusCode::BAD_GATEWAY, + DoHError::StaleKey => StatusCode::UNAUTHORIZED, DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE, DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR, } diff --git a/src/libdoh/src/globals.rs b/src/libdoh/src/globals.rs index f5327e8..2e8f002 100644 --- a/src/libdoh/src/globals.rs +++ b/src/libdoh/src/globals.rs @@ -1,4 +1,4 @@ -use crate::odoh::ODoHPublicKey; +use crate::odoh::ODoHRotator; use std::net::SocketAddr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -30,7 +30,7 @@ pub struct Globals { pub keepalive: bool, pub disable_post: bool, pub odoh_configs_path: String, - pub odoh_public_key: Arc, + pub odoh_rotator: Arc, pub runtime_handle: runtime::Handle, } diff --git a/src/libdoh/src/lib.rs b/src/libdoh/src/lib.rs index db825ea..9f5329f 100644 --- a/src/libdoh/src/lib.rs +++ b/src/libdoh/src/lib.rs @@ -148,9 +148,10 @@ impl DoH { Err(e) => return http_error(StatusCode::from(e)) }; - let (query, context) = match (*self.globals.odoh_public_key).clone().decrypt_query(query_body).await { + let odoh_public_key = (*self.globals.odoh_rotator).clone().current_key(); + let (query, context) = match (*odoh_public_key).clone().decrypt_query(query_body).await { Ok((q, context)) => (q.to_vec(), context), - Err(_) => return http_error(StatusCode::from(DoHError::InvalidData)) + Err(e) => return http_error(StatusCode::from(e)) }; let resp_body = match self.proxy(query).await { @@ -170,7 +171,8 @@ impl DoH { } async fn serve_odoh_configs(&self) -> Result, http::Error> { - let configs = (*self.globals.odoh_public_key).clone().config(); + let odoh_public_key = (*self.globals.odoh_rotator).clone().current_key(); + let configs = (*odoh_public_key).clone().config(); match self.build_response(configs, 0, "application/octet-stream".to_string()) { Ok(resp) => Ok(resp), Err(e) => http_error(StatusCode::from(e)), diff --git a/src/libdoh/src/odoh.rs b/src/libdoh/src/odoh.rs index b7e499c..5169cc6 100644 --- a/src/libdoh/src/odoh.rs +++ b/src/libdoh/src/odoh.rs @@ -1,11 +1,17 @@ +use crate::constants::ODOH_KEY_ROTATION_SECS; use crate::errors::DoHError; +use std::sync::Arc; +use arc_swap::ArcSwap; +use std::time::Duration; +use tokio::runtime; use hpke::kex::Serializable; use odoh_rs::key_utils::{derive_keypair_from_seed}; use odoh_rs::protocol::{create_response_msg, parse_received_query, RESPONSE_NONCE_SIZE, - ObliviousDoHQueryBody, Serialize, + ObliviousDoHQueryBody, Serialize, Deserialize, ObliviousDoHKeyPair, ObliviousDoHConfigContents, - ObliviousDoHConfig, ObliviousDoHConfigs + ObliviousDoHConfig, ObliviousDoHConfigs, + ObliviousDoHMessage, ObliviousDoHMessageType, }; use rand::Rng; use std::fmt; @@ -72,6 +78,25 @@ impl ODoHPublicKey { } pub async fn decrypt_query(self, encrypted_query: Vec) -> Result<(Vec, ODoHQueryContext), DoHError> { + let odoh_query = match ObliviousDoHMessage::from_bytes(&encrypted_query) { + Ok(q) => { + if q.msg_type != ObliviousDoHMessageType::Query { + return Err(DoHError::InvalidData); + } + q + }, + Err(_) => return Err(DoHError::InvalidData) + }; + + match self.key.public_key.identifier() { + Ok(key_id) => { + if !key_id.eq(&odoh_query.key_id) { + return Err(DoHError::StaleKey); + } + }, + Err(_) => return Err(DoHError::InvalidData) + }; + let (query, server_secret) = match parse_received_query(&self.key, &encrypted_query).await { Ok((pq, ss)) => (pq, ss), Err(_) => return Err(DoHError::InvalidData) @@ -93,4 +118,41 @@ impl ODoHQueryContext { DoHError::InvalidData }) } +} + +#[derive(Clone, Debug)] +pub struct ODoHRotator { + key: Arc>, +} + +impl ODoHRotator { + pub fn new(runtime_handle: runtime::Handle) -> Result { + let odoh_key = match ODoHPublicKey::new() { + Ok(key) => Arc::new(ArcSwap::from_pointee(key)), + Err(e) => panic!("ODoH key rotation error: {}", e), + }; + + let current_key = Arc::clone(&odoh_key); + + runtime_handle.clone().spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(ODOH_KEY_ROTATION_SECS.into())).await; + match ODoHPublicKey::new() { + Ok(key) => { + current_key.store(Arc::new(key)); + }, + Err(e) => eprintln!("ODoH key rotation error: {}", e), + }; + } + }); + + Ok(ODoHRotator{ + key: Arc::clone(&odoh_key) + }) + } + + pub fn current_key(&self) -> Arc { + let key = Arc::clone(&self.key); + Arc::clone(&key.load()) + } } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index fbf249c..86ecfe0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,7 +13,7 @@ use libdoh::*; use crate::config::*; use crate::constants::*; -use libdoh::odoh::ODoHPublicKey; +use libdoh::odoh::ODoHRotator; use libdoh::reexports::tokio; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; @@ -25,9 +25,9 @@ fn main() { runtime_builder.thread_name("doh-proxy"); let runtime = runtime_builder.build().unwrap(); - let odoh_key = match ODoHPublicKey::new() { - Ok(key) => key, - Err(_) => panic!("Failed to generate ODoH public key configuration"), + let rotator = match ODoHRotator::new(runtime.handle().clone()) { + Ok(r) => r, + Err(_) => panic!("Failed to create ODoHRotator"), }; let mut globals = Globals { @@ -50,7 +50,7 @@ fn main() { keepalive: true, disable_post: false, odoh_configs_path: ODOH_CONFIGS_PATH.to_string(), - odoh_public_key: Arc::new(odoh_key), + odoh_rotator: Arc::new(rotator), runtime_handle: runtime.handle().clone(), };