diff --git a/impl/src/ast.rs b/impl/src/ast.rs index 5cbae38..77f9583 100644 --- a/impl/src/ast.rs +++ b/impl/src/ast.rs @@ -91,9 +91,13 @@ impl<'a> Enum<'a> { .iter() .map(|node| { let mut variant = Variant::from_syn(node, &scope)?; - if variant.attrs.display.is_none() && variant.attrs.transparent.is_none() { + if variant.attrs.display.is_none() + && variant.attrs.transparent.is_none() + && variant.attrs.fmt.is_none() + { variant.attrs.display.clone_from(&attrs.display); variant.attrs.transparent = attrs.transparent; + variant.attrs.fmt.clone_from(&attrs.fmt); } if let Some(display) = &mut variant.attrs.display { let container = ContainerKind::from_variant(node); diff --git a/impl/src/attr.rs b/impl/src/attr.rs index 2fc04af..f98f731 100644 --- a/impl/src/attr.rs +++ b/impl/src/attr.rs @@ -4,8 +4,8 @@ use std::collections::BTreeSet as Set; use syn::parse::discouraged::Speculative; use syn::parse::{End, ParseStream}; use syn::{ - braced, bracketed, parenthesized, token, Attribute, Error, Ident, Index, LitFloat, LitInt, - LitStr, Meta, Result, Token, + braced, bracketed, parenthesized, token, Attribute, Error, ExprPath, Ident, Index, LitFloat, + LitInt, LitStr, Meta, Result, Token, }; pub struct Attrs<'a> { @@ -14,6 +14,7 @@ pub struct Attrs<'a> { pub backtrace: Option<&'a Attribute>, pub from: Option>, pub transparent: Option>, + pub fmt: Option>, } #[derive(Clone)] @@ -45,6 +46,12 @@ pub struct Transparent<'a> { pub span: Span, } +#[derive(Clone)] +pub struct Fmt<'a> { + pub original: &'a Attribute, + pub path: ExprPath, +} + #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)] pub enum Trait { Debug, @@ -65,6 +72,7 @@ pub fn get(input: &[Attribute]) -> Result { backtrace: None, from: None, transparent: None, + fmt: None, }; for attr in input { @@ -113,14 +121,17 @@ pub fn get(input: &[Attribute]) -> Result { } fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> { - syn::custom_keyword!(transparent); + mod kw { + syn::custom_keyword!(transparent); + syn::custom_keyword!(fmt); + } attr.parse_args_with(|input: ParseStream| { let lookahead = input.lookahead1(); let fmt = if lookahead.peek(LitStr) { input.parse::()? - } else if lookahead.peek(transparent) { - let kw: transparent = input.parse()?; + } else if lookahead.peek(kw::transparent) { + let kw: kw::transparent = input.parse()?; if attrs.transparent.is_some() { return Err(Error::new_spanned( attr, @@ -132,6 +143,21 @@ fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Resu span: kw.span, }); return Ok(()); + } else if lookahead.peek(kw::fmt) { + input.parse::()?; + input.parse::()?; + let path: ExprPath = input.parse()?; + if attrs.fmt.is_some() { + return Err(Error::new_spanned( + attr, + "duplicate #[error(fmt = ...)] attribute", + )); + } + attrs.fmt = Some(Fmt { + original: attr, + path, + }); + return Ok(()); } else { return Err(lookahead.error()); }; diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 64e7891..f6f45f9 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -404,19 +404,23 @@ fn impl_enum(input: Enum) -> TokenStream { }; let arms = input.variants.iter().map(|variant| { let mut display_implied_bounds = Set::new(); - let display = match &variant.attrs.display { - Some(display) => { - display_implied_bounds.clone_from(&display.implied_bounds); - display.to_token_stream() - } - None => { - let only_field = match &variant.fields[0].member { - MemberUnraw::Named(ident) => ident.to_local(), - MemberUnraw::Unnamed(index) => format_ident!("_{}", index), - }; - display_implied_bounds.insert((0, Trait::Display)); - quote!(::core::fmt::Display::fmt(#only_field, __formatter)) - } + let display = if let Some(display) = &variant.attrs.display { + display_implied_bounds.clone_from(&display.implied_bounds); + display.to_token_stream() + } else if let Some(fmt) = &variant.attrs.fmt { + let fmt_path = &fmt.path; + let vars = variant.fields.iter().map(|field| match &field.member { + MemberUnraw::Named(ident) => ident.to_local(), + MemberUnraw::Unnamed(index) => format_ident!("_{}", index), + }); + quote!(#fmt_path(#(#vars,)* __formatter)) + } else { + let only_field = match &variant.fields[0].member { + MemberUnraw::Named(ident) => ident.to_local(), + MemberUnraw::Unnamed(index) => format_ident!("_{}", index), + }; + display_implied_bounds.insert((0, Trait::Display)); + quote!(::core::fmt::Display::fmt(#only_field, __formatter)) }; for (field, bound) in display_implied_bounds { let field = &variant.fields[field]; @@ -494,7 +498,7 @@ fn fields_pat(fields: &[Field]) -> TokenStream { Some(MemberUnraw::Named(_)) => quote!({ #(#members),* }), Some(MemberUnraw::Unnamed(_)) => { let vars = members.map(|member| match member { - MemberUnraw::Unnamed(member) => format_ident!("_{}", member), + MemberUnraw::Unnamed(index) => format_ident!("_{}", index), MemberUnraw::Named(_) => unreachable!(), }); quote!((#(#vars),*)) diff --git a/impl/src/prop.rs b/impl/src/prop.rs index 56ab0c5..0a101fc 100644 --- a/impl/src/prop.rs +++ b/impl/src/prop.rs @@ -38,10 +38,11 @@ impl Enum<'_> { pub(crate) fn has_display(&self) -> bool { self.attrs.display.is_some() || self.attrs.transparent.is_some() + || self.attrs.fmt.is_some() || self .variants .iter() - .any(|variant| variant.attrs.display.is_some()) + .any(|variant| variant.attrs.display.is_some() || variant.attrs.fmt.is_some()) || self .variants .iter() diff --git a/impl/src/valid.rs b/impl/src/valid.rs index 5755b78..21bd885 100644 --- a/impl/src/valid.rs +++ b/impl/src/valid.rs @@ -28,6 +28,12 @@ impl Struct<'_> { )); } } + if let Some(fmt) = &self.attrs.fmt { + return Err(Error::new_spanned( + fmt.original, + "#[error(fmt = ...)] is only supported in enums; for a struct, handwrite your own Display impl", + )); + } check_field_attrs(&self.fields)?; for field in &self.fields { field.validate()?; @@ -42,7 +48,10 @@ impl Enum<'_> { let has_display = self.has_display(); for variant in &self.variants { variant.validate()?; - if has_display && variant.attrs.display.is_none() && variant.attrs.transparent.is_none() + if has_display + && variant.attrs.display.is_none() + && variant.attrs.transparent.is_none() + && variant.attrs.fmt.is_none() { return Err(Error::new_spanned( variant.original, @@ -81,9 +90,15 @@ impl Variant<'_> { impl Field<'_> { fn validate(&self) -> Result<()> { - if let Some(display) = &self.attrs.display { + if let Some(unexpected_display_attr) = if let Some(display) = &self.attrs.display { + Some(display.original) + } else if let Some(fmt) = &self.attrs.fmt { + Some(fmt.original) + } else { + None + } { return Err(Error::new_spanned( - display.original, + unexpected_display_attr, "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant", )); } @@ -110,14 +125,26 @@ fn check_non_field_attrs(attrs: &Attrs) -> Result<()> { "not expected here; the #[backtrace] attribute belongs on a specific field", )); } - if let Some(display) = &attrs.display { - if attrs.transparent.is_some() { + if attrs.transparent.is_some() { + if let Some(display) = &attrs.display { return Err(Error::new_spanned( display.original, "cannot have both #[error(transparent)] and a display attribute", )); } + if let Some(fmt) = &attrs.fmt { + return Err(Error::new_spanned( + fmt.original, + "cannot have both #[error(transparent)] and #[error(fmt = ...)]", + )); + } + } else if let (Some(display), Some(_)) = (&attrs.display, &attrs.fmt) { + return Err(Error::new_spanned( + display.original, + "cannot have both #[error(fmt = ...)] and a format arguments attribute", + )); } + Ok(()) } diff --git a/tests/test_display.rs b/tests/test_display.rs index 91fe9e0..ec4170d 100644 --- a/tests/test_display.rs +++ b/tests/test_display.rs @@ -370,3 +370,69 @@ fn test_raw_str() { assert(r#"raw brace right }"#, Error::BraceRight); assert(r#"raw brace right 2 \x7D"#, Error::BraceRight2); } + +mod util { + use core::fmt::{self, Octal}; + + pub fn octal(value: &T, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "0o{:o}", value) + } +} + +#[test] +fn test_fmt_path() { + fn unit(formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("unit=") + } + + fn pair(k: &i32, v: &i32, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "pair={k}:{v}") + } + + #[derive(Error, Debug)] + pub enum Error { + #[error(fmt = unit)] + Unit, + #[error(fmt = pair)] + Tuple(i32, i32), + #[error(fmt = pair)] + Entry { k: i32, v: i32 }, + #[error(fmt = crate::util::octal)] + I16(i16), + #[error(fmt = crate::util::octal::)] + I32 { n: i32 }, + #[error(fmt = core::fmt::Octal::fmt)] + I64(i64), + #[error("...{0}")] + Other(bool), + } + + assert("unit=", Error::Unit); + assert("pair=10:0", Error::Tuple(10, 0)); + assert("pair=10:0", Error::Entry { k: 10, v: 0 }); + assert("0o777", Error::I16(0o777)); + assert("0o777", Error::I32 { n: 0o777 }); + assert("777", Error::I64(0o777)); + assert("...false", Error::Other(false)); +} + +#[test] +fn test_fmt_path_inherited() { + #[derive(Error, Debug)] + #[error(fmt = crate::util::octal)] + pub enum Error { + I16(i16), + I32 { + n: i32, + }, + #[error(fmt = core::fmt::Octal::fmt)] + I64(i64), + #[error("...{0}")] + Other(bool), + } + + assert("0o777", Error::I16(0o777)); + assert("0o777", Error::I32 { n: 0o777 }); + assert("777", Error::I64(0o777)); + assert("...false", Error::Other(false)); +} diff --git a/tests/ui/concat-display.stderr b/tests/ui/concat-display.stderr index d92e635..9255488 100644 --- a/tests/ui/concat-display.stderr +++ b/tests/ui/concat-display.stderr @@ -1,4 +1,4 @@ -error: expected string literal or `transparent` +error: expected one of: string literal, `transparent`, `fmt` --> tests/ui/concat-display.rs:8:17 | 8 | #[error(concat!("invalid ", $what))] diff --git a/tests/ui/duplicate-fmt.rs b/tests/ui/duplicate-fmt.rs index cb3d678..32f7a23 100644 --- a/tests/ui/duplicate-fmt.rs +++ b/tests/ui/duplicate-fmt.rs @@ -5,4 +5,19 @@ use thiserror::Error; #[error("...")] pub struct Error; +#[derive(Error, Debug)] +#[error(fmt = core::fmt::Octal::fmt)] +#[error(fmt = core::fmt::LowerHex::fmt)] +pub enum FmtFmt {} + +#[derive(Error, Debug)] +#[error(fmt = core::fmt::Octal::fmt)] +#[error(transparent)] +pub enum FmtTransparent {} + +#[derive(Error, Debug)] +#[error(fmt = core::fmt::Octal::fmt)] +#[error("...")] +pub enum FmtDisplay {} + fn main() {} diff --git a/tests/ui/duplicate-fmt.stderr b/tests/ui/duplicate-fmt.stderr index 532b16b..a6c9932 100644 --- a/tests/ui/duplicate-fmt.stderr +++ b/tests/ui/duplicate-fmt.stderr @@ -3,3 +3,21 @@ error: only one #[error(...)] attribute is allowed | 5 | #[error("...")] | ^^^^^^^^^^^^^^^ + +error: duplicate #[error(fmt = ...)] attribute + --> tests/ui/duplicate-fmt.rs:10:1 + | +10 | #[error(fmt = core::fmt::LowerHex::fmt)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +error: cannot have both #[error(transparent)] and #[error(fmt = ...)] + --> tests/ui/duplicate-fmt.rs:14:1 + | +14 | #[error(fmt = core::fmt::Octal::fmt)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +error: cannot have both #[error(fmt = ...)] and a format arguments attribute + --> tests/ui/duplicate-fmt.rs:20:1 + | +20 | #[error("...")] + | ^^^^^^^^^^^^^^^ diff --git a/tests/ui/struct-with-fmt.rs b/tests/ui/struct-with-fmt.rs new file mode 100644 index 0000000..73bf79f --- /dev/null +++ b/tests/ui/struct-with-fmt.rs @@ -0,0 +1,7 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +#[error(fmt = core::fmt::Octal::fmt)] +pub struct Error(i32); + +fn main() {} diff --git a/tests/ui/struct-with-fmt.stderr b/tests/ui/struct-with-fmt.stderr new file mode 100644 index 0000000..00463be --- /dev/null +++ b/tests/ui/struct-with-fmt.stderr @@ -0,0 +1,5 @@ +error: #[error(fmt = ...)] is only supported in enums; for a struct, handwrite your own Display impl + --> tests/ui/struct-with-fmt.rs:4:1 + | +4 | #[error(fmt = core::fmt::Octal::fmt)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^