• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines
6 
7 import kotlinx.coroutines.sync.*
8 import java.util.concurrent.*
9 import kotlin.coroutines.*
10 import kotlin.coroutines.intrinsics.*
11 import kotlin.test.*
12 
13 
14 class ThreadLocalStressTest : TestBase() {
15 
16     private val threadLocal = ThreadLocal<String>()
17 
18     // See the comment in doStress for the machinery
19     @Test
<lambda>null20     fun testStress() = runTest {
21         repeat (100 * stressTestMultiplierSqrt) {
22             withContext(Dispatchers.Default) {
23                 repeat(100) {
24                     launch {
25                         doStress(null)
26                     }
27                 }
28             }
29         }
30     }
31 
32     @Test
<lambda>null33     fun testStressWithOuterValue() = runTest {
34         repeat (100 * stressTestMultiplierSqrt) {
35             withContext(Dispatchers.Default + threadLocal.asContextElement("bar")) {
36                 repeat(100) {
37                     launch {
38                         doStress("bar")
39                     }
40                 }
41             }
42         }
43     }
44 
doStressnull45     private suspend fun doStress(expectedValue: String?) {
46         assertEquals(expectedValue, threadLocal.get())
47         try {
48             /*
49              * Here we are using very specific code-path to trigger the execution we want to.
50              * The bug, in general, has a larger impact, but this particular code pinpoints it:
51              *
52              * 1) We use _undispatched_ withContext with thread element
53              * 2) We cancel the coroutine
54              * 3) We use 'suspendCancellableCoroutineReusable' that does _postponed_ cancellation check
55              *    which makes the reproduction of this race pretty reliable.
56              *
57              * Now the following code path is likely to be triggered:
58              *
59              * T1 from within 'withContinuationContext' method:
60              * Finds 'oldValue', finds undispatched completion, invokes its 'block' argument.
61              * 'block' is this coroutine, it goes to 'trySuspend', checks for postponed cancellation and *dispatches* it.
62              * The execution stops _right_ before 'undispatchedCompletion.clearThreadContext()'.
63              *
64              * T2 now executes the dispatched cancellation and concurrently mutates the state of the undispatched completion.
65              * All bets are off, now both threads can leave the thread locals state inconsistent.
66              */
67             withContext(threadLocal.asContextElement("foo")) {
68                 yield()
69                 cancel()
70                 suspendCancellableCoroutineReusable<Unit> { }
71             }
72         } finally {
73             assertEquals(expectedValue, threadLocal.get())
74         }
75     }
76 
77     /*
78      * Another set of tests for undispatcheable continuations that do not require stress test multiplier.
79      * Also note that `uncaughtExceptionHandler` is used as the only available mechanism to propagate error from
80      * `resumeWith`
81      */
82 
83     @Test
testNonDispatcheableLeaknull84     fun testNonDispatcheableLeak() {
85         repeat(100) {
86             doTestWithPreparation(
87                 ::doTest,
88                 { threadLocal.set(null) }) { threadLocal.get() == null }
89             assertNull(threadLocal.get())
90         }
91     }
92 
93     @Test
testNonDispatcheableLeakWithInitialnull94     fun testNonDispatcheableLeakWithInitial() {
95         repeat(100) {
96             doTestWithPreparation(::doTest, { threadLocal.set("initial") }) { threadLocal.get() == "initial" }
97             assertEquals("initial", threadLocal.get())
98         }
99     }
100 
101     @Test
testNonDispatcheableLeakWithContextSwitchnull102     fun testNonDispatcheableLeakWithContextSwitch() {
103         repeat(100) {
104             doTestWithPreparation(
105                 ::doTestWithContextSwitch,
106                 { threadLocal.set(null) }) { threadLocal.get() == null }
107             assertNull(threadLocal.get())
108         }
109     }
110 
111     @Test
testNonDispatcheableLeakWithInitialWithContextSwitchnull112     fun testNonDispatcheableLeakWithInitialWithContextSwitch() {
113         repeat(100) {
114             doTestWithPreparation(
115                 ::doTestWithContextSwitch,
116                 { threadLocal.set("initial") }) { true /* can randomly wake up on the non-main thread */ }
117             // Here we are always on the main thread
118             assertEquals("initial", threadLocal.get())
119         }
120     }
121 
doTestWithPreparationnull122     private fun doTestWithPreparation(testBody: suspend () -> Unit, setup: () -> Unit, isValid: () -> Boolean) {
123         setup()
124         val latch = CountDownLatch(1)
125         testBody.startCoroutineUninterceptedOrReturn(Continuation(EmptyCoroutineContext) {
126             if (!isValid()) {
127                 Thread.currentThread().uncaughtExceptionHandler.uncaughtException(
128                     Thread.currentThread(),
129                     IllegalStateException("Unexpected error: thread local was not cleaned")
130                 )
131             }
132             latch.countDown()
133         })
134         latch.await()
135     }
136 
doTestnull137     private suspend fun doTest() {
138         withContext(threadLocal.asContextElement("foo")) {
139             try {
140                 coroutineScope {
141                     val semaphore = Semaphore(1, 1)
142                     cancel()
143                     semaphore.acquire()
144                 }
145             } catch (e: CancellationException) {
146                 // Ignore cancellation
147             }
148         }
149     }
150 
doTestWithContextSwitchnull151     private suspend fun doTestWithContextSwitch() {
152         withContext(threadLocal.asContextElement("foo")) {
153             try {
154                 coroutineScope {
155                     val semaphore = Semaphore(1, 1)
156                     GlobalScope.launch { }.join()
157                     cancel()
158                     semaphore.acquire()
159                 }
160             } catch (e: CancellationException) {
161                 // Ignore cancellation
162             }
163         }
164     }
165 }
166