1 /*
<lambda>null2 * Copyright 2017-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3 */
5 @file:Suppress("RedundantVisibilityModifier")
7 package kotlinx.atomicfu
9 import java.util.*
10 import java.util.concurrent.atomic.*
11 import java.util.concurrent.locks.*
12 import kotlin.coroutines.*
13 import kotlin.coroutines.intrinsics.*
15 private const val PAUSE_EVERY_N_STEPS = 1000
16 private const val STALL_LIMIT_MS = 15_000L // 15s
17 private const val SHUTDOWN_CHECK_MS = 10L // 10ms
19 private const val STATUS_DONE = Int.MAX_VALUE
21 private const val MAX_PARK_NANOS = 1_000_000L // part for at most 1ms just in case of loosing unpark signal
23 /**
24 * Environment for performing lock-freedom tests for lock-free data structures
25 * that are written with [atomic] variables.
26 */
27 public open class LockFreedomTestEnvironment(
28 private val name: String,
29 private val allowSuspendedThreads: Int = 0
30 ) {
31 private val interceptor = Interceptor()
32 private val threads = mutableListOf<TestThread>()
33 private val performedOps = LongAdder()
34 private val uncaughtException = AtomicReference<Throwable?>()
35 private var started = false
36 private var performedResumes = 0
38 @Volatile
39 private var completed = false
40 private val onCompletion = mutableListOf<() -> Unit>()
42 private val ueh = Thread.UncaughtExceptionHandler { t, e ->
43 synchronized(System.out) {
44 println("Uncaught exception in thread $t")
45 e.printStackTrace(System.out)
46 uncaughtException.compareAndSet(null, e)
47 }
48 }
50 // status < 0 - inv paused thread id
51 // status >= 0 - no. of performed resumes so far (==last epoch)
52 // status == STATUS_DONE - done working
53 private val status = AtomicInteger()
54 private val globalPauseProgress = AtomicInteger()
55 private val suspendedThreads = ArrayList<TestThread>()
57 @Volatile
58 private var isActive = true
60 // ---------- API ----------
62 /**
63 * Starts lock-freedom test for a given duration in seconds,
64 * invoking [progress] every second (it will be invoked `seconds + 1` times).
65 */
66 public fun performTest(seconds: Int, progress: () -> Unit = {}) {
67 check(isActive) { "Can perform test at most once on this instance" }
68 println("=== $name")
69 val minThreads = 2 + allowSuspendedThreads
70 check(threads.size >= minThreads) { "Must define at least $minThreads test threads" }
71 lockAndSetInterceptor(interceptor)
72 started = true
73 var nextTime = System.currentTimeMillis()
74 threads.forEach { thread ->
75 thread.setUncaughtExceptionHandler(ueh)
76 thread.lastOpTime = nextTime
77 thread.start()
78 }
79 try {
80 var second = 0
81 while (uncaughtException.get() == null) {
82 waitUntil(nextTime)
83 println("--- $second: Performed ${performedOps.sum()} operations${resumeStr()}")
84 progress()
85 checkStalled()
86 if (++second > seconds) break
87 nextTime += 1000L
88 }
89 } finally {
90 complete()
91 }
92 println("------ Done with ${performedOps.sum()} operations${resumeStr()}")
93 progress()
94 }
96 private fun complete() {
97 val activeNonPausedThreads: MutableMap<TestThread, Array<StackTraceElement>> = mutableMapOf()
98 val shutdownDeadline = System.currentTimeMillis() + STALL_LIMIT_MS
99 try {
100 completed = true
101 // perform custom completion blocks. For testing of things like channels, these custom completion
102 // blocks close all the channels, so that all suspended coroutines shall get resumed.
103 onCompletion.forEach { it() }
104 // signal shutdown to all threads (non-paused threads will terminate)
105 isActive = false
106 // wait for threads to terminate
107 while (System.currentTimeMillis() < shutdownDeadline) {
108 // Check all threads while shutting down:
109 // All terminated threads are considered to make progress for the purpose of resuming stalled ones
110 activeNonPausedThreads.clear()
111 for (t in threads) {
112 when {
113 !t.isAlive -> t.makeProgress(getPausedEpoch()) // not alive - makes progress
114 t.index.inv() == status.get() -> {} // active, paused -- skip
115 else -> {
116 val stackTrace = t.stackTrace
117 if (t.isAlive) activeNonPausedThreads[t] = stackTrace
118 }
119 }
120 }
121 if (activeNonPausedThreads.isEmpty()) break
122 checkStalled()
123 Thread.sleep(SHUTDOWN_CHECK_MS)
124 }
125 activeNonPausedThreads.forEach { (t, stackTrack) ->
126 println("=== $t had failed to shutdown in time")
127 stackTrack.forEach { println("\tat $it") }
128 }
129 } finally {
130 shutdown(shutdownDeadline)
131 }
132 // if no other exception was throws & we had threads that did not shut down -- still fails
133 if (activeNonPausedThreads.isNotEmpty()) error("Some threads had failed to shutdown in time")
134 }
136 private fun shutdown(shutdownDeadline: Long) {
137 // forcefully unpause paused threads to shut them down (if any left)
138 val curStatus = status.getAndSet(STATUS_DONE)
139 if (curStatus < 0) LockSupport.unpark(threads[curStatus.inv()])
140 threads.forEach {
141 val remaining = shutdownDeadline - System.currentTimeMillis()
142 if (remaining > 0) it.join(remaining)
143 }
144 // abort waiting threads (if still any left)
145 threads.forEach { it.abortWait() }
146 // cleanup & be done
147 unlockAndResetInterceptor(interceptor)
148 uncaughtException.get()?.let { throw it }
149 threads.find { it.isAlive }?.let { dumpThreadsError("A thread is still alive: $it")}
150 }
152 private fun checkStalled() {
153 val stallLimit = System.currentTimeMillis() - STALL_LIMIT_MS
154 val stalled = threads.filter { it.lastOpTime < stallLimit }
155 if (stalled.isNotEmpty()) dumpThreadsError("Progress stalled in threads ${stalled.map { it.name }}")
156 }
158 private fun resumeStr(): String {
159 val resumes = performedResumes
160 return if (resumes == 0) "" else " (pause/resumes $resumes)"
161 }
163 private fun waitUntil(nextTime: Long) {
164 while (true) {
165 val curTime = System.currentTimeMillis()
166 if (curTime >= nextTime) break
167 Thread.sleep(nextTime - curTime)
168 }
169 }
171 private fun dumpThreadsError(message: String) : Nothing {
172 val traces = threads.associate { it to it.stackTrace }
173 println("!!! $message")
174 println("=== Dumping live thread stack traces")
175 for ((thread, trace) in traces) {
176 if (trace.isEmpty()) continue
177 println("Thread \"${thread.name}\" ${thread.state}")
178 for (t in trace) println("\tat ${t.className}.${t.methodName}(${t.fileName}:${t.lineNumber})")
179 println()
180 }
181 println("===")
182 error(message)
183 }
185 /**
186 * Returns true when test was completed.
187 * Sets to true before calling [onCompletion] blocks.
188 */
189 public val isCompleted: Boolean get() = completed
191 /**
192 * Performs a given block of code on test's completion
193 */
194 public fun onCompletion(block: () -> Unit) {
195 onCompletion += block
196 }
198 /**
199 * Creates a new test thread in this environment that is executes a given lock-free [operation]
200 * in a loop while this environment [isActive].
201 */
202 public fun testThread(name: String? = null, operation: suspend TestThread.() -> Unit): TestThread =
203 TestThread(name, operation)
205 /**
206 * Test thread.
207 */
208 @Suppress("LeakingThis")
209 public inner class TestThread internal constructor(
210 name: String?,
211 private val operation: suspend TestThread.() -> Unit
212 ) : Thread(composeThreadName(name)) {
213 internal val index: Int
215 internal @Volatile var lastOpTime = 0L
216 internal @Volatile var pausedEpoch = -1
218 private val random = Random()
220 // thread-local stuff
221 private var operationEpoch = -1
222 private var progressEpoch = -1
223 private var sink = 0
225 init {
226 check(!started)
227 index = threads.size
228 threads += this
229 }
231 public override fun run() {
232 while (isActive) {
233 callOperation()
234 }
235 }
237 /**
238 * Use it to insert an arbitrary intermission between lock-free operations.
239 */
240 public inline fun <T> intermission(block: () -> T): T {
241 afterLockFreeOperation()
242 return try { block() }
243 finally { beforeLockFreeOperation() }
244 }
246 @PublishedApi
247 internal fun beforeLockFreeOperation() {
248 operationEpoch = getPausedEpoch()
249 }
251 @PublishedApi
252 internal fun afterLockFreeOperation() {
253 makeProgress(operationEpoch)
254 lastOpTime = System.currentTimeMillis()
255 performedOps.add(1)
256 }
258 internal fun makeProgress(epoch: Int) {
259 if (epoch <= progressEpoch) return
260 progressEpoch = epoch
261 val total = globalPauseProgress.incrementAndGet()
262 if (total >= threads.size - 1) {
263 check(total == threads.size - 1)
264 check(globalPauseProgress.compareAndSet(threads.size - 1, 0))
265 resumeImpl()
266 }
267 }
269 /**
270 * Inserts random spin wait between multiple lock-free operations in [operation].
271 */
272 public fun randomSpinWaitIntermission() {
273 intermission {
274 if (random.nextInt(100) < 95) return // be quick, no wait 95% of time
275 do {
276 val x = random.nextInt(100)
277 repeat(x) { sink += it }
278 } while (x >= 90)
279 }
280 }
282 internal fun stepImpl() {
283 if (random.nextInt(PAUSE_EVERY_N_STEPS) == 0) pauseImpl()
284 }
286 internal fun pauseImpl() {
287 while (true) {
288 val curStatus = status.get()
289 if (curStatus < 0 || curStatus == STATUS_DONE) return // some other thread paused or done
290 pausedEpoch = curStatus + 1
291 val newStatus = index.inv()
292 if (status.compareAndSet(curStatus, newStatus)) {
293 while (status.get() == newStatus) LockSupport.parkNanos(MAX_PARK_NANOS) // wait
294 return
295 }
296 }
297 }
299 // ----- Lightweight support for suspending operations -----
301 private fun callOperation() {
302 beforeLockFreeOperation()
303 beginRunningOperation()
304 val result = operation.startCoroutineUninterceptedOrReturn(this, completion)
305 when {
306 result === Unit -> afterLockFreeOperation() // operation completed w/o suspension -- done
307 result === COROUTINE_SUSPENDED -> waitUntilCompletion() // operation had suspended
308 else -> error("Unexpected result of operation: $result")
309 }
310 try {
311 doneRunningOperation()
312 } catch(e: IllegalStateException) {
313 throw IllegalStateException("${e.message}; original start result=$result", e)
314 }
315 }
317 private var runningOperation = false
318 private var result: Result<Any?>? = null
319 private var continuation: Continuation<Any?>? = null
321 private fun waitUntilCompletion() {
322 try {
323 while (true) {
324 afterLockFreeOperation()
325 val result: Result<Any?> = waitForResult()
326 val continuation = takeContinuation()
327 if (continuation == null) { // done
328 check(result.getOrThrow() === Unit)
329 return
330 }
331 removeSuspended(this)
332 beforeLockFreeOperation()
333 continuation.resumeWith(result)
334 }
335 } finally {
336 removeSuspended(this)
337 }
338 }
340 private fun beginRunningOperation() {
341 runningOperation = true
342 result = null
343 continuation = null
344 }
346 @Synchronized
347 private fun doneRunningOperation() {
348 check(runningOperation) { "Should be running operation" }
349 check(result == null && continuation == null) {
350 "Callback invoked with result=$result, continuation=$continuation"
351 }
352 runningOperation = false
353 }
356 @Synchronized
357 private fun resumeWith(result: Result<Any?>, continuation: Continuation<Any?>?) {
358 check(runningOperation) { "Should be running operation" }
359 check(this.result == null && this.continuation == null) {
360 "Resumed again with result=$result, continuation=$continuation, when this: result=${this.result}, continuation=${this.continuation}"
361 }
362 this.result = result
363 this.continuation = continuation
364 (this as Object).notifyAll()
365 }
368 @Synchronized
369 private fun waitForResult(): Result<Any?> {
370 while (true) {
371 val result = this.result
372 if (result != null) return result
373 val index = addSuspended(this)
374 if (index < allowSuspendedThreads) {
375 // This suspension was permitted, so assume progress is happening while it is suspended
376 makeProgress(getPausedEpoch())
377 }
378 (this as Object).wait(10) // at most 10 ms
379 }
380 }
382 @Synchronized
383 private fun takeContinuation(): Continuation<Any?>? =
384 continuation.also {
385 this.result = null
386 this.continuation = null
387 }
390 @Synchronized
391 fun abortWait() {
392 this.result = Result.failure(IllegalStateException("Aborted at the end of test"))
393 (this as Object).notifyAll()
394 }
396 private val interceptor: CoroutineContext = object : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {
397 override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
398 Continuation<T>(this) {
399 @Suppress("UNCHECKED_CAST")
400 resumeWith(it, continuation as Continuation<Any?>)
401 }
402 }
404 private val completion = Continuation<Unit>(interceptor) {
405 resumeWith(it, null)
406 }
407 }
409 // ---------- Implementation ----------
411 @Synchronized
412 private fun addSuspended(thread: TestThread): Int {
413 val index = suspendedThreads.indexOf(thread)
414 if (index >= 0) return index
415 suspendedThreads.add(thread)
416 return suspendedThreads.size - 1
417 }
419 @Synchronized
420 private fun removeSuspended(thread: TestThread) {
421 suspendedThreads.remove(thread)
422 }
424 private fun getPausedEpoch(): Int {
425 while (true) {
426 val curStatus = status.get()
427 if (curStatus >= 0) return -1 // not paused
428 val thread = threads[curStatus.inv()]
429 val pausedEpoch = thread.pausedEpoch
430 if (curStatus == status.get()) return pausedEpoch
431 }
432 }
434 internal fun step() {
435 val thread = Thread.currentThread() as? TestThread ?: return
436 thread.stepImpl()
437 }
439 private fun resumeImpl() {
440 while (true) {
441 val curStatus = status.get()
442 if (curStatus == STATUS_DONE) return // done
443 check(curStatus < 0)
444 val thread = threads[curStatus.inv()]
445 performedResumes = thread.pausedEpoch
446 if (status.compareAndSet(curStatus, thread.pausedEpoch)) {
447 LockSupport.unpark(thread)
448 return
449 }
450 }
451 }
453 private fun composeThreadName(threadName: String?): String {
454 if (threadName != null) return "$name-$threadName"
455 return name + "-${threads.size + 1}"
456 }
458 private inner class Interceptor : AtomicOperationInterceptor() {
459 override fun <T> beforeUpdate(ref: AtomicRef<T>) = step()
460 override fun beforeUpdate(ref: AtomicInt) = step()
461 override fun beforeUpdate(ref: AtomicLong) = step()
462 override fun <T> afterSet(ref: AtomicRef<T>, newValue: T) = step()
463 override fun afterSet(ref: AtomicInt, newValue: Int) = step()
464 override fun afterSet(ref: AtomicLong, newValue: Long) = step()
465 override fun <T> afterRMW(ref: AtomicRef<T>, oldValue: T, newValue: T) = step()
466 override fun afterRMW(ref: AtomicInt, oldValue: Int, newValue: Int) = step()
467 override fun afterRMW(ref: AtomicLong, oldValue: Long, newValue: Long) = step()
468 override fun toString(): String = "LockFreedomTestEnvironment($name)"
469 }
470 }
472 /**
473 * Manual pause for on-going lock-free operation in a specified piece of code.
474 * Use it for targeted debugging of specific places in code. It does nothing
475 * when invoked outside of test thread.
476 *
477 * **Don't use it in production code.**
478 */
pauseLockFreeOpnull479 public fun pauseLockFreeOp() {
480 val thread = Thread.currentThread() as? LockFreedomTestEnvironment.TestThread ?: return
481 thread.pauseImpl()
482 }