1 #include <stdlib.h>
2 #include <search.h>
3 #include "tsearch.h"
4
height(struct node * n)5 static inline int height(struct node *n) { return n ? n->h : 0; }
6
rot(void ** p,struct node * x,int dir)7 static int rot(void **p, struct node *x, int dir /* deeper side */)
8 {
9 struct node *y = x->a[dir];
10 struct node *z = y->a[!dir];
11 int hx = x->h;
12 int hz = height(z);
13 if (hz > height(y->a[dir])) {
14 /*
15 * x
16 * / \ dir z
17 * A y / \
18 * / \ --> x y
19 * z D /| |\
20 * / \ A B C D
21 * B C
22 */
23 x->a[dir] = z->a[!dir];
24 y->a[!dir] = z->a[dir];
25 z->a[!dir] = x;
26 z->a[dir] = y;
27 x->h = hz;
28 y->h = hz;
29 z->h = hz+1;
30 } else {
31 /*
32 * x y
33 * / \ / \
34 * A y --> x D
35 * / \ / \
36 * z D A z
37 */
38 x->a[dir] = z;
39 y->a[!dir] = x;
40 x->h = hz+1;
41 y->h = hz+2;
42 z = y;
43 }
44 *p = z;
45 return z->h - hx;
46 }
47
48 /* balance *p, return 0 if height is unchanged. */
__tsearch_balance(void ** p)49 int __tsearch_balance(void **p)
50 {
51 struct node *n = *p;
52 int h0 = height(n->a[0]);
53 int h1 = height(n->a[1]);
54 if (h0 - h1 + 1u < 3u) {
55 int old = n->h;
56 n->h = h0<h1 ? h1+1 : h0+1;
57 return n->h - old;
58 }
59 return rot(p, n, h0<h1);
60 }
61
tsearch(const void * key,void ** rootp,int (* cmp)(const void *,const void *))62 void *tsearch(const void *key, void **rootp,
63 int (*cmp)(const void *, const void *))
64 {
65 if (!rootp)
66 return 0;
67
68 void **a[MAXH];
69 struct node *n = *rootp;
70 struct node *r;
71 int i=0;
72 a[i++] = rootp;
73 for (;;) {
74 if (!n)
75 break;
76 int c = cmp(key, n->key);
77 if (!c)
78 return n;
79 a[i++] = &n->a[c>0];
80 n = n->a[c>0];
81 }
82 r = malloc(sizeof *r);
83 if (!r)
84 return 0;
85 r->key = key;
86 r->a[0] = r->a[1] = 0;
87 r->h = 1;
88 /* insert new node, rebalance ancestors. */
89 *a[--i] = r;
90 while (i && __tsearch_balance(a[--i]));
91 return r;
92 }
93