1 /* <lambda>null2 * Copyright 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package androidx.compose.ui.graphics 18 19 import androidx.annotation.RestrictTo 20 import kotlin.jvm.JvmField 21 import kotlin.math.max 22 import kotlin.math.min 23 24 // TODO: We should probably move this to androidx.collection 25 26 /** 27 * Interval in an [IntervalTree]. The interval is defined between a [start] and an [end] coordinate, 28 * whose meanings are defined by the caller. An interval can also hold arbitrary [data] to be used 29 * to looking at the result of queries with [IntervalTree.findOverlaps]. 30 */ 31 @RestrictTo(RestrictTo.Scope.LIBRARY_GROUP_PREFIX) 32 open class Interval<T>(val start: Float, val end: Float, val data: T? = null) { 33 /** Returns trues if this interval overlaps with another interval. */ 34 fun overlaps(other: Interval<T>) = start <= other.end && end >= other.start 35 36 /** 37 * Returns trues if this interval overlaps with the interval defined by [start] and [end]. 38 * [start] must be less than or equal to [end]. 39 */ 40 fun overlaps(start: Float, end: Float) = this.start <= end && this.end >= start 41 42 /** Returns true if this interval contains [value]. */ 43 operator fun contains(value: Float) = value in start..end 44 45 override fun equals(other: Any?): Boolean { 46 if (this === other) return true 47 if (other == null || this::class != other::class) return false 48 49 other as Interval<*> 50 51 if (start != other.start) return false 52 if (end != other.end) return false 53 if (data != other.data) return false 54 55 return true 56 } 57 58 override fun hashCode(): Int { 59 var result = start.hashCode() 60 result = 31 * result + end.hashCode() 61 result = 31 * result + (data?.hashCode() ?: 0) 62 return result 63 } 64 65 override fun toString(): String { 66 return "Interval(start=$start, end=$end, data=$data)" 67 } 68 } 69 70 /** Represents an empty/invalid interval. */ 71 internal val EmptyInterval: Interval<Any?> = Interval(Float.MAX_VALUE, Float.MIN_VALUE, null) 72 73 /** 74 * An interval tree holds a list of intervals and allows for fast queries of intervals that overlap 75 * any given interval. This can be used for instance to perform fast spatial queries like finding 76 * all the segments in a path that overlap with a given vertical interval. 77 */ 78 @RestrictTo(RestrictTo.Scope.LIBRARY_GROUP_PREFIX) 79 class IntervalTree<T> { 80 // Note: this interval tree is implemented as a binary red/black tree that gets 81 // re-balanced on updates. There's nothing notable about this particular data 82 // structure beyond what can be found in various descriptions of binary search 83 // trees and red/black trees 84 85 @JvmField internal val terminator = Node(Float.MAX_VALUE, Float.MIN_VALUE, null, TreeColorBlack) 86 @JvmField internal var root = terminator 87 @JvmField internal val stack = ArrayList<Node>() 88 89 /** 90 * Clears this tree and prepares it for reuse. After calling [clear], any call to [findOverlaps] 91 * returns false. 92 */ clearnull93 fun clear() { 94 root = terminator 95 } 96 97 /** 98 * Finds the first interval that overlaps with the specified [interval]. If no overlap can be 99 * found, return [EmptyInterval]. 100 */ findFirstOverlapnull101 fun findFirstOverlap(interval: ClosedFloatingPointRange<Float>) = 102 findFirstOverlap(interval.start, interval.endInclusive) 103 104 /** 105 * Finds the first interval that overlaps with the interval defined by [start] and [end]. If no 106 * overlap can be found, return [EmptyInterval]. [start] *must* be lesser than or equal to 107 * [end]. 108 */ 109 fun findFirstOverlap(start: Float, end: Float = start): Interval<T> { 110 if (root !== terminator) { 111 forEach(start, end) { interval -> 112 return interval 113 } 114 } 115 @Suppress("UNCHECKED_CAST") return EmptyInterval as Interval<T> 116 } 117 118 /** 119 * Finds all the intervals that overlap with the specified [interval]. If [results] is 120 * specified, [results] is returned, otherwise a new [MutableList] is returned. 121 */ findOverlapsnull122 fun findOverlaps( 123 interval: ClosedFloatingPointRange<Float>, 124 results: MutableList<Interval<T>> = mutableListOf() 125 ) = findOverlaps(interval.start, interval.endInclusive, results) 126 127 /** 128 * Finds all the intervals that overlap with the interval defined by [start] and [end]. [start] 129 * *must* be lesser than or equal to [end]. If [results] is specified, [results] is returned, 130 * otherwise a new [MutableList] is returned. 131 */ 132 fun findOverlaps( 133 start: Float, 134 end: Float = start, 135 results: MutableList<Interval<T>> = mutableListOf() 136 ): MutableList<Interval<T>> { 137 forEach(start, end) { interval -> results.add(interval) } 138 return results 139 } 140 141 /** Executes [block] for each interval that overlaps the specified [interval]. */ forEachnull142 internal inline fun forEach( 143 interval: ClosedFloatingPointRange<Float>, 144 block: (Interval<T>) -> Unit 145 ) = forEach(interval.start, interval.endInclusive, block) 146 147 /** 148 * Executes [block] for each interval that overlaps with the interval defined by [start] and 149 * [end]. [start] *must* be lesser than or equal to [end]. 150 */ 151 internal inline fun forEach(start: Float, end: Float = start, block: (Interval<T>) -> Unit) { 152 if (root !== terminator) { 153 val s = stack 154 s.add(root) 155 while (s.size > 0) { 156 val node = s.removeAt(s.size - 1) 157 if (node.overlaps(start, end)) block(node) 158 if (node.left !== terminator && node.left.max >= start) { 159 s.add(node.left) 160 } 161 if (node.right !== terminator && node.right.min <= end) { 162 s.add(node.right) 163 } 164 } 165 s.clear() 166 } 167 } 168 169 /** Returns true if [value] is inside any of the intervals in this tree. */ containsnull170 operator fun contains(value: Float) = findFirstOverlap(value, value) !== EmptyInterval 171 172 /** Returns true if the specified [interval] overlaps with any of the intervals in this tree. */ 173 operator fun contains(interval: ClosedFloatingPointRange<Float>) = 174 findFirstOverlap(interval.start, interval.endInclusive) !== EmptyInterval 175 176 operator fun iterator(): Iterator<Interval<T>> { 177 return object : Iterator<Interval<T>> { 178 private var next = root.lowestNode() 179 180 override fun hasNext(): Boolean { 181 return next !== terminator 182 } 183 184 override fun next(): Interval<T> { 185 val node = next 186 next = next.next() 187 return node 188 } 189 } 190 } 191 192 /** Adds the specified [Interval] to the interval tree. */ plusAssignnull193 operator fun plusAssign(interval: Interval<T>) { 194 addInterval(interval.start, interval.end, interval.data) 195 } 196 197 /** 198 * Adds the interval defined between a [start] and an [end] coordinate. 199 * 200 * @param start The start coordinate of the interval 201 * @param end The end coordinate of the interval, must be >= [start] 202 * @param data Data to associate with the interval 203 */ addIntervalnull204 fun addInterval(start: Float, end: Float, data: T?) { 205 val node = Node(start, end, data, TreeColorRed) 206 207 // Update the tree without doing any balancing 208 var current = root 209 var parent = terminator 210 211 while (current !== terminator) { 212 parent = current 213 current = 214 if (node.start <= current.start) { 215 current.left 216 } else { 217 current.right 218 } 219 } 220 221 node.parent = parent 222 223 if (parent === terminator) { 224 root = node 225 } else { 226 if (node.start <= parent.start) { 227 parent.left = node 228 } else { 229 parent.right = node 230 } 231 } 232 233 updateNodeData(node) 234 235 rebalance(node) 236 } 237 rebalancenull238 private fun rebalance(target: Node) { 239 var node = target 240 241 while (node !== root && node.parent.color == TreeColorRed) { 242 val ancestor = node.parent.parent 243 if (node.parent === ancestor.left) { 244 val right = ancestor.right 245 if (right.color == TreeColorRed) { 246 right.color = TreeColorBlack 247 node.parent.color = TreeColorBlack 248 ancestor.color = TreeColorRed 249 node = ancestor 250 } else { 251 if (node === node.parent.right) { 252 node = node.parent 253 rotateLeft(node) 254 } 255 node.parent.color = TreeColorBlack 256 ancestor.color = TreeColorRed 257 rotateRight(ancestor) 258 } 259 } else { 260 val left = ancestor.left 261 if (left.color == TreeColorRed) { 262 left.color = TreeColorBlack 263 node.parent.color = TreeColorBlack 264 ancestor.color = TreeColorRed 265 node = ancestor 266 } else { 267 if (node === node.parent.left) { 268 node = node.parent 269 rotateRight(node) 270 } 271 node.parent.color = TreeColorBlack 272 ancestor.color = TreeColorRed 273 rotateLeft(ancestor) 274 } 275 } 276 } 277 278 root.color = TreeColorBlack 279 } 280 rotateLeftnull281 private fun rotateLeft(node: Node) { 282 val right = node.right 283 node.right = right.left 284 285 if (right.left !== terminator) { 286 right.left.parent = node 287 } 288 289 right.parent = node.parent 290 291 if (node.parent === terminator) { 292 root = right 293 } else { 294 if (node.parent.left === node) { 295 node.parent.left = right 296 } else { 297 node.parent.right = right 298 } 299 } 300 301 right.left = node 302 node.parent = right 303 304 updateNodeData(node) 305 } 306 rotateRightnull307 private fun rotateRight(node: Node) { 308 val left = node.left 309 node.left = left.right 310 311 if (left.right !== terminator) { 312 left.right.parent = node 313 } 314 315 left.parent = node.parent 316 317 if (node.parent === terminator) { 318 root = left 319 } else { 320 if (node.parent.right === node) { 321 node.parent.right = left 322 } else { 323 node.parent.left = left 324 } 325 } 326 327 left.right = node 328 node.parent = left 329 330 updateNodeData(node) 331 } 332 updateNodeDatanull333 private fun updateNodeData(node: Node) { 334 var current = node 335 while (current !== terminator) { 336 current.min = min(current.start, min(current.left.min, current.right.min)) 337 current.max = max(current.end, max(current.left.max, current.right.max)) 338 current = current.parent 339 } 340 } 341 342 internal inner class Node(start: Float, end: Float, data: T?, var color: TreeColor) : <lambda>null343 Interval<T>(start, end, data) { 344 var min: Float = start 345 var max: Float = end 346 347 var left: Node = terminator 348 var right: Node = terminator 349 var parent: Node = terminator 350 351 fun lowestNode(): Node { 352 var node = this 353 while (node.left !== terminator) { 354 node = node.left 355 } 356 return node 357 } 358 359 fun next(): Node { 360 if (right !== terminator) { 361 return right.lowestNode() 362 } 363 364 var a = this 365 var b = parent 366 while (b !== terminator && a === b.right) { 367 a = b 368 b = b.parent 369 } 370 371 return b 372 } 373 } 374 } 375 376 private typealias TreeColor = Int 377 378 private const val TreeColorRed = 0 379 private const val TreeColorBlack = 1 380