/* * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. */ package kotlinx.coroutines.test.internal import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.coroutines.test.* import kotlin.coroutines.* /** * The testable main dispatcher used by kotlinx-coroutines-test. * It is a [MainCoroutineDispatcher] that delegates all actions to a settable delegate. */ internal class TestMainDispatcher(delegate: CoroutineDispatcher): MainCoroutineDispatcher(), Delay { private val mainDispatcher = delegate private var delegate = NonConcurrentlyModifiable(mainDispatcher, "Dispatchers.Main") private val delay get() = delegate.value as? Delay ?: defaultDelay override val immediate: MainCoroutineDispatcher get() = (delegate.value as? MainCoroutineDispatcher)?.immediate ?: this override fun dispatch(context: CoroutineContext, block: Runnable) = delegate.value.dispatch(context, block) override fun isDispatchNeeded(context: CoroutineContext): Boolean = delegate.value.isDispatchNeeded(context) override fun dispatchYield(context: CoroutineContext, block: Runnable) = delegate.value.dispatchYield(context, block) fun setDispatcher(dispatcher: CoroutineDispatcher) { delegate.value = dispatcher } fun resetDispatcher() { delegate.value = mainDispatcher } override fun scheduleResumeAfterDelay(timeMillis: Long, continuation: CancellableContinuation) = delay.scheduleResumeAfterDelay(timeMillis, continuation) override fun invokeOnTimeout(timeMillis: Long, block: Runnable, context: CoroutineContext): DisposableHandle = delay.invokeOnTimeout(timeMillis, block, context) companion object { internal val currentTestDispatcher get() = (Dispatchers.Main as? TestMainDispatcher)?.delegate?.value as? TestDispatcher internal val currentTestScheduler get() = currentTestDispatcher?.scheduler } /** * A wrapper around a value that attempts to throw when writing happens concurrently with reading. * * The read operations never throw. Instead, the failures detected inside them will be remembered and thrown on the * next modification. */ private class NonConcurrentlyModifiable(initialValue: T, private val name: String) { private val readers = atomic(0) // number of concurrent readers private val isWriting = atomic(false) // a modification is happening currently private val exceptionWhenReading: AtomicRef = atomic(null) // exception from reading private val _value = atomic(initialValue) // the backing field for the value private fun concurrentWW() = IllegalStateException("$name is modified concurrently") private fun concurrentRW() = IllegalStateException("$name is used concurrently with setting it") var value: T get() { readers.incrementAndGet() if (isWriting.value) exceptionWhenReading.value = concurrentRW() val result = _value.value readers.decrementAndGet() return result } set(value) { exceptionWhenReading.getAndSet(null)?.let { throw it } if (readers.value != 0) throw concurrentRW() if (!isWriting.compareAndSet(expect = false, update = true)) throw concurrentWW() _value.value = value isWriting.value = false if (readers.value != 0) throw concurrentRW() } } } @Suppress("INVISIBLE_MEMBER") private val defaultDelay inline get() = DefaultDelay @Suppress("INVISIBLE_MEMBER") internal expect fun Dispatchers.getTestMainDispatcher(): TestMainDispatcher