1 /*
<lambda>null2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 @file:Suppress("NOTHING_TO_INLINE")
18
19 package com.android.systemui.kairos.internal
20
21 import com.android.systemui.kairos.internal.store.MapHolder
22 import com.android.systemui.kairos.internal.store.MapK
23 import com.android.systemui.kairos.internal.store.MutableMapK
24 import com.android.systemui.kairos.internal.store.asMapHolder
25 import com.android.systemui.kairos.internal.util.hashString
26 import com.android.systemui.kairos.internal.util.logDuration
27
28 internal typealias MuxResult<W, K, V> = MapK<W, K, PullNode<V>>
29
30 /** Base class for muxing nodes, which have a (potentially dynamic) collection of upstream nodes. */
31 internal sealed class MuxNode<W, K, V>(
32 val lifecycle: MuxLifecycle<W, K, V>,
33 protected val storeFactory: MutableMapK.Factory<W, K>,
34 ) : PushNode<MuxResult<W, K, V>> {
35
36 lateinit var upstreamData: MutableMapK<W, K, PullNode<V>>
37 lateinit var switchedIn: MutableMapK<W, K, BranchNode>
38
39 @Volatile var markedForCompaction = false
40 @Volatile var markedForEvaluation = false
41
42 val downstreamSet: DownstreamSet = DownstreamSet()
43
44 // TODO: inline DepthTracker? would need to be added to PushNode signature
45 final override val depthTracker = DepthTracker()
46
47 val transactionCache = TransactionCache<MuxResult<W, K, V>>()
48 val epoch
49 get() = transactionCache.epoch
50
51 inline fun hasCurrentValueLocked(evalScope: EvalScope): Boolean = epoch == evalScope.epoch
52
53 override fun hasCurrentValue(logIndent: Int, evalScope: EvalScope): Boolean =
54 hasCurrentValueLocked(evalScope)
55
56 final override fun addDownstream(downstream: Schedulable) {
57 addDownstreamLocked(downstream)
58 }
59
60 /**
61 * Adds a downstream schedulable to this mux node, such that when this mux node emits a value,
62 * it will be scheduled for evaluation within this same transaction.
63 *
64 * Must only be called when [mutex] is acquired.
65 */
66 fun addDownstreamLocked(downstream: Schedulable) {
67 downstreamSet.add(downstream)
68 }
69
70 final override fun removeDownstream(downstream: Schedulable) {
71 // TODO: return boolean?
72 downstreamSet.remove(downstream)
73 }
74
75 final override fun removeDownstreamAndDeactivateIfNeeded(downstream: Schedulable) {
76 downstreamSet.remove(downstream)
77 val deactivate = downstreamSet.isEmpty()
78 if (deactivate) {
79 doDeactivate()
80 }
81 }
82
83 final override fun deactivateIfNeeded() {
84 if (downstreamSet.isEmpty()) {
85 doDeactivate()
86 }
87 }
88
89 /** visit this node from the scheduler (push eval) */
90 abstract fun visit(logIndent: Int, evalScope: EvalScope)
91
92 /** perform deactivation logic, propagating to all upstream nodes. */
93 protected abstract fun doDeactivate()
94
95 final override fun scheduleDeactivationIfNeeded(evalScope: EvalScope) {
96 if (downstreamSet.isEmpty()) {
97 evalScope.scheduleDeactivation(this)
98 }
99 }
100
101 fun adjustDirectUpstream(scheduler: Scheduler, oldDepth: Int, newDepth: Int) {
102
103 if (depthTracker.addDirectUpstream(oldDepth, newDepth)) {
104 depthTracker.schedule(scheduler, this)
105 }
106 }
107
108 fun moveIndirectUpstreamToDirect(
109 scheduler: Scheduler,
110 oldIndirectDepth: Int,
111 oldIndirectRoots: Set<MuxDeferredNode<*, *, *>>,
112 newDepth: Int,
113 ) {
114 if (
115 depthTracker.addDirectUpstream(oldDepth = null, newDepth) or
116 depthTracker.removeIndirectUpstream(depth = oldIndirectDepth) or
117 depthTracker.updateIndirectRoots(removals = oldIndirectRoots)
118 ) {
119 depthTracker.schedule(scheduler, this)
120 }
121 }
122
123 fun adjustIndirectUpstream(
124 scheduler: Scheduler,
125 oldDepth: Int,
126 newDepth: Int,
127 removals: Set<MuxDeferredNode<*, *, *>>,
128 additions: Set<MuxDeferredNode<*, *, *>>,
129 ) {
130 if (
131 depthTracker.addIndirectUpstream(oldDepth, newDepth) or
132 depthTracker.updateIndirectRoots(
133 additions,
134 removals,
135 butNot = this as? MuxDeferredNode<*, *, *>,
136 )
137 ) {
138 depthTracker.schedule(scheduler, this)
139 }
140 }
141
142 fun moveDirectUpstreamToIndirect(
143 scheduler: Scheduler,
144 oldDepth: Int,
145 newDepth: Int,
146 newIndirectSet: Set<MuxDeferredNode<*, *, *>>,
147 ) {
148 if (
149 depthTracker.addIndirectUpstream(oldDepth = null, newDepth) or
150 depthTracker.removeDirectUpstream(oldDepth) or
151 depthTracker.updateIndirectRoots(
152 additions = newIndirectSet,
153 butNot = this as? MuxDeferredNode<*, *, *>,
154 )
155 ) {
156 depthTracker.schedule(scheduler, this)
157 }
158 }
159
160 fun removeDirectUpstream(scheduler: Scheduler, depth: Int, key: K) {
161 switchedIn.remove(key)
162 if (depthTracker.removeDirectUpstream(depth)) {
163 depthTracker.schedule(scheduler, this)
164 }
165 }
166
167 fun removeIndirectUpstream(
168 scheduler: Scheduler,
169 oldDepth: Int,
170 indirectSet: Set<MuxDeferredNode<*, *, *>>,
171 key: K,
172 ) {
173 switchedIn.remove(key)
174 if (
175 depthTracker.removeIndirectUpstream(oldDepth) or
176 depthTracker.updateIndirectRoots(removals = indirectSet)
177 ) {
178 depthTracker.schedule(scheduler, this)
179 }
180 }
181
182 fun visitCompact(scheduler: Scheduler) {
183 if (depthTracker.isDirty()) {
184 depthTracker.applyChanges(scheduler, downstreamSet, this@MuxNode)
185 }
186 }
187
188 fun schedule(evalScope: EvalScope) {
189 // TODO: Potential optimization
190 // Detect if this node is guaranteed to have a single upstream within this transaction,
191 // then bypass scheduling it. Instead immediately schedule its downstream and treat this
192 // MuxNode as a Pull (effectively making it a mapCheap).
193 depthTracker.schedule(evalScope.scheduler, this)
194 }
195
196 /** An input branch of a mux node, associated with a key. */
197 inner class BranchNode(val key: K) : SchedulableNode {
198
199 val schedulable = Schedulable.N(this)
200
201 lateinit var upstream: NodeConnection<V>
202
203 override fun schedule(logIndent: Int, evalScope: EvalScope) {
204 logDuration(logIndent, "MuxBranchNode.schedule") {
205 if (this@MuxNode is MuxPromptNode && this@MuxNode.name != null) {
206 logLn("[${this@MuxNode}] scheduling $key")
207 }
208 upstreamData[key] = upstream.directUpstream
209 this@MuxNode.schedule(evalScope)
210 }
211 }
212
213 override fun adjustDirectUpstream(scheduler: Scheduler, oldDepth: Int, newDepth: Int) {
214 this@MuxNode.adjustDirectUpstream(scheduler, oldDepth, newDepth)
215 }
216
217 override fun moveIndirectUpstreamToDirect(
218 scheduler: Scheduler,
219 oldIndirectDepth: Int,
220 oldIndirectSet: Set<MuxDeferredNode<*, *, *>>,
221 newDirectDepth: Int,
222 ) {
223 this@MuxNode.moveIndirectUpstreamToDirect(
224 scheduler,
225 oldIndirectDepth,
226 oldIndirectSet,
227 newDirectDepth,
228 )
229 }
230
231 override fun adjustIndirectUpstream(
232 scheduler: Scheduler,
233 oldDepth: Int,
234 newDepth: Int,
235 removals: Set<MuxDeferredNode<*, *, *>>,
236 additions: Set<MuxDeferredNode<*, *, *>>,
237 ) {
238 this@MuxNode.adjustIndirectUpstream(scheduler, oldDepth, newDepth, removals, additions)
239 }
240
241 override fun moveDirectUpstreamToIndirect(
242 scheduler: Scheduler,
243 oldDirectDepth: Int,
244 newIndirectDepth: Int,
245 newIndirectSet: Set<MuxDeferredNode<*, *, *>>,
246 ) {
247 this@MuxNode.moveDirectUpstreamToIndirect(
248 scheduler,
249 oldDirectDepth,
250 newIndirectDepth,
251 newIndirectSet,
252 )
253 }
254
255 override fun removeDirectUpstream(scheduler: Scheduler, depth: Int) {
256 removeDirectUpstream(scheduler, depth, key)
257 }
258
259 override fun removeIndirectUpstream(
260 scheduler: Scheduler,
261 depth: Int,
262 indirectSet: Set<MuxDeferredNode<*, *, *>>,
263 ) {
264 removeIndirectUpstream(scheduler, depth, indirectSet, key)
265 }
266
267 override fun toString(): String = "MuxBranchNode(key=$key, mux=${this@MuxNode})"
268 }
269 }
270
271 internal typealias BranchNode<W, K, V> = MuxNode<W, K, V>.BranchNode
272
273 /** Tracks lifecycle of MuxNode in the network. Essentially a mutable ref for MuxLifecycleState. */
274 internal class MuxLifecycle<W, K, V>(var lifecycleState: MuxLifecycleState<W, K, V>) :
275 EventsImpl<MuxResult<W, K, V>> {
276
toStringnull277 override fun toString(): String = "MuxLifecycle[$hashString][$lifecycleState]"
278
279 override fun activate(
280 evalScope: EvalScope,
281 downstream: Schedulable,
282 ): ActivationResult<MuxResult<W, K, V>>? =
283 when (val state = lifecycleState) {
284 is MuxLifecycleState.Dead -> {
285 null
286 }
287 is MuxLifecycleState.Active -> {
288 state.node.addDownstreamLocked(downstream)
289 ActivationResult(
290 connection = NodeConnection(state.node, state.node),
291 needsEval = state.node.hasCurrentValueLocked(evalScope),
292 )
293 }
294 is MuxLifecycleState.Inactive -> {
295 state.spec
296 .activate(evalScope, this@MuxLifecycle)
297 .also { node ->
298 lifecycleState =
299 if (node == null) {
300 MuxLifecycleState.Dead
301 } else {
302 MuxLifecycleState.Active(node.first)
303 }
304 }
305 ?.let { (node, postActivate) ->
306 postActivate?.invoke()
307 node.addDownstreamLocked(downstream)
308 ActivationResult(connection = NodeConnection(node, node), needsEval = false)
309 }
310 }
311 }
312 }
313
314 internal sealed interface MuxLifecycleState<out W, out K, out V> {
315 class Inactive<W, K, V>(val spec: MuxActivator<W, K, V>) : MuxLifecycleState<W, K, V> {
toStringnull316 override fun toString(): String = "Inactive"
317 }
318
319 class Active<W, K, V>(val node: MuxNode<W, K, V>) : MuxLifecycleState<W, K, V> {
320 override fun toString(): String = "Active(node=$node)"
321 }
322
323 data object Dead : MuxLifecycleState<Nothing, Nothing, Nothing>
324 }
325
326 internal interface MuxActivator<W, K, V> {
activatenull327 fun activate(
328 evalScope: EvalScope,
329 lifecycle: MuxLifecycle<W, K, V>,
330 ): Pair<MuxNode<W, K, V>, (() -> Unit)?>?
331 }
332
333 internal inline fun <W, K, V> MuxLifecycle(
334 onSubscribe: MuxActivator<W, K, V>
335 ): EventsImpl<MuxResult<W, K, V>> = MuxLifecycle(MuxLifecycleState.Inactive(onSubscribe))
336
337 internal fun <K, V> EventsImpl<MuxResult<MapHolder.W, K, V>>.awaitValues(): EventsImpl<Map<K, V>> =
338 mapImpl({ this@awaitValues }) { results, logIndent ->
<lambda>null339 results.asMapHolder().unwrapped.mapValues { it.value.getPushEvent(logIndent, this) }
340 }
341
342 // activation logic
343
initializeUpstreamnull344 internal fun <W, K, V> MuxNode<W, K, V>.initializeUpstream(
345 evalScope: EvalScope,
346 getStorage: EvalScope.() -> Iterable<Map.Entry<K, EventsImpl<V>>>,
347 storeFactory: MutableMapK.Factory<W, K>,
348 ) {
349 val storage = getStorage(evalScope)
350 val initUpstream = buildList {
351 storage.forEach { (key, events) ->
352 val branchNode = BranchNode(key)
353 add(
354 events.activate(evalScope, branchNode.schedulable)?.let { (conn, needsEval) ->
355 Triple(
356 key,
357 branchNode.apply { upstream = conn },
358 if (needsEval) conn.directUpstream else null,
359 )
360 }
361 )
362 }
363 }
364 switchedIn = storeFactory.create(initUpstream.size)
365 upstreamData = storeFactory.create(initUpstream.size)
366 for (triple in initUpstream) {
367 triple?.let { (key, branch, upstream) ->
368 switchedIn[key] = branch
369 upstream?.let { upstreamData[key] = upstream }
370 }
371 }
372 }
373
initializeDepthnull374 internal fun <W, K, V> MuxNode<W, K, V>.initializeDepth() {
375 switchedIn.forEach { (_, branch) ->
376 val conn = branch.upstream
377 if (conn.depthTracker.snapshotIsDirect) {
378 depthTracker.addDirectUpstream(
379 oldDepth = null,
380 newDepth = conn.depthTracker.snapshotDirectDepth,
381 )
382 } else {
383 depthTracker.addIndirectUpstream(
384 oldDepth = null,
385 newDepth = conn.depthTracker.snapshotIndirectDepth,
386 )
387 depthTracker.updateIndirectRoots(
388 additions = conn.depthTracker.snapshotIndirectRoots,
389 butNot = null,
390 )
391 }
392 }
393 }
394