Implement #[implicit] to automatically generate fields using the new ImplicitField trait

This commit is contained in:
Carl Sverre 2024-12-19 11:40:24 -08:00
parent 2bd29821f4
commit 5195f7ccfa
8 changed files with 223 additions and 12 deletions

View file

@ -12,6 +12,7 @@ pub struct Attrs<'a> {
pub display: Option<Display<'a>>,
pub source: Option<Source<'a>>,
pub backtrace: Option<&'a Attribute>,
pub implicit: Option<&'a Attribute>,
pub from: Option<From<'a>>,
pub transparent: Option<Transparent<'a>>,
pub fmt: Option<Fmt<'a>>,
@ -71,6 +72,7 @@ pub fn get(input: &[Attribute]) -> Result<Attrs> {
display: None,
source: None,
backtrace: None,
implicit: None,
from: None,
transparent: None,
fmt: None,
@ -97,6 +99,12 @@ pub fn get(input: &[Attribute]) -> Result<Attrs> {
return Err(Error::new_spanned(attr, "duplicate #[backtrace] attribute"));
}
attrs.backtrace = Some(attr);
} else if attr.path().is_ident("implicit") {
attr.meta.require_path_only()?;
if attrs.implicit.is_some() {
return Err(Error::new_spanned(attr, "duplicate #[implicit] attribute"));
}
attrs.implicit = Some(attr);
} else if attr.path().is_ident("from") {
match attr.meta {
Meta::Path(_) => {}

View file

@ -166,12 +166,14 @@ fn impl_struct(input: Struct) -> TokenStream {
let from_impl = input.from_field().map(|from_field| {
let span = from_field.attrs.from.unwrap().span;
let backtrace_field = input.distinct_backtrace_field();
let implicit_fields = input.implicit_fields();
let from = unoptional_type(from_field.ty);
let source_var = Ident::new("source", span);
let body = from_initializer(from_field, backtrace_field, &source_var);
let body = from_initializer(from_field, backtrace_field, implicit_fields, &source_var);
let from_impl = quote_spanned! {span=>
#[automatically_derived]
impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
#[track_caller]
fn from(#source_var: #from) -> Self {
#ty #body
}
@ -432,13 +434,15 @@ fn impl_enum(input: Enum) -> TokenStream {
let from_field = variant.from_field()?;
let span = from_field.attrs.from.unwrap().span;
let backtrace_field = variant.distinct_backtrace_field();
let implicit_fields = variant.implicit_fields();
let variant = &variant.ident;
let from = unoptional_type(from_field.ty);
let source_var = Ident::new("source", span);
let body = from_initializer(from_field, backtrace_field, &source_var);
let body = from_initializer(from_field, backtrace_field, implicit_fields, &source_var);
let from_impl = quote_spanned! {span=>
#[automatically_derived]
impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
#[track_caller]
fn from(#source_var: #from) -> Self {
#ty::#variant #body
}
@ -505,6 +509,7 @@ fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
fn from_initializer(
from_field: &Field,
backtrace_field: Option<&Field>,
implicit_fields: Vec<&Field>,
source_var: &Ident,
) -> TokenStream {
let from_member = &from_field.member;
@ -525,7 +530,17 @@ fn from_initializer(
}
}
});
let implicit = implicit_fields
.iter()
.map(|field| {
let member = &field.member;
quote! {
#member: ::thiserror::ImplicitField::generate_with_source(&#source_var),
}
})
.collect::<TokenStream>();
quote!({
#implicit
#from_member: #some_source,
#backtrace
})

View file

@ -32,7 +32,7 @@ mod valid;
use proc_macro::TokenStream;
use syn::{parse_macro_input, DeriveInput};
#[proc_macro_derive(Error, attributes(backtrace, error, from, source))]
#[proc_macro_derive(Error, attributes(backtrace, error, from, source, implicit))]
pub fn derive_error(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
expand::derive(&input).into()

View file

@ -20,6 +20,10 @@ impl Struct<'_> {
let backtrace_field = self.backtrace_field()?;
distinct_backtrace_field(backtrace_field, self.from_field())
}
pub(crate) fn implicit_fields(&self) -> Vec<&Field> {
implicit_fields(&self.fields)
}
}
impl Enum<'_> {
@ -67,6 +71,10 @@ impl Variant<'_> {
let backtrace_field = self.backtrace_field()?;
distinct_backtrace_field(backtrace_field, self.from_field())
}
pub(crate) fn implicit_fields(&self) -> Vec<&Field> {
implicit_fields(&self.fields)
}
}
impl Field<'_> {
@ -74,6 +82,10 @@ impl Field<'_> {
type_is_backtrace(self.ty)
}
pub(crate) fn is_implicit(&self) -> bool {
self.attrs.implicit.is_some()
}
pub(crate) fn source_span(&self) -> Span {
if let Some(source_attr) = &self.attrs.source {
source_attr.span
@ -146,3 +158,10 @@ fn type_is_backtrace(ty: &Type) -> bool {
let last = path.segments.last().unwrap();
last.ident == "Backtrace" && last.arguments.is_empty()
}
fn implicit_fields<'a, 'b>(fields: &'a [Field<'b>]) -> Vec<&'a Field<'b>> {
fields
.iter()
.filter(|field| field.attrs.implicit.is_some())
.collect()
}

View file

@ -125,6 +125,12 @@ fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
"not expected here; the #[backtrace] attribute belongs on a specific field",
));
}
if let Some(implicit) = &attrs.implicit {
return Err(Error::new_spanned(
implicit,
"not expected here; the #[implicit] attribute belongs on a specific field",
));
}
if attrs.transparent.is_some() {
if let Some(display) = &attrs.display {
return Err(Error::new_spanned(
@ -152,7 +158,7 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> {
let mut from_field = None;
let mut source_field = None;
let mut backtrace_field = None;
let mut has_backtrace = false;
let mut first_implicit_field = None;
for field in fields {
if let Some(from) = field.attrs.from {
if from_field.is_some() {
@ -180,7 +186,6 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> {
));
}
backtrace_field = Some(field);
has_backtrace = true;
}
if let Some(transparent) = field.attrs.transparent {
return Err(Error::new_spanned(
@ -188,7 +193,9 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> {
"#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
));
}
has_backtrace |= field.is_backtrace();
if field.attrs.implicit.is_some() && first_implicit_field.is_none() {
first_implicit_field = Some(field);
}
}
if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
if from_field.member != source_field.member {
@ -199,14 +206,17 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> {
}
}
if let Some(from_field) = from_field {
let max_expected_fields = match backtrace_field {
Some(backtrace_field) => 1 + (from_field.member != backtrace_field.member) as usize,
None => 1 + has_backtrace as usize,
};
if fields.len() > max_expected_fields {
let has_unexpected_fields = fields.iter().any(|field| {
field.attrs.from.is_none()
&& field.attrs.source.is_none()
&& field.attrs.backtrace.is_none()
&& !field.is_backtrace()
&& !field.is_implicit()
});
if has_unexpected_fields {
return Err(Error::new_spanned(
from_field.attrs.from.unwrap().original,
"deriving From requires no fields other than source and backtrace",
"deriving From requires no fields other than source, backtrace, and implicit",
));
}
}
@ -218,6 +228,14 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> {
));
}
}
if let Some(first_implicit_field) = first_implicit_field {
if from_field.is_none() {
return Err(Error::new_spanned(
first_implicit_field.original,
"implicit fields require a #[from] field",
));
}
}
Ok(())
}

50
src/implicit.rs Normal file
View file

@ -0,0 +1,50 @@
use core::error::Error;
#[cfg(feature = "std")]
use std::sync::Arc;
pub trait ImplicitField {
// Required method
#[track_caller]
fn generate() -> Self;
// Provided method
#[track_caller]
fn generate_with_source(source: &dyn Error) -> Self
where
Self: Sized,
{
let _ = source;
Self::generate()
}
}
#[cfg(feature = "std")]
impl<T: ImplicitField> ImplicitField for Arc<T> {
#[track_caller]
fn generate() -> Self {
T::generate().into()
}
#[track_caller]
fn generate_with_source(source: &dyn Error) -> Self
where
Self: Sized,
{
T::generate_with_source(source).into()
}
}
impl<T: ImplicitField> ImplicitField for Option<T> {
#[track_caller]
fn generate() -> Self {
T::generate().into()
}
#[track_caller]
fn generate_with_source(source: &dyn Error) -> Self
where
Self: Sized,
{
T::generate_with_source(source).into()
}
}

View file

@ -278,10 +278,12 @@ extern crate std as core;
mod aserror;
mod display;
mod implicit;
#[cfg(error_generic_member_access)]
mod provide;
mod var;
pub use implicit::ImplicitField;
pub use thiserror_impl::*;
// Not public API.

99
tests/test_implicit.rs Normal file
View file

@ -0,0 +1,99 @@
use std::{backtrace, sync::Arc};
use thiserror::{Error, ImplicitField};
#[derive(Error, Debug)]
#[error("Inner")]
pub struct Inner;
#[derive(Debug)]
pub struct ImplicitBacktrace(pub backtrace::Backtrace);
impl ImplicitField for ImplicitBacktrace {
fn generate() -> Self {
Self(backtrace::Backtrace::force_capture())
}
}
#[derive(Debug)]
pub struct Location(pub &'static core::panic::Location<'static>);
impl Default for Location {
#[track_caller]
fn default() -> Self {
Self(core::panic::Location::caller())
}
}
impl ImplicitField for Location {
#[track_caller]
fn generate() -> Self {
Self::default()
}
}
#[derive(Error, Debug)]
#[error("location: {location:?}")]
pub struct ErrorStruct {
#[from]
source: Inner,
#[implicit]
backtrace: ImplicitBacktrace,
#[implicit]
location: Location,
#[implicit]
location_arc: Arc<Location>,
#[implicit]
location_opt: Option<Location>,
}
#[derive(Error, Debug)]
#[error("location: {location:?}")]
pub enum ErrorEnum {
#[error("location: {location:?}")]
Test {
#[from]
source: Inner,
#[implicit]
backtrace: ImplicitBacktrace,
#[implicit]
location: Location,
#[implicit]
location_arc: Arc<Location>,
#[implicit]
location_opt: Option<Location>,
},
}
#[test]
fn test_implicit() {
let base_location = Location::default();
let assert_location = |location: &Location| {
assert_eq!(location.0.file(), file!(), "location: {location:?}");
assert!(
location.0.line() > base_location.0.line(),
"location: {location:?}"
);
};
let error = ErrorStruct::from(Inner);
assert_location(&error.location);
assert_location(&error.location_arc);
assert_location(error.location_opt.as_ref().unwrap());
assert_eq!(
error.backtrace.0.status(),
backtrace::BacktraceStatus::Captured
);
let ErrorEnum::Test {
source: _,
backtrace,
location,
location_arc,
location_opt,
} = ErrorEnum::from(Inner);
assert_location(&location);
assert_location(&location_arc);
assert_location(location_opt.as_ref().unwrap());
assert_eq!(backtrace.0.status(), backtrace::BacktraceStatus::Captured);
}