diff --git a/Cargo.toml b/Cargo.toml index 0d2d343..46998bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,8 +13,8 @@ categories = ["asynchronous", "network-programming","command-line-utilities"] base64 = "~0.9" clap = "~2" futures = "~0.1" -hyper = "~0.11" +hyper = "~0.12" tokio = "~0.1" tokio-io = "~0.1" -tokio-timer = "~0.1" +tokio-timer = "~0.2" clippy = {version = ">=0", optional = true} diff --git a/src/main.rs b/src/main.rs index ee3e7d5..94f4c3e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,16 +14,18 @@ mod dns; use clap::{App, Arg}; use futures::future; use futures::prelude::*; -use hyper::header::{CacheControl, CacheDirective, ContentLength, ContentType}; -use hyper::server::{Http, Request, Response, Service}; -use hyper::{Body, Method, StatusCode}; +use hyper::header::{CONTENT_LENGTH, CONTENT_TYPE, EXPIRES}; +use hyper::server::conn::Http; +use hyper::service::{NewService, Service}; +use hyper::{Body, Method, Request, Response, Server, StatusCode}; use std::cell::RefCell; +use std::io; use std::net::SocketAddr; use std::rc::Rc; -use std::time::Duration; +use std::time::{Duration, Instant, SystemTime}; use tokio::executor::current_thread; use tokio::net::{TcpListener, UdpSocket}; -use tokio_timer::Timer; +use tokio_timer::{timer, Timer}; const DNS_QUERY_PARAM: &str = "dns"; const LISTEN_ADDRESS: &str = "127.0.0.1:3000"; @@ -47,22 +49,24 @@ struct DoH { path: String, max_clients: u32, timeout: Duration, - timers: Timer, + timer_handle: timer::Handle, clients_count: Rc>, } impl Service for DoH { - type Request = Request; - type Response = Response; - type Error = hyper::Error; - type Future = Box>; + type ReqBody = Body; + type ResBody = Body; + type Error = io::Error; + type Future = Box, Error = io::Error>>; - fn call(&self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { { let count = self.clients_count.borrow_mut(); if *count > self.max_clients { - let mut response = Response::new(); - response.set_status(StatusCode::TooManyRequests); + let response = Response::builder() + .status(StatusCode::TOO_MANY_REQUESTS) + .body(Body::empty()) + .unwrap(); return Box::new(future::ok(response)); } (*count).saturating_add(1); @@ -79,27 +83,34 @@ impl Service for DoH { err }); let timed = self - .timers - .timeout(fut.map_err(|_| {}), self.timeout) - .map_err(|_| hyper::Error::Timeout); + .timer_handle + .deadline(fut.map_err(|_| {}), Instant::now() + self.timeout) + .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "Timeout")); Box::new(timed) } } impl DoH { - fn handle_client(&self, req: Request) -> Box> { - let mut response = Response::new(); - if req.path() != self.path { - response.set_status(StatusCode::NotFound); + fn handle_client( + &self, + req: Request, + ) -> Box, Error = io::Error>> { + if req.uri().path() != self.path { + let response = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap(); return Box::new(future::ok(response)); } match *req.method() { - Method::Post => { + Method::POST => { let fut = self.read_body_and_proxy(req.body()); - return Box::new(fut.map_err(|_| hyper::Error::Incomplete)); + return Box::new( + fut.map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "EOF")), + ); } - Method::Get => { - let query = req.query().unwrap_or(""); + Method::GET => { + let query = req.uri().query().unwrap_or(""); let mut question_str = None; for parts in query.split('&') { let mut kv = parts.split('='); @@ -114,21 +125,29 @@ impl DoH { }) { Some(question) => question, _ => { - response.set_status(StatusCode::BadRequest); + let mut response = Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::empty()) + .unwrap(); return Box::new(future::ok(response)); } }; let fut = self.proxy(question); - return Box::new(fut.map_err(|_| hyper::Error::Incomplete)); + return Box::new( + fut.map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "EOF")), + ); } _ => { - response.set_status(StatusCode::MethodNotAllowed); + let mut response = Response::builder() + .status(StatusCode::METHOD_NOT_ALLOWED) + .body(Body::empty()) + .unwrap(); + return Box::new(future::ok(response)); } }; - Box::new(future::ok(response)) } - fn proxy(&self, query: Vec) -> Box> { + fn proxy(&self, query: Vec) -> Box, Error = ()>> { let socket = UdpSocket::bind(&self.local_bind_address).unwrap(); let expected_server_address = self.server_address; let fut = socket @@ -148,43 +167,75 @@ impl DoH { Ok(min_ttl) => min_ttl, }; let packet_len = packet.len(); - let mut response = Response::new(); - response.set_body(packet); - let response = response - .with_header(ContentLength(packet_len as u64)) - .with_header(ContentType( - "application/dns-message".parse().unwrap(), - )) - .with_header(CacheControl(vec![CacheDirective::MaxAge(ttl)])); + let mut response = Response::builder() + .header(CONTENT_LENGTH, format!("{}", packet_len).as_bytes()) + .header(CONTENT_TYPE, "application/dns-message") + .header( + EXPIRES, + format!( + "{}", + (SystemTime::now() + Duration::from_secs(ttl as u64)) + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or(Duration::new(0, 0)) + .as_secs() + ).as_bytes(), + ) + .body(Body::from(packet)) + .unwrap(); future::ok(response) }); Box::new(fut) } - fn read_body_and_proxy(&self, body: Body) -> Box> { + fn read_body_and_proxy(&self, body: &Body) -> Box, Error = ()>> { let mut sum_size = 0; let inner = self.clone(); - let fut = - body.and_then(move |chunk| { + let fut = body + .map_err(move |_err| ()) + .and_then(move |chunk| { sum_size += chunk.len(); if sum_size > MAX_DNS_QUESTION_LEN { - Err(hyper::error::Error::TooLarge) + Err(()) } else { Ok(chunk) } - }).concat2() - .map_err(move |_err| ()) - .map(move |chunk| chunk.to_vec()) - .and_then(move |query| { - if query.len() < MIN_DNS_PACKET_LEN { - return Box::new(future::err(())) as Box>; - } - Box::new(inner.proxy(query)) - }); + }) + .concat2() + .map_err(move |_err| ()) + .map(move |chunk| chunk.to_vec()) + .and_then(move |query| { + if query.len() < MIN_DNS_PACKET_LEN { + return Box::new(future::err(())) as Box>; + } + Box::new(inner.proxy(query)) + }); Box::new(fut) } } +struct DohNewService { + doh: DoH, +} + +impl NewService for DohNewService { + type ReqBody = Body; + type ResBody = Body; + type Error = io::Error; + type Service = DoH; + type Future = Box>; + type InitError = io::Error; + + fn new_service(&self) -> Self::Future { + Box::new(future::ok(self.doh.clone())) + } +} + +impl DohNewService { + fn new(doh: DoH) -> Self { + DohNewService { doh } + } +} + fn main() { let mut doh = DoH { listen_address: LISTEN_ADDRESS.parse().unwrap(), @@ -194,16 +245,18 @@ fn main() { max_clients: MAX_CLIENTS, timeout: Duration::from_secs(TIMEOUT_SEC), clients_count: Rc::new(RefCell::new(0u32)), - timers: tokio_timer::wheel().build(), + timer_handle: Timer::default().handle(), }; parse_opts(&mut doh); let listen_address = doh.listen_address; - let listener = TcpListener::bind(&listen_address).unwrap(); println!("Listening on http://{}", listen_address); - let doh = Rc::new(doh); + let new_service = DohNewService::new(doh); let server = Http::new() .keep_alive(false) - .serve_incoming(listener.incoming(), move || Ok(doh.clone())); + .max_buf_size(0xffff) + .pipeline_flush(true) + .serve_addr(&listen_address, || new_service) + .unwrap(); let fut = server.for_each(move |client_fut| { current_thread::spawn(client_fut.map(|_| {}).map_err(|_| {})); Ok(())