1 /*
<lambda>null2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package com.android.systemui.kairos.internal
18
19 import com.android.systemui.kairos.internal.store.ConcurrentHashMapK
20 import com.android.systemui.kairos.internal.store.MutableArrayMapK
21 import com.android.systemui.kairos.internal.store.MutableMapK
22 import com.android.systemui.kairos.internal.store.StoreEntry
23 import com.android.systemui.kairos.internal.util.hashString
24 import com.android.systemui.kairos.util.Maybe
25
26 internal open class StateImpl<out A>(
27 val name: String?,
28 val operatorName: String,
29 val changes: EventsImpl<A>,
30 val store: StateStore<A>,
31 ) {
32 fun getCurrentWithEpoch(evalScope: EvalScope): Pair<A, Long> =
33 store.getCurrentWithEpoch(evalScope)
34 }
35
36 internal sealed class StateDerived<A> : StateStore<A>() {
37
38 @Volatile
39 var invalidatedEpoch = Long.MIN_VALUE
40 private set
41
42 @Volatile
43 protected var validatedEpoch = Long.MIN_VALUE
44 private set
45
46 @Volatile
47 protected var cache: Any? = EmptyCache
48 private set
49
50 private val transactionCache = TransactionCache<Lazy<Pair<A, Long>>>()
51
getCurrentWithEpochnull52 override fun getCurrentWithEpoch(evalScope: EvalScope): Pair<A, Long> =
53 transactionCache.getOrPut(evalScope) { evalScope.deferAsync { pull(evalScope) } }.value
54
pullnull55 fun pull(evalScope: EvalScope): Pair<A, Long> {
56 @Suppress("UNCHECKED_CAST")
57 val result =
58 recalc(evalScope)?.let { (newValue, epoch) ->
59 newValue.also {
60 if (epoch > validatedEpoch) {
61 validatedEpoch = epoch
62 if (cache != newValue) {
63 cache = newValue
64 invalidatedEpoch = epoch
65 }
66 }
67 }
68 } ?: (cache as A)
69 return result to invalidatedEpoch
70 }
71
getCachedUnsafenull72 fun getCachedUnsafe(): Maybe<A> {
73 @Suppress("UNCHECKED_CAST")
74 return if (cache == EmptyCache) Maybe.absent else Maybe.present(cache as A)
75 }
76
recalcnull77 protected abstract fun recalc(evalScope: EvalScope): Pair<A, Long>?
78
79 fun setCacheFromPush(value: A, epoch: Long) {
80 cache = value
81 validatedEpoch = epoch + 1
82 invalidatedEpoch = epoch + 1
83 }
84
85 private data object EmptyCache
86 }
87
88 internal sealed class StateStore<out S> {
getCurrentWithEpochnull89 abstract fun getCurrentWithEpoch(evalScope: EvalScope): Pair<S, Long>
90 }
91
92 internal class StateSource<S>(init: Lazy<S>) : StateStore<S>() {
93 constructor(init: S) : this(CompletableLazy(init))
94
95 lateinit var upstreamConnection: NodeConnection<S>
96
97 // Note: Don't need to synchronize; we will never interleave reads and writes, since all writes
98 // are performed at the end of a network step, after any reads would have taken place.
99
100 @Volatile private var _current: Lazy<S> = init
101
102 @Volatile
103 var writeEpoch = 0L
104 private set
105
106 override fun getCurrentWithEpoch(evalScope: EvalScope): Pair<S, Long> =
107 _current.value to writeEpoch
108
109 /** called by network after eval phase has completed */
110 fun updateState(logIndent: Int, evalScope: EvalScope) {
111 // write the latch
112 _current = CompletableLazy(upstreamConnection.getPushEvent(logIndent, evalScope))
113 writeEpoch = evalScope.epoch + 1
114 }
115
116 override fun toString(): String = "StateImpl(current=$_current, writeEpoch=$writeEpoch)"
117
118 fun getStorageUnsafe(): Maybe<S> =
119 if (_current.isInitialized()) Maybe.present(_current.value) else Maybe.absent
120 }
121
constStatenull122 internal fun <A> constState(name: String?, operatorName: String, init: A): StateImpl<A> =
123 StateImpl(name, operatorName, neverImpl, StateSource(init))
124
125 internal inline fun <A> activatedStateSource(
126 name: String?,
127 operatorName: String,
128 evalScope: EvalScope,
129 crossinline getChanges: EvalScope.() -> EventsImpl<A>,
130 init: Lazy<A>,
131 ): StateImpl<A> {
132 val store = StateSource(init)
133 val calm: EventsImpl<A> =
134 filterImpl(getChanges) { new -> new != store.getCurrentWithEpoch(evalScope = this).first }
135 evalScope.scheduleOutput(
136 OneShot {
137 calm.activate(evalScope = this, downstream = Schedulable.S(store))?.let {
138 (connection, needsEval) ->
139 store.upstreamConnection = connection
140 if (needsEval) {
141 schedule(store)
142 }
143 }
144 }
145 )
146 return StateImpl(name, operatorName, calm, store)
147 }
148
calmnull149 private inline fun <A> EventsImpl<A>.calm(state: StateDerived<A>): EventsImpl<A> =
150 filterImpl({ this@calm }) { new ->
151 val (current, _) = state.getCurrentWithEpoch(evalScope = this)
152 if (new != current) {
153 state.setCacheFromPush(new, epoch)
154 true
155 } else {
156 false
157 }
158 }
159 .cached()
160
mapStateImplCheapnull161 internal fun <A, B> mapStateImplCheap(
162 stateImpl: Init<StateImpl<A>>,
163 name: String?,
164 operatorName: String,
165 transform: EvalScope.(A) -> B,
166 ): StateImpl<B> =
167 StateImpl(
168 name = name,
169 operatorName = operatorName,
170 changes = mapImpl({ stateImpl.connect(this).changes }) { it, _ -> transform(it) },
171 store = DerivedMapCheap(stateImpl, transform),
172 )
173
174 internal class DerivedMapCheap<A, B>(
175 val upstream: Init<StateImpl<A>>,
176 private val transform: EvalScope.(A) -> B,
177 ) : StateStore<B>() {
178
getCurrentWithEpochnull179 override fun getCurrentWithEpoch(evalScope: EvalScope): Pair<B, Long> {
180 val (a, epoch) = upstream.connect(evalScope).getCurrentWithEpoch(evalScope)
181 return evalScope.transform(a) to epoch
182 }
183
toStringnull184 override fun toString(): String = "${this::class.simpleName}@$hashString"
185 }
186
187 internal fun <A, B> mapStateImpl(
188 stateImpl: InitScope.() -> StateImpl<A>,
189 name: String?,
190 operatorName: String,
191 transform: EvalScope.(A) -> B,
192 ): StateImpl<B> {
193 val store = DerivedMap(stateImpl, transform)
194 val mappedChanges =
195 mapImpl({ stateImpl().changes }) { it, _ -> transform(it) }.cached().calm(store)
196 return StateImpl(name, operatorName, mappedChanges, store)
197 }
198
199 internal class DerivedMap<A, B>(
200 val upstream: InitScope.() -> StateImpl<A>,
201 private val transform: EvalScope.(A) -> B,
202 ) : StateDerived<B>() {
toStringnull203 override fun toString(): String = "${this::class.simpleName}@$hashString"
204
205 override fun recalc(evalScope: EvalScope): Pair<B, Long>? {
206 val (a, epoch) = evalScope.upstream().getCurrentWithEpoch(evalScope)
207 return if (epoch > validatedEpoch) {
208 evalScope.transform(a) to epoch
209 } else {
210 null
211 }
212 }
213 }
214
flattenStateImplnull215 internal fun <A> flattenStateImpl(
216 stateImpl: InitScope.() -> StateImpl<StateImpl<A>>,
217 name: String?,
218 operator: String,
219 ): StateImpl<A> {
220 // emits the current value of the new inner state, when that state is emitted
221 val switchEvents =
222 mapImpl({ stateImpl().changes }) { newInner, _ -> newInner.getCurrentWithEpoch(this).first }
223 // emits the new value of the new inner state when that state is emitted, or
224 // falls back to the current value if a new state is *not* being emitted this
225 // transaction
226 val innerChanges =
227 mapImpl({ stateImpl().changes }) { newInner, _ ->
228 mergeNodes({ switchEvents }, { newInner.changes }) { _, new -> new }
229 }
230 val switchedChanges: EventsImpl<A> =
231 switchPromptImplSingle(
232 getStorage = { stateImpl().getCurrentWithEpoch(evalScope = this).first.changes },
233 getPatches = { innerChanges },
234 )
235 val store: DerivedFlatten<A> = DerivedFlatten(stateImpl)
236 return StateImpl(name, operator, switchedChanges.calm(store), store)
237 }
238
239 internal class DerivedFlatten<A>(val upstream: InitScope.() -> StateImpl<StateImpl<A>>) :
240 StateDerived<A>() {
recalcnull241 override fun recalc(evalScope: EvalScope): Pair<A, Long> {
242 val (inner, epoch0) = evalScope.upstream().getCurrentWithEpoch(evalScope)
243 val (a, epoch1) = inner.getCurrentWithEpoch(evalScope)
244 return a to maxOf(epoch0, epoch1)
245 }
246
toStringnull247 override fun toString(): String = "${this::class.simpleName}@$hashString"
248 }
249
250 @Suppress("NOTHING_TO_INLINE")
251 internal inline fun <A, B> flatMapStateImpl(
252 noinline stateImpl: InitScope.() -> StateImpl<A>,
253 name: String?,
254 operatorName: String,
255 noinline transform: EvalScope.(A) -> StateImpl<B>,
256 ): StateImpl<B> {
257 val mapped = mapStateImpl(stateImpl, null, operatorName, transform)
258 return flattenStateImpl({ mapped }, name, operatorName)
259 }
260
zipStatesnull261 internal fun <A, B, Z> zipStates(
262 name: String?,
263 operatorName: String,
264 l1: Init<StateImpl<A>>,
265 l2: Init<StateImpl<B>>,
266 transform: EvalScope.(A, B) -> Z,
267 ): StateImpl<Z> {
268 val zipped =
269 zipStateList(
270 null,
271 operatorName,
272 2,
273 init(null) { listOf(l1.connect(this), l2.connect(this)) },
274 )
275 return mapStateImpl({ zipped }, name, operatorName) {
276 @Suppress("UNCHECKED_CAST") transform(it[0] as A, it[1] as B)
277 }
278 }
279
zipStatesnull280 internal fun <A, B, C, Z> zipStates(
281 name: String?,
282 operatorName: String,
283 l1: Init<StateImpl<A>>,
284 l2: Init<StateImpl<B>>,
285 l3: Init<StateImpl<C>>,
286 transform: EvalScope.(A, B, C) -> Z,
287 ): StateImpl<Z> {
288 val zipped =
289 zipStateList(
290 null,
291 operatorName,
292 3,
293 init(null) { listOf(l1.connect(this), l2.connect(this), l3.connect(this)) },
294 )
295 return mapStateImpl({ zipped }, name, operatorName) {
296 @Suppress("UNCHECKED_CAST") transform(it[0] as A, it[1] as B, it[2] as C)
297 }
298 }
299
zipStatesnull300 internal fun <A, B, C, D, Z> zipStates(
301 name: String?,
302 operatorName: String,
303 l1: Init<StateImpl<A>>,
304 l2: Init<StateImpl<B>>,
305 l3: Init<StateImpl<C>>,
306 l4: Init<StateImpl<D>>,
307 transform: EvalScope.(A, B, C, D) -> Z,
308 ): StateImpl<Z> {
309 val zipped =
310 zipStateList(
311 null,
312 operatorName,
313 4,
314 init(null) {
315 listOf(l1.connect(this), l2.connect(this), l3.connect(this), l4.connect(this))
316 },
317 )
318 return mapStateImpl({ zipped }, name, operatorName) {
319 @Suppress("UNCHECKED_CAST") transform(it[0] as A, it[1] as B, it[2] as C, it[3] as D)
320 }
321 }
322
zipStatesnull323 internal fun <A, B, C, D, E, Z> zipStates(
324 name: String?,
325 operatorName: String,
326 l1: Init<StateImpl<A>>,
327 l2: Init<StateImpl<B>>,
328 l3: Init<StateImpl<C>>,
329 l4: Init<StateImpl<D>>,
330 l5: Init<StateImpl<E>>,
331 transform: EvalScope.(A, B, C, D, E) -> Z,
332 ): StateImpl<Z> {
333 val zipped =
334 zipStateList(
335 null,
336 operatorName,
337 5,
338 init(null) {
339 listOf(
340 l1.connect(this),
341 l2.connect(this),
342 l3.connect(this),
343 l4.connect(this),
344 l5.connect(this),
345 )
346 },
347 )
348 return mapStateImpl({ zipped }, name, operatorName) {
349 @Suppress("UNCHECKED_CAST")
350 transform(it[0] as A, it[1] as B, it[2] as C, it[3] as D, it[4] as E)
351 }
352 }
353
zipStateMapnull354 internal fun <K, V> zipStateMap(
355 name: String?,
356 operatorName: String,
357 numStates: Int,
358 states: Init<Map<K, StateImpl<V>>>,
359 ): StateImpl<Map<K, V>> =
360 zipStates(
361 name = name,
362 operatorName = operatorName,
363 numStates = numStates,
364 states = init(null) { states.connect(this).asIterable() },
365 storeFactory = ConcurrentHashMapK.Factory(),
366 )
367
zipStateListnull368 internal fun <V> zipStateList(
369 name: String?,
370 operatorName: String,
371 numStates: Int,
372 states: Init<List<StateImpl<V>>>,
373 ): StateImpl<List<V>> {
374 val zipped =
375 zipStates(
376 name = name,
377 operatorName = operatorName,
378 numStates = numStates,
379 states = init(name) { states.connect(this).asIterableWithIndex() },
380 storeFactory = MutableArrayMapK.Factory(),
381 )
382 // Like mapCheap, but with caching (or like map, but without the calm changes, as they are not
383 // necessary).
384 return StateImpl(
385 name = name,
386 operatorName = operatorName,
387 changes = mapImpl({ zipped.changes }) { arrayStore, _ -> arrayStore.values.toList() },
388 DerivedMap(upstream = { zipped }, transform = { arrayStore -> arrayStore.values.toList() }),
389 )
390 }
391
zipStatesnull392 internal fun <W, K, A> zipStates(
393 name: String?,
394 operatorName: String,
395 numStates: Int,
396 states: Init<Iterable<Map.Entry<K, StateImpl<A>>>>,
397 storeFactory: MutableMapK.Factory<W, K>,
398 ): StateImpl<MutableMapK<W, K, A>> {
399 if (numStates == 0) {
400 return constState(name, operatorName, storeFactory.create(0))
401 }
402 val stateStore = DerivedZipped(numStates, states, storeFactory)
403 // No need for calm; invariant ensures that changes will only emit when there's a difference
404 val switchDeferredImpl =
405 switchDeferredImpl(
406 getStorage = {
407 states
408 .connect(this)
409 .asSequence()
410 .map { (k, v) -> StoreEntry(k, v.changes) }
411 .asIterable()
412 },
413 getPatches = { neverImpl },
414 storeFactory = storeFactory,
415 )
416 val changes =
417 mapImpl({ switchDeferredImpl }) { patch, logIndent ->
418 val muxStore = storeFactory.create<A>(numStates)
419 states.connect(this).forEach { (k, state) ->
420 muxStore[k] =
421 if (patch.contains(k)) {
422 patch.getValue(k).getPushEvent(logIndent, evalScope = this@mapImpl)
423 } else {
424 state.getCurrentWithEpoch(evalScope = this@mapImpl).first
425 }
426 }
427 // Read the current value so that it is cached in this transaction and won't be
428 // clobbered by the cache write
429 stateStore.getCurrentWithEpoch(evalScope = this)
430 muxStore.also { stateStore.setCacheFromPush(it, epoch) }
431 }
432 .cached()
433 return StateImpl(name, operatorName, changes, stateStore)
434 }
435
436 internal class DerivedZipped<W, K, A>(
437 private val upstreamSize: Int,
438 val upstream: Init<Iterable<Map.Entry<K, StateImpl<A>>>>,
439 private val storeFactory: MutableMapK.Factory<W, K>,
440 ) : StateDerived<MutableMapK<W, K, A>>() {
recalcnull441 override fun recalc(evalScope: EvalScope): Pair<MutableMapK<W, K, A>, Long> {
442 var newEpoch = 0L
443 val store = storeFactory.create<A>(upstreamSize)
444 for ((key, value) in upstream.connect(evalScope)) {
445 val (a, epoch) = value.getCurrentWithEpoch(evalScope)
446 newEpoch = maxOf(newEpoch, epoch)
447 store[key] = a
448 }
449 return store to newEpoch
450 }
451
toStringnull452 override fun toString(): String = "${this::class.simpleName}@$hashString"
453 }
454
455 @Suppress("NOTHING_TO_INLINE")
456 internal inline fun <A> zipStates(
457 name: String?,
458 operatorName: String,
459 numStates: Int,
460 states: Init<List<StateImpl<A>>>,
461 ): StateImpl<List<A>> =
462 if (numStates <= 0) {
463 constState(name, operatorName, emptyList())
464 } else {
465 zipStateList(null, operatorName, numStates, states)
466 }
467