• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines.test.internal
6 
7 import kotlinx.atomicfu.*
8 import kotlinx.coroutines.*
9 import kotlinx.coroutines.test.*
10 import kotlin.coroutines.*
11 
12 /**
13  * The testable main dispatcher used by kotlinx-coroutines-test.
14  * It is a [MainCoroutineDispatcher] that delegates all actions to a settable delegate.
15  */
16 internal class TestMainDispatcher(delegate: CoroutineDispatcher):
17     MainCoroutineDispatcher(),
18     Delay
19 {
20     private val mainDispatcher = delegate
21     private var delegate = NonConcurrentlyModifiable(mainDispatcher, "Dispatchers.Main")
22 
23     private val delay
24         get() = delegate.value as? Delay ?: defaultDelay
25 
26     override val immediate: MainCoroutineDispatcher
27         get() = (delegate.value as? MainCoroutineDispatcher)?.immediate ?: this
28 
dispatchnull29     override fun dispatch(context: CoroutineContext, block: Runnable) = delegate.value.dispatch(context, block)
30 
31     override fun isDispatchNeeded(context: CoroutineContext): Boolean = delegate.value.isDispatchNeeded(context)
32 
33     override fun dispatchYield(context: CoroutineContext, block: Runnable) = delegate.value.dispatchYield(context, block)
34 
35     fun setDispatcher(dispatcher: CoroutineDispatcher) {
36         delegate.value = dispatcher
37     }
38 
resetDispatchernull39     fun resetDispatcher() {
40         delegate.value = mainDispatcher
41     }
42 
scheduleResumeAfterDelaynull43     override fun scheduleResumeAfterDelay(timeMillis: Long, continuation: CancellableContinuation<Unit>) =
44         delay.scheduleResumeAfterDelay(timeMillis, continuation)
45 
46     override fun invokeOnTimeout(timeMillis: Long, block: Runnable, context: CoroutineContext): DisposableHandle =
47         delay.invokeOnTimeout(timeMillis, block, context)
48 
49     companion object {
50         internal val currentTestDispatcher
51             get() = (Dispatchers.Main as? TestMainDispatcher)?.delegate?.value as? TestDispatcher
52 
53         internal val currentTestScheduler
54             get() = currentTestDispatcher?.scheduler
55     }
56 
57     /**
58      * A wrapper around a value that attempts to throw when writing happens concurrently with reading.
59      *
60      * The read operations never throw. Instead, the failures detected inside them will be remembered and thrown on the
61      * next modification.
62      */
63     private class NonConcurrentlyModifiable<T>(initialValue: T, private val name: String) {
64         private val reader: AtomicRef<Throwable?> = atomic(null) // last reader to attempt access
65         private val readers = atomic(0) // number of concurrent readers
66         private val writer: AtomicRef<Throwable?> = atomic(null) // writer currently performing value modification
67         private val exceptionWhenReading: AtomicRef<Throwable?> = atomic(null) // exception from reading
68         private val _value = atomic(initialValue) // the backing field for the value
69 
concurrentWWnull70         private fun concurrentWW(location: Throwable) = IllegalStateException("$name is modified concurrently", location)
71         private fun concurrentRW(location: Throwable) = IllegalStateException("$name is used concurrently with setting it", location)
72 
73         var value: T
74             get() {
75                 reader.value = Throwable("reader location")
76                 readers.incrementAndGet()
77                 writer.value?.let { exceptionWhenReading.value = concurrentRW(it) }
78                 val result = _value.value
79                 readers.decrementAndGet()
80                 return result
81             }
82             set(value) {
<lambda>null83                 exceptionWhenReading.getAndSet(null)?.let { throw it }
<lambda>null84                 if (readers.value != 0) reader.value?.let { throw concurrentRW(it) }
85                 val writerLocation = Throwable("other writer location")
<lambda>null86                 writer.getAndSet(writerLocation)?.let { throw concurrentWW(it) }
87                 _value.value = value
88                 writer.compareAndSet(writerLocation, null)
<lambda>null89                 if (readers.value != 0) reader.value?.let { throw concurrentRW(it) }
90             }
91     }
92 }
93 
94 @Suppress("INVISIBLE_MEMBER")
95 private val defaultDelay
96     inline get() = DefaultDelay
97 
98 @Suppress("INVISIBLE_MEMBER")
getTestMainDispatchernull99 internal expect fun Dispatchers.getTestMainDispatcher(): TestMainDispatcher
100