• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 @file:Suppress("NOTHING_TO_INLINE")
18 
19 package com.android.systemui.kairos.internal
20 
21 import com.android.systemui.kairos.internal.store.MapHolder
22 import com.android.systemui.kairos.internal.store.MapK
23 import com.android.systemui.kairos.internal.store.MutableMapK
24 import com.android.systemui.kairos.internal.store.asMapHolder
25 import com.android.systemui.kairos.internal.util.hashString
26 import com.android.systemui.kairos.internal.util.logDuration
27 
28 internal typealias MuxResult<W, K, V> = MapK<W, K, PullNode<V>>
29 
30 /** Base class for muxing nodes, which have a (potentially dynamic) collection of upstream nodes. */
31 internal sealed class MuxNode<W, K, V>(
32     val lifecycle: MuxLifecycle<W, K, V>,
33     protected val storeFactory: MutableMapK.Factory<W, K>,
34 ) : PushNode<MuxResult<W, K, V>> {
35 
36     lateinit var upstreamData: MutableMapK<W, K, PullNode<V>>
37     lateinit var switchedIn: MutableMapK<W, K, BranchNode>
38 
39     @Volatile var markedForCompaction = false
40     @Volatile var markedForEvaluation = false
41 
42     val downstreamSet: DownstreamSet = DownstreamSet()
43 
44     // TODO: inline DepthTracker? would need to be added to PushNode signature
45     final override val depthTracker = DepthTracker()
46 
47     val transactionCache = TransactionCache<MuxResult<W, K, V>>()
48     val epoch
49         get() = transactionCache.epoch
50 
51     inline fun hasCurrentValueLocked(evalScope: EvalScope): Boolean = epoch == evalScope.epoch
52 
53     override fun hasCurrentValue(logIndent: Int, evalScope: EvalScope): Boolean =
54         hasCurrentValueLocked(evalScope)
55 
56     final override fun addDownstream(downstream: Schedulable) {
57         addDownstreamLocked(downstream)
58     }
59 
60     /**
61      * Adds a downstream schedulable to this mux node, such that when this mux node emits a value,
62      * it will be scheduled for evaluation within this same transaction.
63      *
64      * Must only be called when [mutex] is acquired.
65      */
66     fun addDownstreamLocked(downstream: Schedulable) {
67         downstreamSet.add(downstream)
68     }
69 
70     final override fun removeDownstream(downstream: Schedulable) {
71         // TODO: return boolean?
72         downstreamSet.remove(downstream)
73     }
74 
75     final override fun removeDownstreamAndDeactivateIfNeeded(downstream: Schedulable) {
76         downstreamSet.remove(downstream)
77         val deactivate = downstreamSet.isEmpty()
78         if (deactivate) {
79             doDeactivate()
80         }
81     }
82 
83     final override fun deactivateIfNeeded() {
84         if (downstreamSet.isEmpty()) {
85             doDeactivate()
86         }
87     }
88 
89     /** visit this node from the scheduler (push eval) */
90     abstract fun visit(logIndent: Int, evalScope: EvalScope)
91 
92     /** perform deactivation logic, propagating to all upstream nodes. */
93     protected abstract fun doDeactivate()
94 
95     final override fun scheduleDeactivationIfNeeded(evalScope: EvalScope) {
96         if (downstreamSet.isEmpty()) {
97             evalScope.scheduleDeactivation(this)
98         }
99     }
100 
101     fun adjustDirectUpstream(scheduler: Scheduler, oldDepth: Int, newDepth: Int) {
102 
103         if (depthTracker.addDirectUpstream(oldDepth, newDepth)) {
104             depthTracker.schedule(scheduler, this)
105         }
106     }
107 
108     fun moveIndirectUpstreamToDirect(
109         scheduler: Scheduler,
110         oldIndirectDepth: Int,
111         oldIndirectRoots: Set<MuxDeferredNode<*, *, *>>,
112         newDepth: Int,
113     ) {
114         if (
115             depthTracker.addDirectUpstream(oldDepth = null, newDepth) or
116                 depthTracker.removeIndirectUpstream(depth = oldIndirectDepth) or
117                 depthTracker.updateIndirectRoots(removals = oldIndirectRoots)
118         ) {
119             depthTracker.schedule(scheduler, this)
120         }
121     }
122 
123     fun adjustIndirectUpstream(
124         scheduler: Scheduler,
125         oldDepth: Int,
126         newDepth: Int,
127         removals: Set<MuxDeferredNode<*, *, *>>,
128         additions: Set<MuxDeferredNode<*, *, *>>,
129     ) {
130         if (
131             depthTracker.addIndirectUpstream(oldDepth, newDepth) or
132                 depthTracker.updateIndirectRoots(
133                     additions,
134                     removals,
135                     butNot = this as? MuxDeferredNode<*, *, *>,
136                 )
137         ) {
138             depthTracker.schedule(scheduler, this)
139         }
140     }
141 
142     fun moveDirectUpstreamToIndirect(
143         scheduler: Scheduler,
144         oldDepth: Int,
145         newDepth: Int,
146         newIndirectSet: Set<MuxDeferredNode<*, *, *>>,
147     ) {
148         if (
149             depthTracker.addIndirectUpstream(oldDepth = null, newDepth) or
150                 depthTracker.removeDirectUpstream(oldDepth) or
151                 depthTracker.updateIndirectRoots(
152                     additions = newIndirectSet,
153                     butNot = this as? MuxDeferredNode<*, *, *>,
154                 )
155         ) {
156             depthTracker.schedule(scheduler, this)
157         }
158     }
159 
160     fun removeDirectUpstream(scheduler: Scheduler, depth: Int, key: K) {
161         switchedIn.remove(key)
162         if (depthTracker.removeDirectUpstream(depth)) {
163             depthTracker.schedule(scheduler, this)
164         }
165     }
166 
167     fun removeIndirectUpstream(
168         scheduler: Scheduler,
169         oldDepth: Int,
170         indirectSet: Set<MuxDeferredNode<*, *, *>>,
171         key: K,
172     ) {
173         switchedIn.remove(key)
174         if (
175             depthTracker.removeIndirectUpstream(oldDepth) or
176                 depthTracker.updateIndirectRoots(removals = indirectSet)
177         ) {
178             depthTracker.schedule(scheduler, this)
179         }
180     }
181 
182     fun visitCompact(scheduler: Scheduler) {
183         if (depthTracker.isDirty()) {
184             depthTracker.applyChanges(scheduler, downstreamSet, this@MuxNode)
185         }
186     }
187 
188     fun schedule(evalScope: EvalScope) {
189         // TODO: Potential optimization
190         //  Detect if this node is guaranteed to have a single upstream within this transaction,
191         //  then bypass scheduling it. Instead immediately schedule its downstream and treat this
192         //  MuxNode as a Pull (effectively making it a mapCheap).
193         depthTracker.schedule(evalScope.scheduler, this)
194     }
195 
196     /** An input branch of a mux node, associated with a key. */
197     inner class BranchNode(val key: K) : SchedulableNode {
198 
199         val schedulable = Schedulable.N(this)
200 
201         lateinit var upstream: NodeConnection<V>
202 
203         override fun schedule(logIndent: Int, evalScope: EvalScope) {
204             logDuration(logIndent, "MuxBranchNode.schedule") {
205                 if (this@MuxNode is MuxPromptNode && this@MuxNode.name != null) {
206                     logLn("[${this@MuxNode}] scheduling $key")
207                 }
208                 upstreamData[key] = upstream.directUpstream
209                 this@MuxNode.schedule(evalScope)
210             }
211         }
212 
213         override fun adjustDirectUpstream(scheduler: Scheduler, oldDepth: Int, newDepth: Int) {
214             this@MuxNode.adjustDirectUpstream(scheduler, oldDepth, newDepth)
215         }
216 
217         override fun moveIndirectUpstreamToDirect(
218             scheduler: Scheduler,
219             oldIndirectDepth: Int,
220             oldIndirectSet: Set<MuxDeferredNode<*, *, *>>,
221             newDirectDepth: Int,
222         ) {
223             this@MuxNode.moveIndirectUpstreamToDirect(
224                 scheduler,
225                 oldIndirectDepth,
226                 oldIndirectSet,
227                 newDirectDepth,
228             )
229         }
230 
231         override fun adjustIndirectUpstream(
232             scheduler: Scheduler,
233             oldDepth: Int,
234             newDepth: Int,
235             removals: Set<MuxDeferredNode<*, *, *>>,
236             additions: Set<MuxDeferredNode<*, *, *>>,
237         ) {
238             this@MuxNode.adjustIndirectUpstream(scheduler, oldDepth, newDepth, removals, additions)
239         }
240 
241         override fun moveDirectUpstreamToIndirect(
242             scheduler: Scheduler,
243             oldDirectDepth: Int,
244             newIndirectDepth: Int,
245             newIndirectSet: Set<MuxDeferredNode<*, *, *>>,
246         ) {
247             this@MuxNode.moveDirectUpstreamToIndirect(
248                 scheduler,
249                 oldDirectDepth,
250                 newIndirectDepth,
251                 newIndirectSet,
252             )
253         }
254 
255         override fun removeDirectUpstream(scheduler: Scheduler, depth: Int) {
256             removeDirectUpstream(scheduler, depth, key)
257         }
258 
259         override fun removeIndirectUpstream(
260             scheduler: Scheduler,
261             depth: Int,
262             indirectSet: Set<MuxDeferredNode<*, *, *>>,
263         ) {
264             removeIndirectUpstream(scheduler, depth, indirectSet, key)
265         }
266 
267         override fun toString(): String = "MuxBranchNode(key=$key, mux=${this@MuxNode})"
268     }
269 }
270 
271 internal typealias BranchNode<W, K, V> = MuxNode<W, K, V>.BranchNode
272 
273 /** Tracks lifecycle of MuxNode in the network. Essentially a mutable ref for MuxLifecycleState. */
274 internal class MuxLifecycle<W, K, V>(var lifecycleState: MuxLifecycleState<W, K, V>) :
275     EventsImpl<MuxResult<W, K, V>> {
276 
toStringnull277     override fun toString(): String = "MuxLifecycle[$hashString][$lifecycleState]"
278 
279     override fun activate(
280         evalScope: EvalScope,
281         downstream: Schedulable,
282     ): ActivationResult<MuxResult<W, K, V>>? =
283         when (val state = lifecycleState) {
284             is MuxLifecycleState.Dead -> {
285                 null
286             }
287             is MuxLifecycleState.Active -> {
288                 state.node.addDownstreamLocked(downstream)
289                 ActivationResult(
290                     connection = NodeConnection(state.node, state.node),
291                     needsEval = state.node.hasCurrentValueLocked(evalScope),
292                 )
293             }
294             is MuxLifecycleState.Inactive -> {
295                 state.spec
296                     .activate(evalScope, this@MuxLifecycle)
297                     .also { node ->
298                         lifecycleState =
299                             if (node == null) {
300                                 MuxLifecycleState.Dead
301                             } else {
302                                 MuxLifecycleState.Active(node.first)
303                             }
304                     }
305                     ?.let { (node, postActivate) ->
306                         postActivate?.invoke()
307                         node.addDownstreamLocked(downstream)
308                         ActivationResult(connection = NodeConnection(node, node), needsEval = false)
309                     }
310             }
311         }
312 }
313 
314 internal sealed interface MuxLifecycleState<out W, out K, out V> {
315     class Inactive<W, K, V>(val spec: MuxActivator<W, K, V>) : MuxLifecycleState<W, K, V> {
toStringnull316         override fun toString(): String = "Inactive"
317     }
318 
319     class Active<W, K, V>(val node: MuxNode<W, K, V>) : MuxLifecycleState<W, K, V> {
320         override fun toString(): String = "Active(node=$node)"
321     }
322 
323     data object Dead : MuxLifecycleState<Nothing, Nothing, Nothing>
324 }
325 
326 internal interface MuxActivator<W, K, V> {
activatenull327     fun activate(
328         evalScope: EvalScope,
329         lifecycle: MuxLifecycle<W, K, V>,
330     ): Pair<MuxNode<W, K, V>, (() -> Unit)?>?
331 }
332 
333 internal inline fun <W, K, V> MuxLifecycle(
334     onSubscribe: MuxActivator<W, K, V>
335 ): EventsImpl<MuxResult<W, K, V>> = MuxLifecycle(MuxLifecycleState.Inactive(onSubscribe))
336 
337 internal fun <K, V> EventsImpl<MuxResult<MapHolder.W, K, V>>.awaitValues(): EventsImpl<Map<K, V>> =
338     mapImpl({ this@awaitValues }) { results, logIndent ->
<lambda>null339         results.asMapHolder().unwrapped.mapValues { it.value.getPushEvent(logIndent, this) }
340     }
341 
342 // activation logic
343 
initializeUpstreamnull344 internal fun <W, K, V> MuxNode<W, K, V>.initializeUpstream(
345     evalScope: EvalScope,
346     getStorage: EvalScope.() -> Iterable<Map.Entry<K, EventsImpl<V>>>,
347     storeFactory: MutableMapK.Factory<W, K>,
348 ) {
349     val storage = getStorage(evalScope)
350     val initUpstream = buildList {
351         storage.forEach { (key, events) ->
352             val branchNode = BranchNode(key)
353             add(
354                 events.activate(evalScope, branchNode.schedulable)?.let { (conn, needsEval) ->
355                     Triple(
356                         key,
357                         branchNode.apply { upstream = conn },
358                         if (needsEval) conn.directUpstream else null,
359                     )
360                 }
361             )
362         }
363     }
364     switchedIn = storeFactory.create(initUpstream.size)
365     upstreamData = storeFactory.create(initUpstream.size)
366     for (triple in initUpstream) {
367         triple?.let { (key, branch, upstream) ->
368             switchedIn[key] = branch
369             upstream?.let { upstreamData[key] = upstream }
370         }
371     }
372 }
373 
initializeDepthnull374 internal fun <W, K, V> MuxNode<W, K, V>.initializeDepth() {
375     switchedIn.forEach { (_, branch) ->
376         val conn = branch.upstream
377         if (conn.depthTracker.snapshotIsDirect) {
378             depthTracker.addDirectUpstream(
379                 oldDepth = null,
380                 newDepth = conn.depthTracker.snapshotDirectDepth,
381             )
382         } else {
383             depthTracker.addIndirectUpstream(
384                 oldDepth = null,
385                 newDepth = conn.depthTracker.snapshotIndirectDepth,
386             )
387             depthTracker.updateIndirectRoots(
388                 additions = conn.depthTracker.snapshotIndirectRoots,
389                 butNot = null,
390             )
391         }
392     }
393 }
394