diff --git a/Cargo.toml b/Cargo.toml index 46998bd..0d2d343 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,8 +13,8 @@ categories = ["asynchronous", "network-programming","command-line-utilities"] base64 = "~0.9" clap = "~2" futures = "~0.1" -hyper = "~0.12" +hyper = "~0.11" tokio = "~0.1" tokio-io = "~0.1" -tokio-timer = "~0.2" +tokio-timer = "~0.1" clippy = {version = ">=0", optional = true} diff --git a/src/main.rs b/src/main.rs index 94f4c3e..ee3e7d5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,18 +14,16 @@ mod dns; use clap::{App, Arg}; use futures::future; use futures::prelude::*; -use hyper::header::{CONTENT_LENGTH, CONTENT_TYPE, EXPIRES}; -use hyper::server::conn::Http; -use hyper::service::{NewService, Service}; -use hyper::{Body, Method, Request, Response, Server, StatusCode}; +use hyper::header::{CacheControl, CacheDirective, ContentLength, ContentType}; +use hyper::server::{Http, Request, Response, Service}; +use hyper::{Body, Method, StatusCode}; use std::cell::RefCell; -use std::io; use std::net::SocketAddr; use std::rc::Rc; -use std::time::{Duration, Instant, SystemTime}; +use std::time::Duration; use tokio::executor::current_thread; use tokio::net::{TcpListener, UdpSocket}; -use tokio_timer::{timer, Timer}; +use tokio_timer::Timer; const DNS_QUERY_PARAM: &str = "dns"; const LISTEN_ADDRESS: &str = "127.0.0.1:3000"; @@ -49,24 +47,22 @@ struct DoH { path: String, max_clients: u32, timeout: Duration, - timer_handle: timer::Handle, + timers: Timer, clients_count: Rc>, } impl Service for DoH { - type ReqBody = Body; - type ResBody = Body; - type Error = io::Error; - type Future = Box, Error = io::Error>>; + type Request = Request; + type Response = Response; + type Error = hyper::Error; + type Future = Box>; - fn call(&mut self, req: Request) -> Self::Future { + fn call(&self, req: Request) -> Self::Future { { let count = self.clients_count.borrow_mut(); if *count > self.max_clients { - let response = Response::builder() - .status(StatusCode::TOO_MANY_REQUESTS) - .body(Body::empty()) - .unwrap(); + let mut response = Response::new(); + response.set_status(StatusCode::TooManyRequests); return Box::new(future::ok(response)); } (*count).saturating_add(1); @@ -83,34 +79,27 @@ impl Service for DoH { err }); let timed = self - .timer_handle - .deadline(fut.map_err(|_| {}), Instant::now() + self.timeout) - .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "Timeout")); + .timers + .timeout(fut.map_err(|_| {}), self.timeout) + .map_err(|_| hyper::Error::Timeout); Box::new(timed) } } impl DoH { - fn handle_client( - &self, - req: Request, - ) -> Box, Error = io::Error>> { - if req.uri().path() != self.path { - let response = Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::empty()) - .unwrap(); + fn handle_client(&self, req: Request) -> Box> { + let mut response = Response::new(); + if req.path() != self.path { + response.set_status(StatusCode::NotFound); return Box::new(future::ok(response)); } match *req.method() { - Method::POST => { + Method::Post => { let fut = self.read_body_and_proxy(req.body()); - return Box::new( - fut.map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "EOF")), - ); + return Box::new(fut.map_err(|_| hyper::Error::Incomplete)); } - Method::GET => { - let query = req.uri().query().unwrap_or(""); + Method::Get => { + let query = req.query().unwrap_or(""); let mut question_str = None; for parts in query.split('&') { let mut kv = parts.split('='); @@ -125,29 +114,21 @@ impl DoH { }) { Some(question) => question, _ => { - let mut response = Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::empty()) - .unwrap(); + response.set_status(StatusCode::BadRequest); return Box::new(future::ok(response)); } }; let fut = self.proxy(question); - return Box::new( - fut.map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "EOF")), - ); + return Box::new(fut.map_err(|_| hyper::Error::Incomplete)); } _ => { - let mut response = Response::builder() - .status(StatusCode::METHOD_NOT_ALLOWED) - .body(Body::empty()) - .unwrap(); - return Box::new(future::ok(response)); + response.set_status(StatusCode::MethodNotAllowed); } }; + Box::new(future::ok(response)) } - fn proxy(&self, query: Vec) -> Box, Error = ()>> { + fn proxy(&self, query: Vec) -> Box> { let socket = UdpSocket::bind(&self.local_bind_address).unwrap(); let expected_server_address = self.server_address; let fut = socket @@ -167,75 +148,43 @@ impl DoH { Ok(min_ttl) => min_ttl, }; let packet_len = packet.len(); - let mut response = Response::builder() - .header(CONTENT_LENGTH, format!("{}", packet_len).as_bytes()) - .header(CONTENT_TYPE, "application/dns-message") - .header( - EXPIRES, - format!( - "{}", - (SystemTime::now() + Duration::from_secs(ttl as u64)) - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap_or(Duration::new(0, 0)) - .as_secs() - ).as_bytes(), - ) - .body(Body::from(packet)) - .unwrap(); + let mut response = Response::new(); + response.set_body(packet); + let response = response + .with_header(ContentLength(packet_len as u64)) + .with_header(ContentType( + "application/dns-message".parse().unwrap(), + )) + .with_header(CacheControl(vec![CacheDirective::MaxAge(ttl)])); future::ok(response) }); Box::new(fut) } - fn read_body_and_proxy(&self, body: &Body) -> Box, Error = ()>> { + fn read_body_and_proxy(&self, body: Body) -> Box> { let mut sum_size = 0; let inner = self.clone(); - let fut = body - .map_err(move |_err| ()) - .and_then(move |chunk| { + let fut = + body.and_then(move |chunk| { sum_size += chunk.len(); if sum_size > MAX_DNS_QUESTION_LEN { - Err(()) + Err(hyper::error::Error::TooLarge) } else { Ok(chunk) } - }) - .concat2() - .map_err(move |_err| ()) - .map(move |chunk| chunk.to_vec()) - .and_then(move |query| { - if query.len() < MIN_DNS_PACKET_LEN { - return Box::new(future::err(())) as Box>; - } - Box::new(inner.proxy(query)) - }); + }).concat2() + .map_err(move |_err| ()) + .map(move |chunk| chunk.to_vec()) + .and_then(move |query| { + if query.len() < MIN_DNS_PACKET_LEN { + return Box::new(future::err(())) as Box>; + } + Box::new(inner.proxy(query)) + }); Box::new(fut) } } -struct DohNewService { - doh: DoH, -} - -impl NewService for DohNewService { - type ReqBody = Body; - type ResBody = Body; - type Error = io::Error; - type Service = DoH; - type Future = Box>; - type InitError = io::Error; - - fn new_service(&self) -> Self::Future { - Box::new(future::ok(self.doh.clone())) - } -} - -impl DohNewService { - fn new(doh: DoH) -> Self { - DohNewService { doh } - } -} - fn main() { let mut doh = DoH { listen_address: LISTEN_ADDRESS.parse().unwrap(), @@ -245,18 +194,16 @@ fn main() { max_clients: MAX_CLIENTS, timeout: Duration::from_secs(TIMEOUT_SEC), clients_count: Rc::new(RefCell::new(0u32)), - timer_handle: Timer::default().handle(), + timers: tokio_timer::wheel().build(), }; parse_opts(&mut doh); let listen_address = doh.listen_address; + let listener = TcpListener::bind(&listen_address).unwrap(); println!("Listening on http://{}", listen_address); - let new_service = DohNewService::new(doh); + let doh = Rc::new(doh); let server = Http::new() .keep_alive(false) - .max_buf_size(0xffff) - .pipeline_flush(true) - .serve_addr(&listen_address, || new_service) - .unwrap(); + .serve_incoming(listener.incoming(), move || Ok(doh.clone())); let fut = server.for_each(move |client_fut| { current_thread::spawn(client_fut.map(|_| {}).map_err(|_| {})); Ok(())