• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines.test
6 
7 import kotlinx.atomicfu.*
8 import kotlinx.coroutines.*
9 import kotlinx.coroutines.channels.*
10 import kotlinx.coroutines.channels.Channel.Factory.CONFLATED
11 import kotlinx.coroutines.internal.*
12 import kotlinx.coroutines.selects.*
13 import kotlin.coroutines.*
14 import kotlin.jvm.*
15 import kotlin.time.*
16 
17 /**
18  * This is a scheduler for coroutines used in tests, providing the delay-skipping behavior.
19  *
20  * [Test dispatchers][TestDispatcher] are parameterized with a scheduler. Several dispatchers can share the
21  * same scheduler, in which case their knowledge about the virtual time will be synchronized. When the dispatchers
22  * require scheduling an event at a later point in time, they notify the scheduler, which will establish the order of
23  * the tasks.
24  *
25  * The scheduler can be queried to advance the time (via [advanceTimeBy]), run all the scheduled tasks advancing the
26  * virtual time as needed (via [advanceUntilIdle]), or run the tasks that are scheduled to run as soon as possible but
27  * haven't yet been dispatched (via [runCurrent]).
28  */
29 @ExperimentalCoroutinesApi
30 public class TestCoroutineScheduler : AbstractCoroutineContextElement(TestCoroutineScheduler),
31     CoroutineContext.Element {
32 
33     /** @suppress */
34     public companion object Key : CoroutineContext.Key<TestCoroutineScheduler>
35 
36     /** This heap stores the knowledge about which dispatchers are interested in which moments of virtual time. */
37     // TODO: all the synchronization is done via a separate lock, so a non-thread-safe priority queue can be used.
38     private val events = ThreadSafeHeap<TestDispatchEvent<Any>>()
39 
40     /** Establishes that [currentTime] can't exceed the time of the earliest event in [events]. */
41     private val lock = SynchronizedObject()
42 
43     /** This counter establishes some order on the events that happen at the same virtual time. */
44     private val count = atomic(0L)
45 
46     /** The current virtual time in milliseconds. */
47     @ExperimentalCoroutinesApi
48     public var currentTime: Long = 0
<lambda>null49         get() = synchronized(lock) { field }
50         private set
51 
52     /** A channel for notifying about the fact that a dispatch recently happened. */
53     private val dispatchEvents: Channel<Unit> = Channel(CONFLATED)
54 
55     /**
56      * Registers a request for the scheduler to notify [dispatcher] at a virtual moment [timeDeltaMillis] milliseconds
57      * later via [TestDispatcher.processEvent], which will be called with the provided [marker] object.
58      *
59      * Returns the handler which can be used to cancel the registration.
60      */
registerEventnull61     internal fun <T : Any> registerEvent(
62         dispatcher: TestDispatcher,
63         timeDeltaMillis: Long,
64         marker: T,
65         context: CoroutineContext,
66         isCancelled: (T) -> Boolean
67     ): DisposableHandle {
68         require(timeDeltaMillis >= 0) { "Attempted scheduling an event earlier in time (with the time delta $timeDeltaMillis)" }
69         checkSchedulerInContext(this, context)
70         val count = count.getAndIncrement()
71         val isForeground = context[BackgroundWork] === null
72         return synchronized(lock) {
73             val time = addClamping(currentTime, timeDeltaMillis)
74             val event = TestDispatchEvent(dispatcher, count, time, marker as Any, isForeground) { isCancelled(marker) }
75             events.addLast(event)
76             /** can't be moved above: otherwise, [onDispatchEvent] could consume the token sent here before there's
77              * actually anything in the event queue. */
78             sendDispatchEvent(context)
79             DisposableHandle {
80                 synchronized(lock) {
81                     events.remove(event)
82                 }
83             }
84         }
85     }
86 
87     /**
88      * Runs the next enqueued task, advancing the virtual time to the time of its scheduled awakening,
89      * unless [condition] holds.
90      */
tryRunNextTaskUnlessnull91     internal fun tryRunNextTaskUnless(condition: () -> Boolean): Boolean {
92         val event = synchronized(lock) {
93             if (condition()) return false
94             val event = events.removeFirstOrNull() ?: return false
95             if (currentTime > event.time)
96                 currentTimeAheadOfEvents()
97             currentTime = event.time
98             event
99         }
100         event.dispatcher.processEvent(event.time, event.marker)
101         return true
102     }
103 
104     /**
105      * Runs the enqueued tasks in the specified order, advancing the virtual time as needed until there are no more
106      * tasks associated with the dispatchers linked to this scheduler.
107      *
108      * A breaking change from [TestCoroutineDispatcher.advanceTimeBy] is that it no longer returns the total number of
109      * milliseconds by which the execution of this method has advanced the virtual time. If you want to recreate that
110      * functionality, query [currentTime] before and after the execution to achieve the same result.
111      */
112     @ExperimentalCoroutinesApi
<lambda>null113     public fun advanceUntilIdle(): Unit = advanceUntilIdleOr { events.none(TestDispatchEvent<*>::isForeground) }
114 
115     /**
116      * [condition]: guaranteed to be invoked under the lock.
117      */
advanceUntilIdleOrnull118     internal fun advanceUntilIdleOr(condition: () -> Boolean) {
119         while (true) {
120             if (!tryRunNextTaskUnless(condition))
121                 return
122         }
123     }
124 
125     /**
126      * Runs the tasks that are scheduled to execute at this moment of virtual time.
127      */
128     @ExperimentalCoroutinesApi
runCurrentnull129     public fun runCurrent() {
130         val timeMark = synchronized(lock) { currentTime }
131         while (true) {
132             val event = synchronized(lock) {
133                 events.removeFirstIf { it.time <= timeMark } ?: return
134             }
135             event.dispatcher.processEvent(event.time, event.marker)
136         }
137     }
138 
139     /**
140      * Moves the virtual clock of this dispatcher forward by [the specified amount][delayTimeMillis], running the
141      * scheduled tasks in the meantime.
142      *
143      * Breaking changes from [TestCoroutineDispatcher.advanceTimeBy]:
144      * * Intentionally doesn't return a `Long` value, as its use cases are unclear. We may restore it in the future;
145      *   please describe your use cases at [the issue tracker](https://github.com/Kotlin/kotlinx.coroutines/issues/).
146      *   For now, it's possible to query [currentTime] before and after execution of this method, to the same effect.
147      * * It doesn't run the tasks that are scheduled at exactly [currentTime] + [delayTimeMillis]. For example,
148      *   advancing the time by one millisecond used to run the tasks at the current millisecond *and* the next
149      *   millisecond, but now will stop just before executing any task starting at the next millisecond.
150      * * Overflowing the target time used to lead to nothing being done, but will now run the tasks scheduled at up to
151      *   (but not including) [Long.MAX_VALUE].
152      *
153      * @throws IllegalStateException if passed a negative [delay][delayTimeMillis].
154      */
155     @ExperimentalCoroutinesApi
advanceTimeBynull156     public fun advanceTimeBy(delayTimeMillis: Long) {
157         require(delayTimeMillis >= 0) { "Can not advance time by a negative delay: $delayTimeMillis" }
158         val startingTime = currentTime
159         val targetTime = addClamping(startingTime, delayTimeMillis)
160         while (true) {
161             val event = synchronized(lock) {
162                 val timeMark = currentTime
163                 val event = events.removeFirstIf { targetTime > it.time }
164                 when {
165                     event == null -> {
166                         currentTime = targetTime
167                         return
168                     }
169                     timeMark > event.time -> currentTimeAheadOfEvents()
170                     else -> {
171                         currentTime = event.time
172                         event
173                     }
174                 }
175             }
176             event.dispatcher.processEvent(event.time, event.marker)
177         }
178     }
179 
180     /**
181      * Checks that the only tasks remaining in the scheduler are cancelled.
182      */
isIdlenull183     internal fun isIdle(strict: Boolean = true): Boolean =
184         synchronized(lock) {
185             if (strict) events.isEmpty else events.none { !it.isCancelled() }
186         }
187 
188     /**
189      * Notifies this scheduler about a dispatch event.
190      *
191      * [context] is the context in which the task will be dispatched.
192      */
sendDispatchEventnull193     internal fun sendDispatchEvent(context: CoroutineContext) {
194         if (context[BackgroundWork] !== BackgroundWork)
195             dispatchEvents.trySend(Unit)
196     }
197 
198     /**
199      * Consumes the knowledge that a dispatch event happened recently.
200      */
201     internal val onDispatchEvent: SelectClause1<Unit> get() = dispatchEvents.onReceive
202 
203     /**
204      * Returns the [TimeSource] representation of the virtual time of this scheduler.
205      */
206     @ExperimentalCoroutinesApi
207     @ExperimentalTime
208     public val timeSource: TimeSource = object : AbstractLongTimeSource(DurationUnit.MILLISECONDS) {
readnull209         override fun read(): Long = currentTime
210     }
211 }
212 
213 // Some error-throwing functions for pretty stack traces
214 private fun currentTimeAheadOfEvents(): Nothing = invalidSchedulerState()
215 
216 private fun invalidSchedulerState(): Nothing =
217     throw IllegalStateException("The test scheduler entered an invalid state. Please report this at https://github.com/Kotlin/kotlinx.coroutines/issues.")
218 
219 /** [ThreadSafeHeap] node representing a scheduled task, ordered by the planned execution time. */
220 private class TestDispatchEvent<T>(
221     @JvmField val dispatcher: TestDispatcher,
222     private val count: Long,
223     @JvmField val time: Long,
224     @JvmField val marker: T,
225     @JvmField val isForeground: Boolean,
226     // TODO: remove once the deprecated API is gone
227     @JvmField val isCancelled: () -> Boolean
228 ) : Comparable<TestDispatchEvent<*>>, ThreadSafeHeapNode {
229     override var heap: ThreadSafeHeap<*>? = null
230     override var index: Int = 0
231 
232     override fun compareTo(other: TestDispatchEvent<*>) =
233         compareValuesBy(this, other, TestDispatchEvent<*>::time, TestDispatchEvent<*>::count)
234 
235     override fun toString() = "TestDispatchEvent(time=$time, dispatcher=$dispatcher${if (isForeground) "" else ", background"})"
236 }
237 
238 // works with positive `a`, `b`
<lambda>null239 private fun addClamping(a: Long, b: Long): Long = (a + b).let { if (it >= 0) it else Long.MAX_VALUE }
240 
checkSchedulerInContextnull241 internal fun checkSchedulerInContext(scheduler: TestCoroutineScheduler, context: CoroutineContext) {
242     context[TestCoroutineScheduler]?.let {
243         check(it === scheduler) {
244             "Detected use of different schedulers. If you need to use several test coroutine dispatchers, " +
245                 "create one `TestCoroutineScheduler` and pass it to each of them."
246         }
247     }
248 }
249 
250 /**
251  * A coroutine context key denoting that the work is to be executed in the background.
252  * @see [TestScope.backgroundScope]
253  */
254 internal object BackgroundWork : CoroutineContext.Key<BackgroundWork>, CoroutineContext.Element {
255     override val key: CoroutineContext.Key<*>
256         get() = this
257 
toStringnull258     override fun toString(): String = "BackgroundWork"
259 }
260 
261 private fun<T> ThreadSafeHeap<T>.none(predicate: (T) -> Boolean) where T: ThreadSafeHeapNode, T: Comparable<T> =
262     find(predicate) == null
263