• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/core/graph_kernel_cluster.h"
17 
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "mindspore/core/ops/nn_optimizer_ops.h"
20 #include "mindspore/core/ops/nn_ops.h"
21 #include "mindspore/core/ops/math_ops.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "mindspore/core/ops/comparison_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "ops/auto_generate/gen_ops_primitive.h"
27 #include "utils/hash_map.h"
28 #include "ir/graph_utils.h"
29 #include "utils/anf_utils.h"
30 #include "utils/file_utils.h"
31 #include "ops/sequence_ops.h"
32 #include "ops/nn_optimizer_ops.h"
33 #include "backend/common/graph_kernel/graph_kernel_flags.h"
34 #include "backend/common/graph_kernel/core/graph_builder.h"
35 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
36 
37 namespace mindspore::graphkernel {
Cluster(size_t node_id,const AnfNodePtr & node,const mindspore::HashMap<AnfNodePtr,size_t> & node_idx_map)38 Graph::Cluster::Cluster(size_t node_id, const AnfNodePtr &node,
39                         const mindspore::HashMap<AnfNodePtr, size_t> &node_idx_map)
40     : cluster_id_(node_id) {
41   auto cnode = node->cast<CNodePtr>();
42   MS_EXCEPTION_IF_NULL(cnode);
43   for (const auto &inp : cnode->inputs()) {
44     auto iter = node_idx_map.find(inp);
45     if (iter != node_idx_map.end()) {
46       // At the beginning, cluster_id is equal to node_id
47       (void)inputs_.insert(iter->second);
48     }
49   }
50 }
51 
Merge(Cluster * other_cluster)52 void Graph::Cluster::Merge(Cluster *other_cluster) {
53   other_cluster->cluster_id_ = cluster_id_;
54   cluster_size_ += other_cluster->cluster_size_;
55   inputs_.insert(other_cluster->inputs_.cbegin(), other_cluster->inputs_.cend());
56   other_cluster->Clean();
57 }
58 
Build(const FuncGraphPtr & func_graph,AnfNodePtrList * nodes,HashMap<AnfNodePtr,size_t> * node_idx_map)59 GraphPtr Graph::Build(const FuncGraphPtr &func_graph, AnfNodePtrList *nodes,
60                       HashMap<AnfNodePtr, size_t> *node_idx_map) {
61   MS_EXCEPTION_IF_NULL(func_graph);
62   auto cnodes = TopoSort(func_graph->output(), SuccIncoming,
63                          [](const AnfNodePtr &node) { return node->isa<CNode>() ? FOLLOW : EXCLUDE; });
64   HashMap<AnfNodePtr, size_t> tmp_node_idx_map;
65   for (size_t i = 0; i < cnodes.size(); i++) {
66     tmp_node_idx_map[cnodes[i]] = i;
67   }
68   auto graph_ptr = std::make_shared<Graph>(cnodes, tmp_node_idx_map);
69   if (nodes != nullptr) {
70     *nodes = std::move(cnodes);
71   }
72   if (node_idx_map != nullptr) {
73     *node_idx_map = std::move(tmp_node_idx_map);
74   }
75   return graph_ptr;
76 }
77 
Graph(const AnfNodePtrList & nodes,const HashMap<AnfNodePtr,size_t> & node_idx_map)78 Graph::Graph(const AnfNodePtrList &nodes, const HashMap<AnfNodePtr, size_t> &node_idx_map) {
79   clusters_.reserve(nodes.size());
80   for (size_t i = 0; i < nodes.size(); i++) {
81     (void)clusters_.emplace_back(i, nodes[i], node_idx_map);
82   }
83 }
84 
Find(size_t node_id)85 size_t Graph::Find(size_t node_id) {
86   size_t &pre_id = clusters_[node_id].cluster_id_;
87   return (pre_id == clusters_[pre_id].cluster_id_) ? pre_id : (pre_id = Find(pre_id));
88 }
89 
Merge(const std::vector<size_t> & candidates)90 void Graph::Merge(const std::vector<size_t> &candidates) {
91   size_t min_id = *std::min_element(candidates.begin(), candidates.end());
92   for (auto id : candidates) {
93     if (id == min_id) {
94       continue;
95     }
96     clusters_[min_id].Merge(&clusters_[id]);
97   }
98 }
99 
CollectClusters()100 std::vector<std::vector<size_t>> Graph::CollectClusters() {
101   std::vector<std::vector<size_t>> cluster_map(clusters_.size());
102   for (size_t i = 0; i < clusters_.size(); i++) {
103     cluster_map[Find(i)].push_back(i);
104   }
105   return cluster_map;
106 }
107 
Dfs(size_t node_id,const Graph::VisitFunc & visitor)108 void Graph::Dfs(size_t node_id, const Graph::VisitFunc &visitor) {
109   ++seen_;
110   return DepthFirstSearch(Find(node_id), visitor);
111 }
112 
GetInputs(size_t cluster_id)113 const std::set<size_t> &Graph::GetInputs(size_t cluster_id) {
114   cluster_id = Find(cluster_id);
115   RefreshInputs(cluster_id);
116   return clusters_[cluster_id].inputs_;
117 }
118 
RefreshInputs(size_t i)119 void Graph::RefreshInputs(size_t i) {
120   auto &inputs = clusters_[i].inputs_;
121   for (auto iter = inputs.cbegin(); iter != inputs.cend();) {
122     size_t new_id = Find(*iter);
123     if (new_id != *iter) {
124       iter = inputs.erase(iter);
125       (void)inputs.insert(new_id);
126     } else {
127       ++iter;
128     }
129   }
130   (void)inputs.erase(i);
131 }
132 
DepthFirstSearch(size_t cluster_id,const VisitFunc & visitor)133 void Graph::DepthFirstSearch(size_t cluster_id, const VisitFunc &visitor) {
134   if (clusters_[cluster_id].seed_ >= seen_) {
135     return;
136   }
137   clusters_[cluster_id].seed_ = seen_;
138   if (visitor(cluster_id) != FOLLOW) {
139     return;
140   }
141   // traverse inputs in descending order.
142   const auto &inputs = GetInputs(cluster_id);
143   for (auto iter = inputs.crbegin(); iter != inputs.crend(); ++iter) {
144     DepthFirstSearch(*iter, visitor);
145   }
146 }
147 
RemoveCircle(std::vector<size_t> * candidates)148 void CircleChecker::RemoveCircle(std::vector<size_t> *candidates) {
149   if (candidates->size() <= 1) {
150     return;
151   }
152   candidates_.clear();
153   candidates_.insert(candidates->cbegin(), candidates->cend());
154   for (auto iter = candidates->cbegin(); iter != candidates->cend(); ++iter) {
155     if (candidates_.count(*iter) == 0) {
156       continue;
157     }
158     circle_nodes_.clear();
159     if (CheckCircle(*iter)) {
160       RemoveCircleNodesFromCandidates();
161     }
162   }
163   (void)candidates->erase(std::remove_if(candidates->begin(), candidates->end(),
164                                          [this](size_t c) { return this->candidates_.count(c) == 0; }),
165                           candidates->end());
166 }
167 
168 /**
169  * Check circle. the candidate is collected into circle_nodes_ if it will form a circle.
170  *
171  * algorithm:
172  * Search from the basenode's input that is NOT in candidates (the basenode is a candidate),
173  * If it depends on a node that belongs to candidates, it will form a circle.
174  *  e.g.     A -> x -> ... -> B
175  *             -> y -> ... -> C
176  * In this case, A, B and C are candidates while x and y are not.
177  * Both x and y are inputs of A. assumes A is the basenode.
178  * When searching from x, the B will be found and added into circle_nodes list,
179  * and then when searching from y, the C will be found and added into circle_nodes list.
180  */
CheckCircle(size_t basenode)181 bool CircleChecker::CheckCircle(size_t basenode) {
182   const auto &inputs = graph_->GetInputs(basenode);
183   std::set<size_t> visited_circle_nodes;
184   for (auto x : inputs) {
185     if (candidates_.count(x) > 0) {
186       continue;
187     }
188     bool has_circle = false;
189     std::set<size_t> done;
190     auto vis_func = [this, &has_circle, &done, &visited_circle_nodes](size_t node_id) {
191       if (done.count(node_id) > 0 || acyclic_nodes_.count(node_id) > 0 || visited_circle_nodes.count(node_id) > 0) {
192         return EXCLUDE;
193       }
194       (void)done.insert(node_id);
195       if (candidates_.count(node_id) > 0) {
196         has_circle = true;
197         circle_nodes_.push_back(node_id);
198         return EXCLUDE;
199       }
200       return FOLLOW;
201     };
202     graph_->Dfs(x, vis_func);
203     if (has_circle) {
204       visited_circle_nodes.insert(done.cbegin(), done.cend());
205     } else {
206       acyclic_nodes_.insert(done.cbegin(), done.cend());
207     }
208   }
209   return !circle_nodes_.empty();
210 }
211 
RemoveCircleNodesFromCandidates()212 void CircleChecker::RemoveCircleNodesFromCandidates() {
213   auto remove_from_candidates = [this](size_t node_id) {
214     if (candidates_.count(node_id) > 0) {
215       (void)candidates_.erase(node_id);
216       return FOLLOW;
217     }
218     return EXCLUDE;
219   };
220   for (auto node : circle_nodes_) {
221     graph_->Dfs(node, remove_from_candidates);
222   }
223 }
224 
FindCandidates(size_t basenode_id)225 std::vector<size_t> GraphKernelCluster::FindCandidates(size_t basenode_id) {
226   std::vector<size_t> candidates;
227   auto include = [this, &candidates, func_graph = nodes_[basenode_id]->func_graph()](size_t cluster_id) {
228     const AnfNodePtr &node = this->nodes_[cluster_id];
229     if (node->func_graph() != func_graph) {
230       return EXCLUDE;
231     }
232     if (!IsClusterableOp(node) && !IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
233       return EXCLUDE;
234     }
235     candidates.push_back(cluster_id);
236     // Do not search from clustered node again.
237     if (this->graph_->GetSize(cluster_id) > 1) {
238       return NOFOLLOW;
239     }
240     return FOLLOW;
241   };
242   graph_->Dfs(basenode_id, include);
243   std::reverse(candidates.begin(), candidates.end());
244   return candidates;
245 }
246 
Process(const FuncGraphPtr & func_graph)247 bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) {
248   bool changed = false;
249   for (int i = SizeToInt(nodes_.size()) - 1; i >= 0; i--) {
250     // if the node has been clustered, it has tried to find its previous nodes, so it's unnecessary to try again.
251     if (graph_->GetSize(IntToSize(i)) > 1) {
252       continue;
253     }
254     auto candidates = FindCandidates(IntToSize(i));
255     CircleChecker circle_checker(graph_);
256     circle_checker.RemoveCircle(&candidates);
257     RemoveWildGetitem(&candidates);
258     if (candidates.empty()) {
259       continue;
260     }
261     // merge candidates into one cluster
262     graph_->Merge(candidates);
263   }
264 
265   // Rebuild func_graphs
266   auto clusters = graph_->CollectClusters();
267   for (size_t i = 0; i < clusters.size(); i++) {
268     auto node_without_getitem = std::count_if(clusters[i].begin(), clusters[i].end(), [this](size_t node_id) {
269       return !IsPrimitiveCNode(this->nodes_[node_id], prim::kPrimTupleGetItem);
270     });
271     if (node_without_getitem == 0) {
272       continue;
273     }
274     if (node_without_getitem == 1) {
275       // Do not cluster a single GraphKernel again.
276       // Do not cluster a single Assign.
277       const auto &node = nodes_[clusters[i][0]];
278       if (AnfUtils::IsGraphKernel(node) || IsPrimitiveCNode(node, prim::kPrimAssign) || !IsClusterableOp(node)) {
279         continue;
280       }
281     }
282     CreateFuncGraph(func_graph, clusters[i]);
283     changed = true;
284   }
285   return changed;
286 }
287 
CreateFuncGraph(const FuncGraphPtr & func_graph,const std::vector<size_t> & nodes_id)288 void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id) {
289   AnfNodePtrList old_nodes;
290   (void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
291                        [this](size_t id) { return this->nodes_[id]; });
292   auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion");
293   if (GraphKernelFlags::GetInstance().dump_as_text) {
294     DumpClusterInfo(old_nodes, new_node);
295   }
296 }
297 
DumpClusterInfo(const AnfNodePtrList & old_nodes,const AnfNodePtr & new_node)298 void GraphKernelCluster::DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node) {
299   dump_buf_ << "Source nodes of " << new_node->fullname_with_scope() << " = " << new_node->DebugString() << std::endl;
300   for (const auto &node : old_nodes) {
301     dump_buf_ << "  " << node->fullname_with_scope() << " = " << node->DebugString() << std::endl;
302   }
303   dump_buf_ << "=======================" << std::endl;
304 }
305 
DumpToFile()306 void GraphKernelCluster::DumpToFile() {
307   auto dir_path = FileUtils::CreateNotExistDirs(std::string("./") + kGraphKernelDumpPath);
308   if (!dir_path.has_value()) {
309     MS_LOG(WARNING) << "Failed to CreateNotExistDirs: ./" << kGraphKernelDumpPath;
310     return;
311   }
312   std::optional<std::string> whole_path = "";
313   std::optional<std::string> file_name = "graph_kernel_cluster_" + std::to_string(getpid()) + ".txt";
314   FileUtils::ConcatDirAndFileName(&dir_path, &file_name, &whole_path);
315   if (!whole_path.has_value()) {
316     MS_LOG(WARNING) << "Failed to get real path of file: " << file_name.value();
317     return;
318   }
319   auto filepath = whole_path.value();
320   ChangeFileMode(filepath, S_IWUSR);
321   std::ofstream fout(filepath, std::ios::app);
322   if (!fout.is_open()) {
323     MS_LOG(INFO) << "Open dump file '" << filepath << "' failed!";
324     ChangeFileMode(filepath, S_IRUSR);
325     return;
326   }
327   fout << dump_buf_.str() << std::endl;
328   fout.close();
329   ChangeFileMode(filepath, S_IRUSR);
330 }
331 
332 // The GetItem node should be clustered with its real input.
333 // If its real input is not in the candidates, the GetItem should be excluded.
RemoveWildGetitem(std::vector<size_t> * candidates)334 void GraphKernelCluster::RemoveWildGetitem(std::vector<size_t> *candidates) {
335   bool changed = false;
336   std::set<size_t> candidates_set(candidates->begin(), candidates->end());
337 
338   for (auto iter = candidates_set.cbegin(); iter != candidates_set.cend();) {
339     size_t cluster_id = *iter;
340     if (IsPrimitiveCNode(nodes_[cluster_id], prim::kPrimTupleGetItem)) {
341       const auto &inputs = graph_->GetInputs(cluster_id);
342       if (inputs.size() != 1) {
343         MS_LOG(INFO) << "Input size of GetItem(" << cluster_id << ") should be 1, but got " << inputs.size();
344         candidates->clear();
345         return;
346       }
347       auto prev_id = *(inputs.cbegin());
348       if (candidates_set.count(prev_id) == 0) {
349         iter = candidates_set.erase(iter);
350         changed = true;
351         continue;
352       }
353     }
354     ++iter;
355   }
356   if (changed) {
357     (void)candidates->erase(std::remove_if(candidates->begin(), candidates->end(),
358                                            [&candidates_set](size_t c) { return candidates_set.count(c) == 0; }),
359                             candidates->end());
360   }
361 }
362 
Init(const FuncGraphPtr & func_graph)363 void GraphKernelCluster::Init(const FuncGraphPtr &func_graph) {
364   op_list_ = GetClusterableOpList();
365   graph_ = Graph::Build(func_graph, &nodes_);
366   MS_EXCEPTION_IF_NULL(graph_);
367 }
368 
Run(const FuncGraphPtr & func_graph)369 bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) {
370   auto mng = func_graph->manager();
371   MS_EXCEPTION_IF_NULL(mng);
372   Init(func_graph);
373   bool changed = Process(func_graph);
374   if (changed) {
375     if (GraphKernelFlags::GetInstance().dump_as_text) {
376       DumpToFile();
377     }
378     mng->RemoveRoots();
379     mng->KeepRoots({func_graph});
380   }
381   Clean();
382   return changed;
383 }
384 }  // namespace mindspore::graphkernel
385