1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TREAP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TREAP_H_ 18 19 #include <functional> 20 #include <iterator> 21 #include <stack> 22 #include <utility> 23 #include <vector> 24 25 namespace mindspore { 26 namespace dataset { 27 // A treap is a combination of binary search tree and heap. Each key is given a priority. The priority 28 // for any non-leaf node is greater than or equal to the priority of its children. 29 // @tparam K 30 // Data type of key 31 // @tparam P 32 // Data type of priority 33 // @tparam KC 34 // Class to compare key. Default to std::less 35 // @tparam KP 36 // Class to compare priority. Default to std:less 37 template <typename K, typename P, typename KC = std::less<K>, typename KP = std::less<P>> 38 class Treap { 39 public: 40 using key_type = K; 41 using priority_type = P; 42 using key_compare = KC; 43 using priority_compare = KP; 44 45 struct NodeValue { 46 key_type key; 47 priority_type priority; 48 }; 49 50 class TreapNode { 51 public: TreapNode()52 TreapNode() : left(nullptr), right(nullptr) {} ~TreapNode()53 ~TreapNode() { 54 left = nullptr; 55 right = nullptr; 56 } 57 NodeValue nv; 58 TreapNode *left; 59 TreapNode *right; 60 }; 61 62 // search API 63 // @param k 64 // key to search for 65 // @return 66 // a pair is returned. The 2nd value of type bool indicate if the search is successful. 67 // If true, the first value of the pair contains the key and the priority. Search(key_type k)68 std::pair<NodeValue, bool> Search(key_type k) const { 69 auto *n = Search(root_, k); 70 if (n != nullptr) { 71 return std::make_pair(n->nv, true); 72 } else { 73 return std::make_pair(NodeValue{key_type(), priority_type()}, false); 74 } 75 } 76 77 // @return 78 // Return the root of the heap. It has the highest priority. But not necessarily the first key. Top()79 std::pair<NodeValue, bool> Top() const { 80 if (root_ != nullptr) { 81 return std::make_pair(root_->nv, true); 82 } else { 83 return std::make_pair(NodeValue{key_type(), priority_type()}, false); 84 } 85 } 86 87 // Remove the root of the heap. Pop()88 void Pop() { 89 if (root_ != nullptr) { 90 DeleteKey(root_->nv.key); 91 } 92 } 93 94 // Insert API. 95 // @param k 96 // The key to insert. 97 // @param p 98 // The priority of the key. Insert(key_type k,priority_type p)99 void Insert(key_type k, priority_type p) { root_ = Insert(root_, k, p); } 100 101 // Delete a key. 102 // @param k DeleteKey(key_type k)103 void DeleteKey(key_type k) { root_ = DeleteNode(root_, k); } 104 Treap()105 Treap() : root_(nullptr), count_(0) { free_list_.reserve(kResvSz); } 106 ~Treap()107 ~Treap() noexcept { 108 DeleteTreap(root_); 109 while (!free_list_.empty()) { 110 TreapNode *n = free_list_.back(); 111 delete (n); 112 free_list_.pop_back(); 113 } 114 } 115 116 class iterator : public std::iterator<std::forward_iterator_tag, TreapNode> { 117 public: iterator(Treap * tr)118 explicit iterator(Treap *tr) : tr_(tr), cur_(nullptr) { 119 if (tr_ != nullptr) { 120 cur_ = tr_->root_; 121 while (cur_ != nullptr) { 122 stack_.push(cur_); 123 cur_ = cur_->left; 124 } 125 } 126 if (!stack_.empty()) { 127 cur_ = stack_.top(); 128 } else { 129 cur_ = nullptr; 130 } 131 } ~iterator()132 ~iterator() { 133 tr_ = nullptr; 134 cur_ = nullptr; 135 } 136 137 NodeValue &operator*() { return cur_->nv; } 138 139 NodeValue *operator->() { return &(cur_->nv); } 140 141 const TreapNode &operator*() const { return *cur_; } 142 143 const TreapNode *operator->() const { return cur_; } 144 145 bool operator==(const iterator &rhs) const { return cur_ == rhs.cur_; } 146 147 bool operator!=(const iterator &rhs) const { return cur_ != rhs.cur_; } 148 149 // Prefix increment 150 iterator &operator++() { 151 if (cur_) { 152 stack_.pop(); 153 if (cur_->right) { 154 TreapNode *n = cur_->right; 155 while (n) { 156 stack_.push(n); 157 n = n->left; 158 } 159 } 160 } 161 if (!stack_.empty()) { 162 cur_ = stack_.top(); 163 } else { 164 cur_ = nullptr; 165 } 166 return *this; 167 } 168 169 // Postfix increment 170 iterator operator++(int junk) { 171 iterator tmp(*this); 172 if (cur_) { 173 stack_.pop(); 174 if (cur_->right) { 175 TreapNode *n = cur_->right; 176 while (n) { 177 stack_.push(n); 178 n = n->left; 179 } 180 } 181 } 182 if (!stack_.empty()) { 183 cur_ = stack_.top(); 184 } else { 185 cur_ = nullptr; 186 } 187 return tmp; 188 } 189 190 private: 191 Treap *tr_; 192 TreapNode *cur_; 193 std::stack<TreapNode *> stack_; 194 }; 195 196 class const_iterator : public std::iterator<std::forward_iterator_tag, TreapNode> { 197 public: const_iterator(const Treap * tr)198 explicit const_iterator(const Treap *tr) : tr_(tr), cur_(nullptr) { 199 if (tr_ != nullptr) { 200 cur_ = tr_->root_; 201 while (cur_ != nullptr) { 202 stack_.push(cur_); 203 cur_ = cur_->left; 204 } 205 } 206 if (!stack_.empty()) { 207 cur_ = stack_.top(); 208 } else { 209 cur_ = nullptr; 210 } 211 } ~const_iterator()212 ~const_iterator() { 213 tr_ = nullptr; 214 cur_ = nullptr; 215 } 216 217 const NodeValue &operator*() const { return cur_->nv; } 218 219 const NodeValue *operator->() const { return &(cur_->nv); } 220 221 bool operator==(const const_iterator &rhs) const { return cur_ == rhs.cur_; } 222 223 bool operator!=(const const_iterator &rhs) const { return cur_ != rhs.cur_; } 224 225 // Prefix increment 226 const_iterator &operator++() { 227 if (cur_) { 228 stack_.pop(); 229 if (cur_->right != nullptr) { 230 TreapNode *n = cur_->right; 231 while (n) { 232 stack_.push(n); 233 n = n->left; 234 } 235 } 236 } 237 if (!stack_.empty()) { 238 cur_ = stack_.top(); 239 } else { 240 cur_ = nullptr; 241 } 242 return *this; 243 } 244 245 // Postfix increment 246 const_iterator operator++(int junk) { 247 iterator tmp(*this); 248 if (cur_) { 249 stack_.pop(); 250 if ((cur_->right) != nullptr) { 251 TreapNode *n = cur_->right; 252 while (n) { 253 stack_.push(n); 254 n = n->left; 255 } 256 } 257 } 258 if (!stack_.empty()) { 259 cur_ = stack_.top(); 260 } else { 261 cur_ = nullptr; 262 } 263 return tmp; 264 } 265 266 private: 267 const Treap *tr_; 268 TreapNode *cur_; 269 std::stack<TreapNode *> stack_; 270 }; 271 begin()272 iterator begin() { return iterator(this); } 273 end()274 iterator end() { return iterator(nullptr); } 275 begin()276 const_iterator begin() const { return const_iterator(this); } 277 end()278 const_iterator end() const { return const_iterator(nullptr); } 279 cbegin()280 const_iterator cbegin() { return const_iterator(this); } 281 cend()282 const_iterator cend() { return const_iterator(nullptr); } 283 empty()284 bool empty() { return root_ == nullptr; } 285 size()286 size_t size() { return count_; } 287 288 private: NewNode()289 TreapNode *NewNode() { 290 TreapNode *n = nullptr; 291 if (!free_list_.empty()) { 292 n = free_list_.back(); 293 free_list_.pop_back(); 294 new (n) TreapNode(); 295 } else { 296 n = new TreapNode(); 297 } 298 return n; 299 } 300 FreeNode(TreapNode * n)301 void FreeNode(TreapNode *n) { free_list_.push_back(n); } 302 DeleteTreap(TreapNode * n)303 void DeleteTreap(TreapNode *n) noexcept { 304 if (n == nullptr) { 305 return; 306 } 307 TreapNode *x = n->left; 308 TreapNode *y = n->right; 309 delete (n); 310 DeleteTreap(x); 311 DeleteTreap(y); 312 } 313 RightRotate(TreapNode * y)314 TreapNode *RightRotate(TreapNode *y) { 315 TreapNode *x = y->left; 316 TreapNode *T2 = x->right; 317 x->right = y; 318 y->left = T2; 319 return x; 320 } 321 LeftRotate(TreapNode * x)322 TreapNode *LeftRotate(TreapNode *x) { 323 TreapNode *y = x->right; 324 TreapNode *T2 = y->left; 325 y->left = x; 326 x->right = T2; 327 return y; 328 } 329 Search(TreapNode * n,key_type k)330 TreapNode *Search(TreapNode *n, key_type k) const { 331 key_compare keyCompare; 332 if (n == nullptr) { 333 return n; 334 } else if (keyCompare(k, n->nv.key)) { 335 return Search(n->left, k); 336 } else if (keyCompare(n->nv.key, k)) { 337 return Search(n->right, k); 338 } else { 339 return n; 340 } 341 } 342 Insert(TreapNode * n,key_type k,priority_type p)343 TreapNode *Insert(TreapNode *n, key_type k, priority_type p) { 344 key_compare keyCompare; 345 priority_compare priorityCompare; 346 if (n == nullptr) { 347 n = NewNode(); 348 n->nv.key = k; 349 n->nv.priority = p; 350 count_++; 351 return n; 352 } 353 if (keyCompare(k, n->nv.key)) { 354 n->left = Insert(n->left, k, p); 355 if (priorityCompare(n->nv.priority, n->left->nv.priority)) { 356 n = RightRotate(n); 357 } 358 } else if (keyCompare(n->nv.key, k)) { 359 n->right = Insert(n->right, k, p); 360 if (priorityCompare(n->nv.priority, n->right->nv.priority)) { 361 n = LeftRotate(n); 362 } 363 } else { 364 // If we insert the same key again, do nothing. 365 return n; 366 } 367 return n; 368 } 369 DeleteNode(TreapNode * n,key_type k)370 TreapNode *DeleteNode(TreapNode *n, key_type k) { 371 key_compare keyCompare; 372 priority_compare priorityCompare; 373 if (n == nullptr) { 374 return n; 375 } 376 if (keyCompare(k, n->nv.key)) { 377 n->left = DeleteNode(n->left, k); 378 } else if (keyCompare(n->nv.key, k)) { 379 n->right = DeleteNode(n->right, k); 380 } else if (n->left == nullptr) { 381 TreapNode *t = n; 382 n = n->right; 383 FreeNode(t); 384 count_--; 385 } else if (n->right == nullptr) { 386 TreapNode *t = n; 387 n = n->left; 388 FreeNode(t); 389 count_--; 390 } else if (priorityCompare(n->left->nv.priority, n->right->nv.priority)) { 391 n = LeftRotate(n); 392 n->left = DeleteNode(n->left, k); 393 } else { 394 n = RightRotate(n); 395 n->right = DeleteNode(n->right, k); 396 } 397 return n; 398 } 399 400 static constexpr int kResvSz = 512; 401 TreapNode *root_; 402 size_t count_; 403 std::vector<TreapNode *> free_list_; 404 }; 405 } // namespace dataset 406 } // namespace mindspore 407 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TREAP_H_ 408