diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 3b821e7..51b9c9e 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -1,7 +1,6 @@ use crate::ast::{Enum, Field, Input, Struct}; -use crate::valid; use proc_macro2::TokenStream; -use quote::{format_ident, quote, quote_spanned}; +use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::spanned::Spanned; use syn::{DeriveInput, Member, PathArguments, Result, Type}; @@ -18,7 +17,12 @@ fn impl_struct(input: Struct) -> TokenStream { let ty = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let source_method = input.source_field().map(|source_field| { + let source_body = if input.attrs.transparent.is_some() { + let only_field = &input.fields[0].member; + Some(quote! { + std::error::Error::source(self.#only_field.as_dyn_error()) + }) + } else if let Some(source_field) = input.source_field() { let source = &source_field.member; let asref = if type_is_option(source_field.ty) { Some(quote_spanned!(source.span()=> .as_ref()?)) @@ -26,10 +30,17 @@ fn impl_struct(input: Struct) -> TokenStream { None }; let dyn_error = quote_spanned!(source.span()=> self.#source #asref.as_dyn_error()); + Some(quote! { + std::option::Option::Some(#dyn_error) + }) + } else { + None + }; + let source_method = source_body.map(|body| { quote! { fn source(&self) -> std::option::Option<&(dyn std::error::Error + 'static)> { use thiserror::private::AsDynError; - std::option::Option::Some(#dyn_error) + #body } } }); @@ -76,7 +87,12 @@ fn impl_struct(input: Struct) -> TokenStream { } }); - let display_impl = input.attrs.display.as_ref().map(|display| { + let display_body = if input.attrs.transparent.is_some() { + let only_field = &input.fields[0].member; + Some(quote! { + std::fmt::Display::fmt(&self.#only_field, __formatter) + }) + } else if let Some(display) = &input.attrs.display { let use_as_display = if display.has_bonus_display { Some(quote! { #[allow(unused_imports)] @@ -86,13 +102,20 @@ fn impl_struct(input: Struct) -> TokenStream { None }; let pat = fields_pat(&input.fields); + Some(quote! { + #use_as_display + #[allow(unused_variables)] + let Self #pat = self; + #display + }) + } else { + None + }; + let display_impl = display_body.map(|body| { quote! { impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause { fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - #use_as_display - #[allow(unused_variables)] - let Self #pat = self; - #display + #body } } } @@ -128,22 +151,27 @@ fn impl_enum(input: Enum) -> TokenStream { let source_method = if input.has_source() { let arms = input.variants.iter().map(|variant| { let ident = &variant.ident; - match variant.source_field() { - Some(source_field) => { - let source = &source_field.member; - let asref = if type_is_option(source_field.ty) { - Some(quote_spanned!(source.span()=> .as_ref()?)) - } else { - None - }; - let dyn_error = quote_spanned!(source.span()=> source #asref.as_dyn_error()); - quote! { - #ty::#ident {#source: source, ..} => std::option::Option::Some(#dyn_error), - } + if variant.attrs.transparent.is_some() { + let only_field = &variant.fields[0].member; + let source = quote!(std::error::Error::source(transparent.as_dyn_error())); + quote! { + #ty::#ident {#only_field: transparent} => #source, } - None => quote! { + } else if let Some(source_field) = variant.source_field() { + let source = &source_field.member; + let asref = if type_is_option(source_field.ty) { + Some(quote_spanned!(source.span()=> .as_ref()?)) + } else { + None + }; + let dyn_error = quote_spanned!(source.span()=> source #asref.as_dyn_error()); + quote! { + #ty::#ident {#source: source, ..} => std::option::Option::Some(#dyn_error), + } + } else { + quote! { #ty::#ident {..} => std::option::Option::None, - }, + } } }); Some(quote! { @@ -228,8 +256,7 @@ fn impl_enum(input: Enum) -> TokenStream { v.attrs .display .as_ref() - .expect(valid::CHECKED) - .has_bonus_display + .map_or(false, |display| display.has_bonus_display) }) { Some(quote! { #[allow(unused_imports)] @@ -244,7 +271,16 @@ fn impl_enum(input: Enum) -> TokenStream { None }; let arms = input.variants.iter().map(|variant| { - let display = variant.attrs.display.as_ref().expect(valid::CHECKED); + let display = match &variant.attrs.display { + Some(display) => display.to_token_stream(), + None => { + let only_field = match &variant.fields[0].member { + Member::Named(ident) => ident.clone(), + Member::Unnamed(index) => format_ident!("_{}", index), + }; + quote!(std::fmt::Display::fmt(#only_field, __formatter)) + } + }; let ident = &variant.ident; let pat = fields_pat(&variant.fields); quote! { @@ -297,7 +333,7 @@ fn fields_pat(fields: &[Field]) -> TokenStream { Some(Member::Named(_)) => quote!({ #(#members),* }), Some(Member::Unnamed(_)) => { let vars = members.map(|member| match member { - Member::Unnamed(member) => format_ident!("_{}", member.index), + Member::Unnamed(member) => format_ident!("_{}", member), Member::Named(_) => unreachable!(), }); quote!((#(#vars),*)) diff --git a/impl/src/prop.rs b/impl/src/prop.rs index 940b4f8..e011848 100644 --- a/impl/src/prop.rs +++ b/impl/src/prop.rs @@ -19,7 +19,7 @@ impl Enum<'_> { pub(crate) fn has_source(&self) -> bool { self.variants .iter() - .any(|variant| variant.source_field().is_some()) + .any(|variant| variant.source_field().is_some() || variant.attrs.transparent.is_some()) } pub(crate) fn has_backtrace(&self) -> bool { @@ -30,10 +30,15 @@ impl Enum<'_> { pub(crate) fn has_display(&self) -> bool { self.attrs.display.is_some() + || self.attrs.transparent.is_some() || self .variants .iter() .any(|variant| variant.attrs.display.is_some()) + || self + .variants + .iter() + .all(|variant| variant.attrs.transparent.is_some()) } } diff --git a/impl/src/valid.rs b/impl/src/valid.rs index e97fb8a..ffe7488 100644 --- a/impl/src/valid.rs +++ b/impl/src/valid.rs @@ -4,8 +4,6 @@ use quote::ToTokens; use std::collections::BTreeSet as Set; use syn::{Error, Member, Result}; -pub(crate) const CHECKED: &str = "checked in validation"; - impl Input<'_> { pub(crate) fn validate(&self) -> Result<()> { match self {