1 /*
<lambda>null2 * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3 */
4
5 package kotlinx.coroutines
6
7 import kotlinx.coroutines.internal.*
8 import kotlin.coroutines.*
9 import kotlin.coroutines.jvm.internal.CoroutineStackFrame
10
11 /**
12 * Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
13 * [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
14 * and copyable-thread-local facilities on JVM.
15 * See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM.
16 */
17 @ExperimentalCoroutinesApi
18 public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
19 val combined = foldCopies(coroutineContext, context, true)
20 val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined
21 return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
22 debug + Dispatchers.Default else debug
23 }
24
25 /**
26 * Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
27 * @suppress
28 */
29 @InternalCoroutinesApi
newCoroutineContextnull30 public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
31 /*
32 * Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements)
33 * contains copyable elements.
34 */
35 if (!addedContext.hasCopyableElements()) return this + addedContext
36 return foldCopies(this, addedContext, false)
37 }
38
CoroutineContextnull39 private fun CoroutineContext.hasCopyableElements(): Boolean =
40 fold(false) { result, it -> result || it is CopyableThreadContextElement<*> }
41
42 /**
43 * Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary.
44 * The rules are the following:
45 * * If neither context has CTCE, the sum of two contexts is returned
46 * * Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context
47 * is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`.
48 * * Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild]
49 * * Every CTCE from the right-hand side context that hasn't been merged is copied
50 * * Everything else is added to the resulting context as is.
51 */
foldCopiesnull52 private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext {
53 // Do we have something to copy left-hand side?
54 val hasElementsLeft = originalContext.hasCopyableElements()
55 val hasElementsRight = appendContext.hasCopyableElements()
56
57 // Nothing to fold, so just return the sum of contexts
58 if (!hasElementsLeft && !hasElementsRight) {
59 return originalContext + appendContext
60 }
61
62 var leftoverContext = appendContext
63 val folded = originalContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
64 if (element !is CopyableThreadContextElement<*>) return@fold result + element
65 // Will this element be overwritten?
66 val newElement = leftoverContext[element.key]
67 // No, just copy it
68 if (newElement == null) {
69 // For 'withContext'-like builders we do not copy as the element is not shared
70 return@fold result + if (isNewCoroutine) element.copyForChild() else element
71 }
72 // Yes, then first remove the element from append context
73 leftoverContext = leftoverContext.minusKey(element.key)
74 // Return the sum
75 @Suppress("UNCHECKED_CAST")
76 return@fold result + (element as CopyableThreadContextElement<Any?>).mergeForChild(newElement)
77 }
78
79 if (hasElementsRight) {
80 leftoverContext = leftoverContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
81 // We're appending new context element -- we have to copy it, otherwise it may be shared with others
82 if (element is CopyableThreadContextElement<*>) {
83 return@fold result + element.copyForChild()
84 }
85 return@fold result + element
86 }
87 }
88 return folded + leftoverContext
89 }
90
91 /**
92 * Executes a block using a given coroutine context.
93 */
withCoroutineContextnull94 internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T {
95 val oldValue = updateThreadContext(context, countOrElement)
96 try {
97 return block()
98 } finally {
99 restoreThreadContext(context, oldValue)
100 }
101 }
102
103 /**
104 * Executes a block using a context of a given continuation.
105 */
withContinuationContextnull106 internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T {
107 val context = continuation.context
108 val oldValue = updateThreadContext(context, countOrElement)
109 val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) {
110 // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them
111 continuation.updateUndispatchedCompletion(context, oldValue)
112 } else {
113 null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context
114 }
115 try {
116 return block()
117 } finally {
118 if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) {
119 restoreThreadContext(context, oldValue)
120 }
121 }
122 }
123
updateUndispatchedCompletionnull124 internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? {
125 if (this !is CoroutineStackFrame) return null
126 /*
127 * Fast-path to detect whether we have undispatched coroutine at all in our stack.
128 *
129 * Implementation note.
130 * If we ever find that stackwalking for thread-locals is way too slow, here is another idea:
131 * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance
132 * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker`
133 * from the context when creating dispatched coroutine in `withContext`.
134 * Another option is to "unmark it" instead of removing to save an allocation.
135 * Both options should work, but it requires more careful studying of the performance
136 * and, mostly, maintainability impact.
137 */
138 val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null
139 if (!potentiallyHasUndispatchedCoroutine) return null
140 val completion = undispatchedCompletion()
141 completion?.saveThreadContext(context, oldValue)
142 return completion
143 }
144
undispatchedCompletionnull145 internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? {
146 // Find direct completion of this continuation
147 val completion: CoroutineStackFrame = when (this) {
148 is DispatchedCoroutine<*> -> return null
149 else -> callerFrame ?: return null // something else -- not supported
150 }
151 if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine!
152 return completion.undispatchedCompletion() // walk up the call stack with tail call
153 }
154
155 /**
156 * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack.
157 * Used as a performance optimization to avoid stack walking where it is not necessary.
158 */
159 private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
160 override val key: CoroutineContext.Key<*>
161 get() = this
162 }
163
164 // Used by withContext when context changes, but dispatcher stays the same
165 internal actual class UndispatchedCoroutine<in T>actual constructor (
166 context: CoroutineContext,
167 uCont: Continuation<T>
168 ) : ScopeCoroutine<T>(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) {
169
170 /*
171 * The state is thread-local because this coroutine can be used concurrently.
172 * Scenario of usage (withContinuationContext):
173 * val state = saveThreadContext(ctx)
174 * try {
175 * invokeSmthWithThisCoroutineAsCompletion() // Completion implies that 'afterResume' will be called
176 * // COROUTINE_SUSPENDED is returned
177 * } finally {
178 * thisCoroutine().clearThreadContext() // Concurrently the "smth" could've been already resumed on a different thread
179 * // and it also calls saveThreadContext and clearThreadContext
180 * }
181 */
182 private var threadStateToRecover = ThreadLocal<Pair<CoroutineContext, Any?>>()
183
184 init {
185 /*
186 * This is a hack for a very specific case in #2930 unless #3253 is implemented.
187 * 'ThreadLocalStressTest' covers this change properly.
188 *
189 * The scenario this change covers is the following:
190 * 1) The coroutine is being started as plain non kotlinx.coroutines related suspend function,
191 * e.g. `suspend fun main` or, more importantly, Ktor `SuspendFunGun`, that is invoking
192 * `withContext(tlElement)` which creates `UndispatchedCoroutine`.
193 * 2) It (original continuation) is then not wrapped into `DispatchedContinuation` via `intercept()`
194 * and goes neither through `DC.run` nor through `resumeUndispatchedWith` that both
195 * do thread context element tracking.
196 * 3) So thread locals never got chance to get properly set up via `saveThreadContext`,
197 * but when `withContext` finishes, it attempts to recover thread locals in its `afterResume`.
198 *
199 * Here we detect precisely this situation and properly setup context to recover later.
200 *
201 */
202 if (uCont.context[ContinuationInterceptor] !is CoroutineDispatcher) {
203 /*
204 * We cannot just "read" the elements as there is no such API,
205 * so we update-restore it immediately and use the intermediate value
206 * as the initial state, leveraging the fact that thread context element
207 * is idempotent and such situations are increasingly rare.
208 */
209 val values = updateThreadContext(context, null)
210 restoreThreadContext(context, values)
211 saveThreadContext(context, values)
212 }
213 }
214
saveThreadContextnull215 fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
216 threadStateToRecover.set(context to oldValue)
217 }
218
clearThreadContextnull219 fun clearThreadContext(): Boolean {
220 if (threadStateToRecover.get() == null) return false
221 threadStateToRecover.set(null)
222 return true
223 }
224
afterResumenull225 override fun afterResume(state: Any?) {
226 threadStateToRecover.get()?.let { (ctx, value) ->
227 restoreThreadContext(ctx, value)
228 threadStateToRecover.set(null)
229 }
230 // resume undispatched -- update context but stay on the same dispatcher
231 val result = recoverResult(state, uCont)
232 withContinuationContext(uCont, null) {
233 uCont.resumeWith(result)
234 }
235 }
236 }
237
238 internal actual val CoroutineContext.coroutineName: String? get() {
239 if (!DEBUG) return null
240 val coroutineId = this[CoroutineId] ?: return null
241 val coroutineName = this[CoroutineName]?.name ?: "coroutine"
242 return "$coroutineName#${coroutineId.id}"
243 }
244
245 private const val DEBUG_THREAD_NAME_SEPARATOR = " @"
246
247 @IgnoreJreRequirement // desugared hashcode implementation
248 internal data class CoroutineId(
249 val id: Long
250 ) : ThreadContextElement<String>, AbstractCoroutineContextElement(CoroutineId) {
251 companion object Key : CoroutineContext.Key<CoroutineId>
toStringnull252 override fun toString(): String = "CoroutineId($id)"
253
254 override fun updateThreadContext(context: CoroutineContext): String {
255 val coroutineName = context[CoroutineName]?.name ?: "coroutine"
256 val currentThread = Thread.currentThread()
257 val oldName = currentThread.name
258 var lastIndex = oldName.lastIndexOf(DEBUG_THREAD_NAME_SEPARATOR)
259 if (lastIndex < 0) lastIndex = oldName.length
260 currentThread.name = buildString(lastIndex + coroutineName.length + 10) {
261 append(oldName.substring(0, lastIndex))
262 append(DEBUG_THREAD_NAME_SEPARATOR)
263 append(coroutineName)
264 append('#')
265 append(id)
266 }
267 return oldName
268 }
269
restoreThreadContextnull270 override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
271 Thread.currentThread().name = oldState
272 }
273 }
274