1 /*
<lambda>null2 * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3 */
4
5 package kotlinx.coroutines
6
7 import kotlinx.coroutines.internal.*
8 import kotlin.coroutines.*
9 import kotlin.coroutines.jvm.internal.CoroutineStackFrame
10
11 /**
12 * Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
13 * [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
14 * and copyable-thread-local facilities on JVM.
15 * See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM.
16 */
17 @ExperimentalCoroutinesApi
18 public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
19 val combined = foldCopies(coroutineContext, context, true)
20 val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined
21 return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
22 debug + Dispatchers.Default else debug
23 }
24
25 /**
26 * Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
27 * @suppress
28 */
29 @InternalCoroutinesApi
newCoroutineContextnull30 public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
31 /*
32 * Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements)
33 * contains copyable elements.
34 */
35 if (!addedContext.hasCopyableElements()) return this + addedContext
36 return foldCopies(this, addedContext, false)
37 }
38
CoroutineContextnull39 private fun CoroutineContext.hasCopyableElements(): Boolean =
40 fold(false) { result, it -> result || it is CopyableThreadContextElement<*> }
41
42 /**
43 * Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary.
44 * The rules are the following:
45 * * If neither context has CTCE, the sum of two contexts is returned
46 * * Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context
47 * is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`.
48 * * Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild]
49 * * Every CTCE from the right-hand side context that hasn't been merged is copied
50 * * Everything else is added to the resulting context as is.
51 */
foldCopiesnull52 private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext {
53 // Do we have something to copy left-hand side?
54 val hasElementsLeft = originalContext.hasCopyableElements()
55 val hasElementsRight = appendContext.hasCopyableElements()
56
57 // Nothing to fold, so just return the sum of contexts
58 if (!hasElementsLeft && !hasElementsRight) {
59 return originalContext + appendContext
60 }
61
62 var leftoverContext = appendContext
63 val folded = originalContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
64 if (element !is CopyableThreadContextElement<*>) return@fold result + element
65 // Will this element be overwritten?
66 val newElement = leftoverContext[element.key]
67 // No, just copy it
68 if (newElement == null) {
69 // For 'withContext'-like builders we do not copy as the element is not shared
70 return@fold result + if (isNewCoroutine) element.copyForChild() else element
71 }
72 // Yes, then first remove the element from append context
73 leftoverContext = leftoverContext.minusKey(element.key)
74 // Return the sum
75 @Suppress("UNCHECKED_CAST")
76 return@fold result + (element as CopyableThreadContextElement<Any?>).mergeForChild(newElement)
77 }
78
79 if (hasElementsRight) {
80 leftoverContext = leftoverContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
81 // We're appending new context element -- we have to copy it, otherwise it may be shared with others
82 if (element is CopyableThreadContextElement<*>) {
83 return@fold result + element.copyForChild()
84 }
85 return@fold result + element
86 }
87 }
88 return folded + leftoverContext
89 }
90
91 /**
92 * Executes a block using a given coroutine context.
93 */
withCoroutineContextnull94 internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T {
95 val oldValue = updateThreadContext(context, countOrElement)
96 try {
97 return block()
98 } finally {
99 restoreThreadContext(context, oldValue)
100 }
101 }
102
103 /**
104 * Executes a block using a context of a given continuation.
105 */
withContinuationContextnull106 internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T {
107 val context = continuation.context
108 val oldValue = updateThreadContext(context, countOrElement)
109 val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) {
110 // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them
111 continuation.updateUndispatchedCompletion(context, oldValue)
112 } else {
113 null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context
114 }
115 try {
116 return block()
117 } finally {
118 if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) {
119 restoreThreadContext(context, oldValue)
120 }
121 }
122 }
123
updateUndispatchedCompletionnull124 internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? {
125 if (this !is CoroutineStackFrame) return null
126 /*
127 * Fast-path to detect whether we have undispatched coroutine at all in our stack.
128 *
129 * Implementation note.
130 * If we ever find that stackwalking for thread-locals is way too slow, here is another idea:
131 * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance
132 * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker`
133 * from the context when creating dispatched coroutine in `withContext`.
134 * Another option is to "unmark it" instead of removing to save an allocation.
135 * Both options should work, but it requires more careful studying of the performance
136 * and, mostly, maintainability impact.
137 */
138 val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null
139 if (!potentiallyHasUndispatchedCoroutine) return null
140 val completion = undispatchedCompletion()
141 completion?.saveThreadContext(context, oldValue)
142 return completion
143 }
144
undispatchedCompletionnull145 internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? {
146 // Find direct completion of this continuation
147 val completion: CoroutineStackFrame = when (this) {
148 is DispatchedCoroutine<*> -> return null
149 else -> callerFrame ?: return null // something else -- not supported
150 }
151 if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine!
152 return completion.undispatchedCompletion() // walk up the call stack with tail call
153 }
154
155 /**
156 * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack.
157 * Used as a performance optimization to avoid stack walking where it is not necessary.
158 */
159 private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
160 override val key: CoroutineContext.Key<*>
161 get() = this
162 }
163
164 // Used by withContext when context changes, but dispatcher stays the same
165 internal actual class UndispatchedCoroutine<in T>actual constructor (
166 context: CoroutineContext,
167 uCont: Continuation<T>
168 ) : ScopeCoroutine<T>(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) {
169
170 /**
171 * The state of [ThreadContextElement]s associated with the current undispatched coroutine.
172 * It is stored in a thread local because this coroutine can be used concurrently in suspend-resume race scenario.
173 * See the followin, boiled down example with inlined `withContinuationContext` body:
174 * ```
175 * val state = saveThreadContext(ctx)
176 * try {
177 * invokeSmthWithThisCoroutineAsCompletion() // Completion implies that 'afterResume' will be called
178 * // COROUTINE_SUSPENDED is returned
179 * } finally {
180 * thisCoroutine().clearThreadContext() // Concurrently the "smth" could've been already resumed on a different thread
181 * // and it also calls saveThreadContext and clearThreadContext
182 * }
183 * ```
184 *
185 * Usage note:
186 *
187 * This part of the code is performance-sensitive.
188 * It is a well-established pattern to wrap various activities into system-specific undispatched
189 * `withContext` for the sake of logging, MDC, tracing etc., meaning that there exists thousands of
190 * undispatched coroutines.
191 * Each access to Java's [ThreadLocal] leaves a footprint in the corresponding Thread's `ThreadLocalMap`
192 * that is cleared automatically as soon as the associated thread-local (-> UndispatchedCoroutine) is garbage collected.
193 * When such coroutines are promoted to old generation, `ThreadLocalMap`s become bloated and an arbitrary accesses to thread locals
194 * start to consume significant amount of CPU because these maps are open-addressed and cleaned up incrementally on each access.
195 * (You can read more about this effect as "GC nepotism").
196 *
197 * To avoid that, we attempt to narrow down the lifetime of this thread local as much as possible:
198 * * It's never accessed when we are sure there are no thread context elements
199 * * It's cleaned up via [ThreadLocal.remove] as soon as the coroutine is suspended or finished.
200 */
201 private val threadStateToRecover = ThreadLocal<Pair<CoroutineContext, Any?>>()
202
203 /*
204 * Indicates that a coroutine has at least one thread context element associated with it
205 * and that 'threadStateToRecover' is going to be set in case of dispatchhing in order to preserve them.
206 * Better than nullable thread-local for easier debugging.
207 *
208 * It is used as a performance optimization to avoid 'threadStateToRecover' initialization
209 * (note: tl.get() initializes thread local),
210 * and is prone to false-positives as it is never reset: otherwise
211 * it may lead to logical data races between suspensions point where
212 * coroutine is yet being suspended in one thread while already being resumed
213 * in another.
214 */
215 @Volatile
216 private var threadLocalIsSet = false
217
218 init {
219 /*
220 * This is a hack for a very specific case in #2930 unless #3253 is implemented.
221 * 'ThreadLocalStressTest' covers this change properly.
222 *
223 * The scenario this change covers is the following:
224 * 1) The coroutine is being started as plain non kotlinx.coroutines related suspend function,
225 * e.g. `suspend fun main` or, more importantly, Ktor `SuspendFunGun`, that is invoking
226 * `withContext(tlElement)` which creates `UndispatchedCoroutine`.
227 * 2) It (original continuation) is then not wrapped into `DispatchedContinuation` via `intercept()`
228 * and goes neither through `DC.run` nor through `resumeUndispatchedWith` that both
229 * do thread context element tracking.
230 * 3) So thread locals never got chance to get properly set up via `saveThreadContext`,
231 * but when `withContext` finishes, it attempts to recover thread locals in its `afterResume`.
232 *
233 * Here we detect precisely this situation and properly setup context to recover later.
234 *
235 */
236 if (uCont.context[ContinuationInterceptor] !is CoroutineDispatcher) {
237 /*
238 * We cannot just "read" the elements as there is no such API,
239 * so we update-restore it immediately and use the intermediate value
240 * as the initial state, leveraging the fact that thread context element
241 * is idempotent and such situations are increasingly rare.
242 */
243 val values = updateThreadContext(context, null)
244 restoreThreadContext(context, values)
245 saveThreadContext(context, values)
246 }
247 }
248
saveThreadContextnull249 fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
250 threadLocalIsSet = true // Specify that thread-local is touched at all
251 threadStateToRecover.set(context to oldValue)
252 }
253
clearThreadContextnull254 fun clearThreadContext(): Boolean {
255 return !(threadLocalIsSet && threadStateToRecover.get() == null).also {
256 threadStateToRecover.remove()
257 }
258 }
259
afterResumenull260 override fun afterResume(state: Any?) {
261 if (threadLocalIsSet) {
262 threadStateToRecover.get()?.let { (ctx, value) ->
263 restoreThreadContext(ctx, value)
264 }
265 threadStateToRecover.remove()
266 }
267 // resume undispatched -- update context but stay on the same dispatcher
268 val result = recoverResult(state, uCont)
269 withContinuationContext(uCont, null) {
270 uCont.resumeWith(result)
271 }
272 }
273 }
274
275 internal actual val CoroutineContext.coroutineName: String? get() {
276 if (!DEBUG) return null
277 val coroutineId = this[CoroutineId] ?: return null
278 val coroutineName = this[CoroutineName]?.name ?: "coroutine"
279 return "$coroutineName#${coroutineId.id}"
280 }
281
282 private const val DEBUG_THREAD_NAME_SEPARATOR = " @"
283
284 @IgnoreJreRequirement // desugared hashcode implementation
285 @PublishedApi
286 internal data class CoroutineId(
287 // Used by the IDEA debugger via reflection and must be kept binary-compatible, see KTIJ-24102
288 val id: Long
289 ) : ThreadContextElement<String>, AbstractCoroutineContextElement(CoroutineId) {
290 // Used by the IDEA debugger via reflection and must be kept binary-compatible, see KTIJ-24102
291 companion object Key : CoroutineContext.Key<CoroutineId>
toStringnull292 override fun toString(): String = "CoroutineId($id)"
293
294 override fun updateThreadContext(context: CoroutineContext): String {
295 val coroutineName = context[CoroutineName]?.name ?: "coroutine"
296 val currentThread = Thread.currentThread()
297 val oldName = currentThread.name
298 var lastIndex = oldName.lastIndexOf(DEBUG_THREAD_NAME_SEPARATOR)
299 if (lastIndex < 0) lastIndex = oldName.length
300 currentThread.name = buildString(lastIndex + coroutineName.length + 10) {
301 append(oldName.substring(0, lastIndex))
302 append(DEBUG_THREAD_NAME_SEPARATOR)
303 append(coroutineName)
304 append('#')
305 append(id)
306 }
307 return oldName
308 }
309
restoreThreadContextnull310 override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
311 Thread.currentThread().name = oldState
312 }
313 }
314