• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines.internal
6 
7 import kotlinx.coroutines.*
8 import kotlin.coroutines.*
9 
10 
11 private val ZERO = Symbol("ZERO")
12 
13 // Used when there are >= 2 active elements in the context
14 private class ThreadState(val context: CoroutineContext, n: Int) {
15     private var a = arrayOfNulls<Any>(n)
16     private var i = 0
17 
appendnull18     fun append(value: Any?) { a[i++] = value }
takenull19     fun take() = a[i++]
20     fun start() { i = 0 }
21 }
22 
23 // Counts ThreadContextElements in the context
24 // Any? here is Int | ThreadContextElement (when count is one)
25 private val countAll =
<anonymous>null26     fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
27         if (element is ThreadContextElement<*>) {
28             val inCount = countOrElement as? Int ?: 1
29             return if (inCount == 0) element else inCount + 1
30         }
31         return countOrElement
32     }
33 
34 // Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
35 private val findOne =
<anonymous>null36     fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
37         if (found != null) return found
38         return element as? ThreadContextElement<*>
39     }
40 
41 // Updates state for ThreadContextElements in the context using the given ThreadState
42 private val updateState =
<anonymous>null43     fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
44         if (element is ThreadContextElement<*>) {
45             state.append(element.updateThreadContext(state.context))
46         }
47         return state
48     }
49 
50 // Restores state for all ThreadContextElements in the context from the given ThreadState
51 private val restoreState =
<anonymous>null52     fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
53         @Suppress("UNCHECKED_CAST")
54         if (element is ThreadContextElement<*>) {
55             (element as ThreadContextElement<Any?>).restoreThreadContext(state.context, state.take())
56         }
57         return state
58     }
59 
threadContextElementsnull60 internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!
61 
62 // countOrElement is pre-cached in dispatched continuation
63 internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
64     @Suppress("NAME_SHADOWING")
65     val countOrElement = countOrElement ?: threadContextElements(context)
66     @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
67     return when {
68         countOrElement === 0 -> ZERO // very fast path when there are no active ThreadContextElements
69         //    ^^^ identity comparison for speed, we know zero always has the same identity
70         countOrElement is Int -> {
71             // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
72             context.fold(ThreadState(context, countOrElement), updateState)
73         }
74         else -> {
75             // fast path for one ThreadContextElement (no allocations, no additional context scan)
76             @Suppress("UNCHECKED_CAST")
77             val element = countOrElement as ThreadContextElement<Any?>
78             element.updateThreadContext(context)
79         }
80     }
81 }
82 
restoreThreadContextnull83 internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
84     when {
85         oldState === ZERO -> return // very fast path when there are no ThreadContextElements
86         oldState is ThreadState -> {
87             // slow path with multiple stored ThreadContextElements
88             oldState.start()
89             context.fold(oldState, restoreState)
90         }
91         else -> {
92             // fast path for one ThreadContextElement, but need to find it
93             @Suppress("UNCHECKED_CAST")
94             val element = context.fold(null, findOne) as ThreadContextElement<Any?>
95             element.restoreThreadContext(context, oldState)
96         }
97     }
98 }
99 
100 // top-level data class for a nicer out-of-the-box toString representation and class name
101 @PublishedApi
102 internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
103 
104 internal class ThreadLocalElement<T>(
105     private val value: T,
106     private val threadLocal: ThreadLocal<T>
107 ) : ThreadContextElement<T> {
108     override val key: CoroutineContext.Key<*> = ThreadLocalKey(threadLocal)
109 
updateThreadContextnull110     override fun updateThreadContext(context: CoroutineContext): T {
111         val oldState = threadLocal.get()
112         threadLocal.set(value)
113         return oldState
114     }
115 
restoreThreadContextnull116     override fun restoreThreadContext(context: CoroutineContext, oldState: T) {
117         threadLocal.set(oldState)
118     }
119 
120     // this method is overridden to perform value comparison (==) on key
minusKeynull121     override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext {
122         return if (this.key == key) EmptyCoroutineContext else this
123     }
124 
125     // this method is overridden to perform value comparison (==) on key
getnull126     public override operator fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? =
127         @Suppress("UNCHECKED_CAST")
128         if (this.key == key) this as E else null
129 
130     override fun toString(): String = "ThreadLocal(value=$value, threadLocal = $threadLocal)"
131 }
132