1 /* <lambda>null2 * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. 3 */ 4 5 package kotlinx.coroutines.test 6 7 import kotlinx.atomicfu.* 8 import kotlinx.coroutines.* 9 import kotlinx.coroutines.internal.* 10 import kotlin.coroutines.* 11 import kotlin.math.* 12 13 /** 14 * [CoroutineDispatcher] that performs both immediate and lazy execution of coroutines in tests 15 * and implements [DelayController] to control its virtual clock. 16 * 17 * By default, [TestCoroutineDispatcher] is immediate. That means any tasks scheduled to be run without delay are 18 * immediately executed. If they were scheduled with a delay, the virtual clock-time must be advanced via one of the 19 * methods on [DelayController]. 20 * 21 * When switched to lazy execution using [pauseDispatcher] any coroutines started via [launch] or [async] will 22 * not execute until a call to [DelayController.runCurrent] or the virtual clock-time has been advanced via one of the 23 * methods on [DelayController]. 24 * 25 * @see DelayController 26 */ 27 @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0 28 public class TestCoroutineDispatcher: CoroutineDispatcher(), Delay, DelayController { 29 private var dispatchImmediately = true 30 set(value) { 31 field = value 32 if (value) { 33 // there may already be tasks from setup code we need to run 34 advanceUntilIdle() 35 } 36 } 37 38 // The ordered queue for the runnable tasks. 39 private val queue = ThreadSafeHeap<TimedRunnable>() 40 41 // The per-scheduler global order counter. 42 private val _counter = atomic(0L) 43 44 // Storing time in nanoseconds internally. 45 private val _time = atomic(0L) 46 47 /** @suppress */ 48 override fun dispatch(context: CoroutineContext, block: Runnable) { 49 if (dispatchImmediately) { 50 block.run() 51 } else { 52 post(block) 53 } 54 } 55 56 /** @suppress */ 57 @InternalCoroutinesApi 58 override fun dispatchYield(context: CoroutineContext, block: Runnable) { 59 post(block) 60 } 61 62 /** @suppress */ 63 override fun scheduleResumeAfterDelay(timeMillis: Long, continuation: CancellableContinuation<Unit>) { 64 postDelayed(CancellableContinuationRunnable(continuation) { resumeUndispatched(Unit) }, timeMillis) 65 } 66 67 /** @suppress */ 68 override fun invokeOnTimeout(timeMillis: Long, block: Runnable): DisposableHandle { 69 val node = postDelayed(block, timeMillis) 70 return object : DisposableHandle { 71 override fun dispose() { 72 queue.remove(node) 73 } 74 } 75 } 76 77 /** @suppress */ 78 override fun toString(): String { 79 return "TestCoroutineDispatcher[currentTime=${currentTime}ms, queued=${queue.size}]" 80 } 81 82 private fun post(block: Runnable) = 83 queue.addLast(TimedRunnable(block, _counter.getAndIncrement())) 84 85 private fun postDelayed(block: Runnable, delayTime: Long) = 86 TimedRunnable(block, _counter.getAndIncrement(), safePlus(currentTime, delayTime)) 87 .also { 88 queue.addLast(it) 89 } 90 91 private fun safePlus(currentTime: Long, delayTime: Long): Long { 92 check(delayTime >= 0) 93 val result = currentTime + delayTime 94 if (result < currentTime) return Long.MAX_VALUE // clam on overflow 95 return result 96 } 97 98 private fun doActionsUntil(targetTime: Long) { 99 while (true) { 100 val current = queue.removeFirstIf { it.time <= targetTime } ?: break 101 // If the scheduled time is 0 (immediate) use current virtual time 102 if (current.time != 0L) _time.value = current.time 103 current.run() 104 } 105 } 106 107 /** @suppress */ 108 override val currentTime get() = _time.value 109 110 /** @suppress */ 111 override fun advanceTimeBy(delayTimeMillis: Long): Long { 112 val oldTime = currentTime 113 advanceUntilTime(oldTime + delayTimeMillis) 114 return currentTime - oldTime 115 } 116 117 /** 118 * Moves the CoroutineContext's clock-time to a particular moment in time. 119 * 120 * @param targetTime The point in time to which to move the CoroutineContext's clock (milliseconds). 121 */ 122 private fun advanceUntilTime(targetTime: Long) { 123 doActionsUntil(targetTime) 124 _time.update { currentValue -> max(currentValue, targetTime) } 125 } 126 127 /** @suppress */ 128 override fun advanceUntilIdle(): Long { 129 val oldTime = currentTime 130 while(!queue.isEmpty) { 131 runCurrent() 132 val next = queue.peek() ?: break 133 advanceUntilTime(next.time) 134 } 135 return currentTime - oldTime 136 } 137 138 /** @suppress */ 139 override fun runCurrent() = doActionsUntil(currentTime) 140 141 /** @suppress */ 142 override suspend fun pauseDispatcher(block: suspend () -> Unit) { 143 val previous = dispatchImmediately 144 dispatchImmediately = false 145 try { 146 block() 147 } finally { 148 dispatchImmediately = previous 149 } 150 } 151 152 /** @suppress */ 153 override fun pauseDispatcher() { 154 dispatchImmediately = false 155 } 156 157 /** @suppress */ 158 override fun resumeDispatcher() { 159 dispatchImmediately = true 160 } 161 162 /** @suppress */ 163 override fun cleanupTestCoroutines() { 164 // process any pending cancellations or completions, but don't advance time 165 doActionsUntil(currentTime) 166 167 // run through all pending tasks, ignore any submitted coroutines that are not active 168 val pendingTasks = mutableListOf<TimedRunnable>() 169 while (true) { 170 pendingTasks += queue.removeFirstOrNull() ?: break 171 } 172 val activeDelays = pendingTasks 173 .mapNotNull { it.runnable as? CancellableContinuationRunnable<*> } 174 .filter { it.continuation.isActive } 175 176 val activeTimeouts = pendingTasks.filter { it.runnable !is CancellableContinuationRunnable<*> } 177 if (activeDelays.isNotEmpty() || activeTimeouts.isNotEmpty()) { 178 throw UncompletedCoroutinesError( 179 "Unfinished coroutines during teardown. Ensure all coroutines are" + 180 " completed or cancelled by your test." 181 ) 182 } 183 } 184 } 185 186 /** 187 * This class exists to allow cleanup code to avoid throwing for cancelled continuations scheduled 188 * in the future. 189 */ 190 private class CancellableContinuationRunnable<T>( 191 @JvmField val continuation: CancellableContinuation<T>, 192 private val block: CancellableContinuation<T>.() -> Unit 193 ) : Runnable { runnull194 override fun run() = continuation.block() 195 } 196 197 /** 198 * A Runnable for our event loop that represents a task to perform at a time. 199 */ 200 private class TimedRunnable( 201 @JvmField val runnable: Runnable, 202 private val count: Long = 0, 203 @JvmField val time: Long = 0 204 ) : Comparable<TimedRunnable>, Runnable by runnable, ThreadSafeHeapNode { 205 override var heap: ThreadSafeHeap<*>? = null 206 override var index: Int = 0 207 208 override fun compareTo(other: TimedRunnable) = if (time == other.time) { 209 count.compareTo(other.count) 210 } else { 211 time.compareTo(other.time) 212 } 213 214 override fun toString() = "TimedRunnable(time=$time, run=$runnable)" 215 }