Merge pull request #59 from chris-wood/caw/add-odoh

Add Oblivious DoH target support as a default feature.
This commit is contained in:
Frank Denis 2021-05-11 22:49:57 +02:00 committed by GitHub
commit 4e54008b10
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 302 additions and 28 deletions

View file

@ -2,6 +2,7 @@ pub const LISTEN_ADDRESS: &str = "127.0.0.1:3000";
pub const MAX_CLIENTS: usize = 512;
pub const MAX_CONCURRENT_STREAMS: u32 = 16;
pub const PATH: &str = "/dns-query";
pub const ODOH_CONFIGS_PATH: &str = "/.well-known/odohconfigs";
pub const SERVER_ADDRESS: &str = "9.9.9.9:53";
pub const TIMEOUT_SEC: u64 = 10;
pub const MAX_TTL: u32 = 86400 * 7;

View file

@ -22,6 +22,10 @@ futures = "0.3.13"
hyper = { version = "0.14.4", default-features = false, features = ["server", "http1", "http2", "stream"] }
tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time", "sync"] }
tokio-rustls = { version = "0.22.0", features = ["early-data"], optional = true }
odoh-rs = "0.1.11"
rand = "0.7"
hpke = "0.5.0"
arc-swap = "1.2.0"
[profile.release]
codegen-units = 1

View file

@ -5,3 +5,4 @@ pub const MIN_DNS_PACKET_LEN: usize = 17;
pub const STALE_IF_ERROR_SECS: u32 = 86400;
pub const STALE_WHILE_REVALIDATE_SECS: u32 = 60;
pub const CERTS_WATCH_DELAY_SECS: u32 = 10;
pub const ODOH_KEY_ROTATION_SECS: u32 = 86400;

View file

@ -9,6 +9,7 @@ pub enum DoHError {
TooLarge,
UpstreamIssue,
UpstreamTimeout,
StaleKey,
Hyper(hyper::Error),
Io(io::Error),
}
@ -23,6 +24,7 @@ impl std::fmt::Display for DoHError {
DoHError::TooLarge => write!(fmt, "Too large"),
DoHError::UpstreamIssue => write!(fmt, "Upstream error"),
DoHError::UpstreamTimeout => write!(fmt, "Upstream timeout"),
DoHError::StaleKey => write!(fmt, "Stale key material"),
DoHError::Hyper(e) => write!(fmt, "HTTP error: {}", e),
DoHError::Io(e) => write!(fmt, "IO error: {}", e),
}
@ -37,6 +39,7 @@ impl From<DoHError> for StatusCode {
DoHError::TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
DoHError::UpstreamIssue => StatusCode::BAD_GATEWAY,
DoHError::UpstreamTimeout => StatusCode::BAD_GATEWAY,
DoHError::StaleKey => StatusCode::UNAUTHORIZED,
DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE,
DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
}

View file

@ -1,3 +1,4 @@
use crate::odoh::ODoHRotator;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
@ -28,6 +29,8 @@ pub struct Globals {
pub err_ttl: u32,
pub keepalive: bool,
pub disable_post: bool,
pub odoh_configs_path: String,
pub odoh_rotator: Arc<ODoHRotator>,
pub runtime_handle: runtime::Handle,
}

View file

