• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use proc_macro2::{Span, TokenStream};
2 use quote::{format_ident, quote};
3 use syn::{spanned::Spanned, Data, DeriveInput, Fields};
4 
5 use crate::helpers::{non_enum_error, snakify, HasStrumVariantProperties};
6 
enum_table_inner(ast: &DeriveInput) -> syn::Result<TokenStream>7 pub fn enum_table_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8     let name = &ast.ident;
9     let gen = &ast.generics;
10     let vis = &ast.vis;
11     let mut doc_comment = format!("A map over the variants of `{}`", name);
12 
13     if gen.lifetimes().count() > 0 {
14         return Err(syn::Error::new(
15             Span::call_site(),
16             "`EnumTable` doesn't support enums with lifetimes.",
17         ));
18     }
19 
20     let variants = match &ast.data {
21         Data::Enum(v) => &v.variants,
22         _ => return Err(non_enum_error()),
23     };
24 
25     let table_name = format_ident!("{}Table", name);
26 
27     // the identifiers of each variant, in PascalCase
28     let mut pascal_idents = Vec::new();
29     // the identifiers of each struct field, in snake_case
30     let mut snake_idents = Vec::new();
31     // match arms in the form `MyEnumTable::Variant => &self.variant,`
32     let mut get_matches = Vec::new();
33     // match arms in the form `MyEnumTable::Variant => &mut self.variant,`
34     let mut get_matches_mut = Vec::new();
35     // match arms in the form `MyEnumTable::Variant => self.variant = new_value`
36     let mut set_matches = Vec::new();
37     // struct fields of the form `variant: func(MyEnum::Variant),*
38     let mut closure_fields = Vec::new();
39     // struct fields of the form `variant: func(MyEnum::Variant, self.variant),`
40     let mut transform_fields = Vec::new();
41 
42     // identifiers for disabled variants
43     let mut disabled_variants = Vec::new();
44     // match arms for disabled variants
45     let mut disabled_matches = Vec::new();
46 
47     for variant in variants {
48         // skip disabled variants
49         if variant.get_variant_properties()?.disabled.is_some() {
50             let disabled_ident = &variant.ident;
51             let panic_message = format!(
52                 "Can't use `{}` with `{}` - variant is disabled for Strum features",
53                 disabled_ident, table_name
54             );
55             disabled_variants.push(disabled_ident);
56             disabled_matches.push(quote!(#name::#disabled_ident => panic!(#panic_message),));
57             continue;
58         }
59 
60         // Error on variants with data
61         if !matches!(variant.fields, Fields::Unit) {
62             return Err(syn::Error::new(
63                 variant.fields.span(),
64                 "`EnumTable` doesn't support enums with non-unit variants",
65             ));
66         };
67 
68         let pascal_case = &variant.ident;
69         let snake_case = format_ident!("_{}", snakify(&pascal_case.to_string()));
70 
71         get_matches.push(quote! {#name::#pascal_case => &self.#snake_case,});
72         get_matches_mut.push(quote! {#name::#pascal_case => &mut self.#snake_case,});
73         set_matches.push(quote! {#name::#pascal_case => self.#snake_case = new_value,});
74         closure_fields.push(quote! {#snake_case: func(#name::#pascal_case),});
75         transform_fields.push(quote! {#snake_case: func(#name::#pascal_case, &self.#snake_case),});
76         pascal_idents.push(pascal_case);
77         snake_idents.push(snake_case);
78     }
79 
80     // Error on empty enums
81     if pascal_idents.is_empty() {
82         return Err(syn::Error::new(
83             variants.span(),
84             "`EnumTable` requires at least one non-disabled variant",
85         ));
86     }
87 
88     // if the index operation can panic, add that to the documentation
89     if !disabled_variants.is_empty() {
90         doc_comment.push_str(&format!(
91             "\n# Panics\nIndexing `{}` with any of the following variants will cause a panic:",
92             table_name
93         ));
94         for variant in disabled_variants {
95             doc_comment.push_str(&format!("\n\n- `{}::{}`", name, variant));
96         }
97     }
98 
99     let doc_new = format!(
100         "Create a new {} with a value for each variant of {}",
101         table_name, name
102     );
103     let doc_closure = format!(
104         "Create a new {} by running a function on each variant of `{}`",
105         table_name, name
106     );
107     let doc_transform = format!("Create a new `{}` by running a function on each variant of `{}` and the corresponding value in the current `{0}`", table_name, name);
108     let doc_filled = format!(
109         "Create a new `{}` with the same value in each field.",
110         table_name
111     );
112     let doc_option_all = format!("Converts `{}<Option<T>>` into `Option<{0}<T>>`. Returns `Some` if all fields are `Some`, otherwise returns `None`.", table_name);
113     let doc_result_all_ok = format!("Converts `{}<Result<T, E>>` into `Result<{0}<T>, E>`. Returns `Ok` if all fields are `Ok`, otherwise returns `Err`.", table_name);
114 
115     Ok(quote! {
116         #[doc = #doc_comment]
117         #[allow(
118             missing_copy_implementations,
119         )]
120         #[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
121         #vis struct #table_name<T> {
122             #(#snake_idents: T,)*
123         }
124 
125         impl<T: Clone> #table_name<T> {
126             #[doc = #doc_filled]
127             #vis fn filled(value: T) -> #table_name<T> {
128                 #table_name {
129                     #(#snake_idents: value.clone(),)*
130                 }
131             }
132         }
133 
134         impl<T> #table_name<T> {
135             #[doc = #doc_new]
136             #[inline]
137             #vis fn new(
138                 #(#snake_idents: T,)*
139             ) -> #table_name<T> {
140                 #table_name {
141                     #(#snake_idents,)*
142                 }
143             }
144 
145             #[doc = #doc_closure]
146             #[inline]
147             #vis fn from_closure<F: Fn(#name)->T>(func: F) -> #table_name<T> {
148               #table_name {
149                 #(#closure_fields)*
150               }
151             }
152 
153             #[doc = #doc_transform]
154             #[inline]
155             #vis fn transform<U, F: Fn(#name, &T)->U>(&self, func: F) -> #table_name<U> {
156               #table_name {
157                 #(#transform_fields)*
158               }
159             }
160 
161         }
162 
163         impl<T> ::core::ops::Index<#name> for #table_name<T> {
164             type Output = T;
165 
166             #[inline]
167             fn index(&self, idx: #name) -> &T {
168                 match idx {
169                     #(#get_matches)*
170                     #(#disabled_matches)*
171                 }
172             }
173         }
174 
175         impl<T> ::core::ops::IndexMut<#name> for #table_name<T> {
176             #[inline]
177             fn index_mut(&mut self, idx: #name) -> &mut T {
178                 match idx {
179                     #(#get_matches_mut)*
180                     #(#disabled_matches)*
181                 }
182             }
183         }
184 
185         impl<T> #table_name<::core::option::Option<T>> {
186             #[doc = #doc_option_all]
187             #[inline]
188             #vis fn all(self) -> ::core::option::Option<#table_name<T>> {
189                 if let #table_name {
190                     #(#snake_idents: ::core::option::Option::Some(#snake_idents),)*
191                 } = self {
192                     ::core::option::Option::Some(#table_name {
193                         #(#snake_idents,)*
194                     })
195                 } else {
196                     ::core::option::Option::None
197                 }
198             }
199         }
200 
201         impl<T, E> #table_name<::core::result::Result<T, E>> {
202             #[doc = #doc_result_all_ok]
203             #[inline]
204             #vis fn all_ok(self) -> ::core::result::Result<#table_name<T>, E> {
205                 ::core::result::Result::Ok(#table_name {
206                     #(#snake_idents: self.#snake_idents?,)*
207                 })
208             }
209         }
210     })
211 }
212