• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
17 
18 #include "base/core_ops.h"
19 #include "ir/graph_utils.h"
20 #include "utils/file_utils.h"
21 #include "utils/context/graph_kernel_flags.h"
22 #include "backend/kernel_compiler/common_utils.h"
23 #include "backend/session/anf_runtime_algorithm.h"
24 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
25 #include "backend/optimizer/pass/getitem_tuple.h"
26 #include "backend/optimizer/graph_kernel/update_state_formatter.h"
27 
28 namespace mindspore {
29 namespace opt {
30 using context::OpLevel_0;
31 using context::OpLevel_1;
GetClusterableOpList()32 std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
33   std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> clusterable_ops_with_level = {
34     // all target
35     {kAllTarget, OpLevel_0, prim::kPrimAbs},
36     {kAllTarget, OpLevel_0, prim::kPrimAdd},
37     {kAllTarget, OpLevel_0, prim::kPrimCast},
38     {kAllTarget, OpLevel_0, prim::kPrimEqual},
39     {kAllTarget, OpLevel_0, prim::kPrimExp},
40     {kAllTarget, OpLevel_0, prim::kPrimInplaceAssign},
41     {kAllTarget, OpLevel_0, prim::kPrimLog},
42     {kAllTarget, OpLevel_0, prim::kPrimMaximum},
43     {kAllTarget, OpLevel_0, prim::kPrimMinimum},
44     {kAllTarget, OpLevel_0, prim::kPrimMul},
45     {kAllTarget, OpLevel_0, prim::kPrimNeg},
46     {kAllTarget, OpLevel_0, prim::kPrimPow},
47     {kAllTarget, OpLevel_0, prim::kPrimRealDiv},
48     {kAllTarget, OpLevel_0, prim::kPrimReciprocal},
49     {kAllTarget, OpLevel_1, prim::kPrimReduceSum},
50     {kAllTarget, OpLevel_1, prim::kPrimReshape},
51     {kAllTarget, OpLevel_0, prim::kPrimRound},
52     {kAllTarget, OpLevel_0, prim::kPrimRsqrt},
53     {kAllTarget, OpLevel_0, prim::kPrimSqrt},
54     {kAllTarget, OpLevel_0, prim::kPrimSub},
55     {kAllTarget, OpLevel_0, prim::kPrimTanh},
56     {kAllTarget, OpLevel_1, prim::kPrimTranspose},
57     // ascend
58     {kAscendDevice, OpLevel_1, prim::kPrimMatMul},
59     {kAscendDevice, OpLevel_1, prim::kPrimTransData},
60     {kAscendDevice, OpLevel_1, prim::kPrimBatchMatMul},
61     // gpu
62     {kGPUDevice, OpLevel_0, prim::kPrimACos},
63     {kGPUDevice, OpLevel_0, prim::kPrimAcosh},
64     {kGPUDevice, OpLevel_1, prim::kPrimArgMax},
65     {kGPUDevice, OpLevel_1, prim::kPrimArgMin},
66     {kGPUDevice, OpLevel_0, prim::kPrimAsin},
67     {kGPUDevice, OpLevel_0, prim::kPrimAsinh},
68     {kGPUDevice, OpLevel_0, prim::kPrimAssign},
69     {kGPUDevice, OpLevel_0, prim::kPrimAtan},
70     {kGPUDevice, OpLevel_0, prim::kPrimAtan2},
71     {kGPUDevice, OpLevel_0, prim::kPrimCos},
72     {kGPUDevice, OpLevel_0, prim::kPrimDiv},
73     {kGPUDevice, OpLevel_0, prim::kPrimErf},
74     {kGPUDevice, OpLevel_0, prim::kPrimExpm1},
75     {kGPUDevice, OpLevel_0, prim::kPrimFloor},
76     {kGPUDevice, OpLevel_0, prim::kPrimFloorDiv},
77     {kGPUDevice, OpLevel_0, prim::kPrimFloorMod},
78     {kGPUDevice, OpLevel_0, prim::kPrimGreater},
79     {kGPUDevice, OpLevel_0, prim::kPrimGreaterEqual},
80     {kGPUDevice, OpLevel_0, prim::kPrimIsFinite},
81     {kGPUDevice, OpLevel_0, prim::kPrimIsInf},
82     {kGPUDevice, OpLevel_0, prim::kPrimIsNan},
83     {kGPUDevice, OpLevel_0, prim::kPrimLess},
84     {kGPUDevice, OpLevel_0, prim::kPrimLessEqual},
85     {kGPUDevice, OpLevel_0, prim::kPrimLogicalAnd},
86     {kGPUDevice, OpLevel_0, prim::kPrimLogicalOr},
87     {kGPUDevice, OpLevel_0, prim::kPrimLogicalNot},
88     {kGPUDevice, OpLevel_0, prim::kPrimMod},
89     {kGPUDevice, OpLevel_0, prim::kPrimNotEqual},
90     {kGPUDevice, OpLevel_1, prim::kPrimReduceMax},
91     {kGPUDevice, OpLevel_1, prim::kPrimReduceMin},
92     {kGPUDevice, OpLevel_0, prim::kPrimSelect},
93     {kGPUDevice, OpLevel_0, prim::kPrimSign},
94     {kGPUDevice, OpLevel_0, prim::kPrimSin},
95     {kGPUDevice, OpLevel_0, prim::kPrimStridedSlice},
96     {kGPUDevice, OpLevel_0, prim::kPrimUserDefined},
97   };
98   const auto &flags = context::GraphKernelFlags::GetInstance();
99   std::vector<PrimitivePtr> clusterable_ops = GetValidOps(clusterable_ops_with_level, flags.fusion_ops_level);
100   OpListFilter(&clusterable_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops);
101   return clusterable_ops;
102 }
103 
104 namespace {
CountGraphKernelInnerNodes(const AnfNodePtr & node)105 size_t CountGraphKernelInnerNodes(const AnfNodePtr &node) {
106   AnfNodePtrList node_list;
107   kernel::GetValidKernelNodes(AnfAlgo::GetCNodeFuncGraphPtr(node), &node_list);
108   return node_list.size();
109 }
110 }  // namespace
111 
IsClusterableOp(const AnfNodePtr & node)112 bool GraphKernelCluster::IsClusterableOp(const AnfNodePtr &node) {
113   if (AnfAlgo::IsGraphKernel(node)) {
114     return true;
115   }
116   if (IsKeepBasicNode(node)) {
117     return false;
118   }
119   bool node_in_oplist = std::any_of(op_list_.begin(), op_list_.end(),
120                                     [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
121   if (!node_in_oplist) {
122     return false;
123   }
124 #if ENABLE_D
125   // For AICPU operators, only the Reshape can be clustered.
126   if (AnfAlgo::GetProcessor(node) != kernel::Processor::AICORE && !IsPrimitiveCNode(node, prim::kPrimReshape)) {
127     return false;
128   }
129 #endif
130   return true;
131 }
132 
133 class Graph {
134   struct Cluster {
135     size_t cluster_id_;        // node_id of the representative.
136     size_t cluster_size_{1};   // size of cluster, composite node is considered as one node.
137     size_t basic_op_cnt_{1};   // basic node count, the inner nodes of composite node are counted.
138     std::set<size_t> inputs_;  // inputs' cluster_id.
139     size_t seed_{0};           // visited flag of dfs.
140     size_t max_node_id_;       // largest node id of a cluster
141 
Clustermindspore::opt::Graph::Cluster142     Cluster(size_t node_id, const AnfNodePtr &node, const std::unordered_map<AnfNodePtr, size_t> &node_idx_map)
143         : cluster_id_(node_id), max_node_id_(node_id) {
144       if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
145         basic_op_cnt_ = 0;
146       } else if (AnfAlgo::IsGraphKernel(node)) {
147         // the basic_op_cnt_ is used to limit the composite op size
148         basic_op_cnt_ = CountGraphKernelInnerNodes(node);
149       }
150       auto cnode = node->cast<CNodePtr>();
151       MS_EXCEPTION_IF_NULL(cnode);
152       for (const auto &inp : cnode->inputs()) {
153         auto iter = node_idx_map.find(inp);
154         if (iter != node_idx_map.end()) {
155           // At the beginning, cluster_id is equal to node_id
156           (void)inputs_.insert(iter->second);
157         }
158       }
159     }
160     ~Cluster() = default;
161 
Mergemindspore::opt::Graph::Cluster162     void Merge(Cluster *other_cluster) {
163       other_cluster->cluster_id_ = cluster_id_;
164       max_node_id_ = std::max(other_cluster->max_node_id_, max_node_id_);
165       cluster_size_ += other_cluster->cluster_size_;
166       basic_op_cnt_ += other_cluster->basic_op_cnt_;
167       (void)std::for_each(other_cluster->inputs_.begin(), other_cluster->inputs_.end(),
168                           [this](size_t inp) { (void)this->inputs_.insert(inp); });
169       other_cluster->Clean();
170     }
171 
172     // clean the info to free memory.
Cleanmindspore::opt::Graph::Cluster173     void Clean() {
174       inputs_.clear();
175       cluster_size_ = 0;
176       basic_op_cnt_ = 0;
177       max_node_id_ = 0;
178     }
179   };  // struct Cluster
180 
181  public:
182   // Init and build graph
Graph(const AnfNodePtrList & nodes,const std::unordered_map<AnfNodePtr,size_t> & node_idx_map)183   Graph(const AnfNodePtrList &nodes, const std::unordered_map<AnfNodePtr, size_t> &node_idx_map) {
184     clusters_.reserve(nodes.size());
185     for (size_t i = 0; i < nodes.size(); i++) {
186       (void)clusters_.emplace_back(i, nodes[i], node_idx_map);
187     }
188   }
189   ~Graph() = default;
190 
191   // find the representative of the cluster
Find(size_t node_id)192   size_t Find(size_t node_id) {
193     size_t &pre_id = clusters_[node_id].cluster_id_;
194     return (pre_id == clusters_[pre_id].cluster_id_) ? pre_id : (pre_id = Find(pre_id));
195   }
196 
197   // merge clusters, the smallest cluster id will be the new cluster id.
Merge(const std::vector<size_t> & candidates)198   void Merge(const std::vector<size_t> &candidates) {
199     size_t min_id = *std::min_element(candidates.begin(), candidates.end());
200     for (auto id : candidates) {
201       if (id == min_id) continue;
202       clusters_[min_id].Merge(&clusters_[id]);
203     }
204   }
205 
206   // Collect nodes together that are in the same cluster.
CollectClusters()207   std::vector<std::vector<size_t>> CollectClusters() {
208     std::vector<std::vector<size_t>> cluster_map(clusters_.size());
209     for (size_t i = 0; i < clusters_.size(); i++) {
210       cluster_map[Find(i)].push_back(i);
211     }
212     return cluster_map;
213   }
214 
215   // Get cluster's max node id
GetClusterMaxNodeId(size_t cluster_id)216   size_t GetClusterMaxNodeId(size_t cluster_id) { return clusters_[Find(cluster_id)].max_node_id_; }
217 
218   using VisitFunc = std::function<IncludeType(size_t)>;
Dfs(size_t node_id,const VisitFunc & visitor)219   void Dfs(size_t node_id, const VisitFunc &visitor) {
220     ++seen_;
221     return DepthFirstSearch(Find(node_id), visitor);
222   }
223 
224   // Get cluster size
GetSize(size_t cluster_id)225   size_t GetSize(size_t cluster_id) { return clusters_[Find(cluster_id)].cluster_size_; }
226 
227   // Get cluster's basic op count
GetBasicNodeCount(size_t cluster_id)228   size_t GetBasicNodeCount(size_t cluster_id) { return clusters_[Find(cluster_id)].basic_op_cnt_; }
229 
230   // Get cluster's inputs
GetInputs(size_t cluster_id)231   const std::set<size_t> &GetInputs(size_t cluster_id) {
232     cluster_id = Find(cluster_id);
233     RefreshInputs(cluster_id);
234     return clusters_[cluster_id].inputs_;
235   }
236 
237  private:
RefreshInputs(size_t i)238   void RefreshInputs(size_t i) {
239     auto &inputs = clusters_[i].inputs_;
240     for (auto iter = inputs.begin(); iter != inputs.end();) {
241       size_t new_id = Find(*iter);
242       if (new_id != *iter) {
243         iter = inputs.erase(iter);
244         (void)inputs.insert(new_id);
245       } else {
246         ++iter;
247       }
248     }
249     (void)inputs.erase(i);
250   }
251 
DepthFirstSearch(size_t cluster_id,const VisitFunc & visitor)252   void DepthFirstSearch(size_t cluster_id, const VisitFunc &visitor) {
253     if (clusters_[cluster_id].seed_ >= seen_) return;
254     clusters_[cluster_id].seed_ = seen_;
255     if (visitor(cluster_id) != FOLLOW) {
256       return;
257     }
258     // traverse inputs in descending order.
259     const auto &inputs = GetInputs(cluster_id);
260     for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
261       DepthFirstSearch(*iter, visitor);
262     }
263   }
264 
265   std::vector<Cluster> clusters_;
266   size_t seen_{0};
267 };  // class Graph
268 
269 class CircleChecker {
270  public:
CircleChecker(GraphPtr graph)271   explicit CircleChecker(GraphPtr graph) : graph_(graph) {}
272   ~CircleChecker() = default;
273 
RemoveCircle(std::vector<size_t> * candidates)274   void RemoveCircle(std::vector<size_t> *candidates) {
275     if (candidates->size() <= 1) {
276       return;
277     }
278     candidates_.clear();
279     candidates_.insert(candidates->begin(), candidates->end());
280     for (auto iter = candidates->begin(); iter != candidates->end(); ++iter) {
281       if (!candidates_.count(*iter)) continue;
282       circle_nodes_.clear();
283       if (CheckCircle(*iter)) {
284         RemoveCircleNodesFromCandidates();
285       }
286     }
287     (void)candidates->erase(std::remove_if(candidates->begin(), candidates->end(),
288                                            [this](size_t c) { return this->candidates_.count(c) == 0; }),
289                             candidates->end());
290   }
291 
292  private:
293   /**
294    * Check circle. the candidate is collected into circle_nodes_ if it will form a circle.
295    *
296    * algorithm:
297    * Search from the basenode's input that is NOT in candidates (the basenode is a candidate),
298    * If it depends on a node that belongs to candidates, it will form a circle.
299    *  e.g.     A -> x -> ... -> B
300    *             -> y -> ... -> C
301    * In this case, A, B and C are candidates while x and y are not.
302    * Both x and y are inputs of A. assumes A is the basenode.
303    * When searching from x, the B will be found and added into circle_nodes list,
304    * and then when searching from y, the C will be found and added into circle_nodes list.
305    */
CheckCircle(size_t basenode)306   bool CheckCircle(size_t basenode) {
307     const auto &inputs = graph_->GetInputs(basenode);
308     std::set<size_t> visited_circle_nodes;
309     for (auto x : inputs) {
310       if (candidates_.count(x)) continue;
311       bool has_circle = false;
312       std::set<size_t> done;
313       auto vis_func = [this, &has_circle, &done, &visited_circle_nodes](size_t node_id) {
314         if (done.count(node_id) || acyclic_nodes_.count(node_id) || visited_circle_nodes.count(node_id)) {
315           return EXCLUDE;
316         }
317         (void)done.insert(node_id);
318         if (candidates_.count(node_id)) {
319           has_circle = true;
320           circle_nodes_.push_back(node_id);
321           return EXCLUDE;
322         }
323         // all nodes are indexed by topo order,
324         // so if the current node's cluster's max node id is less than the minimal candidate, a circle cannot be formed
325         // from this node.
326         if (candidates_.empty() || graph_->GetClusterMaxNodeId(node_id) < *candidates_.begin()) {
327           return EXCLUDE;
328         }
329         return FOLLOW;
330       };
331       graph_->Dfs(x, vis_func);
332       if (has_circle) {
333         visited_circle_nodes.insert(done.begin(), done.end());
334       } else {
335         acyclic_nodes_.insert(done.begin(), done.end());
336       }
337     }
338     return !circle_nodes_.empty();
339   }
340 
341   // remove all circle nodes from candidates
RemoveCircleNodesFromCandidates()342   void RemoveCircleNodesFromCandidates() {
343     auto remove_from_candidates = [this](size_t node_id) {
344       if (candidates_.count(node_id)) {
345         (void)candidates_.erase(node_id);
346         return FOLLOW;
347       }
348       return EXCLUDE;
349     };
350     for (auto node : circle_nodes_) {
351       graph_->Dfs(node, remove_from_candidates);
352     }
353   }
354 
355   GraphPtr graph_;               // bind the global graph
356   std::set<size_t> candidates_;  // bind the input candidates
357   std::vector<size_t> circle_nodes_;
358   std::set<size_t> acyclic_nodes_;
359 };  // CircleChecker
360 
FindCandidates(size_t basenode_id)361 std::vector<size_t> GraphKernelCluster::FindCandidates(size_t basenode_id) {
362   std::vector<size_t> candidates;
363   auto include = [this, &candidates, func_graph = nodes_[basenode_id]->func_graph()](size_t cluster_id) {
364     const AnfNodePtr &node = this->nodes_[cluster_id];
365     if (node->func_graph() != func_graph) {
366       return EXCLUDE;
367     }
368     if (!IsClusterableOp(node) && !IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
369       return EXCLUDE;
370     }
371     candidates.push_back(cluster_id);
372     // Do not search from clustered node again.
373     if (this->graph_->GetSize(cluster_id) > 1) {
374       return NOFOLLOW;
375     }
376     return FOLLOW;
377   };
378   graph_->Dfs(basenode_id, include);
379   std::reverse(candidates.begin(), candidates.end());
380   return candidates;
381 }
382 
Process(const FuncGraphPtr & func_graph)383 bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) {
384   bool changed = false;
385   for (int i = SizeToInt(nodes_.size()) - 1; i >= 0; i--) {
386     // if the node has been clustered, it has tried to find its previous nodes, so it's unnecessary to try again.
387     if (graph_->GetSize(IntToSize(i)) > 1) {
388       continue;
389     }
390     auto candidates = FindCandidates(IntToSize(i));
391     CircleChecker(graph_).RemoveCircle(&candidates);
392     RemoveWildGetitem(&candidates);
393     if (candidates.empty()) continue;
394     // merge candidates into one cluster
395     graph_->Merge(candidates);
396   }
397 
398   // Rebuild func_graphs
399   auto clusters = graph_->CollectClusters();
400   for (size_t i = 0; i < clusters.size(); i++) {
401     auto node_without_getitem = std::count_if(clusters[i].begin(), clusters[i].end(), [this](size_t node_id) {
402       return !IsPrimitiveCNode(this->nodes_[node_id], prim::kPrimTupleGetItem);
403     });
404     if (node_without_getitem == 0) continue;
405     if (node_without_getitem == 1) {
406       // Do not cluster a single GraphKernel again.
407       // Do not cluster a single Assign.
408       const auto &node = nodes_[clusters[i][0]];
409       if (AnfAlgo::IsGraphKernel(node) || IsPrimitiveCNode(node, prim::kPrimAssign) || !IsClusterableOp(node)) {
410         continue;
411       }
412     }
413     CreateFuncGraph(func_graph, clusters[i]);
414     changed = true;
415   }
416   return changed;
417 }
418 
CreateFuncGraph(const FuncGraphPtr & func_graph,const std::vector<size_t> & nodes_id)419 void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id) {
420   AnfNodePtrList old_nodes;
421   AnfNodePtr new_node;
422   (void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
423                        [this](size_t id) { return this->nodes_[id]; });
424   std::tie(new_node, std::ignore) = FuseNodesToSubGraph(old_nodes, func_graph, "fusion");
425   std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
426   (void)eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(new_node));
427   if (context::GraphKernelFlags::GetInstance().dump_as_text) {
428     DumpClusterInfo(old_nodes, new_node);
429   }
430 }
431 
DumpClusterInfo(const AnfNodePtrList & old_nodes,const AnfNodePtr & new_node)432 void GraphKernelCluster::DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node) {
433   dump_buf_ << "Source nodes of " << new_node->fullname_with_scope() << " = " << new_node->DebugString() << std::endl;
434   for (const auto &node : old_nodes) {
435     dump_buf_ << "  " << node->fullname_with_scope() << " = " << node->DebugString() << std::endl;
436   }
437   dump_buf_ << "=======================" << std::endl;
438 }
439 
DumpToFile()440 void GraphKernelCluster::DumpToFile() {
441   auto dir_path = FileUtils::CreateNotExistDirs(std::string("./") + kGraphKernelDumpPath);
442   if (!dir_path.has_value()) {
443     MS_LOG(ERROR) << "Failed to CreateNotExistDirs: ./" << kGraphKernelDumpPath;
444     return;
445   }
446   std::string filepath = dir_path.value() + "/" + "graph_kernel_cluster.txt";
447   std::ofstream fout(filepath, std::ios::app);
448   if (!fout.is_open()) {
449     MS_LOG(ERROR) << "Open dump file '" << filepath << "' failed!";
450     return;
451   }
452   fout << dump_buf_.str() << std::endl;
453   fout.close();
454 }
455 
456 // The GetItem node should be clustered with its real input.
457 // If its real input is not in the candidates, the GetItem should be excluded.
RemoveWildGetitem(std::vector<size_t> * candidates)458 void GraphKernelCluster::RemoveWildGetitem(std::vector<size_t> *candidates) {
459   bool changed = false;
460   std::set<size_t> candidates_set(candidates->begin(), candidates->end());
461 
462   for (auto iter = candidates_set.begin(); iter != candidates_set.end();) {
463     size_t cluster_id = *iter;
464     if (IsPrimitiveCNode(nodes_[cluster_id], prim::kPrimTupleGetItem)) {
465       const auto &inputs = graph_->GetInputs(cluster_id);
466       if (inputs.size() != 1) {
467         MS_LOG(ERROR) << "Input size of GetItem(" << cluster_id << ") should be 1, but got " << inputs.size();
468         candidates->clear();
469         return;
470       }
471       auto prev_id = *(inputs.begin());
472       if (!candidates_set.count(prev_id)) {
473         iter = candidates_set.erase(iter);
474         changed = true;
475         continue;
476       }
477     }
478     ++iter;
479   }
480   if (changed) {
481     (void)candidates->erase(std::remove_if(candidates->begin(), candidates->end(),
482                                            [&candidates_set](size_t c) { return candidates_set.count(c) == 0; }),
483                             candidates->end());
484   }
485 }
486 
Init(const FuncGraphPtr & func_graph)487 void GraphKernelCluster::Init(const FuncGraphPtr &func_graph) {
488   op_list_ = GetClusterableOpList();
489   // process cnode only
490   nodes_ = TopoSort(func_graph->get_return(), SuccIncoming,
491                     [](const AnfNodePtr &node) { return node->isa<CNode>() ? FOLLOW : EXCLUDE; });
492   for (size_t i = 0; i < nodes_.size(); i++) {
493     node_idx_map_[nodes_[i]] = i;
494   }
495   graph_ = std::make_shared<Graph>(nodes_, node_idx_map_);
496   MS_EXCEPTION_IF_NULL(graph_);
497 }
498 
Run(const FuncGraphPtr & func_graph)499 bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) {
500   (void)std::make_shared<ShrinkUpdateState>()->Run(func_graph);
501   auto mng = func_graph->manager();
502   MS_EXCEPTION_IF_NULL(mng);
503   Init(func_graph);
504   bool changed = Process(func_graph);
505   if (changed) {
506     if (context::GraphKernelFlags::GetInstance().dump_as_text) {
507       DumpToFile();
508     }
509     mng->RemoveRoots();
510     mng->KeepRoots({func_graph});
511   }
512   Clean();
513   (void)std::make_shared<SpreadUpdateState>()->Run(func_graph);
514   return changed;
515 }
516 }  // namespace opt
517 }  // namespace mindspore
518