• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_SRC_CORE_UTIL_AVL_H
16 #define GRPC_SRC_CORE_UTIL_AVL_H
17 
18 #include <grpc/support/port_platform.h>
19 #include <stdlib.h>
20 
21 #include <algorithm>  // IWYU pragma: keep
22 #include <iterator>
23 #include <utility>
24 
25 #include "src/core/util/ref_counted.h"
26 #include "src/core/util/ref_counted_ptr.h"
27 #include "src/core/util/useful.h"
28 
29 namespace grpc_core {
30 
31 template <class K, class V = void>
32 class AVL {
33  public:
AVL()34   AVL() {}
35 
Add(K key,V value)36   AVL Add(K key, V value) const {
37     return AVL(AddKey(root_, std::move(key), std::move(value)));
38   }
39   template <typename SomethingLikeK>
Remove(const SomethingLikeK & key)40   AVL Remove(const SomethingLikeK& key) const {
41     return AVL(RemoveKey(root_, key));
42   }
43   template <typename SomethingLikeK>
Lookup(const SomethingLikeK & key)44   const V* Lookup(const SomethingLikeK& key) const {
45     NodePtr n = Get(root_, key);
46     return n != nullptr ? &n->kv.second : nullptr;
47   }
48 
LookupBelow(const K & key)49   const std::pair<K, V>* LookupBelow(const K& key) const {
50     NodePtr n = GetBelow(root_, *key);
51     return n != nullptr ? &n->kv : nullptr;
52   }
53 
Empty()54   bool Empty() const { return root_ == nullptr; }
55 
56   template <class F>
ForEach(F && f)57   void ForEach(F&& f) const {
58     ForEachImpl(root_.get(), std::forward<F>(f));
59   }
60 
SameIdentity(const AVL & avl)61   bool SameIdentity(const AVL& avl) const { return root_ == avl.root_; }
62 
QsortCompare(const AVL & left,const AVL & right)63   friend int QsortCompare(const AVL& left, const AVL& right) {
64     if (left.root_.get() == right.root_.get()) return 0;
65     Iterator a(left.root_);
66     Iterator b(right.root_);
67     for (;;) {
68       Node* p = a.current();
69       Node* q = b.current();
70       if (p != q) {
71         if (p == nullptr) return -1;
72         if (q == nullptr) return 1;
73         const int kv = QsortCompare(p->kv, q->kv);
74         if (kv != 0) return kv;
75       } else if (p == nullptr) {
76         return 0;
77       }
78       a.MoveNext();
79       b.MoveNext();
80     }
81   }
82 
83   bool operator==(const AVL& other) const {
84     return QsortCompare(*this, other) == 0;
85   }
86 
87   bool operator<(const AVL& other) const {
88     return QsortCompare(*this, other) < 0;
89   }
90 
Height()91   size_t Height() const {
92     if (root_ == nullptr) return 0;
93     return root_->height;
94   }
95 
96  private:
97   struct Node;
98 
99   typedef RefCountedPtr<Node> NodePtr;
100   struct Node : public RefCounted<Node, NonPolymorphicRefCount> {
NodeNode101     Node(K k, V v, NodePtr l, NodePtr r, long h)
102         : kv(std::move(k), std::move(v)),
103           left(std::move(l)),
104           right(std::move(r)),
105           height(h) {}
106     const std::pair<K, V> kv;
107     const NodePtr left;
108     const NodePtr right;
109     const long height;
110   };
111   NodePtr root_;
112 
113   class IteratorStack {
114    public:
Push(Node * n)115     void Push(Node* n) {
116       nodes_[depth_] = n;
117       ++depth_;
118     }
119 
Pop()120     Node* Pop() {
121       --depth_;
122       return nodes_[depth_];
123     }
124 
Back()125     Node* Back() const { return nodes_[depth_ - 1]; }
126 
Empty()127     bool Empty() const { return depth_ == 0; }
128 
129    private:
130     size_t depth_{0};
131     // 32 is the maximum depth we can accept, and corresponds to ~4billion nodes
132     // - which ought to suffice our use cases.
133     Node* nodes_[32];
134   };
135 
136   class Iterator {
137    public:
Iterator(const NodePtr & root)138     explicit Iterator(const NodePtr& root) {
139       auto* n = root.get();
140       while (n != nullptr) {
141         stack_.Push(n);
142         n = n->left.get();
143       }
144     }
current()145     Node* current() const { return stack_.Empty() ? nullptr : stack_.Back(); }
MoveNext()146     void MoveNext() {
147       auto* n = stack_.Pop();
148       if (n->right != nullptr) {
149         n = n->right.get();
150         while (n != nullptr) {
151           stack_.Push(n);
152           n = n->left.get();
153         }
154       }
155     }
156 
157    private:
158     IteratorStack stack_;
159   };
160 
AVL(NodePtr root)161   explicit AVL(NodePtr root) : root_(std::move(root)) {}
162 
163   template <class F>
ForEachImpl(const Node * n,F && f)164   static void ForEachImpl(const Node* n, F&& f) {
165     if (n == nullptr) return;
166     ForEachImpl(n->left.get(), std::forward<F>(f));
167     f(const_cast<const K&>(n->kv.first), const_cast<const V&>(n->kv.second));
168     ForEachImpl(n->right.get(), std::forward<F>(f));
169   }
170 
Height(const NodePtr & n)171   static long Height(const NodePtr& n) { return n != nullptr ? n->height : 0; }
172 
MakeNode(K key,V value,const NodePtr & left,const NodePtr & right)173   static NodePtr MakeNode(K key, V value, const NodePtr& left,
174                           const NodePtr& right) {
175     return MakeRefCounted<Node>(std::move(key), std::move(value), left, right,
176                                 1 + std::max(Height(left), Height(right)));
177   }
178 
179   template <typename SomethingLikeK>
Get(const NodePtr & node,const SomethingLikeK & key)180   static NodePtr Get(const NodePtr& node, const SomethingLikeK& key) {
181     if (node == nullptr) {
182       return nullptr;
183     }
184 
185     if (node->kv.first > key) {
186       return Get(node->left, key);
187     } else if (node->kv.first < key) {
188       return Get(node->right, key);
189     } else {
190       return node;
191     }
192   }
193 
GetBelow(const NodePtr & node,const K & key)194   static NodePtr GetBelow(const NodePtr& node, const K& key) {
195     if (!node) return nullptr;
196     if (node->kv.first > key) {
197       return GetBelow(node->left, key);
198     } else if (node->kv.first < key) {
199       NodePtr n = GetBelow(node->right, key);
200       if (n == nullptr) n = node;
201       return n;
202     } else {
203       return node;
204     }
205   }
206 
RotateLeft(K key,V value,const NodePtr & left,const NodePtr & right)207   static NodePtr RotateLeft(K key, V value, const NodePtr& left,
208                             const NodePtr& right) {
209     return MakeNode(
210         right->kv.first, right->kv.second,
211         MakeNode(std::move(key), std::move(value), left, right->left),
212         right->right);
213   }
214 
RotateRight(K key,V value,const NodePtr & left,const NodePtr & right)215   static NodePtr RotateRight(K key, V value, const NodePtr& left,
216                              const NodePtr& right) {
217     return MakeNode(
218         left->kv.first, left->kv.second, left->left,
219         MakeNode(std::move(key), std::move(value), left->right, right));
220   }
221 
RotateLeftRight(K key,V value,const NodePtr & left,const NodePtr & right)222   static NodePtr RotateLeftRight(K key, V value, const NodePtr& left,
223                                  const NodePtr& right) {
224     // rotate_right(..., rotate_left(left), right)
225     return MakeNode(
226         left->right->kv.first, left->right->kv.second,
227         MakeNode(left->kv.first, left->kv.second, left->left,
228                  left->right->left),
229         MakeNode(std::move(key), std::move(value), left->right->right, right));
230   }
231 
RotateRightLeft(K key,V value,const NodePtr & left,const NodePtr & right)232   static NodePtr RotateRightLeft(K key, V value, const NodePtr& left,
233                                  const NodePtr& right) {
234     // rotate_left(..., left, rotate_right(right))
235     return MakeNode(
236         right->left->kv.first, right->left->kv.second,
237         MakeNode(std::move(key), std::move(value), left, right->left->left),
238         MakeNode(right->kv.first, right->kv.second, right->left->right,
239                  right->right));
240   }
241 
Rebalance(K key,V value,const NodePtr & left,const NodePtr & right)242   static NodePtr Rebalance(K key, V value, const NodePtr& left,
243                            const NodePtr& right) {
244     switch (Height(left) - Height(right)) {
245       case 2:
246         if (Height(left->left) - Height(left->right) == -1) {
247           return RotateLeftRight(std::move(key), std::move(value), left, right);
248         } else {
249           return RotateRight(std::move(key), std::move(value), left, right);
250         }
251       case -2:
252         if (Height(right->left) - Height(right->right) == 1) {
253           return RotateRightLeft(std::move(key), std::move(value), left, right);
254         } else {
255           return RotateLeft(std::move(key), std::move(value), left, right);
256         }
257       default:
258         return MakeNode(key, value, left, right);
259     }
260   }
261 
AddKey(const NodePtr & node,K key,V value)262   static NodePtr AddKey(const NodePtr& node, K key, V value) {
263     if (node == nullptr) {
264       return MakeNode(std::move(key), std::move(value), nullptr, nullptr);
265     }
266     if (node->kv.first < key) {
267       return Rebalance(node->kv.first, node->kv.second, node->left,
268                        AddKey(node->right, std::move(key), std::move(value)));
269     }
270     if (key < node->kv.first) {
271       return Rebalance(node->kv.first, node->kv.second,
272                        AddKey(node->left, std::move(key), std::move(value)),
273                        node->right);
274     }
275     return MakeNode(std::move(key), std::move(value), node->left, node->right);
276   }
277 
InOrderHead(NodePtr node)278   static NodePtr InOrderHead(NodePtr node) {
279     while (node->left != nullptr) {
280       node = node->left;
281     }
282     return node;
283   }
284 
InOrderTail(NodePtr node)285   static NodePtr InOrderTail(NodePtr node) {
286     while (node->right != nullptr) {
287       node = node->right;
288     }
289     return node;
290   }
291 
292   template <typename SomethingLikeK>
RemoveKey(const NodePtr & node,const SomethingLikeK & key)293   static NodePtr RemoveKey(const NodePtr& node, const SomethingLikeK& key) {
294     if (node == nullptr) {
295       return nullptr;
296     }
297     if (key < node->kv.first) {
298       return Rebalance(node->kv.first, node->kv.second,
299                        RemoveKey(node->left, key), node->right);
300     } else if (node->kv.first < key) {
301       return Rebalance(node->kv.first, node->kv.second, node->left,
302                        RemoveKey(node->right, key));
303     } else {
304       if (node->left == nullptr) {
305         return node->right;
306       } else if (node->right == nullptr) {
307         return node->left;
308       } else if (node->left->height < node->right->height) {
309         NodePtr h = InOrderHead(node->right);
310         return Rebalance(h->kv.first, h->kv.second, node->left,
311                          RemoveKey(node->right, h->kv.first));
312       } else {
313         NodePtr h = InOrderTail(node->left);
314         return Rebalance(h->kv.first, h->kv.second,
315                          RemoveKey(node->left, h->kv.first), node->right);
316       }
317     }
318     abort();
319   }
320 };
321 
322 }  // namespace grpc_core
323 
324 #endif  // GRPC_SRC_CORE_UTIL_AVL_H
325