1 /*
2  * Copyright 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.compose.runtime.benchmark.state
18 
19 import androidx.benchmark.junit4.BenchmarkRule
20 import androidx.benchmark.junit4.measureRepeatedOnMainThread
21 import androidx.compose.runtime.Applier
22 import androidx.compose.runtime.Composition
23 import androidx.compose.runtime.Recomposer
24 import androidx.compose.runtime.derivedStateOf
25 import androidx.compose.runtime.mutableIntStateOf
26 import androidx.compose.runtime.snapshots.Snapshot
27 import androidx.compose.runtime.snapshots.SnapshotApplyResult
28 import androidx.compose.runtime.snapshots.SnapshotStateObserver
29 import androidx.test.filters.LargeTest
30 import kotlin.coroutines.CoroutineContext
31 import kotlin.coroutines.EmptyCoroutineContext
32 import org.junit.Rule
33 import org.junit.Test
34 import org.junit.runner.RunWith
35 import org.junit.runners.Parameterized
36 
37 @LargeTest
38 @RunWith(Parameterized::class)
39 class ComposeStateReadBenchmark(private val readContext: ReadContext) {
40     enum class ReadContext {
41         Composition,
42         Measure
43     }
44 
45     companion object {
46         private const val MEASURE_OBSERVATION_DEPTH = 5
<lambda>null47         private val OnCommitInvalidatingMeasure: (Any) -> Unit = {}
48 
49         @Parameterized.Parameters(name = "{0}")
50         @JvmStatic
parametersnull51         fun parameters() = arrayOf(ReadContext.Composition, ReadContext.Measure)
52     }
53 
54     @get:Rule val benchmarkRule = BenchmarkRule()
55 
56     @Test
57     fun readState() {
58         val state = mutableIntStateOf(0)
59 
60         benchmarkRead { state.value }
61     }
62 
63     @Test
readDerivedStatenull64     fun readDerivedState() {
65         val stateA = mutableIntStateOf(0)
66         val stateB = mutableIntStateOf(0)
67         val derivedState = derivedStateOf { stateA.value + stateB.value }
68 
69         derivedState.value // precompute result
70 
71         benchmarkRead { derivedState.value }
72     }
73 
74     @Test
readDerivedState_secondReadnull75     fun readDerivedState_secondRead() {
76         val stateA = mutableIntStateOf(0)
77         val stateB = mutableIntStateOf(0)
78         val derivedState = derivedStateOf { stateA.value + stateB.value }
79 
80         derivedState.value // precompute result
81 
82         benchmarkRead(before = { derivedState.value }) { derivedState.value }
83     }
84 
85     @Test
readDerivedState_afterWritenull86     fun readDerivedState_afterWrite() {
87         val stateA = mutableIntStateOf(0)
88         val stateB = mutableIntStateOf(0)
89         val derivedState = derivedStateOf { stateA.value + stateB.value }
90 
91         derivedState.value // precompute result
92 
93         benchmarkRead(before = { stateA.value += 1 }) { derivedState.value }
94     }
95 
96     @Test
readState_afterWritenull97     fun readState_afterWrite() {
98         val stateA = mutableIntStateOf(0)
99 
100         benchmarkRead(before = { stateA.value += 1 }) { stateA.value }
101     }
102 
103     @Test
readState_preinitializednull104     fun readState_preinitialized() {
105         val stateA = mutableIntStateOf(0)
106         val stateB = mutableIntStateOf(0)
107 
108         benchmarkRead(before = { stateA.value }) { stateB.value }
109     }
110 
111     @Test
readDerivedState_preinitializednull112     fun readDerivedState_preinitialized() {
113         val stateA = mutableIntStateOf(0)
114         val stateB = mutableIntStateOf(0)
115 
116         val derivedStateA = derivedStateOf { stateA.value + stateB.value }
117         val derivedStateB = derivedStateOf { stateB.value + stateA.value }
118 
119         benchmarkRead(before = { derivedStateA.value }) { derivedStateB.value }
120     }
121 
benchmarkReadnull122     private fun benchmarkRead(
123         before: () -> Unit = {},
<lambda>null124         after: () -> Unit = {},
125         measure: () -> Unit
126     ) {
127         val benchmarkState = benchmarkRule.getState()
<lambda>null128         benchmarkRule.measureRepeatedOnMainThread {
129             benchmarkState.pauseTiming()
130             runInReadObservationScope {
131                 before()
132                 benchmarkState.resumeTiming()
133 
134                 measure()
135 
136                 benchmarkState.pauseTiming()
137                 after()
138             }
139             benchmarkRule.getState().resumeTiming()
140         }
141     }
142 
runInReadObservationScopenull143     private fun runInReadObservationScope(scopeBlock: () -> Unit) {
144         when (readContext) {
145             ReadContext.Composition -> createComposition().setContent { scopeBlock() }
146             ReadContext.Measure -> {
147                 val snapshot = Snapshot.takeMutableSnapshot()
148                 snapshot.enter {
149                     SnapshotStateObserver { it() }
150                         .apply {
151                             val nodes = List(MEASURE_OBSERVATION_DEPTH) { Any() }
152                             start()
153                             recursiveObserve(nodes, nodes.size, scopeBlock)
154                             stop()
155                         }
156                 }
157                 val applyResult = snapshot.apply()
158                 check(applyResult !is SnapshotApplyResult.Failure) { "Failed to apply snapshot" }
159                 snapshot.dispose()
160             }
161         }
162     }
163 
SnapshotStateObservernull164     private fun SnapshotStateObserver.recursiveObserve(
165         nodes: List<Any>,
166         depth: Int,
167         block: () -> Unit
168     ) {
169         if (depth == 0) {
170             block()
171             return
172         }
173         observeReads(nodes[depth - 1], OnCommitInvalidatingMeasure) {
174             recursiveObserve(nodes, depth - 1, block)
175         }
176     }
177 
createCompositionnull178     private fun createComposition(
179         coroutineContext: CoroutineContext = EmptyCoroutineContext
180     ): Composition {
181         val applier = UnitApplier()
182         val recomposer = Recomposer(coroutineContext)
183         return Composition(applier, recomposer)
184     }
185 
186     private class UnitApplier : Applier<Unit> {
187         override val current: Unit = Unit
188 
clearnull189         override fun clear() {}
190 
movenull191         override fun move(from: Int, to: Int, count: Int) {}
192 
removenull193         override fun remove(index: Int, count: Int) {}
194 
upnull195         override fun up() {}
196 
insertTopDownnull197         override fun insertTopDown(index: Int, instance: Unit) {}
198 
insertBottomUpnull199         override fun insertBottomUp(index: Int, instance: Unit) {}
200 
downnull201         override fun down(node: Unit) {}
202     }
203 }
204