1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2020 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_ 20 #define MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_ 21 22 #include <map> 23 #include <set> 24 25 namespace mindspore { 26 template <class T> 27 class UnionFindSet { 28 public: UnionFindSet()29 UnionFindSet() : union_find_set_() {} 30 ~UnionFindSet() = default; Add(const T & elem)31 void Add(const T &elem) { 32 if (union_find_set_.find(elem) != union_find_set_.end()) { 33 return; 34 } 35 36 union_find_set_[elem] = elem; 37 } 38 Find(const T & key)39 T Find(const T &key) { 40 T key_parent = key; 41 auto iter = union_find_set_.find(key_parent); 42 if (iter == union_find_set_.end()) { 43 MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent; 44 } 45 while (key_parent != iter->second) { 46 key_parent = iter->second; 47 iter = union_find_set_.find(key_parent); 48 if (iter == union_find_set_.end()) { 49 MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent; 50 } 51 } 52 53 T tmp = key; 54 T tmp_parent; 55 while (tmp != key_parent) { 56 iter = union_find_set_.find(tmp); 57 if (iter == union_find_set_.end()) { 58 MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << tmp; 59 } 60 tmp_parent = iter->second; 61 union_find_set_[tmp] = key_parent; 62 tmp = tmp_parent; 63 } 64 return key_parent; 65 } 66 Union(const T & left,const T & right)67 void Union(const T &left, const T &right) { union_find_set_[Find(left)] = Find(right); } 68 GetSets()69 std::map<T, std::set<T>> GetSets() { 70 std::map<T, std::set<T>> result; 71 for (auto &iter : union_find_set_) { 72 (void)Find(iter.first); 73 } 74 for (auto &iter : union_find_set_) { 75 T parent = Find(iter.first); 76 result[parent].insert(iter.first); 77 } 78 return result; 79 } 80 81 private: 82 std::map<T, T> union_find_set_; 83 }; 84 } // namespace mindspore 85 86 #endif // MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_ 87