• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 @file:OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
18 
19 package com.android.test.tracing.coroutines
20 
21 import android.platform.test.flag.junit.SetFlagsRule
22 import androidx.test.ext.junit.runners.AndroidJUnit4
23 import com.android.app.tracing.coroutines.COROUTINE_EXECUTION
24 import com.android.app.tracing.coroutines.createCoroutineTracingContext
25 import com.android.app.tracing.coroutines.traceThreadLocal
26 import com.android.test.tracing.coroutines.util.FakeTraceState
27 import com.android.test.tracing.coroutines.util.FakeTraceState.getOpenTraceSectionsOnCurrentThread
28 import com.android.test.tracing.coroutines.util.ShadowTrace
29 import java.io.PrintWriter
30 import java.io.StringWriter
31 import java.util.concurrent.atomic.AtomicInteger
32 import kotlin.coroutines.CoroutineContext
33 import kotlinx.coroutines.CancellationException
34 import kotlinx.coroutines.CoroutineDispatcher
35 import kotlinx.coroutines.CoroutineExceptionHandler
36 import kotlinx.coroutines.CoroutineScope
37 import kotlinx.coroutines.DelicateCoroutinesApi
38 import kotlinx.coroutines.ExperimentalCoroutinesApi
39 import kotlinx.coroutines.TimeoutCancellationException
40 import kotlinx.coroutines.cancel
41 import kotlinx.coroutines.delay
42 import kotlinx.coroutines.isActive
43 import kotlinx.coroutines.launch
44 import kotlinx.coroutines.newSingleThreadContext
45 import kotlinx.coroutines.runBlocking
46 import kotlinx.coroutines.withContext
47 import kotlinx.coroutines.withTimeout
48 import org.junit.After
49 import org.junit.Assert.assertTrue
50 import org.junit.Assert.fail
51 import org.junit.Before
52 import org.junit.ClassRule
53 import org.junit.Rule
54 import org.junit.runner.RunWith
55 import org.robolectric.annotation.Config
56 
57 class InvalidTraceStateException(message: String, cause: Throwable? = null) :
58     AssertionError(message, cause)
59 
60 internal val mainTestDispatcher = newSingleThreadContext("test-main")
61 internal val bgThread1 = newSingleThreadContext("test-bg-1")
62 internal val bgThread2 = newSingleThreadContext("test-bg-2")
63 internal val bgThread3 = newSingleThreadContext("test-bg-3")
64 internal val bgThread4 = newSingleThreadContext("test-bg-4")
65 
66 @RunWith(AndroidJUnit4::class)
67 @Config(shadows = [ShadowTrace::class])
68 abstract class TestBase {
69     companion object {
70         @JvmField
71         @ClassRule
72         val setFlagsClassRule: SetFlagsRule.ClassRule =
73             SetFlagsRule.ClassRule(com.android.systemui.Flags::class.java)
74 
75         @JvmStatic
76         private fun isRobolectricTest(): Boolean {
77             return System.getProperty("java.vm.name") != "Dalvik"
78         }
79     }
80 
81     // TODO(b/339471826): Robolectric does not execute @ClassRule correctly
82     @get:Rule
83     val setFlagsRule: SetFlagsRule =
84         if (isRobolectricTest()) SetFlagsRule() else setFlagsClassRule.createSetFlagsRule()
85 
86     private val eventCounter = AtomicInteger(0)
87     private val allEventCounter = AtomicInteger(0)
88     private val finalEvent = AtomicInteger(INVALID_EVENT)
89     private val allExceptions = mutableListOf<Throwable>()
90     private val assertionErrors = mutableListOf<AssertionError>()
91 
92     /** The scope to be used by the test in [runTest] */
93     val scope: CoroutineScope by lazy { CoroutineScope(extraContext + mainTestDispatcher) }
94 
95     /**
96      * Context passed to the scope used for the test. If the returned [CoroutineContext] contains a
97      * [CoroutineDispatcher] it will be overwritten.
98      */
99     open val extraContext: CoroutineContext by lazy {
100         createCoroutineTracingContext("main", testMode = true)
101     }
102 
103     @Before
104     fun setup() {
105         FakeTraceState.isTracingEnabled = true
106         FakeTraceState.clearAll()
107 
108         // Reset all thread-local state
109         traceThreadLocal.remove()
110         val dispatchers = listOf(mainTestDispatcher, bgThread1, bgThread2, bgThread3, bgThread4)
111         runBlocking { dispatchers.forEach { withContext(it) { traceThreadLocal.remove() } } }
112 
113         // Initialize scope, which is a lazy type:
114         assertTrue(scope.isActive)
115     }
116 
117     @After
118     fun tearDown() {
119         val sw = StringWriter()
120         val pw = PrintWriter(sw)
121 
122         allExceptions.forEach { it.printStackTrace(pw) }
123         assertTrue("Test failed due to unexpected exception\n$sw", allExceptions.isEmpty())
124 
125         assertionErrors.forEach { it.printStackTrace(pw) }
126         assertTrue("Test failed due to incorrect trace sections\n$sw", assertionErrors.isEmpty())
127     }
128 
129     /**
130      * Launches the test on the provided [scope], then uses [runBlocking] to wait for completion.
131      * The test will timeout if it takes longer than 200ms.
132      */
133     protected fun runTest(
134         isExpectedException: ((Throwable) -> Boolean)? = null,
135         finalEvent: Int? = null,
136         totalEvents: Int? = null,
137         block: suspend CoroutineScope.() -> Unit,
138     ) {
139         var foundExpectedException = false
140         try {
141             val job =
142                 scope.launch(
143                     context =
144                         CoroutineExceptionHandler { _, e ->
145                             if (e is CancellationException)
146                                 return@CoroutineExceptionHandler // ignore it
147                             if (isExpectedException != null && isExpectedException(e)) {
148                                 foundExpectedException = true
149                             } else {
150                                 allExceptions.add(e)
151                             }
152                         },
153                     block = block,
154                 )
155 
156             runBlocking {
157                 val timeoutMs = 200L
158                 try {
159                     withTimeout(timeoutMs) { job.join() }
160                 } catch (e: TimeoutCancellationException) {
161                     fail("Timeout running test. Test should complete in less than $timeoutMs ms")
162                     throw e
163                 } finally {
164                     scope.cancel()
165                 }
166             }
167         } finally {
168             if (isExpectedException != null && !foundExpectedException) {
169                 fail("Expected exceptions, but none were thrown")
170             }
171         }
172         if (finalEvent != null) {
173             checkFinalEvent(finalEvent)
174         }
175         if (totalEvents != null) {
176             checkTotalEvents(totalEvents)
177         }
178     }
179 
180     private fun logInvalidTraceState(message: String, throwInsteadOfLog: Boolean = false) {
181         val e = InvalidTraceStateException(message)
182         if (throwInsteadOfLog) {
183             throw e
184         } else {
185             assertionErrors.add(e)
186         }
187     }
188 
189     /**
190      * Same as [expect], but also call [delay] for 1ms, calling [expect] before and after the
191      * suspension point.
192      */
193     protected suspend fun expectD(vararg expectedOpenTraceSections: String) {
194         expect(*expectedOpenTraceSections)
195         delay(1)
196         expect(*expectedOpenTraceSections)
197     }
198 
199     protected fun expectEndsWith(vararg expectedOpenTraceSections: String) {
200         allEventCounter.getAndAdd(1)
201         // Inspect trace output to the fake used for recording android.os.Trace API calls:
202         val actualSections = getOpenTraceSectionsOnCurrentThread()
203         if (expectedOpenTraceSections.size <= actualSections.size) {
204             val lastSections =
205                 actualSections.takeLast(expectedOpenTraceSections.size).toTypedArray()
206             assertTraceSectionsEquals(expectedOpenTraceSections, null, lastSections, null)
207         } else {
208             logInvalidTraceState(
209                 "Invalid length: expected size (${expectedOpenTraceSections.size}) <= actual size (${actualSections.size})"
210             )
211         }
212     }
213 
214     protected fun expectEvent(expectedEvent: Collection<Int>): Int {
215         val previousEvent = eventCounter.getAndAdd(1)
216         val currentEvent = previousEvent + 1
217         if (!expectedEvent.contains(currentEvent)) {
218             logInvalidTraceState(
219                 if (previousEvent == FINAL_EVENT) {
220                     "Expected event ${expectedEvent.prettyPrintList()}, but finish() was already called"
221                 } else {
222                     "Expected event ${expectedEvent.prettyPrintList()}," +
223                         " but the event counter is currently at #$currentEvent"
224                 }
225             )
226         }
227         return currentEvent
228     }
229 
230     /**
231      * Checks the currently active trace sections on the current thread, and optionally checks the
232      * order of operations if [expectedEvent] is not null.
233      */
234     internal fun expectAny(vararg possibleOpenSections: Array<out String>) {
235         allEventCounter.getAndAdd(1)
236         val actualOpenSections = getOpenTraceSectionsOnCurrentThread()
237         val caughtExceptions = mutableListOf<AssertionError>()
238         possibleOpenSections.forEach { expectedSections ->
239             try {
240                 assertTraceSectionsEquals(
241                     expectedSections,
242                     expectedEvent = null,
243                     actualOpenSections,
244                     actualEvent = null,
245                     throwInsteadOfLog = true,
246                 )
247             } catch (e: AssertionError) {
248                 caughtExceptions.add(e)
249             }
250         }
251         if (caughtExceptions.size == possibleOpenSections.size) {
252             val e = caughtExceptions[0]
253             val allLists =
254                 possibleOpenSections.joinToString(separator = ", OR ") { it.prettyPrintList() }
255             assertionErrors.add(
256                 InvalidTraceStateException("Expected $allLists. For example, ${e.message}", e.cause)
257             )
258         }
259     }
260 
261     internal fun expect(vararg expectedOpenTraceSections: String) {
262         expect(null, *expectedOpenTraceSections)
263     }
264 
265     internal fun expect(expectedEvent: Int, vararg expectedOpenTraceSections: String) {
266         expect(listOf(expectedEvent), *expectedOpenTraceSections)
267     }
268 
269     /**
270      * Checks the currently active trace sections on the current thread, and optionally checks the
271      * order of operations if [expectedEvent] is not null.
272      */
273     internal fun expect(possibleEventPos: List<Int>?, vararg expectedOpenTraceSections: String) {
274         var currentEvent: Int? = null
275         allEventCounter.getAndAdd(1)
276         if (possibleEventPos != null) {
277             currentEvent = expectEvent(possibleEventPos)
278         }
279         val actualOpenSections = getOpenTraceSectionsOnCurrentThread()
280         assertTraceSectionsEquals(
281             expectedOpenTraceSections,
282             possibleEventPos,
283             actualOpenSections,
284             currentEvent,
285         )
286     }
287 
288     private fun assertTraceSectionsEquals(
289         expectedOpenTraceSections: Array<out String>,
290         expectedEvent: List<Int>?,
291         actualOpenSections: Array<String>,
292         actualEvent: Int?,
293         throwInsteadOfLog: Boolean = false,
294     ) {
295         val expectedSize = expectedOpenTraceSections.size
296         val actualSize = actualOpenSections.size
297         if (expectedSize != actualSize) {
298             logInvalidTraceState(
299                 createFailureMessage(
300                     expectedOpenTraceSections,
301                     expectedEvent,
302                     actualOpenSections,
303                     actualEvent,
304                     "Size mismatch, expected size $expectedSize but was size $actualSize",
305                 ),
306                 throwInsteadOfLog,
307             )
308         } else {
309             expectedOpenTraceSections.forEachIndexed { n, expected ->
310                 val actualTrace = actualOpenSections[n]
311                 val actual = actualTrace.getTracedName()
312                 if (expected != actual) {
313                     logInvalidTraceState(
314                         createFailureMessage(
315                             expectedOpenTraceSections,
316                             expectedEvent,
317                             actualOpenSections,
318                             actualEvent,
319                             "Differed at index #$n, expected \"$expected\" but was \"$actual\"",
320                         ),
321                         throwInsteadOfLog,
322                     )
323                     return
324                 }
325             }
326         }
327     }
328 
329     private fun createFailureMessage(
330         expectedOpenTraceSections: Array<out String>,
331         expectedEventNumber: List<Int>?,
332         actualOpenSections: Array<String>,
333         actualEventNumber: Int?,
334         extraMessage: String,
335     ): String {
336         val locationMarker =
337             if (expectedEventNumber == null || actualEventNumber == null) ""
338             else if (expectedEventNumber.contains(actualEventNumber))
339                 " at event #$actualEventNumber"
340             else
341                 ", expected event ${expectedEventNumber.prettyPrintList()}, actual event #$actualEventNumber"
342         return """
343                 Incorrect trace$locationMarker. $extraMessage
344                   Expected : {${expectedOpenTraceSections.prettyPrintList()}}
345                   Actual   : {${actualOpenSections.prettyPrintList()}}
346                 """
347             .trimIndent()
348     }
349 
350     private fun checkFinalEvent(expectedEvent: Int): Int {
351         finalEvent.compareAndSet(INVALID_EVENT, expectedEvent)
352         val previousEvent = eventCounter.getAndSet(FINAL_EVENT)
353         if (expectedEvent != previousEvent) {
354             logInvalidTraceState(
355                 "Expected to finish with event #$expectedEvent, but " +
356                     if (previousEvent == FINAL_EVENT)
357                         "finish() was already called with event #${finalEvent.get()}"
358                     else "the event counter is currently at #$previousEvent"
359             )
360         }
361         return previousEvent
362     }
363 
364     private fun checkTotalEvents(totalEvents: Int): Int {
365         allEventCounter.compareAndSet(INVALID_EVENT, totalEvents)
366         val previousEvent = allEventCounter.getAndSet(FINAL_EVENT)
367         if (totalEvents != previousEvent) {
368             logInvalidTraceState(
369                 "Expected test to end with a total of $totalEvents events, but " +
370                     if (previousEvent == FINAL_EVENT)
371                         "finish() was already called at event #${finalEvent.get()}"
372                     else "instead there were $previousEvent events"
373             )
374         }
375         return previousEvent
376     }
377 }
378 
Stringnull379 private fun String.getTracedName(): String =
380     if (startsWith(COROUTINE_EXECUTION))
381     // For strings like "coroutine execution;scope-name;c=1234;p=5678", extract:
382     // "scope-name"
383     substringAfter(";").substringBefore(";")
384     else substringBefore(";")
385 
386 private const val INVALID_EVENT = -1
387 
388 private const val FINAL_EVENT = Int.MIN_VALUE
389 
390 private fun Collection<Int>.prettyPrintList(): String {
391     return if (isEmpty()) ""
392     else if (size == 1) "#${iterator().next()}"
393     else {
394         "{${
395             toList().joinToString(
396                 separator = ", #",
397                 prefix = "#",
398                 postfix = "",
399             ) { it.toString() }
400         }}"
401     }
402 }
403 
Arraynull404 private fun Array<out String>.prettyPrintList(): String {
405     return if (isEmpty()) ""
406     else
407         toList().joinToString(separator = "\", \"", prefix = "\"", postfix = "\"") {
408             it.getTracedName()
409         }
410 }
411