• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines.test
6 
7 import kotlinx.atomicfu.*
8 import kotlinx.coroutines.*
9 import kotlinx.coroutines.channels.*
10 import kotlinx.coroutines.channels.Channel.Factory.CONFLATED
11 import kotlinx.coroutines.internal.*
12 import kotlinx.coroutines.selects.*
13 import kotlin.coroutines.*
14 import kotlin.jvm.*
15 import kotlin.time.*
16 import kotlin.time.Duration.Companion.milliseconds
17 
18 /**
19  * This is a scheduler for coroutines used in tests, providing the delay-skipping behavior.
20  *
21  * [Test dispatchers][TestDispatcher] are parameterized with a scheduler. Several dispatchers can share the
22  * same scheduler, in which case their knowledge about the virtual time will be synchronized. When the dispatchers
23  * require scheduling an event at a later point in time, they notify the scheduler, which will establish the order of
24  * the tasks.
25  *
26  * The scheduler can be queried to advance the time (via [advanceTimeBy]), run all the scheduled tasks advancing the
27  * virtual time as needed (via [advanceUntilIdle]), or run the tasks that are scheduled to run as soon as possible but
28  * haven't yet been dispatched (via [runCurrent]).
29  */
30 public class TestCoroutineScheduler : AbstractCoroutineContextElement(TestCoroutineScheduler),
31     CoroutineContext.Element {
32 
33     /** @suppress */
34     public companion object Key : CoroutineContext.Key<TestCoroutineScheduler>
35 
36     /** This heap stores the knowledge about which dispatchers are interested in which moments of virtual time. */
37     // TODO: all the synchronization is done via a separate lock, so a non-thread-safe priority queue can be used.
38     private val events = ThreadSafeHeap<TestDispatchEvent<Any>>()
39 
40     /** Establishes that [currentTime] can't exceed the time of the earliest event in [events]. */
41     private val lock = SynchronizedObject()
42 
43     /** This counter establishes some order on the events that happen at the same virtual time. */
44     private val count = atomic(0L)
45 
46     /** The current virtual time in milliseconds. */
47     @ExperimentalCoroutinesApi
48     public var currentTime: Long = 0
<lambda>null49         get() = synchronized(lock) { field }
50         private set
51 
52     /** A channel for notifying about the fact that a foreground work dispatch recently happened. */
53     private val dispatchEventsForeground: Channel<Unit> = Channel(CONFLATED)
54 
55     /** A channel for notifying about the fact that a dispatch recently happened. */
56     private val dispatchEvents: Channel<Unit> = Channel(CONFLATED)
57 
58     /**
59      * Registers a request for the scheduler to notify [dispatcher] at a virtual moment [timeDeltaMillis] milliseconds
60      * later via [TestDispatcher.processEvent], which will be called with the provided [marker] object.
61      *
62      * Returns the handler which can be used to cancel the registration.
63      */
registerEventnull64     internal fun <T : Any> registerEvent(
65         dispatcher: TestDispatcher,
66         timeDeltaMillis: Long,
67         marker: T,
68         context: CoroutineContext,
69         isCancelled: (T) -> Boolean
70     ): DisposableHandle {
71         require(timeDeltaMillis >= 0) { "Attempted scheduling an event earlier in time (with the time delta $timeDeltaMillis)" }
72         checkSchedulerInContext(this, context)
73         val count = count.getAndIncrement()
74         val isForeground = context[BackgroundWork] === null
75         return synchronized(lock) {
76             val time = addClamping(currentTime, timeDeltaMillis)
77             val event = TestDispatchEvent(dispatcher, count, time, marker as Any, isForeground) { isCancelled(marker) }
78             events.addLast(event)
79             /** can't be moved above: otherwise, [onDispatchEventForeground] or [onDispatchEvent] could consume the
80              * token sent here before there's actually anything in the event queue. */
81             sendDispatchEvent(context)
82             DisposableHandle {
83                 synchronized(lock) {
84                     events.remove(event)
85                 }
86             }
87         }
88     }
89 
90     /**
91      * Runs the next enqueued task, advancing the virtual time to the time of its scheduled awakening,
92      * unless [condition] holds.
93      */
tryRunNextTaskUnlessnull94     internal fun tryRunNextTaskUnless(condition: () -> Boolean): Boolean {
95         val event = synchronized(lock) {
96             if (condition()) return false
97             val event = events.removeFirstOrNull() ?: return false
98             if (currentTime > event.time)
99                 currentTimeAheadOfEvents()
100             currentTime = event.time
101             event
102         }
103         event.dispatcher.processEvent(event.marker)
104         return true
105     }
106 
107     /**
108      * Runs the enqueued tasks in the specified order, advancing the virtual time as needed until there are no more
109      * tasks associated with the dispatchers linked to this scheduler.
110      *
111      * A breaking change from `TestCoroutineDispatcher.advanceTimeBy` is that it no longer returns the total number of
112      * milliseconds by which the execution of this method has advanced the virtual time. If you want to recreate that
113      * functionality, query [currentTime] before and after the execution to achieve the same result.
114      */
<lambda>null115     public fun advanceUntilIdle(): Unit = advanceUntilIdleOr { events.none(TestDispatchEvent<*>::isForeground) }
116 
117     /**
118      * [condition]: guaranteed to be invoked under the lock.
119      */
advanceUntilIdleOrnull120     internal fun advanceUntilIdleOr(condition: () -> Boolean) {
121         while (true) {
122             if (!tryRunNextTaskUnless(condition))
123                 return
124         }
125     }
126 
127     /**
128      * Runs the tasks that are scheduled to execute at this moment of virtual time.
129      */
runCurrentnull130     public fun runCurrent() {
131         val timeMark = synchronized(lock) { currentTime }
132         while (true) {
133             val event = synchronized(lock) {
134                 events.removeFirstIf { it.time <= timeMark } ?: return
135             }
136             event.dispatcher.processEvent(event.marker)
137         }
138     }
139 
140     /**
141      * Moves the virtual clock of this dispatcher forward by [the specified amount][delayTimeMillis], running the
142      * scheduled tasks in the meantime.
143      *
144      * Breaking changes from [TestCoroutineDispatcher.advanceTimeBy]:
145      * * Intentionally doesn't return a `Long` value, as its use cases are unclear. We may restore it in the future;
146      *   please describe your use cases at [the issue tracker](https://github.com/Kotlin/kotlinx.coroutines/issues/).
147      *   For now, it's possible to query [currentTime] before and after execution of this method, to the same effect.
148      * * It doesn't run the tasks that are scheduled at exactly [currentTime] + [delayTimeMillis]. For example,
149      *   advancing the time by one millisecond used to run the tasks at the current millisecond *and* the next
150      *   millisecond, but now will stop just before executing any task starting at the next millisecond.
151      * * Overflowing the target time used to lead to nothing being done, but will now run the tasks scheduled at up to
152      *   (but not including) [Long.MAX_VALUE].
153      *
154      * @throws IllegalArgumentException if passed a negative [delay][delayTimeMillis].
155      */
156     @ExperimentalCoroutinesApi
advanceTimeBynull157     public fun advanceTimeBy(delayTimeMillis: Long): Unit = advanceTimeBy(delayTimeMillis.milliseconds)
158 
159     /**
160      * Moves the virtual clock of this dispatcher forward by [the specified amount][delayTime], running the
161      * scheduled tasks in the meantime.
162      *
163      * @throws IllegalArgumentException if passed a negative [delay][delayTime].
164      */
165     public fun advanceTimeBy(delayTime: Duration) {
166         require(!delayTime.isNegative()) { "Can not advance time by a negative delay: $delayTime" }
167         val startingTime = currentTime
168         val targetTime = addClamping(startingTime, delayTime.inWholeMilliseconds)
169         while (true) {
170             val event = synchronized(lock) {
171                 val timeMark = currentTime
172                 val event = events.removeFirstIf { targetTime > it.time }
173                 when {
174                     event == null -> {
175                         currentTime = targetTime
176                         return
177                     }
178                     timeMark > event.time -> currentTimeAheadOfEvents()
179                     else -> {
180                         currentTime = event.time
181                         event
182                     }
183                 }
184             }
185             event.dispatcher.processEvent(event.marker)
186         }
187     }
188 
189     /**
190      * Checks that the only tasks remaining in the scheduler are cancelled.
191      */
isIdlenull192     internal fun isIdle(strict: Boolean = true): Boolean =
193         synchronized(lock) {
194             if (strict) events.isEmpty else events.none { !it.isCancelled() }
195         }
196 
197     /**
198      * Notifies this scheduler about a dispatch event.
199      *
200      * [context] is the context in which the task will be dispatched.
201      */
sendDispatchEventnull202     internal fun sendDispatchEvent(context: CoroutineContext) {
203         dispatchEvents.trySend(Unit)
204         if (context[BackgroundWork] !== BackgroundWork)
205             dispatchEventsForeground.trySend(Unit)
206     }
207 
208     /**
209      * Waits for a notification about a dispatch event.
210      */
receiveDispatchEventnull211     internal suspend fun receiveDispatchEvent() = dispatchEvents.receive()
212 
213     /**
214      * Consumes the knowledge that a dispatch event happened recently.
215      */
216     internal val onDispatchEvent: SelectClause1<Unit> get() = dispatchEvents.onReceive
217 
218     /**
219      * Consumes the knowledge that a foreground work dispatch event happened recently.
220      */
221     internal val onDispatchEventForeground: SelectClause1<Unit> get() = dispatchEventsForeground.onReceive
222 
223     /**
224      * Returns the [TimeSource] representation of the virtual time of this scheduler.
225      */
226     @ExperimentalTime
227     public val timeSource: TimeSource.WithComparableMarks = object : AbstractLongTimeSource(DurationUnit.MILLISECONDS) {
228         override fun read(): Long = currentTime
229     }
230 }
231 
232 // Some error-throwing functions for pretty stack traces
currentTimeAheadOfEventsnull233 private fun currentTimeAheadOfEvents(): Nothing = invalidSchedulerState()
234 
235 private fun invalidSchedulerState(): Nothing =
236     throw IllegalStateException("The test scheduler entered an invalid state. Please report this at https://github.com/Kotlin/kotlinx.coroutines/issues.")
237 
238 /** [ThreadSafeHeap] node representing a scheduled task, ordered by the planned execution time. */
239 private class TestDispatchEvent<T>(
240     @JvmField val dispatcher: TestDispatcher,
241     private val count: Long,
242     @JvmField val time: Long,
243     @JvmField val marker: T,
244     @JvmField val isForeground: Boolean,
245     // TODO: remove once the deprecated API is gone
246     @JvmField val isCancelled: () -> Boolean
247 ) : Comparable<TestDispatchEvent<*>>, ThreadSafeHeapNode {
248     override var heap: ThreadSafeHeap<*>? = null
249     override var index: Int = 0
250 
251     override fun compareTo(other: TestDispatchEvent<*>) =
252         compareValuesBy(this, other, TestDispatchEvent<*>::time, TestDispatchEvent<*>::count)
253 
254     override fun toString() = "TestDispatchEvent(time=$time, dispatcher=$dispatcher${if (isForeground) "" else ", background"})"
255 }
256 
257 // works with positive `a`, `b`
<lambda>null258 private fun addClamping(a: Long, b: Long): Long = (a + b).let { if (it >= 0) it else Long.MAX_VALUE }
259 
checkSchedulerInContextnull260 internal fun checkSchedulerInContext(scheduler: TestCoroutineScheduler, context: CoroutineContext) {
261     context[TestCoroutineScheduler]?.let {
262         check(it === scheduler) {
263             "Detected use of different schedulers. If you need to use several test coroutine dispatchers, " +
264                 "create one `TestCoroutineScheduler` and pass it to each of them."
265         }
266     }
267 }
268 
269 /**
270  * A coroutine context key denoting that the work is to be executed in the background.
271  * @see [TestScope.backgroundScope]
272  */
273 internal object BackgroundWork : CoroutineContext.Key<BackgroundWork>, CoroutineContext.Element {
274     override val key: CoroutineContext.Key<*>
275         get() = this
276 
toStringnull277     override fun toString(): String = "BackgroundWork"
278 }
279 
280 private fun<T> ThreadSafeHeap<T>.none(predicate: (T) -> Boolean) where T: ThreadSafeHeapNode, T: Comparable<T> =
281     find(predicate) == null
282