1 /*
<lambda>null2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.compose.ui.graphics
18 
19 import androidx.annotation.RestrictTo
20 import kotlin.jvm.JvmField
21 import kotlin.math.max
22 import kotlin.math.min
23 
24 // TODO: We should probably move this to androidx.collection
25 
26 /**
27  * Interval in an [IntervalTree]. The interval is defined between a [start] and an [end] coordinate,
28  * whose meanings are defined by the caller. An interval can also hold arbitrary [data] to be used
29  * to looking at the result of queries with [IntervalTree.findOverlaps].
30  */
31 @RestrictTo(RestrictTo.Scope.LIBRARY_GROUP_PREFIX)
32 open class Interval<T>(val start: Float, val end: Float, val data: T? = null) {
33     /** Returns trues if this interval overlaps with another interval. */
34     fun overlaps(other: Interval<T>) = start <= other.end && end >= other.start
35 
36     /**
37      * Returns trues if this interval overlaps with the interval defined by [start] and [end].
38      * [start] must be less than or equal to [end].
39      */
40     fun overlaps(start: Float, end: Float) = this.start <= end && this.end >= start
41 
42     /** Returns true if this interval contains [value]. */
43     operator fun contains(value: Float) = value in start..end
44 
45     override fun equals(other: Any?): Boolean {
46         if (this === other) return true
47         if (other == null || this::class != other::class) return false
48 
49         other as Interval<*>
50 
51         if (start != other.start) return false
52         if (end != other.end) return false
53         if (data != other.data) return false
54 
55         return true
56     }
57 
58     override fun hashCode(): Int {
59         var result = start.hashCode()
60         result = 31 * result + end.hashCode()
61         result = 31 * result + (data?.hashCode() ?: 0)
62         return result
63     }
64 
65     override fun toString(): String {
66         return "Interval(start=$start, end=$end, data=$data)"
67     }
68 }
69 
70 /** Represents an empty/invalid interval. */
71 internal val EmptyInterval: Interval<Any?> = Interval(Float.MAX_VALUE, Float.MIN_VALUE, null)
72 
73 /**
74  * An interval tree holds a list of intervals and allows for fast queries of intervals that overlap
75  * any given interval. This can be used for instance to perform fast spatial queries like finding
76  * all the segments in a path that overlap with a given vertical interval.
77  */
78 @RestrictTo(RestrictTo.Scope.LIBRARY_GROUP_PREFIX)
79 class IntervalTree<T> {
80     // Note: this interval tree is implemented as a binary red/black tree that gets
81     // re-balanced on updates. There's nothing notable about this particular data
82     // structure beyond what can be found in various descriptions of binary search
83     // trees and red/black trees
84 
85     @JvmField internal val terminator = Node(Float.MAX_VALUE, Float.MIN_VALUE, null, TreeColorBlack)
86     @JvmField internal var root = terminator
87     @JvmField internal val stack = ArrayList<Node>()
88 
89     /**
90      * Clears this tree and prepares it for reuse. After calling [clear], any call to [findOverlaps]
91      * returns false.
92      */
clearnull93     fun clear() {
94         root = terminator
95     }
96 
97     /**
98      * Finds the first interval that overlaps with the specified [interval]. If no overlap can be
99      * found, return [EmptyInterval].
100      */
findFirstOverlapnull101     fun findFirstOverlap(interval: ClosedFloatingPointRange<Float>) =
102         findFirstOverlap(interval.start, interval.endInclusive)
103 
104     /**
105      * Finds the first interval that overlaps with the interval defined by [start] and [end]. If no
106      * overlap can be found, return [EmptyInterval]. [start] *must* be lesser than or equal to
107      * [end].
108      */
109     fun findFirstOverlap(start: Float, end: Float = start): Interval<T> {
110         if (root !== terminator) {
111             forEach(start, end) { interval ->
112                 return interval
113             }
114         }
115         @Suppress("UNCHECKED_CAST") return EmptyInterval as Interval<T>
116     }
117 
118     /**
119      * Finds all the intervals that overlap with the specified [interval]. If [results] is
120      * specified, [results] is returned, otherwise a new [MutableList] is returned.
121      */
findOverlapsnull122     fun findOverlaps(
123         interval: ClosedFloatingPointRange<Float>,
124         results: MutableList<Interval<T>> = mutableListOf()
125     ) = findOverlaps(interval.start, interval.endInclusive, results)
126 
127     /**
128      * Finds all the intervals that overlap with the interval defined by [start] and [end]. [start]
129      * *must* be lesser than or equal to [end]. If [results] is specified, [results] is returned,
130      * otherwise a new [MutableList] is returned.
131      */
132     fun findOverlaps(
133         start: Float,
134         end: Float = start,
135         results: MutableList<Interval<T>> = mutableListOf()
136     ): MutableList<Interval<T>> {
137         forEach(start, end) { interval -> results.add(interval) }
138         return results
139     }
140 
141     /** Executes [block] for each interval that overlaps the specified [interval]. */
forEachnull142     internal inline fun forEach(
143         interval: ClosedFloatingPointRange<Float>,
144         block: (Interval<T>) -> Unit
145     ) = forEach(interval.start, interval.endInclusive, block)
146 
147     /**
148      * Executes [block] for each interval that overlaps with the interval defined by [start] and
149      * [end]. [start] *must* be lesser than or equal to [end].
150      */
151     internal inline fun forEach(start: Float, end: Float = start, block: (Interval<T>) -> Unit) {
152         if (root !== terminator) {
153             val s = stack
154             s.add(root)
155             while (s.size > 0) {
156                 val node = s.removeAt(s.size - 1)
157                 if (node.overlaps(start, end)) block(node)
158                 if (node.left !== terminator && node.left.max >= start) {
159                     s.add(node.left)
160                 }
161                 if (node.right !== terminator && node.right.min <= end) {
162                     s.add(node.right)
163                 }
164             }
165             s.clear()
166         }
167     }
168 
169     /** Returns true if [value] is inside any of the intervals in this tree. */
containsnull170     operator fun contains(value: Float) = findFirstOverlap(value, value) !== EmptyInterval
171 
172     /** Returns true if the specified [interval] overlaps with any of the intervals in this tree. */
173     operator fun contains(interval: ClosedFloatingPointRange<Float>) =
174         findFirstOverlap(interval.start, interval.endInclusive) !== EmptyInterval
175 
176     operator fun iterator(): Iterator<Interval<T>> {
177         return object : Iterator<Interval<T>> {
178             private var next = root.lowestNode()
179 
180             override fun hasNext(): Boolean {
181                 return next !== terminator
182             }
183 
184             override fun next(): Interval<T> {
185                 val node = next
186                 next = next.next()
187                 return node
188             }
189         }
190     }
191 
192     /** Adds the specified [Interval] to the interval tree. */
plusAssignnull193     operator fun plusAssign(interval: Interval<T>) {
194         addInterval(interval.start, interval.end, interval.data)
195     }
196 
197     /**
198      * Adds the interval defined between a [start] and an [end] coordinate.
199      *
200      * @param start The start coordinate of the interval
201      * @param end The end coordinate of the interval, must be >= [start]
202      * @param data Data to associate with the interval
203      */
addIntervalnull204     fun addInterval(start: Float, end: Float, data: T?) {
205         val node = Node(start, end, data, TreeColorRed)
206 
207         // Update the tree without doing any balancing
208         var current = root
209         var parent = terminator
210 
211         while (current !== terminator) {
212             parent = current
213             current =
214                 if (node.start <= current.start) {
215                     current.left
216                 } else {
217                     current.right
218                 }
219         }
220 
221         node.parent = parent
222 
223         if (parent === terminator) {
224             root = node
225         } else {
226             if (node.start <= parent.start) {
227                 parent.left = node
228             } else {
229                 parent.right = node
230             }
231         }
232 
233         updateNodeData(node)
234 
235         rebalance(node)
236     }
237 
rebalancenull238     private fun rebalance(target: Node) {
239         var node = target
240 
241         while (node !== root && node.parent.color == TreeColorRed) {
242             val ancestor = node.parent.parent
243             if (node.parent === ancestor.left) {
244                 val right = ancestor.right
245                 if (right.color == TreeColorRed) {
246                     right.color = TreeColorBlack
247                     node.parent.color = TreeColorBlack
248                     ancestor.color = TreeColorRed
249                     node = ancestor
250                 } else {
251                     if (node === node.parent.right) {
252                         node = node.parent
253                         rotateLeft(node)
254                     }
255                     node.parent.color = TreeColorBlack
256                     ancestor.color = TreeColorRed
257                     rotateRight(ancestor)
258                 }
259             } else {
260                 val left = ancestor.left
261                 if (left.color == TreeColorRed) {
262                     left.color = TreeColorBlack
263                     node.parent.color = TreeColorBlack
264                     ancestor.color = TreeColorRed
265                     node = ancestor
266                 } else {
267                     if (node === node.parent.left) {
268                         node = node.parent
269                         rotateRight(node)
270                     }
271                     node.parent.color = TreeColorBlack
272                     ancestor.color = TreeColorRed
273                     rotateLeft(ancestor)
274                 }
275             }
276         }
277 
278         root.color = TreeColorBlack
279     }
280 
rotateLeftnull281     private fun rotateLeft(node: Node) {
282         val right = node.right
283         node.right = right.left
284 
285         if (right.left !== terminator) {
286             right.left.parent = node
287         }
288 
289         right.parent = node.parent
290 
291         if (node.parent === terminator) {
292             root = right
293         } else {
294             if (node.parent.left === node) {
295                 node.parent.left = right
296             } else {
297                 node.parent.right = right
298             }
299         }
300 
301         right.left = node
302         node.parent = right
303 
304         updateNodeData(node)
305     }
306 
rotateRightnull307     private fun rotateRight(node: Node) {
308         val left = node.left
309         node.left = left.right
310 
311         if (left.right !== terminator) {
312             left.right.parent = node
313         }
314 
315         left.parent = node.parent
316 
317         if (node.parent === terminator) {
318             root = left
319         } else {
320             if (node.parent.right === node) {
321                 node.parent.right = left
322             } else {
323                 node.parent.left = left
324             }
325         }
326 
327         left.right = node
328         node.parent = left
329 
330         updateNodeData(node)
331     }
332 
updateNodeDatanull333     private fun updateNodeData(node: Node) {
334         var current = node
335         while (current !== terminator) {
336             current.min = min(current.start, min(current.left.min, current.right.min))
337             current.max = max(current.end, max(current.left.max, current.right.max))
338             current = current.parent
339         }
340     }
341 
342     internal inner class Node(start: Float, end: Float, data: T?, var color: TreeColor) :
<lambda>null343         Interval<T>(start, end, data) {
344         var min: Float = start
345         var max: Float = end
346 
347         var left: Node = terminator
348         var right: Node = terminator
349         var parent: Node = terminator
350 
351         fun lowestNode(): Node {
352             var node = this
353             while (node.left !== terminator) {
354                 node = node.left
355             }
356             return node
357         }
358 
359         fun next(): Node {
360             if (right !== terminator) {
361                 return right.lowestNode()
362             }
363 
364             var a = this
365             var b = parent
366             while (b !== terminator && a === b.right) {
367                 a = b
368                 b = b.parent
369             }
370 
371             return b
372         }
373     }
374 }
375 
376 private typealias TreeColor = Int
377 
378 private const val TreeColorRed = 0
379 private const val TreeColorBlack = 1
380