• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2020 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_
20 #define MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_
21 
22 #include <map>
23 #include <set>
24 
25 namespace mindspore {
26 template <class T>
27 class UnionFindSet {
28  public:
UnionFindSet()29   UnionFindSet() : union_find_set_() {}
30   ~UnionFindSet() = default;
Add(const T & elem)31   void Add(const T &elem) {
32     if (union_find_set_.find(elem) != union_find_set_.end()) {
33       return;
34     }
35 
36     union_find_set_[elem] = elem;
37   }
38 
Find(const T & key)39   T Find(const T &key) {
40     T key_parent = key;
41     auto iter = union_find_set_.find(key_parent);
42     if (iter == union_find_set_.end()) {
43       MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent;
44     }
45     while (key_parent != iter->second) {
46       key_parent = iter->second;
47       iter = union_find_set_.find(key_parent);
48       if (iter == union_find_set_.end()) {
49         MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent;
50       }
51     }
52 
53     T tmp = key;
54     T tmp_parent;
55     while (tmp != key_parent) {
56       iter = union_find_set_.find(tmp);
57       if (iter == union_find_set_.end()) {
58         MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << tmp;
59       }
60       tmp_parent = iter->second;
61       union_find_set_[tmp] = key_parent;
62       tmp = tmp_parent;
63     }
64     return key_parent;
65   }
66 
Union(const T & left,const T & right)67   void Union(const T &left, const T &right) { union_find_set_[Find(left)] = Find(right); }
68 
GetSets()69   std::map<T, std::set<T>> GetSets() {
70     std::map<T, std::set<T>> result;
71     for (auto &iter : union_find_set_) {
72       (void)Find(iter.first);
73     }
74     for (auto &iter : union_find_set_) {
75       T parent = Find(iter.first);
76       result[parent].insert(iter.first);
77     }
78     return result;
79   }
80 
81  private:
82   std::map<T, T> union_find_set_;
83 };
84 }  // namespace mindspore
85 
86 #endif  // MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_
87