Implied bounds for Display and Error impl

This commit is contained in:
David Tolnay 2021-09-04 16:13:12 -07:00
parent 81b881063f
commit 1e6e267914
No known key found for this signature in database
GPG key ID: F9BA143B95FF6D82
7 changed files with 170 additions and 24 deletions

View file

@ -1,8 +1,12 @@
use crate::ast::{Enum, Field, Input, Struct};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::BTreeSet as Set;
use syn::spanned::Spanned;
use syn::{Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type, Visibility};
use syn::{
parse_quote, Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type,
Visibility,
};
pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
let input = Input::from_syn(node)?;
@ -16,6 +20,8 @@ pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
fn impl_struct(input: Struct) -> TokenStream {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut error_generics = input.generics.clone();
let error_where_clause = error_generics.make_where_clause();
let source_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0].member;
@ -24,6 +30,12 @@ fn impl_struct(input: Struct) -> TokenStream {
})
} else if let Some(source_field) = input.source_field() {
let source = &source_field.member;
if source_field.contains_generic {
let ty = unoptional_type(source_field.ty);
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error + 'static));
}
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
} else {
@ -89,12 +101,14 @@ fn impl_struct(input: Struct) -> TokenStream {
}
});
let mut display_implied_bounds = &Set::new();
let display_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0].member;
Some(quote! {
std::fmt::Display::fmt(&self.#only_field, __formatter)
})
} else if let Some(display) = &input.attrs.display {
display_implied_bounds = &display.implied_bounds;
let use_as_display = if display.has_bonus_display {
Some(quote! {
#[allow(unused_imports)]
@ -114,9 +128,20 @@ fn impl_struct(input: Struct) -> TokenStream {
None
};
let display_impl = display_body.map(|body| {
let mut display_generics = input.generics.clone();
let display_where_clause = display_generics.make_where_clause();
for &(field, bound) in display_implied_bounds {
let field = &input.fields[field];
if field.contains_generic {
let field_ty = field.ty;
display_where_clause
.predicates
.push(parse_quote!(#field_ty: #bound));
}
}
quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
#[allow(clippy::used_underscore_binding)]
fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
#body
@ -141,10 +166,15 @@ fn impl_struct(input: Struct) -> TokenStream {
});
let error_trait = spanned_error_trait(input.original);
if input.generics.type_params().next().is_some() {
error_where_clause
.predicates
.push(parse_quote!(Self: std::fmt::Debug + std::fmt::Display));
}
quote! {
#[allow(unused_qualifications)]
impl #impl_generics #error_trait for #ty #ty_generics #where_clause {
impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
#source_method
#backtrace_method
}
@ -156,6 +186,8 @@ fn impl_struct(input: Struct) -> TokenStream {
fn impl_enum(input: Enum) -> TokenStream {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut error_generics = input.generics.clone();
let error_where_clause = error_generics.make_where_clause();
let source_method = if input.has_source() {
let arms = input.variants.iter().map(|variant| {
@ -168,6 +200,12 @@ fn impl_enum(input: Enum) -> TokenStream {
}
} else if let Some(source_field) = variant.source_field() {
let source = &source_field.member;
if source_field.contains_generic {
let ty = unoptional_type(source_field.ty);
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error + 'static));
}
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
} else {
@ -286,6 +324,8 @@ fn impl_enum(input: Enum) -> TokenStream {
};
let display_impl = if input.has_display() {
let mut display_generics = input.generics.clone();
let display_where_clause = display_generics.make_where_clause();
let use_as_display = if input.variants.iter().any(|v| {
v.attrs
.display
@ -305,8 +345,12 @@ fn impl_enum(input: Enum) -> TokenStream {
None
};
let arms = input.variants.iter().map(|variant| {
let mut display_implied_bounds = &Set::new();
let display = match &variant.attrs.display {
Some(display) => display.to_token_stream(),
Some(display) => {
display_implied_bounds = &display.implied_bounds;
display.to_token_stream()
}
None => {
let only_field = match &variant.fields[0].member {
Member::Named(ident) => ident.clone(),
@ -315,15 +359,25 @@ fn impl_enum(input: Enum) -> TokenStream {
quote!(std::fmt::Display::fmt(#only_field, __formatter))
}
};
for &(field, bound) in display_implied_bounds {
let field = &variant.fields[field];
if field.contains_generic {
let field_ty = field.ty;
display_where_clause
.predicates
.push(parse_quote!(#field_ty: #bound));
}
}
let ident = &variant.ident;
let pat = fields_pat(&variant.fields);
quote! {
#ty::#ident #pat => #display
}
});
let arms = arms.collect::<Vec<_>>();
Some(quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
#use_as_display
#[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
@ -355,10 +409,15 @@ fn impl_enum(input: Enum) -> TokenStream {
});
let error_trait = spanned_error_trait(input.original);
if input.generics.type_params().next().is_some() {
error_where_clause
.predicates
.push(parse_quote!(Self: std::fmt::Debug + std::fmt::Display));
}
quote! {
#[allow(unused_qualifications)]
impl #impl_generics #error_trait for #ty #ty_generics #where_clause {
impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
#source_method
#backtrace_method
}