Merge pull request #516 from str4d/515-cli-common-feature-bugs

age: Fix feature flag combination bugs in `cli_common` module
This commit is contained in:
Jack Grigg 2024-08-23 06:40:22 -07:00 committed by GitHub
commit cf96347fbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 110 additions and 50 deletions

View file

@ -1,7 +1,10 @@
use std::fmt; use std::fmt;
use std::io; use std::io;
use crate::{wfl, wlnfl, DecryptError}; use crate::{wfl, DecryptError};
#[cfg(feature = "plugin")]
use crate::wlnfl;
/// Errors that can occur while reading recipients or identities. /// Errors that can occur while reading recipients or identities.
#[derive(Debug)] #[derive(Debug)]

View file

@ -1,10 +1,10 @@
use std::io::{self, BufReader}; use std::io::{self, BufReader};
use super::{file_io::InputReader, ReadError, StdinGuard, UiCallbacks}; use super::{ReadError, StdinGuard, UiCallbacks};
use crate::{identity::IdentityFile, Identity}; use crate::{identity::IdentityFile, Identity};
#[cfg(feature = "armor")] #[cfg(feature = "armor")]
use crate::armor::ArmoredReader; use crate::{armor::ArmoredReader, cli_common::file_io::InputReader};
/// Reads identities from the provided files. /// Reads identities from the provided files.
/// ///
@ -23,10 +23,12 @@ pub fn read_identities(
max_work_factor, max_work_factor,
stdin_guard, stdin_guard,
&mut identities, &mut identities,
#[cfg(feature = "armor")]
|identities, identity| { |identities, identity| {
identities.push(Box::new(identity)); identities.push(Box::new(identity));
Ok(()) Ok(())
}, },
#[cfg(feature = "ssh")]
|identities, _, identity| { |identities, _, identity| {
identities.push(Box::new(identity.with_callbacks(UiCallbacks))); identities.push(Box::new(identity.with_callbacks(UiCallbacks)));
Ok(()) Ok(())
@ -62,7 +64,7 @@ pub fn read_identities(
/// Parses the provided identity files. /// Parses the provided identity files.
pub(super) fn parse_identity_files<Ctx, E: From<ReadError> + From<io::Error>>( pub(super) fn parse_identity_files<Ctx, E: From<ReadError> + From<io::Error>>(
filenames: Vec<String>, filenames: Vec<String>,
max_work_factor: Option<u8>, _max_work_factor: Option<u8>,
stdin_guard: &mut StdinGuard, stdin_guard: &mut StdinGuard,
ctx: &mut Ctx, ctx: &mut Ctx,
#[cfg(feature = "armor")] encrypted_identity: impl Fn( #[cfg(feature = "armor")] encrypted_identity: impl Fn(
@ -73,6 +75,7 @@ pub(super) fn parse_identity_files<Ctx, E: From<ReadError> + From<io::Error>>(
identity_file_entry: impl Fn(&mut Ctx, crate::IdentityFileEntry) -> Result<(), E>, identity_file_entry: impl Fn(&mut Ctx, crate::IdentityFileEntry) -> Result<(), E>,
) -> Result<(), E> { ) -> Result<(), E> {
for filename in filenames { for filename in filenames {
#[cfg_attr(not(any(feature = "armor", feature = "ssh")), allow(unused_mut))]
let mut reader = PeekableReader::new(BufReader::new( let mut reader = PeekableReader::new(BufReader::new(
stdin_guard.open(filename.clone()).map_err(|e| match e { stdin_guard.open(filename.clone()).map_err(|e| match e {
ReadError::Io(e) if matches!(e.kind(), io::ErrorKind::NotFound) => { ReadError::Io(e) if matches!(e.kind(), io::ErrorKind::NotFound) => {
@ -88,7 +91,7 @@ pub(super) fn parse_identity_files<Ctx, E: From<ReadError> + From<io::Error>>(
ArmoredReader::new_buffered(&mut reader), ArmoredReader::new_buffered(&mut reader),
Some(filename.clone()), Some(filename.clone()),
UiCallbacks, UiCallbacks,
max_work_factor, _max_work_factor,
) )
.is_ok() .is_ok()
{ {
@ -101,7 +104,7 @@ pub(super) fn parse_identity_files<Ctx, E: From<ReadError> + From<io::Error>>(
ArmoredReader::new_buffered(reader.inner), ArmoredReader::new_buffered(reader.inner),
Some(filename.clone()), Some(filename.clone()),
UiCallbacks, UiCallbacks,
max_work_factor, _max_work_factor,
) )
.expect("already parsed the age ciphertext header"); .expect("already parsed the age ciphertext header");
@ -160,6 +163,7 @@ impl<R: io::BufRead> PeekableReader<R> {
} }
} }
#[cfg(any(feature = "armor", feature = "ssh"))]
fn reset(&mut self) -> io::Result<()> { fn reset(&mut self) -> io::Result<()> {
match &mut self.state { match &mut self.state {
PeekState::Peeking { consumed } => { PeekState::Peeking { consumed } => {

View file

@ -1,15 +1,21 @@
use std::io::{self, BufReader}; use std::io::{self, BufReader};
use super::StdinGuard; use super::StdinGuard;
use super::{identities::parse_identity_files, ReadError, UiCallbacks}; use super::{identities::parse_identity_files, ReadError};
use crate::{x25519, EncryptError, IdentityFileEntry, Recipient}; use crate::{x25519, IdentityFileEntry, Recipient};
#[cfg(feature = "plugin")] #[cfg(feature = "plugin")]
use crate::plugin; use crate::{cli_common::UiCallbacks, plugin};
#[cfg(not(feature = "plugin"))]
use std::convert::Infallible;
#[cfg(feature = "ssh")] #[cfg(feature = "ssh")]
use crate::ssh; use crate::ssh;
#[cfg(any(feature = "armor", feature = "plugin"))]
use crate::EncryptError;
/// Handles error mapping for the given SSH recipient parser. /// Handles error mapping for the given SSH recipient parser.
/// ///
/// Returns `Ok(None)` if the parser finds a parseable value that should be ignored. This /// Returns `Ok(None)` if the parser finds a parseable value that should be ignored. This
@ -44,25 +50,35 @@ where
/// Parses a recipient from a string. /// Parses a recipient from a string.
fn parse_recipient( fn parse_recipient(
filename: &str, _filename: &str,
s: String, s: String,
recipients: &mut Vec<Box<dyn Recipient + Send>>, recipients: &mut Vec<Box<dyn Recipient + Send>>,
plugin_recipients: &mut Vec<plugin::Recipient>, #[cfg(feature = "plugin")] plugin_recipients: &mut Vec<plugin::Recipient>,
) -> Result<(), ReadError> { ) -> Result<(), ReadError> {
if let Ok(pk) = s.parse::<x25519::Recipient>() { if let Ok(pk) = s.parse::<x25519::Recipient>() {
recipients.push(Box::new(pk)); recipients.push(Box::new(pk));
} else if let Some(pk) = { } else if let Some(pk) = {
#[cfg(feature = "ssh")] #[cfg(feature = "ssh")]
{ {
parse_ssh_recipient(|| s.parse::<ssh::Recipient>(), || Ok(None), filename)? parse_ssh_recipient(|| s.parse::<ssh::Recipient>(), || Ok(None), _filename)?
} }
#[cfg(not(feature = "ssh"))] #[cfg(not(feature = "ssh"))]
None None
} { } {
recipients.push(pk); recipients.push(pk);
} else if let Ok(recipient) = s.parse::<plugin::Recipient>() { } else if let Some(_recipient) = {
plugin_recipients.push(recipient); #[cfg(feature = "plugin")]
{
// TODO Do something with the error?
s.parse::<plugin::Recipient>().ok()
}
#[cfg(not(feature = "plugin"))]
None::<Infallible>
} {
#[cfg(feature = "plugin")]
plugin_recipients.push(_recipient);
} else { } else {
return Err(ReadError::InvalidRecipient(s)); return Err(ReadError::InvalidRecipient(s));
} }
@ -75,7 +91,7 @@ fn read_recipients_list<R: io::BufRead>(
filename: &str, filename: &str,
buf: R, buf: R,
recipients: &mut Vec<Box<dyn Recipient + Send>>, recipients: &mut Vec<Box<dyn Recipient + Send>>,
plugin_recipients: &mut Vec<plugin::Recipient>, #[cfg(feature = "plugin")] plugin_recipients: &mut Vec<plugin::Recipient>,
) -> Result<(), ReadError> { ) -> Result<(), ReadError> {
for (line_number, line) in buf.lines().enumerate() { for (line_number, line) in buf.lines().enumerate() {
let line = line?; let line = line?;
@ -83,13 +99,19 @@ fn read_recipients_list<R: io::BufRead>(
// Skip empty lines and comments // Skip empty lines and comments
if line.is_empty() || line.find('#') == Some(0) { if line.is_empty() || line.find('#') == Some(0) {
continue; continue;
} else if let Err(e) = parse_recipient(filename, line, recipients, plugin_recipients) { } else if let Err(_e) = parse_recipient(
filename,
line,
recipients,
#[cfg(feature = "plugin")]
plugin_recipients,
) {
#[cfg(feature = "ssh")] #[cfg(feature = "ssh")]
match e { match _e {
ReadError::RsaModulusTooLarge ReadError::RsaModulusTooLarge
| ReadError::RsaModulusTooSmall | ReadError::RsaModulusTooSmall
| ReadError::UnsupportedKey(_, _) => { | ReadError::UnsupportedKey(_, _) => {
return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string()).into()); return Err(io::Error::new(io::ErrorKind::InvalidData, _e.to_string()).into());
} }
_ => (), _ => (),
} }
@ -119,11 +141,19 @@ pub fn read_recipients(
stdin_guard: &mut StdinGuard, stdin_guard: &mut StdinGuard,
) -> Result<Vec<Box<dyn Recipient + Send>>, ReadError> { ) -> Result<Vec<Box<dyn Recipient + Send>>, ReadError> {
let mut recipients: Vec<Box<dyn Recipient + Send>> = vec![]; let mut recipients: Vec<Box<dyn Recipient + Send>> = vec![];
#[cfg(feature = "plugin")]
let mut plugin_recipients: Vec<plugin::Recipient> = vec![]; let mut plugin_recipients: Vec<plugin::Recipient> = vec![];
#[cfg(feature = "plugin")]
let mut plugin_identities: Vec<plugin::Identity> = vec![]; let mut plugin_identities: Vec<plugin::Identity> = vec![];
for arg in recipient_strings { for arg in recipient_strings {
parse_recipient("", arg, &mut recipients, &mut plugin_recipients)?; parse_recipient(
"",
arg,
&mut recipients,
#[cfg(feature = "plugin")]
&mut plugin_recipients,
)?;
} }
for arg in recipients_file_strings { for arg in recipients_file_strings {
@ -134,15 +164,29 @@ pub fn read_recipients(
_ => e, _ => e,
})?; })?;
let buf = BufReader::new(f); let buf = BufReader::new(f);
read_recipients_list(&arg, buf, &mut recipients, &mut plugin_recipients)?; read_recipients_list(
&arg,
buf,
&mut recipients,
#[cfg(feature = "plugin")]
&mut plugin_recipients,
)?;
} }
#[cfg(feature = "plugin")]
let ctx = &mut (&mut recipients, &mut plugin_identities);
#[cfg(not(feature = "plugin"))]
let ctx = &mut recipients;
parse_identity_files::<_, ReadError>( parse_identity_files::<_, ReadError>(
identity_strings, identity_strings,
max_work_factor, max_work_factor,
stdin_guard, stdin_guard,
&mut (&mut recipients, &mut plugin_identities), ctx,
|(recipients, _), identity| { #[cfg(feature = "armor")]
|recipients, identity| {
#[cfg(feature = "plugin")]
let (recipients, _) = recipients;
recipients.extend(identity.recipients().map_err(|e| { recipients.extend(identity.recipients().map_err(|e| {
// Only one error can occur here. // Only one error can occur here.
if let EncryptError::EncryptedIdentities(e) = e { if let EncryptError::EncryptedIdentities(e) = e {
@ -153,7 +197,10 @@ pub fn read_recipients(
})?); })?);
Ok(()) Ok(())
}, },
|(recipients, _), filename, identity| { #[cfg(feature = "ssh")]
|recipients, filename, identity| {
#[cfg(feature = "plugin")]
let (recipients, _) = recipients;
let recipient = parse_ssh_recipient( let recipient = parse_ssh_recipient(
|| ssh::Recipient::try_from(identity), || ssh::Recipient::try_from(identity),
|| Err(ReadError::InvalidRecipient(filename.to_owned())), || Err(ReadError::InvalidRecipient(filename.to_owned())),
@ -163,42 +210,48 @@ pub fn read_recipients(
recipients.push(recipient); recipients.push(recipient);
Ok(()) Ok(())
}, },
|(recipients, plugin_identities), entry| { |recipients, entry| {
#[cfg(feature = "plugin")]
let (recipients, plugin_identities) = recipients;
match entry { match entry {
IdentityFileEntry::Native(i) => recipients.push(Box::new(i.to_public())), IdentityFileEntry::Native(i) => recipients.push(Box::new(i.to_public())),
#[cfg(feature = "plugin")]
IdentityFileEntry::Plugin(i) => plugin_identities.push(i), IdentityFileEntry::Plugin(i) => plugin_identities.push(i),
} }
Ok(()) Ok(())
}, },
)?; )?;
// Collect the names of the required plugins. #[cfg(feature = "plugin")]
let mut plugin_names = plugin_recipients {
.iter() // Collect the names of the required plugins.
.map(|r| r.plugin()) let mut plugin_names = plugin_recipients
.chain(plugin_identities.iter().map(|i| i.plugin())) .iter()
.collect::<Vec<_>>(); .map(|r| r.plugin())
plugin_names.sort_unstable(); .chain(plugin_identities.iter().map(|i| i.plugin()))
plugin_names.dedup(); .collect::<Vec<_>>();
plugin_names.sort_unstable();
plugin_names.dedup();
// Find the required plugins. // Find the required plugins.
for plugin_name in plugin_names { for plugin_name in plugin_names {
recipients.push(Box::new( recipients.push(Box::new(
plugin::RecipientPluginV1::new( plugin::RecipientPluginV1::new(
plugin_name, plugin_name,
&plugin_recipients, &plugin_recipients,
&plugin_identities, &plugin_identities,
UiCallbacks, UiCallbacks,
) )
.map_err(|e| { .map_err(|e| {
// Only one error can occur here. // Only one error can occur here.
if let EncryptError::MissingPlugin { binary_name } = e { if let EncryptError::MissingPlugin { binary_name } = e {
ReadError::MissingPlugin { binary_name } ReadError::MissingPlugin { binary_name }
} else { } else {
unreachable!() unreachable!()
} }
})?, })?,
)) ))
}
} }
Ok(recipients) Ok(recipients)