1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
17
18 #include <algorithm>
19 #include <map>
20 #include <memory>
21 #include <sstream>
22 #include <unordered_map>
23 #include <utility>
24
25 #include "backend/common/graph_kernel/graph_kernel_flags.h"
26 #include "backend/common/graph_kernel/model/graph_builder.h"
27 #include "backend/common/graph_kernel/model/node.h"
28 #include "backend/common/graph_kernel/model/op_node.h"
29 #include "mindspore/core/ops/conv_pool_ops.h"
30 #include "mindspore/core/ops/math_ops.h"
31 #include "mindspore/core/ops/sequence_ops.h"
32 #include "runtime/hardware/device_context_manager.h"
33 #include "utils/anf_utils.h"
34 #include "utils/ms_context.h"
35
36 namespace mindspore::graphkernel {
37 namespace {
GetOutputSymbolicShape(const AnfNodePtr & node,size_t i)38 ListSymbolPtr GetOutputSymbolicShape(const AnfNodePtr &node, size_t i) {
39 if (node == nullptr) {
40 return nullptr;
41 }
42 auto abstract = node->abstract();
43 if (abstract == nullptr) {
44 return nullptr;
45 }
46 auto symbol_shape = abstract->GetSymbolicShape();
47 if (symbol_shape == nullptr) {
48 return nullptr;
49 }
50 if (abstract->isa<abstract::AbstractSequence>()) {
51 // multiple outputs
52 if (i >= symbol_shape->size()) {
53 MS_LOG(WARNING) << "Output idx '" << i << "' is out of range [0, " << symbol_shape->size()
54 << ") for node: " << node->ToString();
55 return nullptr;
56 }
57 auto shape_i = symbol_shape->symbols()[i];
58 if (shape_i == nullptr) {
59 return nullptr;
60 }
61 return shape_i->as_sptr_noexcept<ListSymbol>();
62 }
63 // single output
64 return symbol_shape;
65 }
66 } // namespace
67
ExtractGraphKernelName(const AnfNodePtrList & nodes,const std::string & prefix,const std::string & postfix)68 std::string GkUtils::ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix,
69 const std::string &postfix) {
70 std::stringstream name;
71 if (!prefix.empty()) {
72 name << prefix << "_";
73 }
74 for (const auto &node : nodes) {
75 if (AnfUtils::IsGraphKernel(node)) {
76 auto fg_flag_val = GetCNodeFuncGraph(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
77 name << GetValue<std::string>(fg_flag_val) << "_";
78 } else if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
79 name << GetCNodePrimitive(node)->name() << "_";
80 }
81 }
82 if (!postfix.empty()) {
83 name << postfix;
84 }
85 return name.str();
86 }
87
SpreadTuples(const AnfNodePtrList & nodes,size_t begin_index)88 AnfNodePtrList GkUtils::SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
89 AnfNodePtrList result;
90 for (size_t i = begin_index; i < nodes.size(); i++) {
91 if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) {
92 auto mt = nodes[i]->cast<CNodePtr>();
93 // recursively spread all inner tuples.
94 auto mt_inputs = SpreadTuples(mt->inputs(), 1);
95 (void)result.insert(result.cend(), mt_inputs.cbegin(), mt_inputs.cend());
96 } else {
97 result.push_back(nodes[i]);
98 }
99 }
100 return result;
101 }
102
GetValidOps(const std::vector<OpWithLevel> & ops_with_level,unsigned int level,const std::vector<std::string> & enable_ops_only,const std::vector<std::string> & enable_ops,const std::vector<std::string> & disable_ops)103 std::vector<PrimitivePtr> GkUtils::GetValidOps(const std::vector<OpWithLevel> &ops_with_level, unsigned int level,
104 const std::vector<std::string> &enable_ops_only,
105 const std::vector<std::string> &enable_ops,
106 const std::vector<std::string> &disable_ops) {
107 std::vector<PrimitivePtr> ops;
108 auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
109 if (!enable_ops_only.empty()) {
110 (void)std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::back_inserter(ops), new_prim);
111 return ops;
112 }
113 auto target = Callback::Instance()->GetTargetFromContext();
114 for (const auto &[op_target, op_level, op] : ops_with_level) {
115 if (op_target == kAllTarget || op_target == target) {
116 if (level >= op_level) {
117 (void)ops.emplace_back(op);
118 }
119 }
120 }
121 if (!enable_ops.empty()) {
122 (void)std::transform(enable_ops.begin(), enable_ops.end(), std::back_inserter(ops), new_prim);
123 }
124 if (!disable_ops.empty()) {
125 auto iter = std::remove_if(ops.begin(), ops.end(), [&disable_ops](const PrimitivePtr &p) {
126 return std::find(disable_ops.begin(), disable_ops.end(), p->name()) != disable_ops.end();
127 });
128 (void)ops.erase(iter, ops.cend());
129 }
130 return ops;
131 }
132
FilterExcludedOps(const std::vector<PrimitivePtr> & ops)133 std::vector<PrimitivePtr> GkUtils::FilterExcludedOps(const std::vector<PrimitivePtr> &ops) {
134 #ifndef MSLITE_ENABLE_GRAPH_KERNEL
135 if (Callback::Instance()->GetTargetFromContext() != kGPUDevice) {
136 return ops;
137 }
138 std::vector<PrimitivePtr> dst_ops;
139 const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
140 {kGPUDevice, MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
141 MS_EXCEPTION_IF_NULL(device_context);
142 auto deprecated_ptr = device_context->GetDeprecatedInterface();
143 MS_EXCEPTION_IF_NULL(deprecated_ptr);
144 auto major_compute_capability = deprecated_ptr->GetGPUCapabilityMajor();
145 std::unordered_map<std::string, int> limited_capacity_ops = {
146 {prim::kPrimConv2D->name(), 7}, {prim::kPrimMatMul->name(), 7}, {prim::kPrimBatchMatMul->name(), 7}};
147 std::vector<std::string> final_filter_ops;
148 for (auto op : ops) {
149 if (limited_capacity_ops.find(op->name()) != limited_capacity_ops.end() &&
150 limited_capacity_ops[op->name()] != major_compute_capability) {
151 (void)final_filter_ops.emplace_back(op->name());
152 } else {
153 (void)dst_ops.emplace_back(op);
154 }
155 }
156 // Give hint for excluded src_ops.
157 static bool give_hint = false;
158 if (!give_hint && final_filter_ops.size() > 0) {
159 give_hint = true;
160 for (size_t i = 0; i < final_filter_ops.size(); ++i) {
161 MS_LOG(INFO) << "For op : " << final_filter_ops[i]
162 << " can not be enabled in GraphKernel because the current device's computing capacity is "
163 << major_compute_capability << ", which is != " << limited_capacity_ops[final_filter_ops[i]];
164 }
165 }
166 return dst_ops;
167 #else
168 return ops;
169 #endif
170 }
171
IsKeepBasicNode(const AnfNodePtr & node)172 bool GkUtils::IsKeepBasicNode(const AnfNodePtr &node) {
173 MS_EXCEPTION_IF_NULL(node);
174 auto prim = GetCNodePrimitive(node);
175 auto target = Callback::Instance()->GetTargetFromContext();
176 if (prim == nullptr) {
177 return false;
178 }
179 // Heterogeneous computing is not support yet
180 // so if node's primitive_target is inconsistent with target from context
181 // the node cannot be added to the cluster list.
182 if (prim->HasAttr("primitive_target") && GetValue<std::string>(prim->GetAttr("primitive_target")) != target) {
183 return true;
184 }
185
186 // the "skip" is used by inplace node.
187 // the kAttrIsInternalOutputNopNode is used by internal output of KernelGraph.
188 const std::vector<std::string> exclude_bool_attrs = {"skip", kAttrIsInternalOutputNopNode};
189 if (std::any_of(exclude_bool_attrs.cbegin(), exclude_bool_attrs.cend(), [&prim](const std::string &attr_name) {
190 return prim->HasAttr(attr_name) && GetValue<bool>(prim->GetAttr(attr_name));
191 })) {
192 return true;
193 }
194
195 // If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is.
196 const std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index",
197 "aggregate", "aggregate_input_index"};
198 if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(),
199 [&prim](const std::string &attr_name) -> bool { return prim->HasAttr(attr_name); })) {
200 return true;
201 }
202 auto cnode = node->cast<CNodePtr>();
203 return (cnode != nullptr && cnode->HasAttr("keep_basic"));
204 }
205
NewRealCNode(const std::vector<AnfNodePtr> & inputs,const FuncGraphPtr & func_graph,const std::vector<inner::NodeBase> & out_info_list,const CallbackPtr & cb)206 CNodePtr GkUtils::NewRealCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph,
207 const std::vector<inner::NodeBase> &out_info_list, const CallbackPtr &cb) {
208 auto cnode = func_graph->NewCNode(inputs);
209 MS_EXCEPTION_IF_NULL(cnode);
210
211 if (out_info_list.size() == 0) {
212 MS_LOG(EXCEPTION) << "CNode must have output!";
213 }
214
215 // Setup abstract.
216 AbstractBasePtrList abs_list;
217 (void)std::transform(
218 out_info_list.begin(), out_info_list.end(), std::back_inserter(abs_list), [](const inner::NodeBase &out_info) {
219 auto abs_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(out_info.type), out_info.shape);
220 return abs_tensor;
221 });
222 if (abs_list.size() == 1) {
223 cnode->set_abstract(abs_list[0]);
224 } else {
225 cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
226 }
227
228 // Setup kernel build info.
229 cb->SetBasicNodeKernelInfo(cnode, out_info_list);
230 func_graph->AddNode(cnode);
231 return cnode;
232 }
233
LiteGraph2AnfGraph(const inner::LiteGraphPtr & lite_graph,const CallbackPtr & cb)234 FuncGraphPtr GkUtils::LiteGraph2AnfGraph(const inner::LiteGraphPtr &lite_graph, const CallbackPtr &cb) {
235 auto func_graph = std::make_shared<FuncGraph>();
236 std::map<inner::NodePtr, AnfNodePtr> node_map;
237 for (const auto &inp : lite_graph->inputs()) {
238 auto param = func_graph->add_parameter();
239 node_map[inp] = param;
240 param->set_abstract(std::make_shared<abstract::AbstractTensor>(TypeIdToType(inp->type), inp->shape));
241 cb->SetBasicNodeKernelInfo(param, {{inp->shape, inp->type, inp->format}});
242 }
243 // Create CNodes.
244 for (const auto &op_node : lite_graph->GetOrderedNodes()) {
245 if (op_node->NodeType() != inner::NType::Primitive) {
246 MS_LOG(EXCEPTION) << "Node " << op_node->debug_name() << " should be a Primitive node";
247 }
248 auto op = std::static_pointer_cast<inner::PrimOp>(op_node);
249 auto primitive = std::make_shared<Primitive>(op->op(), op->attrs());
250 auto prim = GetOpsPrim(primitive->name());
251 if (prim != nullptr) {
252 (void)primitive->AddAttr(kAttrInputNames, prim->GetAttr(kAttrInputNames));
253 (void)primitive->AddAttr(kAttrOutputNames, prim->GetAttr(kAttrOutputNames));
254 }
255 AnfNodePtrList inputs = {NewValueNode(primitive)};
256 (void)std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(inputs),
257 [&node_map, &cb](const inner::NodePtr &inp) -> AnfNodePtr {
258 const auto iter = node_map.find(inp);
259 if (iter != node_map.end()) {
260 return iter->second;
261 } else {
262 auto node_type = inp->NodeType();
263 if (node_type != inner::NType::Tensor && node_type != inner::NType::Scalar &&
264 node_type != inner::NType::Tuple) {
265 MS_LOG(EXCEPTION)
266 << "Node " << inp->debug_name() << " should be a Tensor or Scalar node";
267 }
268 ValuePtr inp_value = nullptr;
269 if (node_type == inner::NType::Tensor) {
270 inp_value = inp->As<inner::ConstTensorNode>()->data();
271 } else if (node_type == inner::NType::Scalar) {
272 inp_value = inp->As<inner::ConstScalarNode>()->data();
273 } else {
274 inp_value = inp->As<inner::ConstTupleNode>()->data();
275 }
276 auto value_node = NewValueNode(inp_value);
277 value_node->set_abstract(inp_value->ToAbstract());
278 cb->SetBasicNodeKernelInfo(value_node, {{inp->shape, inp->type, inp->format}});
279 return value_node;
280 }
281 });
282 auto output_info_list = op->outputs();
283 if (output_info_list.empty()) {
284 (void)output_info_list.emplace_back(static_cast<inner::NodeBase>(*op));
285 }
286 auto cnode = NewRealCNode(inputs, func_graph, output_info_list, cb);
287 MS_EXCEPTION_IF_NULL(cnode);
288 node_map[op_node] = cnode;
289 }
290 if (lite_graph->GetOutputs().empty()) {
291 MS_LOG(EXCEPTION) << "The output of LiteGraph " << lite_graph->name() << " is empty.";
292 } else if (lite_graph->GetOutputs().size() == 1) {
293 func_graph->set_output(node_map[lite_graph->GetOutputs()[0]]);
294 } else {
295 AnfNodePtrList mt_inputs;
296 AbstractBasePtrList out_abs_list;
297 (void)std::transform(lite_graph->GetOutputs().begin(), lite_graph->GetOutputs().end(),
298 std::back_inserter(mt_inputs), [&node_map, &out_abs_list](const inner::NodePtr &out) {
299 auto out_node = node_map[out];
300 MS_EXCEPTION_IF_NULL(out_node);
301 (void)out_abs_list.emplace_back(out_node->abstract());
302 return out_node;
303 });
304 auto mt = func_graph->NewCNode(prim::kPrimMakeTuple, mt_inputs);
305 mt->set_abstract(std::make_shared<abstract::AbstractTuple>(out_abs_list));
306 cb->SetEmptyKernelInfo(mt);
307 func_graph->AddNode(mt);
308 func_graph->set_output(mt);
309 }
310 return func_graph;
311 }
312
InputValue2Tensor(ValuePtr input_value)313 tensor::TensorPtr InputValue2Tensor(ValuePtr input_value) {
314 // input value of a cnode can be one of tensor, valuesequence and int,
315 // in order to emit litegraph node by gb.Value, convert the type of value to tensor anyway
316 tensor::TensorPtr input_tensor = nullptr;
317 if (input_value->isa<Int32Imm>() || input_value->isa<Int64Imm>()) {
318 auto input_num = AnfUtils::GetIntValue(input_value);
319 input_tensor = std::make_shared<tensor::Tensor>(input_num);
320 } else if (input_value->isa<ValueSequence>()) {
321 auto input_seq = input_value->cast<ValueSequencePtr>()->value();
322 std::vector<int64_t> input_vec;
323 (void)std::transform(input_seq.begin(), input_seq.end(), std::back_inserter(input_vec),
324 [](auto v) { return AnfUtils::GetIntValue(v); });
325 input_tensor = std::make_shared<tensor::Tensor>(input_vec);
326 } else if (input_value->isa<tensor::Tensor>()) {
327 input_tensor = input_value->cast<tensor::TensorPtr>();
328 } else if (input_value->isa<BoolImm>()) {
329 auto input_bool = GetValue<bool>(input_value);
330 input_tensor = std::make_shared<tensor::Tensor>(input_bool);
331 } else {
332 MS_LOG(EXCEPTION) << "Unsupported Type in InputValue2Tensor";
333 }
334 return input_tensor;
335 }
336
AnfGraph2LiteGraph(const FuncGraphPtr & func_graph,HashMap<inner::NodePtr,AnfNodePtr> * op_node_map)337 inner::LiteGraphPtr GkUtils::AnfGraph2LiteGraph(const FuncGraphPtr &func_graph,
338 HashMap<inner::NodePtr, AnfNodePtr> *op_node_map) {
339 std::string name = "Default";
340 if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
341 name = GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
342 }
343 inner::GraphBuilder gb(name);
344 std::map<AnfNodePtr, inner::NodePtr> node_map;
345 auto todos = TopoSort(func_graph->output(), SuccIncoming,
346 [](const AnfNodePtr &node) { return node->isa<CNode>() ? FOLLOW : EXCLUDE; });
347 const auto ¶ms = func_graph->parameters();
348 auto cb = Callback::Instance();
349 auto ExtractBuildInfo = [&cb](const AnfNodePtr &node) -> inner::NodeBaseList {
350 inner::NodeBaseList listinfo;
351 size_t output_num = AnfUtils::GetOutputTensorNum(node);
352 for (size_t i = 0; i < output_num; ++i) {
353 auto shape = cb->GetOutputShape(node, i);
354 auto type = cb->GetOutputType(node, i);
355 auto format = cb->GetOutputFormat(node, i);
356 auto symbol_shape = GetOutputSymbolicShape(node, i);
357 listinfo.push_back(inner::NodeBase({shape, type, format, symbol_shape}));
358 }
359 return listinfo;
360 };
361 // set inputs
362 for (auto &p : params) {
363 node_map[p] = gb.Parameter(ExtractBuildInfo(p)[0]);
364 }
365 // set ops
366 for (auto node : todos) {
367 auto cnode = node->cast<CNodePtr>();
368 MS_EXCEPTION_IF_NULL(cnode);
369 if (node == func_graph->output() && IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
370 break;
371 }
372 auto prim = GetCNodePrimitive(cnode);
373 MS_EXCEPTION_IF_NULL(prim);
374 inner::NodePtrList inputs;
375 for (size_t i = 1; i < cnode->size(); ++i) {
376 auto input_i = cnode->input(i);
377 const auto iter = node_map.find(input_i);
378 if (iter != node_map.end()) {
379 // input is parameter or cnode
380 inputs.push_back(iter->second);
381 continue;
382 }
383 // input is valuenode
384 auto input_value_node = input_i->cast<ValueNodePtr>();
385 auto input_value = input_value_node->value();
386 constexpr size_t idx = 2;
387 inner::NodePtr input_node;
388 if ((IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) && i == idx) {
389 input_node = std::make_shared<inner::ConstScalarNode>(input_value);
390 } else {
391 auto tensor = InputValue2Tensor(input_value);
392 MS_EXCEPTION_IF_NULL(tensor);
393 input_node = gb.Value(tensor);
394 }
395 inputs.push_back(input_node);
396 }
397 auto op = gb.Op(AnfUtils::GetCNodeName(node), ExtractBuildInfo(node), inputs, prim->attrs());
398 node_map[node] = op;
399 if (op_node_map != nullptr) {
400 (*op_node_map)[op] = node;
401 }
402 }
403 // set outputs
404 auto output_node = func_graph->output();
405 if (!IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
406 gb.SetOutputs({node_map[output_node]});
407 return gb.Get();
408 }
409 inner::NodePtrList outputs;
410 auto mt = output_node->cast<CNodePtr>();
411 (void)std::transform(mt->inputs().begin() + 1, mt->inputs().end(), std::back_inserter(outputs),
412 [&node_map](const AnfNodePtr &no) { return node_map[no]; });
413 gb.SetOutputs(std::move(outputs));
414 return gb.Get();
415 }
416
GetFuncGraphManager(const FuncGraphPtr & func_graph)417 FuncGraphManagerPtr GkUtils::GetFuncGraphManager(const FuncGraphPtr &func_graph) {
418 MS_EXCEPTION_IF_NULL(func_graph);
419 FuncGraphManagerPtr manager = func_graph->manager();
420 if (manager == nullptr) {
421 manager = Manage(func_graph, true);
422 func_graph->set_manager(manager);
423 }
424 return manager;
425 }
426
UpdateFuncGraphManager(const FuncGraphManagerPtr & mng,const FuncGraphPtr & func_graph)427 void GkUtils::UpdateFuncGraphManager(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph) {
428 mng->RemoveRoots();
429 mng->KeepRoots({func_graph});
430 }
431
GetOpsPrim(const std::string & name)432 PrimitivePtr GkUtils::GetOpsPrim(const std::string &name) {
433 const auto &op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
434 auto const iter = op_primc_fns.find(name);
435 if (iter == op_primc_fns.end()) {
436 return nullptr;
437 }
438 return iter->second();
439 }
440
GetValidKernelNodes(const FuncGraphPtr & func_graph,AnfNodePtrList * node_list,AnfNodePtrList * input_list,AnfNodePtrList * output_list)441 void GkUtils::GetValidKernelNodes(const FuncGraphPtr &func_graph, AnfNodePtrList *node_list, AnfNodePtrList *input_list,
442 AnfNodePtrList *output_list) {
443 MS_EXCEPTION_IF_NULL(func_graph);
444 MS_EXCEPTION_IF_NULL(node_list);
445 AnfNodePtrList todos = TopoSort(func_graph->output());
446 (void)std::copy_if(todos.cbegin(), todos.cend(), std::back_inserter(*node_list), AnfUtils::IsRealCNodeKernel);
447
448 if (input_list != nullptr) {
449 const auto ¶meters = func_graph->parameters();
450 (void)input_list->insert(input_list->cend(), parameters.cbegin(), parameters.cend());
451 }
452 if (output_list != nullptr) {
453 if (IsPrimitiveCNode(todos.back(), prim::kPrimMakeTuple)) {
454 auto fg_output = todos.back()->cast<CNodePtr>();
455 MS_EXCEPTION_IF_NULL(fg_output);
456 auto output_inputs = fg_output->inputs();
457 (void)output_list->insert(output_list->cend(), output_inputs.cbegin() + 1, output_inputs.cend());
458 } else {
459 (void)output_list->emplace_back(func_graph->output());
460 }
461 }
462 }
463
GetChannelInConvFormat(const std::string & format_string)464 int64_t GkUtils::GetChannelInConvFormat(const std::string &format_string) {
465 constexpr size_t nchwc_len = 5;
466 if (format_string.size() <= nchwc_len || format_string.find("NCHW") != 0) {
467 MS_LOG(EXCEPTION) << "Format must be NCHWnc, but got [" << format_string << "]";
468 }
469 constexpr size_t n_pos = 4;
470 auto channel = format_string.substr(n_pos, format_string.size() - nchwc_len);
471 return std::stol(channel);
472 }
473
GetGraphKernelNodes(const FuncGraphPtr & func_graph)474 AnfNodePtrList GkUtils::GetGraphKernelNodes(const FuncGraphPtr &func_graph) {
475 AnfNodePtrList todos = TopoSort(func_graph->output());
476 AnfNodePtrList node_list;
477 (void)std::copy_if(todos.cbegin(), todos.cend(), std::back_inserter(node_list), AnfUtils::IsGraphKernel);
478 return node_list;
479 }
480
UseAkgCceLib(const AnfNodePtr & node)481 bool GkUtils::UseAkgCceLib(const AnfNodePtr &node) {
482 if (node->isa<CNode>()) {
483 auto cnode = dyn_cast_ptr<CNode>(node);
484 if (cnode == nullptr) {
485 return false;
486 }
487 return cnode->HasAttr("use_akg_cce");
488 }
489 return false;
490 }
491 } // namespace mindspore::graphkernel
492