/* * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. */ package kotlinx.coroutines.test import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.coroutines.internal.* import kotlin.coroutines.* import kotlin.math.* /** * [CoroutineDispatcher] that performs both immediate and lazy execution of coroutines in tests * and implements [DelayController] to control its virtual clock. * * By default, [TestCoroutineDispatcher] is immediate. That means any tasks scheduled to be run without delay are * immediately executed. If they were scheduled with a delay, the virtual clock-time must be advanced via one of the * methods on [DelayController]. * * When switched to lazy execution using [pauseDispatcher] any coroutines started via [launch] or [async] will * not execute until a call to [DelayController.runCurrent] or the virtual clock-time has been advanced via one of the * methods on [DelayController]. * * @see DelayController */ @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0 public class TestCoroutineDispatcher: CoroutineDispatcher(), Delay, DelayController { private var dispatchImmediately = true set(value) { field = value if (value) { // there may already be tasks from setup code we need to run advanceUntilIdle() } } // The ordered queue for the runnable tasks. private val queue = ThreadSafeHeap() // The per-scheduler global order counter. private val _counter = atomic(0L) // Storing time in nanoseconds internally. private val _time = atomic(0L) /** @suppress */ override fun dispatch(context: CoroutineContext, block: Runnable) { if (dispatchImmediately) { block.run() } else { post(block) } } /** @suppress */ @InternalCoroutinesApi override fun dispatchYield(context: CoroutineContext, block: Runnable) { post(block) } /** @suppress */ override fun scheduleResumeAfterDelay(timeMillis: Long, continuation: CancellableContinuation) { postDelayed(CancellableContinuationRunnable(continuation) { resumeUndispatched(Unit) }, timeMillis) } /** @suppress */ override fun invokeOnTimeout(timeMillis: Long, block: Runnable, context: CoroutineContext): DisposableHandle { val node = postDelayed(block, timeMillis) return object : DisposableHandle { override fun dispose() { queue.remove(node) } } } /** @suppress */ override fun toString(): String { return "TestCoroutineDispatcher[currentTime=${currentTime}ms, queued=${queue.size}]" } private fun post(block: Runnable) = queue.addLast(TimedRunnable(block, _counter.getAndIncrement())) private fun postDelayed(block: Runnable, delayTime: Long) = TimedRunnable(block, _counter.getAndIncrement(), safePlus(currentTime, delayTime)) .also { queue.addLast(it) } private fun safePlus(currentTime: Long, delayTime: Long): Long { check(delayTime >= 0) val result = currentTime + delayTime if (result < currentTime) return Long.MAX_VALUE // clam on overflow return result } private fun doActionsUntil(targetTime: Long) { while (true) { val current = queue.removeFirstIf { it.time <= targetTime } ?: break // If the scheduled time is 0 (immediate) use current virtual time if (current.time != 0L) _time.value = current.time current.run() } } /** @suppress */ override val currentTime: Long get() = _time.value /** @suppress */ override fun advanceTimeBy(delayTimeMillis: Long): Long { val oldTime = currentTime advanceUntilTime(oldTime + delayTimeMillis) return currentTime - oldTime } /** * Moves the CoroutineContext's clock-time to a particular moment in time. * * @param targetTime The point in time to which to move the CoroutineContext's clock (milliseconds). */ private fun advanceUntilTime(targetTime: Long) { doActionsUntil(targetTime) _time.update { currentValue -> max(currentValue, targetTime) } } /** @suppress */ override fun advanceUntilIdle(): Long { val oldTime = currentTime while(!queue.isEmpty) { runCurrent() val next = queue.peek() ?: break advanceUntilTime(next.time) } return currentTime - oldTime } /** @suppress */ override fun runCurrent(): Unit = doActionsUntil(currentTime) /** @suppress */ override suspend fun pauseDispatcher(block: suspend () -> Unit) { val previous = dispatchImmediately dispatchImmediately = false try { block() } finally { dispatchImmediately = previous } } /** @suppress */ override fun pauseDispatcher() { dispatchImmediately = false } /** @suppress */ override fun resumeDispatcher() { dispatchImmediately = true } /** @suppress */ override fun cleanupTestCoroutines() { // process any pending cancellations or completions, but don't advance time doActionsUntil(currentTime) // run through all pending tasks, ignore any submitted coroutines that are not active val pendingTasks = mutableListOf() while (true) { pendingTasks += queue.removeFirstOrNull() ?: break } val activeDelays = pendingTasks .mapNotNull { it.runnable as? CancellableContinuationRunnable<*> } .filter { it.continuation.isActive } val activeTimeouts = pendingTasks.filter { it.runnable !is CancellableContinuationRunnable<*> } if (activeDelays.isNotEmpty() || activeTimeouts.isNotEmpty()) { throw UncompletedCoroutinesError( "Unfinished coroutines during teardown. Ensure all coroutines are" + " completed or cancelled by your test." ) } } } /** * This class exists to allow cleanup code to avoid throwing for cancelled continuations scheduled * in the future. */ private class CancellableContinuationRunnable( @JvmField val continuation: CancellableContinuation, private val block: CancellableContinuation.() -> Unit ) : Runnable { override fun run() = continuation.block() } /** * A Runnable for our event loop that represents a task to perform at a time. */ private class TimedRunnable( @JvmField val runnable: Runnable, private val count: Long = 0, @JvmField val time: Long = 0 ) : Comparable, Runnable by runnable, ThreadSafeHeapNode { override var heap: ThreadSafeHeap<*>? = null override var index: Int = 0 override fun compareTo(other: TimedRunnable) = if (time == other.time) { count.compareTo(other.count) } else { time.compareTo(other.time) } override fun toString() = "TimedRunnable(time=$time, run=$runnable)" }