1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/adapter/graph_kernel_cluster_cloud.h"
17 #include <set>
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "mindspore/core/ops/nn_optimizer_ops.h"
20 #include "mindspore/core/ops/nn_ops.h"
21 #include "mindspore/core/ops/math_ops.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "mindspore/core/ops/comparison_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "ir/graph_utils.h"
27 #include "include/common/utils/anfalgo.h"
28 #include "utils/anf_utils.h"
29 #include "utils/ms_context.h"
30 #include "utils/file_utils.h"
31 #include "backend/common/graph_kernel/graph_kernel_flags.h"
32 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
33 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
34 #include "backend/common/graph_kernel/core/value_depend_op_utils.h"
35 #include "backend/common/graph_kernel/graph_kernel_helper.h"
36
37 namespace mindspore::graphkernel {
38 namespace {
39 std::set<TypeId> dvm_float_types{kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeBFloat16};
40
CheckFormat(const AnfNodePtr & node)41 bool CheckFormat(const AnfNodePtr &node) {
42 if (common::AnfAlgo::IsDynamicShape(node) && !CheckDefaultFormat(node)) {
43 // dvm kernel infer shape use inputs device shape, but the output abstract shape inferred from device shape is
44 // not unique if some shape value are not a multiple of 16
45 MS_LOG(DEBUG) << "skip node: " << node->fullname_with_scope()
46 << " because only default format is supported in dynamic shape";
47 return false;
48 }
49 auto cb = Callback::Instance();
50 MS_EXCEPTION_IF_NULL(cb);
51 auto input_num = AnfUtils::GetInputTensorNum(node);
52 if (input_num > 0) {
53 bool has_special_format = false;
54 auto base_format = cb->GetInputFormat(node, 0);
55 for (size_t i = 0; i < input_num; ++i) {
56 auto input_format = cb->GetInputFormat(node, i);
57 if (!has_special_format &&
58 (input_format.find("FRACTAL") != std::string::npos || input_format.find("C0") != std::string::npos)) {
59 has_special_format = true;
60 }
61 if (has_special_format && input_format != base_format) {
62 // mixed special format and default format is not supported, because extra Reshape/TransData is needed
63 return false;
64 }
65 }
66 }
67 return true;
68 }
69
DvmSliceSupported(const AnfNodePtr & node,TypeId node_output_type)70 bool DvmSliceSupported(const AnfNodePtr &node, TypeId node_output_type) {
71 constexpr size_t input_num = 3;
72 if (common::AnfAlgo::IsDynamicRankNode(node) || GetShape(node).size() > input_num) {
73 return false;
74 }
75 if (IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
76 auto cnode = node->cast<CNodePtr>();
77 auto step_node = cnode->input(kIndex4)->cast<ValueNodePtr>();
78 if (step_node == nullptr) {
79 return false;
80 }
81 auto step_vector = GetValue<std::vector<int64_t>>(step_node->value());
82 if (std::any_of(step_vector.begin(), step_vector.end(), [](int i) { return i != 1; })) {
83 return false;
84 }
85 }
86 return (dvm_float_types.find(node_output_type) != dvm_float_types.end() || node_output_type == kNumberTypeInt32);
87 }
88
DvmSupported(const AnfNodePtr & node)89 bool DvmSupported(const AnfNodePtr &node) {
90 // check format
91 if (!CheckFormat(node)) {
92 return false;
93 }
94 auto cb = Callback::Instance();
95 MS_EXCEPTION_IF_NULL(cb);
96 auto node_output_type = cb->GetOutputType(node, 0);
97 // cast op
98 if (IsPrimitiveCNode(node, prim::kPrimCast)) {
99 static std::set<TypeId> supported_types{kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeBool, kNumberTypeInt32,
100 kNumberTypeBFloat16};
101 auto node_input_type = cb->GetInputType(node, 0);
102 return !(supported_types.find(node_input_type) == supported_types.end() ||
103 supported_types.find(node_output_type) == supported_types.end());
104 }
105 // reduceSum op
106 if (IsPrimitiveCNode(node, prim::kPrimReduceSum)) {
107 auto prim = GetCNodePrimitive(node);
108 MS_EXCEPTION_IF_NULL(prim);
109 auto skip_mode_attr = prim->GetAttr(kAttrSkipMode);
110 MS_EXCEPTION_IF_NULL(skip_mode_attr);
111 auto skip_mode = GetValue<bool>(skip_mode_attr);
112 if (skip_mode == true) {
113 return false;
114 }
115 }
116 // compare op
117 static std::vector<PrimitivePtr> compare_ops{prim::kPrimEqual, prim::kPrimNotEqual, prim::kPrimGreater,
118 prim::kPrimGreaterEqual, prim::kPrimLess, prim::kPrimLessEqual};
119 if (std::any_of(compare_ops.begin(), compare_ops.end(),
120 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
121 auto node_input_type = cb->GetInputType(node, 0);
122 return (dvm_float_types.find(node_input_type) != dvm_float_types.end() || node_input_type == kNumberTypeInt32);
123 }
124 // int op
125 static std::vector<PrimitivePtr> int_ops{
126 prim::kPrimAdd, prim::kPrimSub, prim::kPrimMul, prim::kPrimMaximum, prim::kPrimMinimum,
127 prim::kPrimNeg, prim::kPrimAbs, prim::kPrimSelect, prim::kPrimAssign, prim::kPrimBroadcastTo};
128 if (std::any_of(int_ops.begin(), int_ops.end(),
129 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
130 return (dvm_float_types.find(node_output_type) != dvm_float_types.end() || node_output_type == kNumberTypeInt32);
131 }
132 // slice op
133 static std::vector<PrimitivePtr> slice_ops{prim::kPrimSlice, prim::kPrimStridedSlice};
134 if (std::any_of(slice_ops.begin(), slice_ops.end(),
135 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
136 return DvmSliceSupported(node, node_output_type);
137 }
138 // matmul op
139 static std::vector<PrimitivePtr> matmul_ops{prim::kPrimMatMul, prim::kPrimBatchMatMul};
140 if (std::any_of(matmul_ops.begin(), matmul_ops.end(),
141 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
142 return node_output_type == kNumberTypeFloat16 || node_output_type == kNumberTypeBFloat16;
143 }
144 if (IsPrimitiveCNode(node, prim::kPrimTranspose)) {
145 // for bf16, extra cast will be inserted, to do: move ConvertBFloat16 after garph kernel split
146 return node_output_type == kNumberTypeFloat16 || node_output_type == kNumberTypeFloat32;
147 }
148 // other op
149 return dvm_float_types.find(node_output_type) != dvm_float_types.end();
150 }
151
152 const std::vector<OpWithLevel> clusterable_ops_with_level = {
153 // all target
154 {kAllTarget, OpLevel_0, prim::kPrimAbs},
155 {kAllTarget, OpLevel_0, prim::kPrimAdd},
156 {kAllTarget, OpLevel_0, prim::kPrimCast},
157 {kAllTarget, OpLevel_0, prim::kPrimEqual},
158 {kAllTarget, OpLevel_0, prim::kPrimExp},
159 {kAllTarget, OpLevel_0, prim::kPrimLog},
160 {kAllTarget, OpLevel_0, prim::kPrimMaximum},
161 {kAllTarget, OpLevel_0, prim::kPrimMinimum},
162 {kAllTarget, OpLevel_0, prim::kPrimMul},
163 {kAllTarget, OpLevel_0, prim::kPrimNeg},
164 {kAllTarget, OpLevel_0, prim::kPrimPow},
165 {kAllTarget, OpLevel_0, prim::kPrimRealDiv},
166 {kAllTarget, OpLevel_0, prim::kPrimReciprocal},
167 {kAllTarget, OpLevel_1, prim::kPrimReduceSum},
168 {kAllTarget, OpLevel_1, prim::kPrimReshape},
169 {kAllTarget, OpLevel_0, prim::kPrimRound},
170 {kAllTarget, OpLevel_0, prim::kPrimRsqrt},
171 {kAllTarget, OpLevel_0, prim::kPrimSqrt},
172 {kAllTarget, OpLevel_0, prim::kPrimSub},
173 {kAllTarget, OpLevel_0, prim::kPrimTanh},
174 {kAllTarget, OpLevel_1, prim::kPrimTranspose},
175 // ascend
176 {kAscendDevice, OpLevel_1, prim::kPrimMatMul},
177 {kAscendDevice, OpLevel_1, prim::kPrimTransData},
178 {kAscendDevice, OpLevel_1, prim::kPrimBatchMatMul},
179 // gpu
180 {kGPUDevice, OpLevel_0, prim::kPrimACos},
181 {kGPUDevice, OpLevel_0, prim::kPrimAcosh},
182 {kGPUDevice, OpLevel_2, prim::kPrimArgMax},
183 {kGPUDevice, OpLevel_2, prim::kPrimArgmin},
184 {kGPUDevice, OpLevel_0, prim::kPrimAsin},
185 {kGPUDevice, OpLevel_0, prim::kPrimAsinh},
186 {kGPUDevice, OpLevel_0, prim::kPrimAssign},
187 {kGPUDevice, OpLevel_0, prim::kPrimAtan},
188 {kGPUDevice, OpLevel_0, prim::kPrimAtan2},
189 {kGPUDevice, OpLevel_0, prim::kPrimCos},
190 {kGPUDevice, OpLevel_0, prim::kPrimDiv},
191 {kGPUDevice, OpLevel_0, prim::kPrimErf},
192 {kGPUDevice, OpLevel_0, prim::kPrimExpm1},
193 {kGPUDevice, OpLevel_0, prim::kPrimFloor},
194 {kGPUDevice, OpLevel_0, prim::kPrimFloorDiv},
195 {kGPUDevice, OpLevel_0, prim::kPrimFloorMod},
196 {kGPUDevice, OpLevel_0, prim::kPrimGreater},
197 {kGPUDevice, OpLevel_0, prim::kPrimGreaterEqual},
198 {kGPUDevice, OpLevel_0, prim::kPrimIsFinite},
199 {kGPUDevice, OpLevel_0, prim::kPrimIsInf},
200 {kGPUDevice, OpLevel_0, prim::kPrimIsNan},
201 {kGPUDevice, OpLevel_0, prim::kPrimLess},
202 {kGPUDevice, OpLevel_0, prim::kPrimLessEqual},
203 {kGPUDevice, OpLevel_0, prim::kPrimLogicalAnd},
204 {kGPUDevice, OpLevel_0, prim::kPrimLogicalOr},
205 {kGPUDevice, OpLevel_0, prim::kPrimLogicalNot},
206 {kGPUDevice, OpLevel_0, prim::kPrimMod},
207 {kGPUDevice, OpLevel_0, prim::kPrimNotEqual},
208 {kGPUDevice, OpLevel_1, prim::kPrimReduceMax},
209 {kGPUDevice, OpLevel_1, prim::kPrimReduceMin},
210 {kGPUDevice, OpLevel_0, prim::kPrimSelect},
211 {kGPUDevice, OpLevel_0, prim::kPrimSign},
212 {kGPUDevice, OpLevel_0, prim::kPrimSin},
213 {kGPUDevice, OpLevel_0, prim::kPrimStridedSlice},
214 {kGPUDevice, OpLevel_1, prim::kPrimCumSum},
215 {kGPUDevice, OpLevel_1, prim::kPrimOneHot},
216 // cpu
217 {kCPUDevice, OpLevel_0, prim::kPrimLogicalNot},
218 {kCPUDevice, OpLevel_0, prim::kPrimMod},
219 {kCPUDevice, OpLevel_1, prim::kPrimReduceMax},
220 {kCPUDevice, OpLevel_0, prim::kPrimSelect},
221 {kCPUDevice, OpLevel_0, prim::kPrimLess},
222 {kCPUDevice, OpLevel_0, prim::kPrimLessEqual},
223 };
224
225 const std::vector<OpWithLevel> clusterable_ops_with_level_v2 = {
226 // cpu
227 {kCPUDevice, OpLevel_0, prim::kPrimNotEqual},
228 {kCPUDevice, OpLevel_0, prim::kPrimGreaterEqual},
229 {kCPUDevice, OpLevel_0, prim::kPrimGreater},
230 {kCPUDevice, OpLevel_0, prim::kPrimFloor},
231 {kCPUDevice, OpLevel_0, prim::kPrimIsNan},
232 {kCPUDevice, OpLevel_0, prim::kPrimAssign},
233 {kCPUDevice, OpLevel_0, prim::kPrimBroadcastTo},
234 {kCPUDevice, OpLevel_0, prim::kPrimTile},
235 {kCPUDevice, OpLevel_0, prim::kPrimLogicalAnd},
236 {kCPUDevice, OpLevel_0, prim::kPrimCos},
237 {kCPUDevice, OpLevel_0, prim::kPrimSin},
238 {kCPUDevice, OpLevel_0, prim::kPrimACos},
239 {kCPUDevice, OpLevel_0, prim::kPrimAsin},
240 {kCPUDevice, OpLevel_0, prim::kPrimTanh},
241 {kCPUDevice, OpLevel_0, prim::kPrimAtan2},
242 {kCPUDevice, OpLevel_0, prim::kPrimMinimum},
243 {kCPUDevice, OpLevel_0, prim::kPrimMaximum},
244 {kCPUDevice, OpLevel_0, prim::kPrimReduceAll},
245 {kCPUDevice, OpLevel_0, prim::kPrimStridedSlice},
246 // gpu
247 {kGPUDevice, OpLevel_0, prim::kPrimNotEqual},
248 {kGPUDevice, OpLevel_0, prim::kPrimSelect},
249 {kGPUDevice, OpLevel_0, prim::kPrimTile},
250 {kGPUDevice, OpLevel_0, prim::kPrimLogicalAnd},
251 {kGPUDevice, OpLevel_0, prim::kPrimCos},
252 {kGPUDevice, OpLevel_0, prim::kPrimSin},
253 {kGPUDevice, OpLevel_0, prim::kPrimMinimum},
254 {kGPUDevice, OpLevel_0, prim::kPrimMaximum},
255 {kGPUDevice, OpLevel_0, prim::kPrimAssign},
256 };
257
258 const std::vector<std::string> disable_cluster_op_list_v2 = {"OneHot", "CumSum", "Transpose", "BatchMatMul",
259 "MatMul", "BroadcastTo", "StridedSlice"};
260
261 const std::vector<OpWithLevel> clusterable_ops_with_level_dvm = {
262 {kAscendDevice, OpLevel_0, prim::kPrimAbs}, {kAscendDevice, OpLevel_0, prim::kPrimAdd},
263 {kAscendDevice, OpLevel_0, prim::kPrimBroadcastTo}, {kAscendDevice, OpLevel_0, prim::kPrimCast},
264 {kAscendDevice, OpLevel_0, prim::kPrimExp}, {kAscendDevice, OpLevel_0, prim::kPrimLog},
265 {kAscendDevice, OpLevel_0, prim::kPrimMaximum}, {kAscendDevice, OpLevel_0, prim::kPrimMinimum},
266 {kAscendDevice, OpLevel_0, prim::kPrimMul}, {kAscendDevice, OpLevel_0, prim::kPrimNeg},
267 {kAscendDevice, OpLevel_0, prim::kPrimPow}, {kAscendDevice, OpLevel_0, prim::kPrimDiv},
268 {kAscendDevice, OpLevel_0, prim::kPrimRealDiv}, {kAscendDevice, OpLevel_0, prim::kPrimReciprocal},
269 {kAscendDevice, OpLevel_0, prim::kPrimRsqrt}, {kAscendDevice, OpLevel_0, prim::kPrimSqrt},
270 {kAscendDevice, OpLevel_0, prim::kPrimSub}, {kAscendDevice, OpLevel_0, prim::kPrimEqual},
271 {kAscendDevice, OpLevel_0, prim::kPrimNotEqual}, {kAscendDevice, OpLevel_0, prim::kPrimGreater},
272 {kAscendDevice, OpLevel_0, prim::kPrimGreaterEqual}, {kAscendDevice, OpLevel_0, prim::kPrimLess},
273 {kAscendDevice, OpLevel_0, prim::kPrimLessEqual}, {kAscendDevice, OpLevel_0, prim::kPrimLogicalAnd},
274 {kAscendDevice, OpLevel_0, prim::kPrimLogicalOr}, {kAscendDevice, OpLevel_0, prim::kPrimLogicalNot},
275 {kAscendDevice, OpLevel_0, prim::kPrimSelect}, {kAscendDevice, OpLevel_0, prim::kPrimAssign},
276 {kAscendDevice, OpLevel_0, prim::kPrimReduceSum}, {kAscendDevice, OpLevel_0, prim::kPrimIsFinite},
277 {kAscendDevice, OpLevel_1, prim::kPrimReshape}, {kAscendDevice, OpLevel_0, prim::kPrimTranspose},
278 };
279 } // namespace
280
GetClusterOps()281 std::vector<PrimitivePtr> StaticShapeCluster::GetClusterOps() {
282 const auto &flags = GraphKernelFlags::GetInstance();
283 std::vector<std::string> disable_cluster_ops = flags.disable_cluster_ops;
284 auto cb = Callback::Instance();
285
286 std::vector<OpWithLevel> clusterable_ops;
287 if (flags.kernel_generator == "AKG_V2") {
288 clusterable_ops = clusterable_ops_with_level;
289 clusterable_ops.insert(clusterable_ops.end(), clusterable_ops_with_level_v2.begin(),
290 clusterable_ops_with_level_v2.end());
291 if (cb->GetTargetFromContext() == kCPUDevice &&
292 std::find(flags.enable_cluster_ops.begin(), flags.enable_cluster_ops.end(), "Reshape") ==
293 flags.enable_cluster_ops.end()) {
294 disable_cluster_ops.push_back("Reshape");
295 }
296 if (cb->GetTargetFromContext() == kGPUDevice) {
297 for (const std::string &item : disable_cluster_op_list_v2) {
298 if (std::find(flags.enable_cluster_ops.begin(), flags.enable_cluster_ops.end(), item) ==
299 flags.enable_cluster_ops.end()) {
300 disable_cluster_ops.push_back(item);
301 }
302 }
303 }
304 } else if (flags.kernel_generator == "DVM") {
305 clusterable_ops = clusterable_ops_with_level_dvm;
306 } else {
307 clusterable_ops = clusterable_ops_with_level;
308 }
309 auto ops = GkUtils::GetValidOps(clusterable_ops, flags.fusion_ops_level, flags.enable_cluster_ops_only,
310 flags.enable_cluster_ops, disable_cluster_ops);
311 return GkUtils::FilterExcludedOps(ops);
312 }
313
GetClusterableOpList()314 std::vector<PrimitivePtr> StaticShapeCluster::GetClusterableOpList() { return StaticShapeCluster::GetClusterOps(); }
315
SkipHostInputNode(const AnfNodePtr & node,bool is_dvm)316 bool SkipHostInputNode(const AnfNodePtr &node, bool is_dvm) {
317 if (is_dvm && GraphKernelFlags::GetInstance().IsEnableKernelPacket()) {
318 auto cnode = node->cast<CNodePtr>();
319 return cnode != nullptr &&
320 std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), AnfAlgo::IsKernelSelectBackoffOp);
321 }
322 return false;
323 }
324
IsClusterableOp(const AnfNodePtr & node)325 bool StaticShapeCluster::IsClusterableOp(const AnfNodePtr &node) {
326 if (AnfUtils::IsGraphKernel(node)) {
327 auto sub_graph = GetCNodeFuncGraph(node);
328 if (auto type = sub_graph->get_attr("composite_type")) {
329 if (GetValue<std::string>(type) == "inplace_assign_builder") {
330 return false;
331 }
332 }
333 return true;
334 }
335 if (GkUtils::IsKeepBasicNode(node)) {
336 return false;
337 }
338 bool is_dvm = (GraphKernelFlags::GetInstance().kernel_generator == "DVM");
339 if (!is_dvm && common::AnfAlgo::IsDynamicShape(node)) {
340 return false;
341 }
342 bool node_in_oplist = std::any_of(op_list_.begin(), op_list_.end(),
343 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
344 if (!node_in_oplist) {
345 return false;
346 }
347
348 auto cb = Callback::Instance();
349 MS_EXCEPTION_IF_NULL(cb);
350 // if node's output type is complex64 or complex128, cannot be added to the cluster list.
351 auto node_output_type = cb->GetOutputType(node, 0);
352 if (node_output_type == kNumberTypeComplex64 || node_output_type == kNumberTypeComplex128) {
353 return false;
354 }
355 if (IsPrimitiveCNode(node, prim::kPrimCast)) {
356 auto node_input_type = cb->GetInputType(node, 0);
357 if ((node_input_type == kNumberTypeComplex64) || (node_input_type == kNumberTypeComplex128)) {
358 return false;
359 }
360 }
361
362 if (is_dvm && !DvmSupported(node)) {
363 return false;
364 }
365
366 if (IsPrimitiveCNode(node, prim::kPrimReshape)) {
367 auto output_format = cb->GetOutputFormat(node, 0);
368 if (output_format != kOpFormat_DEFAULT) {
369 auto primitive = GetCNodePrimitive(node);
370 MS_EXCEPTION_IF_NULL(primitive);
371 primitive = primitive->Clone();
372 // format attr used by ReshapeOp::InferFormat
373 primitive->AddAttr("format", MakeValue(output_format));
374 auto cnode = node->cast<CNodePtr>();
375 MS_EXCEPTION_IF_NULL(cnode);
376 cnode->set_input(kAnfPrimitiveIndex, NewValueNode(primitive));
377 }
378 }
379 if (!ValueDependOpUtils::IsConstInput(node)) {
380 return false;
381 }
382 if (SkipHostInputNode(node, is_dvm)) {
383 // this node can be fused with input host ops by kernelpacket
384 return false;
385 }
386
387 return true;
388 }
389
GetClusterableOpList()390 std::vector<PrimitivePtr> DynamicShapeCluster::GetClusterableOpList() {
391 std::vector<PrimitivePtr> dyn_clusterable_ops_list = {
392 prim::kPrimAdd, prim::kPrimCast, prim::kPrimMul, prim::kPrimRealDiv, prim::kPrimSub,
393 prim::kPrimAbs, prim::kPrimExp, prim::kPrimLog, prim::kPrimMaximum, prim::kPrimMinimum,
394 prim::kPrimNeg, prim::kPrimPow, prim::kPrimSqrt, prim::kPrimTranspose, prim::kPrimReduceSum};
395 return dyn_clusterable_ops_list;
396 }
397
IsClusterableOp(const AnfNodePtr & node)398 bool DynamicShapeCluster::IsClusterableOp(const AnfNodePtr &node) {
399 bool node_in_oplist = std::any_of(op_list_.begin(), op_list_.end(),
400 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
401 if (!node_in_oplist || common::AnfAlgo::IsDynamicRankNode(node)) {
402 return false;
403 }
404 if (GkUtils::IsKeepBasicNode(node)) {
405 return false;
406 }
407 if (!ValueDependOpUtils::IsConstInput(node)) {
408 return false;
409 }
410 return true;
411 }
412
Run(const FuncGraphPtr & func_graph)413 bool DynamicShapeCluster::Run(const FuncGraphPtr &func_graph) {
414 auto mng = func_graph->manager();
415 MS_EXCEPTION_IF_NULL(mng);
416 Init(func_graph);
417 bool changed = Process(func_graph);
418 if (changed) {
419 mng->RemoveRoots();
420 mng->KeepRoots({func_graph});
421 }
422 Clean();
423 return changed;
424 }
425 } // namespace mindspore::graphkernel
426