diff --git a/impl/src/attr.rs b/impl/src/attr.rs index 00fb965..8d4b096 100644 --- a/impl/src/attr.rs +++ b/impl/src/attr.rs @@ -10,6 +10,7 @@ use syn::{ pub struct Attrs<'a> { pub display: Option>, pub source: Option>, + pub backtrace: Option>, } #[derive(Clone)] @@ -24,10 +25,15 @@ pub struct Source<'a> { pub original: &'a Attribute, } +pub struct Backtrace<'a> { + pub original: &'a Attribute, +} + pub fn get(input: &[Attribute]) -> Result { let mut attrs = Attrs { display: None, source: None, + backtrace: None, }; for attr in input { @@ -46,6 +52,12 @@ pub fn get(input: &[Attribute]) -> Result { return Err(Error::new_spanned(attr, "duplicate #[source] attribute")); } attrs.source = Some(source); + } else if attr.path.is_ident("backtrace") { + let backtrace = parse_backtrace(attr)?; + if attrs.backtrace.is_some() { + return Err(Error::new_spanned(attr, "duplicate #[backtrace] attribute")); + } + attrs.backtrace = Some(backtrace); } } @@ -118,6 +130,11 @@ fn parse_source(attr: &Attribute) -> Result { Ok(Source { original: attr }) } +fn parse_backtrace(attr: &Attribute) -> Result { + syn::parse2::(attr.tokens.clone())?; + Ok(Backtrace { original: attr }) +} + impl ToTokens for Display<'_> { fn to_tokens(&self, tokens: &mut TokenStream) { let fmt = &self.fmt; diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 91fabfb..25c50c8 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -28,10 +28,22 @@ fn impl_struct(input: Struct) -> TokenStream { } }); - let backtrace_method = input.backtrace_member().map(|backtrace| { + let backtrace_method = input.backtrace_field().map(|backtrace| { + let backtrace = &backtrace.member; + let body = if let Some(source) = input.source_member() { + let dyn_error = quote_spanned!(source.span()=> self.#source.as_dyn_error()); + quote!({ + use thiserror::private::AsDynError; + #dyn_error.backtrace().unwrap_or(&self.#backtrace) + }) + } else { + quote! { + &self.#backtrace + } + }; quote! { fn backtrace(&self) -> std::option::Option<&std::backtrace::Backtrace> { - std::option::Option::Some(&self.#backtrace) + std::option::Option::Some(#body) } } }); @@ -67,10 +79,9 @@ fn impl_enum(input: Enum) -> TokenStream { let ident = &variant.ident; match variant.source_member() { Some(source) => { - let var = quote_spanned!(source.span()=> source); - let dyn_error = quote_spanned!(source.span()=> #var.as_dyn_error()); + let dyn_error = quote_spanned!(source.span()=> source.as_dyn_error()); quote! { - #ty::#ident {#source: #var, ..} => std::option::Option::Some(#dyn_error), + #ty::#ident {#source: source, ..} => std::option::Option::Some(#dyn_error), } } None => quote! { @@ -93,11 +104,28 @@ fn impl_enum(input: Enum) -> TokenStream { let backtrace_method = if input.has_backtrace() { let arms = input.variants.iter().map(|variant| { let ident = &variant.ident; - match variant.backtrace_member() { - Some(backtrace) => quote! { - #ty::#ident {#backtrace: backtrace, ..} => std::option::Option::Some(backtrace), - }, - None => quote! { + match (variant.backtrace_field(), variant.source_member()) { + (Some(backtrace), Some(source)) if backtrace.attrs.backtrace.is_none() => { + let backtrace = &backtrace.member; + let dyn_error = quote_spanned!(source.span()=> source.as_dyn_error()); + quote! { + #ty::#ident { + #backtrace: backtrace, + #source: source, + .. + } => std::option::Option::Some({ + use thiserror::private::AsDynError; + #dyn_error.backtrace().unwrap_or(backtrace) + }), + } + } + (Some(backtrace), _) => { + let backtrace = &backtrace.member; + quote! { + #ty::#ident {#backtrace: backtrace, ..} => std::option::Option::Some(backtrace), + } + } + (None, _) => quote! { #ty::#ident {..} => std::option::Option::None, }, } diff --git a/impl/src/lib.rs b/impl/src/lib.rs index 00bfe7f..2ea966b 100644 --- a/impl/src/lib.rs +++ b/impl/src/lib.rs @@ -10,7 +10,7 @@ mod valid; use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput}; -#[proc_macro_derive(Error, attributes(error, source))] +#[proc_macro_derive(Error, attributes(backtrace, error, source))] pub fn derive_error(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand::derive(&input) diff --git a/impl/src/prop.rs b/impl/src/prop.rs index cd4c12d..240de8c 100644 --- a/impl/src/prop.rs +++ b/impl/src/prop.rs @@ -6,8 +6,8 @@ impl Struct<'_> { source_member(&self.fields) } - pub(crate) fn backtrace_member(&self) -> Option<&Member> { - backtrace_member(&self.fields) + pub(crate) fn backtrace_field(&self) -> Option<&Field> { + backtrace_field(&self.fields) } } @@ -21,7 +21,7 @@ impl Enum<'_> { pub(crate) fn has_backtrace(&self) -> bool { self.variants .iter() - .any(|variant| variant.backtrace_member().is_some()) + .any(|variant| variant.backtrace_field().is_some()) } pub(crate) fn has_display(&self) -> bool { @@ -38,14 +38,8 @@ impl Variant<'_> { source_member(&self.fields) } - pub(crate) fn backtrace_member(&self) -> Option<&Member> { - backtrace_member(&self.fields) - } -} - -impl Field<'_> { - fn is_backtrace(&self) -> bool { - type_is_backtrace(self.ty) + pub(crate) fn backtrace_field(&self) -> Option<&Field> { + backtrace_field(&self.fields) } } @@ -64,11 +58,18 @@ fn source_member<'a>(fields: &'a [Field]) -> Option<&'a Member> { None } -fn backtrace_member<'a>(fields: &'a [Field]) -> Option<&'a Member> { - fields - .iter() - .find(|field| field.is_backtrace()) - .map(|field| &field.member) +fn backtrace_field<'a, 'b>(fields: &'a [Field<'b>]) -> Option<&'a Field<'b>> { + for field in fields { + if field.attrs.backtrace.is_some() { + return Some(&field); + } + } + for field in fields { + if type_is_backtrace(field.ty) { + return Some(&field); + } + } + None } fn type_is_backtrace(ty: &Type) -> bool { diff --git a/impl/src/valid.rs b/impl/src/valid.rs index b4be8f2..c4d2fcc 100644 --- a/impl/src/valid.rs +++ b/impl/src/valid.rs @@ -15,8 +15,8 @@ impl Input<'_> { impl Struct<'_> { fn validate(&self) -> Result<()> { - check_no_source(&self.attrs)?; - find_duplicate_source(&self.fields)?; + check_no_source_or_backtrace(&self.attrs)?; + check_no_duplicate_source_or_backtrace(&self.fields)?; for field in &self.fields { field.validate()?; } @@ -26,7 +26,7 @@ impl Struct<'_> { impl Enum<'_> { fn validate(&self) -> Result<()> { - check_no_source(&self.attrs)?; + check_no_source_or_backtrace(&self.attrs)?; let has_display = self.has_display(); for variant in &self.variants { variant.validate()?; @@ -43,8 +43,8 @@ impl Enum<'_> { impl Variant<'_> { fn validate(&self) -> Result<()> { - check_no_source(&self.attrs)?; - find_duplicate_source(&self.fields)?; + check_no_source_or_backtrace(&self.attrs)?; + check_no_duplicate_source_or_backtrace(&self.fields)?; for field in &self.fields { field.validate()?; } @@ -64,18 +64,25 @@ impl Field<'_> { } } -fn check_no_source(attrs: &Attrs) -> Result<()> { +fn check_no_source_or_backtrace(attrs: &Attrs) -> Result<()> { if let Some(source) = &attrs.source { return Err(Error::new_spanned( source.original, "not expected here; the #[source] attribute belongs on a specific field", )); } + if let Some(backtrace) = &attrs.backtrace { + return Err(Error::new_spanned( + backtrace.original, + "not expected here; the #[backtrace] attribute belongs on a specific field", + )); + } Ok(()) } -fn find_duplicate_source(fields: &[Field]) -> Result<()> { +fn check_no_duplicate_source_or_backtrace(fields: &[Field]) -> Result<()> { let mut has_source = false; + let mut has_backtrace = false; for field in fields { if let Some(source) = &field.attrs.source { if has_source { @@ -86,6 +93,15 @@ fn find_duplicate_source(fields: &[Field]) -> Result<()> { } has_source = true; } + if let Some(backtrace) = &field.attrs.backtrace { + if has_backtrace { + return Err(Error::new_spanned( + backtrace.original, + "duplicate #[backtrace] attribute", + )); + } + has_backtrace = true; + } } Ok(()) }