1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/core/graph_kernel_cluster.h"
17
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "mindspore/core/ops/nn_optimizer_ops.h"
20 #include "mindspore/core/ops/nn_ops.h"
21 #include "mindspore/core/ops/math_ops.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "mindspore/core/ops/comparison_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "ops/auto_generate/gen_ops_primitive.h"
27 #include "utils/hash_map.h"
28 #include "ir/graph_utils.h"
29 #include "utils/anf_utils.h"
30 #include "utils/file_utils.h"
31 #include "ops/sequence_ops.h"
32 #include "ops/nn_optimizer_ops.h"
33 #include "backend/common/graph_kernel/graph_kernel_flags.h"
34 #include "backend/common/graph_kernel/core/graph_builder.h"
35 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
36
37 namespace mindspore::graphkernel {
Cluster(size_t node_id,const AnfNodePtr & node,const mindspore::HashMap<AnfNodePtr,size_t> & node_idx_map)38 Graph::Cluster::Cluster(size_t node_id, const AnfNodePtr &node,
39 const mindspore::HashMap<AnfNodePtr, size_t> &node_idx_map)
40 : cluster_id_(node_id) {
41 auto cnode = node->cast<CNodePtr>();
42 MS_EXCEPTION_IF_NULL(cnode);
43 for (const auto &inp : cnode->inputs()) {
44 auto iter = node_idx_map.find(inp);
45 if (iter != node_idx_map.end()) {
46 // At the beginning, cluster_id is equal to node_id
47 (void)inputs_.insert(iter->second);
48 }
49 }
50 }
51
Merge(Cluster * other_cluster)52 void Graph::Cluster::Merge(Cluster *other_cluster) {
53 other_cluster->cluster_id_ = cluster_id_;
54 cluster_size_ += other_cluster->cluster_size_;
55 inputs_.insert(other_cluster->inputs_.cbegin(), other_cluster->inputs_.cend());
56 other_cluster->Clean();
57 }
58
Build(const FuncGraphPtr & func_graph,AnfNodePtrList * nodes,HashMap<AnfNodePtr,size_t> * node_idx_map)59 GraphPtr Graph::Build(const FuncGraphPtr &func_graph, AnfNodePtrList *nodes,
60 HashMap<AnfNodePtr, size_t> *node_idx_map) {
61 MS_EXCEPTION_IF_NULL(func_graph);
62 auto cnodes = TopoSort(func_graph->output(), SuccIncoming,
63 [](const AnfNodePtr &node) { return node->isa<CNode>() ? FOLLOW : EXCLUDE; });
64 HashMap<AnfNodePtr, size_t> tmp_node_idx_map;
65 for (size_t i = 0; i < cnodes.size(); i++) {
66 tmp_node_idx_map[cnodes[i]] = i;
67 }
68 auto graph_ptr = std::make_shared<Graph>(cnodes, tmp_node_idx_map);
69 if (nodes != nullptr) {
70 *nodes = std::move(cnodes);
71 }
72 if (node_idx_map != nullptr) {
73 *node_idx_map = std::move(tmp_node_idx_map);
74 }
75 return graph_ptr;
76 }
77
Graph(const AnfNodePtrList & nodes,const HashMap<AnfNodePtr,size_t> & node_idx_map)78 Graph::Graph(const AnfNodePtrList &nodes, const HashMap<AnfNodePtr, size_t> &node_idx_map) {
79 clusters_.reserve(nodes.size());
80 for (size_t i = 0; i < nodes.size(); i++) {
81 (void)clusters_.emplace_back(i, nodes[i], node_idx_map);
82 }
83 }
84
Find(size_t node_id)85 size_t Graph::Find(size_t node_id) {
86 size_t &pre_id = clusters_[node_id].cluster_id_;
87 return (pre_id == clusters_[pre_id].cluster_id_) ? pre_id : (pre_id = Find(pre_id));
88 }
89
Merge(const std::vector<size_t> & candidates)90 void Graph::Merge(const std::vector<size_t> &candidates) {
91 size_t min_id = *std::min_element(candidates.begin(), candidates.end());
92 for (auto id : candidates) {
93 if (id == min_id) {
94 continue;
95 }
96 clusters_[min_id].Merge(&clusters_[id]);
97 }
98 }
99
CollectClusters()100 std::vector<std::vector<size_t>> Graph::CollectClusters() {
101 std::vector<std::vector<size_t>> cluster_map(clusters_.size());
102 for (size_t i = 0; i < clusters_.size(); i++) {
103 cluster_map[Find(i)].push_back(i);
104 }
105 return cluster_map;
106 }
107
Dfs(size_t node_id,const Graph::VisitFunc & visitor)108 void Graph::Dfs(size_t node_id, const Graph::VisitFunc &visitor) {
109 ++seen_;
110 return DepthFirstSearch(Find(node_id), visitor);
111 }
112
GetInputs(size_t cluster_id)113 const std::set<size_t> &Graph::GetInputs(size_t cluster_id) {
114 cluster_id = Find(cluster_id);
115 RefreshInputs(cluster_id);
116 return clusters_[cluster_id].inputs_;
117 }
118
RefreshInputs(size_t i)119 void Graph::RefreshInputs(size_t i) {
120 auto &inputs = clusters_[i].inputs_;
121 for (auto iter = inputs.cbegin(); iter != inputs.cend();) {
122 size_t new_id = Find(*iter);
123 if (new_id != *iter) {
124 iter = inputs.erase(iter);
125 (void)inputs.insert(new_id);
126 } else {
127 ++iter;
128 }
129 }
130 (void)inputs.erase(i);
131 }
132
DepthFirstSearch(size_t cluster_id,const VisitFunc & visitor)133 void Graph::DepthFirstSearch(size_t cluster_id, const VisitFunc &visitor) {
134 if (clusters_[cluster_id].seed_ >= seen_) {
135 return;
136 }
137 clusters_[cluster_id].seed_ = seen_;
138 if (visitor(cluster_id) != FOLLOW) {
139 return;
140 }
141 // traverse inputs in descending order.
142 const auto &inputs = GetInputs(cluster_id);
143 for (auto iter = inputs.crbegin(); iter != inputs.crend(); ++iter) {
144 DepthFirstSearch(*iter, visitor);
145 }
146 }
147
RemoveCircle(std::vector<size_t> * candidates)148 void CircleChecker::RemoveCircle(std::vector<size_t> *candidates) {
149 if (candidates->size() <= 1) {
150 return;
151 }
152 candidates_.clear();
153 candidates_.insert(candidates->cbegin(), candidates->cend());
154 for (auto iter = candidates->cbegin(); iter != candidates->cend(); ++iter) {
155 if (candidates_.count(*iter) == 0) {
156 continue;
157 }
158 circle_nodes_.clear();
159 if (CheckCircle(*iter)) {
160 RemoveCircleNodesFromCandidates();
161 }
162 }
163 (void)candidates->erase(std::remove_if(candidates->begin(), candidates->end(),
164 [this](size_t c) { return this->candidates_.count(c) == 0; }),
165 candidates->end());
166 }
167
168 /**
169 * Check circle. the candidate is collected into circle_nodes_ if it will form a circle.
170 *
171 * algorithm:
172 * Search from the basenode's input that is NOT in candidates (the basenode is a candidate),
173 * If it depends on a node that belongs to candidates, it will form a circle.
174 * e.g. A -> x -> ... -> B
175 * -> y -> ... -> C
176 * In this case, A, B and C are candidates while x and y are not.
177 * Both x and y are inputs of A. assumes A is the basenode.
178 * When searching from x, the B will be found and added into circle_nodes list,
179 * and then when searching from y, the C will be found and added into circle_nodes list.
180 */
CheckCircle(size_t basenode)181 bool CircleChecker::CheckCircle(size_t basenode) {
182 const auto &inputs = graph_->GetInputs(basenode);
183 std::set<size_t> visited_circle_nodes;
184 for (auto x : inputs) {
185 if (candidates_.count(x) > 0) {
186 continue;
187 }
188 bool has_circle = false;
189 std::set<size_t> done;
190 auto vis_func = [this, &has_circle, &done, &visited_circle_nodes](size_t node_id) {
191 if (done.count(node_id) > 0 || acyclic_nodes_.count(node_id) > 0 || visited_circle_nodes.count(node_id) > 0) {
192 return EXCLUDE;
193 }
194 (void)done.insert(node_id);
195 if (candidates_.count(node_id) > 0) {
196 has_circle = true;
197 circle_nodes_.push_back(node_id);
198 return EXCLUDE;
199 }
200 return FOLLOW;
201 };
202 graph_->Dfs(x, vis_func);
203 if (has_circle) {
204 visited_circle_nodes.insert(done.cbegin(), done.cend());
205 } else {
206 acyclic_nodes_.insert(done.cbegin(), done.cend());
207 }
208 }
209 return !circle_nodes_.empty();
210 }
211
RemoveCircleNodesFromCandidates()212 void CircleChecker::RemoveCircleNodesFromCandidates() {
213 auto remove_from_candidates = [this](size_t node_id) {
214 if (candidates_.count(node_id) > 0) {
215 (void)candidates_.erase(node_id);
216 return FOLLOW;
217 }
218 return EXCLUDE;
219 };
220 for (auto node : circle_nodes_) {
221 graph_->Dfs(node, remove_from_candidates);
222 }
223 }
224
FindCandidates(size_t basenode_id)225 std::vector<size_t> GraphKernelCluster::FindCandidates(size_t basenode_id) {
226 std::vector<size_t> candidates;
227 auto include = [this, &candidates, func_graph = nodes_[basenode_id]->func_graph()](size_t cluster_id) {
228 const AnfNodePtr &node = this->nodes_[cluster_id];
229 if (node->func_graph() != func_graph) {
230 return EXCLUDE;
231 }
232 if (!IsClusterableOp(node) && !IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
233 return EXCLUDE;
234 }
235 candidates.push_back(cluster_id);
236 // Do not search from clustered node again.
237 if (this->graph_->GetSize(cluster_id) > 1) {
238 return NOFOLLOW;
239 }
240 return FOLLOW;
241 };
242 graph_->Dfs(basenode_id, include);
243 std::reverse(candidates.begin(), candidates.end());
244 return candidates;
245 }
246
Process(const FuncGraphPtr & func_graph)247 bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) {
248 bool changed = false;
249 for (int i = SizeToInt(nodes_.size()) - 1; i >= 0; i--) {
250 // if the node has been clustered, it has tried to find its previous nodes, so it's unnecessary to try again.
251 if (graph_->GetSize(IntToSize(i)) > 1) {
252 continue;
253 }
254 auto candidates = FindCandidates(IntToSize(i));
255 CircleChecker circle_checker(graph_);
256 circle_checker.RemoveCircle(&candidates);
257 RemoveWildGetitem(&candidates);
258 if (candidates.empty()) {
259 continue;
260 }
261 // merge candidates into one cluster
262 graph_->Merge(candidates);
263 }
264
265 // Rebuild func_graphs
266 auto clusters = graph_->CollectClusters();
267 for (size_t i = 0; i < clusters.size(); i++) {
268 auto node_without_getitem = std::count_if(clusters[i].begin(), clusters[i].end(), [this](size_t node_id) {
269 return !IsPrimitiveCNode(this->nodes_[node_id], prim::kPrimTupleGetItem);
270 });
271 if (node_without_getitem == 0) {
272 continue;
273 }
274 if (node_without_getitem == 1) {
275 // Do not cluster a single GraphKernel again.
276 // Do not cluster a single Assign.
277 const auto &node = nodes_[clusters[i][0]];
278 if (AnfUtils::IsGraphKernel(node) || IsPrimitiveCNode(node, prim::kPrimAssign) || !IsClusterableOp(node)) {
279 continue;
280 }
281 }
282 CreateFuncGraph(func_graph, clusters[i]);
283 changed = true;
284 }
285 return changed;
286 }
287
CreateFuncGraph(const FuncGraphPtr & func_graph,const std::vector<size_t> & nodes_id)288 void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id) {
289 AnfNodePtrList old_nodes;
290 (void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
291 [this](size_t id) { return this->nodes_[id]; });
292 auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion");
293 if (GraphKernelFlags::GetInstance().dump_as_text) {
294 DumpClusterInfo(old_nodes, new_node);
295 }
296 }
297
DumpClusterInfo(const AnfNodePtrList & old_nodes,const AnfNodePtr & new_node)298 void GraphKernelCluster::DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node) {
299 dump_buf_ << "Source nodes of " << new_node->fullname_with_scope() << " = " << new_node->DebugString() << std::endl;
300 for (const auto &node : old_nodes) {
301 dump_buf_ << " " << node->fullname_with_scope() << " = " << node->DebugString() << std::endl;
302 }
303 dump_buf_ << "=======================" << std::endl;
304 }
305
DumpToFile()306 void GraphKernelCluster::DumpToFile() {
307 auto dir_path = FileUtils::CreateNotExistDirs(std::string("./") + kGraphKernelDumpPath);
308 if (!dir_path.has_value()) {
309 MS_LOG(WARNING) << "Failed to CreateNotExistDirs: ./" << kGraphKernelDumpPath;
310 return;
311 }
312 std::optional<std::string> whole_path = "";
313 std::optional<std::string> file_name = "graph_kernel_cluster_" + std::to_string(getpid()) + ".txt";
314 FileUtils::ConcatDirAndFileName(&dir_path, &file_name, &whole_path);
315 if (!whole_path.has_value()) {
316 MS_LOG(WARNING) << "Failed to get real path of file: " << file_name.value();
317 return;
318 }
319 auto filepath = whole_path.value();
320 ChangeFileMode(filepath, S_IWUSR);
321 std::ofstream fout(filepath, std::ios::app);
322 if (!fout.is_open()) {
323 MS_LOG(INFO) << "Open dump file '" << filepath << "' failed!";
324 ChangeFileMode(filepath, S_IRUSR);
325 return;
326 }
327 fout << dump_buf_.str() << std::endl;
328 fout.close();
329 ChangeFileMode(filepath, S_IRUSR);
330 }
331
332 // The GetItem node should be clustered with its real input.
333 // If its real input is not in the candidates, the GetItem should be excluded.
RemoveWildGetitem(std::vector<size_t> * candidates)334 void GraphKernelCluster::RemoveWildGetitem(std::vector<size_t> *candidates) {
335 bool changed = false;
336 std::set<size_t> candidates_set(candidates->begin(), candidates->end());
337
338 for (auto iter = candidates_set.cbegin(); iter != candidates_set.cend();) {
339 size_t cluster_id = *iter;
340 if (IsPrimitiveCNode(nodes_[cluster_id], prim::kPrimTupleGetItem)) {
341 const auto &inputs = graph_->GetInputs(cluster_id);
342 if (inputs.size() != 1) {
343 MS_LOG(INFO) << "Input size of GetItem(" << cluster_id << ") should be 1, but got " << inputs.size();
344 candidates->clear();
345 return;
346 }
347 auto prev_id = *(inputs.cbegin());
348 if (candidates_set.count(prev_id) == 0) {
349 iter = candidates_set.erase(iter);
350 changed = true;
351 continue;
352 }
353 }
354 ++iter;
355 }
356 if (changed) {
357 (void)candidates->erase(std::remove_if(candidates->begin(), candidates->end(),
358 [&candidates_set](size_t c) { return candidates_set.count(c) == 0; }),
359 candidates->end());
360 }
361 }
362
Init(const FuncGraphPtr & func_graph)363 void GraphKernelCluster::Init(const FuncGraphPtr &func_graph) {
364 op_list_ = GetClusterableOpList();
365 graph_ = Graph::Build(func_graph, &nodes_);
366 MS_EXCEPTION_IF_NULL(graph_);
367 }
368
Run(const FuncGraphPtr & func_graph)369 bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) {
370 auto mng = func_graph->manager();
371 MS_EXCEPTION_IF_NULL(mng);
372 Init(func_graph);
373 bool changed = Process(func_graph);
374 if (changed) {
375 if (GraphKernelFlags::GetInstance().dump_as_text) {
376 DumpToFile();
377 }
378 mng->RemoveRoots();
379 mng->KeepRoots({func_graph});
380 }
381 Clean();
382 return changed;
383 }
384 } // namespace mindspore::graphkernel
385