1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package androidx.arch.core.executor.testing; 18 19 import android.os.SystemClock; 20 21 import androidx.arch.core.executor.ArchTaskExecutor; 22 import androidx.arch.core.executor.DefaultTaskExecutor; 23 24 import org.jspecify.annotations.NonNull; 25 import org.junit.rules.TestWatcher; 26 import org.junit.runner.Description; 27 28 import java.util.concurrent.TimeUnit; 29 import java.util.concurrent.TimeoutException; 30 31 /** 32 * A JUnit Test Rule that swaps the background executor used by the Architecture Components with a 33 * different one which counts the tasks as they are start and finish. 34 * <p> 35 * You can use this rule for your host side tests that use Architecture Components. 36 */ 37 public class CountingTaskExecutorRule extends TestWatcher { 38 private final Object mCountLock = new Object(); 39 private int mTaskCount = 0; 40 41 @Override starting(Description description)42 protected void starting(Description description) { 43 super.starting(description); 44 ArchTaskExecutor.getInstance().setDelegate(new DefaultTaskExecutor() { 45 @Override 46 public void executeOnDiskIO(@NonNull Runnable runnable) { 47 super.executeOnDiskIO(new CountingRunnable(runnable)); 48 } 49 50 @Override 51 public void postToMainThread(@NonNull Runnable runnable) { 52 super.postToMainThread(new CountingRunnable(runnable)); 53 } 54 }); 55 } 56 57 @Override finished(Description description)58 protected void finished(Description description) { 59 super.finished(description); 60 ArchTaskExecutor.getInstance().setDelegate(null); 61 } 62 63 @SuppressWarnings("WeakerAccess") /* synthetic access */ increment()64 void increment() { 65 synchronized (mCountLock) { 66 mTaskCount++; 67 } 68 } 69 70 @SuppressWarnings("WeakerAccess") /* synthetic access */ decrement()71 void decrement() { 72 synchronized (mCountLock) { 73 mTaskCount--; 74 if (mTaskCount == 0) { 75 onIdle(); 76 mCountLock.notifyAll(); 77 } 78 } 79 } 80 81 /** 82 * Called when the number of awaiting tasks reaches to 0. 83 * 84 * @see #isIdle() 85 */ onIdle()86 protected void onIdle() { 87 88 } 89 90 /** 91 * Returns false if there are tasks waiting to be executed, true otherwise. 92 * 93 * @return False if there are tasks waiting to be executed, true otherwise. 94 * 95 * @see #onIdle() 96 */ isIdle()97 public boolean isIdle() { 98 synchronized (mCountLock) { 99 return mTaskCount == 0; 100 } 101 } 102 103 /** 104 * Waits until all active tasks are finished. 105 * 106 * @param time The duration to wait 107 * @param timeUnit The time unit for the {@code time} parameter 108 * 109 * @throws InterruptedException If thread is interrupted while waiting 110 * @throws TimeoutException If tasks cannot be drained at the given time 111 */ drainTasks(int time, @NonNull TimeUnit timeUnit)112 public void drainTasks(int time, @NonNull TimeUnit timeUnit) 113 throws InterruptedException, TimeoutException { 114 long end = SystemClock.uptimeMillis() + timeUnit.toMillis(time); 115 synchronized (mCountLock) { 116 while (mTaskCount != 0) { 117 long now = SystemClock.uptimeMillis(); 118 long remaining = end - now; 119 if (remaining > 0) { 120 mCountLock.wait(remaining); 121 } else { 122 throw new TimeoutException("could not drain tasks"); 123 } 124 } 125 } 126 } 127 128 class CountingRunnable implements Runnable { 129 final Runnable mWrapped; 130 CountingRunnable(Runnable wrapped)131 CountingRunnable(Runnable wrapped) { 132 mWrapped = wrapped; 133 increment(); 134 } 135 136 @Override run()137 public void run() { 138 try { 139 mWrapped.run(); 140 } finally { 141 decrement(); 142 } 143 } 144 } 145 } 146