1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package android.car.test.mocks; 17 18 import static com.android.dx.mockito.inline.extended.ExtendedMockito.doAnswer; 19 import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession; 20 21 import static org.mockito.ArgumentMatchers.any; 22 import static org.mockito.ArgumentMatchers.anyInt; 23 import static org.mockito.ArgumentMatchers.anyString; 24 import static org.mockito.ArgumentMatchers.notNull; 25 import static org.mockito.Mockito.when; 26 27 import static java.lang.annotation.ElementType.METHOD; 28 import static java.lang.annotation.RetentionPolicy.RUNTIME; 29 30 import android.annotation.NonNull; 31 import android.annotation.Nullable; 32 import android.annotation.UserIdInt; 33 import android.app.ActivityManager; 34 import android.os.Binder; 35 import android.os.Handler; 36 import android.os.HandlerThread; 37 import android.os.Trace; 38 import android.os.UserManager; 39 import android.provider.Settings; 40 import android.util.Log; 41 import android.util.Slog; 42 import android.util.TimingsTraceLog; 43 44 import com.android.dx.mockito.inline.extended.StaticMockitoSessionBuilder; 45 import com.android.internal.util.Preconditions; 46 47 import org.junit.After; 48 import org.junit.Before; 49 import org.junit.Rule; 50 import org.junit.rules.TestRule; 51 import org.junit.runner.Description; 52 import org.junit.runners.model.Statement; 53 import org.mockito.MockitoSession; 54 import org.mockito.invocation.InvocationOnMock; 55 import org.mockito.quality.Strictness; 56 import org.mockito.session.MockitoSessionBuilder; 57 import org.mockito.stubbing.Answer; 58 59 import java.lang.annotation.Retention; 60 import java.lang.annotation.Target; 61 import java.lang.reflect.Method; 62 import java.util.ArrayList; 63 import java.util.HashMap; 64 import java.util.List; 65 import java.util.Set; 66 67 /** 68 * Base class for tests that must use {@link com.android.dx.mockito.inline.extended.ExtendedMockito} 69 * to mock static classes and final methods. 70 * 71 * <p><b>Note: </b> this class automatically spy on {@link Log} and {@link Slog} and fail tests that 72 * all any of their {@code wtf()} methods. If a test is expect to call {@code wtf()}, it should be 73 * annotated with {@link ExpectWtf}. 74 * 75 * <p><b>Note: </b>when using this class, you must include the following 76 * dependencies on {@code Android.bp} (or {@code Android.mk}: 77 * <pre><code> 78 jni_libs: [ 79 "libdexmakerjvmtiagent", 80 "libstaticjvmtiagent", 81 ], 82 83 LOCAL_JNI_SHARED_LIBRARIES := \ 84 libdexmakerjvmtiagent \ 85 libstaticjvmtiagent \ 86 * </code></pre> 87 */ 88 public abstract class AbstractExtendedMockitoTestCase { 89 90 private static final String TAG = AbstractExtendedMockitoTestCase.class.getSimpleName(); 91 92 private static final boolean TRACE = false; 93 private static final boolean VERBOSE = false; 94 95 private final List<Class<?>> mStaticSpiedClasses = new ArrayList<>(); 96 97 // Tracks (S)Log.wtf() calls made during code execution, then used on verifyWtfNeverLogged() 98 private final List<RuntimeException> mWtfs = new ArrayList<>(); 99 100 private MockitoSession mSession; 101 private MockSettings mSettings; 102 103 @Nullable 104 private final TimingsTraceLog mTracer; 105 106 @Rule 107 public final WtfCheckerRule mWtfCheckerRule = new WtfCheckerRule(); 108 AbstractExtendedMockitoTestCase()109 protected AbstractExtendedMockitoTestCase() { 110 mTracer = TRACE ? new TimingsTraceLog(TAG, Trace.TRACE_TAG_APP) : null; 111 } 112 113 @Before startSession()114 public final void startSession() { 115 beginTrace("startSession()"); 116 117 beginTrace("startMocking()"); 118 mSession = newSessionBuilder().startMocking(); 119 endTrace(); 120 121 beginTrace("MockSettings()"); 122 mSettings = new MockSettings(); 123 endTrace(); 124 125 beginTrace("interceptWtfCalls()"); 126 interceptWtfCalls(); 127 endTrace(); 128 129 endTrace(); // startSession 130 } 131 132 @After finishSession()133 public final void finishSession() { 134 beginTrace("finishSession()"); 135 completeAllHandlerThreadTasks(); 136 if (mSession != null) { 137 beginTrace("finishMocking()"); 138 mSession.finishMocking(); 139 endTrace(); 140 } else { 141 Log.w(TAG, getClass().getSimpleName() + ".finishSession(): no session"); 142 } 143 endTrace(); 144 } 145 146 /** 147 * Waits for completion of all pending Handler tasks for all HandlerThread in the process. 148 * 149 * <p>This can prevent pending Handler tasks of one test from affecting another. This does not 150 * work if the message is posted with delay. 151 */ completeAllHandlerThreadTasks()152 protected void completeAllHandlerThreadTasks() { 153 beginTrace("completeAllHandlerThreadTasks"); 154 Set<Thread> threadSet = Thread.getAllStackTraces().keySet(); 155 ArrayList<HandlerThread> handlerThreads = new ArrayList<>(threadSet.size()); 156 Thread currentThread = Thread.currentThread(); 157 for (Thread t : threadSet) { 158 if (t != currentThread && t instanceof HandlerThread) { 159 handlerThreads.add((HandlerThread) t); 160 } 161 } 162 ArrayList<SyncRunnable> syncs = new ArrayList<>(handlerThreads.size()); 163 Log.i(TAG, "will wait for " + handlerThreads.size() + " HandlerThreads"); 164 for (int i = 0; i < handlerThreads.size(); i++) { 165 Handler handler = new Handler(handlerThreads.get(i).getLooper()); 166 SyncRunnable sr = new SyncRunnable(() -> { }); 167 handler.post(sr); 168 syncs.add(sr); 169 } 170 beginTrace("waitForComplete"); 171 for (int i = 0; i < syncs.size(); i++) { 172 syncs.get(i).waitForComplete(); 173 } 174 endTrace(); // waitForComplete 175 endTrace(); // completeAllHandlerThreadTasks 176 } 177 178 /** 179 * Adds key-value(int) pair in mocked Settings.Global and Settings.Secure 180 */ putSettingsInt(@onNull String key, int value)181 protected void putSettingsInt(@NonNull String key, int value) { 182 mSettings.insertObject(key, value); 183 } 184 185 /** 186 * Gets value(int) from mocked Settings.Global and Settings.Secure 187 */ getSettingsInt(@onNull String key)188 protected int getSettingsInt(@NonNull String key) { 189 return mSettings.getInt(key); 190 } 191 192 /** 193 * Adds key-value(String) pair in mocked Settings.Global and Settings.Secure 194 */ putSettingsString(@onNull String key, @NonNull String value)195 protected void putSettingsString(@NonNull String key, @NonNull String value) { 196 mSettings.insertObject(key, value); 197 } 198 199 /** 200 * Gets value(String) from mocked Settings.Global and Settings.Secure 201 */ getSettingsString(@onNull String key)202 protected String getSettingsString(@NonNull String key) { 203 return mSettings.getString(key); 204 } 205 206 /** 207 * Asserts that the giving settings was not set. 208 */ assertSettingsNotSet(String key)209 protected void assertSettingsNotSet(String key) { 210 mSettings.assertDoesNotContainsKey(key); 211 } 212 213 /** 214 * Subclasses can use this method to initialize the Mockito session that's started before every 215 * test on {@link #startSession()}. 216 * 217 * <p>Typically, it should be overridden when mocking static methods. 218 */ onSessionBuilder(@onNull CustomMockitoSessionBuilder session)219 protected void onSessionBuilder(@NonNull CustomMockitoSessionBuilder session) { 220 if (VERBOSE) Log.v(TAG, getLogPrefix() + "onSessionBuilder()"); 221 } 222 223 /** 224 * Changes the value of the session created by 225 * {@link #onSessionBuilder(CustomMockitoSessionBuilder)}. 226 * 227 * <p>By default it's set to {@link Strictness.LENIENT}, but subclasses can overwrite this 228 * method to change the behavior. 229 */ 230 @NonNull getSessionStrictness()231 protected Strictness getSessionStrictness() { 232 return Strictness.LENIENT; 233 } 234 235 /** 236 * Mocks a call to {@link ActivityManager#getCurrentUser()}. 237 * 238 * @param userId result of such call 239 * 240 * @throws IllegalStateException if class didn't override {@link #newSessionBuilder()} and 241 * called {@code spyStatic(ActivityManager.class)} on the session passed to it. 242 */ mockGetCurrentUser(@serIdInt int userId)243 protected final void mockGetCurrentUser(@UserIdInt int userId) { 244 if (VERBOSE) Log.v(TAG, getLogPrefix() + "mockGetCurrentUser(" + userId + ")"); 245 assertSpied(ActivityManager.class); 246 247 beginTrace("mockAmGetCurrentUser-" + userId); 248 AndroidMockitoHelper.mockAmGetCurrentUser(userId); 249 endTrace(); 250 } 251 252 /** 253 * Mocks a call to {@link UserManager#isHeadlessSystemUserMode()}. 254 * 255 * @param mode result of such call 256 * 257 * @throws IllegalStateException if class didn't override {@link #newSessionBuilder()} and 258 * called {@code spyStatic(UserManager.class)} on the session passed to it. 259 */ mockIsHeadlessSystemUserMode(boolean mode)260 protected final void mockIsHeadlessSystemUserMode(boolean mode) { 261 if (VERBOSE) Log.v(TAG, getLogPrefix() + "mockIsHeadlessSystemUserMode(" + mode + ")"); 262 assertSpied(UserManager.class); 263 264 beginTrace("mockUmIsHeadlessSystemUserMode"); 265 AndroidMockitoHelper.mockUmIsHeadlessSystemUserMode(mode); 266 endTrace(); 267 } 268 269 /** 270 * Mocks a call to {@link Binder.getCallingUserHandle()}. 271 * 272 * @throws IllegalStateException if class didn't override {@link #newSessionBuilder()} and 273 * called {@code spyStatic(Binder.class)} on the session passed to it. 274 */ mockGetCallingUserHandle(@serIdInt int userId)275 protected final void mockGetCallingUserHandle(@UserIdInt int userId) { 276 if (VERBOSE) Log.v(TAG, getLogPrefix() + "mockBinderCallingUser(" + userId + ")"); 277 assertSpied(Binder.class); 278 279 beginTrace("mockBinderCallingUser"); 280 AndroidMockitoHelper.mockBinderGetCallingUserHandle(userId); 281 endTrace(); 282 } 283 284 /** 285 * Starts a tracing message. 286 * 287 * <p>MUST be followed by a {@link #endTrace()} calls. 288 * 289 * <p>Ignored if {@value #VERBOSE} is {@code false}. 290 */ beginTrace(@onNull String message)291 protected final void beginTrace(@NonNull String message) { 292 if (mTracer == null) return; 293 294 Log.d(TAG, getLogPrefix() + message); 295 mTracer.traceBegin(message); 296 } 297 298 /** 299 * Ends a tracing call. 300 * 301 * <p>MUST be called after {@link #beginTrace(String)}. 302 * 303 * <p>Ignored if {@value #VERBOSE} is {@code false}. 304 */ endTrace()305 protected final void endTrace() { 306 if (mTracer == null) return; 307 308 mTracer.traceEnd(); 309 } 310 interceptWtfCalls()311 private void interceptWtfCalls() { 312 doAnswer((invocation) -> { 313 return addWtf(invocation); 314 }).when(() -> Log.wtf(anyString(), anyString())); 315 doAnswer((invocation) -> { 316 return addWtf(invocation); 317 }).when(() -> Log.wtf(anyString(), anyString(), notNull())); 318 doAnswer((invocation) -> { 319 return addWtf(invocation); 320 }).when(() -> Slog.wtf(anyString(), anyString())); 321 doAnswer((invocation) -> { 322 return addWtf(invocation); 323 }).when(() -> Slog.wtf(anyString(), anyString(), any(Throwable.class))); 324 } 325 addWtf(InvocationOnMock invocation)326 private Object addWtf(InvocationOnMock invocation) { 327 String message = "Called " + invocation; 328 Log.d(TAG, message); // Log always, as some test expect it 329 mWtfs.add(new IllegalStateException(message)); 330 return null; 331 } 332 verifyWtfLogged()333 private void verifyWtfLogged() { 334 Preconditions.checkState(!mWtfs.isEmpty(), "no wtf() called"); 335 } 336 verifyWtfNeverLogged()337 private void verifyWtfNeverLogged() { 338 int size = mWtfs.size(); 339 340 switch (size) { 341 case 0: 342 return; 343 case 1: 344 throw mWtfs.get(0); 345 default: 346 StringBuilder msg = new StringBuilder("wtf called ").append(size).append(" times") 347 .append(": ").append(mWtfs); 348 throw new AssertionError(msg.toString()); 349 } 350 } 351 352 @NonNull newSessionBuilder()353 private MockitoSessionBuilder newSessionBuilder() { 354 // TODO (b/155523104): change from mock to spy 355 StaticMockitoSessionBuilder builder = mockitoSession() 356 .strictness(getSessionStrictness()) 357 .mockStatic(Settings.Global.class) 358 .mockStatic(Settings.System.class) 359 .mockStatic(Settings.Secure.class); 360 361 CustomMockitoSessionBuilder customBuilder = 362 new CustomMockitoSessionBuilder(builder, mStaticSpiedClasses) 363 .spyStatic(Log.class) 364 .spyStatic(Slog.class); 365 366 onSessionBuilder(customBuilder); 367 368 if (VERBOSE) Log.v(TAG, "spied classes" + customBuilder.mStaticSpiedClasses); 369 370 return builder.initMocks(this); 371 } 372 373 /** 374 * Gets a prefix for {@link Log} calls 375 */ getLogPrefix()376 protected String getLogPrefix() { 377 return getClass().getSimpleName() + "."; 378 } 379 380 /** 381 * Asserts the given class is being spied in the Mockito session. 382 */ assertSpied(Class<?> clazz)383 protected void assertSpied(Class<?> clazz) { 384 Preconditions.checkArgument(mStaticSpiedClasses.contains(clazz), 385 "did not call spyStatic() on %s", clazz.getName()); 386 } 387 388 /** 389 * Custom {@code MockitoSessionBuilder} used to make sure some pre-defined mock stations 390 * (like {@link AbstractExtendedMockitoTestCase#mockGetCurrentUser(int)} fail if the test case 391 * didn't explicitly set it to spy / mock the required classes. 392 * 393 * <p><b>NOTE: </b>for now it only provides simple {@link #spyStatic(Class)}, but more methods 394 * (as provided by {@link StaticMockitoSessionBuilder}) could be provided as needed. 395 */ 396 public static final class CustomMockitoSessionBuilder { 397 private final StaticMockitoSessionBuilder mBuilder; 398 private final List<Class<?>> mStaticSpiedClasses; 399 CustomMockitoSessionBuilder(StaticMockitoSessionBuilder builder, List<Class<?>> staticSpiedClasses)400 private CustomMockitoSessionBuilder(StaticMockitoSessionBuilder builder, 401 List<Class<?>> staticSpiedClasses) { 402 mBuilder = builder; 403 mStaticSpiedClasses = staticSpiedClasses; 404 } 405 406 /** 407 * Same as {@link StaticMockitoSessionBuilder#spyStatic(Class)}. 408 */ spyStatic(Class<T> clazz)409 public <T> CustomMockitoSessionBuilder spyStatic(Class<T> clazz) { 410 Preconditions.checkState(!mStaticSpiedClasses.contains(clazz), 411 "already called spyStatic() on " + clazz); 412 mStaticSpiedClasses.add(clazz); 413 mBuilder.spyStatic(clazz); 414 return this; 415 } 416 } 417 418 private final class WtfCheckerRule implements TestRule { 419 420 @Override apply(Statement base, Description description)421 public Statement apply(Statement base, Description description) { 422 return new Statement() { 423 @Override 424 public void evaluate() throws Throwable { 425 String testName = description.getMethodName(); 426 if (VERBOSE) Log.v(TAG, "running " + testName); 427 beginTrace("evaluate-" + testName); 428 base.evaluate(); 429 endTrace(); 430 431 Method testMethod = AbstractExtendedMockitoTestCase.this.getClass() 432 .getMethod(testName); 433 ExpectWtf expectWtfAnnotation = testMethod.getAnnotation(ExpectWtf.class); 434 435 beginTrace("verify-wtfs"); 436 try { 437 if (expectWtfAnnotation != null) { 438 if (VERBOSE) Log.v(TAG, "expecting wtf()"); 439 verifyWtfLogged(); 440 } else { 441 if (VERBOSE) Log.v(TAG, "NOT expecting wtf()"); 442 verifyWtfNeverLogged(); 443 } 444 } finally { 445 endTrace(); 446 } 447 } 448 }; 449 } 450 } 451 452 // TODO (b/155523104): Add log 453 // TODO (b/156033195): Clean settings API 454 private static final class MockSettings { 455 private static final int INVALID_DEFAULT_INDEX = -1; 456 private HashMap<String, Object> mSettingsMapping = new HashMap<>(); 457 458 MockSettings() { 459 460 Answer<Object> insertObjectAnswer = 461 invocation -> insertObjectFromInvocation(invocation, 1, 2); 462 Answer<Integer> getIntAnswer = invocation -> 463 getAnswer(invocation, Integer.class, 1, 2); 464 Answer<String> getStringAnswer = invocation -> 465 getAnswer(invocation, String.class, 1, INVALID_DEFAULT_INDEX); 466 467 when(Settings.Global.putInt(any(), any(), anyInt())).thenAnswer(insertObjectAnswer); 468 469 when(Settings.Global.getInt(any(), any(), anyInt())).thenAnswer(getIntAnswer); 470 471 when(Settings.Secure.putIntForUser(any(), any(), anyInt(), anyInt())) 472 .thenAnswer(insertObjectAnswer); 473 474 when(Settings.Secure.getIntForUser(any(), any(), anyInt(), anyInt())) 475 .thenAnswer(getIntAnswer); 476 477 when(Settings.Secure.putStringForUser(any(), anyString(), anyString(), anyInt())) 478 .thenAnswer(insertObjectAnswer); 479 480 when(Settings.Global.putString(any(), any(), any())) 481 .thenAnswer(insertObjectAnswer); 482 483 when(Settings.Global.getString(any(), any())).thenAnswer(getStringAnswer); 484 485 when(Settings.System.putIntForUser(any(), any(), anyInt(), anyInt())) 486 .thenAnswer(insertObjectAnswer); 487 488 when(Settings.System.getIntForUser(any(), any(), anyInt(), anyInt())) 489 .thenAnswer(getIntAnswer); 490 491 when(Settings.System.putStringForUser(any(), any(), anyString(), anyInt())) 492 .thenAnswer(insertObjectAnswer); 493 } 494 495 private Object insertObjectFromInvocation(InvocationOnMock invocation, 496 int keyIndex, int valueIndex) { 497 String key = (String) invocation.getArguments()[keyIndex]; 498 Object value = invocation.getArguments()[valueIndex]; 499 insertObject(key, value); 500 return null; 501 } 502 503 private void insertObject(String key, Object value) { 504 if (VERBOSE) Log.v(TAG, "Inserting Setting " + key + ": " + value); 505 mSettingsMapping.put(key, value); 506 } 507 508 private <T> T getAnswer(InvocationOnMock invocation, Class<T> clazz, 509 int keyIndex, int defaultValueIndex) { 510 String key = (String) invocation.getArguments()[keyIndex]; 511 T defaultValue = null; 512 if (defaultValueIndex > INVALID_DEFAULT_INDEX) { 513 defaultValue = safeCast(invocation.getArguments()[defaultValueIndex], clazz); 514 } 515 return get(key, defaultValue, clazz); 516 } 517 518 @Nullable 519 private <T> T get(String key, T defaultValue, Class<T> clazz) { 520 if (VERBOSE) { 521 Log.v(TAG, "get(): key=" + key + ", default=" + defaultValue + ", class=" + clazz); 522 } 523 Object value = mSettingsMapping.get(key); 524 if (value == null) { 525 if (VERBOSE) Log.v(TAG, "not found"); 526 return defaultValue; 527 } 528 529 if (VERBOSE) Log.v(TAG, "returning " + value); 530 return safeCast(value, clazz); 531 } 532 533 private static <T> T safeCast(Object value, Class<T> clazz) { 534 if (value == null) { 535 return null; 536 } 537 Preconditions.checkArgument(value.getClass() == clazz, 538 "Setting value has class %s but requires class %s", 539 value.getClass(), clazz); 540 return clazz.cast(value); 541 } 542 543 private String getString(String key) { 544 return get(key, null, String.class); 545 } 546 547 public int getInt(String key) { 548 return get(key, null, Integer.class); 549 } 550 551 public void assertDoesNotContainsKey(String key) { 552 if (mSettingsMapping.containsKey(key)) { 553 throw new AssertionError("Should not have key " + key + ", but has: " 554 + mSettingsMapping.get(key)); 555 } 556 } 557 } 558 559 /** 560 * Annotation used on test methods that are expect to call {@code wtf()} methods on {@link Log} 561 * or {@link Slog} - if such methods are not annotated with this annotation, they will fail. 562 */ 563 @Retention(RUNTIME) 564 @Target({METHOD}) 565 public static @interface ExpectWtf { 566 } 567 568 private static final class SyncRunnable implements Runnable { 569 private final Runnable mTarget; 570 private volatile boolean mComplete = false; 571 572 private SyncRunnable(Runnable target) { 573 mTarget = target; 574 } 575 576 @Override 577 public void run() { 578 mTarget.run(); 579 synchronized (this) { 580 mComplete = true; 581 notifyAll(); 582 } 583 } 584 585 private void waitForComplete() { 586 synchronized (this) { 587 while (!mComplete) { 588 try { 589 wait(); 590 } catch (InterruptedException e) { 591 } 592 } 593 } 594 } 595 } 596 } 597