mirror of
https://github.com/DNSCrypt/doh-server.git
synced 2025-04-03 04:57:37 +03:00
Factorize a bit
This commit is contained in:
parent
f2e5f13e85
commit
a4938aa962
1 changed files with 51 additions and 68 deletions
119
src/main.rs
119
src/main.rs
|
@ -21,7 +21,6 @@ use crate::globals::*;
|
|||
#[cfg(feature = "tls")]
|
||||
use crate::tls::*;
|
||||
|
||||
use futures::future;
|
||||
use futures::prelude::*;
|
||||
use futures::task::{Context, Poll};
|
||||
use hyper::http;
|
||||
|
@ -39,6 +38,14 @@ pub struct DoH {
|
|||
pub globals: Arc<Globals>,
|
||||
}
|
||||
|
||||
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
|
||||
let response = Response::builder()
|
||||
.status(status_code)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
impl hyper::service::Service<http::Request<Body>> for DoH {
|
||||
type Response = Response<Body>;
|
||||
type Error = http::Error;
|
||||
|
@ -51,80 +58,56 @@ impl hyper::service::Service<http::Request<Body>> for DoH {
|
|||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||
let globals = &self.globals;
|
||||
if req.uri().path() != globals.path {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Box::pin(async { Ok(response) });
|
||||
return Box::pin(async { http_error(StatusCode::NOT_FOUND) });
|
||||
}
|
||||
let self_inner = self.clone();
|
||||
match *req.method() {
|
||||
Method::POST => {
|
||||
if globals.disable_post {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::METHOD_NOT_ALLOWED)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Box::pin(async { Ok(response) });
|
||||
}
|
||||
if let Err(response) = Self::check_content_type(&req) {
|
||||
return Box::pin(async { Ok(response) });
|
||||
}
|
||||
let fut = async move {
|
||||
match self_inner.read_body_and_proxy(req.into_body()).await {
|
||||
Err(e) => Response::builder()
|
||||
.status(StatusCode::from(e))
|
||||
.body(Body::empty()),
|
||||
Ok(res) => Ok(res),
|
||||
}
|
||||
};
|
||||
Box::pin(fut)
|
||||
}
|
||||
Method::GET => {
|
||||
let query = req.uri().query().unwrap_or("");
|
||||
let mut question_str = None;
|
||||
for parts in query.split('&') {
|
||||
let mut kv = parts.split('=');
|
||||
if let Some(k) = kv.next() {
|
||||
if k == DNS_QUERY_PARAM {
|
||||
question_str = kv.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
let question = match question_str.and_then(|question_str| {
|
||||
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
|
||||
}) {
|
||||
Some(question) => question,
|
||||
_ => {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return Box::pin(future::ok(response));
|
||||
}
|
||||
};
|
||||
let fut = async move {
|
||||
match self_inner.proxy(question).await {
|
||||
Err(e) => Response::builder()
|
||||
.status(StatusCode::from(e))
|
||||
.body(Body::empty()),
|
||||
Ok(res) => Ok(res),
|
||||
}
|
||||
};
|
||||
Box::pin(fut)
|
||||
}
|
||||
_ => {
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::METHOD_NOT_ALLOWED)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
Box::pin(async { Ok(response) })
|
||||
}
|
||||
Method::POST => Box::pin(async move { self_inner.serve_post(req).await }),
|
||||
Method::GET => Box::pin(async move { self_inner.serve_get(req).await }),
|
||||
_ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DoH {
|
||||
async fn serve_post(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
|
||||
if self.globals.disable_post {
|
||||
return http_error(StatusCode::METHOD_NOT_ALLOWED);
|
||||
}
|
||||
if let Err(response) = Self::check_content_type(&req) {
|
||||
return Ok(response);
|
||||
}
|
||||
match self.read_body_and_proxy(req.into_body()).await {
|
||||
Err(e) => http_error(StatusCode::from(e)),
|
||||
Ok(res) => Ok(res),
|
||||
}
|
||||
}
|
||||
|
||||
async fn serve_get(&self, req: Request<Body>) -> Result<Response<Body>, http::Error> {
|
||||
let query = req.uri().query().unwrap_or("");
|
||||
let mut question_str = None;
|
||||
for parts in query.split('&') {
|
||||
let mut kv = parts.split('=');
|
||||
if let Some(k) = kv.next() {
|
||||
if k == DNS_QUERY_PARAM {
|
||||
question_str = kv.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
let question = match question_str.and_then(|question_str| {
|
||||
base64::decode_config(question_str, base64::URL_SAFE_NO_PAD).ok()
|
||||
}) {
|
||||
Some(question) => question,
|
||||
_ => {
|
||||
return http_error(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
};
|
||||
match self.proxy(question).await {
|
||||
Err(e) => http_error(StatusCode::from(e)),
|
||||
Ok(res) => Ok(res),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_content_type(req: &Request<Body>) -> Result<(), Response<Body>> {
|
||||
let headers = req.headers();
|
||||
let content_type = match headers.get(hyper::header::CONTENT_TYPE) {
|
||||
|
@ -256,7 +239,7 @@ impl DoH {
|
|||
let listen_address = self.globals.listen_address;
|
||||
let listener = TcpListener::bind(&listen_address)
|
||||
.await
|
||||
.map_err(|e| DoHError::Io(e))?;
|
||||
.map_err(DoHError::Io)?;
|
||||
let path = &self.globals.path;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue