1 /* 2 * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. 3 */ 4 5 package kotlinx.coroutines.scheduling 6 7 import kotlinx.atomicfu.* 8 import kotlinx.coroutines.* 9 import java.util.concurrent.atomic.* 10 import kotlin.jvm.internal.Ref.ObjectRef 11 12 internal const val BUFFER_CAPACITY_BASE = 7 13 internal const val BUFFER_CAPACITY = 1 shl BUFFER_CAPACITY_BASE 14 internal const val MASK = BUFFER_CAPACITY - 1 // 128 by default 15 16 internal const val TASK_STOLEN = -1L 17 internal const val NOTHING_TO_STEAL = -2L 18 19 internal typealias StealingMode = Int 20 internal const val STEAL_ANY: StealingMode = 3 21 internal const val STEAL_CPU_ONLY: StealingMode = 2 22 internal const val STEAL_BLOCKING_ONLY: StealingMode = 1 23 24 internal inline val Task.maskForStealingMode: Int 25 get() = if (isBlocking) STEAL_BLOCKING_ONLY else STEAL_CPU_ONLY 26 27 /** 28 * Tightly coupled with [CoroutineScheduler] queue of pending tasks, but extracted to separate file for simplicity. 29 * At any moment queue is used only by [CoroutineScheduler.Worker] threads, has only one producer (worker owning this queue) 30 * and any amount of consumers, other pool workers which are trying to steal work. 31 * 32 * ### Fairness 33 * 34 * [WorkQueue] provides semi-FIFO order, but with priority for most recently submitted task assuming 35 * that these two (current one and submitted) are communicating and sharing state thus making such communication extremely fast. 36 * E.g. submitted jobs [1, 2, 3, 4] will be executed in [4, 1, 2, 3] order. 37 * 38 * ### Algorithm and implementation details 39 * This is a regular SPMC bounded queue with the additional property that tasks can be removed from the middle of the queue 40 * (scheduler workers without a CPU permit steal blocking tasks via this mechanism). Such property enforces us to use CAS in 41 * order to properly claim value from the buffer. 42 * Moreover, [Task] objects are reusable, so it may seem that this queue is prone to ABA problem. 43 * Indeed, it formally has ABA-problem, but the whole processing logic is written in the way that such ABA is harmless. 44 * I have discovered a truly marvelous proof of this, which this KDoc is too narrow to contain. 45 */ 46 internal class WorkQueue { 47 48 /* 49 * We read two independent counter here. 50 * Producer index is incremented only by owner 51 * Consumer index is incremented both by owner and external threads 52 * 53 * The only harmful race is: 54 * [T1] readProducerIndex (1) preemption(2) readConsumerIndex(5) 55 * [T2] changeProducerIndex (3) 56 * [T3] changeConsumerIndex (4) 57 * 58 * Which can lead to resulting size being negative or bigger than actual size at any moment of time. 59 * This is in general harmless because steal will be blocked by timer. 60 * Negative sizes can be observed only when non-owner reads the size, which happens only 61 * for diagnostic toString(). 62 */ 63 private val bufferSize: Int get() = producerIndex.value - consumerIndex.value 64 internal val size: Int get() = if (lastScheduledTask.value != null) bufferSize + 1 else bufferSize 65 private val buffer: AtomicReferenceArray<Task?> = AtomicReferenceArray(BUFFER_CAPACITY) 66 private val lastScheduledTask = atomic<Task?>(null) 67 68 private val producerIndex = atomic(0) 69 private val consumerIndex = atomic(0) 70 // Shortcut to avoid scanning queue without blocking tasks 71 private val blockingTasksInBuffer = atomic(0) 72 73 /** 74 * Retrieves and removes task from the head of the queue 75 * Invariant: this method is called only by the owner of the queue. 76 */ pollnull77 fun poll(): Task? = lastScheduledTask.getAndSet(null) ?: pollBuffer() 78 79 /** 80 * Invariant: Called only by the owner of the queue, returns 81 * `null` if task was added, task that wasn't added otherwise. 82 */ 83 fun add(task: Task, fair: Boolean = false): Task? { 84 if (fair) return addLast(task) 85 val previous = lastScheduledTask.getAndSet(task) ?: return null 86 return addLast(previous) 87 } 88 89 /** 90 * Invariant: Called only by the owner of the queue, returns 91 * `null` if task was added, task that wasn't added otherwise. 92 */ addLastnull93 private fun addLast(task: Task): Task? { 94 if (bufferSize == BUFFER_CAPACITY - 1) return task 95 if (task.isBlocking) blockingTasksInBuffer.incrementAndGet() 96 val nextIndex = producerIndex.value and MASK 97 /* 98 * If current element is not null then we're racing with a really slow consumer that committed the consumer index, 99 * but hasn't yet nulled out the slot, effectively preventing us from using it. 100 * Such situations are very rare in practise (although possible) and we decided to give up a progress guarantee 101 * to have a stronger invariant "add to queue with bufferSize == 0 is always successful". 102 * This algorithm can still be wait-free for add, but if and only if tasks are not reusable, otherwise 103 * nulling out the buffer wouldn't be possible. 104 */ 105 while (buffer[nextIndex] != null) { 106 Thread.yield() 107 } 108 buffer.lazySet(nextIndex, task) 109 producerIndex.incrementAndGet() 110 return null 111 } 112 113 /** 114 * Tries stealing from this queue into the [stolenTaskRef] argument. 115 * 116 * Returns [NOTHING_TO_STEAL] if queue has nothing to steal, [TASK_STOLEN] if at least task was stolen 117 * or positive value of how many nanoseconds should pass until the head of this queue will be available to steal. 118 * 119 * [StealingMode] controls what tasks to steal: 120 * * [STEAL_ANY] is default mode for scheduler, task from the head (in FIFO order) is stolen 121 * * [STEAL_BLOCKING_ONLY] is mode for stealing *an arbitrary* blocking task, which is used by the scheduler when helping in Dispatchers.IO mode 122 * * [STEAL_CPU_ONLY] is a kludge for `runSingleTaskFromCurrentSystemDispatcher` 123 */ tryStealnull124 fun trySteal(stealingMode: StealingMode, stolenTaskRef: ObjectRef<Task?>): Long { 125 val task = when (stealingMode) { 126 STEAL_ANY -> pollBuffer() 127 else -> stealWithExclusiveMode(stealingMode) 128 } 129 130 if (task != null) { 131 stolenTaskRef.element = task 132 return TASK_STOLEN 133 } 134 return tryStealLastScheduled(stealingMode, stolenTaskRef) 135 } 136 137 // Steal only tasks of a particular kind, potentially invoking full queue scan stealWithExclusiveModenull138 private fun stealWithExclusiveMode(stealingMode: StealingMode): Task? { 139 var start = consumerIndex.value 140 val end = producerIndex.value 141 val onlyBlocking = stealingMode == STEAL_BLOCKING_ONLY 142 // Bail out if there is no blocking work for us 143 while (start != end) { 144 if (onlyBlocking && blockingTasksInBuffer.value == 0) return null 145 return tryExtractFromTheMiddle(start++, onlyBlocking) ?: continue 146 } 147 148 return null 149 } 150 151 // Polls for blocking task, invoked only by the owner 152 // NB: ONLY for runSingleTask method pollBlockingnull153 fun pollBlocking(): Task? = pollWithExclusiveMode(onlyBlocking = true /* only blocking */) 154 155 // Polls for CPU task, invoked only by the owner 156 // NB: ONLY for runSingleTask method 157 fun pollCpu(): Task? = pollWithExclusiveMode(onlyBlocking = false /* only cpu */) 158 159 private fun pollWithExclusiveMode(/* Only blocking OR only CPU */ onlyBlocking: Boolean): Task? { 160 while (true) { // Poll the slot 161 val lastScheduled = lastScheduledTask.value ?: break 162 if (lastScheduled.isBlocking != onlyBlocking) break 163 if (lastScheduledTask.compareAndSet(lastScheduled, null)) { 164 return lastScheduled 165 } // Failed -> someone else stole it 166 } 167 168 // Failed to poll the slot, scan the queue 169 val start = consumerIndex.value 170 var end = producerIndex.value 171 // Bail out if there is no blocking work for us 172 while (start != end) { 173 if (onlyBlocking && blockingTasksInBuffer.value == 0) return null 174 val task = tryExtractFromTheMiddle(--end, onlyBlocking) 175 if (task != null) { 176 return task 177 } 178 } 179 return null 180 } 181 tryExtractFromTheMiddlenull182 private fun tryExtractFromTheMiddle(index: Int, onlyBlocking: Boolean): Task? { 183 val arrayIndex = index and MASK 184 val value = buffer[arrayIndex] 185 if (value != null && value.isBlocking == onlyBlocking && buffer.compareAndSet(arrayIndex, value, null)) { 186 if (onlyBlocking) blockingTasksInBuffer.decrementAndGet() 187 return value 188 } 189 return null 190 } 191 offloadAllWorkTonull192 fun offloadAllWorkTo(globalQueue: GlobalQueue) { 193 lastScheduledTask.getAndSet(null)?.let { globalQueue.addLast(it) } 194 while (pollTo(globalQueue)) { 195 // Steal everything 196 } 197 } 198 199 /** 200 * Contract on return value is the same as for [trySteal] 201 */ tryStealLastSchedulednull202 private fun tryStealLastScheduled(stealingMode: StealingMode, stolenTaskRef: ObjectRef<Task?>): Long { 203 while (true) { 204 val lastScheduled = lastScheduledTask.value ?: return NOTHING_TO_STEAL 205 if ((lastScheduled.maskForStealingMode and stealingMode) == 0) { 206 return NOTHING_TO_STEAL 207 } 208 209 // TODO time wraparound ? 210 val time = schedulerTimeSource.nanoTime() 211 val staleness = time - lastScheduled.submissionTime 212 if (staleness < WORK_STEALING_TIME_RESOLUTION_NS) { 213 return WORK_STEALING_TIME_RESOLUTION_NS - staleness 214 } 215 216 /* 217 * If CAS has failed, either someone else had stolen this task or the owner executed this task 218 * and dispatched another one. In the latter case we should retry to avoid missing task. 219 */ 220 if (lastScheduledTask.compareAndSet(lastScheduled, null)) { 221 stolenTaskRef.element = lastScheduled 222 return TASK_STOLEN 223 } 224 continue 225 } 226 } 227 pollTonull228 private fun pollTo(queue: GlobalQueue): Boolean { 229 val task = pollBuffer() ?: return false 230 queue.addLast(task) 231 return true 232 } 233 pollBuffernull234 private fun pollBuffer(): Task? { 235 while (true) { 236 val tailLocal = consumerIndex.value 237 if (tailLocal - producerIndex.value == 0) return null 238 val index = tailLocal and MASK 239 if (consumerIndex.compareAndSet(tailLocal, tailLocal + 1)) { 240 // Nulls are allowed when blocking tasks are stolen from the middle of the queue. 241 val value = buffer.getAndSet(index, null) ?: continue 242 value.decrementIfBlocking() 243 return value 244 } 245 } 246 } 247 Tasknull248 private fun Task?.decrementIfBlocking() { 249 if (this != null && isBlocking) { 250 val value = blockingTasksInBuffer.decrementAndGet() 251 assert { value >= 0 } 252 } 253 } 254 } 255