1 /* <lambda>null2 * Copyright 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package androidx.compose.runtime.snapshots 18 19 import androidx.collection.MutableObjectIntMap 20 import androidx.collection.MutableScatterMap 21 import androidx.collection.MutableScatterSet 22 import androidx.compose.runtime.DerivedState 23 import androidx.compose.runtime.DerivedStateObserver 24 import androidx.compose.runtime.TestOnly 25 import androidx.compose.runtime.collection.ScopeMap 26 import androidx.compose.runtime.collection.fastForEach 27 import androidx.compose.runtime.collection.mutableVectorOf 28 import androidx.compose.runtime.composeRuntimeError 29 import androidx.compose.runtime.internal.AtomicReference 30 import androidx.compose.runtime.internal.currentThreadId 31 import androidx.compose.runtime.internal.currentThreadName 32 import androidx.compose.runtime.observeDerivedStateRecalculations 33 import androidx.compose.runtime.platform.makeSynchronizedObject 34 import androidx.compose.runtime.platform.synchronized 35 import androidx.compose.runtime.requirePrecondition 36 import androidx.compose.runtime.structuralEqualityPolicy 37 38 /** 39 * Helper class to efficiently observe snapshot state reads. See [observeReads] for more details. 40 * 41 * NOTE: This class is not thread-safe, so implementations should not reuse observer between 42 * different threads to avoid race conditions. 43 */ 44 @Suppress("NotCloseable") // we can't implement AutoCloseable from commonMain 45 class SnapshotStateObserver(private val onChangedExecutor: (callback: () -> Unit) -> Unit) { 46 private val pendingChanges = AtomicReference<Any?>(null) 47 private var sendingNotifications = false 48 49 private val applyObserver: (Set<Any>, Snapshot) -> Unit = { applied, _ -> 50 addChanges(applied) 51 if (drainChanges()) sendNotifications() 52 } 53 54 /** 55 * Drain the pending changes from the pending changes queue invalidating any scope maps that 56 * contain objects in any of the sets. Return true if changes were found. 57 * 58 * This immediately returns with false if notifications are already being sent. It is the 59 * responsibility of any user of this function to ensure that the queue is re-checked after 60 * dispatching notifications. 61 */ 62 private fun drainChanges(): Boolean { 63 // Don't modify the scope maps while notifications are being sent either by the caller or 64 // on another thread 65 if (synchronized(observedScopeMapsLock) { sendingNotifications }) return false 66 67 // Remove all pending changes and return true if any of the objects are observed 68 var hasValues = false 69 while (true) { 70 val notifications = removeChanges() ?: return hasValues 71 forEachScopeMap { scopeMap -> 72 hasValues = scopeMap.recordInvalidation(notifications) || hasValues 73 } 74 } 75 } 76 77 /** 78 * Send any pending notifications. Uses [onChangedExecutor] to schedule this work. 79 * 80 * This method should only be called if, and only if, a call to `drainChanges()` returns `true`. 81 */ 82 private fun sendNotifications() { 83 onChangedExecutor { 84 while (true) { 85 synchronized(observedScopeMapsLock) { 86 if (!sendingNotifications) { 87 sendingNotifications = true 88 try { 89 observedScopeMaps.forEach { scopeMap -> 90 scopeMap.notifyInvalidatedScopes() 91 } 92 } finally { 93 sendingNotifications = false 94 } 95 } 96 } 97 98 // If any changes arrived while we were notifying, send the new changes. 99 if (!drainChanges()) break 100 } 101 } 102 } 103 104 /** 105 * Add changes to the changes queue. This uses an atomic reference as a queue to minimize the 106 * number of allocations required in the normal case. If, for example, only one set is added to 107 * the queue, the set itself is the atomic reference. If the queue is empty the reference is 108 * null. Only if there are more than one set added to the queue is an allocation required, then 109 * the atomic reference is a list containing all the sets in the queue. Given the size of the 110 * queue, the type of object referenced is, 0 -> null 1 -> Set<Any?> 2 or more -> 111 * List<Set<Any?>> 112 */ 113 private fun addChanges(set: Set<Any>) { 114 while (true) { 115 val old = pendingChanges.get() 116 val new = 117 when (old) { 118 null -> set 119 is Set<*> -> listOf(old, set) 120 is List<*> -> old + listOf(set) 121 else -> report() 122 } 123 if (pendingChanges.compareAndSet(old, new)) break 124 } 125 } 126 127 /** 128 * Remove a set of changes from the change queue. See [addChanges] for a description of how this 129 * queue works. 130 */ 131 @Suppress("UNCHECKED_CAST") 132 private fun removeChanges(): Set<Any>? { 133 while (true) { 134 val old = pendingChanges.get() 135 var result: Set<Any>? 136 var new: Any? 137 when (old) { 138 null -> return null // The queue is empty 139 is Set<*> -> { 140 result = old as Set<Any>? 141 new = null 142 } 143 is List<*> -> { 144 result = old[0] as Set<Any>? 145 new = 146 when { 147 old.size == 2 -> old[1] 148 old.size > 2 -> old.subList(1, old.size) 149 else -> null 150 } 151 } 152 else -> report() 153 } 154 if (pendingChanges.compareAndSet(old, new)) { 155 return result 156 } 157 } 158 } 159 160 private fun report(): Nothing = composeRuntimeError("Unexpected notification") 161 162 /** The observer used by this [SnapshotStateObserver] during [observeReads]. */ 163 private val readObserver: (Any) -> Unit = { state -> 164 if (!isPaused) { 165 synchronized(observedScopeMapsLock) { currentMap!!.recordRead(state) } 166 } 167 } 168 169 /** 170 * List of all [ObservedScopeMap]s. When [observeReads] is called, there will be a 171 * [ObservedScopeMap] associated with its [ObservedScopeMap.onChanged] callback in this list. 172 * The list only grows. 173 */ 174 private val observedScopeMaps = mutableVectorOf<ObservedScopeMap>() 175 private val observedScopeMapsLock = makeSynchronizedObject() 176 177 /** 178 * Helper for synchronized iteration over [observedScopeMaps]. All observed reads should happen 179 * on the same thread, but snapshots can be applied on a different thread, requiring 180 * synchronization. 181 */ 182 private inline fun forEachScopeMap(block: (ObservedScopeMap) -> Unit) { 183 synchronized(observedScopeMapsLock) { observedScopeMaps.forEach(block) } 184 } 185 186 private inline fun removeScopeMapIf(block: (ObservedScopeMap) -> Boolean) { 187 synchronized(observedScopeMapsLock) { observedScopeMaps.removeIf(block) } 188 } 189 190 /** Method to call when unsubscribing from the apply observer. */ 191 private var applyUnsubscribe: ObserverHandle? = null 192 193 /** 194 * `true` when [withNoObservations] is called and read observations should not be considered 195 * invalidations for the current scope. 196 */ 197 private var isPaused = false 198 199 /** 200 * The [ObservedScopeMap] that should be added to when a model is read during [observeReads]. 201 */ 202 private var currentMap: ObservedScopeMap? = null 203 204 /** Thread id that has set the [currentMap] */ 205 private var currentMapThreadId = -1L 206 207 /** 208 * Executes [block], observing state object reads during its execution. 209 * 210 * The [scope] and [onValueChangedForScope] are associated with any values that are read so that 211 * when those values change, [onValueChangedForScope] will be called with the [scope] parameter. 212 * 213 * Observation can be paused with [Snapshot.withoutReadObservation]. 214 * 215 * @param scope value associated with the observed scope. 216 * @param onValueChangedForScope is called with the [scope] when value read within [block] has 217 * been changed. For repeated observations, it is more performant to pass the same instance of 218 * the callback, as [observedScopeMaps] grows with each new callback instance. 219 * @param block to observe reads within. 220 */ 221 fun <T : Any> observeReads(scope: T, onValueChangedForScope: (T) -> Unit, block: () -> Unit) { 222 val scopeMap = synchronized(observedScopeMapsLock) { ensureMap(onValueChangedForScope) } 223 224 val oldPaused = isPaused 225 val oldMap = currentMap 226 val oldThreadId = currentMapThreadId 227 228 if (oldThreadId != -1L) { 229 requirePrecondition(oldThreadId == currentThreadId()) { 230 "Detected multithreaded access to SnapshotStateObserver: " + 231 "previousThreadId=$oldThreadId), " + 232 "currentThread={id=${currentThreadId()}, name=${currentThreadName()}}. " + 233 "Note that observation on multiple threads in layout/draw is not supported. " + 234 "Make sure your measure/layout/draw for each Owner (AndroidComposeView) " + 235 "is executed on the same thread." 236 } 237 } 238 239 try { 240 isPaused = false 241 currentMap = scopeMap 242 currentMapThreadId = currentThreadId() 243 244 scopeMap.observe(scope, readObserver, block) 245 } finally { 246 currentMap = oldMap 247 isPaused = oldPaused 248 currentMapThreadId = oldThreadId 249 } 250 } 251 252 /** 253 * Stops observing state object reads while executing [block]. State object reads may be 254 * restarted by calling [observeReads] inside [block]. 255 */ 256 @Deprecated( 257 "Replace with Snapshot.withoutReadObservation()", 258 ReplaceWith( 259 "Snapshot.withoutReadObservation(block)", 260 "androidx.compose.runtime.snapshots.Snapshot" 261 ) 262 ) 263 fun withNoObservations(block: () -> Unit) { 264 val oldPaused = isPaused 265 isPaused = true 266 try { 267 block() 268 } finally { 269 isPaused = oldPaused 270 } 271 } 272 273 /** 274 * Clears all state read observations for a given [scope]. This clears values for all 275 * `onValueChangedForScope` callbacks passed in [observeReads]. 276 */ 277 fun clear(scope: Any) { 278 removeScopeMapIf { 279 it.clearScopeObservations(scope) 280 !it.hasScopeObservations() 281 } 282 } 283 284 /** 285 * Remove observations using [predicate] to identify scopes to be removed. This is used when a 286 * scope is no longer in the hierarchy and should not receive any callbacks. 287 */ 288 fun clearIf(predicate: (scope: Any) -> Boolean) { 289 removeScopeMapIf { scopeMap -> 290 scopeMap.removeScopeIf(predicate) 291 !scopeMap.hasScopeObservations() 292 } 293 } 294 295 /** Starts watching for state commits. */ 296 fun start() { 297 applyUnsubscribe = Snapshot.registerApplyObserver(applyObserver) 298 } 299 300 /** Stops watching for state commits. */ 301 fun stop() { 302 applyUnsubscribe?.dispose() 303 } 304 305 /** 306 * This method is only used for testing. It notifies that [changes] have been made on 307 * [snapshot]. 308 */ 309 @TestOnly 310 fun notifyChanges(changes: Set<Any>, snapshot: Snapshot) { 311 applyObserver(changes, snapshot) 312 } 313 314 /** Remove all observations. */ 315 fun clear() { 316 forEachScopeMap { scopeMap -> scopeMap.clear() } 317 } 318 319 /** 320 * Returns the [ObservedScopeMap] within [observedScopeMaps] associated with [onChanged] or a 321 * newly- inserted one if it doesn't exist. 322 * 323 * Must be called inside a synchronized block. 324 */ 325 @Suppress("UNCHECKED_CAST") 326 private fun <T : Any> ensureMap(onChanged: (T) -> Unit): ObservedScopeMap { 327 val scopeMap = observedScopeMaps.firstOrNull { it.onChanged === onChanged } 328 if (scopeMap == null) { 329 val map = ObservedScopeMap(onChanged as ((Any) -> Unit)) 330 observedScopeMaps += map 331 return map 332 } 333 return scopeMap 334 } 335 336 /** Connects observed values to scopes for each [onChanged] callback. */ 337 @Suppress("UNCHECKED_CAST") 338 private class ObservedScopeMap(val onChanged: (Any) -> Unit) { 339 /** Currently observed scope. */ 340 private var currentScope: Any? = null 341 342 /** 343 * key: State reads observed in current scope. value: [currentToken] at the time the read 344 * was observed in. 345 */ 346 private var currentScopeReads: MutableObjectIntMap<Any>? = null 347 348 /** 349 * Token for current observation cycle; usually corresponds to snapshot ID at the time when 350 * observation started. 351 */ 352 private var currentToken: Int = -1 353 354 /** Values that have been read during the scope's [SnapshotStateObserver.observeReads]. */ 355 private val valueToScopes = ScopeMap<Any, Any>() 356 357 /** Reverse index (scope -> values) for faster scope invalidation. */ 358 private val scopeToValues: MutableScatterMap<Any, MutableObjectIntMap<Any>> = 359 MutableScatterMap() 360 361 /** Scopes that were invalidated during previous apply step. */ 362 private val invalidated = MutableScatterSet<Any>() 363 364 /** Reusable vector for re-recording states inside [recordInvalidation] */ 365 private val statesToReread = mutableVectorOf<DerivedState<*>>() 366 367 // derived state handling 368 369 /** Observer for derived state recalculation */ 370 val derivedStateObserver = 371 object : DerivedStateObserver { 372 override fun start(derivedState: DerivedState<*>) { 373 deriveStateScopeCount++ 374 } 375 376 override fun done(derivedState: DerivedState<*>) { 377 deriveStateScopeCount-- 378 } 379 } 380 381 /** 382 * Counter for skipping reads inside derived states. If count is > 0, read happens inside a 383 * derived state. Reads for derived states are captured separately through 384 * [DerivedState.Record.dependencies]. 385 */ 386 private var deriveStateScopeCount = 0 387 388 /** Invalidation index from state objects to derived states reading them. */ 389 private val dependencyToDerivedStates = ScopeMap<Any, DerivedState<*>>() 390 391 /** Last derived state value recorded during read. */ 392 private val recordedDerivedStateValues = HashMap<DerivedState<*>, Any?>() 393 394 fun recordRead(value: Any) { 395 val scope = currentScope!! 396 recordRead( 397 value = value, 398 currentToken = currentToken, 399 currentScope = scope, 400 recordedValues = 401 currentScopeReads 402 ?: MutableObjectIntMap<Any>().also { 403 currentScopeReads = it 404 scopeToValues[scope] = it 405 } 406 ) 407 } 408 409 /** Record that [value] was read in [currentScope]. */ 410 private fun recordRead( 411 value: Any, 412 currentToken: Int, 413 currentScope: Any, 414 recordedValues: MutableObjectIntMap<Any> 415 ) { 416 if (deriveStateScopeCount > 0) { 417 // Reads coming from derivedStateOf block 418 return 419 } 420 421 val previousToken = recordedValues.put(value, currentToken, -1) 422 if (value is DerivedState<*> && previousToken != currentToken) { 423 val record = value.currentRecord 424 // re-read the value before removing dependencies, in case the new value wasn't read 425 recordedDerivedStateValues[value] = record.currentValue 426 427 val dependencies = record.dependencies 428 val dependencyToDerivedStates = dependencyToDerivedStates 429 430 dependencyToDerivedStates.removeScope(value) 431 dependencies.forEachKey { dependency -> 432 if (dependency is StateObjectImpl) { 433 dependency.recordReadIn(ReaderKind.SnapshotStateObserver) 434 } 435 dependencyToDerivedStates.add(dependency, value) 436 } 437 } 438 439 if (previousToken == -1) { 440 if (value is StateObjectImpl) { 441 value.recordReadIn(ReaderKind.SnapshotStateObserver) 442 } 443 valueToScopes.add(value, currentScope) 444 } 445 } 446 447 /** Setup new scope for state read observation, observe them, and cleanup afterwards */ 448 fun observe(scope: Any, readObserver: (Any) -> Unit, block: () -> Unit) { 449 val previousScope = currentScope 450 val previousReads = currentScopeReads 451 val previousToken = currentToken 452 453 currentScope = scope 454 currentScopeReads = scopeToValues[scope] 455 if (currentToken == -1) { 456 currentToken = currentSnapshot().snapshotId.hashCode() 457 } 458 459 observeDerivedStateRecalculations(derivedStateObserver) { 460 Snapshot.observe(readObserver, null, block) 461 } 462 463 clearObsoleteStateReads(currentScope!!) 464 465 currentScope = previousScope 466 currentScopeReads = previousReads 467 currentToken = previousToken 468 } 469 470 private fun clearObsoleteStateReads(scope: Any) { 471 val currentToken = currentToken 472 currentScopeReads?.removeIf { value, token -> 473 (token != currentToken).also { willRemove -> 474 if (willRemove) { 475 removeObservation(scope, value) 476 } 477 } 478 } 479 } 480 481 /** Clear observations for [scope]. */ 482 fun clearScopeObservations(scope: Any) { 483 val recordedValues = scopeToValues.remove(scope) ?: return 484 recordedValues.forEach { value, _ -> removeObservation(scope, value) } 485 } 486 487 /** Remove observations in scopes matching [predicate]. */ 488 fun removeScopeIf(predicate: (scope: Any) -> Boolean) { 489 scopeToValues.removeIf { scope, valueSet -> 490 predicate(scope).also { willRemove -> 491 if (willRemove) { 492 valueSet.forEach { value, _ -> removeObservation(scope, value) } 493 } 494 } 495 } 496 } 497 498 fun hasScopeObservations(): Boolean = scopeToValues.isNotEmpty() 499 500 private fun removeObservation(scope: Any, value: Any) { 501 valueToScopes.remove(value, scope) 502 if (value is DerivedState<*> && value !in valueToScopes) { 503 dependencyToDerivedStates.removeScope(value) 504 recordedDerivedStateValues.remove(value) 505 } 506 } 507 508 /** Clear all observations. */ 509 fun clear() { 510 valueToScopes.clear() 511 scopeToValues.clear() 512 dependencyToDerivedStates.clear() 513 recordedDerivedStateValues.clear() 514 } 515 516 /** 517 * Record scope invalidation for given set of values. 518 * 519 * @return whether any scopes observe changed values 520 */ 521 fun recordInvalidation(changes: Set<Any>): Boolean { 522 var hasValues = false 523 524 val dependencyToDerivedStates = dependencyToDerivedStates 525 val recordedDerivedStateValues = recordedDerivedStateValues 526 val valueToScopes = valueToScopes 527 val invalidated = invalidated 528 529 changes.fastForEach { value -> 530 if (value is StateObjectImpl && !value.isReadIn(ReaderKind.SnapshotStateObserver)) { 531 return@fastForEach 532 } 533 534 if (value in dependencyToDerivedStates) { 535 // Find derived state that is invalidated by this change 536 dependencyToDerivedStates.forEachScopeOf(value) { derivedState -> 537 derivedState as DerivedState<Any?> 538 val previousValue = recordedDerivedStateValues[derivedState] 539 val policy = derivedState.policy ?: structuralEqualityPolicy() 540 541 // Invalidate only if currentValue is different than observed on read 542 if ( 543 !policy.equivalent( 544 derivedState.currentRecord.currentValue, 545 previousValue 546 ) 547 ) { 548 valueToScopes.forEachScopeOf(derivedState) { scope -> 549 invalidated.add(scope) 550 hasValues = true 551 } 552 } else { 553 // Re-read state to ensure its dependencies are up-to-date 554 statesToReread.add(derivedState) 555 } 556 } 557 } 558 559 valueToScopes.forEachScopeOf(value) { scope -> 560 invalidated.add(scope) 561 hasValues = true 562 } 563 } 564 565 if (statesToReread.isNotEmpty()) { 566 statesToReread.forEach { rereadDerivedState(it) } 567 statesToReread.clear() 568 } 569 570 return hasValues 571 } 572 573 fun rereadDerivedState(derivedState: DerivedState<*>) { 574 val scopeToValues = scopeToValues 575 val token = currentSnapshot().snapshotId.hashCode() 576 577 valueToScopes.forEachScopeOf(derivedState) { scope -> 578 recordRead( 579 value = derivedState, 580 currentToken = token, 581 currentScope = scope, 582 recordedValues = 583 scopeToValues[scope] 584 ?: MutableObjectIntMap<Any>().also { scopeToValues[scope] = it } 585 ) 586 } 587 } 588 589 /** Call [onChanged] for previously invalidated scopes. */ 590 fun notifyInvalidatedScopes() { 591 val invalidated = invalidated 592 invalidated.forEach(onChanged) 593 invalidated.clear() 594 } 595 } 596 } 597