1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.arch.core.executor.testing;
18 
19 import android.os.SystemClock;
20 
21 import androidx.arch.core.executor.ArchTaskExecutor;
22 import androidx.arch.core.executor.DefaultTaskExecutor;
23 
24 import org.jspecify.annotations.NonNull;
25 import org.junit.rules.TestWatcher;
26 import org.junit.runner.Description;
27 
28 import java.util.concurrent.TimeUnit;
29 import java.util.concurrent.TimeoutException;
30 
31 /**
32  * A JUnit Test Rule that swaps the background executor used by the Architecture Components with a
33  * different one which counts the tasks as they are start and finish.
34  * <p>
35  * You can use this rule for your host side tests that use Architecture Components.
36  */
37 public class CountingTaskExecutorRule extends TestWatcher {
38     private final Object mCountLock = new Object();
39     private int mTaskCount = 0;
40 
41     @Override
starting(Description description)42     protected void starting(Description description) {
43         super.starting(description);
44         ArchTaskExecutor.getInstance().setDelegate(new DefaultTaskExecutor() {
45             @Override
46             public void executeOnDiskIO(@NonNull Runnable runnable) {
47                 super.executeOnDiskIO(new CountingRunnable(runnable));
48             }
49 
50             @Override
51             public void postToMainThread(@NonNull Runnable runnable) {
52                 super.postToMainThread(new CountingRunnable(runnable));
53             }
54         });
55     }
56 
57     @Override
finished(Description description)58     protected void finished(Description description) {
59         super.finished(description);
60         ArchTaskExecutor.getInstance().setDelegate(null);
61     }
62 
63     @SuppressWarnings("WeakerAccess") /* synthetic access */
increment()64     void increment() {
65         synchronized (mCountLock) {
66             mTaskCount++;
67         }
68     }
69 
70     @SuppressWarnings("WeakerAccess") /* synthetic access */
decrement()71     void decrement() {
72         synchronized (mCountLock) {
73             mTaskCount--;
74             if (mTaskCount == 0) {
75                 onIdle();
76                 mCountLock.notifyAll();
77             }
78         }
79     }
80 
81     /**
82      * Called when the number of awaiting tasks reaches to 0.
83      *
84      * @see #isIdle()
85      */
onIdle()86     protected void onIdle() {
87 
88     }
89 
90     /**
91      * Returns false if there are tasks waiting to be executed, true otherwise.
92      *
93      * @return False if there are tasks waiting to be executed, true otherwise.
94      *
95      * @see #onIdle()
96      */
isIdle()97     public boolean isIdle() {
98         synchronized (mCountLock) {
99             return mTaskCount == 0;
100         }
101     }
102 
103     /**
104      * Waits until all active tasks are finished.
105      *
106      * @param time The duration to wait
107      * @param timeUnit The time unit for the {@code time} parameter
108      *
109      * @throws InterruptedException If thread is interrupted while waiting
110      * @throws TimeoutException If tasks cannot be drained at the given time
111      */
drainTasks(int time, @NonNull TimeUnit timeUnit)112     public void drainTasks(int time, @NonNull TimeUnit timeUnit)
113             throws InterruptedException, TimeoutException {
114         long end = SystemClock.uptimeMillis() + timeUnit.toMillis(time);
115         synchronized (mCountLock) {
116             while (mTaskCount != 0) {
117                 long now = SystemClock.uptimeMillis();
118                 long remaining = end - now;
119                 if (remaining > 0) {
120                     mCountLock.wait(remaining);
121                 } else {
122                     throw new TimeoutException("could not drain tasks");
123                 }
124             }
125         }
126     }
127 
128     class CountingRunnable implements Runnable {
129         final Runnable mWrapped;
130 
CountingRunnable(Runnable wrapped)131         CountingRunnable(Runnable wrapped) {
132             mWrapped = wrapped;
133             increment();
134         }
135 
136         @Override
run()137         public void run() {
138             try {
139                 mWrapped.run();
140             } finally {
141                 decrement();
142             }
143         }
144     }
145 }
146