From 6818fbe8a14920ee9faaac8ab82b45b0582ea60a Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Sun, 25 Dec 2022 12:37:48 +0100 Subject: [PATCH] Update to clap 4 The new API is confusing and very error-prone, with errors being thrown at runtime rather than compile-time. Hopefully nothing got broken in the process. --- Cargo.toml | 4 +- src/config.rs | 124 ++++++++++++++++++++++++++++++++------------------ src/utils.rs | 16 +++---- 3 files changed, 88 insertions(+), 56 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fa976e0..2a09b98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,9 +17,9 @@ tls = ["libdoh/tls"] [dependencies] libdoh = { path = "src/libdoh", version = "0.9.5", default-features = false } -clap = { version = "3", features = ["std", "cargo", "wrap_help"] } +clap = { version = "4", features = ["std", "cargo", "wrap_help", "string"] } dnsstamps = "0.1.9" -mimalloc = { version = "0.1.29", default-features = false } +mimalloc = { version = "0.1.32", default-features = false } [package.metadata.deb] extended-description = """\ diff --git a/src/config.rs b/src/config.rs index 278e80e..643463a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,7 +3,7 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSoc use std::path::PathBuf; use std::time::Duration; -use clap::Arg; +use clap::{Arg, ArgAction::SetTrue}; use libdoh::*; use crate::constants::*; @@ -24,54 +24,54 @@ pub fn parse_opts(globals: &mut Globals) { Arg::new("hostname") .short('H') .long("hostname") - .takes_value(true) + .num_args(1) .help("Host name (not IP address) DoH clients will use to connect"), ) .arg( Arg::new("public_address") .short('g') .long("public-address") - .takes_value(true) + .num_args(1) .help("External IP address DoH clients will connect to"), ) .arg( Arg::new("public_port") .short('j') .long("public-port") - .takes_value(true) + .num_args(1) .help("External port DoH clients will connect to, if not 443"), ) .arg( Arg::new("listen_address") .short('l') .long("listen-address") - .takes_value(true) + .num_args(1) .default_value(LISTEN_ADDRESS) - .validator(verify_sock_addr) + .value_parser(verify_sock_addr) .help("Address to listen to"), ) .arg( Arg::new("server_address") .short('u') .long("server-address") - .takes_value(true) + .num_args(1) .default_value(SERVER_ADDRESS) - .validator(verify_remote_server) + .value_parser(verify_remote_server) .help("Address to connect to"), ) .arg( Arg::new("local_bind_address") .short('b') .long("local-bind-address") - .takes_value(true) - .validator(verify_sock_addr) + .num_args(1) + .value_parser(verify_sock_addr) .help("Address to connect from"), ) .arg( Arg::new("path") .short('p') .long("path") - .takes_value(true) + .num_args(1) .default_value(PATH) .help("URI path"), ) @@ -79,65 +79,68 @@ pub fn parse_opts(globals: &mut Globals) { Arg::new("max_clients") .short('c') .long("max-clients") - .takes_value(true) - .default_value(&max_clients) + .num_args(1) + .default_value(max_clients) .help("Maximum number of simultaneous clients"), ) .arg( Arg::new("max_concurrent") .short('C') .long("max-concurrent") - .takes_value(true) - .default_value(&max_concurrent_streams) + .num_args(1) + .default_value(max_concurrent_streams) .help("Maximum number of concurrent requests per client"), ) .arg( Arg::new("timeout") .short('t') .long("timeout") - .takes_value(true) - .default_value(&timeout_sec) + .num_args(1) + .default_value(timeout_sec) .help("Timeout, in seconds"), ) .arg( Arg::new("min_ttl") .short('T') .long("min-ttl") - .takes_value(true) - .default_value(&min_ttl) + .num_args(1) + .default_value(min_ttl) .help("Minimum TTL, in seconds"), ) .arg( Arg::new("max_ttl") .short('X') .long("max-ttl") - .takes_value(true) - .default_value(&max_ttl) + .num_args(1) + .default_value(max_ttl) .help("Maximum TTL, in seconds"), ) .arg( Arg::new("err_ttl") .short('E') .long("err-ttl") - .takes_value(true) - .default_value(&err_ttl) + .num_args(1) + .default_value(err_ttl) .help("TTL for errors, in seconds"), ) .arg( Arg::new("disable_keepalive") .short('K') + .action(SetTrue) .long("disable-keepalive") .help("Disable keepalive"), ) .arg( Arg::new("disable_post") .short('P') + .action(SetTrue) .long("disable-post") .help("Disable POST queries"), ) .arg( Arg::new("allow_odoh_post") .short('O') + .action(SetTrue) .long("allow-odoh-post") .help("Allow POST queries over ODoH even if they have been disabed for DoH"), ); @@ -148,7 +151,7 @@ pub fn parse_opts(globals: &mut Globals) { Arg::new("tls_cert_path") .short('i') .long("tls-cert-path") - .takes_value(true) + .num_args(1) .help( "Path to the PEM/PKCS#8-encoded certificates (only required for built-in TLS)", ), @@ -157,21 +160,24 @@ pub fn parse_opts(globals: &mut Globals) { Arg::new("tls_cert_key_path") .short('I') .long("tls-cert-key-path") - .takes_value(true) + .num_args(1) .help("Path to the PEM-encoded secret keys (only required for built-in TLS)"), ); let matches = options.get_matches(); - globals.listen_address = matches.value_of("listen_address").unwrap().parse().unwrap(); - + globals.listen_address = matches + .get_one::("listen_address") + .unwrap() + .parse() + .unwrap(); globals.server_address = matches - .value_of("server_address") + .get_one::("server_address") .unwrap() .to_socket_addrs() .unwrap() .next() .unwrap(); - globals.local_bind_address = match matches.value_of("local_bind_address") { + globals.local_bind_address = match matches.get_one::("local_bind_address") { Some(address) => address.parse().unwrap(), None => match globals.server_address { SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), @@ -183,36 +189,64 @@ pub fn parse_opts(globals: &mut Globals) { )), }, }; - globals.path = matches.value_of("path").unwrap().to_string(); + globals.path = matches.get_one::("path").unwrap().to_string(); if !globals.path.starts_with('/') { globals.path = format!("/{}", globals.path); } - globals.max_clients = matches.value_of("max_clients").unwrap().parse().unwrap(); - globals.timeout = Duration::from_secs(matches.value_of("timeout").unwrap().parse().unwrap()); - globals.max_concurrent_streams = matches.value_of("max_concurrent").unwrap().parse().unwrap(); - globals.min_ttl = matches.value_of("min_ttl").unwrap().parse().unwrap(); - globals.max_ttl = matches.value_of("max_ttl").unwrap().parse().unwrap(); - globals.err_ttl = matches.value_of("err_ttl").unwrap().parse().unwrap(); - globals.keepalive = !matches.is_present("disable_keepalive"); - globals.disable_post = matches.is_present("disable_post"); - globals.allow_odoh_post = matches.is_present("allow_odoh_post"); + globals.max_clients = matches + .get_one::("max_clients") + .unwrap() + .parse() + .unwrap(); + globals.timeout = Duration::from_secs( + matches + .get_one::("timeout") + .unwrap() + .parse() + .unwrap(), + ); + globals.max_concurrent_streams = matches + .get_one::("max_concurrent") + .unwrap() + .parse() + .unwrap(); + globals.min_ttl = matches + .get_one::("min_ttl") + .unwrap() + .parse() + .unwrap(); + globals.max_ttl = matches + .get_one::("max_ttl") + .unwrap() + .parse() + .unwrap(); + globals.err_ttl = matches + .get_one::("err_ttl") + .unwrap() + .parse() + .unwrap(); + globals.keepalive = !matches.get_flag("disable_keepalive"); + globals.disable_post = matches.get_flag("disable_post"); + globals.allow_odoh_post = matches.get_flag("allow_odoh_post"); #[cfg(feature = "tls")] { - globals.tls_cert_path = matches.value_of("tls_cert_path").map(PathBuf::from); + globals.tls_cert_path = matches + .get_one::("tls_cert_path") + .map(PathBuf::from); globals.tls_cert_key_path = matches - .value_of("tls_cert_key_path") + .get_one::("tls_cert_key_path") .map(PathBuf::from) .or_else(|| globals.tls_cert_path.clone()); } - if let Some(hostname) = matches.value_of("hostname") { + if let Some(hostname) = matches.get_one::("hostname") { let mut builder = dnsstamps::DoHBuilder::new(hostname.to_string(), globals.path.to_string()); - if let Some(public_address) = matches.value_of("public_address") { + if let Some(public_address) = matches.get_one::("public_address") { builder = builder.with_address(public_address.to_string()); } - if let Some(public_port) = matches.value_of("public_port") { + if let Some(public_port) = matches.get_one::("public_port") { let public_port = public_port.parse().expect("Invalid public port"); builder = builder.with_port(public_port); } @@ -224,7 +258,7 @@ pub fn parse_opts(globals: &mut Globals) { let mut builder = dnsstamps::ODoHTargetBuilder::new(hostname.to_string(), globals.path.to_string()); - if let Some(public_port) = matches.value_of("public_port") { + if let Some(public_port) = matches.get_one::("public_port") { let public_port = public_port.parse().expect("Invalid public port"); builder = builder.with_port(public_port); } diff --git a/src/utils.rs b/src/utils.rs index 0ec0eff..5440def 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -2,25 +2,23 @@ use std::net::{SocketAddr, ToSocketAddrs}; -pub(crate) fn verify_sock_addr(arg_val: &str) -> Result<(), String> { +pub(crate) fn verify_sock_addr(arg_val: &str) -> Result { match arg_val.parse::() { - Ok(_addr) => Ok(()), + Ok(_addr) => Ok(arg_val.to_string()), Err(_) => Err(format!( - "Could not parse \"{}\" as a valid socket address (with port).", - arg_val + "Could not parse \"{arg_val}\" as a valid socket address (with port)." )), } } -pub(crate) fn verify_remote_server(arg_val: &str) -> Result<(), String> { +pub(crate) fn verify_remote_server(arg_val: &str) -> Result { match arg_val.to_socket_addrs() { Ok(mut addr_iter) => match addr_iter.next() { - Some(_) => Ok(()), + Some(_) => Ok(arg_val.to_string()), None => Err(format!( - "Could not parse \"{}\" as a valid remote uri", - arg_val + "Could not parse \"{arg_val}\" as a valid remote uri" )), }, - Err(err) => Err(format!("{}", err)), + Err(err) => Err(format!("{err}")), } }