1 /*
2  * Copyright 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.compose.runtime.internal
18 
19 import androidx.compose.runtime.platform.makeSynchronizedObject
20 import androidx.compose.runtime.platform.synchronized
21 
22 /**
23  * This is similar to a [ThreadLocal] but has lower overhead because it avoids a weak reference.
24  * This should only be used when the writes are delimited by a try...finally call that will clean up
25  * the reference such as [androidx.compose.runtime.snapshots.Snapshot.enter] else the reference
26  * could get pinned by the thread local causing a leak.
27  *
28  * [ThreadLocal] can be used to implement the actual for platforms that do not exhibit the same
29  * overhead for thread locals as the JVM and ART.
30  */
31 internal class SnapshotThreadLocal<T> {
32     private val map = AtomicReference(emptyThreadMap)
33     private val writeMutex = makeSynchronizedObject()
34 
35     private var mainThreadValue: T? = null
36 
37     @Suppress("UNCHECKED_CAST")
getnull38     fun get(): T? {
39         val threadId = currentThreadId()
40         return if (threadId == MainThreadId) {
41             mainThreadValue
42         } else {
43             map.get().get(threadId) as T?
44         }
45     }
46 
setnull47     fun set(value: T?) {
48         val key = currentThreadId()
49         if (key == MainThreadId) {
50             mainThreadValue = value
51         } else {
52             synchronized(writeMutex) {
53                 val current = map.get()
54                 if (current.trySet(key, value)) return
55                 map.set(current.newWith(key, value))
56             }
57         }
58     }
59 }
60 
61 internal class ThreadMap(
62     private val size: Int,
63     private val keys: LongArray,
64     private val values: Array<Any?>
65 ) {
getnull66     fun get(key: Long): Any? {
67         val index = find(key)
68         return if (index >= 0) values[index] else null
69     }
70 
71     /**
72      * Set the value if it is already in the map. Otherwise a new map must be allocated to contain
73      * the new entry.
74      */
trySetnull75     fun trySet(key: Long, value: Any?): Boolean {
76         val index = find(key)
77         if (index < 0) return false
78         values[index] = value
79         return true
80     }
81 
newWithnull82     fun newWith(key: Long, value: Any?): ThreadMap {
83         val size = size
84         val newSize = values.count { it != null } + 1
85         val newKeys = LongArray(newSize)
86         val newValues = arrayOfNulls<Any?>(newSize)
87         if (newSize > 1) {
88             var dest = 0
89             var source = 0
90             while (dest < newSize && source < size) {
91                 val oldKey = keys[source]
92                 val oldValue = values[source]
93                 if (oldKey > key) {
94                     newKeys[dest] = key
95                     newValues[dest] = value
96                     dest++
97                     // Continue with a loop without this check
98                     break
99                 }
100                 if (oldValue != null) {
101                     newKeys[dest] = oldKey
102                     newValues[dest] = oldValue
103                     dest++
104                 }
105                 source++
106             }
107             if (source == size) {
108                 // Appending a value to the end.
109                 newKeys[newSize - 1] = key
110                 newValues[newSize - 1] = value
111             } else {
112                 while (dest < newSize) {
113                     val oldKey = keys[source]
114                     val oldValue = values[source]
115                     if (oldValue != null) {
116                         newKeys[dest] = oldKey
117                         newValues[dest] = oldValue
118                         dest++
119                     }
120                     source++
121                 }
122             }
123         } else {
124             // The only element
125             newKeys[0] = key
126             newValues[0] = value
127         }
128         return ThreadMap(newSize, newKeys, newValues)
129     }
130 
findnull131     private fun find(key: Long): Int {
132         var high = size - 1
133         when (high) {
134             -1 -> return -1
135             0 -> return if (keys[0] == key) 0 else if (keys[0] > key) -2 else -1
136         }
137         var low = 0
138 
139         while (low <= high) {
140             val mid = (low + high).ushr(1)
141             val midVal = keys[mid]
142             val comparison = midVal - key
143             when {
144                 comparison < 0 -> low = mid + 1
145                 comparison > 0 -> high = mid - 1
146                 else -> return mid
147             }
148         }
149         return -(low + 1)
150     }
151 }
152 
153 private val emptyThreadMap = ThreadMap(0, LongArray(0), emptyArray())
154