• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.android.modules.utils.testing;
17 
18 import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession;
19 
20 import android.annotation.Nullable;
21 import android.util.Log;
22 
23 import com.android.dx.mockito.inline.extended.StaticMockitoSessionBuilder;
24 import com.android.internal.annotations.VisibleForTesting;
25 import com.android.modules.utils.testing.AbstractExtendedMockitoRule.AbstractBuilder;
26 import com.android.modules.utils.testing.ExtendedMockitoRule.MockStatic;
27 import com.android.modules.utils.testing.ExtendedMockitoRule.MockStaticClasses;
28 import com.android.modules.utils.testing.ExtendedMockitoRule.SpyStatic;
29 import com.android.modules.utils.testing.ExtendedMockitoRule.SpyStaticClasses;
30 
31 import org.junit.rules.TestRule;
32 import org.junit.runner.Description;
33 import org.junit.runners.model.Statement;
34 import org.mockito.Mockito;
35 import org.mockito.MockitoFramework;
36 import org.mockito.MockitoSession;
37 import org.mockito.quality.Strictness;
38 
39 import java.lang.annotation.Annotation;
40 import java.lang.annotation.ElementType;
41 import java.lang.annotation.Repeatable;
42 import java.lang.annotation.Retention;
43 import java.lang.annotation.RetentionPolicy;
44 import java.lang.annotation.Target;
45 import java.util.ArrayList;
46 import java.util.Arrays;
47 import java.util.Collection;
48 import java.util.Collections;
49 import java.util.HashSet;
50 import java.util.List;
51 import java.util.Objects;
52 import java.util.Set;
53 import java.util.function.Function;
54 import java.util.function.Supplier;
55 import java.util.stream.Collectors;
56 
57 /**
58  * Base class for {@link ExtendedMockitoRule} and other rules - it contains the logic so those
59  * classes can be final.
60  */
61 public abstract class AbstractExtendedMockitoRule<R extends AbstractExtendedMockitoRule<R, B>,
62         B extends AbstractBuilder<R, B>> implements TestRule {
63 
64     private static final String TAG = AbstractExtendedMockitoRule.class.getSimpleName();
65 
66     private static final AnnotationFetcher<SpyStatic, SpyStaticClasses>
67         sSpyStaticAnnotationFetcher = new AnnotationFetcher<>(SpyStatic.class,
68                 SpyStaticClasses.class, r -> r.value());
69     private static final AnnotationFetcher<MockStatic, MockStaticClasses>
70         sMockStaticAnnotationFetcher = new AnnotationFetcher<>(MockStatic.class,
71                 MockStaticClasses.class, r -> r.value());
72 
73     private final Object mTestClassInstance;
74     private final Strictness mStrictness;
75     private @Nullable final MockitoFramework mMockitoFramework;
76     private @Nullable final Runnable mAfterSessionFinishedCallback;
77     private final Set<Class<?>> mMockedStaticClasses;
78     private final Set<Class<?>> mSpiedStaticClasses;
79     private final List<StaticMockFixture> mStaticMockFixtures;
80     private final boolean mClearInlineMocks;
81 
82     private MockitoSession mMockitoSession;
83 
AbstractExtendedMockitoRule(B builder)84     protected AbstractExtendedMockitoRule(B builder) {
85         mTestClassInstance = builder.mTestClassInstance;
86         mStrictness = builder.mStrictness;
87         mMockitoFramework = builder.mMockitoFramework;
88         mMockitoSession = builder.mMockitoSession;
89         mAfterSessionFinishedCallback = builder.mAfterSessionFinishedCallback;
90         mMockedStaticClasses = builder.mMockedStaticClasses;
91         mSpiedStaticClasses = builder.mSpiedStaticClasses;
92         mStaticMockFixtures = builder.mStaticMockFixtures == null ? Collections.emptyList()
93                 : builder.mStaticMockFixtures;
94         mClearInlineMocks = builder.mClearInlineMocks;
95         Log.v(TAG, "strictness=" + mStrictness + ", testClassInstance" + mTestClassInstance
96                 + ", mockedStaticClasses=" + mMockedStaticClasses
97                 + ", spiedStaticClasses=" + mSpiedStaticClasses
98                 + ", staticMockFixtures=" + mStaticMockFixtures
99                 + ", afterSessionFinishedCallback=" + mAfterSessionFinishedCallback
100                 + ", mockitoFramework=" + mMockitoFramework
101                 + ", mockitoSession=" + mMockitoSession
102                 + ", clearInlineMocks=" + mClearInlineMocks);
103     }
104 
105     /**
106      * Gets the mocked static classes present in the given test.
107      *
108      * <p>By default, it returns the classes defined by {@link AbstractBuilder#mockStatic(Class)}
109      * plus the classes present in the {@link MockStatic} and {@link MockStaticClasses}
110      * annotations (presents in the test method, its class, or its superclasses).
111      */
getMockedStaticClasses(Description description)112     protected Set<Class<?>> getMockedStaticClasses(Description description) {
113         Set<Class<?>> staticClasses = new HashSet<>(mMockedStaticClasses);
114         sMockStaticAnnotationFetcher.getAnnotations(description)
115                 .forEach(a -> staticClasses.add(a.value()));
116         return Collections.unmodifiableSet(staticClasses);
117     }
118 
119     /**
120      * Gets the spied static classes present in the given test.
121      *
122      * <p>By default, it returns the classes defined by {@link AbstractBuilder#spyStatic(Class)}
123      * plus the classes present in the {@link SpyStatic} and {@link SpyStaticClasses}
124      * annotations (presents in the test method, its class, or its superclasses).
125      */
getSpiedStaticClasses(Description description)126     protected Set<Class<?>> getSpiedStaticClasses(Description description) {
127         Set<Class<?>> staticClasses = new HashSet<>(mSpiedStaticClasses);
128         sSpyStaticAnnotationFetcher.getAnnotations(description)
129                 .forEach(a -> staticClasses.add(a.value()));
130         return Collections.unmodifiableSet(staticClasses);
131     }
132 
133     /**
134      * Gets whether the rule should clear the inline mocks after the given test.
135      *
136      * <p>By default, it returns {@code} (unless the rule was built with
137      * {@link AbstractBuilder#dontClearInlineMocks()}, but subclasses can override to change the
138      * behavior (for example, to decide based on custom annotations).
139      */
getClearInlineMethodsAtTheEnd(Description description)140     protected boolean getClearInlineMethodsAtTheEnd(Description description) {
141         return mClearInlineMocks;
142     }
143 
144     @Override
apply(Statement base, Description description)145     public Statement apply(Statement base, Description description) {
146         return new Statement() {
147             @Override
148             public void evaluate() throws Throwable {
149                 createMockitoSession(base, description);
150                 Throwable error = null;
151                 try {
152                     // TODO(b/296937563): need to add unit tests that make sure the session is
153                     // always closed
154                     base.evaluate();
155                 } catch (Throwable t) {
156                     error = t;
157                 }
158                 try {
159                     tearDown(description, error);
160                 } catch (Throwable t) {
161                     if (error != null) {
162                         Log.e(TAG, "Teardown failed for " + description.getDisplayName()
163                             + ", but not throwing it because test also threw (" + error + ")", t);
164                     } else {
165                         error = t;
166                     }
167                 }
168                 if (error != null) {
169                     // TODO(b/296937563): ideally should also add unit tests to make sure the
170                     // test error is thrown (in case tearDown() above fails)
171                     throw error;
172                 }
173             }
174         };
175     }
176 
177     private void createMockitoSession(Statement base, Description description) {
178         // TODO(b/296937563): might be prudent to save the session statically so it's explicitly
179         // closed in case it fails to be created again if for some reason it was not closed by us
180         // (although that should not happen)
181         Log.v(TAG, "Creating session builder with strictness " + mStrictness);
182         StaticMockitoSessionBuilder mSessionBuilder = mockitoSession().strictness(mStrictness);
183 
184         setUpMockedClasses(description, mSessionBuilder);
185 
186         if (mTestClassInstance != null) {
187             Log.v(TAG, "Initializing mocks on " + description + " using " + mSessionBuilder);
188             mSessionBuilder.initMocks(mTestClassInstance);
189         } else {
190             Log.v(TAG, "NOT Initializing mocks on " + description + " as requested by builder");
191         }
192 
193         if (mMockitoSession != null) {
194             Log.d(TAG, "NOT creating session as set on builder: " + mMockitoSession);
195         } else {
196             Log.d(TAG, "Creating mockito session using " + mSessionBuilder);
197             mMockitoSession = mSessionBuilder.startMocking();
198         }
199 
200         setUpMockBehaviors();
201     }
202 
203     private void setUpMockedClasses(Description description,
204             StaticMockitoSessionBuilder sessionBuilder) {
205         if (!mStaticMockFixtures.isEmpty()) {
206             for (StaticMockFixture fixture : mStaticMockFixtures) {
207                 Log.v(TAG, "Calling setUpMockedClasses(" + sessionBuilder + ") on " + fixture);
208                 fixture.setUpMockedClasses(sessionBuilder);
209             }
210         }
211         for (Class<?> clazz: getMockedStaticClasses(description)) {
212             Log.v(TAG, "Calling mockStatic() on " + clazz);
213             sessionBuilder.mockStatic(clazz);
214         }
215         for (Class<?> clazz: getSpiedStaticClasses(description)) {
216             Log.v(TAG, "Calling spyStatic() on " + clazz);
217             sessionBuilder.spyStatic(clazz);
218         }
219     }
220 
221     private void setUpMockBehaviors() {
222         if (mStaticMockFixtures.isEmpty()) {
223             Log.v(TAG, "setUpMockBehaviors(): not needed, mStaticMockFixtures is empty");
224             return;
225         }
226         for (StaticMockFixture fixture : mStaticMockFixtures) {
227             Log.v(TAG, "Calling setUpMockBehaviors() on " + fixture);
228             fixture.setUpMockBehaviors();
229         }
230     }
231 
232     private void tearDown(Description description, Throwable e) {
233         Log.d(TAG, "Finishing mockito session " + mMockitoSession + " on " + description
234                 + (e == null ? "" : " (which failed with " + e + ")"));
235         try {
236             try {
237                 mMockitoSession.finishMocking(e);
238                 mMockitoSession = null;
239             } finally {
240                 // Must iterate in reverse order
241                 for (int i = mStaticMockFixtures.size() - 1; i >= 0; i--) {
242                     StaticMockFixture fixture = mStaticMockFixtures.get(i);
243                     Log.v(TAG, "Calling tearDown() on " + fixture);
244                     fixture.tearDown();
245                 }
246                 if (mAfterSessionFinishedCallback != null) {
247                     mAfterSessionFinishedCallback.run();
248                 }
249             }
250         } finally {
251             clearInlineMocks(description);
252         }
253     }
254 
255     private void clearInlineMocks(Description description) {
256         boolean clearIt = getClearInlineMethodsAtTheEnd(description);
257         if (!clearIt) {
258             Log.d(TAG, "NOT calling clearInlineMocks() as set on builder");
259             return;
260         }
261         if (mMockitoFramework != null) {
262             Log.v(TAG, "Calling clearInlineMocks() on custom mockito framework: "
263                     + mMockitoFramework);
264             mMockitoFramework.clearInlineMocks();
265             return;
266         }
267         Log.v(TAG, "Calling Mockito.framework().clearInlineMocks()");
268         Mockito.framework().clearInlineMocks();
269     }
270 
271     /**
272      * Builder for the rule.
273      */
274     public static abstract class AbstractBuilder<R extends
275             AbstractExtendedMockitoRule<R, B>, B extends AbstractBuilder<R, B>> {
276         final Object mTestClassInstance;
277         final Set<Class<?>> mMockedStaticClasses = new HashSet<>();
278         final Set<Class<?>> mSpiedStaticClasses = new HashSet<>();
279         @Nullable List<StaticMockFixture> mStaticMockFixtures;
280         Strictness mStrictness = Strictness.LENIENT;
281         @Nullable MockitoFramework mMockitoFramework;
282         @Nullable MockitoSession mMockitoSession;
283         @Nullable Runnable mAfterSessionFinishedCallback;
284         boolean mClearInlineMocks = true;
285 
286         /**
287          * Constructs a builder for the giving test instance (typically {@code this}) and initialize
288          * mocks on it.
289          */
290         protected AbstractBuilder(Object testClassInstance) {
291             mTestClassInstance = Objects.requireNonNull(testClassInstance);
292         }
293 
294         /**
295          * Constructs a builder that doesn't initialize mocks.
296          *
297          * <p>Typically used on test classes that already initialize mocks somewhere else.
298          */
299         protected AbstractBuilder() {
300             mTestClassInstance = null;
301         }
302 
303         /**
304          * Sets the mock strictness.
305          */
306         public final B setStrictness(Strictness strictness) {
307             mStrictness = Objects.requireNonNull(strictness);
308             return thisBuilder();
309         }
310 
311         /**
312          * Same as {@link
313          * com.android.dx.mockito.inline.extended.StaticMockitoSessionBuilder#mockStatic(Class)}.
314          *
315          * @throws IllegalStateException if the same class was already passed to
316          *   {@link #mockStatic(Class)} or {@link #spyStatic(Class)}.
317          */
318         public final B mockStatic(Class<?> clazz) {
319             mMockedStaticClasses.add(checkClassNotMockedOrSpied(clazz));
320             return thisBuilder();
321         }
322 
323         /**
324          * Same as {@link
325          * com.android.dx.mockito.inline.extended.StaticMockitoSessionBuilder#spyStatic(Class)}.
326          *
327          * @throws IllegalStateException if the same class was already passed to
328          *   {@link #mockStatic(Class)} or {@link #spyStatic(Class)}.
329          */
330         public final B spyStatic(Class<?> clazz) {
331             mSpiedStaticClasses.add(checkClassNotMockedOrSpied(clazz));
332             return thisBuilder();
333         }
334 
335         /**
336          * Uses the supplied {@link StaticMockFixture} as well.
337          */
338         @SafeVarargs
339         public final B addStaticMockFixtures(
340                 Supplier<? extends StaticMockFixture>... suppliers) {
341             List<StaticMockFixture> fixtures = Arrays
342                     .stream(Objects.requireNonNull(suppliers)).map(s -> s.get())
343                     .collect(Collectors.toList());
344             if (mStaticMockFixtures == null) {
345                 mStaticMockFixtures = fixtures;
346             } else {
347                 mStaticMockFixtures.addAll(fixtures);
348             }
349             return thisBuilder();
350         }
351 
352         /**
353          * Runs the given {@code runnable} after the session finished.
354          *
355          * <p>Typically used for clean-up code that cannot be executed on {@code &#064;After}, as
356          * those methods are called before the session is finished.
357          */
358         public final B afterSessionFinished(Runnable runnable) {
359             mAfterSessionFinishedCallback = Objects.requireNonNull(runnable);
360             return thisBuilder();
361         }
362 
363         /**
364          * By default, it cleans up inline mocks after the session is closed to prevent OutOfMemory
365          * errors (see <a href="https://github.com/mockito/mockito/issues/1614">external bug</a>
366          * and/or <a href="http://b/259280359">internal bug</a>); use this method to not do so.
367          */
368         public final B dontClearInlineMocks() {
369             mClearInlineMocks  = false;
370             return thisBuilder();
371         }
372 
373         // Used by ExtendedMockitoRuleTest itself
374         @VisibleForTesting
375         final B setMockitoFrameworkForTesting(MockitoFramework mockitoFramework) {
376             mMockitoFramework = Objects.requireNonNull(mockitoFramework);
377             return thisBuilder();
378         }
379 
380         // Used by ExtendedMockitoRuleTest itself
381         @VisibleForTesting
382         final B setMockitoSessionForTesting(MockitoSession mockitoSession) {
383             mMockitoSession = Objects.requireNonNull(mockitoSession);
384             return thisBuilder();
385         }
386 
387         /**
388          * Builds the rule.
389          */
390         public abstract R build();
391 
392         @SuppressWarnings("unchecked")
393         private B thisBuilder() {
394             return (B) this;
395         }
396 
397         private Class<?> checkClassNotMockedOrSpied(Class<?> clazz) {
398             Objects.requireNonNull(clazz);
399             checkState(!mMockedStaticClasses.contains(clazz), "class %s already mocked", clazz);
400             checkState(!mSpiedStaticClasses.contains(clazz), "class %s already spied", clazz);
401             return clazz;
402         }
403     }
404 
405     // Copied from com.android.internal.util.Preconditions, as that method is not available on RVC
406     private static void checkState(boolean expression, String messageTemplate,
407             Object... messageArgs) {
408         if (!expression) {
409             throw new IllegalStateException(String.format(messageTemplate, messageArgs));
410         }
411     }
412 
413     // TODO: make it public so it can be used by other modules
414     private static final class AnnotationFetcher<A extends Annotation, R extends Annotation> {
415 
416         private final Class<A> mAnnotationType;
417         private final Class<R> mRepeatableType;
418         private final Function<R, A[]> mConverter;
419 
420         AnnotationFetcher(Class<A> annotationType, Class<R> repeatableType,
421                 Function<R, A[]> converter) {
422             mAnnotationType = annotationType;
423             mRepeatableType = repeatableType;
424             mConverter = converter;
425         }
426 
427         private void add(Set<A> allAnnotations, R repeatableAnnotation) {
428             A[] repeatedAnnotations = mConverter.apply(repeatableAnnotation);
429             for (A repeatedAnnotation : repeatedAnnotations) {
430                 allAnnotations.add(repeatedAnnotation);
431             }
432         }
433 
434         Set<A> getAnnotations(Description description) {
435             Set<A> allAnnotations = new HashSet<>();
436 
437             // Gets the annotations from the method first
438             Collection<Annotation> annotations = description.getAnnotations();
439             if (annotations != null) {
440                 for (Annotation annotation : annotations) {
441                     if (mAnnotationType.isInstance(annotation)) {
442                         allAnnotations.add(mAnnotationType.cast(annotation));
443                     }
444                     if (mRepeatableType.isInstance(annotation)) {
445                         add(allAnnotations, mRepeatableType.cast(annotation));
446                     }
447                 }
448             }
449 
450             // Then superclasses
451             Class<?> clazz = description.getTestClass();
452             do {
453                 A[] repeatedAnnotations = clazz.getAnnotationsByType(mAnnotationType);
454                 if (repeatedAnnotations != null) {
455                     for (A repeatedAnnotation : repeatedAnnotations) {
456                         allAnnotations.add(repeatedAnnotation);
457                     }
458                 }
459                 R[] repeatableAnnotations = clazz.getAnnotationsByType(mRepeatableType);
460                 if (repeatableAnnotations != null) {
461                     for (R repeatableAnnotation : repeatableAnnotations) {
462                         add(allAnnotations, mRepeatableType.cast(repeatableAnnotation));
463                     }
464                 }
465                 clazz = clazz.getSuperclass();
466             } while (clazz != null);
467 
468             return allAnnotations;
469         }
470     }
471 }
472