• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.tools.flicker.junit
18 
19 import android.app.Instrumentation
20 import android.device.collectors.util.SendToInstrumentation
21 import android.os.Bundle
22 import android.tools.Scenario
23 import android.tools.ScenarioBuilder
24 import android.tools.flicker.FlickerService
25 import android.tools.flicker.FlickerServiceResultsCollector.Companion.FLICKER_ASSERTIONS_COUNT_KEY
26 import android.tools.flicker.ScenarioInstance
27 import android.tools.flicker.Utils.captureTrace
28 import android.tools.flicker.annotation.ExpectedScenarios
29 import android.tools.flicker.annotation.FlickerConfigProvider
30 import android.tools.flicker.assertions.ScenarioAssertion
31 import android.tools.flicker.config.FlickerConfig
32 import android.tools.flicker.config.ScenarioId
33 import android.tools.io.Reader
34 import android.tools.traces.getDefaultFlickerOutputDir
35 import android.tools.traces.now
36 import androidx.test.platform.app.InstrumentationRegistry
37 import com.google.common.truth.Truth
38 import java.lang.reflect.Method
39 import org.junit.After
40 import org.junit.Before
41 import org.junit.Rule
42 import org.junit.Test
43 import org.junit.rules.TestRule
44 import org.junit.runner.Description
45 import org.junit.runner.Description.createTestDescription
46 import org.junit.runners.model.FrameworkMember
47 import org.junit.runners.model.FrameworkMethod
48 import org.junit.runners.model.MemberValueConsumer
49 import org.junit.runners.model.Statement
50 import org.junit.runners.model.TestClass
51 
52 class FlickerServiceDecorator(
53     testClass: TestClass,
54     val paramString: String?,
55     private val skipNonBlocking: Boolean,
56     inner: IFlickerJUnitDecorator?,
57     instrumentation: Instrumentation = InstrumentationRegistry.getInstrumentation(),
58     flickerService: FlickerService? = null,
59 ) : AbstractFlickerRunnerDecorator(testClass, inner, instrumentation) {
60     private val flickerService by lazy { flickerService ?: FlickerService(getFlickerConfig()) }
61 
62     private val testClassName =
63         ScenarioBuilder().forClass("${testClass.name}${paramString ?: ""}").build()
64 
65     override fun getChildDescription(method: FrameworkMethod): Description {
66         return if (isMethodHandledByDecorator(method)) {
67             createTestDescription(testClass.javaClass, method.name, *method.annotations)
68         } else {
69             inner?.getChildDescription(method) ?: error("No child descriptor found")
70         }
71     }
72 
73     private val flickerServiceMethodsFor =
74         mutableMapOf<FrameworkMethod, Collection<InjectedTestCase>>()
75     private val innerMethodsResults = mutableMapOf<FrameworkMethod, Throwable?>()
76 
77     private fun getTestRules(): MutableList<TestRule> {
78         val collector = RuleCollector<TestRule>()
79         val instance = testClass.onlyConstructor.newInstance()
80         testClass.collectAnnotatedMethodValues<TestRule>(
81             instance,
82             Rule::class.java,
83             TestRule::class.java,
84             collector,
85         )
86         testClass.collectAnnotatedFieldValues<TestRule>(
87             instance,
88             Rule::class.java,
89             TestRule::class.java,
90             collector,
91         )
92         return collector.result
93     }
94 
95     class RuleCollector<T> internal constructor() : MemberValueConsumer<T> {
96         val result: MutableList<T> = ArrayList()
97 
98         override fun accept(member: FrameworkMember<*>, value: T) {
99             this.result.add(value)
100         }
101     }
102 
103     override fun getTestMethods(test: Any): List<FrameworkMethod> {
104         val innerMethods =
105             inner?.getTestMethods(test)
106                 ?: error("FlickerServiceDecorator requires a non-null inner decorator")
107         val testMethods = innerMethods.toMutableList()
108 
109         val testRules = getTestRules()
110 
111         val ruleContainer = RuleContainer()
112         for (rule in testRules) {
113             ruleContainer.add(rule)
114         }
115 
116         if (shouldComputeTestMethods()) {
117             for (method in innerMethods) {
118                 if (!innerMethodsResults.containsKey(method)) {
119                     val description = createTestDescription(testClass.javaClass.name, method.name)
120                     val statement =
121                         object : Statement() {
122                             override fun evaluate() {
123                                 var methodResult: Throwable? =
124                                     null // TODO: Maybe don't use null but wrap in another object
125                                 val reader =
126                                     captureTrace(testClassName, getDefaultFlickerOutputDir()) {
127                                         writer ->
128                                         try {
129                                             Utils.notifyRunnerProgress(
130                                                 testClassName,
131                                                 "Running setup",
132                                                 instrumentation,
133                                             )
134                                             val befores =
135                                                 testClass.getAnnotatedMethods(Before::class.java)
136                                             befores.forEach { it.invokeExplosively(test) }
137 
138                                             Utils.notifyRunnerProgress(
139                                                 testClassName,
140                                                 "Running transition",
141                                                 instrumentation,
142                                             )
143 
144                                             val traceStartTime = now()
145                                             Utils.notifyRunnerProgress(
146                                                 testClassName,
147                                                 "Setting trace start time to :: $traceStartTime",
148                                                 instrumentation,
149                                             )
150 
151                                             writer.setTransitionStartTime(traceStartTime)
152                                             method.invokeExplosively(test)
153 
154                                             val traceEndTime = now()
155                                             Utils.notifyRunnerProgress(
156                                                 testClassName,
157                                                 "Setting trace end time to :: $traceEndTime",
158                                                 instrumentation,
159                                             )
160                                             writer.setTransitionEndTime(traceEndTime)
161 
162                                             Utils.notifyRunnerProgress(
163                                                 testClassName,
164                                                 "Running teardown",
165                                                 instrumentation,
166                                             )
167                                             val afters =
168                                                 testClass.getAnnotatedMethods(After::class.java)
169                                             afters.forEach { it.invokeExplosively(test) }
170                                         } catch (e: Throwable) {
171                                             methodResult = e
172                                         } finally {
173                                             innerMethodsResults[method] = methodResult
174                                         }
175                                     }
176                                 if (methodResult == null) {
177                                     Utils.notifyRunnerProgress(
178                                         testClassName,
179                                         "Computing Flicker service tests",
180                                         instrumentation,
181                                     )
182                                     try {
183                                         flickerServiceMethodsFor[method] =
184                                             computeFlickerServiceTests(
185                                                 reader,
186                                                 testClassName,
187                                                 method,
188                                             )
189                                     } catch (e: Throwable) {
190                                         // Failed to compute flicker service methods
191                                         innerMethodsResults[method] = e
192                                     }
193                                 }
194                             }
195                         }
196                     ruleContainer
197                         .apply(
198                             method,
199                             description,
200                             testClass.onlyConstructor.newInstance(),
201                             statement,
202                         )
203                         .evaluate()
204                 }
205 
206                 if (innerMethodsResults[method] == null) {
207                     testMethods.addAll(flickerServiceMethodsFor[method]!!)
208                 }
209             }
210         }
211 
212         return testMethods
213     }
214 
215     // TODO: Common with LegacyFlickerServiceDecorator, might be worth extracting this up
216     private fun shouldComputeTestMethods(): Boolean {
217         // Don't compute when called from validateInstanceMethods since this will fail
218         // as the parameters will not be set. And AndroidLogOnlyBuilder is a non-executing runner
219         // used to run tests in dry-run mode, so we don't want to execute in flicker transition in
220         // that case either.
221         val stackTrace = Thread.currentThread().stackTrace
222         val isDryRun =
223             stackTrace.any { it.methodName == "validateInstanceMethods" } ||
224                 stackTrace.any {
225                     it.className == "androidx.test.internal.runner.AndroidLogOnlyBuilder"
226                 } ||
227                 stackTrace.any {
228                     it.className == "androidx.test.internal.runner.NonExecutingRunner"
229                 }
230 
231         return !isDryRun
232     }
233 
234     override fun getMethodInvoker(method: FrameworkMethod, test: Any): Statement {
235         return object : Statement() {
236             @Throws(Throwable::class)
237             override fun evaluate() {
238                 val description = getChildDescription(method)
239                 if (isMethodHandledByDecorator(method)) {
240                     (method as InjectedTestCase).execute(description)
241                 } else {
242                     if (innerMethodsResults.containsKey(method)) {
243                         innerMethodsResults[method]?.let { throw it }
244                     } else {
245                         inner?.getMethodInvoker(method, test)?.evaluate()
246                     }
247                 }
248             }
249         }
250     }
251 
252     override fun doValidateInstanceMethods(): List<Throwable> {
253         val errors = super.doValidateInstanceMethods().toMutableList()
254 
255         val testMethods = testClass.getAnnotatedMethods(Test::class.java)
256         if (testMethods.size > 1) {
257             errors.add(IllegalArgumentException("Only one @Test annotated method is supported"))
258         }
259 
260         // Validate Registry provider
261         val flickerConfigProviderProviderFunctions =
262             testClass.getAnnotatedMethods(FlickerConfigProvider::class.java).filter {
263                 it.isStatic && it.isPublic
264             }
265         if (flickerConfigProviderProviderFunctions.isEmpty()) {
266             errors.add(
267                 IllegalArgumentException(
268                     "A public static function returning a " +
269                         "${FlickerConfig::class.simpleName} annotated with " +
270                         "@${FlickerConfigProvider::class.simpleName} should be provided."
271                 )
272             )
273         } else if (flickerConfigProviderProviderFunctions.size > 1) {
274             errors.add(
275                 IllegalArgumentException(
276                     "Only one @${FlickerConfigProvider::class.simpleName} " +
277                         "annotated method is supported."
278                 )
279             )
280         } else if (
281             flickerConfigProviderProviderFunctions.first().returnType.name !=
282                 FlickerConfig::class.qualifiedName
283         ) {
284             errors.add(
285                 IllegalArgumentException(
286                     "Expected method annotated with " +
287                         "@${FlickerConfig::class.simpleName} to return " +
288                         "${FlickerConfig::class.qualifiedName} but was " +
289                         "${flickerConfigProviderProviderFunctions.first().returnType.name} instead."
290                 )
291             )
292         } else {
293             // Validate @ExpectedScenarios annotation
294             val expectedScenarioAnnotations =
295                 testClass.getAnnotatedMethods(ExpectedScenarios::class.java).map {
296                     it.getAnnotation(ExpectedScenarios::class.java)
297                 }
298             val registeredScenarios = getFlickerConfig().getEntries().map { it.scenarioId.name }
299             for (expectedScenarioAnnotation in expectedScenarioAnnotations) {
300                 for (expectedScenario in expectedScenarioAnnotation.expectedScenarios) {
301                     val scenarioRegistered = registeredScenarios.contains(expectedScenario)
302                     if (!scenarioRegistered) {
303                         errors.add(
304                             IllegalArgumentException(
305                                 "Provided scenarios that are not registered to " +
306                                     "@${ExpectedScenarios::class.simpleName} annotation. " +
307                                     "$expectedScenario is not registered in the " +
308                                     "${FlickerConfig::class.simpleName}. Available scenarios " +
309                                     "are [${registeredScenarios.joinToString()}]."
310                             )
311                         )
312                     }
313                 }
314             }
315         }
316 
317         return errors
318     }
319 
320     private fun getFlickerConfig(): FlickerConfig {
321         require(testClass.getAnnotatedMethods(ExpectedScenarios::class.java).size == 1) {
322             "@ExpectedScenarios missing. " +
323                 "getFlickerConfig() may have been called before validation."
324         }
325 
326         val flickerConfigProviderProviderFunction =
327             testClass.getAnnotatedMethods(FlickerConfigProvider::class.java).first()
328         // TODO: Pass the correct target
329         return flickerConfigProviderProviderFunction.invokeExplosively(testClass) as FlickerConfig
330     }
331 
332     override fun shouldRunBeforeOn(method: FrameworkMethod): Boolean {
333         return false
334     }
335 
336     override fun shouldRunAfterOn(method: FrameworkMethod): Boolean {
337         return false
338     }
339 
340     private fun isMethodHandledByDecorator(method: FrameworkMethod): Boolean {
341         return method is InjectedTestCase && method.injectedBy == this
342     }
343 
344     private fun computeFlickerServiceTests(
345         reader: Reader,
346         testScenario: Scenario,
347         method: FrameworkMethod,
348     ): Collection<InjectedTestCase> {
349         val expectedScenarios =
350             (method.annotations
351                     .filterIsInstance<ExpectedScenarios>()
352                     .firstOrNull()
353                     ?.expectedScenarios ?: emptyArray())
354                 .map { ScenarioId(it) }
355                 .toSet()
356 
357         return getFaasTestCases(
358             testScenario,
359             expectedScenarios,
360             paramString ?: "",
361             reader,
362             flickerService,
363             instrumentation,
364             this,
365             skipNonBlocking,
366         )
367     }
368 
369     companion object {
370         private fun getDetectedScenarios(
371             testScenario: Scenario,
372             reader: Reader,
373             flickerService: FlickerService,
374         ): Collection<ScenarioId> {
375             val groupedAssertions = getGroupedAssertions(testScenario, reader, flickerService)
376             return groupedAssertions.keys.map { it.type }.distinct()
377         }
378 
379         private fun getCachedResultMethod(): Method {
380             return InjectedTestCase::class.java.getMethod("execute", Description::class.java)
381         }
382 
383         private fun getGroupedAssertions(
384             testScenario: Scenario,
385             reader: Reader,
386             flickerService: FlickerService,
387         ): Map<ScenarioInstance, Collection<ScenarioAssertion>> {
388             if (
389                 !android.tools.flicker.datastore.DataStore.containsFlickerServiceResult(
390                     testScenario
391                 )
392             ) {
393                 val detectedScenarios = flickerService.detectScenarios(reader)
394                 val groupedAssertions = detectedScenarios.associateWith { it.generateAssertions() }
395                 android.tools.flicker.datastore.DataStore.addFlickerServiceAssertions(
396                     testScenario,
397                     groupedAssertions,
398                 )
399             }
400 
401             return android.tools.flicker.datastore.DataStore.getFlickerServiceAssertions(
402                 testScenario
403             )
404         }
405 
406         internal fun getFaasTestCases(
407             testScenario: Scenario,
408             expectedScenarios: Set<ScenarioId>,
409             paramString: String,
410             reader: Reader,
411             flickerService: FlickerService,
412             instrumentation: Instrumentation,
413             caller: IFlickerJUnitDecorator,
414             skipNonBlocking: Boolean,
415         ): Collection<InjectedTestCase> {
416             val groupedAssertions = getGroupedAssertions(testScenario, reader, flickerService)
417             val organizedScenarioInstances = groupedAssertions.keys.groupBy { it.type }
418 
419             val faasTestCases = mutableListOf<FlickerServiceCachedTestCase>()
420             organizedScenarioInstances.values.forEachIndexed {
421                 scenarioTypesIndex,
422                 scenarioInstancesOfSameType ->
423                 scenarioInstancesOfSameType.forEachIndexed { scenarioInstanceIndex, scenarioInstance
424                     ->
425                     val assertionsForScenarioInstance = groupedAssertions[scenarioInstance]!!
426 
427                     assertionsForScenarioInstance.forEach {
428                         faasTestCases.add(
429                             FlickerServiceCachedTestCase(
430                                 assertion = it,
431                                 method = getCachedResultMethod(),
432                                 skipNonBlocking = skipNonBlocking,
433                                 isLast =
434                                     organizedScenarioInstances.values.size == scenarioTypesIndex &&
435                                         scenarioInstancesOfSameType.size == scenarioInstanceIndex,
436                                 injectedBy = caller,
437                                 paramString =
438                                     "${paramString}${
439                                     if (scenarioInstancesOfSameType.size > 1) {
440                                         "_${scenarioInstanceIndex + 1}"
441                                     } else {
442                                         ""
443                                     }}",
444                                 instrumentation = instrumentation,
445                             )
446                         )
447                     }
448                 }
449             }
450 
451             val detectedScenarioTestCase =
452                 AnonymousInjectedTestCase(
453                     getCachedResultMethod(),
454                     "FaaS_DetectedExpectedScenarios$paramString",
455                     injectedBy = caller,
456                 ) {
457                     val metricBundle = Bundle()
458                     metricBundle.putString(FLICKER_ASSERTIONS_COUNT_KEY, "${faasTestCases.size}")
459                     SendToInstrumentation.sendBundle(instrumentation, metricBundle)
460 
461                     Truth.assertThat(getDetectedScenarios(testScenario, reader, flickerService))
462                         .containsAtLeastElementsIn(expectedScenarios)
463                 }
464 
465             return faasTestCases + listOf(detectedScenarioTestCase)
466         }
467     }
468 }
469