mirror of
https://github.com/DNSCrypt/doh-server.git
synced 2025-04-05 05:57:38 +03:00
Merge pull request #59 from chris-wood/caw/add-odoh
Add Oblivious DoH target support as a default feature.
This commit is contained in:
commit
4e54008b10
8 changed files with 302 additions and 28 deletions
|
@ -2,6 +2,7 @@ pub const LISTEN_ADDRESS: &str = "127.0.0.1:3000";
|
||||||
pub const MAX_CLIENTS: usize = 512;
|
pub const MAX_CLIENTS: usize = 512;
|
||||||
pub const MAX_CONCURRENT_STREAMS: u32 = 16;
|
pub const MAX_CONCURRENT_STREAMS: u32 = 16;
|
||||||
pub const PATH: &str = "/dns-query";
|
pub const PATH: &str = "/dns-query";
|
||||||
|
pub const ODOH_CONFIGS_PATH: &str = "/.well-known/odohconfigs";
|
||||||
pub const SERVER_ADDRESS: &str = "9.9.9.9:53";
|
pub const SERVER_ADDRESS: &str = "9.9.9.9:53";
|
||||||
pub const TIMEOUT_SEC: u64 = 10;
|
pub const TIMEOUT_SEC: u64 = 10;
|
||||||
pub const MAX_TTL: u32 = 86400 * 7;
|
pub const MAX_TTL: u32 = 86400 * 7;
|
||||||
|
|
|
@ -22,6 +22,10 @@ futures = "0.3.13"
|
||||||
hyper = { version = "0.14.4", default-features = false, features = ["server", "http1", "http2", "stream"] }
|
hyper = { version = "0.14.4", default-features = false, features = ["server", "http1", "http2", "stream"] }
|
||||||
tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time", "sync"] }
|
tokio = { version = "1.2.0", features = ["net", "rt-multi-thread", "parking_lot", "time", "sync"] }
|
||||||
tokio-rustls = { version = "0.22.0", features = ["early-data"], optional = true }
|
tokio-rustls = { version = "0.22.0", features = ["early-data"], optional = true }
|
||||||
|
odoh-rs = "0.1.11"
|
||||||
|
rand = "0.7"
|
||||||
|
hpke = "0.5.0"
|
||||||
|
arc-swap = "1.2.0"
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
codegen-units = 1
|
codegen-units = 1
|
||||||
|
|
|
@ -5,3 +5,4 @@ pub const MIN_DNS_PACKET_LEN: usize = 17;
|
||||||
pub const STALE_IF_ERROR_SECS: u32 = 86400;
|
pub const STALE_IF_ERROR_SECS: u32 = 86400;
|
||||||
pub const STALE_WHILE_REVALIDATE_SECS: u32 = 60;
|
pub const STALE_WHILE_REVALIDATE_SECS: u32 = 60;
|
||||||
pub const CERTS_WATCH_DELAY_SECS: u32 = 10;
|
pub const CERTS_WATCH_DELAY_SECS: u32 = 10;
|
||||||
|
pub const ODOH_KEY_ROTATION_SECS: u32 = 86400;
|
||||||
|
|
|
@ -9,6 +9,7 @@ pub enum DoHError {
|
||||||
TooLarge,
|
TooLarge,
|
||||||
UpstreamIssue,
|
UpstreamIssue,
|
||||||
UpstreamTimeout,
|
UpstreamTimeout,
|
||||||
|
StaleKey,
|
||||||
Hyper(hyper::Error),
|
Hyper(hyper::Error),
|
||||||
Io(io::Error),
|
Io(io::Error),
|
||||||
}
|
}
|
||||||
|
@ -23,6 +24,7 @@ impl std::fmt::Display for DoHError {
|
||||||
DoHError::TooLarge => write!(fmt, "Too large"),
|
DoHError::TooLarge => write!(fmt, "Too large"),
|
||||||
DoHError::UpstreamIssue => write!(fmt, "Upstream error"),
|
DoHError::UpstreamIssue => write!(fmt, "Upstream error"),
|
||||||
DoHError::UpstreamTimeout => write!(fmt, "Upstream timeout"),
|
DoHError::UpstreamTimeout => write!(fmt, "Upstream timeout"),
|
||||||
|
DoHError::StaleKey => write!(fmt, "Stale key material"),
|
||||||
DoHError::Hyper(e) => write!(fmt, "HTTP error: {}", e),
|
DoHError::Hyper(e) => write!(fmt, "HTTP error: {}", e),
|
||||||
DoHError::Io(e) => write!(fmt, "IO error: {}", e),
|
DoHError::Io(e) => write!(fmt, "IO error: {}", e),
|
||||||
}
|
}
|
||||||
|
@ -37,6 +39,7 @@ impl From<DoHError> for StatusCode {
|
||||||
DoHError::TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
|
DoHError::TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
|
||||||
DoHError::UpstreamIssue => StatusCode::BAD_GATEWAY,
|
DoHError::UpstreamIssue => StatusCode::BAD_GATEWAY,
|
||||||
DoHError::UpstreamTimeout => StatusCode::BAD_GATEWAY,
|
DoHError::UpstreamTimeout => StatusCode::BAD_GATEWAY,
|
||||||
|
DoHError::StaleKey => StatusCode::UNAUTHORIZED,
|
||||||
DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE,
|
DoHError::Hyper(_) => StatusCode::SERVICE_UNAVAILABLE,
|
||||||
DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
DoHError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::odoh::ODoHRotator;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -28,6 +29,8 @@ pub struct Globals {
|
||||||
pub err_ttl: u32,
|
pub err_ttl: u32,
|
||||||
pub keepalive: bool,
|
pub keepalive: bool,
|
||||||
pub disable_post: bool,
|
pub disable_post: bool,
|
||||||
|
pub odoh_configs_path: String,
|
||||||
|
pub odoh_rotator: Arc<ODoHRotator>,
|
||||||
|
|
||||||
pub runtime_handle: runtime::Handle,
|
pub runtime_handle: runtime::Handle,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
mod constants;
|
mod constants;
|
||||||
pub mod dns;
|
pub mod dns;
|
||||||
|
pub mod odoh;
|
||||||
mod errors;
|
mod errors;
|
||||||
mod globals;
|
mod globals;
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
|
@ -25,9 +26,30 @@ pub mod reexports {
|
||||||
pub use tokio;
|
pub use tokio;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct DnsResponse {
|
||||||
|
packet: Vec<u8>,
|
||||||
|
ttl: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
enum DoHType {
|
||||||
|
Standard,
|
||||||
|
Oblivious,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DoHType {
|
||||||
|
fn as_str(&self) -> String {
|
||||||
|
match self {
|
||||||
|
DoHType::Standard => String::from("application/dns-message"),
|
||||||
|
DoHType::Oblivious => String::from("application/oblivious-dns-message"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct DoH {
|
pub struct DoH {
|
||||||
pub globals: Arc<Globals>,
|
pub globals: Arc<Globals>
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::unnecessary_wraps)]
|
#[allow(clippy::unnecessary_wraps)]
|
||||||
|
@ -72,14 +94,20 @@ impl hyper::service::Service<http::Request<Body>> for DoH {
|
||||||
|
|
||||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||||
let globals = &self.globals;
|
let globals = &self.globals;
|
||||||
if req.uri().path() != globals.path {
|
|
||||||
return Box::pin(async { http_error(StatusCode::NOT_FOUND) });
|
|
||||||
}
|
|
||||||
let self_inner = self.clone();
|
let self_inner = self.clone();
|
||||||
match *req.method() {
|
if req.uri().path() == globals.path {
|
||||||
Method::POST => Box::pin(async move { self_inner.serve_post(req).await }),
|
match *req.method() {
|
||||||
Method::GET => Box::pin(async move { self_inner.serve_get(req).await }),
|
Method::POST => Box::pin(async move { self_inner.serve_post(req).await }),
|
||||||
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
|
Method::GET => Box::pin(async move { self_inner.serve_get(req).await }),
|
||||||
|
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
|
||||||
|
}
|
||||||
|
} else if req.uri().path() == globals.odoh_configs_path {
|
||||||
|
match *req.method() {
|
||||||
|
Method::GET => Box::pin(async move { self_inner.serve_odoh_configs().await }),
|
||||||
|
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Box::pin(async { http_error(StatusCode::NOT_FOUND) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -89,12 +117,65 @@ impl DoH {
|
||||||
if self.globals.disable_post {
|
if self.globals.disable_post {
|
||||||
return http_error(StatusCode::METHOD_NOT_ALLOWED);
|
return http_error(StatusCode::METHOD_NOT_ALLOWED);
|
||||||
}
|
}
|
||||||
if let Err(response) = Self::check_content_type(&req) {
|
|
||||||
return Ok(response);
|
match Self::parse_content_type(&req) {
|
||||||
|
Ok(DoHType::Standard) => self.serve_doh_post(req).await,
|
||||||
|
Ok(DoHType::Oblivious) => self.serve_odoh_post(req).await,
|
||||||
|
Err(response) => return Ok(response)
|
||||||
}
|
}
|
||||||
match self.read_body_and_proxy(req.into_body()).await {
|
}
|
||||||
|
|
||||||
|
async fn serve_doh_post(&self, req:Request<Body>) -> Result<Response<Body>, http::Error> {
|
||||||
|
let query = match self.read_body(req.into_body()).await {
|
||||||
|
Ok(q) => q,
|
||||||
|
Err(e) => return http_error(StatusCode::from(e))
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = match self.proxy(query).await {
|
||||||
|
Ok(resp) => self.build_response(resp.packet, resp.ttl, DoHType::Standard.as_str()),
|
||||||
|
Err(e) => return http_error(StatusCode::from(e)),
|
||||||
|
};
|
||||||
|
|
||||||
|
match resp {
|
||||||
|
Ok(resp) => Ok(resp),
|
||||||
|
Err(e) => http_error(StatusCode::from(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn serve_odoh_post(&self, req:Request<Body>) -> Result<Response<Body>, http::Error> {
|
||||||
|
let query_body = match self.read_body(req.into_body()).await {
|
||||||
|
Ok(q) => q,
|
||||||
|
Err(e) => return http_error(StatusCode::from(e))
|
||||||
|
};
|
||||||
|
|
||||||
|
let odoh_public_key = (*self.globals.odoh_rotator).clone().current_key();
|
||||||
|
let (query, context) = match (*odoh_public_key).clone().decrypt_query(query_body).await {
|
||||||
|
Ok((q, context)) => (q.to_vec(), context),
|
||||||
|
Err(e) => return http_error(StatusCode::from(e))
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp_body = match self.proxy(query).await {
|
||||||
|
Ok(resp) => resp,
|
||||||
|
Err(e) => return http_error(StatusCode::from(e))
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = match context.encrypt_response(resp_body.packet).await {
|
||||||
|
Ok(resp) => self.build_response(resp, 0u32, DoHType::Oblivious.as_str()),
|
||||||
|
Err(e) => return http_error(StatusCode::from(e)),
|
||||||
|
};
|
||||||
|
|
||||||
|
match resp {
|
||||||
|
Ok(resp) => Ok(resp),
|
||||||
|
Err(e) => http_error(StatusCode::from(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn serve_odoh_configs(&self) -> Result<Response<Body>, http::Error> {
|
||||||
|
let odoh_public_key = (*self.globals.odoh_rotator).clone().current_key();
|
||||||
|
let configs = (*odoh_public_key).clone().config();
|
||||||
|
match self.build_response(configs, 0, "application/octet-stream".to_string()) {
|
||||||
|
Ok(resp) => Ok(resp),
|
||||||
Err(e) => http_error(StatusCode::from(e)),
|
Err(e) => http_error(StatusCode::from(e)),
|
||||||
Ok(res) => Ok(res),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,13 +198,18 @@ impl DoH {
|
||||||
return http_error(StatusCode::BAD_REQUEST);
|
return http_error(StatusCode::BAD_REQUEST);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
match self.proxy(question).await {
|
|
||||||
|
let resp = match self.proxy(question).await {
|
||||||
|
Ok(dns_resp) => self.build_response(dns_resp.packet, dns_resp.ttl, DoHType::Standard.as_str()),
|
||||||
|
Err(e) => Err(e),
|
||||||
|
};
|
||||||
|
match resp {
|
||||||
|
Ok(resp) => Ok(resp),
|
||||||
Err(e) => http_error(StatusCode::from(e)),
|
Err(e) => http_error(StatusCode::from(e)),
|
||||||
Ok(res) => Ok(res),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_content_type(req: &Request<Body>) -> Result<(), Response<Body>> {
|
fn parse_content_type(req: &Request<Body>) -> Result<DoHType, Response<Body>> {
|
||||||
let headers = req.headers();
|
let headers = req.headers();
|
||||||
let content_type = match headers.get(hyper::header::CONTENT_TYPE) {
|
let content_type = match headers.get(hyper::header::CONTENT_TYPE) {
|
||||||
None => {
|
None => {
|
||||||
|
@ -145,17 +231,21 @@ impl DoH {
|
||||||
}
|
}
|
||||||
Ok(content_type) => content_type.to_lowercase(),
|
Ok(content_type) => content_type.to_lowercase(),
|
||||||
};
|
};
|
||||||
if content_type != "application/dns-message" {
|
|
||||||
let response = Response::builder()
|
match content_type.as_str() {
|
||||||
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
|
"application/dns-message" => Ok(DoHType::Standard),
|
||||||
.body(Body::empty())
|
"application/oblivious-dns-message" => Ok(DoHType::Oblivious),
|
||||||
.unwrap();
|
_ => {
|
||||||
return Err(response);
|
let response = Response::builder()
|
||||||
|
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
return Err(response);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn read_body_and_proxy(&self, mut body: Body) -> Result<Response<Body>, DoHError> {
|
async fn read_body(&self, mut body: Body) -> Result<Vec<u8>, DoHError> {
|
||||||
let mut sum_size = 0;
|
let mut sum_size = 0;
|
||||||
let mut query = vec![];
|
let mut query = vec![];
|
||||||
while let Some(chunk) = body.next().await {
|
while let Some(chunk) = body.next().await {
|
||||||
|
@ -166,17 +256,16 @@ impl DoH {
|
||||||
}
|
}
|
||||||
query.extend(chunk);
|
query.extend(chunk);
|
||||||
}
|
}
|
||||||
let response = self.proxy(query).await?;
|
Ok(query)
|
||||||
Ok(response)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn proxy(&self, query: Vec<u8>) -> Result<Response<Body>, DoHError> {
|
async fn proxy(&self, query: Vec<u8>) -> Result<DnsResponse, DoHError> {
|
||||||
let proxy_timeout = self.globals.timeout;
|
let proxy_timeout = self.globals.timeout;
|
||||||
let timeout_res = tokio::time::timeout(proxy_timeout, self._proxy(query)).await;
|
let timeout_res = tokio::time::timeout(proxy_timeout, self._proxy(query)).await;
|
||||||
timeout_res.map_err(|_| DoHError::UpstreamTimeout)?
|
timeout_res.map_err(|_| DoHError::UpstreamTimeout)?
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn _proxy(&self, mut query: Vec<u8>) -> Result<Response<Body>, DoHError> {
|
async fn _proxy(&self, mut query: Vec<u8>) -> Result<DnsResponse, DoHError> {
|
||||||
if query.len() < MIN_DNS_PACKET_LEN {
|
if query.len() < MIN_DNS_PACKET_LEN {
|
||||||
return Err(DoHError::Incomplete);
|
return Err(DoHError::Incomplete);
|
||||||
}
|
}
|
||||||
|
@ -209,10 +298,17 @@ impl DoH {
|
||||||
dns::add_edns_padding(&mut packet)
|
dns::add_edns_padding(&mut packet)
|
||||||
.map_err(|_| DoHError::TooLarge)
|
.map_err(|_| DoHError::TooLarge)
|
||||||
.ok();
|
.ok();
|
||||||
|
Ok(DnsResponse{
|
||||||
|
packet,
|
||||||
|
ttl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_response(&self, packet: Vec<u8>, ttl: u32, content_type: String) -> Result<Response<Body>, DoHError> {
|
||||||
let packet_len = packet.len();
|
let packet_len = packet.len();
|
||||||
let response = Response::builder()
|
let response = Response::builder()
|
||||||
.header(hyper::header::CONTENT_LENGTH, packet_len)
|
.header(hyper::header::CONTENT_LENGTH, packet_len)
|
||||||
.header(hyper::header::CONTENT_TYPE, "application/dns-message")
|
.header(hyper::header::CONTENT_TYPE, content_type.as_str())
|
||||||
.header(
|
.header(
|
||||||
hyper::header::CACHE_CONTROL,
|
hyper::header::CACHE_CONTROL,
|
||||||
format!(
|
format!(
|
||||||
|
|
158
src/libdoh/src/odoh.rs
Normal file
158
src/libdoh/src/odoh.rs
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
use crate::constants::ODOH_KEY_ROTATION_SECS;
|
||||||
|
use crate::errors::DoHError;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use arc_swap::ArcSwap;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::runtime;
|
||||||
|
use hpke::kex::Serializable;
|
||||||
|
use odoh_rs::key_utils::{derive_keypair_from_seed};
|
||||||
|
use odoh_rs::protocol::{create_response_msg,
|
||||||
|
parse_received_query, RESPONSE_NONCE_SIZE,
|
||||||
|
ObliviousDoHQueryBody, Serialize, Deserialize,
|
||||||
|
ObliviousDoHKeyPair, ObliviousDoHConfigContents,
|
||||||
|
ObliviousDoHConfig, ObliviousDoHConfigs,
|
||||||
|
ObliviousDoHMessage, ObliviousDoHMessageType,
|
||||||
|
};
|
||||||
|
use rand::Rng;
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
// https://cfrg.github.io/draft-irtf-cfrg-hpke/draft-irtf-cfrg-hpke.html#name-algorithm-identifiers
|
||||||
|
const DEFAULT_HPKE_SEED_SIZE: usize = 32;
|
||||||
|
const DEFAULT_HPKE_KEM: u16 = 0x0020; // DHKEM(X25519, HKDF-SHA256)
|
||||||
|
const DEFAULT_HPKE_KDF: u16 = 0x0001; // KDF(SHA-256)
|
||||||
|
const DEFAULT_HPKE_AEAD: u16 = 0x0001; // AEAD(AES-GCM-128)
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ODoHPublicKey {
|
||||||
|
key: ObliviousDoHKeyPair,
|
||||||
|
serialized_configs: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for ODoHPublicKey {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("ODoHPublicKey").finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ODoHQueryContext {
|
||||||
|
query: ObliviousDoHQueryBody,
|
||||||
|
secret: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_key_pair() -> ObliviousDoHKeyPair {
|
||||||
|
let ikm = rand::thread_rng().gen::<[u8; DEFAULT_HPKE_SEED_SIZE]>();
|
||||||
|
let (secret_key, public_key) = derive_keypair_from_seed(&ikm);
|
||||||
|
let public_key_bytes = public_key.to_bytes().to_vec();
|
||||||
|
let odoh_public_key = ObliviousDoHConfigContents {
|
||||||
|
kem_id: DEFAULT_HPKE_KEM,
|
||||||
|
kdf_id: DEFAULT_HPKE_KDF,
|
||||||
|
aead_id: DEFAULT_HPKE_AEAD,
|
||||||
|
public_key: public_key_bytes,
|
||||||
|
};
|
||||||
|
ObliviousDoHKeyPair {
|
||||||
|
private_key: secret_key,
|
||||||
|
public_key: odoh_public_key,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ODoHPublicKey {
|
||||||
|
pub fn new() -> Result<ODoHPublicKey, DoHError> {
|
||||||
|
let key_pair = generate_key_pair();
|
||||||
|
let config = ObliviousDoHConfig::new(&key_pair.public_key.clone().to_bytes().unwrap()).unwrap();
|
||||||
|
let serialized_configs = ObliviousDoHConfigs {
|
||||||
|
configs: vec![config.clone()],
|
||||||
|
}
|
||||||
|
.to_bytes()
|
||||||
|
.unwrap()
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
Ok(ODoHPublicKey{
|
||||||
|
key: key_pair,
|
||||||
|
serialized_configs: serialized_configs
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn config(self) -> Vec<u8> {
|
||||||
|
self.serialized_configs
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn decrypt_query(self, encrypted_query: Vec<u8>) -> Result<(Vec<u8>, ODoHQueryContext), DoHError> {
|
||||||
|
let odoh_query = match ObliviousDoHMessage::from_bytes(&encrypted_query) {
|
||||||
|
Ok(q) => {
|
||||||
|
if q.msg_type != ObliviousDoHMessageType::Query {
|
||||||
|
return Err(DoHError::InvalidData);
|
||||||
|
}
|
||||||
|
q
|
||||||
|
},
|
||||||
|
Err(_) => return Err(DoHError::InvalidData)
|
||||||
|
};
|
||||||
|
|
||||||
|
match self.key.public_key.identifier() {
|
||||||
|
Ok(key_id) => {
|
||||||
|
if !key_id.eq(&odoh_query.key_id) {
|
||||||
|
return Err(DoHError::StaleKey);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(_) => return Err(DoHError::InvalidData)
|
||||||
|
};
|
||||||
|
|
||||||
|
let (query, server_secret) = match parse_received_query(&self.key, &encrypted_query).await {
|
||||||
|
Ok((pq, ss)) => (pq, ss),
|
||||||
|
Err(_) => return Err(DoHError::InvalidData)
|
||||||
|
};
|
||||||
|
let context = ODoHQueryContext{
|
||||||
|
query: query.clone(),
|
||||||
|
secret: server_secret,
|
||||||
|
};
|
||||||
|
Ok((query.dns_msg.clone(), context))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ODoHQueryContext {
|
||||||
|
pub async fn encrypt_response(self, response_body: Vec<u8>) -> Result<Vec<u8>, DoHError> {
|
||||||
|
let response_nonce = rand::thread_rng().gen::<[u8; RESPONSE_NONCE_SIZE]>();
|
||||||
|
create_response_msg(&self.secret, &response_body, None, Some(response_nonce.to_vec()), &self.query)
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
DoHError::InvalidData
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ODoHRotator {
|
||||||
|
key: Arc<ArcSwap<ODoHPublicKey>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ODoHRotator {
|
||||||
|
pub fn new(runtime_handle: runtime::Handle) -> Result<ODoHRotator, DoHError> {
|
||||||
|
let odoh_key = match ODoHPublicKey::new() {
|
||||||
|
Ok(key) => Arc::new(ArcSwap::from_pointee(key)),
|
||||||
|
Err(e) => panic!("ODoH key rotation error: {}", e),
|
||||||
|
};
|
||||||
|
|
||||||
|
let current_key = Arc::clone(&odoh_key);
|
||||||
|
|
||||||
|
runtime_handle.clone().spawn(async move {
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(Duration::from_secs(ODOH_KEY_ROTATION_SECS.into())).await;
|
||||||
|
match ODoHPublicKey::new() {
|
||||||
|
Ok(key) => {
|
||||||
|
current_key.store(Arc::new(key));
|
||||||
|
},
|
||||||
|
Err(e) => eprintln!("ODoH key rotation error: {}", e),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(ODoHRotator{
|
||||||
|
key: Arc::clone(&odoh_key)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn current_key(&self) -> Arc<ODoHPublicKey> {
|
||||||
|
let key = Arc::clone(&self.key);
|
||||||
|
Arc::clone(&key.load())
|
||||||
|
}
|
||||||
|
}
|
|
@ -13,6 +13,7 @@ use libdoh::*;
|
||||||
use crate::config::*;
|
use crate::config::*;
|
||||||
use crate::constants::*;
|
use crate::constants::*;
|
||||||
|
|
||||||
|
use libdoh::odoh::ODoHRotator;
|
||||||
use libdoh::reexports::tokio;
|
use libdoh::reexports::tokio;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -24,6 +25,11 @@ fn main() {
|
||||||
runtime_builder.thread_name("doh-proxy");
|
runtime_builder.thread_name("doh-proxy");
|
||||||
let runtime = runtime_builder.build().unwrap();
|
let runtime = runtime_builder.build().unwrap();
|
||||||
|
|
||||||
|
let rotator = match ODoHRotator::new(runtime.handle().clone()) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(_) => panic!("Failed to create ODoHRotator"),
|
||||||
|
};
|
||||||
|
|
||||||
let mut globals = Globals {
|
let mut globals = Globals {
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
tls_cert_path: None,
|
tls_cert_path: None,
|
||||||
|
@ -43,6 +49,8 @@ fn main() {
|
||||||
err_ttl: ERR_TTL,
|
err_ttl: ERR_TTL,
|
||||||
keepalive: true,
|
keepalive: true,
|
||||||
disable_post: false,
|
disable_post: false,
|
||||||
|
odoh_configs_path: ODOH_CONFIGS_PATH.to_string(),
|
||||||
|
odoh_rotator: Arc::new(rotator),
|
||||||
|
|
||||||
runtime_handle: runtime.handle().clone(),
|
runtime_handle: runtime.handle().clone(),
|
||||||
};
|
};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue