1 /*
2  * Copyright 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.work.impl
18 
19 import android.content.Context
20 import android.os.Build
21 import androidx.concurrent.futures.CallbackToFutureAdapter.Completer
22 import androidx.concurrent.futures.CallbackToFutureAdapter.getFuture
23 import androidx.core.app.NotificationCompat
24 import androidx.test.core.app.ApplicationProvider
25 import androidx.test.ext.junit.runners.AndroidJUnit4
26 import androidx.test.filters.SdkSuppress
27 import androidx.test.filters.SmallTest
28 import androidx.work.Configuration
29 import androidx.work.ForegroundInfo
30 import androidx.work.ListenableWorker
31 import androidx.work.OneTimeWorkRequest
32 import androidx.work.OutOfQuotaPolicy
33 import androidx.work.SystemClock
34 import androidx.work.WorkerFactory
35 import androidx.work.WorkerParameters
36 import androidx.work.impl.foreground.ForegroundProcessor
37 import androidx.work.impl.utils.SerialExecutorImpl
38 import androidx.work.impl.utils.taskexecutor.TaskExecutor
39 import com.google.common.truth.Truth.assertThat
40 import com.google.common.util.concurrent.ListenableFuture
41 import java.util.concurrent.Executor
42 import org.junit.Test
43 import org.junit.runner.RunWith
44 
45 @RunWith(AndroidJUnit4::class)
46 @SmallTest
47 class ControlledWorkerWrapperTest {
48     private val context: Context = ApplicationProvider.getApplicationContext()
49     private val taskExecutor = ManualTaskExecutor()
50     private val backgroundExecutor = ManualExecutor()
51     private val workDatabase =
52         WorkDatabase.create(context, taskExecutor.serialTaskExecutor, SystemClock(), true)
53 
54     @Test
testInterruptionsBeforenull55     fun testInterruptionsBefore() {
56         val work = OneTimeWorkRequest.Builder(TestWrapperWorker::class.java).build()
57         workDatabase.workSpecDao().insertWorkSpec(work.workSpec)
58         lateinit var worker: TestWrapperWorker
59         val workerWrapper = workerWrapper(work.stringId) { worker = it }
60         val future = workerWrapper.launch()
61 
62         while (
63             taskExecutor.serialTaskExecutor.hasPendingTask() || backgroundExecutor.hasPendingTask()
64         ) {
65             taskExecutor.serialTaskExecutor.drain()
66             backgroundExecutor.drain()
67         }
68         workerWrapper.interrupt(0)
69         drainAll()
70         assertThat(future.isDone).isTrue()
71         assertThat(worker.startWorkWasCalled).isFalse()
72     }
73 
74     @Test
75     @SdkSuppress(maxSdkVersion = Build.VERSION_CODES.R) // getForegroundInfoAsync isn't called on S
testInterruptionsBetweenGetForegroundInfoAsyncAndStartWorknull76     fun testInterruptionsBetweenGetForegroundInfoAsyncAndStartWork() {
77         val work =
78             OneTimeWorkRequest.Builder(TestWrapperWorker::class.java)
79                 .setExpedited(OutOfQuotaPolicy.DROP_WORK_REQUEST)
80                 .build()
81         workDatabase.workSpecDao().insertWorkSpec(work.workSpec)
82         lateinit var worker: TestWrapperWorker
83         val workerWrapper = workerWrapper(work.stringId) { worker = it }
84         val future = workerWrapper.launch()
85         drainAll()
86         assertThat(worker.getForegroundInfoAsyncWasCalled).isTrue()
87         assertThat(worker.startWorkWasCalled).isFalse()
88         worker.foregroundInfoCompleter.set(
89             ForegroundInfo(0, NotificationCompat.Builder(context, "test").build())
90         )
91         workerWrapper.interrupt(0)
92         drainAll()
93         assertThat(worker.startWorkWasCalled).isFalse()
94         assertThat(future.isDone).isTrue()
95     }
96 
drainAllnull97     private fun drainAll() {
98         while (
99             taskExecutor.serialTaskExecutor.hasPendingTask() ||
100                 backgroundExecutor.hasPendingTask() ||
101                 taskExecutor.mainExecutor.hasPendingTask()
102         ) {
103             taskExecutor.serialTaskExecutor.drain()
104             backgroundExecutor.drain()
105             taskExecutor.mainExecutor.drain()
106         }
107     }
108 
workerWrappernull109     private fun workerWrapper(
110         id: String,
111         workerInterceptor: (TestWrapperWorker) -> Unit
112     ): WorkerWrapper {
113         val config =
114             Configuration.Builder()
115                 .setExecutor(backgroundExecutor)
116                 .setWorkerFactory(
117                     object : WorkerFactory() {
118                         override fun createWorker(
119                             appContext: Context,
120                             workerClassName: String,
121                             workerParameters: WorkerParameters
122                         ): ListenableWorker {
123                             val worker =
124                                 TestWrapperWorker(
125                                     appContext,
126                                     workerParameters,
127                                 )
128                             workerInterceptor(worker)
129                             return worker
130                         }
131                     }
132                 )
133                 .build()
134         return WorkerWrapper.Builder(
135                 context,
136                 config,
137                 taskExecutor,
138                 NoOpForegroundProcessor,
139                 workDatabase,
140                 workDatabase.workSpecDao().getWorkSpec(id)!!,
141                 emptyList()
142             )
143             .build()
144     }
145 }
146 
147 internal class TestWrapperWorker(
148     appContext: Context,
149     workerParams: WorkerParameters,
150 ) : ListenableWorker(appContext, workerParams) {
151     var getForegroundInfoAsyncWasCalled = false
152     var startWorkWasCalled = false
153     lateinit var foregroundInfoCompleter: Completer<ForegroundInfo>
154 
getForegroundInfoAsyncnull155     override fun getForegroundInfoAsync(): ListenableFuture<ForegroundInfo> {
156         getForegroundInfoAsyncWasCalled = true
157         return getFuture {
158             foregroundInfoCompleter = it
159             "getForegroundInfoAsync completer"
160         }
161     }
162 
startWorknull163     override fun startWork(): ListenableFuture<Result> {
164         startWorkWasCalled = true
165         return getFuture { it.set(Result.success()) }
166     }
167 }
168 
169 object NoOpForegroundProcessor : ForegroundProcessor {
startForegroundnull170     override fun startForeground(workSpecId: String, foregroundInfo: ForegroundInfo) {}
171 }
172 
173 class ManualExecutor : Executor {
174     private val tasks = ArrayDeque<Runnable>(10)
175 
executenull176     override fun execute(runnable: Runnable) {
177         tasks.add(runnable)
178     }
179 
drainnull180     fun drain() {
181         while (tasks.isNotEmpty()) {
182             val head = tasks.removeFirst()
183             head.run()
184         }
185     }
186 
hasPendingTasknull187     fun hasPendingTask() = tasks.isNotEmpty()
188 }
189 
190 class ManualTaskExecutor : TaskExecutor {
191     val mainExecutor = ManualExecutor()
192     val serialTaskExecutor = ManualExecutor()
193     private val serialBackgroundExecutor = SerialExecutorImpl(serialTaskExecutor)
194 
195     override fun getMainThreadExecutor() = mainExecutor
196 
197     override fun getSerialTaskExecutor(): SerialExecutorImpl = serialBackgroundExecutor
198 }
199