• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines.flow.internal
6 
7 import kotlinx.coroutines.*
8 import kotlinx.coroutines.channels.*
9 import kotlinx.coroutines.flow.*
10 import kotlinx.coroutines.internal.*
11 import kotlin.coroutines.*
12 import kotlin.coroutines.intrinsics.*
13 import kotlin.jvm.*
14 
15 internal fun <T> Flow<T>.asChannelFlow(): ChannelFlow<T> =
16     this as? ChannelFlow ?: ChannelFlowOperatorImpl(this)
17 
18 /**
19  * Operators that can fuse with **downstream** [buffer] and [flowOn] operators implement this interface.
20  *
21  * @suppress **This an internal API and should not be used from general code.**
22  */
23 @InternalCoroutinesApi
24 public interface FusibleFlow<T> : Flow<T> {
25     /**
26      * This function is called by [flowOn] (with context) and [buffer] (with capacity) operators
27      * that are applied to this flow. Should not be used with [capacity] of [Channel.CONFLATED]
28      * (it shall be desugared to `capacity = 0, onBufferOverflow = DROP_OLDEST`).
29      */
30     public fun fuse(
31         context: CoroutineContext = EmptyCoroutineContext,
32         capacity: Int = Channel.OPTIONAL_CHANNEL,
33         onBufferOverflow: BufferOverflow = BufferOverflow.SUSPEND
34     ): Flow<T>
35 }
36 
37 /**
38  * Operators that use channels as their "output" extend this `ChannelFlow` and are always fused with each other.
39  * This class servers as a skeleton implementation of [FusibleFlow] and provides other cross-cutting
40  * methods like ability to [produceIn] the corresponding flow, thus making it
41  * possible to directly use the backing channel if it exists (hence the `ChannelFlow` name).
42  *
43  * @suppress **This an internal API and should not be used from general code.**
44  */
45 @InternalCoroutinesApi
46 public abstract class ChannelFlow<T>(
47     // upstream context
48     @JvmField public val context: CoroutineContext,
49     // buffer capacity between upstream and downstream context
50     @JvmField public val capacity: Int,
51     // buffer overflow strategy
52     @JvmField public val onBufferOverflow: BufferOverflow
53 ) : FusibleFlow<T> {
54     init {
<lambda>null55         assert { capacity != Channel.CONFLATED } // CONFLATED must be desugared to 0, DROP_OLDEST by callers
56     }
57 
58     // shared code to create a suspend lambda from collectTo function in one place
59     internal val collectToFun: suspend (ProducerScope<T>) -> Unit
<lambda>null60         get() = { collectTo(it) }
61 
62     internal val produceCapacity: Int
63         get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity
64 
65     /**
66      * When this [ChannelFlow] implementation can work without a channel (supports [Channel.OPTIONAL_CHANNEL]),
67      * then it should return a non-null value from this function, so that a caller can use it without the effect of
68      * additional [flowOn] and [buffer] operators, by incorporating its
69      * [context], [capacity], and [onBufferOverflow] into its own implementation.
70      */
dropChannelOperatorsnull71     public open fun dropChannelOperators(): Flow<T>? = null
72 
73     public override fun fuse(context: CoroutineContext, capacity: Int, onBufferOverflow: BufferOverflow): Flow<T> {
74         assert { capacity != Channel.CONFLATED } // CONFLATED must be desugared to (0, DROP_OLDEST) by callers
75         // note: previous upstream context (specified before) takes precedence
76         val newContext = context + this.context
77         val newCapacity: Int
78         val newOverflow: BufferOverflow
79         if (onBufferOverflow != BufferOverflow.SUSPEND) {
80             // this additional buffer never suspends => overwrite preceding buffering configuration
81             newCapacity = capacity
82             newOverflow = onBufferOverflow
83         } else {
84             // combine capacities, keep previous overflow strategy
85             newCapacity = when {
86                 this.capacity == Channel.OPTIONAL_CHANNEL -> capacity
87                 capacity == Channel.OPTIONAL_CHANNEL -> this.capacity
88                 this.capacity == Channel.BUFFERED -> capacity
89                 capacity == Channel.BUFFERED -> this.capacity
90                 else -> {
91                     // sanity checks
92                     assert { this.capacity >= 0 }
93                     assert { capacity >= 0 }
94                     // combine capacities clamping to UNLIMITED on overflow
95                     val sum = this.capacity + capacity
96                     if (sum >= 0) sum else Channel.UNLIMITED // unlimited on int overflow
97                 }
98             }
99             newOverflow = this.onBufferOverflow
100         }
101         if (newContext == this.context && newCapacity == this.capacity && newOverflow == this.onBufferOverflow)
102             return this
103         return create(newContext, newCapacity, newOverflow)
104     }
105 
createnull106     protected abstract fun create(context: CoroutineContext, capacity: Int, onBufferOverflow: BufferOverflow): ChannelFlow<T>
107 
108     protected abstract suspend fun collectTo(scope: ProducerScope<T>)
109 
110     /**
111      * Here we use ATOMIC start for a reason (#1825).
112      * NB: [produceImpl] is used for [flowOn].
113      * For non-atomic start it is possible to observe the situation,
114      * where the pipeline after the [flowOn] call successfully executes (mostly, its `onCompletion`)
115      * handlers, while the pipeline before does not, because it was cancelled during its dispatch.
116      * Thus `onCompletion` and `finally` blocks won't be executed and it may lead to a different kinds of memory leaks.
117      */
118     public open fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> =
119         scope.produce(context, produceCapacity, onBufferOverflow, start = CoroutineStart.ATOMIC, block = collectToFun)
120 
121     override suspend fun collect(collector: FlowCollector<T>): Unit =
122         coroutineScope {
123             collector.emitAll(produceImpl(this))
124         }
125 
additionalToStringPropsnull126     protected open fun additionalToStringProps(): String? = null
127 
128     // debug toString
129     override fun toString(): String {
130         val props = ArrayList<String>(4)
131         additionalToStringProps()?.let { props.add(it) }
132         if (context !== EmptyCoroutineContext) props.add("context=$context")
133         if (capacity != Channel.OPTIONAL_CHANNEL) props.add("capacity=$capacity")
134         if (onBufferOverflow != BufferOverflow.SUSPEND) props.add("onBufferOverflow=$onBufferOverflow")
135         return "$classSimpleName[${props.joinToString(", ")}]"
136     }
137 }
138 
139 // ChannelFlow implementation that operates on another flow before it
140 internal abstract class ChannelFlowOperator<S, T>(
141     @JvmField protected val flow: Flow<S>,
142     context: CoroutineContext,
143     capacity: Int,
144     onBufferOverflow: BufferOverflow
145 ) : ChannelFlow<T>(context, capacity, onBufferOverflow) {
flowCollectnull146     protected abstract suspend fun flowCollect(collector: FlowCollector<T>)
147 
148     // Changes collecting context upstream to the specified newContext, while collecting in the original context
149     private suspend fun collectWithContextUndispatched(collector: FlowCollector<T>, newContext: CoroutineContext) {
150         val originalContextCollector = collector.withUndispatchedContextCollector(coroutineContext)
151         // invoke flowCollect(originalContextCollector) in the newContext
152         return withContextUndispatched(newContext, block = { flowCollect(it) }, value = originalContextCollector)
153     }
154 
155     // Slow path when output channel is required
collectTonull156     protected override suspend fun collectTo(scope: ProducerScope<T>) =
157         flowCollect(SendingCollector(scope))
158 
159     // Optimizations for fast-path when channel creation is optional
160     override suspend fun collect(collector: FlowCollector<T>) {
161         // Fast-path: When channel creation is optional (flowOn/flowWith operators without buffer)
162         if (capacity == Channel.OPTIONAL_CHANNEL) {
163             val collectContext = coroutineContext
164             val newContext = collectContext.newCoroutineContext(context) // compute resulting collect context
165             // #1: If the resulting context happens to be the same as it was -- fallback to plain collect
166             if (newContext == collectContext)
167                 return flowCollect(collector)
168             // #2: If we don't need to change the dispatcher we can go without channels
169             if (newContext[ContinuationInterceptor] == collectContext[ContinuationInterceptor])
170                 return collectWithContextUndispatched(collector, newContext)
171         }
172         // Slow-path: create the actual channel
173         super.collect(collector)
174     }
175 
176     // debug toString
toStringnull177     override fun toString(): String = "$flow -> ${super.toString()}"
178 }
179 
180 /**
181  * Simple channel flow operator: [flowOn], [buffer], or their fused combination.
182  */
183 internal class ChannelFlowOperatorImpl<T>(
184     flow: Flow<T>,
185     context: CoroutineContext = EmptyCoroutineContext,
186     capacity: Int = Channel.OPTIONAL_CHANNEL,
187     onBufferOverflow: BufferOverflow = BufferOverflow.SUSPEND
188 ) : ChannelFlowOperator<T, T>(flow, context, capacity, onBufferOverflow) {
189     override fun create(context: CoroutineContext, capacity: Int, onBufferOverflow: BufferOverflow): ChannelFlow<T> =
190         ChannelFlowOperatorImpl(flow, context, capacity, onBufferOverflow)
191 
192     override fun dropChannelOperators(): Flow<T> = flow
193 
194     override suspend fun flowCollect(collector: FlowCollector<T>) =
195         flow.collect(collector)
196 }
197 
198 // Now if the underlying collector was accepting concurrent emits, then this one is too
199 // todo: we might need to generalize this pattern for "thread-safe" operators that can fuse with channels
withUndispatchedContextCollectornull200 private fun <T> FlowCollector<T>.withUndispatchedContextCollector(emitContext: CoroutineContext): FlowCollector<T> = when (this) {
201     // SendingCollector & NopCollector do not care about the context at all and can be used as is
202     is SendingCollector, is NopCollector -> this
203     // Otherwise just wrap into UndispatchedContextCollector interface implementation
204     else -> UndispatchedContextCollector(this, emitContext)
205 }
206 
207 private class UndispatchedContextCollector<T>(
208     downstream: FlowCollector<T>,
209     private val emitContext: CoroutineContext
210 ) : FlowCollector<T> {
211     private val countOrElement = threadContextElements(emitContext) // precompute for fast withContextUndispatched
<lambda>null212     private val emitRef: suspend (T) -> Unit = { downstream.emit(it) } // allocate suspend function ref once on creation
213 
emitnull214     override suspend fun emit(value: T): Unit =
215         withContextUndispatched(emitContext, value, countOrElement, emitRef)
216 }
217 
218 // Efficiently computes block(value) in the newContext
219 internal suspend fun <T, V> withContextUndispatched(
220     newContext: CoroutineContext,
221     value: V,
222     countOrElement: Any = threadContextElements(newContext), // can be precomputed for speed
223     block: suspend (V) -> T
224 ): T =
225     suspendCoroutineUninterceptedOrReturn { uCont ->
226         withCoroutineContext(newContext, countOrElement) {
227             block.startCoroutineUninterceptedOrReturn(value, StackFrameContinuation(uCont, newContext))
228         }
229     }
230 
231 // Continuation that links the caller with uCont with walkable CoroutineStackFrame
232 private class StackFrameContinuation<T>(
233     private val uCont: Continuation<T>, override val context: CoroutineContext
234 ) : Continuation<T>, CoroutineStackFrame {
235 
236     override val callerFrame: CoroutineStackFrame?
237         get() = uCont as? CoroutineStackFrame
238 
resumeWithnull239     override fun resumeWith(result: Result<T>) {
240         uCont.resumeWith(result)
241     }
242 
getStackTraceElementnull243     override fun getStackTraceElement(): StackTraceElement? = null
244 }
245