diff --git a/Cargo.toml b/Cargo.toml index 51228ba..1da9faf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ tls = ["native-tls", "tokio-tls"] [dependencies] base64 = "0.10" -clap = "2" +clap = "2.33.0" futures = "0.1.27" hyper = "0.12.30" jemallocator = "0" diff --git a/src/main.rs b/src/main.rs index 69cf371..6b541af 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,6 +14,7 @@ use hyper::server::conn::Http; use hyper::service::Service; use hyper::{Body, Method, Request, Response, StatusCode}; use std::io; +use std::net::{SocketAddr, ToSocketAddrs}; #[cfg(feature = "tls")] use native_tls::{self, Identity}; @@ -23,7 +24,6 @@ use std::fs::File; #[cfg(feature = "tls")] use std::io::{self, Read}; -use std::net::SocketAddr; #[cfg(feature = "tls")] use std::path::{Path, PathBuf}; @@ -483,6 +483,8 @@ fn main() { } fn parse_opts(inner_doh: &mut InnerDoH) { + use crate::utils::{verify_remote_server, verify_sock_addr}; + let max_clients = MAX_CLIENTS.to_string(); let timeout_sec = TIMEOUT_SEC.to_string(); let min_ttl = MIN_TTL.to_string(); @@ -497,6 +499,7 @@ fn parse_opts(inner_doh: &mut InnerDoH) { .long("listen-address") .takes_value(true) .default_value(LISTEN_ADDRESS) + .validator(verify_sock_addr) .help("Address to listen to"), ) .arg( @@ -505,6 +508,7 @@ fn parse_opts(inner_doh: &mut InnerDoH) { .long("server-address") .takes_value(true) .default_value(SERVER_ADDRESS) + .validator(verify_remote_server) .help("Address to connect to"), ) .arg( @@ -513,6 +517,7 @@ fn parse_opts(inner_doh: &mut InnerDoH) { .long("local-bind-address") .takes_value(true) .default_value(LOCAL_BIND_ADDRESS) + .validator(verify_sock_addr) .help("Address to connect from"), ) .arg( @@ -595,7 +600,14 @@ fn parse_opts(inner_doh: &mut InnerDoH) { let matches = options.get_matches(); inner_doh.listen_address = matches.value_of("listen_address").unwrap().parse().unwrap(); - inner_doh.server_address = matches.value_of("server_address").unwrap().parse().unwrap(); + + inner_doh.server_address = matches + .value_of("server_address") + .unwrap() + .to_socket_addrs() + .unwrap() + .next() + .unwrap(); inner_doh.local_bind_address = matches .value_of("local_bind_address") .unwrap() diff --git a/src/utils.rs b/src/utils.rs index 984e843..b170fa8 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,31 @@ +use std::net::{SocketAddr, ToSocketAddrs}; + pub(crate) fn padding_string(input_size: usize, block_size: usize) -> String { let block_size_ = block_size - 1; let padding_len = block_size_ - ((input_size + block_size_) & block_size_); String::from_utf8(vec![b'X'; padding_len]).unwrap() } + +// functions to verify the startup arguments as correct +pub(crate) fn verify_sock_addr(arg_val: String) -> Result<(), String> { + match arg_val.parse::() { + Ok(_addr) => Ok(()), + Err(_) => Err(format!( + "Could not parse \"{}\" as a valid socket address (with port).", + arg_val + )), + } +} + +pub(crate) fn verify_remote_server(arg_val: String) -> Result<(), String> { + match arg_val.to_socket_addrs() { + Ok(mut addr_iter) => match addr_iter.next() { + Some(_) => Ok(()), + None => Err(format!( + "Could not parse \"{}\" as a valid remote uri", + arg_val + )), + }, + Err(err) => Err(format!("{}", err)), + } +}