1 /*
<lambda>null2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package com.android.systemui.kairos.internal
18
19 import com.android.systemui.kairos.internal.store.MutableMapK
20 import com.android.systemui.kairos.internal.store.SingletonMapK
21 import com.android.systemui.kairos.internal.store.asSingle
22 import com.android.systemui.kairos.internal.store.singleOf
23 import com.android.systemui.kairos.internal.util.LogIndent
24 import com.android.systemui.kairos.internal.util.hashString
25 import com.android.systemui.kairos.internal.util.logDuration
26 import com.android.systemui.kairos.util.Maybe
27 import com.android.systemui.kairos.util.Maybe.Absent
28 import com.android.systemui.kairos.util.Maybe.Present
29
30 internal class MuxPromptNode<W, K, V>(
31 val name: String?,
32 lifecycle: MuxLifecycle<W, K, V>,
33 private val spec: MuxActivator<W, K, V>,
34 factory: MutableMapK.Factory<W, K>,
35 ) : MuxNode<W, K, V>(lifecycle, factory) {
36
37 var patchData: Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>? = null
38 var patches: PatchNode? = null
39
40 override fun visit(logIndent: Int, evalScope: EvalScope) {
41 check(epoch < evalScope.epoch) { "node unexpectedly visited multiple times in transaction" }
42 logDuration(logIndent, "MuxPrompt.visit") {
43 val patch: Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>? = patchData
44 patchData = null
45
46 // If there's a patch, process it.
47 patch?.let {
48 val needsReschedule = processPatch(patch, evalScope)
49 // We may need to reschedule if newly-switched-in nodes have not yet been
50 // visited within this transaction.
51 val depthIncreased = depthTracker.dirty_depthIncreased()
52 if (needsReschedule || depthIncreased) {
53 if (depthIncreased) {
54 depthTracker.schedule(evalScope.compactor, this@MuxPromptNode)
55 }
56 if (name != null) {
57 logLn(
58 "[${this@MuxPromptNode}] rescheduling (reschedule=$needsReschedule, depthIncrease=$depthIncreased)"
59 )
60 }
61 schedule(evalScope)
62 return
63 }
64 }
65 val results = upstreamData.readOnlyCopy().also { upstreamData.clear() }
66
67 // If we don't need to reschedule, or there wasn't a patch at all, then we proceed
68 // with merging pre-switch and post-switch results
69 val hasResult = results.isNotEmpty()
70 val compactDownstream = depthTracker.isDirty()
71 if (hasResult || compactDownstream) {
72 if (compactDownstream) {
73 adjustDownstreamDepths(evalScope)
74 }
75 if (hasResult) {
76 transactionCache.put(evalScope, results)
77 if (!scheduleAll(currentLogIndent, downstreamSet, evalScope)) {
78 evalScope.scheduleDeactivation(this@MuxPromptNode)
79 }
80 }
81 }
82 }
83 }
84
85 // side-effect: this will populate `upstreamData` with any immediately available results
86 private fun LogIndent.processPatch(
87 patch: Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>,
88 evalScope: EvalScope,
89 ): Boolean {
90 var needsReschedule = false
91 // We have a patch, process additions/updates and removals
92 val adds = mutableListOf<Pair<K, EventsImpl<V>>>()
93 val removes = mutableListOf<K>()
94 patch.forEach { (k, newUpstream) ->
95 when (newUpstream) {
96 is Present -> adds.add(k to newUpstream.value)
97 Absent -> removes.add(k)
98 }
99 }
100
101 val severed = mutableListOf<NodeConnection<*>>()
102
103 // remove and sever
104 removes.forEach { k ->
105 switchedIn.remove(k)?.let { branchNode: BranchNode ->
106 if (name != null) {
107 logLn("[${this@MuxPromptNode}] removing $k")
108 }
109 val conn: NodeConnection<V> = branchNode.upstream
110 severed.add(conn)
111 conn.removeDownstream(downstream = branchNode.schedulable)
112 if (conn.depthTracker.snapshotIsDirect) {
113 depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
114 } else {
115 depthTracker.removeIndirectUpstream(conn.depthTracker.snapshotIndirectDepth)
116 depthTracker.updateIndirectRoots(
117 removals = conn.depthTracker.snapshotIndirectRoots
118 )
119 }
120 }
121 }
122
123 // add or replace
124 adds.forEach { (k, newUpstream: EventsImpl<V>) ->
125 // remove old and sever, if present
126 switchedIn.remove(k)?.let { oldBranch: BranchNode ->
127 if (name != null) {
128 logLn("[${this@MuxPromptNode}] replacing $k")
129 }
130 val conn: NodeConnection<V> = oldBranch.upstream
131 severed.add(conn)
132 conn.removeDownstream(downstream = oldBranch.schedulable)
133 if (conn.depthTracker.snapshotIsDirect) {
134 depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
135 } else {
136 depthTracker.removeIndirectUpstream(conn.depthTracker.snapshotIndirectDepth)
137 depthTracker.updateIndirectRoots(
138 removals = conn.depthTracker.snapshotIndirectRoots
139 )
140 }
141 }
142
143 // add new
144 val newBranch = BranchNode(k)
145 newUpstream.activate(evalScope, newBranch.schedulable)?.let { (conn, needsEval) ->
146 newBranch.upstream = conn
147 if (name != null) {
148 logLn("[${this@MuxPromptNode}] switching in $k")
149 }
150 switchedIn[k] = newBranch
151 if (needsEval) {
152 upstreamData[k] = newBranch.upstream.directUpstream
153 } else {
154 needsReschedule = true
155 }
156 val branchDepthTracker = newBranch.upstream.depthTracker
157 if (branchDepthTracker.snapshotIsDirect) {
158 depthTracker.addDirectUpstream(
159 oldDepth = null,
160 newDepth = branchDepthTracker.snapshotDirectDepth,
161 )
162 } else {
163 depthTracker.addIndirectUpstream(
164 oldDepth = null,
165 newDepth = branchDepthTracker.snapshotIndirectDepth,
166 )
167 depthTracker.updateIndirectRoots(
168 additions = branchDepthTracker.snapshotIndirectRoots,
169 butNot = null,
170 )
171 }
172 }
173 }
174
175 for (severedNode in severed) {
176 severedNode.scheduleDeactivationIfNeeded(evalScope)
177 }
178
179 return needsReschedule
180 }
181
182 private fun adjustDownstreamDepths(evalScope: EvalScope) {
183 if (depthTracker.dirty_depthIncreased()) {
184 // schedule downstream nodes on the compaction scheduler; this scheduler is drained at
185 // the end of this eval depth, so that all depth increases are applied before we advance
186 // the eval step
187 depthTracker.schedule(evalScope.compactor, node = this@MuxPromptNode)
188 } else if (depthTracker.isDirty()) {
189 // schedule downstream nodes on the eval scheduler; this is more efficient and is only
190 // safe if the depth hasn't increased
191 depthTracker.applyChanges(
192 evalScope.scheduler,
193 downstreamSet,
194 muxNode = this@MuxPromptNode,
195 )
196 }
197 }
198
199 override fun getPushEvent(logIndent: Int, evalScope: EvalScope): MuxResult<W, K, V> =
200 logDuration(logIndent, "MuxPrompt.getPushEvent") {
201 transactionCache.getCurrentValue(evalScope)
202 }
203
204 override fun doDeactivate() {
205 // Update lifecycle
206 if (lifecycle.lifecycleState !is MuxLifecycleState.Active) return
207 lifecycle.lifecycleState = MuxLifecycleState.Inactive(spec)
208 // Process branch nodes
209 switchedIn.forEach { (_, branchNode) ->
210 branchNode.upstream.removeDownstreamAndDeactivateIfNeeded(
211 downstream = branchNode.schedulable
212 )
213 }
214 // Process patch node
215 patches?.let { patches ->
216 patches.upstream.removeDownstreamAndDeactivateIfNeeded(downstream = patches.schedulable)
217 }
218 }
219
220 fun removeIndirectPatchNode(
221 scheduler: Scheduler,
222 oldDepth: Int,
223 indirectSet: Set<MuxDeferredNode<*, *, *>>,
224 ) {
225 patches = null
226 if (
227 depthTracker.removeIndirectUpstream(oldDepth) or
228 depthTracker.updateIndirectRoots(removals = indirectSet)
229 ) {
230 depthTracker.schedule(scheduler, this)
231 }
232 }
233
234 fun removeDirectPatchNode(scheduler: Scheduler, depth: Int) {
235 patches = null
236 if (depthTracker.removeDirectUpstream(depth)) {
237 depthTracker.schedule(scheduler, this)
238 }
239 }
240
241 override fun toString(): String =
242 "${this::class.simpleName}@$hashString${name?.let { "[$it]" }.orEmpty()}"
243
244 inner class PatchNode : SchedulableNode {
245
246 val schedulable = Schedulable.N(this)
247
248 lateinit var upstream: NodeConnection<Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>>
249
250 override fun schedule(logIndent: Int, evalScope: EvalScope) {
251 logDuration(logIndent, "MuxPromptPatchNode.schedule") {
252 patchData = upstream.getPushEvent(currentLogIndent, evalScope)
253 this@MuxPromptNode.schedule(evalScope)
254 }
255 }
256
257 override fun adjustDirectUpstream(scheduler: Scheduler, oldDepth: Int, newDepth: Int) {
258 this@MuxPromptNode.adjustDirectUpstream(scheduler, oldDepth, newDepth)
259 }
260
261 override fun moveIndirectUpstreamToDirect(
262 scheduler: Scheduler,
263 oldIndirectDepth: Int,
264 oldIndirectSet: Set<MuxDeferredNode<*, *, *>>,
265 newDirectDepth: Int,
266 ) {
267 this@MuxPromptNode.moveIndirectUpstreamToDirect(
268 scheduler,
269 oldIndirectDepth,
270 oldIndirectSet,
271 newDirectDepth,
272 )
273 }
274
275 override fun adjustIndirectUpstream(
276 scheduler: Scheduler,
277 oldDepth: Int,
278 newDepth: Int,
279 removals: Set<MuxDeferredNode<*, *, *>>,
280 additions: Set<MuxDeferredNode<*, *, *>>,
281 ) {
282 this@MuxPromptNode.adjustIndirectUpstream(
283 scheduler,
284 oldDepth,
285 newDepth,
286 removals,
287 additions,
288 )
289 }
290
291 override fun moveDirectUpstreamToIndirect(
292 scheduler: Scheduler,
293 oldDirectDepth: Int,
294 newIndirectDepth: Int,
295 newIndirectSet: Set<MuxDeferredNode<*, *, *>>,
296 ) {
297 this@MuxPromptNode.moveDirectUpstreamToIndirect(
298 scheduler,
299 oldDirectDepth,
300 newIndirectDepth,
301 newIndirectSet,
302 )
303 }
304
305 override fun removeDirectUpstream(scheduler: Scheduler, depth: Int) {
306 this@MuxPromptNode.removeDirectPatchNode(scheduler, depth)
307 }
308
309 override fun removeIndirectUpstream(
310 scheduler: Scheduler,
311 depth: Int,
312 indirectSet: Set<MuxDeferredNode<*, *, *>>,
313 ) {
314 this@MuxPromptNode.removeIndirectPatchNode(scheduler, depth, indirectSet)
315 }
316 }
317 }
318
switchPromptImplSinglenull319 internal inline fun <A> switchPromptImplSingle(
320 crossinline getStorage: EvalScope.() -> EventsImpl<A>,
321 crossinline getPatches: EvalScope.() -> EventsImpl<EventsImpl<A>>,
322 ): EventsImpl<A> {
323 val switchPromptImpl =
324 switchPromptImpl(
325 getStorage = { singleOf(getStorage()).asIterable() },
326 getPatches = {
327 mapImpl(getPatches) { newEvents, _ ->
328 singleOf(Maybe.present(newEvents)).asIterable()
329 }
330 },
331 storeFactory = SingletonMapK.Factory(),
332 )
333 return mapImpl({ switchPromptImpl }) { map, logIndent ->
334 map.asSingle().getValue(Unit).getPushEvent(logIndent, this)
335 }
336 }
337
switchPromptImplnull338 internal fun <W, K, V> switchPromptImpl(
339 name: String? = null,
340 getStorage: EvalScope.() -> Iterable<Map.Entry<K, EventsImpl<V>>>,
341 getPatches: EvalScope.() -> EventsImpl<Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>>,
342 storeFactory: MutableMapK.Factory<W, K>,
343 ): EventsImpl<MuxResult<W, K, V>> =
344 MuxLifecycle(MuxPromptActivator(name, getStorage, storeFactory, getPatches))
345
346 private class MuxPromptActivator<W, K, V>(
347 private val name: String?,
348 private val getStorage: EvalScope.() -> Iterable<Map.Entry<K, EventsImpl<V>>>,
349 private val storeFactory: MutableMapK.Factory<W, K>,
350 private val getPatches: EvalScope.() -> EventsImpl<Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>>,
351 ) : MuxActivator<W, K, V> {
352 override fun activate(
353 evalScope: EvalScope,
354 lifecycle: MuxLifecycle<W, K, V>,
355 ): Pair<MuxNode<W, K, V>, (() -> Unit)?>? {
356 // Initialize mux node and switched-in connections.
357 val movingNode =
358 MuxPromptNode(name, lifecycle, this, storeFactory).apply {
359 initializeUpstream(evalScope, getStorage, storeFactory)
360 // Setup patches connection
361 val patchNode = PatchNode()
362 getPatches(evalScope)
363 .activate(evalScope = evalScope, downstream = patchNode.schedulable)
364 ?.let { (conn, needsEval) ->
365 patchNode.upstream = conn
366 patches = patchNode
367 if (needsEval) {
368 patchData = conn.getPushEvent(0, evalScope)
369 }
370 }
371 // Update depth based on all initial switched-in nodes.
372 initializeDepth()
373 // Update depth based on patches node.
374 patches?.upstream?.let { conn ->
375 if (conn.depthTracker.snapshotIsDirect) {
376 depthTracker.addDirectUpstream(
377 oldDepth = null,
378 newDepth = conn.depthTracker.snapshotDirectDepth,
379 )
380 } else {
381 depthTracker.addIndirectUpstream(
382 oldDepth = null,
383 newDepth = conn.depthTracker.snapshotIndirectDepth,
384 )
385 depthTracker.updateIndirectRoots(
386 additions = conn.depthTracker.snapshotIndirectRoots,
387 butNot = null,
388 )
389 }
390 }
391 // Reset all depth adjustments, since no downstream has been notified
392 depthTracker.reset()
393 }
394
395 // Schedule for evaluation if any switched-in nodes or the patches node have
396 // already emitted within this transaction.
397 if (movingNode.patchData != null || movingNode.upstreamData.isNotEmpty()) {
398 movingNode.schedule(evalScope)
399 }
400
401 return if (movingNode.patches == null && movingNode.switchedIn.isEmpty()) {
402 null
403 } else {
404 movingNode to null
405 }
406 }
407 }
408