Implement #[error(fmt = ...)]

This commit is contained in:
David Tolnay 2024-11-04 23:20:03 -05:00
parent 5e4b7a5117
commit ba9af4522e
No known key found for this signature in database
GPG key ID: F9BA143B95FF6D82
11 changed files with 200 additions and 27 deletions

View file

@ -91,9 +91,13 @@ impl<'a> Enum<'a> {
.iter()
.map(|node| {
let mut variant = Variant::from_syn(node, &scope)?;
if variant.attrs.display.is_none() && variant.attrs.transparent.is_none() {
if variant.attrs.display.is_none()
&& variant.attrs.transparent.is_none()
&& variant.attrs.fmt.is_none()
{
variant.attrs.display.clone_from(&attrs.display);
variant.attrs.transparent = attrs.transparent;
variant.attrs.fmt.clone_from(&attrs.fmt);
}
if let Some(display) = &mut variant.attrs.display {
let container = ContainerKind::from_variant(node);

View file

@ -4,8 +4,8 @@ use std::collections::BTreeSet as Set;
use syn::parse::discouraged::Speculative;
use syn::parse::{End, ParseStream};
use syn::{
braced, bracketed, parenthesized, token, Attribute, Error, Ident, Index, LitFloat, LitInt,
LitStr, Meta, Result, Token,
braced, bracketed, parenthesized, token, Attribute, Error, ExprPath, Ident, Index, LitFloat,
LitInt, LitStr, Meta, Result, Token,
};
pub struct Attrs<'a> {
@ -14,6 +14,7 @@ pub struct Attrs<'a> {
pub backtrace: Option<&'a Attribute>,
pub from: Option<From<'a>>,
pub transparent: Option<Transparent<'a>>,
pub fmt: Option<Fmt<'a>>,
}
#[derive(Clone)]
@ -45,6 +46,12 @@ pub struct Transparent<'a> {
pub span: Span,
}
#[derive(Clone)]
pub struct Fmt<'a> {
pub original: &'a Attribute,
pub path: ExprPath,
}
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
pub enum Trait {
Debug,
@ -65,6 +72,7 @@ pub fn get(input: &[Attribute]) -> Result<Attrs> {
backtrace: None,
from: None,
transparent: None,
fmt: None,
};
for attr in input {
@ -113,14 +121,17 @@ pub fn get(input: &[Attribute]) -> Result<Attrs> {
}
fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> {
syn::custom_keyword!(transparent);
mod kw {
syn::custom_keyword!(transparent);
syn::custom_keyword!(fmt);
}
attr.parse_args_with(|input: ParseStream| {
let lookahead = input.lookahead1();
let fmt = if lookahead.peek(LitStr) {
input.parse::<LitStr>()?
} else if lookahead.peek(transparent) {
let kw: transparent = input.parse()?;
} else if lookahead.peek(kw::transparent) {
let kw: kw::transparent = input.parse()?;
if attrs.transparent.is_some() {
return Err(Error::new_spanned(
attr,
@ -132,6 +143,21 @@ fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Resu
span: kw.span,
});
return Ok(());
} else if lookahead.peek(kw::fmt) {
input.parse::<kw::fmt>()?;
input.parse::<Token![=]>()?;
let path: ExprPath = input.parse()?;
if attrs.fmt.is_some() {
return Err(Error::new_spanned(
attr,
"duplicate #[error(fmt = ...)] attribute",
));
}
attrs.fmt = Some(Fmt {
original: attr,
path,
});
return Ok(());
} else {
return Err(lookahead.error());
};

View file

@ -404,19 +404,23 @@ fn impl_enum(input: Enum) -> TokenStream {
};
let arms = input.variants.iter().map(|variant| {
let mut display_implied_bounds = Set::new();
let display = match &variant.attrs.display {
Some(display) => {
display_implied_bounds.clone_from(&display.implied_bounds);
display.to_token_stream()
}
None => {
let only_field = match &variant.fields[0].member {
MemberUnraw::Named(ident) => ident.to_local(),
MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
};
display_implied_bounds.insert((0, Trait::Display));
quote!(::core::fmt::Display::fmt(#only_field, __formatter))
}
let display = if let Some(display) = &variant.attrs.display {
display_implied_bounds.clone_from(&display.implied_bounds);
display.to_token_stream()
} else if let Some(fmt) = &variant.attrs.fmt {
let fmt_path = &fmt.path;
let vars = variant.fields.iter().map(|field| match &field.member {
MemberUnraw::Named(ident) => ident.to_local(),
MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
});
quote!(#fmt_path(#(#vars,)* __formatter))
} else {
let only_field = match &variant.fields[0].member {
MemberUnraw::Named(ident) => ident.to_local(),
MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
};
display_implied_bounds.insert((0, Trait::Display));
quote!(::core::fmt::Display::fmt(#only_field, __formatter))
};
for (field, bound) in display_implied_bounds {
let field = &variant.fields[field];
@ -494,7 +498,7 @@ fn fields_pat(fields: &[Field]) -> TokenStream {
Some(MemberUnraw::Named(_)) => quote!({ #(#members),* }),
Some(MemberUnraw::Unnamed(_)) => {
let vars = members.map(|member| match member {
MemberUnraw::Unnamed(member) => format_ident!("_{}", member),
MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
MemberUnraw::Named(_) => unreachable!(),
});
quote!((#(#vars),*))

View file

@ -38,10 +38,11 @@ impl Enum<'_> {
pub(crate) fn has_display(&self) -> bool {
self.attrs.display.is_some()
|| self.attrs.transparent.is_some()
|| self.attrs.fmt.is_some()
|| self
.variants
.iter()
.any(|variant| variant.attrs.display.is_some())
.any(|variant| variant.attrs.display.is_some() || variant.attrs.fmt.is_some())
|| self
.variants
.iter()

View file

@ -28,6 +28,12 @@ impl Struct<'_> {
));
}
}
if let Some(fmt) = &self.attrs.fmt {
return Err(Error::new_spanned(
fmt.original,
"#[error(fmt = ...)] is only supported in enums; for a struct, handwrite your own Display impl",
));
}
check_field_attrs(&self.fields)?;
for field in &self.fields {
field.validate()?;
@ -42,7 +48,10 @@ impl Enum<'_> {
let has_display = self.has_display();
for variant in &self.variants {
variant.validate()?;
if has_display && variant.attrs.display.is_none() && variant.attrs.transparent.is_none()
if has_display
&& variant.attrs.display.is_none()
&& variant.attrs.transparent.is_none()
&& variant.attrs.fmt.is_none()
{
return Err(Error::new_spanned(
variant.original,
@ -81,9 +90,15 @@ impl Variant<'_> {
impl Field<'_> {
fn validate(&self) -> Result<()> {
if let Some(display) = &self.attrs.display {
if let Some(unexpected_display_attr) = if let Some(display) = &self.attrs.display {
Some(display.original)
} else if let Some(fmt) = &self.attrs.fmt {
Some(fmt.original)
} else {
None
} {
return Err(Error::new_spanned(
display.original,
unexpected_display_attr,
"not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
));
}
@ -110,14 +125,26 @@ fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
"not expected here; the #[backtrace] attribute belongs on a specific field",
));
}
if let Some(display) = &attrs.display {
if attrs.transparent.is_some() {
if attrs.transparent.is_some() {
if let Some(display) = &attrs.display {
return Err(Error::new_spanned(
display.original,
"cannot have both #[error(transparent)] and a display attribute",
));
}
if let Some(fmt) = &attrs.fmt {
return Err(Error::new_spanned(
fmt.original,
"cannot have both #[error(transparent)] and #[error(fmt = ...)]",
));
}
} else if let (Some(display), Some(_)) = (&attrs.display, &attrs.fmt) {
return Err(Error::new_spanned(
display.original,
"cannot have both #[error(fmt = ...)] and a format arguments attribute",
));
}
Ok(())
}

View file

@ -370,3 +370,69 @@ fn test_raw_str() {
assert(r#"raw brace right }"#, Error::BraceRight);
assert(r#"raw brace right 2 \x7D"#, Error::BraceRight2);
}
mod util {
use core::fmt::{self, Octal};
pub fn octal<T: Octal>(value: &T, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "0o{:o}", value)
}
}
#[test]
fn test_fmt_path() {
fn unit(formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("unit=")
}
fn pair(k: &i32, v: &i32, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "pair={k}:{v}")
}
#[derive(Error, Debug)]
pub enum Error {
#[error(fmt = unit)]
Unit,
#[error(fmt = pair)]
Tuple(i32, i32),
#[error(fmt = pair)]
Entry { k: i32, v: i32 },
#[error(fmt = crate::util::octal)]
I16(i16),
#[error(fmt = crate::util::octal::<i32>)]
I32 { n: i32 },
#[error(fmt = core::fmt::Octal::fmt)]
I64(i64),
#[error("...{0}")]
Other(bool),
}
assert("unit=", Error::Unit);
assert("pair=10:0", Error::Tuple(10, 0));
assert("pair=10:0", Error::Entry { k: 10, v: 0 });
assert("0o777", Error::I16(0o777));
assert("0o777", Error::I32 { n: 0o777 });
assert("777", Error::I64(0o777));
assert("...false", Error::Other(false));
}
#[test]
fn test_fmt_path_inherited() {
#[derive(Error, Debug)]
#[error(fmt = crate::util::octal)]
pub enum Error {
I16(i16),
I32 {
n: i32,
},
#[error(fmt = core::fmt::Octal::fmt)]
I64(i64),
#[error("...{0}")]
Other(bool),
}
assert("0o777", Error::I16(0o777));
assert("0o777", Error::I32 { n: 0o777 });
assert("777", Error::I64(0o777));
assert("...false", Error::Other(false));
}

View file

@ -1,4 +1,4 @@
error: expected string literal or `transparent`
error: expected one of: string literal, `transparent`, `fmt`
--> tests/ui/concat-display.rs:8:17
|
8 | #[error(concat!("invalid ", $what))]

View file

@ -5,4 +5,19 @@ use thiserror::Error;
#[error("...")]
pub struct Error;
#[derive(Error, Debug)]
#[error(fmt = core::fmt::Octal::fmt)]
#[error(fmt = core::fmt::LowerHex::fmt)]
pub enum FmtFmt {}
#[derive(Error, Debug)]
#[error(fmt = core::fmt::Octal::fmt)]
#[error(transparent)]
pub enum FmtTransparent {}
#[derive(Error, Debug)]
#[error(fmt = core::fmt::Octal::fmt)]
#[error("...")]
pub enum FmtDisplay {}
fn main() {}

View file

@ -3,3 +3,21 @@ error: only one #[error(...)] attribute is allowed
|
5 | #[error("...")]
| ^^^^^^^^^^^^^^^
error: duplicate #[error(fmt = ...)] attribute
--> tests/ui/duplicate-fmt.rs:10:1
|
10 | #[error(fmt = core::fmt::LowerHex::fmt)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
error: cannot have both #[error(transparent)] and #[error(fmt = ...)]
--> tests/ui/duplicate-fmt.rs:14:1
|
14 | #[error(fmt = core::fmt::Octal::fmt)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
error: cannot have both #[error(fmt = ...)] and a format arguments attribute
--> tests/ui/duplicate-fmt.rs:20:1
|
20 | #[error("...")]
| ^^^^^^^^^^^^^^^

View file

@ -0,0 +1,7 @@
use thiserror::Error;
#[derive(Error, Debug)]
#[error(fmt = core::fmt::Octal::fmt)]
pub struct Error(i32);
fn main() {}

View file

@ -0,0 +1,5 @@
error: #[error(fmt = ...)] is only supported in enums; for a struct, handwrite your own Display impl
--> tests/ui/struct-with-fmt.rs:4:1
|
4 | #[error(fmt = core::fmt::Octal::fmt)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^