• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_UTILS_ORDERED_SET_H_
18 #define MINDSPORE_CORE_UTILS_ORDERED_SET_H_
19 
20 #include <algorithm>
21 #include <vector>
22 #include <list>
23 #include <utility>
24 #include <functional>
25 #include <memory>
26 #include "utils/hashing.h"
27 #include "utils/hash_map.h"
28 
29 namespace mindspore {
30 // Implementation of OrderedSet that keeps insertion order
31 // using map as set, and use list as a sequential container to record elements to keep insertion order
32 template <class T, class Hash = std::hash<T>, class KeyEqual = std::equal_to<T>>
33 class OrderedSet {
34  public:
35   using element_type = T;
36   using hasher = Hash;
37   using equal = KeyEqual;
38   using sequential_type = std::list<element_type>;
39   using vector_type = std::vector<element_type>;
40   using iterator = typename sequential_type::iterator;
41   using const_iterator = typename sequential_type::const_iterator;
42   using reverse_iterator = typename sequential_type::reverse_iterator;
43   using const_reverse_iterator = typename sequential_type::const_reverse_iterator;
44   using map_type = mindspore::HashMap<element_type, iterator, hasher, equal>;
45   using ordered_set_type = OrderedSet<element_type, hasher, equal>;
46 
47   OrderedSet() = default;
48   ~OrderedSet() = default;
49   // OrderedSet use an iterator to list as mapped value to improve the performance of insertion and deletion,
50   // So copy of OrderedSet should re-build value of the map key to make it pointer to the new list,, thus we use
51   // traversal to build elements.
OrderedSet(const OrderedSet & os)52   OrderedSet(const OrderedSet &os) {
53     for (auto &item : os.ordered_data_) {
54       add(item);
55     }
56   }
57 
58   OrderedSet(OrderedSet &&os) = default;
59 
OrderedSet(const sequential_type & other)60   explicit OrderedSet(const sequential_type &other) {
61     for (auto &item : other) {
62       add(item);
63     }
64   }
65 
66   // Explicitly construct an OrderedSet use vector
OrderedSet(const vector_type & other)67   explicit OrderedSet(const vector_type &other) {
68     for (auto &item : other) {
69       add(item);
70     }
71   }
72 
73   OrderedSet &operator=(const OrderedSet &other) {
74     if (this != &other) {
75       clear();
76       reserve(other.size());
77       for (auto &item : other.ordered_data_) {
78         add(item);
79       }
80     }
81     return *this;
82   }
83 
84   OrderedSet &operator=(OrderedSet &&other) = default;
85 
86   // insert an element to the OrderedSet after the given position.
insert(const iterator & pos,const element_type & e)87   std::pair<iterator, bool> insert(const iterator &pos, const element_type &e) {
88     auto result = map_.emplace(e, ordered_data_.end());
89     if (result.second) {
90       result.first->second = ordered_data_.emplace(pos, e);
91     }
92     return {result.first->second, result.second};
93   }
94 
95   // Add an element to the OrderedSet, without judging return value
add(const element_type & e)96   void add(const element_type &e) { (void)insert(ordered_data_.end(), e); }
97 
98   // insert an element to the end of OrderedSet.
insert(const element_type & e)99   std::pair<iterator, bool> insert(const element_type &e) { return insert(ordered_data_.end(), e); }
100 
push_back(const element_type & e)101   void push_back(const element_type &e) const { (void)insert(ordered_data_.end(), e); }
102 
push_front(const element_type & e)103   void push_front(const element_type &e) const { (void)insert(ordered_data_.begin(), e); }
104 
105   // Remove an element, if removed return true, otherwise return false
erase(const element_type & e)106   bool erase(const element_type &e) {
107     auto pos = map_.find(e);
108     if (pos == map_.end()) {
109       return false;
110     }
111     // erase the sequential data first
112     (void)ordered_data_.erase(pos->second);
113     (void)map_.erase(pos);
114     return true;
115   }
116 
erase(iterator pos)117   iterator erase(iterator pos) {
118     (void)map_.erase(*pos);
119     return ordered_data_.erase(pos);
120   }
121 
erase(const_iterator pos)122   iterator erase(const_iterator pos) {
123     (void)map_.erase(*pos);
124     return ordered_data_.erase(pos);
125   }
126 
127   // Return the container size
size()128   std::size_t size() const { return map_.size(); }
129 
empty()130   bool empty() const { return map_.size() == 0; }
131 
132   // Clear the elements
clear()133   void clear() {
134     if (!map_.empty()) {
135       map_.clear();
136       ordered_data_.clear();
137     }
138   }
139 
140   // Reserve memory for the number of entries.
reserve(size_t num_entries)141   void reserve(size_t num_entries) { map_.reserve(num_entries); }
142 
143   // Compare two orderedset, if the order is not equal shall return false
144   bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; }
145 
pop()146   element_type pop() {
147     element_type e = std::move(ordered_data_.front());
148     (void)map_.erase(e);
149     (void)ordered_data_.erase(ordered_data_.begin());
150     return e;
151   }
152 
back()153   element_type &back() { return ordered_data_.back(); }
front()154   element_type &front() { return ordered_data_.front(); }
155 
back()156   const element_type &back() const { return ordered_data_.back(); }
front()157   const element_type &front() const { return ordered_data_.front(); }
158 
159   // Return true if there are no common elements
is_disjoint(const OrderedSet & other)160   bool is_disjoint(const OrderedSet &other) {
161     for (auto &item : other.ordered_data_) {
162       if (map_.find(item) != map_.end()) {
163         return false;
164       }
165     }
166     return true;
167   }
168 
169   // Test whether this is subset of other
is_subset(const OrderedSet & other)170   bool is_subset(const OrderedSet &other) {
171     for (auto &item : ordered_data_) {
172       if (other.map_.find(item) == other.map_.end()) {
173         return false;
174       }
175     }
176     return true;
177   }
178 
179   // Add elements in other to this orderedset
update(const OrderedSet & other)180   void update(const OrderedSet &other) {
181     for (auto &item : other.ordered_data_) {
182       add(item);
183     }
184   }
185 
update(const std::shared_ptr<OrderedSet> & other)186   void update(const std::shared_ptr<OrderedSet> &other) { update(*other); }
187 
update(const sequential_type & other)188   void update(const sequential_type &other) {
189     for (auto &item : other) {
190       add(item);
191     }
192   }
193 
update(const vector_type & other)194   void update(const vector_type &other) {
195     for (auto &item : other) {
196       add(item);
197     }
198   }
199 
get_union(const OrderedSet & other)200   ordered_set_type get_union(const OrderedSet &other) {
201     ordered_set_type res(ordered_data_);
202     res.update(other);
203     return res;
204   }
205 
206   // Get the union with other set, this operator may cost time because of copy
207   ordered_set_type operator|(const OrderedSet &other) { return get_union(other); }
208 
209   // Return the intersection of two sets
intersection(const OrderedSet & other)210   ordered_set_type intersection(const OrderedSet &other) {
211     ordered_set_type res(ordered_data_);
212     for (auto &item : ordered_data_) {
213       if (other.map_.find(item) == other.map_.end()) {
214         (void)res.erase(item);
215       }
216     }
217     return res;
218   }
219   ordered_set_type operator&(const OrderedSet &other) { return intersection(other); }
220 
221   // Return the symmetric difference of two sets
symmetric_difference(const OrderedSet & other)222   ordered_set_type symmetric_difference(const OrderedSet &other) {
223     ordered_set_type res(ordered_data_);
224     for (auto &item : other.ordered_data_) {
225       if (map_.find(item) != map_.end()) {
226         (void)res.erase(item);
227       } else {
228         res.add(item);
229       }
230     }
231     return res;
232   }
233 
234   ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); }
235 
236   // Remove elements which is also in others.
difference_update(const OrderedSet & other)237   void difference_update(const OrderedSet &other) {
238     // use vector traversal, to keep ordrer
239     for (auto &item : other.ordered_data_) {
240       (void)erase(item);
241     }
242   }
243 
difference_update(const sequential_type & other)244   void difference_update(const sequential_type &other) {
245     for (auto &item : other) {
246       (void)erase(item);
247     }
248   }
249 
difference_update(const vector_type & other)250   void difference_update(const vector_type &other) {
251     for (auto &item : other) {
252       (void)erase(item);
253     }
254   }
255 
256   // Return the set with elements that are not in the others
difference(const OrderedSet & other)257   ordered_set_type difference(const OrderedSet &other) {
258     ordered_set_type res(ordered_data_);
259     res.difference_update(other);
260     return res;
261   }
262   ordered_set_type operator-(const OrderedSet &other) { return difference(other); }
263 
contains(const element_type & e)264   bool contains(const element_type &e) const { return (map_.find(e) != map_.end()); }
265 
find(const element_type & e)266   const_iterator find(const element_type &e) const {
267     auto iter = map_.find(e);
268     if (iter == map_.end()) {
269       return ordered_data_.end();
270     }
271     return iter->second;
272   }
273 
find(const element_type & e)274   iterator find(const element_type &e) {
275     auto iter = map_.find(e);
276     if (iter == map_.end()) {
277       return ordered_data_.end();
278     }
279     return iter->second;
280   }
281 
282   // Return the count of an element in set
count(const element_type & e)283   std::size_t count(const element_type &e) const { return map_.count(e); }
284 
begin()285   iterator begin() { return ordered_data_.begin(); }
end()286   iterator end() { return ordered_data_.end(); }
287 
begin()288   const_iterator begin() const { return ordered_data_.cbegin(); }
end()289   const_iterator end() const { return ordered_data_.cend(); }
290 
cbegin()291   const_iterator cbegin() const { return ordered_data_.cbegin(); }
cend()292   const_iterator cend() const { return ordered_data_.cend(); }
293 
rbegin()294   reverse_iterator rbegin() const { return ordered_data_.rbegin(); }
rend()295   reverse_iterator rend() const { return ordered_data_.rend(); }
296 
297  private:
298   map_type map_;
299   sequential_type ordered_data_;
300 };
301 
302 // OrderedSet that specially optimized for shared_ptr.
303 template <class T>
304 class OrderedSet<std::shared_ptr<T>> {
305  public:
306   using element_type = std::shared_ptr<T>;
307   using key_type = const T *;
308   using sequential_type = std::list<element_type>;
309   using vector_type = std::vector<element_type>;
310   using iterator = typename sequential_type::iterator;
311   using const_iterator = typename sequential_type::const_iterator;
312   using reverse_iterator = typename sequential_type::reverse_iterator;
313   using const_reverse_iterator = typename sequential_type::const_reverse_iterator;
314   using map_type = mindspore::HashMap<key_type, iterator>;
315   using ordered_set_type = OrderedSet<std::shared_ptr<T>>;
316 
317   OrderedSet() = default;
318   ~OrderedSet() = default;
319 
OrderedSet(const OrderedSet & os)320   OrderedSet(const OrderedSet &os) {
321     for (auto &item : os.ordered_data_) {
322       add(item);
323     }
324   }
325 
326   OrderedSet(OrderedSet &&os) = default;
327 
OrderedSet(const sequential_type & other)328   explicit OrderedSet(const sequential_type &other) {
329     reserve(other.size());
330     for (auto &item : other) {
331       add(item);
332     }
333   }
334 
OrderedSet(const vector_type & other)335   explicit OrderedSet(const vector_type &other) {
336     reserve(other.size());
337     for (auto &item : other) {
338       add(item);
339     }
340   }
341 
342   OrderedSet &operator=(const OrderedSet &other) {
343     if (this != &other) {
344       clear();
345       reserve(other.size());
346       for (auto &item : other.ordered_data_) {
347         add(item);
348       }
349     }
350     return *this;
351   }
352 
353   OrderedSet &operator=(OrderedSet &&other) = default;
354 
insert(const iterator & pos,const element_type & e)355   std::pair<iterator, bool> insert(const iterator &pos, const element_type &e) {
356     auto [map_iter, inserted] = map_.emplace(e.get(), iterator{});
357     if (inserted) {
358       map_iter->second = ordered_data_.emplace(pos, e);
359     }
360     return {map_iter->second, inserted};
361   }
362 
insert(const iterator & pos,element_type && e)363   std::pair<iterator, bool> insert(const iterator &pos, element_type &&e) {
364     auto [map_iter, inserted] = map_.emplace(e.get(), iterator{});
365     if (inserted) {
366       map_iter->second = ordered_data_.emplace(pos, std::move(e));
367     }
368     return {map_iter->second, inserted};
369   }
370 
add(const element_type & e)371   void add(const element_type &e) { (void)insert(ordered_data_.end(), e); }
372 
add(element_type && e)373   void add(element_type &&e) { (void)insert(ordered_data_.end(), std::move(e)); }
374 
insert(const element_type & e)375   std::pair<iterator, bool> insert(const element_type &e) { return insert(ordered_data_.end(), e); }
376 
insert(element_type && e)377   std::pair<iterator, bool> insert(element_type &&e) { return insert(ordered_data_.end(), std::move(e)); }
378 
push_back(const element_type & e)379   void push_back(const element_type &e) { (void)insert(ordered_data_.end(), e); }
380 
push_front(const element_type & e)381   void push_front(const element_type &e) { (void)insert(ordered_data_.begin(), e); }
382 
erase(const element_type & e)383   bool erase(const element_type &e) {
384     auto pos = map_.find(e.get());
385     if (pos == map_.end()) {
386       return false;
387     }
388     auto iter = pos->second;
389     (void)map_.erase(pos);
390     (void)ordered_data_.erase(iter);
391     return true;
392   }
393 
erase(const iterator & pos)394   iterator erase(const iterator &pos) {
395     (void)map_.erase(pos->get());
396     return ordered_data_.erase(pos);
397   }
398 
erase(const_iterator pos)399   iterator erase(const_iterator pos) {
400     (void)map_.erase(pos->get());
401     return ordered_data_.erase(pos);
402   }
403 
size()404   std::size_t size() const { return ordered_data_.size(); }
405 
empty()406   bool empty() const { return ordered_data_.empty(); }
407 
clear()408   void clear() {
409     if (!map_.empty()) {
410       map_.clear();
411       ordered_data_.clear();
412     }
413   }
414 
reserve(size_t num_entries)415   void reserve(size_t num_entries) { map_.reserve(num_entries); }
416 
417   bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; }
418 
pop()419   element_type pop() {
420     element_type e = std::move(ordered_data_.front());
421     (void)map_.erase(e.get());
422     (void)ordered_data_.erase(ordered_data_.begin());
423     return e;
424   }
425 
back()426   element_type &back() { return ordered_data_.back(); }
front()427   element_type &front() { return ordered_data_.front(); }
428 
back()429   const element_type &back() const { return ordered_data_.back(); }
front()430   const element_type &front() const { return ordered_data_.front(); }
431 
432   // Return true if there are no common elements.
is_disjoint(const OrderedSet & other)433   bool is_disjoint(const OrderedSet &other) const {
434     return std::all_of(begin(), end(), [&other](const auto &e) { return !other.contains(e); });
435   }
436 
437   // Test whether this is subset of other.
is_subset(const OrderedSet & other)438   bool is_subset(const OrderedSet &other) {
439     return std::all_of(begin(), end(), [&other](const auto &e) { return other.contains(e); });
440   }
441 
442   // Add elements in other to this orderedset.
update(const OrderedSet & other)443   void update(const OrderedSet &other) {
444     for (auto &item : other.ordered_data_) {
445       add(item);
446     }
447   }
448 
update(const std::shared_ptr<OrderedSet> & other)449   void update(const std::shared_ptr<OrderedSet> &other) { update(*other); }
450 
update(const sequential_type & other)451   void update(const sequential_type &other) {
452     for (auto &item : other) {
453       add(item);
454     }
455   }
456 
update(const vector_type & other)457   void update(const vector_type &other) {
458     for (auto &item : other) {
459       add(item);
460     }
461   }
462 
get_union(const OrderedSet & other)463   ordered_set_type get_union(const OrderedSet &other) {
464     ordered_set_type res(ordered_data_);
465     res.update(other);
466     return res;
467   }
468 
469   // Get the union with other set, this operator may cost time because of copy.
470   ordered_set_type operator|(const OrderedSet &other) { return get_union(other); }
471 
472   // Return the intersection of two sets.
intersection(const OrderedSet & other)473   ordered_set_type intersection(const OrderedSet &other) {
474     ordered_set_type res;
475     for (auto &item : ordered_data_) {
476       if (other.contains(item)) {
477         res.add(item);
478       }
479     }
480     return res;
481   }
482 
483   ordered_set_type operator&(const OrderedSet &other) { return intersection(other); }
484 
485   // Return the symmetric difference of two sets.
symmetric_difference(const OrderedSet & other)486   ordered_set_type symmetric_difference(const OrderedSet &other) {
487     ordered_set_type res(ordered_data_);
488     for (auto &item : other) {
489       if (contains(item)) {
490         (void)res.erase(item);
491       } else {
492         res.add(item);
493       }
494     }
495     return res;
496   }
497 
498   ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); }
499 
500   // Remove elements which is also in others.
difference_update(const OrderedSet & other)501   void difference_update(const OrderedSet &other) {
502     for (auto &item : other) {
503       (void)erase(item);
504     }
505   }
506 
difference_update(const sequential_type & other)507   void difference_update(const sequential_type &other) {
508     for (auto &item : other) {
509       (void)erase(item);
510     }
511   }
512 
difference_update(const vector_type & other)513   void difference_update(const vector_type &other) {
514     for (auto &item : other) {
515       (void)erase(item);
516     }
517   }
518 
519   // Return the set with elements that are not in the others.
difference(const OrderedSet & other)520   ordered_set_type difference(const OrderedSet &other) {
521     ordered_set_type res;
522     for (auto &item : ordered_data_) {
523       if (!other.contains(item)) {
524         res.add(item);
525       }
526     }
527     return res;
528   }
529 
530   ordered_set_type operator-(const OrderedSet &other) { return difference(other); }
531 
contains(const element_type & e)532   bool contains(const element_type &e) const { return (map_.find(e.get()) != map_.end()); }
533 
find(const element_type & e)534   const_iterator find(const element_type &e) const {
535     auto iter = map_.find(e.get());
536     if (iter == map_.end()) {
537       return ordered_data_.end();
538     }
539     return iter->second;
540   }
541 
find(const element_type & e)542   iterator find(const element_type &e) {
543     auto iter = map_.find(e.get());
544     if (iter == map_.end()) {
545       return ordered_data_.end();
546     }
547     return iter->second;
548   }
549 
count(const element_type & e)550   std::size_t count(const element_type &e) const { return map_.count(e.get()); }
551 
begin()552   iterator begin() { return ordered_data_.begin(); }
end()553   iterator end() { return ordered_data_.end(); }
554 
begin()555   const_iterator begin() const { return ordered_data_.cbegin(); }
end()556   const_iterator end() const { return ordered_data_.cend(); }
557 
cbegin()558   const_iterator cbegin() const { return ordered_data_.cbegin(); }
cend()559   const_iterator cend() const { return ordered_data_.cend(); }
560 
rbegin()561   reverse_iterator rbegin() { return ordered_data_.rbegin(); }
rend()562   reverse_iterator rend() { return ordered_data_.rend(); }
563 
crbegin()564   const_reverse_iterator crbegin() const { return ordered_data_.crbegin(); }
crend()565   const_reverse_iterator crend() const { return ordered_data_.crend(); }
566 
567  private:
568   map_type map_;
569   sequential_type ordered_data_;
570 };
571 }  // namespace mindspore
572 
573 #endif  // MINDSPORE_CORE_UTILS_ORDERED_SET_H_
574