1 /*
<lambda>null2  * Copyright 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.compose.runtime.snapshots
18 
19 import androidx.compose.runtime.Stable
20 import androidx.compose.runtime.external.kotlinx.collections.immutable.PersistentMap
21 import androidx.compose.runtime.external.kotlinx.collections.immutable.persistentHashMapOf
22 import androidx.compose.runtime.platform.makeSynchronizedObject
23 import androidx.compose.runtime.platform.synchronized
24 import kotlin.jvm.JvmName
25 
26 /**
27  * An implementation of [MutableMap] that can be observed and snapshot. This is the result type
28  * created by [androidx.compose.runtime.mutableStateMapOf].
29  *
30  * This class closely implements the same semantics as [HashMap].
31  *
32  * @see androidx.compose.runtime.mutableStateMapOf
33  */
34 @Stable
35 class SnapshotStateMap<K, V> : StateObject, MutableMap<K, V> {
36     override var firstStateRecord: StateRecord =
37         persistentHashMapOf<K, V>().let { map ->
38             val snapshot = currentSnapshot()
39             StateMapStateRecord(snapshot.snapshotId, map).also {
40                 if (snapshot !is GlobalSnapshot) {
41                     it.next =
42                         StateMapStateRecord(Snapshot.PreexistingSnapshotId.toSnapshotId(), map)
43                 }
44             }
45         }
46         private set
47 
48     override fun prependStateRecord(value: StateRecord) {
49         @Suppress("UNCHECKED_CAST")
50         firstStateRecord = value as StateMapStateRecord<K, V>
51     }
52 
53     /**
54      * Returns an immutable map containing all key-value pairs from the original map.
55      *
56      * The content of the map returned will not change even if the content of the map is changed in
57      * the same snapshot. It also will be the same instance until the content is changed. It is not,
58      * however, guaranteed to be the same instance for the same content as adding and removing the
59      * same item from the this map might produce a different instance with the same content.
60      *
61      * This operation is O(1) and does not involve a physically copying the map. It instead returns
62      * the underlying immutable map used internally to store the content of the map.
63      *
64      * It is recommended to use [toMap] when using returning the value of this map from
65      * [androidx.compose.runtime.snapshotFlow].
66      */
67     fun toMap(): Map<K, V> = readable.map
68 
69     override val size
70         get() = readable.map.size
71 
72     override fun containsKey(key: K) = readable.map.containsKey(key)
73 
74     override fun containsValue(value: V) = readable.map.containsValue(value)
75 
76     override fun get(key: K) = readable.map[key]
77 
78     override fun isEmpty() = readable.map.isEmpty()
79 
80     override val entries: MutableSet<MutableMap.MutableEntry<K, V>> = SnapshotMapEntrySet(this)
81     override val keys: MutableSet<K> = SnapshotMapKeySet(this)
82     override val values: MutableCollection<V> = SnapshotMapValueSet(this)
83 
84     @Suppress("UNCHECKED_CAST")
85     override fun toString(): String =
86         (firstStateRecord as StateMapStateRecord<K, V>).withCurrent {
87             "SnapshotStateMap(value=${it.map})@${hashCode()}"
88         }
89 
90     override fun clear() = update { persistentHashMapOf() }
91 
92     override fun put(key: K, value: V): V? = mutate { it.put(key, value) }
93 
94     override fun putAll(from: Map<out K, V>) = mutate { it.putAll(from) }
95 
96     override fun remove(key: K): V? = mutate { it.remove(key) }
97 
98     internal val modification
99         get() = readable.modification
100 
101     internal fun removeValue(value: V) =
102         entries
103             .firstOrNull { it.value == value }
104             ?.let {
105                 remove(it.key)
106                 true
107             } == true
108 
109     @Suppress("UNCHECKED_CAST")
110     internal val readable: StateMapStateRecord<K, V>
111         get() = (firstStateRecord as StateMapStateRecord<K, V>).readable(this)
112 
113     internal inline fun removeIf(predicate: (MutableMap.MutableEntry<K, V>) -> Boolean): Boolean {
114         var removed = false
115         mutate {
116             for (entry in this.entries) {
117                 if (predicate(entry)) {
118                     it.remove(entry.key)
119                     removed = true
120                 }
121             }
122         }
123         return removed
124     }
125 
126     internal inline fun any(predicate: (Map.Entry<K, V>) -> Boolean): Boolean {
127         for (entry in readable.map.entries) {
128             if (predicate(entry)) return true
129         }
130         return false
131     }
132 
133     internal inline fun all(predicate: (Map.Entry<K, V>) -> Boolean): Boolean {
134         for (entry in readable.map.entries) {
135             if (!predicate(entry)) return false
136         }
137         return true
138     }
139 
140     /**
141      * An internal function used by the debugger to display the value of the current value of the
142      * mutable state object without triggering read observers.
143      */
144     @Suppress("unused")
145     internal val debuggerDisplayValue: Map<K, V>
146         @JvmName("getDebuggerDisplayValue") get() = withCurrent { map }
147 
148     private inline fun <R> withCurrent(block: StateMapStateRecord<K, V>.() -> R): R =
149         @Suppress("UNCHECKED_CAST")
150         (firstStateRecord as StateMapStateRecord<K, V>).withCurrent(block)
151 
152     private inline fun <R> writable(block: StateMapStateRecord<K, V>.() -> R): R =
153         @Suppress("UNCHECKED_CAST")
154         (firstStateRecord as StateMapStateRecord<K, V>).writable(this, block)
155 
156     private inline fun <R> mutate(block: (MutableMap<K, V>) -> R): R {
157         var result: R
158         while (true) {
159             var oldMap: PersistentMap<K, V>? = null
160             var currentModification = 0
161             synchronized(sync) {
162                 val current = withCurrent { this }
163                 oldMap = current.map
164                 currentModification = current.modification
165             }
166             val builder = oldMap!!.builder()
167             result = block(builder)
168             val newMap = builder.build()
169             if (newMap == oldMap || writable { attemptUpdate(currentModification, newMap) }) break
170         }
171         return result
172     }
173 
174     private fun StateMapStateRecord<K, V>.attemptUpdate(
175         currentModification: Int,
176         newMap: PersistentMap<K, V>
177     ) =
178         synchronized(sync) {
179             if (modification == currentModification) {
180                 map = newMap
181                 modification++
182                 true
183             } else false
184         }
185 
186     private inline fun update(block: (PersistentMap<K, V>) -> PersistentMap<K, V>) = withCurrent {
187         val newMap = block(map)
188         if (newMap !== map) writable { commitUpdate(newMap) }
189     }
190 
191     // NOTE: do not inline this method to avoid class verification failures, see b/369909868
192     private fun StateMapStateRecord<K, V>.commitUpdate(newMap: PersistentMap<K, V>) =
193         synchronized(sync) {
194             map = newMap
195             modification++
196         }
197 
198     /** Implementation class of [SnapshotStateMap]. Do not use. */
199     internal class StateMapStateRecord<K, V>
200     internal constructor(snapshotId: SnapshotId, internal var map: PersistentMap<K, V>) :
201         StateRecord(snapshotId) {
202         internal var modification = 0
203 
204         override fun assign(value: StateRecord) {
205             @Suppress("UNCHECKED_CAST") val other = (value as StateMapStateRecord<K, V>)
206             synchronized(sync) {
207                 map = other.map
208                 modification = other.modification
209             }
210         }
211 
212         override fun create(): StateRecord = StateMapStateRecord(currentSnapshot().snapshotId, map)
213 
214         override fun create(snapshotId: SnapshotId): StateRecord =
215             StateMapStateRecord(snapshotId, map)
216     }
217 }
218 
219 private abstract class SnapshotMapSet<K, V, E>(val map: SnapshotStateMap<K, V>) : MutableSet<E> {
220     override val size: Int
221         get() = map.size
222 
clearnull223     override fun clear() = map.clear()
224 
225     override fun isEmpty() = map.isEmpty()
226 }
227 
228 private class SnapshotMapEntrySet<K, V>(map: SnapshotStateMap<K, V>) :
229     SnapshotMapSet<K, V, MutableMap.MutableEntry<K, V>>(map) {
230     override fun add(element: MutableMap.MutableEntry<K, V>) = unsupported()
231 
232     override fun addAll(elements: Collection<MutableMap.MutableEntry<K, V>>) = unsupported()
233 
234     override fun iterator(): MutableIterator<MutableMap.MutableEntry<K, V>> =
235         StateMapMutableEntriesIterator(map, map.readable.map.entries.iterator())
236 
237     override fun remove(element: MutableMap.MutableEntry<K, V>) = map.remove(element.key) != null
238 
239     override fun removeAll(elements: Collection<MutableMap.MutableEntry<K, V>>): Boolean {
240         var removed = false
241         for (element in elements) {
242             removed = map.remove(element.key) != null || removed
243         }
244         return removed
245     }
246 
247     override fun retainAll(elements: Collection<MutableMap.MutableEntry<K, V>>): Boolean {
248         val entries = elements.associate { it.key to it.value }
249         return map.removeIf { !entries.containsKey(it.key) || entries[it.key] != it.value }
250     }
251 
252     override fun contains(element: MutableMap.MutableEntry<K, V>): Boolean {
253         return map[element.key] == element.value
254     }
255 
256     override fun containsAll(elements: Collection<MutableMap.MutableEntry<K, V>>): Boolean {
257         return elements.all { contains(it) }
258     }
259 }
260 
261 private class SnapshotMapKeySet<K, V>(map: SnapshotStateMap<K, V>) : SnapshotMapSet<K, V, K>(map) {
addnull262     override fun add(element: K) = unsupported()
263 
264     override fun addAll(elements: Collection<K>) = unsupported()
265 
266     override fun iterator() = StateMapMutableKeysIterator(map, map.readable.map.entries.iterator())
267 
268     override fun remove(element: K): Boolean = map.remove(element) != null
269 
270     override fun removeAll(elements: Collection<K>): Boolean {
271         var removed = false
272         elements.forEach { removed = map.remove(it) != null || removed }
273         return removed
274     }
275 
retainAllnull276     override fun retainAll(elements: Collection<K>): Boolean {
277         val set = elements.toSet()
278         return map.removeIf { it.key !in set }
279     }
280 
containsnull281     override fun contains(element: K) = map.contains(element)
282 
283     override fun containsAll(elements: Collection<K>): Boolean = elements.all { map.contains(it) }
284 }
285 
286 private class SnapshotMapValueSet<K, V>(map: SnapshotStateMap<K, V>) :
287     SnapshotMapSet<K, V, V>(map) {
addnull288     override fun add(element: V) = unsupported()
289 
290     override fun addAll(elements: Collection<V>) = unsupported()
291 
292     override fun iterator() =
293         StateMapMutableValuesIterator(map, map.readable.map.entries.iterator())
294 
295     override fun remove(element: V): Boolean = map.removeValue(element)
296 
297     override fun removeAll(elements: Collection<V>): Boolean {
298         val set = elements.toSet()
299         return map.removeIf { it.value in set }
300     }
301 
retainAllnull302     override fun retainAll(elements: Collection<V>): Boolean {
303         val set = elements.toSet()
304         return map.removeIf { it.value !in set }
305     }
306 
containsnull307     override fun contains(element: V) = map.containsValue(element)
308 
309     override fun containsAll(elements: Collection<V>): Boolean {
310         return elements.all { map.containsValue(it) }
311     }
312 }
313 
314 /**
315  * This lock is used to ensure that the value of modification and the map in the state record, when
316  * used together, are atomically read and written.
317  *
318  * A global sync object is used to avoid having to allocate a sync object and initialize a monitor
319  * for each instance the map. This avoids additional allocations but introduces some contention
320  * between maps. As there is already contention on the global snapshot lock to write so the
321  * additional contention introduced by this lock is nominal.
322  *
323  * In code the requires this lock and calls `writable` (or other operation that acquires the
324  * snapshot global lock), this lock *MUST* be acquired last to avoid deadlocks. In other words, the
325  * lock must be taken in the `writable` lambda, if `writable` is used.
326  */
327 private val sync = makeSynchronizedObject()
328 
329 private abstract class StateMapMutableIterator<K, V>(
330     val map: SnapshotStateMap<K, V>,
331     val iterator: Iterator<Map.Entry<K, V>>
332 ) {
333     protected var modification = map.modification
334     protected var current: Map.Entry<K, V>? = null
335     protected var next: Map.Entry<K, V>? = null
336 
337     init {
338         advance()
339     }
340 
<lambda>null341     fun remove() = modify {
342         val value = current
343 
344         if (value != null) {
345             map.remove(value.key)
346             current = null
347         } else {
348             throw IllegalStateException()
349         }
350     }
351 
hasNextnull352     fun hasNext() = next != null
353 
354     protected fun advance() {
355         current = next
356         next = if (iterator.hasNext()) iterator.next() else null
357     }
358 
modifynull359     protected inline fun <T> modify(block: () -> T): T {
360         if (map.modification != modification) {
361             throw ConcurrentModificationException()
362         }
363         return block().also { modification = map.modification }
364     }
365 }
366 
367 private class StateMapMutableEntriesIterator<K, V>(
368     map: SnapshotStateMap<K, V>,
369     iterator: Iterator<Map.Entry<K, V>>
370 ) : StateMapMutableIterator<K, V>(map, iterator), MutableIterator<MutableMap.MutableEntry<K, V>> {
nextnull371     override fun next(): MutableMap.MutableEntry<K, V> {
372         advance()
373         if (current != null) {
374             return object : MutableMap.MutableEntry<K, V> {
375                 override val key = current!!.key
376                 override var value = current!!.value
377 
378                 override fun setValue(newValue: V): V = modify {
379                     val result = value
380                     map[key] = newValue
381                     value = newValue
382                     return result
383                 }
384             }
385         } else {
386             throw IllegalStateException()
387         }
388     }
389 }
390 
391 private class StateMapMutableKeysIterator<K, V>(
392     map: SnapshotStateMap<K, V>,
393     iterator: Iterator<Map.Entry<K, V>>
394 ) : StateMapMutableIterator<K, V>(map, iterator), MutableIterator<K> {
nextnull395     override fun next(): K {
396         val result = next ?: throw IllegalStateException()
397         advance()
398         return result.key
399     }
400 }
401 
402 private class StateMapMutableValuesIterator<K, V>(
403     map: SnapshotStateMap<K, V>,
404     iterator: Iterator<Map.Entry<K, V>>
405 ) : StateMapMutableIterator<K, V>(map, iterator), MutableIterator<V> {
nextnull406     override fun next(): V {
407         val result = next ?: throw IllegalStateException()
408         advance()
409         return result.value
410     }
411 }
412 
unsupportednull413 internal fun unsupported(): Nothing {
414     throw UnsupportedOperationException()
415 }
416