• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines.flow
6 
7 import kotlinx.coroutines.*
8 import kotlinx.coroutines.channels.*
9 import kotlin.test.*
10 
11 class ShareInTest : TestBase() {
12     @Test
13     fun testReplay0Eager() = runTest {
14         expect(1)
15         val flow = flowOf("OK")
16         val shared = flow.shareIn(this, SharingStarted.Eagerly)
17         yield() // actually start sharing
18         // all subscribers miss "OK"
19         val jobs = List(10) {
20             shared.onEach { expectUnreached() }.launchIn(this)
21         }
22         yield() // ensure nothing is collected
23         jobs.forEach { it.cancel() }
24         finish(2)
25     }
26 
27     @Test
28     fun testReplay0Lazy() = testReplayZeroOrOne(0)
29 
30     @Test
31     fun testReplay1Lazy() = testReplayZeroOrOne(1)
32 
33     private fun testReplayZeroOrOne(replay: Int) = runTest {
34         expect(1)
35         val doneBarrier = Job()
36         val flow = flow {
37             expect(2)
38             emit("OK")
39             doneBarrier.join()
40             emit("DONE")
41         }
42         val sharingJob = Job()
43         val shared = flow.shareIn(this + sharingJob, started = SharingStarted.Lazily, replay = replay)
44         yield() // should not start sharing
45         // first subscriber gets "OK", other subscribers miss "OK"
46         val n = 10
47         val replayOfs = replay * (n - 1)
48         val subscriberJobs = List(n) { index ->
49             val subscribedBarrier = Job()
50             val job = shared
51                 .onSubscription {
52                     subscribedBarrier.complete()
53                 }
54                 .onEach { value ->
55                     when (value) {
56                         "OK" -> {
57                             expect(3 + index)
58                             if (replay == 0) { // only the first subscriber collects "OK" without replay
59                                 assertEquals(0, index)
60                             }
61                         }
62                         "DONE" -> {
63                             expect(4 + index + replayOfs)
64                         }
65                         else -> expectUnreached()
66                     }
67                 }
68                 .takeWhile { it != "DONE" }
69                 .launchIn(this)
70             subscribedBarrier.join() // wait until the launched job subscribed before launching the next one
71             job
72         }
73         doneBarrier.complete()
74         subscriberJobs.joinAll()
75         expect(4 + n + replayOfs)
76         sharingJob.cancel()
77         finish(5 + n + replayOfs)
78     }
79 
80     @Test
81     fun testUpstreamCompleted() =
82         testUpstreamCompletedOrFailed(failed = false)
83 
84     @Test
85     fun testUpstreamFailed() =
86         testUpstreamCompletedOrFailed(failed = true)
87 
88     private fun testUpstreamCompletedOrFailed(failed: Boolean) = runTest {
89         val emitted = Job()
90         val terminate = Job()
91         val sharingJob = CompletableDeferred<Unit>()
92         val upstream = flow {
93             emit("OK")
94             emitted.complete()
95             terminate.join()
96             if (failed) throw TestException()
97         }
98         val shared = upstream.shareIn(this + sharingJob, SharingStarted.Eagerly, 1)
99         assertEquals(emptyList(), shared.replayCache)
100         emitted.join() // should start sharing, emit & cache
101         assertEquals(listOf("OK"), shared.replayCache)
102         terminate.complete()
103         sharingJob.complete(Unit)
104         sharingJob.join() // should complete sharing
105         assertEquals(listOf("OK"), shared.replayCache) // cache is still there
106         if (failed) {
107             assertTrue(sharingJob.getCompletionExceptionOrNull() is TestException)
108         } else {
109             assertNull(sharingJob.getCompletionExceptionOrNull())
110         }
111     }
112 
113     @Test
114     fun testWhileSubscribedBasic() =
115         testWhileSubscribed(1, SharingStarted.WhileSubscribed())
116 
117     @Test
118     fun testWhileSubscribedCustomAtLeast1() =
119         testWhileSubscribed(1, SharingStarted.WhileSubscribedAtLeast(1))
120 
121     @Test
122     fun testWhileSubscribedCustomAtLeast2() =
123         testWhileSubscribed(2, SharingStarted.WhileSubscribedAtLeast(2))
124 
125     @OptIn(ExperimentalStdlibApi::class)
126     private fun testWhileSubscribed(threshold: Int, started: SharingStarted) = runTest {
127         expect(1)
128         val flowState = FlowState()
129         val n = 3 // max number of subscribers
130         val log = Channel<String>(2 * n)
131 
132         suspend fun checkStartTransition(subscribers: Int) {
133             when (subscribers) {
134                 in 0 until threshold -> assertFalse(flowState.started)
135                 threshold -> {
136                     flowState.awaitStart() // must eventually start the flow
137                     for (i in 1..threshold) {
138                         assertEquals("sub$i: OK", log.receive()) // threshold subs must receive the values
139                     }
140                 }
141                 in threshold + 1..n -> assertTrue(flowState.started)
142             }
143         }
144 
145         suspend fun checkStopTransition(subscribers: Int) {
146             when (subscribers) {
147                 in threshold + 1..n -> assertTrue(flowState.started)
148                 threshold - 1 -> flowState.awaitStop() // upstream flow must be eventually stopped
149                 in 0..threshold - 2 -> assertFalse(flowState.started) // should have stopped already
150             }
151         }
152 
153         val flow = flow {
154             flowState.track {
155                 emit("OK")
156                 delay(Long.MAX_VALUE) // await forever, will get cancelled
157             }
158         }
159 
160         val shared = flow.shareIn(this, started)
161         repeat(5) { // repeat scenario a few times
162             yield()
163             assertFalse(flowState.started) // flow is not running even if we yield
164             // start 3 subscribers
165             val subs = ArrayList<Job>()
166             for (i in 1..n) {
167                 subs += shared
168                     .onEach { value -> // only the first threshold subscribers get the value
169                         when (i) {
170                             in 1..threshold -> log.trySend("sub$i: $value")
171                             else -> expectUnreached()
172                         }
173                     }
174                     .onCompletion { log.trySend("sub$i: completion") }
175                     .launchIn(this)
176                 checkStartTransition(i)
177             }
178             // now cancel all subscribers
179             for (i in 1..n) {
180                 subs.removeFirst().cancel() // cancel subscriber
181                 assertEquals("sub$i: completion", log.receive()) // subscriber shall shutdown
182                 checkStopTransition(n - i)
183             }
184         }
185         coroutineContext.cancelChildren() // cancel sharing job
186         finish(2)
187     }
188 
189     @Suppress("TestFunctionName")
190     private fun SharingStarted.Companion.WhileSubscribedAtLeast(threshold: Int) =
191         SharingStarted { subscriptionCount ->
192             subscriptionCount.map { if (it >= threshold) SharingCommand.START else SharingCommand.STOP }
193         }
194 
195     private class FlowState {
196         private val timeLimit = 10000L
197         private val _started = MutableStateFlow(false)
198         val started: Boolean get() = _started.value
199         fun start() = check(_started.compareAndSet(expect = false, update = true))
200         fun stop() = check(_started.compareAndSet(expect = true, update = false))
201         suspend fun awaitStart() = withTimeout(timeLimit) { _started.first { it } }
202         suspend fun awaitStop() = withTimeout(timeLimit) { _started.first { !it } }
203     }
204 
205     private suspend fun FlowState.track(block: suspend () -> Unit) {
206         start()
207         try {
208             block()
209         } finally {
210             stop()
211         }
212     }
213 
214     @Test
215     fun testShouldStart() = runTest {
216         val flow = flow {
217             expect(2)
218             emit(1)
219             expect(3)
220         }.shareIn(this, SharingStarted.Lazily)
221 
222         expect(1)
223         flow.onSubscription { throw CancellationException("") }
224             .catch { e -> assertTrue { e is CancellationException } }
225             .collect()
226         yield()
227         finish(4)
228     }
229 
230     @Test
231     fun testShouldStartScalar() = runTest {
232         val j = Job()
233         val shared = flowOf(239).stateIn(this + j, SharingStarted.Lazily, 42)
234         assertEquals(42, shared.first())
235         yield()
236         assertEquals(239, shared.first())
237         j.cancel()
238     }
239 }
240