• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::collections::HashSet;
2 
3 use syn;
4 use syn::punctuated::{Pair, Punctuated};
5 
6 use internals::ast::{Container, Data};
7 use internals::{attr, ungroup};
8 
9 use proc_macro2::Span;
10 
11 // Remove the default from every type parameter because in the generated impls
12 // they look like associated types: "error: associated type bindings are not
13 // allowed here".
without_defaults(generics: &syn::Generics) -> syn::Generics14 pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
15     syn::Generics {
16         params: generics
17             .params
18             .iter()
19             .map(|param| match param {
20                 syn::GenericParam::Type(param) => syn::GenericParam::Type(syn::TypeParam {
21                     eq_token: None,
22                     default: None,
23                     ..param.clone()
24                 }),
25                 _ => param.clone(),
26             })
27             .collect(),
28         ..generics.clone()
29     }
30 }
31 
with_where_predicates( generics: &syn::Generics, predicates: &[syn::WherePredicate], ) -> syn::Generics32 pub fn with_where_predicates(
33     generics: &syn::Generics,
34     predicates: &[syn::WherePredicate],
35 ) -> syn::Generics {
36     let mut generics = generics.clone();
37     generics
38         .make_where_clause()
39         .predicates
40         .extend(predicates.iter().cloned());
41     generics
42 }
43 
with_where_predicates_from_fields( cont: &Container, generics: &syn::Generics, from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>, ) -> syn::Generics44 pub fn with_where_predicates_from_fields(
45     cont: &Container,
46     generics: &syn::Generics,
47     from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
48 ) -> syn::Generics {
49     let predicates = cont
50         .data
51         .all_fields()
52         .filter_map(|field| from_field(&field.attrs))
53         .flat_map(<[syn::WherePredicate]>::to_vec);
54 
55     let mut generics = generics.clone();
56     generics.make_where_clause().predicates.extend(predicates);
57     generics
58 }
59 
with_where_predicates_from_variants( cont: &Container, generics: &syn::Generics, from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>, ) -> syn::Generics60 pub fn with_where_predicates_from_variants(
61     cont: &Container,
62     generics: &syn::Generics,
63     from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>,
64 ) -> syn::Generics {
65     let variants = match &cont.data {
66         Data::Enum(variants) => variants,
67         Data::Struct(_, _) => {
68             return generics.clone();
69         }
70     };
71 
72     let predicates = variants
73         .iter()
74         .filter_map(|variant| from_variant(&variant.attrs))
75         .flat_map(<[syn::WherePredicate]>::to_vec);
76 
77     let mut generics = generics.clone();
78     generics.make_where_clause().predicates.extend(predicates);
79     generics
80 }
81 
82 // Puts the given bound on any generic type parameters that are used in fields
83 // for which filter returns true.
84 //
85 // For example, the following struct needs the bound `A: Serialize, B:
86 // Serialize`.
87 //
88 //     struct S<'b, A, B: 'b, C> {
89 //         a: A,
90 //         b: Option<&'b B>
91 //         #[serde(skip_serializing)]
92 //         c: C,
93 //     }
with_bound( cont: &Container, generics: &syn::Generics, filter: fn(&attr::Field, Option<&attr::Variant>) -> bool, bound: &syn::Path, ) -> syn::Generics94 pub fn with_bound(
95     cont: &Container,
96     generics: &syn::Generics,
97     filter: fn(&attr::Field, Option<&attr::Variant>) -> bool,
98     bound: &syn::Path,
99 ) -> syn::Generics {
100     struct FindTyParams<'ast> {
101         // Set of all generic type parameters on the current struct (A, B, C in
102         // the example). Initialized up front.
103         all_type_params: HashSet<syn::Ident>,
104 
105         // Set of generic type parameters used in fields for which filter
106         // returns true (A and B in the example). Filled in as the visitor sees
107         // them.
108         relevant_type_params: HashSet<syn::Ident>,
109 
110         // Fields whose type is an associated type of one of the generic type
111         // parameters.
112         associated_type_usage: Vec<&'ast syn::TypePath>,
113     }
114 
115     impl<'ast> FindTyParams<'ast> {
116         fn visit_field(&mut self, field: &'ast syn::Field) {
117             if let syn::Type::Path(ty) = ungroup(&field.ty) {
118                 if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() {
119                     if self.all_type_params.contains(&t.ident) {
120                         self.associated_type_usage.push(ty);
121                     }
122                 }
123             }
124             self.visit_type(&field.ty);
125         }
126 
127         fn visit_path(&mut self, path: &'ast syn::Path) {
128             if let Some(seg) = path.segments.last() {
129                 if seg.ident == "PhantomData" {
130                     // Hardcoded exception, because PhantomData<T> implements
131                     // Serialize and Deserialize whether or not T implements it.
132                     return;
133                 }
134             }
135             if path.leading_colon.is_none() && path.segments.len() == 1 {
136                 let id = &path.segments[0].ident;
137                 if self.all_type_params.contains(id) {
138                     self.relevant_type_params.insert(id.clone());
139                 }
140             }
141             for segment in &path.segments {
142                 self.visit_path_segment(segment);
143             }
144         }
145 
146         // Everything below is simply traversing the syntax tree.
147 
148         fn visit_type(&mut self, ty: &'ast syn::Type) {
149             match ty {
150                 syn::Type::Array(ty) => self.visit_type(&ty.elem),
151                 syn::Type::BareFn(ty) => {
152                     for arg in &ty.inputs {
153                         self.visit_type(&arg.ty);
154                     }
155                     self.visit_return_type(&ty.output);
156                 }
157                 syn::Type::Group(ty) => self.visit_type(&ty.elem),
158                 syn::Type::ImplTrait(ty) => {
159                     for bound in &ty.bounds {
160                         self.visit_type_param_bound(bound);
161                     }
162                 }
163                 syn::Type::Macro(ty) => self.visit_macro(&ty.mac),
164                 syn::Type::Paren(ty) => self.visit_type(&ty.elem),
165                 syn::Type::Path(ty) => {
166                     if let Some(qself) = &ty.qself {
167                         self.visit_type(&qself.ty);
168                     }
169                     self.visit_path(&ty.path);
170                 }
171                 syn::Type::Ptr(ty) => self.visit_type(&ty.elem),
172                 syn::Type::Reference(ty) => self.visit_type(&ty.elem),
173                 syn::Type::Slice(ty) => self.visit_type(&ty.elem),
174                 syn::Type::TraitObject(ty) => {
175                     for bound in &ty.bounds {
176                         self.visit_type_param_bound(bound);
177                     }
178                 }
179                 syn::Type::Tuple(ty) => {
180                     for elem in &ty.elems {
181                         self.visit_type(elem);
182                     }
183                 }
184 
185                 syn::Type::Infer(_) | syn::Type::Never(_) | syn::Type::Verbatim(_) => {}
186 
187                 #[cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
188                 _ => {}
189             }
190         }
191 
192         fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) {
193             self.visit_path_arguments(&segment.arguments);
194         }
195 
196         fn visit_path_arguments(&mut self, arguments: &'ast syn::PathArguments) {
197             match arguments {
198                 syn::PathArguments::None => {}
199                 syn::PathArguments::AngleBracketed(arguments) => {
200                     for arg in &arguments.args {
201                         match arg {
202                             syn::GenericArgument::Type(arg) => self.visit_type(arg),
203                             syn::GenericArgument::AssocType(arg) => self.visit_type(&arg.ty),
204                             syn::GenericArgument::Lifetime(_)
205                             | syn::GenericArgument::Const(_)
206                             | syn::GenericArgument::AssocConst(_)
207                             | syn::GenericArgument::Constraint(_) => {}
208                             #[cfg_attr(
209                                 all(test, exhaustive),
210                                 deny(non_exhaustive_omitted_patterns)
211                             )]
212                             _ => {}
213                         }
214                     }
215                 }
216                 syn::PathArguments::Parenthesized(arguments) => {
217                     for argument in &arguments.inputs {
218                         self.visit_type(argument);
219                     }
220                     self.visit_return_type(&arguments.output);
221                 }
222             }
223         }
224 
225         fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) {
226             match return_type {
227                 syn::ReturnType::Default => {}
228                 syn::ReturnType::Type(_, output) => self.visit_type(output),
229             }
230         }
231 
232         fn visit_type_param_bound(&mut self, bound: &'ast syn::TypeParamBound) {
233             match bound {
234                 syn::TypeParamBound::Trait(bound) => self.visit_path(&bound.path),
235                 syn::TypeParamBound::Lifetime(_) | syn::TypeParamBound::Verbatim(_) => {}
236                 #[cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
237                 _ => {}
238             }
239         }
240 
241         // Type parameter should not be considered used by a macro path.
242         //
243         //     struct TypeMacro<T> {
244         //         mac: T!(),
245         //         marker: PhantomData<T>,
246         //     }
247         fn visit_macro(&mut self, _mac: &'ast syn::Macro) {}
248     }
249 
250     let all_type_params = generics
251         .type_params()
252         .map(|param| param.ident.clone())
253         .collect();
254 
255     let mut visitor = FindTyParams {
256         all_type_params,
257         relevant_type_params: HashSet::new(),
258         associated_type_usage: Vec::new(),
259     };
260     match &cont.data {
261         Data::Enum(variants) => {
262             for variant in variants.iter() {
263                 let relevant_fields = variant
264                     .fields
265                     .iter()
266                     .filter(|field| filter(&field.attrs, Some(&variant.attrs)));
267                 for field in relevant_fields {
268                     visitor.visit_field(field.original);
269                 }
270             }
271         }
272         Data::Struct(_, fields) => {
273             for field in fields.iter().filter(|field| filter(&field.attrs, None)) {
274                 visitor.visit_field(field.original);
275             }
276         }
277     }
278 
279     let relevant_type_params = visitor.relevant_type_params;
280     let associated_type_usage = visitor.associated_type_usage;
281     let new_predicates = generics
282         .type_params()
283         .map(|param| param.ident.clone())
284         .filter(|id| relevant_type_params.contains(id))
285         .map(|id| syn::TypePath {
286             qself: None,
287             path: id.into(),
288         })
289         .chain(associated_type_usage.into_iter().cloned())
290         .map(|bounded_ty| {
291             syn::WherePredicate::Type(syn::PredicateType {
292                 lifetimes: None,
293                 // the type parameter that is being bounded e.g. T
294                 bounded_ty: syn::Type::Path(bounded_ty),
295                 colon_token: <Token![:]>::default(),
296                 // the bound e.g. Serialize
297                 bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
298                     paren_token: None,
299                     modifier: syn::TraitBoundModifier::None,
300                     lifetimes: None,
301                     path: bound.clone(),
302                 })]
303                 .into_iter()
304                 .collect(),
305             })
306         });
307 
308     let mut generics = generics.clone();
309     generics
310         .make_where_clause()
311         .predicates
312         .extend(new_predicates);
313     generics
314 }
315 
with_self_bound( cont: &Container, generics: &syn::Generics, bound: &syn::Path, ) -> syn::Generics316 pub fn with_self_bound(
317     cont: &Container,
318     generics: &syn::Generics,
319     bound: &syn::Path,
320 ) -> syn::Generics {
321     let mut generics = generics.clone();
322     generics
323         .make_where_clause()
324         .predicates
325         .push(syn::WherePredicate::Type(syn::PredicateType {
326             lifetimes: None,
327             // the type that is being bounded e.g. MyStruct<'a, T>
328             bounded_ty: type_of_item(cont),
329             colon_token: <Token![:]>::default(),
330             // the bound e.g. Default
331             bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
332                 paren_token: None,
333                 modifier: syn::TraitBoundModifier::None,
334                 lifetimes: None,
335                 path: bound.clone(),
336             })]
337             .into_iter()
338             .collect(),
339         }));
340     generics
341 }
342 
with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics343 pub fn with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics {
344     let bound = syn::Lifetime::new(lifetime, Span::call_site());
345     let def = syn::LifetimeParam {
346         attrs: Vec::new(),
347         lifetime: bound.clone(),
348         colon_token: None,
349         bounds: Punctuated::new(),
350     };
351 
352     let params = Some(syn::GenericParam::Lifetime(def))
353         .into_iter()
354         .chain(generics.params.iter().cloned().map(|mut param| {
355             match &mut param {
356                 syn::GenericParam::Lifetime(param) => {
357                     param.bounds.push(bound.clone());
358                 }
359                 syn::GenericParam::Type(param) => {
360                     param
361                         .bounds
362                         .push(syn::TypeParamBound::Lifetime(bound.clone()));
363                 }
364                 syn::GenericParam::Const(_) => {}
365             }
366             param
367         }))
368         .collect();
369 
370     syn::Generics {
371         params,
372         ..generics.clone()
373     }
374 }
375 
type_of_item(cont: &Container) -> syn::Type376 fn type_of_item(cont: &Container) -> syn::Type {
377     syn::Type::Path(syn::TypePath {
378         qself: None,
379         path: syn::Path {
380             leading_colon: None,
381             segments: vec![syn::PathSegment {
382                 ident: cont.ident.clone(),
383                 arguments: syn::PathArguments::AngleBracketed(
384                     syn::AngleBracketedGenericArguments {
385                         colon2_token: None,
386                         lt_token: <Token![<]>::default(),
387                         args: cont
388                             .generics
389                             .params
390                             .iter()
391                             .map(|param| match param {
392                                 syn::GenericParam::Type(param) => {
393                                     syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
394                                         qself: None,
395                                         path: param.ident.clone().into(),
396                                     }))
397                                 }
398                                 syn::GenericParam::Lifetime(param) => {
399                                     syn::GenericArgument::Lifetime(param.lifetime.clone())
400                                 }
401                                 syn::GenericParam::Const(_) => {
402                                     panic!("Serde does not support const generics yet");
403                                 }
404                             })
405                             .collect(),
406                         gt_token: <Token![>]>::default(),
407                     },
408                 ),
409             }]
410             .into_iter()
411             .collect(),
412         },
413     })
414 }
415