• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016 Mockito contributors
3  * This program is made available under the terms of the MIT License.
4  */
5 package org.mockito.internal.creation.bytebuddy;
6 
7 import net.bytebuddy.asm.Advice;
8 import net.bytebuddy.description.method.MethodDescription;
9 import net.bytebuddy.description.type.TypeDescription;
10 import net.bytebuddy.dynamic.scaffold.MethodGraph;
11 import net.bytebuddy.implementation.bind.annotation.Argument;
12 import net.bytebuddy.implementation.bind.annotation.This;
13 import net.bytebuddy.implementation.bytecode.assign.Assigner;
14 import org.mockito.exceptions.base.MockitoException;
15 import org.mockito.internal.debugging.LocationImpl;
16 import org.mockito.internal.exceptions.stacktrace.ConditionalStackTraceFilter;
17 import org.mockito.internal.invocation.RealMethod;
18 import org.mockito.internal.invocation.SerializableMethod;
19 import org.mockito.internal.invocation.mockref.MockReference;
20 import org.mockito.internal.invocation.mockref.MockWeakReference;
21 import org.mockito.internal.util.concurrent.WeakConcurrentMap;
22 
23 import java.io.IOException;
24 import java.io.ObjectInputStream;
25 import java.io.Serializable;
26 import java.lang.annotation.Retention;
27 import java.lang.annotation.RetentionPolicy;
28 import java.lang.ref.SoftReference;
29 import java.lang.reflect.InvocationTargetException;
30 import java.lang.reflect.Method;
31 import java.lang.reflect.Modifier;
32 import java.util.ArrayList;
33 import java.util.List;
34 import java.util.concurrent.Callable;
35 
36 public class MockMethodAdvice extends MockMethodDispatcher {
37 
38     final WeakConcurrentMap<Object, MockMethodInterceptor> interceptors;
39 
40     private final String identifier;
41 
42     private final SelfCallInfo selfCallInfo = new SelfCallInfo();
43     private final MethodGraph.Compiler compiler = MethodGraph.Compiler.Default.forJavaHierarchy();
44     private final WeakConcurrentMap<Class<?>, SoftReference<MethodGraph>> graphs
45         = new WeakConcurrentMap.WithInlinedExpunction<Class<?>, SoftReference<MethodGraph>>();
46 
MockMethodAdvice(WeakConcurrentMap<Object, MockMethodInterceptor> interceptors, String identifier)47     public MockMethodAdvice(WeakConcurrentMap<Object, MockMethodInterceptor> interceptors, String identifier) {
48         this.interceptors = interceptors;
49         this.identifier = identifier;
50     }
51 
52     @SuppressWarnings("unused")
53     @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class)
enter(@dentifier String identifier, @Advice.This Object mock, @Advice.Origin Method origin, @Advice.AllArguments Object[] arguments)54     private static Callable<?> enter(@Identifier String identifier,
55                                      @Advice.This Object mock,
56                                      @Advice.Origin Method origin,
57                                      @Advice.AllArguments Object[] arguments) throws Throwable {
58         MockMethodDispatcher dispatcher = MockMethodDispatcher.get(identifier, mock);
59         if (dispatcher == null || !dispatcher.isMocked(mock) || dispatcher.isOverridden(mock, origin)) {
60             return null;
61         } else {
62             return dispatcher.handle(mock, origin, arguments);
63         }
64     }
65 
66     @SuppressWarnings({"unused", "UnusedAssignment"})
67     @Advice.OnMethodExit
exit(@dvice.ReturnreadOnly = false, typing = Assigner.Typing.DYNAMIC) Object returned, @Advice.Enter Callable<?> mocked)68     private static void exit(@Advice.Return(readOnly = false, typing = Assigner.Typing.DYNAMIC) Object returned,
69                              @Advice.Enter Callable<?> mocked) throws Throwable {
70         if (mocked != null) {
71             returned = mocked.call();
72         }
73     }
74 
hideRecursiveCall(Throwable throwable, int current, Class<?> targetType)75     static Throwable hideRecursiveCall(Throwable throwable, int current, Class<?> targetType) {
76         try {
77             StackTraceElement[] stack = throwable.getStackTrace();
78             int skip = 0;
79             StackTraceElement next;
80             do {
81                 next = stack[stack.length - current - ++skip];
82             } while (!next.getClassName().equals(targetType.getName()));
83             int top = stack.length - current - skip;
84             StackTraceElement[] cleared = new StackTraceElement[stack.length - skip];
85             System.arraycopy(stack, 0, cleared, 0, top);
86             System.arraycopy(stack, top + skip, cleared, top, current);
87             throwable.setStackTrace(cleared);
88             return throwable;
89         } catch (RuntimeException ignored) {
90             // This should not happen unless someone instrumented or manipulated exception stack traces.
91             return throwable;
92         }
93     }
94 
95     @Override
handle(Object instance, Method origin, Object[] arguments)96     public Callable<?> handle(Object instance, Method origin, Object[] arguments) throws Throwable {
97         MockMethodInterceptor interceptor = interceptors.get(instance);
98         if (interceptor == null) {
99             return null;
100         }
101         RealMethod realMethod;
102         if (instance instanceof Serializable) {
103             realMethod = new SerializableRealMethodCall(identifier, origin, instance, arguments);
104         } else {
105             realMethod = new RealMethodCall(selfCallInfo, origin, instance, arguments);
106         }
107         Throwable t = new Throwable();
108         t.setStackTrace(skipInlineMethodElement(t.getStackTrace()));
109         return new ReturnValueWrapper(interceptor.doIntercept(instance,
110                 origin,
111                 arguments,
112                 realMethod,
113                 new LocationImpl(t)));
114     }
115 
116     @Override
isMock(Object instance)117     public boolean isMock(Object instance) {
118         // We need to exclude 'interceptors.target' explicitly to avoid a recursive check on whether
119         // the map is a mock object what requires reading from the map.
120         return instance != interceptors.target && interceptors.containsKey(instance);
121     }
122 
123     @Override
isMocked(Object instance)124     public boolean isMocked(Object instance) {
125         return selfCallInfo.checkSuperCall(instance) && isMock(instance);
126     }
127 
128     @Override
isOverridden(Object instance, Method origin)129     public boolean isOverridden(Object instance, Method origin) {
130         SoftReference<MethodGraph> reference = graphs.get(instance.getClass());
131         MethodGraph methodGraph = reference == null ? null : reference.get();
132         if (methodGraph == null) {
133             methodGraph = compiler.compile(new TypeDescription.ForLoadedType(instance.getClass()));
134             graphs.put(instance.getClass(), new SoftReference<MethodGraph>(methodGraph));
135         }
136         MethodGraph.Node node = methodGraph.locate(new MethodDescription.ForLoadedMethod(origin).asSignatureToken());
137         return !node.getSort().isResolved() || !node.getRepresentative().asDefined().getDeclaringType().represents(origin.getDeclaringClass());
138     }
139 
140     private static class RealMethodCall implements RealMethod {
141 
142         private final SelfCallInfo selfCallInfo;
143 
144         private final Method origin;
145 
146         private final MockWeakReference<Object> instanceRef;
147 
148         private final Object[] arguments;
149 
RealMethodCall(SelfCallInfo selfCallInfo, Method origin, Object instance, Object[] arguments)150         private RealMethodCall(SelfCallInfo selfCallInfo, Method origin, Object instance, Object[] arguments) {
151             this.selfCallInfo = selfCallInfo;
152             this.origin = origin;
153             this.instanceRef = new MockWeakReference<Object>(instance);
154             this.arguments = arguments;
155         }
156 
157         @Override
isInvokable()158         public boolean isInvokable() {
159             return true;
160         }
161 
162         @Override
invoke()163         public Object invoke() throws Throwable {
164             if (!Modifier.isPublic(origin.getDeclaringClass().getModifiers() & origin.getModifiers())) {
165                 origin.setAccessible(true);
166             }
167             selfCallInfo.set(instanceRef.get());
168             return tryInvoke(origin, instanceRef.get(), arguments);
169         }
170 
171     }
172 
173     private static class SerializableRealMethodCall implements RealMethod {
174 
175         private final String identifier;
176 
177         private final SerializableMethod origin;
178 
179         private final MockReference<Object> instanceRef;
180 
181         private final Object[] arguments;
182 
SerializableRealMethodCall(String identifier, Method origin, Object instance, Object[] arguments)183         private SerializableRealMethodCall(String identifier, Method origin, Object instance, Object[] arguments) {
184             this.origin = new SerializableMethod(origin);
185             this.identifier = identifier;
186             this.instanceRef = new MockWeakReference<Object>(instance);
187             this.arguments = arguments;
188         }
189 
190         @Override
isInvokable()191         public boolean isInvokable() {
192             return true;
193         }
194 
195         @Override
invoke()196         public Object invoke() throws Throwable {
197             Method method = origin.getJavaMethod();
198             if (!Modifier.isPublic(method.getDeclaringClass().getModifiers() & method.getModifiers())) {
199                 method.setAccessible(true);
200             }
201             MockMethodDispatcher mockMethodDispatcher = MockMethodDispatcher.get(identifier, instanceRef.get());
202             if (!(mockMethodDispatcher instanceof MockMethodAdvice)) {
203                 throw new MockitoException("Unexpected dispatcher for advice-based super call");
204             }
205             Object previous = ((MockMethodAdvice) mockMethodDispatcher).selfCallInfo.replace(instanceRef.get());
206             try {
207                 return tryInvoke(method, instanceRef.get(), arguments);
208             } finally {
209                 ((MockMethodAdvice) mockMethodDispatcher).selfCallInfo.set(previous);
210             }
211         }
212     }
213 
tryInvoke(Method origin, Object instance, Object[] arguments)214     private static Object tryInvoke(Method origin, Object instance, Object[] arguments) throws Throwable {
215         try {
216             return origin.invoke(instance, arguments);
217         } catch (InvocationTargetException exception) {
218             Throwable cause = exception.getCause();
219             new ConditionalStackTraceFilter().filter(hideRecursiveCall(cause, new Throwable().getStackTrace().length, origin.getDeclaringClass()));
220             throw cause;
221         }
222     }
223 
224     // With inline mocking, mocks for concrete classes are not subclassed, so elements of the stubbing methods are not filtered out.
225     // Therefore, if the method is inlined, skip the element.
skipInlineMethodElement(StackTraceElement[] elements)226     private static StackTraceElement[] skipInlineMethodElement(StackTraceElement[] elements) {
227         List<StackTraceElement> list = new ArrayList<StackTraceElement>(elements.length);
228         for (int i = 0; i < elements.length; i++) {
229             StackTraceElement element = elements[i];
230             list.add(element);
231             if (element.getClassName().equals(MockMethodAdvice.class.getName()) && element.getMethodName().equals("handle")) {
232                 // If the current element is MockMethodAdvice#handle(), the next is assumed to be an inlined method.
233                 i++;
234             }
235         }
236         return list.toArray(new StackTraceElement[list.size()]);
237     }
238 
239     private static class ReturnValueWrapper implements Callable<Object> {
240 
241         private final Object returned;
242 
ReturnValueWrapper(Object returned)243         private ReturnValueWrapper(Object returned) {
244             this.returned = returned;
245         }
246 
247         @Override
call()248         public Object call() {
249             return returned;
250         }
251     }
252 
253     private static class SelfCallInfo extends ThreadLocal<Object> {
254 
replace(Object value)255         Object replace(Object value) {
256             Object current = get();
257             set(value);
258             return current;
259         }
260 
checkSuperCall(Object value)261         boolean checkSuperCall(Object value) {
262             if (value == get()) {
263                 set(null);
264                 return false;
265             } else {
266                 return true;
267             }
268         }
269     }
270 
271     @Retention(RetentionPolicy.RUNTIME)
272     @interface Identifier {
273 
274     }
275 
276     static class ForHashCode {
277 
278         @SuppressWarnings("unused")
279         @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class)
enter(@dentifier String id, @Advice.This Object self)280         private static boolean enter(@Identifier String id,
281                                      @Advice.This Object self) {
282             MockMethodDispatcher dispatcher = MockMethodDispatcher.get(id, self);
283             return dispatcher != null && dispatcher.isMock(self);
284         }
285 
286         @SuppressWarnings({"unused", "UnusedAssignment"})
287         @Advice.OnMethodExit
enter(@dvice.This Object self, @Advice.Return(readOnly = false) int hashCode, @Advice.Enter boolean skipped)288         private static void enter(@Advice.This Object self,
289                                   @Advice.Return(readOnly = false) int hashCode,
290                                   @Advice.Enter boolean skipped) {
291             if (skipped) {
292                 hashCode = System.identityHashCode(self);
293             }
294         }
295     }
296 
297     static class ForEquals {
298 
299         @SuppressWarnings("unused")
300         @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class)
enter(@dentifier String identifier, @Advice.This Object self)301         private static boolean enter(@Identifier String identifier,
302                                      @Advice.This Object self) {
303             MockMethodDispatcher dispatcher = MockMethodDispatcher.get(identifier, self);
304             return dispatcher != null && dispatcher.isMock(self);
305         }
306 
307         @SuppressWarnings({"unused", "UnusedAssignment"})
308         @Advice.OnMethodExit
enter(@dvice.This Object self, @Advice.Argument(0) Object other, @Advice.Return(readOnly = false) boolean equals, @Advice.Enter boolean skipped)309         private static void enter(@Advice.This Object self,
310                                   @Advice.Argument(0) Object other,
311                                   @Advice.Return(readOnly = false) boolean equals,
312                                   @Advice.Enter boolean skipped) {
313             if (skipped) {
314                 equals = self == other;
315             }
316         }
317     }
318 
319     public static class ForReadObject {
320 
321         @SuppressWarnings("unused")
doReadObject(@dentifier String identifier, @This MockAccess thiz, @Argument(0) ObjectInputStream objectInputStream)322         public static void doReadObject(@Identifier String identifier,
323                                         @This MockAccess thiz,
324                                         @Argument(0) ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
325             objectInputStream.defaultReadObject();
326             MockMethodAdvice mockMethodAdvice = (MockMethodAdvice) MockMethodDispatcher.get(identifier, thiz);
327             if (mockMethodAdvice != null) {
328                 mockMethodAdvice.interceptors.put(thiz, thiz.getMockitoInterceptor());
329             }
330         }
331     }
332 }
333