1 package kotlinx.coroutines.internal
2
3 import kotlinx.coroutines.*
4 import kotlin.coroutines.*
5
6 @JvmField
7 internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")
8
9 // Used when there are >= 2 active elements in the context
10 @Suppress("UNCHECKED_CAST")
11 private class ThreadState(@JvmField val context: CoroutineContext, n: Int) {
12 private val values = arrayOfNulls<Any>(n)
13 private val elements = arrayOfNulls<ThreadContextElement<Any?>>(n)
14 private var i = 0
15
appendnull16 fun append(element: ThreadContextElement<*>, value: Any?) {
17 values[i] = value
18 elements[i++] = element as ThreadContextElement<Any?>
19 }
20
restorenull21 fun restore(context: CoroutineContext) {
22 for (i in elements.indices.reversed()) {
23 elements[i]!!.restoreThreadContext(context, values[i])
24 }
25 }
26 }
27
28 // Counts ThreadContextElements in the context
29 // Any? here is Int | ThreadContextElement (when count is one)
30 private val countAll =
<anonymous>null31 fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
32 if (element is ThreadContextElement<*>) {
33 val inCount = countOrElement as? Int ?: 1
34 return if (inCount == 0) element else inCount + 1
35 }
36 return countOrElement
37 }
38
39 // Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
40 private val findOne =
<anonymous>null41 fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
42 if (found != null) return found
43 return element as? ThreadContextElement<*>
44 }
45
46 // Updates state for ThreadContextElements in the context using the given ThreadState
47 private val updateState =
<anonymous>null48 fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
49 if (element is ThreadContextElement<*>) {
50 state.append(element, element.updateThreadContext(state.context))
51 }
52 return state
53 }
54
threadContextElementsnull55 internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!
56
57 // countOrElement is pre-cached in dispatched continuation
58 // returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements
59 internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
60 @Suppress("NAME_SHADOWING")
61 val countOrElement = countOrElement ?: threadContextElements(context)
62 @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
63 return when {
64 countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements
65 // ^^^ identity comparison for speed, we know zero always has the same identity
66 countOrElement is Int -> {
67 // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
68 context.fold(ThreadState(context, countOrElement), updateState)
69 }
70 else -> {
71 // fast path for one ThreadContextElement (no allocations, no additional context scan)
72 @Suppress("UNCHECKED_CAST")
73 val element = countOrElement as ThreadContextElement<Any?>
74 element.updateThreadContext(context)
75 }
76 }
77 }
78
restoreThreadContextnull79 internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
80 when {
81 oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
82 oldState is ThreadState -> {
83 // slow path with multiple stored ThreadContextElements
84 oldState.restore(context)
85 }
86 else -> {
87 // fast path for one ThreadContextElement, but need to find it
88 @Suppress("UNCHECKED_CAST")
89 val element = context.fold(null, findOne) as ThreadContextElement<Any?>
90 element.restoreThreadContext(context, oldState)
91 }
92 }
93 }
94
95 // top-level data class for a nicer out-of-the-box toString representation and class name
96 @PublishedApi
97 internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
98
99 internal class ThreadLocalElement<T>(
100 private val value: T,
101 private val threadLocal: ThreadLocal<T>
102 ) : ThreadContextElement<T> {
103 override val key: CoroutineContext.Key<*> = ThreadLocalKey(threadLocal)
104
updateThreadContextnull105 override fun updateThreadContext(context: CoroutineContext): T {
106 val oldState = threadLocal.get()
107 threadLocal.set(value)
108 return oldState
109 }
110
restoreThreadContextnull111 override fun restoreThreadContext(context: CoroutineContext, oldState: T) {
112 threadLocal.set(oldState)
113 }
114
115 // this method is overridden to perform value comparison (==) on key
minusKeynull116 override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext {
117 return if (this.key == key) EmptyCoroutineContext else this
118 }
119
120 // this method is overridden to perform value comparison (==) on key
getnull121 public override operator fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? =
122 @Suppress("UNCHECKED_CAST")
123 if (this.key == key) this as E else null
124
125 override fun toString(): String = "ThreadLocal(value=$value, threadLocal = $threadLocal)"
126 }
127