• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 package kotlinx.coroutines.internal
2 
3 import kotlinx.coroutines.*
4 import kotlin.coroutines.*
5 
6 @JvmField
7 internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")
8 
9 // Used when there are >= 2 active elements in the context
10 @Suppress("UNCHECKED_CAST")
11 private class ThreadState(@JvmField val context: CoroutineContext, n: Int) {
12     private val values = arrayOfNulls<Any>(n)
13     private val elements = arrayOfNulls<ThreadContextElement<Any?>>(n)
14     private var i = 0
15 
appendnull16     fun append(element: ThreadContextElement<*>, value: Any?) {
17         values[i] = value
18         elements[i++] = element as ThreadContextElement<Any?>
19     }
20 
restorenull21     fun restore(context: CoroutineContext) {
22         for (i in elements.indices.reversed()) {
23             elements[i]!!.restoreThreadContext(context, values[i])
24         }
25     }
26 }
27 
28 // Counts ThreadContextElements in the context
29 // Any? here is Int | ThreadContextElement (when count is one)
30 private val countAll =
<anonymous>null31     fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
32         if (element is ThreadContextElement<*>) {
33             val inCount = countOrElement as? Int ?: 1
34             return if (inCount == 0) element else inCount + 1
35         }
36         return countOrElement
37     }
38 
39 // Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
40 private val findOne =
<anonymous>null41     fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
42         if (found != null) return found
43         return element as? ThreadContextElement<*>
44     }
45 
46 // Updates state for ThreadContextElements in the context using the given ThreadState
47 private val updateState =
<anonymous>null48     fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
49         if (element is ThreadContextElement<*>) {
50             state.append(element, element.updateThreadContext(state.context))
51         }
52         return state
53     }
54 
threadContextElementsnull55 internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!
56 
57 // countOrElement is pre-cached in dispatched continuation
58 // returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements
59 internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
60     @Suppress("NAME_SHADOWING")
61     val countOrElement = countOrElement ?: threadContextElements(context)
62     @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
63     return when {
64         countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements
65         //    ^^^ identity comparison for speed, we know zero always has the same identity
66         countOrElement is Int -> {
67             // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
68             context.fold(ThreadState(context, countOrElement), updateState)
69         }
70         else -> {
71             // fast path for one ThreadContextElement (no allocations, no additional context scan)
72             @Suppress("UNCHECKED_CAST")
73             val element = countOrElement as ThreadContextElement<Any?>
74             element.updateThreadContext(context)
75         }
76     }
77 }
78 
restoreThreadContextnull79 internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
80     when {
81         oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
82         oldState is ThreadState -> {
83             // slow path with multiple stored ThreadContextElements
84             oldState.restore(context)
85         }
86         else -> {
87             // fast path for one ThreadContextElement, but need to find it
88             @Suppress("UNCHECKED_CAST")
89             val element = context.fold(null, findOne) as ThreadContextElement<Any?>
90             element.restoreThreadContext(context, oldState)
91         }
92     }
93 }
94 
95 // top-level data class for a nicer out-of-the-box toString representation and class name
96 @PublishedApi
97 internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
98 
99 internal class ThreadLocalElement<T>(
100     private val value: T,
101     private val threadLocal: ThreadLocal<T>
102 ) : ThreadContextElement<T> {
103     override val key: CoroutineContext.Key<*> = ThreadLocalKey(threadLocal)
104 
updateThreadContextnull105     override fun updateThreadContext(context: CoroutineContext): T {
106         val oldState = threadLocal.get()
107         threadLocal.set(value)
108         return oldState
109     }
110 
restoreThreadContextnull111     override fun restoreThreadContext(context: CoroutineContext, oldState: T) {
112         threadLocal.set(oldState)
113     }
114 
115     // this method is overridden to perform value comparison (==) on key
minusKeynull116     override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext {
117         return if (this.key == key) EmptyCoroutineContext else this
118     }
119 
120     // this method is overridden to perform value comparison (==) on key
getnull121     public override operator fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? =
122         @Suppress("UNCHECKED_CAST")
123         if (this.key == key) this as E else null
124 
125     override fun toString(): String = "ThreadLocal(value=$value, threadLocal = $threadLocal)"
126 }
127