1 /*
<lambda>null2  * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines.debug.internal
6 
7 import kotlinx.atomicfu.*
8 import kotlinx.coroutines.internal.*
9 import java.lang.ref.*
10 
11 // This is very limited implementation, not suitable as a generic map replacement.
12 // It has lock-free get and put with synchronized rehash for simplicity (and better CPU usage on contention)
13 @Suppress("UNCHECKED_CAST")
14 internal class ConcurrentWeakMap<K : Any, V: Any>(
15     /**
16      * Weak reference queue is needed when a small key is mapped to a large value, and we need to promptly release a
17      * reference to the value when the key was already disposed.
18      */
19     weakRefQueue: Boolean = false
20 ) : AbstractMutableMap<K, V>() {
21     private val _size = atomic(0)
22     private val core = atomic(Core(MIN_CAPACITY))
23     private val weakRefQueue: ReferenceQueue<K>? = if (weakRefQueue) ReferenceQueue() else null
24 
25     override val size: Int
26         get() = _size.value
27 
28     private fun decrementSize() { _size.decrementAndGet() }
29 
30     override fun get(key: K): V? = core.value.getImpl(key)
31 
32     override fun put(key: K, value: V): V? {
33         var oldValue = core.value.putImpl(key, value)
34         if (oldValue === REHASH) oldValue = putSynchronized(key, value)
35         if (oldValue == null) _size.incrementAndGet()
36         return oldValue as V?
37     }
38 
39     override fun remove(key: K): V? {
40         var oldValue = core.value.putImpl(key, null)
41         if (oldValue === REHASH) oldValue = putSynchronized(key, null)
42         if (oldValue != null) _size.decrementAndGet()
43         return oldValue as V?
44     }
45 
46     @Synchronized
47     private fun putSynchronized(key: K, value: V?): V? {
48         // Note: concurrent put leaves chance that we fail to put even after rehash, we retry until successful
49         var curCore = core.value
50         while (true) {
51             val oldValue = curCore.putImpl(key, value)
52             if (oldValue !== REHASH) return oldValue as V?
53             curCore = curCore.rehash()
54             core.value = curCore
55         }
56     }
57 
58     override val keys: MutableSet<K>
59         get() = KeyValueSet { k, _ -> k }
60 
61     override val entries: MutableSet<MutableMap.MutableEntry<K, V>>
62         get() = KeyValueSet { k, v -> Entry(k, v) }
63 
64     // We don't care much about clear's efficiency
65     override fun clear() {
66         for (k in keys) remove(k)
67     }
68 
69     fun runWeakRefQueueCleaningLoopUntilInterrupted() {
70         check(weakRefQueue != null) { "Must be created with weakRefQueue = true" }
71         try {
72             while (true) {
73                 cleanWeakRef(weakRefQueue.remove() as HashedWeakRef<*>)
74             }
75         } catch (e: InterruptedException) {
76             Thread.currentThread().interrupt()
77         }
78     }
79 
80     private fun cleanWeakRef(w: HashedWeakRef<*>) {
81         core.value.cleanWeakRef(w)
82     }
83 
84     @Suppress("UNCHECKED_CAST")
85     private inner class Core(private val allocated: Int) {
86         private val shift = allocated.countLeadingZeroBits() + 1
87         private val threshold = 2 * allocated / 3 // max fill factor at 66% to ensure speedy lookups
88         private val load = atomic(0) // counts how many slots are occupied in this core
89         private val keys = atomicArrayOfNulls<HashedWeakRef<K>?>(allocated)
90         private val values = atomicArrayOfNulls<Any?>(allocated)
91 
92         private fun index(hash: Int) = (hash * MAGIC) ushr shift
93 
94         // get is always lock-free, unwraps the value that was marked by concurrent rehash
95         fun getImpl(key: K): V? {
96             var index = index(key.hashCode())
97             while (true) {
98                 val w = keys[index].value ?: return null // not found
99                 val k = w.get()
100                 if (key == k) {
101                     val value = values[index].value
102                     return (if (value is Marked) value.ref else value) as V?
103                 }
104                 if (k == null) removeCleanedAt(index) // weak ref was here, but collected
105                 if (index == 0) index = allocated
106                 index--
107             }
108         }
109 
110         private fun removeCleanedAt(index: Int) {
111             while (true) {
112                 val oldValue = values[index].value ?: return // return when already removed
113                 if (oldValue is Marked) return // cannot remove marked (rehash is working on it, will not copy)
114                 if (values[index].compareAndSet(oldValue, null)) { // removed
115                     decrementSize()
116                     return
117                 }
118             }
119         }
120 
121         // returns REHASH when rehash is needed (the value was not put)
122         fun putImpl(key: K, value: V?, weakKey0: HashedWeakRef<K>? = null): Any? {
123             var index = index(key.hashCode())
124             var loadIncremented = false
125             var weakKey: HashedWeakRef<K>? = weakKey0
126             while (true) {
127                 val w = keys[index].value
128                 if (w == null) { // slot empty => not found => try reserving slot
129                     if (value == null) return null // removing missing value, nothing to do here
130                     if (!loadIncremented) {
131                         // We must increment load before we even try to occupy a slot to avoid overfill during concurrent put
132                         load.update { n ->
133                             if (n >= threshold) return REHASH // the load is already too big -- rehash
134                             n + 1 // otherwise increment
135                         }
136                         loadIncremented = true
137                     }
138                     if (weakKey == null) weakKey = HashedWeakRef(key, weakRefQueue)
139                     if (keys[index].compareAndSet(null, weakKey)) break // slot reserved !!!
140                     continue // retry at this slot on CAS failure (somebody already reserved this slot)
141                 }
142                 val k = w.get()
143                 if (key == k) { // found already reserved slot at index
144                     if (loadIncremented) load.decrementAndGet() // undo increment, because found a slot
145                     break
146                 }
147                 if (k == null) removeCleanedAt(index) // weak ref was here, but collected
148                 if (index == 0) index = allocated
149                 index--
150             }
151             // update value
152             var oldValue: Any?
153             while (true) {
154                 oldValue = values[index].value
155                 if (oldValue is Marked) return REHASH // rehash started, cannot work here
156                 if (values[index].compareAndSet(oldValue, value)) break
157             }
158             return oldValue as V?
159         }
160 
161         // only one thread can rehash, but may have concurrent puts/gets
162         fun rehash(): Core {
163             // use size to approximate new required capacity to have at least 25-50% fill factor,
164             // may fail due to concurrent modification, will retry
165             retry@while (true) {
166                 val newCapacity = size.coerceAtLeast(MIN_CAPACITY / 4).takeHighestOneBit() * 4
167                 val newCore = Core(newCapacity)
168                 for (index in 0 until allocated) {
169                     // load the key
170                     val w = keys[index].value
171                     val k = w?.get()
172                     if (w != null && k == null) removeCleanedAt(index) // weak ref was here, but collected
173                     // mark value so that it cannot be changed while we rehash to new core
174                     var value: Any?
175                     while (true) {
176                         value = values[index].value
177                         if (value is Marked) { // already marked -- good
178                             value = value.ref
179                             break
180                         }
181                         // try mark
182                         if (values[index].compareAndSet(value, value.mark())) break
183                     }
184                     if (k != null && value != null) {
185                         val oldValue = newCore.putImpl(k, value as V, w)
186                         if (oldValue === REHASH) continue@retry // retry if we underestimated capacity
187                         assert(oldValue == null)
188                     }
189                 }
190                 return newCore // rehashed everything successfully
191             }
192         }
193 
194         fun cleanWeakRef(weakRef: HashedWeakRef<*>) {
195             var index = index(weakRef.hash)
196             while (true) {
197                 val w = keys[index].value ?: return // return when slots are over
198                 if (w === weakRef) { // found
199                     removeCleanedAt(index)
200                     return
201                 }
202                 if (index == 0) index = allocated
203                 index--
204             }
205         }
206 
207         fun <E> keyValueIterator(factory: (K, V) -> E): MutableIterator<E> = KeyValueIterator(factory)
208 
209         private inner class KeyValueIterator<E>(private val factory: (K, V) -> E) : MutableIterator<E> {
210             private var index = -1
211             private lateinit var key: K
212             private lateinit var value: V
213 
214             init { findNext() }
215 
216             private fun findNext() {
217                 while (++index < allocated) {
218                     key = keys[index].value?.get() ?: continue
219                     var value = values[index].value
220                     if (value is Marked) value = value.ref
221                     if (value != null) {
222                         this.value = value as V
223                         return
224                     }
225                 }
226             }
227 
228             override fun hasNext(): Boolean = index < allocated
229 
230             override fun next(): E {
231                 if (index >= allocated) throw NoSuchElementException()
232                 return factory(key, value).also { findNext() }
233             }
234 
235             override fun remove() = noImpl()
236         }
237     }
238 
239     private class Entry<K, V>(override val key: K, override val value: V) : MutableMap.MutableEntry<K, V> {
240         override fun setValue(newValue: V): V = noImpl()
241     }
242 
243     private inner class KeyValueSet<E>(
244         private val factory: (K, V) -> E
245     ) : AbstractMutableSet<E>() {
246         override val size: Int get() = this@ConcurrentWeakMap.size
247         override fun add(element: E): Boolean = noImpl()
248         override fun iterator(): MutableIterator<E> = core.value.keyValueIterator(factory)
249     }
250 }
251 
252 private const val MAGIC = 2654435769L.toInt() // golden ratio
253 private const val MIN_CAPACITY = 16
254 private val REHASH = Symbol("REHASH")
255 private val MARKED_NULL = Marked(null)
256 private val MARKED_TRUE = Marked(true) // When using map as set "true" used as value, optimize its mark allocation
257 
258 /**
259  * Weak reference that stores the original hash code so that we can use reference queue to promptly clean them up
260  * from the hashtable even in the absence of ongoing modifications.
261  */
262 internal class HashedWeakRef<T>(
263     ref: T, queue: ReferenceQueue<T>?
264 ) : WeakReference<T>(ref, queue) {
265     @JvmField
266     val hash = ref.hashCode()
267 }
268 
269 /**
270  * Marked values cannot be modified. The marking is performed when rehash has started to ensure that concurrent
271  * modifications (that are lock-free) cannot perform any changes and are forced to synchronize with ongoing rehash.
272  */
273 private class Marked(@JvmField val ref: Any?)
274 
Anynull275 private fun Any?.mark(): Marked = when(this) {
276     null -> MARKED_NULL
277     true -> MARKED_TRUE
278     else -> Marked(this)
279 }
280 
noImplnull281 private fun noImpl(): Nothing {
282     throw UnsupportedOperationException("not implemented")
283 }
284