@ -1,5 +1,6 @@
mod constants;
pub mod dns;
pub mod odoh;
mod errors;
mod globals;
#[cfg(feature = "tls")]
@ -25,9 +26,30 @@ pub mod reexports {
pub use tokio;
}
#[derive(Clone, Debug)]
struct DnsResponse {
packet: Vec<u8>,
ttl: u32,
}
#[derive(Clone, Debug)]
enum DoHType {
Standard,
Oblivious,
}
impl DoHType {
fn as_str(&self) -> String {
match self {
DoHType::Standard => String::from("application/dns-message"),
DoHType::Oblivious => String::from("application/oblivious-dns-message"),
}
}
}
#[derive(Clone, Debug)]
pub struct DoH {
pub globals: Arc<Globals>,
pub globals: Arc<Globals>
}
#[allow(clippy::unnecessary_wraps)]
@ -72,14 +94,20 @@ impl hyper::service::Service<http::Request<Body>> for DoH {
fn call(&mut self, req: Request<Body>) -> Self::Future {
let globals = &self.globals;
if req.uri().path() != globals.path {
return Box::pin(async { http_error(StatusCode::NOT_FOUND) });
}
let self_inner = self.clone();
match *req.method() {
Method::POST => Box::pin(async move { self_inner.serve_post(req).await }),
Method::GET => Box::pin(async move { self_inner.serve_get(req).await }),
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
if req.uri().path() == globals.path {
match *req.method() {
Method::POST => Box::pin(async move { self_inner.serve_post(req).await }),
Method::GET => Box::pin(async move { self_inner.serve_get(req).await }),
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
}
} else if req.uri().path() == globals.odoh_configs_path {
match *req.method() {
Method::GET => Box::pin(async move { self_inner.serve_odoh_configs().await }),
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
}
} else {
Box::pin(async { http_error(StatusCode::NOT_FOUND) })
}
}
}
@ -89,12 +117,65 @@ impl DoH {
if self.globals.disable_post {
return http_error(StatusCode::METHOD_NOT_ALLOWED);
}
if let Err(response) = Self::check_content_type(&req) {
return Ok(response);
match Self::parse_content_type(&req) {
Ok(DoHType::Standard) => self.serve_doh_post(req).await,
Ok(DoHType::Oblivious) => self.serve_odoh_post(req).await,
Err(response) => return Ok(response)
}
match self.read_body_and_proxy(req.into_body()).await {
}
async fn serve_doh_post(&self, req:Request<Body>) -> Result<Response<Body>, http::Error> {
let query = match self.read_body(req.into_body()).await {
Ok(q) => q,
Err(e) => return http_error(StatusCode::from(e))
};
let resp = match self.proxy(query).await {
Ok(resp) => self.build_response(resp.packet, resp.ttl, DoHType::Standard.as_str()),
Err(e) => return http_error(StatusCode::from(e)),
};
match resp {
Ok(resp) => Ok(resp),
Err(e) => http_error(StatusCode::from(e)),
}
}
async fn serve_odoh_post(&self, req:Request<Body>) -> Result<Response<Body>, http::Error> {
let query_body = match self.read_body(req.into_body()).await {
Ok(q) => q,
Err(e) => return http_error(StatusCode::from(e))
};
let odoh_public_key = (*self.globals.odoh_rotator).clone().current_key();
let (query, context) = match (*odoh_public_key).clone().decrypt_query(query_body).await {
Ok((q, context)) => (q.to_vec(), context),
Err(e) => return http_error(StatusCode::from(e))
};
let resp_body = match self.proxy(query).await {
Ok(resp) => resp,
Err(e) => return http_error(StatusCode::from(e))
};
let resp = match context.encrypt_response(resp_body.packet).await {
Ok(resp) => self.build_response(resp, 0u32, DoHType::Oblivious.as_str()),
Err(e) => return http_error(StatusCode::from(e)),
};
match resp {
Ok(resp) => Ok(resp),
Err(e) => http_error(StatusCode::from(e)),
}
}
async fn serve_odoh_configs(&self) -> Result<Response<Body>, http::Error> {
let odoh_public_key = (*self.globals.odoh_rotator).clone().current_key();
let configs = (*odoh_public_key).clone().config();
match self.build_response(configs, 0, "application/octet-stream".to_string()) {
Ok(resp) => Ok(resp),
Err(e) => http_error(StatusCode::from(e)),
Ok(res) => Ok(res),
}
}
@ -117,13 +198,18 @@ impl DoH {
return http_error(StatusCode::BAD_REQUEST);
}
};
match self.proxy(question).await {
let resp = match self.proxy(question).await {
Ok(dns_resp) => self.build_response(dns_resp.packet, dns_resp.ttl, DoHType::Standard.as_str()),
Err(e) => Err(e),
};
match resp {
Ok(resp) => Ok(resp),
Err(e) => http_error(StatusCode::from(e)),
Ok(res) => Ok(res),
}
}
fn check_content_type(req: &Request<Body>) -> Result<(), Response<Body>> {
fn parse_content_type(req: &Request<Body>) -> Result<DoHType, Response<Body>> {
let headers = req.headers();
let content_type = match headers.get(hyper::header::CONTENT_TYPE) {
None => {
@ -145,17 +231,21 @@ impl DoH {
}
Ok(content_type) => content_type.to_lowercase(),
};
if content_type != "application/dns-message" {
let response = Response::builder()
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
.body(Body::empty())
.unwrap();
return Err(response);
match content_type.as_str() {
"application/dns-message" => Ok(DoHType::Standard),
"application/oblivious-dns-message" => Ok(DoHType::Oblivious),
_ => {
let response = Response::builder()
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
.body(Body::empty())
.unwrap();
return Err(response);
}
}
Ok(())
}
async fn read_body_and_proxy(&self, mut body: Body) -> Result<Response<Body>, DoHError> {
async fn read_body(&self, mut body: Body) -> Result<Vec<u8>, DoHError> {
let mut sum_size = 0;
let mut query = vec![];
while let Some(chunk) = body.next().await {
@ -166,17 +256,16 @@ impl DoH {
}
query.extend(chunk);
}
let response = self.proxy(query).await?;
Ok(response)
Ok(query)
}
async fn proxy(&self, query: Vec<u8>) -> Result<Response<Body>, DoHError> {
async fn proxy(&self, query: Vec<u8>) -> Result<DnsResponse, DoHError> {
let proxy_timeout = self.globals.timeout;
let timeout_res = tokio::time::timeout(proxy_timeout, self._proxy(query)).await;
timeout_res.map_err(|_| DoHError::UpstreamTimeout)?
}
async fn _proxy(&self, mut query: Vec<u8>) -> Result<Response<Body>, DoHError> {
async fn _proxy(&self, mut query: Vec<u8>) -> Result<DnsResponse, DoHError> {
if query.len() < MIN_DNS_PACKET_LEN {
return Err(DoHError::Incomplete);
}
@ -209,10 +298,17 @@ impl DoH {
dns::add_edns_padding(&mut packet)
.map_err(|_| DoHError::TooLarge)
.ok();
Ok(DnsResponse{
packet,
ttl,
})
}
fn build_response(&self, packet: Vec<u8>, ttl: u32, content_type: String) -> Result<Response<Body>, DoHError> {
let packet_len = packet.len();
let response = Response::builder()
.header(hyper::header::CONTENT_LENGTH, packet_len)
.header(hyper::header::CONTENT_TYPE, "application/dns-message")
.header(hyper::header::CONTENT_TYPE, content_type.as_str())
.header(
hyper::header::CACHE_CONTROL,
format!(

158
src/libdoh/src/odoh.rs Normal file
View file

@ -0,0 +1,158 @@
use crate::constants::ODOH_KEY_ROTATION_SECS;
use crate::errors::DoHError;
use std::sync::Arc;
use arc_swap::ArcSwap;
use std::time::Duration;
use tokio::runtime;
use hpke::kex::Serializable;
use odoh_rs::key_utils::{derive_keypair_from_seed};
use odoh_rs::protocol::{create_response_msg,
parse_received_query, RESPONSE_NONCE_SIZE,
ObliviousDoHQueryBody, Serialize, Deserialize,
ObliviousDoHKeyPair, ObliviousDoHConfigContents,
ObliviousDoHConfig, ObliviousDoHConfigs,
ObliviousDoHMessage, ObliviousDoHMessageType,
};
use rand::Rng;
use std::fmt;
// https://cfrg.github.io/draft-irtf-cfrg-hpke/draft-irtf-cfrg-hpke.html#name-algorithm-identifiers
const DEFAULT_HPKE_SEED_SIZE: usize = 32;
const DEFAULT_HPKE_KEM: u16 = 0x0020; // DHKEM(X25519, HKDF-SHA256)
const DEFAULT_HPKE_KDF: u16 = 0x0001; // KDF(SHA-256)
const DEFAULT_HPKE_AEAD: u16 = 0x0001; // AEAD(AES-GCM-128)
#[derive(Clone)]
pub struct ODoHPublicKey {
key: ObliviousDoHKeyPair,
serialized_configs: Vec<u8>,
}
impl fmt::Debug for ODoHPublicKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ODoHPublicKey").finish()
}
}
#[derive(Clone, Debug)]
pub struct ODoHQueryContext {
query: ObliviousDoHQueryBody,
secret: Vec<u8>,
}
fn generate_key_pair() -> ObliviousDoHKeyPair {
let ikm = rand::thread_rng().gen::<[u8; DEFAULT_HPKE_SEED_SIZE]>();
let (secret_key, public_key) = derive_keypair_from_seed(&ikm);
let public_key_bytes = public_key.to_bytes().to_vec();
let odoh_public_key = ObliviousDoHConfigContents {
kem_id: DEFAULT_HPKE_KEM,
kdf_id: DEFAULT_HPKE_KDF,
aead_id: DEFAULT_HPKE_AEAD,
public_key: public_key_bytes,
};
ObliviousDoHKeyPair {
private_key: secret_key,
public_key: odoh_public_key,
}
}
impl ODoHPublicKey {
pub fn new() -> Result<ODoHPublicKey, DoHError> {
let key_pair = generate_key_pair();
let config = ObliviousDoHConfig::new(&key_pair.public_key.clone().to_bytes().unwrap()).unwrap();
let serialized_configs = ObliviousDoHConfigs {
configs: vec![config.clone()],
}
.to_bytes()
.unwrap()
.to_vec();
Ok(ODoHPublicKey{
key: key_pair,
serialized_configs: serialized_configs
})
}
pub fn config(self) -> Vec<u8> {
self.serialized_configs
}
pub async fn decrypt_query(self, encrypted_query: Vec<u8>) -> Result<(Vec<u8>, ODoHQueryContext), DoHError> {
let odoh_query = match ObliviousDoHMessage::from_bytes(&encrypted_query) {
Ok(q) => {
if q.msg_type != ObliviousDoHMessageType::Query {
return Err(DoHError::InvalidData);
}
q
},
Err(_) => return Err(DoHError::InvalidData)
};
match self.key.public_key.identifier() {
Ok(key_id) => {
if !key_id.eq(&odoh_query.key_id) {
return Err(DoHError::StaleKey);
}
},
Err(_) => return Err(DoHError::InvalidData)
};
let (query, server_secret) = match parse_received_query(&self.key, &encrypted_query).await {
Ok((pq, ss)) => (pq, ss),
Err(_) => return Err(DoHError::InvalidData)
};
let context = ODoHQueryContext{
query: query.clone(),
secret: server_secret,
};
Ok((query.dns_msg.clone(), context))
}
}
impl ODoHQueryContext {
pub async fn encrypt_response(self, response_body: Vec<u8>) -> Result<Vec<u8>, DoHError> {
let response_nonce = rand::thread_rng().gen::<[u8; RESPONSE_NONCE_SIZE]>();
create_response_msg(&self.secret, &response_body, None, Some(response_nonce.to_vec()), &self.query)
.await
.map_err(|_| {
DoHError::InvalidData
})
}
}
#[derive(Clone, Debug)]
pub struct ODoHRotator {
key: Arc<ArcSwap<ODoHPublicKey>>,
}
impl ODoHRotator {
pub fn new(runtime_handle: runtime::Handle) -> Result<ODoHRotator, DoHError> {
let odoh_key = match ODoHPublicKey::new() {
Ok(key) => Arc::new(ArcSwap::from_pointee(key)),
Err(e) => panic!("ODoH key rotation error: {}", e),
};
let current_key = Arc::clone(&odoh_key);
runtime_handle.clone().spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(ODOH_KEY_ROTATION_SECS.into())).await;
match ODoHPublicKey::new() {
Ok(key) => {
current_key.store(Arc::new(key));
},
Err(e) => eprintln!("ODoH key rotation error: {}", e),
};
}
});
Ok(ODoHRotator{
key: Arc::clone(&odoh_key)
})
}
pub fn current_key(&self) -> Arc<ODoHPublicKey> {
let key = Arc::clone(&self.key);
Arc::clone(&key.load())
}
}

View file

@ -13,6 +13,7 @@ use libdoh::*;
use crate::config::*;
use crate::constants::*;
use libdoh::odoh::ODoHRotator;
use libdoh::reexports::tokio;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
@ -24,6 +25,11 @@ fn main() {
runtime_builder.thread_name("doh-proxy");
let runtime = runtime_builder.build().unwrap();
let rotator = match ODoHRotator::new(runtime.handle().clone()) {
Ok(r) => r,
Err(_) => panic!("Failed to create ODoHRotator"),
};
let mut globals = Globals {
#[cfg(feature = "tls")]
tls_cert_path: None,
@ -43,6 +49,8 @@ fn main() {
err_ttl: ERR_TTL,
keepalive: true,
disable_post: false,
odoh_configs_path: ODOH_CONFIGS_PATH.to_string(),
odoh_rotator: Arc::new(rotator),
runtime_handle: runtime.handle().clone(),
};