• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/backend/kernel_graph.h"
17 #include <algorithm>
18 #include <exception>
19 #include <queue>
20 #include <set>
21 #include "abstract/ops/primitive_infer_map.h"
22 #include "backend/common/session/exec_order_builder.h"
23 #include "include/backend/anf_runtime_algorithm.h"
24 #include "include/backend/kernel_info.h"
25 #include "include/backend/optimizer/helper.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "include/common/utils/utils.h"
28 #include "kernel/common_utils.h"
29 #include "kernel/framework_utils.h"
30 #include "kernel/kernel_build_info.h"
31 #include "ops/array_ops.h"
32 #include "ops/op_def.h"
33 #include "ops/framework_ops.h"
34 #include "ops/nn_optimizer_ops.h"
35 #include "ops/other_ops.h"
36 #include "ops/sequence_ops.h"
37 #include "runtime/device/kernel_runtime_manager.h"
38 #include "utils/anf_utils.h"
39 #include "utils/check_convert_utils.h"
40 #include "utils/hash_set.h"
41 
42 namespace mindspore {
43 namespace session {
44 namespace {
45 constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
46 constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
47 constexpr size_t k5dDims = 5;
48 const std::set<std::string> kOpAssignKernelNameList = {mindspore::kAssignOpName, mindspore::kAssignAddOpName,
49                                                        mindspore::kAssignSubOpName};
50 
GetCallRealOutputs(const AnfNodePtr & call_node)51 AnfNodePtrList GetCallRealOutputs(const AnfNodePtr &call_node) {
52   auto item_with_index =
53     common::AnfAlgo::VisitKernelWithReturnType(call_node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
54   AnfNodePtr node = item_with_index.first;
55   MS_EXCEPTION_IF_NULL(node);
56   if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
57     auto outputs = common::AnfAlgo::GetAllOutput(node);
58     std::set<AnfNodePtr> memo;
59     AnfNodePtrList new_output;
60     for (auto &output : outputs) {
61       if (memo.find(output) != memo.end()) {
62         continue;
63       }
64       memo.insert(output);
65       new_output.push_back(output);
66     }
67     if (new_output.size() == 1 && common::AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) {
68       node = new_output[0];
69     }
70   }
71   if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
72     return {node};
73   }
74   AnfNodePtrList real_inputs;
75   auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node->cast<CNodePtr>());
76   for (const auto &child_graph : child_graphs) {
77     MS_EXCEPTION_IF_NULL(child_graph);
78     auto real_input = child_graph->output();
79     auto child_real_inputs = GetCallRealOutputs(real_input);
80     std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs));
81   }
82   return real_inputs;
83 }
84 
IsSameLabel(const CNodePtr & left,const CNodePtr & right)85 bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
86   if (left == right) {
87     return true;
88   }
89   if (left == nullptr || right == nullptr) {
90     return false;
91   }
92   if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) {
93     return false;
94   }
95   if (common::AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && common::AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) {
96     return common::AnfAlgo::GetNodeAttr<uint32_t>(left, kAttrLabelIndex) ==
97            common::AnfAlgo::GetNodeAttr<uint32_t>(right, kAttrLabelIndex);
98   }
99   return false;
100 }
101 
SyncDeviceInfoToValueNode(const ValueNodePtr & value_node,std::vector<std::string> * device_formats,std::vector<TypeId> * device_types)102 void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector<std::string> *device_formats,
103                                std::vector<TypeId> *device_types) {
104   MS_EXCEPTION_IF_NULL(value_node);
105   MS_EXCEPTION_IF_NULL(device_formats);
106   MS_EXCEPTION_IF_NULL(device_types);
107   ValuePtr value = value_node->value();
108   std::vector<tensor::BaseTensorPtr> tensors;
109   TensorValueToTensor(value, &tensors);
110   if (!tensors.empty()) {
111     device_formats->clear();
112     device_types->clear();
113     for (const auto &tensor : tensors) {
114       MS_EXCEPTION_IF_NULL(tensor);
115       auto device_sync = tensor->device_address();
116       if (device_sync != nullptr) {
117         auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
118         MS_EXCEPTION_IF_NULL(device_address);
119         device_formats->emplace_back(device_address->format());
120         device_types->emplace_back(device_address->type_id());
121         continue;
122       }
123       device_formats->emplace_back(kOpFormat_DEFAULT);
124       device_types->emplace_back(kTypeUnknown);
125     }
126   }
127 }
128 
SetInternalOutputAttr(const AnfNodePtr & node)129 void SetInternalOutputAttr(const AnfNodePtr &node) {
130   if (!common::AnfAlgo::IsNopNode(node)) {
131     return;
132   }
133   auto p = GetCNodePrimitive(node);
134   if (p == nullptr) {
135     return;
136   }
137   auto prim_node = NewValueNode(p->Clone());
138   MS_EXCEPTION_IF_NULL(node);
139   auto cnode = node->cast<CNodePtr>();
140   MS_EXCEPTION_IF_NULL(cnode);
141   cnode->set_input(kAnfPrimitiveIndex, prim_node);
142   common::AnfAlgo::SetNodeAttr(kAttrIsInternalOutputNopNode, MakeValue(true), node);
143 }
144 }  // namespace
145 
MakeValueNode(const AnfNodePtr & node) const146 AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) const {
147   MS_EXCEPTION_IF_NULL(node);
148   auto value_node = node->cast<ValueNodePtr>();
149   if (value_node == nullptr) {
150     return nullptr;
151   }
152   ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
153   MS_EXCEPTION_IF_NULL(new_value_node);
154   new_value_node->set_abstract(value_node->abstract());
155   this->SetKernelInfoForNode(new_value_node);
156   return new_value_node;
157 }
158 
outputs() const159 AnfNodePtrList KernelGraph::outputs() const {
160   auto graph_output = output();
161   if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) {
162     auto make_tuple = output()->cast<CNodePtr>();
163     MS_EXCEPTION_IF_NULL(make_tuple);
164     auto &inputs = make_tuple->inputs();
165     return AnfNodePtrList(inputs.begin() + 1, inputs.end());
166   }
167   return AnfNodePtrList(1, graph_output);
168 }
169 
SetNodeOutputEdges()170 void KernelGraph::SetNodeOutputEdges() {
171   node_output_edges_.clear();
172   std::queue<AnfNodePtr> to_visit;
173   to_visit.emplace(get_return());
174   auto seen = NewSeenGeneration();
175   while (!to_visit.empty()) {
176     auto node = to_visit.front();
177     to_visit.pop();
178     MS_EXCEPTION_IF_NULL(node);
179     if (!node->isa<CNode>()) {
180       continue;
181     }
182     auto cnode = node->cast<CNodePtr>();
183     MS_EXCEPTION_IF_NULL(cnode);
184     for (auto &input : cnode->inputs()) {
185       (void)node_output_edges_[input].emplace_back(node);
186       if (input->seen_ == seen) {
187         continue;
188       }
189       to_visit.emplace(input);
190       input->seen_ = seen;
191     }
192   }
193 }
194 
SetExecOrderByDefault()195 void KernelGraph::SetExecOrderByDefault() {
196   ExecOrderBuilder builder;
197   builder.Build(this, &execution_order_, &node_output_edges_);
198   execution_order_ = SortStartLabelAndEndGoto();
199 }
200 
SortStartLabelAndEndGoto()201 std::vector<CNodePtr> KernelGraph::SortStartLabelAndEndGoto() {
202   std::vector<CNodePtr> re_order;
203   if (start_label_ != nullptr) {
204     re_order.emplace_back(start_label_);
205   }
206   for (auto &node : execution_order_) {
207     if (node == start_label_ || node == end_goto_) {
208       continue;
209     }
210 
211     if (IsSameLabel(node, end_goto_)) {
212       end_goto_ = node;
213       MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id();
214       continue;
215     }
216 
217     if (IsSameLabel(node, start_label_)) {
218       start_label_ = node;
219       MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id();
220       continue;
221     }
222 
223     //
224     // Re-order:
225     //   u = LabelGoto(...)
226     //   x = Mul(...)
227     //   LabelSet(u)
228     // To:
229     //   u = LabelGoto(...)
230     //   LabelSet(u)
231     //   x = Mul(...)
232     // This prevent Mul be skipped.
233     //
234     if (IsPrimitiveCNode(node, prim::kPrimLabelSet) && (re_order.back() != node->input(1))) {
235       auto iter = std::find(re_order.crbegin() + 1, re_order.crend(), node->input(1));
236       if (iter != re_order.rend()) {
237         re_order.insert(iter.base(), node);
238         continue;
239       }
240     }
241 
242     re_order.emplace_back(node);
243   }
244   if (end_goto_ != nullptr) {
245     re_order.emplace_back(end_goto_);
246   }
247   return re_order;
248 }
249 
NewCNodeWeak(AnfNodeWeakPtrList && weak_inputs)250 CNodePtr KernelGraph::NewCNodeWeak(AnfNodeWeakPtrList &&weak_inputs) {
251   auto cnode = FuncGraph::NewCNodeWeak(std::move(weak_inputs));
252   PostNewCNode(cnode);
253   return cnode;
254 }
255 
NewCNodeWeak(const AnfNodeWeakPtrList & weak_inputs)256 CNodePtr KernelGraph::NewCNodeWeak(const AnfNodeWeakPtrList &weak_inputs) {
257   auto cnode = FuncGraph::NewCNodeWeak(weak_inputs);
258   PostNewCNode(cnode);
259   return cnode;
260 }
261 
NewCNode(AnfNodePtrList && inputs)262 CNodePtr KernelGraph::NewCNode(AnfNodePtrList &&inputs) {
263   auto cnode = FuncGraph::NewCNode(std::move(inputs));
264   PostNewCNode(cnode);
265   return cnode;
266 }
267 
NewCNode(const AnfNodePtrList & inputs)268 CNodePtr KernelGraph::NewCNode(const AnfNodePtrList &inputs) {
269   auto cnode = FuncGraph::NewCNode(inputs);
270   PostNewCNode(cnode);
271   return cnode;
272 }
273 
PostNewCNode(const CNodePtr & cnode) const274 void KernelGraph::PostNewCNode(const CNodePtr &cnode) const {
275   MS_EXCEPTION_IF_NULL(cnode);
276   if (cnode->abstract() == nullptr) {
277     cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
278   }
279   if (common::AnfAlgo::IsGraphKernel(cnode)) {
280     CreateKernelInfoFromNewParameter(cnode);
281   }
282   if (common::AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
283     common::AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
284   }
285   if (cnode->kernel_info() == nullptr) {
286     SetKernelInfoForNode(cnode);
287   }
288   AnfAlgo::SetGraphId(graph_id_, cnode.get());
289 }
290 
NewCNodeWithInfos(const AnfNodePtrList & inputs,const CNodePtr & ori_cnode)291 CNodePtr KernelGraph::NewCNodeWithInfos(const AnfNodePtrList &inputs, const CNodePtr &ori_cnode) {
292   auto cnode = NewCNode(inputs);
293   if (ori_cnode != nullptr) {
294     cnode->set_attrs(ori_cnode->attrs());
295     cnode->set_primal_attrs(ori_cnode->primal_attrs());
296     cnode->set_primal_debug_infos(ori_cnode->primal_debug_infos());
297   }
298   return cnode;
299 }
300 
CreateKernelInfoFromNewParameter(const CNodePtr & cnode) const301 void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) const {
302   auto func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
303   MS_EXCEPTION_IF_NULL(func_graph);
304 
305   AnfNodePtrList node_list;
306   AnfNodePtrList input_list;
307   AnfNodePtrList output_list;
308   kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
309   for (auto &anf_node : node_list) {
310     MS_EXCEPTION_IF_NULL(anf_node);
311     if (anf_node->kernel_info() == nullptr) {
312       anf_node->set_kernel_info(std::make_shared<device::KernelInfo>());
313     }
314     auto anf_cnode = anf_node->cast<CNodePtr>();
315     MS_EXCEPTION_IF_NULL(anf_cnode);
316     size_t input_num = common::AnfAlgo::GetInputTensorNum(anf_cnode);
317     for (size_t i = 0; i < input_num; ++i) {
318       auto input_node = anf_cnode->input(i + 1);
319       MS_EXCEPTION_IF_NULL(input_node);
320       if (IsValueNode<tensor::Tensor>(input_node)) {
321         auto new_input_node = MakeValueNode(input_node);
322         if (new_input_node != nullptr) {
323           anf_cnode->set_input(i + 1, new_input_node);
324         }
325       }
326     }
327   }
328   for (auto &anf_node : input_list) {
329     MS_EXCEPTION_IF_NULL(anf_node);
330     if (anf_node->kernel_info() == nullptr) {
331       anf_node->set_kernel_info(std::make_shared<device::KernelInfo>());
332     }
333   }
334 }
335 
ResetAssignInputFeatureMapFlag(const CNodePtr & cnode) const336 void KernelGraph::ResetAssignInputFeatureMapFlag(const CNodePtr &cnode) const {
337   if (kOpAssignKernelNameList.find(common::AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) {
338     MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map "
339                          "flag but got the node :"
340                       << cnode->DebugString();
341   }
342   auto input_node = common::AnfAlgo::GetInputNode(cnode, 0);
343   MS_EXCEPTION_IF_NULL(input_node);
344   auto assign_value_node = common::AnfAlgo::GetInputNode(cnode, 1);
345   if (AnfAlgo::IsFeatureMapOutput(input_node)) {
346     return;
347   }
348   if (!AnfAlgo::IsFeatureMapOutput(input_node) && AnfAlgo::IsFeatureMapOutput(assign_value_node)) {
349     auto kernel_info = dynamic_cast<device::KernelInfo *>(input_node->kernel_info());
350     MS_EXCEPTION_IF_NULL(kernel_info);
351     kernel_info->set_feature_map_flag(true);
352   }
353 }
354 
SetKernelInfoForNode(const AnfNodePtr & node) const355 void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
356   MS_EXCEPTION_IF_NULL(node);
357   auto kernel_info = std::make_shared<device::KernelInfo>();
358   MS_EXCEPTION_IF_NULL(kernel_info);
359   node->set_kernel_info(kernel_info);
360   if (node->isa<CNode>()) {
361     if (kOpAssignKernelNameList.find(common::AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) {
362       ResetAssignInputFeatureMapFlag(node->cast<CNodePtr>());
363     }
364 #if defined(__APPLE__)
365     std::vector<int> feature_map_input_indexs;
366 #else
367     std::vector<size_t> feature_map_input_indexs;
368 #endif
369     kernel_info->set_feature_map_flag(false);
370     size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
371     for (size_t index = 0; index < input_num; ++index) {
372       if (AnfAlgo::IsFeatureMapInput(node, index)) {
373         kernel_info->set_feature_map_flag(true);
374         feature_map_input_indexs.push_back(index);
375       }
376     }
377     if (common::AnfAlgo::GetInputTensorNum(node) == 0) {
378       kernel_info->set_feature_map_flag(true);
379     }
380     if (AnfUtils::IsRealKernel(node)) {
381       // if the node only has the primitive(such as getNext) or the node's input has a feature map input
382       // then the node's output is a feature map output
383       common::AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node);
384       common::AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), node);
385     }
386     return;
387   }
388   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
389   MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
390   // set the format of value_node to DEFAULT_FORMAT
391   std::vector<TypeId> types;
392   std::vector<std::string> formats = {kOpFormat_DEFAULT};
393   if (node->isa<ValueNode>()) {
394     kernel_info->set_feature_map_flag(false);
395     (void)types.emplace_back(kTypeUnknown);
396     auto value_node = node->cast<ValueNodePtr>();
397     SyncDeviceInfoToValueNode(value_node, &formats, &types);
398   }
399   if (node->isa<Parameter>()) {
400     auto parameter = node->cast<ParameterPtr>();
401     MS_EXCEPTION_IF_NULL(parameter);
402     bool is_weight = common::AnfAlgo::IsParameterWeight(parameter);
403     kernel_info->set_feature_map_flag(!is_weight);
404     types.push_back(is_weight ? kTypeUnknown : common::AnfAlgo::GetOutputInferDataType(parameter, 0));
405   }
406   // set parameter initaial device data type
407   auto abs = node->abstract();
408   auto abs_type = AnfAlgo::GetAbstractObjectType(abs);
409   auto kernel_object_type = kernel::TypeIdToKernelObjectTypeForTupleUnfold(abs_type);
410   if (common::AnfAlgo::IsDynamicSequence(node) || (node->isa<ValueNode>() && AnfAlgo::IsSequenceOutputOfScalar(node))) {
411     kernel_object_type = kernel::KernelObjectType::TUPLE;
412   } else if (abs_type == kObjectTypeTuple || abs_type == kObjectTypeList) {
413     auto tuple_len = AnfAlgo::GetOutputElementNum(node);
414     formats = std::vector<std::string>(tuple_len, formats[0]);
415     types = std::vector<TypeId>(tuple_len, types[0]);
416   }
417   kernel_build_info_builder->SetOutputsKernelObjectType({kernel_object_type});
418   kernel_build_info_builder->SetOutputsFormat(formats);
419   kernel_build_info_builder->SetOutputsDeviceType(types);
420   MS_LOG(DEBUG) << "Kernel object type is:" << TypeIdLabel(abs_type)
421                 << " for parameter or value node:" << node->fullname_with_scope()
422                 << ", debug name:" << node->DebugString();
423   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get());
424 }
425 
NewCNode(const CNodePtr & cnode)426 CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
427   MS_EXCEPTION_IF_NULL(cnode);
428   auto new_cnode = std::make_shared<CNode>(*cnode);
429   new_cnode->CloneUserData(cnode);
430   new_cnode->set_scope(cnode->scope());
431   new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
432   // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map
433   if (BackendNodeExistInFrontBackendMap(cnode)) {
434     FrontBackendlMapUpdate(cnode, new_cnode);
435   }
436   AnfAlgo::SetGraphId(graph_id_, cnode.get());
437   return new_cnode;
438 }
439 
NewParameter(const ParameterPtr & parameter)440 ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
441   auto abstract = parameter == nullptr ? std::make_shared<abstract::AbstractNone>() : parameter->abstract();
442   auto new_parameter = NewParameter(abstract);
443   MS_EXCEPTION_IF_NULL(new_parameter);
444   // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
445   if (parameter != nullptr) {
446     new_parameter->set_name(parameter->name());
447     if (common::AnfAlgo::IsParameterWeight(parameter)) {
448       new_parameter->set_default_param(parameter->default_param());
449     }
450   } else {
451     // The created parameter name is empty, so set name to ensure that the parameter name is unique.
452     new_parameter->set_name(new_parameter->UniqueName());
453   }
454   // create kernel_info form new parameter
455   SetKernelInfoForNode(new_parameter);
456   AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
457   return new_parameter;
458 }
459 
NewParameter(const abstract::AbstractBasePtr & abstract)460 ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) {
461   ParameterPtr new_parameter = add_parameter();
462   MS_EXCEPTION_IF_NULL(new_parameter);
463   new_parameter->set_abstract(abstract);
464   // The created parameter name is empty, so set name to ensure that the parameter name is unique.
465   new_parameter->set_name(new_parameter->UniqueName());
466   // create kernel_info form new parameter
467   SetKernelInfoForNode(new_parameter);
468   AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
469   return new_parameter;
470 }
471 
NewValueNode(const ValueNodePtr & value_node) const472 ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) const {
473   MS_EXCEPTION_IF_NULL(value_node);
474   auto new_value_node = MakeValueNode(value_node)->cast<ValueNodePtr>();
475   SetKernelInfoForNode(new_value_node);
476   AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
477   return new_value_node;
478 }
479 
NewValueNode(const AbstractBasePtr & abstract,const ValuePtr & value)480 ValueNodePtr KernelGraph::NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value) {
481   MS_EXCEPTION_IF_NULL(abstract);
482   MS_EXCEPTION_IF_NULL(value);
483   ValueNodePtr new_value_node = std::make_shared<ValueNode>(value);
484   MS_EXCEPTION_IF_NULL(new_value_node);
485   new_value_node->set_abstract(abstract);
486   SetKernelInfoForNode(new_value_node);
487   AnfAlgo::SetGraphId(graph_id(), new_value_node.get());
488   AddValueNodeToGraph(new_value_node);
489   return new_value_node;
490 }
491 
NewValueNode(const tensor::TensorPtr & input_tensor)492 ValueNodePtr KernelGraph::NewValueNode(const tensor::TensorPtr &input_tensor) {
493   MS_EXCEPTION_IF_NULL(input_tensor);
494   ValueNodePtr value_node = nullptr;
495   if (input_tensor->data_type() == kObjectTypeString) {
496     std::string value_string;
497     (void)value_string.assign(static_cast<char *>(input_tensor->data_c()), LongToSize(input_tensor->data().size()));
498     StringImmPtr string_imm_value = std::make_shared<StringImm>(value_string);
499     value_node = std::make_shared<ValueNode>(string_imm_value);
500   } else {
501     value_node = std::make_shared<ValueNode>(input_tensor);
502   }
503   MS_EXCEPTION_IF_NULL(value_node);
504   value_node->set_abstract(input_tensor->ToAbstract());
505   // add value node to graph
506   auto input_value_node = NewValueNode(value_node);
507   AddValueNodeToGraph(input_value_node);
508   return input_value_node;
509 }
510 
NewValueNode(const ValuePtr & input_value)511 ValueNodePtr KernelGraph::NewValueNode(const ValuePtr &input_value) {
512   if (input_value->isa<tensor::Tensor>()) {
513     return NewValueNode(input_value->cast<tensor::TensorPtr>());
514   }
515 
516   auto value_node = std::make_shared<ValueNode>(input_value);
517   value_node->set_abstract(input_value->ToAbstract());
518   // add value node to graph
519   auto input_value_node = NewValueNode(value_node);
520   AddValueNodeToGraph(input_value_node);
521   return input_value_node;
522 }
523 
TransValueNodeTuple(const AbstractBasePtr & abstract,const ValuePtr & value)524 AnfNodePtr KernelGraph::TransValueNodeTuple(const AbstractBasePtr &abstract, const ValuePtr &value) {
525   MS_EXCEPTION_IF_NULL(abstract);
526   MS_EXCEPTION_IF_NULL(value);
527   if (!abstract->isa<abstract::AbstractSequence>()) {
528     auto new_value_node = NewValueNode(abstract, value);
529     AddValueNodeToGraph(new_value_node);
530     return new_value_node;
531   }
532   auto tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
533   auto value_tuple = value->cast<ValueSequencePtr>();
534   MS_EXCEPTION_IF_NULL(tuple_abstract);
535   MS_EXCEPTION_IF_NULL(value_tuple);
536   if (tuple_abstract->size() != value_tuple->size()) {
537     MS_LOG(EXCEPTION) << "Abstract size:" << tuple_abstract->size()
538                       << " is not equal to value size:" << value_tuple->size();
539   }
540   AnfNodePtrList make_tuple_inputs = {
541     mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
542   for (size_t index = 0; index < tuple_abstract->size(); ++index) {
543     make_tuple_inputs.push_back(TransValueNodeTuple((*tuple_abstract)[index], (*value_tuple)[index]));
544   }
545   auto make_tuple = NewCNode(std::move(make_tuple_inputs));
546   MS_EXCEPTION_IF_NULL(make_tuple);
547   make_tuple->set_abstract(tuple_abstract);
548   return make_tuple;
549 }
550 
TransParameterTuple(const AbstractBasePtr & abstract)551 AnfNodePtr KernelGraph::TransParameterTuple(const AbstractBasePtr &abstract) {
552   MS_EXCEPTION_IF_NULL(abstract);
553   if (!abstract->isa<abstract::AbstractSequence>()) {
554     return NewParameter(abstract);
555   }
556   auto tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
557   MS_EXCEPTION_IF_NULL(tuple_abstract);
558   AnfNodePtrList make_tuple_inputs = {
559     mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
560   for (size_t index = 0; index < tuple_abstract->size(); ++index) {
561     const auto &abs = (*tuple_abstract)[index];
562     if (abs != nullptr && abs->isa<abstract::AbstractSequence>() &&
563         abs->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
564       make_tuple_inputs.push_back(NewParameter(abs));
565       continue;
566     }
567     make_tuple_inputs.push_back(TransParameterTuple(abs));
568   }
569   auto make_tuple = NewCNode(std::move(make_tuple_inputs));
570   make_tuple->set_abstract(tuple_abstract);
571   return make_tuple;
572 }
573 
CreatTupleGetItemNode(const AnfNodePtr & node,size_t output_idx)574 AnfNodePtr KernelGraph::CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx) {
575   auto idx = mindspore::NewValueNode(SizeToLong(output_idx));
576   MS_EXCEPTION_IF_NULL(idx);
577   auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx));
578   auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
579   idx->set_abstract(abstract_scalar);
580   AnfNodePtr tuple_getitem = NewCNode({mindspore::NewValueNode(prim::kPrimTupleGetItem), node, idx});
581   MS_EXCEPTION_IF_NULL(tuple_getitem);
582   tuple_getitem->set_scope(node->scope());
583   auto abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
584   MS_EXCEPTION_IF_NULL(abs);
585   auto abs_i = abs->elements()[output_idx];
586   MS_EXCEPTION_IF_NULL(abs_i);
587   tuple_getitem->set_abstract(abs_i);
588   return tuple_getitem;
589 }
590 
TransCNodeTuple(const CNodePtr & node)591 AnfNodePtr KernelGraph::TransCNodeTuple(const CNodePtr &node) {
592   MS_EXCEPTION_IF_NULL(node);
593   AnfNodePtrList make_tuple_inputs_list = {mindspore::NewValueNode(prim::kPrimMakeTuple)};
594   size_t output_num = AnfAlgo::GetOutputElementNum(node);
595   std::vector<AbstractBasePtr> abstract_list;
596   for (size_t tuple_out_index = 0; tuple_out_index < output_num; ++tuple_out_index) {
597     auto out = CreatTupleGetItemNode(node, tuple_out_index);
598     MS_EXCEPTION_IF_NULL(out);
599     if (common::AnfAlgo::IsTupleOutput(out)) {
600       out = TransCNodeTuple(out->cast<CNodePtr>());
601     }
602     make_tuple_inputs_list.emplace_back(out);
603     MS_EXCEPTION_IF_NULL(out->abstract());
604     abstract_list.emplace_back(out->abstract()->Clone());
605   }
606   auto make_tuple = NewCNode(std::move(make_tuple_inputs_list));
607   make_tuple->set_scope(node->scope());
608   make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
609   return make_tuple;
610 }
611 
TransTupleToMakeTuple(const AnfNodePtr & node)612 AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) {
613   MS_EXCEPTION_IF_NULL(node);
614   if (!common::AnfAlgo::IsTupleOutput(node)) {
615     return node;
616   }
617   if (node->isa<Parameter>()) {
618     if (common::AnfAlgo::IsDynamicSequence(node)) {
619       return NewParameter(node->cast<ParameterPtr>());
620     }
621     return TransParameterTuple(node->abstract());
622   } else if (node->isa<ValueNode>()) {
623     auto value_node = node->cast<ValueNodePtr>();
624     MS_EXCEPTION_IF_NULL(value_node);
625     auto make_tuple = TransValueNodeTuple(value_node->abstract(), value_node->value());
626     if (!RemoveValueNodeFromGraph(value_node)) {
627       MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
628     }
629     return make_tuple;
630   } else if (node->isa<CNode>()) {
631     return TransCNodeTuple(node->cast<CNodePtr>());
632   } else {
633     return nullptr;
634   }
635 }
636 
inputs() const637 const AnfNodePtrList &KernelGraph::inputs() const {
638   MS_EXCEPTION_IF_NULL(inputs_);
639   return *inputs_;
640 }
641 
FrontBackendMapAdd(const AnfNodePtr & front_anf,const AnfNodePtr & backend_anf)642 void KernelGraph::FrontBackendMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) {
643   MS_EXCEPTION_IF_NULL(front_anf);
644   MS_EXCEPTION_IF_NULL(backend_anf);
645   if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) {
646     MS_LOG(INTERNAL_EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
647   }
648   front_backend_anf_map_[front_anf] = backend_anf;
649   if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) {
650     // If def func(x, y) and call as func(arg, arg) ,then the parameter x and y share same param_info "arg".
651     // In this case, parameter is get from param_info and has been exist in the map. So can't add it to map again.
652     if (backend_anf->isa<Parameter>()) {
653       MS_LOG(INFO) << "Backend parameter already exist, backend parameter:" << backend_anf->DebugString()
654                    << ", exist front parameter:" << backend_front_anf_map_[backend_anf]->DebugString();
655       return;
656     }
657     auto front_node = front_anf->cast<CNodePtr>();
658     MS_EXCEPTION_IF_NULL(front_node);
659     auto attr_input = front_node->input(kAnfPrimitiveIndex);
660     MS_EXCEPTION_IF_NULL(attr_input);
661     if (!attr_input->isa<CNode>()) {
662       MS_LOG(INTERNAL_EXCEPTION) << "Kernel " << backend_anf->DebugString()
663                                  << "has been exist in the backend_front_anf_map_";
664     }
665   }
666   backend_front_anf_map_[backend_anf] = front_anf;
667 }
668 
FrontBackendlMapUpdate(const AnfNodePtr & old_backend_anf,const AnfNodePtr & new_backend_anf)669 void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) {
670   MS_EXCEPTION_IF_NULL(old_backend_anf);
671   MS_EXCEPTION_IF_NULL(new_backend_anf);
672   if (old_backend_anf == new_backend_anf) {
673     MS_LOG(DEBUG) << "Old same with new:" << old_backend_anf->DebugString();
674     return;
675   }
676   auto bf_iter = backend_front_anf_map_.find(old_backend_anf);
677   if (bf_iter == backend_front_anf_map_.end()) {
678     MS_LOG(DEBUG) << "Old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
679     return;
680   }
681   auto front_anf = bf_iter->second;
682   auto fb_iter = front_backend_anf_map_.find(front_anf);
683   if (fb_iter == front_backend_anf_map_.end()) {
684     MS_LOG(INTERNAL_EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString();
685   }
686   fb_iter->second = new_backend_anf;
687   // Delete old kernel, should be called before add new item to map.
688   (void)backend_front_anf_map_.erase(bf_iter);
689   backend_front_anf_map_[new_backend_anf] = front_anf;
690   if (IsInternalOutput(old_backend_anf)) {
691     ReplaceInternalOutput(old_backend_anf, new_backend_anf);
692   }
693 }
694 
695 // get kernel by anf
GetBackendAnfByFrontAnf(const AnfNodePtr & front_anf)696 AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) {
697   auto iter = front_backend_anf_map_.find(front_anf);
698   if (iter == front_backend_anf_map_.end()) {
699     return nullptr;
700   }
701   return iter->second;
702 }
703 
GetFrontAnfByBackendAnf(const AnfNodePtr & backend_anf) const704 AnfNodePtr KernelGraph::GetFrontAnfByBackendAnf(const AnfNodePtr &backend_anf) const {
705   auto iter = backend_front_anf_map_.find(backend_anf);
706   if (iter == backend_front_anf_map_.end()) {
707     return nullptr;
708   }
709   return iter->second;
710 }
711 
BackendNodeExistInFrontBackendMap(const AnfNodePtr & backend_anf)712 bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) {
713   return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end();
714 }
715 
GetValueNodeByTensor(const mindspore::tensor::TensorPtr & tensor)716 ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) {
717   auto iter = tensor_to_value_node_map_.find(tensor);
718   if (iter == tensor_to_value_node_map_.end()) {
719     return nullptr;
720   }
721   return iter->second;
722 }
723 
TensorValueNodeMapAdd(const tensor::TensorPtr & tensor,const ValueNodePtr & value_node)724 void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) {
725   MS_EXCEPTION_IF_NULL(tensor);
726   MS_EXCEPTION_IF_NULL(value_node);
727   tensor_to_value_node_map_[tensor] = value_node;
728 }
729 
AddValueNodeToGraph(const ValueNodePtr & value_node)730 void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) {
731   if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) {
732     ++graph_value_nodes_[value_node];
733   } else {
734     graph_value_nodes_[value_node] = 1;
735   }
736   MS_LOG(DEBUG) << "graph:" << ToString()
737                 << " add value node:" << (value_node == nullptr ? "null" : value_node->DebugString())
738                 << " num:" << graph_value_nodes_[value_node];
739 }
740 
RemoveValueNodeFromGraph(const ValueNodePtr & value_node)741 bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
742   if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end() && graph_value_nodes_[value_node] > 1) {
743     --graph_value_nodes_[value_node];
744     return true;
745   }
746   MS_LOG(INFO) << "graph:" << ToString()
747                << " erase value node:" << (value_node == nullptr ? "null" : value_node->DebugString());
748   return graph_value_nodes_.erase(value_node) != 0;
749 }
750 
graph_value_nodes() const751 mindspore::HashSet<ValueNodePtr> KernelGraph::graph_value_nodes() const {
752   mindspore::HashSet<ValueNodePtr> value_nodes;
753   (void)std::for_each(graph_value_nodes_.begin(), graph_value_nodes_.end(),
754                       [&value_nodes](const auto &node_pair) { (void)value_nodes.emplace(node_pair.first); });
755   return value_nodes;
756 }
757 
IsInRefOutputMap(const AnfWithOutIndex & pair) const758 bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; }
759 
IsRefOutputMapValue(const AnfWithOutIndex & pair) const760 bool KernelGraph::IsRefOutputMapValue(const AnfWithOutIndex &pair) const {
761   return std::any_of(ref_out_in_map_.cbegin(), ref_out_in_map_.cend(),
762                      [&pair](const auto &iter) { return iter.second == pair; });
763 }
764 
GetRefCorrespondOutput(const AnfWithOutIndex & out_pair) const765 AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
766   return ref_out_in_map_.at(out_pair);
767 }
768 
GetRefNodeRecursive(const AnfWithOutIndex & out_pair) const769 AnfWithOutIndex KernelGraph::GetRefNodeRecursive(const AnfWithOutIndex &out_pair) const {
770   if (IsInRefOutputMap(out_pair)) {
771     const auto &origin_pair = GetRefCorrespondOutput(out_pair);
772     return GetRefNodeRecursive(origin_pair);
773   }
774   return out_pair;
775 }
776 
AddRefCorrespondPairs(const AnfWithOutIndex & final_pair,const AnfWithOutIndex & origin_pair)777 void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) {
778   if (IsInRefOutputMap(final_pair)) {
779     MS_LOG(INTERNAL_EXCEPTION) << "Out_pair is already in RefOutputMap, node is " << final_pair.first->DebugString()
780                                << ", index is " << final_pair.second;
781   }
782   (void)ref_out_in_map_.emplace(final_pair, origin_pair);
783 }
784 
ReplaceRefPair(const AnfWithOutIndex & old_pair,const AnfWithOutIndex & new_pair)785 void KernelGraph::ReplaceRefPair(const AnfWithOutIndex &old_pair, const AnfWithOutIndex &new_pair) {
786   // replace key
787   if (IsInRefOutputMap(old_pair)) {
788     auto tmp = ref_out_in_map_.extract(old_pair);
789     tmp.key() = new_pair;
790     ref_out_in_map_.insert(std::move(tmp));
791   }
792   // replace value
793   for (auto &item : ref_out_in_map_) {
794     if (item.second == old_pair) {
795       item.second = new_pair;
796     }
797   }
798 }
799 
SetOutputNodeToTensor(const KernelMapTensor & node_to_tensor)800 void KernelGraph::SetOutputNodeToTensor(const KernelMapTensor &node_to_tensor) {
801   output_node_to_tensor_ = node_to_tensor;
802   for (const auto &item : output_node_to_tensor_) {
803     auto node = item.first.first;
804     auto out_index = item.first.second;
805     if (!common::AnfAlgo::IsNopNode(node)) {
806       continue;
807     }
808     while (common::AnfAlgo::IsNopNode(node)) {
809       const auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, 0);
810       node = kernel_with_index.first;
811       out_index = kernel_with_index.second;
812     }
813     KernelWithIndex real_output{node, out_index};
814     nop_node_output_map_.emplace(real_output, item.first);
815   }
816 }
817 
ReplaceGraphInput(const AnfNodePtr & old_parameter,const AnfNodePtr & new_parameter)818 void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter) {
819   // update graph inputs
820   MS_EXCEPTION_IF_NULL(old_parameter);
821   MS_EXCEPTION_IF_NULL(new_parameter);
822   if (old_parameter == new_parameter) {
823     return;
824   }
825   for (size_t i = 0; i < inputs_->size(); i++) {
826     if ((*inputs_)[i] == old_parameter) {
827       MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_parameter->DebugString()
828                    << ",new graph input:" << new_parameter->DebugString();
829       (*inputs_)[i] = new_parameter;
830       FrontBackendlMapUpdate(old_parameter, new_parameter);
831       break;
832     }
833   }
834 }
835 
ReplaceNode(const AnfNodePtr & old_anf_node,const AnfNodePtr & new_anf_node)836 void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, const AnfNodePtr &new_anf_node) {
837   MS_EXCEPTION_IF_NULL(inputs_);
838   auto it = node_output_edges_.find(old_anf_node);
839   if (it == node_output_edges_.end()) {
840     MS_LOG(WARNING) << "Old node not found " << old_anf_node->DebugString();
841     return;
842   }
843   for (auto &user : it->second) {
844     auto user_cnode = dyn_cast<CNode>(user);
845     MS_EXCEPTION_IF_NULL(user_cnode);
846     auto &inputs = user_cnode->inputs();
847     for (size_t i = 1; i < inputs.size(); i++) {
848       if (inputs[i] == old_anf_node) {
849         user_cnode->set_input(i, new_anf_node);
850       }
851     }
852   }
853 }
854 
UpdateExecuteKernelStreamLabel()855 void KernelGraph::UpdateExecuteKernelStreamLabel() {
856   for (auto &kernel : execution_order_) {
857     AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get());
858   }
859 }
860 
GetLeafGraphOrder()861 std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
862   std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order;
863   if (IsLeafGraph()) {
864     leaf_graph_order.push_back(shared_from_this()->cast<KernelGraphPtr>());
865   } else {
866     for (const auto &child_graph : child_graph_order_) {
867       std::shared_ptr<KernelGraph> child_graph_ptr = child_graph.lock();
868       MS_EXCEPTION_IF_NULL(child_graph_ptr);
869       auto child_leaf_graph_order = child_graph_ptr->GetLeafGraphOrder();
870       std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order));
871     }
872   }
873   return leaf_graph_order;
874 }
875 
IsLeafGraph() const876 bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); }
877 
FindNodeByPrimitive(const PrimitivePtr & primitive) const878 std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const {
879   std::vector<CNodePtr> result;
880   for (const auto &anf : execution_order_) {
881     MS_EXCEPTION_IF_NULL(anf);
882     if (common::AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
883       result.push_back(anf->cast<CNodePtr>());
884     }
885   }
886   return result;
887 }
888 
FindNodeByPrimitive(const std::vector<PrimitivePtr> & primitive_list) const889 std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const {
890   std::vector<CNodePtr> result;
891   for (const auto &anf : execution_order_) {
892     MS_EXCEPTION_IF_NULL(anf);
893     for (const auto &primitive : primitive_list) {
894       if (common::AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
895         result.push_back(anf->cast<CNodePtr>());
896       }
897     }
898   }
899   return result;
900 }
901 
PrintGraphExecuteOrder() const902 void KernelGraph::PrintGraphExecuteOrder() const {
903   if (!(IS_OUTPUT_ON(mindspore::kInfo))) {
904     return;
905   }
906   MS_LOG(INFO) << "Graph " << graph_id_ << " execution order:";
907   for (size_t i = 0; i < execution_order_.size(); i++) {
908     CNodePtr cur_cnode_ptr = execution_order_[i];
909     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
910 
911     std::string event_str;
912     if (common::AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) {
913       event_str =
914         ", event id[" + std::to_string(common::AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrEventId)) + "]";
915     }
916 
917     std::string label_str;
918     if (common::AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) {
919       label_str =
920         ", label id[" + std::to_string(common::AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrLabelIndex)) + "]";
921     }
922 
923     if (common::AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) {
924       auto label_list = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrLabelSwitchList);
925       label_str = ", label id[";
926       for (size_t j = 0; j < label_list.size(); ++j) {
927         label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]");
928       }
929     }
930 
931     std::string active_stream_str;
932     if (common::AnfAlgo::HasNodeAttr(kAttrActiveStreamList, cur_cnode_ptr)) {
933       auto stream_list = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrActiveStreamList);
934       active_stream_str = ", active stream id[";
935       for (size_t j = 0; j < stream_list.size(); ++j) {
936         active_stream_str += std::to_string(stream_list[j]) + (j + 1 < stream_list.size() ? ", " : "]");
937       }
938     }
939 
940     std::string group_str;
941     if (AnfAlgo::GetKernelType(cur_cnode_ptr) == HCCL_KERNEL &&
942         common::AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode_ptr)) {
943       group_str = ", group[" + common::AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup) + "]";
944     }
945 
946     MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
947                  << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
948                  << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]"
949                  << event_str << label_str << active_stream_str << group_str;
950   }
951 }
952 
AddInternalOutput(const AnfNodePtr & front_node,const AnfNodePtr & node,size_t output_idx,bool unique_target)953 void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, size_t output_idx,
954                                     bool unique_target) {
955   if (front_node == nullptr || node == nullptr) {
956     MS_LOG(INFO) << "Front node or node is nullptr";
957     return;
958   }
959   MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString();
960   if (common::AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
961     output_idx = common::AnfAlgo::GetTupleGetItemOutIndex(front_node->cast<CNodePtr>());
962   }
963   front_to_internal_outputs_map_[front_node] = {node, output_idx};
964   SetInternalOutputAttr(node);
965   internal_outputs_to_front_map_[node][output_idx] = std::pair<AnfNodePtr, bool>(front_node, unique_target);
966 }
967 
AddInternalOutputTensor(const AnfNodePtr & node,size_t output_idx,const tensor::TensorPtr & tensor)968 void KernelGraph::AddInternalOutputTensor(const AnfNodePtr &node, size_t output_idx, const tensor::TensorPtr &tensor) {
969   if (node == nullptr) {
970     return;
971   }
972   internal_outputs_tensor_map_[node][output_idx] = tensor;
973 }
974 
GetInternalOutputTensor(const AnfNodePtr & node,size_t output_idx)975 tensor::TensorPtr KernelGraph::GetInternalOutputTensor(const AnfNodePtr &node, size_t output_idx) {
976   if (node == nullptr) {
977     return nullptr;
978   }
979   auto iter = internal_outputs_tensor_map_.find(node);
980   if (iter == internal_outputs_tensor_map_.end()) {
981     return nullptr;
982   }
983   auto idx_iter = iter->second.find(output_idx);
984   if (idx_iter == iter->second.end()) {
985     return nullptr;
986   }
987   return idx_iter->second;
988 }
989 
ReplaceInternalOutput(const AnfNodePtr & node,const AnfNodePtr & new_node)990 void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) {
991   if (new_node == nullptr || node == nullptr) {
992     MS_LOG(INFO) << "New node or node is nullptr";
993     return;
994   }
995   if (node == new_node) {
996     MS_LOG(INFO) << "New node and node is the same";
997     return;
998   }
999   auto iter = internal_outputs_to_front_map_.find(node);
1000   if (iter == internal_outputs_to_front_map_.end()) {
1001     MS_LOG(INFO) << "Node is not internal output";
1002     return;
1003   }
1004   MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString();
1005   auto front_nodes = std::move(iter->second);
1006   // We should do 'erase(iter)' before modify 'internal_outputs_to_front_map_',
1007   // since the 'iter' may be invalidated after new item added.
1008   internal_outputs_to_front_map_.erase(iter);
1009   // Move all front nodes to new node mapping.
1010   for (const auto &front_node_iter : front_nodes) {
1011     front_to_internal_outputs_map_[front_node_iter.second.first] = {new_node, front_node_iter.first};
1012   }
1013   internal_outputs_to_front_map_[new_node] = std::move(front_nodes);
1014   SetInternalOutputAttr(new_node);
1015 }
1016 
EnableRuntimeCache() const1017 void KernelGraph::EnableRuntimeCache() const {
1018   auto node_list = TopoSort(get_return());
1019   for (auto &node : node_list) {
1020     auto kernel_info = node->kernel_info();
1021     if (!kernel_info) {
1022       continue;
1023     }
1024     auto runtime_cache = kernel_info->runtime_cache();
1025     runtime_cache.runtime_cache().set_is_valid(true);
1026   }
1027 }
1028 
DisableRuntimeCache() const1029 void KernelGraph::DisableRuntimeCache() const {
1030   auto node_list = TopoSort(get_return());
1031   for (auto &node : node_list) {
1032     auto kernel_info = node->kernel_info();
1033     if (!kernel_info) {
1034       continue;
1035     }
1036     auto runtime_cache = kernel_info->runtime_cache();
1037     runtime_cache.runtime_cache().set_is_valid(false);
1038     runtime_cache.runtime_cache().reset();
1039   }
1040 }
1041 
ReplaceInternalOutput(const AnfNodePtr & node,const AnfNodePtr & new_node,size_t src_output_idx,size_t dst_output_idx)1042 void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, size_t src_output_idx,
1043                                         size_t dst_output_idx) {
1044   if (new_node == nullptr || node == nullptr) {
1045     MS_LOG(INFO) << "New node or node is nullptr";
1046     return;
1047   }
1048   if (node == new_node) {
1049     MS_LOG(INFO) << "New node and node is the same";
1050     return;
1051   }
1052   auto iter = internal_outputs_to_front_map_.find(node);
1053   if (iter == internal_outputs_to_front_map_.end()) {
1054     MS_LOG(INFO) << "Node is not internal output";
1055     return;
1056   }
1057   MS_LOG(INFO) << "Replace internal output node " << node->DebugString() << " to " << new_node->DebugString();
1058   auto &front_nodes = iter->second;
1059   // Move specified front node to new node mapping
1060   auto front_node_iter = front_nodes.find(src_output_idx);
1061   if (front_node_iter == front_nodes.end()) {
1062     MS_LOG(INFO) << "The output " << src_output_idx << " of node " << node->DebugString() << " is not an internal node";
1063     return;
1064   }
1065   auto front_node_pair = std::move(front_node_iter->second);
1066   (void)front_nodes.erase(front_node_iter);
1067   if (front_nodes.empty()) {
1068     (void)internal_outputs_to_front_map_.erase(iter);
1069   }
1070   // We should do 'erase' before 'insert', since the 'iter' may be invalidated after new item added.
1071   front_to_internal_outputs_map_[front_node_pair.first] = {new_node, dst_output_idx};
1072   internal_outputs_to_front_map_[new_node][dst_output_idx] = std::move(front_node_pair);
1073   SetInternalOutputAttr(new_node);
1074 }
UpdateInternalParameter()1075 void KernelGraph::UpdateInternalParameter() {
1076   for (const auto &internal_parameter_to_front_node : internal_parameter_to_front_node_map_) {
1077     const auto &parameter = internal_parameter_to_front_node.first;
1078     const auto &front_node_with_index = internal_parameter_to_front_node.second;
1079     auto front_outputs = common::AnfAlgo::GetAllOutputWithIndex(front_node_with_index.first);
1080     AnfWithOutIndex new_front_node_with_index;
1081     if (front_node_with_index.second < front_outputs.size()) {
1082       new_front_node_with_index = front_outputs[front_node_with_index.second];
1083     } else {
1084       new_front_node_with_index = front_node_with_index;
1085     }
1086 
1087     if (new_front_node_with_index.first == nullptr) {
1088       return;
1089     }
1090     MS_LOG(INFO) << "Cache internal parameter: " << parameter->DebugString()
1091                  << " to front node: " << new_front_node_with_index.first->DebugString()
1092                  << " with index: " << new_front_node_with_index.second
1093                  << ", from front node: " << front_node_with_index.first->DebugString()
1094                  << " with index: " << front_node_with_index.second;
1095     internal_parameter_to_front_node_map_[parameter] = new_front_node_with_index;
1096   }
1097 }
1098 
CacheInternalParameterToFrontNode(const AnfNodePtr & parameter,const AnfWithOutIndex & front_node_with_index)1099 void KernelGraph::CacheInternalParameterToFrontNode(const AnfNodePtr &parameter,
1100                                                     const AnfWithOutIndex &front_node_with_index) {
1101   if ((parameter == nullptr) || (front_node_with_index.first == nullptr)) {
1102     return;
1103   }
1104   internal_parameter_to_front_node_map_[parameter] = front_node_with_index;
1105 }
1106 
GetFrontNodeByInternalParameter(const AnfNodePtr & parameter) const1107 AnfWithOutIndex KernelGraph::GetFrontNodeByInternalParameter(const AnfNodePtr &parameter) const {
1108   auto iter = internal_parameter_to_front_node_map_.find(parameter);
1109   if (iter != internal_parameter_to_front_node_map_.end()) {
1110     // The load/depend node need fetch the real parameter node.
1111     const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {prim::kPrimDepend,
1112                                                                                                 prim::kPrimLoad};
1113     if (IsOneOfPrimitiveCNode(iter->second.first, auto_monad_prims)) {
1114       return common::AnfAlgo::VisitKernelWithReturnType(iter->second.first, iter->second.second, false);
1115     } else {
1116       return iter->second;
1117     }
1118   }
1119 
1120   return AnfWithOutIndex();
1121 }
1122 
GetOriginFrontNodeByInternalParameter(const AnfNodePtr & parameter) const1123 AnfWithOutIndex KernelGraph::GetOriginFrontNodeByInternalParameter(const AnfNodePtr &parameter) const {
1124   auto iter = internal_parameter_to_front_node_map_.find(parameter);
1125   if (iter != internal_parameter_to_front_node_map_.end()) {
1126     return iter->second;
1127   }
1128   return AnfWithOutIndex();
1129 }
1130 
GetFuncGraph()1131 FuncGraphPtr KernelGraph::GetFuncGraph() {
1132   for (const auto &front_backend_anf : front_backend_anf_map_) {
1133     const auto &front_node = front_backend_anf.first;
1134     const auto &func_graph = front_node->func_graph();
1135     if (func_graph != nullptr) {
1136       return func_graph;
1137     }
1138   }
1139   return nullptr;
1140 }
1141 
CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtrList & backend_outputs,const AnfNodePtrList & front_outputs)1142 void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtrList &backend_outputs,
1143                                                        const AnfNodePtrList &front_outputs) {
1144   MS_LOG(INFO) << "Get graph backend output nodes.";
1145   std::vector<KernelWithIndex> backend_output_nodes;
1146   for (auto &backend_output : backend_outputs) {
1147     auto temp_backend_outputs = common::AnfAlgo::GetAllOutputWithIndex(backend_output);
1148     (void)backend_output_nodes.insert(backend_output_nodes.end(), temp_backend_outputs.cbegin(),
1149                                       temp_backend_outputs.cend());
1150   }
1151 
1152   MS_LOG(INFO) << "Get graph front output nodes.";
1153   std::vector<KernelWithIndex> front_output_nodes;
1154   for (auto &front_output : front_outputs) {
1155     auto temp_front_outputs = common::AnfAlgo::GetAllOutputWithIndex(front_output);
1156     (void)front_output_nodes.insert(front_output_nodes.cend(), temp_front_outputs.cbegin(), temp_front_outputs.cend());
1157   }
1158 
1159   if (backend_output_nodes.size() != front_output_nodes.size()) {
1160     MS_LOG(WARNING) << "The size(" << backend_output_nodes.size() << ") of backend outputs is not equal to the size("
1161                     << front_output_nodes.size() << ") of front outputs for graph:" << ToString();
1162     return;
1163   }
1164 
1165   for (size_t i = 0; i < backend_output_nodes.size(); ++i) {
1166     auto backend_output_node = backend_output_nodes[i];
1167     auto front_output_node = front_output_nodes[i];
1168     graph_output_to_front_node_map_[backend_output_node] = front_output_node;
1169     front_node_to_graph_output_map_[front_output_node] = backend_output_node;
1170     MS_LOG(INFO) << "Backend output: " << backend_output_node.first->fullname_with_scope()
1171                  << " with index: " << backend_output_node.second
1172                  << " map to front node: " << front_output_node.first->fullname_with_scope()
1173                  << " with index: " << front_output_node.second;
1174   }
1175 }
1176 
GetTupleGetItemOutputKernelObjectType(const AnfNodePtr & node)1177 kernel::KernelObjectType GetTupleGetItemOutputKernelObjectType(const AnfNodePtr &node) {
1178   MS_EXCEPTION_IF_NULL(node);
1179   auto tuple_get_item = node->cast<CNodePtr>();
1180   auto kernel_with_index = common::AnfAlgo::VisitKernelWithReturnType(tuple_get_item, 0);
1181   auto input_node = kernel_with_index.first;
1182   MS_EXCEPTION_IF_NULL(input_node);
1183   auto output_idx = kernel_with_index.second;
1184   auto kernel_info = dynamic_cast<device::KernelInfo *>(input_node->kernel_info());
1185   MS_LOG(DEBUG) << "GetItem node:" << node->DebugString() << " real node:" << input_node->DebugString()
1186                 << " index:" << output_idx << " kernel info:" << kernel_info;
1187   if (kernel_info != nullptr && kernel_info->has_build_info()) {
1188     auto build_info = kernel_info->select_kernel_build_info();
1189     const auto &output_kernel_obj_types = build_info->GetAllOutputKernelObjectTypes();
1190     const auto &output_elements_kernel_obj_types = build_info->GetAllOutputElementsKernelObjectTypes();
1191     MS_LOG(DEBUG) << "real node:" << input_node->fullname_with_scope()
1192                   << " output kernel object type:" << output_elements_kernel_obj_types
1193                   << " size:" << output_elements_kernel_obj_types.size();
1194     if (output_idx < output_elements_kernel_obj_types.size() && output_kernel_obj_types.size() == 1 &&
1195         output_kernel_obj_types[0] == kernel::KernelObjectType::TUPLE_UNFOLD) {
1196       MS_LOG(DEBUG) << "return type:" << output_elements_kernel_obj_types[output_idx];
1197       return output_elements_kernel_obj_types[output_idx];
1198     } else if (output_kernel_obj_types.size() == 1 && output_kernel_obj_types[0] == kernel::KernelObjectType::TUPLE &&
1199                input_node->abstract() != nullptr && input_node->abstract()->isa<abstract::AbstractSequence>()) {
1200       const auto &sequence_abstract = input_node->abstract()->cast<abstract::AbstractSequencePtr>();
1201       MS_EXCEPTION_IF_NULL(sequence_abstract);
1202       if (sequence_abstract->dynamic_len()) {
1203         MS_EXCEPTION_IF_NULL(sequence_abstract->dynamic_len_element_abs());
1204         return kernel::TypeIdToKernelObjectType(
1205           AnfAlgo::GetAbstractObjectType(sequence_abstract->dynamic_len_element_abs()));
1206       } else {
1207         if (output_idx < sequence_abstract->size()) {
1208           return kernel::TypeIdToKernelObjectType(
1209             AnfAlgo::GetAbstractObjectType(sequence_abstract->elements()[output_idx]));
1210         } else {
1211           MS_LOG(EXCEPTION) << "Invalid index:" << output_idx << " for abstract:" << sequence_abstract->ToString()
1212                             << " in node:" << input_node->fullname_with_scope()
1213                             << " real node:" << node->fullname_with_scope();
1214         }
1215       }
1216     }
1217   }
1218   if (node->abstract() != nullptr && node->abstract()->isa<abstract::AbstractSequence>()) {
1219     MS_LOG(DEBUG) << "node:" << node->fullname_with_scope() << " abstract:" << node->abstract()->ToString();
1220     const auto &sequence_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
1221     MS_EXCEPTION_IF_NULL(sequence_abs);
1222     if (sequence_abs->dynamic_len()) {
1223       return kernel::KernelObjectType::TUPLE;
1224     }
1225   }
1226   return kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAbstractObjectType(node->abstract()));
1227 }
1228 
SetKernelObjectTypesForUnrealNodes() const1229 void KernelGraph::SetKernelObjectTypesForUnrealNodes() const {
1230   auto SetKernelObjectTypesForUnrealNode = [](const AnfNodePtr &node) {
1231     MS_EXCEPTION_IF_NULL(node);
1232     std::vector<kernel::KernelObjectType> output_kernel_object_types;
1233     std::vector<kernel::KernelObjectType> input_kernel_object_types;
1234     if (node->isa<CNode>()) {
1235       auto kernel_info = node->kernel_info_ptr();
1236       MS_EXCEPTION_IF_NULL(kernel_info);
1237       if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) &&
1238           (!kernel_info->has_build_info() || AnfAlgo::GetOutputKernelObjectTypes(node).empty())) {
1239         const auto &input_object_types = AnfAlgo::GetAllInputObjectType(node);
1240         input_kernel_object_types = kernel::TypeIdToKernelObjectTypeForTupleUnfold(input_object_types);
1241         output_kernel_object_types = {kernel::KernelObjectType::TUPLE_UNFOLD};
1242       }
1243       if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem) &&
1244           (!kernel_info->has_build_info() || AnfAlgo::GetOutputKernelObjectTypes(node).empty())) {
1245         output_kernel_object_types = {GetTupleGetItemOutputKernelObjectType(node)};
1246         MS_LOG(DEBUG) << "node:" << node->DebugString() << " output kernel object type:" << output_kernel_object_types;
1247         const auto &input_object_types = AnfAlgo::GetAllInputObjectType(node);
1248         input_kernel_object_types = kernel::TypeIdToKernelObjectTypeForTupleUnfold(input_object_types);
1249       }
1250     }
1251     if (output_kernel_object_types.empty() && input_kernel_object_types.empty()) {
1252       return;
1253     }
1254     kernel::SetKernelObjectTypeBuildInfo(node, input_kernel_object_types, output_kernel_object_types);
1255   };
1256 
1257   auto node_list = TopoSort(get_return());
1258   for (auto &node : node_list) {
1259     SetKernelObjectTypesForUnrealNode(node);
1260   }
1261 }
1262 
GetFrontNodeWithIndexByGraphOutput(const AnfWithOutIndex & backend_graph_output_with_index) const1263 AnfWithOutIndex KernelGraph::GetFrontNodeWithIndexByGraphOutput(
1264   const AnfWithOutIndex &backend_graph_output_with_index) const {
1265   auto iter = graph_output_to_front_node_map_.find(backend_graph_output_with_index);
1266   if (iter != graph_output_to_front_node_map_.end()) {
1267     return iter->second;
1268   }
1269   return AnfWithOutIndex();
1270 }
1271 
GetInternalOutputByFrontNode(const AnfNodePtr & front_node) const1272 AnfWithOutIndex KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
1273   auto iter = front_to_internal_outputs_map_.find(front_node);
1274   if (iter != front_to_internal_outputs_map_.end()) {
1275     return iter->second;
1276   }
1277   return {nullptr, 0};
1278 }
1279 
GetGraphOutputByFrontNode(const AnfWithOutIndex & front_node) const1280 AnfWithOutIndex KernelGraph::GetGraphOutputByFrontNode(const AnfWithOutIndex &front_node) const {
1281   auto iter = front_node_to_graph_output_map_.find(front_node);
1282   if (iter != front_node_to_graph_output_map_.end()) {
1283     return iter->second;
1284   }
1285   return AnfWithOutIndex(nullptr, 0);
1286 }
1287 
IsInternalOutput(const AnfNodePtr & node) const1288 bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const {
1289   return internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end();
1290 }
1291 
IsInternalOutput(const AnfNodePtr & node,size_t output_idx) const1292 bool KernelGraph::IsInternalOutput(const AnfNodePtr &node, size_t output_idx) const {
1293   auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
1294   if (front_nodes_iter == internal_outputs_to_front_map_.end()) {
1295     return false;
1296   }
1297   auto &front_nodes = front_nodes_iter->second;
1298   return front_nodes.find(output_idx) != front_nodes.end();
1299 }
1300 
IsUniqueTargetInternalOutput(const AnfNodePtr & node,size_t output_idx) const1301 bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, size_t output_idx) const {
1302   auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
1303   if (front_nodes_iter == internal_outputs_to_front_map_.end()) {
1304     return false;
1305   }
1306   auto &front_nodes = front_nodes_iter->second;
1307   auto idx_iter = front_nodes.find(output_idx);
1308   if (idx_iter == front_nodes.end()) {
1309     return false;
1310   }
1311   return idx_iter->second.second;
1312 }
1313 
UpdateChildGraphOrder()1314 void KernelGraph::UpdateChildGraphOrder() {
1315   MS_LOG(INFO) << "Update " << ToString() << " child graph order.";
1316   SetExecOrderByDefault();
1317   auto call_nodes = FindNodeByPrimitive({std::make_shared<Primitive>(prim::kPrimCall->name()),
1318                                          std::make_shared<Primitive>(prim::kPrimSwitch->name()),
1319                                          std::make_shared<Primitive>(prim::kPrimSwitchLayer->name())});
1320   std::vector<std::weak_ptr<KernelGraph>> child_graph_order;
1321   for (auto &call_node : call_nodes) {
1322     MS_EXCEPTION_IF_NULL(call_node);
1323     auto call_child_graphs = AnfAlgo::GetCallSwitchKernelGraph(call_node->cast<CNodePtr>());
1324     for (const auto &child_graph : call_child_graphs) {
1325       MS_EXCEPTION_IF_NULL(child_graph);
1326       if (child_graph != parent_graph_.lock()) {
1327         auto shared_this = std::dynamic_pointer_cast<KernelGraph>(shared_from_this());
1328         MS_EXCEPTION_IF_NULL(shared_this);
1329         child_graph->set_parent_graph(shared_this);
1330       }
1331       child_graph_order.push_back(child_graph);
1332     }
1333   }
1334   for (size_t i = 0; i < child_graph_order.size(); ++i) {
1335     std::shared_ptr<KernelGraph> child_graph = child_graph_order[i].lock();
1336     MS_EXCEPTION_IF_NULL(child_graph);
1337     MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph->graph_id() << "]";
1338   }
1339   child_graph_order_ = child_graph_order;
1340 }
1341 
RemoveNodeFromGraph(const AnfNodePtr & node)1342 void KernelGraph::RemoveNodeFromGraph(const AnfNodePtr &node) {
1343   MS_EXCEPTION_IF_NULL(node);
1344   auto iter = backend_front_anf_map_.find(node);
1345   if (iter != backend_front_anf_map_.end()) {
1346     (void)front_backend_anf_map_.erase(iter->second);
1347     (void)backend_front_anf_map_.erase(iter);
1348   }
1349   if (node->isa<ValueNode>()) {
1350     (void)RemoveValueNodeFromGraph(node->cast<ValueNodePtr>());
1351   }
1352 }
1353 
UpdateGraphDynamicAttr()1354 void KernelGraph::UpdateGraphDynamicAttr() {
1355   for (const auto &cnode : execution_order_) {
1356     if (common::AnfAlgo::IsDynamicShape(cnode)) {
1357       MS_LOG(INFO) << "Update Graph Dynamic Attr";
1358       is_dynamic_shape_ = true;
1359       return;
1360     }
1361   }
1362   is_dynamic_shape_ = false;
1363 }
1364 
SetInputNodes()1365 void KernelGraph::SetInputNodes() {
1366   input_nodes_.clear();
1367   for (const auto &input_node : inputs()) {
1368     MS_EXCEPTION_IF_NULL(input_node);
1369     auto params = common::AnfAlgo::GetAllOutput(input_node);
1370     auto abs = input_node->abstract();
1371     MS_EXCEPTION_IF_NULL(abs);
1372     if (params.size() > 1 ||
1373         (abs->isa<abstract::AbstractSequence>() && (!common::AnfAlgo::IsDynamicSequence(input_node))) ||
1374         abs->isa<abstract::AbstractDictionary>()) {
1375       if (backend_front_anf_map_.find(input_node) == backend_front_anf_map_.end()) {
1376         MS_LOG(WARNING) << "Cannot find input_node: " << input_node->DebugString() << " in backend_front_anf_map.";
1377         continue;
1378       }
1379       auto front_node = backend_front_anf_map_[input_node];
1380       for (size_t i = 0; i < params.size(); ++i) {
1381         // Keep the input_node in the map. Otherwise, the SetInputNodes function is not reentrant.
1382         tuple_backend_front_anf_index_map_[params[i]] = AnfWithOutIndex(front_node, i);
1383       }
1384     } else if (params.size() == 1) {
1385       FrontBackendlMapUpdate(input_node, params[0]);
1386     }
1387     std::copy(params.begin(), params.end(), std::back_inserter(input_nodes_));
1388   }
1389 }
1390 
UpdateGraphAquireGilAttr()1391 void KernelGraph::UpdateGraphAquireGilAttr() {
1392   for (const auto &cnode : execution_order_) {
1393     if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPyFunc)) {
1394       MS_LOG(INFO) << "The Graph require GIL. Graph id: " << graph_id_;
1395       is_need_gil_ = true;
1396       return;
1397     }
1398   }
1399 }
1400 
SetOptimizerFlag()1401 void KernelGraph::SetOptimizerFlag() {
1402   has_optimizer_ = false;
1403   for (const auto &cnode : execution_order_) {
1404     MS_EXCEPTION_IF_NULL(cnode);
1405     if (!common::AnfAlgo::IsUpdateParameterKernel(cnode)) {
1406       continue;
1407     }
1408     for (auto &input : cnode->inputs()) {
1409       MS_EXCEPTION_IF_NULL(input);
1410       auto real_node = common::AnfAlgo::VisitKernel(input, 0).first;
1411       MS_EXCEPTION_IF_NULL(real_node);
1412       if (!real_node->isa<Parameter>()) {
1413         continue;
1414       }
1415       auto param = real_node->cast<ParameterPtr>();
1416       auto abstract = param->abstract();
1417       MS_EXCEPTION_IF_NULL(abstract);
1418       if (abstract->isa<abstract::AbstractRefTensor>()) {
1419         has_optimizer_ = true;
1420         (void)updated_parameters_.insert(param);
1421       }
1422     }
1423   }
1424 }
1425 
IsDatasetGraph() const1426 bool KernelGraph::IsDatasetGraph() const {
1427   // check if there is InitDataSetQueue node
1428   const auto &nodes = execution_order_;
1429   // The size of execution_order for the dataset graph is equal to 1.
1430   if (execution_order_.size() > 1) {
1431     return false;
1432   }
1433   for (const auto &node : nodes) {
1434     auto node_name = common::AnfAlgo::GetCNodeName(node);
1435     if (node_name == prim::kPrimInitDataSetQueue->name()) {
1436       return true;
1437     }
1438   }
1439   return false;
1440 }
1441 
ToString() const1442 std::string KernelGraph::ToString() const {
1443   std::string prefix = is_from_pynative() ? "pynative_kernel_graph" : "kernel_graph";
1444   return prefix.append(std::to_string(graph_id_));
1445 }
1446 
FrontendNodeExistInFrontBackendMap(const AnfNodePtr & frontend_anf)1447 bool KernelGraph::FrontendNodeExistInFrontBackendMap(const AnfNodePtr &frontend_anf) {
1448   return front_backend_anf_map_.find(frontend_anf) != front_backend_anf_map_.end();
1449 }
1450 
IsChildGraphResult(const AnfNodePtr & node)1451 bool KernelGraph::IsChildGraphResult(const AnfNodePtr &node) {
1452   AnfNodePtrList child_graph_results;
1453   for (const auto &child_graph_result : child_graph_result_) {
1454     MS_EXCEPTION_IF_NULL(child_graph_result);
1455     auto outputs = common::AnfAlgo::GetAllOutput(child_graph_result);
1456     (void)child_graph_results.insert(child_graph_results.cend(), outputs.cbegin(), outputs.cend());
1457   }
1458 
1459   return find(child_graph_results.begin(), child_graph_results.end(), node) != child_graph_results.end();
1460 }
1461 
~KernelGraph()1462 KernelGraph::~KernelGraph() {
1463   try {
1464     device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_);
1465   } catch (const std::exception &e) {
1466     MS_LOG(ERROR) << "KernelGraph call destructor failed: " << e.what();
1467   } catch (...) {
1468     MS_LOG(ERROR) << "KernelGraph call destructor failed";
1469   }
1470 }
1471 
FetchInputAbstracts(const CNodePtr & cnode)1472 std::vector<abstract::AbstractBasePtr> FetchInputAbstracts(const CNodePtr &cnode) {
1473   MS_EXCEPTION_IF_NULL(cnode);
1474   std::vector<abstract::AbstractBasePtr> abstracts{};
1475   for (size_t i = 1; i < cnode->size(); ++i) {
1476     const auto &input = cnode->inputs()[i];
1477     MS_EXCEPTION_IF_NULL(input);
1478     const auto &abstract = input->abstract();
1479     if (abstract == nullptr) {
1480       MS_LOG(EXCEPTION) << "Invalid abstract for input:" << input->DebugString()
1481                         << " for node:" << cnode->fullname_with_scope() << " input index:" << i;
1482     }
1483     MS_LOG(DEBUG) << "Add abstract:" << abstract->ToString() << " for input:" << input->DebugString();
1484     abstracts.emplace_back(abstract);
1485   }
1486   return abstracts;
1487 }
1488 
InferType()1489 void KernelGraph::InferType() {
1490   MS_LOG(DEBUG) << "Start infer type for graph:" << ToString();
1491   AnfNodePtrList nodes = TopoSort(get_return());
1492   for (const auto &node : nodes) {
1493     if (node == nullptr || (!node->isa<CNode>())) {
1494       continue;
1495     }
1496     const auto &cnode = node->cast<CNodePtr>();
1497     MS_EXCEPTION_IF_NULL(cnode);
1498     if (cnode->inputs().empty() || (!IsValueNode<Primitive>(cnode->input(0))) ||
1499         common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPyExecute)) {
1500       continue;
1501     }
1502     cnode->set_abstract(nullptr);
1503     MS_LOG(DEBUG) << "Infer abstract for node:" << node->fullname_with_scope();
1504 
1505     // Fetch input abstracts.
1506     std::vector<abstract::AbstractBasePtr> abstracts = FetchInputAbstracts(cnode);
1507 
1508     // Fetch infer function.
1509     const auto &primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
1510     MS_EXCEPTION_IF_NULL(primitive);
1511     auto abstract_opt = abstract::TryInferAbstract(primitive, abstracts);
1512     if (!abstract_opt.has_value()) {
1513       MS_LOG(EXCEPTION) << "Failed to infer for primitive:" << primitive->ToString()
1514                         << " in node:" << cnode->fullname_with_scope();
1515     }
1516     auto abstract = abstract_opt.value();
1517     MS_LOG(INFO) << "Set abstract:" << abstract->ToString() << " for node:" << cnode->DebugString();
1518     cnode->set_abstract(abstract);
1519   }
1520 }
1521 
CacheRootWeight(const std::vector<AnfNodePtr> & weights)1522 void KernelGraph::CacheRootWeight(const std::vector<AnfNodePtr> &weights) { root_weights_ = weights; }
1523 }  // namespace session
1524 }  // namespace mindspore
1525