• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::{file, full, gen};
2 use anyhow::Result;
3 use proc_macro2::{Ident, Span, TokenStream};
4 use quote::{format_ident, quote};
5 use syn::Index;
6 use syn_codegen::{Data, Definitions, Features, Node, Type};
7 
8 const FOLD_SRC: &str = "src/gen/fold.rs";
9 
simple_visit(item: &str, name: &TokenStream) -> TokenStream10 fn simple_visit(item: &str, name: &TokenStream) -> TokenStream {
11     let ident = gen::under_name(item);
12     let method = format_ident!("fold_{}", ident);
13     quote! {
14         f.#method(#name)
15     }
16 }
17 
visit( ty: &Type, features: &Features, defs: &Definitions, name: &TokenStream, ) -> Option<TokenStream>18 fn visit(
19     ty: &Type,
20     features: &Features,
21     defs: &Definitions,
22     name: &TokenStream,
23 ) -> Option<TokenStream> {
24     match ty {
25         Type::Box(t) => {
26             let res = visit(t, features, defs, &quote!(*#name))?;
27             Some(quote! {
28                 Box::new(#res)
29             })
30         }
31         Type::Vec(t) => {
32             let operand = quote!(it);
33             let val = visit(t, features, defs, &operand)?;
34             Some(quote! {
35                 FoldHelper::lift(#name, |it| #val)
36             })
37         }
38         Type::Punctuated(p) => {
39             let operand = quote!(it);
40             let val = visit(&p.element, features, defs, &operand)?;
41             Some(quote! {
42                 FoldHelper::lift(#name, |it| #val)
43             })
44         }
45         Type::Option(t) => {
46             let it = quote!(it);
47             let val = visit(t, features, defs, &it)?;
48             Some(quote! {
49                 (#name).map(|it| #val)
50             })
51         }
52         Type::Tuple(t) => {
53             let mut code = TokenStream::new();
54             for (i, elem) in t.iter().enumerate() {
55                 let i = Index::from(i);
56                 let it = quote!((#name).#i);
57                 let val = visit(elem, features, defs, &it).unwrap_or(it);
58                 code.extend(val);
59                 code.extend(quote!(,));
60             }
61             Some(quote! {
62                 (#code)
63             })
64         }
65         Type::Syn(t) => {
66             fn requires_full(features: &Features) -> bool {
67                 features.any.contains("full") && features.any.len() == 1
68             }
69             let mut res = simple_visit(t, name);
70             let target = defs.types.iter().find(|ty| ty.ident == *t).unwrap();
71             if requires_full(&target.features) && !requires_full(features) {
72                 res = quote!(full!(#res));
73             }
74             Some(res)
75         }
76         Type::Ext(t) if gen::TERMINAL_TYPES.contains(&&t[..]) => Some(simple_visit(t, name)),
77         Type::Ext(_) | Type::Std(_) | Type::Token(_) | Type::Group(_) => None,
78     }
79 }
80 
node(traits: &mut TokenStream, impls: &mut TokenStream, s: &Node, defs: &Definitions)81 fn node(traits: &mut TokenStream, impls: &mut TokenStream, s: &Node, defs: &Definitions) {
82     let under_name = gen::under_name(&s.ident);
83     let ty = Ident::new(&s.ident, Span::call_site());
84     let fold_fn = format_ident!("fold_{}", under_name);
85 
86     let mut fold_impl = TokenStream::new();
87 
88     match &s.data {
89         Data::Enum(variants) => {
90             let mut fold_variants = TokenStream::new();
91 
92             for (variant, fields) in variants {
93                 let variant_ident = Ident::new(variant, Span::call_site());
94 
95                 if fields.is_empty() {
96                     fold_variants.extend(quote! {
97                         #ty::#variant_ident => {
98                             #ty::#variant_ident
99                         }
100                     });
101                 } else {
102                     let mut bind_fold_fields = TokenStream::new();
103                     let mut fold_fields = TokenStream::new();
104 
105                     for (idx, ty) in fields.iter().enumerate() {
106                         let binding = format_ident!("_binding_{}", idx);
107 
108                         bind_fold_fields.extend(quote! {
109                             #binding,
110                         });
111 
112                         let owned_binding = quote!(#binding);
113 
114                         fold_fields.extend(
115                             visit(ty, &s.features, defs, &owned_binding).unwrap_or(owned_binding),
116                         );
117 
118                         fold_fields.extend(quote!(,));
119                     }
120 
121                     fold_variants.extend(quote! {
122                         #ty::#variant_ident(#bind_fold_fields) => {
123                             #ty::#variant_ident(
124                                 #fold_fields
125                             )
126                         }
127                     });
128                 }
129             }
130 
131             fold_impl.extend(quote! {
132                 match node {
133                     #fold_variants
134                 }
135             });
136         }
137         Data::Struct(fields) => {
138             let mut fold_fields = TokenStream::new();
139 
140             for (field, ty) in fields {
141                 let id = Ident::new(field, Span::call_site());
142                 let ref_toks = quote!(node.#id);
143 
144                 let fold = visit(ty, &s.features, defs, &ref_toks).unwrap_or(ref_toks);
145 
146                 fold_fields.extend(quote! {
147                     #id: #fold,
148                 });
149             }
150 
151             if fields.is_empty() {
152                 if ty == "Ident" {
153                     fold_impl.extend(quote! {
154                         let mut node = node;
155                         let span = f.fold_span(node.span());
156                         node.set_span(span);
157                     });
158                 }
159                 fold_impl.extend(quote! {
160                     node
161                 });
162             } else {
163                 fold_impl.extend(quote! {
164                     #ty {
165                         #fold_fields
166                     }
167                 });
168             }
169         }
170         Data::Private => {
171             if ty == "Ident" {
172                 fold_impl.extend(quote! {
173                     let mut node = node;
174                     let span = f.fold_span(node.span());
175                     node.set_span(span);
176                 });
177             }
178             fold_impl.extend(quote! {
179                 node
180             });
181         }
182     }
183 
184     let fold_span_only =
185         s.data == Data::Private && !gen::TERMINAL_TYPES.contains(&s.ident.as_str());
186     if fold_span_only {
187         fold_impl = quote! {
188             let span = f.fold_span(node.span());
189             let mut node = node;
190             node.set_span(span);
191             node
192         };
193     }
194 
195     traits.extend(quote! {
196         fn #fold_fn(&mut self, i: #ty) -> #ty {
197             #fold_fn(self, i)
198         }
199     });
200 
201     impls.extend(quote! {
202         pub fn #fold_fn<F>(f: &mut F, node: #ty) -> #ty
203         where
204             F: Fold + ?Sized,
205         {
206             #fold_impl
207         }
208     });
209 }
210 
generate(defs: &Definitions) -> Result<()>211 pub fn generate(defs: &Definitions) -> Result<()> {
212     let (traits, impls) = gen::traverse(defs, node);
213     let full_macro = full::get_macro();
214     file::write(
215         FOLD_SRC,
216         quote! {
217             // Unreachable code is generated sometimes without the full feature.
218             #![allow(unreachable_code, unused_variables)]
219             #![allow(
220                 clippy::match_wildcard_for_single_variants,
221                 clippy::needless_match,
222                 clippy::needless_pass_by_ref_mut,
223             )]
224 
225             #[cfg(any(feature = "full", feature = "derive"))]
226             use crate::gen::helper::fold::*;
227             use crate::*;
228             use proc_macro2::Span;
229 
230             #full_macro
231 
232             /// Syntax tree traversal to transform the nodes of an owned syntax tree.
233             ///
234             /// See the [module documentation] for details.
235             ///
236             /// [module documentation]: self
237             pub trait Fold {
238                 #traits
239             }
240 
241             #impls
242         },
243     )?;
244     Ok(())
245 }
246