1 /* 2 * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. 3 */ 4 5 package kotlinx.coroutines 6 7 import org.junit.* 8 import org.junit.Test 9 import java.lang.IllegalStateException 10 import kotlin.test.* 11 12 @Suppress("RedundantAsync") 13 class ThreadLocalTest : TestBase() { 14 private val stringThreadLocal = ThreadLocal<String?>() 15 private val intThreadLocal = ThreadLocal<Int?>() 16 private val executor = newFixedThreadPoolContext(1, "threadLocalTest") 17 18 @After tearDownnull19 fun tearDown() { 20 executor.close() 21 } 22 23 @Test <lambda>null24 fun testThreadLocal() = runTest { 25 assertNull(stringThreadLocal.get()) 26 assertFalse(stringThreadLocal.isPresent()) 27 val deferred = async(Dispatchers.Default + stringThreadLocal.asContextElement("value")) { 28 assertEquals("value", stringThreadLocal.get()) 29 assertTrue(stringThreadLocal.isPresent()) 30 withContext(executor) { 31 assertTrue(stringThreadLocal.isPresent()) 32 assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() } 33 assertEquals("value", stringThreadLocal.get()) 34 } 35 assertTrue(stringThreadLocal.isPresent()) 36 assertEquals("value", stringThreadLocal.get()) 37 } 38 39 assertNull(stringThreadLocal.get()) 40 deferred.await() 41 assertNull(stringThreadLocal.get()) 42 assertFalse(stringThreadLocal.isPresent()) 43 } 44 45 @Test <lambda>null46 fun testThreadLocalInitialValue() = runTest { 47 intThreadLocal.set(42) 48 assertFalse(intThreadLocal.isPresent()) 49 val deferred = async(Dispatchers.Default + intThreadLocal.asContextElement(239)) { 50 assertEquals(239, intThreadLocal.get()) 51 withContext(executor) { 52 intThreadLocal.ensurePresent() 53 assertEquals(239, intThreadLocal.get()) 54 } 55 assertEquals(239, intThreadLocal.get()) 56 } 57 58 deferred.await() 59 assertEquals(42, intThreadLocal.get()) 60 } 61 62 @Test <lambda>null63 fun testMultipleThreadLocals() = runTest { 64 stringThreadLocal.set("test") 65 intThreadLocal.set(314) 66 67 val deferred = async(Dispatchers.Default 68 + intThreadLocal.asContextElement(value = 239) + stringThreadLocal.asContextElement(value = "pew")) { 69 assertEquals(239, intThreadLocal.get()) 70 assertEquals("pew", stringThreadLocal.get()) 71 72 withContext(executor) { 73 assertEquals(239, intThreadLocal.get()) 74 assertEquals("pew", stringThreadLocal.get()) 75 intThreadLocal.ensurePresent() 76 stringThreadLocal.ensurePresent() 77 } 78 79 assertEquals(239, intThreadLocal.get()) 80 assertEquals("pew", stringThreadLocal.get()) 81 } 82 83 deferred.await() 84 assertEquals(314, intThreadLocal.get()) 85 assertEquals("test", stringThreadLocal.get()) 86 } 87 88 @Test <lambda>null89 fun testConflictingThreadLocals() = runTest { 90 intThreadLocal.set(42) 91 92 val deferred = GlobalScope.async(intThreadLocal.asContextElement(1)) { 93 assertEquals(1, intThreadLocal.get()) 94 95 withContext(executor + intThreadLocal.asContextElement(42)) { 96 assertEquals(42, intThreadLocal.get()) 97 } 98 99 assertEquals(1, intThreadLocal.get()) 100 101 val deferred = async(intThreadLocal.asContextElement(53)) { 102 assertEquals(53, intThreadLocal.get()) 103 } 104 105 deferred.await() 106 assertEquals(1, intThreadLocal.get()) 107 108 val deferred2 = GlobalScope.async(executor) { 109 assertNull(intThreadLocal.get()) 110 } 111 112 deferred2.await() 113 assertEquals(1, intThreadLocal.get()) 114 } 115 116 deferred.await() 117 assertEquals(42, intThreadLocal.get()) 118 } 119 120 @Test <lambda>null121 fun testThreadLocalModification() = runTest { 122 stringThreadLocal.set("main") 123 124 val deferred = async(Dispatchers.Default 125 + stringThreadLocal.asContextElement("initial")) { 126 assertEquals("initial", stringThreadLocal.get()) 127 128 stringThreadLocal.set("overridden") // <- this value is not reflected in the context, so it's not restored 129 130 withContext(executor + stringThreadLocal.asContextElement("ctx")) { 131 assertEquals("ctx", stringThreadLocal.get()) 132 } 133 134 val deferred = async(stringThreadLocal.asContextElement("async")) { 135 assertEquals("async", stringThreadLocal.get()) 136 } 137 138 deferred.await() 139 assertEquals("initial", stringThreadLocal.get()) // <- not restored 140 } 141 142 deferred.await() 143 assertFalse(stringThreadLocal.isPresent()) 144 assertEquals("main", stringThreadLocal.get()) 145 } 146 147 148 149 private data class Counter(var cnt: Int) 150 private val myCounterLocal = ThreadLocal<Counter>() 151 152 @Test <lambda>null153 fun testThreadLocalModificationMutableBox() = runTest { 154 myCounterLocal.set(Counter(42)) 155 156 val deferred = async(Dispatchers.Default 157 + myCounterLocal.asContextElement(Counter(0))) { 158 assertEquals(0, myCounterLocal.get().cnt) 159 160 // Mutate 161 myCounterLocal.get().cnt = 71 162 163 withContext(executor + myCounterLocal.asContextElement(Counter(-1))) { 164 assertEquals(-1, myCounterLocal.get().cnt) 165 ++myCounterLocal.get().cnt 166 } 167 168 val deferred = async(myCounterLocal.asContextElement(Counter(31))) { 169 assertEquals(31, myCounterLocal.get().cnt) 170 ++myCounterLocal.get().cnt 171 } 172 173 deferred.await() 174 assertEquals(71, myCounterLocal.get().cnt) 175 } 176 177 deferred.await() 178 assertEquals(42, myCounterLocal.get().cnt) 179 } 180 181 @Test <lambda>null182 fun testWithContext() = runTest { 183 expect(1) 184 newSingleThreadContext("withContext").use { 185 val data = 42 186 GlobalScope.async(Dispatchers.Default + intThreadLocal.asContextElement(42)) { 187 188 assertEquals(data, intThreadLocal.get()) 189 expect(2) 190 191 GlobalScope.async(it + intThreadLocal.asContextElement(31)) { 192 assertEquals(31, intThreadLocal.get()) 193 expect(3) 194 }.await() 195 196 withContext(it + intThreadLocal.asContextElement(2)) { 197 assertEquals(2, intThreadLocal.get()) 198 expect(4) 199 } 200 201 GlobalScope.async(it) { 202 assertNull(intThreadLocal.get()) 203 expect(5) 204 }.await() 205 206 expect(6) 207 }.await() 208 } 209 210 finish(7) 211 } 212 213 @Test <lambda>null214 fun testScope() = runTest { 215 intThreadLocal.set(42) 216 val mainThread = Thread.currentThread() 217 GlobalScope.async { 218 assertNull(intThreadLocal.get()) 219 assertNotSame(mainThread, Thread.currentThread()) 220 }.await() 221 222 GlobalScope.async(intThreadLocal.asContextElement()) { 223 assertEquals(42, intThreadLocal.get()) 224 assertNotSame(mainThread, Thread.currentThread()) 225 }.await() 226 } 227 228 @Test <lambda>null229 fun testMissingThreadLocal() = runTest { 230 assertFailsWith<IllegalStateException> { stringThreadLocal.ensurePresent() } 231 assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() } 232 } 233 } 234