1 use std::collections::HashSet;
2
3 use syn;
4 use syn::punctuated::{Pair, Punctuated};
5
6 use internals::ast::{Container, Data};
7 use internals::{attr, ungroup};
8
9 use proc_macro2::Span;
10
11 // Remove the default from every type parameter because in the generated impls
12 // they look like associated types: "error: associated type bindings are not
13 // allowed here".
without_defaults(generics: &syn::Generics) -> syn::Generics14 pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
15 syn::Generics {
16 params: generics
17 .params
18 .iter()
19 .map(|param| match param {
20 syn::GenericParam::Type(param) => syn::GenericParam::Type(syn::TypeParam {
21 eq_token: None,
22 default: None,
23 ..param.clone()
24 }),
25 _ => param.clone(),
26 })
27 .collect(),
28 ..generics.clone()
29 }
30 }
31
with_where_predicates( generics: &syn::Generics, predicates: &[syn::WherePredicate], ) -> syn::Generics32 pub fn with_where_predicates(
33 generics: &syn::Generics,
34 predicates: &[syn::WherePredicate],
35 ) -> syn::Generics {
36 let mut generics = generics.clone();
37 generics
38 .make_where_clause()
39 .predicates
40 .extend(predicates.iter().cloned());
41 generics
42 }
43
with_where_predicates_from_fields( cont: &Container, generics: &syn::Generics, from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>, ) -> syn::Generics44 pub fn with_where_predicates_from_fields(
45 cont: &Container,
46 generics: &syn::Generics,
47 from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
48 ) -> syn::Generics {
49 let predicates = cont
50 .data
51 .all_fields()
52 .flat_map(|field| from_field(&field.attrs))
53 .flat_map(|predicates| predicates.to_vec());
54
55 let mut generics = generics.clone();
56 generics.make_where_clause().predicates.extend(predicates);
57 generics
58 }
59
with_where_predicates_from_variants( cont: &Container, generics: &syn::Generics, from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>, ) -> syn::Generics60 pub fn with_where_predicates_from_variants(
61 cont: &Container,
62 generics: &syn::Generics,
63 from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>,
64 ) -> syn::Generics {
65 let variants = match &cont.data {
66 Data::Enum(variants) => variants,
67 Data::Struct(_, _) => {
68 return generics.clone();
69 }
70 };
71
72 let predicates = variants
73 .iter()
74 .flat_map(|variant| from_variant(&variant.attrs))
75 .flat_map(|predicates| predicates.to_vec());
76
77 let mut generics = generics.clone();
78 generics.make_where_clause().predicates.extend(predicates);
79 generics
80 }
81
82 // Puts the given bound on any generic type parameters that are used in fields
83 // for which filter returns true.
84 //
85 // For example, the following struct needs the bound `A: Serialize, B:
86 // Serialize`.
87 //
88 // struct S<'b, A, B: 'b, C> {
89 // a: A,
90 // b: Option<&'b B>
91 // #[serde(skip_serializing)]
92 // c: C,
93 // }
with_bound( cont: &Container, generics: &syn::Generics, filter: fn(&attr::Field, Option<&attr::Variant>) -> bool, bound: &syn::Path, ) -> syn::Generics94 pub fn with_bound(
95 cont: &Container,
96 generics: &syn::Generics,
97 filter: fn(&attr::Field, Option<&attr::Variant>) -> bool,
98 bound: &syn::Path,
99 ) -> syn::Generics {
100 struct FindTyParams<'ast> {
101 // Set of all generic type parameters on the current struct (A, B, C in
102 // the example). Initialized up front.
103 all_type_params: HashSet<syn::Ident>,
104
105 // Set of generic type parameters used in fields for which filter
106 // returns true (A and B in the example). Filled in as the visitor sees
107 // them.
108 relevant_type_params: HashSet<syn::Ident>,
109
110 // Fields whose type is an associated type of one of the generic type
111 // parameters.
112 associated_type_usage: Vec<&'ast syn::TypePath>,
113 }
114
115 impl<'ast> FindTyParams<'ast> {
116 fn visit_field(&mut self, field: &'ast syn::Field) {
117 if let syn::Type::Path(ty) = ungroup(&field.ty) {
118 if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() {
119 if self.all_type_params.contains(&t.ident) {
120 self.associated_type_usage.push(ty);
121 }
122 }
123 }
124 self.visit_type(&field.ty);
125 }
126
127 fn visit_path(&mut self, path: &'ast syn::Path) {
128 if let Some(seg) = path.segments.last() {
129 if seg.ident == "PhantomData" {
130 // Hardcoded exception, because PhantomData<T> implements
131 // Serialize and Deserialize whether or not T implements it.
132 return;
133 }
134 }
135 if path.leading_colon.is_none() && path.segments.len() == 1 {
136 let id = &path.segments[0].ident;
137 if self.all_type_params.contains(id) {
138 self.relevant_type_params.insert(id.clone());
139 }
140 }
141 for segment in &path.segments {
142 self.visit_path_segment(segment);
143 }
144 }
145
146 // Everything below is simply traversing the syntax tree.
147
148 fn visit_type(&mut self, ty: &'ast syn::Type) {
149 match ty {
150 syn::Type::Array(ty) => self.visit_type(&ty.elem),
151 syn::Type::BareFn(ty) => {
152 for arg in &ty.inputs {
153 self.visit_type(&arg.ty);
154 }
155 self.visit_return_type(&ty.output);
156 }
157 syn::Type::Group(ty) => self.visit_type(&ty.elem),
158 syn::Type::ImplTrait(ty) => {
159 for bound in &ty.bounds {
160 self.visit_type_param_bound(bound);
161 }
162 }
163 syn::Type::Macro(ty) => self.visit_macro(&ty.mac),
164 syn::Type::Paren(ty) => self.visit_type(&ty.elem),
165 syn::Type::Path(ty) => {
166 if let Some(qself) = &ty.qself {
167 self.visit_type(&qself.ty);
168 }
169 self.visit_path(&ty.path);
170 }
171 syn::Type::Ptr(ty) => self.visit_type(&ty.elem),
172 syn::Type::Reference(ty) => self.visit_type(&ty.elem),
173 syn::Type::Slice(ty) => self.visit_type(&ty.elem),
174 syn::Type::TraitObject(ty) => {
175 for bound in &ty.bounds {
176 self.visit_type_param_bound(bound);
177 }
178 }
179 syn::Type::Tuple(ty) => {
180 for elem in &ty.elems {
181 self.visit_type(elem);
182 }
183 }
184
185 syn::Type::Infer(_) | syn::Type::Never(_) | syn::Type::Verbatim(_) => {}
186
187 #[cfg(test)]
188 syn::Type::__TestExhaustive(_) => unimplemented!(),
189 #[cfg(not(test))]
190 _ => {}
191 }
192 }
193
194 fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) {
195 self.visit_path_arguments(&segment.arguments);
196 }
197
198 fn visit_path_arguments(&mut self, arguments: &'ast syn::PathArguments) {
199 match arguments {
200 syn::PathArguments::None => {}
201 syn::PathArguments::AngleBracketed(arguments) => {
202 for arg in &arguments.args {
203 match arg {
204 syn::GenericArgument::Type(arg) => self.visit_type(arg),
205 syn::GenericArgument::Binding(arg) => self.visit_type(&arg.ty),
206 syn::GenericArgument::Lifetime(_)
207 | syn::GenericArgument::Constraint(_)
208 | syn::GenericArgument::Const(_) => {}
209 }
210 }
211 }
212 syn::PathArguments::Parenthesized(arguments) => {
213 for argument in &arguments.inputs {
214 self.visit_type(argument);
215 }
216 self.visit_return_type(&arguments.output);
217 }
218 }
219 }
220
221 fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) {
222 match return_type {
223 syn::ReturnType::Default => {}
224 syn::ReturnType::Type(_, output) => self.visit_type(output),
225 }
226 }
227
228 fn visit_type_param_bound(&mut self, bound: &'ast syn::TypeParamBound) {
229 match bound {
230 syn::TypeParamBound::Trait(bound) => self.visit_path(&bound.path),
231 syn::TypeParamBound::Lifetime(_) => {}
232 }
233 }
234
235 // Type parameter should not be considered used by a macro path.
236 //
237 // struct TypeMacro<T> {
238 // mac: T!(),
239 // marker: PhantomData<T>,
240 // }
241 fn visit_macro(&mut self, _mac: &'ast syn::Macro) {}
242 }
243
244 let all_type_params = generics
245 .type_params()
246 .map(|param| param.ident.clone())
247 .collect();
248
249 let mut visitor = FindTyParams {
250 all_type_params,
251 relevant_type_params: HashSet::new(),
252 associated_type_usage: Vec::new(),
253 };
254 match &cont.data {
255 Data::Enum(variants) => {
256 for variant in variants.iter() {
257 let relevant_fields = variant
258 .fields
259 .iter()
260 .filter(|field| filter(&field.attrs, Some(&variant.attrs)));
261 for field in relevant_fields {
262 visitor.visit_field(field.original);
263 }
264 }
265 }
266 Data::Struct(_, fields) => {
267 for field in fields.iter().filter(|field| filter(&field.attrs, None)) {
268 visitor.visit_field(field.original);
269 }
270 }
271 }
272
273 let relevant_type_params = visitor.relevant_type_params;
274 let associated_type_usage = visitor.associated_type_usage;
275 let new_predicates = generics
276 .type_params()
277 .map(|param| param.ident.clone())
278 .filter(|id| relevant_type_params.contains(id))
279 .map(|id| syn::TypePath {
280 qself: None,
281 path: id.into(),
282 })
283 .chain(associated_type_usage.into_iter().cloned())
284 .map(|bounded_ty| {
285 syn::WherePredicate::Type(syn::PredicateType {
286 lifetimes: None,
287 // the type parameter that is being bounded e.g. T
288 bounded_ty: syn::Type::Path(bounded_ty),
289 colon_token: <Token![:]>::default(),
290 // the bound e.g. Serialize
291 bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
292 paren_token: None,
293 modifier: syn::TraitBoundModifier::None,
294 lifetimes: None,
295 path: bound.clone(),
296 })]
297 .into_iter()
298 .collect(),
299 })
300 });
301
302 let mut generics = generics.clone();
303 generics
304 .make_where_clause()
305 .predicates
306 .extend(new_predicates);
307 generics
308 }
309
with_self_bound( cont: &Container, generics: &syn::Generics, bound: &syn::Path, ) -> syn::Generics310 pub fn with_self_bound(
311 cont: &Container,
312 generics: &syn::Generics,
313 bound: &syn::Path,
314 ) -> syn::Generics {
315 let mut generics = generics.clone();
316 generics
317 .make_where_clause()
318 .predicates
319 .push(syn::WherePredicate::Type(syn::PredicateType {
320 lifetimes: None,
321 // the type that is being bounded e.g. MyStruct<'a, T>
322 bounded_ty: type_of_item(cont),
323 colon_token: <Token![:]>::default(),
324 // the bound e.g. Default
325 bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
326 paren_token: None,
327 modifier: syn::TraitBoundModifier::None,
328 lifetimes: None,
329 path: bound.clone(),
330 })]
331 .into_iter()
332 .collect(),
333 }));
334 generics
335 }
336
with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics337 pub fn with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics {
338 let bound = syn::Lifetime::new(lifetime, Span::call_site());
339 let def = syn::LifetimeDef {
340 attrs: Vec::new(),
341 lifetime: bound.clone(),
342 colon_token: None,
343 bounds: Punctuated::new(),
344 };
345
346 let params = Some(syn::GenericParam::Lifetime(def))
347 .into_iter()
348 .chain(generics.params.iter().cloned().map(|mut param| {
349 match &mut param {
350 syn::GenericParam::Lifetime(param) => {
351 param.bounds.push(bound.clone());
352 }
353 syn::GenericParam::Type(param) => {
354 param
355 .bounds
356 .push(syn::TypeParamBound::Lifetime(bound.clone()));
357 }
358 syn::GenericParam::Const(_) => {}
359 }
360 param
361 }))
362 .collect();
363
364 syn::Generics {
365 params,
366 ..generics.clone()
367 }
368 }
369
type_of_item(cont: &Container) -> syn::Type370 fn type_of_item(cont: &Container) -> syn::Type {
371 syn::Type::Path(syn::TypePath {
372 qself: None,
373 path: syn::Path {
374 leading_colon: None,
375 segments: vec![syn::PathSegment {
376 ident: cont.ident.clone(),
377 arguments: syn::PathArguments::AngleBracketed(
378 syn::AngleBracketedGenericArguments {
379 colon2_token: None,
380 lt_token: <Token![<]>::default(),
381 args: cont
382 .generics
383 .params
384 .iter()
385 .map(|param| match param {
386 syn::GenericParam::Type(param) => {
387 syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
388 qself: None,
389 path: param.ident.clone().into(),
390 }))
391 }
392 syn::GenericParam::Lifetime(param) => {
393 syn::GenericArgument::Lifetime(param.lifetime.clone())
394 }
395 syn::GenericParam::Const(_) => {
396 panic!("Serde does not support const generics yet");
397 }
398 })
399 .collect(),
400 gt_token: <Token![>]>::default(),
401 },
402 ),
403 }]
404 .into_iter()
405 .collect(),
406 },
407 })
408 }
409