Merge pull request #8 from bluetech/hyper-0.12

Upgrade code to hyper 0.12
This commit is contained in:
Frank Denis 2018-07-07 21:17:05 +02:00 committed by GitHub
commit d042aa0f5a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 59 deletions

View file

@ -10,11 +10,10 @@ repository = "https://github.com/jedisct1/rust-doh"
categories = ["asynchronous", "network-programming","command-line-utilities"] categories = ["asynchronous", "network-programming","command-line-utilities"]
[dependencies] [dependencies]
base64 = "~0.9" base64 = "0.9"
clap = "~2" clap = "2"
futures = "~0.1" futures = "0.1"
hyper = "~0.11" hyper = "0.12"
tokio = "~0.1" tokio = "0.1"
tokio-io = "~0.1" tokio-timer = "0.2"
tokio-timer = "~0.2"
clippy = {version = ">=0", optional = true} clippy = {version = ">=0", optional = true}

View file

@ -6,7 +6,6 @@ extern crate clap;
extern crate futures; extern crate futures;
extern crate hyper; extern crate hyper;
extern crate tokio; extern crate tokio;
extern crate tokio_io;
extern crate tokio_timer; extern crate tokio_timer;
mod dns; mod dns;
@ -14,12 +13,11 @@ mod dns;
use clap::{App, Arg}; use clap::{App, Arg};
use futures::future; use futures::future;
use futures::prelude::*; use futures::prelude::*;
use hyper::header::{CacheControl, CacheDirective, ContentLength, ContentType}; use hyper::service::Service;
use hyper::server::{Http, Request, Response, Service}; use hyper::server::conn::Http;
use hyper::{Body, Method, StatusCode}; use hyper::{Request, Response, Body, Method, StatusCode};
use std::cell::RefCell;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::rc::Rc; use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::executor::current_thread; use tokio::executor::current_thread;
use tokio::net::{TcpListener, UdpSocket}; use tokio::net::{TcpListener, UdpSocket};
@ -48,21 +46,38 @@ struct DoH {
max_clients: u32, max_clients: u32,
timeout: Duration, timeout: Duration,
timer_handle: timer::Handle, timer_handle: timer::Handle,
clients_count: Rc<RefCell<u32>>, clients_count: Arc<Mutex<u32>>,
} }
impl Service for DoH { #[derive(Debug)]
type Request = Request; enum Error {
type Response = Response; Timeout,
type Error = hyper::Error; Incomplete,
type Future = Box<Future<Item = Self::Response, Error = Self::Error>>; TooLarge,
Hyper(hyper::Error),
}
impl std::fmt::Display for Error {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
// Should match, i'm lazy...
std::fmt::Debug::fmt(self, fmt)
}
}
impl std::error::Error for Error {}
fn call(&self, req: Request) -> Self::Future { impl Service for DoH {
type ReqBody = Body;
type ResBody = Body;
type Error = Error;
type Future = Box<Future<Item = Response<Body>, Error = Self::Error> + Send>;
fn call(&mut self, req: Request<Body>) -> Self::Future {
{ {
let count = self.clients_count.borrow_mut(); let count = self.clients_count.lock().unwrap();
if *count > self.max_clients { if *count > self.max_clients {
let mut response = Response::new(); let response = Response::builder()
response.set_status(StatusCode::TooManyRequests); .status(StatusCode::TOO_MANY_REQUESTS)
.body(Body::empty())
.unwrap();
return Box::new(future::ok(response)); return Box::new(future::ok(response));
} }
(*count).saturating_add(1); (*count).saturating_add(1);
@ -71,35 +86,37 @@ impl Service for DoH {
let fut = self let fut = self
.handle_client(req) .handle_client(req)
.then(move |fut| { .then(move |fut| {
(*clients_count_inner).borrow_mut().saturating_sub(1); (*clients_count_inner).lock().unwrap().saturating_sub(1);
fut fut
}) })
.map_err(|err| { .map_err(|err| {
eprintln!("server error: {:?}", err); eprintln!("server error: {}", err);
err err
}); });
let timed = self let timed = self
.timer_handle .timer_handle
.deadline(fut.map_err(|_| {}), Instant::now() + self.timeout) .deadline(fut.map_err(|_| {}), Instant::now() + self.timeout)
.map_err(|_| hyper::Error::Timeout); .map_err(|_| Error::Timeout);
Box::new(timed) Box::new(timed)
} }
} }
impl DoH { impl DoH {
fn handle_client(&self, req: Request) -> Box<Future<Item = Response, Error = hyper::Error>> { fn handle_client(&self, req: Request<Body>) -> Box<Future<Item = Response<Body>, Error = Error> + Send> {
let mut response = Response::new(); if req.uri().path() != self.path {
if req.path() != self.path { let response = Response::builder()
response.set_status(StatusCode::NotFound); .status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap();
return Box::new(future::ok(response)); return Box::new(future::ok(response));
} }
match *req.method() { match *req.method() {
Method::Post => { Method::POST => {
let fut = self.read_body_and_proxy(req.body()); let fut = self.read_body_and_proxy(req.into_body());
return Box::new(fut.map_err(|_| hyper::Error::Incomplete)); return Box::new(fut.map_err(|_| Error::Incomplete));
} }
Method::Get => { Method::GET => {
let query = req.query().unwrap_or(""); let query = req.uri().query().unwrap_or("");
let mut question_str = None; let mut question_str = None;
for parts in query.split('&') { for parts in query.split('&') {
let mut kv = parts.split('='); let mut kv = parts.split('=');
@ -114,21 +131,27 @@ impl DoH {
}) { }) {
Some(question) => question, Some(question) => question,
_ => { _ => {
response.set_status(StatusCode::BadRequest); let response = Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::empty())
.unwrap();
return Box::new(future::ok(response)); return Box::new(future::ok(response));
} }
}; };
let fut = self.proxy(question); let fut = self.proxy(question);
return Box::new(fut.map_err(|_| hyper::Error::Incomplete)); return Box::new(fut.map_err(|_| Error::Incomplete));
} }
_ => { _ => {
response.set_status(StatusCode::MethodNotAllowed); let response = Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::empty())
.unwrap();
return Box::new(future::ok(response));
} }
}; };
Box::new(future::ok(response))
} }
fn proxy(&self, query: Vec<u8>) -> Box<Future<Item = Response, Error = ()>> { fn proxy(&self, query: Vec<u8>) -> Box<Future<Item = Response<Body>, Error = ()> + Send> {
let socket = UdpSocket::bind(&self.local_bind_address).unwrap(); let socket = UdpSocket::bind(&self.local_bind_address).unwrap();
let expected_server_address = self.server_address; let expected_server_address = self.server_address;
let fut = socket let fut = socket
@ -148,25 +171,25 @@ impl DoH {
Ok(min_ttl) => min_ttl, Ok(min_ttl) => min_ttl,
}; };
let packet_len = packet.len(); let packet_len = packet.len();
let mut response = Response::new(); let response = Response::builder()
response.set_body(packet); .header(hyper::header::CONTENT_LENGTH, packet_len)
let response = response .header(hyper::header::CONTENT_TYPE, "application/dns-message")
.with_header(ContentLength(packet_len as u64)) .header(hyper::header::CACHE_CONTROL, format!("max-age={}", ttl).as_str())
.with_header(ContentType("application/dns-message".parse().unwrap())) .body(Body::from(packet))
.with_header(CacheControl(vec![CacheDirective::MaxAge(ttl)])); .unwrap();
future::ok(response) future::ok(response)
}); });
Box::new(fut) Box::new(fut)
} }
fn read_body_and_proxy(&self, body: Body) -> Box<Future<Item = Response, Error = ()>> { fn read_body_and_proxy(&self, body: Body) -> Box<Future<Item = Response<Body>, Error = ()> + Send> {
let mut sum_size = 0; let mut sum_size = 0;
let inner = self.clone(); let inner = self.clone();
let fut = let fut =
body.and_then(move |chunk| { body.map_err(|e| Error::Hyper(e)).and_then(move |chunk| {
sum_size += chunk.len(); sum_size += chunk.len();
if sum_size > MAX_DNS_QUESTION_LEN { if sum_size > MAX_DNS_QUESTION_LEN {
Err(hyper::error::Error::TooLarge) Err(Error::TooLarge)
} else { } else {
Ok(chunk) Ok(chunk)
} }
@ -175,9 +198,9 @@ impl DoH {
.map(move |chunk| chunk.to_vec()) .map(move |chunk| chunk.to_vec())
.and_then(move |query| { .and_then(move |query| {
if query.len() < MIN_DNS_PACKET_LEN { if query.len() < MIN_DNS_PACKET_LEN {
return Box::new(future::err(())) as Box<Future<Item = _, Error = _>>; return Box::new(future::err(())) as Box<Future<Item = _, Error = _> + Send>;
} }
Box::new(inner.proxy(query)) inner.proxy(query)
}); });
Box::new(fut) Box::new(fut)
} }
@ -191,22 +214,23 @@ fn main() {
path: PATH.to_string(), path: PATH.to_string(),
max_clients: MAX_CLIENTS, max_clients: MAX_CLIENTS,
timeout: Duration::from_secs(TIMEOUT_SEC), timeout: Duration::from_secs(TIMEOUT_SEC),
clients_count: Rc::new(RefCell::new(0u32)), clients_count: Arc::new(Mutex::new(0u32)),
timer_handle: Timer::default().handle(), timer_handle: Timer::default().handle(),
}; };
parse_opts(&mut doh); parse_opts(&mut doh);
let listen_address = doh.listen_address; let listen_address = doh.listen_address;
let listener = TcpListener::bind(&listen_address).unwrap(); let listener = TcpListener::bind(&listen_address).unwrap();
println!("Listening on http://{}", listen_address); println!("Listening on http://{}", listen_address);
let doh = Rc::new(doh); let mut http = Http::new();
let server = Http::new() http.keep_alive(false);
.keep_alive(false) let doh = Arc::new(Mutex::new(doh));
.serve_incoming(listener.incoming(), move || Ok(doh.clone())); let server = listener.incoming().for_each(move |io| {
let fut = server.for_each(move |client_fut| { let service = doh.lock().unwrap().clone();
current_thread::spawn(client_fut.map(|_| {}).map_err(|_| {})); let conn = http.serve_connection(io, service).map_err(|_| {});
current_thread::spawn(conn);
Ok(()) Ok(())
}); });
current_thread::block_on_all(fut).unwrap(); current_thread::block_on_all(server).unwrap();
} }
fn parse_opts(doh: &mut DoH) { fn parse_opts(doh: &mut DoH) {