• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines
6 
7 import kotlinx.coroutines.internal.*
8 import org.junit.Test
9 import kotlin.coroutines.*
10 import kotlin.test.*
11 
12 class ThreadContextOrderTest : TestBase() {
13     /*
14      * The test verifies that two thread context elements are correctly nested:
15      * The restoration order is the reverse of update order.
16      */
17     private val transactionalContext = ThreadLocal<String>()
18     private val loggingContext = ThreadLocal<String>()
19 
20     private val transactionalElement = object : ThreadContextElement<String> {
21         override val key = ThreadLocalKey(transactionalContext)
22 
updateThreadContextnull23         override fun updateThreadContext(context: CoroutineContext): String {
24             assertEquals("test", loggingContext.get())
25             val previous = transactionalContext.get()
26             transactionalContext.set("tr coroutine")
27             return previous
28         }
29 
restoreThreadContextnull30         override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
31             assertEquals("test", loggingContext.get())
32             assertEquals("tr coroutine", transactionalContext.get())
33             transactionalContext.set(oldState)
34         }
35     }
36 
37     private val loggingElement = object : ThreadContextElement<String> {
38         override val key = ThreadLocalKey(loggingContext)
39 
updateThreadContextnull40         override fun updateThreadContext(context: CoroutineContext): String {
41             val previous = loggingContext.get()
42             loggingContext.set("log coroutine")
43             return previous
44         }
45 
restoreThreadContextnull46         override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
47             assertEquals("log coroutine", loggingContext.get())
48             assertEquals("tr coroutine", transactionalContext.get())
49             loggingContext.set(oldState)
50         }
51     }
52 
53     @Test
<lambda>null54     fun testCorrectOrder() = runTest {
55         transactionalContext.set("test")
56         loggingContext.set("test")
57         launch(transactionalElement + loggingElement) {
58             assertEquals("log coroutine", loggingContext.get())
59             assertEquals("tr coroutine", transactionalContext.get())
60         }
61         assertEquals("test", loggingContext.get())
62         assertEquals("test", transactionalContext.get())
63 
64     }
65 }
66