From 02ce4c9e9b970ba812b2143cc1f3cbf33ed7fc27 Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Mon, 23 Dec 2019 16:56:56 +0100 Subject: [PATCH] Move the TLS stuff to a dedicated file --- Cargo.toml | 1 - src/config.rs | 9 ++++-- src/errors.rs | 2 -- src/main.rs | 90 +++++++++++---------------------------------------- src/tls.rs | 64 ++++++++++++++++++++++++++++++++++++ 5 files changed, 89 insertions(+), 77 deletions(-) create mode 100644 src/tls.rs diff --git a/Cargo.toml b/Cargo.toml index 71083d4..53e3548 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,6 @@ default = [] tls = ["native-tls", "tokio-tls"] [dependencies] -anyhow = "1.0" base64 = "0.11" clap = "2.33.0" futures = { version = "0.3" } diff --git a/src/config.rs b/src/config.rs index b02211a..6fbfba5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,8 +1,13 @@ +use crate::constants::*; +use crate::globals::*; + +use clap::Arg; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; +use std::time::Duration; + #[cfg(feature = "tls")] use std::path::PathBuf; -use super::*; - pub fn parse_opts(globals: &mut Globals) { use crate::utils::{verify_remote_server, verify_sock_addr}; diff --git a/src/errors.rs b/src/errors.rs index 1eccc52..73cecff 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,5 +1,3 @@ -pub use anyhow::{anyhow, bail, ensure, Error}; - use hyper::StatusCode; use std::io; diff --git a/src/main.rs b/src/main.rs index 6bac0db..e69f8fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,8 @@ mod constants; mod dns; mod errors; mod globals; +#[cfg(feature = "tls")] +mod tls; mod utils; use crate::config::*; @@ -16,62 +18,25 @@ use crate::constants::*; use crate::errors::*; use crate::globals::*; -use clap::Arg; +#[cfg(feature = "tls")] +use crate::tls::*; + use futures::future; use futures::prelude::*; use futures::task::{Context, Poll}; use hyper::http; use hyper::server::conn::Http; use hyper::{Body, Method, Request, Response, StatusCode}; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, UdpSocket}; -#[cfg(feature = "tls")] -use native_tls::{self, Identity}; -#[cfg(feature = "tls")] -use std::fs::File; -#[cfg(feature = "tls")] -use std::io; -#[cfg(feature = "tls")] -use std::io::Read; -#[cfg(feature = "tls")] -use std::path::Path; -#[cfg(feature = "tls")] -use tokio_tls::TlsAcceptor; - #[derive(Clone, Debug)] -struct DoH { - globals: Arc, -} - -#[cfg(feature = "tls")] -fn create_tls_acceptor

(path: P, password: &str) -> io::Result -where - P: AsRef, -{ - let identity_bin = { - let mut fp = File::open(path)?; - let mut identity_bin = vec![]; - fp.read_to_end(&mut identity_bin)?; - identity_bin - }; - let identity = Identity::from_pkcs12(&identity_bin, password).map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Unusable PKCS12-encoded identity. The encoding and/or the password may be wrong", - ) - })?; - let native_acceptor = native_tls::TlsAcceptor::new(identity).map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Unable to use the provided PKCS12-encoded identity", - ) - })?; - Ok(TlsAcceptor::from(native_acceptor)) +pub struct DoH { + pub globals: Arc, } impl hyper::service::Service> for DoH { @@ -268,7 +233,11 @@ impl DoH { }); } - async fn start_without_tls(self, mut listener: TcpListener, server: Http) -> Result<(), Error> { + async fn start_without_tls( + self, + mut listener: TcpListener, + server: Http, + ) -> Result<(), DoHError> { let listener_service = async { while let Some(stream) = listener.incoming().next().await { let stream = match stream { @@ -277,40 +246,17 @@ impl DoH { }; self.clone().client_serve(stream, server.clone()).await; } - Ok(()) as Result<(), Error> + Ok(()) as Result<(), DoHError> }; listener_service.await?; Ok(()) } - #[cfg(feature = "tls")] - async fn start_with_tls( - self, - tls_acceptor: TlsAcceptor, - mut listener: TcpListener, - server: Http, - ) -> Result<(), Error> { - let listener_service = async { - while let Some(raw_stream) = listener.incoming().next().await { - let raw_stream = match raw_stream { - Ok(raw_stream) => raw_stream, - Err(_) => continue, - }; - let stream = match tls_acceptor.accept(raw_stream).await { - Ok(stream) => stream, - Err(_) => continue, - }; - self.clone().client_serve(stream, server.clone()).await; - } - Ok(()) as Result<(), Error> - }; - listener_service.await?; - Ok(()) - } - - async fn entrypoint(self) -> Result<(), Error> { + async fn entrypoint(self) -> Result<(), DoHError> { let listen_address = self.globals.listen_address; - let listener = TcpListener::bind(&listen_address).await?; + let listener = TcpListener::bind(&listen_address) + .await + .map_err(|e| DoHError::Io(e))?; let path = &self.globals.path; #[cfg(feature = "tls")] diff --git a/src/tls.rs b/src/tls.rs new file mode 100644 index 0000000..607ee34 --- /dev/null +++ b/src/tls.rs @@ -0,0 +1,64 @@ +use crate::errors::*; +use crate::DoH; + +use hyper::server::conn::Http; +use native_tls::{self, Identity}; +use std::fs::File; +use std::io; +use std::io::Read; +use std::path::Path; +use tokio::stream::StreamExt; +pub use tokio_tls::TlsAcceptor; + +use tokio::net::TcpListener; + +pub fn create_tls_acceptor

(path: P, password: &str) -> io::Result +where + P: AsRef, +{ + let identity_bin = { + let mut fp = File::open(path)?; + let mut identity_bin = vec![]; + fp.read_to_end(&mut identity_bin)?; + identity_bin + }; + let identity = Identity::from_pkcs12(&identity_bin, password).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unusable PKCS12-encoded identity. The encoding and/or the password may be wrong", + ) + })?; + let native_acceptor = native_tls::TlsAcceptor::new(identity).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to use the provided PKCS12-encoded identity", + ) + })?; + Ok(TlsAcceptor::from(native_acceptor)) +} + +impl DoH { + pub async fn start_with_tls( + self, + tls_acceptor: TlsAcceptor, + mut listener: TcpListener, + server: Http, + ) -> Result<(), DoHError> { + let listener_service = async { + while let Some(raw_stream) = listener.incoming().next().await { + let raw_stream = match raw_stream { + Ok(raw_stream) => raw_stream, + Err(_) => continue, + }; + let stream = match tls_acceptor.accept(raw_stream).await { + Ok(stream) => stream, + Err(_) => continue, + }; + self.clone().client_serve(stream, server.clone()).await; + } + Ok(()) as Result<(), DoHError> + }; + listener_service.await?; + Ok(()) + } +}