1 // vim: tw=80
2 use std::collections::HashSet;
3
4 use proc_macro2::TokenStream;
5 use quote::{ToTokens, format_ident, quote};
6 use syn::{
7 *,
8 spanned::Spanned
9 };
10
11 use crate::{
12 AttrFormatter,
13 MockableStruct,
14 compile_error,
15 gen_mod_ident,
16 mock_function::{self, MockFunction},
17 mock_trait::MockTrait
18 };
19
phantom_default_inits(generics: &Generics) -> Vec<TokenStream>20 fn phantom_default_inits(generics: &Generics) -> Vec<TokenStream> {
21 generics.params
22 .iter()
23 .enumerate()
24 .map(|(count, _param)| {
25 let phident = format_ident!("_t{count}");
26 quote!(#phident: ::std::marker::PhantomData)
27 }).collect()
28 }
29
30 /// Generate any PhantomData field definitions
phantom_fields(generics: &Generics) -> Vec<TokenStream>31 fn phantom_fields(generics: &Generics) -> Vec<TokenStream> {
32 generics.params
33 .iter()
34 .enumerate()
35 .filter_map(|(count, param)| {
36 let phident = format_ident!("_t{count}");
37 match param {
38 syn::GenericParam::Lifetime(l) => {
39 if !l.bounds.is_empty() {
40 compile_error(l.bounds.span(),
41 "#automock does not yet support lifetime bounds on structs");
42 }
43 let lifetime = &l.lifetime;
44 Some(
45 quote!(#phident: ::std::marker::PhantomData<&#lifetime ()>)
46 )
47 },
48 syn::GenericParam::Type(tp) => {
49 let ty = &tp.ident;
50 Some(
51 quote!(#phident: ::std::marker::PhantomData<#ty>)
52 )
53 },
54 syn::GenericParam::Const(_) => {
55 compile_error(param.span(),
56 "#automock does not yet support generic constants");
57 None
58 }
59 }
60 }).collect()
61 }
62
63 /// Filter out multiple copies of the same trait, even if they're implemented on
64 /// different types. But allow them if they have different attributes, which
65 /// probably indicates that they aren't meant to be compiled together.
unique_trait_iter<'a, I: Iterator<Item = &'a MockTrait>>(i: I) -> impl Iterator<Item = &'a MockTrait>66 fn unique_trait_iter<'a, I: Iterator<Item = &'a MockTrait>>(i: I)
67 -> impl Iterator<Item = &'a MockTrait>
68 {
69 let mut hs = HashSet::<(Path, Vec<Attribute>)>::default();
70 i.filter(move |mt| {
71 let impl_attrs = AttrFormatter::new(&mt.attrs)
72 .async_trait(false)
73 .doc(false)
74 .format();
75 let key = (mt.trait_path.clone(), impl_attrs);
76 if hs.contains(&key) {
77 false
78 } else {
79 hs.insert(key);
80 true
81 }
82 })
83 }
84
85 /// A collection of methods defined in one spot
86 struct Methods(Vec<MockFunction>);
87
88 impl Methods {
89 /// Are all of these methods static?
all_static(&self) -> bool90 fn all_static(&self) -> bool {
91 self.0.iter()
92 .all(|meth| meth.is_static())
93 }
94
checkpoints(&self) -> Vec<impl ToTokens>95 fn checkpoints(&self) -> Vec<impl ToTokens> {
96 self.0.iter()
97 .filter(|meth| !meth.is_static())
98 .map(|meth| meth.checkpoint())
99 .collect::<Vec<_>>()
100 }
101
102 /// Return a fragment of code to initialize struct fields during default()
default_inits(&self) -> Vec<TokenStream>103 fn default_inits(&self) -> Vec<TokenStream> {
104 self.0.iter()
105 .filter(|meth| !meth.is_static())
106 .map(|meth| {
107 let name = meth.name();
108 let attrs = AttrFormatter::new(&meth.attrs)
109 .doc(false)
110 .format();
111 quote!(#(#attrs)* #name: Default::default())
112 }).collect::<Vec<_>>()
113 }
114
field_definitions(&self, modname: &Ident) -> Vec<TokenStream>115 fn field_definitions(&self, modname: &Ident) -> Vec<TokenStream> {
116 self.0.iter()
117 .filter(|meth| !meth.is_static())
118 .map(|meth| meth.field_definition(Some(modname)))
119 .collect::<Vec<_>>()
120 }
121
priv_mods(&self) -> Vec<impl ToTokens>122 fn priv_mods(&self) -> Vec<impl ToTokens> {
123 self.0.iter()
124 .map(|meth| meth.priv_module())
125 .collect::<Vec<_>>()
126 }
127 }
128
129 pub(crate) struct MockItemStruct {
130 attrs: Vec<Attribute>,
131 consts: Vec<ImplItemConst>,
132 generics: Generics,
133 /// Should Mockall generate a Debug implementation?
134 auto_debug: bool,
135 /// Does the original struct have a `new` method?
136 has_new: bool,
137 /// Inherent methods of the mock struct
138 methods: Methods,
139 /// Name of the overall module that holds all of the mock stuff
140 modname: Ident,
141 name: Ident,
142 /// Is this a whole MockStruct or just a substructure for a trait impl?
143 traits: Vec<MockTrait>,
144 vis: Visibility,
145 }
146
147 impl MockItemStruct {
debug_impl(&self) -> impl ToTokens148 fn debug_impl(&self) -> impl ToTokens {
149 if self.auto_debug {
150 let (ig, tg, wc) = self.generics.split_for_impl();
151 let struct_name = &self.name;
152 let struct_name_str = format!("{}", self.name);
153 quote!(
154 impl #ig ::std::fmt::Debug for #struct_name #tg #wc {
155 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>)
156 -> ::std::result::Result<(), std::fmt::Error>
157 {
158 f.debug_struct(#struct_name_str).finish()
159 }
160 }
161 )
162 } else {
163 quote!()
164 }
165 }
166
new_method(&self) -> impl ToTokens167 fn new_method(&self) -> impl ToTokens {
168 if self.has_new {
169 TokenStream::new()
170 } else {
171 quote!(
172 /// Create a new mock object with no expectations.
173 ///
174 /// This method will not be generated if the real struct
175 /// already has a `new` method. However, it *will* be
176 /// generated if the struct implements a trait with a `new`
177 /// method. The trait's `new` method can still be called
178 /// like `<MockX as TraitY>::new`
179 pub fn new() -> Self {
180 Self::default()
181 }
182 )
183 }
184 }
185
phantom_default_inits(&self) -> Vec<TokenStream>186 fn phantom_default_inits(&self) -> Vec<TokenStream> {
187 phantom_default_inits(&self.generics)
188 }
189
phantom_fields(&self) -> Vec<TokenStream>190 fn phantom_fields(&self) -> Vec<TokenStream> {
191 phantom_fields(&self.generics)
192 }
193 }
194
195 impl From<MockableStruct> for MockItemStruct {
from(mockable: MockableStruct) -> MockItemStruct196 fn from(mockable: MockableStruct) -> MockItemStruct {
197 let auto_debug = mockable.derives_debug();
198 let modname = gen_mod_ident(&mockable.name, None);
199 let generics = mockable.generics.clone();
200 let struct_name = &mockable.name;
201 let vis = mockable.vis;
202 let has_new = mockable.methods.iter()
203 .any(|meth| meth.sig.ident == "new") ||
204 mockable.impls.iter()
205 .any(|impl_|
206 impl_.items.iter()
207 .any(|ii| if let ImplItem::Fn(iif) = ii {
208 iif.sig.ident == "new"
209 } else {
210 false
211 }
212 )
213 );
214 let methods = Methods(mockable.methods.into_iter()
215 .map(|meth|
216 mock_function::Builder::new(&meth.sig, &meth.vis)
217 .attrs(&meth.attrs)
218 .struct_(struct_name)
219 .struct_generics(&generics)
220 .levels(2)
221 .call_levels(0)
222 .build()
223 ).collect::<Vec<_>>());
224 let structname = &mockable.name;
225 let traits = mockable.impls.into_iter()
226 .map(|i| MockTrait::new(structname, &generics, i, &vis))
227 .collect();
228
229 MockItemStruct {
230 attrs: mockable.attrs,
231 auto_debug,
232 consts: mockable.consts,
233 generics,
234 has_new,
235 methods,
236 modname,
237 name: mockable.name,
238 traits,
239 vis
240 }
241 }
242 }
243
244 impl ToTokens for MockItemStruct {
to_tokens(&self, tokens: &mut TokenStream)245 fn to_tokens(&self, tokens: &mut TokenStream) {
246 let attrs = AttrFormatter::new(&self.attrs)
247 .async_trait(false)
248 .must_use(true)
249 .format();
250 let consts = &self.consts;
251 let debug_impl = self.debug_impl();
252 let struct_name = &self.name;
253 let (ig, tg, wc) = self.generics.split_for_impl();
254 let modname = &self.modname;
255 let calls = self.methods.0.iter()
256 .map(|meth| meth.call(Some(modname)))
257 .collect::<Vec<_>>();
258 let contexts = self.methods.0.iter()
259 .filter(|meth| meth.is_static())
260 .map(|meth| meth.context_fn(Some(modname)))
261 .collect::<Vec<_>>();
262 let expects = self.methods.0.iter()
263 .filter(|meth| !meth.is_static())
264 .map(|meth| meth.expect(modname, None))
265 .collect::<Vec<_>>();
266 let method_checkpoints = self.methods.checkpoints();
267 let new_method = self.new_method();
268 let priv_mods = self.methods.priv_mods();
269 let substructs = unique_trait_iter(self.traits.iter())
270 .map(|trait_| {
271 MockItemTraitImpl {
272 attrs: trait_.attrs.clone(),
273 generics: self.generics.clone(),
274 fieldname: format_ident!("{}_expectations",
275 trait_.ss_name()),
276 methods: Methods(trait_.methods.clone()),
277 modname: format_ident!("{}_{}", &self.modname,
278 trait_.ss_name()),
279 name: format_ident!("{}_{}", &self.name, trait_.ss_name()),
280 }
281 }).collect::<Vec<_>>();
282 let substruct_expectations = substructs.iter()
283 .filter(|ss| !ss.all_static())
284 .map(|ss| {
285 let attrs = AttrFormatter::new(&ss.attrs)
286 .async_trait(false)
287 .doc(false)
288 .format();
289 let fieldname = &ss.fieldname;
290 quote!(#(#attrs)* self.#fieldname.checkpoint();)
291 }).collect::<Vec<_>>();
292 let mut field_definitions = substructs.iter()
293 .filter(|ss| !ss.all_static())
294 .map(|ss| {
295 let attrs = AttrFormatter::new(&ss.attrs)
296 .async_trait(false)
297 .doc(false)
298 .format();
299 let fieldname = &ss.fieldname;
300 let tyname = &ss.name;
301 quote!(#(#attrs)* #fieldname: #tyname #tg)
302 }).collect::<Vec<_>>();
303 field_definitions.extend(self.methods.field_definitions(modname));
304 field_definitions.extend(self.phantom_fields());
305 let mut default_inits = substructs.iter()
306 .filter(|ss| !ss.all_static())
307 .map(|ss| {
308 let attrs = AttrFormatter::new(&ss.attrs)
309 .async_trait(false)
310 .doc(false)
311 .format();
312 let fieldname = &ss.fieldname;
313 quote!(#(#attrs)* #fieldname: Default::default())
314 }).collect::<Vec<_>>();
315 default_inits.extend(self.methods.default_inits());
316 default_inits.extend(self.phantom_default_inits());
317 let trait_impls = self.traits.iter()
318 .map(|trait_| {
319 let modname = format_ident!("{}_{}", &self.modname,
320 trait_.ss_name());
321 trait_.trait_impl(&modname)
322 }).collect::<Vec<_>>();
323 let vis = &self.vis;
324 quote!(
325 #[allow(non_snake_case)]
326 #[allow(missing_docs)]
327 pub mod #modname {
328 use super::*;
329 #(#priv_mods)*
330 }
331 #[allow(non_camel_case_types)]
332 #[allow(non_snake_case)]
333 #[allow(missing_docs)]
334 #(#attrs)*
335 #vis struct #struct_name #ig #wc
336 {
337 #(#field_definitions),*
338 }
339 #debug_impl
340 impl #ig ::std::default::Default for #struct_name #tg #wc {
341 #[allow(clippy::default_trait_access)]
342 fn default() -> Self {
343 Self {
344 #(#default_inits),*
345 }
346 }
347 }
348 #(#substructs)*
349 impl #ig #struct_name #tg #wc {
350 #(#consts)*
351 #(#calls)*
352 #(#contexts)*
353 #(#expects)*
354 /// Validate that all current expectations for all methods have
355 /// been satisfied, and discard them.
356 pub fn checkpoint(&mut self) {
357 #(#substruct_expectations)*
358 #(#method_checkpoints)*
359 }
360 #new_method
361 }
362 #(#trait_impls)*
363 ).to_tokens(tokens);
364 }
365 }
366
367 pub(crate) struct MockItemTraitImpl {
368 attrs: Vec<Attribute>,
369 generics: Generics,
370 /// Inherent methods of the mock struct
371 methods: Methods,
372 /// Name of the overall module that holds all of the mock stuff
373 modname: Ident,
374 name: Ident,
375 /// Name of the field of this type in the parent's structure
376 fieldname: Ident,
377 }
378
379 impl MockItemTraitImpl {
380 /// Are all of this traits's methods static?
all_static(&self) -> bool381 fn all_static(&self) -> bool {
382 self.methods.all_static()
383 }
384
phantom_default_inits(&self) -> Vec<TokenStream>385 fn phantom_default_inits(&self) -> Vec<TokenStream> {
386 phantom_default_inits(&self.generics)
387 }
388
phantom_fields(&self) -> Vec<TokenStream>389 fn phantom_fields(&self) -> Vec<TokenStream> {
390 phantom_fields(&self.generics)
391 }
392 }
393
394 impl ToTokens for MockItemTraitImpl {
to_tokens(&self, tokens: &mut TokenStream)395 fn to_tokens(&self, tokens: &mut TokenStream) {
396 let mod_attrs = AttrFormatter::new(&self.attrs)
397 .async_trait(false)
398 .doc(false)
399 .format();
400 let struct_attrs = AttrFormatter::new(&self.attrs)
401 .async_trait(false)
402 .doc(false)
403 .must_use(false)
404 .format();
405 let impl_attrs = AttrFormatter::new(&self.attrs)
406 .async_trait(false)
407 .doc(false)
408 .format();
409 let struct_name = &self.name;
410 let (ig, tg, wc) = self.generics.split_for_impl();
411 let modname = &self.modname;
412 let method_checkpoints = self.methods.checkpoints();
413 let mut default_inits = self.methods.default_inits();
414 default_inits.extend(self.phantom_default_inits());
415 let mut field_definitions = self.methods.field_definitions(modname);
416 field_definitions.extend(self.phantom_fields());
417 let priv_mods = self.methods.priv_mods();
418 quote!(
419 #[allow(non_snake_case)]
420 #[allow(missing_docs)]
421 #(#mod_attrs)*
422 pub mod #modname {
423 use super::*;
424 #(#priv_mods)*
425 }
426 #[allow(non_camel_case_types)]
427 #[allow(non_snake_case)]
428 #[allow(missing_docs)]
429 #(#struct_attrs)*
430 struct #struct_name #ig #wc
431 {
432 #(#field_definitions),*
433 }
434 #(#impl_attrs)*
435 impl #ig ::std::default::Default for #struct_name #tg #wc {
436 fn default() -> Self {
437 Self {
438 #(#default_inits),*
439 }
440 }
441 }
442 #(#impl_attrs)*
443 impl #ig #struct_name #tg #wc {
444 /// Validate that all current expectations for all methods have
445 /// been satisfied, and discard them.
446 pub fn checkpoint(&mut self) {
447 #(#method_checkpoints)*
448 }
449 }
450 ).to_tokens(tokens);
451 }
452 }
453