• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.systemui.kairos.internal
18 
19 import com.android.systemui.kairos.internal.store.MutableMapK
20 import com.android.systemui.kairos.internal.store.SingletonMapK
21 import com.android.systemui.kairos.internal.store.asSingle
22 import com.android.systemui.kairos.internal.store.singleOf
23 import com.android.systemui.kairos.internal.util.LogIndent
24 import com.android.systemui.kairos.internal.util.hashString
25 import com.android.systemui.kairos.internal.util.logDuration
26 import com.android.systemui.kairos.util.Maybe
27 import com.android.systemui.kairos.util.Maybe.Absent
28 import com.android.systemui.kairos.util.Maybe.Present
29 
30 internal class MuxPromptNode<W, K, V>(
31     val name: String?,
32     lifecycle: MuxLifecycle<W, K, V>,
33     private val spec: MuxActivator<W, K, V>,
34     factory: MutableMapK.Factory<W, K>,
35 ) : MuxNode<W, K, V>(lifecycle, factory) {
36 
37     var patchData: Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>? = null
38     var patches: PatchNode? = null
39 
40     override fun visit(logIndent: Int, evalScope: EvalScope) {
41         check(epoch < evalScope.epoch) { "node unexpectedly visited multiple times in transaction" }
42         logDuration(logIndent, "MuxPrompt.visit") {
43             val patch: Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>? = patchData
44             patchData = null
45 
46             // If there's a patch, process it.
47             patch?.let {
48                 val needsReschedule = processPatch(patch, evalScope)
49                 // We may need to reschedule if newly-switched-in nodes have not yet been
50                 // visited within this transaction.
51                 val depthIncreased = depthTracker.dirty_depthIncreased()
52                 if (needsReschedule || depthIncreased) {
53                     if (depthIncreased) {
54                         depthTracker.schedule(evalScope.compactor, this@MuxPromptNode)
55                     }
56                     if (name != null) {
57                         logLn(
58                             "[${this@MuxPromptNode}] rescheduling (reschedule=$needsReschedule, depthIncrease=$depthIncreased)"
59                         )
60                     }
61                     schedule(evalScope)
62                     return
63                 }
64             }
65             val results = upstreamData.readOnlyCopy().also { upstreamData.clear() }
66 
67             // If we don't need to reschedule, or there wasn't a patch at all, then we proceed
68             // with merging pre-switch and post-switch results
69             val hasResult = results.isNotEmpty()
70             val compactDownstream = depthTracker.isDirty()
71             if (hasResult || compactDownstream) {
72                 if (compactDownstream) {
73                     adjustDownstreamDepths(evalScope)
74                 }
75                 if (hasResult) {
76                     transactionCache.put(evalScope, results)
77                     if (!scheduleAll(currentLogIndent, downstreamSet, evalScope)) {
78                         evalScope.scheduleDeactivation(this@MuxPromptNode)
79                     }
80                 }
81             }
82         }
83     }
84 
85     // side-effect: this will populate `upstreamData` with any immediately available results
86     private fun LogIndent.processPatch(
87         patch: Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>,
88         evalScope: EvalScope,
89     ): Boolean {
90         var needsReschedule = false
91         // We have a patch, process additions/updates and removals
92         val adds = mutableListOf<Pair<K, EventsImpl<V>>>()
93         val removes = mutableListOf<K>()
94         patch.forEach { (k, newUpstream) ->
95             when (newUpstream) {
96                 is Present -> adds.add(k to newUpstream.value)
97                 Absent -> removes.add(k)
98             }
99         }
100 
101         val severed = mutableListOf<NodeConnection<*>>()
102 
103         // remove and sever
104         removes.forEach { k ->
105             switchedIn.remove(k)?.let { branchNode: BranchNode ->
106                 if (name != null) {
107                     logLn("[${this@MuxPromptNode}] removing $k")
108                 }
109                 val conn: NodeConnection<V> = branchNode.upstream
110                 severed.add(conn)
111                 conn.removeDownstream(downstream = branchNode.schedulable)
112                 if (conn.depthTracker.snapshotIsDirect) {
113                     depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
114                 } else {
115                     depthTracker.removeIndirectUpstream(conn.depthTracker.snapshotIndirectDepth)
116                     depthTracker.updateIndirectRoots(
117                         removals = conn.depthTracker.snapshotIndirectRoots
118                     )
119                 }
120             }
121         }
122 
123         // add or replace
124         adds.forEach { (k, newUpstream: EventsImpl<V>) ->
125             // remove old and sever, if present
126             switchedIn.remove(k)?.let { oldBranch: BranchNode ->
127                 if (name != null) {
128                     logLn("[${this@MuxPromptNode}] replacing $k")
129                 }
130                 val conn: NodeConnection<V> = oldBranch.upstream
131                 severed.add(conn)
132                 conn.removeDownstream(downstream = oldBranch.schedulable)
133                 if (conn.depthTracker.snapshotIsDirect) {
134                     depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
135                 } else {
136                     depthTracker.removeIndirectUpstream(conn.depthTracker.snapshotIndirectDepth)
137                     depthTracker.updateIndirectRoots(
138                         removals = conn.depthTracker.snapshotIndirectRoots
139                     )
140                 }
141             }
142 
143             // add new
144             val newBranch = BranchNode(k)
145             newUpstream.activate(evalScope, newBranch.schedulable)?.let { (conn, needsEval) ->
146                 newBranch.upstream = conn
147                 if (name != null) {
148                     logLn("[${this@MuxPromptNode}] switching in $k")
149                 }
150                 switchedIn[k] = newBranch
151                 if (needsEval) {
152                     upstreamData[k] = newBranch.upstream.directUpstream
153                 } else {
154                     needsReschedule = true
155                 }
156                 val branchDepthTracker = newBranch.upstream.depthTracker
157                 if (branchDepthTracker.snapshotIsDirect) {
158                     depthTracker.addDirectUpstream(
159                         oldDepth = null,
160                         newDepth = branchDepthTracker.snapshotDirectDepth,
161                     )
162                 } else {
163                     depthTracker.addIndirectUpstream(
164                         oldDepth = null,
165                         newDepth = branchDepthTracker.snapshotIndirectDepth,
166                     )
167                     depthTracker.updateIndirectRoots(
168                         additions = branchDepthTracker.snapshotIndirectRoots,
169                         butNot = null,
170                     )
171                 }
172             }
173         }
174 
175         for (severedNode in severed) {
176             severedNode.scheduleDeactivationIfNeeded(evalScope)
177         }
178 
179         return needsReschedule
180     }
181 
182     private fun adjustDownstreamDepths(evalScope: EvalScope) {
183         if (depthTracker.dirty_depthIncreased()) {
184             // schedule downstream nodes on the compaction scheduler; this scheduler is drained at
185             // the end of this eval depth, so that all depth increases are applied before we advance
186             // the eval step
187             depthTracker.schedule(evalScope.compactor, node = this@MuxPromptNode)
188         } else if (depthTracker.isDirty()) {
189             // schedule downstream nodes on the eval scheduler; this is more efficient and is only
190             // safe if the depth hasn't increased
191             depthTracker.applyChanges(
192                 evalScope.scheduler,
193                 downstreamSet,
194                 muxNode = this@MuxPromptNode,
195             )
196         }
197     }
198 
199     override fun getPushEvent(logIndent: Int, evalScope: EvalScope): MuxResult<W, K, V> =
200         logDuration(logIndent, "MuxPrompt.getPushEvent") {
201             transactionCache.getCurrentValue(evalScope)
202         }
203 
204     override fun doDeactivate() {
205         // Update lifecycle
206         if (lifecycle.lifecycleState !is MuxLifecycleState.Active) return
207         lifecycle.lifecycleState = MuxLifecycleState.Inactive(spec)
208         // Process branch nodes
209         switchedIn.forEach { (_, branchNode) ->
210             branchNode.upstream.removeDownstreamAndDeactivateIfNeeded(
211                 downstream = branchNode.schedulable
212             )
213         }
214         // Process patch node
215         patches?.let { patches ->
216             patches.upstream.removeDownstreamAndDeactivateIfNeeded(downstream = patches.schedulable)
217         }
218     }
219 
220     fun removeIndirectPatchNode(
221         scheduler: Scheduler,
222         oldDepth: Int,
223         indirectSet: Set<MuxDeferredNode<*, *, *>>,
224     ) {
225         patches = null
226         if (
227             depthTracker.removeIndirectUpstream(oldDepth) or
228                 depthTracker.updateIndirectRoots(removals = indirectSet)
229         ) {
230             depthTracker.schedule(scheduler, this)
231         }
232     }
233 
234     fun removeDirectPatchNode(scheduler: Scheduler, depth: Int) {
235         patches = null
236         if (depthTracker.removeDirectUpstream(depth)) {
237             depthTracker.schedule(scheduler, this)
238         }
239     }
240 
241     override fun toString(): String =
242         "${this::class.simpleName}@$hashString${name?.let { "[$it]" }.orEmpty()}"
243 
244     inner class PatchNode : SchedulableNode {
245 
246         val schedulable = Schedulable.N(this)
247 
248         lateinit var upstream: NodeConnection<Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>>
249 
250         override fun schedule(logIndent: Int, evalScope: EvalScope) {
251             logDuration(logIndent, "MuxPromptPatchNode.schedule") {
252                 patchData = upstream.getPushEvent(currentLogIndent, evalScope)
253                 this@MuxPromptNode.schedule(evalScope)
254             }
255         }
256 
257         override fun adjustDirectUpstream(scheduler: Scheduler, oldDepth: Int, newDepth: Int) {
258             this@MuxPromptNode.adjustDirectUpstream(scheduler, oldDepth, newDepth)
259         }
260 
261         override fun moveIndirectUpstreamToDirect(
262             scheduler: Scheduler,
263             oldIndirectDepth: Int,
264             oldIndirectSet: Set<MuxDeferredNode<*, *, *>>,
265             newDirectDepth: Int,
266         ) {
267             this@MuxPromptNode.moveIndirectUpstreamToDirect(
268                 scheduler,
269                 oldIndirectDepth,
270                 oldIndirectSet,
271                 newDirectDepth,
272             )
273         }
274 
275         override fun adjustIndirectUpstream(
276             scheduler: Scheduler,
277             oldDepth: Int,
278             newDepth: Int,
279             removals: Set<MuxDeferredNode<*, *, *>>,
280             additions: Set<MuxDeferredNode<*, *, *>>,
281         ) {
282             this@MuxPromptNode.adjustIndirectUpstream(
283                 scheduler,
284                 oldDepth,
285                 newDepth,
286                 removals,
287                 additions,
288             )
289         }
290 
291         override fun moveDirectUpstreamToIndirect(
292             scheduler: Scheduler,
293             oldDirectDepth: Int,
294             newIndirectDepth: Int,
295             newIndirectSet: Set<MuxDeferredNode<*, *, *>>,
296         ) {
297             this@MuxPromptNode.moveDirectUpstreamToIndirect(
298                 scheduler,
299                 oldDirectDepth,
300                 newIndirectDepth,
301                 newIndirectSet,
302             )
303         }
304 
305         override fun removeDirectUpstream(scheduler: Scheduler, depth: Int) {
306             this@MuxPromptNode.removeDirectPatchNode(scheduler, depth)
307         }
308 
309         override fun removeIndirectUpstream(
310             scheduler: Scheduler,
311             depth: Int,
312             indirectSet: Set<MuxDeferredNode<*, *, *>>,
313         ) {
314             this@MuxPromptNode.removeIndirectPatchNode(scheduler, depth, indirectSet)
315         }
316     }
317 }
318 
switchPromptImplSinglenull319 internal inline fun <A> switchPromptImplSingle(
320     crossinline getStorage: EvalScope.() -> EventsImpl<A>,
321     crossinline getPatches: EvalScope.() -> EventsImpl<EventsImpl<A>>,
322 ): EventsImpl<A> {
323     val switchPromptImpl =
324         switchPromptImpl(
325             getStorage = { singleOf(getStorage()).asIterable() },
326             getPatches = {
327                 mapImpl(getPatches) { newEvents, _ ->
328                     singleOf(Maybe.present(newEvents)).asIterable()
329                 }
330             },
331             storeFactory = SingletonMapK.Factory(),
332         )
333     return mapImpl({ switchPromptImpl }) { map, logIndent ->
334         map.asSingle().getValue(Unit).getPushEvent(logIndent, this)
335     }
336 }
337 
switchPromptImplnull338 internal fun <W, K, V> switchPromptImpl(
339     name: String? = null,
340     getStorage: EvalScope.() -> Iterable<Map.Entry<K, EventsImpl<V>>>,
341     getPatches: EvalScope.() -> EventsImpl<Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>>,
342     storeFactory: MutableMapK.Factory<W, K>,
343 ): EventsImpl<MuxResult<W, K, V>> =
344     MuxLifecycle(MuxPromptActivator(name, getStorage, storeFactory, getPatches))
345 
346 private class MuxPromptActivator<W, K, V>(
347     private val name: String?,
348     private val getStorage: EvalScope.() -> Iterable<Map.Entry<K, EventsImpl<V>>>,
349     private val storeFactory: MutableMapK.Factory<W, K>,
350     private val getPatches: EvalScope.() -> EventsImpl<Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>>,
351 ) : MuxActivator<W, K, V> {
352     override fun activate(
353         evalScope: EvalScope,
354         lifecycle: MuxLifecycle<W, K, V>,
355     ): Pair<MuxNode<W, K, V>, (() -> Unit)?>? {
356         // Initialize mux node and switched-in connections.
357         val movingNode =
358             MuxPromptNode(name, lifecycle, this, storeFactory).apply {
359                 initializeUpstream(evalScope, getStorage, storeFactory)
360                 // Setup patches connection
361                 val patchNode = PatchNode()
362                 getPatches(evalScope)
363                     .activate(evalScope = evalScope, downstream = patchNode.schedulable)
364                     ?.let { (conn, needsEval) ->
365                         patchNode.upstream = conn
366                         patches = patchNode
367                         if (needsEval) {
368                             patchData = conn.getPushEvent(0, evalScope)
369                         }
370                     }
371                 // Update depth based on all initial switched-in nodes.
372                 initializeDepth()
373                 // Update depth based on patches node.
374                 patches?.upstream?.let { conn ->
375                     if (conn.depthTracker.snapshotIsDirect) {
376                         depthTracker.addDirectUpstream(
377                             oldDepth = null,
378                             newDepth = conn.depthTracker.snapshotDirectDepth,
379                         )
380                     } else {
381                         depthTracker.addIndirectUpstream(
382                             oldDepth = null,
383                             newDepth = conn.depthTracker.snapshotIndirectDepth,
384                         )
385                         depthTracker.updateIndirectRoots(
386                             additions = conn.depthTracker.snapshotIndirectRoots,
387                             butNot = null,
388                         )
389                     }
390                 }
391                 // Reset all depth adjustments, since no downstream has been notified
392                 depthTracker.reset()
393             }
394 
395         // Schedule for evaluation if any switched-in nodes or the patches node have
396         // already emitted within this transaction.
397         if (movingNode.patchData != null || movingNode.upstreamData.isNotEmpty()) {
398             movingNode.schedule(evalScope)
399         }
400 
401         return if (movingNode.patches == null && movingNode.switchedIn.isEmpty()) {
402             null
403         } else {
404             movingNode to null
405         }
406     }
407 }
408