• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::internals::ast::{Container, Data};
2 use crate::internals::{attr, ungroup};
3 use proc_macro2::Span;
4 use std::collections::HashSet;
5 use syn::punctuated::{Pair, Punctuated};
6 use syn::Token;
7 
8 // Remove the default from every type parameter because in the generated impls
9 // they look like associated types: "error: associated type bindings are not
10 // allowed here".
without_defaults(generics: &syn::Generics) -> syn::Generics11 pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
12     syn::Generics {
13         params: generics
14             .params
15             .iter()
16             .map(|param| match param {
17                 syn::GenericParam::Type(param) => syn::GenericParam::Type(syn::TypeParam {
18                     eq_token: None,
19                     default: None,
20                     ..param.clone()
21                 }),
22                 _ => param.clone(),
23             })
24             .collect(),
25         ..generics.clone()
26     }
27 }
28 
with_where_predicates( generics: &syn::Generics, predicates: &[syn::WherePredicate], ) -> syn::Generics29 pub fn with_where_predicates(
30     generics: &syn::Generics,
31     predicates: &[syn::WherePredicate],
32 ) -> syn::Generics {
33     let mut generics = generics.clone();
34     generics
35         .make_where_clause()
36         .predicates
37         .extend(predicates.iter().cloned());
38     generics
39 }
40 
with_where_predicates_from_fields( cont: &Container, generics: &syn::Generics, from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>, ) -> syn::Generics41 pub fn with_where_predicates_from_fields(
42     cont: &Container,
43     generics: &syn::Generics,
44     from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
45 ) -> syn::Generics {
46     let predicates = cont
47         .data
48         .all_fields()
49         .filter_map(|field| from_field(&field.attrs))
50         .flat_map(<[syn::WherePredicate]>::to_vec);
51 
52     let mut generics = generics.clone();
53     generics.make_where_clause().predicates.extend(predicates);
54     generics
55 }
56 
with_where_predicates_from_variants( cont: &Container, generics: &syn::Generics, from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>, ) -> syn::Generics57 pub fn with_where_predicates_from_variants(
58     cont: &Container,
59     generics: &syn::Generics,
60     from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>,
61 ) -> syn::Generics {
62     let variants = match &cont.data {
63         Data::Enum(variants) => variants,
64         Data::Struct(_, _) => {
65             return generics.clone();
66         }
67     };
68 
69     let predicates = variants
70         .iter()
71         .filter_map(|variant| from_variant(&variant.attrs))
72         .flat_map(<[syn::WherePredicate]>::to_vec);
73 
74     let mut generics = generics.clone();
75     generics.make_where_clause().predicates.extend(predicates);
76     generics
77 }
78 
79 // Puts the given bound on any generic type parameters that are used in fields
80 // for which filter returns true.
81 //
82 // For example, the following struct needs the bound `A: Serialize, B:
83 // Serialize`.
84 //
85 //     struct S<'b, A, B: 'b, C> {
86 //         a: A,
87 //         b: Option<&'b B>
88 //         #[serde(skip_serializing)]
89 //         c: C,
90 //     }
with_bound( cont: &Container, generics: &syn::Generics, filter: fn(&attr::Field, Option<&attr::Variant>) -> bool, bound: &syn::Path, ) -> syn::Generics91 pub fn with_bound(
92     cont: &Container,
93     generics: &syn::Generics,
94     filter: fn(&attr::Field, Option<&attr::Variant>) -> bool,
95     bound: &syn::Path,
96 ) -> syn::Generics {
97     struct FindTyParams<'ast> {
98         // Set of all generic type parameters on the current struct (A, B, C in
99         // the example). Initialized up front.
100         all_type_params: HashSet<syn::Ident>,
101 
102         // Set of generic type parameters used in fields for which filter
103         // returns true (A and B in the example). Filled in as the visitor sees
104         // them.
105         relevant_type_params: HashSet<syn::Ident>,
106 
107         // Fields whose type is an associated type of one of the generic type
108         // parameters.
109         associated_type_usage: Vec<&'ast syn::TypePath>,
110     }
111 
112     impl<'ast> FindTyParams<'ast> {
113         fn visit_field(&mut self, field: &'ast syn::Field) {
114             if let syn::Type::Path(ty) = ungroup(&field.ty) {
115                 if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() {
116                     if self.all_type_params.contains(&t.ident) {
117                         self.associated_type_usage.push(ty);
118                     }
119                 }
120             }
121             self.visit_type(&field.ty);
122         }
123 
124         fn visit_path(&mut self, path: &'ast syn::Path) {
125             if let Some(seg) = path.segments.last() {
126                 if seg.ident == "PhantomData" {
127                     // Hardcoded exception, because PhantomData<T> implements
128                     // Serialize and Deserialize whether or not T implements it.
129                     return;
130                 }
131             }
132             if path.leading_colon.is_none() && path.segments.len() == 1 {
133                 let id = &path.segments[0].ident;
134                 if self.all_type_params.contains(id) {
135                     self.relevant_type_params.insert(id.clone());
136                 }
137             }
138             for segment in &path.segments {
139                 self.visit_path_segment(segment);
140             }
141         }
142 
143         // Everything below is simply traversing the syntax tree.
144 
145         fn visit_type(&mut self, ty: &'ast syn::Type) {
146             match ty {
147                 #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
148                 syn::Type::Array(ty) => self.visit_type(&ty.elem),
149                 syn::Type::BareFn(ty) => {
150                     for arg in &ty.inputs {
151                         self.visit_type(&arg.ty);
152                     }
153                     self.visit_return_type(&ty.output);
154                 }
155                 syn::Type::Group(ty) => self.visit_type(&ty.elem),
156                 syn::Type::ImplTrait(ty) => {
157                     for bound in &ty.bounds {
158                         self.visit_type_param_bound(bound);
159                     }
160                 }
161                 syn::Type::Macro(ty) => self.visit_macro(&ty.mac),
162                 syn::Type::Paren(ty) => self.visit_type(&ty.elem),
163                 syn::Type::Path(ty) => {
164                     if let Some(qself) = &ty.qself {
165                         self.visit_type(&qself.ty);
166                     }
167                     self.visit_path(&ty.path);
168                 }
169                 syn::Type::Ptr(ty) => self.visit_type(&ty.elem),
170                 syn::Type::Reference(ty) => self.visit_type(&ty.elem),
171                 syn::Type::Slice(ty) => self.visit_type(&ty.elem),
172                 syn::Type::TraitObject(ty) => {
173                     for bound in &ty.bounds {
174                         self.visit_type_param_bound(bound);
175                     }
176                 }
177                 syn::Type::Tuple(ty) => {
178                     for elem in &ty.elems {
179                         self.visit_type(elem);
180                     }
181                 }
182 
183                 syn::Type::Infer(_) | syn::Type::Never(_) | syn::Type::Verbatim(_) => {}
184 
185                 _ => {}
186             }
187         }
188 
189         fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) {
190             self.visit_path_arguments(&segment.arguments);
191         }
192 
193         fn visit_path_arguments(&mut self, arguments: &'ast syn::PathArguments) {
194             match arguments {
195                 syn::PathArguments::None => {}
196                 syn::PathArguments::AngleBracketed(arguments) => {
197                     for arg in &arguments.args {
198                         match arg {
199                             #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
200                             syn::GenericArgument::Type(arg) => self.visit_type(arg),
201                             syn::GenericArgument::AssocType(arg) => self.visit_type(&arg.ty),
202                             syn::GenericArgument::Lifetime(_)
203                             | syn::GenericArgument::Const(_)
204                             | syn::GenericArgument::AssocConst(_)
205                             | syn::GenericArgument::Constraint(_) => {}
206                             _ => {}
207                         }
208                     }
209                 }
210                 syn::PathArguments::Parenthesized(arguments) => {
211                     for argument in &arguments.inputs {
212                         self.visit_type(argument);
213                     }
214                     self.visit_return_type(&arguments.output);
215                 }
216             }
217         }
218 
219         fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) {
220             match return_type {
221                 syn::ReturnType::Default => {}
222                 syn::ReturnType::Type(_, output) => self.visit_type(output),
223             }
224         }
225 
226         fn visit_type_param_bound(&mut self, bound: &'ast syn::TypeParamBound) {
227             match bound {
228                 #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
229                 syn::TypeParamBound::Trait(bound) => self.visit_path(&bound.path),
230                 syn::TypeParamBound::Lifetime(_)
231                 | syn::TypeParamBound::PreciseCapture(_)
232                 | syn::TypeParamBound::Verbatim(_) => {}
233                 _ => {}
234             }
235         }
236 
237         // Type parameter should not be considered used by a macro path.
238         //
239         //     struct TypeMacro<T> {
240         //         mac: T!(),
241         //         marker: PhantomData<T>,
242         //     }
243         fn visit_macro(&mut self, _mac: &'ast syn::Macro) {}
244     }
245 
246     let all_type_params = generics
247         .type_params()
248         .map(|param| param.ident.clone())
249         .collect();
250 
251     let mut visitor = FindTyParams {
252         all_type_params,
253         relevant_type_params: HashSet::new(),
254         associated_type_usage: Vec::new(),
255     };
256     match &cont.data {
257         Data::Enum(variants) => {
258             for variant in variants {
259                 let relevant_fields = variant
260                     .fields
261                     .iter()
262                     .filter(|field| filter(&field.attrs, Some(&variant.attrs)));
263                 for field in relevant_fields {
264                     visitor.visit_field(field.original);
265                 }
266             }
267         }
268         Data::Struct(_, fields) => {
269             for field in fields.iter().filter(|field| filter(&field.attrs, None)) {
270                 visitor.visit_field(field.original);
271             }
272         }
273     }
274 
275     let relevant_type_params = visitor.relevant_type_params;
276     let associated_type_usage = visitor.associated_type_usage;
277     let new_predicates = generics
278         .type_params()
279         .map(|param| param.ident.clone())
280         .filter(|id| relevant_type_params.contains(id))
281         .map(|id| syn::TypePath {
282             qself: None,
283             path: id.into(),
284         })
285         .chain(associated_type_usage.into_iter().cloned())
286         .map(|bounded_ty| {
287             syn::WherePredicate::Type(syn::PredicateType {
288                 lifetimes: None,
289                 // the type parameter that is being bounded e.g. T
290                 bounded_ty: syn::Type::Path(bounded_ty),
291                 colon_token: <Token![:]>::default(),
292                 // the bound e.g. Serialize
293                 bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
294                     paren_token: None,
295                     modifier: syn::TraitBoundModifier::None,
296                     lifetimes: None,
297                     path: bound.clone(),
298                 })]
299                 .into_iter()
300                 .collect(),
301             })
302         });
303 
304     let mut generics = generics.clone();
305     generics
306         .make_where_clause()
307         .predicates
308         .extend(new_predicates);
309     generics
310 }
311 
with_self_bound( cont: &Container, generics: &syn::Generics, bound: &syn::Path, ) -> syn::Generics312 pub fn with_self_bound(
313     cont: &Container,
314     generics: &syn::Generics,
315     bound: &syn::Path,
316 ) -> syn::Generics {
317     let mut generics = generics.clone();
318     generics
319         .make_where_clause()
320         .predicates
321         .push(syn::WherePredicate::Type(syn::PredicateType {
322             lifetimes: None,
323             // the type that is being bounded e.g. MyStruct<'a, T>
324             bounded_ty: type_of_item(cont),
325             colon_token: <Token![:]>::default(),
326             // the bound e.g. Default
327             bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
328                 paren_token: None,
329                 modifier: syn::TraitBoundModifier::None,
330                 lifetimes: None,
331                 path: bound.clone(),
332             })]
333             .into_iter()
334             .collect(),
335         }));
336     generics
337 }
338 
with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics339 pub fn with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics {
340     let bound = syn::Lifetime::new(lifetime, Span::call_site());
341     let def = syn::LifetimeParam {
342         attrs: Vec::new(),
343         lifetime: bound.clone(),
344         colon_token: None,
345         bounds: Punctuated::new(),
346     };
347 
348     let params = Some(syn::GenericParam::Lifetime(def))
349         .into_iter()
350         .chain(generics.params.iter().cloned().map(|mut param| {
351             match &mut param {
352                 syn::GenericParam::Lifetime(param) => {
353                     param.bounds.push(bound.clone());
354                 }
355                 syn::GenericParam::Type(param) => {
356                     param
357                         .bounds
358                         .push(syn::TypeParamBound::Lifetime(bound.clone()));
359                 }
360                 syn::GenericParam::Const(_) => {}
361             }
362             param
363         }))
364         .collect();
365 
366     syn::Generics {
367         params,
368         ..generics.clone()
369     }
370 }
371 
type_of_item(cont: &Container) -> syn::Type372 fn type_of_item(cont: &Container) -> syn::Type {
373     syn::Type::Path(syn::TypePath {
374         qself: None,
375         path: syn::Path {
376             leading_colon: None,
377             segments: vec![syn::PathSegment {
378                 ident: cont.ident.clone(),
379                 arguments: syn::PathArguments::AngleBracketed(
380                     syn::AngleBracketedGenericArguments {
381                         colon2_token: None,
382                         lt_token: <Token![<]>::default(),
383                         args: cont
384                             .generics
385                             .params
386                             .iter()
387                             .map(|param| match param {
388                                 syn::GenericParam::Type(param) => {
389                                     syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
390                                         qself: None,
391                                         path: param.ident.clone().into(),
392                                     }))
393                                 }
394                                 syn::GenericParam::Lifetime(param) => {
395                                     syn::GenericArgument::Lifetime(param.lifetime.clone())
396                                 }
397                                 syn::GenericParam::Const(_) => {
398                                     panic!("Serde does not support const generics yet");
399                                 }
400                             })
401                             .collect(),
402                         gt_token: <Token![>]>::default(),
403                     },
404                 ),
405             }]
406             .into_iter()
407             .collect(),
408         },
409     })
410 }
411