• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TREAP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TREAP_H_
18 
19 #include <functional>
20 #include <iterator>
21 #include <stack>
22 #include <utility>
23 #include <vector>
24 
25 namespace mindspore {
26 namespace dataset {
27 // A treap is a combination of binary search tree and heap. Each key is given a priority. The priority
28 // for any non-leaf node is greater than or equal to the priority of its children.
29 // @tparam K
30 //  Data type of key
31 // @tparam P
32 //  Data type of priority
33 // @tparam KC
34 //  Class to compare key. Default to std::less
35 // @tparam KP
36 //  Class to compare priority. Default to std:less
37 template <typename K, typename P, typename KC = std::less<K>, typename KP = std::less<P>>
38 class Treap {
39  public:
40   using key_type = K;
41   using priority_type = P;
42   using key_compare = KC;
43   using priority_compare = KP;
44 
45   struct NodeValue {
46     key_type key;
47     priority_type priority;
48   };
49 
50   class TreapNode {
51    public:
TreapNode()52     TreapNode() : left(nullptr), right(nullptr) {}
~TreapNode()53     ~TreapNode() {
54       left = nullptr;
55       right = nullptr;
56     }
57     NodeValue nv;
58     TreapNode *left;
59     TreapNode *right;
60   };
61 
62   // search API
63   // @param k
64   //    key to search for
65   // @return
66   //    a pair is returned. The 2nd value of type bool indicate if the search is successful.
67   //    If true, the first value of the pair contains the key and the priority.
Search(key_type k)68   std::pair<NodeValue, bool> Search(key_type k) const {
69     auto *n = Search(root_, k);
70     if (n != nullptr) {
71       return std::make_pair(n->nv, true);
72     } else {
73       return std::make_pair(NodeValue{key_type(), priority_type()}, false);
74     }
75   }
76 
77   // @return
78   //    Return the root of the heap. It has the highest priority. But not necessarily the first key.
Top()79   std::pair<NodeValue, bool> Top() const {
80     if (root_ != nullptr) {
81       return std::make_pair(root_->nv, true);
82     } else {
83       return std::make_pair(NodeValue{key_type(), priority_type()}, false);
84     }
85   }
86 
87   // Remove the root of the heap.
Pop()88   void Pop() {
89     if (root_ != nullptr) {
90       DeleteKey(root_->nv.key);
91     }
92   }
93 
94   // Insert API.
95   // @param k
96   //    The key to insert.
97   // @param p
98   //    The priority of the key.
Insert(key_type k,priority_type p)99   void Insert(key_type k, priority_type p) { root_ = Insert(root_, k, p); }
100 
101   // Delete a key.
102   // @param k
DeleteKey(key_type k)103   void DeleteKey(key_type k) { root_ = DeleteNode(root_, k); }
104 
Treap()105   Treap() : root_(nullptr), count_(0) { free_list_.reserve(kResvSz); }
106 
~Treap()107   ~Treap() noexcept {
108     DeleteTreap(root_);
109     while (!free_list_.empty()) {
110       TreapNode *n = free_list_.back();
111       delete (n);
112       free_list_.pop_back();
113     }
114   }
115 
116   class iterator : public std::iterator<std::forward_iterator_tag, TreapNode> {
117    public:
iterator(Treap * tr)118     explicit iterator(Treap *tr) : tr_(tr), cur_(nullptr) {
119       if (tr_ != nullptr) {
120         cur_ = tr_->root_;
121         while (cur_ != nullptr) {
122           stack_.push(cur_);
123           cur_ = cur_->left;
124         }
125       }
126       if (!stack_.empty()) {
127         cur_ = stack_.top();
128       } else {
129         cur_ = nullptr;
130       }
131     }
~iterator()132     ~iterator() {
133       tr_ = nullptr;
134       cur_ = nullptr;
135     }
136 
137     NodeValue &operator*() { return cur_->nv; }
138 
139     NodeValue *operator->() { return &(cur_->nv); }
140 
141     const TreapNode &operator*() const { return *cur_; }
142 
143     const TreapNode *operator->() const { return cur_; }
144 
145     bool operator==(const iterator &rhs) const { return cur_ == rhs.cur_; }
146 
147     bool operator!=(const iterator &rhs) const { return cur_ != rhs.cur_; }
148 
149     // Prefix increment
150     iterator &operator++() {
151       if (cur_) {
152         stack_.pop();
153         if (cur_->right) {
154           TreapNode *n = cur_->right;
155           while (n) {
156             stack_.push(n);
157             n = n->left;
158           }
159         }
160       }
161       if (!stack_.empty()) {
162         cur_ = stack_.top();
163       } else {
164         cur_ = nullptr;
165       }
166       return *this;
167     }
168 
169     // Postfix increment
170     iterator operator++(int junk) {
171       iterator tmp(*this);
172       if (cur_) {
173         stack_.pop();
174         if (cur_->right) {
175           TreapNode *n = cur_->right;
176           while (n) {
177             stack_.push(n);
178             n = n->left;
179           }
180         }
181       }
182       if (!stack_.empty()) {
183         cur_ = stack_.top();
184       } else {
185         cur_ = nullptr;
186       }
187       return tmp;
188     }
189 
190    private:
191     Treap *tr_;
192     TreapNode *cur_;
193     std::stack<TreapNode *> stack_;
194   };
195 
196   class const_iterator : public std::iterator<std::forward_iterator_tag, TreapNode> {
197    public:
const_iterator(const Treap * tr)198     explicit const_iterator(const Treap *tr) : tr_(tr), cur_(nullptr) {
199       if (tr_ != nullptr) {
200         cur_ = tr_->root_;
201         while (cur_ != nullptr) {
202           stack_.push(cur_);
203           cur_ = cur_->left;
204         }
205       }
206       if (!stack_.empty()) {
207         cur_ = stack_.top();
208       } else {
209         cur_ = nullptr;
210       }
211     }
~const_iterator()212     ~const_iterator() {
213       tr_ = nullptr;
214       cur_ = nullptr;
215     }
216 
217     const NodeValue &operator*() const { return cur_->nv; }
218 
219     const NodeValue *operator->() const { return &(cur_->nv); }
220 
221     bool operator==(const const_iterator &rhs) const { return cur_ == rhs.cur_; }
222 
223     bool operator!=(const const_iterator &rhs) const { return cur_ != rhs.cur_; }
224 
225     // Prefix increment
226     const_iterator &operator++() {
227       if (cur_) {
228         stack_.pop();
229         if (cur_->right != nullptr) {
230           TreapNode *n = cur_->right;
231           while (n) {
232             stack_.push(n);
233             n = n->left;
234           }
235         }
236       }
237       if (!stack_.empty()) {
238         cur_ = stack_.top();
239       } else {
240         cur_ = nullptr;
241       }
242       return *this;
243     }
244 
245     // Postfix increment
246     const_iterator operator++(int junk) {
247       iterator tmp(*this);
248       if (cur_) {
249         stack_.pop();
250         if ((cur_->right) != nullptr) {
251           TreapNode *n = cur_->right;
252           while (n) {
253             stack_.push(n);
254             n = n->left;
255           }
256         }
257       }
258       if (!stack_.empty()) {
259         cur_ = stack_.top();
260       } else {
261         cur_ = nullptr;
262       }
263       return tmp;
264     }
265 
266    private:
267     const Treap *tr_;
268     TreapNode *cur_;
269     std::stack<TreapNode *> stack_;
270   };
271 
begin()272   iterator begin() { return iterator(this); }
273 
end()274   iterator end() { return iterator(nullptr); }
275 
begin()276   const_iterator begin() const { return const_iterator(this); }
277 
end()278   const_iterator end() const { return const_iterator(nullptr); }
279 
cbegin()280   const_iterator cbegin() { return const_iterator(this); }
281 
cend()282   const_iterator cend() { return const_iterator(nullptr); }
283 
empty()284   bool empty() { return root_ == nullptr; }
285 
size()286   size_t size() { return count_; }
287 
288  private:
NewNode()289   TreapNode *NewNode() {
290     TreapNode *n = nullptr;
291     if (!free_list_.empty()) {
292       n = free_list_.back();
293       free_list_.pop_back();
294       new (n) TreapNode();
295     } else {
296       n = new TreapNode();
297     }
298     return n;
299   }
300 
FreeNode(TreapNode * n)301   void FreeNode(TreapNode *n) { free_list_.push_back(n); }
302 
DeleteTreap(TreapNode * n)303   void DeleteTreap(TreapNode *n) noexcept {
304     if (n == nullptr) {
305       return;
306     }
307     TreapNode *x = n->left;
308     TreapNode *y = n->right;
309     delete (n);
310     DeleteTreap(x);
311     DeleteTreap(y);
312   }
313 
RightRotate(TreapNode * y)314   TreapNode *RightRotate(TreapNode *y) {
315     TreapNode *x = y->left;
316     TreapNode *T2 = x->right;
317     x->right = y;
318     y->left = T2;
319     return x;
320   }
321 
LeftRotate(TreapNode * x)322   TreapNode *LeftRotate(TreapNode *x) {
323     TreapNode *y = x->right;
324     TreapNode *T2 = y->left;
325     y->left = x;
326     x->right = T2;
327     return y;
328   }
329 
Search(TreapNode * n,key_type k)330   TreapNode *Search(TreapNode *n, key_type k) const {
331     key_compare keyCompare;
332     if (n == nullptr) {
333       return n;
334     } else if (keyCompare(k, n->nv.key)) {
335       return Search(n->left, k);
336     } else if (keyCompare(n->nv.key, k)) {
337       return Search(n->right, k);
338     } else {
339       return n;
340     }
341   }
342 
Insert(TreapNode * n,key_type k,priority_type p)343   TreapNode *Insert(TreapNode *n, key_type k, priority_type p) {
344     key_compare keyCompare;
345     priority_compare priorityCompare;
346     if (n == nullptr) {
347       n = NewNode();
348       n->nv.key = k;
349       n->nv.priority = p;
350       count_++;
351       return n;
352     }
353     if (keyCompare(k, n->nv.key)) {
354       n->left = Insert(n->left, k, p);
355       if (priorityCompare(n->nv.priority, n->left->nv.priority)) {
356         n = RightRotate(n);
357       }
358     } else if (keyCompare(n->nv.key, k)) {
359       n->right = Insert(n->right, k, p);
360       if (priorityCompare(n->nv.priority, n->right->nv.priority)) {
361         n = LeftRotate(n);
362       }
363     } else {
364       // If we insert the same key again, do nothing.
365       return n;
366     }
367     return n;
368   }
369 
DeleteNode(TreapNode * n,key_type k)370   TreapNode *DeleteNode(TreapNode *n, key_type k) {
371     key_compare keyCompare;
372     priority_compare priorityCompare;
373     if (n == nullptr) {
374       return n;
375     }
376     if (keyCompare(k, n->nv.key)) {
377       n->left = DeleteNode(n->left, k);
378     } else if (keyCompare(n->nv.key, k)) {
379       n->right = DeleteNode(n->right, k);
380     } else if (n->left == nullptr) {
381       TreapNode *t = n;
382       n = n->right;
383       FreeNode(t);
384       count_--;
385     } else if (n->right == nullptr) {
386       TreapNode *t = n;
387       n = n->left;
388       FreeNode(t);
389       count_--;
390     } else if (priorityCompare(n->left->nv.priority, n->right->nv.priority)) {
391       n = LeftRotate(n);
392       n->left = DeleteNode(n->left, k);
393     } else {
394       n = RightRotate(n);
395       n->right = DeleteNode(n->right, k);
396     }
397     return n;
398   }
399 
400   static constexpr int kResvSz = 512;
401   TreapNode *root_;
402   size_t count_;
403   std::vector<TreapNode *> free_list_;
404 };
405 }  // namespace dataset
406 }  // namespace mindspore
407 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TREAP_H_
408