1 /*
2 * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3 */
4
5 package kotlinx.coroutines.internal
6
7 import kotlinx.coroutines.*
8 import kotlin.coroutines.*
9
10 @JvmField
11 internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")
12
13 // Used when there are >= 2 active elements in the context
14 @Suppress("UNCHECKED_CAST")
15 private class ThreadState(@JvmField val context: CoroutineContext, n: Int) {
16 private val values = arrayOfNulls<Any>(n)
17 private val elements = arrayOfNulls<ThreadContextElement<Any?>>(n)
18 private var i = 0
19
appendnull20 fun append(element: ThreadContextElement<*>, value: Any?) {
21 values[i] = value
22 elements[i++] = element as ThreadContextElement<Any?>
23 }
24
restorenull25 fun restore(context: CoroutineContext) {
26 for (i in elements.indices.reversed()) {
27 elements[i]!!.restoreThreadContext(context, values[i])
28 }
29 }
30 }
31
32 // Counts ThreadContextElements in the context
33 // Any? here is Int | ThreadContextElement (when count is one)
34 private val countAll =
<anonymous>null35 fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
36 if (element is ThreadContextElement<*>) {
37 val inCount = countOrElement as? Int ?: 1
38 return if (inCount == 0) element else inCount + 1
39 }
40 return countOrElement
41 }
42
43 // Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
44 private val findOne =
<anonymous>null45 fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
46 if (found != null) return found
47 return element as? ThreadContextElement<*>
48 }
49
50 // Updates state for ThreadContextElements in the context using the given ThreadState
51 private val updateState =
<anonymous>null52 fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
53 if (element is ThreadContextElement<*>) {
54 state.append(element, element.updateThreadContext(state.context))
55 }
56 return state
57 }
58
threadContextElementsnull59 internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!
60
61 // countOrElement is pre-cached in dispatched continuation
62 // returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements
63 internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
64 @Suppress("NAME_SHADOWING")
65 val countOrElement = countOrElement ?: threadContextElements(context)
66 @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
67 return when {
68 countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements
69 // ^^^ identity comparison for speed, we know zero always has the same identity
70 countOrElement is Int -> {
71 // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
72 context.fold(ThreadState(context, countOrElement), updateState)
73 }
74 else -> {
75 // fast path for one ThreadContextElement (no allocations, no additional context scan)
76 @Suppress("UNCHECKED_CAST")
77 val element = countOrElement as ThreadContextElement<Any?>
78 element.updateThreadContext(context)
79 }
80 }
81 }
82
restoreThreadContextnull83 internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
84 when {
85 oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
86 oldState is ThreadState -> {
87 // slow path with multiple stored ThreadContextElements
88 oldState.restore(context)
89 }
90 else -> {
91 // fast path for one ThreadContextElement, but need to find it
92 @Suppress("UNCHECKED_CAST")
93 val element = context.fold(null, findOne) as ThreadContextElement<Any?>
94 element.restoreThreadContext(context, oldState)
95 }
96 }
97 }
98
99 // top-level data class for a nicer out-of-the-box toString representation and class name
100 @PublishedApi
101 internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
102
103 internal class ThreadLocalElement<T>(
104 private val value: T,
105 private val threadLocal: ThreadLocal<T>
106 ) : ThreadContextElement<T> {
107 override val key: CoroutineContext.Key<*> = ThreadLocalKey(threadLocal)
108
updateThreadContextnull109 override fun updateThreadContext(context: CoroutineContext): T {
110 val oldState = threadLocal.get()
111 threadLocal.set(value)
112 return oldState
113 }
114
restoreThreadContextnull115 override fun restoreThreadContext(context: CoroutineContext, oldState: T) {
116 threadLocal.set(oldState)
117 }
118
119 // this method is overridden to perform value comparison (==) on key
minusKeynull120 override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext {
121 return if (this.key == key) EmptyCoroutineContext else this
122 }
123
124 // this method is overridden to perform value comparison (==) on key
getnull125 public override operator fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? =
126 @Suppress("UNCHECKED_CAST")
127 if (this.key == key) this as E else null
128
129 override fun toString(): String = "ThreadLocal(value=$value, threadLocal = $threadLocal)"
130 }
131