1 /*
<lambda>null2  * Copyright 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.compose.runtime.snapshots
18 
19 import androidx.collection.MutableObjectIntMap
20 import androidx.collection.MutableScatterMap
21 import androidx.collection.MutableScatterSet
22 import androidx.compose.runtime.DerivedState
23 import androidx.compose.runtime.DerivedStateObserver
24 import androidx.compose.runtime.TestOnly
25 import androidx.compose.runtime.collection.ScopeMap
26 import androidx.compose.runtime.collection.fastForEach
27 import androidx.compose.runtime.collection.mutableVectorOf
28 import androidx.compose.runtime.composeRuntimeError
29 import androidx.compose.runtime.internal.AtomicReference
30 import androidx.compose.runtime.internal.currentThreadId
31 import androidx.compose.runtime.internal.currentThreadName
32 import androidx.compose.runtime.observeDerivedStateRecalculations
33 import androidx.compose.runtime.platform.makeSynchronizedObject
34 import androidx.compose.runtime.platform.synchronized
35 import androidx.compose.runtime.requirePrecondition
36 import androidx.compose.runtime.structuralEqualityPolicy
37 
38 /**
39  * Helper class to efficiently observe snapshot state reads. See [observeReads] for more details.
40  *
41  * NOTE: This class is not thread-safe, so implementations should not reuse observer between
42  * different threads to avoid race conditions.
43  */
44 @Suppress("NotCloseable") // we can't implement AutoCloseable from commonMain
45 class SnapshotStateObserver(private val onChangedExecutor: (callback: () -> Unit) -> Unit) {
46     private val pendingChanges = AtomicReference<Any?>(null)
47     private var sendingNotifications = false
48 
49     private val applyObserver: (Set<Any>, Snapshot) -> Unit = { applied, _ ->
50         addChanges(applied)
51         if (drainChanges()) sendNotifications()
52     }
53 
54     /**
55      * Drain the pending changes from the pending changes queue invalidating any scope maps that
56      * contain objects in any of the sets. Return true if changes were found.
57      *
58      * This immediately returns with false if notifications are already being sent. It is the
59      * responsibility of any user of this function to ensure that the queue is re-checked after
60      * dispatching notifications.
61      */
62     private fun drainChanges(): Boolean {
63         // Don't modify the scope maps while notifications are being sent either by the caller or
64         // on another thread
65         if (synchronized(observedScopeMapsLock) { sendingNotifications }) return false
66 
67         // Remove all pending changes and return true if any of the objects are observed
68         var hasValues = false
69         while (true) {
70             val notifications = removeChanges() ?: return hasValues
71             forEachScopeMap { scopeMap ->
72                 hasValues = scopeMap.recordInvalidation(notifications) || hasValues
73             }
74         }
75     }
76 
77     /**
78      * Send any pending notifications. Uses [onChangedExecutor] to schedule this work.
79      *
80      * This method should only be called if, and only if, a call to `drainChanges()` returns `true`.
81      */
82     private fun sendNotifications() {
83         onChangedExecutor {
84             while (true) {
85                 synchronized(observedScopeMapsLock) {
86                     if (!sendingNotifications) {
87                         sendingNotifications = true
88                         try {
89                             observedScopeMaps.forEach { scopeMap ->
90                                 scopeMap.notifyInvalidatedScopes()
91                             }
92                         } finally {
93                             sendingNotifications = false
94                         }
95                     }
96                 }
97 
98                 // If any changes arrived while we were notifying, send the new changes.
99                 if (!drainChanges()) break
100             }
101         }
102     }
103 
104     /**
105      * Add changes to the changes queue. This uses an atomic reference as a queue to minimize the
106      * number of allocations required in the normal case. If, for example, only one set is added to
107      * the queue, the set itself is the atomic reference. If the queue is empty the reference is
108      * null. Only if there are more than one set added to the queue is an allocation required, then
109      * the atomic reference is a list containing all the sets in the queue. Given the size of the
110      * queue, the type of object referenced is, 0 -> null 1 -> Set<Any?> 2 or more ->
111      * List<Set<Any?>>
112      */
113     private fun addChanges(set: Set<Any>) {
114         while (true) {
115             val old = pendingChanges.get()
116             val new =
117                 when (old) {
118                     null -> set
119                     is Set<*> -> listOf(old, set)
120                     is List<*> -> old + listOf(set)
121                     else -> report()
122                 }
123             if (pendingChanges.compareAndSet(old, new)) break
124         }
125     }
126 
127     /**
128      * Remove a set of changes from the change queue. See [addChanges] for a description of how this
129      * queue works.
130      */
131     @Suppress("UNCHECKED_CAST")
132     private fun removeChanges(): Set<Any>? {
133         while (true) {
134             val old = pendingChanges.get()
135             var result: Set<Any>?
136             var new: Any?
137             when (old) {
138                 null -> return null // The queue is empty
139                 is Set<*> -> {
140                     result = old as Set<Any>?
141                     new = null
142                 }
143                 is List<*> -> {
144                     result = old[0] as Set<Any>?
145                     new =
146                         when {
147                             old.size == 2 -> old[1]
148                             old.size > 2 -> old.subList(1, old.size)
149                             else -> null
150                         }
151                 }
152                 else -> report()
153             }
154             if (pendingChanges.compareAndSet(old, new)) {
155                 return result
156             }
157         }
158     }
159 
160     private fun report(): Nothing = composeRuntimeError("Unexpected notification")
161 
162     /** The observer used by this [SnapshotStateObserver] during [observeReads]. */
163     private val readObserver: (Any) -> Unit = { state ->
164         if (!isPaused) {
165             synchronized(observedScopeMapsLock) { currentMap!!.recordRead(state) }
166         }
167     }
168 
169     /**
170      * List of all [ObservedScopeMap]s. When [observeReads] is called, there will be a
171      * [ObservedScopeMap] associated with its [ObservedScopeMap.onChanged] callback in this list.
172      * The list only grows.
173      */
174     private val observedScopeMaps = mutableVectorOf<ObservedScopeMap>()
175     private val observedScopeMapsLock = makeSynchronizedObject()
176 
177     /**
178      * Helper for synchronized iteration over [observedScopeMaps]. All observed reads should happen
179      * on the same thread, but snapshots can be applied on a different thread, requiring
180      * synchronization.
181      */
182     private inline fun forEachScopeMap(block: (ObservedScopeMap) -> Unit) {
183         synchronized(observedScopeMapsLock) { observedScopeMaps.forEach(block) }
184     }
185 
186     private inline fun removeScopeMapIf(block: (ObservedScopeMap) -> Boolean) {
187         synchronized(observedScopeMapsLock) { observedScopeMaps.removeIf(block) }
188     }
189 
190     /** Method to call when unsubscribing from the apply observer. */
191     private var applyUnsubscribe: ObserverHandle? = null
192 
193     /**
194      * `true` when [withNoObservations] is called and read observations should not be considered
195      * invalidations for the current scope.
196      */
197     private var isPaused = false
198 
199     /**
200      * The [ObservedScopeMap] that should be added to when a model is read during [observeReads].
201      */
202     private var currentMap: ObservedScopeMap? = null
203 
204     /** Thread id that has set the [currentMap] */
205     private var currentMapThreadId = -1L
206 
207     /**
208      * Executes [block], observing state object reads during its execution.
209      *
210      * The [scope] and [onValueChangedForScope] are associated with any values that are read so that
211      * when those values change, [onValueChangedForScope] will be called with the [scope] parameter.
212      *
213      * Observation can be paused with [Snapshot.withoutReadObservation].
214      *
215      * @param scope value associated with the observed scope.
216      * @param onValueChangedForScope is called with the [scope] when value read within [block] has
217      *   been changed. For repeated observations, it is more performant to pass the same instance of
218      *   the callback, as [observedScopeMaps] grows with each new callback instance.
219      * @param block to observe reads within.
220      */
221     fun <T : Any> observeReads(scope: T, onValueChangedForScope: (T) -> Unit, block: () -> Unit) {
222         val scopeMap = synchronized(observedScopeMapsLock) { ensureMap(onValueChangedForScope) }
223 
224         val oldPaused = isPaused
225         val oldMap = currentMap
226         val oldThreadId = currentMapThreadId
227 
228         if (oldThreadId != -1L) {
229             requirePrecondition(oldThreadId == currentThreadId()) {
230                 "Detected multithreaded access to SnapshotStateObserver: " +
231                     "previousThreadId=$oldThreadId), " +
232                     "currentThread={id=${currentThreadId()}, name=${currentThreadName()}}. " +
233                     "Note that observation on multiple threads in layout/draw is not supported. " +
234                     "Make sure your measure/layout/draw for each Owner (AndroidComposeView) " +
235                     "is executed on the same thread."
236             }
237         }
238 
239         try {
240             isPaused = false
241             currentMap = scopeMap
242             currentMapThreadId = currentThreadId()
243 
244             scopeMap.observe(scope, readObserver, block)
245         } finally {
246             currentMap = oldMap
247             isPaused = oldPaused
248             currentMapThreadId = oldThreadId
249         }
250     }
251 
252     /**
253      * Stops observing state object reads while executing [block]. State object reads may be
254      * restarted by calling [observeReads] inside [block].
255      */
256     @Deprecated(
257         "Replace with Snapshot.withoutReadObservation()",
258         ReplaceWith(
259             "Snapshot.withoutReadObservation(block)",
260             "androidx.compose.runtime.snapshots.Snapshot"
261         )
262     )
263     fun withNoObservations(block: () -> Unit) {
264         val oldPaused = isPaused
265         isPaused = true
266         try {
267             block()
268         } finally {
269             isPaused = oldPaused
270         }
271     }
272 
273     /**
274      * Clears all state read observations for a given [scope]. This clears values for all
275      * `onValueChangedForScope` callbacks passed in [observeReads].
276      */
277     fun clear(scope: Any) {
278         removeScopeMapIf {
279             it.clearScopeObservations(scope)
280             !it.hasScopeObservations()
281         }
282     }
283 
284     /**
285      * Remove observations using [predicate] to identify scopes to be removed. This is used when a
286      * scope is no longer in the hierarchy and should not receive any callbacks.
287      */
288     fun clearIf(predicate: (scope: Any) -> Boolean) {
289         removeScopeMapIf { scopeMap ->
290             scopeMap.removeScopeIf(predicate)
291             !scopeMap.hasScopeObservations()
292         }
293     }
294 
295     /** Starts watching for state commits. */
296     fun start() {
297         applyUnsubscribe = Snapshot.registerApplyObserver(applyObserver)
298     }
299 
300     /** Stops watching for state commits. */
301     fun stop() {
302         applyUnsubscribe?.dispose()
303     }
304 
305     /**
306      * This method is only used for testing. It notifies that [changes] have been made on
307      * [snapshot].
308      */
309     @TestOnly
310     fun notifyChanges(changes: Set<Any>, snapshot: Snapshot) {
311         applyObserver(changes, snapshot)
312     }
313 
314     /** Remove all observations. */
315     fun clear() {
316         forEachScopeMap { scopeMap -> scopeMap.clear() }
317     }
318 
319     /**
320      * Returns the [ObservedScopeMap] within [observedScopeMaps] associated with [onChanged] or a
321      * newly- inserted one if it doesn't exist.
322      *
323      * Must be called inside a synchronized block.
324      */
325     @Suppress("UNCHECKED_CAST")
326     private fun <T : Any> ensureMap(onChanged: (T) -> Unit): ObservedScopeMap {
327         val scopeMap = observedScopeMaps.firstOrNull { it.onChanged === onChanged }
328         if (scopeMap == null) {
329             val map = ObservedScopeMap(onChanged as ((Any) -> Unit))
330             observedScopeMaps += map
331             return map
332         }
333         return scopeMap
334     }
335 
336     /** Connects observed values to scopes for each [onChanged] callback. */
337     @Suppress("UNCHECKED_CAST")
338     private class ObservedScopeMap(val onChanged: (Any) -> Unit) {
339         /** Currently observed scope. */
340         private var currentScope: Any? = null
341 
342         /**
343          * key: State reads observed in current scope. value: [currentToken] at the time the read
344          * was observed in.
345          */
346         private var currentScopeReads: MutableObjectIntMap<Any>? = null
347 
348         /**
349          * Token for current observation cycle; usually corresponds to snapshot ID at the time when
350          * observation started.
351          */
352         private var currentToken: Int = -1
353 
354         /** Values that have been read during the scope's [SnapshotStateObserver.observeReads]. */
355         private val valueToScopes = ScopeMap<Any, Any>()
356 
357         /** Reverse index (scope -> values) for faster scope invalidation. */
358         private val scopeToValues: MutableScatterMap<Any, MutableObjectIntMap<Any>> =
359             MutableScatterMap()
360 
361         /** Scopes that were invalidated during previous apply step. */
362         private val invalidated = MutableScatterSet<Any>()
363 
364         /** Reusable vector for re-recording states inside [recordInvalidation] */
365         private val statesToReread = mutableVectorOf<DerivedState<*>>()
366 
367         // derived state handling
368 
369         /** Observer for derived state recalculation */
370         val derivedStateObserver =
371             object : DerivedStateObserver {
372                 override fun start(derivedState: DerivedState<*>) {
373                     deriveStateScopeCount++
374                 }
375 
376                 override fun done(derivedState: DerivedState<*>) {
377                     deriveStateScopeCount--
378                 }
379             }
380 
381         /**
382          * Counter for skipping reads inside derived states. If count is > 0, read happens inside a
383          * derived state. Reads for derived states are captured separately through
384          * [DerivedState.Record.dependencies].
385          */
386         private var deriveStateScopeCount = 0
387 
388         /** Invalidation index from state objects to derived states reading them. */
389         private val dependencyToDerivedStates = ScopeMap<Any, DerivedState<*>>()
390 
391         /** Last derived state value recorded during read. */
392         private val recordedDerivedStateValues = HashMap<DerivedState<*>, Any?>()
393 
394         fun recordRead(value: Any) {
395             val scope = currentScope!!
396             recordRead(
397                 value = value,
398                 currentToken = currentToken,
399                 currentScope = scope,
400                 recordedValues =
401                     currentScopeReads
402                         ?: MutableObjectIntMap<Any>().also {
403                             currentScopeReads = it
404                             scopeToValues[scope] = it
405                         }
406             )
407         }
408 
409         /** Record that [value] was read in [currentScope]. */
410         private fun recordRead(
411             value: Any,
412             currentToken: Int,
413             currentScope: Any,
414             recordedValues: MutableObjectIntMap<Any>
415         ) {
416             if (deriveStateScopeCount > 0) {
417                 // Reads coming from derivedStateOf block
418                 return
419             }
420 
421             val previousToken = recordedValues.put(value, currentToken, -1)
422             if (value is DerivedState<*> && previousToken != currentToken) {
423                 val record = value.currentRecord
424                 // re-read the value before removing dependencies, in case the new value wasn't read
425                 recordedDerivedStateValues[value] = record.currentValue
426 
427                 val dependencies = record.dependencies
428                 val dependencyToDerivedStates = dependencyToDerivedStates
429 
430                 dependencyToDerivedStates.removeScope(value)
431                 dependencies.forEachKey { dependency ->
432                     if (dependency is StateObjectImpl) {
433                         dependency.recordReadIn(ReaderKind.SnapshotStateObserver)
434                     }
435                     dependencyToDerivedStates.add(dependency, value)
436                 }
437             }
438 
439             if (previousToken == -1) {
440                 if (value is StateObjectImpl) {
441                     value.recordReadIn(ReaderKind.SnapshotStateObserver)
442                 }
443                 valueToScopes.add(value, currentScope)
444             }
445         }
446 
447         /** Setup new scope for state read observation, observe them, and cleanup afterwards */
448         fun observe(scope: Any, readObserver: (Any) -> Unit, block: () -> Unit) {
449             val previousScope = currentScope
450             val previousReads = currentScopeReads
451             val previousToken = currentToken
452 
453             currentScope = scope
454             currentScopeReads = scopeToValues[scope]
455             if (currentToken == -1) {
456                 currentToken = currentSnapshot().snapshotId.hashCode()
457             }
458 
459             observeDerivedStateRecalculations(derivedStateObserver) {
460                 Snapshot.observe(readObserver, null, block)
461             }
462 
463             clearObsoleteStateReads(currentScope!!)
464 
465             currentScope = previousScope
466             currentScopeReads = previousReads
467             currentToken = previousToken
468         }
469 
470         private fun clearObsoleteStateReads(scope: Any) {
471             val currentToken = currentToken
472             currentScopeReads?.removeIf { value, token ->
473                 (token != currentToken).also { willRemove ->
474                     if (willRemove) {
475                         removeObservation(scope, value)
476                     }
477                 }
478             }
479         }
480 
481         /** Clear observations for [scope]. */
482         fun clearScopeObservations(scope: Any) {
483             val recordedValues = scopeToValues.remove(scope) ?: return
484             recordedValues.forEach { value, _ -> removeObservation(scope, value) }
485         }
486 
487         /** Remove observations in scopes matching [predicate]. */
488         fun removeScopeIf(predicate: (scope: Any) -> Boolean) {
489             scopeToValues.removeIf { scope, valueSet ->
490                 predicate(scope).also { willRemove ->
491                     if (willRemove) {
492                         valueSet.forEach { value, _ -> removeObservation(scope, value) }
493                     }
494                 }
495             }
496         }
497 
498         fun hasScopeObservations(): Boolean = scopeToValues.isNotEmpty()
499 
500         private fun removeObservation(scope: Any, value: Any) {
501             valueToScopes.remove(value, scope)
502             if (value is DerivedState<*> && value !in valueToScopes) {
503                 dependencyToDerivedStates.removeScope(value)
504                 recordedDerivedStateValues.remove(value)
505             }
506         }
507 
508         /** Clear all observations. */
509         fun clear() {
510             valueToScopes.clear()
511             scopeToValues.clear()
512             dependencyToDerivedStates.clear()
513             recordedDerivedStateValues.clear()
514         }
515 
516         /**
517          * Record scope invalidation for given set of values.
518          *
519          * @return whether any scopes observe changed values
520          */
521         fun recordInvalidation(changes: Set<Any>): Boolean {
522             var hasValues = false
523 
524             val dependencyToDerivedStates = dependencyToDerivedStates
525             val recordedDerivedStateValues = recordedDerivedStateValues
526             val valueToScopes = valueToScopes
527             val invalidated = invalidated
528 
529             changes.fastForEach { value ->
530                 if (value is StateObjectImpl && !value.isReadIn(ReaderKind.SnapshotStateObserver)) {
531                     return@fastForEach
532                 }
533 
534                 if (value in dependencyToDerivedStates) {
535                     // Find derived state that is invalidated by this change
536                     dependencyToDerivedStates.forEachScopeOf(value) { derivedState ->
537                         derivedState as DerivedState<Any?>
538                         val previousValue = recordedDerivedStateValues[derivedState]
539                         val policy = derivedState.policy ?: structuralEqualityPolicy()
540 
541                         // Invalidate only if currentValue is different than observed on read
542                         if (
543                             !policy.equivalent(
544                                 derivedState.currentRecord.currentValue,
545                                 previousValue
546                             )
547                         ) {
548                             valueToScopes.forEachScopeOf(derivedState) { scope ->
549                                 invalidated.add(scope)
550                                 hasValues = true
551                             }
552                         } else {
553                             // Re-read state to ensure its dependencies are up-to-date
554                             statesToReread.add(derivedState)
555                         }
556                     }
557                 }
558 
559                 valueToScopes.forEachScopeOf(value) { scope ->
560                     invalidated.add(scope)
561                     hasValues = true
562                 }
563             }
564 
565             if (statesToReread.isNotEmpty()) {
566                 statesToReread.forEach { rereadDerivedState(it) }
567                 statesToReread.clear()
568             }
569 
570             return hasValues
571         }
572 
573         fun rereadDerivedState(derivedState: DerivedState<*>) {
574             val scopeToValues = scopeToValues
575             val token = currentSnapshot().snapshotId.hashCode()
576 
577             valueToScopes.forEachScopeOf(derivedState) { scope ->
578                 recordRead(
579                     value = derivedState,
580                     currentToken = token,
581                     currentScope = scope,
582                     recordedValues =
583                         scopeToValues[scope]
584                             ?: MutableObjectIntMap<Any>().also { scopeToValues[scope] = it }
585                 )
586             }
587         }
588 
589         /** Call [onChanged] for previously invalidated scopes. */
590         fun notifyInvalidatedScopes() {
591             val invalidated = invalidated
592             invalidated.forEach(onChanged)
593             invalidated.clear()
594         }
595     }
596 }
597