1 /*
2  * Copyright 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.compose.runtime.snapshots
18 
19 import androidx.compose.runtime.TestOnly
20 import androidx.compose.runtime.collection.fastCopyInto
21 import androidx.compose.runtime.internal.WeakReference
22 import androidx.compose.runtime.internal.identityHashCode
23 
24 private const val INITIAL_CAPACITY = 16
25 
26 /**
27  * A set of values references where the values are held weakly.
28  *
29  * This doesn't implement the entire Set<T> API and only implements those methods that are needed
30  * for use in [Snapshot].
31  *
32  * [add], [find] and [findExactIndex] are copied from IdentityArraySet and refined to use weak
33  * references. Any bugs found in these methods are likely to also be in IdentityArraySet and vis
34  * versa.
35  */
36 internal class SnapshotWeakSet<T : Any> {
37     /**
38      * The size of the set. The set has at most [size] entries but could have fewer if any of the
39      * entries have been collected.
40      */
41     internal var size: Int = 0
42 
43     /**
44      * Hashes are kept separately as the original object might not be available but its hash is
45      * required to be available as the entries are stored in hash order and found via a binary
46      * search.
47      */
48     internal var hashes = IntArray(INITIAL_CAPACITY)
49     internal var values: Array<WeakReference<T>?> = arrayOfNulls(INITIAL_CAPACITY)
50 
51     /**
52      * Add [value] to the set and return `true` if it was added or `false` if it already existed.
53      */
addnull54     fun add(value: T): Boolean {
55         val index: Int
56         val size = size
57         val hash = identityHashCode(value)
58         if (size > 0) {
59             index = find(value, hash)
60 
61             if (index >= 0) {
62                 return false
63             }
64         } else {
65             index = -1
66         }
67 
68         val insertIndex = -(index + 1)
69         val capacity = values.size
70         if (size == capacity) {
71             val newCapacity = capacity * 2
72             val newValues = arrayOfNulls<WeakReference<T>?>(newCapacity)
73             val newHashes = IntArray(newCapacity)
74             values.fastCopyInto(
75                 destination = newValues,
76                 destinationOffset = insertIndex + 1,
77                 startIndex = insertIndex,
78                 endIndex = size
79             )
80             values.fastCopyInto(
81                 destination = newValues,
82                 destinationOffset = 0,
83                 startIndex = 0,
84                 endIndex = insertIndex
85             )
86             hashes.copyInto(
87                 destination = newHashes,
88                 destinationOffset = insertIndex + 1,
89                 startIndex = insertIndex,
90                 endIndex = size
91             )
92             hashes.copyInto(destination = newHashes, endIndex = insertIndex)
93             values = newValues
94             hashes = newHashes
95         } else {
96             values.fastCopyInto(
97                 destination = values,
98                 destinationOffset = insertIndex + 1,
99                 startIndex = insertIndex,
100                 endIndex = size
101             )
102             hashes.copyInto(
103                 destination = hashes,
104                 destinationOffset = insertIndex + 1,
105                 startIndex = insertIndex,
106                 endIndex = size
107             )
108         }
109 
110         // A hole for the new items has been opened with the arrays, add the element there.
111         values[insertIndex] = WeakReference(value)
112         hashes[insertIndex] = hash
113         this.size++
114         return true
115     }
116 
117     /**
118      * Remove an entry from the set if [block] returns true.
119      *
120      * This also will discard any weak references that are no longer referring to their objects.
121      *
122      * This call is inline to avoid allocations while enumerating the set.
123      */
removeIfnull124     inline fun removeIf(block: (T) -> Boolean) {
125         val size = size
126         var currentUsed = 0
127         // Call `block` on all entries that still have a valid reference
128         // removing entries that are not valid or return `true` from block.
129         for (i in 0 until size) {
130             val entry = values[i]
131             val value = entry?.get()
132             if (value != null && !block(value)) {
133                 // We are keeping this entry
134                 if (currentUsed != i) {
135                     values[currentUsed] = entry
136                     hashes[currentUsed] = hashes[i]
137                 }
138                 currentUsed++
139             }
140         }
141 
142         // Clear the remaining entries
143         for (i in currentUsed until size) {
144             values[i] = null
145             hashes[i] = 0
146         }
147 
148         // Adjust the size to match number of slots left.
149         if (currentUsed != size) {
150             this.size = currentUsed
151         }
152     }
153 
154     /**
155      * Returns the index of [value] in the set or the negative index - 1 of the location where it
156      * would have been if it had been in the set.
157      */
findnull158     private fun find(value: T, hash: Int): Int {
159         var low = 0
160         var high = size - 1
161 
162         while (low <= high) {
163             val mid = (low + high).ushr(1)
164             val midHash = hashes[mid]
165             when {
166                 midHash < hash -> low = mid + 1
167                 midHash > hash -> high = mid - 1
168                 else -> {
169                     val midVal = values[mid]?.get()
170                     if (value === midVal) return mid
171                     return findExactIndex(mid, value, hash)
172                 }
173             }
174         }
175         return -(low + 1)
176     }
177 
178     /**
179      * When multiple items share the same [identityHashCode], then we must find the specific index
180      * of the target item. This method assumes that [midIndex] has already been checked for an exact
181      * match for [value], but will look at nearby values to find the exact item index. If no match
182      * is found, the negative index - 1 of the position in which it would be will be returned, which
183      * is always after the last item with the same [identityHashCode].
184      */
findExactIndexnull185     private fun findExactIndex(midIndex: Int, value: T, valueHash: Int): Int {
186         // hunt down first
187         for (i in midIndex - 1 downTo 0) {
188             if (hashes[i] != valueHash) {
189                 break // we've gone too far
190             }
191             val v = values[i]?.get()
192             if (v === value) {
193                 return i
194             }
195         }
196 
197         for (i in midIndex + 1 until size) {
198             if (hashes[i] != valueHash) {
199                 // We've gone too far. We should insert here.
200                 return -(i + 1)
201             }
202             val v = values[i]?.get()
203             if (v === value) {
204                 return i
205             }
206         }
207 
208         // We should insert at the end
209         return -(size + 1)
210     }
211 
212     @TestOnly
isValidnull213     internal fun isValid(): Boolean {
214         val size = size
215         val values = values
216         val hashes = hashes
217         val capacity = values.size
218 
219         // Validate that the size is less than or equal to the capacity
220         if (size > capacity) return false
221 
222         // Validate that the hashes are in order and they match identity hash of the value or
223         // the value has been collected.
224         var previous = Int.MIN_VALUE
225         for (i in 0 until size) {
226             val hash = hashes[i]
227             if (hash < previous) return false
228             val entry = values[i] ?: return false
229             val value = entry.get()
230             if (value != null && hash != identityHashCode(value)) return false
231             previous = hash
232         }
233 
234         // Validate that all hashes and entries size and above are empty
235         for (i in size until capacity) {
236             if (hashes[i] != 0) return false
237             if (values[i] != null) return false
238         }
239 
240         return true
241     }
242 }
243