Implement ODoH key rotation.

This commit is contained in:
Christopher Wood 2021-05-01 07:56:10 -07:00
parent 25a9c285db
commit 822d3d9a51
7 changed files with 81 additions and 12 deletions

View file

@ -25,6 +25,7 @@ tokio-rustls = { version = "0.22.0", features = ["early-data"], optional = true
odoh-rs = "0.1.11"
rand = "0.7"
hpke = "0.5.0"
arc-swap = "1.2.0"
[profile.release]
codegen-units = 1

View file

@ -5,3 +5,4 @@ pub const MIN_DNS_PACKET_LEN: usize = 17;
pub const STALE_IF_ERROR_SECS: u32 = 86400;
pub const STALE_WHILE_REVALIDATE_SECS: u32 = 60;
pub const CERTS_WATCH_DELAY_SECS: u32 = 10;
pub const ODOH_KEY_ROTATION_SECS: u32 = 86400;

View file

@ -9,6 +9,7 @@ pub enum DoHError {
TooLarge,
UpstreamIssue,
UpstreamTimeout,
StaleKey,
Hyper(hyper::Error),
Io(io::Error),
}
@ -23,6 +24,7 @@ impl std::fmt::Display for DoHError {
DoHError::TooLarge => write!(fmt, "Too large"),
DoHError::UpstreamIssue => write!(fmt, "Upstream error"),
DoHError::UpstreamTimeout => write!(fmt, "Upstream timeout"),
DoHError::StaleKey => write!(fmt, "Stale key material"),
DoHError::Hyper(e) => write!(fmt, "HTTP error: {}", e),
DoHError::Io(e) => write!(fmt, "IO error: {}", e),
}
@ -37,6 +39,7 @@ impl From<DoHError> for StatusCode {
DoHError::TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
DoHError::UpstreamIssue => StatusCode::BAD_GATEWAY,
DoHError::UpstreamTimeout => StatusCode::BAD_GATEWAY,
DoHError::StaleKey => StatusCode::UNAUTHORIZED,
DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE,
DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
}

View file

@ -1,4 +1,4 @@
use crate::odoh::ODoHPublicKey;
use crate::odoh::ODoHRotator;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
@ -30,7 +30,7 @@ pub struct Globals {
pub keepalive: bool,
pub disable_post: bool,
pub odoh_configs_path: String,
pub odoh_public_key: Arc<ODoHPublicKey>,
pub odoh_rotator: Arc<ODoHRotator>,
pub runtime_handle: runtime::Handle,
}

View file

@ -148,9 +148,10 @@ impl DoH {
Err(e) => return http_error(StatusCode::from(e))
};
let (query, context) = match (*self.globals.odoh_public_key).clone().decrypt_query(query_body).await {
let odoh_public_key = (*self.globals.odoh_rotator).clone().current_key();
let (query, context) = match (*odoh_public_key).clone().decrypt_query(query_body).await {
Ok((q, context)) => (q.to_vec(), context),
Err(_) => return http_error(StatusCode::from(DoHError::InvalidData))
Err(e) => return http_error(StatusCode::from(e))
};
let resp_body = match self.proxy(query).await {
@ -170,7 +171,8 @@ impl DoH {
}
async fn serve_odoh_configs(&self) -> Result<Response<Body>, http::Error> {
let configs = (*self.globals.odoh_public_key).clone().config();
let odoh_public_key = (*self.globals.odoh_rotator).clone().current_key();
let configs = (*odoh_public_key).clone().config();
match self.build_response(configs, 0, "application/octet-stream".to_string()) {
Ok(resp) => Ok(resp),
Err(e) => http_error(StatusCode::from(e)),

View file

@ -1,11 +1,17 @@
use crate::constants::ODOH_KEY_ROTATION_SECS;
use crate::errors::DoHError;
use std::sync::Arc;
use arc_swap::ArcSwap;
use std::time::Duration;
use tokio::runtime;
use hpke::kex::Serializable;
use odoh_rs::key_utils::{derive_keypair_from_seed};
use odoh_rs::protocol::{create_response_msg,
parse_received_query, RESPONSE_NONCE_SIZE,
ObliviousDoHQueryBody, Serialize,
ObliviousDoHQueryBody, Serialize, Deserialize,
ObliviousDoHKeyPair, ObliviousDoHConfigContents,
ObliviousDoHConfig, ObliviousDoHConfigs
ObliviousDoHConfig, ObliviousDoHConfigs,
ObliviousDoHMessage, ObliviousDoHMessageType,
};
use rand::Rng;
use std::fmt;
@ -72,6 +78,25 @@ impl ODoHPublicKey {
}
pub async fn decrypt_query(self, encrypted_query: Vec<u8>) -> Result<(Vec<u8>, ODoHQueryContext), DoHError> {
let odoh_query = match ObliviousDoHMessage::from_bytes(&encrypted_query) {
Ok(q) => {
if q.msg_type != ObliviousDoHMessageType::Query {
return Err(DoHError::InvalidData);
}
q
},
Err(_) => return Err(DoHError::InvalidData)
};
match self.key.public_key.identifier() {
Ok(key_id) => {
if !key_id.eq(&odoh_query.key_id) {
return Err(DoHError::StaleKey);
}
},
Err(_) => return Err(DoHError::InvalidData)
};
let (query, server_secret) = match parse_received_query(&self.key, &encrypted_query).await {
Ok((pq, ss)) => (pq, ss),
Err(_) => return Err(DoHError::InvalidData)
@ -93,4 +118,41 @@ impl ODoHQueryContext {
DoHError::InvalidData
})
}
}
#[derive(Clone, Debug)]
pub struct ODoHRotator {
key: Arc<ArcSwap<ODoHPublicKey>>,
}
impl ODoHRotator {
pub fn new(runtime_handle: runtime::Handle) -> Result<ODoHRotator, DoHError> {
let odoh_key = match ODoHPublicKey::new() {
Ok(key) => Arc::new(ArcSwap::from_pointee(key)),
Err(e) => panic!("ODoH key rotation error: {}", e),
};
let current_key = Arc::clone(&odoh_key);
runtime_handle.clone().spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(ODOH_KEY_ROTATION_SECS.into())).await;
match ODoHPublicKey::new() {
Ok(key) => {
current_key.store(Arc::new(key));
},
Err(e) => eprintln!("ODoH key rotation error: {}", e),
};
}
});
Ok(ODoHRotator{
key: Arc::clone(&odoh_key)
})
}
pub fn current_key(&self) -> Arc<ODoHPublicKey> {
let key = Arc::clone(&self.key);
Arc::clone(&key.load())
}
}

View file

@ -13,7 +13,7 @@ use libdoh::*;
use crate::config::*;
use crate::constants::*;
use libdoh::odoh::ODoHPublicKey;
use libdoh::odoh::ODoHRotator;
use libdoh::reexports::tokio;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
@ -25,9 +25,9 @@ fn main() {
runtime_builder.thread_name("doh-proxy");
let runtime = runtime_builder.build().unwrap();
let odoh_key = match ODoHPublicKey::new() {
Ok(key) => key,
Err(_) => panic!("Failed to generate ODoH public key configuration"),
let rotator = match ODoHRotator::new(runtime.handle().clone()) {
Ok(r) => r,
Err(_) => panic!("Failed to create ODoHRotator"),
};
let mut globals = Globals {
@ -50,7 +50,7 @@ fn main() {
keepalive: true,
disable_post: false,
odoh_configs_path: ODOH_CONFIGS_PATH.to_string(),
odoh_public_key: Arc::new(odoh_key),
odoh_rotator: Arc::new(rotator),
runtime_handle: runtime.handle().clone(),
};