• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/optimizer/helper.h"
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <algorithm>
23 #include <map>
24 #include <set>
25 #include <deque>
26 #include <vector>
27 #include "kernel/kernel_build_info.h"
28 #include "mindspore/core/ops/sequence_ops.h"
29 #include "mindspore/core/ops/nn_ops.h"
30 #include "mindspore/core/ops/array_ops.h"
31 #include "mindspore/core/ops/framework_ops.h"
32 #include "utils/hash_set.h"
33 #include "include/common/utils/utils.h"
34 #include "base/base_ref.h"
35 #include "include/backend/anf_runtime_algorithm.h"
36 #include "include/common/utils/anfalgo.h"
37 #include "utils/log_adapter.h"
38 #include "utils/ms_utils.h"
39 #include "include/common/utils/convert_utils.h"
40 #include "include/backend/kernel_info.h"
41 #include "utils/ms_context.h"
42 #include "utils/trace_base.h"
43 #include "backend/common/pass/const_input_to_attr.h"
44 #include "backend/operator/ops_backend_infer_function.h"
45 #include "frontend/operator/ops_front_infer_function.h"
46 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
47 #include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
48 #include "include/common/profiler.h"
49 #include "abstract/ops/primitive_infer_map.h"
50 
51 namespace mindspore {
52 namespace opt {
53 namespace {
54 constexpr size_t kType32Len = 4;
55 constexpr size_t kType64Len = 8;
56 constexpr auto kNopNodeRealInputIndex = 1;
57 const std::map<std::string, std::map<size_t, TypeId>> OpInputDtypeMap = {{prim::kPrimGroupedMatmul->name(),
58                                                                           {{2, TypeId::kNumberTypeFloat16},
59                                                                            {3, TypeId::kNumberTypeUInt64},
60                                                                            {4, TypeId::kNumberTypeFloat32},
61                                                                            {5, TypeId::kNumberTypeFloat16},
62                                                                            {6, TypeId::kNumberTypeFloat16}}}};
63 
UpdateDumpFlagAndDebugInfo(const CNodePtr & node,const std::vector<AnfNodePtr> & orig_nodes)64 void UpdateDumpFlagAndDebugInfo(const CNodePtr &node, const std::vector<AnfNodePtr> &orig_nodes) {
65   MS_EXCEPTION_IF_NULL(node);
66   std::vector<AnfNodePtr> orig_real_cnodes;
67   for (auto &orig_node : orig_nodes) {
68     MS_EXCEPTION_IF_NULL(orig_node);
69     if (AnfUtils::IsRealCNodeKernel(orig_node)) {
70       auto orig_cnode = orig_node->cast<CNodePtr>();
71       MS_EXCEPTION_IF_NULL(orig_cnode);
72       if (common::AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) {
73         common::AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node);
74       }
75       orig_real_cnodes.push_back(orig_node);
76     }
77   }
78 
79   node->AddFusedDebugInfoList(orig_real_cnodes);
80 }
81 }  // namespace
82 
IsDepend(const FuncGraph & graph,const AnfNodePtr & node,const std::vector<AnfNodePtr> & nodes)83 bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
84   mindspore::HashSet<AnfNodePtr> visited_nodes;
85   return IsDepend(graph, node, nodes, &visited_nodes);
86 }
87 
IsDepend(const FuncGraph & graph,const AnfNodePtr & node,const std::vector<AnfNodePtr> & nodes,mindspore::HashSet<AnfNodePtr> * visited_nodes)88 bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes,
89               mindspore::HashSet<AnfNodePtr> *visited_nodes) {
90   MS_EXCEPTION_IF_NULL(node);
91   MS_EXCEPTION_IF_NULL(visited_nodes);
92   FuncGraphManagerPtr manager = graph.manager();
93   MS_EXCEPTION_IF_NULL(manager);
94 
95   std::deque<AnfNodePtr> todo{node};
96   while (!todo.empty()) {
97     AnfNodePtr nd = todo.front();
98     todo.pop_front();
99     if (visited_nodes->count(nd) > 0 || !manager->all_nodes().contains(nd)) {
100       continue;
101     }
102     (void)visited_nodes->insert(nd);
103 
104     if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) {
105       return true;
106     }
107     if (nd->isa<CNode>()) {
108       auto cnode = nd->cast<CNodePtr>();
109       MS_EXCEPTION_IF_NULL(cnode);
110       auto inputs = cnode->inputs();
111       (void)todo.insert(todo.cend(), inputs.cbegin(), inputs.cend());
112     }
113   }
114   return false;
115 }
116 
UnVisited(const BaseRef & n)117 bool UnVisited(const BaseRef &n) {
118   if (utils::isa<AnfNodePtr>(n)) {
119     AnfNodePtr in = utils::cast<AnfNodePtr>(n);
120     MS_EXCEPTION_IF_NULL(in);
121     if (IsValueNode<Primitive>(in)) {
122       auto value_node = in->cast<ValueNodePtr>();
123       MS_EXCEPTION_IF_NULL(value_node);
124       auto value = value_node->value();
125       MS_EXCEPTION_IF_NULL(value);
126       auto prim_py = value->cast<PrimitivePtr>();
127       MS_EXCEPTION_IF_NULL(prim_py);
128       return !prim_py->HasAttr(kAttrVisited);
129     } else if (IsValueNode<FuncGraph>(in)) {
130       auto func_graph = GetValueNode<FuncGraphPtr>(in);
131       MS_EXCEPTION_IF_NULL(func_graph);
132       return !func_graph->has_flag(kAttrVisited);
133     }
134     return false;
135   }
136   return false;
137 }
138 
NewCNode(const std::vector<AnfNodePtr> & inputs,const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & orig_nodes)139 CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
140                   const std::vector<AnfNodePtr> &orig_nodes) {
141   MS_EXCEPTION_IF_NULL(fg);
142   auto node = fg->NewCNode(inputs);
143   MS_EXCEPTION_IF_NULL(node);
144   UpdateDumpFlagAndDebugInfo(node, orig_nodes);
145   return node;
146 }
147 
NewCNode(const CNodePtr & cnode,const KernelGraphPtr & fg,const std::vector<AnfNodePtr> & orig_nodes)148 CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes) {
149   MS_EXCEPTION_IF_NULL(fg);
150   auto node = fg->NewCNode(cnode);
151   MS_EXCEPTION_IF_NULL(node);
152   UpdateDumpFlagAndDebugInfo(node, orig_nodes);
153   return node;
154 }
155 
CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr & node,size_t input_size)156 CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size) {
157   MS_EXCEPTION_IF_NULL(node);
158   if (!node->isa<CNode>()) {
159     MS_LOG(INTERNAL_EXCEPTION) << "The node is expected to be a cnode";
160   }
161   auto cnode = node->cast<CNodePtr>();
162   CheckCNodeInputSize(cnode, input_size);
163   return cnode;
164 }
165 
CheckCNodeInputSize(const CNodePtr & cnode,size_t input_tensor_size)166 void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_size) {
167   MS_EXCEPTION_IF_NULL(cnode);
168   auto real_input_tensor_num = common::AnfAlgo::GetInputTensorNum(cnode);
169   if (real_input_tensor_num != input_tensor_size) {
170     MS_LOG(EXCEPTION) << "The input tensor size[" << real_input_tensor_num
171                       << "] of node [" + cnode->DebugString() + "] is not equal to " << input_tensor_size
172                       << trace::DumpSourceLines(cnode);
173   }
174 }
175 
HasSymmetricalKernelInfo(const AnfNodePtr & node_x,const AnfNodePtr & node_y)176 bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) {
177   MS_EXCEPTION_IF_NULL(node_x);
178   MS_EXCEPTION_IF_NULL(node_y);
179   return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) &&
180           AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0));
181 }
182 
EliminateDependTransop(const FuncGraphPtr & func_graph,const AnfNodePtr & node)183 const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
184   MS_EXCEPTION_IF_NULL(func_graph);
185   MS_EXCEPTION_IF_NULL(node);
186 
187   auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
188   MS_EXCEPTION_IF_NULL(transop_cnode);
189   auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(1), kDependInputTensorNum);
190   auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputTensorNum);
191   auto transed_node = prev_transop_cnode->input(1);
192   MS_EXCEPTION_IF_NULL(transed_node);
193 
194   std::vector<AnfNodePtr> replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node,
195                                                 depend_cnode->input(kDependAttachNodeIndex)};
196   AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs);
197   MS_EXCEPTION_IF_NULL(replace_depend);
198   auto transed_abstract = transed_node->abstract();
199   replace_depend->set_abstract(transed_abstract);
200   return replace_depend;
201 }
202 
Visited(const BaseRef & n)203 bool Visited(const BaseRef &n) {
204   if (utils::isa<AnfNodePtr>(n)) {
205     AnfNodePtr in = utils::cast<AnfNodePtr>(n);
206     MS_EXCEPTION_IF_NULL(in);
207     if (IsValueNode<Primitive>(in)) {
208       auto value_node = in->cast<ValueNodePtr>();
209       MS_EXCEPTION_IF_NULL(value_node);
210       auto value = value_node->value();
211       MS_EXCEPTION_IF_NULL(value);
212       auto prim_py = value->cast<PrimitivePtr>();
213       MS_EXCEPTION_IF_NULL(prim_py);
214       return prim_py->HasAttr(kAttrVisited);
215     } else if (IsValueNode<FuncGraph>(in)) {
216       auto func_graph = GetValueNode<FuncGraphPtr>(in);
217       MS_EXCEPTION_IF_NULL(func_graph);
218       return func_graph->has_flag(kAttrVisited);
219     }
220     return false;
221   }
222   return false;
223 }
224 
CreateMultipleOutputsOfAnfNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,size_t output_num,std::vector<AnfNodePtr> * outputs)225 void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
226                                     std::vector<AnfNodePtr> *outputs) {
227   MS_EXCEPTION_IF_NULL(func_graph);
228   MS_EXCEPTION_IF_NULL(node);
229   MS_EXCEPTION_IF_NULL(outputs);
230   auto type_ptr = node->Type();
231   for (size_t i = 0; i < output_num; i++) {
232     int64_t temp = SizeToLong(i);
233     auto idx = NewValueNode(temp);
234     MS_EXCEPTION_IF_NULL(idx);
235     auto imm = std::make_shared<Int64Imm>(temp);
236     auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
237     idx->set_abstract(abstract_scalar);
238     auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
239     MS_EXCEPTION_IF_NULL(tuple_getitem);
240     tuple_getitem->set_scope(node->scope());
241     common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(type_ptr, i)},
242                                                 {common::AnfAlgo::GetOutputInferShape(node, i)}, tuple_getitem.get());
243     (*outputs).push_back(tuple_getitem);
244   }
245 }
246 
247 template <typename T>
CreateTensorWithValueTuple(const ValueTuplePtr & value_tuple_ptr,const TypePtr & type_ptr,size_t data_length)248 tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
249                                              size_t data_length) {
250   MS_EXCEPTION_IF_NULL(value_tuple_ptr);
251   MS_EXCEPTION_IF_NULL(type_ptr);
252   std::vector<T> values;
253   for (const auto &v : value_tuple_ptr->value()) {
254     MS_EXCEPTION_IF_NULL(v);
255     if (v->isa<Scalar>()) {
256       ScalarPtr scalar = v->cast<ScalarPtr>();
257       values.push_back(GetValue<T>(scalar));
258     } else {
259       MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
260       return nullptr;
261     }
262   }
263   std::vector<int64_t> tensor_shape = {SizeToLong(values.size())};
264   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
265   MS_EXCEPTION_IF_NULL(tensor);
266   tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
267   tensor->set_device_info(device_info);
268   auto data_ptr = tensor->data_c();
269   MS_EXCEPTION_IF_NULL(data_ptr);
270   auto elem_num = values.size() * data_length;
271   auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num);
272   if (ret_code != EOK) {
273     MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
274   }
275   return tensor;
276 }
277 
CreateEmptyTupleTensor(const ValueTuplePtr & value_tuple)278 tensor::TensorPtr CreateEmptyTupleTensor(const ValueTuplePtr &value_tuple) {
279   std::vector<int64_t> tensor_shape = {0};
280   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kInt64->type_id(), tensor_shape);
281   MS_EXCEPTION_IF_NULL(tensor);
282   tensor::DeviceInfo device_info{kOpFormat_DEFAULT, kInt64};
283   tensor->set_device_info(device_info);
284   tensor->set_user_data(kTensorValueIsEmpty, value_tuple);
285   return tensor;
286 }
287 
CreateTensorInput(const KernelGraphPtr & kernel_graph,const AnfNodePtr & input_node)288 AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) {
289   MS_EXCEPTION_IF_NULL(input_node);
290   auto value_node = input_node->cast<ValueNodePtr>();
291   MS_EXCEPTION_IF_NULL(value_node);
292   auto value = value_node->value();
293   MS_EXCEPTION_IF_NULL(value);
294   tensor::TensorPtr tensor_ptr = nullptr;
295   if (value->isa<Scalar>()) {
296     tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
297   } else if (value->isa<ValueTuple>()) {
298     tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
299   } else if (value->isa<ValueList>()) {
300     tensor_ptr = CreateTupleTensor(std::make_shared<ValueTuple>(value->cast<ValueListPtr>()->value()));
301   } else {
302     MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple";
303   }
304   if (tensor_ptr == nullptr) {
305     MS_LOG(DEBUG) << "Create tensor failed";
306     return nullptr;
307   }
308   auto tensor_input = std::make_shared<ValueNode>(tensor_ptr);
309   MS_EXCEPTION_IF_NULL(tensor_input);
310   tensor_input->set_abstract(tensor_ptr->ToAbstract());
311   if (kernel_graph != nullptr) {
312     tensor_input = kernel_graph->NewValueNode(tensor_input);
313     kernel_graph->AddValueNodeToGraph(tensor_input);
314     kernel_graph->FrontBackendlMapUpdate(input_node, tensor_input);
315   } else {
316     tensor_input = MakeValueNode(tensor_input);
317   }
318   tensor_input->set_scope(input_node->scope());
319   return tensor_input;
320 }
321 
CreateTupleTensor(const ValueTuplePtr & value_tuple)322 tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
323   MS_EXCEPTION_IF_NULL(value_tuple);
324   tensor::TensorPtr tensor = nullptr;
325   if (value_tuple->value().empty()) {
326     tensor = CreateEmptyTupleTensor(value_tuple);
327     return tensor;
328   }
329   ValuePtr v = *(value_tuple->value().begin());
330   MS_EXCEPTION_IF_NULL(v);
331   // Currently we only deal with the scalar tuple
332   if (!v->isa<Scalar>()) {
333     MS_LOG(DEBUG) << "The value " << v << "of tuple is not a scalar";
334     return nullptr;
335   }
336   ScalarPtr scalar = v->cast<ScalarPtr>();
337   MS_EXCEPTION_IF_NULL(scalar);
338   if (scalar->isa<Int32Imm>()) {
339     tensor = CreateTensorWithValueTuple<int32_t>(value_tuple, kInt32, sizeof(int32_t));
340   } else if (scalar->isa<Int64Imm>()) {
341     tensor = CreateTensorWithValueTuple<int64_t>(value_tuple, kInt64, sizeof(int64_t));
342   } else if (scalar->isa<FloatImm>()) {
343     tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, sizeof(float));
344   } else {
345     auto type = scalar->type();
346     auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
347     MS_LOG(ERROR) << "Invalid scalar type: " << type_str;
348     return nullptr;
349   }
350   return tensor;
351 }
352 
CreateTensorMoveOp(const FuncGraphPtr & graph,const AnfNodePtr & node)353 AnfNodePtr CreateTensorMoveOp(const FuncGraphPtr &graph, const AnfNodePtr &node) {
354   MS_EXCEPTION_IF_NULL(graph);
355   MS_EXCEPTION_IF_NULL(node);
356   auto prim = std::make_shared<Primitive>(kTensorMoveOpName);
357   std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim), node};
358   auto new_node = graph->NewCNode(new_node_inputs);
359   MS_EXCEPTION_IF_NULL(new_node);
360   new_node->set_abstract(node->abstract());
361   new_node->set_scope(node->scope());
362   common::AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), new_node);
363   return new_node;
364 }
365 
InsertTensorMoveForGraphOutput(const FuncGraphPtr & graph,const AnfNodePtr & node)366 std::vector<AnfNodePtr> InsertTensorMoveForGraphOutput(const FuncGraphPtr &graph, const AnfNodePtr &node) {
367   MS_EXCEPTION_IF_NULL(graph);
368   MS_EXCEPTION_IF_NULL(node);
369   auto kernel_graph = graph->cast<KernelGraphPtr>();
370   MS_EXCEPTION_IF_NULL(kernel_graph);
371 
372   std::vector<AnfNodePtr> ret;
373   auto manager = graph->manager();
374   MS_EXCEPTION_IF_NULL(manager);
375   auto &node_users = manager->node_users();
376   auto iter = node_users.find(node);
377   if (iter == node_users.end()) {
378     return ret;
379   }
380   for (auto &item : iter->second) {
381     MS_EXCEPTION_IF_NULL(item.first);
382     auto next_node = item.first->cast<CNodePtr>();
383     bool find = false;
384     auto graph_outputs_pair =
385       common::AnfAlgo::GetAllOutputIndexByReturnTypes(graph->output(), {prim::kPrimTupleGetItem});
386     for (auto output_pair : graph_outputs_pair) {
387       while (AnfUtils::IsRealCNodeKernel(output_pair.first)) {
388         auto output_kernel = output_pair.first;
389         MS_EXCEPTION_IF_NULL(output_kernel);
390         auto cnode = output_kernel->cast<CNodePtr>();
391         // nop node
392         if (common::AnfAlgo::IsNopNode(cnode)) {
393           output_pair = common::AnfAlgo::VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, true);
394           continue;
395         }
396         // ref node
397         if (kernel_graph->IsInRefOutputMap(output_pair)) {
398           output_pair = kernel_graph->GetRefCorrespondOutput(output_pair);
399           continue;
400         }
401         break;
402       }
403       MS_EXCEPTION_IF_NULL(output_pair.first);
404       if (next_node == output_pair.first->cast<CNodePtr>()) {
405         find = true;
406         break;
407       }
408     }
409     if (!find) {
410       continue;
411     }
412     auto tensor_move = CreateTensorMoveOp(graph, next_node);
413     auto kernel_info = std::make_shared<device::KernelInfo>();
414     MS_EXCEPTION_IF_NULL(tensor_move);
415     tensor_move->set_kernel_info(kernel_info);
416     (void)manager->Replace(next_node, tensor_move);
417     ret.push_back(tensor_move);
418     MS_LOG(DEBUG) << "Insert Output TensorMove for op " << node->fullname_with_scope();
419   }
420   return ret;
421 }
422 
IsAllNopNode(const session::KernelGraph * const graph)423 bool IsAllNopNode(const session::KernelGraph *const graph) {
424   MS_EXCEPTION_IF_NULL(graph);
425   auto execution_order = graph->execution_order();
426   for (auto &cnode : execution_order) {
427     MS_EXCEPTION_IF_NULL(cnode);
428     if (!common::AnfAlgo::IsNopNode(cnode)) {
429       return false;
430     }
431   }
432   return true;
433 }
434 
NeedHideNode(const std::vector<AnfNodePtr> & outputs,const AnfNodePtr & node,bool need_keep_output_nop_node)435 bool NeedHideNode(const std::vector<AnfNodePtr> &outputs, const AnfNodePtr &node, bool need_keep_output_nop_node) {
436   MS_EXCEPTION_IF_NULL(node);
437   // if node is not a nop node, keep it in execution order
438   if (!common::AnfAlgo::IsNopNode(node)) {
439     return false;
440   }
441   // if node is nop node and the graph is dynamic graph, check if the nop node is graph's output.
442   if (need_keep_output_nop_node) {
443     auto iter = find(outputs.begin(), outputs.end(), node);
444     if (iter != outputs.end()) {
445       return false;
446     }
447   }
448   return true;
449 }
450 
HideNopNode(session::KernelGraph * const graph)451 void HideNopNode(session::KernelGraph *const graph) {
452   MS_EXCEPTION_IF_NULL(graph);
453   if (IsAllNopNode(graph) == true) {
454     return;
455   }
456   auto execution_order = graph->execution_order();
457   auto outputs = common::AnfAlgo::GetAllOutput(graph->output());
458   // If the graph has flag kFlagEnableZeroCopyInGraph, it means in subgraph sink mode, the inputs and outputs memory of
459   // graph should not be allocated, and the node should not be skipped.
460   bool need_keep_output_nop_node = (graph->is_dynamic_shape() || graph->has_flag(kFlagEnableZeroCopyInGraph));
461   MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
462   std::vector<CNodePtr> new_nodes;
463   for (auto &cnode : execution_order) {
464     MS_EXCEPTION_IF_NULL(cnode);
465     if (NeedHideNode(outputs, cnode, need_keep_output_nop_node)) {
466       common::AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode);
467       common::AnfAlgo::SetNodeAttr(kAttrSkipNopOpExecution, MakeValue(true), cnode);
468     } else {
469       new_nodes.push_back(cnode);
470     }
471   }
472   graph->set_execution_order(new_nodes);
473   MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size();
474 }
475 
RemoveNopNode(session::KernelGraph * const graph)476 void RemoveNopNode(session::KernelGraph *const graph) {
477   MS_EXCEPTION_IF_NULL(graph);
478   if (IsAllNopNode(graph) == true) {
479     return;
480   }
481   bool changed = true;
482   while (changed) {
483     changed = false;
484     std::vector<CNodePtr> new_nodes;
485     auto outputs = graph->outputs();
486     bool is_dynamic_graph = graph->is_dynamic_shape();
487     for (auto &cnode : graph->execution_order()) {
488       MS_EXCEPTION_IF_NULL(cnode);
489       // ignore nop node itself
490       if (NeedHideNode(outputs, cnode, is_dynamic_graph)) {
491         common::AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode);
492         common::AnfAlgo::SetNodeAttr(kAttrSkipNopOpExecution, MakeValue(true), cnode);
493         continue;
494       }
495       // Replace the input which is nop node
496       std::vector<AnfNodePtr> new_inputs;
497       new_inputs.push_back(cnode->input(0));
498       bool need_update = false;
499       for (size_t i = 1; i < cnode->size(); ++i) {
500         auto input = cnode->input(i);
501         MS_EXCEPTION_IF_NULL(input);
502         auto cinput = input->cast<CNodePtr>();
503         if (cinput == nullptr || !common::AnfAlgo::IsNopNode(cinput)) {
504           new_inputs.push_back(input);
505           continue;
506         }
507         constexpr auto kInputSize = 2;
508         if (cinput->size() == kInputSize) {
509           new_inputs.push_back(cinput->input(1));
510           need_update = true;
511           changed = true;
512         } else {
513           new_inputs.push_back(input);
514         }
515       }
516       if (need_update) {
517         cnode->set_inputs(new_inputs);
518       }
519       // push into new execution list
520       new_nodes.push_back(cnode);
521     }
522     graph->set_execution_order(new_nodes);
523   }
524 }
525 
GetRealNodeNum(const FuncGraphPtr & graph,const AnfNodePtr & node)526 size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
527   auto out_list = GetRealNodeUsedList(graph, node);
528   MS_EXCEPTION_IF_NULL(out_list);
529   return out_list->size();
530 }
531 
GetRealNodeUsedList(const FuncGraphPtr & graph,const AnfNodePtr & node)532 std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
533                                                                              const AnfNodePtr &node) {
534   auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
535   MS_EXCEPTION_IF_NULL(graph);
536   auto manager = graph->manager();
537   MS_EXCEPTION_IF_NULL(manager);
538   auto iter = manager->node_users().find(node);
539   if (iter == manager->node_users().end()) {
540     return output_node_list;
541   }
542   auto output_info_list = iter->second;
543   for (const auto &output_info : output_info_list) {
544     auto cnode_name = common::AnfAlgo::GetCNodeName(output_info.first);
545     if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
546         (cnode_name == prim::kPrimUpdateState->name())) {
547       continue;
548     }
549     output_node_list->push_back(output_info);
550   }
551   return output_node_list;
552 }
553 
GetRealNodeUsedListByOutputIdx(const FuncGraphPtr & graph,const AnfNodePtr & node,size_t output_index)554 std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
555                                                                                         const AnfNodePtr &node,
556                                                                                         size_t output_index) {
557   auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
558   MS_EXCEPTION_IF_NULL(graph);
559   auto manager = graph->manager();
560   MS_EXCEPTION_IF_NULL(manager);
561   auto iter = manager->node_users().find(node);
562   if (iter == manager->node_users().end()) {
563     MS_LOG(INTERNAL_EXCEPTION) << "node has no output in manager";
564   }
565   auto output_info_list = iter->second;
566   for (const auto &output_info : output_info_list) {
567     auto cnode_name = common::AnfAlgo::GetCNodeName(output_info.first);
568     if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
569         (cnode_name == prim::kPrimUpdateState->name())) {
570       continue;
571     }
572     size_t used_output_index;
573     if (cnode_name == prim::kPrimTupleGetItem->name()) {
574       used_output_index = common::AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
575     } else if (common::AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
576       used_output_index = output_index;
577     } else {
578       auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(output_info.first, IntToSize(output_info.second - 1));
579       if (kernel_with_index.first.get() != node.get()) {
580         MS_LOG(INTERNAL_EXCEPTION) << "Get used node failed for op[" << common::AnfAlgo::GetCNodeName(node) << "]";
581       }
582       used_output_index = kernel_with_index.second;
583     }
584     if (used_output_index == output_index) {
585       output_node_list->push_back(output_info);
586     }
587   }
588   return output_node_list;
589 }
590 
IsUsedByOthers(const FuncGraphPtr & graph,const AnfNodePtr & node)591 bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
592   MS_EXCEPTION_IF_NULL(graph);
593   MS_EXCEPTION_IF_NULL(node);
594   auto output_node_list = GetRealNodeUsedList(graph, node);
595   MS_EXCEPTION_IF_NULL(output_node_list);
596   return output_node_list->size() > 1;
597 }
598 
IsNotRealUsedByOthers(const FuncGraphPtr & graph,const AnfNodePtr & node)599 bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
600   MS_EXCEPTION_IF_NULL(graph);
601   MS_EXCEPTION_IF_NULL(node);
602   auto output_node_list = GetRealNodeUsedList(graph, node);
603   MS_EXCEPTION_IF_NULL(output_node_list);
604   if (output_node_list->empty()) {
605     return true;
606   }
607   for (const auto &output : *output_node_list) {
608     auto out_node = output.first;
609     auto name = common::AnfAlgo::GetCNodeName(out_node);
610     if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
611         name == prim::kPrimTupleGetItem->name() || name == prim::kPrimLoad->name()) {
612       auto result = IsNotRealUsedByOthers(graph, out_node);
613       if (!result) {
614         return result;
615       }
616       continue;
617     }
618     return false;
619   }
620   return true;
621 }
622 
CreatTupleGetItemNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,size_t output_idx)623 CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
624   MS_EXCEPTION_IF_NULL(func_graph);
625   auto idx = NewValueNode(SizeToLong(output_idx));
626   MS_EXCEPTION_IF_NULL(idx);
627   auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx));
628   auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
629   idx->set_abstract(abstract_scalar);
630   CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
631   MS_EXCEPTION_IF_NULL(tuple_getitem);
632   tuple_getitem->set_scope(node->scope());
633   auto abs = node->abstract()->cast<abstract::AbstractTuplePtr>();
634   MS_EXCEPTION_IF_NULL(abs);
635   auto abs_i = abs->elements()[output_idx];
636   MS_EXCEPTION_IF_NULL(abs_i);
637   tuple_getitem->set_abstract(abs_i);
638   return tuple_getitem;
639 }
640 
CreateMakeTupleNode(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & tuple_inputs)641 CNodePtr CreateMakeTupleNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &tuple_inputs) {
642   MS_EXCEPTION_IF_NULL(func_graph);
643   std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
644   AbstractBasePtrList make_tuple_abstract;
645   std::for_each(tuple_inputs.cbegin(), tuple_inputs.cend(),
646                 [&make_tuple_inputs, &make_tuple_abstract](const AnfNodePtr &node) {
647                   MS_EXCEPTION_IF_NULL(node);
648                   (void)make_tuple_inputs.emplace_back(node);
649                   (void)make_tuple_abstract.emplace_back(node->abstract());
650                 });
651   auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
652   MS_EXCEPTION_IF_NULL(make_tuple);
653   make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(make_tuple_abstract));
654   return make_tuple;
655 }
656 
CreateShapeValueNode(const FuncGraphPtr & func_graph,const std::vector<int64_t> & shape,bool to_tensor)657 ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape, bool to_tensor) {
658   MS_EXCEPTION_IF_NULL(func_graph);
659   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
660   MS_EXCEPTION_IF_NULL(kernel_graph);
661   ValuePtr shape_value = nullptr;
662   AbstractBasePtr abstract = nullptr;
663   if (to_tensor) {
664     // create Tensor
665     int64_t shape_dim = SizeToLong(shape.size());
666     std::vector<int64_t> shape_vec_shape = {shape_dim};
667     auto shape_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, shape_vec_shape);
668     MS_EXCEPTION_IF_NULL(shape_tensor);
669     auto data_ptr = shape_tensor->data_c();
670     MS_EXCEPTION_IF_NULL(data_ptr);
671     auto elem_num = shape.size() * kType64Len;
672     auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(shape_tensor->data().nbytes()), &shape[0], elem_num);
673     if (ret_code != EOK) {
674       MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
675     }
676     shape_value = shape_tensor;
677     abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);
678   } else {
679     // create ValueTuple
680     std::vector<ValuePtr> dim_values{};
681     abstract::AbstractBasePtrList abs{};
682     for (const auto &dim : shape) {
683       dim_values.push_back(MakeValue(dim));
684       abs.push_back(std::make_shared<abstract::AbstractScalar>(dim));
685     }
686     shape_value = std::make_shared<ValueTuple>(dim_values);
687     abstract = std::make_shared<abstract::AbstractTuple>(abs);
688   }
689   MS_EXCEPTION_IF_NULL(shape_value);
690   MS_EXCEPTION_IF_NULL(abstract);
691   auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value);
692   MS_EXCEPTION_IF_NULL(shape_value_node);
693   kernel_graph->AddValueNodeToGraph(shape_value_node);
694   return shape_value_node;
695 }
696 
AddCastNode(const FuncGraphPtr & func_graph,const TypeId dst_type,const CNodePtr & node,const bool is_input,const size_t input_index)697 CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, const CNodePtr &node, const bool is_input,
698                      const size_t input_index) {
699   MS_EXCEPTION_IF_NULL(func_graph);
700   MS_EXCEPTION_IF_NULL(node);
701   std::vector<AnfNodePtr> new_cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name()))};
702   BaseShapePtr shape;
703   if (is_input) {
704     auto node_input = common::AnfAlgo::GetInputNode(node, input_index);
705     (void)new_cast_inputs.emplace_back(node_input);
706     shape = AnfAlgo::GetOutputDetailShape(node_input, 0);
707   } else {
708     (void)new_cast_inputs.emplace_back(node);
709     shape = AnfAlgo::GetOutputDetailShape(node, 0);
710   }
711   CNodePtr new_cast = NewCNode(new_cast_inputs, func_graph, {node});
712   MS_EXCEPTION_IF_NULL(new_cast);
713   new_cast->set_scope(node->scope());
714   new_cast->set_abstract(node->abstract());
715   common::AnfAlgo::SetNodeAttr(kAttrDstType, MakeValue(static_cast<size_t>(dst_type)), new_cast);
716   common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {shape}, new_cast.get());
717   return new_cast;
718 }
719 
CreateNodeBase(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & new_node_inputs,const AnfNodePtr & node)720 AnfNodePtr CreateNodeBase(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &new_node_inputs,
721                           const AnfNodePtr &node) {
722   MS_EXCEPTION_IF_NULL(graph);
723   MS_EXCEPTION_IF_NULL(node);
724   auto new_node = graph->NewCNode(new_node_inputs);
725   MS_EXCEPTION_IF_NULL(new_node);
726 
727   new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
728   new_node->set_scope(node->scope());
729   new_node->set_abstract(node->abstract());
730 
731   auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
732   auto shapes = {common::AnfAlgo::GetOutputInferShape(node, 0)};
733   common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get());
734 
735   return new_node;
736 }
737 
AnfEqual(const BaseRef & a,const BaseRef & b)738 bool AnfEqual(const BaseRef &a, const BaseRef &b) {
739   if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
740     auto a_node = utils::cast<AnfNodePtr>(a);
741     auto b_node = utils::cast<AnfNodePtr>(b);
742     MS_EXCEPTION_IF_NULL(a_node);
743     MS_EXCEPTION_IF_NULL(b_node);
744     if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
745       auto a_value_node = a_node->cast<ValueNodePtr>();
746       MS_EXCEPTION_IF_NULL(a_value_node);
747       auto a_value = a_value_node->value();
748       MS_EXCEPTION_IF_NULL(a_value);
749       auto a_prim = a_value->cast<PrimitivePtr>();
750       MS_EXCEPTION_IF_NULL(a_prim);
751 
752       auto b_value_node = b_node->cast<ValueNodePtr>();
753       MS_EXCEPTION_IF_NULL(b_value_node);
754       auto b_value = b_value_node->value();
755       MS_EXCEPTION_IF_NULL(b_value);
756       auto b_prim = b_value->cast<PrimitivePtr>();
757       MS_EXCEPTION_IF_NULL(b_prim);
758 
759       return a_prim->name() == b_prim->name();
760     } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
761       auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
762       if (a_value_node_ptr == nullptr) {
763         MS_LOG(INTERNAL_EXCEPTION) << "Cast value node ptr fail.";
764       }
765       auto a_value_ptr = a_value_node_ptr->value();
766       if (a_value_ptr == nullptr) {
767         MS_LOG(INTERNAL_EXCEPTION) << "Value ptr is nullptr.";
768       }
769 
770       auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
771       if (b_value_node_ptr == nullptr) {
772         MS_LOG(INTERNAL_EXCEPTION) << "Cast value node ptr fail.";
773       }
774       auto b_value_ptr = b_value_node_ptr->value();
775       if (b_value_ptr == nullptr) {
776         MS_LOG(INTERNAL_EXCEPTION) << "Value ptr is nullptr.";
777       }
778       if (a_value_ptr->isa<tensor::Tensor>() && b_value_ptr->isa<tensor::Tensor>()) {
779         auto a_tensor_ptr = a_value_ptr->cast<tensor::TensorPtr>();
780         auto b_tensor_ptr = b_value_ptr->cast<tensor::TensorPtr>();
781         if (a_tensor_ptr == nullptr || b_tensor_ptr == nullptr) {
782           MS_LOG(INTERNAL_EXCEPTION) << "Cast value node ptr fail.";
783         }
784         return a_tensor_ptr->ValueEqual(*b_tensor_ptr);
785       } else {
786         return (*a_value_ptr) == (*b_value_ptr);
787       }
788     }
789     MS_LOG(DEBUG) << "check AnfNodePtr equal";
790   }
791   if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
792     MS_LOG(DEBUG) << "check GraphPtr equal";
793   }
794   return a == b;
795 }
796 
CNodeTypeEqual(const BaseRef & a,const BaseRef & b)797 bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
798   // To matchCNode and Kernel's type
799   if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
800     return true;
801   }
802   return a.type() == b.type();
803 }
804 
805 namespace {
CreateValueNodeWithSexp(const BaseRef & sexp,PrimitiveVarMap * primitive_vars)806 ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp, PrimitiveVarMap *primitive_vars) {
807   if (utils::isa<int>(sexp)) {
808     return NewValueNode(utils::cast<int>(sexp));
809   }
810   if (utils::isa<int64_t>(sexp)) {
811     return NewValueNode(utils::cast<int64_t>(sexp));
812   }
813   if (utils::isa<float>(sexp)) {
814     return NewValueNode(utils::cast<float>(sexp));
815   }
816   if (utils::isa<bool>(sexp)) {
817     return NewValueNode(utils::cast<bool>(sexp));
818   }
819   if (utils::isa<ValuePtr>(sexp)) {
820     auto value = utils::cast<ValuePtr>(sexp);
821     if (utils::isa<PrimitivePtr>(sexp)) {
822       auto prim = utils::cast<PrimitivePtr>(sexp);
823       if (primitive_vars->find(prim) != primitive_vars->end()) {
824         prim = std::make_shared<Primitive>(prim->name());
825         value = prim;
826       }
827       (*primitive_vars)[prim] = std::make_shared<Var>(prim);
828     }
829     return NewValueNode(value);
830   }
831   return nullptr;
832 }
833 
CreateCNodeWithGraph(const std::vector<AnfNodePtr> & input_nodes,const BaseRef & graph)834 CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
835   if (utils::isa<FuncGraphPtr>(graph)) {
836     return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
837   }
838   if (utils::isa<VarPtr>(graph)) {
839     return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
840   }
841   return nullptr;
842 }
843 
CreateVarNodeWithSexp(const BaseRef & sexp,const BaseRef & graph)844 VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
845   if (utils::isa<VarPtr>(graph)) {
846     MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
847     return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
848   }
849   if (utils::isa<FuncGraphPtr>(graph)) {
850     MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
851     return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
852   }
853   MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
854   return nullptr;
855 }
856 
HandleSexpVector(const BaseRef & sexp,const BaseRef & graph,PrimitiveVarMap * primitive_vars,bool multigraph)857 AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
858                             bool multigraph) {
859   MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
860   std::vector<AnfNodePtr> input_nodes;
861   const auto &tuple = utils::cast<VectorRef>(sexp);
862   if (multigraph && utils::isa<VarPtr>(graph)) {
863     for (auto &x : tuple) {
864       AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
865       input_nodes.push_back(node);
866     }
867     VarPtr var_ptr = utils::cast<VarPtr>(graph);
868     return std::make_shared<CNode>(input_nodes, var_ptr);
869   }
870 
871   for (auto &x : tuple) {
872     AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
873     input_nodes.push_back(node);
874   }
875   return CreateCNodeWithGraph(input_nodes, graph);
876 }
877 
RectifyAbstractFromStructuralAttr(const ValuePtr & value,const AbstractBasePtrList & input_abstract,const std::vector<size_t> & list_start_vec,size_t input_index)878 std::pair<AbstractBasePtr, size_t> RectifyAbstractFromStructuralAttr(const ValuePtr &value,
879                                                                      const AbstractBasePtrList &input_abstract,
880                                                                      const std::vector<size_t> &list_start_vec,
881                                                                      size_t input_index) {
882   MS_EXCEPTION_IF_NULL(value);
883   auto begin_iter = input_abstract.begin() + input_index;
884   if (value->isa<ValueSequence>()) {
885     size_t offset = 0;
886     std::vector<AbstractBasePtr> abs_list;
887     auto seq_value = value->cast_ptr<ValueSequence>();
888     for (size_t i = 0; i < seq_value->size(); ++i) {
889       auto [abs, offset_inner] =
890         RectifyAbstractFromStructuralAttr((*seq_value)[i], input_abstract, list_start_vec, input_index + offset);
891       MS_EXCEPTION_IF_NULL(abs);
892       if (abs->isa<abstract::AbstractSequence>() &&
893           std::find(list_start_vec.begin(), list_start_vec.end(), input_index + offset) != list_start_vec.end()) {
894         auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
895         const auto &elements = abs_seq->elements();
896         bool is_nested = std::any_of(elements.begin(), elements.end(),
897                                      [](const AbstractBasePtr &abs) { return abs->isa<abstract::AbstractSequence>(); });
898         if (!is_nested) {
899           const auto &first_abs_in_list = input_abstract[input_index + offset];
900           MS_EXCEPTION_IF_NULL(first_abs_in_list);
901           if (!first_abs_in_list->has_user_data<kernel::PyExecuteOutputUserData>()) {
902             MS_LOG(INTERNAL_EXCEPTION) << "List input abstract PyExecuteOutputUserData not found.";
903           }
904           const auto &list_user_data = first_abs_in_list->user_data<kernel::PyExecuteOutputUserData>();
905           abs->set_user_data<kernel::PyExecuteOutputUserData>(list_user_data);
906         }
907       }
908       (void)abs_list.emplace_back(abs);
909       offset += offset_inner;
910     }
911     (void)std::for_each(begin_iter, begin_iter + offset, [](AbstractBasePtr abs) -> void {
912       MS_LOG(DEBUG) << "The convert abs is :" << abs->ToString();
913     });
914     return std::make_pair(std::make_shared<abstract::AbstractTuple>(abs_list), offset);
915   }
916 
917   const auto num_value = GetValue<int64_t>(value);
918 
919   constexpr auto kNotDynamicFlag = -1;
920   if (num_value == kNotDynamicFlag) {
921     return std::make_pair(*begin_iter, 1);
922   } else {
923     MS_LOG(EXCEPTION) << "The attr of structural must all value -1 but got " << num_value;
924   }
925 }
926 
RectifyEmptyTupleAbstract(const ValuePtr & structural)927 AbstractBasePtr RectifyEmptyTupleAbstract(const ValuePtr &structural) {
928   MS_EXCEPTION_IF_NULL(structural);
929   if (!structural->isa<ValueTuple>()) {
930     MS_LOG(EXCEPTION) << "input abstract is out of range.";
931   }
932 
933   auto value_tuple = structural->cast_ptr<ValueTuple>();
934   std::vector<AbstractBasePtr> abs_list;
935   MS_EXCEPTION_IF_NULL(value_tuple);
936   for (size_t i = 0; i < value_tuple->size(); ++i) {
937     auto item = (*value_tuple)[i];
938     (void)abs_list.emplace_back(RectifyEmptyTupleAbstract(item));
939   }
940 
941   return std::make_shared<abstract::AbstractTuple>(abs_list);
942 }
943 
RectifyAbstractFromTupleInputStructural(const ValuePtr & tuple_structural,const AbstractBasePtrList & input_abstract,const ValuePtrList & list_start)944 AbstractBasePtrList RectifyAbstractFromTupleInputStructural(const ValuePtr &tuple_structural,
945                                                             const AbstractBasePtrList &input_abstract,
946                                                             const ValuePtrList &list_start) {
947   if (tuple_structural == nullptr) {
948     return input_abstract;
949   }
950   auto tuple_structural_value = tuple_structural->cast_ptr<ValueSequence>();
951   MS_EXCEPTION_IF_NULL(tuple_structural_value);
952   AbstractBasePtrList rectifyed_abs_list;
953   size_t input_index = 0;
954   for (size_t i = 0; i < tuple_structural_value->size(); ++i) {
955     auto item = (*tuple_structural_value)[i];
956     MS_EXCEPTION_IF_NULL(item);
957     if (input_abstract.size() <= input_index) {
958       // The Ori  Node : Oper(a, b, ())  ==> Oper(a, b)  with structural --> (-1, -1 , ())
959       // The abstract size will be smaller than the attr of tuple input structural.
960       (void)rectifyed_abs_list.emplace_back(RectifyEmptyTupleAbstract(item));
961     }
962     std::vector<size_t> list_start_vec;
963     (void)std::transform(list_start.begin(), list_start.end(), std::back_inserter(list_start_vec),
964                          [](const ValuePtr val) { return GetValue<size_t>(val); });
965     auto [abs, offset] = RectifyAbstractFromStructuralAttr(item, input_abstract, list_start_vec, input_index);
966     input_index += offset;
967     (void)rectifyed_abs_list.emplace_back(abs);
968     MS_LOG(DEBUG) << "Rectify abs :" << item->ToString() << ", from structural " << abs->ToString();
969   }
970 
971   return rectifyed_abs_list;
972 }
973 
RectifyAbstractFromDynamicInput(const PrimitivePtr & prim,const AbstractBasePtrList & input_abstract)974 AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
975                                                     const AbstractBasePtrList &input_abstract) {
976   MS_EXCEPTION_IF_NULL(prim);
977   auto dyn_input_list = prim->GetAttr(kAttrDynInputSizes);
978   if (dyn_input_list == nullptr) {
979     return input_abstract;
980   }
981   AbstractBasePtrList rectifyed_abs_list;
982   const int kNotDynamicFlag = -1;
983   auto dynamic_input_index = GetValue<std::vector<int64_t>>(dyn_input_list);
984   size_t input_index = 0;
985   for (auto item : dynamic_input_index) {
986     if (item == kNotDynamicFlag) {
987       if (input_index >= input_abstract.size()) {
988         if ((prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name())) {
989           MS_LOG(WARNING) << "For primitive \'PyExecute\', index " << input_index
990                           << " is out of range in input abstract " << input_abstract.size();
991           continue;
992         }
993         MS_LOG(EXCEPTION) << "For primitive \'" << prim->name() << "\', index " << input_index
994                           << " is out of range in input abstract " << input_abstract.size();
995       }
996       (void)rectifyed_abs_list.emplace_back(input_abstract[input_index++]);
997     } else {
998       if (item < 0) {
999         MS_LOG(EXCEPTION) << "The dynamic input size check error the index should be -1 or positive number but got "
1000                           << item;
1001       }
1002       AbstractBasePtrList dynamic_inputs_abs;
1003       for (auto index = item; index > 0; --index) {
1004         if (input_index >= input_abstract.size()) {
1005           if ((prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name())) {
1006             MS_LOG(WARNING) << "For primitive \'PyExecute\', index " << input_index
1007                             << " is out of range in input abstract " << input_abstract.size();
1008             continue;
1009           }
1010           MS_LOG(EXCEPTION) << "For primitive \'" << prim->name() << "\', index " << input_index
1011                             << " is out of range in input abstract " << input_abstract.size();
1012         }
1013         (void)dynamic_inputs_abs.emplace_back(input_abstract[input_index++]);
1014       }
1015       (void)rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs));
1016     }
1017   }
1018   return rectifyed_abs_list;
1019 }
1020 
RectifyAbstract(const PrimitivePtr & prim,const AbstractBasePtrList & input_abstract)1021 AbstractBasePtrList RectifyAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &input_abstract) {
1022   auto input_structural = prim->GetAttr(kAttrTupleInputStructural);
1023   if (input_structural != nullptr) {
1024     if (prim->HasAttr(kAttrListStartIndex)) {
1025       auto list_start_index = prim->GetAttr(kAttrListStartIndex);
1026       MS_EXCEPTION_IF_NULL(list_start_index);
1027       auto list_start_index_value = list_start_index->cast_ptr<ValueSequence>();
1028       MS_EXCEPTION_IF_NULL(list_start_index_value);
1029       return RectifyAbstractFromTupleInputStructural(input_structural, input_abstract, list_start_index_value->value());
1030     }
1031     return RectifyAbstractFromTupleInputStructural(input_structural, input_abstract, {});
1032   }
1033   return RectifyAbstractFromDynamicInput(prim, input_abstract);
1034 }
1035 
InferShapeWithCheck(const PrimitivePtr & prim,const PrimitivePtr & prim_clone,const AbstractBasePtrList & infer_spec_list,const AbstractBasePtr & orig_abs,const CNodePtr & cnode)1036 inline AbstractBasePtr InferShapeWithCheck(const PrimitivePtr &prim, const PrimitivePtr &prim_clone,
1037                                            const AbstractBasePtrList &infer_spec_list, const AbstractBasePtr &orig_abs,
1038                                            const CNodePtr &cnode) {
1039   MS_EXCEPTION_IF_NULL(prim_clone);
1040   MS_EXCEPTION_IF_NULL(orig_abs);
1041   AbstractBasePtr out_abs;
1042   if (auto shape_optional = abstract::InferShapeByFuncImpl(prim_clone, infer_spec_list); shape_optional.has_value()) {
1043     out_abs = orig_abs->Clone();
1044     out_abs->set_shape(shape_optional.value());
1045   } else if (auto found = abstract::GetBackendPrimitiveInferImpl(prim_clone); found.has_value()) {
1046     auto infer = found.value();
1047     MS_EXCEPTION_IF_CHECK_FAIL(infer.IsImplInferShapeAndType(), "There is no infer-shape implement for backend!");
1048     MS_EXCEPTION_IF_NULL(cnode);
1049     if (common::AnfAlgo::IsDynamicSequence(cnode)) {
1050       out_abs = infer.InferShapeAndType(nullptr, prim_clone, infer_spec_list);
1051     } else {
1052       out_abs = orig_abs->Clone();
1053       auto shape = infer.InferShape(prim_clone, infer_spec_list);
1054       if (shape == nullptr) {
1055         MS_LOG(EXCEPTION) << "Infer shape with backend function failed";
1056       }
1057       out_abs->set_shape(shape);
1058     }
1059   } else {
1060     MS_EXCEPTION_IF_NULL(prim);
1061     MS_LOG(EXCEPTION) << "Get infer functions failed, the operator is not support dynamic shape yet, primitive name:"
1062                       << prim->name() << " primitive type:" << prim->type_name();
1063   }
1064   return out_abs;
1065 }
1066 }  // namespace
1067 
SexpToNode(const BaseRef & sexp,const BaseRef & graph,PrimitiveVarMap * primitive_vars,bool multigraph)1068 AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
1069   MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
1070   MS_EXCEPTION_IF_NULL(primitive_vars);
1071   if (utils::isa<VectorRef>(sexp)) {
1072     return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
1073   }
1074   if (utils::isa<VarPtr>(sexp)) {
1075     auto var_ptr = utils::cast<VarPtr>(sexp);
1076     MS_EXCEPTION_IF_NULL(var_ptr);
1077     if (var_ptr->primitive()) {
1078       (*primitive_vars)[var_ptr->primitive()] = var_ptr;
1079       return NewValueNode(var_ptr->primitive());
1080     }
1081     return CreateVarNodeWithSexp(sexp, graph);
1082   }
1083   if (utils::isa<AnfNodePtr>(sexp)) {
1084     return utils::cast<AnfNodePtr>(sexp);
1085   }
1086   auto value_node = CreateValueNodeWithSexp(sexp, primitive_vars);
1087   if (value_node == nullptr) {
1088     MS_LOG(INTERNAL_EXCEPTION) << "Sexp cannot converted, sexp: " + sexp.ToString();
1089   }
1090   return value_node;
1091 }
1092 
IsSameNode(const EquivPtr & equiv1,const EquivPtr & equiv2,const VarPtr & var_node)1093 bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) {
1094   MS_EXCEPTION_IF_NULL(equiv1);
1095   MS_EXCEPTION_IF_NULL(equiv2);
1096   MS_EXCEPTION_IF_NULL(var_node);
1097   auto equiv1_node = GetAnfNodeByVar(equiv1, var_node);
1098   MS_EXCEPTION_IF_NULL(equiv1_node);
1099   auto equiv2_node = GetAnfNodeByVar(equiv2, var_node);
1100   MS_EXCEPTION_IF_NULL(equiv2_node);
1101   return *equiv1_node == *equiv2_node;
1102 }
1103 
GetAnfNodeByVar(const EquivPtr & equiv,const VarPtr & var_node)1104 AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) {
1105   MS_EXCEPTION_IF_NULL(equiv);
1106   MS_EXCEPTION_IF_NULL(var_node);
1107   auto iter = (*equiv).find(var_node);
1108   if (iter == (*equiv).cend()) {
1109     MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched.";
1110     return nullptr;
1111   }
1112   auto res = utils::cast<AnfNodePtr>(iter->second);
1113   if (res == nullptr) {
1114     MS_LOG(INTERNAL_EXCEPTION) << "Cast fail! Maybe var is not a anf node";
1115   }
1116   return res;
1117 }
1118 
GetGetitemIndex(const AnfNodePtr & getitem)1119 int64_t GetGetitemIndex(const AnfNodePtr &getitem) {
1120   if (!getitem->isa<CNode>() || IsPrimitive(getitem, prim::kPrimTupleGetItem)) {
1121     MS_LOG(INTERNAL_EXCEPTION) << "Expect TupleGetItem, but got " << getitem->DebugString();
1122   }
1123   auto vnode = GetValueNode(getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
1124   return GetValue<int64_t>(vnode);
1125 }
1126 
CompareTupleGetitem(const AnfNodePtr & n1,const AnfNodePtr & n2)1127 bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
1128   MS_EXCEPTION_IF_NULL(n1);
1129   MS_EXCEPTION_IF_NULL(n2);
1130   auto n1_cnode = n1->cast<CNodePtr>();
1131   auto n2_cnode = n2->cast<CNodePtr>();
1132   MS_EXCEPTION_IF_NULL(n1_cnode);
1133   MS_EXCEPTION_IF_NULL(n2_cnode);
1134   auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem);
1135   MS_EXCEPTION_IF_NULL(index_input1);
1136   auto value_node1 = index_input1->cast<ValueNodePtr>();
1137   MS_EXCEPTION_IF_NULL(value_node1);
1138   auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem);
1139   MS_EXCEPTION_IF_NULL(index_input2);
1140   auto value_node2 = index_input2->cast<ValueNodePtr>();
1141   MS_EXCEPTION_IF_NULL(value_node2);
1142   return GetValue<int64_t>(value_node1->value()) < GetValue<int64_t>(value_node2->value());
1143 }
1144 
GetBoolAttr(const AnfNodePtr & node,const std::string & attr_name)1145 bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
1146   MS_EXCEPTION_IF_NULL(node);
1147   if (!node->isa<CNode>()) {
1148     MS_LOG(INFO) << "node is not a cnode";
1149     return false;
1150   }
1151   auto cnode = node->cast<CNodePtr>();
1152   MS_EXCEPTION_IF_NULL(cnode);
1153   return common::AnfAlgo::HasNodeAttr(attr_name, cnode) && common::AnfAlgo::GetNodeAttr<bool>(node, attr_name);
1154 }
1155 
CheckSupportDataType(const AnfNodePtr & node,const std::set<TypeId> & supported_data_type_set)1156 bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set) {
1157   MS_EXCEPTION_IF_NULL(node);
1158   TypeId data_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
1159   if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) {
1160     return true;
1161   }
1162   MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
1163   return false;
1164 }
1165 
MakeValueNode(const ValueNodePtr & value_node)1166 ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
1167   MS_EXCEPTION_IF_NULL(value_node);
1168   ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
1169   MS_EXCEPTION_IF_NULL(new_value_node);
1170   new_value_node->set_abstract(value_node->abstract());
1171   // create kernel_info fo new value node
1172   auto kernel_info = std::make_shared<device::KernelInfo>();
1173   new_value_node->set_kernel_info(kernel_info);
1174   // create kernel_build_info for new value node
1175   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
1176   MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
1177   // set the format of value_node to DEFAULT_FORMAT
1178   kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
1179   // set value node initial device data type = infer data type
1180   std::vector<TypeId> types;
1181   size_t output_num = AnfAlgo::GetOutputTensorNum(value_node);
1182   for (size_t index = 0; index < output_num; ++index) {
1183     types.push_back(kTypeUnknown);
1184   }
1185   kernel_build_info_builder->SetOutputsDeviceType(types);
1186   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
1187   return new_value_node;
1188 }
1189 
TransferDependOrUpdateState(const CNodePtr & old_node,const FuncGraphPtr & graph,const CNodePtr & new_node)1190 void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
1191   MS_EXCEPTION_IF_NULL(old_node);
1192   MS_EXCEPTION_IF_NULL(graph);
1193   auto manager = graph->manager();
1194   MS_EXCEPTION_IF_NULL(manager);
1195   // Find BatchNorm's output which is a Depend or UpdateState.
1196   auto node_users = manager->node_users()[old_node];
1197   for (const auto &node_index : node_users) {
1198     AnfNodePtr output = node_index.first;
1199     MS_EXCEPTION_IF_NULL(output);
1200     if (common::AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
1201         common::AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
1202       auto depend = output->cast<CNodePtr>();
1203       MS_EXCEPTION_IF_NULL(depend);
1204       manager->SetEdge(depend, node_index.second, new_node);
1205     }
1206   }
1207 }
1208 
GetPrimitiveChangeInfo(const PrimitivePtr & prim,std::string * me_name,bool * ir_change)1209 void GetPrimitiveChangeInfo(const PrimitivePtr &prim, std::string *me_name, bool *ir_change) {
1210   MS_EXCEPTION_IF_NULL(prim);
1211   MS_EXCEPTION_IF_NULL(me_name);
1212   MS_EXCEPTION_IF_NULL(ir_change);
1213   if (prim->HasAttr(kAttrMeOpName)) {
1214     *me_name = GetValue<std::string>(prim->GetAttr(kAttrMeOpName));
1215   }
1216   if (prim->HasAttr(kAttrIRChange)) {
1217     *ir_change = GetValue<bool>(prim->GetAttr(kAttrIRChange));
1218   }
1219   if (*ir_change || !me_name->empty()) {
1220     MS_LOG(DEBUG) << "Note: primitive(" << prim->ToString() << ", me_name:" << *me_name
1221                   << ", ori_name: " << prim->name() << ", ir_change" << *ir_change << ") "
1222                   << "has been changed in ascend vm pass, it should been rectify abstract before infer or provide a "
1223                      "new infer func";
1224   }
1225 }
1226 
CppInferShape(const PrimitivePtr & prim,const AbstractBasePtrList & args_spec_list,const CNodePtr & cnode)1227 void CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list, const CNodePtr &cnode) {
1228   MS_EXCEPTION_IF_NULL(prim);
1229   MS_EXCEPTION_IF_NULL(cnode);
1230   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kKernelInferInner,
1231                                      prim->name(), true);
1232   auto old_abs = cnode->abstract();
1233   MS_EXCEPTION_IF_NULL(old_abs);
1234 
1235   if (IS_OUTPUT_ON(mindspore::kDebug)) {
1236     MS_LOG(DEBUG) << "Infer name = " << cnode->fullname_with_scope();
1237     for (size_t i = 0; i < args_spec_list.size(); i++) {
1238       MS_LOG(DEBUG) << "Infer name '" << cnode->fullname_with_scope() << "', The input[" << i
1239                     << "] abs is : " << args_spec_list[i]->ToString();
1240     }
1241   }
1242 
1243   PrimitivePtr prim_clone = prim;
1244   MS_EXCEPTION_IF_NULL(prim_clone);
1245   std::string me_name;
1246   std::string ori_name;
1247   bool ir_change = false;
1248   GetPrimitiveChangeInfo(prim, &me_name, &ir_change);
1249   if (!me_name.empty()) {
1250     prim_clone = prim->Clone();
1251     ori_name = prim->name();
1252     prim_clone->set_name(me_name);
1253   }
1254 
1255   auto infer_spec_list = RectifyAbstract(prim_clone, args_spec_list);
1256   AbstractBasePtr out_abs = InferShapeWithCheck(prim, prim_clone, infer_spec_list, old_abs, cnode);
1257 
1258   if (prim_clone != prim) {
1259     *prim = *prim_clone;
1260     prim->set_name(ori_name);
1261   }
1262   cnode->set_abstract(out_abs);
1263   MS_LOG(DEBUG) << "The abstract of " << cnode->fullname_with_scope() << " changes from " << old_abs << " to "
1264                 << out_abs;
1265 }
1266 
CppInferShapeAndType(const PrimitivePtr & prim,const AbstractBasePtrList & args_spec_list)1267 AbstractBasePtr CppInferShapeAndType(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
1268   MS_EXCEPTION_IF_NULL(prim);
1269   PrimitivePtr prim_clone = prim;
1270   MS_EXCEPTION_IF_NULL(prim_clone);
1271   std::string me_name;
1272   std::string ori_name;
1273   bool ir_change = false;
1274   GetPrimitiveChangeInfo(prim, &me_name, &ir_change);
1275   if (!me_name.empty()) {
1276     prim_clone = prim->Clone();
1277     ori_name = prim->name();
1278     prim_clone->set_name(me_name);
1279   }
1280 
1281   AbstractBasePtr ret;
1282   if (auto abstract_optional = abstract::InferAbstractByFuncImpl(prim_clone, args_spec_list);
1283       abstract_optional.has_value()) {
1284     ret = abstract_optional.value();
1285   } else if (auto found = abstract::GetBackendPrimitiveInferImpl(prim_clone); found.has_value()) {
1286     auto infer = found.value();
1287     MS_EXCEPTION_IF_CHECK_FAIL(infer.IsImplInferShapeAndType(), "There is no infer-abstract implement!");
1288     auto infer_spec_list = RectifyAbstract(prim_clone, args_spec_list);
1289     ret = infer.InferShapeAndType(nullptr, prim_clone, infer_spec_list);
1290   } else {
1291     MS_LOG(EXCEPTION)
1292       << "Get infer shape function failed, the operator is not support dynamic shape yet, primitive name:"
1293       << prim->name() << " primitive type:" << prim->type_name();
1294   }
1295 
1296   if (prim_clone != prim) {
1297     *prim = *prim_clone;
1298     prim->set_name(ori_name);
1299   }
1300   return ret;
1301 }
1302 
GenerateKernelBuildInfo(const std::vector<AnfNodePtr> & node_list)1303 kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
1304   std::vector<std::string> inputs_device_format;
1305   std::vector<std::string> outputs_device_format;
1306   std::vector<TypeId> inputs_device_type;
1307   std::vector<TypeId> outputs_device_type;
1308   kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
1309   for (size_t idx = 0; idx < node_list.size(); ++idx) {
1310     auto cnode = utils::cast<CNodePtr>(node_list[idx]);
1311     MS_EXCEPTION_IF_NULL(cnode);
1312     size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
1313     for (size_t input_index = 0; input_index < input_num; ++input_index) {
1314       (void)inputs_device_format.emplace_back(kOpFormat_DEFAULT);
1315       (void)inputs_device_type.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
1316     }
1317     size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
1318     for (size_t output_index = 0; output_index < output_num; ++output_index) {
1319       (void)outputs_device_format.emplace_back(kOpFormat_DEFAULT);
1320       (void)outputs_device_type.emplace_back(common::AnfAlgo::GetOutputInferDataType(cnode, output_index));
1321     }
1322   }
1323   builder.SetInputsFormat(inputs_device_format);
1324   builder.SetOutputsFormat(outputs_device_format);
1325   builder.SetInputsDeviceType(inputs_device_type);
1326   builder.SetOutputsDeviceType(outputs_device_type);
1327   return builder.Build();
1328 }
1329 
GetNodeOutputUsedNum(const session::KernelGraph & kernel_graph,const AnfNodePtr & node)1330 std::vector<int64_t> GetNodeOutputUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
1331   MS_EXCEPTION_IF_NULL(node);
1332   auto manager = kernel_graph.manager();
1333   MS_EXCEPTION_IF_NULL(manager);
1334   auto output_num = AnfAlgo::GetOutputTensorNum(node);
1335   std::vector<int64_t> output_used_num(output_num, 0);
1336   if (output_num == 1) {
1337     output_used_num[0] = SizeToLong(manager->node_users()[node].size());
1338   } else {
1339     for (auto out_getitem : manager->node_users()[node]) {
1340       MS_EXCEPTION_IF_NULL(out_getitem.first);
1341       if (!common::AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
1342         continue;
1343       }
1344       auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
1345       MS_EXCEPTION_IF_NULL(out_getitem_ptr);
1346       auto getitem_input2 = out_getitem_ptr->input(kInputNodeOutputIndexInTupleGetItem);
1347       auto output_idx = LongToSize(GetValue<int64_t>(GetValueNode(getitem_input2)));
1348       output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
1349     }
1350   }
1351   return output_used_num;
1352 }
1353 
GetNodeOutputTotalUsedNum(const session::KernelGraph & kernel_graph,const AnfNodePtr & node)1354 int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
1355   auto output_used_num = GetNodeOutputUsedNum(kernel_graph, node);
1356   return std::accumulate(output_used_num.begin(), output_used_num.end(), int64_t(0));
1357 }
1358 
GetCustomOpAttrIndex(const PrimitivePtr & primitive,mindspore::HashSet<size_t> * indexes)1359 void GetCustomOpAttrIndex(const PrimitivePtr &primitive, mindspore::HashSet<size_t> *indexes) {
1360   if (primitive == nullptr || primitive->name() != prim::kPrimCustom->name()) {
1361     return;
1362   }
1363   MS_EXCEPTION_IF_NULL(indexes);
1364   auto input_names = primitive->GetAttr(kAttrInputNames);
1365   auto attr_names = primitive->GetAttr(kAttrAttrNames);
1366   if (input_names == nullptr || attr_names == nullptr) {
1367     return;
1368   }
1369   auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
1370   auto attr_names_vec = GetValue<std::vector<std::string>>(attr_names);
1371   for (size_t i = 0; i < input_names_vec.size(); ++i) {
1372     if (std::find(attr_names_vec.begin(), attr_names_vec.end(), input_names_vec[i]) != attr_names_vec.end()) {
1373       (void)indexes->insert(i);
1374     }
1375   }
1376 }
1377 
GetInputNodeIndex(const AnfNodePtr & input,const CNodePtr & user_node)1378 size_t GetInputNodeIndex(const AnfNodePtr &input, const CNodePtr &user_node) {
1379   MS_EXCEPTION_IF_NULL(input);
1380   MS_EXCEPTION_IF_NULL(user_node);
1381 
1382   AnfNodePtrList input_list = user_node->inputs();
1383   auto pos = std::find(input_list.begin(), input_list.end(), input);
1384   if (pos == input_list.end()) {
1385     MS_LOG(EXCEPTION) << input->fullname_with_scope() << " is not the input of " << user_node->fullname_with_scope();
1386   }
1387 
1388   // The first input is Primitive and needs to be skipped.
1389   return std::distance(input_list.begin() + kSizeOne, pos);
1390 }
1391 
SplitTupleInputs(const FuncGraphPtr & graph,const AnfNodePtr & tuple_input,std::vector<AnfNodePtr> * plant_inputs)1392 int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
1393                          std::vector<AnfNodePtr> *plant_inputs) {
1394   MS_EXCEPTION_IF_NULL(tuple_input);
1395   if (!common::AnfAlgo::IsTupleOutput(tuple_input)) {
1396     auto abs = tuple_input->abstract();
1397     MS_EXCEPTION_IF_NULL(abs);
1398     MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
1399     return -1;
1400   }
1401   MS_EXCEPTION_IF_NULL(plant_inputs);
1402   auto input_size = AnfAlgo::GetOutputElementNum(tuple_input);
1403   if (tuple_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
1404     auto make_tuple = tuple_input->cast<CNodePtr>();
1405     MS_EXCEPTION_IF_NULL(make_tuple);
1406     size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
1407     for (size_t j = 0; j < tuple_input_num; ++j) {
1408       // using for graph kernel
1409       auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
1410       MS_EXCEPTION_IF_NULL(dyn_input_node);
1411       // Handle tuple nested scenes.
1412       if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
1413         input_size += LongToSize(SplitTupleInputs(graph, dyn_input_node, plant_inputs));
1414         continue;
1415       }
1416       (void)plant_inputs->emplace_back(dyn_input_node);
1417     }
1418     return input_size;
1419   }
1420   for (size_t index = 0; index < input_size; ++index) {
1421     auto dynamic_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
1422     (void)plant_inputs->emplace_back(dynamic_input_node);
1423   }
1424   return input_size;
1425 }
1426 
IsNotSequenceOfTensor(const abstract::AbstractBasePtr & abs)1427 static bool IsNotSequenceOfTensor(const abstract::AbstractBasePtr &abs) {
1428   if (abs->isa<abstract::AbstractTensor>()) {
1429     return false;
1430   }
1431 
1432   if (abs->isa<abstract::AbstractSequence>()) {
1433     auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
1434     MS_EXCEPTION_IF_NULL(seq_abs);
1435     if (seq_abs->size() == 0) {
1436       return true;
1437     }
1438 
1439     return IsNotSequenceOfTensor(seq_abs->elements()[0]);
1440   }
1441 
1442   return true;
1443 }
1444 
GenPrintAttrDynInputSizes(const CNodePtr & cnode)1445 std::vector<int64_t> GenPrintAttrDynInputSizes(const CNodePtr &cnode) {
1446   int64_t num_inputs = 0;
1447   std::vector<AnfNodePtr> node_inputs = cnode->inputs();
1448   for (size_t node_inputs_index = 1; node_inputs_index < node_inputs.size(); ++node_inputs_index) {
1449     auto &input = node_inputs[node_inputs_index];
1450     MS_EXCEPTION_IF_NULL(input);
1451     if (IsValueNode<UMonad>(input) || IsValueNode<IOMonad>(input) || HasAbstractMonad(input)) {
1452       continue;
1453     }
1454     num_inputs++;
1455   }
1456   // the first input of print is a placeholder
1457   return std::vector<int64_t>{-1, num_inputs - 1, -1};
1458 }
1459 
InputArgTypeIsDynamicType(const mindspore::ops::OP_DTYPE input_arg_dtype)1460 bool InputArgTypeIsDynamicType(const mindspore::ops::OP_DTYPE input_arg_dtype) {
1461   if (input_arg_dtype >= mindspore::ops::DT_TUPLE_BOOL && input_arg_dtype <= mindspore::ops::DT_LIST_ANY) {
1462     return true;
1463   }
1464   return false;
1465 }
1466 
UseEmptyNodeReplaceNone(const FuncGraphPtr & graph,const std::string & cnode_name,const size_t input_idx,std::vector<int64_t> * dyn_input_sizes,std::vector<AnfNodePtr> * plant_inputs)1467 void UseEmptyNodeReplaceNone(const FuncGraphPtr &graph, const std::string &cnode_name, const size_t input_idx,
1468                              std::vector<int64_t> *dyn_input_sizes, std::vector<AnfNodePtr> *plant_inputs) {
1469   MS_EXCEPTION_IF_NULL(dyn_input_sizes);
1470   MS_EXCEPTION_IF_NULL(plant_inputs);
1471   if (OpInputDtypeMap.at(cnode_name).find(input_idx) != OpInputDtypeMap.at(cnode_name).end()) {
1472     // create empty tensor
1473     auto tensor_type = OpInputDtypeMap.at(cnode_name).at(input_idx);
1474     std::vector<int64_t> tensor_shape = {0};
1475     auto empty_tensor = std::make_shared<tensor::Tensor>(tensor_type, tensor_shape);
1476     // create node
1477     auto empty_node = opt::CreateValueNodeWithKernelInfo(graph, empty_tensor);
1478     ValueNodePtr empty_value_node = empty_node->cast<ValueNodePtr>();
1479     // empty node size is 1
1480     dyn_input_sizes->emplace_back(1);
1481     plant_inputs->emplace_back(empty_value_node);
1482   } else {
1483     MS_LOG(EXCEPTION) << "Invalid input index. The [" << input_idx << "] in op [" << cnode_name
1484                       << "] is not in OpInputDtypeMap, cannot use new node replace None.";
1485   }
1486 }
1487 
GetPlantInputsAndSize(const FuncGraphPtr & graph,const CNodePtr & cnode_ptr,std::vector<AnfNodePtr> * plant_inputs,std::vector<int64_t> * dyn_input_sizes)1488 void GetPlantInputsAndSize(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr, std::vector<AnfNodePtr> *plant_inputs,
1489                            std::vector<int64_t> *dyn_input_sizes) {
1490   MS_EXCEPTION_IF_NULL(cnode_ptr);
1491   auto cnode_name = common::AnfAlgo::GetCNodeName(cnode_ptr);
1492   plant_inputs->push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
1493   size_t input_num = cnode_ptr->size() - 1;
1494   bool cnode_is_print = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPrint);
1495   for (size_t i = 0; i < input_num; ++i) {
1496     auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i);
1497     MS_EXCEPTION_IF_NULL(input_node);
1498     bool output_is_tuple = common::AnfAlgo::IsTupleOutput(input_node);
1499     if (output_is_tuple && cnode_is_print) {
1500       continue;
1501     } else if (output_is_tuple) {
1502       int64_t dyn_input_size;
1503       if (IsNotSequenceOfTensor(input_node->abstract())) {
1504         dyn_input_size = 0;
1505       } else {
1506         dyn_input_size = SplitTupleInputs(graph, input_node, plant_inputs);
1507       }
1508       if (dyn_input_size == 0) {
1509         dyn_input_sizes->push_back(-1);
1510         plant_inputs->push_back(input_node);
1511       } else {
1512         (void)dyn_input_sizes->emplace_back(dyn_input_size);
1513       }
1514     } else if (OpInputDtypeMap.find(cnode_name) != OpInputDtypeMap.end()) {
1515       // Only op in OpInputDtypeMap can be replace None input.
1516       auto opdef_ptr = mindspore::ops::GetOpDef(cnode_name);
1517       MS_EXCEPTION_IF_NULL(opdef_ptr);
1518       auto input_args = (opdef_ptr)->args_;
1519       if (i >= input_args.size()) {
1520         MS_LOG(EXCEPTION) << "The [" << i << "] in op [" << cnode_name << "] is out of op_def args range";
1521       }
1522       // When input[i] is None and input[i] type in op_yaml is dynamic type, do replace
1523       if (common::AnfAlgo::IsNoneInput(cnode_ptr, i) && InputArgTypeIsDynamicType(input_args[i].arg_dtype_)) {
1524         UseEmptyNodeReplaceNone(graph, cnode_name, i, dyn_input_sizes, plant_inputs);
1525       } else {
1526         dyn_input_sizes->push_back(-1);
1527         plant_inputs->push_back(input_node);
1528       }
1529     } else {
1530       dyn_input_sizes->push_back(-1);
1531       plant_inputs->push_back(input_node);
1532     }
1533   }
1534 }
1535 
ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr & graph,const CNodePtr & cnode_ptr)1536 AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
1537   MS_EXCEPTION_IF_NULL(cnode_ptr);
1538   MS_EXCEPTION_IF_NULL(graph);
1539   if (common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) ||
1540       common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial) ||
1541       common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimBpropCut)) {
1542     return nullptr;
1543   }
1544 
1545   if (common::AnfAlgo::HasDynamicTupleInput(cnode_ptr)) {
1546     MS_LOG(INFO) << "Node " << cnode_ptr->fullname_with_scope()
1547                  << " has dynamic tuple input, can't convert. Node debug string:" << cnode_ptr->DebugString();
1548     return nullptr;
1549   }
1550   std::vector<AnfNodePtr> plant_inputs;
1551   std::vector<int64_t> dyn_input_sizes;
1552   GetPlantInputsAndSize(graph, cnode_ptr, &plant_inputs, &dyn_input_sizes);
1553 
1554   // If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs.
1555   if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
1556     auto new_cnode = NewCNode(plant_inputs, graph, {cnode_ptr});
1557     MS_EXCEPTION_IF_NULL(new_cnode);
1558     new_cnode->set_abstract(cnode_ptr->abstract());
1559     new_cnode->set_scope(cnode_ptr->scope());
1560     new_cnode->set_primal_attrs(cnode_ptr->primal_attrs());
1561     new_cnode->set_attrs(cnode_ptr->attrs());
1562     bool cnode_is_print = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPrint);
1563     if (cnode_is_print) {
1564       dyn_input_sizes = GenPrintAttrDynInputSizes(new_cnode);
1565     }
1566     common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_cnode);
1567     auto kernel_graph = graph->cast<KernelGraphPtr>();
1568     if (kernel_graph != nullptr) {
1569       kernel_graph->FrontBackendlMapUpdate(cnode_ptr, new_cnode);
1570     }
1571     return new_cnode;
1572   }
1573   return nullptr;
1574 }
1575 
InferOp(const CNodePtr & node,void * args)1576 void InferOp(const CNodePtr &node, void *args) { dynamic_shape::InferOp(node, args); }
1577 
1578 LaunchHandler launch_py_handler{nullptr};
set_launch_handler(const LaunchHandler & handler)1579 void set_launch_handler(const LaunchHandler &handler) { launch_py_handler = handler; }
1580 
LaunchPy(const PrimitivePtr & primitive,const std::vector<abstract::AbstractBase * > & args_abs_list)1581 abstract::AbstractBasePtr LaunchPy(const PrimitivePtr &primitive,
1582                                    const std::vector<abstract::AbstractBase *> &args_abs_list) {
1583   MS_EXCEPTION_IF_NULL(launch_py_handler);
1584   return launch_py_handler(primitive, args_abs_list);
1585 }
1586 
InferAbstract(const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & input_list)1587 AbstractBasePtr InferAbstract(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &input_list) {
1588   MS_EXCEPTION_IF_NULL(primitive);
1589   const auto &op_name = primitive->name();
1590   std::vector<AbstractBasePtr> input_args;
1591   std::for_each(input_list.begin(), input_list.end(),
1592                 [&input_args](const auto &input) { input_args.emplace_back(input->abstract()); });
1593   auto shape_optional = abstract::InferAbstractByFuncImpl(primitive, input_args);
1594   if (shape_optional.has_value()) {
1595     return shape_optional.value();
1596   }
1597 
1598   auto infer_impl = abstract::GetBackendPrimitiveInferImpl(primitive);
1599   if (infer_impl.has_value()) {
1600     auto infer = infer_impl.value();
1601     if (infer.IsImplInferShapeAndType()) {
1602       return infer.InferShapeAndType(nullptr, primitive, input_args);
1603     }
1604   }
1605   MS_LOG(EXCEPTION) << "The InferAbstract function of [" << op_name << "] is not defined.";
1606 }
1607 
CreateValueNodeWithKernelInfo(const FuncGraphPtr & graph,const ValuePtr & value)1608 AnfNodePtr CreateValueNodeWithKernelInfo(const FuncGraphPtr &graph, const ValuePtr &value) {
1609   MS_EXCEPTION_IF_NULL(value);
1610   auto value_node = NewValueNode(value);
1611   MS_EXCEPTION_IF_NULL(value_node);
1612   auto value_abs = value->ToAbstract();
1613   value_node->set_abstract(value_abs);
1614 
1615   MS_EXCEPTION_IF_NULL(graph);
1616   auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(graph);
1617   if (kernel_graph != nullptr) {
1618     // In kernel graph case, a new value node should set kernel_info and kernel_build_info here for no-kernel-selecting.
1619     auto kernel_info = std::make_shared<device::KernelInfo>();
1620     value_node->set_kernel_info(kernel_info);
1621     kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
1622     builder.SetOutputsFormat({kOpFormat_DEFAULT});
1623     if (value->isa<tensor::Tensor>()) {
1624       auto tensor = value->cast<tensor::TensorPtr>();
1625       MS_EXCEPTION_IF_NULL(tensor);
1626       builder.SetOutputsDeviceType({tensor->data_type()});
1627     } else {
1628       MS_EXCEPTION_IF_NULL(value->type());
1629       auto type_id = value->type()->type_id();
1630       if (value->isa<ValueSequence>()) {
1631         auto value_sequence = value->cast<ValueSequencePtr>()->value();
1632         if (value_sequence.empty()) {
1633           type_id = kNumberTypeInt64;
1634         } else {
1635           MS_EXCEPTION_IF_NULL(value_sequence[0]->type());
1636           type_id = value_sequence[0]->type()->type_id();
1637         }
1638       }
1639       builder.SetOutputsDeviceType({type_id});
1640     }
1641     auto object_type = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAbstractObjectType(value_abs));
1642     builder.SetOutputsKernelObjectType({object_type});
1643     AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), value_node.get());
1644 
1645     kernel_graph->AddValueNodeToGraph(value_node);
1646   }
1647 
1648   return value_node;
1649 }
1650 }  // namespace opt
1651 }  // namespace mindspore
1652