• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines.internal
6 
7 import kotlinx.coroutines.*
8 import kotlin.coroutines.*
9 
10 @JvmField
11 internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")
12 
13 // Used when there are >= 2 active elements in the context
14 @Suppress("UNCHECKED_CAST")
15 private class ThreadState(@JvmField val context: CoroutineContext, n: Int) {
16     private val values = arrayOfNulls<Any>(n)
17     private val elements = arrayOfNulls<ThreadContextElement<Any?>>(n)
18     private var i = 0
19 
appendnull20     fun append(element: ThreadContextElement<*>, value: Any?) {
21         values[i] = value
22         elements[i++] = element as ThreadContextElement<Any?>
23     }
24 
restorenull25     fun restore(context: CoroutineContext) {
26         for (i in elements.indices.reversed()) {
27             elements[i]!!.restoreThreadContext(context, values[i])
28         }
29     }
30 }
31 
32 // Counts ThreadContextElements in the context
33 // Any? here is Int | ThreadContextElement (when count is one)
34 private val countAll =
<anonymous>null35     fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
36         if (element is ThreadContextElement<*>) {
37             val inCount = countOrElement as? Int ?: 1
38             return if (inCount == 0) element else inCount + 1
39         }
40         return countOrElement
41     }
42 
43 // Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
44 private val findOne =
<anonymous>null45     fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
46         if (found != null) return found
47         return element as? ThreadContextElement<*>
48     }
49 
50 // Updates state for ThreadContextElements in the context using the given ThreadState
51 private val updateState =
<anonymous>null52     fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
53         if (element is ThreadContextElement<*>) {
54             state.append(element, element.updateThreadContext(state.context))
55         }
56         return state
57     }
58 
threadContextElementsnull59 internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!
60 
61 // countOrElement is pre-cached in dispatched continuation
62 // returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements
63 internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
64     @Suppress("NAME_SHADOWING")
65     val countOrElement = countOrElement ?: threadContextElements(context)
66     @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
67     return when {
68         countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements
69         //    ^^^ identity comparison for speed, we know zero always has the same identity
70         countOrElement is Int -> {
71             // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
72             context.fold(ThreadState(context, countOrElement), updateState)
73         }
74         else -> {
75             // fast path for one ThreadContextElement (no allocations, no additional context scan)
76             @Suppress("UNCHECKED_CAST")
77             val element = countOrElement as ThreadContextElement<Any?>
78             element.updateThreadContext(context)
79         }
80     }
81 }
82 
restoreThreadContextnull83 internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
84     when {
85         oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
86         oldState is ThreadState -> {
87             // slow path with multiple stored ThreadContextElements
88             oldState.restore(context)
89         }
90         else -> {
91             // fast path for one ThreadContextElement, but need to find it
92             @Suppress("UNCHECKED_CAST")
93             val element = context.fold(null, findOne) as ThreadContextElement<Any?>
94             element.restoreThreadContext(context, oldState)
95         }
96     }
97 }
98 
99 // top-level data class for a nicer out-of-the-box toString representation and class name
100 @PublishedApi
101 internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
102 
103 internal class ThreadLocalElement<T>(
104     private val value: T,
105     private val threadLocal: ThreadLocal<T>
106 ) : ThreadContextElement<T> {
107     override val key: CoroutineContext.Key<*> = ThreadLocalKey(threadLocal)
108 
updateThreadContextnull109     override fun updateThreadContext(context: CoroutineContext): T {
110         val oldState = threadLocal.get()
111         threadLocal.set(value)
112         return oldState
113     }
114 
restoreThreadContextnull115     override fun restoreThreadContext(context: CoroutineContext, oldState: T) {
116         threadLocal.set(oldState)
117     }
118 
119     // this method is overridden to perform value comparison (==) on key
minusKeynull120     override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext {
121         return if (this.key == key) EmptyCoroutineContext else this
122     }
123 
124     // this method is overridden to perform value comparison (==) on key
getnull125     public override operator fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? =
126         @Suppress("UNCHECKED_CAST")
127         if (this.key == key) this as E else null
128 
129     override fun toString(): String = "ThreadLocal(value=$value, threadLocal = $threadLocal)"
130 }
131