1 /*
<lambda>null2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 @file:OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
18
19 package com.android.test.tracing.coroutines
20
21 import android.platform.test.flag.junit.SetFlagsRule
22 import androidx.test.ext.junit.runners.AndroidJUnit4
23 import com.android.app.tracing.coroutines.COROUTINE_EXECUTION
24 import com.android.app.tracing.coroutines.createCoroutineTracingContext
25 import com.android.app.tracing.coroutines.traceThreadLocal
26 import com.android.test.tracing.coroutines.util.FakeTraceState
27 import com.android.test.tracing.coroutines.util.FakeTraceState.getOpenTraceSectionsOnCurrentThread
28 import com.android.test.tracing.coroutines.util.ShadowTrace
29 import java.io.PrintWriter
30 import java.io.StringWriter
31 import java.util.concurrent.atomic.AtomicInteger
32 import kotlin.coroutines.CoroutineContext
33 import kotlinx.coroutines.CancellationException
34 import kotlinx.coroutines.CoroutineDispatcher
35 import kotlinx.coroutines.CoroutineExceptionHandler
36 import kotlinx.coroutines.CoroutineScope
37 import kotlinx.coroutines.DelicateCoroutinesApi
38 import kotlinx.coroutines.ExperimentalCoroutinesApi
39 import kotlinx.coroutines.TimeoutCancellationException
40 import kotlinx.coroutines.cancel
41 import kotlinx.coroutines.delay
42 import kotlinx.coroutines.isActive
43 import kotlinx.coroutines.launch
44 import kotlinx.coroutines.newSingleThreadContext
45 import kotlinx.coroutines.runBlocking
46 import kotlinx.coroutines.withContext
47 import kotlinx.coroutines.withTimeout
48 import org.junit.After
49 import org.junit.Assert.assertTrue
50 import org.junit.Assert.fail
51 import org.junit.Before
52 import org.junit.ClassRule
53 import org.junit.Rule
54 import org.junit.runner.RunWith
55 import org.robolectric.annotation.Config
56
57 class InvalidTraceStateException(message: String, cause: Throwable? = null) :
58 AssertionError(message, cause)
59
60 internal val mainTestDispatcher = newSingleThreadContext("test-main")
61 internal val bgThread1 = newSingleThreadContext("test-bg-1")
62 internal val bgThread2 = newSingleThreadContext("test-bg-2")
63 internal val bgThread3 = newSingleThreadContext("test-bg-3")
64 internal val bgThread4 = newSingleThreadContext("test-bg-4")
65
66 @RunWith(AndroidJUnit4::class)
67 @Config(shadows = [ShadowTrace::class])
68 abstract class TestBase {
69 companion object {
70 @JvmField
71 @ClassRule
72 val setFlagsClassRule: SetFlagsRule.ClassRule =
73 SetFlagsRule.ClassRule(com.android.systemui.Flags::class.java)
74
75 @JvmStatic
76 private fun isRobolectricTest(): Boolean {
77 return System.getProperty("java.vm.name") != "Dalvik"
78 }
79 }
80
81 // TODO(b/339471826): Robolectric does not execute @ClassRule correctly
82 @get:Rule
83 val setFlagsRule: SetFlagsRule =
84 if (isRobolectricTest()) SetFlagsRule() else setFlagsClassRule.createSetFlagsRule()
85
86 private val eventCounter = AtomicInteger(0)
87 private val allEventCounter = AtomicInteger(0)
88 private val finalEvent = AtomicInteger(INVALID_EVENT)
89 private val allExceptions = mutableListOf<Throwable>()
90 private val assertionErrors = mutableListOf<AssertionError>()
91
92 /** The scope to be used by the test in [runTest] */
93 val scope: CoroutineScope by lazy { CoroutineScope(extraContext + mainTestDispatcher) }
94
95 /**
96 * Context passed to the scope used for the test. If the returned [CoroutineContext] contains a
97 * [CoroutineDispatcher] it will be overwritten.
98 */
99 open val extraContext: CoroutineContext by lazy {
100 createCoroutineTracingContext("main", testMode = true)
101 }
102
103 @Before
104 fun setup() {
105 FakeTraceState.isTracingEnabled = true
106 FakeTraceState.clearAll()
107
108 // Reset all thread-local state
109 traceThreadLocal.remove()
110 val dispatchers = listOf(mainTestDispatcher, bgThread1, bgThread2, bgThread3, bgThread4)
111 runBlocking { dispatchers.forEach { withContext(it) { traceThreadLocal.remove() } } }
112
113 // Initialize scope, which is a lazy type:
114 assertTrue(scope.isActive)
115 }
116
117 @After
118 fun tearDown() {
119 val sw = StringWriter()
120 val pw = PrintWriter(sw)
121
122 allExceptions.forEach { it.printStackTrace(pw) }
123 assertTrue("Test failed due to unexpected exception\n$sw", allExceptions.isEmpty())
124
125 assertionErrors.forEach { it.printStackTrace(pw) }
126 assertTrue("Test failed due to incorrect trace sections\n$sw", assertionErrors.isEmpty())
127 }
128
129 /**
130 * Launches the test on the provided [scope], then uses [runBlocking] to wait for completion.
131 * The test will timeout if it takes longer than 200ms.
132 */
133 protected fun runTest(
134 isExpectedException: ((Throwable) -> Boolean)? = null,
135 finalEvent: Int? = null,
136 totalEvents: Int? = null,
137 block: suspend CoroutineScope.() -> Unit,
138 ) {
139 var foundExpectedException = false
140 try {
141 val job =
142 scope.launch(
143 context =
144 CoroutineExceptionHandler { _, e ->
145 if (e is CancellationException)
146 return@CoroutineExceptionHandler // ignore it
147 if (isExpectedException != null && isExpectedException(e)) {
148 foundExpectedException = true
149 } else {
150 allExceptions.add(e)
151 }
152 },
153 block = block,
154 )
155
156 runBlocking {
157 val timeoutMs = 200L
158 try {
159 withTimeout(timeoutMs) { job.join() }
160 } catch (e: TimeoutCancellationException) {
161 fail("Timeout running test. Test should complete in less than $timeoutMs ms")
162 throw e
163 } finally {
164 scope.cancel()
165 }
166 }
167 } finally {
168 if (isExpectedException != null && !foundExpectedException) {
169 fail("Expected exceptions, but none were thrown")
170 }
171 }
172 if (finalEvent != null) {
173 checkFinalEvent(finalEvent)
174 }
175 if (totalEvents != null) {
176 checkTotalEvents(totalEvents)
177 }
178 }
179
180 private fun logInvalidTraceState(message: String, throwInsteadOfLog: Boolean = false) {
181 val e = InvalidTraceStateException(message)
182 if (throwInsteadOfLog) {
183 throw e
184 } else {
185 assertionErrors.add(e)
186 }
187 }
188
189 /**
190 * Same as [expect], but also call [delay] for 1ms, calling [expect] before and after the
191 * suspension point.
192 */
193 protected suspend fun expectD(vararg expectedOpenTraceSections: String) {
194 expect(*expectedOpenTraceSections)
195 delay(1)
196 expect(*expectedOpenTraceSections)
197 }
198
199 protected fun expectEndsWith(vararg expectedOpenTraceSections: String) {
200 allEventCounter.getAndAdd(1)
201 // Inspect trace output to the fake used for recording android.os.Trace API calls:
202 val actualSections = getOpenTraceSectionsOnCurrentThread()
203 if (expectedOpenTraceSections.size <= actualSections.size) {
204 val lastSections =
205 actualSections.takeLast(expectedOpenTraceSections.size).toTypedArray()
206 assertTraceSectionsEquals(expectedOpenTraceSections, null, lastSections, null)
207 } else {
208 logInvalidTraceState(
209 "Invalid length: expected size (${expectedOpenTraceSections.size}) <= actual size (${actualSections.size})"
210 )
211 }
212 }
213
214 protected fun expectEvent(expectedEvent: Collection<Int>): Int {
215 val previousEvent = eventCounter.getAndAdd(1)
216 val currentEvent = previousEvent + 1
217 if (!expectedEvent.contains(currentEvent)) {
218 logInvalidTraceState(
219 if (previousEvent == FINAL_EVENT) {
220 "Expected event ${expectedEvent.prettyPrintList()}, but finish() was already called"
221 } else {
222 "Expected event ${expectedEvent.prettyPrintList()}," +
223 " but the event counter is currently at #$currentEvent"
224 }
225 )
226 }
227 return currentEvent
228 }
229
230 /**
231 * Checks the currently active trace sections on the current thread, and optionally checks the
232 * order of operations if [expectedEvent] is not null.
233 */
234 internal fun expectAny(vararg possibleOpenSections: Array<out String>) {
235 allEventCounter.getAndAdd(1)
236 val actualOpenSections = getOpenTraceSectionsOnCurrentThread()
237 val caughtExceptions = mutableListOf<AssertionError>()
238 possibleOpenSections.forEach { expectedSections ->
239 try {
240 assertTraceSectionsEquals(
241 expectedSections,
242 expectedEvent = null,
243 actualOpenSections,
244 actualEvent = null,
245 throwInsteadOfLog = true,
246 )
247 } catch (e: AssertionError) {
248 caughtExceptions.add(e)
249 }
250 }
251 if (caughtExceptions.size == possibleOpenSections.size) {
252 val e = caughtExceptions[0]
253 val allLists =
254 possibleOpenSections.joinToString(separator = ", OR ") { it.prettyPrintList() }
255 assertionErrors.add(
256 InvalidTraceStateException("Expected $allLists. For example, ${e.message}", e.cause)
257 )
258 }
259 }
260
261 internal fun expect(vararg expectedOpenTraceSections: String) {
262 expect(null, *expectedOpenTraceSections)
263 }
264
265 internal fun expect(expectedEvent: Int, vararg expectedOpenTraceSections: String) {
266 expect(listOf(expectedEvent), *expectedOpenTraceSections)
267 }
268
269 /**
270 * Checks the currently active trace sections on the current thread, and optionally checks the
271 * order of operations if [expectedEvent] is not null.
272 */
273 internal fun expect(possibleEventPos: List<Int>?, vararg expectedOpenTraceSections: String) {
274 var currentEvent: Int? = null
275 allEventCounter.getAndAdd(1)
276 if (possibleEventPos != null) {
277 currentEvent = expectEvent(possibleEventPos)
278 }
279 val actualOpenSections = getOpenTraceSectionsOnCurrentThread()
280 assertTraceSectionsEquals(
281 expectedOpenTraceSections,
282 possibleEventPos,
283 actualOpenSections,
284 currentEvent,
285 )
286 }
287
288 private fun assertTraceSectionsEquals(
289 expectedOpenTraceSections: Array<out String>,
290 expectedEvent: List<Int>?,
291 actualOpenSections: Array<String>,
292 actualEvent: Int?,
293 throwInsteadOfLog: Boolean = false,
294 ) {
295 val expectedSize = expectedOpenTraceSections.size
296 val actualSize = actualOpenSections.size
297 if (expectedSize != actualSize) {
298 logInvalidTraceState(
299 createFailureMessage(
300 expectedOpenTraceSections,
301 expectedEvent,
302 actualOpenSections,
303 actualEvent,
304 "Size mismatch, expected size $expectedSize but was size $actualSize",
305 ),
306 throwInsteadOfLog,
307 )
308 } else {
309 expectedOpenTraceSections.forEachIndexed { n, expected ->
310 val actualTrace = actualOpenSections[n]
311 val actual = actualTrace.getTracedName()
312 if (expected != actual) {
313 logInvalidTraceState(
314 createFailureMessage(
315 expectedOpenTraceSections,
316 expectedEvent,
317 actualOpenSections,
318 actualEvent,
319 "Differed at index #$n, expected \"$expected\" but was \"$actual\"",
320 ),
321 throwInsteadOfLog,
322 )
323 return
324 }
325 }
326 }
327 }
328
329 private fun createFailureMessage(
330 expectedOpenTraceSections: Array<out String>,
331 expectedEventNumber: List<Int>?,
332 actualOpenSections: Array<String>,
333 actualEventNumber: Int?,
334 extraMessage: String,
335 ): String {
336 val locationMarker =
337 if (expectedEventNumber == null || actualEventNumber == null) ""
338 else if (expectedEventNumber.contains(actualEventNumber))
339 " at event #$actualEventNumber"
340 else
341 ", expected event ${expectedEventNumber.prettyPrintList()}, actual event #$actualEventNumber"
342 return """
343 Incorrect trace$locationMarker. $extraMessage
344 Expected : {${expectedOpenTraceSections.prettyPrintList()}}
345 Actual : {${actualOpenSections.prettyPrintList()}}
346 """
347 .trimIndent()
348 }
349
350 private fun checkFinalEvent(expectedEvent: Int): Int {
351 finalEvent.compareAndSet(INVALID_EVENT, expectedEvent)
352 val previousEvent = eventCounter.getAndSet(FINAL_EVENT)
353 if (expectedEvent != previousEvent) {
354 logInvalidTraceState(
355 "Expected to finish with event #$expectedEvent, but " +
356 if (previousEvent == FINAL_EVENT)
357 "finish() was already called with event #${finalEvent.get()}"
358 else "the event counter is currently at #$previousEvent"
359 )
360 }
361 return previousEvent
362 }
363
364 private fun checkTotalEvents(totalEvents: Int): Int {
365 allEventCounter.compareAndSet(INVALID_EVENT, totalEvents)
366 val previousEvent = allEventCounter.getAndSet(FINAL_EVENT)
367 if (totalEvents != previousEvent) {
368 logInvalidTraceState(
369 "Expected test to end with a total of $totalEvents events, but " +
370 if (previousEvent == FINAL_EVENT)
371 "finish() was already called at event #${finalEvent.get()}"
372 else "instead there were $previousEvent events"
373 )
374 }
375 return previousEvent
376 }
377 }
378
Stringnull379 private fun String.getTracedName(): String =
380 if (startsWith(COROUTINE_EXECUTION))
381 // For strings like "coroutine execution;scope-name;c=1234;p=5678", extract:
382 // "scope-name"
383 substringAfter(";").substringBefore(";")
384 else substringBefore(";")
385
386 private const val INVALID_EVENT = -1
387
388 private const val FINAL_EVENT = Int.MIN_VALUE
389
390 private fun Collection<Int>.prettyPrintList(): String {
391 return if (isEmpty()) ""
392 else if (size == 1) "#${iterator().next()}"
393 else {
394 "{${
395 toList().joinToString(
396 separator = ", #",
397 prefix = "#",
398 postfix = "",
399 ) { it.toString() }
400 }}"
401 }
402 }
403
Arraynull404 private fun Array<out String>.prettyPrintList(): String {
405 return if (isEmpty()) ""
406 else
407 toList().joinToString(separator = "\", \"", prefix = "\"", postfix = "\"") {
408 it.getTracedName()
409 }
410 }
411