Revert "..."

This reverts commit 5571b6c405.
This commit is contained in:
Frank Denis 2018-07-02 13:43:26 +02:00
parent 5571b6c405
commit 4ca54eb71b
2 changed files with 56 additions and 109 deletions

View file

@ -13,8 +13,8 @@ categories = ["asynchronous", "network-programming","command-line-utilities"]
base64 = "~0.9" base64 = "~0.9"
clap = "~2" clap = "~2"
futures = "~0.1" futures = "~0.1"
hyper = "~0.12" hyper = "~0.11"
tokio = "~0.1" tokio = "~0.1"
tokio-io = "~0.1" tokio-io = "~0.1"
tokio-timer = "~0.2" tokio-timer = "~0.1"
clippy = {version = ">=0", optional = true} clippy = {version = ">=0", optional = true}

View file

@ -14,18 +14,16 @@ mod dns;
use clap::{App, Arg}; use clap::{App, Arg};
use futures::future; use futures::future;
use futures::prelude::*; use futures::prelude::*;
use hyper::header::{CONTENT_LENGTH, CONTENT_TYPE, EXPIRES}; use hyper::header::{CacheControl, CacheDirective, ContentLength, ContentType};
use hyper::server::conn::Http; use hyper::server::{Http, Request, Response, Service};
use hyper::service::{NewService, Service}; use hyper::{Body, Method, StatusCode};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use std::cell::RefCell; use std::cell::RefCell;
use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::rc::Rc; use std::rc::Rc;
use std::time::{Duration, Instant, SystemTime}; use std::time::Duration;
use tokio::executor::current_thread; use tokio::executor::current_thread;
use tokio::net::{TcpListener, UdpSocket}; use tokio::net::{TcpListener, UdpSocket};
use tokio_timer::{timer, Timer}; use tokio_timer::Timer;
const DNS_QUERY_PARAM: &str = "dns"; const DNS_QUERY_PARAM: &str = "dns";
const LISTEN_ADDRESS: &str = "127.0.0.1:3000"; const LISTEN_ADDRESS: &str = "127.0.0.1:3000";
@ -49,24 +47,22 @@ struct DoH {
path: String, path: String,
max_clients: u32, max_clients: u32,
timeout: Duration, timeout: Duration,
timer_handle: timer::Handle, timers: Timer,
clients_count: Rc<RefCell<u32>>, clients_count: Rc<RefCell<u32>>,
} }
impl Service for DoH { impl Service for DoH {
type ReqBody = Body; type Request = Request;
type ResBody = Body; type Response = Response;
type Error = io::Error; type Error = hyper::Error;
type Future = Box<Future<Item = Response<Self::ResBody>, Error = io::Error>>; type Future = Box<Future<Item = Self::Response, Error = Self::Error>>;
fn call(&mut self, req: Request<Self::ReqBody>) -> Self::Future { fn call(&self, req: Request) -> Self::Future {
{ {
let count = self.clients_count.borrow_mut(); let count = self.clients_count.borrow_mut();
if *count > self.max_clients { if *count > self.max_clients {
let response = Response::builder() let mut response = Response::new();
.status(StatusCode::TOO_MANY_REQUESTS) response.set_status(StatusCode::TooManyRequests);
.body(Body::empty())
.unwrap();
return Box::new(future::ok(response)); return Box::new(future::ok(response));
} }
(*count).saturating_add(1); (*count).saturating_add(1);
@ -83,34 +79,27 @@ impl Service for DoH {
err err
}); });
let timed = self let timed = self
.timer_handle .timers
.deadline(fut.map_err(|_| {}), Instant::now() + self.timeout) .timeout(fut.map_err(|_| {}), self.timeout)
.map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "Timeout")); .map_err(|_| hyper::Error::Timeout);
Box::new(timed) Box::new(timed)
} }
} }
impl DoH { impl DoH {
fn handle_client( fn handle_client(&self, req: Request) -> Box<Future<Item = Response, Error = hyper::Error>> {
&self, let mut response = Response::new();
req: Request<Body>, if req.path() != self.path {
) -> Box<Future<Item = Response<Body>, Error = io::Error>> { response.set_status(StatusCode::NotFound);
if req.uri().path() != self.path {
let response = Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap();
return Box::new(future::ok(response)); return Box::new(future::ok(response));
} }
match *req.method() { match *req.method() {
Method::POST => { Method::Post => {
let fut = self.read_body_and_proxy(req.body()); let fut = self.read_body_and_proxy(req.body());
return Box::new( return Box::new(fut.map_err(|_| hyper::Error::Incomplete));
fut.map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "EOF")),
);
} }
Method::GET => { Method::Get => {
let query = req.uri().query().unwrap_or(""); let query = req.query().unwrap_or("");
let mut question_str = None; let mut question_str = None;
for parts in query.split('&') { for parts in query.split('&') {
let mut kv = parts.split('='); let mut kv = parts.split('=');
@ -125,29 +114,21 @@ impl DoH {
}) { }) {
Some(question) => question, Some(question) => question,
_ => { _ => {
let mut response = Response::builder() response.set_status(StatusCode::BadRequest);
.status(StatusCode::BAD_REQUEST)
.body(Body::empty())
.unwrap();
return Box::new(future::ok(response)); return Box::new(future::ok(response));
} }
}; };
let fut = self.proxy(question); let fut = self.proxy(question);
return Box::new( return Box::new(fut.map_err(|_| hyper::Error::Incomplete));
fut.map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "EOF")),
);
} }
_ => { _ => {
let mut response = Response::builder() response.set_status(StatusCode::MethodNotAllowed);
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::empty())
.unwrap();
return Box::new(future::ok(response));
} }
}; };
Box::new(future::ok(response))
} }
fn proxy(&self, query: Vec<u8>) -> Box<Future<Item = Response<Body>, Error = ()>> { fn proxy(&self, query: Vec<u8>) -> Box<Future<Item = Response, Error = ()>> {
let socket = UdpSocket::bind(&self.local_bind_address).unwrap(); let socket = UdpSocket::bind(&self.local_bind_address).unwrap();
let expected_server_address = self.server_address; let expected_server_address = self.server_address;
let fut = socket let fut = socket
@ -167,40 +148,31 @@ impl DoH {
Ok(min_ttl) => min_ttl, Ok(min_ttl) => min_ttl,
}; };
let packet_len = packet.len(); let packet_len = packet.len();
let mut response = Response::builder() let mut response = Response::new();
.header(CONTENT_LENGTH, format!("{}", packet_len).as_bytes()) response.set_body(packet);
.header(CONTENT_TYPE, "application/dns-message") let response = response
.header( .with_header(ContentLength(packet_len as u64))
EXPIRES, .with_header(ContentType(
format!( "application/dns-message".parse().unwrap(),
"{}", ))
(SystemTime::now() + Duration::from_secs(ttl as u64)) .with_header(CacheControl(vec![CacheDirective::MaxAge(ttl)]));
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(Duration::new(0, 0))
.as_secs()
).as_bytes(),
)
.body(Body::from(packet))
.unwrap();
future::ok(response) future::ok(response)
}); });
Box::new(fut) Box::new(fut)
} }
fn read_body_and_proxy(&self, body: &Body) -> Box<Future<Item = Response<Body>, Error = ()>> { fn read_body_and_proxy(&self, body: Body) -> Box<Future<Item = Response, Error = ()>> {
let mut sum_size = 0; let mut sum_size = 0;
let inner = self.clone(); let inner = self.clone();
let fut = body let fut =
.map_err(move |_err| ()) body.and_then(move |chunk| {
.and_then(move |chunk| {
sum_size += chunk.len(); sum_size += chunk.len();
if sum_size > MAX_DNS_QUESTION_LEN { if sum_size > MAX_DNS_QUESTION_LEN {
Err(()) Err(hyper::error::Error::TooLarge)
} else { } else {
Ok(chunk) Ok(chunk)
} }
}) }).concat2()
.concat2()
.map_err(move |_err| ()) .map_err(move |_err| ())
.map(move |chunk| chunk.to_vec()) .map(move |chunk| chunk.to_vec())
.and_then(move |query| { .and_then(move |query| {
@ -213,29 +185,6 @@ impl DoH {
} }
} }
struct DohNewService {
doh: DoH,
}
impl NewService for DohNewService {
type ReqBody = Body;
type ResBody = Body;
type Error = io::Error;
type Service = DoH;
type Future = Box<Future<Item = Self::Service, Error = Self::InitError>>;
type InitError = io::Error;
fn new_service(&self) -> Self::Future {
Box::new(future::ok(self.doh.clone()))
}
}
impl DohNewService {
fn new(doh: DoH) -> Self {
DohNewService { doh }
}
}
fn main() { fn main() {
let mut doh = DoH { let mut doh = DoH {
listen_address: LISTEN_ADDRESS.parse().unwrap(), listen_address: LISTEN_ADDRESS.parse().unwrap(),
@ -245,18 +194,16 @@ fn main() {
max_clients: MAX_CLIENTS, max_clients: MAX_CLIENTS,
timeout: Duration::from_secs(TIMEOUT_SEC), timeout: Duration::from_secs(TIMEOUT_SEC),
clients_count: Rc::new(RefCell::new(0u32)), clients_count: Rc::new(RefCell::new(0u32)),
timer_handle: Timer::default().handle(), timers: tokio_timer::wheel().build(),
}; };
parse_opts(&mut doh); parse_opts(&mut doh);
let listen_address = doh.listen_address; let listen_address = doh.listen_address;
let listener = TcpListener::bind(&listen_address).unwrap();
println!("Listening on http://{}", listen_address); println!("Listening on http://{}", listen_address);
let new_service = DohNewService::new(doh); let doh = Rc::new(doh);
let server = Http::new() let server = Http::new()
.keep_alive(false) .keep_alive(false)
.max_buf_size(0xffff) .serve_incoming(listener.incoming(), move || Ok(doh.clone()));
.pipeline_flush(true)
.serve_addr(&listen_address, || new_service)
.unwrap();
let fut = server.for_each(move |client_fut| { let fut = server.for_each(move |client_fut| {
current_thread::spawn(client_fut.map(|_| {}).map_err(|_| {})); current_thread::spawn(client_fut.map(|_| {}).map_err(|_| {}));
Ok(()) Ok(())