• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines
6 
7 import kotlinx.coroutines.internal.*
8 import kotlin.coroutines.*
9 import kotlin.coroutines.jvm.internal.CoroutineStackFrame
10 
11 /**
12  * Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
13  * [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
14  * and copyable-thread-local facilities on JVM.
15  * See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM.
16  */
17 @ExperimentalCoroutinesApi
18 public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
19     val combined = foldCopies(coroutineContext, context, true)
20     val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined
21     return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
22         debug + Dispatchers.Default else debug
23 }
24 
25 /**
26  * Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
27  * @suppress
28  */
29 @InternalCoroutinesApi
newCoroutineContextnull30 public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
31     /*
32      * Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements)
33      * contains copyable elements.
34      */
35     if (!addedContext.hasCopyableElements()) return this + addedContext
36     return foldCopies(this, addedContext, false)
37 }
38 
CoroutineContextnull39 private fun CoroutineContext.hasCopyableElements(): Boolean =
40     fold(false) { result, it -> result || it is CopyableThreadContextElement<*> }
41 
42 /**
43  * Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary.
44  * The rules are the following:
45  * * If neither context has CTCE, the sum of two contexts is returned
46  * * Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context
47  *   is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`.
48  * * Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild]
49  * * Every CTCE from the right-hand side context that hasn't been merged is copied
50  * * Everything else is added to the resulting context as is.
51  */
foldCopiesnull52 private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext {
53     // Do we have something to copy left-hand side?
54     val hasElementsLeft = originalContext.hasCopyableElements()
55     val hasElementsRight = appendContext.hasCopyableElements()
56 
57     // Nothing to fold, so just return the sum of contexts
58     if (!hasElementsLeft && !hasElementsRight) {
59         return originalContext + appendContext
60     }
61 
62     var leftoverContext = appendContext
63     val folded = originalContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
64         if (element !is CopyableThreadContextElement<*>) return@fold result + element
65         // Will this element be overwritten?
66         val newElement = leftoverContext[element.key]
67         // No, just copy it
68         if (newElement == null) {
69             // For 'withContext'-like builders we do not copy as the element is not shared
70             return@fold result + if (isNewCoroutine) element.copyForChild() else element
71         }
72         // Yes, then first remove the element from append context
73         leftoverContext = leftoverContext.minusKey(element.key)
74         // Return the sum
75         @Suppress("UNCHECKED_CAST")
76         return@fold result + (element as CopyableThreadContextElement<Any?>).mergeForChild(newElement)
77     }
78 
79     if (hasElementsRight) {
80         leftoverContext = leftoverContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
81             // We're appending new context element -- we have to copy it, otherwise it may be shared with others
82             if (element is CopyableThreadContextElement<*>) {
83                 return@fold result + element.copyForChild()
84             }
85             return@fold result + element
86         }
87     }
88     return folded + leftoverContext
89 }
90 
91 /**
92  * Executes a block using a given coroutine context.
93  */
withCoroutineContextnull94 internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T {
95     val oldValue = updateThreadContext(context, countOrElement)
96     try {
97         return block()
98     } finally {
99         restoreThreadContext(context, oldValue)
100     }
101 }
102 
103 /**
104  * Executes a block using a context of a given continuation.
105  */
withContinuationContextnull106 internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T {
107     val context = continuation.context
108     val oldValue = updateThreadContext(context, countOrElement)
109     val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) {
110         // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them
111         continuation.updateUndispatchedCompletion(context, oldValue)
112     } else {
113         null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context
114     }
115     try {
116         return block()
117     } finally {
118         if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) {
119             restoreThreadContext(context, oldValue)
120         }
121     }
122 }
123 
updateUndispatchedCompletionnull124 internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? {
125     if (this !is CoroutineStackFrame) return null
126     /*
127      * Fast-path to detect whether we have undispatched coroutine at all in our stack.
128      *
129      * Implementation note.
130      * If we ever find that stackwalking for thread-locals is way too slow, here is another idea:
131      * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance
132      * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker`
133      *    from the context when creating dispatched coroutine in `withContext`.
134      *    Another option is to "unmark it" instead of removing to save an allocation.
135      *    Both options should work, but it requires more careful studying of the performance
136      *    and, mostly, maintainability impact.
137      */
138     val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null
139     if (!potentiallyHasUndispatchedCoroutine) return null
140     val completion = undispatchedCompletion()
141     completion?.saveThreadContext(context, oldValue)
142     return completion
143 }
144 
undispatchedCompletionnull145 internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? {
146     // Find direct completion of this continuation
147     val completion: CoroutineStackFrame = when (this) {
148         is DispatchedCoroutine<*> -> return null
149         else -> callerFrame ?: return null // something else -- not supported
150     }
151     if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine!
152     return completion.undispatchedCompletion() // walk up the call stack with tail call
153 }
154 
155 /**
156  * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack.
157  * Used as a performance optimization to avoid stack walking where it is not necessary.
158  */
159 private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
160     override val key: CoroutineContext.Key<*>
161         get() = this
162 }
163 
164 // Used by withContext when context changes, but dispatcher stays the same
165 internal actual class UndispatchedCoroutine<in T>actual constructor (
166     context: CoroutineContext,
167     uCont: Continuation<T>
168 ) : ScopeCoroutine<T>(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) {
169 
170     /*
171      * The state is thread-local because this coroutine can be used concurrently.
172      * Scenario of usage (withContinuationContext):
173      * val state = saveThreadContext(ctx)
174      * try {
175      *     invokeSmthWithThisCoroutineAsCompletion() // Completion implies that 'afterResume' will be called
176      *     // COROUTINE_SUSPENDED is returned
177      * } finally {
178      *     thisCoroutine().clearThreadContext() // Concurrently the "smth" could've been already resumed on a different thread
179      *     // and it also calls saveThreadContext and clearThreadContext
180      * }
181      */
182     private var threadStateToRecover = ThreadLocal<Pair<CoroutineContext, Any?>>()
183 
184     init {
185         /*
186          * This is a hack for a very specific case in #2930 unless #3253 is implemented.
187          * 'ThreadLocalStressTest' covers this change properly.
188          *
189          * The scenario this change covers is the following:
190          * 1) The coroutine is being started as plain non kotlinx.coroutines related suspend function,
191          *    e.g. `suspend fun main` or, more importantly, Ktor `SuspendFunGun`, that is invoking
192          *    `withContext(tlElement)` which creates `UndispatchedCoroutine`.
193          * 2) It (original continuation) is then not wrapped into `DispatchedContinuation` via `intercept()`
194          *    and goes neither through `DC.run` nor through `resumeUndispatchedWith` that both
195          *    do thread context element tracking.
196          * 3) So thread locals never got chance to get properly set up via `saveThreadContext`,
197          *    but when `withContext` finishes, it attempts to recover thread locals in its `afterResume`.
198          *
199          * Here we detect precisely this situation and properly setup context to recover later.
200          *
201          */
202         if (uCont.context[ContinuationInterceptor] !is CoroutineDispatcher) {
203             /*
204              * We cannot just "read" the elements as there is no such API,
205              * so we update-restore it immediately and use the intermediate value
206              * as the initial state, leveraging the fact that thread context element
207              * is idempotent and such situations are increasingly rare.
208              */
209             val values = updateThreadContext(context, null)
210             restoreThreadContext(context, values)
211             saveThreadContext(context, values)
212         }
213     }
214 
saveThreadContextnull215     fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
216         threadStateToRecover.set(context to oldValue)
217     }
218 
clearThreadContextnull219     fun clearThreadContext(): Boolean {
220         if (threadStateToRecover.get() == null) return false
221         threadStateToRecover.set(null)
222         return true
223     }
224 
afterResumenull225     override fun afterResume(state: Any?) {
226         threadStateToRecover.get()?.let { (ctx, value) ->
227             restoreThreadContext(ctx, value)
228             threadStateToRecover.set(null)
229         }
230         // resume undispatched -- update context but stay on the same dispatcher
231         val result = recoverResult(state, uCont)
232         withContinuationContext(uCont, null) {
233             uCont.resumeWith(result)
234         }
235     }
236 }
237 
238 internal actual val CoroutineContext.coroutineName: String? get() {
239     if (!DEBUG) return null
240     val coroutineId = this[CoroutineId] ?: return null
241     val coroutineName = this[CoroutineName]?.name ?: "coroutine"
242     return "$coroutineName#${coroutineId.id}"
243 }
244 
245 private const val DEBUG_THREAD_NAME_SEPARATOR = " @"
246 
247 @IgnoreJreRequirement // desugared hashcode implementation
248 internal data class CoroutineId(
249     val id: Long
250 ) : ThreadContextElement<String>, AbstractCoroutineContextElement(CoroutineId) {
251     companion object Key : CoroutineContext.Key<CoroutineId>
toStringnull252     override fun toString(): String = "CoroutineId($id)"
253 
254     override fun updateThreadContext(context: CoroutineContext): String {
255         val coroutineName = context[CoroutineName]?.name ?: "coroutine"
256         val currentThread = Thread.currentThread()
257         val oldName = currentThread.name
258         var lastIndex = oldName.lastIndexOf(DEBUG_THREAD_NAME_SEPARATOR)
259         if (lastIndex < 0) lastIndex = oldName.length
260         currentThread.name = buildString(lastIndex + coroutineName.length + 10) {
261             append(oldName.substring(0, lastIndex))
262             append(DEBUG_THREAD_NAME_SEPARATOR)
263             append(coroutineName)
264             append('#')
265             append(id)
266         }
267         return oldName
268     }
269 
restoreThreadContextnull270     override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
271         Thread.currentThread().name = oldState
272     }
273 }
274