• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
17 
18 #include <algorithm>
19 #include <map>
20 #include <memory>
21 #include <sstream>
22 #include <unordered_map>
23 #include <utility>
24 
25 #include "backend/common/graph_kernel/graph_kernel_flags.h"
26 #include "backend/common/graph_kernel/model/graph_builder.h"
27 #include "backend/common/graph_kernel/model/node.h"
28 #include "backend/common/graph_kernel/model/op_node.h"
29 #include "mindspore/core/ops/conv_pool_ops.h"
30 #include "mindspore/core/ops/math_ops.h"
31 #include "mindspore/core/ops/sequence_ops.h"
32 #include "runtime/hardware/device_context_manager.h"
33 #include "utils/anf_utils.h"
34 #include "utils/ms_context.h"
35 
36 namespace mindspore::graphkernel {
37 namespace {
GetOutputSymbolicShape(const AnfNodePtr & node,size_t i)38 ListSymbolPtr GetOutputSymbolicShape(const AnfNodePtr &node, size_t i) {
39   if (node == nullptr) {
40     return nullptr;
41   }
42   auto abstract = node->abstract();
43   if (abstract == nullptr) {
44     return nullptr;
45   }
46   auto symbol_shape = abstract->GetSymbolicShape();
47   if (symbol_shape == nullptr) {
48     return nullptr;
49   }
50   if (abstract->isa<abstract::AbstractSequence>()) {
51     // multiple outputs
52     if (i >= symbol_shape->size()) {
53       MS_LOG(WARNING) << "Output idx '" << i << "' is out of range [0, " << symbol_shape->size()
54                       << ") for node: " << node->ToString();
55       return nullptr;
56     }
57     auto shape_i = symbol_shape->symbols()[i];
58     if (shape_i == nullptr) {
59       return nullptr;
60     }
61     return shape_i->as_sptr_noexcept<ListSymbol>();
62   }
63   // single output
64   return symbol_shape;
65 }
66 }  // namespace
67 
ExtractGraphKernelName(const AnfNodePtrList & nodes,const std::string & prefix,const std::string & postfix)68 std::string GkUtils::ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix,
69                                             const std::string &postfix) {
70   std::stringstream name;
71   if (!prefix.empty()) {
72     name << prefix << "_";
73   }
74   for (const auto &node : nodes) {
75     if (AnfUtils::IsGraphKernel(node)) {
76       auto fg_flag_val = GetCNodeFuncGraph(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
77       name << GetValue<std::string>(fg_flag_val) << "_";
78     } else if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
79       name << GetCNodePrimitive(node)->name() << "_";
80     }
81   }
82   if (!postfix.empty()) {
83     name << postfix;
84   }
85   return name.str();
86 }
87 
SpreadTuples(const AnfNodePtrList & nodes,size_t begin_index)88 AnfNodePtrList GkUtils::SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
89   AnfNodePtrList result;
90   for (size_t i = begin_index; i < nodes.size(); i++) {
91     if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) {
92       auto mt = nodes[i]->cast<CNodePtr>();
93       // recursively spread all inner tuples.
94       auto mt_inputs = SpreadTuples(mt->inputs(), 1);
95       (void)result.insert(result.cend(), mt_inputs.cbegin(), mt_inputs.cend());
96     } else {
97       result.push_back(nodes[i]);
98     }
99   }
100   return result;
101 }
102 
GetValidOps(const std::vector<OpWithLevel> & ops_with_level,unsigned int level,const std::vector<std::string> & enable_ops_only,const std::vector<std::string> & enable_ops,const std::vector<std::string> & disable_ops)103 std::vector<PrimitivePtr> GkUtils::GetValidOps(const std::vector<OpWithLevel> &ops_with_level, unsigned int level,
104                                                const std::vector<std::string> &enable_ops_only,
105                                                const std::vector<std::string> &enable_ops,
106                                                const std::vector<std::string> &disable_ops) {
107   std::vector<PrimitivePtr> ops;
108   auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
109   if (!enable_ops_only.empty()) {
110     (void)std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::back_inserter(ops), new_prim);
111     return ops;
112   }
113   auto target = Callback::Instance()->GetTargetFromContext();
114   for (const auto &[op_target, op_level, op] : ops_with_level) {
115     if (op_target == kAllTarget || op_target == target) {
116       if (level >= op_level) {
117         (void)ops.emplace_back(op);
118       }
119     }
120   }
121   if (!enable_ops.empty()) {
122     (void)std::transform(enable_ops.begin(), enable_ops.end(), std::back_inserter(ops), new_prim);
123   }
124   if (!disable_ops.empty()) {
125     auto iter = std::remove_if(ops.begin(), ops.end(), [&disable_ops](const PrimitivePtr &p) {
126       return std::find(disable_ops.begin(), disable_ops.end(), p->name()) != disable_ops.end();
127     });
128     (void)ops.erase(iter, ops.cend());
129   }
130   return ops;
131 }
132 
FilterExcludedOps(const std::vector<PrimitivePtr> & ops)133 std::vector<PrimitivePtr> GkUtils::FilterExcludedOps(const std::vector<PrimitivePtr> &ops) {
134 #ifndef MSLITE_ENABLE_GRAPH_KERNEL
135   if (Callback::Instance()->GetTargetFromContext() != kGPUDevice) {
136     return ops;
137   }
138   std::vector<PrimitivePtr> dst_ops;
139   const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
140     {kGPUDevice, MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
141   MS_EXCEPTION_IF_NULL(device_context);
142   auto deprecated_ptr = device_context->GetDeprecatedInterface();
143   MS_EXCEPTION_IF_NULL(deprecated_ptr);
144   auto major_compute_capability = deprecated_ptr->GetGPUCapabilityMajor();
145   std::unordered_map<std::string, int> limited_capacity_ops = {
146     {prim::kPrimConv2D->name(), 7}, {prim::kPrimMatMul->name(), 7}, {prim::kPrimBatchMatMul->name(), 7}};
147   std::vector<std::string> final_filter_ops;
148   for (auto op : ops) {
149     if (limited_capacity_ops.find(op->name()) != limited_capacity_ops.end() &&
150         limited_capacity_ops[op->name()] != major_compute_capability) {
151       (void)final_filter_ops.emplace_back(op->name());
152     } else {
153       (void)dst_ops.emplace_back(op);
154     }
155   }
156   // Give hint for excluded src_ops.
157   static bool give_hint = false;
158   if (!give_hint && final_filter_ops.size() > 0) {
159     give_hint = true;
160     for (size_t i = 0; i < final_filter_ops.size(); ++i) {
161       MS_LOG(INFO) << "For op : " << final_filter_ops[i]
162                    << " can not be enabled in GraphKernel because the current device's computing capacity is "
163                    << major_compute_capability << ", which is != " << limited_capacity_ops[final_filter_ops[i]];
164     }
165   }
166   return dst_ops;
167 #else
168   return ops;
169 #endif
170 }
171 
IsKeepBasicNode(const AnfNodePtr & node)172 bool GkUtils::IsKeepBasicNode(const AnfNodePtr &node) {
173   MS_EXCEPTION_IF_NULL(node);
174   auto prim = GetCNodePrimitive(node);
175   auto target = Callback::Instance()->GetTargetFromContext();
176   if (prim == nullptr) {
177     return false;
178   }
179   // Heterogeneous computing is not support yet
180   // so if node's primitive_target is inconsistent with target from context
181   // the node cannot be added to the cluster list.
182   if (prim->HasAttr("primitive_target") && GetValue<std::string>(prim->GetAttr("primitive_target")) != target) {
183     return true;
184   }
185 
186   // the "skip" is used by inplace node.
187   // the kAttrIsInternalOutputNopNode is used by internal output of KernelGraph.
188   const std::vector<std::string> exclude_bool_attrs = {"skip", kAttrIsInternalOutputNopNode};
189   if (std::any_of(exclude_bool_attrs.cbegin(), exclude_bool_attrs.cend(), [&prim](const std::string &attr_name) {
190         return prim->HasAttr(attr_name) && GetValue<bool>(prim->GetAttr(attr_name));
191       })) {
192     return true;
193   }
194 
195   // If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is.
196   const std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index",
197                                                      "aggregate", "aggregate_input_index"};
198   if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(),
199                   [&prim](const std::string &attr_name) -> bool { return prim->HasAttr(attr_name); })) {
200     return true;
201   }
202   auto cnode = node->cast<CNodePtr>();
203   return (cnode != nullptr && cnode->HasAttr("keep_basic"));
204 }
205 
NewRealCNode(const std::vector<AnfNodePtr> & inputs,const FuncGraphPtr & func_graph,const std::vector<inner::NodeBase> & out_info_list,const CallbackPtr & cb)206 CNodePtr GkUtils::NewRealCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph,
207                                const std::vector<inner::NodeBase> &out_info_list, const CallbackPtr &cb) {
208   auto cnode = func_graph->NewCNode(inputs);
209   MS_EXCEPTION_IF_NULL(cnode);
210 
211   if (out_info_list.size() == 0) {
212     MS_LOG(EXCEPTION) << "CNode must have output!";
213   }
214 
215   // Setup abstract.
216   AbstractBasePtrList abs_list;
217   (void)std::transform(
218     out_info_list.begin(), out_info_list.end(), std::back_inserter(abs_list), [](const inner::NodeBase &out_info) {
219       auto abs_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(out_info.type), out_info.shape);
220       return abs_tensor;
221     });
222   if (abs_list.size() == 1) {
223     cnode->set_abstract(abs_list[0]);
224   } else {
225     cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
226   }
227 
228   // Setup kernel build info.
229   cb->SetBasicNodeKernelInfo(cnode, out_info_list);
230   func_graph->AddNode(cnode);
231   return cnode;
232 }
233 
LiteGraph2AnfGraph(const inner::LiteGraphPtr & lite_graph,const CallbackPtr & cb)234 FuncGraphPtr GkUtils::LiteGraph2AnfGraph(const inner::LiteGraphPtr &lite_graph, const CallbackPtr &cb) {
235   auto func_graph = std::make_shared<FuncGraph>();
236   std::map<inner::NodePtr, AnfNodePtr> node_map;
237   for (const auto &inp : lite_graph->inputs()) {
238     auto param = func_graph->add_parameter();
239     node_map[inp] = param;
240     param->set_abstract(std::make_shared<abstract::AbstractTensor>(TypeIdToType(inp->type), inp->shape));
241     cb->SetBasicNodeKernelInfo(param, {{inp->shape, inp->type, inp->format}});
242   }
243   // Create CNodes.
244   for (const auto &op_node : lite_graph->GetOrderedNodes()) {
245     if (op_node->NodeType() != inner::NType::Primitive) {
246       MS_LOG(EXCEPTION) << "Node " << op_node->debug_name() << " should be a Primitive node";
247     }
248     auto op = std::static_pointer_cast<inner::PrimOp>(op_node);
249     auto primitive = std::make_shared<Primitive>(op->op(), op->attrs());
250     auto prim = GetOpsPrim(primitive->name());
251     if (prim != nullptr) {
252       (void)primitive->AddAttr(kAttrInputNames, prim->GetAttr(kAttrInputNames));
253       (void)primitive->AddAttr(kAttrOutputNames, prim->GetAttr(kAttrOutputNames));
254     }
255     AnfNodePtrList inputs = {NewValueNode(primitive)};
256     (void)std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(inputs),
257                          [&node_map, &cb](const inner::NodePtr &inp) -> AnfNodePtr {
258                            const auto iter = node_map.find(inp);
259                            if (iter != node_map.end()) {
260                              return iter->second;
261                            } else {
262                              auto node_type = inp->NodeType();
263                              if (node_type != inner::NType::Tensor && node_type != inner::NType::Scalar &&
264                                  node_type != inner::NType::Tuple) {
265                                MS_LOG(EXCEPTION)
266                                  << "Node " << inp->debug_name() << " should be a Tensor or Scalar node";
267                              }
268                              ValuePtr inp_value = nullptr;
269                              if (node_type == inner::NType::Tensor) {
270                                inp_value = inp->As<inner::ConstTensorNode>()->data();
271                              } else if (node_type == inner::NType::Scalar) {
272                                inp_value = inp->As<inner::ConstScalarNode>()->data();
273                              } else {
274                                inp_value = inp->As<inner::ConstTupleNode>()->data();
275                              }
276                              auto value_node = NewValueNode(inp_value);
277                              value_node->set_abstract(inp_value->ToAbstract());
278                              cb->SetBasicNodeKernelInfo(value_node, {{inp->shape, inp->type, inp->format}});
279                              return value_node;
280                            }
281                          });
282     auto output_info_list = op->outputs();
283     if (output_info_list.empty()) {
284       (void)output_info_list.emplace_back(static_cast<inner::NodeBase>(*op));
285     }
286     auto cnode = NewRealCNode(inputs, func_graph, output_info_list, cb);
287     MS_EXCEPTION_IF_NULL(cnode);
288     node_map[op_node] = cnode;
289   }
290   if (lite_graph->GetOutputs().empty()) {
291     MS_LOG(EXCEPTION) << "The output of LiteGraph " << lite_graph->name() << " is empty.";
292   } else if (lite_graph->GetOutputs().size() == 1) {
293     func_graph->set_output(node_map[lite_graph->GetOutputs()[0]]);
294   } else {
295     AnfNodePtrList mt_inputs;
296     AbstractBasePtrList out_abs_list;
297     (void)std::transform(lite_graph->GetOutputs().begin(), lite_graph->GetOutputs().end(),
298                          std::back_inserter(mt_inputs), [&node_map, &out_abs_list](const inner::NodePtr &out) {
299                            auto out_node = node_map[out];
300                            MS_EXCEPTION_IF_NULL(out_node);
301                            (void)out_abs_list.emplace_back(out_node->abstract());
302                            return out_node;
303                          });
304     auto mt = func_graph->NewCNode(prim::kPrimMakeTuple, mt_inputs);
305     mt->set_abstract(std::make_shared<abstract::AbstractTuple>(out_abs_list));
306     cb->SetEmptyKernelInfo(mt);
307     func_graph->AddNode(mt);
308     func_graph->set_output(mt);
309   }
310   return func_graph;
311 }
312 
InputValue2Tensor(ValuePtr input_value)313 tensor::TensorPtr InputValue2Tensor(ValuePtr input_value) {
314   // input value of a cnode can be one of tensor, valuesequence and int,
315   // in order to emit litegraph node by gb.Value, convert the type of value to tensor anyway
316   tensor::TensorPtr input_tensor = nullptr;
317   if (input_value->isa<Int32Imm>() || input_value->isa<Int64Imm>()) {
318     auto input_num = AnfUtils::GetIntValue(input_value);
319     input_tensor = std::make_shared<tensor::Tensor>(input_num);
320   } else if (input_value->isa<ValueSequence>()) {
321     auto input_seq = input_value->cast<ValueSequencePtr>()->value();
322     std::vector<int64_t> input_vec;
323     (void)std::transform(input_seq.begin(), input_seq.end(), std::back_inserter(input_vec),
324                          [](auto v) { return AnfUtils::GetIntValue(v); });
325     input_tensor = std::make_shared<tensor::Tensor>(input_vec);
326   } else if (input_value->isa<tensor::Tensor>()) {
327     input_tensor = input_value->cast<tensor::TensorPtr>();
328   } else if (input_value->isa<BoolImm>()) {
329     auto input_bool = GetValue<bool>(input_value);
330     input_tensor = std::make_shared<tensor::Tensor>(input_bool);
331   } else {
332     MS_LOG(EXCEPTION) << "Unsupported Type in InputValue2Tensor";
333   }
334   return input_tensor;
335 }
336 
AnfGraph2LiteGraph(const FuncGraphPtr & func_graph,HashMap<inner::NodePtr,AnfNodePtr> * op_node_map)337 inner::LiteGraphPtr GkUtils::AnfGraph2LiteGraph(const FuncGraphPtr &func_graph,
338                                                 HashMap<inner::NodePtr, AnfNodePtr> *op_node_map) {
339   std::string name = "Default";
340   if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
341     name = GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
342   }
343   inner::GraphBuilder gb(name);
344   std::map<AnfNodePtr, inner::NodePtr> node_map;
345   auto todos = TopoSort(func_graph->output(), SuccIncoming,
346                         [](const AnfNodePtr &node) { return node->isa<CNode>() ? FOLLOW : EXCLUDE; });
347   const auto &params = func_graph->parameters();
348   auto cb = Callback::Instance();
349   auto ExtractBuildInfo = [&cb](const AnfNodePtr &node) -> inner::NodeBaseList {
350     inner::NodeBaseList listinfo;
351     size_t output_num = AnfUtils::GetOutputTensorNum(node);
352     for (size_t i = 0; i < output_num; ++i) {
353       auto shape = cb->GetOutputShape(node, i);
354       auto type = cb->GetOutputType(node, i);
355       auto format = cb->GetOutputFormat(node, i);
356       auto symbol_shape = GetOutputSymbolicShape(node, i);
357       listinfo.push_back(inner::NodeBase({shape, type, format, symbol_shape}));
358     }
359     return listinfo;
360   };
361   // set inputs
362   for (auto &p : params) {
363     node_map[p] = gb.Parameter(ExtractBuildInfo(p)[0]);
364   }
365   // set ops
366   for (auto node : todos) {
367     auto cnode = node->cast<CNodePtr>();
368     MS_EXCEPTION_IF_NULL(cnode);
369     if (node == func_graph->output() && IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
370       break;
371     }
372     auto prim = GetCNodePrimitive(cnode);
373     MS_EXCEPTION_IF_NULL(prim);
374     inner::NodePtrList inputs;
375     for (size_t i = 1; i < cnode->size(); ++i) {
376       auto input_i = cnode->input(i);
377       const auto iter = node_map.find(input_i);
378       if (iter != node_map.end()) {
379         // input is parameter or cnode
380         inputs.push_back(iter->second);
381         continue;
382       }
383       // input is valuenode
384       auto input_value_node = input_i->cast<ValueNodePtr>();
385       auto input_value = input_value_node->value();
386       constexpr size_t idx = 2;
387       inner::NodePtr input_node;
388       if ((IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) && i == idx) {
389         input_node = std::make_shared<inner::ConstScalarNode>(input_value);
390       } else {
391         auto tensor = InputValue2Tensor(input_value);
392         MS_EXCEPTION_IF_NULL(tensor);
393         input_node = gb.Value(tensor);
394       }
395       inputs.push_back(input_node);
396     }
397     auto op = gb.Op(AnfUtils::GetCNodeName(node), ExtractBuildInfo(node), inputs, prim->attrs());
398     node_map[node] = op;
399     if (op_node_map != nullptr) {
400       (*op_node_map)[op] = node;
401     }
402   }
403   // set outputs
404   auto output_node = func_graph->output();
405   if (!IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
406     gb.SetOutputs({node_map[output_node]});
407     return gb.Get();
408   }
409   inner::NodePtrList outputs;
410   auto mt = output_node->cast<CNodePtr>();
411   (void)std::transform(mt->inputs().begin() + 1, mt->inputs().end(), std::back_inserter(outputs),
412                        [&node_map](const AnfNodePtr &no) { return node_map[no]; });
413   gb.SetOutputs(std::move(outputs));
414   return gb.Get();
415 }
416 
GetFuncGraphManager(const FuncGraphPtr & func_graph)417 FuncGraphManagerPtr GkUtils::GetFuncGraphManager(const FuncGraphPtr &func_graph) {
418   MS_EXCEPTION_IF_NULL(func_graph);
419   FuncGraphManagerPtr manager = func_graph->manager();
420   if (manager == nullptr) {
421     manager = Manage(func_graph, true);
422     func_graph->set_manager(manager);
423   }
424   return manager;
425 }
426 
UpdateFuncGraphManager(const FuncGraphManagerPtr & mng,const FuncGraphPtr & func_graph)427 void GkUtils::UpdateFuncGraphManager(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph) {
428   mng->RemoveRoots();
429   mng->KeepRoots({func_graph});
430 }
431 
GetOpsPrim(const std::string & name)432 PrimitivePtr GkUtils::GetOpsPrim(const std::string &name) {
433   const auto &op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
434   auto const iter = op_primc_fns.find(name);
435   if (iter == op_primc_fns.end()) {
436     return nullptr;
437   }
438   return iter->second();
439 }
440 
GetValidKernelNodes(const FuncGraphPtr & func_graph,AnfNodePtrList * node_list,AnfNodePtrList * input_list,AnfNodePtrList * output_list)441 void GkUtils::GetValidKernelNodes(const FuncGraphPtr &func_graph, AnfNodePtrList *node_list, AnfNodePtrList *input_list,
442                                   AnfNodePtrList *output_list) {
443   MS_EXCEPTION_IF_NULL(func_graph);
444   MS_EXCEPTION_IF_NULL(node_list);
445   AnfNodePtrList todos = TopoSort(func_graph->output());
446   (void)std::copy_if(todos.cbegin(), todos.cend(), std::back_inserter(*node_list), AnfUtils::IsRealCNodeKernel);
447 
448   if (input_list != nullptr) {
449     const auto &parameters = func_graph->parameters();
450     (void)input_list->insert(input_list->cend(), parameters.cbegin(), parameters.cend());
451   }
452   if (output_list != nullptr) {
453     if (IsPrimitiveCNode(todos.back(), prim::kPrimMakeTuple)) {
454       auto fg_output = todos.back()->cast<CNodePtr>();
455       MS_EXCEPTION_IF_NULL(fg_output);
456       auto output_inputs = fg_output->inputs();
457       (void)output_list->insert(output_list->cend(), output_inputs.cbegin() + 1, output_inputs.cend());
458     } else {
459       (void)output_list->emplace_back(func_graph->output());
460     }
461   }
462 }
463 
GetChannelInConvFormat(const std::string & format_string)464 int64_t GkUtils::GetChannelInConvFormat(const std::string &format_string) {
465   constexpr size_t nchwc_len = 5;
466   if (format_string.size() <= nchwc_len || format_string.find("NCHW") != 0) {
467     MS_LOG(EXCEPTION) << "Format must be NCHWnc, but got [" << format_string << "]";
468   }
469   constexpr size_t n_pos = 4;
470   auto channel = format_string.substr(n_pos, format_string.size() - nchwc_len);
471   return std::stol(channel);
472 }
473 
GetGraphKernelNodes(const FuncGraphPtr & func_graph)474 AnfNodePtrList GkUtils::GetGraphKernelNodes(const FuncGraphPtr &func_graph) {
475   AnfNodePtrList todos = TopoSort(func_graph->output());
476   AnfNodePtrList node_list;
477   (void)std::copy_if(todos.cbegin(), todos.cend(), std::back_inserter(node_list), AnfUtils::IsGraphKernel);
478   return node_list;
479 }
480 
UseAkgCceLib(const AnfNodePtr & node)481 bool GkUtils::UseAkgCceLib(const AnfNodePtr &node) {
482   if (node->isa<CNode>()) {
483     auto cnode = dyn_cast_ptr<CNode>(node);
484     if (cnode == nullptr) {
485       return false;
486     }
487     return cnode->HasAttr("use_akg_cce");
488   }
489   return false;
490 }
491 }  // namespace mindspore::graphkernel
492