1 /*
<lambda>null2  * Copyright 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 @file:JvmName("SnapshotStateKt")
18 @file:JvmMultifileClass
19 
20 package androidx.compose.runtime
21 
22 import androidx.collection.MutableObjectIntMap
23 import androidx.collection.ObjectIntMap
24 import androidx.collection.emptyObjectIntMap
25 import androidx.compose.runtime.collection.MutableVector
26 import androidx.compose.runtime.internal.IntRef
27 import androidx.compose.runtime.internal.SnapshotThreadLocal
28 import androidx.compose.runtime.internal.identityHashCode
29 import androidx.compose.runtime.snapshots.Snapshot
30 import androidx.compose.runtime.snapshots.SnapshotId
31 import androidx.compose.runtime.snapshots.SnapshotIdZero
32 import androidx.compose.runtime.snapshots.StateFactoryMarker
33 import androidx.compose.runtime.snapshots.StateObject
34 import androidx.compose.runtime.snapshots.StateObjectImpl
35 import androidx.compose.runtime.snapshots.StateRecord
36 import androidx.compose.runtime.snapshots.current
37 import androidx.compose.runtime.snapshots.currentSnapshot
38 import androidx.compose.runtime.snapshots.newWritableRecord
39 import androidx.compose.runtime.snapshots.sync
40 import androidx.compose.runtime.snapshots.withCurrent
41 import kotlin.jvm.JvmMultifileClass
42 import kotlin.jvm.JvmName
43 import kotlin.math.min
44 
45 /**
46  * A [State] that is derived from one or more other states.
47  *
48  * @see derivedStateOf
49  */
50 internal interface DerivedState<T> : State<T> {
51     /** Provides a current [Record]. */
52     val currentRecord: Record<T>
53 
54     /**
55      * Mutation policy that controls how changes are handled after state dependencies update. If the
56      * policy is `null`, the derived state update is triggered regardless of the value produced and
57      * it is up to observer to invalidate it correctly.
58      */
59     val policy: SnapshotMutationPolicy<T>?
60 
61     interface Record<T> {
62         /**
63          * The value of the derived state retrieved without triggering a notification to read
64          * observers.
65          */
66         val currentValue: T
67 
68         /**
69          * Map of the dependencies used to produce [value] or [currentValue] to nested read level.
70          *
71          * This map can be used to determine if the state could affect value of this derived state,
72          * when a [StateObject] appears in the apply observer set.
73          */
74         val dependencies: ObjectIntMap<StateObject>
75     }
76 }
77 
78 private val calculationBlockNestedLevel = SnapshotThreadLocal<IntRef>()
79 
withCalculationNestedLevelnull80 private inline fun <T> withCalculationNestedLevel(block: (IntRef) -> T): T {
81     val ref =
82         calculationBlockNestedLevel.get() ?: IntRef(0).also { calculationBlockNestedLevel.set(it) }
83     return block(ref)
84 }
85 
86 private class DerivedSnapshotState<T>(
87     private val calculation: () -> T,
88     override val policy: SnapshotMutationPolicy<T>?
89 ) : StateObjectImpl(), DerivedState<T> {
90     private var first: ResultRecord<T> = ResultRecord(currentSnapshot().snapshotId)
91 
92     class ResultRecord<T>(snapshotId: SnapshotId) :
93         StateRecord(snapshotId), DerivedState.Record<T> {
94         companion object {
95             val Unset = Any()
96         }
97 
98         var validSnapshotId: SnapshotId = SnapshotIdZero
99         var validSnapshotWriteCount: Int = 0
100 
101         override var dependencies: ObjectIntMap<StateObject> = emptyObjectIntMap()
102         var result: Any? = Unset
103         var resultHash: Int = 0
104 
assignnull105         override fun assign(value: StateRecord) {
106             @Suppress("UNCHECKED_CAST") val other = value as ResultRecord<T>
107             dependencies = other.dependencies
108             result = other.result
109             resultHash = other.resultHash
110         }
111 
createnull112         override fun create(): StateRecord = create(currentSnapshot().snapshotId)
113 
114         override fun create(snapshotId: SnapshotId): StateRecord = ResultRecord<T>(snapshotId)
115 
116         fun isValid(derivedState: DerivedState<*>, snapshot: Snapshot): Boolean {
117             val snapshotChanged = sync {
118                 validSnapshotId != snapshot.snapshotId ||
119                     validSnapshotWriteCount != snapshot.writeCount
120             }
121             val isValid =
122                 result !== Unset &&
123                     (!snapshotChanged || resultHash == readableHash(derivedState, snapshot))
124 
125             if (isValid && snapshotChanged) {
126                 sync {
127                     validSnapshotId = snapshot.snapshotId
128                     validSnapshotWriteCount = snapshot.writeCount
129                 }
130             }
131 
132             return isValid
133         }
134 
readableHashnull135         fun readableHash(derivedState: DerivedState<*>, snapshot: Snapshot): Int {
136             var hash = 7
137             val dependencies = sync { dependencies }
138             if (dependencies.isNotEmpty()) {
139                 notifyObservers(derivedState) {
140                     dependencies.forEach { stateObject, readLevel ->
141                         if (readLevel != 1) {
142                             return@forEach
143                         }
144 
145                         // Find the first record without triggering an observer read.
146                         val record =
147                             if (stateObject is DerivedSnapshotState<*>) {
148                                 // eagerly access the parent derived states without recording the
149                                 // read
150                                 // that way we can be sure derived states in deps were recalculated,
151                                 // and are updated to the last values
152                                 stateObject.current(snapshot)
153                             } else {
154                                 current(stateObject.firstStateRecord, snapshot)
155                             }
156 
157                         hash = 31 * hash + identityHashCode(record)
158                         hash = 31 * hash + record.snapshotId.hashCode()
159                     }
160                 }
161             }
162             return hash
163         }
164 
165         override val currentValue: T
166             @Suppress("UNCHECKED_CAST") get() = result as T
167     }
168 
169     /**
170      * Get current record in snapshot. Forces recalculation if record is invalid to refresh state
171      * value.
172      *
173      * @return latest state record for the derived state.
174      */
currentnull175     fun current(snapshot: Snapshot): StateRecord =
176         currentRecord(current(first, snapshot), snapshot, false, calculation)
177 
178     private fun currentRecord(
179         readable: ResultRecord<T>,
180         snapshot: Snapshot,
181         forceDependencyReads: Boolean,
182         calculation: () -> T
183     ): ResultRecord<T> {
184         if (readable.isValid(this, snapshot)) {
185             // If the dependency is not recalculated, emulate nested state reads
186             // for correct invalidation later
187             if (forceDependencyReads) {
188                 notifyObservers(this) {
189                     val dependencies = readable.dependencies
190                     withCalculationNestedLevel { calculationLevelRef ->
191                         val invalidationNestedLevel = calculationLevelRef.element
192                         dependencies.forEach { dependency, nestedLevel ->
193                             calculationLevelRef.element = invalidationNestedLevel + nestedLevel
194                             snapshot.readObserver?.invoke(dependency)
195                         }
196                         calculationLevelRef.element = invalidationNestedLevel
197                     }
198                 }
199             }
200             return readable
201         }
202 
203         val newDependencies = MutableObjectIntMap<StateObject>()
204         val result = withCalculationNestedLevel { calculationLevelRef ->
205             val nestedCalculationLevel = calculationLevelRef.element
206             notifyObservers(this) {
207                 calculationLevelRef.element = nestedCalculationLevel + 1
208 
209                 val result =
210                     Snapshot.observe(
211                         {
212                             if (it === this) error("A derived state calculation cannot read itself")
213                             if (it is StateObject) {
214                                 val readNestedLevel = calculationLevelRef.element
215                                 newDependencies[it] =
216                                     min(
217                                         readNestedLevel - nestedCalculationLevel,
218                                         newDependencies.getOrDefault(it, Int.MAX_VALUE)
219                                     )
220                             }
221                         },
222                         null,
223                         calculation
224                     )
225 
226                 calculationLevelRef.element = nestedCalculationLevel
227                 result
228             }
229         }
230 
231         val record = sync {
232             val currentSnapshot = Snapshot.current
233 
234             if (
235                 readable.result !== ResultRecord.Unset &&
236                     @Suppress("UNCHECKED_CAST") policy?.equivalent(result, readable.result as T) ==
237                         true
238             ) {
239                 readable.dependencies = newDependencies
240                 readable.resultHash = readable.readableHash(this, currentSnapshot)
241                 readable
242             } else {
243                 val writable = first.newWritableRecord(this, currentSnapshot)
244                 writable.dependencies = newDependencies
245                 writable.resultHash = writable.readableHash(this, currentSnapshot)
246                 writable.result = result
247                 writable
248             }
249         }
250 
251         if (calculationBlockNestedLevel.get()?.element == 0) {
252             Snapshot.notifyObjectsInitialized()
253 
254             sync {
255                 val currentSnapshot = Snapshot.current
256                 record.validSnapshotId = currentSnapshot.snapshotId
257                 record.validSnapshotWriteCount = currentSnapshot.writeCount
258             }
259         }
260 
261         return record
262     }
263 
264     override val firstStateRecord: StateRecord
265         get() = first
266 
prependStateRecordnull267     override fun prependStateRecord(value: StateRecord) {
268         @Suppress("UNCHECKED_CAST")
269         first = value as ResultRecord<T>
270     }
271 
272     override val value: T
273         get() {
274             // Unlike most state objects, the record list of a derived state can change during a
275             // read
276             // because reading updates the cache. To account for this, instead of calling readable,
277             // which sends the read notification, the read observer is notified directly and current
278             // value is used instead which doesn't notify. This allow the read observer to read the
279             // value and only update the cache once.
280             Snapshot.current.readObserver?.invoke(this)
281             // Read observer could advance the snapshot, so get current snapshot again
282             val snapshot = Snapshot.current
283             val record = current(first, snapshot)
284             @Suppress("UNCHECKED_CAST")
285             return currentRecord(record, snapshot, true, calculation).result as T
286         }
287 
288     override val currentRecord: DerivedState.Record<T>
289         get() {
290             val snapshot = Snapshot.current
291             val record = current(first, snapshot)
292             return currentRecord(record, snapshot, false, calculation)
293         }
294 
toStringnull295     override fun toString(): String =
296         first.withCurrent { "DerivedState(value=${displayValue()})@${hashCode()}" }
297 
298     /**
299      * A function used by the debugger to display the value of the current value of the mutable
300      * state object without triggering read observers.
301      */
302     @Suppress("unused")
303     val debuggerDisplayValue: T?
304         @JvmName("getDebuggerDisplayValue")
305         get() =
<lambda>null306             first.withCurrent {
307                 @Suppress("UNCHECKED_CAST")
308                 if (it.isValid(this, Snapshot.current)) it.result as T else null
309             }
310 
displayValuenull311     private fun displayValue(): String {
312         first.withCurrent {
313             if (it.isValid(this, Snapshot.current)) {
314                 return it.result.toString()
315             }
316             return "<Not calculated>"
317         }
318     }
319 }
320 
321 /**
322  * Creates a [State] object whose [State.value] is the result of [calculation]. The result of
323  * calculation will be cached in such a way that calling [State.value] repeatedly will not cause
324  * [calculation] to be executed multiple times, but reading [State.value] will cause all [State]
325  * objects that got read during the [calculation] to be read in the current [Snapshot], meaning that
326  * this will correctly subscribe to the derived state objects if the value is being read in an
327  * observed context such as a [Composable] function. Derived states without mutation policy trigger
328  * updates on each dependency change. To avoid invalidation on update, provide suitable
329  * [SnapshotMutationPolicy] through [derivedStateOf] overload.
330  *
331  * @sample androidx.compose.runtime.samples.DerivedStateSample
332  * @param calculation the calculation to create the value this state object represents.
333  */
334 @StateFactoryMarker
derivedStateOfnull335 fun <T> derivedStateOf(
336     calculation: () -> T,
337 ): State<T> = DerivedSnapshotState(calculation, null)
338 
339 /**
340  * Creates a [State] object whose [State.value] is the result of [calculation]. The result of
341  * calculation will be cached in such a way that calling [State.value] repeatedly will not cause
342  * [calculation] to be executed multiple times, but reading [State.value] will cause all [State]
343  * objects that got read during the [calculation] to be read in the current [Snapshot], meaning that
344  * this will correctly subscribe to the derived state objects if the value is being read in an
345  * observed context such as a [Composable] function.
346  *
347  * @sample androidx.compose.runtime.samples.DerivedStateSample
348  * @param policy mutation policy to control when changes to the [calculation] result trigger update.
349  * @param calculation the calculation to create the value this state object represents.
350  */
351 @StateFactoryMarker
352 fun <T> derivedStateOf(
353     policy: SnapshotMutationPolicy<T>,
354     calculation: () -> T,
355 ): State<T> = DerivedSnapshotState(calculation, policy)
356 
357 /** Observe the recalculations performed by derived states. */
358 internal interface DerivedStateObserver {
359     /** Called before a calculation starts. */
360     fun start(derivedState: DerivedState<*>)
361 
362     /** Called after the started calculation is complete. */
363     fun done(derivedState: DerivedState<*>)
364 }
365 
366 private val derivedStateObservers = SnapshotThreadLocal<MutableVector<DerivedStateObserver>>()
367 
derivedStateObserversnull368 internal fun derivedStateObservers(): MutableVector<DerivedStateObserver> =
369     derivedStateObservers.get()
370         ?: MutableVector<DerivedStateObserver>(0).also { derivedStateObservers.set(it) }
371 
notifyObserversnull372 private inline fun <R> notifyObservers(derivedState: DerivedState<*>, block: () -> R): R {
373     val observers = derivedStateObservers()
374     observers.forEach { it.start(derivedState) }
375     return try {
376         block()
377     } finally {
378         observers.forEach { it.done(derivedState) }
379     }
380 }
381 
382 /**
383  * Observe the recalculations performed by any derived state that is recalculated during the
384  * execution of [block].
385  *
386  * @param observer called for every calculation of a derived state in the [block].
387  * @param block the block of code to observe.
388  */
observeDerivedStateRecalculationsnull389 internal inline fun <R> observeDerivedStateRecalculations(
390     observer: DerivedStateObserver,
391     block: () -> R
392 ) {
393     val observers = derivedStateObservers()
394     try {
395         observers.add(observer)
396         block()
397     } finally {
398         observers.removeAt(observers.lastIndex)
399     }
400 }
401