• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines.debug.junit4
6 
7 import kotlinx.coroutines.debug.*
8 import org.junit.rules.*
9 import org.junit.runner.*
10 import org.junit.runners.model.*
11 import java.io.*
12 import kotlin.test.*
13 
TestFailureValidationnull14 internal fun TestFailureValidation(
15     timeoutMs: Long,
16     cancelOnTimeout: Boolean,
17     creationStackTraces: Boolean,
18     vararg specs: TestResultSpec
19 ): RuleChain =
20     RuleChain
21         .outerRule(TestFailureValidation(specs.associateBy { it.testName }))
22         .around(
23             CoroutinesTimeout(
24                 timeoutMs,
25                 cancelOnTimeout,
26                 creationStackTraces
27             )
28         )
29 
30 /**
31  * Rule that captures test result, serr and sout and validates it against provided [testsSpec]
32  */
33 internal class TestFailureValidation(private val testsSpec: Map<String, TestResultSpec>) : TestRule {
34 
35     companion object {
36         init {
37             DebugProbes.sanitizeStackTraces = false
38         }
39     }
applynull40     override fun apply(base: Statement, description: Description): Statement {
41         return TestFailureStatement(base, description)
42     }
43 
44     inner class TestFailureStatement(private val test: Statement, private val description: Description) : Statement() {
45         private lateinit var sout: PrintStream
46         private lateinit var serr: PrintStream
47         private val capturedOut = ByteArrayOutputStream()
48 
evaluatenull49         override fun evaluate() {
50             try {
51                 replaceOut()
52                 test.evaluate()
53             } catch (e: Throwable) {
54                 validateFailure(e)
55                 return
56             } finally {
57                 resetOut()
58             }
59 
60             validateSuccess() // To avoid falling into catch
61         }
62 
validateSuccessnull63         private fun validateSuccess() {
64             val spec = testsSpec[description.methodName] ?: error("Test spec not found: ${description.methodName}")
65             require(spec.error == null) { "Expected exception of type ${spec.error}, but test successfully passed" }
66 
67             val captured = capturedOut.toString()
68             assertFalse(captured.contains("Coroutines dump"))
69             assertTrue(captured.isEmpty(), captured)
70         }
71 
validateFailurenull72         private fun validateFailure(e: Throwable) {
73             val spec = testsSpec[description.methodName] ?: error("Test spec not found: ${description.methodName}")
74             if (spec.error == null || !spec.error.isInstance(e)) {
75                 throw IllegalStateException("Unexpected failure, expected ${spec.error}, had ${e::class}", e)
76             }
77 
78             if (e !is TestTimedOutException) return
79 
80             val captured = capturedOut.toString()
81             assertTrue(captured.contains("Coroutines dump"))
82             for (part in spec.expectedOutParts) {
83                 assertTrue(captured.contains(part), "Expected $part to be part of the\n$captured")
84             }
85 
86             for (part in spec.notExpectedOutParts) {
87                 assertFalse(captured.contains(part), "Expected $part not to be part of the\n$captured")
88             }
89         }
90 
replaceOutnull91         private fun replaceOut() {
92             sout = System.out
93             serr = System.err
94 
95             System.setOut(PrintStream(capturedOut))
96             System.setErr(PrintStream(capturedOut))
97         }
98 
resetOutnull99         private fun resetOut() {
100             System.setOut(sout)
101             System.setErr(serr)
102         }
103     }
104 }
105 
106 data class TestResultSpec(
107     val testName: String, val expectedOutParts: List<String> = listOf(),
108     val notExpectedOutParts: List<String> = listOf(), val error: Class<out Throwable>? = null
109 )
110