use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{Data, DeriveInput, Fields, Ident}; use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties}; pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result { let name = &ast.ident; let gen = &ast.generics; let (impl_generics, ty_generics, where_clause) = gen.split_for_impl(); let vis = &ast.vis; let type_properties = ast.get_type_properties()?; let strum_module_path = type_properties.crate_module_path(); let doc_comment = format!("An iterator over the variants of [{}]", name); if gen.lifetimes().count() > 0 { return Err(syn::Error::new( Span::call_site(), "This macro doesn't support enums with lifetimes. \ The resulting enums would be unbounded.", )); } let phantom_data = if gen.type_params().count() > 0 { let g = gen.type_params().map(|param| ¶m.ident); quote! { < ( #(#g),* ) > } } else { quote! { < () > } }; let variants = match &ast.data { Data::Enum(v) => &v.variants, _ => return Err(non_enum_error()), }; let mut arms = Vec::new(); let mut idx = 0usize; for variant in variants { if variant.get_variant_properties()?.disabled.is_some() { continue; } let ident = &variant.ident; let params = match &variant.fields { Fields::Unit => quote! {}, Fields::Unnamed(fields) => { let defaults = ::core::iter::repeat(quote!(::core::default::Default::default())) .take(fields.unnamed.len()); quote! { (#(#defaults),*) } } Fields::Named(fields) => { let fields = fields .named .iter() .map(|field| field.ident.as_ref().unwrap()); quote! { {#(#fields: ::core::default::Default::default()),*} } } }; arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)}); idx += 1; } let variant_count = arms.len(); arms.push(quote! { _ => ::core::option::Option::None }); let iter_name = syn::parse_str::(&format!("{}Iter", name)).unwrap(); // Create a string literal "MyEnumIter" to use in the debug impl. let iter_name_debug_struct = syn::parse_str::(&format!("\"{}\"", iter_name)).unwrap(); Ok(quote! { #[doc = #doc_comment] #[allow( missing_copy_implementations, )] #vis struct #iter_name #ty_generics { idx: usize, back_idx: usize, marker: ::core::marker::PhantomData #phantom_data, } impl #impl_generics ::core::fmt::Debug for #iter_name #ty_generics #where_clause { fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { // We don't know if the variants implement debug themselves so the only thing we // can really show is how many elements are left. f.debug_struct(#iter_name_debug_struct) .field("len", &self.len()) .finish() } } impl #impl_generics #iter_name #ty_generics #where_clause { fn get(&self, idx: usize) -> ::core::option::Option<#name #ty_generics> { match idx { #(#arms),* } } } impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause { type Iterator = #iter_name #ty_generics; fn iter() -> #iter_name #ty_generics { #iter_name { idx: 0, back_idx: 0, marker: ::core::marker::PhantomData, } } } impl #impl_generics Iterator for #iter_name #ty_generics #where_clause { type Item = #name #ty_generics; fn next(&mut self) -> ::core::option::Option<::Item> { self.nth(0) } fn size_hint(&self) -> (usize, ::core::option::Option) { let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx }; (t, Some(t)) } fn nth(&mut self, n: usize) -> ::core::option::Option<::Item> { let idx = self.idx + n + 1; if idx + self.back_idx > #variant_count { // We went past the end of the iterator. Freeze idx at #variant_count // so that it doesn't overflow if the user calls this repeatedly. // See PR #76 for context. self.idx = #variant_count; ::core::option::Option::None } else { self.idx = idx; self.get(idx - 1) } } } impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause { fn len(&self) -> usize { self.size_hint().0 } } impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause { fn next_back(&mut self) -> ::core::option::Option<::Item> { let back_idx = self.back_idx + 1; if self.idx + back_idx > #variant_count { // We went past the end of the iterator. Freeze back_idx at #variant_count // so that it doesn't overflow if the user calls this repeatedly. // See PR #76 for context. self.back_idx = #variant_count; ::core::option::Option::None } else { self.back_idx = back_idx; self.get(#variant_count - self.back_idx) } } } impl #impl_generics Clone for #iter_name #ty_generics #where_clause { fn clone(&self) -> #iter_name #ty_generics { #iter_name { idx: self.idx, back_idx: self.back_idx, marker: self.marker.clone(), } } } }) }