• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_UTILS_ORDERED_SET_H_
18 #define MINDSPORE_CORE_UTILS_ORDERED_SET_H_
19 
20 #include <algorithm>
21 #include <unordered_map>
22 #include <vector>
23 #include <list>
24 #include <utility>
25 #include <functional>
26 #include <memory>
27 #include "utils/hashing.h"
28 
29 namespace mindspore {
30 // Implementation of OrderedSet that keeps insertion order
31 // using map as set, and use list as a sequential container to record elements to keep insertion order
32 template <class T, class Hash = std::hash<T>, class KeyEqual = std::equal_to<T>>
33 class OrderedSet {
34  public:
35   using element_type = T;
36   using hasher = Hash;
37   using equal = KeyEqual;
38   using sequential_type = std::list<element_type>;
39   using vector_type = std::vector<element_type>;
40   using iterator = typename sequential_type::iterator;
41   using const_iterator = typename sequential_type::const_iterator;
42   using reverse_iterator = typename sequential_type::reverse_iterator;
43   using const_reverse_iterator = typename sequential_type::const_reverse_iterator;
44   using map_type = std::unordered_map<element_type, iterator, hasher, equal>;
45   using ordered_set_type = OrderedSet<element_type, hasher, equal>;
46 
47   OrderedSet() = default;
48   ~OrderedSet() = default;
49   // OrderedSet use an iterator to list as mapped value to improve the performance of insertion and deletion,
50   // So copy of OrderedSet should re-build value of the map key to make it pointer to the new list,, thus we use
51   // traversal to build elements.
OrderedSet(const OrderedSet & os)52   OrderedSet(const OrderedSet &os) {
53     for (auto &item : os.ordered_data_) {
54       add(item);
55     }
56   }
57 
58   OrderedSet(OrderedSet &&os) = default;
59 
OrderedSet(const sequential_type & other)60   explicit OrderedSet(const sequential_type &other) {
61     for (auto &item : other) {
62       add(item);
63     }
64   }
65 
66   // Explicitly construct an OrderedSet use vector
OrderedSet(const vector_type & other)67   explicit OrderedSet(const vector_type &other) {
68     for (auto &item : other) {
69       add(item);
70     }
71   }
72 
73   OrderedSet &operator=(const OrderedSet &other) {
74     if (this != &other) {
75       clear();
76       reserve(other.size());
77       for (auto &item : other.ordered_data_) {
78         add(item);
79       }
80     }
81     return *this;
82   }
83 
84   OrderedSet &operator=(OrderedSet &&other) = default;
85 
86   // insert an element to the OrderedSet after the given position.
insert(iterator pos,const element_type & e)87   std::pair<iterator, bool> insert(iterator pos, const element_type &e) {
88     auto result = map_.emplace(e, ordered_data_.end());
89     if (result.second) {
90       result.first->second = ordered_data_.emplace(pos, e);
91     }
92     return {result.first->second, result.second};
93   }
94 
95   // Add an element to the OrderedSet, without judging return value
add(const element_type & e)96   void add(const element_type &e) { (void)insert(ordered_data_.end(), e); }
97 
98   // insert an element to the end of OrderedSet.
insert(const element_type & e)99   std::pair<iterator, bool> insert(const element_type &e) { return insert(ordered_data_.end(), e); }
100 
push_back(const element_type & e)101   void push_back(const element_type &e) { (void)insert(ordered_data_.end(), e); }
102 
push_front(const element_type & e)103   void push_front(const element_type &e) { (void)insert(ordered_data_.begin(), e); }
104 
105   // Remove an element, if removed return true, otherwise return false
erase(const element_type & e)106   bool erase(const element_type &e) {
107     auto pos = map_.find(e);
108     if (pos == map_.end()) {
109       return false;
110     }
111     // erase the sequential data first
112     (void)ordered_data_.erase(pos->second);
113     (void)map_.erase(pos);
114     return true;
115   }
116 
erase(iterator pos)117   iterator erase(iterator pos) {
118     (void)map_.erase(*pos);
119     return ordered_data_.erase(pos);
120   }
121 
erase(const_iterator pos)122   iterator erase(const_iterator pos) {
123     (void)map_.erase(*pos);
124     return ordered_data_.erase(pos);
125   }
126 
127   // Return the container size
size()128   std::size_t size() const { return map_.size(); }
129 
empty()130   bool empty() const { return map_.size() == 0; }
131 
132   // Clear the elements
clear()133   void clear() {
134     map_.clear();
135     ordered_data_.clear();
136   }
137 
138   // Reserve memory for the number of entries.
reserve(size_t num_entries)139   void reserve(size_t num_entries) { map_.reserve(num_entries); }
140 
141   // Compare two orderedset, if the order is not equal shall return false
142   bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; }
143 
pop()144   element_type pop() {
145     element_type e = std::move(ordered_data_.front());
146     (void)map_.erase(e);
147     (void)ordered_data_.erase(ordered_data_.begin());
148     return e;
149   }
150 
back()151   element_type &back() { return ordered_data_.back(); }
front()152   element_type &front() { return ordered_data_.front(); }
153 
back()154   const element_type &back() const { return ordered_data_.back(); }
front()155   const element_type &front() const { return ordered_data_.front(); }
156 
157   // Return true if there are no common elements
is_disjoint(const OrderedSet & other)158   bool is_disjoint(const OrderedSet &other) {
159     for (auto &item : other.ordered_data_) {
160       if (map_.find(item) != map_.end()) {
161         return false;
162       }
163     }
164     return true;
165   }
166 
167   // Test whether this is subset of other
is_subset(const OrderedSet & other)168   bool is_subset(const OrderedSet &other) {
169     for (auto &item : ordered_data_) {
170       if (other.map_.find(item) == other.map_.end()) {
171         return false;
172       }
173     }
174     return true;
175   }
176 
177   // Add elements in other to this orderedset
update(const OrderedSet & other)178   void update(const OrderedSet &other) {
179     for (auto &item : other.ordered_data_) {
180       add(item);
181     }
182   }
183 
update(const std::shared_ptr<OrderedSet> & other)184   void update(const std::shared_ptr<OrderedSet> &other) { update(*other); }
185 
update(const sequential_type & other)186   void update(const sequential_type &other) {
187     for (auto &item : other) {
188       add(item);
189     }
190   }
191 
update(const vector_type & other)192   void update(const vector_type &other) {
193     for (auto &item : other) {
194       add(item);
195     }
196   }
197 
get_union(const OrderedSet & other)198   ordered_set_type get_union(const OrderedSet &other) {
199     ordered_set_type res(ordered_data_);
200     res.update(other);
201     return res;
202   }
203 
204   // Get the union with other set, this operator may cost time because of copy
205   ordered_set_type operator|(const OrderedSet &other) { return get_union(other); }
206 
207   // Return the intersection of two sets
intersection(const OrderedSet & other)208   ordered_set_type intersection(const OrderedSet &other) {
209     ordered_set_type res(ordered_data_);
210     for (auto &item : ordered_data_) {
211       if (other.map_.find(item) == other.map_.end()) {
212         (void)res.erase(item);
213       }
214     }
215     return res;
216   }
217   ordered_set_type operator&(const OrderedSet &other) { return intersection(other); }
218 
219   // Return the symmetric difference of two sets
symmetric_difference(const OrderedSet & other)220   ordered_set_type symmetric_difference(const OrderedSet &other) {
221     ordered_set_type res(ordered_data_);
222     for (auto &item : other.ordered_data_) {
223       if (map_.find(item) != map_.end()) {
224         (void)res.erase(item);
225       } else {
226         res.add(item);
227       }
228     }
229     return res;
230   }
231 
232   ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); }
233 
234   // Remove elements which is also in others.
difference_update(const OrderedSet & other)235   void difference_update(const OrderedSet &other) {
236     // use vector traversal, to keep ordrer
237     for (auto &item : other.ordered_data_) {
238       (void)erase(item);
239     }
240   }
241 
difference_update(const sequential_type & other)242   void difference_update(const sequential_type &other) {
243     for (auto &item : other) {
244       (void)erase(item);
245     }
246   }
247 
difference_update(const vector_type & other)248   void difference_update(const vector_type &other) {
249     for (auto &item : other) {
250       (void)erase(item);
251     }
252   }
253 
254   // Return the set with elements that are not in the others
difference(const OrderedSet & other)255   ordered_set_type difference(const OrderedSet &other) {
256     ordered_set_type res(ordered_data_);
257     res.difference_update(other);
258     return res;
259   }
260   ordered_set_type operator-(const OrderedSet &other) { return difference(other); }
261 
contains(const element_type & e)262   bool contains(const element_type &e) const { return (map_.find(e) != map_.end()); }
263 
find(const element_type & e)264   const_iterator find(const element_type &e) const {
265     auto iter = map_.find(e);
266     if (iter == map_.end()) {
267       return ordered_data_.end();
268     }
269     return iter->second;
270   }
271 
find(const element_type & e)272   iterator find(const element_type &e) {
273     auto iter = map_.find(e);
274     if (iter == map_.end()) {
275       return ordered_data_.end();
276     }
277     return iter->second;
278   }
279 
280   // Return the count of an element in set
count(const element_type & e)281   std::size_t count(const element_type &e) const { return map_.count(e); }
282 
begin()283   iterator begin() { return ordered_data_.begin(); }
end()284   iterator end() { return ordered_data_.end(); }
285 
begin()286   const_iterator begin() const { return ordered_data_.cbegin(); }
end()287   const_iterator end() const { return ordered_data_.cend(); }
288 
cbegin()289   const_iterator cbegin() const { return ordered_data_.cbegin(); }
cend()290   const_iterator cend() const { return ordered_data_.cend(); }
291 
292  private:
293   map_type map_;
294   sequential_type ordered_data_;
295 };
296 
297 // OrderedSet that specially optimized for shared_ptr.
298 template <class T>
299 class OrderedSet<std::shared_ptr<T>> {
300  public:
301   using element_type = std::shared_ptr<T>;
302   using key_type = const T *;
303   using hash_t = PointerHash<T>;
304   using sequential_type = std::list<element_type>;
305   using vector_type = std::vector<element_type>;
306   using iterator = typename sequential_type::iterator;
307   using const_iterator = typename sequential_type::const_iterator;
308   using reverse_iterator = typename sequential_type::reverse_iterator;
309   using const_reverse_iterator = typename sequential_type::const_reverse_iterator;
310   using map_type = std::unordered_map<key_type, iterator, hash_t>;
311   using ordered_set_type = OrderedSet<std::shared_ptr<T>>;
312 
313   OrderedSet() = default;
314   ~OrderedSet() = default;
315 
OrderedSet(const OrderedSet & os)316   OrderedSet(const OrderedSet &os) {
317     for (auto &item : os.ordered_data_) {
318       add(item);
319     }
320   }
321 
322   OrderedSet(OrderedSet &&os) = default;
323 
OrderedSet(const sequential_type & other)324   explicit OrderedSet(const sequential_type &other) {
325     reserve(other.size());
326     for (auto &item : other) {
327       add(item);
328     }
329   }
330 
OrderedSet(const vector_type & other)331   explicit OrderedSet(const vector_type &other) {
332     reserve(other.size());
333     for (auto &item : other) {
334       add(item);
335     }
336   }
337 
338   OrderedSet &operator=(const OrderedSet &other) {
339     if (this != &other) {
340       clear();
341       reserve(other.size());
342       for (auto &item : other.ordered_data_) {
343         add(item);
344       }
345     }
346     return *this;
347   }
348 
349   OrderedSet &operator=(OrderedSet &&other) = default;
350 
insert(iterator pos,const element_type & e)351   std::pair<iterator, bool> insert(iterator pos, const element_type &e) {
352     auto [map_iter, inserted] = map_.emplace(e.get(), iterator{});
353     if (inserted) {
354       map_iter->second = ordered_data_.emplace(pos, e);
355     }
356     return {map_iter->second, inserted};
357   }
358 
insert(iterator pos,element_type && e)359   std::pair<iterator, bool> insert(iterator pos, element_type &&e) {
360     auto [map_iter, inserted] = map_.emplace(e.get(), iterator{});
361     if (inserted) {
362       map_iter->second = ordered_data_.emplace(pos, std::move(e));
363     }
364     return {map_iter->second, inserted};
365   }
366 
add(const element_type & e)367   void add(const element_type &e) { (void)insert(ordered_data_.end(), e); }
368 
add(element_type && e)369   void add(element_type &&e) { (void)insert(ordered_data_.end(), std::move(e)); }
370 
insert(const element_type & e)371   std::pair<iterator, bool> insert(const element_type &e) { return insert(ordered_data_.end(), e); }
372 
insert(element_type && e)373   std::pair<iterator, bool> insert(element_type &&e) { return insert(ordered_data_.end(), std::move(e)); }
374 
push_back(const element_type & e)375   void push_back(const element_type &e) { (void)insert(ordered_data_.end(), e); }
376 
push_front(const element_type & e)377   void push_front(const element_type &e) { (void)insert(ordered_data_.begin(), e); }
378 
erase(const element_type & e)379   bool erase(const element_type &e) {
380     auto pos = map_.find(e.get());
381     if (pos == map_.end()) {
382       return false;
383     }
384     auto iter = pos->second;
385     (void)map_.erase(pos);
386     (void)ordered_data_.erase(iter);
387     return true;
388   }
389 
erase(iterator pos)390   iterator erase(iterator pos) {
391     (void)map_.erase(pos->get());
392     return ordered_data_.erase(pos);
393   }
394 
erase(const_iterator pos)395   iterator erase(const_iterator pos) {
396     (void)map_.erase(pos->get());
397     return ordered_data_.erase(pos);
398   }
399 
size()400   std::size_t size() const { return ordered_data_.size(); }
401 
empty()402   bool empty() const { return ordered_data_.empty(); }
403 
clear()404   void clear() {
405     map_.clear();
406     ordered_data_.clear();
407   }
408 
reserve(size_t num_entries)409   void reserve(size_t num_entries) { map_.reserve(num_entries); }
410 
411   bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; }
412 
pop()413   element_type pop() {
414     element_type e = std::move(ordered_data_.front());
415     (void)map_.erase(e.get());
416     (void)ordered_data_.erase(ordered_data_.begin());
417     return e;
418   }
419 
back()420   element_type &back() { return ordered_data_.back(); }
front()421   element_type &front() { return ordered_data_.front(); }
422 
back()423   const element_type &back() const { return ordered_data_.back(); }
front()424   const element_type &front() const { return ordered_data_.front(); }
425 
426   // Return true if there are no common elements.
is_disjoint(const OrderedSet & other)427   bool is_disjoint(const OrderedSet &other) {
428     return std::all_of(begin(), end(), [&other](const auto &e) { return !other.contains(e); });
429   }
430 
431   // Test whether this is subset of other.
is_subset(const OrderedSet & other)432   bool is_subset(const OrderedSet &other) {
433     return std::all_of(begin(), end(), [&other](const auto &e) { return other.contains(e); });
434   }
435 
436   // Add elements in other to this orderedset.
update(const OrderedSet & other)437   void update(const OrderedSet &other) {
438     for (auto &item : other.ordered_data_) {
439       add(item);
440     }
441   }
442 
update(const std::shared_ptr<OrderedSet> & other)443   void update(const std::shared_ptr<OrderedSet> &other) { update(*other); }
444 
update(const sequential_type & other)445   void update(const sequential_type &other) {
446     for (auto &item : other) {
447       add(item);
448     }
449   }
450 
update(const vector_type & other)451   void update(const vector_type &other) {
452     for (auto &item : other) {
453       add(item);
454     }
455   }
456 
get_union(const OrderedSet & other)457   ordered_set_type get_union(const OrderedSet &other) {
458     ordered_set_type res(ordered_data_);
459     res.update(other);
460     return res;
461   }
462 
463   // Get the union with other set, this operator may cost time because of copy.
464   ordered_set_type operator|(const OrderedSet &other) { return get_union(other); }
465 
466   // Return the intersection of two sets.
intersection(const OrderedSet & other)467   ordered_set_type intersection(const OrderedSet &other) {
468     ordered_set_type res;
469     for (auto &item : ordered_data_) {
470       if (other.contains(item)) {
471         res.add(item);
472       }
473     }
474     return res;
475   }
476 
477   ordered_set_type operator&(const OrderedSet &other) { return intersection(other); }
478 
479   // Return the symmetric difference of two sets.
symmetric_difference(const OrderedSet & other)480   ordered_set_type symmetric_difference(const OrderedSet &other) {
481     ordered_set_type res(ordered_data_);
482     for (auto &item : other) {
483       if (contains(item)) {
484         (void)res.erase(item);
485       } else {
486         res.add(item);
487       }
488     }
489     return res;
490   }
491 
492   ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); }
493 
494   // Remove elements which is also in others.
difference_update(const OrderedSet & other)495   void difference_update(const OrderedSet &other) {
496     for (auto &item : other) {
497       (void)erase(item);
498     }
499   }
500 
difference_update(const sequential_type & other)501   void difference_update(const sequential_type &other) {
502     for (auto &item : other) {
503       (void)erase(item);
504     }
505   }
506 
difference_update(const vector_type & other)507   void difference_update(const vector_type &other) {
508     for (auto &item : other) {
509       (void)erase(item);
510     }
511   }
512 
513   // Return the set with elements that are not in the others.
difference(const OrderedSet & other)514   ordered_set_type difference(const OrderedSet &other) {
515     ordered_set_type res;
516     for (auto &item : ordered_data_) {
517       if (!other.contains(item)) {
518         res.add(item);
519       }
520     }
521     return res;
522   }
523 
524   ordered_set_type operator-(const OrderedSet &other) { return difference(other); }
525 
contains(const element_type & e)526   bool contains(const element_type &e) const { return (map_.find(e.get()) != map_.end()); }
527 
find(const element_type & e)528   const_iterator find(const element_type &e) const {
529     auto iter = map_.find(e.get());
530     if (iter == map_.end()) {
531       return ordered_data_.end();
532     }
533     return iter->second;
534   }
535 
find(const element_type & e)536   iterator find(const element_type &e) {
537     auto iter = map_.find(e.get());
538     if (iter == map_.end()) {
539       return ordered_data_.end();
540     }
541     return iter->second;
542   }
543 
count(const element_type & e)544   std::size_t count(const element_type &e) const { return map_.count(e.get()); }
545 
begin()546   iterator begin() { return ordered_data_.begin(); }
end()547   iterator end() { return ordered_data_.end(); }
548 
begin()549   const_iterator begin() const { return ordered_data_.cbegin(); }
end()550   const_iterator end() const { return ordered_data_.cend(); }
551 
cbegin()552   const_iterator cbegin() const { return ordered_data_.cbegin(); }
cend()553   const_iterator cend() const { return ordered_data_.cend(); }
554 
555  private:
556   map_type map_;
557   sequential_type ordered_data_;
558 };
559 }  // namespace mindspore
560 
561 #endif  // MINDSPORE_CORE_UTILS_ORDERED_SET_H_
562