• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use proc_macro2::{Span, TokenStream, TokenTree};
2 use quote::{quote, ToTokens};
3 use syn::parse_quote;
4 use syn::{Data, DeriveInput, Fields};
5 
6 use crate::helpers::{non_enum_error, strum_discriminants_passthrough_error, HasTypeProperties};
7 
8 /// Attributes to copy from the main enum's variants to the discriminant enum's variants.
9 ///
10 /// Attributes not in this list may be for other `proc_macro`s on the main enum, and may cause
11 /// compilation problems when copied across.
12 const ATTRIBUTES_TO_COPY: &[&str] = &["doc", "cfg", "allow", "deny", "strum_discriminants"];
13 
enum_discriminants_inner(ast: &DeriveInput) -> syn::Result<TokenStream>14 pub fn enum_discriminants_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
15     let name = &ast.ident;
16     let vis = &ast.vis;
17 
18     let variants = match &ast.data {
19         Data::Enum(v) => &v.variants,
20         _ => return Err(non_enum_error()),
21     };
22 
23     // Derives for the generated enum
24     let type_properties = ast.get_type_properties()?;
25 
26     let derives = type_properties.discriminant_derives;
27 
28     let derives = quote! {
29         #[derive(Clone, Copy, Debug, PartialEq, Eq, #(#derives),*)]
30     };
31 
32     // Work out the name
33     let default_name = syn::Ident::new(&format!("{}Discriminants", name), Span::call_site());
34 
35     let discriminants_name = type_properties.discriminant_name.unwrap_or(default_name);
36     let discriminants_vis = type_properties
37         .discriminant_vis
38         .unwrap_or_else(|| vis.clone());
39 
40     // Pass through all other attributes
41     let pass_though_attributes = type_properties.discriminant_others;
42 
43     // Add the variants without fields, but exclude the `strum` meta item
44     let mut discriminants = Vec::new();
45     for variant in variants {
46         let ident = &variant.ident;
47 
48         // Don't copy across the "strum" meta attribute. Only passthrough the whitelisted
49         // attributes and proxy `#[strum_discriminants(...)]` attributes
50         let attrs = variant
51             .attrs
52             .iter()
53             .filter(|attr| {
54                 ATTRIBUTES_TO_COPY
55                     .iter()
56                     .any(|attr_whitelisted| attr.path().is_ident(attr_whitelisted))
57             })
58             .map(|attr| {
59                 if attr.path().is_ident("strum_discriminants") {
60                     let mut ts = attr.meta.require_list()?.to_token_stream().into_iter();
61 
62                     // Discard strum_discriminants(...)
63                     let _ = ts.next();
64 
65                     let passthrough_group = ts
66                         .next()
67                         .ok_or_else(|| strum_discriminants_passthrough_error(attr))?;
68                     let passthrough_attribute = match passthrough_group {
69                         TokenTree::Group(ref group) => group.stream(),
70                         _ => {
71                             return Err(strum_discriminants_passthrough_error(&passthrough_group));
72                         }
73                     };
74                     if passthrough_attribute.is_empty() {
75                         return Err(strum_discriminants_passthrough_error(&passthrough_group));
76                     }
77                     Ok(quote! { #[#passthrough_attribute] })
78                 } else {
79                     Ok(attr.to_token_stream())
80                 }
81             })
82             .collect::<Result<Vec<_>, _>>()?;
83 
84         discriminants.push(quote! { #(#attrs)* #ident });
85     }
86 
87     // Ideally:
88     //
89     // * For `Copy` types, we `impl From<TheEnum> for TheEnumDiscriminants`
90     // * For `!Copy` types, we `impl<'enum> From<&'enum TheEnum> for TheEnumDiscriminants`
91     //
92     // That way we ensure users are not able to pass a `Copy` type by reference. However, the
93     // `#[derive(..)]` attributes are not in the parsed tokens, so we are not able to check if a
94     // type is `Copy`, so we just implement both.
95     //
96     // See <https://github.com/dtolnay/syn/issues/433>
97     // ---
98     // let is_copy = unique_meta_list(type_meta.iter(), "derive")
99     //     .map(extract_list_metas)
100     //     .map(|metas| {
101     //         metas
102     //             .filter_map(get_meta_ident)
103     //             .any(|derive| derive.to_string() == "Copy")
104     //     }).unwrap_or(false);
105 
106     let arms = variants
107         .iter()
108         .map(|variant| {
109             let ident = &variant.ident;
110             let params = match &variant.fields {
111                 Fields::Unit => quote! {},
112                 Fields::Unnamed(_fields) => {
113                     quote! { (..) }
114                 }
115                 Fields::Named(_fields) => {
116                     quote! { { .. } }
117                 }
118             };
119 
120             quote! { #name::#ident #params => #discriminants_name::#ident }
121         })
122         .collect::<Vec<_>>();
123 
124     let from_fn_body = quote! { match val { #(#arms),* } };
125 
126     let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
127     let impl_from = quote! {
128         impl #impl_generics ::core::convert::From< #name #ty_generics > for #discriminants_name #where_clause {
129             fn from(val: #name #ty_generics) -> #discriminants_name {
130                 #from_fn_body
131             }
132         }
133     };
134     let impl_from_ref = {
135         let mut generics = ast.generics.clone();
136 
137         let lifetime = parse_quote!('_enum);
138         let enum_life = quote! { & #lifetime };
139         generics.params.push(lifetime);
140 
141         // Shadows the earlier `impl_generics`
142         let (impl_generics, _, _) = generics.split_for_impl();
143 
144         quote! {
145             impl #impl_generics ::core::convert::From< #enum_life #name #ty_generics > for #discriminants_name #where_clause {
146                 fn from(val: #enum_life #name #ty_generics) -> #discriminants_name {
147                     #from_fn_body
148                 }
149             }
150         }
151     };
152 
153     Ok(quote! {
154         /// Auto-generated discriminant enum variants
155         #derives
156         #(#[ #pass_though_attributes ])*
157         #discriminants_vis enum #discriminants_name {
158             #(#discriminants),*
159         }
160 
161         #impl_from
162         #impl_from_ref
163     })
164 }
165