1 // vim: tw=80
2 use proc_macro2::Span;
3 use quote::{ToTokens, format_ident, quote};
4 use std::{
5 collections::hash_map::DefaultHasher,
6 fmt::Write,
7 hash::{Hash, Hasher}
8 };
9 use syn::{
10 *,
11 spanned::Spanned
12 };
13
14 use crate::{
15 AttrFormatter,
16 mock_function::{self, MockFunction},
17 compile_error
18 };
19
20 pub(crate) struct MockTrait {
21 pub attrs: Vec<Attribute>,
22 pub consts: Vec<ImplItemConst>,
23 pub generics: Generics,
24 pub methods: Vec<MockFunction>,
25 /// Internally-used name of the trait used.
26 pub ss_name: Ident,
27 /// Fully-qualified name of the trait
28 pub trait_path: Path,
29 /// Path on which the trait is implemented. Usually will be the same as
30 /// structname, but might include concrete generic parameters.
31 self_path: PathSegment,
32 pub types: Vec<ImplItemType>,
33 pub unsafety: Option<Token![unsafe]>
34 }
35
36 impl MockTrait {
ss_name_priv(trait_path: &Path) -> Ident37 fn ss_name_priv(trait_path: &Path) -> Ident {
38 let id = join_path_segments(trait_path, "__");
39 // Skip the hashing step for easy debugging of generated code
40 if let Some(hash) = hash_path_arguments(trait_path) {
41 // Hash the path args to permit mocking structs that implement
42 // multiple traits distinguished only by their path args
43 format_ident!("{id}_{hash}")
44 } else {
45 format_ident!("{id}")
46 }
47 }
48
ss_name(&self) -> &Ident49 pub fn ss_name(&self) -> &Ident {
50 &self.ss_name
51 }
52
53 /// Create a new MockTrait
54 ///
55 /// # Arguments
56 /// * `structname` - name of the struct that implements this trait
57 /// * `struct_generics` - Generics of the parent structure
58 /// * `impl_` - Mockable ItemImpl for a trait
59 /// * `vis` - Visibility of the struct
new(structname: &Ident, struct_generics: &Generics, impl_: ItemImpl, vis: &Visibility) -> Self60 pub fn new(structname: &Ident,
61 struct_generics: &Generics,
62 impl_: ItemImpl,
63 vis: &Visibility) -> Self
64 {
65 let mut consts = Vec::new();
66 let mut methods = Vec::new();
67 let mut types = Vec::new();
68 let trait_path = if let Some((_, path, _)) = impl_.trait_ {
69 path
70 } else {
71 compile_error(impl_.span(), "impl block must implement a trait");
72 Path::from(format_ident!("__mockall_invalid"))
73 };
74 let ss_name = MockTrait::ss_name_priv(&trait_path);
75 let self_path = match *impl_.self_ty {
76 Type::Path(mut type_path) =>
77 type_path.path.segments.pop().unwrap().into_value(),
78 x => {
79 compile_error(x.span(),
80 "mockall_derive only supports mocking traits and structs");
81 PathSegment::from(Ident::new("", Span::call_site()))
82 }
83 };
84
85 for ii in impl_.items.into_iter() {
86 match ii {
87 ImplItem::Const(iic) => {
88 consts.push(iic);
89 },
90 ImplItem::Fn(iif) => {
91 let mf = mock_function::Builder::new(&iif.sig, vis)
92 .attrs(&iif.attrs)
93 .levels(2)
94 .call_levels(0)
95 .struct_(structname)
96 .struct_generics(struct_generics)
97 .trait_(&ss_name)
98 .build();
99 methods.push(mf);
100 },
101 ImplItem::Type(iit) => {
102 types.push(iit);
103 },
104 _ => {
105 compile_error(ii.span(),
106 "This impl item is not yet supported by MockAll");
107 }
108 }
109 }
110 MockTrait {
111 attrs: impl_.attrs,
112 consts,
113 generics: impl_.generics,
114 methods,
115 ss_name,
116 trait_path,
117 self_path,
118 types,
119 unsafety: impl_.unsafety
120 }
121 }
122
123 /// Generate code for the trait implementation on the mock struct
124 ///
125 /// # Arguments
126 ///
127 /// * `modname`: Name of the parent struct's private module
128 // Supplying modname is an unfortunately hack. Ideally MockTrait
129 // wouldn't need to know that.
trait_impl(&self, modname: &Ident) -> impl ToTokens130 pub fn trait_impl(&self, modname: &Ident) -> impl ToTokens {
131 let trait_impl_attrs = AttrFormatter::new(&self.attrs)
132 .must_use(false)
133 .format();
134 let impl_attrs = AttrFormatter::new(&self.attrs)
135 .async_trait(false)
136 .doc(false)
137 .format();
138 let (ig, _tg, wc) = self.generics.split_for_impl();
139 let consts = &self.consts;
140 let path_args = &self.self_path.arguments;
141 let calls = self.methods.iter()
142 .map(|meth| meth.call(Some(modname)))
143 .collect::<Vec<_>>();
144 let contexts = self.methods.iter()
145 .filter(|meth| meth.is_static())
146 .map(|meth| meth.context_fn(Some(modname)))
147 .collect::<Vec<_>>();
148 let expects = self.methods.iter()
149 .filter(|meth| !meth.is_static())
150 .map(|meth| {
151 if meth.is_method_generic() {
152 // Specific impls with generic methods are TODO.
153 meth.expect(modname, None)
154 } else {
155 meth.expect(modname, Some(path_args))
156 }
157 }).collect::<Vec<_>>();
158 let trait_path = &self.trait_path;
159 let self_path = &self.self_path;
160 let types = &self.types;
161 let unsafety = &self.unsafety;
162 quote!(
163 #(#trait_impl_attrs)*
164 #unsafety impl #ig #trait_path for #self_path #wc {
165 #(#consts)*
166 #(#types)*
167 #(#calls)*
168 }
169 #(#impl_attrs)*
170 impl #ig #self_path #wc {
171 #(#expects)*
172 #(#contexts)*
173 }
174 )
175 }
176 }
177
join_path_segments(path: &Path, sep: &str) -> String178 fn join_path_segments(path: &Path, sep: &str) -> String {
179 let mut output = String::new();
180 for segment in &path.segments {
181 if write!(
182 output,
183 "{}{}",
184 if output.is_empty() { "" } else { sep },
185 segment.ident
186 )
187 .is_err()
188 {
189 break;
190 };
191 }
192 output
193 }
194
hash_path_arguments(path: &Path) -> Option<u64>195 fn hash_path_arguments(path: &Path) -> Option<u64> {
196 let mut hasher = DefaultHasher::new();
197 let mut is_some = false;
198 for arguments in path
199 .segments
200 .iter()
201 .map(|segment| &segment.arguments)
202 .filter(|arguments| !arguments.is_empty())
203 {
204 arguments.hash(&mut hasher);
205 is_some = true;
206 }
207 is_some.then(|| hasher.finish())
208 }
209