1 /*
2 * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3 */
4
5 package kotlinx.coroutines.internal
6
7 import kotlinx.coroutines.*
8 import kotlin.coroutines.*
9
10
11 private val ZERO = Symbol("ZERO")
12
13 // Used when there are >= 2 active elements in the context
14 private class ThreadState(val context: CoroutineContext, n: Int) {
15 private var a = arrayOfNulls<Any>(n)
16 private var i = 0
17
appendnull18 fun append(value: Any?) { a[i++] = value }
takenull19 fun take() = a[i++]
20 fun start() { i = 0 }
21 }
22
23 // Counts ThreadContextElements in the context
24 // Any? here is Int | ThreadContextElement (when count is one)
25 private val countAll =
<anonymous>null26 fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
27 if (element is ThreadContextElement<*>) {
28 val inCount = countOrElement as? Int ?: 1
29 return if (inCount == 0) element else inCount + 1
30 }
31 return countOrElement
32 }
33
34 // Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
35 private val findOne =
<anonymous>null36 fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
37 if (found != null) return found
38 return element as? ThreadContextElement<*>
39 }
40
41 // Updates state for ThreadContextElements in the context using the given ThreadState
42 private val updateState =
<anonymous>null43 fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
44 if (element is ThreadContextElement<*>) {
45 state.append(element.updateThreadContext(state.context))
46 }
47 return state
48 }
49
50 // Restores state for all ThreadContextElements in the context from the given ThreadState
51 private val restoreState =
<anonymous>null52 fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
53 @Suppress("UNCHECKED_CAST")
54 if (element is ThreadContextElement<*>) {
55 (element as ThreadContextElement<Any?>).restoreThreadContext(state.context, state.take())
56 }
57 return state
58 }
59
threadContextElementsnull60 internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!
61
62 // countOrElement is pre-cached in dispatched continuation
63 internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
64 @Suppress("NAME_SHADOWING")
65 val countOrElement = countOrElement ?: threadContextElements(context)
66 @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
67 return when {
68 countOrElement === 0 -> ZERO // very fast path when there are no active ThreadContextElements
69 // ^^^ identity comparison for speed, we know zero always has the same identity
70 countOrElement is Int -> {
71 // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
72 context.fold(ThreadState(context, countOrElement), updateState)
73 }
74 else -> {
75 // fast path for one ThreadContextElement (no allocations, no additional context scan)
76 @Suppress("UNCHECKED_CAST")
77 val element = countOrElement as ThreadContextElement<Any?>
78 element.updateThreadContext(context)
79 }
80 }
81 }
82
restoreThreadContextnull83 internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
84 when {
85 oldState === ZERO -> return // very fast path when there are no ThreadContextElements
86 oldState is ThreadState -> {
87 // slow path with multiple stored ThreadContextElements
88 oldState.start()
89 context.fold(oldState, restoreState)
90 }
91 else -> {
92 // fast path for one ThreadContextElement, but need to find it
93 @Suppress("UNCHECKED_CAST")
94 val element = context.fold(null, findOne) as ThreadContextElement<Any?>
95 element.restoreThreadContext(context, oldState)
96 }
97 }
98 }
99
100 // top-level data class for a nicer out-of-the-box toString representation and class name
101 @PublishedApi
102 internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
103
104 internal class ThreadLocalElement<T>(
105 private val value: T,
106 private val threadLocal: ThreadLocal<T>
107 ) : ThreadContextElement<T> {
108 override val key: CoroutineContext.Key<*> = ThreadLocalKey(threadLocal)
109
updateThreadContextnull110 override fun updateThreadContext(context: CoroutineContext): T {
111 val oldState = threadLocal.get()
112 threadLocal.set(value)
113 return oldState
114 }
115
restoreThreadContextnull116 override fun restoreThreadContext(context: CoroutineContext, oldState: T) {
117 threadLocal.set(oldState)
118 }
119
120 // this method is overridden to perform value comparison (==) on key
minusKeynull121 override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext {
122 return if (this.key == key) EmptyCoroutineContext else this
123 }
124
125 // this method is overridden to perform value comparison (==) on key
getnull126 public override operator fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? =
127 @Suppress("UNCHECKED_CAST")
128 if (this.key == key) this as E else null
129
130 override fun toString(): String = "ThreadLocal(value=$value, threadLocal = $threadLocal)"
131 }
132