1 // SPDX-License-Identifier: LGPL-2.1
2 /*
3 * Copyright (C) 2023 Google, Steven Rostedt <rostedt@goodmis.org>
4 *
5 */
6 #include <stdlib.h>
7 #include <stdbool.h>
8 #include "trace-local.h"
9 #include "trace-rbtree.h"
10
11 enum {
12 RED,
13 BLACK,
14 };
15
trace_rbtree_init(struct trace_rbtree * tree,trace_rbtree_cmp_fn cmp_fn,trace_rbtree_search_fn search_fn)16 void __hidden trace_rbtree_init(struct trace_rbtree *tree, trace_rbtree_cmp_fn cmp_fn,
17 trace_rbtree_search_fn search_fn)
18 {
19 memset(tree, 0, sizeof(*tree));
20 tree->search = search_fn;
21 tree->cmp = cmp_fn;
22 }
23
is_left(struct trace_rbtree_node * node)24 static bool is_left(struct trace_rbtree_node *node)
25 {
26 return node == node->parent->left;
27 }
28
get_parent_ptr(struct trace_rbtree * tree,struct trace_rbtree_node * node)29 static struct trace_rbtree_node **get_parent_ptr(struct trace_rbtree *tree,
30 struct trace_rbtree_node *node)
31 {
32 if (!node->parent)
33 return &tree->node;
34 else if (is_left(node))
35 return &node->parent->left;
36 else
37 return &node->parent->right;
38 }
39
rotate_left(struct trace_rbtree * tree,struct trace_rbtree_node * node)40 static void rotate_left(struct trace_rbtree *tree,
41 struct trace_rbtree_node *node)
42 {
43 struct trace_rbtree_node **parent_ptr = get_parent_ptr(tree, node);
44 struct trace_rbtree_node *parent = node->parent;
45 struct trace_rbtree_node *old_right = node->right;
46
47 *parent_ptr = old_right;
48 node->right = old_right->left;
49 old_right->left = node;
50
51 if (node->right)
52 node->right->parent = node;
53 node->parent = old_right;
54 old_right->parent = parent;
55 }
56
rotate_right(struct trace_rbtree * tree,struct trace_rbtree_node * node)57 static void rotate_right(struct trace_rbtree *tree,
58 struct trace_rbtree_node *node)
59 {
60 struct trace_rbtree_node **parent_ptr = get_parent_ptr(tree, node);
61 struct trace_rbtree_node *parent = node->parent;
62 struct trace_rbtree_node *old_left = node->left;
63
64 *parent_ptr = old_left;
65 node->left = old_left->right;
66 old_left->right = node;
67
68 if (node->left)
69 node->left->parent = node;
70 node->parent = old_left;
71 old_left->parent = parent;
72 }
73
insert_tree(struct trace_rbtree * tree,struct trace_rbtree_node * node)74 static void insert_tree(struct trace_rbtree *tree,
75 struct trace_rbtree_node *node)
76 {
77 struct trace_rbtree_node *next = tree->node;
78 struct trace_rbtree_node *last_next = NULL;
79 bool went_left = false;
80
81 while (next) {
82 last_next = next;
83 if (tree->cmp(next, node) > 0) {
84 next = next->right;
85 went_left = false;
86 } else {
87 next = next->left;
88 went_left = true;
89 }
90 }
91
92 if (!last_next) {
93 tree->node = node;
94 return;
95 }
96
97 if (went_left)
98 last_next->left = node;
99 else
100 last_next->right = node;
101
102 node->parent = last_next;
103 }
104
105 #if 0
106 static int check_node(struct trace_rbtree *tree, struct trace_rbtree_node *node)
107 {
108 if (!node->parent) {
109 if (tree->node != node)
110 goto fail;
111 } else {
112 if (!is_left(node)) {
113 if (node->parent->right != node)
114 goto fail;
115 }
116 }
117 return 0;
118 fail:
119 printf("FAILED ON NODE!");
120 breakpoint();
121 return -1;
122 }
123
124 static void check_tree(struct trace_rbtree *tree)
125 {
126 struct trace_rbtree_node *node = tree->node;
127
128 if (node) {
129 if (check_node(tree, node))
130 return;
131 while (node->left) {
132 node = node->left;
133 if (check_node(tree, node))
134 return;
135 }
136 }
137
138 while (node) {
139 if (check_node(tree, node))
140 return;
141 if (node->right) {
142 node = node->right;
143 if (check_node(tree, node))
144 return;
145 while (node->left) {
146 node = node->left;
147 if (check_node(tree, node))
148 return;
149 }
150 continue;
151 }
152 while (node->parent) {
153 if (is_left(node))
154 break;
155 node = node->parent;
156 if (check_node(tree, node))
157 return;
158 }
159 node = node->parent;
160 }
161 }
162 #else
check_tree(struct trace_rbtree * tree)163 static inline void check_tree(struct trace_rbtree *tree) { }
164 #endif
165
trace_rbtree_insert(struct trace_rbtree * tree,struct trace_rbtree_node * node)166 int __hidden trace_rbtree_insert(struct trace_rbtree *tree,
167 struct trace_rbtree_node *node)
168 {
169 struct trace_rbtree_node *uncle;
170
171 memset(node, 0, sizeof(*node));
172
173 insert_tree(tree, node);
174 node->color = RED;
175 while (node && node->parent && node->parent->color == RED) {
176 if (is_left(node->parent)) {
177 uncle = node->parent->parent->right;
178 if (uncle && uncle->color == RED) {
179 node->parent->color = BLACK;
180 uncle->color = BLACK;
181 node->parent->parent->color = RED;
182 node = node->parent->parent;
183 } else {
184 if (!is_left(node)) {
185 node = node->parent;
186 rotate_left(tree, node);
187 check_tree(tree);
188 }
189 node->parent->color = BLACK;
190 node->parent->parent->color = RED;
191 rotate_right(tree, node->parent->parent);
192 check_tree(tree);
193 }
194 } else {
195 uncle = node->parent->parent->left;
196 if (uncle && uncle->color == RED) {
197 node->parent->color = BLACK;
198 uncle->color = BLACK;
199 node->parent->parent->color = RED;
200 node = node->parent->parent;
201 } else {
202 if (is_left(node)) {
203 node = node->parent;
204 rotate_right(tree, node);
205 check_tree(tree);
206 }
207 node->parent->color = BLACK;
208 node->parent->parent->color = RED;
209 rotate_left(tree, node->parent->parent);
210 check_tree(tree);
211 }
212 }
213 }
214 check_tree(tree);
215 tree->node->color = BLACK;
216 tree->nr_nodes++;
217 return 0;
218 }
219
trace_rbtree_find(struct trace_rbtree * tree,const void * data)220 struct trace_rbtree_node *trace_rbtree_find(struct trace_rbtree *tree, const void *data)
221 {
222 struct trace_rbtree_node *node = tree->node;
223 int ret;
224
225 while (node) {
226 ret = tree->search(node, data);
227 if (!ret)
228 return node;
229 if (ret > 0)
230 node = node->right;
231 else
232 node = node->left;
233 }
234 return NULL;
235 }
236
next_node(struct trace_rbtree_node * node)237 static struct trace_rbtree_node *next_node(struct trace_rbtree_node *node)
238 {
239 if (node->right) {
240 node = node->right;
241 while (node->left)
242 node = node->left;
243 return node;
244 }
245
246 while (node->parent && !is_left(node))
247 node = node->parent;
248
249 return node->parent;
250 }
251
tree_fixup(struct trace_rbtree * tree,struct trace_rbtree_node * node)252 static void tree_fixup(struct trace_rbtree *tree, struct trace_rbtree_node *node)
253 {
254 while (node->parent && node->color == BLACK) {
255 if (is_left(node)) {
256 struct trace_rbtree_node *old_right = node->parent->right;
257
258 if (old_right->color == RED) {
259 old_right->color = BLACK;
260 node->parent->color = RED;
261 rotate_left(tree, node->parent);
262 old_right = node->parent->right;
263 }
264 if (old_right->left->color == BLACK &&
265 old_right->right->color == BLACK) {
266 old_right->color = RED;
267 node = node->parent;
268 } else {
269 if (old_right->right->color == BLACK) {
270 old_right->left->color = BLACK;
271 old_right->color = RED;
272 rotate_right(tree, old_right);
273 old_right = node->parent->right;
274 }
275 old_right->color = node->parent->color;
276 node->parent->color = BLACK;
277 old_right->right->color = BLACK;
278 rotate_left(tree, node->parent);
279 node = tree->node;
280 }
281 } else {
282 struct trace_rbtree_node *old_left = node->parent->left;
283
284 if (old_left->color == RED) {
285 old_left->color = BLACK;
286 node->parent->color = RED;
287 rotate_right(tree, node->parent);
288 old_left = node->parent->left;
289 }
290 if (old_left->right->color == BLACK &&
291 old_left->left->color == BLACK) {
292 old_left->color = RED;
293 node = node->parent;
294 } else {
295 if (old_left->left->color == BLACK) {
296 old_left->right->color = BLACK;
297 old_left->color = RED;
298 rotate_left(tree, old_left);
299 old_left = node->parent->left;
300 }
301 old_left->color = node->parent->color;
302 node->parent->color = BLACK;
303 old_left->left->color = BLACK;
304 rotate_right(tree, node->parent);
305 node = tree->node;
306 }
307 }
308 }
309 node->color = BLACK;
310 }
311
trace_rbtree_delete(struct trace_rbtree * tree,struct trace_rbtree_node * node)312 void trace_rbtree_delete(struct trace_rbtree *tree, struct trace_rbtree_node *node)
313 {
314 struct trace_rbtree_node *x, *y;
315 bool do_fixup = false;
316
317 if (!node->left && !node->right && !node->parent) {
318 tree->node = NULL;
319 goto out;
320 }
321
322 if (!node->left || !node->right)
323 y = node;
324 else
325 y = next_node(node);
326
327 if (y->left)
328 x = y->left;
329 else
330 x = y->right;
331
332 if (x)
333 x->parent = y->parent;
334
335 if (!y->parent) {
336 tree->node = x;
337 } else {
338 if (is_left(y))
339 y->parent->left = x;
340 else
341 y->parent->right = x;
342 }
343
344 do_fixup = y->color == BLACK;
345
346 if (y != node) {
347 y->color = node->color;
348 y->parent = node->parent;
349 y->left = node->left;
350 y->right = node->right;
351 if (y->left)
352 y->left->parent = y;
353 if (y->right)
354 y->right->parent = y;
355 if (!y->parent) {
356 tree->node = y;
357 } else {
358 if (is_left(node))
359 y->parent->left = y;
360 else
361 y->parent->right = y;
362 }
363 }
364
365 if (do_fixup)
366 tree_fixup(tree, x);
367
368 out:
369 node->parent = node->left = node->right = NULL;
370 tree->nr_nodes--;
371 check_tree(tree);
372 }
373
trace_rbtree_next(struct trace_rbtree * tree,struct trace_rbtree_node * node)374 __hidden struct trace_rbtree_node *trace_rbtree_next(struct trace_rbtree *tree,
375 struct trace_rbtree_node *node)
376 {
377 check_tree(tree);
378 /*
379 * When either starting or the previous iteration returned a
380 * node with a right branch, then go to the first node (if starting)
381 * or the right node, and then return the left most node.
382 */
383 if (!node || node->right) {
384 if (!node)
385 node = tree->node;
386 else
387 node = node->right;
388 while (node && node->left)
389 node = node->left;
390 return node;
391 }
392
393 /*
394 * If we are here, then the previous iteration returned the
395 * left most node of the tree or the right branch. If this
396 * is a left node, then simply return the parent. If this
397 * is a right node, then keep going up until its a left node,
398 * or we finished the iteration.
399 *
400 * If we are here and are the top node, then there is no right
401 * node, and this is finished (return NULL).
402 */
403 if (!node->parent || is_left(node))
404 return node->parent;
405
406 do {
407 node = node->parent;
408 } while (node->parent && !is_left(node));
409
410 return node->parent;
411 }
412
413 /*
414 * Used for freeing a tree, just quickly pop off the children in
415 * no particular order. This will corrupt the tree! That is,
416 * do not do any inserting or deleting of this tree after calling
417 * this function.
418 */
trace_rbtree_pop_nobalance(struct trace_rbtree * tree)419 struct trace_rbtree_node *trace_rbtree_pop_nobalance(struct trace_rbtree *tree)
420 {
421 struct trace_rbtree_node *node = tree->node;
422
423 if (!node)
424 return NULL;
425
426 while (node->left)
427 node = node->left;
428
429 while (node->right)
430 node = node->right;
431
432 if (node->parent) {
433 if (is_left(node))
434 node->parent->left = NULL;
435 else
436 node->parent->right = NULL;
437 } else {
438 tree->node = NULL;
439 }
440
441 return node;
442 }
443