diff --git a/Cargo.toml b/Cargo.toml index 8c03da9..6381d5b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,11 +10,10 @@ repository = "https://github.com/jedisct1/rust-doh" categories = ["asynchronous", "network-programming","command-line-utilities"] [dependencies] -base64 = "~0.9" -clap = "~2" -futures = "~0.1" -hyper = "~0.11" -tokio = "~0.1" -tokio-io = "~0.1" -tokio-timer = "~0.2" +base64 = "0.9" +clap = "2" +futures = "0.1" +hyper = "0.12" +tokio = "0.1" +tokio-timer = "0.2" clippy = {version = ">=0", optional = true} diff --git a/src/main.rs b/src/main.rs index 14d9600..6ec009d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,6 @@ extern crate clap; extern crate futures; extern crate hyper; extern crate tokio; -extern crate tokio_io; extern crate tokio_timer; mod dns; @@ -14,12 +13,11 @@ mod dns; use clap::{App, Arg}; use futures::future; use futures::prelude::*; -use hyper::header::{CacheControl, CacheDirective, ContentLength, ContentType}; -use hyper::server::{Http, Request, Response, Service}; -use hyper::{Body, Method, StatusCode}; -use std::cell::RefCell; +use hyper::service::Service; +use hyper::server::conn::Http; +use hyper::{Request, Response, Body, Method, StatusCode}; use std::net::SocketAddr; -use std::rc::Rc; +use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use tokio::executor::current_thread; use tokio::net::{TcpListener, UdpSocket}; @@ -48,21 +46,38 @@ struct DoH { max_clients: u32, timeout: Duration, timer_handle: timer::Handle, - clients_count: Rc>, + clients_count: Arc>, } -impl Service for DoH { - type Request = Request; - type Response = Response; - type Error = hyper::Error; - type Future = Box>; +#[derive(Debug)] +enum Error { + Timeout, + Incomplete, + TooLarge, + Hyper(hyper::Error), +} +impl std::fmt::Display for Error { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { + // Should match, i'm lazy... + std::fmt::Debug::fmt(self, fmt) + } +} +impl std::error::Error for Error {} - fn call(&self, req: Request) -> Self::Future { +impl Service for DoH { + type ReqBody = Body; + type ResBody = Body; + type Error = Error; + type Future = Box, Error = Self::Error> + Send>; + + fn call(&mut self, req: Request) -> Self::Future { { - let count = self.clients_count.borrow_mut(); + let count = self.clients_count.lock().unwrap(); if *count > self.max_clients { - let mut response = Response::new(); - response.set_status(StatusCode::TooManyRequests); + let response = Response::builder() + .status(StatusCode::TOO_MANY_REQUESTS) + .body(Body::empty()) + .unwrap(); return Box::new(future::ok(response)); } (*count).saturating_add(1); @@ -71,35 +86,37 @@ impl Service for DoH { let fut = self .handle_client(req) .then(move |fut| { - (*clients_count_inner).borrow_mut().saturating_sub(1); + (*clients_count_inner).lock().unwrap().saturating_sub(1); fut }) .map_err(|err| { - eprintln!("server error: {:?}", err); + eprintln!("server error: {}", err); err }); let timed = self .timer_handle .deadline(fut.map_err(|_| {}), Instant::now() + self.timeout) - .map_err(|_| hyper::Error::Timeout); + .map_err(|_| Error::Timeout); Box::new(timed) } } impl DoH { - fn handle_client(&self, req: Request) -> Box> { - let mut response = Response::new(); - if req.path() != self.path { - response.set_status(StatusCode::NotFound); + fn handle_client(&self, req: Request) -> Box, Error = Error> + Send> { + if req.uri().path() != self.path { + let response = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap(); return Box::new(future::ok(response)); } match *req.method() { - Method::Post => { - let fut = self.read_body_and_proxy(req.body()); - return Box::new(fut.map_err(|_| hyper::Error::Incomplete)); + Method::POST => { + let fut = self.read_body_and_proxy(req.into_body()); + return Box::new(fut.map_err(|_| Error::Incomplete)); } - Method::Get => { - let query = req.query().unwrap_or(""); + Method::GET => { + let query = req.uri().query().unwrap_or(""); let mut question_str = None; for parts in query.split('&') { let mut kv = parts.split('='); @@ -114,21 +131,27 @@ impl DoH { }) { Some(question) => question, _ => { - response.set_status(StatusCode::BadRequest); + let response = Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::empty()) + .unwrap(); return Box::new(future::ok(response)); } }; let fut = self.proxy(question); - return Box::new(fut.map_err(|_| hyper::Error::Incomplete)); + return Box::new(fut.map_err(|_| Error::Incomplete)); } _ => { - response.set_status(StatusCode::MethodNotAllowed); + let response = Response::builder() + .status(StatusCode::METHOD_NOT_ALLOWED) + .body(Body::empty()) + .unwrap(); + return Box::new(future::ok(response)); } }; - Box::new(future::ok(response)) } - fn proxy(&self, query: Vec) -> Box> { + fn proxy(&self, query: Vec) -> Box, Error = ()> + Send> { let socket = UdpSocket::bind(&self.local_bind_address).unwrap(); let expected_server_address = self.server_address; let fut = socket @@ -148,25 +171,25 @@ impl DoH { Ok(min_ttl) => min_ttl, }; let packet_len = packet.len(); - let mut response = Response::new(); - response.set_body(packet); - let response = response - .with_header(ContentLength(packet_len as u64)) - .with_header(ContentType("application/dns-message".parse().unwrap())) - .with_header(CacheControl(vec![CacheDirective::MaxAge(ttl)])); + let response = Response::builder() + .header(hyper::header::CONTENT_LENGTH, packet_len) + .header(hyper::header::CONTENT_TYPE, "application/dns-message") + .header(hyper::header::CACHE_CONTROL, format!("max-age={}", ttl).as_str()) + .body(Body::from(packet)) + .unwrap(); future::ok(response) }); Box::new(fut) } - fn read_body_and_proxy(&self, body: Body) -> Box> { + fn read_body_and_proxy(&self, body: Body) -> Box, Error = ()> + Send> { let mut sum_size = 0; let inner = self.clone(); let fut = - body.and_then(move |chunk| { + body.map_err(|e| Error::Hyper(e)).and_then(move |chunk| { sum_size += chunk.len(); if sum_size > MAX_DNS_QUESTION_LEN { - Err(hyper::error::Error::TooLarge) + Err(Error::TooLarge) } else { Ok(chunk) } @@ -175,9 +198,9 @@ impl DoH { .map(move |chunk| chunk.to_vec()) .and_then(move |query| { if query.len() < MIN_DNS_PACKET_LEN { - return Box::new(future::err(())) as Box>; + return Box::new(future::err(())) as Box + Send>; } - Box::new(inner.proxy(query)) + inner.proxy(query) }); Box::new(fut) } @@ -191,22 +214,23 @@ fn main() { path: PATH.to_string(), max_clients: MAX_CLIENTS, timeout: Duration::from_secs(TIMEOUT_SEC), - clients_count: Rc::new(RefCell::new(0u32)), + clients_count: Arc::new(Mutex::new(0u32)), timer_handle: Timer::default().handle(), }; parse_opts(&mut doh); let listen_address = doh.listen_address; let listener = TcpListener::bind(&listen_address).unwrap(); println!("Listening on http://{}", listen_address); - let doh = Rc::new(doh); - let server = Http::new() - .keep_alive(false) - .serve_incoming(listener.incoming(), move || Ok(doh.clone())); - let fut = server.for_each(move |client_fut| { - current_thread::spawn(client_fut.map(|_| {}).map_err(|_| {})); + let mut http = Http::new(); + http.keep_alive(false); + let doh = Arc::new(Mutex::new(doh)); + let server = listener.incoming().for_each(move |io| { + let service = doh.lock().unwrap().clone(); + let conn = http.serve_connection(io, service).map_err(|_| {}); + current_thread::spawn(conn); Ok(()) }); - current_thread::block_on_all(fut).unwrap(); + current_thread::block_on_all(server).unwrap(); } fn parse_opts(doh: &mut DoH) {