• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use proc_macro2::{Span, TokenStream};
2 use quote::quote;
3 use syn::{Data, DeriveInput, Fields, Ident};
4 
5 use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};
6 
enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream>7 pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8     let name = &ast.ident;
9     let gen = &ast.generics;
10     let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
11     let vis = &ast.vis;
12     let type_properties = ast.get_type_properties()?;
13     let strum_module_path = type_properties.crate_module_path();
14     let doc_comment = format!("An iterator over the variants of [{}]", name);
15 
16     if gen.lifetimes().count() > 0 {
17         return Err(syn::Error::new(
18             Span::call_site(),
19             "This macro doesn't support enums with lifetimes. \
20              The resulting enums would be unbounded.",
21         ));
22     }
23 
24     let phantom_data = if gen.type_params().count() > 0 {
25         let g = gen.type_params().map(|param| &param.ident);
26         quote! { < ( #(#g),* ) > }
27     } else {
28         quote! { < () > }
29     };
30 
31     let variants = match &ast.data {
32         Data::Enum(v) => &v.variants,
33         _ => return Err(non_enum_error()),
34     };
35 
36     let mut arms = Vec::new();
37     let mut idx = 0usize;
38     for variant in variants {
39         if variant.get_variant_properties()?.disabled.is_some() {
40             continue;
41         }
42 
43         let ident = &variant.ident;
44         let params = match &variant.fields {
45             Fields::Unit => quote! {},
46             Fields::Unnamed(fields) => {
47                 let defaults = ::core::iter::repeat(quote!(::core::default::Default::default()))
48                     .take(fields.unnamed.len());
49                 quote! { (#(#defaults),*) }
50             }
51             Fields::Named(fields) => {
52                 let fields = fields
53                     .named
54                     .iter()
55                     .map(|field| field.ident.as_ref().unwrap());
56                 quote! { {#(#fields: ::core::default::Default::default()),*} }
57             }
58         };
59 
60         arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)});
61         idx += 1;
62     }
63 
64     let variant_count = arms.len();
65     arms.push(quote! { _ => ::core::option::Option::None });
66     let iter_name = syn::parse_str::<Ident>(&format!("{}Iter", name)).unwrap();
67 
68     // Create a string literal "MyEnumIter" to use in the debug impl.
69     let iter_name_debug_struct =
70         syn::parse_str::<syn::LitStr>(&format!("\"{}\"", iter_name)).unwrap();
71 
72     Ok(quote! {
73         #[doc = #doc_comment]
74         #[allow(
75             missing_copy_implementations,
76         )]
77         #vis struct #iter_name #ty_generics {
78             idx: usize,
79             back_idx: usize,
80             marker: ::core::marker::PhantomData #phantom_data,
81         }
82 
83         impl #impl_generics ::core::fmt::Debug for #iter_name #ty_generics #where_clause {
84             fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
85                 // We don't know if the variants implement debug themselves so the only thing we
86                 // can really show is how many elements are left.
87                 f.debug_struct(#iter_name_debug_struct)
88                     .field("len", &self.len())
89                     .finish()
90             }
91         }
92 
93         impl #impl_generics #iter_name #ty_generics #where_clause {
94             fn get(&self, idx: usize) -> ::core::option::Option<#name #ty_generics> {
95                 match idx {
96                     #(#arms),*
97                 }
98             }
99         }
100 
101         impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause {
102             type Iterator = #iter_name #ty_generics;
103             fn iter() -> #iter_name #ty_generics {
104                 #iter_name {
105                     idx: 0,
106                     back_idx: 0,
107                     marker: ::core::marker::PhantomData,
108                 }
109             }
110         }
111 
112         impl #impl_generics Iterator for #iter_name #ty_generics #where_clause {
113             type Item = #name #ty_generics;
114 
115             fn next(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> {
116                 self.nth(0)
117             }
118 
119             fn size_hint(&self) -> (usize, ::core::option::Option<usize>) {
120                 let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx };
121                 (t, Some(t))
122             }
123 
124             fn nth(&mut self, n: usize) -> ::core::option::Option<<Self as Iterator>::Item> {
125                 let idx = self.idx + n + 1;
126                 if idx + self.back_idx > #variant_count {
127                     // We went past the end of the iterator. Freeze idx at #variant_count
128                     // so that it doesn't overflow if the user calls this repeatedly.
129                     // See PR #76 for context.
130                     self.idx = #variant_count;
131                     ::core::option::Option::None
132                 } else {
133                     self.idx = idx;
134                     self.get(idx - 1)
135                 }
136             }
137         }
138 
139         impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause {
140             fn len(&self) -> usize {
141                 self.size_hint().0
142             }
143         }
144 
145         impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause {
146             fn next_back(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> {
147                 let back_idx = self.back_idx + 1;
148 
149                 if self.idx + back_idx > #variant_count {
150                     // We went past the end of the iterator. Freeze back_idx at #variant_count
151                     // so that it doesn't overflow if the user calls this repeatedly.
152                     // See PR #76 for context.
153                     self.back_idx = #variant_count;
154                     ::core::option::Option::None
155                 } else {
156                     self.back_idx = back_idx;
157                     self.get(#variant_count - self.back_idx)
158                 }
159             }
160         }
161 
162         impl #impl_generics Clone for #iter_name #ty_generics #where_clause {
163             fn clone(&self) -> #iter_name #ty_generics {
164                 #iter_name {
165                     idx: self.idx,
166                     back_idx: self.back_idx,
167                     marker: self.marker.clone(),
168                 }
169             }
170         }
171     })
172 }
173