Factorize a bit

This commit is contained in:
Frank Denis 2019-12-23 20:22:00 +01:00
parent f2e5f13e85
commit a4938aa962

View file

@ -21,7 +21,6 @@ use crate::globals::*;
#[cfg(feature = "tls")]
use crate::tls::*;
use futures::future;
use futures::prelude::*;
use futures::task::{Context, Poll};
use hyper::http;
@ -39,6 +38,14 @@ pub struct DoH {
pub globals: Arc<Globals>,
}
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
let response = Response::builder()
.status(status_code)
.body(Body::empty())
.unwrap();
Ok(response)
}
impl hyper::service::Service<http::Request<Body>> for DoH {
type Response = Response<Body>;
type Error = http::Error;
@ -51,80 +58,56 @@ impl hyper::service::Service<http::Request<Body>> for DoH {
fn call(&mut self, req: Request<Body>) -> Self::Future {
let globals = &self.globals;
if req.uri().path() != globals.path {
let response = Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap();
return Box::pin(async { Ok(response) });
return Box::pin(async { http_error(StatusCode::NOT_FOUND) });
}
let self_inner = self.clone();
match *req.method() {
Method::POST => {
if globals.disable_post {
let response = Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::empty())
.unwrap();
return Box::pin(async { Ok(response) });
}
if let Err(response) = Self::check_content_type(&req) {
return Box::pin(async { Ok(response) });
}
let fut = async move {
match self_inner.read_body_and_proxy(req.into_body()).await {
Err(e) => Response::builder()
.status(StatusCode::from(e))
.body(Body::empty()),
Ok(res) => Ok(res),
}
};
Box::pin(fut)
}
Method::GET => {
let query = req.uri().query().unwrap_or("");
let mut question_str = None;
for parts in query.split('&') {
let mut kv = parts.split('=');
if let Some(k) = kv.next() {
if k == DNS_QUERY_PARAM {
question_str = kv.next();
}
}
}
let question = match question_str.and_then(|question_str| {
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
}) {
Some(question) => question,
_ => {
let response = Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::empty())
.unwrap();
return Box::pin(future::ok(response));
}
};
let fut = async move {
match self_inner.proxy(question).await {
Err(e) => Response::builder()
.status(StatusCode::from(e))
.body(Body::empty()),
Ok(res) => Ok(res),
}
};
Box::pin(fut)
}
_ => {
let response = Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::empty())
.unwrap();
Box::pin(async { Ok(response) })
}
Method::POST => Box::pin(async move { self_inner.serve_post(req).await }),
Method::GET => Box::pin(async move { self_inner.serve_get(req).await }),
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
}
}
}
impl DoH {
async fn serve_post(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
if self.globals.disable_post {
return http_error(StatusCode::METHOD_NOT_ALLOWED);
}
if let Err(response) = Self::check_content_type(&req) {
return Ok(response);
}
match self.read_body_and_proxy(req.into_body()).await {
Err(e) => http_error(StatusCode::from(e)),
Ok(res) => Ok(res),
}
}
async fn serve_get(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
let query = req.uri().query().unwrap_or("");
let mut question_str = None;
for parts in query.split('&') {
let mut kv = parts.split('=');
if let Some(k) = kv.next() {
if k == DNS_QUERY_PARAM {
question_str = kv.next();
}
}
}
let question = match question_str.and_then(|question_str| {
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
}) {
Some(question) => question,
_ => {
return http_error(StatusCode::BAD_REQUEST);
}
};
match self.proxy(question).await {
Err(e) => http_error(StatusCode::from(e)),
Ok(res) => Ok(res),
}
}
fn check_content_type(req: &Request<Body>) -> Result<(), Response<Body>> {
let headers = req.headers();
let content_type = match headers.get(hyper::header::CONTENT_TYPE) {
@ -256,7 +239,7 @@ impl DoH {
let listen_address = self.globals.listen_address;
let listener = TcpListener::bind(&listen_address)
.await
.map_err(|e| DoHError::Io(e))?;
.map_err(DoHError::Io)?;
let path = &self.globals.path;
#[cfg(feature = "tls")]