1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_UTILS_ORDERED_SET_H_ 18 #define MINDSPORE_CORE_UTILS_ORDERED_SET_H_ 19 20 #include <algorithm> 21 #include <unordered_map> 22 #include <vector> 23 #include <list> 24 #include <utility> 25 #include <functional> 26 #include <memory> 27 #include "utils/hashing.h" 28 29 namespace mindspore { 30 // Implementation of OrderedSet that keeps insertion order 31 // using map as set, and use list as a sequential container to record elements to keep insertion order 32 template <class T, class Hash = std::hash<T>, class KeyEqual = std::equal_to<T>> 33 class OrderedSet { 34 public: 35 using element_type = T; 36 using hasher = Hash; 37 using equal = KeyEqual; 38 using sequential_type = std::list<element_type>; 39 using vector_type = std::vector<element_type>; 40 using iterator = typename sequential_type::iterator; 41 using const_iterator = typename sequential_type::const_iterator; 42 using reverse_iterator = typename sequential_type::reverse_iterator; 43 using const_reverse_iterator = typename sequential_type::const_reverse_iterator; 44 using map_type = std::unordered_map<element_type, iterator, hasher, equal>; 45 using ordered_set_type = OrderedSet<element_type, hasher, equal>; 46 47 OrderedSet() = default; 48 ~OrderedSet() = default; 49 // OrderedSet use an iterator to list as mapped value to improve the performance of insertion and deletion, 50 // So copy of OrderedSet should re-build value of the map key to make it pointer to the new list,, thus we use 51 // traversal to build elements. OrderedSet(const OrderedSet & os)52 OrderedSet(const OrderedSet &os) { 53 for (auto &item : os.ordered_data_) { 54 add(item); 55 } 56 } 57 58 OrderedSet(OrderedSet &&os) = default; 59 OrderedSet(const sequential_type & other)60 explicit OrderedSet(const sequential_type &other) { 61 for (auto &item : other) { 62 add(item); 63 } 64 } 65 66 // Explicitly construct an OrderedSet use vector OrderedSet(const vector_type & other)67 explicit OrderedSet(const vector_type &other) { 68 for (auto &item : other) { 69 add(item); 70 } 71 } 72 73 OrderedSet &operator=(const OrderedSet &other) { 74 if (this != &other) { 75 clear(); 76 reserve(other.size()); 77 for (auto &item : other.ordered_data_) { 78 add(item); 79 } 80 } 81 return *this; 82 } 83 84 OrderedSet &operator=(OrderedSet &&other) = default; 85 86 // insert an element to the OrderedSet after the given position. insert(iterator pos,const element_type & e)87 std::pair<iterator, bool> insert(iterator pos, const element_type &e) { 88 auto result = map_.emplace(e, ordered_data_.end()); 89 if (result.second) { 90 result.first->second = ordered_data_.emplace(pos, e); 91 } 92 return {result.first->second, result.second}; 93 } 94 95 // Add an element to the OrderedSet, without judging return value add(const element_type & e)96 void add(const element_type &e) { (void)insert(ordered_data_.end(), e); } 97 98 // insert an element to the end of OrderedSet. insert(const element_type & e)99 std::pair<iterator, bool> insert(const element_type &e) { return insert(ordered_data_.end(), e); } 100 push_back(const element_type & e)101 void push_back(const element_type &e) { (void)insert(ordered_data_.end(), e); } 102 push_front(const element_type & e)103 void push_front(const element_type &e) { (void)insert(ordered_data_.begin(), e); } 104 105 // Remove an element, if removed return true, otherwise return false erase(const element_type & e)106 bool erase(const element_type &e) { 107 auto pos = map_.find(e); 108 if (pos == map_.end()) { 109 return false; 110 } 111 // erase the sequential data first 112 (void)ordered_data_.erase(pos->second); 113 (void)map_.erase(pos); 114 return true; 115 } 116 erase(iterator pos)117 iterator erase(iterator pos) { 118 (void)map_.erase(*pos); 119 return ordered_data_.erase(pos); 120 } 121 erase(const_iterator pos)122 iterator erase(const_iterator pos) { 123 (void)map_.erase(*pos); 124 return ordered_data_.erase(pos); 125 } 126 127 // Return the container size size()128 std::size_t size() const { return map_.size(); } 129 empty()130 bool empty() const { return map_.size() == 0; } 131 132 // Clear the elements clear()133 void clear() { 134 map_.clear(); 135 ordered_data_.clear(); 136 } 137 138 // Reserve memory for the number of entries. reserve(size_t num_entries)139 void reserve(size_t num_entries) { map_.reserve(num_entries); } 140 141 // Compare two orderedset, if the order is not equal shall return false 142 bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; } 143 pop()144 element_type pop() { 145 element_type e = std::move(ordered_data_.front()); 146 (void)map_.erase(e); 147 (void)ordered_data_.erase(ordered_data_.begin()); 148 return e; 149 } 150 back()151 element_type &back() { return ordered_data_.back(); } front()152 element_type &front() { return ordered_data_.front(); } 153 back()154 const element_type &back() const { return ordered_data_.back(); } front()155 const element_type &front() const { return ordered_data_.front(); } 156 157 // Return true if there are no common elements is_disjoint(const OrderedSet & other)158 bool is_disjoint(const OrderedSet &other) { 159 for (auto &item : other.ordered_data_) { 160 if (map_.find(item) != map_.end()) { 161 return false; 162 } 163 } 164 return true; 165 } 166 167 // Test whether this is subset of other is_subset(const OrderedSet & other)168 bool is_subset(const OrderedSet &other) { 169 for (auto &item : ordered_data_) { 170 if (other.map_.find(item) == other.map_.end()) { 171 return false; 172 } 173 } 174 return true; 175 } 176 177 // Add elements in other to this orderedset update(const OrderedSet & other)178 void update(const OrderedSet &other) { 179 for (auto &item : other.ordered_data_) { 180 add(item); 181 } 182 } 183 update(const std::shared_ptr<OrderedSet> & other)184 void update(const std::shared_ptr<OrderedSet> &other) { update(*other); } 185 update(const sequential_type & other)186 void update(const sequential_type &other) { 187 for (auto &item : other) { 188 add(item); 189 } 190 } 191 update(const vector_type & other)192 void update(const vector_type &other) { 193 for (auto &item : other) { 194 add(item); 195 } 196 } 197 get_union(const OrderedSet & other)198 ordered_set_type get_union(const OrderedSet &other) { 199 ordered_set_type res(ordered_data_); 200 res.update(other); 201 return res; 202 } 203 204 // Get the union with other set, this operator may cost time because of copy 205 ordered_set_type operator|(const OrderedSet &other) { return get_union(other); } 206 207 // Return the intersection of two sets intersection(const OrderedSet & other)208 ordered_set_type intersection(const OrderedSet &other) { 209 ordered_set_type res(ordered_data_); 210 for (auto &item : ordered_data_) { 211 if (other.map_.find(item) == other.map_.end()) { 212 (void)res.erase(item); 213 } 214 } 215 return res; 216 } 217 ordered_set_type operator&(const OrderedSet &other) { return intersection(other); } 218 219 // Return the symmetric difference of two sets symmetric_difference(const OrderedSet & other)220 ordered_set_type symmetric_difference(const OrderedSet &other) { 221 ordered_set_type res(ordered_data_); 222 for (auto &item : other.ordered_data_) { 223 if (map_.find(item) != map_.end()) { 224 (void)res.erase(item); 225 } else { 226 res.add(item); 227 } 228 } 229 return res; 230 } 231 232 ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); } 233 234 // Remove elements which is also in others. difference_update(const OrderedSet & other)235 void difference_update(const OrderedSet &other) { 236 // use vector traversal, to keep ordrer 237 for (auto &item : other.ordered_data_) { 238 (void)erase(item); 239 } 240 } 241 difference_update(const sequential_type & other)242 void difference_update(const sequential_type &other) { 243 for (auto &item : other) { 244 (void)erase(item); 245 } 246 } 247 difference_update(const vector_type & other)248 void difference_update(const vector_type &other) { 249 for (auto &item : other) { 250 (void)erase(item); 251 } 252 } 253 254 // Return the set with elements that are not in the others difference(const OrderedSet & other)255 ordered_set_type difference(const OrderedSet &other) { 256 ordered_set_type res(ordered_data_); 257 res.difference_update(other); 258 return res; 259 } 260 ordered_set_type operator-(const OrderedSet &other) { return difference(other); } 261 contains(const element_type & e)262 bool contains(const element_type &e) const { return (map_.find(e) != map_.end()); } 263 find(const element_type & e)264 const_iterator find(const element_type &e) const { 265 auto iter = map_.find(e); 266 if (iter == map_.end()) { 267 return ordered_data_.end(); 268 } 269 return iter->second; 270 } 271 find(const element_type & e)272 iterator find(const element_type &e) { 273 auto iter = map_.find(e); 274 if (iter == map_.end()) { 275 return ordered_data_.end(); 276 } 277 return iter->second; 278 } 279 280 // Return the count of an element in set count(const element_type & e)281 std::size_t count(const element_type &e) const { return map_.count(e); } 282 begin()283 iterator begin() { return ordered_data_.begin(); } end()284 iterator end() { return ordered_data_.end(); } 285 begin()286 const_iterator begin() const { return ordered_data_.cbegin(); } end()287 const_iterator end() const { return ordered_data_.cend(); } 288 cbegin()289 const_iterator cbegin() const { return ordered_data_.cbegin(); } cend()290 const_iterator cend() const { return ordered_data_.cend(); } 291 292 private: 293 map_type map_; 294 sequential_type ordered_data_; 295 }; 296 297 // OrderedSet that specially optimized for shared_ptr. 298 template <class T> 299 class OrderedSet<std::shared_ptr<T>> { 300 public: 301 using element_type = std::shared_ptr<T>; 302 using key_type = const T *; 303 using hash_t = PointerHash<T>; 304 using sequential_type = std::list<element_type>; 305 using vector_type = std::vector<element_type>; 306 using iterator = typename sequential_type::iterator; 307 using const_iterator = typename sequential_type::const_iterator; 308 using reverse_iterator = typename sequential_type::reverse_iterator; 309 using const_reverse_iterator = typename sequential_type::const_reverse_iterator; 310 using map_type = std::unordered_map<key_type, iterator, hash_t>; 311 using ordered_set_type = OrderedSet<std::shared_ptr<T>>; 312 313 OrderedSet() = default; 314 ~OrderedSet() = default; 315 OrderedSet(const OrderedSet & os)316 OrderedSet(const OrderedSet &os) { 317 for (auto &item : os.ordered_data_) { 318 add(item); 319 } 320 } 321 322 OrderedSet(OrderedSet &&os) = default; 323 OrderedSet(const sequential_type & other)324 explicit OrderedSet(const sequential_type &other) { 325 reserve(other.size()); 326 for (auto &item : other) { 327 add(item); 328 } 329 } 330 OrderedSet(const vector_type & other)331 explicit OrderedSet(const vector_type &other) { 332 reserve(other.size()); 333 for (auto &item : other) { 334 add(item); 335 } 336 } 337 338 OrderedSet &operator=(const OrderedSet &other) { 339 if (this != &other) { 340 clear(); 341 reserve(other.size()); 342 for (auto &item : other.ordered_data_) { 343 add(item); 344 } 345 } 346 return *this; 347 } 348 349 OrderedSet &operator=(OrderedSet &&other) = default; 350 insert(iterator pos,const element_type & e)351 std::pair<iterator, bool> insert(iterator pos, const element_type &e) { 352 auto [map_iter, inserted] = map_.emplace(e.get(), iterator{}); 353 if (inserted) { 354 map_iter->second = ordered_data_.emplace(pos, e); 355 } 356 return {map_iter->second, inserted}; 357 } 358 insert(iterator pos,element_type && e)359 std::pair<iterator, bool> insert(iterator pos, element_type &&e) { 360 auto [map_iter, inserted] = map_.emplace(e.get(), iterator{}); 361 if (inserted) { 362 map_iter->second = ordered_data_.emplace(pos, std::move(e)); 363 } 364 return {map_iter->second, inserted}; 365 } 366 add(const element_type & e)367 void add(const element_type &e) { (void)insert(ordered_data_.end(), e); } 368 add(element_type && e)369 void add(element_type &&e) { (void)insert(ordered_data_.end(), std::move(e)); } 370 insert(const element_type & e)371 std::pair<iterator, bool> insert(const element_type &e) { return insert(ordered_data_.end(), e); } 372 insert(element_type && e)373 std::pair<iterator, bool> insert(element_type &&e) { return insert(ordered_data_.end(), std::move(e)); } 374 push_back(const element_type & e)375 void push_back(const element_type &e) { (void)insert(ordered_data_.end(), e); } 376 push_front(const element_type & e)377 void push_front(const element_type &e) { (void)insert(ordered_data_.begin(), e); } 378 erase(const element_type & e)379 bool erase(const element_type &e) { 380 auto pos = map_.find(e.get()); 381 if (pos == map_.end()) { 382 return false; 383 } 384 auto iter = pos->second; 385 (void)map_.erase(pos); 386 (void)ordered_data_.erase(iter); 387 return true; 388 } 389 erase(iterator pos)390 iterator erase(iterator pos) { 391 (void)map_.erase(pos->get()); 392 return ordered_data_.erase(pos); 393 } 394 erase(const_iterator pos)395 iterator erase(const_iterator pos) { 396 (void)map_.erase(pos->get()); 397 return ordered_data_.erase(pos); 398 } 399 size()400 std::size_t size() const { return ordered_data_.size(); } 401 empty()402 bool empty() const { return ordered_data_.empty(); } 403 clear()404 void clear() { 405 map_.clear(); 406 ordered_data_.clear(); 407 } 408 reserve(size_t num_entries)409 void reserve(size_t num_entries) { map_.reserve(num_entries); } 410 411 bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; } 412 pop()413 element_type pop() { 414 element_type e = std::move(ordered_data_.front()); 415 (void)map_.erase(e.get()); 416 (void)ordered_data_.erase(ordered_data_.begin()); 417 return e; 418 } 419 back()420 element_type &back() { return ordered_data_.back(); } front()421 element_type &front() { return ordered_data_.front(); } 422 back()423 const element_type &back() const { return ordered_data_.back(); } front()424 const element_type &front() const { return ordered_data_.front(); } 425 426 // Return true if there are no common elements. is_disjoint(const OrderedSet & other)427 bool is_disjoint(const OrderedSet &other) { 428 return std::all_of(begin(), end(), [&other](const auto &e) { return !other.contains(e); }); 429 } 430 431 // Test whether this is subset of other. is_subset(const OrderedSet & other)432 bool is_subset(const OrderedSet &other) { 433 return std::all_of(begin(), end(), [&other](const auto &e) { return other.contains(e); }); 434 } 435 436 // Add elements in other to this orderedset. update(const OrderedSet & other)437 void update(const OrderedSet &other) { 438 for (auto &item : other.ordered_data_) { 439 add(item); 440 } 441 } 442 update(const std::shared_ptr<OrderedSet> & other)443 void update(const std::shared_ptr<OrderedSet> &other) { update(*other); } 444 update(const sequential_type & other)445 void update(const sequential_type &other) { 446 for (auto &item : other) { 447 add(item); 448 } 449 } 450 update(const vector_type & other)451 void update(const vector_type &other) { 452 for (auto &item : other) { 453 add(item); 454 } 455 } 456 get_union(const OrderedSet & other)457 ordered_set_type get_union(const OrderedSet &other) { 458 ordered_set_type res(ordered_data_); 459 res.update(other); 460 return res; 461 } 462 463 // Get the union with other set, this operator may cost time because of copy. 464 ordered_set_type operator|(const OrderedSet &other) { return get_union(other); } 465 466 // Return the intersection of two sets. intersection(const OrderedSet & other)467 ordered_set_type intersection(const OrderedSet &other) { 468 ordered_set_type res; 469 for (auto &item : ordered_data_) { 470 if (other.contains(item)) { 471 res.add(item); 472 } 473 } 474 return res; 475 } 476 477 ordered_set_type operator&(const OrderedSet &other) { return intersection(other); } 478 479 // Return the symmetric difference of two sets. symmetric_difference(const OrderedSet & other)480 ordered_set_type symmetric_difference(const OrderedSet &other) { 481 ordered_set_type res(ordered_data_); 482 for (auto &item : other) { 483 if (contains(item)) { 484 (void)res.erase(item); 485 } else { 486 res.add(item); 487 } 488 } 489 return res; 490 } 491 492 ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); } 493 494 // Remove elements which is also in others. difference_update(const OrderedSet & other)495 void difference_update(const OrderedSet &other) { 496 for (auto &item : other) { 497 (void)erase(item); 498 } 499 } 500 difference_update(const sequential_type & other)501 void difference_update(const sequential_type &other) { 502 for (auto &item : other) { 503 (void)erase(item); 504 } 505 } 506 difference_update(const vector_type & other)507 void difference_update(const vector_type &other) { 508 for (auto &item : other) { 509 (void)erase(item); 510 } 511 } 512 513 // Return the set with elements that are not in the others. difference(const OrderedSet & other)514 ordered_set_type difference(const OrderedSet &other) { 515 ordered_set_type res; 516 for (auto &item : ordered_data_) { 517 if (!other.contains(item)) { 518 res.add(item); 519 } 520 } 521 return res; 522 } 523 524 ordered_set_type operator-(const OrderedSet &other) { return difference(other); } 525 contains(const element_type & e)526 bool contains(const element_type &e) const { return (map_.find(e.get()) != map_.end()); } 527 find(const element_type & e)528 const_iterator find(const element_type &e) const { 529 auto iter = map_.find(e.get()); 530 if (iter == map_.end()) { 531 return ordered_data_.end(); 532 } 533 return iter->second; 534 } 535 find(const element_type & e)536 iterator find(const element_type &e) { 537 auto iter = map_.find(e.get()); 538 if (iter == map_.end()) { 539 return ordered_data_.end(); 540 } 541 return iter->second; 542 } 543 count(const element_type & e)544 std::size_t count(const element_type &e) const { return map_.count(e.get()); } 545 begin()546 iterator begin() { return ordered_data_.begin(); } end()547 iterator end() { return ordered_data_.end(); } 548 begin()549 const_iterator begin() const { return ordered_data_.cbegin(); } end()550 const_iterator end() const { return ordered_data_.cend(); } 551 cbegin()552 const_iterator cbegin() const { return ordered_data_.cbegin(); } cend()553 const_iterator cend() const { return ordered_data_.cend(); } 554 555 private: 556 map_type map_; 557 sequential_type ordered_data_; 558 }; 559 } // namespace mindspore 560 561 #endif // MINDSPORE_CORE_UTILS_ORDERED_SET_H_ 562