• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines
6 
7 import org.junit.Test
8 import kotlin.coroutines.*
9 import kotlin.test.*
10 import kotlinx.coroutines.flow.*
11 
12 class ThreadContextElementTest : TestBase() {
13 
14     @Test
<lambda>null15     fun testExample() = runTest {
16         val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!!
17         val mainDispatcher = coroutineContext[ContinuationInterceptor]!!
18         val mainThread = Thread.currentThread()
19         val data = MyData()
20         val element = MyElement(data)
21         assertNull(myThreadLocal.get())
22         val job = GlobalScope.launch(element + exceptionHandler) {
23             assertTrue(mainThread != Thread.currentThread())
24             assertSame(element, coroutineContext[MyElement])
25             assertSame(data, myThreadLocal.get())
26             withContext(mainDispatcher) {
27                 assertSame(mainThread, Thread.currentThread())
28                 assertSame(element, coroutineContext[MyElement])
29                 assertSame(data, myThreadLocal.get())
30             }
31             assertTrue(mainThread != Thread.currentThread())
32             assertSame(element, coroutineContext[MyElement])
33             assertSame(data, myThreadLocal.get())
34         }
35         assertNull(myThreadLocal.get())
36         job.join()
37         assertNull(myThreadLocal.get())
38     }
39 
40     @Test
testUndispatchednull41     fun testUndispatched() = runTest {
42         val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!!
43         val data = MyData()
44         val element = MyElement(data)
45         val job = GlobalScope.launch(
46             context = Dispatchers.Default + exceptionHandler + element,
47             start = CoroutineStart.UNDISPATCHED
48         ) {
49             assertSame(data, myThreadLocal.get())
50             yield()
51             assertSame(data, myThreadLocal.get())
52         }
53         assertNull(myThreadLocal.get())
54         job.join()
55         assertNull(myThreadLocal.get())
56     }
57 
58     @Test
<lambda>null59     fun testWithContext() = runTest {
60         expect(1)
61         newSingleThreadContext("withContext").use {
62             val data = MyData()
63             GlobalScope.async(Dispatchers.Default + MyElement(data)) {
64                 assertSame(data, myThreadLocal.get())
65                 expect(2)
66 
67                 val newData = MyData()
68                 GlobalScope.async(it + MyElement(newData)) {
69                     assertSame(newData, myThreadLocal.get())
70                     expect(3)
71                 }.await()
72 
73                 withContext(it + MyElement(newData)) {
74                     assertSame(newData, myThreadLocal.get())
75                     expect(4)
76                 }
77 
78                 GlobalScope.async(it) {
79                     assertNull(myThreadLocal.get())
80                     expect(5)
81                 }.await()
82 
83                 expect(6)
84             }.await()
85         }
86 
87         finish(7)
88     }
89 
90     @Test
<lambda>null91     fun testNonCopyableElementReferenceInheritedOnLaunch() = runTest {
92         var parentElement: MyElement? = null
93         var inheritedElement: MyElement? = null
94 
95         newSingleThreadContext("withContext").use {
96             withContext(it + MyElement(MyData())) {
97                 parentElement = coroutineContext[MyElement.Key]
98                 launch {
99                     inheritedElement = coroutineContext[MyElement.Key]
100                 }
101             }
102         }
103 
104         assertSame(inheritedElement, parentElement,
105             "Inner and outer coroutines did not have the same object reference to a" +
106                 " ThreadContextElement that did not override `copyForChildCoroutine()`")
107     }
108 
109     @Test
<lambda>null110     fun testCopyableElementCopiedOnLaunch() = runTest {
111         var parentElement: CopyForChildCoroutineElement? = null
112         var inheritedElement: CopyForChildCoroutineElement? = null
113 
114         newSingleThreadContext("withContext").use {
115             withContext(it + CopyForChildCoroutineElement(MyData())) {
116                 parentElement = coroutineContext[CopyForChildCoroutineElement.Key]
117                 launch {
118                     inheritedElement = coroutineContext[CopyForChildCoroutineElement.Key]
119                 }
120             }
121         }
122 
123         assertNotSame(inheritedElement, parentElement,
124             "Inner coroutine did not copy its copyable ThreadContextElement.")
125     }
126 
127     @Test
<lambda>null128     fun testCopyableThreadContextElementImplementsWriteVisibility() = runTest {
129         newFixedThreadPoolContext(nThreads = 4, name = "withContext").use {
130             withContext(it + CopyForChildCoroutineElement(MyData())) {
131                 val forBlockData = MyData()
132                 myThreadLocal.setForBlock(forBlockData) {
133                     assertSame(myThreadLocal.get(), forBlockData)
134                     launch {
135                         assertSame(myThreadLocal.get(), forBlockData)
136                     }
137                     launch {
138                         assertSame(myThreadLocal.get(), forBlockData)
139                         // Modify value in child coroutine. Writes to the ThreadLocal and
140                         // the (copied) ThreadLocalElement's memory are not visible to peer or
141                         // ancestor coroutines, so this write is both threadsafe and coroutinesafe.
142                         val innerCoroutineData = MyData()
143                         myThreadLocal.setForBlock(innerCoroutineData) {
144                             assertSame(myThreadLocal.get(), innerCoroutineData)
145                         }
146                         assertSame(myThreadLocal.get(), forBlockData) // Asserts value was restored.
147                     }
148                     launch {
149                         val innerCoroutineData = MyData()
150                         myThreadLocal.setForBlock(innerCoroutineData) {
151                             assertSame(myThreadLocal.get(), innerCoroutineData)
152                         }
153                         assertSame(myThreadLocal.get(), forBlockData)
154                     }
155                 }
156                 assertNull(myThreadLocal.get()) // Asserts value was restored to its origin
157             }
158         }
159     }
160 
161     class JobCaptor(val capturees: ArrayList<Job> = ArrayList()) : ThreadContextElement<Unit> {
162 
163         companion object Key : CoroutineContext.Key<MyElement>
164 
165         override val key: CoroutineContext.Key<*> get() = Key
166 
updateThreadContextnull167         override fun updateThreadContext(context: CoroutineContext) {
168             capturees.add(context.job)
169         }
170 
restoreThreadContextnull171         override fun restoreThreadContext(context: CoroutineContext, oldState: Unit) {
172         }
173     }
174 
175     @Test
<lambda>null176     fun testWithContextJobAccess() = runTest {
177         val captor = JobCaptor()
178         val manuallyCaptured = ArrayList<Job>()
179         runBlocking(captor) {
180             manuallyCaptured += coroutineContext.job
181             withContext(CoroutineName("undispatched")) {
182                 manuallyCaptured += coroutineContext.job
183                 withContext(Dispatchers.IO) {
184                     manuallyCaptured += coroutineContext.job
185                 }
186                 // Context restored, captured again
187                 manuallyCaptured += coroutineContext.job
188             }
189             // Context restored, captured again
190             manuallyCaptured += coroutineContext.job
191         }
192 
193         assertEquals(manuallyCaptured, captor.capturees)
194     }
195 
196     @Test
testThreadLocalFlowOnnull197     fun testThreadLocalFlowOn() = runTest {
198         val myData = MyData()
199         myThreadLocal.set(myData)
200         expect(1)
201         flow {
202             assertEquals(myData, myThreadLocal.get())
203             emit(1)
204         }
205             .flowOn(myThreadLocal.asContextElement() + Dispatchers.Default)
206             .single()
207         myThreadLocal.set(null)
208         finish(2)
209     }
210 }
211 
212 class MyData
213 
214 // declare thread local variable holding MyData
215 private val myThreadLocal = ThreadLocal<MyData?>()
216 
217 // declare context element holding MyData
218 class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
219     // declare companion object for a key of this element in coroutine context
220     companion object Key : CoroutineContext.Key<MyElement>
221 
222     // provide the key of the corresponding context element
223     override val key: CoroutineContext.Key<MyElement>
224         get() = Key
225 
226     // this is invoked before coroutine is resumed on current thread
updateThreadContextnull227     override fun updateThreadContext(context: CoroutineContext): MyData? {
228         val oldState = myThreadLocal.get()
229         myThreadLocal.set(data)
230         return oldState
231     }
232 
233     // this is invoked after coroutine has suspended on current thread
restoreThreadContextnull234     override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
235         myThreadLocal.set(oldState)
236     }
237 }
238 
239 /**
240  * A [ThreadContextElement] that implements copy semantics in [copyForChild].
241  */
242 class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement<MyData?> {
243     companion object Key : CoroutineContext.Key<CopyForChildCoroutineElement>
244 
245     override val key: CoroutineContext.Key<CopyForChildCoroutineElement>
246         get() = Key
247 
updateThreadContextnull248     override fun updateThreadContext(context: CoroutineContext): MyData? {
249         val oldState = myThreadLocal.get()
250         myThreadLocal.set(data)
251         return oldState
252     }
253 
mergeForChildnull254     override fun mergeForChild(overwritingElement: CoroutineContext.Element): CopyForChildCoroutineElement {
255         TODO("Not used in tests")
256     }
257 
restoreThreadContextnull258     override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
259         myThreadLocal.set(oldState)
260     }
261 
262     /**
263      * At coroutine launch time, the _current value of the ThreadLocal_ is inherited by the new
264      * child coroutine, and that value is copied to a new, unique, ThreadContextElement memory
265      * reference for the child coroutine to use uniquely.
266      *
267      * n.b. the value copied to the child must be the __current value of the ThreadLocal__ and not
268      * the value initially passed to the ThreadContextElement in order to reflect writes made to the
269      * ThreadLocal between coroutine resumption and the child coroutine launch point. Those writes
270      * will be reflected in the parent coroutine's [CopyForChildCoroutineElement] when it yields the
271      * thread and calls [restoreThreadContext].
272      */
copyForChildnull273     override fun copyForChild(): CopyForChildCoroutineElement {
274         return CopyForChildCoroutineElement(myThreadLocal.get())
275     }
276 }
277 
278 
279 /**
280  * Calls [block], setting the value of [this] [ThreadLocal] for the duration of [block].
281  *
282  * When a [CopyForChildCoroutineElement] for `this` [ThreadLocal] is used within a
283  * [CoroutineContext], a ThreadLocal set this way will have the "correct" value expected lexically
284  * at every statement reached, whether that statement is reached immediately, across suspend and
285  * redispatch within one coroutine, or within a child coroutine. Writes made to the `ThreadLocal`
286  * by child coroutines will not be visible to the parent coroutine. Writes made to the `ThreadLocal`
287  * by the parent coroutine _after_ launching a child coroutine will not be visible to that child
288  * coroutine.
289  */
setForBlocknull290 private inline fun <ThreadLocalT, OutputT> ThreadLocal<ThreadLocalT>.setForBlock(
291     value: ThreadLocalT,
292     crossinline block: () -> OutputT
293 ) {
294     val priorValue = get()
295     set(value)
296     block()
297     set(priorValue)
298 }
299 
300