1 /*
2 * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3 */
4
5 package kotlinx.coroutines
6
7 import org.junit.Test
8 import kotlin.coroutines.*
9 import kotlin.test.*
10 import kotlinx.coroutines.flow.*
11
12 class ThreadContextElementTest : TestBase() {
13
14 @Test
<lambda>null15 fun testExample() = runTest {
16 val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!!
17 val mainDispatcher = coroutineContext[ContinuationInterceptor]!!
18 val mainThread = Thread.currentThread()
19 val data = MyData()
20 val element = MyElement(data)
21 assertNull(myThreadLocal.get())
22 val job = GlobalScope.launch(element + exceptionHandler) {
23 assertTrue(mainThread != Thread.currentThread())
24 assertSame(element, coroutineContext[MyElement])
25 assertSame(data, myThreadLocal.get())
26 withContext(mainDispatcher) {
27 assertSame(mainThread, Thread.currentThread())
28 assertSame(element, coroutineContext[MyElement])
29 assertSame(data, myThreadLocal.get())
30 }
31 assertTrue(mainThread != Thread.currentThread())
32 assertSame(element, coroutineContext[MyElement])
33 assertSame(data, myThreadLocal.get())
34 }
35 assertNull(myThreadLocal.get())
36 job.join()
37 assertNull(myThreadLocal.get())
38 }
39
40 @Test
testUndispatchednull41 fun testUndispatched() = runTest {
42 val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!!
43 val data = MyData()
44 val element = MyElement(data)
45 val job = GlobalScope.launch(
46 context = Dispatchers.Default + exceptionHandler + element,
47 start = CoroutineStart.UNDISPATCHED
48 ) {
49 assertSame(data, myThreadLocal.get())
50 yield()
51 assertSame(data, myThreadLocal.get())
52 }
53 assertNull(myThreadLocal.get())
54 job.join()
55 assertNull(myThreadLocal.get())
56 }
57
58 @Test
<lambda>null59 fun testWithContext() = runTest {
60 expect(1)
61 newSingleThreadContext("withContext").use {
62 val data = MyData()
63 GlobalScope.async(Dispatchers.Default + MyElement(data)) {
64 assertSame(data, myThreadLocal.get())
65 expect(2)
66
67 val newData = MyData()
68 GlobalScope.async(it + MyElement(newData)) {
69 assertSame(newData, myThreadLocal.get())
70 expect(3)
71 }.await()
72
73 withContext(it + MyElement(newData)) {
74 assertSame(newData, myThreadLocal.get())
75 expect(4)
76 }
77
78 GlobalScope.async(it) {
79 assertNull(myThreadLocal.get())
80 expect(5)
81 }.await()
82
83 expect(6)
84 }.await()
85 }
86
87 finish(7)
88 }
89
90 @Test
<lambda>null91 fun testNonCopyableElementReferenceInheritedOnLaunch() = runTest {
92 var parentElement: MyElement? = null
93 var inheritedElement: MyElement? = null
94
95 newSingleThreadContext("withContext").use {
96 withContext(it + MyElement(MyData())) {
97 parentElement = coroutineContext[MyElement.Key]
98 launch {
99 inheritedElement = coroutineContext[MyElement.Key]
100 }
101 }
102 }
103
104 assertSame(inheritedElement, parentElement,
105 "Inner and outer coroutines did not have the same object reference to a" +
106 " ThreadContextElement that did not override `copyForChildCoroutine()`")
107 }
108
109 @Test
<lambda>null110 fun testCopyableElementCopiedOnLaunch() = runTest {
111 var parentElement: CopyForChildCoroutineElement? = null
112 var inheritedElement: CopyForChildCoroutineElement? = null
113
114 newSingleThreadContext("withContext").use {
115 withContext(it + CopyForChildCoroutineElement(MyData())) {
116 parentElement = coroutineContext[CopyForChildCoroutineElement.Key]
117 launch {
118 inheritedElement = coroutineContext[CopyForChildCoroutineElement.Key]
119 }
120 }
121 }
122
123 assertNotSame(inheritedElement, parentElement,
124 "Inner coroutine did not copy its copyable ThreadContextElement.")
125 }
126
127 @Test
<lambda>null128 fun testCopyableThreadContextElementImplementsWriteVisibility() = runTest {
129 newFixedThreadPoolContext(nThreads = 4, name = "withContext").use {
130 withContext(it + CopyForChildCoroutineElement(MyData())) {
131 val forBlockData = MyData()
132 myThreadLocal.setForBlock(forBlockData) {
133 assertSame(myThreadLocal.get(), forBlockData)
134 launch {
135 assertSame(myThreadLocal.get(), forBlockData)
136 }
137 launch {
138 assertSame(myThreadLocal.get(), forBlockData)
139 // Modify value in child coroutine. Writes to the ThreadLocal and
140 // the (copied) ThreadLocalElement's memory are not visible to peer or
141 // ancestor coroutines, so this write is both threadsafe and coroutinesafe.
142 val innerCoroutineData = MyData()
143 myThreadLocal.setForBlock(innerCoroutineData) {
144 assertSame(myThreadLocal.get(), innerCoroutineData)
145 }
146 assertSame(myThreadLocal.get(), forBlockData) // Asserts value was restored.
147 }
148 launch {
149 val innerCoroutineData = MyData()
150 myThreadLocal.setForBlock(innerCoroutineData) {
151 assertSame(myThreadLocal.get(), innerCoroutineData)
152 }
153 assertSame(myThreadLocal.get(), forBlockData)
154 }
155 }
156 assertNull(myThreadLocal.get()) // Asserts value was restored to its origin
157 }
158 }
159 }
160
161 class JobCaptor(val capturees: ArrayList<Job> = ArrayList()) : ThreadContextElement<Unit> {
162
163 companion object Key : CoroutineContext.Key<MyElement>
164
165 override val key: CoroutineContext.Key<*> get() = Key
166
updateThreadContextnull167 override fun updateThreadContext(context: CoroutineContext) {
168 capturees.add(context.job)
169 }
170
restoreThreadContextnull171 override fun restoreThreadContext(context: CoroutineContext, oldState: Unit) {
172 }
173 }
174
175 @Test
<lambda>null176 fun testWithContextJobAccess() = runTest {
177 val captor = JobCaptor()
178 val manuallyCaptured = ArrayList<Job>()
179 runBlocking(captor) {
180 manuallyCaptured += coroutineContext.job
181 withContext(CoroutineName("undispatched")) {
182 manuallyCaptured += coroutineContext.job
183 withContext(Dispatchers.IO) {
184 manuallyCaptured += coroutineContext.job
185 }
186 // Context restored, captured again
187 manuallyCaptured += coroutineContext.job
188 }
189 // Context restored, captured again
190 manuallyCaptured += coroutineContext.job
191 }
192
193 assertEquals(manuallyCaptured, captor.capturees)
194 }
195
196 @Test
testThreadLocalFlowOnnull197 fun testThreadLocalFlowOn() = runTest {
198 val myData = MyData()
199 myThreadLocal.set(myData)
200 expect(1)
201 flow {
202 assertEquals(myData, myThreadLocal.get())
203 emit(1)
204 }
205 .flowOn(myThreadLocal.asContextElement() + Dispatchers.Default)
206 .single()
207 myThreadLocal.set(null)
208 finish(2)
209 }
210 }
211
212 class MyData
213
214 // declare thread local variable holding MyData
215 private val myThreadLocal = ThreadLocal<MyData?>()
216
217 // declare context element holding MyData
218 class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
219 // declare companion object for a key of this element in coroutine context
220 companion object Key : CoroutineContext.Key<MyElement>
221
222 // provide the key of the corresponding context element
223 override val key: CoroutineContext.Key<MyElement>
224 get() = Key
225
226 // this is invoked before coroutine is resumed on current thread
updateThreadContextnull227 override fun updateThreadContext(context: CoroutineContext): MyData? {
228 val oldState = myThreadLocal.get()
229 myThreadLocal.set(data)
230 return oldState
231 }
232
233 // this is invoked after coroutine has suspended on current thread
restoreThreadContextnull234 override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
235 myThreadLocal.set(oldState)
236 }
237 }
238
239 /**
240 * A [ThreadContextElement] that implements copy semantics in [copyForChild].
241 */
242 class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement<MyData?> {
243 companion object Key : CoroutineContext.Key<CopyForChildCoroutineElement>
244
245 override val key: CoroutineContext.Key<CopyForChildCoroutineElement>
246 get() = Key
247
updateThreadContextnull248 override fun updateThreadContext(context: CoroutineContext): MyData? {
249 val oldState = myThreadLocal.get()
250 myThreadLocal.set(data)
251 return oldState
252 }
253
mergeForChildnull254 override fun mergeForChild(overwritingElement: CoroutineContext.Element): CopyForChildCoroutineElement {
255 TODO("Not used in tests")
256 }
257
restoreThreadContextnull258 override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
259 myThreadLocal.set(oldState)
260 }
261
262 /**
263 * At coroutine launch time, the _current value of the ThreadLocal_ is inherited by the new
264 * child coroutine, and that value is copied to a new, unique, ThreadContextElement memory
265 * reference for the child coroutine to use uniquely.
266 *
267 * n.b. the value copied to the child must be the __current value of the ThreadLocal__ and not
268 * the value initially passed to the ThreadContextElement in order to reflect writes made to the
269 * ThreadLocal between coroutine resumption and the child coroutine launch point. Those writes
270 * will be reflected in the parent coroutine's [CopyForChildCoroutineElement] when it yields the
271 * thread and calls [restoreThreadContext].
272 */
copyForChildnull273 override fun copyForChild(): CopyForChildCoroutineElement {
274 return CopyForChildCoroutineElement(myThreadLocal.get())
275 }
276 }
277
278
279 /**
280 * Calls [block], setting the value of [this] [ThreadLocal] for the duration of [block].
281 *
282 * When a [CopyForChildCoroutineElement] for `this` [ThreadLocal] is used within a
283 * [CoroutineContext], a ThreadLocal set this way will have the "correct" value expected lexically
284 * at every statement reached, whether that statement is reached immediately, across suspend and
285 * redispatch within one coroutine, or within a child coroutine. Writes made to the `ThreadLocal`
286 * by child coroutines will not be visible to the parent coroutine. Writes made to the `ThreadLocal`
287 * by the parent coroutine _after_ launching a child coroutine will not be visible to that child
288 * coroutine.
289 */
setForBlocknull290 private inline fun <ThreadLocalT, OutputT> ThreadLocal<ThreadLocalT>.setForBlock(
291 value: ThreadLocalT,
292 crossinline block: () -> OutputT
293 ) {
294 val priorValue = get()
295 set(value)
296 block()
297 set(priorValue)
298 }
299
300