1 /*
2 * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3 */
4
5 package kotlinx.coroutines.test
6
7 import kotlinx.atomicfu.*
8 import kotlinx.coroutines.*
9 import kotlinx.coroutines.channels.*
10 import kotlinx.coroutines.channels.Channel.Factory.CONFLATED
11 import kotlinx.coroutines.internal.*
12 import kotlinx.coroutines.selects.*
13 import kotlin.coroutines.*
14 import kotlin.jvm.*
15 import kotlin.time.*
16 import kotlin.time.Duration.Companion.milliseconds
17
18 /**
19 * This is a scheduler for coroutines used in tests, providing the delay-skipping behavior.
20 *
21 * [Test dispatchers][TestDispatcher] are parameterized with a scheduler. Several dispatchers can share the
22 * same scheduler, in which case their knowledge about the virtual time will be synchronized. When the dispatchers
23 * require scheduling an event at a later point in time, they notify the scheduler, which will establish the order of
24 * the tasks.
25 *
26 * The scheduler can be queried to advance the time (via [advanceTimeBy]), run all the scheduled tasks advancing the
27 * virtual time as needed (via [advanceUntilIdle]), or run the tasks that are scheduled to run as soon as possible but
28 * haven't yet been dispatched (via [runCurrent]).
29 */
30 public class TestCoroutineScheduler : AbstractCoroutineContextElement(TestCoroutineScheduler),
31 CoroutineContext.Element {
32
33 /** @suppress */
34 public companion object Key : CoroutineContext.Key<TestCoroutineScheduler>
35
36 /** This heap stores the knowledge about which dispatchers are interested in which moments of virtual time. */
37 // TODO: all the synchronization is done via a separate lock, so a non-thread-safe priority queue can be used.
38 private val events = ThreadSafeHeap<TestDispatchEvent<Any>>()
39
40 /** Establishes that [currentTime] can't exceed the time of the earliest event in [events]. */
41 private val lock = SynchronizedObject()
42
43 /** This counter establishes some order on the events that happen at the same virtual time. */
44 private val count = atomic(0L)
45
46 /** The current virtual time in milliseconds. */
47 @ExperimentalCoroutinesApi
48 public var currentTime: Long = 0
<lambda>null49 get() = synchronized(lock) { field }
50 private set
51
52 /** A channel for notifying about the fact that a foreground work dispatch recently happened. */
53 private val dispatchEventsForeground: Channel<Unit> = Channel(CONFLATED)
54
55 /** A channel for notifying about the fact that a dispatch recently happened. */
56 private val dispatchEvents: Channel<Unit> = Channel(CONFLATED)
57
58 /**
59 * Registers a request for the scheduler to notify [dispatcher] at a virtual moment [timeDeltaMillis] milliseconds
60 * later via [TestDispatcher.processEvent], which will be called with the provided [marker] object.
61 *
62 * Returns the handler which can be used to cancel the registration.
63 */
registerEventnull64 internal fun <T : Any> registerEvent(
65 dispatcher: TestDispatcher,
66 timeDeltaMillis: Long,
67 marker: T,
68 context: CoroutineContext,
69 isCancelled: (T) -> Boolean
70 ): DisposableHandle {
71 require(timeDeltaMillis >= 0) { "Attempted scheduling an event earlier in time (with the time delta $timeDeltaMillis)" }
72 checkSchedulerInContext(this, context)
73 val count = count.getAndIncrement()
74 val isForeground = context[BackgroundWork] === null
75 return synchronized(lock) {
76 val time = addClamping(currentTime, timeDeltaMillis)
77 val event = TestDispatchEvent(dispatcher, count, time, marker as Any, isForeground) { isCancelled(marker) }
78 events.addLast(event)
79 /** can't be moved above: otherwise, [onDispatchEventForeground] or [onDispatchEvent] could consume the
80 * token sent here before there's actually anything in the event queue. */
81 sendDispatchEvent(context)
82 DisposableHandle {
83 synchronized(lock) {
84 events.remove(event)
85 }
86 }
87 }
88 }
89
90 /**
91 * Runs the next enqueued task, advancing the virtual time to the time of its scheduled awakening,
92 * unless [condition] holds.
93 */
tryRunNextTaskUnlessnull94 internal fun tryRunNextTaskUnless(condition: () -> Boolean): Boolean {
95 val event = synchronized(lock) {
96 if (condition()) return false
97 val event = events.removeFirstOrNull() ?: return false
98 if (currentTime > event.time)
99 currentTimeAheadOfEvents()
100 currentTime = event.time
101 event
102 }
103 event.dispatcher.processEvent(event.marker)
104 return true
105 }
106
107 /**
108 * Runs the enqueued tasks in the specified order, advancing the virtual time as needed until there are no more
109 * tasks associated with the dispatchers linked to this scheduler.
110 *
111 * A breaking change from `TestCoroutineDispatcher.advanceTimeBy` is that it no longer returns the total number of
112 * milliseconds by which the execution of this method has advanced the virtual time. If you want to recreate that
113 * functionality, query [currentTime] before and after the execution to achieve the same result.
114 */
<lambda>null115 public fun advanceUntilIdle(): Unit = advanceUntilIdleOr { events.none(TestDispatchEvent<*>::isForeground) }
116
117 /**
118 * [condition]: guaranteed to be invoked under the lock.
119 */
advanceUntilIdleOrnull120 internal fun advanceUntilIdleOr(condition: () -> Boolean) {
121 while (true) {
122 if (!tryRunNextTaskUnless(condition))
123 return
124 }
125 }
126
127 /**
128 * Runs the tasks that are scheduled to execute at this moment of virtual time.
129 */
runCurrentnull130 public fun runCurrent() {
131 val timeMark = synchronized(lock) { currentTime }
132 while (true) {
133 val event = synchronized(lock) {
134 events.removeFirstIf { it.time <= timeMark } ?: return
135 }
136 event.dispatcher.processEvent(event.marker)
137 }
138 }
139
140 /**
141 * Moves the virtual clock of this dispatcher forward by [the specified amount][delayTimeMillis], running the
142 * scheduled tasks in the meantime.
143 *
144 * Breaking changes from [TestCoroutineDispatcher.advanceTimeBy]:
145 * * Intentionally doesn't return a `Long` value, as its use cases are unclear. We may restore it in the future;
146 * please describe your use cases at [the issue tracker](https://github.com/Kotlin/kotlinx.coroutines/issues/).
147 * For now, it's possible to query [currentTime] before and after execution of this method, to the same effect.
148 * * It doesn't run the tasks that are scheduled at exactly [currentTime] + [delayTimeMillis]. For example,
149 * advancing the time by one millisecond used to run the tasks at the current millisecond *and* the next
150 * millisecond, but now will stop just before executing any task starting at the next millisecond.
151 * * Overflowing the target time used to lead to nothing being done, but will now run the tasks scheduled at up to
152 * (but not including) [Long.MAX_VALUE].
153 *
154 * @throws IllegalArgumentException if passed a negative [delay][delayTimeMillis].
155 */
156 @ExperimentalCoroutinesApi
advanceTimeBynull157 public fun advanceTimeBy(delayTimeMillis: Long): Unit = advanceTimeBy(delayTimeMillis.milliseconds)
158
159 /**
160 * Moves the virtual clock of this dispatcher forward by [the specified amount][delayTime], running the
161 * scheduled tasks in the meantime.
162 *
163 * @throws IllegalArgumentException if passed a negative [delay][delayTime].
164 */
165 public fun advanceTimeBy(delayTime: Duration) {
166 require(!delayTime.isNegative()) { "Can not advance time by a negative delay: $delayTime" }
167 val startingTime = currentTime
168 val targetTime = addClamping(startingTime, delayTime.inWholeMilliseconds)
169 while (true) {
170 val event = synchronized(lock) {
171 val timeMark = currentTime
172 val event = events.removeFirstIf { targetTime > it.time }
173 when {
174 event == null -> {
175 currentTime = targetTime
176 return
177 }
178 timeMark > event.time -> currentTimeAheadOfEvents()
179 else -> {
180 currentTime = event.time
181 event
182 }
183 }
184 }
185 event.dispatcher.processEvent(event.marker)
186 }
187 }
188
189 /**
190 * Checks that the only tasks remaining in the scheduler are cancelled.
191 */
isIdlenull192 internal fun isIdle(strict: Boolean = true): Boolean =
193 synchronized(lock) {
194 if (strict) events.isEmpty else events.none { !it.isCancelled() }
195 }
196
197 /**
198 * Notifies this scheduler about a dispatch event.
199 *
200 * [context] is the context in which the task will be dispatched.
201 */
sendDispatchEventnull202 internal fun sendDispatchEvent(context: CoroutineContext) {
203 dispatchEvents.trySend(Unit)
204 if (context[BackgroundWork] !== BackgroundWork)
205 dispatchEventsForeground.trySend(Unit)
206 }
207
208 /**
209 * Waits for a notification about a dispatch event.
210 */
receiveDispatchEventnull211 internal suspend fun receiveDispatchEvent() = dispatchEvents.receive()
212
213 /**
214 * Consumes the knowledge that a dispatch event happened recently.
215 */
216 internal val onDispatchEvent: SelectClause1<Unit> get() = dispatchEvents.onReceive
217
218 /**
219 * Consumes the knowledge that a foreground work dispatch event happened recently.
220 */
221 internal val onDispatchEventForeground: SelectClause1<Unit> get() = dispatchEventsForeground.onReceive
222
223 /**
224 * Returns the [TimeSource] representation of the virtual time of this scheduler.
225 */
226 @ExperimentalTime
227 public val timeSource: TimeSource.WithComparableMarks = object : AbstractLongTimeSource(DurationUnit.MILLISECONDS) {
228 override fun read(): Long = currentTime
229 }
230 }
231
232 // Some error-throwing functions for pretty stack traces
currentTimeAheadOfEventsnull233 private fun currentTimeAheadOfEvents(): Nothing = invalidSchedulerState()
234
235 private fun invalidSchedulerState(): Nothing =
236 throw IllegalStateException("The test scheduler entered an invalid state. Please report this at https://github.com/Kotlin/kotlinx.coroutines/issues.")
237
238 /** [ThreadSafeHeap] node representing a scheduled task, ordered by the planned execution time. */
239 private class TestDispatchEvent<T>(
240 @JvmField val dispatcher: TestDispatcher,
241 private val count: Long,
242 @JvmField val time: Long,
243 @JvmField val marker: T,
244 @JvmField val isForeground: Boolean,
245 // TODO: remove once the deprecated API is gone
246 @JvmField val isCancelled: () -> Boolean
247 ) : Comparable<TestDispatchEvent<*>>, ThreadSafeHeapNode {
248 override var heap: ThreadSafeHeap<*>? = null
249 override var index: Int = 0
250
251 override fun compareTo(other: TestDispatchEvent<*>) =
252 compareValuesBy(this, other, TestDispatchEvent<*>::time, TestDispatchEvent<*>::count)
253
254 override fun toString() = "TestDispatchEvent(time=$time, dispatcher=$dispatcher${if (isForeground) "" else ", background"})"
255 }
256
257 // works with positive `a`, `b`
<lambda>null258 private fun addClamping(a: Long, b: Long): Long = (a + b).let { if (it >= 0) it else Long.MAX_VALUE }
259
checkSchedulerInContextnull260 internal fun checkSchedulerInContext(scheduler: TestCoroutineScheduler, context: CoroutineContext) {
261 context[TestCoroutineScheduler]?.let {
262 check(it === scheduler) {
263 "Detected use of different schedulers. If you need to use several test coroutine dispatchers, " +
264 "create one `TestCoroutineScheduler` and pass it to each of them."
265 }
266 }
267 }
268
269 /**
270 * A coroutine context key denoting that the work is to be executed in the background.
271 * @see [TestScope.backgroundScope]
272 */
273 internal object BackgroundWork : CoroutineContext.Key<BackgroundWork>, CoroutineContext.Element {
274 override val key: CoroutineContext.Key<*>
275 get() = this
276
toStringnull277 override fun toString(): String = "BackgroundWork"
278 }
279
280 private fun<T> ThreadSafeHeap<T>.none(predicate: (T) -> Boolean) where T: ThreadSafeHeapNode, T: Comparable<T> =
281 find(predicate) == null
282