1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
17
18 #include "base/core_ops.h"
19 #include "ir/graph_utils.h"
20 #include "utils/file_utils.h"
21 #include "utils/context/graph_kernel_flags.h"
22 #include "backend/kernel_compiler/common_utils.h"
23 #include "backend/session/anf_runtime_algorithm.h"
24 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
25 #include "backend/optimizer/pass/getitem_tuple.h"
26 #include "backend/optimizer/graph_kernel/update_state_formatter.h"
27
28 namespace mindspore {
29 namespace opt {
30 using context::OpLevel_0;
31 using context::OpLevel_1;
GetClusterableOpList()32 std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
33 std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> clusterable_ops_with_level = {
34 // all target
35 {kAllTarget, OpLevel_0, prim::kPrimAbs},
36 {kAllTarget, OpLevel_0, prim::kPrimAdd},
37 {kAllTarget, OpLevel_0, prim::kPrimCast},
38 {kAllTarget, OpLevel_0, prim::kPrimEqual},
39 {kAllTarget, OpLevel_0, prim::kPrimExp},
40 {kAllTarget, OpLevel_0, prim::kPrimInplaceAssign},
41 {kAllTarget, OpLevel_0, prim::kPrimLog},
42 {kAllTarget, OpLevel_0, prim::kPrimMaximum},
43 {kAllTarget, OpLevel_0, prim::kPrimMinimum},
44 {kAllTarget, OpLevel_0, prim::kPrimMul},
45 {kAllTarget, OpLevel_0, prim::kPrimNeg},
46 {kAllTarget, OpLevel_0, prim::kPrimPow},
47 {kAllTarget, OpLevel_0, prim::kPrimRealDiv},
48 {kAllTarget, OpLevel_0, prim::kPrimReciprocal},
49 {kAllTarget, OpLevel_1, prim::kPrimReduceSum},
50 {kAllTarget, OpLevel_1, prim::kPrimReshape},
51 {kAllTarget, OpLevel_0, prim::kPrimRound},
52 {kAllTarget, OpLevel_0, prim::kPrimRsqrt},
53 {kAllTarget, OpLevel_0, prim::kPrimSqrt},
54 {kAllTarget, OpLevel_0, prim::kPrimSub},
55 {kAllTarget, OpLevel_0, prim::kPrimTanh},
56 {kAllTarget, OpLevel_1, prim::kPrimTranspose},
57 // ascend
58 {kAscendDevice, OpLevel_1, prim::kPrimMatMul},
59 {kAscendDevice, OpLevel_1, prim::kPrimTransData},
60 {kAscendDevice, OpLevel_1, prim::kPrimBatchMatMul},
61 // gpu
62 {kGPUDevice, OpLevel_0, prim::kPrimACos},
63 {kGPUDevice, OpLevel_0, prim::kPrimAcosh},
64 {kGPUDevice, OpLevel_1, prim::kPrimArgMax},
65 {kGPUDevice, OpLevel_1, prim::kPrimArgMin},
66 {kGPUDevice, OpLevel_0, prim::kPrimAsin},
67 {kGPUDevice, OpLevel_0, prim::kPrimAsinh},
68 {kGPUDevice, OpLevel_0, prim::kPrimAssign},
69 {kGPUDevice, OpLevel_0, prim::kPrimAtan},
70 {kGPUDevice, OpLevel_0, prim::kPrimAtan2},
71 {kGPUDevice, OpLevel_0, prim::kPrimCos},
72 {kGPUDevice, OpLevel_0, prim::kPrimDiv},
73 {kGPUDevice, OpLevel_0, prim::kPrimErf},
74 {kGPUDevice, OpLevel_0, prim::kPrimExpm1},
75 {kGPUDevice, OpLevel_0, prim::kPrimFloor},
76 {kGPUDevice, OpLevel_0, prim::kPrimFloorDiv},
77 {kGPUDevice, OpLevel_0, prim::kPrimFloorMod},
78 {kGPUDevice, OpLevel_0, prim::kPrimGreater},
79 {kGPUDevice, OpLevel_0, prim::kPrimGreaterEqual},
80 {kGPUDevice, OpLevel_0, prim::kPrimIsFinite},
81 {kGPUDevice, OpLevel_0, prim::kPrimIsInf},
82 {kGPUDevice, OpLevel_0, prim::kPrimIsNan},
83 {kGPUDevice, OpLevel_0, prim::kPrimLess},
84 {kGPUDevice, OpLevel_0, prim::kPrimLessEqual},
85 {kGPUDevice, OpLevel_0, prim::kPrimLogicalAnd},
86 {kGPUDevice, OpLevel_0, prim::kPrimLogicalOr},
87 {kGPUDevice, OpLevel_0, prim::kPrimLogicalNot},
88 {kGPUDevice, OpLevel_0, prim::kPrimMod},
89 {kGPUDevice, OpLevel_0, prim::kPrimNotEqual},
90 {kGPUDevice, OpLevel_1, prim::kPrimReduceMax},
91 {kGPUDevice, OpLevel_1, prim::kPrimReduceMin},
92 {kGPUDevice, OpLevel_0, prim::kPrimSelect},
93 {kGPUDevice, OpLevel_0, prim::kPrimSign},
94 {kGPUDevice, OpLevel_0, prim::kPrimSin},
95 {kGPUDevice, OpLevel_0, prim::kPrimStridedSlice},
96 {kGPUDevice, OpLevel_0, prim::kPrimUserDefined},
97 };
98 const auto &flags = context::GraphKernelFlags::GetInstance();
99 std::vector<PrimitivePtr> clusterable_ops = GetValidOps(clusterable_ops_with_level, flags.fusion_ops_level);
100 OpListFilter(&clusterable_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops);
101 return clusterable_ops;
102 }
103
104 namespace {
CountGraphKernelInnerNodes(const AnfNodePtr & node)105 size_t CountGraphKernelInnerNodes(const AnfNodePtr &node) {
106 AnfNodePtrList node_list;
107 kernel::GetValidKernelNodes(AnfAlgo::GetCNodeFuncGraphPtr(node), &node_list);
108 return node_list.size();
109 }
110 } // namespace
111
IsClusterableOp(const AnfNodePtr & node)112 bool GraphKernelCluster::IsClusterableOp(const AnfNodePtr &node) {
113 if (AnfAlgo::IsGraphKernel(node)) {
114 return true;
115 }
116 if (IsKeepBasicNode(node)) {
117 return false;
118 }
119 bool node_in_oplist = std::any_of(op_list_.begin(), op_list_.end(),
120 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
121 if (!node_in_oplist) {
122 return false;
123 }
124 #if ENABLE_D
125 // For AICPU operators, only the Reshape can be clustered.
126 if (AnfAlgo::GetProcessor(node) != kernel::Processor::AICORE && !IsPrimitiveCNode(node, prim::kPrimReshape)) {
127 return false;
128 }
129 #endif
130 return true;
131 }
132
133 class Graph {
134 struct Cluster {
135 size_t cluster_id_; // node_id of the representative.
136 size_t cluster_size_{1}; // size of cluster, composite node is considered as one node.
137 size_t basic_op_cnt_{1}; // basic node count, the inner nodes of composite node are counted.
138 std::set<size_t> inputs_; // inputs' cluster_id.
139 size_t seed_{0}; // visited flag of dfs.
140 size_t max_node_id_; // largest node id of a cluster
141
Clustermindspore::opt::Graph::Cluster142 Cluster(size_t node_id, const AnfNodePtr &node, const std::unordered_map<AnfNodePtr, size_t> &node_idx_map)
143 : cluster_id_(node_id), max_node_id_(node_id) {
144 if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
145 basic_op_cnt_ = 0;
146 } else if (AnfAlgo::IsGraphKernel(node)) {
147 // the basic_op_cnt_ is used to limit the composite op size
148 basic_op_cnt_ = CountGraphKernelInnerNodes(node);
149 }
150 auto cnode = node->cast<CNodePtr>();
151 MS_EXCEPTION_IF_NULL(cnode);
152 for (const auto &inp : cnode->inputs()) {
153 auto iter = node_idx_map.find(inp);
154 if (iter != node_idx_map.end()) {
155 // At the beginning, cluster_id is equal to node_id
156 (void)inputs_.insert(iter->second);
157 }
158 }
159 }
160 ~Cluster() = default;
161
Mergemindspore::opt::Graph::Cluster162 void Merge(Cluster *other_cluster) {
163 other_cluster->cluster_id_ = cluster_id_;
164 max_node_id_ = std::max(other_cluster->max_node_id_, max_node_id_);
165 cluster_size_ += other_cluster->cluster_size_;
166 basic_op_cnt_ += other_cluster->basic_op_cnt_;
167 (void)std::for_each(other_cluster->inputs_.begin(), other_cluster->inputs_.end(),
168 [this](size_t inp) { (void)this->inputs_.insert(inp); });
169 other_cluster->Clean();
170 }
171
172 // clean the info to free memory.
Cleanmindspore::opt::Graph::Cluster173 void Clean() {
174 inputs_.clear();
175 cluster_size_ = 0;
176 basic_op_cnt_ = 0;
177 max_node_id_ = 0;
178 }
179 }; // struct Cluster
180
181 public:
182 // Init and build graph
Graph(const AnfNodePtrList & nodes,const std::unordered_map<AnfNodePtr,size_t> & node_idx_map)183 Graph(const AnfNodePtrList &nodes, const std::unordered_map<AnfNodePtr, size_t> &node_idx_map) {
184 clusters_.reserve(nodes.size());
185 for (size_t i = 0; i < nodes.size(); i++) {
186 (void)clusters_.emplace_back(i, nodes[i], node_idx_map);
187 }
188 }
189 ~Graph() = default;
190
191 // find the representative of the cluster
Find(size_t node_id)192 size_t Find(size_t node_id) {
193 size_t &pre_id = clusters_[node_id].cluster_id_;
194 return (pre_id == clusters_[pre_id].cluster_id_) ? pre_id : (pre_id = Find(pre_id));
195 }
196
197 // merge clusters, the smallest cluster id will be the new cluster id.
Merge(const std::vector<size_t> & candidates)198 void Merge(const std::vector<size_t> &candidates) {
199 size_t min_id = *std::min_element(candidates.begin(), candidates.end());
200 for (auto id : candidates) {
201 if (id == min_id) continue;
202 clusters_[min_id].Merge(&clusters_[id]);
203 }
204 }
205
206 // Collect nodes together that are in the same cluster.
CollectClusters()207 std::vector<std::vector<size_t>> CollectClusters() {
208 std::vector<std::vector<size_t>> cluster_map(clusters_.size());
209 for (size_t i = 0; i < clusters_.size(); i++) {
210 cluster_map[Find(i)].push_back(i);
211 }
212 return cluster_map;
213 }
214
215 // Get cluster's max node id
GetClusterMaxNodeId(size_t cluster_id)216 size_t GetClusterMaxNodeId(size_t cluster_id) { return clusters_[Find(cluster_id)].max_node_id_; }
217
218 using VisitFunc = std::function<IncludeType(size_t)>;
Dfs(size_t node_id,const VisitFunc & visitor)219 void Dfs(size_t node_id, const VisitFunc &visitor) {
220 ++seen_;
221 return DepthFirstSearch(Find(node_id), visitor);
222 }
223
224 // Get cluster size
GetSize(size_t cluster_id)225 size_t GetSize(size_t cluster_id) { return clusters_[Find(cluster_id)].cluster_size_; }
226
227 // Get cluster's basic op count
GetBasicNodeCount(size_t cluster_id)228 size_t GetBasicNodeCount(size_t cluster_id) { return clusters_[Find(cluster_id)].basic_op_cnt_; }
229
230 // Get cluster's inputs
GetInputs(size_t cluster_id)231 const std::set<size_t> &GetInputs(size_t cluster_id) {
232 cluster_id = Find(cluster_id);
233 RefreshInputs(cluster_id);
234 return clusters_[cluster_id].inputs_;
235 }
236
237 private:
RefreshInputs(size_t i)238 void RefreshInputs(size_t i) {
239 auto &inputs = clusters_[i].inputs_;
240 for (auto iter = inputs.begin(); iter != inputs.end();) {
241 size_t new_id = Find(*iter);
242 if (new_id != *iter) {
243 iter = inputs.erase(iter);
244 (void)inputs.insert(new_id);
245 } else {
246 ++iter;
247 }
248 }
249 (void)inputs.erase(i);
250 }
251
DepthFirstSearch(size_t cluster_id,const VisitFunc & visitor)252 void DepthFirstSearch(size_t cluster_id, const VisitFunc &visitor) {
253 if (clusters_[cluster_id].seed_ >= seen_) return;
254 clusters_[cluster_id].seed_ = seen_;
255 if (visitor(cluster_id) != FOLLOW) {
256 return;
257 }
258 // traverse inputs in descending order.
259 const auto &inputs = GetInputs(cluster_id);
260 for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
261 DepthFirstSearch(*iter, visitor);
262 }
263 }
264
265 std::vector<Cluster> clusters_;
266 size_t seen_{0};
267 }; // class Graph
268
269 class CircleChecker {
270 public:
CircleChecker(GraphPtr graph)271 explicit CircleChecker(GraphPtr graph) : graph_(graph) {}
272 ~CircleChecker() = default;
273
RemoveCircle(std::vector<size_t> * candidates)274 void RemoveCircle(std::vector<size_t> *candidates) {
275 if (candidates->size() <= 1) {
276 return;
277 }
278 candidates_.clear();
279 candidates_.insert(candidates->begin(), candidates->end());
280 for (auto iter = candidates->begin(); iter != candidates->end(); ++iter) {
281 if (!candidates_.count(*iter)) continue;
282 circle_nodes_.clear();
283 if (CheckCircle(*iter)) {
284 RemoveCircleNodesFromCandidates();
285 }
286 }
287 (void)candidates->erase(std::remove_if(candidates->begin(), candidates->end(),
288 [this](size_t c) { return this->candidates_.count(c) == 0; }),
289 candidates->end());
290 }
291
292 private:
293 /**
294 * Check circle. the candidate is collected into circle_nodes_ if it will form a circle.
295 *
296 * algorithm:
297 * Search from the basenode's input that is NOT in candidates (the basenode is a candidate),
298 * If it depends on a node that belongs to candidates, it will form a circle.
299 * e.g. A -> x -> ... -> B
300 * -> y -> ... -> C
301 * In this case, A, B and C are candidates while x and y are not.
302 * Both x and y are inputs of A. assumes A is the basenode.
303 * When searching from x, the B will be found and added into circle_nodes list,
304 * and then when searching from y, the C will be found and added into circle_nodes list.
305 */
CheckCircle(size_t basenode)306 bool CheckCircle(size_t basenode) {
307 const auto &inputs = graph_->GetInputs(basenode);
308 std::set<size_t> visited_circle_nodes;
309 for (auto x : inputs) {
310 if (candidates_.count(x)) continue;
311 bool has_circle = false;
312 std::set<size_t> done;
313 auto vis_func = [this, &has_circle, &done, &visited_circle_nodes](size_t node_id) {
314 if (done.count(node_id) || acyclic_nodes_.count(node_id) || visited_circle_nodes.count(node_id)) {
315 return EXCLUDE;
316 }
317 (void)done.insert(node_id);
318 if (candidates_.count(node_id)) {
319 has_circle = true;
320 circle_nodes_.push_back(node_id);
321 return EXCLUDE;
322 }
323 // all nodes are indexed by topo order,
324 // so if the current node's cluster's max node id is less than the minimal candidate, a circle cannot be formed
325 // from this node.
326 if (candidates_.empty() || graph_->GetClusterMaxNodeId(node_id) < *candidates_.begin()) {
327 return EXCLUDE;
328 }
329 return FOLLOW;
330 };
331 graph_->Dfs(x, vis_func);
332 if (has_circle) {
333 visited_circle_nodes.insert(done.begin(), done.end());
334 } else {
335 acyclic_nodes_.insert(done.begin(), done.end());
336 }
337 }
338 return !circle_nodes_.empty();
339 }
340
341 // remove all circle nodes from candidates
RemoveCircleNodesFromCandidates()342 void RemoveCircleNodesFromCandidates() {
343 auto remove_from_candidates = [this](size_t node_id) {
344 if (candidates_.count(node_id)) {
345 (void)candidates_.erase(node_id);
346 return FOLLOW;
347 }
348 return EXCLUDE;
349 };
350 for (auto node : circle_nodes_) {
351 graph_->Dfs(node, remove_from_candidates);
352 }
353 }
354
355 GraphPtr graph_; // bind the global graph
356 std::set<size_t> candidates_; // bind the input candidates
357 std::vector<size_t> circle_nodes_;
358 std::set<size_t> acyclic_nodes_;
359 }; // CircleChecker
360
FindCandidates(size_t basenode_id)361 std::vector<size_t> GraphKernelCluster::FindCandidates(size_t basenode_id) {
362 std::vector<size_t> candidates;
363 auto include = [this, &candidates, func_graph = nodes_[basenode_id]->func_graph()](size_t cluster_id) {
364 const AnfNodePtr &node = this->nodes_[cluster_id];
365 if (node->func_graph() != func_graph) {
366 return EXCLUDE;
367 }
368 if (!IsClusterableOp(node) && !IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
369 return EXCLUDE;
370 }
371 candidates.push_back(cluster_id);
372 // Do not search from clustered node again.
373 if (this->graph_->GetSize(cluster_id) > 1) {
374 return NOFOLLOW;
375 }
376 return FOLLOW;
377 };
378 graph_->Dfs(basenode_id, include);
379 std::reverse(candidates.begin(), candidates.end());
380 return candidates;
381 }
382
Process(const FuncGraphPtr & func_graph)383 bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) {
384 bool changed = false;
385 for (int i = SizeToInt(nodes_.size()) - 1; i >= 0; i--) {
386 // if the node has been clustered, it has tried to find its previous nodes, so it's unnecessary to try again.
387 if (graph_->GetSize(IntToSize(i)) > 1) {
388 continue;
389 }
390 auto candidates = FindCandidates(IntToSize(i));
391 CircleChecker(graph_).RemoveCircle(&candidates);
392 RemoveWildGetitem(&candidates);
393 if (candidates.empty()) continue;
394 // merge candidates into one cluster
395 graph_->Merge(candidates);
396 }
397
398 // Rebuild func_graphs
399 auto clusters = graph_->CollectClusters();
400 for (size_t i = 0; i < clusters.size(); i++) {
401 auto node_without_getitem = std::count_if(clusters[i].begin(), clusters[i].end(), [this](size_t node_id) {
402 return !IsPrimitiveCNode(this->nodes_[node_id], prim::kPrimTupleGetItem);
403 });
404 if (node_without_getitem == 0) continue;
405 if (node_without_getitem == 1) {
406 // Do not cluster a single GraphKernel again.
407 // Do not cluster a single Assign.
408 const auto &node = nodes_[clusters[i][0]];
409 if (AnfAlgo::IsGraphKernel(node) || IsPrimitiveCNode(node, prim::kPrimAssign) || !IsClusterableOp(node)) {
410 continue;
411 }
412 }
413 CreateFuncGraph(func_graph, clusters[i]);
414 changed = true;
415 }
416 return changed;
417 }
418
CreateFuncGraph(const FuncGraphPtr & func_graph,const std::vector<size_t> & nodes_id)419 void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id) {
420 AnfNodePtrList old_nodes;
421 AnfNodePtr new_node;
422 (void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
423 [this](size_t id) { return this->nodes_[id]; });
424 std::tie(new_node, std::ignore) = FuseNodesToSubGraph(old_nodes, func_graph, "fusion");
425 std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
426 (void)eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(new_node));
427 if (context::GraphKernelFlags::GetInstance().dump_as_text) {
428 DumpClusterInfo(old_nodes, new_node);
429 }
430 }
431
DumpClusterInfo(const AnfNodePtrList & old_nodes,const AnfNodePtr & new_node)432 void GraphKernelCluster::DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node) {
433 dump_buf_ << "Source nodes of " << new_node->fullname_with_scope() << " = " << new_node->DebugString() << std::endl;
434 for (const auto &node : old_nodes) {
435 dump_buf_ << " " << node->fullname_with_scope() << " = " << node->DebugString() << std::endl;
436 }
437 dump_buf_ << "=======================" << std::endl;
438 }
439
DumpToFile()440 void GraphKernelCluster::DumpToFile() {
441 auto dir_path = FileUtils::CreateNotExistDirs(std::string("./") + kGraphKernelDumpPath);
442 if (!dir_path.has_value()) {
443 MS_LOG(ERROR) << "Failed to CreateNotExistDirs: ./" << kGraphKernelDumpPath;
444 return;
445 }
446 std::string filepath = dir_path.value() + "/" + "graph_kernel_cluster.txt";
447 std::ofstream fout(filepath, std::ios::app);
448 if (!fout.is_open()) {
449 MS_LOG(ERROR) << "Open dump file '" << filepath << "' failed!";
450 return;
451 }
452 fout << dump_buf_.str() << std::endl;
453 fout.close();
454 }
455
456 // The GetItem node should be clustered with its real input.
457 // If its real input is not in the candidates, the GetItem should be excluded.
RemoveWildGetitem(std::vector<size_t> * candidates)458 void GraphKernelCluster::RemoveWildGetitem(std::vector<size_t> *candidates) {
459 bool changed = false;
460 std::set<size_t> candidates_set(candidates->begin(), candidates->end());
461
462 for (auto iter = candidates_set.begin(); iter != candidates_set.end();) {
463 size_t cluster_id = *iter;
464 if (IsPrimitiveCNode(nodes_[cluster_id], prim::kPrimTupleGetItem)) {
465 const auto &inputs = graph_->GetInputs(cluster_id);
466 if (inputs.size() != 1) {
467 MS_LOG(ERROR) << "Input size of GetItem(" << cluster_id << ") should be 1, but got " << inputs.size();
468 candidates->clear();
469 return;
470 }
471 auto prev_id = *(inputs.begin());
472 if (!candidates_set.count(prev_id)) {
473 iter = candidates_set.erase(iter);
474 changed = true;
475 continue;
476 }
477 }
478 ++iter;
479 }
480 if (changed) {
481 (void)candidates->erase(std::remove_if(candidates->begin(), candidates->end(),
482 [&candidates_set](size_t c) { return candidates_set.count(c) == 0; }),
483 candidates->end());
484 }
485 }
486
Init(const FuncGraphPtr & func_graph)487 void GraphKernelCluster::Init(const FuncGraphPtr &func_graph) {
488 op_list_ = GetClusterableOpList();
489 // process cnode only
490 nodes_ = TopoSort(func_graph->get_return(), SuccIncoming,
491 [](const AnfNodePtr &node) { return node->isa<CNode>() ? FOLLOW : EXCLUDE; });
492 for (size_t i = 0; i < nodes_.size(); i++) {
493 node_idx_map_[nodes_[i]] = i;
494 }
495 graph_ = std::make_shared<Graph>(nodes_, node_idx_map_);
496 MS_EXCEPTION_IF_NULL(graph_);
497 }
498
Run(const FuncGraphPtr & func_graph)499 bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) {
500 (void)std::make_shared<ShrinkUpdateState>()->Run(func_graph);
501 auto mng = func_graph->manager();
502 MS_EXCEPTION_IF_NULL(mng);
503 Init(func_graph);
504 bool changed = Process(func_graph);
505 if (changed) {
506 if (context::GraphKernelFlags::GetInstance().dump_as_text) {
507 DumpToFile();
508 }
509 mng->RemoveRoots();
510 mng->KeepRoots({func_graph});
511 }
512 Clean();
513 (void)std::make_shared<SpreadUpdateState>()->Run(func_graph);
514 return changed;
515 }
516 } // namespace opt
517 } // namespace mindspore
518