• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/session/kernel_graph.h"
17 #include <algorithm>
18 #include <queue>
19 #include <unordered_set>
20 #include <set>
21 #include <exception>
22 #include "base/core_ops.h"
23 #include "ir/param_info.h"
24 #include "utils/utils.h"
25 #include "utils/check_convert_utils.h"
26 #include "backend/session/anf_runtime_algorithm.h"
27 #include "runtime/device/kernel_info.h"
28 #include "backend/kernel_compiler/kernel_build_info.h"
29 #include "runtime/device/kernel_runtime_manager.h"
30 #include "backend/kernel_compiler/common_utils.h"
31 
32 namespace mindspore {
33 namespace session {
34 namespace {
35 constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
36 constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
37 constexpr size_t k5dDims = 5;
38 const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(),
39                                                        prim::kPrimAssignSub->name()};
40 
PushNoVisitedNode(const AnfNodePtr & node,std::queue<AnfNodePtr> * que,std::unordered_set<AnfNodePtr> * visited_nodes)41 void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
42                        std::unordered_set<AnfNodePtr> *visited_nodes) {
43   MS_EXCEPTION_IF_NULL(node);
44   MS_EXCEPTION_IF_NULL(que);
45   MS_EXCEPTION_IF_NULL(visited_nodes);
46   if (visited_nodes->find(node) == visited_nodes->end()) {
47     que->push(node);
48     (void)visited_nodes->insert(node);
49     MS_LOG(DEBUG) << "Push que:" << node->DebugString();
50   }
51 }
52 
GetCallRealOutputs(const AnfNodePtr & call_node)53 std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
54   auto item_with_index =
55     AnfAlgo::VisitKernelWithReturnType(call_node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
56   AnfNodePtr node = item_with_index.first;
57   MS_EXCEPTION_IF_NULL(node);
58   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
59     auto outputs = AnfAlgo::GetAllOutput(node);
60     std::set<AnfNodePtr> memo;
61     std::vector<AnfNodePtr> new_output;
62     for (auto &output : outputs) {
63       if (memo.find(output) != memo.end()) {
64         continue;
65       }
66       memo.insert(output);
67       new_output.push_back(output);
68     }
69     if (new_output.size() == 1 && AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) {
70       node = new_output[0];
71     }
72   }
73   if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
74     return {node};
75   }
76   std::vector<AnfNodePtr> real_inputs;
77   auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node->cast<CNodePtr>());
78   for (const auto &child_graph : child_graphs) {
79     MS_EXCEPTION_IF_NULL(child_graph);
80     auto real_input = child_graph->output();
81     auto child_real_inputs = GetCallRealOutputs(real_input);
82     std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs));
83   }
84   return real_inputs;
85 }
86 
IsSameLabel(const CNodePtr & left,const CNodePtr & right)87 bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
88   if (left == right) {
89     return true;
90   }
91   if (left == nullptr || right == nullptr) {
92     return false;
93   }
94   if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) {
95     return false;
96   }
97   if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) {
98     return AnfAlgo::GetNodeAttr<uint32_t>(left, kAttrLabelIndex) ==
99            AnfAlgo::GetNodeAttr<uint32_t>(right, kAttrLabelIndex);
100   }
101   return false;
102 }
103 
SyncDeviceInfoToValueNode(const ValueNodePtr & value_node,std::vector<std::string> * device_formats,std::vector<TypeId> * device_types)104 void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector<std::string> *device_formats,
105                                std::vector<TypeId> *device_types) {
106   MS_EXCEPTION_IF_NULL(value_node);
107   MS_EXCEPTION_IF_NULL(device_formats);
108   MS_EXCEPTION_IF_NULL(device_types);
109   ValuePtr value = value_node->value();
110   std::vector<tensor::TensorPtr> tensors;
111   TensorValueToTensor(value, &tensors);
112   if (!tensors.empty()) {
113     device_formats->clear();
114     device_types->clear();
115     for (const auto &tensor : tensors) {
116       MS_EXCEPTION_IF_NULL(tensor);
117       auto device_sync = tensor->device_address();
118       if (device_sync != nullptr) {
119         auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
120         MS_EXCEPTION_IF_NULL(device_address);
121         device_formats->emplace_back(device_address->format());
122         device_types->emplace_back(device_address->type_id());
123         continue;
124       }
125       device_formats->emplace_back(kOpFormat_DEFAULT);
126       device_types->emplace_back(kTypeUnknown);
127     }
128   }
129 }
130 
GetNodeGroup(const AnfNodePtr & node)131 std::string GetNodeGroup(const AnfNodePtr &node) {
132   MS_EXCEPTION_IF_NULL(node);
133   auto cnode = node->cast<CNodePtr>();
134   if (AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
135     return AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
136   }
137   return "";
138 }
139 }  // namespace
140 
MakeValueNode(const AnfNodePtr & node) const141 AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) const {
142   MS_EXCEPTION_IF_NULL(node);
143   auto value_node = node->cast<ValueNodePtr>();
144   if (value_node == nullptr) {
145     return nullptr;
146   }
147   ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
148   MS_EXCEPTION_IF_NULL(new_value_node);
149   new_value_node->set_abstract(value_node->abstract());
150   this->SetKernelInfoForNode(new_value_node);
151   return new_value_node;
152 }
153 
outputs() const154 std::vector<AnfNodePtr> KernelGraph::outputs() const {
155   auto graph_output = output();
156   if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) {
157     auto make_tuple = output()->cast<CNodePtr>();
158     MS_EXCEPTION_IF_NULL(make_tuple);
159     auto &inputs = make_tuple->inputs();
160     return std::vector<AnfNodePtr>(inputs.begin() + 1, inputs.end());
161   }
162   return std::vector<AnfNodePtr>(1, graph_output);
163 }
164 
EnqueueActiveNodes(const AnfNodePtr & node,std::queue<AnfNodePtr> * visit_queue,std::unordered_set<AnfNodePtr> * visited_nodes,bool comm_first)165 void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
166                                      std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) {
167   MS_EXCEPTION_IF_NULL(visit_queue);
168   MS_EXCEPTION_IF_NULL(visited_nodes);
169   auto it = node_output_edges_.find(node);
170   if (it == node_output_edges_.end()) {
171     // value node and parameter has no input,no need to print log
172     if (node->isa<CNode>()) {
173       MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
174     }
175     return;
176   }
177   // visit all reduce node first, then other nodes
178   std::vector<AnfNodePtr> active_nodes;
179   for (const auto &output_edge : it->second) {
180     auto next_node = output_edge.first;
181     MS_EXCEPTION_IF_NULL(next_node);
182     if (node_input_num_.find(next_node) == node_input_num_.end()) {
183       MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
184     }
185     MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
186                   << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
187     if (node_input_num_[next_node] < output_edge.second) {
188       MS_LOG(DEBUG) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node]
189                     << ",depend edge:" << output_edge.second;
190       continue;
191     }
192     node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
193     // allreduce first
194     if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
195       (void)visited_nodes->insert(next_node);
196       bool is_comm_node = AnfAlgo::IsCommunicationOp(next_node);
197       if (AnfAlgo::CheckPrimitiveType(next_node, prim::kPrimLoad)) {
198         EnqueueActiveNodes(next_node, visit_queue, visited_nodes);
199       } else if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) {
200         MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString();
201         visit_queue->push(next_node);
202       } else {
203         active_nodes.emplace_back(next_node);
204       }
205     }
206   }
207   for (auto &active_node : active_nodes) {
208     visit_queue->push(active_node);
209   }
210 }
211 
SetExecOrderByDefault()212 void KernelGraph::SetExecOrderByDefault() {
213   std::queue<AnfNodePtr> seed_nodes;
214   UpdateNodeEdgeList(&seed_nodes);
215   execution_order_.clear();
216   std::unordered_set<AnfNodePtr> visited_nodes;
217   std::queue<AnfNodePtr> zero_input_nodes;
218   std::queue<AnfNodePtr> delay_comm_stack;
219   std::queue<AnfNodePtr> communication_descendants;
220   std::string optimized_comm_group;
221   while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
222     // seed nodes first, then delay comm nodes
223     if (seed_nodes.empty()) {
224       EnqueueActiveNodes(delay_comm_stack.front(), &communication_descendants, &visited_nodes, false);
225       delay_comm_stack.pop();
226     } else {
227       zero_input_nodes.push(seed_nodes.front());
228       seed_nodes.pop();
229     }
230     // comm descendant first, then common queue
231     while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
232       AnfNodePtr node = nullptr;
233       bool is_communication_descendant = false;
234       if (communication_descendants.empty()) {
235         node = zero_input_nodes.front();
236         zero_input_nodes.pop();
237       } else {
238         node = communication_descendants.front();
239         communication_descendants.pop();
240         is_communication_descendant = true;
241       }
242       // add execute node
243       MS_EXCEPTION_IF_NULL(node);
244       if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
245         execution_order_.push_back(node->cast<CNodePtr>());
246       }
247       // delay execute comm ops that need optimize
248       bool is_fused_comm = AnfAlgo::IsFusedCommunicationOp(node);
249       bool optimize_comm = false;
250       if (is_fused_comm && optimized_comm_group.empty()) {
251         auto node_group = GetNodeGroup(node);
252         if (node_group.find(kSyncBnGroup) == string::npos) {
253           optimized_comm_group = node_group;
254           optimize_comm = true;
255         }
256       }
257       if (optimize_comm) {
258         while (!delay_comm_stack.empty()) {
259           EnqueueActiveNodes(delay_comm_stack.front(), &communication_descendants, &visited_nodes, false);
260           delay_comm_stack.pop();
261         }
262         delay_comm_stack.push(node);
263       } else if (is_fused_comm) {
264         delay_comm_stack.push(node);
265       } else if (is_communication_descendant) {
266         EnqueueActiveNodes(node, &communication_descendants, &visited_nodes);
267       } else {
268         EnqueueActiveNodes(node, &zero_input_nodes, &visited_nodes);
269       }
270     }
271   }
272   CheckLoop();
273   // resort start label / end goto
274   execution_order_ = SortStartLabelAndEndGoto();
275 }
276 
SortStartLabelAndEndGoto()277 std::vector<CNodePtr> KernelGraph::SortStartLabelAndEndGoto() {
278   std::vector<CNodePtr> re_order;
279   if (start_label_ != nullptr) {
280     re_order.push_back(start_label_);
281   }
282   for (auto &node : execution_order_) {
283     if (node == start_label_ || node == end_goto_) {
284       continue;
285     }
286 
287     if (IsSameLabel(node, end_goto_)) {
288       end_goto_ = node;
289       MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id();
290       continue;
291     }
292 
293     if (IsSameLabel(node, start_label_)) {
294       start_label_ = node;
295       MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id();
296       continue;
297     }
298 
299     //
300     // Re-order:
301     //   u = LabelGoto(...)
302     //   x = Mul(...)
303     //   LabelSet(u)
304     // To:
305     //   u = LabelGoto(...)
306     //   LabelSet(u)
307     //   x = Mul(...)
308     // This prevent Mul be skipped.
309     //
310     if (IsPrimitiveCNode(node, prim::kPrimLabelSet) && (re_order.back() != node->input(1))) {
311       auto iter = std::find(re_order.rbegin() + 1, re_order.rend(), node->input(1));
312       if (iter != re_order.rend()) {
313         re_order.insert(iter.base(), node);
314         continue;
315       }
316     }
317 
318     re_order.push_back(node);
319   }
320   if (end_goto_ != nullptr) {
321     re_order.push_back(end_goto_);
322   }
323   return re_order;
324 }
325 
GetLoopNodesByDFS(const AnfNodePtr & node,uint32_t * loop_num)326 void KernelGraph::GetLoopNodesByDFS(const AnfNodePtr &node, uint32_t *loop_num) {
327   MS_EXCEPTION_IF_NULL(node);
328   auto node_input_it = node_input_edges_.find(node);
329   if (node_input_it == node_input_edges_.end()) {
330     MS_LOG(DEBUG) << "Node [" << node->DebugString() << "] don't have input edges.";
331     return;
332   }
333   if (*loop_num != 0) {
334     return;
335   }
336   (void)visited_nodes_.insert(node);
337   for (auto &input_edge : node_input_edges_[node]) {
338     size_t input_num = node_input_num_[input_edge.first];
339     if (input_num == 0) {
340       continue;
341     }
342     if (find(visited_nodes_.begin(), visited_nodes_.end(), input_edge.first) == visited_nodes_.end()) {
343       MS_EXCEPTION_IF_NULL(input_edge.first);
344       edge_to_[input_edge.first] = node;
345       GetLoopNodesByDFS(input_edge.first, loop_num);
346     } else {
347       AnfNodePtr node_iter = node;
348       MS_EXCEPTION_IF_NULL(node_iter);
349       MS_LOG(INFO) << "Print loop nodes start:";
350       for (; node_iter != input_edge.first && node_iter != nullptr; node_iter = edge_to_[node_iter]) {
351         loop_nodes_.push(node_iter);
352         node_input_num_[node_iter]--;
353         MS_LOG(INFO) << "Get loop node:" << node_iter->DebugString();
354       }
355       if (node_iter != nullptr) {
356         loop_nodes_.push(node_iter);
357         loop_nodes_.push(node);
358         (*loop_num)++;
359         node_input_num_[node_iter]--;
360         MS_LOG(INFO) << "Get loop node:" << node_iter->DebugString();
361         MS_LOG(INFO) << "Get loop node:" << node->DebugString();
362         MS_LOG(INFO) << "Print loop nodes end, Loop num:" << *loop_num;
363         while (!loop_nodes_.empty()) {
364           loop_nodes_.pop();
365         }
366         return;
367       }
368     }
369   }
370 }
371 
GetLoopNum(const std::map<AnfNodePtr,size_t> & none_zero_nodes)372 uint32_t KernelGraph::GetLoopNum(const std::map<AnfNodePtr, size_t> &none_zero_nodes) {
373   uint32_t loop_num = 0;
374   for (auto &iter : none_zero_nodes) {
375     auto node = iter.first;
376     MS_EXCEPTION_IF_NULL(node);
377     if (node_input_num_[node] == 0) {
378       continue;
379     }
380     edge_to_.clear();
381     visited_nodes_.clear();
382     GetLoopNodesByDFS(node, &loop_num);
383   }
384   return loop_num;
385 }
386 
CheckLoop()387 void KernelGraph::CheckLoop() {
388   std::map<AnfNodePtr, size_t> none_zero_nodes;
389   if (node_input_edges_.size() != node_input_num_.size()) {
390     MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size()
391                       << "not equal to node_input_num_ size:" << node_input_num_.size();
392   }
393   for (auto &it : node_input_num_) {
394     MS_EXCEPTION_IF_NULL(it.first);
395     string str;
396     auto node_input_it = node_input_edges_.find(it.first);
397     if (node_input_it == node_input_edges_.end()) {
398       MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]";
399     }
400     if (it.second != 0) {
401       for (const auto &input_edge : node_input_edges_[it.first]) {
402         MS_EXCEPTION_IF_NULL(input_edge.first);
403         str = str.append(input_edge.first->DebugString()).append("|");
404       }
405       MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second;
406       none_zero_nodes[it.first] = it.second;
407     }
408   }
409   // if don't consider loop exit,a exception will be throw
410   if (!none_zero_nodes.empty()) {
411     MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes);
412     MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
413   }
414 }
415 
NewCNode(const std::vector<AnfNodePtr> & inputs)416 CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
417   auto cnode = FuncGraph::NewCNode(inputs);
418   MS_EXCEPTION_IF_NULL(cnode);
419   cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
420   if (AnfAlgo::IsGraphKernel(cnode)) {
421     CreateKernelInfoFromNewParameter(cnode);
422   }
423   if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
424     AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
425   }
426   SetKernelInfoForNode(cnode);
427   AnfAlgo::SetGraphId(graph_id_, cnode.get());
428   return cnode;
429 }
430 
NewCNodeWithInfos(const std::vector<AnfNodePtr> & inputs,const CNodePtr & ori_cnode)431 CNodePtr KernelGraph::NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode) {
432   auto cnode = NewCNode(inputs);
433   if (ori_cnode != nullptr) {
434     cnode->set_attrs(ori_cnode->attrs());
435     cnode->set_primal_attrs(ori_cnode->primal_attrs());
436     cnode->set_primal_debug_infos(ori_cnode->primal_debug_infos());
437   }
438   return cnode;
439 }
440 
CreateKernelInfoFromNewParameter(const CNodePtr & cnode)441 void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
442   auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
443   MS_EXCEPTION_IF_NULL(func_graph);
444 
445   std::vector<AnfNodePtr> node_list;
446   std::vector<AnfNodePtr> input_list;
447   std::vector<AnfNodePtr> output_list;
448   kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
449   for (auto &anf_node : node_list) {
450     MS_EXCEPTION_IF_NULL(anf_node);
451     if (anf_node->kernel_info() == nullptr) {
452       anf_node->set_kernel_info(std::make_shared<device::KernelInfo>());
453     }
454     auto anf_cnode = anf_node->cast<CNodePtr>();
455     MS_EXCEPTION_IF_NULL(anf_cnode);
456     size_t input_num = AnfAlgo::GetInputTensorNum(anf_cnode);
457     for (size_t i = 0; i < input_num; ++i) {
458       auto input_node = anf_cnode->input(i + 1);
459       MS_EXCEPTION_IF_NULL(input_node);
460       if (IsValueNode<tensor::Tensor>(input_node)) {
461         auto new_input_node = MakeValueNode(input_node);
462         if (new_input_node != nullptr) {
463           anf_cnode->set_input(i + 1, new_input_node);
464         }
465       }
466     }
467   }
468   for (auto &anf_node : input_list) {
469     MS_EXCEPTION_IF_NULL(anf_node);
470     if (anf_node->kernel_info() == nullptr) {
471       anf_node->set_kernel_info(std::make_shared<device::KernelInfo>());
472     }
473   }
474 }
475 
ResetAssignInputFeatureMapFlag(const CNodePtr & cnode) const476 void KernelGraph::ResetAssignInputFeatureMapFlag(const CNodePtr &cnode) const {
477   if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) {
478     MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map "
479                          "flag but got the node :"
480                       << cnode->DebugString();
481   }
482   auto input_node = AnfAlgo::GetInputNode(cnode, 0);
483   MS_EXCEPTION_IF_NULL(input_node);
484   auto assign_value_node = AnfAlgo::GetInputNode(cnode, 1);
485   if (AnfAlgo::IsFeatureMapOutput(input_node)) {
486     return;
487   }
488   if (!AnfAlgo::IsFeatureMapOutput(input_node) && AnfAlgo::IsFeatureMapOutput(assign_value_node)) {
489     auto kernel_info = dynamic_cast<device::KernelInfo *>(input_node->kernel_info());
490     MS_EXCEPTION_IF_NULL(kernel_info);
491     kernel_info->set_feature_map_flag(true);
492   }
493 }
494 
SetKernelInfoForNode(const AnfNodePtr & node) const495 void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
496   MS_EXCEPTION_IF_NULL(node);
497   auto kernel_info = std::make_shared<device::KernelInfo>();
498   MS_EXCEPTION_IF_NULL(kernel_info);
499   node->set_kernel_info(kernel_info);
500   if (node->isa<CNode>()) {
501     if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) {
502       ResetAssignInputFeatureMapFlag(node->cast<CNodePtr>());
503     }
504 #if defined(__APPLE__)
505     std::vector<int> feature_map_input_indexs;
506 #else
507     std::vector<size_t> feature_map_input_indexs;
508 #endif
509     kernel_info->set_feature_map_flag(false);
510     size_t input_num = AnfAlgo::GetInputTensorNum(node);
511     for (size_t index = 0; index < input_num; ++index) {
512       if (AnfAlgo::IsFeatureMapInput(node, index)) {
513         kernel_info->set_feature_map_flag(true);
514         feature_map_input_indexs.push_back(index);
515       }
516     }
517     if (AnfAlgo::GetInputTensorNum(node) == 0) {
518       kernel_info->set_feature_map_flag(true);
519     }
520     if (AnfAlgo::IsRealKernel(node)) {
521       // if the node only has the primitive(such as getNext) or the node's input has a feature map input
522       // then the node's output is a feature map output
523       AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node);
524       AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), node);
525     }
526     return;
527   }
528   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
529   MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
530   // set the format of value_node to DEFAULT_FORMAT
531   std::vector<TypeId> types;
532   std::vector<std::string> formats = {kOpFormat_DEFAULT};
533   if (node->isa<ValueNode>()) {
534     kernel_info->set_feature_map_flag(false);
535     (void)types.emplace_back(kTypeUnknown);
536     auto value_node = node->cast<ValueNodePtr>();
537     SyncDeviceInfoToValueNode(value_node, &formats, &types);
538   }
539   if (node->isa<Parameter>()) {
540     auto parameter = node->cast<ParameterPtr>();
541     MS_EXCEPTION_IF_NULL(parameter);
542     bool is_weight = AnfAlgo::IsParameterWeight(parameter);
543     kernel_info->set_feature_map_flag(!is_weight);
544     types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0));
545   }
546   // set parameter initaial device data type
547   kernel_build_info_builder->SetOutputsFormat(formats);
548   kernel_build_info_builder->SetOutputsDeviceType(types);
549   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get());
550 }
551 
NewCNode(const CNodePtr & cnode)552 CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
553   MS_EXCEPTION_IF_NULL(cnode);
554   auto new_cnode = std::make_shared<CNode>(*cnode);
555   // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map
556   if (BackendNodeExistInFrontBackendMap(cnode)) {
557     FrontBackendlMapUpdate(cnode, new_cnode);
558   }
559   AnfAlgo::SetGraphId(graph_id_, cnode.get());
560   return new_cnode;
561 }
562 
NewParameter(const ParameterPtr & parameter)563 ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
564   auto abstract = parameter == nullptr ? std::make_shared<abstract::AbstractNone>() : parameter->abstract();
565   auto new_parameter = NewParameter(abstract);
566   // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
567   if (parameter != nullptr) {
568     new_parameter->set_name(parameter->name());
569     if (AnfAlgo::IsParameterWeight(parameter)) {
570       new_parameter->set_default_param(parameter->default_param());
571     }
572   }
573   // create kernel_info form new parameter
574   SetKernelInfoForNode(new_parameter);
575   AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
576   return new_parameter;
577 }
578 
NewParameter(const abstract::AbstractBasePtr & abstract)579 ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) {
580   ParameterPtr new_parameter = add_parameter();
581   new_parameter->set_abstract(abstract);
582   // create kernel_info form new parameter
583   SetKernelInfoForNode(new_parameter);
584   AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
585   return new_parameter;
586 }
587 
NewValueNode(const ValueNodePtr & value_node)588 ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
589   MS_EXCEPTION_IF_NULL(value_node);
590   auto new_value_node = MakeValueNode(value_node)->cast<ValueNodePtr>();
591   AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
592   return new_value_node;
593 }
594 
NewValueNode(const AbstractBasePtr & abstract,const ValuePtr & value)595 ValueNodePtr KernelGraph::NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value) {
596   MS_EXCEPTION_IF_NULL(abstract);
597   MS_EXCEPTION_IF_NULL(value);
598   ValueNodePtr new_value_node = std::make_shared<ValueNode>(value);
599   MS_EXCEPTION_IF_NULL(new_value_node);
600   new_value_node->set_abstract(abstract);
601   SetKernelInfoForNode(new_value_node);
602   AnfAlgo::SetGraphId(graph_id(), new_value_node.get());
603   return new_value_node;
604 }
605 
NewValueNode(const tensor::TensorPtr & input_tensor)606 ValueNodePtr KernelGraph::NewValueNode(const tensor::TensorPtr &input_tensor) {
607   MS_EXCEPTION_IF_NULL(input_tensor);
608   auto value_node = std::make_shared<ValueNode>(input_tensor);
609   MS_EXCEPTION_IF_NULL(value_node);
610   // construct abstract of value node
611   auto type_of_tensor = input_tensor->Dtype();
612   auto shape_of_tensor = input_tensor->shape();
613   auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
614   value_node->set_abstract(abstract);
615   // add value node to graph
616   auto input_value_node = NewValueNode(value_node);
617   AddValueNodeToGraph(input_value_node);
618   return input_value_node;
619 }
620 
TransValueNodeTuple(const AbstractBasePtr & abstract,const ValuePtr & value)621 AnfNodePtr KernelGraph::TransValueNodeTuple(const AbstractBasePtr &abstract, const ValuePtr &value) {
622   MS_EXCEPTION_IF_NULL(abstract);
623   MS_EXCEPTION_IF_NULL(value);
624   if (!abstract->isa<abstract::AbstractTuple>()) {
625     auto new_value_node = NewValueNode(abstract, value);
626     AddValueNodeToGraph(new_value_node);
627     return new_value_node;
628   }
629   auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
630   auto value_tuple = value->cast<ValueTuplePtr>();
631   MS_EXCEPTION_IF_NULL(tuple_abstract);
632   MS_EXCEPTION_IF_NULL(value_tuple);
633   if (tuple_abstract->size() != value_tuple->size()) {
634     MS_LOG(EXCEPTION) << "Abstract size:" << tuple_abstract->size()
635                       << " is not equal to value size:" << value_tuple->size();
636   }
637   std::vector<AnfNodePtr> make_tuple_inputs = {
638     mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
639   for (size_t index = 0; index < tuple_abstract->size(); ++index) {
640     make_tuple_inputs.push_back(TransValueNodeTuple((*tuple_abstract)[index], (*value_tuple)[index]));
641   }
642   auto make_tuple = NewCNode(make_tuple_inputs);
643   MS_EXCEPTION_IF_NULL(make_tuple);
644   make_tuple->set_abstract(tuple_abstract);
645   return make_tuple;
646 }
647 
TransParameterTuple(const AbstractBasePtr & abstract)648 AnfNodePtr KernelGraph::TransParameterTuple(const AbstractBasePtr &abstract) {
649   MS_EXCEPTION_IF_NULL(abstract);
650   if (!abstract->isa<abstract::AbstractTuple>()) {
651     return NewParameter(abstract);
652   }
653   auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
654   MS_EXCEPTION_IF_NULL(tuple_abstract);
655   std::vector<AnfNodePtr> make_tuple_inputs = {
656     mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
657   for (size_t index = 0; index < tuple_abstract->size(); ++index) {
658     make_tuple_inputs.push_back(TransParameterTuple((*tuple_abstract)[index]));
659   }
660   auto make_tuple = NewCNode(make_tuple_inputs);
661   make_tuple->set_abstract(tuple_abstract);
662   return make_tuple;
663 }
664 
CreatTupleGetItemNode(const AnfNodePtr & node,size_t output_idx)665 AnfNodePtr KernelGraph::CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx) {
666   auto idx = mindspore::NewValueNode(SizeToLong(output_idx));
667   MS_EXCEPTION_IF_NULL(idx);
668   auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx));
669   auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
670   idx->set_abstract(abstract_scalar);
671   AnfNodePtr tuple_getitem = NewCNode({mindspore::NewValueNode(prim::kPrimTupleGetItem), node, idx});
672   MS_EXCEPTION_IF_NULL(tuple_getitem);
673   tuple_getitem->set_scope(node->scope());
674   std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
675   TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx);
676   AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
677   return tuple_getitem;
678 }
679 
TransCNodeTuple(const CNodePtr & node)680 AnfNodePtr KernelGraph::TransCNodeTuple(const CNodePtr &node) {
681   MS_EXCEPTION_IF_NULL(node);
682   std::vector<TypeId> types;
683   std::vector<std::vector<size_t>> shapes;
684   std::vector<AnfNodePtr> make_tuple_inputs_list = {mindspore::NewValueNode(prim::kPrimMakeTuple)};
685   size_t output_num = AnfAlgo::GetOutputTensorNum(node);
686   for (size_t tuple_out_index = 0; tuple_out_index < output_num; ++tuple_out_index) {
687     make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(node, tuple_out_index));
688     types.push_back(AnfAlgo::GetOutputInferDataType(node, tuple_out_index));
689     shapes.emplace_back(AnfAlgo::GetOutputInferShape(node, tuple_out_index));
690   }
691   auto make_tuple = NewCNode(make_tuple_inputs_list);
692   AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get());
693   return make_tuple;
694 }
695 
TransTupleToMakeTuple(const AnfNodePtr & node)696 AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) {
697   MS_EXCEPTION_IF_NULL(node);
698   if (!AnfAlgo::IsTupleOutput(node)) {
699     return node;
700   }
701   if (node->isa<Parameter>()) {
702     return TransParameterTuple(node->abstract());
703   } else if (node->isa<ValueNode>()) {
704     auto value_node = node->cast<ValueNodePtr>();
705     MS_EXCEPTION_IF_NULL(value_node);
706     auto make_tuple = TransValueNodeTuple(value_node->abstract(), value_node->value());
707     if (!RemoveValueNodeFromGraph(value_node)) {
708       MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
709     }
710     return make_tuple;
711   } else if (node->isa<CNode>()) {
712     return TransCNodeTuple(node->cast<CNodePtr>());
713   } else {
714     return nullptr;
715   }
716 }
717 
inputs() const718 const std::vector<AnfNodePtr> &KernelGraph::inputs() const {
719   MS_EXCEPTION_IF_NULL(inputs_);
720   return *inputs_;
721 }
722 
FrontBackendlMapAdd(const AnfNodePtr & front_anf,const AnfNodePtr & backend_anf)723 void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) {
724   MS_EXCEPTION_IF_NULL(front_anf);
725   MS_EXCEPTION_IF_NULL(backend_anf);
726   if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) {
727     MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
728   }
729   if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) {
730     auto front_node = front_anf->cast<CNodePtr>();
731     MS_EXCEPTION_IF_NULL(front_node);
732     auto attr_input = front_node->input(kAnfPrimitiveIndex);
733     MS_EXCEPTION_IF_NULL(attr_input);
734     if (!attr_input->isa<CNode>()) {
735       MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
736     }
737   }
738   front_backend_anf_map_[front_anf] = backend_anf;
739   backend_front_anf_map_[backend_anf] = front_anf;
740 }
741 
FrontBackendlMapUpdate(const AnfNodePtr & old_backend_anf,const AnfNodePtr & new_backend_anf)742 void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) {
743   MS_EXCEPTION_IF_NULL(old_backend_anf);
744   MS_EXCEPTION_IF_NULL(new_backend_anf);
745   if (old_backend_anf == new_backend_anf) {
746     MS_LOG(DEBUG) << "Old same with new:" << old_backend_anf->DebugString();
747     return;
748   }
749   if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
750     MS_LOG(DEBUG) << "Old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
751     return;
752   }
753   if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) {
754     MS_LOG(EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString();
755   }
756   if (IsInternalOutput(old_backend_anf)) {
757     ReplaceInternalOutput(old_backend_anf, new_backend_anf);
758   }
759   front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf;
760   backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf];
761   // delete old kernel
762   (void)backend_front_anf_map_.erase(old_backend_anf);
763 }
764 
765 // get kernel by anf
GetBackendAnfByFrontAnf(const AnfNodePtr & front_anf)766 AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) {
767   if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) {
768     return nullptr;
769   }
770   return front_backend_anf_map_[front_anf];
771 }
772 
GetFrontAnfByBackendAnf(const AnfNodePtr & backend_anf)773 AnfNodePtr KernelGraph::GetFrontAnfByBackendAnf(const AnfNodePtr &backend_anf) {
774   if (backend_front_anf_map_.find(backend_anf) == backend_front_anf_map_.end()) {
775     return nullptr;
776   }
777   return backend_front_anf_map_[backend_anf];
778 }
779 
BackendNodeExistInFrontBackendMap(const AnfNodePtr & backend_anf)780 bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) {
781   return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end();
782 }
783 
GetValueNodeByTensor(const mindspore::tensor::TensorPtr & tensor)784 ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) {
785   if (tensor_to_value_node_map_.find(tensor) == tensor_to_value_node_map_.end()) {
786     return nullptr;
787   }
788   return tensor_to_value_node_map_[tensor];
789 }
790 
TensorValueNodeMapAdd(const tensor::TensorPtr & tensor,const ValueNodePtr & value_node)791 void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) {
792   MS_EXCEPTION_IF_NULL(tensor);
793   MS_EXCEPTION_IF_NULL(value_node);
794   tensor_to_value_node_map_[tensor] = value_node;
795 }
796 
AddDependEdge(const AnfNodePtr & node,const AnfNodePtr & input,size_t depend_edge_num)797 void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) {
798   MS_EXCEPTION_IF_NULL(node);
799   MS_EXCEPTION_IF_NULL(input);
800   MS_LOG(DEBUG) << "Input:" << input->DebugString() << ",  node:" << node->DebugString() << ",num:" << depend_edge_num;
801   auto output_depend_edge = std::pair<AnfNodePtr, size_t>(node, depend_edge_num);
802   // add output depend edge of input
803   auto output_it = node_output_edges_.find(input);
804   if (output_it == node_output_edges_.end()) {
805     node_output_edges_[input] = std::vector<std::pair<AnfNodePtr, size_t>>{output_depend_edge};
806   } else {
807     output_it->second.push_back(output_depend_edge);
808   }
809   // add input depend edge of output
810   auto input_depend_edge = std::pair<AnfNodePtr, size_t>(input, depend_edge_num);
811   auto input_it = node_input_edges_.find(node);
812   if (input_it == node_input_edges_.end()) {
813     node_input_edges_[node] = std::vector<std::pair<AnfNodePtr, size_t>>{input_depend_edge};
814   } else {
815     input_it->second.push_back(input_depend_edge);
816   }
817   // add node input depend num
818   auto depend_it = node_input_num_.find(node);
819   if (depend_it == node_input_num_.end()) {
820     node_input_num_[node] = depend_edge_num;
821   } else {
822     depend_it->second += depend_edge_num;
823   }
824 }
825 
GetOutputNodes(const AnfNodePtr & node)826 std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
827   MS_EXCEPTION_IF_NULL(node);
828   auto it = node_output_edges_.find(node);
829   if (it == node_output_edges_.end()) {
830     MS_LOG(EXCEPTION) << "Can't find node[" << node->DebugString() << "]";
831   }
832   std::vector<AnfNodePtr> output_nodes;
833   auto trans = [](const std::pair<AnfNodePtr, size_t> &pair) -> AnfNodePtr { return pair.first; };
834   (void)std::transform(it->second.begin(), it->second.end(), std::back_inserter(output_nodes), trans);
835   return output_nodes;
836 }
837 
UpdateNodeEdgeList(std::queue<AnfNodePtr> * seed_nodes)838 void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
839   MS_EXCEPTION_IF_NULL(seed_nodes);
840   node_output_edges_.clear();
841   node_input_num_.clear();
842   node_input_edges_.clear();
843   std::unordered_set<AnfNodePtr> visited_nodes;
844   std::queue<AnfNodePtr> que;
845   que.push(get_return());
846   while (!que.empty()) {
847     auto node = que.front();
848     que.pop();
849     MS_EXCEPTION_IF_NULL(node);
850     if (node->isa<Parameter>() || node->isa<ValueNode>()) {
851       seed_nodes->push(node);
852       continue;
853     }
854     auto cnode = dyn_cast<CNode>(node);
855     if (cnode == nullptr) {
856       continue;
857     }
858     auto &inputs = cnode->inputs();
859     // We push inputs from right to left, so that them can be evaluated from left to right.
860     for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
861       auto &input = *iter;
862       PushNoVisitedNode(input, &que, &visited_nodes);
863       AddDependEdge(node, input, 1);
864     }
865   }
866 }
867 
AddValueNodeToGraph(const ValueNodePtr & value_node)868 void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); }
869 
IsInRefOutputMap(const AnfWithOutIndex & pair) const870 bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; }
871 
GetRefCorrespondOutput(const AnfWithOutIndex & out_pair) const872 AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
873   if (!IsInRefOutputMap(out_pair)) {
874     MS_LOG(EXCEPTION) << "Out_pair is not in RefOutputMap, node is " << out_pair.first->DebugString() << ", index is "
875                       << out_pair.second;
876   }
877   return ref_out_in_map_.at(out_pair);
878 }
879 
AddRefCorrespondPairs(const AnfWithOutIndex & final_pair,const AnfWithOutIndex & origin_pair)880 void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) {
881   if (IsInRefOutputMap(final_pair)) {
882     MS_LOG(EXCEPTION) << "Out_pair is already in RefOutputMap, node is " << final_pair.first->DebugString()
883                       << ", index is " << final_pair.second;
884   }
885   (void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair));
886 }
887 
RemoveValueNodeFromGraph(const ValueNodePtr & value_node)888 bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
889   if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) {
890     (void)graph_value_nodes_.erase(value_node);
891     return true;
892   }
893   return false;
894 }
895 
ReplaceGraphInput(const AnfNodePtr & old_parameter,const AnfNodePtr & new_parameter)896 void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter) {
897   // update graph inputs
898   MS_EXCEPTION_IF_NULL(old_parameter);
899   MS_EXCEPTION_IF_NULL(new_parameter);
900   if (old_parameter == new_parameter) {
901     return;
902   }
903   for (size_t i = 0; i < inputs_->size(); i++) {
904     if ((*inputs_)[i] == old_parameter) {
905       MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_parameter->DebugString()
906                    << ",new graph input:" << new_parameter->DebugString();
907       (*inputs_)[i] = new_parameter;
908       break;
909     }
910   }
911 }
912 
ReplaceNode(const AnfNodePtr & old_anf_node,const AnfNodePtr & new_anf_node)913 void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, const AnfNodePtr &new_anf_node) {
914   MS_EXCEPTION_IF_NULL(inputs_);
915   {
916     std::queue<AnfNodePtr> seed_nodes;
917     UpdateNodeEdgeList(&seed_nodes);
918   }
919   auto it = node_output_edges_.find(old_anf_node);
920   if (it != node_output_edges_.end()) {
921     const auto &outputs = it->second;
922     for (auto &output_node : outputs) {
923       MS_EXCEPTION_IF_NULL(output_node.first);
924       auto output_cnode = output_node.first->cast<CNodePtr>();
925       MS_EXCEPTION_IF_NULL(output_cnode);
926       auto &output_node_inputs = output_cnode->inputs();
927       // don't replace node if it is a control edge  => output_node.second == 0
928       if (output_node.second == 0) {
929         continue;
930       }
931       for (size_t i = 1; i < output_node_inputs.size(); i++) {
932         if (output_node_inputs[i] == old_anf_node) {
933           output_cnode->set_input(i, new_anf_node);
934         }
935       }
936     }
937     // update front to backend map
938     FrontBackendlMapUpdate(old_anf_node, new_anf_node);
939   }
940   {
941     std::queue<AnfNodePtr> seed_nodes;
942     UpdateNodeEdgeList(&seed_nodes);
943   }
944 }
945 
UpdateExecuteKernelStreamLabel()946 void KernelGraph::UpdateExecuteKernelStreamLabel() {
947   for (auto &kernel : execution_order_) {
948     AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get());
949   }
950 }
951 
GetLeafGraphOrder()952 std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
953   std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order;
954   if (IsLeafGraph()) {
955     leaf_graph_order.push_back(shared_from_this()->cast<KernelGraphPtr>());
956   } else {
957     for (const auto &child_graph : child_graph_order_) {
958       std::shared_ptr<KernelGraph> child_graph_ptr = child_graph.lock();
959       MS_EXCEPTION_IF_NULL(child_graph_ptr);
960       auto child_leaf_graph_order = child_graph_ptr->GetLeafGraphOrder();
961       std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order));
962     }
963   }
964   return leaf_graph_order;
965 }
966 
IsLeafGraph() const967 bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); }
968 
FindNodeByPrimitive(const PrimitivePtr & primitive) const969 std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const {
970   std::vector<CNodePtr> result;
971   for (const auto &anf : execution_order_) {
972     MS_EXCEPTION_IF_NULL(anf);
973     if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
974       result.push_back(anf->cast<CNodePtr>());
975     }
976   }
977   return result;
978 }
979 
FindNodeByPrimitive(const std::vector<PrimitivePtr> & primitive_list) const980 std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const {
981   std::vector<CNodePtr> result;
982   for (const auto &anf : execution_order_) {
983     MS_EXCEPTION_IF_NULL(anf);
984     for (const auto &primitive : primitive_list) {
985       if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
986         result.push_back(anf->cast<CNodePtr>());
987       }
988     }
989   }
990   return result;
991 }
992 
PrintGraphExecuteOrder() const993 void KernelGraph::PrintGraphExecuteOrder() const {
994   if (!(IS_OUTPUT_ON(INFO))) {
995     return;
996   }
997   MS_LOG(INFO) << "Graph " << graph_id_ << " execution order:";
998   for (size_t i = 0; i < execution_order_.size(); i++) {
999     CNodePtr cur_cnode_ptr = execution_order_[i];
1000     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1001 
1002     std::string event_str;
1003     if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) {
1004       event_str = ", event id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrEventId)) + "]";
1005     }
1006 
1007     std::string label_str;
1008     if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) {
1009       label_str = ", label id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrLabelIndex)) + "]";
1010     }
1011 
1012     if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) {
1013       auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrLabelSwitchList);
1014       label_str = ", label id[";
1015       for (size_t j = 0; j < label_list.size(); ++j) {
1016         label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]");
1017       }
1018     }
1019 
1020     std::string active_stream_str;
1021     if (AnfAlgo::HasNodeAttr(kAttrActiveStreamList, cur_cnode_ptr)) {
1022       auto stream_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrActiveStreamList);
1023       active_stream_str = ", active stream id[";
1024       for (size_t j = 0; j < stream_list.size(); ++j) {
1025         active_stream_str += std::to_string(stream_list[j]) + (j + 1 < stream_list.size() ? ", " : "]");
1026       }
1027     }
1028 
1029     std::string group_str;
1030     if (AnfAlgo::GetKernelType(cur_cnode_ptr) == HCCL_KERNEL && AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode_ptr)) {
1031       group_str = ", group[" + AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup) + "]";
1032     }
1033 
1034     MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
1035                  << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
1036                  << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]"
1037                  << event_str << label_str << active_stream_str << group_str;
1038   }
1039 }
1040 
AddInternalOutput(const AnfNodePtr & front_node,const AnfNodePtr & node,size_t output_idx,bool unique_target)1041 void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, size_t output_idx,
1042                                     bool unique_target) {
1043   if (front_node == nullptr || node == nullptr) {
1044     MS_LOG(INFO) << "Front node or node is nullptr";
1045     return;
1046   }
1047   MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString();
1048   front_to_internal_outputs_map_[front_node] = node;
1049   if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
1050     output_idx = AnfAlgo::GetTupleGetItemOutIndex(front_node->cast<CNodePtr>());
1051   }
1052   internal_outputs_to_front_map_[node][output_idx] = std::pair<AnfNodePtr, bool>(front_node, unique_target);
1053 }
1054 
AddInternalOutputTensor(const AnfNodePtr & node,size_t output_idx,const tensor::TensorPtr & tensor)1055 void KernelGraph::AddInternalOutputTensor(const AnfNodePtr &node, size_t output_idx, const tensor::TensorPtr &tensor) {
1056   if (node == nullptr) {
1057     return;
1058   }
1059   internal_outputs_tensor_map_[node][output_idx] = tensor;
1060 }
1061 
GetInternalOutputTensor(const AnfNodePtr & node,size_t output_idx)1062 tensor::TensorPtr KernelGraph::GetInternalOutputTensor(const AnfNodePtr &node, size_t output_idx) {
1063   if (node == nullptr) {
1064     return nullptr;
1065   }
1066   auto iter = internal_outputs_tensor_map_.find(node);
1067   if (iter == internal_outputs_tensor_map_.end()) {
1068     return nullptr;
1069   }
1070   auto idx_iter = iter->second.find(output_idx);
1071   if (idx_iter == iter->second.end()) {
1072     return nullptr;
1073   }
1074   return idx_iter->second;
1075 }
1076 
ReplaceInternalOutput(const AnfNodePtr & node,const AnfNodePtr & new_node)1077 void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) {
1078   if (new_node == nullptr || node == nullptr) {
1079     MS_LOG(INFO) << "New node or node is nullptr";
1080     return;
1081   }
1082   if (node == new_node) {
1083     MS_LOG(INFO) << "New node and node is the same";
1084     return;
1085   }
1086   auto iter = internal_outputs_to_front_map_.find(node);
1087   if (iter == internal_outputs_to_front_map_.end()) {
1088     MS_LOG(INFO) << "Node is not internal output";
1089     return;
1090   }
1091   MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString();
1092   auto &front_nodes = iter->second;
1093   // Move all front nodes to new node mapping
1094   internal_outputs_to_front_map_[new_node] = front_nodes;
1095   for (const auto &front_node_iter : front_nodes) {
1096     front_to_internal_outputs_map_[front_node_iter.second.first] = new_node;
1097   }
1098   internal_outputs_to_front_map_.erase(iter);
1099 }
1100 
ReplaceInternalOutput(const AnfNodePtr & node,const AnfNodePtr & new_node,size_t src_output_idx,size_t dst_output_idx)1101 void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, size_t src_output_idx,
1102                                         size_t dst_output_idx) {
1103   if (new_node == nullptr || node == nullptr) {
1104     MS_LOG(INFO) << "New node or node is nullptr";
1105     return;
1106   }
1107   if (node == new_node) {
1108     MS_LOG(INFO) << "New node and node is the same";
1109     return;
1110   }
1111   auto iter = internal_outputs_to_front_map_.find(node);
1112   if (iter == internal_outputs_to_front_map_.end()) {
1113     MS_LOG(INFO) << "Node is not internal output";
1114     return;
1115   }
1116   MS_LOG(INFO) << "Replace internal output node " << node->DebugString() << " to " << new_node->DebugString();
1117   auto &front_nodes = iter->second;
1118   // Move specified front node to new node mapping
1119   auto front_node_iter = front_nodes.find(src_output_idx);
1120   if (front_node_iter == front_nodes.end()) {
1121     MS_LOG(INFO) << "The output " << src_output_idx << " of node " << node->DebugString() << " is not an internal node";
1122     return;
1123   }
1124   auto front_node_pair = front_node_iter->second;
1125   internal_outputs_to_front_map_[new_node][dst_output_idx] = front_node_pair;
1126   front_to_internal_outputs_map_[front_node_pair.first] = new_node;
1127   front_nodes.erase(src_output_idx);
1128   if (front_nodes.empty()) {
1129     internal_outputs_to_front_map_.erase(iter);
1130   }
1131 }
1132 
CacheInternalParameterToFrontNode(const AnfNodePtr & parameter,const AnfWithOutIndex & front_node_with_index)1133 void KernelGraph::CacheInternalParameterToFrontNode(const AnfNodePtr &parameter,
1134                                                     const AnfWithOutIndex &front_node_with_index) {
1135   if ((parameter == nullptr) || (front_node_with_index.first == nullptr)) {
1136     return;
1137   }
1138 
1139   auto front_outputs = AnfAlgo::GetAllOutputWithIndex(front_node_with_index.first);
1140   AnfWithOutIndex new_front_node_with_index;
1141   if (front_node_with_index.second < front_outputs.size()) {
1142     new_front_node_with_index = front_outputs[front_node_with_index.second];
1143   } else {
1144     new_front_node_with_index = front_node_with_index;
1145   }
1146 
1147   if (new_front_node_with_index.first == nullptr) {
1148     return;
1149   }
1150   MS_LOG(INFO) << "Cache internal parameter: " << parameter->DebugString()
1151                << " to front node: " << new_front_node_with_index.first->DebugString()
1152                << " with index: " << new_front_node_with_index.second
1153                << ", from front node: " << front_node_with_index.first->DebugString()
1154                << " with index: " << front_node_with_index.second;
1155   internal_parameter_to_front_node_map_[parameter] = new_front_node_with_index;
1156 }
1157 
GetFrontNodeByInternalParameter(const AnfNodePtr & parameter) const1158 AnfWithOutIndex KernelGraph::GetFrontNodeByInternalParameter(const AnfNodePtr &parameter) const {
1159   const auto &iter = internal_parameter_to_front_node_map_.find(parameter);
1160   if (iter != internal_parameter_to_front_node_map_.end()) {
1161     return iter->second;
1162   }
1163   return AnfWithOutIndex();
1164 }
1165 
GetFuncGraph()1166 FuncGraphPtr KernelGraph::GetFuncGraph() {
1167   if (front_backend_anf_map_.empty()) {
1168     return nullptr;
1169   }
1170 
1171   for (const auto &front_backend_anf : front_backend_anf_map_) {
1172     const auto &front_node = front_backend_anf.first;
1173     const auto &func_graph = front_node->func_graph();
1174     if (func_graph != nullptr) {
1175       return func_graph;
1176     }
1177   }
1178   return nullptr;
1179 }
1180 
CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr & backend_graph_output,const AnfNodePtr & front_node)1181 void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr &backend_graph_output,
1182                                                        const AnfNodePtr &front_node) {
1183   if ((backend_graph_output == nullptr) || (front_node == nullptr)) {
1184     return;
1185   }
1186 
1187   auto backend_outputs = AnfAlgo::GetAllOutputWithIndex(backend_graph_output);
1188   auto front_outputs = AnfAlgo::GetAllOutputWithIndex(front_node);
1189   if (backend_outputs.size() != front_outputs.size()) {
1190     MS_LOG(INFO) << "The size(" << backend_outputs.size()
1191                  << ") of backend output: " << backend_graph_output->DebugString() << " is not equal to the size("
1192                  << front_outputs.size() << ") of front output: " << front_node->DebugString();
1193     return;
1194   }
1195 
1196   for (size_t i = 0; i < backend_outputs.size(); ++i) {
1197     auto backend_output = backend_outputs[i];
1198     auto front_output = front_outputs[i];
1199     graph_output_to_front_node_map_[backend_output] = front_output;
1200     MS_LOG(INFO) << "Backend output: " << backend_output.first->fullname_with_scope()
1201                  << " with index: " << backend_output.second
1202                  << " map to front node: " << front_output.first->fullname_with_scope()
1203                  << " with index: " << front_output.second;
1204   }
1205 }
1206 
GetFrontNodeWithIndexByGraphOutput(const AnfWithOutIndex & backend_graph_output_with_index) const1207 AnfWithOutIndex KernelGraph::GetFrontNodeWithIndexByGraphOutput(
1208   const AnfWithOutIndex &backend_graph_output_with_index) const {
1209   const auto &iter = graph_output_to_front_node_map_.find(backend_graph_output_with_index);
1210   if (iter != graph_output_to_front_node_map_.end()) {
1211     return iter->second;
1212   }
1213   return AnfWithOutIndex();
1214 }
1215 
UpdateGraphOutputMap(const std::vector<AnfWithOutIndex> & old_outputs,const std::vector<AnfWithOutIndex> & new_outputs)1216 void KernelGraph::UpdateGraphOutputMap(const std::vector<AnfWithOutIndex> &old_outputs,
1217                                        const std::vector<AnfWithOutIndex> &new_outputs) {
1218   MS_LOG(INFO) << "The size of old outputs: " << old_outputs.size()
1219                << ", the size of new outputs: " << new_outputs.size();
1220   if (old_outputs.size() != new_outputs.size()) {
1221     MS_LOG(EXCEPTION) << "The size of old outputs is not equal to the size of new outputs.";
1222   }
1223 
1224   for (size_t i = 0; i < old_outputs.size(); ++i) {
1225     auto old_output = old_outputs[i];
1226     auto new_output = new_outputs[i];
1227     if (old_output == new_output) {
1228       continue;
1229     }
1230     // Update the graph output map.
1231     if (graph_output_to_front_node_map_.count(old_output) > 0) {
1232       MS_LOG(INFO) << "Replace backend output node " << old_output.first->fullname_with_scope() << " with index "
1233                    << old_output.second << " to " << new_output.first->fullname_with_scope() << " with index "
1234                    << new_output.second;
1235       graph_output_to_front_node_map_[new_output] = graph_output_to_front_node_map_[old_output];
1236       (void)graph_output_to_front_node_map_.erase(old_output);
1237     }
1238 
1239     if (old_output.first == new_output.first) {
1240       continue;
1241     }
1242     // Update the front backend node map.
1243     if ((backend_front_anf_map_.count(old_output.first) > 0) && old_output.first->isa<CNode>() &&
1244         new_output.first->isa<CNode>()) {
1245       MS_LOG(INFO) << "Replace backend output node " << old_output.first->fullname_with_scope() << " to "
1246                    << new_output.first->fullname_with_scope();
1247       auto front_node = backend_front_anf_map_[old_output.first];
1248       front_backend_anf_map_[front_node] = new_output.first;
1249       backend_front_anf_map_[new_output.first] = front_node;
1250       (void)backend_front_anf_map_.erase(old_output.first);
1251     }
1252   }
1253 }
1254 
GetInternalOutputByFrontNode(const AnfNodePtr & front_node) const1255 AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
1256   auto iter = front_to_internal_outputs_map_.find(front_node);
1257   if (iter != front_to_internal_outputs_map_.end()) {
1258     return iter->second;
1259   }
1260   return nullptr;
1261 }
1262 
IsInternalOutput(const AnfNodePtr & node) const1263 bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const {
1264   auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
1265   if (front_nodes_iter == internal_outputs_to_front_map_.end()) {
1266     return false;
1267   }
1268   return true;
1269 }
1270 
IsInternalOutput(const AnfNodePtr & node,size_t output_idx) const1271 bool KernelGraph::IsInternalOutput(const AnfNodePtr &node, size_t output_idx) const {
1272   auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
1273   if (front_nodes_iter == internal_outputs_to_front_map_.end()) {
1274     return false;
1275   }
1276   auto &front_nodes = front_nodes_iter->second;
1277   if (front_nodes.find(output_idx) == front_nodes.end()) {
1278     return false;
1279   }
1280   return true;
1281 }
1282 
IsUniqueTargetInternalOutput(const AnfNodePtr & node,size_t output_idx) const1283 bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, size_t output_idx) const {
1284   auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
1285   if (front_nodes_iter == internal_outputs_to_front_map_.end()) {
1286     return false;
1287   }
1288   auto &front_nodes = front_nodes_iter->second;
1289   auto idx_iter = front_nodes.find(output_idx);
1290   if (idx_iter == front_nodes.end()) {
1291     return false;
1292   }
1293   return idx_iter->second.second;
1294 }
1295 
UpdateChildGraphOrder()1296 void KernelGraph::UpdateChildGraphOrder() {
1297   MS_LOG(INFO) << "Update " << ToString() << " child graph order.";
1298   SetExecOrderByDefault();
1299   auto call_nodes = FindNodeByPrimitive({std::make_shared<Primitive>(prim::kPrimCall->name()),
1300                                          std::make_shared<Primitive>(prim::kPrimSwitch->name()),
1301                                          std::make_shared<Primitive>(prim::kPrimSwitchLayer->name())});
1302   std::vector<std::weak_ptr<KernelGraph>> child_graph_order;
1303   for (auto &call_node : call_nodes) {
1304     MS_EXCEPTION_IF_NULL(call_node);
1305     auto call_child_graphs = AnfAlgo::GetCallSwitchKernelGraph(call_node->cast<CNodePtr>());
1306     for (const auto &child_graph : call_child_graphs) {
1307       MS_EXCEPTION_IF_NULL(child_graph);
1308       if (child_graph != parent_graph_.lock()) {
1309         auto shared_this = std::dynamic_pointer_cast<KernelGraph>(shared_from_this());
1310         MS_EXCEPTION_IF_NULL(shared_this);
1311         child_graph->set_parent_graph(shared_this);
1312       }
1313       child_graph_order.push_back(child_graph);
1314     }
1315   }
1316   for (size_t i = 0; i < child_graph_order.size(); ++i) {
1317     std::shared_ptr<KernelGraph> child_graph = child_graph_order[i].lock();
1318     MS_EXCEPTION_IF_NULL(child_graph);
1319     MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph->graph_id() << "]";
1320   }
1321   child_graph_order_ = child_graph_order;
1322 }
1323 
RemoveNodeFromGraph(const AnfNodePtr & node)1324 void KernelGraph::RemoveNodeFromGraph(const AnfNodePtr &node) {
1325   MS_EXCEPTION_IF_NULL(node);
1326   if (backend_front_anf_map_.find(node) != backend_front_anf_map_.end()) {
1327     auto front_node = backend_front_anf_map_[node];
1328     (void)backend_front_anf_map_.erase(node);
1329     (void)front_backend_anf_map_.erase(front_node);
1330   }
1331   if (node->isa<ValueNode>()) {
1332     if (graph_value_nodes_.find(node->cast<ValueNodePtr>()) != graph_value_nodes_.end()) {
1333       (void)graph_value_nodes_.erase(node->cast<ValueNodePtr>());
1334     }
1335   }
1336 }
1337 
UpdateGraphDynamicAttr()1338 void KernelGraph::UpdateGraphDynamicAttr() {
1339   for (const auto &cnode : execution_order_) {
1340     if (AnfAlgo::IsDynamicShape(cnode)) {
1341       MS_LOG(INFO) << "Update Graph Dynamic Attr";
1342       is_dynamic_shape_ = true;
1343       return;
1344     }
1345   }
1346   is_dynamic_shape_ = false;
1347 }
1348 
SetInputNodes()1349 void KernelGraph::SetInputNodes() {
1350   input_nodes_.clear();
1351   for (const auto &input_node : inputs()) {
1352     auto params = AnfAlgo::GetAllOutput(input_node);
1353     std::copy(params.begin(), params.end(), std::back_inserter(input_nodes_));
1354   }
1355 }
1356 
SetOptimizerFlag()1357 void KernelGraph::SetOptimizerFlag() {
1358   has_optimizer_ = false;
1359   for (const auto &cnode : execution_order_) {
1360     MS_EXCEPTION_IF_NULL(cnode);
1361     auto node_name = AnfAlgo::GetCNodeName(cnode);
1362     if (AnfAlgo::HasNodeAttr(kAttrAsync, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrAsync)) {
1363       continue;
1364     }
1365     if (kOptOperatorSet.find(node_name) != kOptOperatorSet.end()) {
1366       has_optimizer_ = true;
1367     } else if (node_name.find("Assign") == string::npos) {
1368       continue;
1369     }
1370     for (auto &input : cnode->inputs()) {
1371       MS_EXCEPTION_IF_NULL(input);
1372       auto real_node = AnfAlgo::VisitKernel(input, 0).first;
1373       MS_EXCEPTION_IF_NULL(real_node);
1374       if (!real_node->isa<Parameter>()) {
1375         continue;
1376       }
1377       auto param = real_node->cast<ParameterPtr>();
1378       auto abstract = param->abstract();
1379       MS_EXCEPTION_IF_NULL(abstract);
1380       if (abstract->isa<abstract::AbstractRef>()) {
1381         has_optimizer_ = true;
1382         (void)updated_parameters_.insert(param);
1383       }
1384     }
1385   }
1386 }
1387 
IsDatasetGraph() const1388 bool KernelGraph::IsDatasetGraph() const {
1389   // check if there is InitDataSetQueue node
1390   const auto &nodes = execution_order_;
1391   for (const auto &node : nodes) {
1392     auto node_name = AnfAlgo::GetCNodeName(node);
1393     if (node_name == prim::kPrimInitDataSetQueue->name()) {
1394       return true;
1395     }
1396   }
1397   return false;
1398 }
1399 
ToString() const1400 std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
1401 
~KernelGraph()1402 KernelGraph::~KernelGraph() {
1403   try {
1404     // Release the kernel resource.
1405     for (const auto &kernel : execution_order_) {
1406       auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1407       if (kernel_mod != nullptr) {
1408         kernel_mod->ReleaseResource();
1409       }
1410     }
1411     device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_);
1412   } catch (const std::exception &e) {
1413     MS_LOG(ERROR) << "KernelGraph call destructor failed: " << e.what();
1414   } catch (...) {
1415     MS_LOG(ERROR) << "KernelGraph call destructor failed";
1416   }
1417 }
1418 }  // namespace session
1419 }  // namespace mindspore
1420