1 /* <lambda>null2 * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. 3 */ 4 5 package kotlinx.coroutines.flow 6 7 import kotlinx.coroutines.* 8 import kotlinx.coroutines.channels.* 9 import kotlin.test.* 10 11 class ShareInTest : TestBase() { 12 @Test 13 fun testReplay0Eager() = runTest { 14 expect(1) 15 val flow = flowOf("OK") 16 val shared = flow.shareIn(this, SharingStarted.Eagerly) 17 yield() // actually start sharing 18 // all subscribers miss "OK" 19 val jobs = List(10) { 20 shared.onEach { expectUnreached() }.launchIn(this) 21 } 22 yield() // ensure nothing is collected 23 jobs.forEach { it.cancel() } 24 finish(2) 25 } 26 27 @Test 28 fun testReplay0Lazy() = testReplayZeroOrOne(0) 29 30 @Test 31 fun testReplay1Lazy() = testReplayZeroOrOne(1) 32 33 private fun testReplayZeroOrOne(replay: Int) = runTest { 34 expect(1) 35 val doneBarrier = Job() 36 val flow = flow { 37 expect(2) 38 emit("OK") 39 doneBarrier.join() 40 emit("DONE") 41 } 42 val sharingJob = Job() 43 val shared = flow.shareIn(this + sharingJob, started = SharingStarted.Lazily, replay = replay) 44 yield() // should not start sharing 45 // first subscriber gets "OK", other subscribers miss "OK" 46 val n = 10 47 val replayOfs = replay * (n - 1) 48 val subscriberJobs = List(n) { index -> 49 val subscribedBarrier = Job() 50 val job = shared 51 .onSubscription { 52 subscribedBarrier.complete() 53 } 54 .onEach { value -> 55 when (value) { 56 "OK" -> { 57 expect(3 + index) 58 if (replay == 0) { // only the first subscriber collects "OK" without replay 59 assertEquals(0, index) 60 } 61 } 62 "DONE" -> { 63 expect(4 + index + replayOfs) 64 } 65 else -> expectUnreached() 66 } 67 } 68 .takeWhile { it != "DONE" } 69 .launchIn(this) 70 subscribedBarrier.join() // wait until the launched job subscribed before launching the next one 71 job 72 } 73 doneBarrier.complete() 74 subscriberJobs.joinAll() 75 expect(4 + n + replayOfs) 76 sharingJob.cancel() 77 finish(5 + n + replayOfs) 78 } 79 80 @Test 81 fun testUpstreamCompleted() = 82 testUpstreamCompletedOrFailed(failed = false) 83 84 @Test 85 fun testUpstreamFailed() = 86 testUpstreamCompletedOrFailed(failed = true) 87 88 private fun testUpstreamCompletedOrFailed(failed: Boolean) = runTest { 89 val emitted = Job() 90 val terminate = Job() 91 val sharingJob = CompletableDeferred<Unit>() 92 val upstream = flow { 93 emit("OK") 94 emitted.complete() 95 terminate.join() 96 if (failed) throw TestException() 97 } 98 val shared = upstream.shareIn(this + sharingJob, SharingStarted.Eagerly, 1) 99 assertEquals(emptyList(), shared.replayCache) 100 emitted.join() // should start sharing, emit & cache 101 assertEquals(listOf("OK"), shared.replayCache) 102 terminate.complete() 103 sharingJob.complete(Unit) 104 sharingJob.join() // should complete sharing 105 assertEquals(listOf("OK"), shared.replayCache) // cache is still there 106 if (failed) { 107 assertTrue(sharingJob.getCompletionExceptionOrNull() is TestException) 108 } else { 109 assertNull(sharingJob.getCompletionExceptionOrNull()) 110 } 111 } 112 113 @Test 114 fun testWhileSubscribedBasic() = 115 testWhileSubscribed(1, SharingStarted.WhileSubscribed()) 116 117 @Test 118 fun testWhileSubscribedCustomAtLeast1() = 119 testWhileSubscribed(1, SharingStarted.WhileSubscribedAtLeast(1)) 120 121 @Test 122 fun testWhileSubscribedCustomAtLeast2() = 123 testWhileSubscribed(2, SharingStarted.WhileSubscribedAtLeast(2)) 124 125 @OptIn(ExperimentalStdlibApi::class) 126 private fun testWhileSubscribed(threshold: Int, started: SharingStarted) = runTest { 127 expect(1) 128 val flowState = FlowState() 129 val n = 3 // max number of subscribers 130 val log = Channel<String>(2 * n) 131 132 suspend fun checkStartTransition(subscribers: Int) { 133 when (subscribers) { 134 in 0 until threshold -> assertFalse(flowState.started) 135 threshold -> { 136 flowState.awaitStart() // must eventually start the flow 137 for (i in 1..threshold) { 138 assertEquals("sub$i: OK", log.receive()) // threshold subs must receive the values 139 } 140 } 141 in threshold + 1..n -> assertTrue(flowState.started) 142 } 143 } 144 145 suspend fun checkStopTransition(subscribers: Int) { 146 when (subscribers) { 147 in threshold + 1..n -> assertTrue(flowState.started) 148 threshold - 1 -> flowState.awaitStop() // upstream flow must be eventually stopped 149 in 0..threshold - 2 -> assertFalse(flowState.started) // should have stopped already 150 } 151 } 152 153 val flow = flow { 154 flowState.track { 155 emit("OK") 156 delay(Long.MAX_VALUE) // await forever, will get cancelled 157 } 158 } 159 160 val shared = flow.shareIn(this, started) 161 repeat(5) { // repeat scenario a few times 162 yield() 163 assertFalse(flowState.started) // flow is not running even if we yield 164 // start 3 subscribers 165 val subs = ArrayList<Job>() 166 for (i in 1..n) { 167 subs += shared 168 .onEach { value -> // only the first threshold subscribers get the value 169 when (i) { 170 in 1..threshold -> log.trySend("sub$i: $value") 171 else -> expectUnreached() 172 } 173 } 174 .onCompletion { log.trySend("sub$i: completion") } 175 .launchIn(this) 176 checkStartTransition(i) 177 } 178 // now cancel all subscribers 179 for (i in 1..n) { 180 subs.removeFirst().cancel() // cancel subscriber 181 assertEquals("sub$i: completion", log.receive()) // subscriber shall shutdown 182 checkStopTransition(n - i) 183 } 184 } 185 coroutineContext.cancelChildren() // cancel sharing job 186 finish(2) 187 } 188 189 @Suppress("TestFunctionName") 190 private fun SharingStarted.Companion.WhileSubscribedAtLeast(threshold: Int) = 191 SharingStarted { subscriptionCount -> 192 subscriptionCount.map { if (it >= threshold) SharingCommand.START else SharingCommand.STOP } 193 } 194 195 private class FlowState { 196 private val timeLimit = 10000L 197 private val _started = MutableStateFlow(false) 198 val started: Boolean get() = _started.value 199 fun start() = check(_started.compareAndSet(expect = false, update = true)) 200 fun stop() = check(_started.compareAndSet(expect = true, update = false)) 201 suspend fun awaitStart() = withTimeout(timeLimit) { _started.first { it } } 202 suspend fun awaitStop() = withTimeout(timeLimit) { _started.first { !it } } 203 } 204 205 private suspend fun FlowState.track(block: suspend () -> Unit) { 206 start() 207 try { 208 block() 209 } finally { 210 stop() 211 } 212 } 213 214 @Test 215 fun testShouldStart() = runTest { 216 val flow = flow { 217 expect(2) 218 emit(1) 219 expect(3) 220 }.shareIn(this, SharingStarted.Lazily) 221 222 expect(1) 223 flow.onSubscription { throw CancellationException("") } 224 .catch { e -> assertTrue { e is CancellationException } } 225 .collect() 226 yield() 227 finish(4) 228 } 229 230 @Test 231 fun testShouldStartScalar() = runTest { 232 val j = Job() 233 val shared = flowOf(239).stateIn(this + j, SharingStarted.Lazily, 42) 234 assertEquals(42, shared.first()) 235 yield() 236 assertEquals(239, shared.first()) 237 j.cancel() 238 } 239 } 240