• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 package kotlinx.coroutines.debug.junit4
2 
3 import kotlinx.coroutines.debug.*
4 import org.junit.rules.*
5 import org.junit.runner.*
6 import org.junit.runners.model.*
7 import java.io.*
8 import kotlin.test.*
9 
TestFailureValidationnull10 internal fun TestFailureValidation(
11     timeoutMs: Long,
12     cancelOnTimeout: Boolean,
13     creationStackTraces: Boolean,
14     vararg specs: TestResultSpec
15 ): RuleChain =
16     RuleChain
17         .outerRule(TestFailureValidation(specs.associateBy { it.testName }))
18         .around(
19             CoroutinesTimeout(
20                 timeoutMs,
21                 cancelOnTimeout,
22                 creationStackTraces
23             )
24         )
25 
26 /**
27  * Rule that captures test result, serr and sout and validates it against provided [testsSpec]
28  */
29 internal class TestFailureValidation(private val testsSpec: Map<String, TestResultSpec>) : TestRule {
30 
31     companion object {
32         init {
33             DebugProbes.sanitizeStackTraces = false
34         }
35     }
applynull36     override fun apply(base: Statement, description: Description): Statement {
37         return TestFailureStatement(base, description)
38     }
39 
40     inner class TestFailureStatement(private val test: Statement, private val description: Description) : Statement() {
41         private lateinit var sout: PrintStream
42         private lateinit var serr: PrintStream
43         private val capturedOut = ByteArrayOutputStream()
44 
evaluatenull45         override fun evaluate() {
46             try {
47                 replaceOut()
48                 test.evaluate()
49             } catch (e: Throwable) {
50                 validateFailure(e)
51                 return
52             } finally {
53                 resetOut()
54             }
55 
56             validateSuccess() // To avoid falling into catch
57         }
58 
validateSuccessnull59         private fun validateSuccess() {
60             val spec = testsSpec[description.methodName] ?: error("Test spec not found: ${description.methodName}")
61             require(spec.error == null) { "Expected exception of type ${spec.error}, but test successfully passed" }
62 
63             val captured = capturedOut.toString()
64             assertFalse(captured.contains("Coroutines dump"))
65             assertTrue(captured.isEmpty(), captured)
66         }
67 
validateFailurenull68         private fun validateFailure(e: Throwable) {
69             val spec = testsSpec[description.methodName] ?: error("Test spec not found: ${description.methodName}")
70             if (spec.error == null || !spec.error.isInstance(e)) {
71                 throw IllegalStateException("Unexpected failure, expected ${spec.error}, had ${e::class}", e)
72             }
73 
74             if (e !is TestTimedOutException) return
75 
76             val captured = capturedOut.toString()
77             assertTrue(captured.contains("Coroutines dump"))
78             for (part in spec.expectedOutParts) {
79                 assertTrue(captured.contains(part), "Expected $part to be part of the\n$captured")
80             }
81 
82             for (part in spec.notExpectedOutParts) {
83                 assertFalse(captured.contains(part), "Expected $part not to be part of the\n$captured")
84             }
85         }
86 
replaceOutnull87         private fun replaceOut() {
88             sout = System.out
89             serr = System.err
90 
91             System.setOut(PrintStream(capturedOut))
92             System.setErr(PrintStream(capturedOut))
93         }
94 
resetOutnull95         private fun resetOut() {
96             System.setOut(sout)
97             System.setErr(serr)
98         }
99     }
100 }
101 
102 data class TestResultSpec(
103     val testName: String, val expectedOutParts: List<String> = listOf(),
104     val notExpectedOutParts: List<String> = listOf(), val error: Class<out Throwable>? = null
105 )
106