1 /* 2 * Copyright 2021 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package androidx.compose.runtime.internal 18 19 import androidx.compose.runtime.platform.makeSynchronizedObject 20 import androidx.compose.runtime.platform.synchronized 21 22 /** 23 * This is similar to a [ThreadLocal] but has lower overhead because it avoids a weak reference. 24 * This should only be used when the writes are delimited by a try...finally call that will clean up 25 * the reference such as [androidx.compose.runtime.snapshots.Snapshot.enter] else the reference 26 * could get pinned by the thread local causing a leak. 27 * 28 * [ThreadLocal] can be used to implement the actual for platforms that do not exhibit the same 29 * overhead for thread locals as the JVM and ART. 30 */ 31 internal class SnapshotThreadLocal<T> { 32 private val map = AtomicReference(emptyThreadMap) 33 private val writeMutex = makeSynchronizedObject() 34 35 private var mainThreadValue: T? = null 36 37 @Suppress("UNCHECKED_CAST") getnull38 fun get(): T? { 39 val threadId = currentThreadId() 40 return if (threadId == MainThreadId) { 41 mainThreadValue 42 } else { 43 map.get().get(threadId) as T? 44 } 45 } 46 setnull47 fun set(value: T?) { 48 val key = currentThreadId() 49 if (key == MainThreadId) { 50 mainThreadValue = value 51 } else { 52 synchronized(writeMutex) { 53 val current = map.get() 54 if (current.trySet(key, value)) return 55 map.set(current.newWith(key, value)) 56 } 57 } 58 } 59 } 60 61 internal class ThreadMap( 62 private val size: Int, 63 private val keys: LongArray, 64 private val values: Array<Any?> 65 ) { getnull66 fun get(key: Long): Any? { 67 val index = find(key) 68 return if (index >= 0) values[index] else null 69 } 70 71 /** 72 * Set the value if it is already in the map. Otherwise a new map must be allocated to contain 73 * the new entry. 74 */ trySetnull75 fun trySet(key: Long, value: Any?): Boolean { 76 val index = find(key) 77 if (index < 0) return false 78 values[index] = value 79 return true 80 } 81 newWithnull82 fun newWith(key: Long, value: Any?): ThreadMap { 83 val size = size 84 val newSize = values.count { it != null } + 1 85 val newKeys = LongArray(newSize) 86 val newValues = arrayOfNulls<Any?>(newSize) 87 if (newSize > 1) { 88 var dest = 0 89 var source = 0 90 while (dest < newSize && source < size) { 91 val oldKey = keys[source] 92 val oldValue = values[source] 93 if (oldKey > key) { 94 newKeys[dest] = key 95 newValues[dest] = value 96 dest++ 97 // Continue with a loop without this check 98 break 99 } 100 if (oldValue != null) { 101 newKeys[dest] = oldKey 102 newValues[dest] = oldValue 103 dest++ 104 } 105 source++ 106 } 107 if (source == size) { 108 // Appending a value to the end. 109 newKeys[newSize - 1] = key 110 newValues[newSize - 1] = value 111 } else { 112 while (dest < newSize) { 113 val oldKey = keys[source] 114 val oldValue = values[source] 115 if (oldValue != null) { 116 newKeys[dest] = oldKey 117 newValues[dest] = oldValue 118 dest++ 119 } 120 source++ 121 } 122 } 123 } else { 124 // The only element 125 newKeys[0] = key 126 newValues[0] = value 127 } 128 return ThreadMap(newSize, newKeys, newValues) 129 } 130 findnull131 private fun find(key: Long): Int { 132 var high = size - 1 133 when (high) { 134 -1 -> return -1 135 0 -> return if (keys[0] == key) 0 else if (keys[0] > key) -2 else -1 136 } 137 var low = 0 138 139 while (low <= high) { 140 val mid = (low + high).ushr(1) 141 val midVal = keys[mid] 142 val comparison = midVal - key 143 when { 144 comparison < 0 -> low = mid + 1 145 comparison > 0 -> high = mid - 1 146 else -> return mid 147 } 148 } 149 return -(low + 1) 150 } 151 } 152 153 private val emptyThreadMap = ThreadMap(0, LongArray(0), emptyArray()) 154