1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/backend/kernel_graph.h"
17 #include <algorithm>
18 #include <exception>
19 #include <queue>
20 #include <set>
21 #include "abstract/ops/primitive_infer_map.h"
22 #include "backend/common/session/exec_order_builder.h"
23 #include "include/backend/anf_runtime_algorithm.h"
24 #include "include/backend/kernel_info.h"
25 #include "include/backend/optimizer/helper.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "include/common/utils/utils.h"
28 #include "kernel/common_utils.h"
29 #include "kernel/framework_utils.h"
30 #include "kernel/kernel_build_info.h"
31 #include "ops/array_ops.h"
32 #include "ops/op_def.h"
33 #include "ops/framework_ops.h"
34 #include "ops/nn_optimizer_ops.h"
35 #include "ops/other_ops.h"
36 #include "ops/sequence_ops.h"
37 #include "runtime/device/kernel_runtime_manager.h"
38 #include "utils/anf_utils.h"
39 #include "utils/check_convert_utils.h"
40 #include "utils/hash_set.h"
41
42 namespace mindspore {
43 namespace session {
44 namespace {
45 constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
46 constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
47 constexpr size_t k5dDims = 5;
48 const std::set<std::string> kOpAssignKernelNameList = {mindspore::kAssignOpName, mindspore::kAssignAddOpName,
49 mindspore::kAssignSubOpName};
50
GetCallRealOutputs(const AnfNodePtr & call_node)51 AnfNodePtrList GetCallRealOutputs(const AnfNodePtr &call_node) {
52 auto item_with_index =
53 common::AnfAlgo::VisitKernelWithReturnType(call_node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
54 AnfNodePtr node = item_with_index.first;
55 MS_EXCEPTION_IF_NULL(node);
56 if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
57 auto outputs = common::AnfAlgo::GetAllOutput(node);
58 std::set<AnfNodePtr> memo;
59 AnfNodePtrList new_output;
60 for (auto &output : outputs) {
61 if (memo.find(output) != memo.end()) {
62 continue;
63 }
64 memo.insert(output);
65 new_output.push_back(output);
66 }
67 if (new_output.size() == 1 && common::AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) {
68 node = new_output[0];
69 }
70 }
71 if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
72 return {node};
73 }
74 AnfNodePtrList real_inputs;
75 auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node->cast<CNodePtr>());
76 for (const auto &child_graph : child_graphs) {
77 MS_EXCEPTION_IF_NULL(child_graph);
78 auto real_input = child_graph->output();
79 auto child_real_inputs = GetCallRealOutputs(real_input);
80 std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs));
81 }
82 return real_inputs;
83 }
84
IsSameLabel(const CNodePtr & left,const CNodePtr & right)85 bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
86 if (left == right) {
87 return true;
88 }
89 if (left == nullptr || right == nullptr) {
90 return false;
91 }
92 if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) {
93 return false;
94 }
95 if (common::AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && common::AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) {
96 return common::AnfAlgo::GetNodeAttr<uint32_t>(left, kAttrLabelIndex) ==
97 common::AnfAlgo::GetNodeAttr<uint32_t>(right, kAttrLabelIndex);
98 }
99 return false;
100 }
101
SyncDeviceInfoToValueNode(const ValueNodePtr & value_node,std::vector<std::string> * device_formats,std::vector<TypeId> * device_types)102 void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector<std::string> *device_formats,
103 std::vector<TypeId> *device_types) {
104 MS_EXCEPTION_IF_NULL(value_node);
105 MS_EXCEPTION_IF_NULL(device_formats);
106 MS_EXCEPTION_IF_NULL(device_types);
107 ValuePtr value = value_node->value();
108 std::vector<tensor::BaseTensorPtr> tensors;
109 TensorValueToTensor(value, &tensors);
110 if (!tensors.empty()) {
111 device_formats->clear();
112 device_types->clear();
113 for (const auto &tensor : tensors) {
114 MS_EXCEPTION_IF_NULL(tensor);
115 auto device_sync = tensor->device_address();
116 if (device_sync != nullptr) {
117 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
118 MS_EXCEPTION_IF_NULL(device_address);
119 device_formats->emplace_back(device_address->format());
120 device_types->emplace_back(device_address->type_id());
121 continue;
122 }
123 device_formats->emplace_back(kOpFormat_DEFAULT);
124 device_types->emplace_back(kTypeUnknown);
125 }
126 }
127 }
128
SetInternalOutputAttr(const AnfNodePtr & node)129 void SetInternalOutputAttr(const AnfNodePtr &node) {
130 if (!common::AnfAlgo::IsNopNode(node)) {
131 return;
132 }
133 auto p = GetCNodePrimitive(node);
134 if (p == nullptr) {
135 return;
136 }
137 auto prim_node = NewValueNode(p->Clone());
138 MS_EXCEPTION_IF_NULL(node);
139 auto cnode = node->cast<CNodePtr>();
140 MS_EXCEPTION_IF_NULL(cnode);
141 cnode->set_input(kAnfPrimitiveIndex, prim_node);
142 common::AnfAlgo::SetNodeAttr(kAttrIsInternalOutputNopNode, MakeValue(true), node);
143 }
144 } // namespace
145
MakeValueNode(const AnfNodePtr & node) const146 AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) const {
147 MS_EXCEPTION_IF_NULL(node);
148 auto value_node = node->cast<ValueNodePtr>();
149 if (value_node == nullptr) {
150 return nullptr;
151 }
152 ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
153 MS_EXCEPTION_IF_NULL(new_value_node);
154 new_value_node->set_abstract(value_node->abstract());
155 this->SetKernelInfoForNode(new_value_node);
156 return new_value_node;
157 }
158
outputs() const159 AnfNodePtrList KernelGraph::outputs() const {
160 auto graph_output = output();
161 if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) {
162 auto make_tuple = output()->cast<CNodePtr>();
163 MS_EXCEPTION_IF_NULL(make_tuple);
164 auto &inputs = make_tuple->inputs();
165 return AnfNodePtrList(inputs.begin() + 1, inputs.end());
166 }
167 return AnfNodePtrList(1, graph_output);
168 }
169
SetNodeOutputEdges()170 void KernelGraph::SetNodeOutputEdges() {
171 node_output_edges_.clear();
172 std::queue<AnfNodePtr> to_visit;
173 to_visit.emplace(get_return());
174 auto seen = NewSeenGeneration();
175 while (!to_visit.empty()) {
176 auto node = to_visit.front();
177 to_visit.pop();
178 MS_EXCEPTION_IF_NULL(node);
179 if (!node->isa<CNode>()) {
180 continue;
181 }
182 auto cnode = node->cast<CNodePtr>();
183 MS_EXCEPTION_IF_NULL(cnode);
184 for (auto &input : cnode->inputs()) {
185 (void)node_output_edges_[input].emplace_back(node);
186 if (input->seen_ == seen) {
187 continue;
188 }
189 to_visit.emplace(input);
190 input->seen_ = seen;
191 }
192 }
193 }
194
SetExecOrderByDefault()195 void KernelGraph::SetExecOrderByDefault() {
196 ExecOrderBuilder builder;
197 builder.Build(this, &execution_order_, &node_output_edges_);
198 execution_order_ = SortStartLabelAndEndGoto();
199 }
200
SortStartLabelAndEndGoto()201 std::vector<CNodePtr> KernelGraph::SortStartLabelAndEndGoto() {
202 std::vector<CNodePtr> re_order;
203 if (start_label_ != nullptr) {
204 re_order.emplace_back(start_label_);
205 }
206 for (auto &node : execution_order_) {
207 if (node == start_label_ || node == end_goto_) {
208 continue;
209 }
210
211 if (IsSameLabel(node, end_goto_)) {
212 end_goto_ = node;
213 MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id();
214 continue;
215 }
216
217 if (IsSameLabel(node, start_label_)) {
218 start_label_ = node;
219 MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id();
220 continue;
221 }
222
223 //
224 // Re-order:
225 // u = LabelGoto(...)
226 // x = Mul(...)
227 // LabelSet(u)
228 // To:
229 // u = LabelGoto(...)
230 // LabelSet(u)
231 // x = Mul(...)
232 // This prevent Mul be skipped.
233 //
234 if (IsPrimitiveCNode(node, prim::kPrimLabelSet) && (re_order.back() != node->input(1))) {
235 auto iter = std::find(re_order.crbegin() + 1, re_order.crend(), node->input(1));
236 if (iter != re_order.rend()) {
237 re_order.insert(iter.base(), node);
238 continue;
239 }
240 }
241
242 re_order.emplace_back(node);
243 }
244 if (end_goto_ != nullptr) {
245 re_order.emplace_back(end_goto_);
246 }
247 return re_order;
248 }
249
NewCNodeWeak(AnfNodeWeakPtrList && weak_inputs)250 CNodePtr KernelGraph::NewCNodeWeak(AnfNodeWeakPtrList &&weak_inputs) {
251 auto cnode = FuncGraph::NewCNodeWeak(std::move(weak_inputs));
252 PostNewCNode(cnode);
253 return cnode;
254 }
255
NewCNodeWeak(const AnfNodeWeakPtrList & weak_inputs)256 CNodePtr KernelGraph::NewCNodeWeak(const AnfNodeWeakPtrList &weak_inputs) {
257 auto cnode = FuncGraph::NewCNodeWeak(weak_inputs);
258 PostNewCNode(cnode);
259 return cnode;
260 }
261
NewCNode(AnfNodePtrList && inputs)262 CNodePtr KernelGraph::NewCNode(AnfNodePtrList &&inputs) {
263 auto cnode = FuncGraph::NewCNode(std::move(inputs));
264 PostNewCNode(cnode);
265 return cnode;
266 }
267
NewCNode(const AnfNodePtrList & inputs)268 CNodePtr KernelGraph::NewCNode(const AnfNodePtrList &inputs) {
269 auto cnode = FuncGraph::NewCNode(inputs);
270 PostNewCNode(cnode);
271 return cnode;
272 }
273
PostNewCNode(const CNodePtr & cnode) const274 void KernelGraph::PostNewCNode(const CNodePtr &cnode) const {
275 MS_EXCEPTION_IF_NULL(cnode);
276 if (cnode->abstract() == nullptr) {
277 cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
278 }
279 if (common::AnfAlgo::IsGraphKernel(cnode)) {
280 CreateKernelInfoFromNewParameter(cnode);
281 }
282 if (common::AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
283 common::AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
284 }
285 if (cnode->kernel_info() == nullptr) {
286 SetKernelInfoForNode(cnode);
287 }
288 AnfAlgo::SetGraphId(graph_id_, cnode.get());
289 }
290
NewCNodeWithInfos(const AnfNodePtrList & inputs,const CNodePtr & ori_cnode)291 CNodePtr KernelGraph::NewCNodeWithInfos(const AnfNodePtrList &inputs, const CNodePtr &ori_cnode) {
292 auto cnode = NewCNode(inputs);
293 if (ori_cnode != nullptr) {
294 cnode->set_attrs(ori_cnode->attrs());
295 cnode->set_primal_attrs(ori_cnode->primal_attrs());
296 cnode->set_primal_debug_infos(ori_cnode->primal_debug_infos());
297 }
298 return cnode;
299 }
300
CreateKernelInfoFromNewParameter(const CNodePtr & cnode) const301 void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) const {
302 auto func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
303 MS_EXCEPTION_IF_NULL(func_graph);
304
305 AnfNodePtrList node_list;
306 AnfNodePtrList input_list;
307 AnfNodePtrList output_list;
308 kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
309 for (auto &anf_node : node_list) {
310 MS_EXCEPTION_IF_NULL(anf_node);
311 if (anf_node->kernel_info() == nullptr) {
312 anf_node->set_kernel_info(std::make_shared<device::KernelInfo>());
313 }
314 auto anf_cnode = anf_node->cast<CNodePtr>();
315 MS_EXCEPTION_IF_NULL(anf_cnode);
316 size_t input_num = common::AnfAlgo::GetInputTensorNum(anf_cnode);
317 for (size_t i = 0; i < input_num; ++i) {
318 auto input_node = anf_cnode->input(i + 1);
319 MS_EXCEPTION_IF_NULL(input_node);
320 if (IsValueNode<tensor::Tensor>(input_node)) {
321 auto new_input_node = MakeValueNode(input_node);
322 if (new_input_node != nullptr) {
323 anf_cnode->set_input(i + 1, new_input_node);
324 }
325 }
326 }
327 }
328 for (auto &anf_node : input_list) {
329 MS_EXCEPTION_IF_NULL(anf_node);
330 if (anf_node->kernel_info() == nullptr) {
331 anf_node->set_kernel_info(std::make_shared<device::KernelInfo>());
332 }
333 }
334 }
335
ResetAssignInputFeatureMapFlag(const CNodePtr & cnode) const336 void KernelGraph::ResetAssignInputFeatureMapFlag(const CNodePtr &cnode) const {
337 if (kOpAssignKernelNameList.find(common::AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) {
338 MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map "
339 "flag but got the node :"
340 << cnode->DebugString();
341 }
342 auto input_node = common::AnfAlgo::GetInputNode(cnode, 0);
343 MS_EXCEPTION_IF_NULL(input_node);
344 auto assign_value_node = common::AnfAlgo::GetInputNode(cnode, 1);
345 if (AnfAlgo::IsFeatureMapOutput(input_node)) {
346 return;
347 }
348 if (!AnfAlgo::IsFeatureMapOutput(input_node) && AnfAlgo::IsFeatureMapOutput(assign_value_node)) {
349 auto kernel_info = dynamic_cast<device::KernelInfo *>(input_node->kernel_info());
350 MS_EXCEPTION_IF_NULL(kernel_info);
351 kernel_info->set_feature_map_flag(true);
352 }
353 }
354
SetKernelInfoForNode(const AnfNodePtr & node) const355 void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
356 MS_EXCEPTION_IF_NULL(node);
357 auto kernel_info = std::make_shared<device::KernelInfo>();
358 MS_EXCEPTION_IF_NULL(kernel_info);
359 node->set_kernel_info(kernel_info);
360 if (node->isa<CNode>()) {
361 if (kOpAssignKernelNameList.find(common::AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) {
362 ResetAssignInputFeatureMapFlag(node->cast<CNodePtr>());
363 }
364 #if defined(__APPLE__)
365 std::vector<int> feature_map_input_indexs;
366 #else
367 std::vector<size_t> feature_map_input_indexs;
368 #endif
369 kernel_info->set_feature_map_flag(false);
370 size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
371 for (size_t index = 0; index < input_num; ++index) {
372 if (AnfAlgo::IsFeatureMapInput(node, index)) {
373 kernel_info->set_feature_map_flag(true);
374 feature_map_input_indexs.push_back(index);
375 }
376 }
377 if (common::AnfAlgo::GetInputTensorNum(node) == 0) {
378 kernel_info->set_feature_map_flag(true);
379 }
380 if (AnfUtils::IsRealKernel(node)) {
381 // if the node only has the primitive(such as getNext) or the node's input has a feature map input
382 // then the node's output is a feature map output
383 common::AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node);
384 common::AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), node);
385 }
386 return;
387 }
388 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
389 MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
390 // set the format of value_node to DEFAULT_FORMAT
391 std::vector<TypeId> types;
392 std::vector<std::string> formats = {kOpFormat_DEFAULT};
393 if (node->isa<ValueNode>()) {
394 kernel_info->set_feature_map_flag(false);
395 (void)types.emplace_back(kTypeUnknown);
396 auto value_node = node->cast<ValueNodePtr>();
397 SyncDeviceInfoToValueNode(value_node, &formats, &types);
398 }
399 if (node->isa<Parameter>()) {
400 auto parameter = node->cast<ParameterPtr>();
401 MS_EXCEPTION_IF_NULL(parameter);
402 bool is_weight = common::AnfAlgo::IsParameterWeight(parameter);
403 kernel_info->set_feature_map_flag(!is_weight);
404 types.push_back(is_weight ? kTypeUnknown : common::AnfAlgo::GetOutputInferDataType(parameter, 0));
405 }
406 // set parameter initaial device data type
407 auto abs = node->abstract();
408 auto abs_type = AnfAlgo::GetAbstractObjectType(abs);
409 auto kernel_object_type = kernel::TypeIdToKernelObjectTypeForTupleUnfold(abs_type);
410 if (common::AnfAlgo::IsDynamicSequence(node) || (node->isa<ValueNode>() && AnfAlgo::IsSequenceOutputOfScalar(node))) {
411 kernel_object_type = kernel::KernelObjectType::TUPLE;
412 } else if (abs_type == kObjectTypeTuple || abs_type == kObjectTypeList) {
413 auto tuple_len = AnfAlgo::GetOutputElementNum(node);
414 formats = std::vector<std::string>(tuple_len, formats[0]);
415 types = std::vector<TypeId>(tuple_len, types[0]);
416 }
417 kernel_build_info_builder->SetOutputsKernelObjectType({kernel_object_type});
418 kernel_build_info_builder->SetOutputsFormat(formats);
419 kernel_build_info_builder->SetOutputsDeviceType(types);
420 MS_LOG(DEBUG) << "Kernel object type is:" << TypeIdLabel(abs_type)
421 << " for parameter or value node:" << node->fullname_with_scope()
422 << ", debug name:" << node->DebugString();
423 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get());
424 }
425
NewCNode(const CNodePtr & cnode)426 CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
427 MS_EXCEPTION_IF_NULL(cnode);
428 auto new_cnode = std::make_shared<CNode>(*cnode);
429 new_cnode->CloneUserData(cnode);
430 new_cnode->set_scope(cnode->scope());
431 new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
432 // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map
433 if (BackendNodeExistInFrontBackendMap(cnode)) {
434 FrontBackendlMapUpdate(cnode, new_cnode);
435 }
436 AnfAlgo::SetGraphId(graph_id_, cnode.get());
437 return new_cnode;
438 }
439
NewParameter(const ParameterPtr & parameter)440 ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) {
441 auto abstract = parameter == nullptr ? std::make_shared<abstract::AbstractNone>() : parameter->abstract();
442 auto new_parameter = NewParameter(abstract);
443 MS_EXCEPTION_IF_NULL(new_parameter);
444 // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
445 if (parameter != nullptr) {
446 new_parameter->set_name(parameter->name());
447 if (common::AnfAlgo::IsParameterWeight(parameter)) {
448 new_parameter->set_default_param(parameter->default_param());
449 }
450 } else {
451 // The created parameter name is empty, so set name to ensure that the parameter name is unique.
452 new_parameter->set_name(new_parameter->UniqueName());
453 }
454 // create kernel_info form new parameter
455 SetKernelInfoForNode(new_parameter);
456 AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
457 return new_parameter;
458 }
459
NewParameter(const abstract::AbstractBasePtr & abstract)460 ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) {
461 ParameterPtr new_parameter = add_parameter();
462 MS_EXCEPTION_IF_NULL(new_parameter);
463 new_parameter->set_abstract(abstract);
464 // The created parameter name is empty, so set name to ensure that the parameter name is unique.
465 new_parameter->set_name(new_parameter->UniqueName());
466 // create kernel_info form new parameter
467 SetKernelInfoForNode(new_parameter);
468 AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
469 return new_parameter;
470 }
471
NewValueNode(const ValueNodePtr & value_node) const472 ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) const {
473 MS_EXCEPTION_IF_NULL(value_node);
474 auto new_value_node = MakeValueNode(value_node)->cast<ValueNodePtr>();
475 SetKernelInfoForNode(new_value_node);
476 AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
477 return new_value_node;
478 }
479
NewValueNode(const AbstractBasePtr & abstract,const ValuePtr & value)480 ValueNodePtr KernelGraph::NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value) {
481 MS_EXCEPTION_IF_NULL(abstract);
482 MS_EXCEPTION_IF_NULL(value);
483 ValueNodePtr new_value_node = std::make_shared<ValueNode>(value);
484 MS_EXCEPTION_IF_NULL(new_value_node);
485 new_value_node->set_abstract(abstract);
486 SetKernelInfoForNode(new_value_node);
487 AnfAlgo::SetGraphId(graph_id(), new_value_node.get());
488 AddValueNodeToGraph(new_value_node);
489 return new_value_node;
490 }
491
NewValueNode(const tensor::TensorPtr & input_tensor)492 ValueNodePtr KernelGraph::NewValueNode(const tensor::TensorPtr &input_tensor) {
493 MS_EXCEPTION_IF_NULL(input_tensor);
494 ValueNodePtr value_node = nullptr;
495 if (input_tensor->data_type() == kObjectTypeString) {
496 std::string value_string;
497 (void)value_string.assign(static_cast<char *>(input_tensor->data_c()), LongToSize(input_tensor->data().size()));
498 StringImmPtr string_imm_value = std::make_shared<StringImm>(value_string);
499 value_node = std::make_shared<ValueNode>(string_imm_value);
500 } else {
501 value_node = std::make_shared<ValueNode>(input_tensor);
502 }
503 MS_EXCEPTION_IF_NULL(value_node);
504 value_node->set_abstract(input_tensor->ToAbstract());
505 // add value node to graph
506 auto input_value_node = NewValueNode(value_node);
507 AddValueNodeToGraph(input_value_node);
508 return input_value_node;
509 }
510
NewValueNode(const ValuePtr & input_value)511 ValueNodePtr KernelGraph::NewValueNode(const ValuePtr &input_value) {
512 if (input_value->isa<tensor::Tensor>()) {
513 return NewValueNode(input_value->cast<tensor::TensorPtr>());
514 }
515
516 auto value_node = std::make_shared<ValueNode>(input_value);
517 value_node->set_abstract(input_value->ToAbstract());
518 // add value node to graph
519 auto input_value_node = NewValueNode(value_node);
520 AddValueNodeToGraph(input_value_node);
521 return input_value_node;
522 }
523
TransValueNodeTuple(const AbstractBasePtr & abstract,const ValuePtr & value)524 AnfNodePtr KernelGraph::TransValueNodeTuple(const AbstractBasePtr &abstract, const ValuePtr &value) {
525 MS_EXCEPTION_IF_NULL(abstract);
526 MS_EXCEPTION_IF_NULL(value);
527 if (!abstract->isa<abstract::AbstractSequence>()) {
528 auto new_value_node = NewValueNode(abstract, value);
529 AddValueNodeToGraph(new_value_node);
530 return new_value_node;
531 }
532 auto tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
533 auto value_tuple = value->cast<ValueSequencePtr>();
534 MS_EXCEPTION_IF_NULL(tuple_abstract);
535 MS_EXCEPTION_IF_NULL(value_tuple);
536 if (tuple_abstract->size() != value_tuple->size()) {
537 MS_LOG(EXCEPTION) << "Abstract size:" << tuple_abstract->size()
538 << " is not equal to value size:" << value_tuple->size();
539 }
540 AnfNodePtrList make_tuple_inputs = {
541 mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
542 for (size_t index = 0; index < tuple_abstract->size(); ++index) {
543 make_tuple_inputs.push_back(TransValueNodeTuple((*tuple_abstract)[index], (*value_tuple)[index]));
544 }
545 auto make_tuple = NewCNode(std::move(make_tuple_inputs));
546 MS_EXCEPTION_IF_NULL(make_tuple);
547 make_tuple->set_abstract(tuple_abstract);
548 return make_tuple;
549 }
550
TransParameterTuple(const AbstractBasePtr & abstract)551 AnfNodePtr KernelGraph::TransParameterTuple(const AbstractBasePtr &abstract) {
552 MS_EXCEPTION_IF_NULL(abstract);
553 if (!abstract->isa<abstract::AbstractSequence>()) {
554 return NewParameter(abstract);
555 }
556 auto tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
557 MS_EXCEPTION_IF_NULL(tuple_abstract);
558 AnfNodePtrList make_tuple_inputs = {
559 mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
560 for (size_t index = 0; index < tuple_abstract->size(); ++index) {
561 const auto &abs = (*tuple_abstract)[index];
562 if (abs != nullptr && abs->isa<abstract::AbstractSequence>() &&
563 abs->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
564 make_tuple_inputs.push_back(NewParameter(abs));
565 continue;
566 }
567 make_tuple_inputs.push_back(TransParameterTuple(abs));
568 }
569 auto make_tuple = NewCNode(std::move(make_tuple_inputs));
570 make_tuple->set_abstract(tuple_abstract);
571 return make_tuple;
572 }
573
CreatTupleGetItemNode(const AnfNodePtr & node,size_t output_idx)574 AnfNodePtr KernelGraph::CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx) {
575 auto idx = mindspore::NewValueNode(SizeToLong(output_idx));
576 MS_EXCEPTION_IF_NULL(idx);
577 auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx));
578 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
579 idx->set_abstract(abstract_scalar);
580 AnfNodePtr tuple_getitem = NewCNode({mindspore::NewValueNode(prim::kPrimTupleGetItem), node, idx});
581 MS_EXCEPTION_IF_NULL(tuple_getitem);
582 tuple_getitem->set_scope(node->scope());
583 auto abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
584 MS_EXCEPTION_IF_NULL(abs);
585 auto abs_i = abs->elements()[output_idx];
586 MS_EXCEPTION_IF_NULL(abs_i);
587 tuple_getitem->set_abstract(abs_i);
588 return tuple_getitem;
589 }
590
TransCNodeTuple(const CNodePtr & node)591 AnfNodePtr KernelGraph::TransCNodeTuple(const CNodePtr &node) {
592 MS_EXCEPTION_IF_NULL(node);
593 AnfNodePtrList make_tuple_inputs_list = {mindspore::NewValueNode(prim::kPrimMakeTuple)};
594 size_t output_num = AnfAlgo::GetOutputElementNum(node);
595 std::vector<AbstractBasePtr> abstract_list;
596 for (size_t tuple_out_index = 0; tuple_out_index < output_num; ++tuple_out_index) {
597 auto out = CreatTupleGetItemNode(node, tuple_out_index);
598 MS_EXCEPTION_IF_NULL(out);
599 if (common::AnfAlgo::IsTupleOutput(out)) {
600 out = TransCNodeTuple(out->cast<CNodePtr>());
601 }
602 make_tuple_inputs_list.emplace_back(out);
603 MS_EXCEPTION_IF_NULL(out->abstract());
604 abstract_list.emplace_back(out->abstract()->Clone());
605 }
606 auto make_tuple = NewCNode(std::move(make_tuple_inputs_list));
607 make_tuple->set_scope(node->scope());
608 make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
609 return make_tuple;
610 }
611
TransTupleToMakeTuple(const AnfNodePtr & node)612 AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) {
613 MS_EXCEPTION_IF_NULL(node);
614 if (!common::AnfAlgo::IsTupleOutput(node)) {
615 return node;
616 }
617 if (node->isa<Parameter>()) {
618 if (common::AnfAlgo::IsDynamicSequence(node)) {
619 return NewParameter(node->cast<ParameterPtr>());
620 }
621 return TransParameterTuple(node->abstract());
622 } else if (node->isa<ValueNode>()) {
623 auto value_node = node->cast<ValueNodePtr>();
624 MS_EXCEPTION_IF_NULL(value_node);
625 auto make_tuple = TransValueNodeTuple(value_node->abstract(), value_node->value());
626 if (!RemoveValueNodeFromGraph(value_node)) {
627 MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
628 }
629 return make_tuple;
630 } else if (node->isa<CNode>()) {
631 return TransCNodeTuple(node->cast<CNodePtr>());
632 } else {
633 return nullptr;
634 }
635 }
636
inputs() const637 const AnfNodePtrList &KernelGraph::inputs() const {
638 MS_EXCEPTION_IF_NULL(inputs_);
639 return *inputs_;
640 }
641
FrontBackendMapAdd(const AnfNodePtr & front_anf,const AnfNodePtr & backend_anf)642 void KernelGraph::FrontBackendMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) {
643 MS_EXCEPTION_IF_NULL(front_anf);
644 MS_EXCEPTION_IF_NULL(backend_anf);
645 if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) {
646 MS_LOG(INTERNAL_EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
647 }
648 front_backend_anf_map_[front_anf] = backend_anf;
649 if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) {
650 // If def func(x, y) and call as func(arg, arg) ,then the parameter x and y share same param_info "arg".
651 // In this case, parameter is get from param_info and has been exist in the map. So can't add it to map again.
652 if (backend_anf->isa<Parameter>()) {
653 MS_LOG(INFO) << "Backend parameter already exist, backend parameter:" << backend_anf->DebugString()
654 << ", exist front parameter:" << backend_front_anf_map_[backend_anf]->DebugString();
655 return;
656 }
657 auto front_node = front_anf->cast<CNodePtr>();
658 MS_EXCEPTION_IF_NULL(front_node);
659 auto attr_input = front_node->input(kAnfPrimitiveIndex);
660 MS_EXCEPTION_IF_NULL(attr_input);
661 if (!attr_input->isa<CNode>()) {
662 MS_LOG(INTERNAL_EXCEPTION) << "Kernel " << backend_anf->DebugString()
663 << "has been exist in the backend_front_anf_map_";
664 }
665 }
666 backend_front_anf_map_[backend_anf] = front_anf;
667 }
668
FrontBackendlMapUpdate(const AnfNodePtr & old_backend_anf,const AnfNodePtr & new_backend_anf)669 void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) {
670 MS_EXCEPTION_IF_NULL(old_backend_anf);
671 MS_EXCEPTION_IF_NULL(new_backend_anf);
672 if (old_backend_anf == new_backend_anf) {
673 MS_LOG(DEBUG) << "Old same with new:" << old_backend_anf->DebugString();
674 return;
675 }
676 auto bf_iter = backend_front_anf_map_.find(old_backend_anf);
677 if (bf_iter == backend_front_anf_map_.end()) {
678 MS_LOG(DEBUG) << "Old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
679 return;
680 }
681 auto front_anf = bf_iter->second;
682 auto fb_iter = front_backend_anf_map_.find(front_anf);
683 if (fb_iter == front_backend_anf_map_.end()) {
684 MS_LOG(INTERNAL_EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString();
685 }
686 fb_iter->second = new_backend_anf;
687 // Delete old kernel, should be called before add new item to map.
688 (void)backend_front_anf_map_.erase(bf_iter);
689 backend_front_anf_map_[new_backend_anf] = front_anf;
690 if (IsInternalOutput(old_backend_anf)) {
691 ReplaceInternalOutput(old_backend_anf, new_backend_anf);
692 }
693 }
694
695 // get kernel by anf
GetBackendAnfByFrontAnf(const AnfNodePtr & front_anf)696 AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) {
697 auto iter = front_backend_anf_map_.find(front_anf);
698 if (iter == front_backend_anf_map_.end()) {
699 return nullptr;
700 }
701 return iter->second;
702 }
703
GetFrontAnfByBackendAnf(const AnfNodePtr & backend_anf) const704 AnfNodePtr KernelGraph::GetFrontAnfByBackendAnf(const AnfNodePtr &backend_anf) const {
705 auto iter = backend_front_anf_map_.find(backend_anf);
706 if (iter == backend_front_anf_map_.end()) {
707 return nullptr;
708 }
709 return iter->second;
710 }
711
BackendNodeExistInFrontBackendMap(const AnfNodePtr & backend_anf)712 bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) {
713 return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end();
714 }
715
GetValueNodeByTensor(const mindspore::tensor::TensorPtr & tensor)716 ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) {
717 auto iter = tensor_to_value_node_map_.find(tensor);
718 if (iter == tensor_to_value_node_map_.end()) {
719 return nullptr;
720 }
721 return iter->second;
722 }
723
TensorValueNodeMapAdd(const tensor::TensorPtr & tensor,const ValueNodePtr & value_node)724 void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) {
725 MS_EXCEPTION_IF_NULL(tensor);
726 MS_EXCEPTION_IF_NULL(value_node);
727 tensor_to_value_node_map_[tensor] = value_node;
728 }
729
AddValueNodeToGraph(const ValueNodePtr & value_node)730 void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) {
731 if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) {
732 ++graph_value_nodes_[value_node];
733 } else {
734 graph_value_nodes_[value_node] = 1;
735 }
736 MS_LOG(DEBUG) << "graph:" << ToString()
737 << " add value node:" << (value_node == nullptr ? "null" : value_node->DebugString())
738 << " num:" << graph_value_nodes_[value_node];
739 }
740
RemoveValueNodeFromGraph(const ValueNodePtr & value_node)741 bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
742 if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end() && graph_value_nodes_[value_node] > 1) {
743 --graph_value_nodes_[value_node];
744 return true;
745 }
746 MS_LOG(INFO) << "graph:" << ToString()
747 << " erase value node:" << (value_node == nullptr ? "null" : value_node->DebugString());
748 return graph_value_nodes_.erase(value_node) != 0;
749 }
750
graph_value_nodes() const751 mindspore::HashSet<ValueNodePtr> KernelGraph::graph_value_nodes() const {
752 mindspore::HashSet<ValueNodePtr> value_nodes;
753 (void)std::for_each(graph_value_nodes_.begin(), graph_value_nodes_.end(),
754 [&value_nodes](const auto &node_pair) { (void)value_nodes.emplace(node_pair.first); });
755 return value_nodes;
756 }
757
IsInRefOutputMap(const AnfWithOutIndex & pair) const758 bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; }
759
IsRefOutputMapValue(const AnfWithOutIndex & pair) const760 bool KernelGraph::IsRefOutputMapValue(const AnfWithOutIndex &pair) const {
761 return std::any_of(ref_out_in_map_.cbegin(), ref_out_in_map_.cend(),
762 [&pair](const auto &iter) { return iter.second == pair; });
763 }
764
GetRefCorrespondOutput(const AnfWithOutIndex & out_pair) const765 AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
766 return ref_out_in_map_.at(out_pair);
767 }
768
GetRefNodeRecursive(const AnfWithOutIndex & out_pair) const769 AnfWithOutIndex KernelGraph::GetRefNodeRecursive(const AnfWithOutIndex &out_pair) const {
770 if (IsInRefOutputMap(out_pair)) {
771 const auto &origin_pair = GetRefCorrespondOutput(out_pair);
772 return GetRefNodeRecursive(origin_pair);
773 }
774 return out_pair;
775 }
776
AddRefCorrespondPairs(const AnfWithOutIndex & final_pair,const AnfWithOutIndex & origin_pair)777 void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) {
778 if (IsInRefOutputMap(final_pair)) {
779 MS_LOG(INTERNAL_EXCEPTION) << "Out_pair is already in RefOutputMap, node is " << final_pair.first->DebugString()
780 << ", index is " << final_pair.second;
781 }
782 (void)ref_out_in_map_.emplace(final_pair, origin_pair);
783 }
784
ReplaceRefPair(const AnfWithOutIndex & old_pair,const AnfWithOutIndex & new_pair)785 void KernelGraph::ReplaceRefPair(const AnfWithOutIndex &old_pair, const AnfWithOutIndex &new_pair) {
786 // replace key
787 if (IsInRefOutputMap(old_pair)) {
788 auto tmp = ref_out_in_map_.extract(old_pair);
789 tmp.key() = new_pair;
790 ref_out_in_map_.insert(std::move(tmp));
791 }
792 // replace value
793 for (auto &item : ref_out_in_map_) {
794 if (item.second == old_pair) {
795 item.second = new_pair;
796 }
797 }
798 }
799
SetOutputNodeToTensor(const KernelMapTensor & node_to_tensor)800 void KernelGraph::SetOutputNodeToTensor(const KernelMapTensor &node_to_tensor) {
801 output_node_to_tensor_ = node_to_tensor;
802 for (const auto &item : output_node_to_tensor_) {
803 auto node = item.first.first;
804 auto out_index = item.first.second;
805 if (!common::AnfAlgo::IsNopNode(node)) {
806 continue;
807 }
808 while (common::AnfAlgo::IsNopNode(node)) {
809 const auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, 0);
810 node = kernel_with_index.first;
811 out_index = kernel_with_index.second;
812 }
813 KernelWithIndex real_output{node, out_index};
814 nop_node_output_map_.emplace(real_output, item.first);
815 }
816 }
817
ReplaceGraphInput(const AnfNodePtr & old_parameter,const AnfNodePtr & new_parameter)818 void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter) {
819 // update graph inputs
820 MS_EXCEPTION_IF_NULL(old_parameter);
821 MS_EXCEPTION_IF_NULL(new_parameter);
822 if (old_parameter == new_parameter) {
823 return;
824 }
825 for (size_t i = 0; i < inputs_->size(); i++) {
826 if ((*inputs_)[i] == old_parameter) {
827 MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_parameter->DebugString()
828 << ",new graph input:" << new_parameter->DebugString();
829 (*inputs_)[i] = new_parameter;
830 FrontBackendlMapUpdate(old_parameter, new_parameter);
831 break;
832 }
833 }
834 }
835
ReplaceNode(const AnfNodePtr & old_anf_node,const AnfNodePtr & new_anf_node)836 void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, const AnfNodePtr &new_anf_node) {
837 MS_EXCEPTION_IF_NULL(inputs_);
838 auto it = node_output_edges_.find(old_anf_node);
839 if (it == node_output_edges_.end()) {
840 MS_LOG(WARNING) << "Old node not found " << old_anf_node->DebugString();
841 return;
842 }
843 for (auto &user : it->second) {
844 auto user_cnode = dyn_cast<CNode>(user);
845 MS_EXCEPTION_IF_NULL(user_cnode);
846 auto &inputs = user_cnode->inputs();
847 for (size_t i = 1; i < inputs.size(); i++) {
848 if (inputs[i] == old_anf_node) {
849 user_cnode->set_input(i, new_anf_node);
850 }
851 }
852 }
853 }
854
UpdateExecuteKernelStreamLabel()855 void KernelGraph::UpdateExecuteKernelStreamLabel() {
856 for (auto &kernel : execution_order_) {
857 AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get());
858 }
859 }
860
GetLeafGraphOrder()861 std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
862 std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order;
863 if (IsLeafGraph()) {
864 leaf_graph_order.push_back(shared_from_this()->cast<KernelGraphPtr>());
865 } else {
866 for (const auto &child_graph : child_graph_order_) {
867 std::shared_ptr<KernelGraph> child_graph_ptr = child_graph.lock();
868 MS_EXCEPTION_IF_NULL(child_graph_ptr);
869 auto child_leaf_graph_order = child_graph_ptr->GetLeafGraphOrder();
870 std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order));
871 }
872 }
873 return leaf_graph_order;
874 }
875
IsLeafGraph() const876 bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); }
877
FindNodeByPrimitive(const PrimitivePtr & primitive) const878 std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const {
879 std::vector<CNodePtr> result;
880 for (const auto &anf : execution_order_) {
881 MS_EXCEPTION_IF_NULL(anf);
882 if (common::AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
883 result.push_back(anf->cast<CNodePtr>());
884 }
885 }
886 return result;
887 }
888
FindNodeByPrimitive(const std::vector<PrimitivePtr> & primitive_list) const889 std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const {
890 std::vector<CNodePtr> result;
891 for (const auto &anf : execution_order_) {
892 MS_EXCEPTION_IF_NULL(anf);
893 for (const auto &primitive : primitive_list) {
894 if (common::AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
895 result.push_back(anf->cast<CNodePtr>());
896 }
897 }
898 }
899 return result;
900 }
901
PrintGraphExecuteOrder() const902 void KernelGraph::PrintGraphExecuteOrder() const {
903 if (!(IS_OUTPUT_ON(mindspore::kInfo))) {
904 return;
905 }
906 MS_LOG(INFO) << "Graph " << graph_id_ << " execution order:";
907 for (size_t i = 0; i < execution_order_.size(); i++) {
908 CNodePtr cur_cnode_ptr = execution_order_[i];
909 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
910
911 std::string event_str;
912 if (common::AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) {
913 event_str =
914 ", event id[" + std::to_string(common::AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrEventId)) + "]";
915 }
916
917 std::string label_str;
918 if (common::AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) {
919 label_str =
920 ", label id[" + std::to_string(common::AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrLabelIndex)) + "]";
921 }
922
923 if (common::AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) {
924 auto label_list = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrLabelSwitchList);
925 label_str = ", label id[";
926 for (size_t j = 0; j < label_list.size(); ++j) {
927 label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]");
928 }
929 }
930
931 std::string active_stream_str;
932 if (common::AnfAlgo::HasNodeAttr(kAttrActiveStreamList, cur_cnode_ptr)) {
933 auto stream_list = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrActiveStreamList);
934 active_stream_str = ", active stream id[";
935 for (size_t j = 0; j < stream_list.size(); ++j) {
936 active_stream_str += std::to_string(stream_list[j]) + (j + 1 < stream_list.size() ? ", " : "]");
937 }
938 }
939
940 std::string group_str;
941 if (AnfAlgo::GetKernelType(cur_cnode_ptr) == HCCL_KERNEL &&
942 common::AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode_ptr)) {
943 group_str = ", group[" + common::AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup) + "]";
944 }
945
946 MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
947 << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
948 << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]"
949 << event_str << label_str << active_stream_str << group_str;
950 }
951 }
952
AddInternalOutput(const AnfNodePtr & front_node,const AnfNodePtr & node,size_t output_idx,bool unique_target)953 void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, size_t output_idx,
954 bool unique_target) {
955 if (front_node == nullptr || node == nullptr) {
956 MS_LOG(INFO) << "Front node or node is nullptr";
957 return;
958 }
959 MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString();
960 if (common::AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
961 output_idx = common::AnfAlgo::GetTupleGetItemOutIndex(front_node->cast<CNodePtr>());
962 }
963 front_to_internal_outputs_map_[front_node] = {node, output_idx};
964 SetInternalOutputAttr(node);
965 internal_outputs_to_front_map_[node][output_idx] = std::pair<AnfNodePtr, bool>(front_node, unique_target);
966 }
967
AddInternalOutputTensor(const AnfNodePtr & node,size_t output_idx,const tensor::TensorPtr & tensor)968 void KernelGraph::AddInternalOutputTensor(const AnfNodePtr &node, size_t output_idx, const tensor::TensorPtr &tensor) {
969 if (node == nullptr) {
970 return;
971 }
972 internal_outputs_tensor_map_[node][output_idx] = tensor;
973 }
974
GetInternalOutputTensor(const AnfNodePtr & node,size_t output_idx)975 tensor::TensorPtr KernelGraph::GetInternalOutputTensor(const AnfNodePtr &node, size_t output_idx) {
976 if (node == nullptr) {
977 return nullptr;
978 }
979 auto iter = internal_outputs_tensor_map_.find(node);
980 if (iter == internal_outputs_tensor_map_.end()) {
981 return nullptr;
982 }
983 auto idx_iter = iter->second.find(output_idx);
984 if (idx_iter == iter->second.end()) {
985 return nullptr;
986 }
987 return idx_iter->second;
988 }
989
ReplaceInternalOutput(const AnfNodePtr & node,const AnfNodePtr & new_node)990 void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) {
991 if (new_node == nullptr || node == nullptr) {
992 MS_LOG(INFO) << "New node or node is nullptr";
993 return;
994 }
995 if (node == new_node) {
996 MS_LOG(INFO) << "New node and node is the same";
997 return;
998 }
999 auto iter = internal_outputs_to_front_map_.find(node);
1000 if (iter == internal_outputs_to_front_map_.end()) {
1001 MS_LOG(INFO) << "Node is not internal output";
1002 return;
1003 }
1004 MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString();
1005 auto front_nodes = std::move(iter->second);
1006 // We should do 'erase(iter)' before modify 'internal_outputs_to_front_map_',
1007 // since the 'iter' may be invalidated after new item added.
1008 internal_outputs_to_front_map_.erase(iter);
1009 // Move all front nodes to new node mapping.
1010 for (const auto &front_node_iter : front_nodes) {
1011 front_to_internal_outputs_map_[front_node_iter.second.first] = {new_node, front_node_iter.first};
1012 }
1013 internal_outputs_to_front_map_[new_node] = std::move(front_nodes);
1014 SetInternalOutputAttr(new_node);
1015 }
1016
EnableRuntimeCache() const1017 void KernelGraph::EnableRuntimeCache() const {
1018 auto node_list = TopoSort(get_return());
1019 for (auto &node : node_list) {
1020 auto kernel_info = node->kernel_info();
1021 if (!kernel_info) {
1022 continue;
1023 }
1024 auto runtime_cache = kernel_info->runtime_cache();
1025 runtime_cache.runtime_cache().set_is_valid(true);
1026 }
1027 }
1028
DisableRuntimeCache() const1029 void KernelGraph::DisableRuntimeCache() const {
1030 auto node_list = TopoSort(get_return());
1031 for (auto &node : node_list) {
1032 auto kernel_info = node->kernel_info();
1033 if (!kernel_info) {
1034 continue;
1035 }
1036 auto runtime_cache = kernel_info->runtime_cache();
1037 runtime_cache.runtime_cache().set_is_valid(false);
1038 runtime_cache.runtime_cache().reset();
1039 }
1040 }
1041
ReplaceInternalOutput(const AnfNodePtr & node,const AnfNodePtr & new_node,size_t src_output_idx,size_t dst_output_idx)1042 void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, size_t src_output_idx,
1043 size_t dst_output_idx) {
1044 if (new_node == nullptr || node == nullptr) {
1045 MS_LOG(INFO) << "New node or node is nullptr";
1046 return;
1047 }
1048 if (node == new_node) {
1049 MS_LOG(INFO) << "New node and node is the same";
1050 return;
1051 }
1052 auto iter = internal_outputs_to_front_map_.find(node);
1053 if (iter == internal_outputs_to_front_map_.end()) {
1054 MS_LOG(INFO) << "Node is not internal output";
1055 return;
1056 }
1057 MS_LOG(INFO) << "Replace internal output node " << node->DebugString() << " to " << new_node->DebugString();
1058 auto &front_nodes = iter->second;
1059 // Move specified front node to new node mapping
1060 auto front_node_iter = front_nodes.find(src_output_idx);
1061 if (front_node_iter == front_nodes.end()) {
1062 MS_LOG(INFO) << "The output " << src_output_idx << " of node " << node->DebugString() << " is not an internal node";
1063 return;
1064 }
1065 auto front_node_pair = std::move(front_node_iter->second);
1066 (void)front_nodes.erase(front_node_iter);
1067 if (front_nodes.empty()) {
1068 (void)internal_outputs_to_front_map_.erase(iter);
1069 }
1070 // We should do 'erase' before 'insert', since the 'iter' may be invalidated after new item added.
1071 front_to_internal_outputs_map_[front_node_pair.first] = {new_node, dst_output_idx};
1072 internal_outputs_to_front_map_[new_node][dst_output_idx] = std::move(front_node_pair);
1073 SetInternalOutputAttr(new_node);
1074 }
UpdateInternalParameter()1075 void KernelGraph::UpdateInternalParameter() {
1076 for (const auto &internal_parameter_to_front_node : internal_parameter_to_front_node_map_) {
1077 const auto ¶meter = internal_parameter_to_front_node.first;
1078 const auto &front_node_with_index = internal_parameter_to_front_node.second;
1079 auto front_outputs = common::AnfAlgo::GetAllOutputWithIndex(front_node_with_index.first);
1080 AnfWithOutIndex new_front_node_with_index;
1081 if (front_node_with_index.second < front_outputs.size()) {
1082 new_front_node_with_index = front_outputs[front_node_with_index.second];
1083 } else {
1084 new_front_node_with_index = front_node_with_index;
1085 }
1086
1087 if (new_front_node_with_index.first == nullptr) {
1088 return;
1089 }
1090 MS_LOG(INFO) << "Cache internal parameter: " << parameter->DebugString()
1091 << " to front node: " << new_front_node_with_index.first->DebugString()
1092 << " with index: " << new_front_node_with_index.second
1093 << ", from front node: " << front_node_with_index.first->DebugString()
1094 << " with index: " << front_node_with_index.second;
1095 internal_parameter_to_front_node_map_[parameter] = new_front_node_with_index;
1096 }
1097 }
1098
CacheInternalParameterToFrontNode(const AnfNodePtr & parameter,const AnfWithOutIndex & front_node_with_index)1099 void KernelGraph::CacheInternalParameterToFrontNode(const AnfNodePtr ¶meter,
1100 const AnfWithOutIndex &front_node_with_index) {
1101 if ((parameter == nullptr) || (front_node_with_index.first == nullptr)) {
1102 return;
1103 }
1104 internal_parameter_to_front_node_map_[parameter] = front_node_with_index;
1105 }
1106
GetFrontNodeByInternalParameter(const AnfNodePtr & parameter) const1107 AnfWithOutIndex KernelGraph::GetFrontNodeByInternalParameter(const AnfNodePtr ¶meter) const {
1108 auto iter = internal_parameter_to_front_node_map_.find(parameter);
1109 if (iter != internal_parameter_to_front_node_map_.end()) {
1110 // The load/depend node need fetch the real parameter node.
1111 const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {prim::kPrimDepend,
1112 prim::kPrimLoad};
1113 if (IsOneOfPrimitiveCNode(iter->second.first, auto_monad_prims)) {
1114 return common::AnfAlgo::VisitKernelWithReturnType(iter->second.first, iter->second.second, false);
1115 } else {
1116 return iter->second;
1117 }
1118 }
1119
1120 return AnfWithOutIndex();
1121 }
1122
GetOriginFrontNodeByInternalParameter(const AnfNodePtr & parameter) const1123 AnfWithOutIndex KernelGraph::GetOriginFrontNodeByInternalParameter(const AnfNodePtr ¶meter) const {
1124 auto iter = internal_parameter_to_front_node_map_.find(parameter);
1125 if (iter != internal_parameter_to_front_node_map_.end()) {
1126 return iter->second;
1127 }
1128 return AnfWithOutIndex();
1129 }
1130
GetFuncGraph()1131 FuncGraphPtr KernelGraph::GetFuncGraph() {
1132 for (const auto &front_backend_anf : front_backend_anf_map_) {
1133 const auto &front_node = front_backend_anf.first;
1134 const auto &func_graph = front_node->func_graph();
1135 if (func_graph != nullptr) {
1136 return func_graph;
1137 }
1138 }
1139 return nullptr;
1140 }
1141
CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtrList & backend_outputs,const AnfNodePtrList & front_outputs)1142 void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtrList &backend_outputs,
1143 const AnfNodePtrList &front_outputs) {
1144 MS_LOG(INFO) << "Get graph backend output nodes.";
1145 std::vector<KernelWithIndex> backend_output_nodes;
1146 for (auto &backend_output : backend_outputs) {
1147 auto temp_backend_outputs = common::AnfAlgo::GetAllOutputWithIndex(backend_output);
1148 (void)backend_output_nodes.insert(backend_output_nodes.end(), temp_backend_outputs.cbegin(),
1149 temp_backend_outputs.cend());
1150 }
1151
1152 MS_LOG(INFO) << "Get graph front output nodes.";
1153 std::vector<KernelWithIndex> front_output_nodes;
1154 for (auto &front_output : front_outputs) {
1155 auto temp_front_outputs = common::AnfAlgo::GetAllOutputWithIndex(front_output);
1156 (void)front_output_nodes.insert(front_output_nodes.cend(), temp_front_outputs.cbegin(), temp_front_outputs.cend());
1157 }
1158
1159 if (backend_output_nodes.size() != front_output_nodes.size()) {
1160 MS_LOG(WARNING) << "The size(" << backend_output_nodes.size() << ") of backend outputs is not equal to the size("
1161 << front_output_nodes.size() << ") of front outputs for graph:" << ToString();
1162 return;
1163 }
1164
1165 for (size_t i = 0; i < backend_output_nodes.size(); ++i) {
1166 auto backend_output_node = backend_output_nodes[i];
1167 auto front_output_node = front_output_nodes[i];
1168 graph_output_to_front_node_map_[backend_output_node] = front_output_node;
1169 front_node_to_graph_output_map_[front_output_node] = backend_output_node;
1170 MS_LOG(INFO) << "Backend output: " << backend_output_node.first->fullname_with_scope()
1171 << " with index: " << backend_output_node.second
1172 << " map to front node: " << front_output_node.first->fullname_with_scope()
1173 << " with index: " << front_output_node.second;
1174 }
1175 }
1176
GetTupleGetItemOutputKernelObjectType(const AnfNodePtr & node)1177 kernel::KernelObjectType GetTupleGetItemOutputKernelObjectType(const AnfNodePtr &node) {
1178 MS_EXCEPTION_IF_NULL(node);
1179 auto tuple_get_item = node->cast<CNodePtr>();
1180 auto kernel_with_index = common::AnfAlgo::VisitKernelWithReturnType(tuple_get_item, 0);
1181 auto input_node = kernel_with_index.first;
1182 MS_EXCEPTION_IF_NULL(input_node);
1183 auto output_idx = kernel_with_index.second;
1184 auto kernel_info = dynamic_cast<device::KernelInfo *>(input_node->kernel_info());
1185 MS_LOG(DEBUG) << "GetItem node:" << node->DebugString() << " real node:" << input_node->DebugString()
1186 << " index:" << output_idx << " kernel info:" << kernel_info;
1187 if (kernel_info != nullptr && kernel_info->has_build_info()) {
1188 auto build_info = kernel_info->select_kernel_build_info();
1189 const auto &output_kernel_obj_types = build_info->GetAllOutputKernelObjectTypes();
1190 const auto &output_elements_kernel_obj_types = build_info->GetAllOutputElementsKernelObjectTypes();
1191 MS_LOG(DEBUG) << "real node:" << input_node->fullname_with_scope()
1192 << " output kernel object type:" << output_elements_kernel_obj_types
1193 << " size:" << output_elements_kernel_obj_types.size();
1194 if (output_idx < output_elements_kernel_obj_types.size() && output_kernel_obj_types.size() == 1 &&
1195 output_kernel_obj_types[0] == kernel::KernelObjectType::TUPLE_UNFOLD) {
1196 MS_LOG(DEBUG) << "return type:" << output_elements_kernel_obj_types[output_idx];
1197 return output_elements_kernel_obj_types[output_idx];
1198 } else if (output_kernel_obj_types.size() == 1 && output_kernel_obj_types[0] == kernel::KernelObjectType::TUPLE &&
1199 input_node->abstract() != nullptr && input_node->abstract()->isa<abstract::AbstractSequence>()) {
1200 const auto &sequence_abstract = input_node->abstract()->cast<abstract::AbstractSequencePtr>();
1201 MS_EXCEPTION_IF_NULL(sequence_abstract);
1202 if (sequence_abstract->dynamic_len()) {
1203 MS_EXCEPTION_IF_NULL(sequence_abstract->dynamic_len_element_abs());
1204 return kernel::TypeIdToKernelObjectType(
1205 AnfAlgo::GetAbstractObjectType(sequence_abstract->dynamic_len_element_abs()));
1206 } else {
1207 if (output_idx < sequence_abstract->size()) {
1208 return kernel::TypeIdToKernelObjectType(
1209 AnfAlgo::GetAbstractObjectType(sequence_abstract->elements()[output_idx]));
1210 } else {
1211 MS_LOG(EXCEPTION) << "Invalid index:" << output_idx << " for abstract:" << sequence_abstract->ToString()
1212 << " in node:" << input_node->fullname_with_scope()
1213 << " real node:" << node->fullname_with_scope();
1214 }
1215 }
1216 }
1217 }
1218 if (node->abstract() != nullptr && node->abstract()->isa<abstract::AbstractSequence>()) {
1219 MS_LOG(DEBUG) << "node:" << node->fullname_with_scope() << " abstract:" << node->abstract()->ToString();
1220 const auto &sequence_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
1221 MS_EXCEPTION_IF_NULL(sequence_abs);
1222 if (sequence_abs->dynamic_len()) {
1223 return kernel::KernelObjectType::TUPLE;
1224 }
1225 }
1226 return kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAbstractObjectType(node->abstract()));
1227 }
1228
SetKernelObjectTypesForUnrealNodes() const1229 void KernelGraph::SetKernelObjectTypesForUnrealNodes() const {
1230 auto SetKernelObjectTypesForUnrealNode = [](const AnfNodePtr &node) {
1231 MS_EXCEPTION_IF_NULL(node);
1232 std::vector<kernel::KernelObjectType> output_kernel_object_types;
1233 std::vector<kernel::KernelObjectType> input_kernel_object_types;
1234 if (node->isa<CNode>()) {
1235 auto kernel_info = node->kernel_info_ptr();
1236 MS_EXCEPTION_IF_NULL(kernel_info);
1237 if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) &&
1238 (!kernel_info->has_build_info() || AnfAlgo::GetOutputKernelObjectTypes(node).empty())) {
1239 const auto &input_object_types = AnfAlgo::GetAllInputObjectType(node);
1240 input_kernel_object_types = kernel::TypeIdToKernelObjectTypeForTupleUnfold(input_object_types);
1241 output_kernel_object_types = {kernel::KernelObjectType::TUPLE_UNFOLD};
1242 }
1243 if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem) &&
1244 (!kernel_info->has_build_info() || AnfAlgo::GetOutputKernelObjectTypes(node).empty())) {
1245 output_kernel_object_types = {GetTupleGetItemOutputKernelObjectType(node)};
1246 MS_LOG(DEBUG) << "node:" << node->DebugString() << " output kernel object type:" << output_kernel_object_types;
1247 const auto &input_object_types = AnfAlgo::GetAllInputObjectType(node);
1248 input_kernel_object_types = kernel::TypeIdToKernelObjectTypeForTupleUnfold(input_object_types);
1249 }
1250 }
1251 if (output_kernel_object_types.empty() && input_kernel_object_types.empty()) {
1252 return;
1253 }
1254 kernel::SetKernelObjectTypeBuildInfo(node, input_kernel_object_types, output_kernel_object_types);
1255 };
1256
1257 auto node_list = TopoSort(get_return());
1258 for (auto &node : node_list) {
1259 SetKernelObjectTypesForUnrealNode(node);
1260 }
1261 }
1262
GetFrontNodeWithIndexByGraphOutput(const AnfWithOutIndex & backend_graph_output_with_index) const1263 AnfWithOutIndex KernelGraph::GetFrontNodeWithIndexByGraphOutput(
1264 const AnfWithOutIndex &backend_graph_output_with_index) const {
1265 auto iter = graph_output_to_front_node_map_.find(backend_graph_output_with_index);
1266 if (iter != graph_output_to_front_node_map_.end()) {
1267 return iter->second;
1268 }
1269 return AnfWithOutIndex();
1270 }
1271
GetInternalOutputByFrontNode(const AnfNodePtr & front_node) const1272 AnfWithOutIndex KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
1273 auto iter = front_to_internal_outputs_map_.find(front_node);
1274 if (iter != front_to_internal_outputs_map_.end()) {
1275 return iter->second;
1276 }
1277 return {nullptr, 0};
1278 }
1279
GetGraphOutputByFrontNode(const AnfWithOutIndex & front_node) const1280 AnfWithOutIndex KernelGraph::GetGraphOutputByFrontNode(const AnfWithOutIndex &front_node) const {
1281 auto iter = front_node_to_graph_output_map_.find(front_node);
1282 if (iter != front_node_to_graph_output_map_.end()) {
1283 return iter->second;
1284 }
1285 return AnfWithOutIndex(nullptr, 0);
1286 }
1287
IsInternalOutput(const AnfNodePtr & node) const1288 bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const {
1289 return internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end();
1290 }
1291
IsInternalOutput(const AnfNodePtr & node,size_t output_idx) const1292 bool KernelGraph::IsInternalOutput(const AnfNodePtr &node, size_t output_idx) const {
1293 auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
1294 if (front_nodes_iter == internal_outputs_to_front_map_.end()) {
1295 return false;
1296 }
1297 auto &front_nodes = front_nodes_iter->second;
1298 return front_nodes.find(output_idx) != front_nodes.end();
1299 }
1300
IsUniqueTargetInternalOutput(const AnfNodePtr & node,size_t output_idx) const1301 bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, size_t output_idx) const {
1302 auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
1303 if (front_nodes_iter == internal_outputs_to_front_map_.end()) {
1304 return false;
1305 }
1306 auto &front_nodes = front_nodes_iter->second;
1307 auto idx_iter = front_nodes.find(output_idx);
1308 if (idx_iter == front_nodes.end()) {
1309 return false;
1310 }
1311 return idx_iter->second.second;
1312 }
1313
UpdateChildGraphOrder()1314 void KernelGraph::UpdateChildGraphOrder() {
1315 MS_LOG(INFO) << "Update " << ToString() << " child graph order.";
1316 SetExecOrderByDefault();
1317 auto call_nodes = FindNodeByPrimitive({std::make_shared<Primitive>(prim::kPrimCall->name()),
1318 std::make_shared<Primitive>(prim::kPrimSwitch->name()),
1319 std::make_shared<Primitive>(prim::kPrimSwitchLayer->name())});
1320 std::vector<std::weak_ptr<KernelGraph>> child_graph_order;
1321 for (auto &call_node : call_nodes) {
1322 MS_EXCEPTION_IF_NULL(call_node);
1323 auto call_child_graphs = AnfAlgo::GetCallSwitchKernelGraph(call_node->cast<CNodePtr>());
1324 for (const auto &child_graph : call_child_graphs) {
1325 MS_EXCEPTION_IF_NULL(child_graph);
1326 if (child_graph != parent_graph_.lock()) {
1327 auto shared_this = std::dynamic_pointer_cast<KernelGraph>(shared_from_this());
1328 MS_EXCEPTION_IF_NULL(shared_this);
1329 child_graph->set_parent_graph(shared_this);
1330 }
1331 child_graph_order.push_back(child_graph);
1332 }
1333 }
1334 for (size_t i = 0; i < child_graph_order.size(); ++i) {
1335 std::shared_ptr<KernelGraph> child_graph = child_graph_order[i].lock();
1336 MS_EXCEPTION_IF_NULL(child_graph);
1337 MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph->graph_id() << "]";
1338 }
1339 child_graph_order_ = child_graph_order;
1340 }
1341
RemoveNodeFromGraph(const AnfNodePtr & node)1342 void KernelGraph::RemoveNodeFromGraph(const AnfNodePtr &node) {
1343 MS_EXCEPTION_IF_NULL(node);
1344 auto iter = backend_front_anf_map_.find(node);
1345 if (iter != backend_front_anf_map_.end()) {
1346 (void)front_backend_anf_map_.erase(iter->second);
1347 (void)backend_front_anf_map_.erase(iter);
1348 }
1349 if (node->isa<ValueNode>()) {
1350 (void)RemoveValueNodeFromGraph(node->cast<ValueNodePtr>());
1351 }
1352 }
1353
UpdateGraphDynamicAttr()1354 void KernelGraph::UpdateGraphDynamicAttr() {
1355 for (const auto &cnode : execution_order_) {
1356 if (common::AnfAlgo::IsDynamicShape(cnode)) {
1357 MS_LOG(INFO) << "Update Graph Dynamic Attr";
1358 is_dynamic_shape_ = true;
1359 return;
1360 }
1361 }
1362 is_dynamic_shape_ = false;
1363 }
1364
SetInputNodes()1365 void KernelGraph::SetInputNodes() {
1366 input_nodes_.clear();
1367 for (const auto &input_node : inputs()) {
1368 MS_EXCEPTION_IF_NULL(input_node);
1369 auto params = common::AnfAlgo::GetAllOutput(input_node);
1370 auto abs = input_node->abstract();
1371 MS_EXCEPTION_IF_NULL(abs);
1372 if (params.size() > 1 ||
1373 (abs->isa<abstract::AbstractSequence>() && (!common::AnfAlgo::IsDynamicSequence(input_node))) ||
1374 abs->isa<abstract::AbstractDictionary>()) {
1375 if (backend_front_anf_map_.find(input_node) == backend_front_anf_map_.end()) {
1376 MS_LOG(WARNING) << "Cannot find input_node: " << input_node->DebugString() << " in backend_front_anf_map.";
1377 continue;
1378 }
1379 auto front_node = backend_front_anf_map_[input_node];
1380 for (size_t i = 0; i < params.size(); ++i) {
1381 // Keep the input_node in the map. Otherwise, the SetInputNodes function is not reentrant.
1382 tuple_backend_front_anf_index_map_[params[i]] = AnfWithOutIndex(front_node, i);
1383 }
1384 } else if (params.size() == 1) {
1385 FrontBackendlMapUpdate(input_node, params[0]);
1386 }
1387 std::copy(params.begin(), params.end(), std::back_inserter(input_nodes_));
1388 }
1389 }
1390
UpdateGraphAquireGilAttr()1391 void KernelGraph::UpdateGraphAquireGilAttr() {
1392 for (const auto &cnode : execution_order_) {
1393 if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPyFunc)) {
1394 MS_LOG(INFO) << "The Graph require GIL. Graph id: " << graph_id_;
1395 is_need_gil_ = true;
1396 return;
1397 }
1398 }
1399 }
1400
SetOptimizerFlag()1401 void KernelGraph::SetOptimizerFlag() {
1402 has_optimizer_ = false;
1403 for (const auto &cnode : execution_order_) {
1404 MS_EXCEPTION_IF_NULL(cnode);
1405 if (!common::AnfAlgo::IsUpdateParameterKernel(cnode)) {
1406 continue;
1407 }
1408 for (auto &input : cnode->inputs()) {
1409 MS_EXCEPTION_IF_NULL(input);
1410 auto real_node = common::AnfAlgo::VisitKernel(input, 0).first;
1411 MS_EXCEPTION_IF_NULL(real_node);
1412 if (!real_node->isa<Parameter>()) {
1413 continue;
1414 }
1415 auto param = real_node->cast<ParameterPtr>();
1416 auto abstract = param->abstract();
1417 MS_EXCEPTION_IF_NULL(abstract);
1418 if (abstract->isa<abstract::AbstractRefTensor>()) {
1419 has_optimizer_ = true;
1420 (void)updated_parameters_.insert(param);
1421 }
1422 }
1423 }
1424 }
1425
IsDatasetGraph() const1426 bool KernelGraph::IsDatasetGraph() const {
1427 // check if there is InitDataSetQueue node
1428 const auto &nodes = execution_order_;
1429 // The size of execution_order for the dataset graph is equal to 1.
1430 if (execution_order_.size() > 1) {
1431 return false;
1432 }
1433 for (const auto &node : nodes) {
1434 auto node_name = common::AnfAlgo::GetCNodeName(node);
1435 if (node_name == prim::kPrimInitDataSetQueue->name()) {
1436 return true;
1437 }
1438 }
1439 return false;
1440 }
1441
ToString() const1442 std::string KernelGraph::ToString() const {
1443 std::string prefix = is_from_pynative() ? "pynative_kernel_graph" : "kernel_graph";
1444 return prefix.append(std::to_string(graph_id_));
1445 }
1446
FrontendNodeExistInFrontBackendMap(const AnfNodePtr & frontend_anf)1447 bool KernelGraph::FrontendNodeExistInFrontBackendMap(const AnfNodePtr &frontend_anf) {
1448 return front_backend_anf_map_.find(frontend_anf) != front_backend_anf_map_.end();
1449 }
1450
IsChildGraphResult(const AnfNodePtr & node)1451 bool KernelGraph::IsChildGraphResult(const AnfNodePtr &node) {
1452 AnfNodePtrList child_graph_results;
1453 for (const auto &child_graph_result : child_graph_result_) {
1454 MS_EXCEPTION_IF_NULL(child_graph_result);
1455 auto outputs = common::AnfAlgo::GetAllOutput(child_graph_result);
1456 (void)child_graph_results.insert(child_graph_results.cend(), outputs.cbegin(), outputs.cend());
1457 }
1458
1459 return find(child_graph_results.begin(), child_graph_results.end(), node) != child_graph_results.end();
1460 }
1461
~KernelGraph()1462 KernelGraph::~KernelGraph() {
1463 try {
1464 device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_);
1465 } catch (const std::exception &e) {
1466 MS_LOG(ERROR) << "KernelGraph call destructor failed: " << e.what();
1467 } catch (...) {
1468 MS_LOG(ERROR) << "KernelGraph call destructor failed";
1469 }
1470 }
1471
FetchInputAbstracts(const CNodePtr & cnode)1472 std::vector<abstract::AbstractBasePtr> FetchInputAbstracts(const CNodePtr &cnode) {
1473 MS_EXCEPTION_IF_NULL(cnode);
1474 std::vector<abstract::AbstractBasePtr> abstracts{};
1475 for (size_t i = 1; i < cnode->size(); ++i) {
1476 const auto &input = cnode->inputs()[i];
1477 MS_EXCEPTION_IF_NULL(input);
1478 const auto &abstract = input->abstract();
1479 if (abstract == nullptr) {
1480 MS_LOG(EXCEPTION) << "Invalid abstract for input:" << input->DebugString()
1481 << " for node:" << cnode->fullname_with_scope() << " input index:" << i;
1482 }
1483 MS_LOG(DEBUG) << "Add abstract:" << abstract->ToString() << " for input:" << input->DebugString();
1484 abstracts.emplace_back(abstract);
1485 }
1486 return abstracts;
1487 }
1488
InferType()1489 void KernelGraph::InferType() {
1490 MS_LOG(DEBUG) << "Start infer type for graph:" << ToString();
1491 AnfNodePtrList nodes = TopoSort(get_return());
1492 for (const auto &node : nodes) {
1493 if (node == nullptr || (!node->isa<CNode>())) {
1494 continue;
1495 }
1496 const auto &cnode = node->cast<CNodePtr>();
1497 MS_EXCEPTION_IF_NULL(cnode);
1498 if (cnode->inputs().empty() || (!IsValueNode<Primitive>(cnode->input(0))) ||
1499 common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPyExecute)) {
1500 continue;
1501 }
1502 cnode->set_abstract(nullptr);
1503 MS_LOG(DEBUG) << "Infer abstract for node:" << node->fullname_with_scope();
1504
1505 // Fetch input abstracts.
1506 std::vector<abstract::AbstractBasePtr> abstracts = FetchInputAbstracts(cnode);
1507
1508 // Fetch infer function.
1509 const auto &primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
1510 MS_EXCEPTION_IF_NULL(primitive);
1511 auto abstract_opt = abstract::TryInferAbstract(primitive, abstracts);
1512 if (!abstract_opt.has_value()) {
1513 MS_LOG(EXCEPTION) << "Failed to infer for primitive:" << primitive->ToString()
1514 << " in node:" << cnode->fullname_with_scope();
1515 }
1516 auto abstract = abstract_opt.value();
1517 MS_LOG(INFO) << "Set abstract:" << abstract->ToString() << " for node:" << cnode->DebugString();
1518 cnode->set_abstract(abstract);
1519 }
1520 }
1521
CacheRootWeight(const std::vector<AnfNodePtr> & weights)1522 void KernelGraph::CacheRootWeight(const std::vector<AnfNodePtr> &weights) { root_weights_ = weights; }
1523 } // namespace session
1524 } // namespace mindspore
1525