• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use proc_macro2::TokenStream;
2 use quote::quote;
3 use syn::{parse_quote, Data, DeriveInput, Fields, Path};
4 
5 use crate::helpers::{
6     missing_parse_err_attr_error, non_enum_error, occurrence_error, HasInnerVariantProperties,
7     HasStrumVariantProperties, HasTypeProperties,
8 };
9 
from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream>10 pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
11     let name = &ast.ident;
12     let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
13     let variants = match &ast.data {
14         Data::Enum(v) => &v.variants,
15         _ => return Err(non_enum_error()),
16     };
17 
18     let type_properties = ast.get_type_properties()?;
19     let strum_module_path = type_properties.crate_module_path();
20 
21     let mut default_kw = None;
22     let (mut default_err_ty, mut default) = match (
23         type_properties.parse_err_ty,
24         type_properties.parse_err_fn,
25     ) {
26         (None, None) => (
27             quote! { #strum_module_path::ParseError },
28             quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) },
29         ),
30         (Some(ty), Some(f)) => {
31             let ty_path: Path = parse_quote!(#ty);
32             let fn_path: Path = parse_quote!(#f);
33 
34             (
35                 quote! { #ty_path },
36                 quote! { ::core::result::Result::Err(#fn_path(s)) },
37             )
38         }
39         _ => return Err(missing_parse_err_attr_error()),
40     };
41     let mut phf_exact_match_arms = Vec::new();
42     let mut standard_match_arms = Vec::new();
43     for variant in variants {
44         let ident = &variant.ident;
45         let variant_properties = variant.get_variant_properties()?;
46 
47         if variant_properties.disabled.is_some() {
48             continue;
49         }
50 
51         if let Some(kw) = variant_properties.default {
52             if let Some(fst_kw) = default_kw {
53                 return Err(occurrence_error(fst_kw, kw, "default"));
54             }
55 
56             default_kw = Some(kw);
57             default_err_ty = quote! { #strum_module_path::ParseError };
58 
59             match &variant.fields {
60                 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
61                     default = quote! {
62                         ::core::result::Result::Ok(#name::#ident(s.into()))
63                     };
64                 }
65                 Fields::Named(ref f) if f.named.len() == 1 => {
66                     let field_name = f.named.last().unwrap().ident.as_ref().unwrap();
67                     default = quote! {
68                         ::core::result::Result::Ok(#name::#ident { #field_name : s.into() } )
69                     };
70                 }
71                 _ => {
72                     return Err(syn::Error::new_spanned(
73                         variant,
74                         "Default only works on newtype structs with a single String field",
75                     ))
76                 }
77             }
78 
79             continue;
80         }
81 
82         let params = match &variant.fields {
83             Fields::Unit => quote! {},
84             Fields::Unnamed(fields) => {
85                 if let Some(ref value) = variant_properties.default_with {
86                     let func = proc_macro2::Ident::new(&value.value(), value.span());
87                     let defaults = vec![quote! { #func() }];
88                     quote! { (#(#defaults),*) }
89                 } else {
90                     let defaults =
91                         ::core::iter::repeat(quote!(Default::default())).take(fields.unnamed.len());
92                     quote! { (#(#defaults),*) }
93                 }
94             }
95             Fields::Named(fields) => {
96                 let mut defaults = vec![];
97                 for field in &fields.named {
98                     let meta = field.get_variant_inner_properties()?;
99                     let field = field.ident.as_ref().unwrap();
100 
101                     if let Some(default_with) = meta.default_with {
102                         let func =
103                             proc_macro2::Ident::new(&default_with.value(), default_with.span());
104                         defaults.push(quote! {
105                             #field: #func()
106                         });
107                     } else {
108                         defaults.push(quote! { #field: Default::default() });
109                     }
110                 }
111 
112                 quote! { {#(#defaults),*} }
113             }
114         };
115 
116         let is_ascii_case_insensitive = variant_properties
117             .ascii_case_insensitive
118             .unwrap_or(type_properties.ascii_case_insensitive);
119 
120         // If we don't have any custom variants, add the default serialized name.
121         for serialization in variant_properties.get_serializations(type_properties.case_style) {
122             if type_properties.use_phf {
123                 phf_exact_match_arms.push(quote! { #serialization => #name::#ident #params, });
124 
125                 if is_ascii_case_insensitive {
126                     // Store the lowercase and UPPERCASE variants in the phf map to capture
127                     let ser_string = serialization.value();
128 
129                     let lower =
130                         syn::LitStr::new(&ser_string.to_ascii_lowercase(), serialization.span());
131                     let upper =
132                         syn::LitStr::new(&ser_string.to_ascii_uppercase(), serialization.span());
133                     phf_exact_match_arms.push(quote! { #lower => #name::#ident #params, });
134                     phf_exact_match_arms.push(quote! { #upper => #name::#ident #params, });
135                     standard_match_arms.push(quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, });
136                 }
137             } else {
138                 standard_match_arms.push(if !is_ascii_case_insensitive {
139                     quote! { #serialization => #name::#ident #params, }
140                 } else {
141                     quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, }
142                 });
143             }
144         }
145     }
146 
147     let phf_body = if phf_exact_match_arms.is_empty() {
148         quote!()
149     } else {
150         quote! {
151             use #strum_module_path::_private_phf_reexport_for_macro_if_phf_feature as phf;
152             static PHF: phf::Map<&'static str, #name> = phf::phf_map! {
153                 #(#phf_exact_match_arms)*
154             };
155             if let Some(value) = PHF.get(s).cloned() {
156                 return ::core::result::Result::Ok(value);
157             }
158         }
159     };
160 
161     let standard_match_body = if standard_match_arms.is_empty() {
162         default
163     } else {
164         quote! {
165             ::core::result::Result::Ok(match s {
166                 #(#standard_match_arms)*
167                 _ => return #default,
168             })
169         }
170     };
171 
172     let from_str = quote! {
173         #[allow(clippy::use_self)]
174         impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
175             type Err = #default_err_ty;
176 
177             #[inline]
178             fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::str::FromStr>::Err> {
179                 #phf_body
180                 #standard_match_body
181             }
182         }
183     };
184     let try_from_str = try_from_str(
185         name,
186         &impl_generics,
187         &ty_generics,
188         where_clause,
189         &default_err_ty,
190     );
191 
192     Ok(quote! {
193         #from_str
194         #try_from_str
195     })
196 }
197 
198 #[rustversion::before(1.34)]
try_from_str( _name: &proc_macro2::Ident, _impl_generics: &syn::ImplGenerics, _ty_generics: &syn::TypeGenerics, _where_clause: Option<&syn::WhereClause>, _strum_module_path: &syn::Path, ) -> TokenStream199 fn try_from_str(
200     _name: &proc_macro2::Ident,
201     _impl_generics: &syn::ImplGenerics,
202     _ty_generics: &syn::TypeGenerics,
203     _where_clause: Option<&syn::WhereClause>,
204     _strum_module_path: &syn::Path,
205 ) -> TokenStream {
206     Default::default()
207 }
208 
209 #[rustversion::since(1.34)]
try_from_str( name: &proc_macro2::Ident, impl_generics: &syn::ImplGenerics, ty_generics: &syn::TypeGenerics, where_clause: Option<&syn::WhereClause>, default_err_ty: &TokenStream, ) -> TokenStream210 fn try_from_str(
211     name: &proc_macro2::Ident,
212     impl_generics: &syn::ImplGenerics,
213     ty_generics: &syn::TypeGenerics,
214     where_clause: Option<&syn::WhereClause>,
215     default_err_ty: &TokenStream,
216 ) -> TokenStream {
217     quote! {
218         #[allow(clippy::use_self)]
219         impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause {
220             type Error = #default_err_ty;
221 
222             #[inline]
223             fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::convert::TryFrom<&str>>::Error> {
224                 ::core::str::FromStr::from_str(s)
225             }
226         }
227     }
228 }
229