• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines.test
6 
7 import kotlinx.atomicfu.*
8 import kotlinx.coroutines.*
9 import kotlinx.coroutines.internal.*
10 import kotlin.coroutines.*
11 import kotlin.math.*
12 
13 /**
14  * [CoroutineDispatcher] that performs both immediate and lazy execution of coroutines in tests
15  * and implements [DelayController] to control its virtual clock.
16  *
17  * By default, [TestCoroutineDispatcher] is immediate. That means any tasks scheduled to be run without delay are
18  * immediately executed. If they were scheduled with a delay, the virtual clock-time must be advanced via one of the
19  * methods on [DelayController].
20  *
21  * When switched to lazy execution using [pauseDispatcher] any coroutines started via [launch] or [async] will
22  * not execute until a call to [DelayController.runCurrent] or the virtual clock-time has been advanced via one of the
23  * methods on [DelayController].
24  *
25  * @see DelayController
26  */
27 @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0
28 public class TestCoroutineDispatcher: CoroutineDispatcher(), Delay, DelayController {
29     private var dispatchImmediately = true
30         set(value) {
31             field = value
32             if (value) {
33                 // there may already be tasks from setup code we need to run
34                 advanceUntilIdle()
35             }
36         }
37 
38     // The ordered queue for the runnable tasks.
39     private val queue = ThreadSafeHeap<TimedRunnable>()
40 
41     // The per-scheduler global order counter.
42     private val _counter = atomic(0L)
43 
44     // Storing time in nanoseconds internally.
45     private val _time = atomic(0L)
46 
47     /** @suppress */
48     override fun dispatch(context: CoroutineContext, block: Runnable) {
49         if (dispatchImmediately) {
50             block.run()
51         } else {
52             post(block)
53         }
54     }
55 
56     /** @suppress */
57     @InternalCoroutinesApi
58     override fun dispatchYield(context: CoroutineContext, block: Runnable) {
59         post(block)
60     }
61 
62     /** @suppress */
63     override fun scheduleResumeAfterDelay(timeMillis: Long, continuation: CancellableContinuation<Unit>) {
64         postDelayed(CancellableContinuationRunnable(continuation) { resumeUndispatched(Unit) }, timeMillis)
65     }
66 
67     /** @suppress */
68     override fun invokeOnTimeout(timeMillis: Long, block: Runnable, context: CoroutineContext): DisposableHandle {
69         val node = postDelayed(block, timeMillis)
70         return object : DisposableHandle {
71             override fun dispose() {
72                 queue.remove(node)
73             }
74         }
75     }
76 
77     /** @suppress */
78     override fun toString(): String {
79         return "TestCoroutineDispatcher[currentTime=${currentTime}ms, queued=${queue.size}]"
80     }
81 
82     private fun post(block: Runnable) =
83         queue.addLast(TimedRunnable(block, _counter.getAndIncrement()))
84 
85     private fun postDelayed(block: Runnable, delayTime: Long) =
86         TimedRunnable(block, _counter.getAndIncrement(), safePlus(currentTime, delayTime))
87             .also {
88                 queue.addLast(it)
89             }
90 
91     private fun safePlus(currentTime: Long, delayTime: Long): Long {
92         check(delayTime >= 0)
93         val result = currentTime + delayTime
94         if (result < currentTime) return Long.MAX_VALUE // clam on overflow
95         return result
96     }
97 
98     private fun doActionsUntil(targetTime: Long) {
99         while (true) {
100             val current = queue.removeFirstIf { it.time <= targetTime } ?: break
101             // If the scheduled time is 0 (immediate) use current virtual time
102             if (current.time != 0L) _time.value = current.time
103             current.run()
104         }
105     }
106 
107     /** @suppress */
108     override val currentTime: Long get() = _time.value
109 
110     /** @suppress */
111     override fun advanceTimeBy(delayTimeMillis: Long): Long {
112         val oldTime = currentTime
113         advanceUntilTime(oldTime + delayTimeMillis)
114         return currentTime - oldTime
115     }
116 
117     /**
118      * Moves the CoroutineContext's clock-time to a particular moment in time.
119      *
120      * @param targetTime The point in time to which to move the CoroutineContext's clock (milliseconds).
121      */
122     private fun advanceUntilTime(targetTime: Long) {
123         doActionsUntil(targetTime)
124         _time.update { currentValue -> max(currentValue, targetTime) }
125     }
126 
127     /** @suppress */
128     override fun advanceUntilIdle(): Long {
129         val oldTime = currentTime
130         while(!queue.isEmpty) {
131             runCurrent()
132             val next = queue.peek() ?: break
133             advanceUntilTime(next.time)
134         }
135         return currentTime - oldTime
136     }
137 
138     /** @suppress */
139     override fun runCurrent(): Unit  = doActionsUntil(currentTime)
140 
141     /** @suppress */
142     override suspend fun pauseDispatcher(block: suspend () -> Unit) {
143         val previous = dispatchImmediately
144         dispatchImmediately = false
145         try {
146             block()
147         } finally {
148             dispatchImmediately = previous
149         }
150     }
151 
152     /** @suppress */
153     override fun pauseDispatcher() {
154         dispatchImmediately = false
155     }
156 
157     /** @suppress */
158     override fun resumeDispatcher() {
159         dispatchImmediately = true
160     }
161 
162     /** @suppress */
163     override fun cleanupTestCoroutines() {
164         // process any pending cancellations or completions, but don't advance time
165         doActionsUntil(currentTime)
166 
167         // run through all pending tasks, ignore any submitted coroutines that are not active
168         val pendingTasks = mutableListOf<TimedRunnable>()
169         while (true) {
170             pendingTasks += queue.removeFirstOrNull() ?: break
171         }
172         val activeDelays = pendingTasks
173             .mapNotNull { it.runnable as? CancellableContinuationRunnable<*> }
174             .filter { it.continuation.isActive }
175 
176         val activeTimeouts = pendingTasks.filter { it.runnable !is CancellableContinuationRunnable<*> }
177         if (activeDelays.isNotEmpty() || activeTimeouts.isNotEmpty()) {
178             throw UncompletedCoroutinesError(
179                 "Unfinished coroutines during teardown. Ensure all coroutines are" +
180                     " completed or cancelled by your test."
181             )
182         }
183     }
184 }
185 
186 /**
187  * This class exists to allow cleanup code to avoid throwing for cancelled continuations scheduled
188  * in the future.
189  */
190 private class CancellableContinuationRunnable<T>(
191     @JvmField val continuation: CancellableContinuation<T>,
192     private val block: CancellableContinuation<T>.() -> Unit
193 ) : Runnable {
runnull194     override fun run() = continuation.block()
195 }
196 
197 /**
198  * A Runnable for our event loop that represents a task to perform at a time.
199  */
200 private class TimedRunnable(
201     @JvmField val runnable: Runnable,
202     private val count: Long = 0,
203     @JvmField val time: Long = 0
204 ) : Comparable<TimedRunnable>, Runnable by runnable, ThreadSafeHeapNode {
205     override var heap: ThreadSafeHeap<*>? = null
206     override var index: Int = 0
207 
208     override fun compareTo(other: TimedRunnable) = if (time == other.time) {
209         count.compareTo(other.count)
210     } else {
211         time.compareTo(other.time)
212     }
213 
214     override fun toString() = "TimedRunnable(time=$time, run=$runnable)"
215 }
216