1 /* 2 * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. 3 */ 4 5 package kotlinx.coroutines 6 7 import kotlinx.coroutines.internal.* 8 import org.junit.Test 9 import kotlin.coroutines.* 10 import kotlin.test.* 11 12 class ThreadContextOrderTest : TestBase() { 13 /* 14 * The test verifies that two thread context elements are correctly nested: 15 * The restoration order is the reverse of update order. 16 */ 17 private val transactionalContext = ThreadLocal<String>() 18 private val loggingContext = ThreadLocal<String>() 19 20 private val transactionalElement = object : ThreadContextElement<String> { 21 override val key = ThreadLocalKey(transactionalContext) 22 updateThreadContextnull23 override fun updateThreadContext(context: CoroutineContext): String { 24 assertEquals("test", loggingContext.get()) 25 val previous = transactionalContext.get() 26 transactionalContext.set("tr coroutine") 27 return previous 28 } 29 restoreThreadContextnull30 override fun restoreThreadContext(context: CoroutineContext, oldState: String) { 31 assertEquals("test", loggingContext.get()) 32 assertEquals("tr coroutine", transactionalContext.get()) 33 transactionalContext.set(oldState) 34 } 35 } 36 37 private val loggingElement = object : ThreadContextElement<String> { 38 override val key = ThreadLocalKey(loggingContext) 39 updateThreadContextnull40 override fun updateThreadContext(context: CoroutineContext): String { 41 val previous = loggingContext.get() 42 loggingContext.set("log coroutine") 43 return previous 44 } 45 restoreThreadContextnull46 override fun restoreThreadContext(context: CoroutineContext, oldState: String) { 47 assertEquals("log coroutine", loggingContext.get()) 48 assertEquals("tr coroutine", transactionalContext.get()) 49 loggingContext.set(oldState) 50 } 51 } 52 53 @Test <lambda>null54 fun testCorrectOrder() = runTest { 55 transactionalContext.set("test") 56 loggingContext.set("test") 57 launch(transactionalElement + loggingElement) { 58 assertEquals("log coroutine", loggingContext.get()) 59 assertEquals("tr coroutine", transactionalContext.get()) 60 } 61 assertEquals("test", loggingContext.get()) 62 assertEquals("test", transactionalContext.get()) 63 64 } 65 } 66