• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.systemui.kairos.internal
18 
19 import com.android.systemui.kairos.internal.store.ConcurrentHashMapK
20 import com.android.systemui.kairos.internal.store.MutableArrayMapK
21 import com.android.systemui.kairos.internal.store.MutableMapK
22 import com.android.systemui.kairos.internal.store.StoreEntry
23 import com.android.systemui.kairos.internal.util.hashString
24 import com.android.systemui.kairos.util.Maybe
25 
26 internal open class StateImpl<out A>(
27     val name: String?,
28     val operatorName: String,
29     val changes: EventsImpl<A>,
30     val store: StateStore<A>,
31 ) {
32     fun getCurrentWithEpoch(evalScope: EvalScope): Pair<A, Long> =
33         store.getCurrentWithEpoch(evalScope)
34 }
35 
36 internal sealed class StateDerived<A> : StateStore<A>() {
37 
38     @Volatile
39     var invalidatedEpoch = Long.MIN_VALUE
40         private set
41 
42     @Volatile
43     protected var validatedEpoch = Long.MIN_VALUE
44         private set
45 
46     @Volatile
47     protected var cache: Any? = EmptyCache
48         private set
49 
50     private val transactionCache = TransactionCache<Lazy<Pair<A, Long>>>()
51 
getCurrentWithEpochnull52     override fun getCurrentWithEpoch(evalScope: EvalScope): Pair<A, Long> =
53         transactionCache.getOrPut(evalScope) { evalScope.deferAsync { pull(evalScope) } }.value
54 
pullnull55     fun pull(evalScope: EvalScope): Pair<A, Long> {
56         @Suppress("UNCHECKED_CAST")
57         val result =
58             recalc(evalScope)?.let { (newValue, epoch) ->
59                 newValue.also {
60                     if (epoch > validatedEpoch) {
61                         validatedEpoch = epoch
62                         if (cache != newValue) {
63                             cache = newValue
64                             invalidatedEpoch = epoch
65                         }
66                     }
67                 }
68             } ?: (cache as A)
69         return result to invalidatedEpoch
70     }
71 
getCachedUnsafenull72     fun getCachedUnsafe(): Maybe<A> {
73         @Suppress("UNCHECKED_CAST")
74         return if (cache == EmptyCache) Maybe.absent else Maybe.present(cache as A)
75     }
76 
recalcnull77     protected abstract fun recalc(evalScope: EvalScope): Pair<A, Long>?
78 
79     fun setCacheFromPush(value: A, epoch: Long) {
80         cache = value
81         validatedEpoch = epoch + 1
82         invalidatedEpoch = epoch + 1
83     }
84 
85     private data object EmptyCache
86 }
87 
88 internal sealed class StateStore<out S> {
getCurrentWithEpochnull89     abstract fun getCurrentWithEpoch(evalScope: EvalScope): Pair<S, Long>
90 }
91 
92 internal class StateSource<S>(init: Lazy<S>) : StateStore<S>() {
93     constructor(init: S) : this(CompletableLazy(init))
94 
95     lateinit var upstreamConnection: NodeConnection<S>
96 
97     // Note: Don't need to synchronize; we will never interleave reads and writes, since all writes
98     // are performed at the end of a network step, after any reads would have taken place.
99 
100     @Volatile private var _current: Lazy<S> = init
101 
102     @Volatile
103     var writeEpoch = 0L
104         private set
105 
106     override fun getCurrentWithEpoch(evalScope: EvalScope): Pair<S, Long> =
107         _current.value to writeEpoch
108 
109     /** called by network after eval phase has completed */
110     fun updateState(logIndent: Int, evalScope: EvalScope) {
111         // write the latch
112         _current = CompletableLazy(upstreamConnection.getPushEvent(logIndent, evalScope))
113         writeEpoch = evalScope.epoch + 1
114     }
115 
116     override fun toString(): String = "StateImpl(current=$_current, writeEpoch=$writeEpoch)"
117 
118     fun getStorageUnsafe(): Maybe<S> =
119         if (_current.isInitialized()) Maybe.present(_current.value) else Maybe.absent
120 }
121 
constStatenull122 internal fun <A> constState(name: String?, operatorName: String, init: A): StateImpl<A> =
123     StateImpl(name, operatorName, neverImpl, StateSource(init))
124 
125 internal inline fun <A> activatedStateSource(
126     name: String?,
127     operatorName: String,
128     evalScope: EvalScope,
129     crossinline getChanges: EvalScope.() -> EventsImpl<A>,
130     init: Lazy<A>,
131 ): StateImpl<A> {
132     val store = StateSource(init)
133     val calm: EventsImpl<A> =
134         filterImpl(getChanges) { new -> new != store.getCurrentWithEpoch(evalScope = this).first }
135     evalScope.scheduleOutput(
136         OneShot {
137             calm.activate(evalScope = this, downstream = Schedulable.S(store))?.let {
138                 (connection, needsEval) ->
139                 store.upstreamConnection = connection
140                 if (needsEval) {
141                     schedule(store)
142                 }
143             }
144         }
145     )
146     return StateImpl(name, operatorName, calm, store)
147 }
148 
calmnull149 private inline fun <A> EventsImpl<A>.calm(state: StateDerived<A>): EventsImpl<A> =
150     filterImpl({ this@calm }) { new ->
151             val (current, _) = state.getCurrentWithEpoch(evalScope = this)
152             if (new != current) {
153                 state.setCacheFromPush(new, epoch)
154                 true
155             } else {
156                 false
157             }
158         }
159         .cached()
160 
mapStateImplCheapnull161 internal fun <A, B> mapStateImplCheap(
162     stateImpl: Init<StateImpl<A>>,
163     name: String?,
164     operatorName: String,
165     transform: EvalScope.(A) -> B,
166 ): StateImpl<B> =
167     StateImpl(
168         name = name,
169         operatorName = operatorName,
170         changes = mapImpl({ stateImpl.connect(this).changes }) { it, _ -> transform(it) },
171         store = DerivedMapCheap(stateImpl, transform),
172     )
173 
174 internal class DerivedMapCheap<A, B>(
175     val upstream: Init<StateImpl<A>>,
176     private val transform: EvalScope.(A) -> B,
177 ) : StateStore<B>() {
178 
getCurrentWithEpochnull179     override fun getCurrentWithEpoch(evalScope: EvalScope): Pair<B, Long> {
180         val (a, epoch) = upstream.connect(evalScope).getCurrentWithEpoch(evalScope)
181         return evalScope.transform(a) to epoch
182     }
183 
toStringnull184     override fun toString(): String = "${this::class.simpleName}@$hashString"
185 }
186 
187 internal fun <A, B> mapStateImpl(
188     stateImpl: InitScope.() -> StateImpl<A>,
189     name: String?,
190     operatorName: String,
191     transform: EvalScope.(A) -> B,
192 ): StateImpl<B> {
193     val store = DerivedMap(stateImpl, transform)
194     val mappedChanges =
195         mapImpl({ stateImpl().changes }) { it, _ -> transform(it) }.cached().calm(store)
196     return StateImpl(name, operatorName, mappedChanges, store)
197 }
198 
199 internal class DerivedMap<A, B>(
200     val upstream: InitScope.() -> StateImpl<A>,
201     private val transform: EvalScope.(A) -> B,
202 ) : StateDerived<B>() {
toStringnull203     override fun toString(): String = "${this::class.simpleName}@$hashString"
204 
205     override fun recalc(evalScope: EvalScope): Pair<B, Long>? {
206         val (a, epoch) = evalScope.upstream().getCurrentWithEpoch(evalScope)
207         return if (epoch > validatedEpoch) {
208             evalScope.transform(a) to epoch
209         } else {
210             null
211         }
212     }
213 }
214 
flattenStateImplnull215 internal fun <A> flattenStateImpl(
216     stateImpl: InitScope.() -> StateImpl<StateImpl<A>>,
217     name: String?,
218     operator: String,
219 ): StateImpl<A> {
220     // emits the current value of the new inner state, when that state is emitted
221     val switchEvents =
222         mapImpl({ stateImpl().changes }) { newInner, _ -> newInner.getCurrentWithEpoch(this).first }
223     // emits the new value of the new inner state when that state is emitted, or
224     // falls back to the current value if a new state is *not* being emitted this
225     // transaction
226     val innerChanges =
227         mapImpl({ stateImpl().changes }) { newInner, _ ->
228             mergeNodes({ switchEvents }, { newInner.changes }) { _, new -> new }
229         }
230     val switchedChanges: EventsImpl<A> =
231         switchPromptImplSingle(
232             getStorage = { stateImpl().getCurrentWithEpoch(evalScope = this).first.changes },
233             getPatches = { innerChanges },
234         )
235     val store: DerivedFlatten<A> = DerivedFlatten(stateImpl)
236     return StateImpl(name, operator, switchedChanges.calm(store), store)
237 }
238 
239 internal class DerivedFlatten<A>(val upstream: InitScope.() -> StateImpl<StateImpl<A>>) :
240     StateDerived<A>() {
recalcnull241     override fun recalc(evalScope: EvalScope): Pair<A, Long> {
242         val (inner, epoch0) = evalScope.upstream().getCurrentWithEpoch(evalScope)
243         val (a, epoch1) = inner.getCurrentWithEpoch(evalScope)
244         return a to maxOf(epoch0, epoch1)
245     }
246 
toStringnull247     override fun toString(): String = "${this::class.simpleName}@$hashString"
248 }
249 
250 @Suppress("NOTHING_TO_INLINE")
251 internal inline fun <A, B> flatMapStateImpl(
252     noinline stateImpl: InitScope.() -> StateImpl<A>,
253     name: String?,
254     operatorName: String,
255     noinline transform: EvalScope.(A) -> StateImpl<B>,
256 ): StateImpl<B> {
257     val mapped = mapStateImpl(stateImpl, null, operatorName, transform)
258     return flattenStateImpl({ mapped }, name, operatorName)
259 }
260 
zipStatesnull261 internal fun <A, B, Z> zipStates(
262     name: String?,
263     operatorName: String,
264     l1: Init<StateImpl<A>>,
265     l2: Init<StateImpl<B>>,
266     transform: EvalScope.(A, B) -> Z,
267 ): StateImpl<Z> {
268     val zipped =
269         zipStateList(
270             null,
271             operatorName,
272             2,
273             init(null) { listOf(l1.connect(this), l2.connect(this)) },
274         )
275     return mapStateImpl({ zipped }, name, operatorName) {
276         @Suppress("UNCHECKED_CAST") transform(it[0] as A, it[1] as B)
277     }
278 }
279 
zipStatesnull280 internal fun <A, B, C, Z> zipStates(
281     name: String?,
282     operatorName: String,
283     l1: Init<StateImpl<A>>,
284     l2: Init<StateImpl<B>>,
285     l3: Init<StateImpl<C>>,
286     transform: EvalScope.(A, B, C) -> Z,
287 ): StateImpl<Z> {
288     val zipped =
289         zipStateList(
290             null,
291             operatorName,
292             3,
293             init(null) { listOf(l1.connect(this), l2.connect(this), l3.connect(this)) },
294         )
295     return mapStateImpl({ zipped }, name, operatorName) {
296         @Suppress("UNCHECKED_CAST") transform(it[0] as A, it[1] as B, it[2] as C)
297     }
298 }
299 
zipStatesnull300 internal fun <A, B, C, D, Z> zipStates(
301     name: String?,
302     operatorName: String,
303     l1: Init<StateImpl<A>>,
304     l2: Init<StateImpl<B>>,
305     l3: Init<StateImpl<C>>,
306     l4: Init<StateImpl<D>>,
307     transform: EvalScope.(A, B, C, D) -> Z,
308 ): StateImpl<Z> {
309     val zipped =
310         zipStateList(
311             null,
312             operatorName,
313             4,
314             init(null) {
315                 listOf(l1.connect(this), l2.connect(this), l3.connect(this), l4.connect(this))
316             },
317         )
318     return mapStateImpl({ zipped }, name, operatorName) {
319         @Suppress("UNCHECKED_CAST") transform(it[0] as A, it[1] as B, it[2] as C, it[3] as D)
320     }
321 }
322 
zipStatesnull323 internal fun <A, B, C, D, E, Z> zipStates(
324     name: String?,
325     operatorName: String,
326     l1: Init<StateImpl<A>>,
327     l2: Init<StateImpl<B>>,
328     l3: Init<StateImpl<C>>,
329     l4: Init<StateImpl<D>>,
330     l5: Init<StateImpl<E>>,
331     transform: EvalScope.(A, B, C, D, E) -> Z,
332 ): StateImpl<Z> {
333     val zipped =
334         zipStateList(
335             null,
336             operatorName,
337             5,
338             init(null) {
339                 listOf(
340                     l1.connect(this),
341                     l2.connect(this),
342                     l3.connect(this),
343                     l4.connect(this),
344                     l5.connect(this),
345                 )
346             },
347         )
348     return mapStateImpl({ zipped }, name, operatorName) {
349         @Suppress("UNCHECKED_CAST")
350         transform(it[0] as A, it[1] as B, it[2] as C, it[3] as D, it[4] as E)
351     }
352 }
353 
zipStateMapnull354 internal fun <K, V> zipStateMap(
355     name: String?,
356     operatorName: String,
357     numStates: Int,
358     states: Init<Map<K, StateImpl<V>>>,
359 ): StateImpl<Map<K, V>> =
360     zipStates(
361         name = name,
362         operatorName = operatorName,
363         numStates = numStates,
364         states = init(null) { states.connect(this).asIterable() },
365         storeFactory = ConcurrentHashMapK.Factory(),
366     )
367 
zipStateListnull368 internal fun <V> zipStateList(
369     name: String?,
370     operatorName: String,
371     numStates: Int,
372     states: Init<List<StateImpl<V>>>,
373 ): StateImpl<List<V>> {
374     val zipped =
375         zipStates(
376             name = name,
377             operatorName = operatorName,
378             numStates = numStates,
379             states = init(name) { states.connect(this).asIterableWithIndex() },
380             storeFactory = MutableArrayMapK.Factory(),
381         )
382     // Like mapCheap, but with caching (or like map, but without the calm changes, as they are not
383     // necessary).
384     return StateImpl(
385         name = name,
386         operatorName = operatorName,
387         changes = mapImpl({ zipped.changes }) { arrayStore, _ -> arrayStore.values.toList() },
388         DerivedMap(upstream = { zipped }, transform = { arrayStore -> arrayStore.values.toList() }),
389     )
390 }
391 
zipStatesnull392 internal fun <W, K, A> zipStates(
393     name: String?,
394     operatorName: String,
395     numStates: Int,
396     states: Init<Iterable<Map.Entry<K, StateImpl<A>>>>,
397     storeFactory: MutableMapK.Factory<W, K>,
398 ): StateImpl<MutableMapK<W, K, A>> {
399     if (numStates == 0) {
400         return constState(name, operatorName, storeFactory.create(0))
401     }
402     val stateStore = DerivedZipped(numStates, states, storeFactory)
403     // No need for calm; invariant ensures that changes will only emit when there's a difference
404     val switchDeferredImpl =
405         switchDeferredImpl(
406             getStorage = {
407                 states
408                     .connect(this)
409                     .asSequence()
410                     .map { (k, v) -> StoreEntry(k, v.changes) }
411                     .asIterable()
412             },
413             getPatches = { neverImpl },
414             storeFactory = storeFactory,
415         )
416     val changes =
417         mapImpl({ switchDeferredImpl }) { patch, logIndent ->
418                 val muxStore = storeFactory.create<A>(numStates)
419                 states.connect(this).forEach { (k, state) ->
420                     muxStore[k] =
421                         if (patch.contains(k)) {
422                             patch.getValue(k).getPushEvent(logIndent, evalScope = this@mapImpl)
423                         } else {
424                             state.getCurrentWithEpoch(evalScope = this@mapImpl).first
425                         }
426                 }
427                 // Read the current value so that it is cached in this transaction and won't be
428                 // clobbered by the cache write
429                 stateStore.getCurrentWithEpoch(evalScope = this)
430                 muxStore.also { stateStore.setCacheFromPush(it, epoch) }
431             }
432             .cached()
433     return StateImpl(name, operatorName, changes, stateStore)
434 }
435 
436 internal class DerivedZipped<W, K, A>(
437     private val upstreamSize: Int,
438     val upstream: Init<Iterable<Map.Entry<K, StateImpl<A>>>>,
439     private val storeFactory: MutableMapK.Factory<W, K>,
440 ) : StateDerived<MutableMapK<W, K, A>>() {
recalcnull441     override fun recalc(evalScope: EvalScope): Pair<MutableMapK<W, K, A>, Long> {
442         var newEpoch = 0L
443         val store = storeFactory.create<A>(upstreamSize)
444         for ((key, value) in upstream.connect(evalScope)) {
445             val (a, epoch) = value.getCurrentWithEpoch(evalScope)
446             newEpoch = maxOf(newEpoch, epoch)
447             store[key] = a
448         }
449         return store to newEpoch
450     }
451 
toStringnull452     override fun toString(): String = "${this::class.simpleName}@$hashString"
453 }
454 
455 @Suppress("NOTHING_TO_INLINE")
456 internal inline fun <A> zipStates(
457     name: String?,
458     operatorName: String,
459     numStates: Int,
460     states: Init<List<StateImpl<A>>>,
461 ): StateImpl<List<A>> =
462     if (numStates <= 0) {
463         constState(name, operatorName, emptyList())
464     } else {
465         zipStateList(null, operatorName, numStates, states)
466     }
467