• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.systemui.kairos.internal
18 
19 import com.android.systemui.kairos.internal.store.MapK
20 import com.android.systemui.kairos.internal.store.MutableArrayMapK
21 import com.android.systemui.kairos.internal.store.MutableMapK
22 import com.android.systemui.kairos.internal.store.SingletonMapK
23 import com.android.systemui.kairos.internal.store.StoreEntry
24 import com.android.systemui.kairos.internal.store.asArrayHolder
25 import com.android.systemui.kairos.internal.store.asSingle
26 import com.android.systemui.kairos.internal.store.singleOf
27 import com.android.systemui.kairos.internal.util.hashString
28 import com.android.systemui.kairos.internal.util.logDuration
29 import com.android.systemui.kairos.internal.util.logLn
30 import com.android.systemui.kairos.util.Maybe
31 import com.android.systemui.kairos.util.Maybe.Absent
32 import com.android.systemui.kairos.util.Maybe.Present
33 import com.android.systemui.kairos.util.These
34 import com.android.systemui.kairos.util.flatMap
35 import com.android.systemui.kairos.util.getMaybe
36 import com.android.systemui.kairos.util.maybeFirst
37 import com.android.systemui.kairos.util.maybeSecond
38 import com.android.systemui.kairos.util.merge
39 import com.android.systemui.kairos.util.orError
40 import com.android.systemui.kairos.util.these
41 
42 internal class MuxDeferredNode<W, K, V>(
43     val name: String?,
44     lifecycle: MuxLifecycle<W, K, V>,
45     val spec: MuxActivator<W, K, V>,
46     factory: MutableMapK.Factory<W, K>,
47 ) : MuxNode<W, K, V>(lifecycle, factory) {
48 
49     val schedulable = Schedulable.M(this)
50     var patches: NodeConnection<Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>>? = null
51     var patchData: Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>? = null
52 
53     override fun visit(logIndent: Int, evalScope: EvalScope) {
54         check(epoch < evalScope.epoch) { "node unexpectedly visited multiple times in transaction" }
55         logDuration(logIndent, "MuxDeferred[$name].visit") {
56             val scheduleDownstream: Boolean
57             val result: MapK<W, K, PullNode<V>>
58             logDuration("copying upstream data", false) {
59                 scheduleDownstream = upstreamData.isNotEmpty()
60                 result = upstreamData.readOnlyCopy()
61                 upstreamData.clear()
62             }
63             if (name != null) {
64                 logLn("[${this@MuxDeferredNode}] result = $result")
65             }
66             val compactDownstream = depthTracker.isDirty()
67             if (scheduleDownstream || compactDownstream) {
68                 if (compactDownstream) {
69                     logDuration("compactDownstream", false) {
70                         depthTracker.applyChanges(
71                             evalScope.scheduler,
72                             downstreamSet,
73                             muxNode = this@MuxDeferredNode,
74                         )
75                     }
76                 }
77                 if (scheduleDownstream) {
78                     logDuration("scheduleDownstream") {
79                         if (name != null) {
80                             logLn("[${this@MuxDeferredNode}] scheduling")
81                         }
82                         transactionCache.put(evalScope, result)
83                         if (!scheduleAll(currentLogIndent, downstreamSet, evalScope)) {
84                             evalScope.scheduleDeactivation(this@MuxDeferredNode)
85                         }
86                     }
87                 }
88             }
89         }
90     }
91 
92     override fun getPushEvent(logIndent: Int, evalScope: EvalScope): MuxResult<W, K, V> =
93         logDuration(logIndent, "MuxDeferred.getPushEvent") {
94             transactionCache.getCurrentValue(evalScope).also {
95                 if (name != null) {
96                     logLn("[${this@MuxDeferredNode}] getPushEvent = $it")
97                 }
98             }
99         }
100 
101     private fun compactIfNeeded(evalScope: EvalScope) {
102         depthTracker.propagateChanges(evalScope.compactor, this)
103     }
104 
105     override fun doDeactivate() {
106         // Update lifecycle
107         if (lifecycle.lifecycleState !is MuxLifecycleState.Active) return@doDeactivate
108         lifecycle.lifecycleState = MuxLifecycleState.Inactive(spec)
109         // Process branch nodes
110         switchedIn.forEach { (_, branchNode) ->
111             branchNode.upstream.removeDownstreamAndDeactivateIfNeeded(branchNode.schedulable)
112         }
113         // Process patch node
114         patches?.removeDownstreamAndDeactivateIfNeeded(schedulable)
115     }
116 
117     // MOVE phase
118     //  - concurrent moves may be occurring, but no more evals. all depth recalculations are
119     //    deferred to the end of this phase.
120     fun performMove(logIndent: Int, evalScope: EvalScope) {
121         if (name != null) {
122             logLn(logIndent, "[${this@MuxDeferredNode}] performMove (patchData = $patchData)")
123         }
124 
125         val patch = patchData ?: return
126         patchData = null
127 
128         // TODO: this logic is very similar to what's in MuxPrompt, maybe turn into an inline fun?
129 
130         // We have a patch, process additions/updates and removals
131         val adds = mutableListOf<Pair<K, EventsImpl<V>>>()
132         val removes = mutableListOf<K>()
133         patch.forEach { (k, newUpstream) ->
134             when (newUpstream) {
135                 is Present -> adds.add(k to newUpstream.value)
136                 Absent -> removes.add(k)
137             }
138         }
139 
140         val severed = mutableListOf<NodeConnection<*>>()
141 
142         // remove and sever
143         removes.forEach { k ->
144             switchedIn.remove(k)?.let { branchNode: BranchNode ->
145                 val conn = branchNode.upstream
146                 severed.add(conn)
147                 conn.removeDownstream(downstream = branchNode.schedulable)
148                 if (conn.depthTracker.snapshotIsDirect) {
149                     depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
150                 } else {
151                     depthTracker.removeIndirectUpstream(conn.depthTracker.snapshotIndirectDepth)
152                     depthTracker.updateIndirectRoots(
153                         removals = conn.depthTracker.snapshotIndirectRoots
154                     )
155                 }
156             }
157         }
158 
159         // add or replace
160         adds.forEach { (k, newUpstream: EventsImpl<V>) ->
161             // remove old and sever, if present
162             switchedIn.remove(k)?.let { branchNode ->
163                 val conn = branchNode.upstream
164                 severed.add(conn)
165                 conn.removeDownstream(downstream = branchNode.schedulable)
166                 if (conn.depthTracker.snapshotIsDirect) {
167                     depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
168                 } else {
169                     depthTracker.removeIndirectUpstream(conn.depthTracker.snapshotIndirectDepth)
170                     depthTracker.updateIndirectRoots(
171                         removals = conn.depthTracker.snapshotIndirectRoots
172                     )
173                 }
174             }
175 
176             // add new
177             val newBranch = BranchNode(k)
178             newUpstream.activate(evalScope, newBranch.schedulable)?.let { (conn, _) ->
179                 newBranch.upstream = conn
180                 switchedIn[k] = newBranch
181                 val branchDepthTracker = newBranch.upstream.depthTracker
182                 if (branchDepthTracker.snapshotIsDirect) {
183                     depthTracker.addDirectUpstream(
184                         oldDepth = null,
185                         newDepth = branchDepthTracker.snapshotDirectDepth,
186                     )
187                 } else {
188                     depthTracker.addIndirectUpstream(
189                         oldDepth = null,
190                         newDepth = branchDepthTracker.snapshotIndirectDepth,
191                     )
192                     depthTracker.updateIndirectRoots(
193                         additions = branchDepthTracker.snapshotIndirectRoots,
194                         butNot = this@MuxDeferredNode,
195                     )
196                 }
197             }
198         }
199 
200         for (severedNode in severed) {
201             severedNode.scheduleDeactivationIfNeeded(evalScope)
202         }
203 
204         compactIfNeeded(evalScope)
205     }
206 
207     fun removeDirectPatchNode(scheduler: Scheduler) {
208         if (
209             depthTracker.removeIndirectUpstream(depth = 0) or depthTracker.setIsIndirectRoot(false)
210         ) {
211             depthTracker.schedule(scheduler, this)
212         }
213         patches = null
214     }
215 
216     fun removeIndirectPatchNode(
217         scheduler: Scheduler,
218         depth: Int,
219         indirectSet: Set<MuxDeferredNode<*, *, *>>,
220     ) {
221         // indirectly connected patches forward the indirectSet
222         if (
223             depthTracker.updateIndirectRoots(removals = indirectSet) or
224                 depthTracker.removeIndirectUpstream(depth)
225         ) {
226             depthTracker.schedule(scheduler, this)
227         }
228         patches = null
229     }
230 
231     fun moveIndirectPatchNodeToDirect(
232         scheduler: Scheduler,
233         oldIndirectDepth: Int,
234         oldIndirectSet: Set<MuxDeferredNode<*, *, *>>,
235     ) {
236         // directly connected patches are stored as an indirect singleton set of the patchNode
237         if (
238             depthTracker.updateIndirectRoots(removals = oldIndirectSet) or
239                 depthTracker.removeIndirectUpstream(oldIndirectDepth) or
240                 depthTracker.setIsIndirectRoot(true)
241         ) {
242             depthTracker.schedule(scheduler, this)
243         }
244     }
245 
246     fun moveDirectPatchNodeToIndirect(
247         scheduler: Scheduler,
248         newIndirectDepth: Int,
249         newIndirectSet: Set<MuxDeferredNode<*, *, *>>,
250     ) {
251         // indirectly connected patches forward the indirectSet
252         if (
253             depthTracker.setIsIndirectRoot(false) or
254                 depthTracker.updateIndirectRoots(additions = newIndirectSet, butNot = this) or
255                 depthTracker.addIndirectUpstream(oldDepth = null, newDepth = newIndirectDepth)
256         ) {
257             depthTracker.schedule(scheduler, this)
258         }
259     }
260 
261     fun adjustIndirectPatchNode(
262         scheduler: Scheduler,
263         oldDepth: Int,
264         newDepth: Int,
265         removals: Set<MuxDeferredNode<*, *, *>>,
266         additions: Set<MuxDeferredNode<*, *, *>>,
267     ) {
268         // indirectly connected patches forward the indirectSet
269         if (
270             depthTracker.updateIndirectRoots(
271                 additions = additions,
272                 removals = removals,
273                 butNot = this,
274             ) or depthTracker.addIndirectUpstream(oldDepth = oldDepth, newDepth = newDepth)
275         ) {
276             depthTracker.schedule(scheduler, this)
277         }
278     }
279 
280     fun scheduleMover(logIndent: Int, evalScope: EvalScope) {
281         logDuration(logIndent, "MuxDeferred.scheduleMover") {
282             patchData =
283                 checkNotNull(patches) { "mux mover scheduled with unset patches upstream node" }
284                     .getPushEvent(currentLogIndent, evalScope)
285             evalScope.scheduleMuxMover(this@MuxDeferredNode)
286         }
287     }
288 
289     override fun toString(): String =
290         "${this::class.simpleName}@$hashString${name?.let { "[$it]" }.orEmpty()}"
291 }
292 
switchDeferredImplSinglenull293 internal inline fun <A> switchDeferredImplSingle(
294     name: String? = null,
295     crossinline getStorage: EvalScope.() -> EventsImpl<A>,
296     crossinline getPatches: EvalScope.() -> EventsImpl<EventsImpl<A>>,
297 ): EventsImpl<A> {
298     val patches =
299         mapImpl(getPatches) { newEvents, _ -> singleOf(Maybe.present(newEvents)).asIterable() }
300     val switchDeferredImpl =
301         switchDeferredImpl(
302             name = name,
303             getStorage = { singleOf(getStorage()).asIterable() },
304             getPatches = { patches },
305             storeFactory = SingletonMapK.Factory(),
306         )
307     return mapImpl({ switchDeferredImpl }) { map, logIndent ->
308         map.asSingle().getValue(Unit).getPushEvent(logIndent, this).also {
309             if (name != null) {
310                 logLn(logIndent, "[$name] extracting single mux: $it")
311             }
312         }
313     }
314 }
315 
switchDeferredImplnull316 internal fun <W, K, V> switchDeferredImpl(
317     name: String? = null,
318     getStorage: EvalScope.() -> Iterable<Map.Entry<K, EventsImpl<V>>>,
319     getPatches: EvalScope.() -> EventsImpl<Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>>,
320     storeFactory: MutableMapK.Factory<W, K>,
321 ): EventsImpl<MuxResult<W, K, V>> =
322     MuxLifecycle(MuxDeferredActivator(name, getStorage, storeFactory, getPatches))
323 
324 private class MuxDeferredActivator<W, K, V>(
325     private val name: String?,
326     private val getStorage: EvalScope.() -> Iterable<Map.Entry<K, EventsImpl<V>>>,
327     private val storeFactory: MutableMapK.Factory<W, K>,
328     private val getPatches: EvalScope.() -> EventsImpl<Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>>,
329 ) : MuxActivator<W, K, V> {
330     override fun activate(
331         evalScope: EvalScope,
332         lifecycle: MuxLifecycle<W, K, V>,
333     ): Pair<MuxNode<W, K, V>, (() -> Unit)?>? {
334         // Initialize mux node and switched-in connections.
335         val muxNode =
336             MuxDeferredNode(name, lifecycle, this, storeFactory).apply {
337                 initializeUpstream(evalScope, getStorage, storeFactory)
338                 // Update depth based on all initial switched-in nodes.
339                 initializeDepth()
340                 // We don't have our patches connection established yet, so for now pretend we have
341                 // a direct connection to patches. We will update downstream nodes later if this
342                 // turns out to be a lie.
343                 depthTracker.setIsIndirectRoot(true)
344                 depthTracker.reset()
345             }
346 
347         // Schedule for evaluation if any switched-in nodes have already emitted within
348         // this transaction.
349         if (muxNode.upstreamData.isNotEmpty()) {
350             muxNode.schedule(evalScope)
351         }
352 
353         return muxNode to
354             fun() {
355                 // Setup patches connection; deferring allows for a recursive connection, where
356                 // muxNode is downstream of itself via patches.
357                 val (patchesConn, needsEval) =
358                     getPatches(evalScope).activate(evalScope, downstream = muxNode.schedulable)
359                         ?: run {
360                             // Turns out we can't connect to patches, so update our depth
361                             muxNode.depthTracker.setIsIndirectRoot(false)
362                             return
363                         }
364                 muxNode.patches = patchesConn
365 
366                 if (!patchesConn.schedulerUpstream.depthTracker.snapshotIsDirect) {
367                     // Turns out patches is indirect, so we are not a root. Update depth and
368                     // propagate.
369                     if (
370                         muxNode.depthTracker.setIsIndirectRoot(false) or
371                             muxNode.depthTracker.addIndirectUpstream(
372                                 oldDepth = null,
373                                 newDepth = patchesConn.depthTracker.snapshotIndirectDepth,
374                             ) or
375                             muxNode.depthTracker.updateIndirectRoots(
376                                 additions = patchesConn.depthTracker.snapshotIndirectRoots
377                             )
378                     ) {
379                         muxNode.depthTracker.schedule(evalScope.scheduler, muxNode)
380                     }
381                 }
382                 // Schedule mover to process patch emission at the end of this transaction, if
383                 // needed.
384                 if (needsEval) {
385                     muxNode.patchData = patchesConn.getPushEvent(0, evalScope)
386                     evalScope.scheduleMuxMover(muxNode)
387                 }
388             }
389     }
390 }
391 
mergeNodesnull392 internal inline fun <A> mergeNodes(
393     crossinline getPulse: EvalScope.() -> EventsImpl<A>,
394     crossinline getOther: EvalScope.() -> EventsImpl<A>,
395     name: String? = null,
396     crossinline f: EvalScope.(A, A) -> A,
397 ): EventsImpl<A> {
398     val mergedThese = mergeNodes(name, getPulse, getOther)
399     val merged =
400         mapImpl({ mergedThese }) { these, _ -> these.merge { thiz, that -> f(thiz, that) } }
401     return merged.cached()
402 }
403 
asIterableWithIndexnull404 internal fun <T> Iterable<T>.asIterableWithIndex(): Iterable<Map.Entry<Int, T>> =
405     asSequence().mapIndexed { i, t -> StoreEntry(i, t) }.asIterable()
406 
mergeNodesnull407 internal inline fun <A, B> mergeNodes(
408     name: String? = null,
409     crossinline getPulse: EvalScope.() -> EventsImpl<A>,
410     crossinline getOther: EvalScope.() -> EventsImpl<B>,
411 ): EventsImpl<These<A, B>> {
412     val storage =
413         listOf(
414                 mapImpl(getPulse) { it, _ -> These.first(it) },
415                 mapImpl(getOther) { it, _ -> These.second(it) },
416             )
417             .asIterableWithIndex()
418     val switchNode =
419         switchDeferredImpl(
420             name = name,
421             getStorage = { storage },
422             getPatches = { neverImpl },
423             storeFactory = MutableArrayMapK.Factory(),
424         )
425     val merged =
426         mapImpl({ switchNode }) { it, logIndent ->
427             val mergeResults = it.asArrayHolder()
428             val first =
429                 mergeResults.getMaybe(0).flatMap { it.getPushEvent(logIndent, this).maybeFirst() }
430             val second =
431                 mergeResults.getMaybe(1).flatMap { it.getPushEvent(logIndent, this).maybeSecond() }
432             these(first, second).orError { "unexpected missing merge result" }
433         }
434     return merged.cached()
435 }
436 
mergeNodesnull437 internal inline fun <A> mergeNodes(
438     crossinline getPulses: EvalScope.() -> Iterable<EventsImpl<A>>
439 ): EventsImpl<List<A>> {
440     val switchNode =
441         switchDeferredImpl(
442             getStorage = { getPulses().asIterableWithIndex() },
443             getPatches = { neverImpl },
444             storeFactory = MutableArrayMapK.Factory(),
445         )
446     val merged =
447         mapImpl({ switchNode }) { it, logIndent ->
448             val mergeResults = it.asArrayHolder()
449             mergeResults.map { (_, node) -> node.getPushEvent(logIndent, this) }
450         }
451     return merged.cached()
452 }
453 
mergeNodesLeftnull454 internal inline fun <A> mergeNodesLeft(
455     crossinline getPulses: EvalScope.() -> Iterable<EventsImpl<A>>
456 ): EventsImpl<A> {
457     val switchNode =
458         switchDeferredImpl(
459             getStorage = { getPulses().asIterableWithIndex() },
460             getPatches = { neverImpl },
461             storeFactory = MutableArrayMapK.Factory(),
462         )
463     val merged =
464         mapImpl({ switchNode }) { it, logIndent ->
465             val mergeResults = it.asArrayHolder()
466             mergeResults.values.first().getPushEvent(logIndent, this)
467         }
468     return merged.cached()
469 }
470