1 /*
<lambda>null2 * Copyright 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 @file:JvmName("SnapshotStateKt")
18 @file:JvmMultifileClass
19
20 package androidx.compose.runtime
21
22 import androidx.collection.MutableObjectIntMap
23 import androidx.collection.ObjectIntMap
24 import androidx.collection.emptyObjectIntMap
25 import androidx.compose.runtime.collection.MutableVector
26 import androidx.compose.runtime.internal.IntRef
27 import androidx.compose.runtime.internal.SnapshotThreadLocal
28 import androidx.compose.runtime.internal.identityHashCode
29 import androidx.compose.runtime.snapshots.Snapshot
30 import androidx.compose.runtime.snapshots.SnapshotId
31 import androidx.compose.runtime.snapshots.SnapshotIdZero
32 import androidx.compose.runtime.snapshots.StateFactoryMarker
33 import androidx.compose.runtime.snapshots.StateObject
34 import androidx.compose.runtime.snapshots.StateObjectImpl
35 import androidx.compose.runtime.snapshots.StateRecord
36 import androidx.compose.runtime.snapshots.current
37 import androidx.compose.runtime.snapshots.currentSnapshot
38 import androidx.compose.runtime.snapshots.newWritableRecord
39 import androidx.compose.runtime.snapshots.sync
40 import androidx.compose.runtime.snapshots.withCurrent
41 import kotlin.jvm.JvmMultifileClass
42 import kotlin.jvm.JvmName
43 import kotlin.math.min
44
45 /**
46 * A [State] that is derived from one or more other states.
47 *
48 * @see derivedStateOf
49 */
50 internal interface DerivedState<T> : State<T> {
51 /** Provides a current [Record]. */
52 val currentRecord: Record<T>
53
54 /**
55 * Mutation policy that controls how changes are handled after state dependencies update. If the
56 * policy is `null`, the derived state update is triggered regardless of the value produced and
57 * it is up to observer to invalidate it correctly.
58 */
59 val policy: SnapshotMutationPolicy<T>?
60
61 interface Record<T> {
62 /**
63 * The value of the derived state retrieved without triggering a notification to read
64 * observers.
65 */
66 val currentValue: T
67
68 /**
69 * Map of the dependencies used to produce [value] or [currentValue] to nested read level.
70 *
71 * This map can be used to determine if the state could affect value of this derived state,
72 * when a [StateObject] appears in the apply observer set.
73 */
74 val dependencies: ObjectIntMap<StateObject>
75 }
76 }
77
78 private val calculationBlockNestedLevel = SnapshotThreadLocal<IntRef>()
79
withCalculationNestedLevelnull80 private inline fun <T> withCalculationNestedLevel(block: (IntRef) -> T): T {
81 val ref =
82 calculationBlockNestedLevel.get() ?: IntRef(0).also { calculationBlockNestedLevel.set(it) }
83 return block(ref)
84 }
85
86 private class DerivedSnapshotState<T>(
87 private val calculation: () -> T,
88 override val policy: SnapshotMutationPolicy<T>?
89 ) : StateObjectImpl(), DerivedState<T> {
90 private var first: ResultRecord<T> = ResultRecord(currentSnapshot().snapshotId)
91
92 class ResultRecord<T>(snapshotId: SnapshotId) :
93 StateRecord(snapshotId), DerivedState.Record<T> {
94 companion object {
95 val Unset = Any()
96 }
97
98 var validSnapshotId: SnapshotId = SnapshotIdZero
99 var validSnapshotWriteCount: Int = 0
100
101 override var dependencies: ObjectIntMap<StateObject> = emptyObjectIntMap()
102 var result: Any? = Unset
103 var resultHash: Int = 0
104
assignnull105 override fun assign(value: StateRecord) {
106 @Suppress("UNCHECKED_CAST") val other = value as ResultRecord<T>
107 dependencies = other.dependencies
108 result = other.result
109 resultHash = other.resultHash
110 }
111
createnull112 override fun create(): StateRecord = create(currentSnapshot().snapshotId)
113
114 override fun create(snapshotId: SnapshotId): StateRecord = ResultRecord<T>(snapshotId)
115
116 fun isValid(derivedState: DerivedState<*>, snapshot: Snapshot): Boolean {
117 val snapshotChanged = sync {
118 validSnapshotId != snapshot.snapshotId ||
119 validSnapshotWriteCount != snapshot.writeCount
120 }
121 val isValid =
122 result !== Unset &&
123 (!snapshotChanged || resultHash == readableHash(derivedState, snapshot))
124
125 if (isValid && snapshotChanged) {
126 sync {
127 validSnapshotId = snapshot.snapshotId
128 validSnapshotWriteCount = snapshot.writeCount
129 }
130 }
131
132 return isValid
133 }
134
readableHashnull135 fun readableHash(derivedState: DerivedState<*>, snapshot: Snapshot): Int {
136 var hash = 7
137 val dependencies = sync { dependencies }
138 if (dependencies.isNotEmpty()) {
139 notifyObservers(derivedState) {
140 dependencies.forEach { stateObject, readLevel ->
141 if (readLevel != 1) {
142 return@forEach
143 }
144
145 // Find the first record without triggering an observer read.
146 val record =
147 if (stateObject is DerivedSnapshotState<*>) {
148 // eagerly access the parent derived states without recording the
149 // read
150 // that way we can be sure derived states in deps were recalculated,
151 // and are updated to the last values
152 stateObject.current(snapshot)
153 } else {
154 current(stateObject.firstStateRecord, snapshot)
155 }
156
157 hash = 31 * hash + identityHashCode(record)
158 hash = 31 * hash + record.snapshotId.hashCode()
159 }
160 }
161 }
162 return hash
163 }
164
165 override val currentValue: T
166 @Suppress("UNCHECKED_CAST") get() = result as T
167 }
168
169 /**
170 * Get current record in snapshot. Forces recalculation if record is invalid to refresh state
171 * value.
172 *
173 * @return latest state record for the derived state.
174 */
currentnull175 fun current(snapshot: Snapshot): StateRecord =
176 currentRecord(current(first, snapshot), snapshot, false, calculation)
177
178 private fun currentRecord(
179 readable: ResultRecord<T>,
180 snapshot: Snapshot,
181 forceDependencyReads: Boolean,
182 calculation: () -> T
183 ): ResultRecord<T> {
184 if (readable.isValid(this, snapshot)) {
185 // If the dependency is not recalculated, emulate nested state reads
186 // for correct invalidation later
187 if (forceDependencyReads) {
188 notifyObservers(this) {
189 val dependencies = readable.dependencies
190 withCalculationNestedLevel { calculationLevelRef ->
191 val invalidationNestedLevel = calculationLevelRef.element
192 dependencies.forEach { dependency, nestedLevel ->
193 calculationLevelRef.element = invalidationNestedLevel + nestedLevel
194 snapshot.readObserver?.invoke(dependency)
195 }
196 calculationLevelRef.element = invalidationNestedLevel
197 }
198 }
199 }
200 return readable
201 }
202
203 val newDependencies = MutableObjectIntMap<StateObject>()
204 val result = withCalculationNestedLevel { calculationLevelRef ->
205 val nestedCalculationLevel = calculationLevelRef.element
206 notifyObservers(this) {
207 calculationLevelRef.element = nestedCalculationLevel + 1
208
209 val result =
210 Snapshot.observe(
211 {
212 if (it === this) error("A derived state calculation cannot read itself")
213 if (it is StateObject) {
214 val readNestedLevel = calculationLevelRef.element
215 newDependencies[it] =
216 min(
217 readNestedLevel - nestedCalculationLevel,
218 newDependencies.getOrDefault(it, Int.MAX_VALUE)
219 )
220 }
221 },
222 null,
223 calculation
224 )
225
226 calculationLevelRef.element = nestedCalculationLevel
227 result
228 }
229 }
230
231 val record = sync {
232 val currentSnapshot = Snapshot.current
233
234 if (
235 readable.result !== ResultRecord.Unset &&
236 @Suppress("UNCHECKED_CAST") policy?.equivalent(result, readable.result as T) ==
237 true
238 ) {
239 readable.dependencies = newDependencies
240 readable.resultHash = readable.readableHash(this, currentSnapshot)
241 readable
242 } else {
243 val writable = first.newWritableRecord(this, currentSnapshot)
244 writable.dependencies = newDependencies
245 writable.resultHash = writable.readableHash(this, currentSnapshot)
246 writable.result = result
247 writable
248 }
249 }
250
251 if (calculationBlockNestedLevel.get()?.element == 0) {
252 Snapshot.notifyObjectsInitialized()
253
254 sync {
255 val currentSnapshot = Snapshot.current
256 record.validSnapshotId = currentSnapshot.snapshotId
257 record.validSnapshotWriteCount = currentSnapshot.writeCount
258 }
259 }
260
261 return record
262 }
263
264 override val firstStateRecord: StateRecord
265 get() = first
266
prependStateRecordnull267 override fun prependStateRecord(value: StateRecord) {
268 @Suppress("UNCHECKED_CAST")
269 first = value as ResultRecord<T>
270 }
271
272 override val value: T
273 get() {
274 // Unlike most state objects, the record list of a derived state can change during a
275 // read
276 // because reading updates the cache. To account for this, instead of calling readable,
277 // which sends the read notification, the read observer is notified directly and current
278 // value is used instead which doesn't notify. This allow the read observer to read the
279 // value and only update the cache once.
280 Snapshot.current.readObserver?.invoke(this)
281 // Read observer could advance the snapshot, so get current snapshot again
282 val snapshot = Snapshot.current
283 val record = current(first, snapshot)
284 @Suppress("UNCHECKED_CAST")
285 return currentRecord(record, snapshot, true, calculation).result as T
286 }
287
288 override val currentRecord: DerivedState.Record<T>
289 get() {
290 val snapshot = Snapshot.current
291 val record = current(first, snapshot)
292 return currentRecord(record, snapshot, false, calculation)
293 }
294
toStringnull295 override fun toString(): String =
296 first.withCurrent { "DerivedState(value=${displayValue()})@${hashCode()}" }
297
298 /**
299 * A function used by the debugger to display the value of the current value of the mutable
300 * state object without triggering read observers.
301 */
302 @Suppress("unused")
303 val debuggerDisplayValue: T?
304 @JvmName("getDebuggerDisplayValue")
305 get() =
<lambda>null306 first.withCurrent {
307 @Suppress("UNCHECKED_CAST")
308 if (it.isValid(this, Snapshot.current)) it.result as T else null
309 }
310
displayValuenull311 private fun displayValue(): String {
312 first.withCurrent {
313 if (it.isValid(this, Snapshot.current)) {
314 return it.result.toString()
315 }
316 return "<Not calculated>"
317 }
318 }
319 }
320
321 /**
322 * Creates a [State] object whose [State.value] is the result of [calculation]. The result of
323 * calculation will be cached in such a way that calling [State.value] repeatedly will not cause
324 * [calculation] to be executed multiple times, but reading [State.value] will cause all [State]
325 * objects that got read during the [calculation] to be read in the current [Snapshot], meaning that
326 * this will correctly subscribe to the derived state objects if the value is being read in an
327 * observed context such as a [Composable] function. Derived states without mutation policy trigger
328 * updates on each dependency change. To avoid invalidation on update, provide suitable
329 * [SnapshotMutationPolicy] through [derivedStateOf] overload.
330 *
331 * @sample androidx.compose.runtime.samples.DerivedStateSample
332 * @param calculation the calculation to create the value this state object represents.
333 */
334 @StateFactoryMarker
derivedStateOfnull335 fun <T> derivedStateOf(
336 calculation: () -> T,
337 ): State<T> = DerivedSnapshotState(calculation, null)
338
339 /**
340 * Creates a [State] object whose [State.value] is the result of [calculation]. The result of
341 * calculation will be cached in such a way that calling [State.value] repeatedly will not cause
342 * [calculation] to be executed multiple times, but reading [State.value] will cause all [State]
343 * objects that got read during the [calculation] to be read in the current [Snapshot], meaning that
344 * this will correctly subscribe to the derived state objects if the value is being read in an
345 * observed context such as a [Composable] function.
346 *
347 * @sample androidx.compose.runtime.samples.DerivedStateSample
348 * @param policy mutation policy to control when changes to the [calculation] result trigger update.
349 * @param calculation the calculation to create the value this state object represents.
350 */
351 @StateFactoryMarker
352 fun <T> derivedStateOf(
353 policy: SnapshotMutationPolicy<T>,
354 calculation: () -> T,
355 ): State<T> = DerivedSnapshotState(calculation, policy)
356
357 /** Observe the recalculations performed by derived states. */
358 internal interface DerivedStateObserver {
359 /** Called before a calculation starts. */
360 fun start(derivedState: DerivedState<*>)
361
362 /** Called after the started calculation is complete. */
363 fun done(derivedState: DerivedState<*>)
364 }
365
366 private val derivedStateObservers = SnapshotThreadLocal<MutableVector<DerivedStateObserver>>()
367
derivedStateObserversnull368 internal fun derivedStateObservers(): MutableVector<DerivedStateObserver> =
369 derivedStateObservers.get()
370 ?: MutableVector<DerivedStateObserver>(0).also { derivedStateObservers.set(it) }
371
notifyObserversnull372 private inline fun <R> notifyObservers(derivedState: DerivedState<*>, block: () -> R): R {
373 val observers = derivedStateObservers()
374 observers.forEach { it.start(derivedState) }
375 return try {
376 block()
377 } finally {
378 observers.forEach { it.done(derivedState) }
379 }
380 }
381
382 /**
383 * Observe the recalculations performed by any derived state that is recalculated during the
384 * execution of [block].
385 *
386 * @param observer called for every calculation of a derived state in the [block].
387 * @param block the block of code to observe.
388 */
observeDerivedStateRecalculationsnull389 internal inline fun <R> observeDerivedStateRecalculations(
390 observer: DerivedStateObserver,
391 block: () -> R
392 ) {
393 val observers = derivedStateObservers()
394 try {
395 observers.add(observer)
396 block()
397 } finally {
398 observers.removeAt(observers.lastIndex)
399 }
400 }
401