• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use proc_macro2::TokenStream;
2 use quote::quote;
3 use syn::{Data, DeriveInput, Fields};
4 
5 use crate::helpers::{
6     non_enum_error, occurrence_error, HasStrumVariantProperties, HasTypeProperties,
7 };
8 
from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream>9 pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
10     let name = &ast.ident;
11     let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
12     let variants = match &ast.data {
13         Data::Enum(v) => &v.variants,
14         _ => return Err(non_enum_error()),
15     };
16 
17     let type_properties = ast.get_type_properties()?;
18     let strum_module_path = type_properties.crate_module_path();
19 
20     let mut default_kw = None;
21     let mut default =
22         quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) };
23 
24     let mut phf_exact_match_arms = Vec::new();
25     let mut standard_match_arms = Vec::new();
26     for variant in variants {
27         let ident = &variant.ident;
28         let variant_properties = variant.get_variant_properties()?;
29 
30         if variant_properties.disabled.is_some() {
31             continue;
32         }
33 
34         if let Some(kw) = variant_properties.default {
35             if let Some(fst_kw) = default_kw {
36                 return Err(occurrence_error(fst_kw, kw, "default"));
37             }
38 
39             match &variant.fields {
40                 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {}
41                 _ => {
42                     return Err(syn::Error::new_spanned(
43                         variant,
44                         "Default only works on newtype structs with a single String field",
45                     ))
46                 }
47             }
48 
49             default_kw = Some(kw);
50             default = quote! {
51                 ::core::result::Result::Ok(#name::#ident(s.into()))
52             };
53             continue;
54         }
55 
56         let params = match &variant.fields {
57             Fields::Unit => quote! {},
58             Fields::Unnamed(fields) => {
59                 let defaults =
60                     ::core::iter::repeat(quote!(Default::default())).take(fields.unnamed.len());
61                 quote! { (#(#defaults),*) }
62             }
63             Fields::Named(fields) => {
64                 let fields = fields
65                     .named
66                     .iter()
67                     .map(|field| field.ident.as_ref().unwrap());
68                 quote! { {#(#fields: Default::default()),*} }
69             }
70         };
71 
72         let is_ascii_case_insensitive = variant_properties
73             .ascii_case_insensitive
74             .unwrap_or(type_properties.ascii_case_insensitive);
75 
76         // If we don't have any custom variants, add the default serialized name.
77         for serialization in variant_properties.get_serializations(type_properties.case_style) {
78             if type_properties.use_phf {
79                 phf_exact_match_arms.push(quote! { #serialization => #name::#ident #params, });
80 
81                 if is_ascii_case_insensitive {
82                     // Store the lowercase and UPPERCASE variants in the phf map to capture
83                     let ser_string = serialization.value();
84 
85                     let lower =
86                         syn::LitStr::new(&ser_string.to_ascii_lowercase(), serialization.span());
87                     let upper =
88                         syn::LitStr::new(&ser_string.to_ascii_uppercase(), serialization.span());
89                     phf_exact_match_arms.push(quote! { #lower => #name::#ident #params, });
90                     phf_exact_match_arms.push(quote! { #upper => #name::#ident #params, });
91                     standard_match_arms.push(quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, });
92                 }
93             } else {
94                 standard_match_arms.push(if !is_ascii_case_insensitive {
95                     quote! { #serialization => #name::#ident #params, }
96                 } else {
97                     quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, }
98                 });
99             }
100         }
101     }
102 
103     let phf_body = if phf_exact_match_arms.is_empty() {
104         quote!()
105     } else {
106         quote! {
107             use #strum_module_path::_private_phf_reexport_for_macro_if_phf_feature as phf;
108             static PHF: phf::Map<&'static str, #name> = phf::phf_map! {
109                 #(#phf_exact_match_arms)*
110             };
111             if let Some(value) = PHF.get(s).cloned() {
112                 return ::core::result::Result::Ok(value);
113             }
114         }
115     };
116     let standard_match_body = if standard_match_arms.is_empty() {
117         default
118     } else {
119         quote! {
120             ::core::result::Result::Ok(match s {
121                 #(#standard_match_arms)*
122                 _ => return #default,
123             })
124         }
125     };
126 
127     let from_str = quote! {
128         #[allow(clippy::use_self)]
129         impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
130             type Err = #strum_module_path::ParseError;
131             fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::str::FromStr>::Err> {
132                 #phf_body
133                 #standard_match_body
134             }
135         }
136     };
137 
138     let try_from_str = try_from_str(
139         name,
140         &impl_generics,
141         &ty_generics,
142         where_clause,
143         &strum_module_path,
144     );
145 
146     Ok(quote! {
147         #from_str
148         #try_from_str
149     })
150 }
151 
152 #[rustversion::before(1.34)]
try_from_str( _name: &proc_macro2::Ident, _impl_generics: &syn::ImplGenerics, _ty_generics: &syn::TypeGenerics, _where_clause: Option<&syn::WhereClause>, _strum_module_path: &syn::Path, ) -> TokenStream153 fn try_from_str(
154     _name: &proc_macro2::Ident,
155     _impl_generics: &syn::ImplGenerics,
156     _ty_generics: &syn::TypeGenerics,
157     _where_clause: Option<&syn::WhereClause>,
158     _strum_module_path: &syn::Path,
159 ) -> TokenStream {
160     Default::default()
161 }
162 
163 #[rustversion::since(1.34)]
try_from_str( name: &proc_macro2::Ident, impl_generics: &syn::ImplGenerics, ty_generics: &syn::TypeGenerics, where_clause: Option<&syn::WhereClause>, strum_module_path: &syn::Path, ) -> TokenStream164 fn try_from_str(
165     name: &proc_macro2::Ident,
166     impl_generics: &syn::ImplGenerics,
167     ty_generics: &syn::TypeGenerics,
168     where_clause: Option<&syn::WhereClause>,
169     strum_module_path: &syn::Path,
170 ) -> TokenStream {
171     quote! {
172         #[allow(clippy::use_self)]
173         impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause {
174             type Error = #strum_module_path::ParseError;
175             fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::convert::TryFrom<&str>>::Error> {
176                 ::core::str::FromStr::from_str(s)
177             }
178         }
179     }
180 }
181