1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_UTILS_ORDERED_SET_H_ 18 #define MINDSPORE_CORE_UTILS_ORDERED_SET_H_ 19 20 #include <algorithm> 21 #include <vector> 22 #include <list> 23 #include <utility> 24 #include <functional> 25 #include <memory> 26 #include "utils/hashing.h" 27 #include "utils/hash_map.h" 28 29 namespace mindspore { 30 // Implementation of OrderedSet that keeps insertion order 31 // using map as set, and use list as a sequential container to record elements to keep insertion order 32 template <class T, class Hash = std::hash<T>, class KeyEqual = std::equal_to<T>> 33 class OrderedSet { 34 public: 35 using element_type = T; 36 using hasher = Hash; 37 using equal = KeyEqual; 38 using sequential_type = std::list<element_type>; 39 using vector_type = std::vector<element_type>; 40 using iterator = typename sequential_type::iterator; 41 using const_iterator = typename sequential_type::const_iterator; 42 using reverse_iterator = typename sequential_type::reverse_iterator; 43 using const_reverse_iterator = typename sequential_type::const_reverse_iterator; 44 using map_type = mindspore::HashMap<element_type, iterator, hasher, equal>; 45 using ordered_set_type = OrderedSet<element_type, hasher, equal>; 46 47 OrderedSet() = default; 48 ~OrderedSet() = default; 49 // OrderedSet use an iterator to list as mapped value to improve the performance of insertion and deletion, 50 // So copy of OrderedSet should re-build value of the map key to make it pointer to the new list,, thus we use 51 // traversal to build elements. OrderedSet(const OrderedSet & os)52 OrderedSet(const OrderedSet &os) { 53 for (auto &item : os.ordered_data_) { 54 add(item); 55 } 56 } 57 58 OrderedSet(OrderedSet &&os) = default; 59 OrderedSet(const sequential_type & other)60 explicit OrderedSet(const sequential_type &other) { 61 for (auto &item : other) { 62 add(item); 63 } 64 } 65 66 // Explicitly construct an OrderedSet use vector OrderedSet(const vector_type & other)67 explicit OrderedSet(const vector_type &other) { 68 for (auto &item : other) { 69 add(item); 70 } 71 } 72 73 OrderedSet &operator=(const OrderedSet &other) { 74 if (this != &other) { 75 clear(); 76 reserve(other.size()); 77 for (auto &item : other.ordered_data_) { 78 add(item); 79 } 80 } 81 return *this; 82 } 83 84 OrderedSet &operator=(OrderedSet &&other) = default; 85 86 // insert an element to the OrderedSet after the given position. insert(const iterator & pos,const element_type & e)87 std::pair<iterator, bool> insert(const iterator &pos, const element_type &e) { 88 auto result = map_.emplace(e, ordered_data_.end()); 89 if (result.second) { 90 result.first->second = ordered_data_.emplace(pos, e); 91 } 92 return {result.first->second, result.second}; 93 } 94 95 // Add an element to the OrderedSet, without judging return value add(const element_type & e)96 void add(const element_type &e) { (void)insert(ordered_data_.end(), e); } 97 98 // insert an element to the end of OrderedSet. insert(const element_type & e)99 std::pair<iterator, bool> insert(const element_type &e) { return insert(ordered_data_.end(), e); } 100 push_back(const element_type & e)101 void push_back(const element_type &e) const { (void)insert(ordered_data_.end(), e); } 102 push_front(const element_type & e)103 void push_front(const element_type &e) const { (void)insert(ordered_data_.begin(), e); } 104 105 // Remove an element, if removed return true, otherwise return false erase(const element_type & e)106 bool erase(const element_type &e) { 107 auto pos = map_.find(e); 108 if (pos == map_.end()) { 109 return false; 110 } 111 // erase the sequential data first 112 (void)ordered_data_.erase(pos->second); 113 (void)map_.erase(pos); 114 return true; 115 } 116 erase(iterator pos)117 iterator erase(iterator pos) { 118 (void)map_.erase(*pos); 119 return ordered_data_.erase(pos); 120 } 121 erase(const_iterator pos)122 iterator erase(const_iterator pos) { 123 (void)map_.erase(*pos); 124 return ordered_data_.erase(pos); 125 } 126 127 // Return the container size size()128 std::size_t size() const { return map_.size(); } 129 empty()130 bool empty() const { return map_.size() == 0; } 131 132 // Clear the elements clear()133 void clear() { 134 if (!map_.empty()) { 135 map_.clear(); 136 ordered_data_.clear(); 137 } 138 } 139 140 // Reserve memory for the number of entries. reserve(size_t num_entries)141 void reserve(size_t num_entries) { map_.reserve(num_entries); } 142 143 // Compare two orderedset, if the order is not equal shall return false 144 bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; } 145 pop()146 element_type pop() { 147 element_type e = std::move(ordered_data_.front()); 148 (void)map_.erase(e); 149 (void)ordered_data_.erase(ordered_data_.begin()); 150 return e; 151 } 152 back()153 element_type &back() { return ordered_data_.back(); } front()154 element_type &front() { return ordered_data_.front(); } 155 back()156 const element_type &back() const { return ordered_data_.back(); } front()157 const element_type &front() const { return ordered_data_.front(); } 158 159 // Return true if there are no common elements is_disjoint(const OrderedSet & other)160 bool is_disjoint(const OrderedSet &other) { 161 for (auto &item : other.ordered_data_) { 162 if (map_.find(item) != map_.end()) { 163 return false; 164 } 165 } 166 return true; 167 } 168 169 // Test whether this is subset of other is_subset(const OrderedSet & other)170 bool is_subset(const OrderedSet &other) { 171 for (auto &item : ordered_data_) { 172 if (other.map_.find(item) == other.map_.end()) { 173 return false; 174 } 175 } 176 return true; 177 } 178 179 // Add elements in other to this orderedset update(const OrderedSet & other)180 void update(const OrderedSet &other) { 181 for (auto &item : other.ordered_data_) { 182 add(item); 183 } 184 } 185 update(const std::shared_ptr<OrderedSet> & other)186 void update(const std::shared_ptr<OrderedSet> &other) { update(*other); } 187 update(const sequential_type & other)188 void update(const sequential_type &other) { 189 for (auto &item : other) { 190 add(item); 191 } 192 } 193 update(const vector_type & other)194 void update(const vector_type &other) { 195 for (auto &item : other) { 196 add(item); 197 } 198 } 199 get_union(const OrderedSet & other)200 ordered_set_type get_union(const OrderedSet &other) { 201 ordered_set_type res(ordered_data_); 202 res.update(other); 203 return res; 204 } 205 206 // Get the union with other set, this operator may cost time because of copy 207 ordered_set_type operator|(const OrderedSet &other) { return get_union(other); } 208 209 // Return the intersection of two sets intersection(const OrderedSet & other)210 ordered_set_type intersection(const OrderedSet &other) { 211 ordered_set_type res(ordered_data_); 212 for (auto &item : ordered_data_) { 213 if (other.map_.find(item) == other.map_.end()) { 214 (void)res.erase(item); 215 } 216 } 217 return res; 218 } 219 ordered_set_type operator&(const OrderedSet &other) { return intersection(other); } 220 221 // Return the symmetric difference of two sets symmetric_difference(const OrderedSet & other)222 ordered_set_type symmetric_difference(const OrderedSet &other) { 223 ordered_set_type res(ordered_data_); 224 for (auto &item : other.ordered_data_) { 225 if (map_.find(item) != map_.end()) { 226 (void)res.erase(item); 227 } else { 228 res.add(item); 229 } 230 } 231 return res; 232 } 233 234 ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); } 235 236 // Remove elements which is also in others. difference_update(const OrderedSet & other)237 void difference_update(const OrderedSet &other) { 238 // use vector traversal, to keep ordrer 239 for (auto &item : other.ordered_data_) { 240 (void)erase(item); 241 } 242 } 243 difference_update(const sequential_type & other)244 void difference_update(const sequential_type &other) { 245 for (auto &item : other) { 246 (void)erase(item); 247 } 248 } 249 difference_update(const vector_type & other)250 void difference_update(const vector_type &other) { 251 for (auto &item : other) { 252 (void)erase(item); 253 } 254 } 255 256 // Return the set with elements that are not in the others difference(const OrderedSet & other)257 ordered_set_type difference(const OrderedSet &other) { 258 ordered_set_type res(ordered_data_); 259 res.difference_update(other); 260 return res; 261 } 262 ordered_set_type operator-(const OrderedSet &other) { return difference(other); } 263 contains(const element_type & e)264 bool contains(const element_type &e) const { return (map_.find(e) != map_.end()); } 265 find(const element_type & e)266 const_iterator find(const element_type &e) const { 267 auto iter = map_.find(e); 268 if (iter == map_.end()) { 269 return ordered_data_.end(); 270 } 271 return iter->second; 272 } 273 find(const element_type & e)274 iterator find(const element_type &e) { 275 auto iter = map_.find(e); 276 if (iter == map_.end()) { 277 return ordered_data_.end(); 278 } 279 return iter->second; 280 } 281 282 // Return the count of an element in set count(const element_type & e)283 std::size_t count(const element_type &e) const { return map_.count(e); } 284 begin()285 iterator begin() { return ordered_data_.begin(); } end()286 iterator end() { return ordered_data_.end(); } 287 begin()288 const_iterator begin() const { return ordered_data_.cbegin(); } end()289 const_iterator end() const { return ordered_data_.cend(); } 290 cbegin()291 const_iterator cbegin() const { return ordered_data_.cbegin(); } cend()292 const_iterator cend() const { return ordered_data_.cend(); } 293 rbegin()294 reverse_iterator rbegin() const { return ordered_data_.rbegin(); } rend()295 reverse_iterator rend() const { return ordered_data_.rend(); } 296 297 private: 298 map_type map_; 299 sequential_type ordered_data_; 300 }; 301 302 // OrderedSet that specially optimized for shared_ptr. 303 template <class T> 304 class OrderedSet<std::shared_ptr<T>> { 305 public: 306 using element_type = std::shared_ptr<T>; 307 using key_type = const T *; 308 using sequential_type = std::list<element_type>; 309 using vector_type = std::vector<element_type>; 310 using iterator = typename sequential_type::iterator; 311 using const_iterator = typename sequential_type::const_iterator; 312 using reverse_iterator = typename sequential_type::reverse_iterator; 313 using const_reverse_iterator = typename sequential_type::const_reverse_iterator; 314 using map_type = mindspore::HashMap<key_type, iterator>; 315 using ordered_set_type = OrderedSet<std::shared_ptr<T>>; 316 317 OrderedSet() = default; 318 ~OrderedSet() = default; 319 OrderedSet(const OrderedSet & os)320 OrderedSet(const OrderedSet &os) { 321 for (auto &item : os.ordered_data_) { 322 add(item); 323 } 324 } 325 326 OrderedSet(OrderedSet &&os) = default; 327 OrderedSet(const sequential_type & other)328 explicit OrderedSet(const sequential_type &other) { 329 reserve(other.size()); 330 for (auto &item : other) { 331 add(item); 332 } 333 } 334 OrderedSet(const vector_type & other)335 explicit OrderedSet(const vector_type &other) { 336 reserve(other.size()); 337 for (auto &item : other) { 338 add(item); 339 } 340 } 341 342 OrderedSet &operator=(const OrderedSet &other) { 343 if (this != &other) { 344 clear(); 345 reserve(other.size()); 346 for (auto &item : other.ordered_data_) { 347 add(item); 348 } 349 } 350 return *this; 351 } 352 353 OrderedSet &operator=(OrderedSet &&other) = default; 354 insert(const iterator & pos,const element_type & e)355 std::pair<iterator, bool> insert(const iterator &pos, const element_type &e) { 356 auto [map_iter, inserted] = map_.emplace(e.get(), iterator{}); 357 if (inserted) { 358 map_iter->second = ordered_data_.emplace(pos, e); 359 } 360 return {map_iter->second, inserted}; 361 } 362 insert(const iterator & pos,element_type && e)363 std::pair<iterator, bool> insert(const iterator &pos, element_type &&e) { 364 auto [map_iter, inserted] = map_.emplace(e.get(), iterator{}); 365 if (inserted) { 366 map_iter->second = ordered_data_.emplace(pos, std::move(e)); 367 } 368 return {map_iter->second, inserted}; 369 } 370 add(const element_type & e)371 void add(const element_type &e) { (void)insert(ordered_data_.end(), e); } 372 add(element_type && e)373 void add(element_type &&e) { (void)insert(ordered_data_.end(), std::move(e)); } 374 insert(const element_type & e)375 std::pair<iterator, bool> insert(const element_type &e) { return insert(ordered_data_.end(), e); } 376 insert(element_type && e)377 std::pair<iterator, bool> insert(element_type &&e) { return insert(ordered_data_.end(), std::move(e)); } 378 push_back(const element_type & e)379 void push_back(const element_type &e) { (void)insert(ordered_data_.end(), e); } 380 push_front(const element_type & e)381 void push_front(const element_type &e) { (void)insert(ordered_data_.begin(), e); } 382 erase(const element_type & e)383 bool erase(const element_type &e) { 384 auto pos = map_.find(e.get()); 385 if (pos == map_.end()) { 386 return false; 387 } 388 auto iter = pos->second; 389 (void)map_.erase(pos); 390 (void)ordered_data_.erase(iter); 391 return true; 392 } 393 erase(const iterator & pos)394 iterator erase(const iterator &pos) { 395 (void)map_.erase(pos->get()); 396 return ordered_data_.erase(pos); 397 } 398 erase(const_iterator pos)399 iterator erase(const_iterator pos) { 400 (void)map_.erase(pos->get()); 401 return ordered_data_.erase(pos); 402 } 403 size()404 std::size_t size() const { return ordered_data_.size(); } 405 empty()406 bool empty() const { return ordered_data_.empty(); } 407 clear()408 void clear() { 409 if (!map_.empty()) { 410 map_.clear(); 411 ordered_data_.clear(); 412 } 413 } 414 reserve(size_t num_entries)415 void reserve(size_t num_entries) { map_.reserve(num_entries); } 416 417 bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; } 418 pop()419 element_type pop() { 420 element_type e = std::move(ordered_data_.front()); 421 (void)map_.erase(e.get()); 422 (void)ordered_data_.erase(ordered_data_.begin()); 423 return e; 424 } 425 back()426 element_type &back() { return ordered_data_.back(); } front()427 element_type &front() { return ordered_data_.front(); } 428 back()429 const element_type &back() const { return ordered_data_.back(); } front()430 const element_type &front() const { return ordered_data_.front(); } 431 432 // Return true if there are no common elements. is_disjoint(const OrderedSet & other)433 bool is_disjoint(const OrderedSet &other) const { 434 return std::all_of(begin(), end(), [&other](const auto &e) { return !other.contains(e); }); 435 } 436 437 // Test whether this is subset of other. is_subset(const OrderedSet & other)438 bool is_subset(const OrderedSet &other) { 439 return std::all_of(begin(), end(), [&other](const auto &e) { return other.contains(e); }); 440 } 441 442 // Add elements in other to this orderedset. update(const OrderedSet & other)443 void update(const OrderedSet &other) { 444 for (auto &item : other.ordered_data_) { 445 add(item); 446 } 447 } 448 update(const std::shared_ptr<OrderedSet> & other)449 void update(const std::shared_ptr<OrderedSet> &other) { update(*other); } 450 update(const sequential_type & other)451 void update(const sequential_type &other) { 452 for (auto &item : other) { 453 add(item); 454 } 455 } 456 update(const vector_type & other)457 void update(const vector_type &other) { 458 for (auto &item : other) { 459 add(item); 460 } 461 } 462 get_union(const OrderedSet & other)463 ordered_set_type get_union(const OrderedSet &other) { 464 ordered_set_type res(ordered_data_); 465 res.update(other); 466 return res; 467 } 468 469 // Get the union with other set, this operator may cost time because of copy. 470 ordered_set_type operator|(const OrderedSet &other) { return get_union(other); } 471 472 // Return the intersection of two sets. intersection(const OrderedSet & other)473 ordered_set_type intersection(const OrderedSet &other) { 474 ordered_set_type res; 475 for (auto &item : ordered_data_) { 476 if (other.contains(item)) { 477 res.add(item); 478 } 479 } 480 return res; 481 } 482 483 ordered_set_type operator&(const OrderedSet &other) { return intersection(other); } 484 485 // Return the symmetric difference of two sets. symmetric_difference(const OrderedSet & other)486 ordered_set_type symmetric_difference(const OrderedSet &other) { 487 ordered_set_type res(ordered_data_); 488 for (auto &item : other) { 489 if (contains(item)) { 490 (void)res.erase(item); 491 } else { 492 res.add(item); 493 } 494 } 495 return res; 496 } 497 498 ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); } 499 500 // Remove elements which is also in others. difference_update(const OrderedSet & other)501 void difference_update(const OrderedSet &other) { 502 for (auto &item : other) { 503 (void)erase(item); 504 } 505 } 506 difference_update(const sequential_type & other)507 void difference_update(const sequential_type &other) { 508 for (auto &item : other) { 509 (void)erase(item); 510 } 511 } 512 difference_update(const vector_type & other)513 void difference_update(const vector_type &other) { 514 for (auto &item : other) { 515 (void)erase(item); 516 } 517 } 518 519 // Return the set with elements that are not in the others. difference(const OrderedSet & other)520 ordered_set_type difference(const OrderedSet &other) { 521 ordered_set_type res; 522 for (auto &item : ordered_data_) { 523 if (!other.contains(item)) { 524 res.add(item); 525 } 526 } 527 return res; 528 } 529 530 ordered_set_type operator-(const OrderedSet &other) { return difference(other); } 531 contains(const element_type & e)532 bool contains(const element_type &e) const { return (map_.find(e.get()) != map_.end()); } 533 find(const element_type & e)534 const_iterator find(const element_type &e) const { 535 auto iter = map_.find(e.get()); 536 if (iter == map_.end()) { 537 return ordered_data_.end(); 538 } 539 return iter->second; 540 } 541 find(const element_type & e)542 iterator find(const element_type &e) { 543 auto iter = map_.find(e.get()); 544 if (iter == map_.end()) { 545 return ordered_data_.end(); 546 } 547 return iter->second; 548 } 549 count(const element_type & e)550 std::size_t count(const element_type &e) const { return map_.count(e.get()); } 551 begin()552 iterator begin() { return ordered_data_.begin(); } end()553 iterator end() { return ordered_data_.end(); } 554 begin()555 const_iterator begin() const { return ordered_data_.cbegin(); } end()556 const_iterator end() const { return ordered_data_.cend(); } 557 cbegin()558 const_iterator cbegin() const { return ordered_data_.cbegin(); } cend()559 const_iterator cend() const { return ordered_data_.cend(); } 560 rbegin()561 reverse_iterator rbegin() { return ordered_data_.rbegin(); } rend()562 reverse_iterator rend() { return ordered_data_.rend(); } 563 crbegin()564 const_reverse_iterator crbegin() const { return ordered_data_.crbegin(); } crend()565 const_reverse_iterator crend() const { return ordered_data_.crend(); } 566 567 private: 568 map_type map_; 569 sequential_type ordered_data_; 570 }; 571 } // namespace mindspore 572 573 #endif // MINDSPORE_CORE_UTILS_ORDERED_SET_H_ 574