1 // Copyright 2021 gRPC authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef GRPC_SRC_CORE_UTIL_AVL_H 16 #define GRPC_SRC_CORE_UTIL_AVL_H 17 18 #include <grpc/support/port_platform.h> 19 #include <stdlib.h> 20 21 #include <algorithm> // IWYU pragma: keep 22 #include <iterator> 23 #include <utility> 24 25 #include "src/core/util/ref_counted.h" 26 #include "src/core/util/ref_counted_ptr.h" 27 #include "src/core/util/useful.h" 28 29 namespace grpc_core { 30 31 template <class K, class V = void> 32 class AVL { 33 public: AVL()34 AVL() {} 35 Add(K key,V value)36 AVL Add(K key, V value) const { 37 return AVL(AddKey(root_, std::move(key), std::move(value))); 38 } 39 template <typename SomethingLikeK> Remove(const SomethingLikeK & key)40 AVL Remove(const SomethingLikeK& key) const { 41 return AVL(RemoveKey(root_, key)); 42 } 43 template <typename SomethingLikeK> Lookup(const SomethingLikeK & key)44 const V* Lookup(const SomethingLikeK& key) const { 45 NodePtr n = Get(root_, key); 46 return n != nullptr ? &n->kv.second : nullptr; 47 } 48 LookupBelow(const K & key)49 const std::pair<K, V>* LookupBelow(const K& key) const { 50 NodePtr n = GetBelow(root_, *key); 51 return n != nullptr ? &n->kv : nullptr; 52 } 53 Empty()54 bool Empty() const { return root_ == nullptr; } 55 56 template <class F> ForEach(F && f)57 void ForEach(F&& f) const { 58 ForEachImpl(root_.get(), std::forward<F>(f)); 59 } 60 SameIdentity(const AVL & avl)61 bool SameIdentity(const AVL& avl) const { return root_ == avl.root_; } 62 QsortCompare(const AVL & left,const AVL & right)63 friend int QsortCompare(const AVL& left, const AVL& right) { 64 if (left.root_.get() == right.root_.get()) return 0; 65 Iterator a(left.root_); 66 Iterator b(right.root_); 67 for (;;) { 68 Node* p = a.current(); 69 Node* q = b.current(); 70 if (p != q) { 71 if (p == nullptr) return -1; 72 if (q == nullptr) return 1; 73 const int kv = QsortCompare(p->kv, q->kv); 74 if (kv != 0) return kv; 75 } else if (p == nullptr) { 76 return 0; 77 } 78 a.MoveNext(); 79 b.MoveNext(); 80 } 81 } 82 83 bool operator==(const AVL& other) const { 84 return QsortCompare(*this, other) == 0; 85 } 86 87 bool operator<(const AVL& other) const { 88 return QsortCompare(*this, other) < 0; 89 } 90 Height()91 size_t Height() const { 92 if (root_ == nullptr) return 0; 93 return root_->height; 94 } 95 96 private: 97 struct Node; 98 99 typedef RefCountedPtr<Node> NodePtr; 100 struct Node : public RefCounted<Node, NonPolymorphicRefCount> { NodeNode101 Node(K k, V v, NodePtr l, NodePtr r, long h) 102 : kv(std::move(k), std::move(v)), 103 left(std::move(l)), 104 right(std::move(r)), 105 height(h) {} 106 const std::pair<K, V> kv; 107 const NodePtr left; 108 const NodePtr right; 109 const long height; 110 }; 111 NodePtr root_; 112 113 class IteratorStack { 114 public: Push(Node * n)115 void Push(Node* n) { 116 nodes_[depth_] = n; 117 ++depth_; 118 } 119 Pop()120 Node* Pop() { 121 --depth_; 122 return nodes_[depth_]; 123 } 124 Back()125 Node* Back() const { return nodes_[depth_ - 1]; } 126 Empty()127 bool Empty() const { return depth_ == 0; } 128 129 private: 130 size_t depth_{0}; 131 // 32 is the maximum depth we can accept, and corresponds to ~4billion nodes 132 // - which ought to suffice our use cases. 133 Node* nodes_[32]; 134 }; 135 136 class Iterator { 137 public: Iterator(const NodePtr & root)138 explicit Iterator(const NodePtr& root) { 139 auto* n = root.get(); 140 while (n != nullptr) { 141 stack_.Push(n); 142 n = n->left.get(); 143 } 144 } current()145 Node* current() const { return stack_.Empty() ? nullptr : stack_.Back(); } MoveNext()146 void MoveNext() { 147 auto* n = stack_.Pop(); 148 if (n->right != nullptr) { 149 n = n->right.get(); 150 while (n != nullptr) { 151 stack_.Push(n); 152 n = n->left.get(); 153 } 154 } 155 } 156 157 private: 158 IteratorStack stack_; 159 }; 160 AVL(NodePtr root)161 explicit AVL(NodePtr root) : root_(std::move(root)) {} 162 163 template <class F> ForEachImpl(const Node * n,F && f)164 static void ForEachImpl(const Node* n, F&& f) { 165 if (n == nullptr) return; 166 ForEachImpl(n->left.get(), std::forward<F>(f)); 167 f(const_cast<const K&>(n->kv.first), const_cast<const V&>(n->kv.second)); 168 ForEachImpl(n->right.get(), std::forward<F>(f)); 169 } 170 Height(const NodePtr & n)171 static long Height(const NodePtr& n) { return n != nullptr ? n->height : 0; } 172 MakeNode(K key,V value,const NodePtr & left,const NodePtr & right)173 static NodePtr MakeNode(K key, V value, const NodePtr& left, 174 const NodePtr& right) { 175 return MakeRefCounted<Node>(std::move(key), std::move(value), left, right, 176 1 + std::max(Height(left), Height(right))); 177 } 178 179 template <typename SomethingLikeK> Get(const NodePtr & node,const SomethingLikeK & key)180 static NodePtr Get(const NodePtr& node, const SomethingLikeK& key) { 181 if (node == nullptr) { 182 return nullptr; 183 } 184 185 if (node->kv.first > key) { 186 return Get(node->left, key); 187 } else if (node->kv.first < key) { 188 return Get(node->right, key); 189 } else { 190 return node; 191 } 192 } 193 GetBelow(const NodePtr & node,const K & key)194 static NodePtr GetBelow(const NodePtr& node, const K& key) { 195 if (!node) return nullptr; 196 if (node->kv.first > key) { 197 return GetBelow(node->left, key); 198 } else if (node->kv.first < key) { 199 NodePtr n = GetBelow(node->right, key); 200 if (n == nullptr) n = node; 201 return n; 202 } else { 203 return node; 204 } 205 } 206 RotateLeft(K key,V value,const NodePtr & left,const NodePtr & right)207 static NodePtr RotateLeft(K key, V value, const NodePtr& left, 208 const NodePtr& right) { 209 return MakeNode( 210 right->kv.first, right->kv.second, 211 MakeNode(std::move(key), std::move(value), left, right->left), 212 right->right); 213 } 214 RotateRight(K key,V value,const NodePtr & left,const NodePtr & right)215 static NodePtr RotateRight(K key, V value, const NodePtr& left, 216 const NodePtr& right) { 217 return MakeNode( 218 left->kv.first, left->kv.second, left->left, 219 MakeNode(std::move(key), std::move(value), left->right, right)); 220 } 221 RotateLeftRight(K key,V value,const NodePtr & left,const NodePtr & right)222 static NodePtr RotateLeftRight(K key, V value, const NodePtr& left, 223 const NodePtr& right) { 224 // rotate_right(..., rotate_left(left), right) 225 return MakeNode( 226 left->right->kv.first, left->right->kv.second, 227 MakeNode(left->kv.first, left->kv.second, left->left, 228 left->right->left), 229 MakeNode(std::move(key), std::move(value), left->right->right, right)); 230 } 231 RotateRightLeft(K key,V value,const NodePtr & left,const NodePtr & right)232 static NodePtr RotateRightLeft(K key, V value, const NodePtr& left, 233 const NodePtr& right) { 234 // rotate_left(..., left, rotate_right(right)) 235 return MakeNode( 236 right->left->kv.first, right->left->kv.second, 237 MakeNode(std::move(key), std::move(value), left, right->left->left), 238 MakeNode(right->kv.first, right->kv.second, right->left->right, 239 right->right)); 240 } 241 Rebalance(K key,V value,const NodePtr & left,const NodePtr & right)242 static NodePtr Rebalance(K key, V value, const NodePtr& left, 243 const NodePtr& right) { 244 switch (Height(left) - Height(right)) { 245 case 2: 246 if (Height(left->left) - Height(left->right) == -1) { 247 return RotateLeftRight(std::move(key), std::move(value), left, right); 248 } else { 249 return RotateRight(std::move(key), std::move(value), left, right); 250 } 251 case -2: 252 if (Height(right->left) - Height(right->right) == 1) { 253 return RotateRightLeft(std::move(key), std::move(value), left, right); 254 } else { 255 return RotateLeft(std::move(key), std::move(value), left, right); 256 } 257 default: 258 return MakeNode(key, value, left, right); 259 } 260 } 261 AddKey(const NodePtr & node,K key,V value)262 static NodePtr AddKey(const NodePtr& node, K key, V value) { 263 if (node == nullptr) { 264 return MakeNode(std::move(key), std::move(value), nullptr, nullptr); 265 } 266 if (node->kv.first < key) { 267 return Rebalance(node->kv.first, node->kv.second, node->left, 268 AddKey(node->right, std::move(key), std::move(value))); 269 } 270 if (key < node->kv.first) { 271 return Rebalance(node->kv.first, node->kv.second, 272 AddKey(node->left, std::move(key), std::move(value)), 273 node->right); 274 } 275 return MakeNode(std::move(key), std::move(value), node->left, node->right); 276 } 277 InOrderHead(NodePtr node)278 static NodePtr InOrderHead(NodePtr node) { 279 while (node->left != nullptr) { 280 node = node->left; 281 } 282 return node; 283 } 284 InOrderTail(NodePtr node)285 static NodePtr InOrderTail(NodePtr node) { 286 while (node->right != nullptr) { 287 node = node->right; 288 } 289 return node; 290 } 291 292 template <typename SomethingLikeK> RemoveKey(const NodePtr & node,const SomethingLikeK & key)293 static NodePtr RemoveKey(const NodePtr& node, const SomethingLikeK& key) { 294 if (node == nullptr) { 295 return nullptr; 296 } 297 if (key < node->kv.first) { 298 return Rebalance(node->kv.first, node->kv.second, 299 RemoveKey(node->left, key), node->right); 300 } else if (node->kv.first < key) { 301 return Rebalance(node->kv.first, node->kv.second, node->left, 302 RemoveKey(node->right, key)); 303 } else { 304 if (node->left == nullptr) { 305 return node->right; 306 } else if (node->right == nullptr) { 307 return node->left; 308 } else if (node->left->height < node->right->height) { 309 NodePtr h = InOrderHead(node->right); 310 return Rebalance(h->kv.first, h->kv.second, node->left, 311 RemoveKey(node->right, h->kv.first)); 312 } else { 313 NodePtr h = InOrderTail(node->left); 314 return Rebalance(h->kv.first, h->kv.second, 315 RemoveKey(node->left, h->kv.first), node->right); 316 } 317 } 318 abort(); 319 } 320 }; 321 322 } // namespace grpc_core 323 324 #endif // GRPC_SRC_CORE_UTIL_AVL_H 325