• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines.debug.internal
6 
7 import kotlinx.atomicfu.*
8 import kotlinx.coroutines.internal.*
9 import java.lang.ref.*
10 
11 // This is very limited implementation, not suitable as a generic map replacement.
12 // It has lock-free get and put with synchronized rehash for simplicity (and better CPU usage on contention)
13 @OptIn(ExperimentalStdlibApi::class)
14 @Suppress("UNCHECKED_CAST")
15 internal class ConcurrentWeakMap<K : Any, V: Any>(
16     /**
17      * Weak reference queue is needed when a small key is mapped to a large value and we need to promptly release a
18      * reference to the value when the key was already disposed.
19      */
20     weakRefQueue: Boolean = false
21 ) : AbstractMutableMap<K, V>() {
22     private val _size = atomic(0)
23     private val core = atomic(Core(MIN_CAPACITY))
24     private val weakRefQueue: ReferenceQueue<K>? = if (weakRefQueue) ReferenceQueue() else null
25 
26     override val size: Int
27         get() = _size.value
28 
29     private fun decrementSize() { _size.decrementAndGet() }
30 
31     override fun get(key: K): V? = core.value.getImpl(key)
32 
33     override fun put(key: K, value: V): V? {
34         var oldValue = core.value.putImpl(key, value)
35         if (oldValue === REHASH) oldValue = putSynchronized(key, value)
36         if (oldValue == null) _size.incrementAndGet()
37         return oldValue as V?
38     }
39 
40     override fun remove(key: K): V? {
41         var oldValue = core.value.putImpl(key, null)
42         if (oldValue === REHASH) oldValue = putSynchronized(key, null)
43         if (oldValue != null) _size.decrementAndGet()
44         return oldValue as V?
45     }
46 
47     @Synchronized
48     private fun putSynchronized(key: K, value: V?): V? {
49         // Note: concurrent put leaves chance that we fail to put even after rehash, we retry until successful
50         var curCore = core.value
51         while (true) {
52             val oldValue = curCore.putImpl(key, value)
53             if (oldValue !== REHASH) return oldValue as V?
54             curCore = curCore.rehash()
55             core.value = curCore
56         }
57     }
58 
59     override val keys: MutableSet<K>
60         get() = KeyValueSet { k, _ -> k }
61 
62     override val entries: MutableSet<MutableMap.MutableEntry<K, V>>
63         get() = KeyValueSet { k, v -> Entry(k, v) }
64 
65     // We don't care much about clear's efficiency
66     override fun clear() {
67         for (k in keys) remove(k)
68     }
69 
70     fun runWeakRefQueueCleaningLoopUntilInterrupted() {
71         check(weakRefQueue != null) { "Must be created with weakRefQueue = true" }
72         try {
73             while (true) {
74                 cleanWeakRef(weakRefQueue.remove() as HashedWeakRef<*>)
75             }
76         } catch(e: InterruptedException) {
77             Thread.currentThread().interrupt()
78         }
79     }
80 
81     private fun cleanWeakRef(w: HashedWeakRef<*>) {
82         core.value.cleanWeakRef(w)
83     }
84 
85     @Suppress("UNCHECKED_CAST")
86     private inner class Core(private val allocated: Int) {
87         private val shift = allocated.countLeadingZeroBits() + 1
88         private val threshold = 2 * allocated / 3 // max fill factor at 66% to ensure speedy lookups
89         private val load = atomic(0) // counts how many slots are occupied in this core
90         private val keys = atomicArrayOfNulls<HashedWeakRef<K>?>(allocated)
91         private val values = atomicArrayOfNulls<Any?>(allocated)
92 
93         private fun index(hash: Int) = (hash * MAGIC) ushr shift
94 
95         // get is always lock-free, unwraps the value that was marked by concurrent rehash
96         fun getImpl(key: K): V? {
97             var index = index(key.hashCode())
98             while (true) {
99                 val w = keys[index].value ?: return null // not found
100                 val k = w.get()
101                 if (key == k) {
102                     val value = values[index].value
103                     return (if (value is Marked) value.ref else value) as V?
104                 }
105                 if (k == null) removeCleanedAt(index) // weak ref was here, but collected
106                 if (index == 0) index = allocated
107                 index--
108             }
109         }
110 
111         private fun removeCleanedAt(index: Int) {
112             while (true) {
113                 val oldValue = values[index].value ?: return // return when already removed
114                 if (oldValue is Marked) return // cannot remove marked (rehash is working on it, will not copy)
115                 if (values[index].compareAndSet(oldValue, null)) { // removed
116                     decrementSize()
117                     return
118                 }
119             }
120         }
121 
122         // returns REHASH when rehash is needed (the value was not put)
123         fun putImpl(key: K, value: V?, weakKey0: HashedWeakRef<K>? = null): Any? {
124             var index = index(key.hashCode())
125             var loadIncremented = false
126             var weakKey: HashedWeakRef<K>? = weakKey0
127             while (true) {
128                 val w = keys[index].value
129                 if (w == null) { // slot empty => not found => try reserving slot
130                     if (value == null) return null // removing missing value, nothing to do here
131                     if (!loadIncremented) {
132                         // We must increment load before we even try to occupy a slot to avoid overfill during concurrent put
133                         load.update { n ->
134                             if (n >= threshold) return REHASH // the load is already too big -- rehash
135                             n + 1 // otherwise increment
136                         }
137                         loadIncremented = true
138                     }
139                     if (weakKey == null) weakKey = HashedWeakRef(key, weakRefQueue)
140                     if (keys[index].compareAndSet(null, weakKey)) break // slot reserved !!!
141                     continue // retry at this slot on CAS failure (somebody already reserved this slot)
142                 }
143                 val k = w.get()
144                 if (key == k) { // found already reserved slot at index
145                     if (loadIncremented) load.decrementAndGet() // undo increment, because found a slot
146                     break
147                 }
148                 if (k == null) removeCleanedAt(index) // weak ref was here, but collected
149                 if (index == 0) index = allocated
150                 index--
151             }
152             // update value
153             var oldValue: Any?
154             while (true) {
155                 oldValue = values[index].value
156                 if (oldValue is Marked) return REHASH // rehash started, cannot work here
157                 if (values[index].compareAndSet(oldValue, value)) break
158             }
159             return oldValue as V?
160         }
161 
162         // only one thread can rehash, but may have concurrent puts/gets
163         fun rehash(): Core {
164             // use size to approximate new required capacity to have at least 25-50% fill factor,
165             // may fail due to concurrent modification, will retry
166             retry@while (true) {
167                 val newCapacity = size.coerceAtLeast(MIN_CAPACITY / 4).takeHighestOneBit() * 4
168                 val newCore = Core(newCapacity)
169                 for (index in 0 until allocated) {
170                     // load the key
171                     val w = keys[index].value
172                     val k = w?.get()
173                     if (w != null && k == null) removeCleanedAt(index) // weak ref was here, but collected
174                     // mark value so that it cannot be changed while we rehash to new core
175                     var value: Any?
176                     while (true) {
177                         value = values[index].value
178                         if (value is Marked) { // already marked -- good
179                             value = value.ref
180                             break
181                         }
182                         // try mark
183                         if (values[index].compareAndSet(value, value.mark())) break
184                     }
185                     if (k != null && value != null) {
186                         val oldValue = newCore.putImpl(k, value as V, w)
187                         if (oldValue === REHASH) continue@retry // retry if we underestimated capacity
188                         assert(oldValue == null)
189                     }
190                 }
191                 return newCore // rehashed everything successfully
192             }
193         }
194 
195         fun cleanWeakRef(weakRef: HashedWeakRef<*>) {
196             var index = index(weakRef.hash)
197             while (true) {
198                 val w = keys[index].value ?: return // return when slots are over
199                 if (w === weakRef) { // found
200                     removeCleanedAt(index)
201                     return
202                 }
203                 if (index == 0) index = allocated
204                 index--
205             }
206         }
207 
208         fun <E> keyValueIterator(factory: (K, V) -> E): MutableIterator<E> = KeyValueIterator(factory)
209 
210         private inner class KeyValueIterator<E>(private val factory: (K, V) -> E) : MutableIterator<E> {
211             private var index = -1
212             private lateinit var key: K
213             private lateinit var value: V
214 
215             init { findNext() }
216 
217             private fun findNext() {
218                 while (++index < allocated) {
219                     key = keys[index].value?.get() ?: continue
220                     var value = values[index].value
221                     if (value is Marked) value = value.ref
222                     if (value != null) {
223                         this.value = value as V
224                         return
225                     }
226                 }
227             }
228 
229             override fun hasNext(): Boolean = index < allocated
230 
231             override fun next(): E {
232                 if (index >= allocated) throw NoSuchElementException()
233                 return factory(key, value).also { findNext() }
234             }
235 
236             override fun remove() = noImpl()
237         }
238     }
239 
240     private class Entry<K, V>(override val key: K, override val value: V) : MutableMap.MutableEntry<K, V> {
241         override fun setValue(newValue: V): V = noImpl()
242     }
243 
244     private inner class KeyValueSet<E>(
245         private val factory: (K, V) -> E
246     ) : AbstractMutableSet<E>() {
247         override val size: Int get() = this@ConcurrentWeakMap.size
248         override fun add(element: E): Boolean = noImpl()
249         override fun iterator(): MutableIterator<E> = core.value.keyValueIterator(factory)
250     }
251 }
252 
253 private const val MAGIC = 2654435769L.toInt() // golden ratio
254 private const val MIN_CAPACITY = 16
255 private val REHASH = Symbol("REHASH")
256 private val MARKED_NULL = Marked(null)
257 private val MARKED_TRUE = Marked(true) // When using map as set "true" used as value, optimize its mark allocation
258 
259 /**
260  * Weak reference that stores the original hash code so that we can use reference queue to promptly clean them up
261  * from the hashtable even in the absence of ongoing modifications.
262  */
263 internal class HashedWeakRef<T>(
264     ref: T, queue: ReferenceQueue<T>?
265 ) : WeakReference<T>(ref, queue) {
266     @JvmField
267     val hash = ref.hashCode()
268 }
269 
270 /**
271  * Marked values cannot be modified. The marking is performed when rehash has started to ensure that concurrent
272  * modifications (that are lock-free) cannot perform any changes and are forced to synchronize with ongoing rehash.
273  */
274 private class Marked(@JvmField val ref: Any?)
275 
Anynull276 private fun Any?.mark(): Marked = when(this) {
277     null -> MARKED_NULL
278     true -> MARKED_TRUE
279     else -> Marked(this)
280 }
281 
noImplnull282 private fun noImpl(): Nothing {
283     throw UnsupportedOperationException("not implemented")
284 }
285