1 /*
<lambda>null2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package com.android.systemui.kairos.internal
18
19 import com.android.systemui.kairos.internal.store.MapK
20 import com.android.systemui.kairos.internal.store.MutableArrayMapK
21 import com.android.systemui.kairos.internal.store.MutableMapK
22 import com.android.systemui.kairos.internal.store.SingletonMapK
23 import com.android.systemui.kairos.internal.store.StoreEntry
24 import com.android.systemui.kairos.internal.store.asArrayHolder
25 import com.android.systemui.kairos.internal.store.asSingle
26 import com.android.systemui.kairos.internal.store.singleOf
27 import com.android.systemui.kairos.internal.util.hashString
28 import com.android.systemui.kairos.internal.util.logDuration
29 import com.android.systemui.kairos.internal.util.logLn
30 import com.android.systemui.kairos.util.Maybe
31 import com.android.systemui.kairos.util.Maybe.Absent
32 import com.android.systemui.kairos.util.Maybe.Present
33 import com.android.systemui.kairos.util.These
34 import com.android.systemui.kairos.util.flatMap
35 import com.android.systemui.kairos.util.getMaybe
36 import com.android.systemui.kairos.util.maybeFirst
37 import com.android.systemui.kairos.util.maybeSecond
38 import com.android.systemui.kairos.util.merge
39 import com.android.systemui.kairos.util.orError
40 import com.android.systemui.kairos.util.these
41
42 internal class MuxDeferredNode<W, K, V>(
43 val name: String?,
44 lifecycle: MuxLifecycle<W, K, V>,
45 val spec: MuxActivator<W, K, V>,
46 factory: MutableMapK.Factory<W, K>,
47 ) : MuxNode<W, K, V>(lifecycle, factory) {
48
49 val schedulable = Schedulable.M(this)
50 var patches: NodeConnection<Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>>? = null
51 var patchData: Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>? = null
52
53 override fun visit(logIndent: Int, evalScope: EvalScope) {
54 check(epoch < evalScope.epoch) { "node unexpectedly visited multiple times in transaction" }
55 logDuration(logIndent, "MuxDeferred[$name].visit") {
56 val scheduleDownstream: Boolean
57 val result: MapK<W, K, PullNode<V>>
58 logDuration("copying upstream data", false) {
59 scheduleDownstream = upstreamData.isNotEmpty()
60 result = upstreamData.readOnlyCopy()
61 upstreamData.clear()
62 }
63 if (name != null) {
64 logLn("[${this@MuxDeferredNode}] result = $result")
65 }
66 val compactDownstream = depthTracker.isDirty()
67 if (scheduleDownstream || compactDownstream) {
68 if (compactDownstream) {
69 logDuration("compactDownstream", false) {
70 depthTracker.applyChanges(
71 evalScope.scheduler,
72 downstreamSet,
73 muxNode = this@MuxDeferredNode,
74 )
75 }
76 }
77 if (scheduleDownstream) {
78 logDuration("scheduleDownstream") {
79 if (name != null) {
80 logLn("[${this@MuxDeferredNode}] scheduling")
81 }
82 transactionCache.put(evalScope, result)
83 if (!scheduleAll(currentLogIndent, downstreamSet, evalScope)) {
84 evalScope.scheduleDeactivation(this@MuxDeferredNode)
85 }
86 }
87 }
88 }
89 }
90 }
91
92 override fun getPushEvent(logIndent: Int, evalScope: EvalScope): MuxResult<W, K, V> =
93 logDuration(logIndent, "MuxDeferred.getPushEvent") {
94 transactionCache.getCurrentValue(evalScope).also {
95 if (name != null) {
96 logLn("[${this@MuxDeferredNode}] getPushEvent = $it")
97 }
98 }
99 }
100
101 private fun compactIfNeeded(evalScope: EvalScope) {
102 depthTracker.propagateChanges(evalScope.compactor, this)
103 }
104
105 override fun doDeactivate() {
106 // Update lifecycle
107 if (lifecycle.lifecycleState !is MuxLifecycleState.Active) return@doDeactivate
108 lifecycle.lifecycleState = MuxLifecycleState.Inactive(spec)
109 // Process branch nodes
110 switchedIn.forEach { (_, branchNode) ->
111 branchNode.upstream.removeDownstreamAndDeactivateIfNeeded(branchNode.schedulable)
112 }
113 // Process patch node
114 patches?.removeDownstreamAndDeactivateIfNeeded(schedulable)
115 }
116
117 // MOVE phase
118 // - concurrent moves may be occurring, but no more evals. all depth recalculations are
119 // deferred to the end of this phase.
120 fun performMove(logIndent: Int, evalScope: EvalScope) {
121 if (name != null) {
122 logLn(logIndent, "[${this@MuxDeferredNode}] performMove (patchData = $patchData)")
123 }
124
125 val patch = patchData ?: return
126 patchData = null
127
128 // TODO: this logic is very similar to what's in MuxPrompt, maybe turn into an inline fun?
129
130 // We have a patch, process additions/updates and removals
131 val adds = mutableListOf<Pair<K, EventsImpl<V>>>()
132 val removes = mutableListOf<K>()
133 patch.forEach { (k, newUpstream) ->
134 when (newUpstream) {
135 is Present -> adds.add(k to newUpstream.value)
136 Absent -> removes.add(k)
137 }
138 }
139
140 val severed = mutableListOf<NodeConnection<*>>()
141
142 // remove and sever
143 removes.forEach { k ->
144 switchedIn.remove(k)?.let { branchNode: BranchNode ->
145 val conn = branchNode.upstream
146 severed.add(conn)
147 conn.removeDownstream(downstream = branchNode.schedulable)
148 if (conn.depthTracker.snapshotIsDirect) {
149 depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
150 } else {
151 depthTracker.removeIndirectUpstream(conn.depthTracker.snapshotIndirectDepth)
152 depthTracker.updateIndirectRoots(
153 removals = conn.depthTracker.snapshotIndirectRoots
154 )
155 }
156 }
157 }
158
159 // add or replace
160 adds.forEach { (k, newUpstream: EventsImpl<V>) ->
161 // remove old and sever, if present
162 switchedIn.remove(k)?.let { branchNode ->
163 val conn = branchNode.upstream
164 severed.add(conn)
165 conn.removeDownstream(downstream = branchNode.schedulable)
166 if (conn.depthTracker.snapshotIsDirect) {
167 depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
168 } else {
169 depthTracker.removeIndirectUpstream(conn.depthTracker.snapshotIndirectDepth)
170 depthTracker.updateIndirectRoots(
171 removals = conn.depthTracker.snapshotIndirectRoots
172 )
173 }
174 }
175
176 // add new
177 val newBranch = BranchNode(k)
178 newUpstream.activate(evalScope, newBranch.schedulable)?.let { (conn, _) ->
179 newBranch.upstream = conn
180 switchedIn[k] = newBranch
181 val branchDepthTracker = newBranch.upstream.depthTracker
182 if (branchDepthTracker.snapshotIsDirect) {
183 depthTracker.addDirectUpstream(
184 oldDepth = null,
185 newDepth = branchDepthTracker.snapshotDirectDepth,
186 )
187 } else {
188 depthTracker.addIndirectUpstream(
189 oldDepth = null,
190 newDepth = branchDepthTracker.snapshotIndirectDepth,
191 )
192 depthTracker.updateIndirectRoots(
193 additions = branchDepthTracker.snapshotIndirectRoots,
194 butNot = this@MuxDeferredNode,
195 )
196 }
197 }
198 }
199
200 for (severedNode in severed) {
201 severedNode.scheduleDeactivationIfNeeded(evalScope)
202 }
203
204 compactIfNeeded(evalScope)
205 }
206
207 fun removeDirectPatchNode(scheduler: Scheduler) {
208 if (
209 depthTracker.removeIndirectUpstream(depth = 0) or depthTracker.setIsIndirectRoot(false)
210 ) {
211 depthTracker.schedule(scheduler, this)
212 }
213 patches = null
214 }
215
216 fun removeIndirectPatchNode(
217 scheduler: Scheduler,
218 depth: Int,
219 indirectSet: Set<MuxDeferredNode<*, *, *>>,
220 ) {
221 // indirectly connected patches forward the indirectSet
222 if (
223 depthTracker.updateIndirectRoots(removals = indirectSet) or
224 depthTracker.removeIndirectUpstream(depth)
225 ) {
226 depthTracker.schedule(scheduler, this)
227 }
228 patches = null
229 }
230
231 fun moveIndirectPatchNodeToDirect(
232 scheduler: Scheduler,
233 oldIndirectDepth: Int,
234 oldIndirectSet: Set<MuxDeferredNode<*, *, *>>,
235 ) {
236 // directly connected patches are stored as an indirect singleton set of the patchNode
237 if (
238 depthTracker.updateIndirectRoots(removals = oldIndirectSet) or
239 depthTracker.removeIndirectUpstream(oldIndirectDepth) or
240 depthTracker.setIsIndirectRoot(true)
241 ) {
242 depthTracker.schedule(scheduler, this)
243 }
244 }
245
246 fun moveDirectPatchNodeToIndirect(
247 scheduler: Scheduler,
248 newIndirectDepth: Int,
249 newIndirectSet: Set<MuxDeferredNode<*, *, *>>,
250 ) {
251 // indirectly connected patches forward the indirectSet
252 if (
253 depthTracker.setIsIndirectRoot(false) or
254 depthTracker.updateIndirectRoots(additions = newIndirectSet, butNot = this) or
255 depthTracker.addIndirectUpstream(oldDepth = null, newDepth = newIndirectDepth)
256 ) {
257 depthTracker.schedule(scheduler, this)
258 }
259 }
260
261 fun adjustIndirectPatchNode(
262 scheduler: Scheduler,
263 oldDepth: Int,
264 newDepth: Int,
265 removals: Set<MuxDeferredNode<*, *, *>>,
266 additions: Set<MuxDeferredNode<*, *, *>>,
267 ) {
268 // indirectly connected patches forward the indirectSet
269 if (
270 depthTracker.updateIndirectRoots(
271 additions = additions,
272 removals = removals,
273 butNot = this,
274 ) or depthTracker.addIndirectUpstream(oldDepth = oldDepth, newDepth = newDepth)
275 ) {
276 depthTracker.schedule(scheduler, this)
277 }
278 }
279
280 fun scheduleMover(logIndent: Int, evalScope: EvalScope) {
281 logDuration(logIndent, "MuxDeferred.scheduleMover") {
282 patchData =
283 checkNotNull(patches) { "mux mover scheduled with unset patches upstream node" }
284 .getPushEvent(currentLogIndent, evalScope)
285 evalScope.scheduleMuxMover(this@MuxDeferredNode)
286 }
287 }
288
289 override fun toString(): String =
290 "${this::class.simpleName}@$hashString${name?.let { "[$it]" }.orEmpty()}"
291 }
292
switchDeferredImplSinglenull293 internal inline fun <A> switchDeferredImplSingle(
294 name: String? = null,
295 crossinline getStorage: EvalScope.() -> EventsImpl<A>,
296 crossinline getPatches: EvalScope.() -> EventsImpl<EventsImpl<A>>,
297 ): EventsImpl<A> {
298 val patches =
299 mapImpl(getPatches) { newEvents, _ -> singleOf(Maybe.present(newEvents)).asIterable() }
300 val switchDeferredImpl =
301 switchDeferredImpl(
302 name = name,
303 getStorage = { singleOf(getStorage()).asIterable() },
304 getPatches = { patches },
305 storeFactory = SingletonMapK.Factory(),
306 )
307 return mapImpl({ switchDeferredImpl }) { map, logIndent ->
308 map.asSingle().getValue(Unit).getPushEvent(logIndent, this).also {
309 if (name != null) {
310 logLn(logIndent, "[$name] extracting single mux: $it")
311 }
312 }
313 }
314 }
315
switchDeferredImplnull316 internal fun <W, K, V> switchDeferredImpl(
317 name: String? = null,
318 getStorage: EvalScope.() -> Iterable<Map.Entry<K, EventsImpl<V>>>,
319 getPatches: EvalScope.() -> EventsImpl<Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>>,
320 storeFactory: MutableMapK.Factory<W, K>,
321 ): EventsImpl<MuxResult<W, K, V>> =
322 MuxLifecycle(MuxDeferredActivator(name, getStorage, storeFactory, getPatches))
323
324 private class MuxDeferredActivator<W, K, V>(
325 private val name: String?,
326 private val getStorage: EvalScope.() -> Iterable<Map.Entry<K, EventsImpl<V>>>,
327 private val storeFactory: MutableMapK.Factory<W, K>,
328 private val getPatches: EvalScope.() -> EventsImpl<Iterable<Map.Entry<K, Maybe<EventsImpl<V>>>>>,
329 ) : MuxActivator<W, K, V> {
330 override fun activate(
331 evalScope: EvalScope,
332 lifecycle: MuxLifecycle<W, K, V>,
333 ): Pair<MuxNode<W, K, V>, (() -> Unit)?>? {
334 // Initialize mux node and switched-in connections.
335 val muxNode =
336 MuxDeferredNode(name, lifecycle, this, storeFactory).apply {
337 initializeUpstream(evalScope, getStorage, storeFactory)
338 // Update depth based on all initial switched-in nodes.
339 initializeDepth()
340 // We don't have our patches connection established yet, so for now pretend we have
341 // a direct connection to patches. We will update downstream nodes later if this
342 // turns out to be a lie.
343 depthTracker.setIsIndirectRoot(true)
344 depthTracker.reset()
345 }
346
347 // Schedule for evaluation if any switched-in nodes have already emitted within
348 // this transaction.
349 if (muxNode.upstreamData.isNotEmpty()) {
350 muxNode.schedule(evalScope)
351 }
352
353 return muxNode to
354 fun() {
355 // Setup patches connection; deferring allows for a recursive connection, where
356 // muxNode is downstream of itself via patches.
357 val (patchesConn, needsEval) =
358 getPatches(evalScope).activate(evalScope, downstream = muxNode.schedulable)
359 ?: run {
360 // Turns out we can't connect to patches, so update our depth
361 muxNode.depthTracker.setIsIndirectRoot(false)
362 return
363 }
364 muxNode.patches = patchesConn
365
366 if (!patchesConn.schedulerUpstream.depthTracker.snapshotIsDirect) {
367 // Turns out patches is indirect, so we are not a root. Update depth and
368 // propagate.
369 if (
370 muxNode.depthTracker.setIsIndirectRoot(false) or
371 muxNode.depthTracker.addIndirectUpstream(
372 oldDepth = null,
373 newDepth = patchesConn.depthTracker.snapshotIndirectDepth,
374 ) or
375 muxNode.depthTracker.updateIndirectRoots(
376 additions = patchesConn.depthTracker.snapshotIndirectRoots
377 )
378 ) {
379 muxNode.depthTracker.schedule(evalScope.scheduler, muxNode)
380 }
381 }
382 // Schedule mover to process patch emission at the end of this transaction, if
383 // needed.
384 if (needsEval) {
385 muxNode.patchData = patchesConn.getPushEvent(0, evalScope)
386 evalScope.scheduleMuxMover(muxNode)
387 }
388 }
389 }
390 }
391
mergeNodesnull392 internal inline fun <A> mergeNodes(
393 crossinline getPulse: EvalScope.() -> EventsImpl<A>,
394 crossinline getOther: EvalScope.() -> EventsImpl<A>,
395 name: String? = null,
396 crossinline f: EvalScope.(A, A) -> A,
397 ): EventsImpl<A> {
398 val mergedThese = mergeNodes(name, getPulse, getOther)
399 val merged =
400 mapImpl({ mergedThese }) { these, _ -> these.merge { thiz, that -> f(thiz, that) } }
401 return merged.cached()
402 }
403
asIterableWithIndexnull404 internal fun <T> Iterable<T>.asIterableWithIndex(): Iterable<Map.Entry<Int, T>> =
405 asSequence().mapIndexed { i, t -> StoreEntry(i, t) }.asIterable()
406
mergeNodesnull407 internal inline fun <A, B> mergeNodes(
408 name: String? = null,
409 crossinline getPulse: EvalScope.() -> EventsImpl<A>,
410 crossinline getOther: EvalScope.() -> EventsImpl<B>,
411 ): EventsImpl<These<A, B>> {
412 val storage =
413 listOf(
414 mapImpl(getPulse) { it, _ -> These.first(it) },
415 mapImpl(getOther) { it, _ -> These.second(it) },
416 )
417 .asIterableWithIndex()
418 val switchNode =
419 switchDeferredImpl(
420 name = name,
421 getStorage = { storage },
422 getPatches = { neverImpl },
423 storeFactory = MutableArrayMapK.Factory(),
424 )
425 val merged =
426 mapImpl({ switchNode }) { it, logIndent ->
427 val mergeResults = it.asArrayHolder()
428 val first =
429 mergeResults.getMaybe(0).flatMap { it.getPushEvent(logIndent, this).maybeFirst() }
430 val second =
431 mergeResults.getMaybe(1).flatMap { it.getPushEvent(logIndent, this).maybeSecond() }
432 these(first, second).orError { "unexpected missing merge result" }
433 }
434 return merged.cached()
435 }
436
mergeNodesnull437 internal inline fun <A> mergeNodes(
438 crossinline getPulses: EvalScope.() -> Iterable<EventsImpl<A>>
439 ): EventsImpl<List<A>> {
440 val switchNode =
441 switchDeferredImpl(
442 getStorage = { getPulses().asIterableWithIndex() },
443 getPatches = { neverImpl },
444 storeFactory = MutableArrayMapK.Factory(),
445 )
446 val merged =
447 mapImpl({ switchNode }) { it, logIndent ->
448 val mergeResults = it.asArrayHolder()
449 mergeResults.map { (_, node) -> node.getPushEvent(logIndent, this) }
450 }
451 return merged.cached()
452 }
453
mergeNodesLeftnull454 internal inline fun <A> mergeNodesLeft(
455 crossinline getPulses: EvalScope.() -> Iterable<EventsImpl<A>>
456 ): EventsImpl<A> {
457 val switchNode =
458 switchDeferredImpl(
459 getStorage = { getPulses().asIterableWithIndex() },
460 getPatches = { neverImpl },
461 storeFactory = MutableArrayMapK.Factory(),
462 )
463 val merged =
464 mapImpl({ switchNode }) { it, logIndent ->
465 val mergeResults = it.asArrayHolder()
466 mergeResults.values.first().getPushEvent(logIndent, this)
467 }
468 return merged.cached()
469 }
470