• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/adapter/graph_kernel_cluster_cloud.h"
17 #include <set>
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "mindspore/core/ops/nn_optimizer_ops.h"
20 #include "mindspore/core/ops/nn_ops.h"
21 #include "mindspore/core/ops/math_ops.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "mindspore/core/ops/comparison_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "ir/graph_utils.h"
27 #include "include/common/utils/anfalgo.h"
28 #include "utils/anf_utils.h"
29 #include "utils/ms_context.h"
30 #include "utils/file_utils.h"
31 #include "backend/common/graph_kernel/graph_kernel_flags.h"
32 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
33 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
34 #include "backend/common/graph_kernel/core/value_depend_op_utils.h"
35 #include "backend/common/graph_kernel/graph_kernel_helper.h"
36 
37 namespace mindspore::graphkernel {
38 namespace {
39 std::set<TypeId> dvm_float_types{kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeBFloat16};
40 
CheckFormat(const AnfNodePtr & node)41 bool CheckFormat(const AnfNodePtr &node) {
42   if (common::AnfAlgo::IsDynamicShape(node) && !CheckDefaultFormat(node)) {
43     // dvm kernel infer shape use inputs device shape, but the output abstract shape inferred from device shape is
44     // not unique if some shape value are not a multiple of 16
45     MS_LOG(DEBUG) << "skip node: " << node->fullname_with_scope()
46                   << " because only default format is supported in dynamic shape";
47     return false;
48   }
49   auto cb = Callback::Instance();
50   MS_EXCEPTION_IF_NULL(cb);
51   auto input_num = AnfUtils::GetInputTensorNum(node);
52   if (input_num > 0) {
53     bool has_special_format = false;
54     auto base_format = cb->GetInputFormat(node, 0);
55     for (size_t i = 0; i < input_num; ++i) {
56       auto input_format = cb->GetInputFormat(node, i);
57       if (!has_special_format &&
58           (input_format.find("FRACTAL") != std::string::npos || input_format.find("C0") != std::string::npos)) {
59         has_special_format = true;
60       }
61       if (has_special_format && input_format != base_format) {
62         // mixed special format and default format is not supported, because extra Reshape/TransData is needed
63         return false;
64       }
65     }
66   }
67   return true;
68 }
69 
DvmSliceSupported(const AnfNodePtr & node,TypeId node_output_type)70 bool DvmSliceSupported(const AnfNodePtr &node, TypeId node_output_type) {
71   constexpr size_t input_num = 3;
72   if (common::AnfAlgo::IsDynamicRankNode(node) || GetShape(node).size() > input_num) {
73     return false;
74   }
75   if (IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
76     auto cnode = node->cast<CNodePtr>();
77     auto step_node = cnode->input(kIndex4)->cast<ValueNodePtr>();
78     if (step_node == nullptr) {
79       return false;
80     }
81     auto step_vector = GetValue<std::vector<int64_t>>(step_node->value());
82     if (std::any_of(step_vector.begin(), step_vector.end(), [](int i) { return i != 1; })) {
83       return false;
84     }
85   }
86   return (dvm_float_types.find(node_output_type) != dvm_float_types.end() || node_output_type == kNumberTypeInt32);
87 }
88 
DvmSupported(const AnfNodePtr & node)89 bool DvmSupported(const AnfNodePtr &node) {
90   // check format
91   if (!CheckFormat(node)) {
92     return false;
93   }
94   auto cb = Callback::Instance();
95   MS_EXCEPTION_IF_NULL(cb);
96   auto node_output_type = cb->GetOutputType(node, 0);
97   // cast op
98   if (IsPrimitiveCNode(node, prim::kPrimCast)) {
99     static std::set<TypeId> supported_types{kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeBool, kNumberTypeInt32,
100                                             kNumberTypeBFloat16};
101     auto node_input_type = cb->GetInputType(node, 0);
102     return !(supported_types.find(node_input_type) == supported_types.end() ||
103              supported_types.find(node_output_type) == supported_types.end());
104   }
105   // reduceSum op
106   if (IsPrimitiveCNode(node, prim::kPrimReduceSum)) {
107     auto prim = GetCNodePrimitive(node);
108     MS_EXCEPTION_IF_NULL(prim);
109     auto skip_mode_attr = prim->GetAttr(kAttrSkipMode);
110     MS_EXCEPTION_IF_NULL(skip_mode_attr);
111     auto skip_mode = GetValue<bool>(skip_mode_attr);
112     if (skip_mode == true) {
113       return false;
114     }
115   }
116   // compare op
117   static std::vector<PrimitivePtr> compare_ops{prim::kPrimEqual,        prim::kPrimNotEqual, prim::kPrimGreater,
118                                                prim::kPrimGreaterEqual, prim::kPrimLess,     prim::kPrimLessEqual};
119   if (std::any_of(compare_ops.begin(), compare_ops.end(),
120                   [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
121     auto node_input_type = cb->GetInputType(node, 0);
122     return (dvm_float_types.find(node_input_type) != dvm_float_types.end() || node_input_type == kNumberTypeInt32);
123   }
124   // int op
125   static std::vector<PrimitivePtr> int_ops{
126     prim::kPrimAdd, prim::kPrimSub, prim::kPrimMul,    prim::kPrimMaximum, prim::kPrimMinimum,
127     prim::kPrimNeg, prim::kPrimAbs, prim::kPrimSelect, prim::kPrimAssign,  prim::kPrimBroadcastTo};
128   if (std::any_of(int_ops.begin(), int_ops.end(),
129                   [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
130     return (dvm_float_types.find(node_output_type) != dvm_float_types.end() || node_output_type == kNumberTypeInt32);
131   }
132   // slice op
133   static std::vector<PrimitivePtr> slice_ops{prim::kPrimSlice, prim::kPrimStridedSlice};
134   if (std::any_of(slice_ops.begin(), slice_ops.end(),
135                   [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
136     return DvmSliceSupported(node, node_output_type);
137   }
138   // matmul op
139   static std::vector<PrimitivePtr> matmul_ops{prim::kPrimMatMul, prim::kPrimBatchMatMul};
140   if (std::any_of(matmul_ops.begin(), matmul_ops.end(),
141                   [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
142     return node_output_type == kNumberTypeFloat16 || node_output_type == kNumberTypeBFloat16;
143   }
144   if (IsPrimitiveCNode(node, prim::kPrimTranspose)) {
145     // for bf16, extra cast will be inserted, to do: move ConvertBFloat16 after garph kernel split
146     return node_output_type == kNumberTypeFloat16 || node_output_type == kNumberTypeFloat32;
147   }
148   // other op
149   return dvm_float_types.find(node_output_type) != dvm_float_types.end();
150 }
151 
152 const std::vector<OpWithLevel> clusterable_ops_with_level = {
153   // all target
154   {kAllTarget, OpLevel_0, prim::kPrimAbs},
155   {kAllTarget, OpLevel_0, prim::kPrimAdd},
156   {kAllTarget, OpLevel_0, prim::kPrimCast},
157   {kAllTarget, OpLevel_0, prim::kPrimEqual},
158   {kAllTarget, OpLevel_0, prim::kPrimExp},
159   {kAllTarget, OpLevel_0, prim::kPrimLog},
160   {kAllTarget, OpLevel_0, prim::kPrimMaximum},
161   {kAllTarget, OpLevel_0, prim::kPrimMinimum},
162   {kAllTarget, OpLevel_0, prim::kPrimMul},
163   {kAllTarget, OpLevel_0, prim::kPrimNeg},
164   {kAllTarget, OpLevel_0, prim::kPrimPow},
165   {kAllTarget, OpLevel_0, prim::kPrimRealDiv},
166   {kAllTarget, OpLevel_0, prim::kPrimReciprocal},
167   {kAllTarget, OpLevel_1, prim::kPrimReduceSum},
168   {kAllTarget, OpLevel_1, prim::kPrimReshape},
169   {kAllTarget, OpLevel_0, prim::kPrimRound},
170   {kAllTarget, OpLevel_0, prim::kPrimRsqrt},
171   {kAllTarget, OpLevel_0, prim::kPrimSqrt},
172   {kAllTarget, OpLevel_0, prim::kPrimSub},
173   {kAllTarget, OpLevel_0, prim::kPrimTanh},
174   {kAllTarget, OpLevel_1, prim::kPrimTranspose},
175   // ascend
176   {kAscendDevice, OpLevel_1, prim::kPrimMatMul},
177   {kAscendDevice, OpLevel_1, prim::kPrimTransData},
178   {kAscendDevice, OpLevel_1, prim::kPrimBatchMatMul},
179   // gpu
180   {kGPUDevice, OpLevel_0, prim::kPrimACos},
181   {kGPUDevice, OpLevel_0, prim::kPrimAcosh},
182   {kGPUDevice, OpLevel_2, prim::kPrimArgMax},
183   {kGPUDevice, OpLevel_2, prim::kPrimArgmin},
184   {kGPUDevice, OpLevel_0, prim::kPrimAsin},
185   {kGPUDevice, OpLevel_0, prim::kPrimAsinh},
186   {kGPUDevice, OpLevel_0, prim::kPrimAssign},
187   {kGPUDevice, OpLevel_0, prim::kPrimAtan},
188   {kGPUDevice, OpLevel_0, prim::kPrimAtan2},
189   {kGPUDevice, OpLevel_0, prim::kPrimCos},
190   {kGPUDevice, OpLevel_0, prim::kPrimDiv},
191   {kGPUDevice, OpLevel_0, prim::kPrimErf},
192   {kGPUDevice, OpLevel_0, prim::kPrimExpm1},
193   {kGPUDevice, OpLevel_0, prim::kPrimFloor},
194   {kGPUDevice, OpLevel_0, prim::kPrimFloorDiv},
195   {kGPUDevice, OpLevel_0, prim::kPrimFloorMod},
196   {kGPUDevice, OpLevel_0, prim::kPrimGreater},
197   {kGPUDevice, OpLevel_0, prim::kPrimGreaterEqual},
198   {kGPUDevice, OpLevel_0, prim::kPrimIsFinite},
199   {kGPUDevice, OpLevel_0, prim::kPrimIsInf},
200   {kGPUDevice, OpLevel_0, prim::kPrimIsNan},
201   {kGPUDevice, OpLevel_0, prim::kPrimLess},
202   {kGPUDevice, OpLevel_0, prim::kPrimLessEqual},
203   {kGPUDevice, OpLevel_0, prim::kPrimLogicalAnd},
204   {kGPUDevice, OpLevel_0, prim::kPrimLogicalOr},
205   {kGPUDevice, OpLevel_0, prim::kPrimLogicalNot},
206   {kGPUDevice, OpLevel_0, prim::kPrimMod},
207   {kGPUDevice, OpLevel_0, prim::kPrimNotEqual},
208   {kGPUDevice, OpLevel_1, prim::kPrimReduceMax},
209   {kGPUDevice, OpLevel_1, prim::kPrimReduceMin},
210   {kGPUDevice, OpLevel_0, prim::kPrimSelect},
211   {kGPUDevice, OpLevel_0, prim::kPrimSign},
212   {kGPUDevice, OpLevel_0, prim::kPrimSin},
213   {kGPUDevice, OpLevel_0, prim::kPrimStridedSlice},
214   {kGPUDevice, OpLevel_1, prim::kPrimCumSum},
215   {kGPUDevice, OpLevel_1, prim::kPrimOneHot},
216   // cpu
217   {kCPUDevice, OpLevel_0, prim::kPrimLogicalNot},
218   {kCPUDevice, OpLevel_0, prim::kPrimMod},
219   {kCPUDevice, OpLevel_1, prim::kPrimReduceMax},
220   {kCPUDevice, OpLevel_0, prim::kPrimSelect},
221   {kCPUDevice, OpLevel_0, prim::kPrimLess},
222   {kCPUDevice, OpLevel_0, prim::kPrimLessEqual},
223 };
224 
225 const std::vector<OpWithLevel> clusterable_ops_with_level_v2 = {
226   // cpu
227   {kCPUDevice, OpLevel_0, prim::kPrimNotEqual},
228   {kCPUDevice, OpLevel_0, prim::kPrimGreaterEqual},
229   {kCPUDevice, OpLevel_0, prim::kPrimGreater},
230   {kCPUDevice, OpLevel_0, prim::kPrimFloor},
231   {kCPUDevice, OpLevel_0, prim::kPrimIsNan},
232   {kCPUDevice, OpLevel_0, prim::kPrimAssign},
233   {kCPUDevice, OpLevel_0, prim::kPrimBroadcastTo},
234   {kCPUDevice, OpLevel_0, prim::kPrimTile},
235   {kCPUDevice, OpLevel_0, prim::kPrimLogicalAnd},
236   {kCPUDevice, OpLevel_0, prim::kPrimCos},
237   {kCPUDevice, OpLevel_0, prim::kPrimSin},
238   {kCPUDevice, OpLevel_0, prim::kPrimACos},
239   {kCPUDevice, OpLevel_0, prim::kPrimAsin},
240   {kCPUDevice, OpLevel_0, prim::kPrimTanh},
241   {kCPUDevice, OpLevel_0, prim::kPrimAtan2},
242   {kCPUDevice, OpLevel_0, prim::kPrimMinimum},
243   {kCPUDevice, OpLevel_0, prim::kPrimMaximum},
244   {kCPUDevice, OpLevel_0, prim::kPrimReduceAll},
245   {kCPUDevice, OpLevel_0, prim::kPrimStridedSlice},
246   // gpu
247   {kGPUDevice, OpLevel_0, prim::kPrimNotEqual},
248   {kGPUDevice, OpLevel_0, prim::kPrimSelect},
249   {kGPUDevice, OpLevel_0, prim::kPrimTile},
250   {kGPUDevice, OpLevel_0, prim::kPrimLogicalAnd},
251   {kGPUDevice, OpLevel_0, prim::kPrimCos},
252   {kGPUDevice, OpLevel_0, prim::kPrimSin},
253   {kGPUDevice, OpLevel_0, prim::kPrimMinimum},
254   {kGPUDevice, OpLevel_0, prim::kPrimMaximum},
255   {kGPUDevice, OpLevel_0, prim::kPrimAssign},
256 };
257 
258 const std::vector<std::string> disable_cluster_op_list_v2 = {"OneHot", "CumSum",      "Transpose",   "BatchMatMul",
259                                                              "MatMul", "BroadcastTo", "StridedSlice"};
260 
261 const std::vector<OpWithLevel> clusterable_ops_with_level_dvm = {
262   {kAscendDevice, OpLevel_0, prim::kPrimAbs},          {kAscendDevice, OpLevel_0, prim::kPrimAdd},
263   {kAscendDevice, OpLevel_0, prim::kPrimBroadcastTo},  {kAscendDevice, OpLevel_0, prim::kPrimCast},
264   {kAscendDevice, OpLevel_0, prim::kPrimExp},          {kAscendDevice, OpLevel_0, prim::kPrimLog},
265   {kAscendDevice, OpLevel_0, prim::kPrimMaximum},      {kAscendDevice, OpLevel_0, prim::kPrimMinimum},
266   {kAscendDevice, OpLevel_0, prim::kPrimMul},          {kAscendDevice, OpLevel_0, prim::kPrimNeg},
267   {kAscendDevice, OpLevel_0, prim::kPrimPow},          {kAscendDevice, OpLevel_0, prim::kPrimDiv},
268   {kAscendDevice, OpLevel_0, prim::kPrimRealDiv},      {kAscendDevice, OpLevel_0, prim::kPrimReciprocal},
269   {kAscendDevice, OpLevel_0, prim::kPrimRsqrt},        {kAscendDevice, OpLevel_0, prim::kPrimSqrt},
270   {kAscendDevice, OpLevel_0, prim::kPrimSub},          {kAscendDevice, OpLevel_0, prim::kPrimEqual},
271   {kAscendDevice, OpLevel_0, prim::kPrimNotEqual},     {kAscendDevice, OpLevel_0, prim::kPrimGreater},
272   {kAscendDevice, OpLevel_0, prim::kPrimGreaterEqual}, {kAscendDevice, OpLevel_0, prim::kPrimLess},
273   {kAscendDevice, OpLevel_0, prim::kPrimLessEqual},    {kAscendDevice, OpLevel_0, prim::kPrimLogicalAnd},
274   {kAscendDevice, OpLevel_0, prim::kPrimLogicalOr},    {kAscendDevice, OpLevel_0, prim::kPrimLogicalNot},
275   {kAscendDevice, OpLevel_0, prim::kPrimSelect},       {kAscendDevice, OpLevel_0, prim::kPrimAssign},
276   {kAscendDevice, OpLevel_0, prim::kPrimReduceSum},    {kAscendDevice, OpLevel_0, prim::kPrimIsFinite},
277   {kAscendDevice, OpLevel_1, prim::kPrimReshape},      {kAscendDevice, OpLevel_0, prim::kPrimTranspose},
278 };
279 }  // namespace
280 
GetClusterOps()281 std::vector<PrimitivePtr> StaticShapeCluster::GetClusterOps() {
282   const auto &flags = GraphKernelFlags::GetInstance();
283   std::vector<std::string> disable_cluster_ops = flags.disable_cluster_ops;
284   auto cb = Callback::Instance();
285 
286   std::vector<OpWithLevel> clusterable_ops;
287   if (flags.kernel_generator == "AKG_V2") {
288     clusterable_ops = clusterable_ops_with_level;
289     clusterable_ops.insert(clusterable_ops.end(), clusterable_ops_with_level_v2.begin(),
290                            clusterable_ops_with_level_v2.end());
291     if (cb->GetTargetFromContext() == kCPUDevice &&
292         std::find(flags.enable_cluster_ops.begin(), flags.enable_cluster_ops.end(), "Reshape") ==
293           flags.enable_cluster_ops.end()) {
294       disable_cluster_ops.push_back("Reshape");
295     }
296     if (cb->GetTargetFromContext() == kGPUDevice) {
297       for (const std::string &item : disable_cluster_op_list_v2) {
298         if (std::find(flags.enable_cluster_ops.begin(), flags.enable_cluster_ops.end(), item) ==
299             flags.enable_cluster_ops.end()) {
300           disable_cluster_ops.push_back(item);
301         }
302       }
303     }
304   } else if (flags.kernel_generator == "DVM") {
305     clusterable_ops = clusterable_ops_with_level_dvm;
306   } else {
307     clusterable_ops = clusterable_ops_with_level;
308   }
309   auto ops = GkUtils::GetValidOps(clusterable_ops, flags.fusion_ops_level, flags.enable_cluster_ops_only,
310                                   flags.enable_cluster_ops, disable_cluster_ops);
311   return GkUtils::FilterExcludedOps(ops);
312 }
313 
GetClusterableOpList()314 std::vector<PrimitivePtr> StaticShapeCluster::GetClusterableOpList() { return StaticShapeCluster::GetClusterOps(); }
315 
SkipHostInputNode(const AnfNodePtr & node,bool is_dvm)316 bool SkipHostInputNode(const AnfNodePtr &node, bool is_dvm) {
317   if (is_dvm && GraphKernelFlags::GetInstance().IsEnableKernelPacket()) {
318     auto cnode = node->cast<CNodePtr>();
319     return cnode != nullptr &&
320            std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), AnfAlgo::IsKernelSelectBackoffOp);
321   }
322   return false;
323 }
324 
IsClusterableOp(const AnfNodePtr & node)325 bool StaticShapeCluster::IsClusterableOp(const AnfNodePtr &node) {
326   if (AnfUtils::IsGraphKernel(node)) {
327     auto sub_graph = GetCNodeFuncGraph(node);
328     if (auto type = sub_graph->get_attr("composite_type")) {
329       if (GetValue<std::string>(type) == "inplace_assign_builder") {
330         return false;
331       }
332     }
333     return true;
334   }
335   if (GkUtils::IsKeepBasicNode(node)) {
336     return false;
337   }
338   bool is_dvm = (GraphKernelFlags::GetInstance().kernel_generator == "DVM");
339   if (!is_dvm && common::AnfAlgo::IsDynamicShape(node)) {
340     return false;
341   }
342   bool node_in_oplist = std::any_of(op_list_.begin(), op_list_.end(),
343                                     [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
344   if (!node_in_oplist) {
345     return false;
346   }
347 
348   auto cb = Callback::Instance();
349   MS_EXCEPTION_IF_NULL(cb);
350   // if node's output type is complex64 or complex128, cannot be added to the cluster list.
351   auto node_output_type = cb->GetOutputType(node, 0);
352   if (node_output_type == kNumberTypeComplex64 || node_output_type == kNumberTypeComplex128) {
353     return false;
354   }
355   if (IsPrimitiveCNode(node, prim::kPrimCast)) {
356     auto node_input_type = cb->GetInputType(node, 0);
357     if ((node_input_type == kNumberTypeComplex64) || (node_input_type == kNumberTypeComplex128)) {
358       return false;
359     }
360   }
361 
362   if (is_dvm && !DvmSupported(node)) {
363     return false;
364   }
365 
366   if (IsPrimitiveCNode(node, prim::kPrimReshape)) {
367     auto output_format = cb->GetOutputFormat(node, 0);
368     if (output_format != kOpFormat_DEFAULT) {
369       auto primitive = GetCNodePrimitive(node);
370       MS_EXCEPTION_IF_NULL(primitive);
371       primitive = primitive->Clone();
372       // format attr used by ReshapeOp::InferFormat
373       primitive->AddAttr("format", MakeValue(output_format));
374       auto cnode = node->cast<CNodePtr>();
375       MS_EXCEPTION_IF_NULL(cnode);
376       cnode->set_input(kAnfPrimitiveIndex, NewValueNode(primitive));
377     }
378   }
379   if (!ValueDependOpUtils::IsConstInput(node)) {
380     return false;
381   }
382   if (SkipHostInputNode(node, is_dvm)) {
383     // this node can be fused with input host ops by kernelpacket
384     return false;
385   }
386 
387   return true;
388 }
389 
GetClusterableOpList()390 std::vector<PrimitivePtr> DynamicShapeCluster::GetClusterableOpList() {
391   std::vector<PrimitivePtr> dyn_clusterable_ops_list = {
392     prim::kPrimAdd, prim::kPrimCast, prim::kPrimMul,  prim::kPrimRealDiv,   prim::kPrimSub,
393     prim::kPrimAbs, prim::kPrimExp,  prim::kPrimLog,  prim::kPrimMaximum,   prim::kPrimMinimum,
394     prim::kPrimNeg, prim::kPrimPow,  prim::kPrimSqrt, prim::kPrimTranspose, prim::kPrimReduceSum};
395   return dyn_clusterable_ops_list;
396 }
397 
IsClusterableOp(const AnfNodePtr & node)398 bool DynamicShapeCluster::IsClusterableOp(const AnfNodePtr &node) {
399   bool node_in_oplist = std::any_of(op_list_.begin(), op_list_.end(),
400                                     [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
401   if (!node_in_oplist || common::AnfAlgo::IsDynamicRankNode(node)) {
402     return false;
403   }
404   if (GkUtils::IsKeepBasicNode(node)) {
405     return false;
406   }
407   if (!ValueDependOpUtils::IsConstInput(node)) {
408     return false;
409   }
410   return true;
411 }
412 
Run(const FuncGraphPtr & func_graph)413 bool DynamicShapeCluster::Run(const FuncGraphPtr &func_graph) {
414   auto mng = func_graph->manager();
415   MS_EXCEPTION_IF_NULL(mng);
416   Init(func_graph);
417   bool changed = Process(func_graph);
418   if (changed) {
419     mng->RemoveRoots();
420     mng->KeepRoots({func_graph});
421   }
422   Clean();
423   return changed;
424 }
425 }  // namespace mindspore::graphkernel
426