Move the TLS stuff to a dedicated file

This commit is contained in:
Frank Denis 2019-12-23 16:56:56 +01:00
parent f7770951da
commit 02ce4c9e9b
5 changed files with 89 additions and 77 deletions

View file

@ -16,7 +16,6 @@ default = []
tls = ["native-tls", "tokio-tls"]
[dependencies]
anyhow = "1.0"
base64 = "0.11"
clap = "2.33.0"
futures = { version = "0.3" }

View file

@ -1,8 +1,13 @@
use crate::constants::*;
use crate::globals::*;
use clap::Arg;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
use std::time::Duration;
#[cfg(feature = "tls")]
use std::path::PathBuf;
use super::*;
pub fn parse_opts(globals: &mut Globals) {
use crate::utils::{verify_remote_server, verify_sock_addr};

View file

@ -1,5 +1,3 @@
pub use anyhow::{anyhow, bail, ensure, Error};
use hyper::StatusCode;
use std::io;

View file

@ -9,6 +9,8 @@ mod constants;
mod dns;
mod errors;
mod globals;
#[cfg(feature = "tls")]
mod tls;
mod utils;
use crate::config::*;
@ -16,62 +18,25 @@ use crate::constants::*;
use crate::errors::*;
use crate::globals::*;
use clap::Arg;
#[cfg(feature = "tls")]
use crate::tls::*;
use futures::future;
use futures::prelude::*;
use futures::task::{Context, Poll};
use hyper::http;
use hyper::server::conn::Http;
use hyper::{Body, Method, Request, Response, StatusCode};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, UdpSocket};
#[cfg(feature = "tls")]
use native_tls::{self, Identity};
#[cfg(feature = "tls")]
use std::fs::File;
#[cfg(feature = "tls")]
use std::io;
#[cfg(feature = "tls")]
use std::io::Read;
#[cfg(feature = "tls")]
use std::path::Path;
#[cfg(feature = "tls")]
use tokio_tls::TlsAcceptor;
#[derive(Clone, Debug)]
struct DoH {
globals: Arc<Globals>,
}
#[cfg(feature = "tls")]
fn create_tls_acceptor<P>(path: P, password: &str) -> io::Result<TlsAcceptor>
where
P: AsRef<Path>,
{
let identity_bin = {
let mut fp = File::open(path)?;
let mut identity_bin = vec![];
fp.read_to_end(&mut identity_bin)?;
identity_bin
};
let identity = Identity::from_pkcs12(&identity_bin, password).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Unusable PKCS12-encoded identity. The encoding and/or the password may be wrong",
)
})?;
let native_acceptor = native_tls::TlsAcceptor::new(identity).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Unable to use the provided PKCS12-encoded identity",
)
})?;
Ok(TlsAcceptor::from(native_acceptor))
pub struct DoH {
pub globals: Arc<Globals>,
}
impl hyper::service::Service<http::Request<Body>> for DoH {
@ -268,7 +233,11 @@ impl DoH {
});
}
async fn start_without_tls(self, mut listener: TcpListener, server: Http) -> Result<(), Error> {
async fn start_without_tls(
self,
mut listener: TcpListener,
server: Http,
) -> Result<(), DoHError> {
let listener_service = async {
while let Some(stream) = listener.incoming().next().await {
let stream = match stream {
@ -277,40 +246,17 @@ impl DoH {
};
self.clone().client_serve(stream, server.clone()).await;
}
Ok(()) as Result<(), Error>
Ok(()) as Result<(), DoHError>
};
listener_service.await?;
Ok(())
}
#[cfg(feature = "tls")]
async fn start_with_tls(
self,
tls_acceptor: TlsAcceptor,
mut listener: TcpListener,
server: Http,
) -> Result<(), Error> {
let listener_service = async {
while let Some(raw_stream) = listener.incoming().next().await {
let raw_stream = match raw_stream {
Ok(raw_stream) => raw_stream,
Err(_) => continue,
};
let stream = match tls_acceptor.accept(raw_stream).await {
Ok(stream) => stream,
Err(_) => continue,
};
self.clone().client_serve(stream, server.clone()).await;
}
Ok(()) as Result<(), Error>
};
listener_service.await?;
Ok(())
}
async fn entrypoint(self) -> Result<(), Error> {
async fn entrypoint(self) -> Result<(), DoHError> {
let listen_address = self.globals.listen_address;
let listener = TcpListener::bind(&listen_address).await?;
let listener = TcpListener::bind(&listen_address)
.await
.map_err(|e| DoHError::Io(e))?;
let path = &self.globals.path;
#[cfg(feature = "tls")]

64
src/tls.rs Normal file
View file

@ -0,0 +1,64 @@
use crate::errors::*;
use crate::DoH;
use hyper::server::conn::Http;
use native_tls::{self, Identity};
use std::fs::File;
use std::io;
use std::io::Read;
use std::path::Path;
use tokio::stream::StreamExt;
pub use tokio_tls::TlsAcceptor;
use tokio::net::TcpListener;
pub fn create_tls_acceptor<P>(path: P, password: &str) -> io::Result<TlsAcceptor>
where
P: AsRef<Path>,
{
let identity_bin = {
let mut fp = File::open(path)?;
let mut identity_bin = vec![];
fp.read_to_end(&mut identity_bin)?;
identity_bin
};
let identity = Identity::from_pkcs12(&identity_bin, password).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Unusable PKCS12-encoded identity. The encoding and/or the password may be wrong",
)
})?;
let native_acceptor = native_tls::TlsAcceptor::new(identity).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Unable to use the provided PKCS12-encoded identity",
)
})?;
Ok(TlsAcceptor::from(native_acceptor))
}
impl DoH {
pub async fn start_with_tls(
self,
tls_acceptor: TlsAcceptor,
mut listener: TcpListener,
server: Http,
) -> Result<(), DoHError> {
let listener_service = async {
while let Some(raw_stream) = listener.incoming().next().await {
let raw_stream = match raw_stream {
Ok(raw_stream) => raw_stream,
Err(_) => continue,
};
let stream = match tls_acceptor.accept(raw_stream).await {
Ok(stream) => stream,
Err(_) => continue,
};
self.clone().client_serve(stream, server.clone()).await;
}
Ok(()) as Result<(), DoHError>
};
listener_service.await?;
Ok(())
}
}