• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // vim: tw=80
2 use std::collections::HashSet;
3 
4 use proc_macro2::TokenStream;
5 use quote::{ToTokens, format_ident, quote};
6 use syn::{
7     *,
8     spanned::Spanned
9 };
10 
11 use crate::{
12     AttrFormatter,
13     MockableStruct,
14     compile_error,
15     gen_mod_ident,
16     mock_function::{self, MockFunction},
17     mock_trait::MockTrait
18 };
19 
phantom_default_inits(generics: &Generics) -> Vec<TokenStream>20 fn phantom_default_inits(generics: &Generics) -> Vec<TokenStream> {
21     generics.params
22     .iter()
23     .enumerate()
24     .map(|(count, _param)| {
25         let phident = format_ident!("_t{count}");
26         quote!(#phident: ::std::marker::PhantomData)
27     }).collect()
28 }
29 
30 /// Generate any PhantomData field definitions
phantom_fields(generics: &Generics) -> Vec<TokenStream>31 fn phantom_fields(generics: &Generics) -> Vec<TokenStream> {
32     generics.params
33     .iter()
34     .enumerate()
35     .filter_map(|(count, param)| {
36         let phident = format_ident!("_t{count}");
37         match param {
38             syn::GenericParam::Lifetime(l) => {
39                 if !l.bounds.is_empty() {
40                     compile_error(l.bounds.span(),
41                         "#automock does not yet support lifetime bounds on structs");
42                 }
43                 let lifetime = &l.lifetime;
44                 Some(
45                 quote!(#phident: ::std::marker::PhantomData<&#lifetime ()>)
46                 )
47             },
48             syn::GenericParam::Type(tp) => {
49                 let ty = &tp.ident;
50                 Some(
51                 quote!(#phident: ::std::marker::PhantomData<#ty>)
52                 )
53             },
54             syn::GenericParam::Const(_) => {
55                 compile_error(param.span(),
56                     "#automock does not yet support generic constants");
57                 None
58             }
59         }
60     }).collect()
61 }
62 
63 /// Filter out multiple copies of the same trait, even if they're implemented on
64 /// different types.  But allow them if they have different attributes, which
65 /// probably indicates that they aren't meant to be compiled together.
unique_trait_iter<'a, I: Iterator<Item = &'a MockTrait>>(i: I) -> impl Iterator<Item = &'a MockTrait>66 fn unique_trait_iter<'a, I: Iterator<Item = &'a MockTrait>>(i: I)
67     -> impl Iterator<Item = &'a MockTrait>
68 {
69     let mut hs = HashSet::<(Path, Vec<Attribute>)>::default();
70     i.filter(move |mt| {
71         let impl_attrs = AttrFormatter::new(&mt.attrs)
72             .async_trait(false)
73             .doc(false)
74             .format();
75         let key = (mt.trait_path.clone(), impl_attrs);
76         if hs.contains(&key) {
77             false
78         } else {
79             hs.insert(key);
80             true
81         }
82     })
83 }
84 
85 /// A collection of methods defined in one spot
86 struct Methods(Vec<MockFunction>);
87 
88 impl Methods {
89     /// Are all of these methods static?
all_static(&self) -> bool90     fn all_static(&self) -> bool {
91         self.0.iter()
92             .all(|meth| meth.is_static())
93     }
94 
checkpoints(&self) -> Vec<impl ToTokens>95     fn checkpoints(&self) -> Vec<impl ToTokens> {
96         self.0.iter()
97             .filter(|meth| !meth.is_static())
98             .map(|meth| meth.checkpoint())
99             .collect::<Vec<_>>()
100     }
101 
102     /// Return a fragment of code to initialize struct fields during default()
default_inits(&self) -> Vec<TokenStream>103     fn default_inits(&self) -> Vec<TokenStream> {
104         self.0.iter()
105             .filter(|meth| !meth.is_static())
106             .map(|meth| {
107                 let name = meth.name();
108                 let attrs = AttrFormatter::new(&meth.attrs)
109                     .doc(false)
110                     .format();
111                 quote!(#(#attrs)* #name: Default::default())
112             }).collect::<Vec<_>>()
113     }
114 
field_definitions(&self, modname: &Ident) -> Vec<TokenStream>115     fn field_definitions(&self, modname: &Ident) -> Vec<TokenStream> {
116         self.0.iter()
117             .filter(|meth| !meth.is_static())
118             .map(|meth| meth.field_definition(Some(modname)))
119             .collect::<Vec<_>>()
120     }
121 
priv_mods(&self) -> Vec<impl ToTokens>122     fn priv_mods(&self) -> Vec<impl ToTokens> {
123         self.0.iter()
124             .map(|meth| meth.priv_module())
125             .collect::<Vec<_>>()
126     }
127 }
128 
129 pub(crate) struct MockItemStruct {
130     attrs: Vec<Attribute>,
131     consts: Vec<ImplItemConst>,
132     generics: Generics,
133     /// Should Mockall generate a Debug implementation?
134     auto_debug: bool,
135     /// Does the original struct have a `new` method?
136     has_new: bool,
137     /// Inherent methods of the mock struct
138     methods: Methods,
139     /// Name of the overall module that holds all of the mock stuff
140     modname: Ident,
141     name: Ident,
142     /// Is this a whole MockStruct or just a substructure for a trait impl?
143     traits: Vec<MockTrait>,
144     vis: Visibility,
145 }
146 
147 impl MockItemStruct {
debug_impl(&self) -> impl ToTokens148     fn debug_impl(&self) -> impl ToTokens {
149         if self.auto_debug {
150             let (ig, tg, wc) = self.generics.split_for_impl();
151             let struct_name = &self.name;
152             let struct_name_str = format!("{}", self.name);
153             quote!(
154                 impl #ig ::std::fmt::Debug for #struct_name #tg #wc {
155                     fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>)
156                         -> ::std::result::Result<(), std::fmt::Error>
157                     {
158                         f.debug_struct(#struct_name_str).finish()
159                     }
160                 }
161             )
162         } else {
163             quote!()
164         }
165     }
166 
new_method(&self) -> impl ToTokens167     fn new_method(&self) -> impl ToTokens {
168         if self.has_new {
169             TokenStream::new()
170         } else {
171             quote!(
172                 /// Create a new mock object with no expectations.
173                 ///
174                 /// This method will not be generated if the real struct
175                 /// already has a `new` method.  However, it *will* be
176                 /// generated if the struct implements a trait with a `new`
177                 /// method.  The trait's `new` method can still be called
178                 /// like `<MockX as TraitY>::new`
179                 pub fn new() -> Self {
180                     Self::default()
181                 }
182             )
183         }
184     }
185 
phantom_default_inits(&self) -> Vec<TokenStream>186     fn phantom_default_inits(&self) -> Vec<TokenStream> {
187         phantom_default_inits(&self.generics)
188     }
189 
phantom_fields(&self) -> Vec<TokenStream>190     fn phantom_fields(&self) -> Vec<TokenStream> {
191         phantom_fields(&self.generics)
192     }
193 }
194 
195 impl From<MockableStruct> for MockItemStruct {
from(mockable: MockableStruct) -> MockItemStruct196     fn from(mockable: MockableStruct) -> MockItemStruct {
197         let auto_debug = mockable.derives_debug();
198         let modname = gen_mod_ident(&mockable.name, None);
199         let generics = mockable.generics.clone();
200         let struct_name = &mockable.name;
201         let vis = mockable.vis;
202         let has_new = mockable.methods.iter()
203             .any(|meth| meth.sig.ident == "new") ||
204             mockable.impls.iter()
205             .any(|impl_|
206                 impl_.items.iter()
207                     .any(|ii| if let ImplItem::Fn(iif) = ii {
208                             iif.sig.ident == "new"
209                         } else {
210                             false
211                         }
212                     )
213             );
214         let methods = Methods(mockable.methods.into_iter()
215             .map(|meth|
216                 mock_function::Builder::new(&meth.sig, &meth.vis)
217                     .attrs(&meth.attrs)
218                     .struct_(struct_name)
219                     .struct_generics(&generics)
220                     .levels(2)
221                     .call_levels(0)
222                     .build()
223             ).collect::<Vec<_>>());
224         let structname = &mockable.name;
225         let traits = mockable.impls.into_iter()
226             .map(|i| MockTrait::new(structname, &generics, i, &vis))
227             .collect();
228 
229         MockItemStruct {
230             attrs: mockable.attrs,
231             auto_debug,
232             consts: mockable.consts,
233             generics,
234             has_new,
235             methods,
236             modname,
237             name: mockable.name,
238             traits,
239             vis
240         }
241     }
242 }
243 
244 impl ToTokens for MockItemStruct {
to_tokens(&self, tokens: &mut TokenStream)245     fn to_tokens(&self, tokens: &mut TokenStream) {
246         let attrs = AttrFormatter::new(&self.attrs)
247             .async_trait(false)
248             .must_use(true)
249             .format();
250         let consts = &self.consts;
251         let debug_impl = self.debug_impl();
252         let struct_name = &self.name;
253         let (ig, tg, wc) = self.generics.split_for_impl();
254         let modname = &self.modname;
255         let calls = self.methods.0.iter()
256             .map(|meth| meth.call(Some(modname)))
257             .collect::<Vec<_>>();
258         let contexts = self.methods.0.iter()
259             .filter(|meth| meth.is_static())
260             .map(|meth| meth.context_fn(Some(modname)))
261             .collect::<Vec<_>>();
262         let expects = self.methods.0.iter()
263             .filter(|meth| !meth.is_static())
264             .map(|meth| meth.expect(modname, None))
265             .collect::<Vec<_>>();
266         let method_checkpoints = self.methods.checkpoints();
267         let new_method = self.new_method();
268         let priv_mods = self.methods.priv_mods();
269         let substructs = unique_trait_iter(self.traits.iter())
270             .map(|trait_| {
271                 MockItemTraitImpl {
272                     attrs: trait_.attrs.clone(),
273                     generics: self.generics.clone(),
274                     fieldname: format_ident!("{}_expectations",
275                                              trait_.ss_name()),
276                     methods: Methods(trait_.methods.clone()),
277                     modname: format_ident!("{}_{}", &self.modname,
278                                            trait_.ss_name()),
279                     name: format_ident!("{}_{}", &self.name, trait_.ss_name()),
280                 }
281             }).collect::<Vec<_>>();
282         let substruct_expectations = substructs.iter()
283             .filter(|ss| !ss.all_static())
284             .map(|ss| {
285                 let attrs = AttrFormatter::new(&ss.attrs)
286                     .async_trait(false)
287                     .doc(false)
288                     .format();
289                 let fieldname = &ss.fieldname;
290                 quote!(#(#attrs)* self.#fieldname.checkpoint();)
291             }).collect::<Vec<_>>();
292         let mut field_definitions = substructs.iter()
293             .filter(|ss| !ss.all_static())
294             .map(|ss| {
295                 let attrs = AttrFormatter::new(&ss.attrs)
296                     .async_trait(false)
297                     .doc(false)
298                     .format();
299                 let fieldname = &ss.fieldname;
300                 let tyname = &ss.name;
301                 quote!(#(#attrs)* #fieldname: #tyname #tg)
302             }).collect::<Vec<_>>();
303         field_definitions.extend(self.methods.field_definitions(modname));
304         field_definitions.extend(self.phantom_fields());
305         let mut default_inits = substructs.iter()
306             .filter(|ss| !ss.all_static())
307             .map(|ss| {
308                 let attrs = AttrFormatter::new(&ss.attrs)
309                     .async_trait(false)
310                     .doc(false)
311                     .format();
312                 let fieldname = &ss.fieldname;
313                 quote!(#(#attrs)* #fieldname: Default::default())
314             }).collect::<Vec<_>>();
315         default_inits.extend(self.methods.default_inits());
316         default_inits.extend(self.phantom_default_inits());
317         let trait_impls = self.traits.iter()
318             .map(|trait_| {
319                 let modname = format_ident!("{}_{}", &self.modname,
320                                             trait_.ss_name());
321                 trait_.trait_impl(&modname)
322             }).collect::<Vec<_>>();
323         let vis = &self.vis;
324         quote!(
325             #[allow(non_snake_case)]
326             #[allow(missing_docs)]
327             pub mod #modname {
328                 use super::*;
329                 #(#priv_mods)*
330             }
331             #[allow(non_camel_case_types)]
332             #[allow(non_snake_case)]
333             #[allow(missing_docs)]
334             #(#attrs)*
335             #vis struct #struct_name #ig #wc
336             {
337                 #(#field_definitions),*
338             }
339             #debug_impl
340             impl #ig ::std::default::Default for #struct_name #tg #wc {
341                 #[allow(clippy::default_trait_access)]
342                 fn default() -> Self {
343                     Self {
344                         #(#default_inits),*
345                     }
346                 }
347             }
348             #(#substructs)*
349             impl #ig #struct_name #tg #wc {
350                 #(#consts)*
351                 #(#calls)*
352                 #(#contexts)*
353                 #(#expects)*
354                 /// Validate that all current expectations for all methods have
355                 /// been satisfied, and discard them.
356                 pub fn checkpoint(&mut self) {
357                     #(#substruct_expectations)*
358                     #(#method_checkpoints)*
359                 }
360                 #new_method
361             }
362             #(#trait_impls)*
363         ).to_tokens(tokens);
364     }
365 }
366 
367 pub(crate) struct MockItemTraitImpl {
368     attrs: Vec<Attribute>,
369     generics: Generics,
370     /// Inherent methods of the mock struct
371     methods: Methods,
372     /// Name of the overall module that holds all of the mock stuff
373     modname: Ident,
374     name: Ident,
375     /// Name of the field of this type in the parent's structure
376     fieldname: Ident,
377 }
378 
379 impl MockItemTraitImpl {
380     /// Are all of this traits's methods static?
all_static(&self) -> bool381     fn all_static(&self) -> bool {
382         self.methods.all_static()
383     }
384 
phantom_default_inits(&self) -> Vec<TokenStream>385     fn phantom_default_inits(&self) -> Vec<TokenStream> {
386         phantom_default_inits(&self.generics)
387     }
388 
phantom_fields(&self) -> Vec<TokenStream>389     fn phantom_fields(&self) -> Vec<TokenStream> {
390         phantom_fields(&self.generics)
391     }
392 }
393 
394 impl ToTokens for MockItemTraitImpl {
to_tokens(&self, tokens: &mut TokenStream)395     fn to_tokens(&self, tokens: &mut TokenStream) {
396         let mod_attrs = AttrFormatter::new(&self.attrs)
397             .async_trait(false)
398             .doc(false)
399             .format();
400         let struct_attrs = AttrFormatter::new(&self.attrs)
401             .async_trait(false)
402             .doc(false)
403             .must_use(false)
404             .format();
405         let impl_attrs = AttrFormatter::new(&self.attrs)
406             .async_trait(false)
407             .doc(false)
408             .format();
409         let struct_name = &self.name;
410         let (ig, tg, wc) = self.generics.split_for_impl();
411         let modname = &self.modname;
412         let method_checkpoints = self.methods.checkpoints();
413         let mut default_inits = self.methods.default_inits();
414         default_inits.extend(self.phantom_default_inits());
415         let mut field_definitions = self.methods.field_definitions(modname);
416         field_definitions.extend(self.phantom_fields());
417         let priv_mods = self.methods.priv_mods();
418         quote!(
419             #[allow(non_snake_case)]
420             #[allow(missing_docs)]
421             #(#mod_attrs)*
422             pub mod #modname {
423                 use super::*;
424                 #(#priv_mods)*
425             }
426             #[allow(non_camel_case_types)]
427             #[allow(non_snake_case)]
428             #[allow(missing_docs)]
429             #(#struct_attrs)*
430             struct #struct_name #ig #wc
431             {
432                 #(#field_definitions),*
433             }
434             #(#impl_attrs)*
435             impl #ig ::std::default::Default for #struct_name #tg #wc {
436                 fn default() -> Self {
437                     Self {
438                         #(#default_inits),*
439                     }
440                 }
441             }
442             #(#impl_attrs)*
443             impl #ig #struct_name #tg #wc {
444                 /// Validate that all current expectations for all methods have
445                 /// been satisfied, and discard them.
446                 pub fn checkpoint(&mut self) {
447                     #(#method_checkpoints)*
448                 }
449             }
450         ).to_tokens(tokens);
451     }
452 }
453