1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/optimizer/helper.h"
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <algorithm>
23 #include <map>
24 #include <set>
25 #include <deque>
26 #include <vector>
27 #include "kernel/kernel_build_info.h"
28 #include "mindspore/core/ops/sequence_ops.h"
29 #include "mindspore/core/ops/nn_ops.h"
30 #include "mindspore/core/ops/array_ops.h"
31 #include "mindspore/core/ops/framework_ops.h"
32 #include "utils/hash_set.h"
33 #include "include/common/utils/utils.h"
34 #include "base/base_ref.h"
35 #include "include/backend/anf_runtime_algorithm.h"
36 #include "include/common/utils/anfalgo.h"
37 #include "utils/log_adapter.h"
38 #include "utils/ms_utils.h"
39 #include "include/common/utils/convert_utils.h"
40 #include "include/backend/kernel_info.h"
41 #include "utils/ms_context.h"
42 #include "utils/trace_base.h"
43 #include "backend/common/pass/const_input_to_attr.h"
44 #include "backend/operator/ops_backend_infer_function.h"
45 #include "frontend/operator/ops_front_infer_function.h"
46 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
47 #include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
48 #include "include/common/profiler.h"
49 #include "abstract/ops/primitive_infer_map.h"
50
51 namespace mindspore {
52 namespace opt {
53 namespace {
54 constexpr size_t kType32Len = 4;
55 constexpr size_t kType64Len = 8;
56 constexpr auto kNopNodeRealInputIndex = 1;
57 const std::map<std::string, std::map<size_t, TypeId>> OpInputDtypeMap = {{prim::kPrimGroupedMatmul->name(),
58 {{2, TypeId::kNumberTypeFloat16},
59 {3, TypeId::kNumberTypeUInt64},
60 {4, TypeId::kNumberTypeFloat32},
61 {5, TypeId::kNumberTypeFloat16},
62 {6, TypeId::kNumberTypeFloat16}}}};
63
UpdateDumpFlagAndDebugInfo(const CNodePtr & node,const std::vector<AnfNodePtr> & orig_nodes)64 void UpdateDumpFlagAndDebugInfo(const CNodePtr &node, const std::vector<AnfNodePtr> &orig_nodes) {
65 MS_EXCEPTION_IF_NULL(node);
66 std::vector<AnfNodePtr> orig_real_cnodes;
67 for (auto &orig_node : orig_nodes) {
68 MS_EXCEPTION_IF_NULL(orig_node);
69 if (AnfUtils::IsRealCNodeKernel(orig_node)) {
70 auto orig_cnode = orig_node->cast<CNodePtr>();
71 MS_EXCEPTION_IF_NULL(orig_cnode);
72 if (common::AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) {
73 common::AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node);
74 }
75 orig_real_cnodes.push_back(orig_node);
76 }
77 }
78
79 node->AddFusedDebugInfoList(orig_real_cnodes);
80 }
81 } // namespace
82
IsDepend(const FuncGraph & graph,const AnfNodePtr & node,const std::vector<AnfNodePtr> & nodes)83 bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
84 mindspore::HashSet<AnfNodePtr> visited_nodes;
85 return IsDepend(graph, node, nodes, &visited_nodes);
86 }
87
IsDepend(const FuncGraph & graph,const AnfNodePtr & node,const std::vector<AnfNodePtr> & nodes,mindspore::HashSet<AnfNodePtr> * visited_nodes)88 bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes,
89 mindspore::HashSet<AnfNodePtr> *visited_nodes) {
90 MS_EXCEPTION_IF_NULL(node);
91 MS_EXCEPTION_IF_NULL(visited_nodes);
92 FuncGraphManagerPtr manager = graph.manager();
93 MS_EXCEPTION_IF_NULL(manager);
94
95 std::deque<AnfNodePtr> todo{node};
96 while (!todo.empty()) {
97 AnfNodePtr nd = todo.front();
98 todo.pop_front();
99 if (visited_nodes->count(nd) > 0 || !manager->all_nodes().contains(nd)) {
100 continue;
101 }
102 (void)visited_nodes->insert(nd);
103
104 if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) {
105 return true;
106 }
107 if (nd->isa<CNode>()) {
108 auto cnode = nd->cast<CNodePtr>();
109 MS_EXCEPTION_IF_NULL(cnode);
110 auto inputs = cnode->inputs();
111 (void)todo.insert(todo.cend(), inputs.cbegin(), inputs.cend());
112 }
113 }
114 return false;
115 }
116
UnVisited(const BaseRef & n)117 bool UnVisited(const BaseRef &n) {
118 if (utils::isa<AnfNodePtr>(n)) {
119 AnfNodePtr in = utils::cast<AnfNodePtr>(n);
120 MS_EXCEPTION_IF_NULL(in);
121 if (IsValueNode<Primitive>(in)) {
122 auto value_node = in->cast<ValueNodePtr>();
123 MS_EXCEPTION_IF_NULL(value_node);
124 auto value = value_node->value();
125 MS_EXCEPTION_IF_NULL(value);
126 auto prim_py = value->cast<PrimitivePtr>();
127 MS_EXCEPTION_IF_NULL(prim_py);
128 return !prim_py->HasAttr(kAttrVisited);
129 } else if (IsValueNode<FuncGraph>(in)) {
130 auto func_graph = GetValueNode<FuncGraphPtr>(in);
131 MS_EXCEPTION_IF_NULL(func_graph);
132 return !func_graph->has_flag(kAttrVisited);
133 }
134 return false;
135 }
136 return false;
137 }
138
NewCNode(const std::vector<AnfNodePtr> & inputs,const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & orig_nodes)139 CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
140 const std::vector<AnfNodePtr> &orig_nodes) {
141 MS_EXCEPTION_IF_NULL(fg);
142 auto node = fg->NewCNode(inputs);
143 MS_EXCEPTION_IF_NULL(node);
144 UpdateDumpFlagAndDebugInfo(node, orig_nodes);
145 return node;
146 }
147
NewCNode(const CNodePtr & cnode,const KernelGraphPtr & fg,const std::vector<AnfNodePtr> & orig_nodes)148 CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes) {
149 MS_EXCEPTION_IF_NULL(fg);
150 auto node = fg->NewCNode(cnode);
151 MS_EXCEPTION_IF_NULL(node);
152 UpdateDumpFlagAndDebugInfo(node, orig_nodes);
153 return node;
154 }
155
CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr & node,size_t input_size)156 CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size) {
157 MS_EXCEPTION_IF_NULL(node);
158 if (!node->isa<CNode>()) {
159 MS_LOG(INTERNAL_EXCEPTION) << "The node is expected to be a cnode";
160 }
161 auto cnode = node->cast<CNodePtr>();
162 CheckCNodeInputSize(cnode, input_size);
163 return cnode;
164 }
165
CheckCNodeInputSize(const CNodePtr & cnode,size_t input_tensor_size)166 void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_size) {
167 MS_EXCEPTION_IF_NULL(cnode);
168 auto real_input_tensor_num = common::AnfAlgo::GetInputTensorNum(cnode);
169 if (real_input_tensor_num != input_tensor_size) {
170 MS_LOG(EXCEPTION) << "The input tensor size[" << real_input_tensor_num
171 << "] of node [" + cnode->DebugString() + "] is not equal to " << input_tensor_size
172 << trace::DumpSourceLines(cnode);
173 }
174 }
175
HasSymmetricalKernelInfo(const AnfNodePtr & node_x,const AnfNodePtr & node_y)176 bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) {
177 MS_EXCEPTION_IF_NULL(node_x);
178 MS_EXCEPTION_IF_NULL(node_y);
179 return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) &&
180 AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0));
181 }
182
EliminateDependTransop(const FuncGraphPtr & func_graph,const AnfNodePtr & node)183 const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
184 MS_EXCEPTION_IF_NULL(func_graph);
185 MS_EXCEPTION_IF_NULL(node);
186
187 auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
188 MS_EXCEPTION_IF_NULL(transop_cnode);
189 auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(1), kDependInputTensorNum);
190 auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputTensorNum);
191 auto transed_node = prev_transop_cnode->input(1);
192 MS_EXCEPTION_IF_NULL(transed_node);
193
194 std::vector<AnfNodePtr> replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node,
195 depend_cnode->input(kDependAttachNodeIndex)};
196 AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs);
197 MS_EXCEPTION_IF_NULL(replace_depend);
198 auto transed_abstract = transed_node->abstract();
199 replace_depend->set_abstract(transed_abstract);
200 return replace_depend;
201 }
202
Visited(const BaseRef & n)203 bool Visited(const BaseRef &n) {
204 if (utils::isa<AnfNodePtr>(n)) {
205 AnfNodePtr in = utils::cast<AnfNodePtr>(n);
206 MS_EXCEPTION_IF_NULL(in);
207 if (IsValueNode<Primitive>(in)) {
208 auto value_node = in->cast<ValueNodePtr>();
209 MS_EXCEPTION_IF_NULL(value_node);
210 auto value = value_node->value();
211 MS_EXCEPTION_IF_NULL(value);
212 auto prim_py = value->cast<PrimitivePtr>();
213 MS_EXCEPTION_IF_NULL(prim_py);
214 return prim_py->HasAttr(kAttrVisited);
215 } else if (IsValueNode<FuncGraph>(in)) {
216 auto func_graph = GetValueNode<FuncGraphPtr>(in);
217 MS_EXCEPTION_IF_NULL(func_graph);
218 return func_graph->has_flag(kAttrVisited);
219 }
220 return false;
221 }
222 return false;
223 }
224
CreateMultipleOutputsOfAnfNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,size_t output_num,std::vector<AnfNodePtr> * outputs)225 void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
226 std::vector<AnfNodePtr> *outputs) {
227 MS_EXCEPTION_IF_NULL(func_graph);
228 MS_EXCEPTION_IF_NULL(node);
229 MS_EXCEPTION_IF_NULL(outputs);
230 auto type_ptr = node->Type();
231 for (size_t i = 0; i < output_num; i++) {
232 int64_t temp = SizeToLong(i);
233 auto idx = NewValueNode(temp);
234 MS_EXCEPTION_IF_NULL(idx);
235 auto imm = std::make_shared<Int64Imm>(temp);
236 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
237 idx->set_abstract(abstract_scalar);
238 auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
239 MS_EXCEPTION_IF_NULL(tuple_getitem);
240 tuple_getitem->set_scope(node->scope());
241 common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(type_ptr, i)},
242 {common::AnfAlgo::GetOutputInferShape(node, i)}, tuple_getitem.get());
243 (*outputs).push_back(tuple_getitem);
244 }
245 }
246
247 template <typename T>
CreateTensorWithValueTuple(const ValueTuplePtr & value_tuple_ptr,const TypePtr & type_ptr,size_t data_length)248 tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
249 size_t data_length) {
250 MS_EXCEPTION_IF_NULL(value_tuple_ptr);
251 MS_EXCEPTION_IF_NULL(type_ptr);
252 std::vector<T> values;
253 for (const auto &v : value_tuple_ptr->value()) {
254 MS_EXCEPTION_IF_NULL(v);
255 if (v->isa<Scalar>()) {
256 ScalarPtr scalar = v->cast<ScalarPtr>();
257 values.push_back(GetValue<T>(scalar));
258 } else {
259 MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
260 return nullptr;
261 }
262 }
263 std::vector<int64_t> tensor_shape = {SizeToLong(values.size())};
264 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
265 MS_EXCEPTION_IF_NULL(tensor);
266 tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
267 tensor->set_device_info(device_info);
268 auto data_ptr = tensor->data_c();
269 MS_EXCEPTION_IF_NULL(data_ptr);
270 auto elem_num = values.size() * data_length;
271 auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num);
272 if (ret_code != EOK) {
273 MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
274 }
275 return tensor;
276 }
277
CreateEmptyTupleTensor(const ValueTuplePtr & value_tuple)278 tensor::TensorPtr CreateEmptyTupleTensor(const ValueTuplePtr &value_tuple) {
279 std::vector<int64_t> tensor_shape = {0};
280 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kInt64->type_id(), tensor_shape);
281 MS_EXCEPTION_IF_NULL(tensor);
282 tensor::DeviceInfo device_info{kOpFormat_DEFAULT, kInt64};
283 tensor->set_device_info(device_info);
284 tensor->set_user_data(kTensorValueIsEmpty, value_tuple);
285 return tensor;
286 }
287
CreateTensorInput(const KernelGraphPtr & kernel_graph,const AnfNodePtr & input_node)288 AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) {
289 MS_EXCEPTION_IF_NULL(input_node);
290 auto value_node = input_node->cast<ValueNodePtr>();
291 MS_EXCEPTION_IF_NULL(value_node);
292 auto value = value_node->value();
293 MS_EXCEPTION_IF_NULL(value);
294 tensor::TensorPtr tensor_ptr = nullptr;
295 if (value->isa<Scalar>()) {
296 tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
297 } else if (value->isa<ValueTuple>()) {
298 tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
299 } else if (value->isa<ValueList>()) {
300 tensor_ptr = CreateTupleTensor(std::make_shared<ValueTuple>(value->cast<ValueListPtr>()->value()));
301 } else {
302 MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple";
303 }
304 if (tensor_ptr == nullptr) {
305 MS_LOG(DEBUG) << "Create tensor failed";
306 return nullptr;
307 }
308 auto tensor_input = std::make_shared<ValueNode>(tensor_ptr);
309 MS_EXCEPTION_IF_NULL(tensor_input);
310 tensor_input->set_abstract(tensor_ptr->ToAbstract());
311 if (kernel_graph != nullptr) {
312 tensor_input = kernel_graph->NewValueNode(tensor_input);
313 kernel_graph->AddValueNodeToGraph(tensor_input);
314 kernel_graph->FrontBackendlMapUpdate(input_node, tensor_input);
315 } else {
316 tensor_input = MakeValueNode(tensor_input);
317 }
318 tensor_input->set_scope(input_node->scope());
319 return tensor_input;
320 }
321
CreateTupleTensor(const ValueTuplePtr & value_tuple)322 tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
323 MS_EXCEPTION_IF_NULL(value_tuple);
324 tensor::TensorPtr tensor = nullptr;
325 if (value_tuple->value().empty()) {
326 tensor = CreateEmptyTupleTensor(value_tuple);
327 return tensor;
328 }
329 ValuePtr v = *(value_tuple->value().begin());
330 MS_EXCEPTION_IF_NULL(v);
331 // Currently we only deal with the scalar tuple
332 if (!v->isa<Scalar>()) {
333 MS_LOG(DEBUG) << "The value " << v << "of tuple is not a scalar";
334 return nullptr;
335 }
336 ScalarPtr scalar = v->cast<ScalarPtr>();
337 MS_EXCEPTION_IF_NULL(scalar);
338 if (scalar->isa<Int32Imm>()) {
339 tensor = CreateTensorWithValueTuple<int32_t>(value_tuple, kInt32, sizeof(int32_t));
340 } else if (scalar->isa<Int64Imm>()) {
341 tensor = CreateTensorWithValueTuple<int64_t>(value_tuple, kInt64, sizeof(int64_t));
342 } else if (scalar->isa<FloatImm>()) {
343 tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, sizeof(float));
344 } else {
345 auto type = scalar->type();
346 auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
347 MS_LOG(ERROR) << "Invalid scalar type: " << type_str;
348 return nullptr;
349 }
350 return tensor;
351 }
352
CreateTensorMoveOp(const FuncGraphPtr & graph,const AnfNodePtr & node)353 AnfNodePtr CreateTensorMoveOp(const FuncGraphPtr &graph, const AnfNodePtr &node) {
354 MS_EXCEPTION_IF_NULL(graph);
355 MS_EXCEPTION_IF_NULL(node);
356 auto prim = std::make_shared<Primitive>(kTensorMoveOpName);
357 std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim), node};
358 auto new_node = graph->NewCNode(new_node_inputs);
359 MS_EXCEPTION_IF_NULL(new_node);
360 new_node->set_abstract(node->abstract());
361 new_node->set_scope(node->scope());
362 common::AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), new_node);
363 return new_node;
364 }
365
InsertTensorMoveForGraphOutput(const FuncGraphPtr & graph,const AnfNodePtr & node)366 std::vector<AnfNodePtr> InsertTensorMoveForGraphOutput(const FuncGraphPtr &graph, const AnfNodePtr &node) {
367 MS_EXCEPTION_IF_NULL(graph);
368 MS_EXCEPTION_IF_NULL(node);
369 auto kernel_graph = graph->cast<KernelGraphPtr>();
370 MS_EXCEPTION_IF_NULL(kernel_graph);
371
372 std::vector<AnfNodePtr> ret;
373 auto manager = graph->manager();
374 MS_EXCEPTION_IF_NULL(manager);
375 auto &node_users = manager->node_users();
376 auto iter = node_users.find(node);
377 if (iter == node_users.end()) {
378 return ret;
379 }
380 for (auto &item : iter->second) {
381 MS_EXCEPTION_IF_NULL(item.first);
382 auto next_node = item.first->cast<CNodePtr>();
383 bool find = false;
384 auto graph_outputs_pair =
385 common::AnfAlgo::GetAllOutputIndexByReturnTypes(graph->output(), {prim::kPrimTupleGetItem});
386 for (auto output_pair : graph_outputs_pair) {
387 while (AnfUtils::IsRealCNodeKernel(output_pair.first)) {
388 auto output_kernel = output_pair.first;
389 MS_EXCEPTION_IF_NULL(output_kernel);
390 auto cnode = output_kernel->cast<CNodePtr>();
391 // nop node
392 if (common::AnfAlgo::IsNopNode(cnode)) {
393 output_pair = common::AnfAlgo::VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, true);
394 continue;
395 }
396 // ref node
397 if (kernel_graph->IsInRefOutputMap(output_pair)) {
398 output_pair = kernel_graph->GetRefCorrespondOutput(output_pair);
399 continue;
400 }
401 break;
402 }
403 MS_EXCEPTION_IF_NULL(output_pair.first);
404 if (next_node == output_pair.first->cast<CNodePtr>()) {
405 find = true;
406 break;
407 }
408 }
409 if (!find) {
410 continue;
411 }
412 auto tensor_move = CreateTensorMoveOp(graph, next_node);
413 auto kernel_info = std::make_shared<device::KernelInfo>();
414 MS_EXCEPTION_IF_NULL(tensor_move);
415 tensor_move->set_kernel_info(kernel_info);
416 (void)manager->Replace(next_node, tensor_move);
417 ret.push_back(tensor_move);
418 MS_LOG(DEBUG) << "Insert Output TensorMove for op " << node->fullname_with_scope();
419 }
420 return ret;
421 }
422
IsAllNopNode(const session::KernelGraph * const graph)423 bool IsAllNopNode(const session::KernelGraph *const graph) {
424 MS_EXCEPTION_IF_NULL(graph);
425 auto execution_order = graph->execution_order();
426 for (auto &cnode : execution_order) {
427 MS_EXCEPTION_IF_NULL(cnode);
428 if (!common::AnfAlgo::IsNopNode(cnode)) {
429 return false;
430 }
431 }
432 return true;
433 }
434
NeedHideNode(const std::vector<AnfNodePtr> & outputs,const AnfNodePtr & node,bool need_keep_output_nop_node)435 bool NeedHideNode(const std::vector<AnfNodePtr> &outputs, const AnfNodePtr &node, bool need_keep_output_nop_node) {
436 MS_EXCEPTION_IF_NULL(node);
437 // if node is not a nop node, keep it in execution order
438 if (!common::AnfAlgo::IsNopNode(node)) {
439 return false;
440 }
441 // if node is nop node and the graph is dynamic graph, check if the nop node is graph's output.
442 if (need_keep_output_nop_node) {
443 auto iter = find(outputs.begin(), outputs.end(), node);
444 if (iter != outputs.end()) {
445 return false;
446 }
447 }
448 return true;
449 }
450
HideNopNode(session::KernelGraph * const graph)451 void HideNopNode(session::KernelGraph *const graph) {
452 MS_EXCEPTION_IF_NULL(graph);
453 if (IsAllNopNode(graph) == true) {
454 return;
455 }
456 auto execution_order = graph->execution_order();
457 auto outputs = common::AnfAlgo::GetAllOutput(graph->output());
458 // If the graph has flag kFlagEnableZeroCopyInGraph, it means in subgraph sink mode, the inputs and outputs memory of
459 // graph should not be allocated, and the node should not be skipped.
460 bool need_keep_output_nop_node = (graph->is_dynamic_shape() || graph->has_flag(kFlagEnableZeroCopyInGraph));
461 MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
462 std::vector<CNodePtr> new_nodes;
463 for (auto &cnode : execution_order) {
464 MS_EXCEPTION_IF_NULL(cnode);
465 if (NeedHideNode(outputs, cnode, need_keep_output_nop_node)) {
466 common::AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode);
467 common::AnfAlgo::SetNodeAttr(kAttrSkipNopOpExecution, MakeValue(true), cnode);
468 } else {
469 new_nodes.push_back(cnode);
470 }
471 }
472 graph->set_execution_order(new_nodes);
473 MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size();
474 }
475
RemoveNopNode(session::KernelGraph * const graph)476 void RemoveNopNode(session::KernelGraph *const graph) {
477 MS_EXCEPTION_IF_NULL(graph);
478 if (IsAllNopNode(graph) == true) {
479 return;
480 }
481 bool changed = true;
482 while (changed) {
483 changed = false;
484 std::vector<CNodePtr> new_nodes;
485 auto outputs = graph->outputs();
486 bool is_dynamic_graph = graph->is_dynamic_shape();
487 for (auto &cnode : graph->execution_order()) {
488 MS_EXCEPTION_IF_NULL(cnode);
489 // ignore nop node itself
490 if (NeedHideNode(outputs, cnode, is_dynamic_graph)) {
491 common::AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode);
492 common::AnfAlgo::SetNodeAttr(kAttrSkipNopOpExecution, MakeValue(true), cnode);
493 continue;
494 }
495 // Replace the input which is nop node
496 std::vector<AnfNodePtr> new_inputs;
497 new_inputs.push_back(cnode->input(0));
498 bool need_update = false;
499 for (size_t i = 1; i < cnode->size(); ++i) {
500 auto input = cnode->input(i);
501 MS_EXCEPTION_IF_NULL(input);
502 auto cinput = input->cast<CNodePtr>();
503 if (cinput == nullptr || !common::AnfAlgo::IsNopNode(cinput)) {
504 new_inputs.push_back(input);
505 continue;
506 }
507 constexpr auto kInputSize = 2;
508 if (cinput->size() == kInputSize) {
509 new_inputs.push_back(cinput->input(1));
510 need_update = true;
511 changed = true;
512 } else {
513 new_inputs.push_back(input);
514 }
515 }
516 if (need_update) {
517 cnode->set_inputs(new_inputs);
518 }
519 // push into new execution list
520 new_nodes.push_back(cnode);
521 }
522 graph->set_execution_order(new_nodes);
523 }
524 }
525
GetRealNodeNum(const FuncGraphPtr & graph,const AnfNodePtr & node)526 size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
527 auto out_list = GetRealNodeUsedList(graph, node);
528 MS_EXCEPTION_IF_NULL(out_list);
529 return out_list->size();
530 }
531
GetRealNodeUsedList(const FuncGraphPtr & graph,const AnfNodePtr & node)532 std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
533 const AnfNodePtr &node) {
534 auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
535 MS_EXCEPTION_IF_NULL(graph);
536 auto manager = graph->manager();
537 MS_EXCEPTION_IF_NULL(manager);
538 auto iter = manager->node_users().find(node);
539 if (iter == manager->node_users().end()) {
540 return output_node_list;
541 }
542 auto output_info_list = iter->second;
543 for (const auto &output_info : output_info_list) {
544 auto cnode_name = common::AnfAlgo::GetCNodeName(output_info.first);
545 if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
546 (cnode_name == prim::kPrimUpdateState->name())) {
547 continue;
548 }
549 output_node_list->push_back(output_info);
550 }
551 return output_node_list;
552 }
553
GetRealNodeUsedListByOutputIdx(const FuncGraphPtr & graph,const AnfNodePtr & node,size_t output_index)554 std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
555 const AnfNodePtr &node,
556 size_t output_index) {
557 auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
558 MS_EXCEPTION_IF_NULL(graph);
559 auto manager = graph->manager();
560 MS_EXCEPTION_IF_NULL(manager);
561 auto iter = manager->node_users().find(node);
562 if (iter == manager->node_users().end()) {
563 MS_LOG(INTERNAL_EXCEPTION) << "node has no output in manager";
564 }
565 auto output_info_list = iter->second;
566 for (const auto &output_info : output_info_list) {
567 auto cnode_name = common::AnfAlgo::GetCNodeName(output_info.first);
568 if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
569 (cnode_name == prim::kPrimUpdateState->name())) {
570 continue;
571 }
572 size_t used_output_index;
573 if (cnode_name == prim::kPrimTupleGetItem->name()) {
574 used_output_index = common::AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
575 } else if (common::AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
576 used_output_index = output_index;
577 } else {
578 auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(output_info.first, IntToSize(output_info.second - 1));
579 if (kernel_with_index.first.get() != node.get()) {
580 MS_LOG(INTERNAL_EXCEPTION) << "Get used node failed for op[" << common::AnfAlgo::GetCNodeName(node) << "]";
581 }
582 used_output_index = kernel_with_index.second;
583 }
584 if (used_output_index == output_index) {
585 output_node_list->push_back(output_info);
586 }
587 }
588 return output_node_list;
589 }
590
IsUsedByOthers(const FuncGraphPtr & graph,const AnfNodePtr & node)591 bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
592 MS_EXCEPTION_IF_NULL(graph);
593 MS_EXCEPTION_IF_NULL(node);
594 auto output_node_list = GetRealNodeUsedList(graph, node);
595 MS_EXCEPTION_IF_NULL(output_node_list);
596 return output_node_list->size() > 1;
597 }
598
IsNotRealUsedByOthers(const FuncGraphPtr & graph,const AnfNodePtr & node)599 bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
600 MS_EXCEPTION_IF_NULL(graph);
601 MS_EXCEPTION_IF_NULL(node);
602 auto output_node_list = GetRealNodeUsedList(graph, node);
603 MS_EXCEPTION_IF_NULL(output_node_list);
604 if (output_node_list->empty()) {
605 return true;
606 }
607 for (const auto &output : *output_node_list) {
608 auto out_node = output.first;
609 auto name = common::AnfAlgo::GetCNodeName(out_node);
610 if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
611 name == prim::kPrimTupleGetItem->name() || name == prim::kPrimLoad->name()) {
612 auto result = IsNotRealUsedByOthers(graph, out_node);
613 if (!result) {
614 return result;
615 }
616 continue;
617 }
618 return false;
619 }
620 return true;
621 }
622
CreatTupleGetItemNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,size_t output_idx)623 CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
624 MS_EXCEPTION_IF_NULL(func_graph);
625 auto idx = NewValueNode(SizeToLong(output_idx));
626 MS_EXCEPTION_IF_NULL(idx);
627 auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx));
628 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
629 idx->set_abstract(abstract_scalar);
630 CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
631 MS_EXCEPTION_IF_NULL(tuple_getitem);
632 tuple_getitem->set_scope(node->scope());
633 auto abs = node->abstract()->cast<abstract::AbstractTuplePtr>();
634 MS_EXCEPTION_IF_NULL(abs);
635 auto abs_i = abs->elements()[output_idx];
636 MS_EXCEPTION_IF_NULL(abs_i);
637 tuple_getitem->set_abstract(abs_i);
638 return tuple_getitem;
639 }
640
CreateMakeTupleNode(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & tuple_inputs)641 CNodePtr CreateMakeTupleNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &tuple_inputs) {
642 MS_EXCEPTION_IF_NULL(func_graph);
643 std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
644 AbstractBasePtrList make_tuple_abstract;
645 std::for_each(tuple_inputs.cbegin(), tuple_inputs.cend(),
646 [&make_tuple_inputs, &make_tuple_abstract](const AnfNodePtr &node) {
647 MS_EXCEPTION_IF_NULL(node);
648 (void)make_tuple_inputs.emplace_back(node);
649 (void)make_tuple_abstract.emplace_back(node->abstract());
650 });
651 auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
652 MS_EXCEPTION_IF_NULL(make_tuple);
653 make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(make_tuple_abstract));
654 return make_tuple;
655 }
656
CreateShapeValueNode(const FuncGraphPtr & func_graph,const std::vector<int64_t> & shape,bool to_tensor)657 ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape, bool to_tensor) {
658 MS_EXCEPTION_IF_NULL(func_graph);
659 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
660 MS_EXCEPTION_IF_NULL(kernel_graph);
661 ValuePtr shape_value = nullptr;
662 AbstractBasePtr abstract = nullptr;
663 if (to_tensor) {
664 // create Tensor
665 int64_t shape_dim = SizeToLong(shape.size());
666 std::vector<int64_t> shape_vec_shape = {shape_dim};
667 auto shape_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, shape_vec_shape);
668 MS_EXCEPTION_IF_NULL(shape_tensor);
669 auto data_ptr = shape_tensor->data_c();
670 MS_EXCEPTION_IF_NULL(data_ptr);
671 auto elem_num = shape.size() * kType64Len;
672 auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(shape_tensor->data().nbytes()), &shape[0], elem_num);
673 if (ret_code != EOK) {
674 MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
675 }
676 shape_value = shape_tensor;
677 abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);
678 } else {
679 // create ValueTuple
680 std::vector<ValuePtr> dim_values{};
681 abstract::AbstractBasePtrList abs{};
682 for (const auto &dim : shape) {
683 dim_values.push_back(MakeValue(dim));
684 abs.push_back(std::make_shared<abstract::AbstractScalar>(dim));
685 }
686 shape_value = std::make_shared<ValueTuple>(dim_values);
687 abstract = std::make_shared<abstract::AbstractTuple>(abs);
688 }
689 MS_EXCEPTION_IF_NULL(shape_value);
690 MS_EXCEPTION_IF_NULL(abstract);
691 auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value);
692 MS_EXCEPTION_IF_NULL(shape_value_node);
693 kernel_graph->AddValueNodeToGraph(shape_value_node);
694 return shape_value_node;
695 }
696
AddCastNode(const FuncGraphPtr & func_graph,const TypeId dst_type,const CNodePtr & node,const bool is_input,const size_t input_index)697 CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, const CNodePtr &node, const bool is_input,
698 const size_t input_index) {
699 MS_EXCEPTION_IF_NULL(func_graph);
700 MS_EXCEPTION_IF_NULL(node);
701 std::vector<AnfNodePtr> new_cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name()))};
702 BaseShapePtr shape;
703 if (is_input) {
704 auto node_input = common::AnfAlgo::GetInputNode(node, input_index);
705 (void)new_cast_inputs.emplace_back(node_input);
706 shape = AnfAlgo::GetOutputDetailShape(node_input, 0);
707 } else {
708 (void)new_cast_inputs.emplace_back(node);
709 shape = AnfAlgo::GetOutputDetailShape(node, 0);
710 }
711 CNodePtr new_cast = NewCNode(new_cast_inputs, func_graph, {node});
712 MS_EXCEPTION_IF_NULL(new_cast);
713 new_cast->set_scope(node->scope());
714 new_cast->set_abstract(node->abstract());
715 common::AnfAlgo::SetNodeAttr(kAttrDstType, MakeValue(static_cast<size_t>(dst_type)), new_cast);
716 common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {shape}, new_cast.get());
717 return new_cast;
718 }
719
CreateNodeBase(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & new_node_inputs,const AnfNodePtr & node)720 AnfNodePtr CreateNodeBase(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &new_node_inputs,
721 const AnfNodePtr &node) {
722 MS_EXCEPTION_IF_NULL(graph);
723 MS_EXCEPTION_IF_NULL(node);
724 auto new_node = graph->NewCNode(new_node_inputs);
725 MS_EXCEPTION_IF_NULL(new_node);
726
727 new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
728 new_node->set_scope(node->scope());
729 new_node->set_abstract(node->abstract());
730
731 auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
732 auto shapes = {common::AnfAlgo::GetOutputInferShape(node, 0)};
733 common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get());
734
735 return new_node;
736 }
737
AnfEqual(const BaseRef & a,const BaseRef & b)738 bool AnfEqual(const BaseRef &a, const BaseRef &b) {
739 if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
740 auto a_node = utils::cast<AnfNodePtr>(a);
741 auto b_node = utils::cast<AnfNodePtr>(b);
742 MS_EXCEPTION_IF_NULL(a_node);
743 MS_EXCEPTION_IF_NULL(b_node);
744 if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
745 auto a_value_node = a_node->cast<ValueNodePtr>();
746 MS_EXCEPTION_IF_NULL(a_value_node);
747 auto a_value = a_value_node->value();
748 MS_EXCEPTION_IF_NULL(a_value);
749 auto a_prim = a_value->cast<PrimitivePtr>();
750 MS_EXCEPTION_IF_NULL(a_prim);
751
752 auto b_value_node = b_node->cast<ValueNodePtr>();
753 MS_EXCEPTION_IF_NULL(b_value_node);
754 auto b_value = b_value_node->value();
755 MS_EXCEPTION_IF_NULL(b_value);
756 auto b_prim = b_value->cast<PrimitivePtr>();
757 MS_EXCEPTION_IF_NULL(b_prim);
758
759 return a_prim->name() == b_prim->name();
760 } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
761 auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
762 if (a_value_node_ptr == nullptr) {
763 MS_LOG(INTERNAL_EXCEPTION) << "Cast value node ptr fail.";
764 }
765 auto a_value_ptr = a_value_node_ptr->value();
766 if (a_value_ptr == nullptr) {
767 MS_LOG(INTERNAL_EXCEPTION) << "Value ptr is nullptr.";
768 }
769
770 auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
771 if (b_value_node_ptr == nullptr) {
772 MS_LOG(INTERNAL_EXCEPTION) << "Cast value node ptr fail.";
773 }
774 auto b_value_ptr = b_value_node_ptr->value();
775 if (b_value_ptr == nullptr) {
776 MS_LOG(INTERNAL_EXCEPTION) << "Value ptr is nullptr.";
777 }
778 if (a_value_ptr->isa<tensor::Tensor>() && b_value_ptr->isa<tensor::Tensor>()) {
779 auto a_tensor_ptr = a_value_ptr->cast<tensor::TensorPtr>();
780 auto b_tensor_ptr = b_value_ptr->cast<tensor::TensorPtr>();
781 if (a_tensor_ptr == nullptr || b_tensor_ptr == nullptr) {
782 MS_LOG(INTERNAL_EXCEPTION) << "Cast value node ptr fail.";
783 }
784 return a_tensor_ptr->ValueEqual(*b_tensor_ptr);
785 } else {
786 return (*a_value_ptr) == (*b_value_ptr);
787 }
788 }
789 MS_LOG(DEBUG) << "check AnfNodePtr equal";
790 }
791 if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
792 MS_LOG(DEBUG) << "check GraphPtr equal";
793 }
794 return a == b;
795 }
796
CNodeTypeEqual(const BaseRef & a,const BaseRef & b)797 bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
798 // To matchCNode and Kernel's type
799 if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
800 return true;
801 }
802 return a.type() == b.type();
803 }
804
805 namespace {
CreateValueNodeWithSexp(const BaseRef & sexp,PrimitiveVarMap * primitive_vars)806 ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp, PrimitiveVarMap *primitive_vars) {
807 if (utils::isa<int>(sexp)) {
808 return NewValueNode(utils::cast<int>(sexp));
809 }
810 if (utils::isa<int64_t>(sexp)) {
811 return NewValueNode(utils::cast<int64_t>(sexp));
812 }
813 if (utils::isa<float>(sexp)) {
814 return NewValueNode(utils::cast<float>(sexp));
815 }
816 if (utils::isa<bool>(sexp)) {
817 return NewValueNode(utils::cast<bool>(sexp));
818 }
819 if (utils::isa<ValuePtr>(sexp)) {
820 auto value = utils::cast<ValuePtr>(sexp);
821 if (utils::isa<PrimitivePtr>(sexp)) {
822 auto prim = utils::cast<PrimitivePtr>(sexp);
823 if (primitive_vars->find(prim) != primitive_vars->end()) {
824 prim = std::make_shared<Primitive>(prim->name());
825 value = prim;
826 }
827 (*primitive_vars)[prim] = std::make_shared<Var>(prim);
828 }
829 return NewValueNode(value);
830 }
831 return nullptr;
832 }
833
CreateCNodeWithGraph(const std::vector<AnfNodePtr> & input_nodes,const BaseRef & graph)834 CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
835 if (utils::isa<FuncGraphPtr>(graph)) {
836 return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
837 }
838 if (utils::isa<VarPtr>(graph)) {
839 return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
840 }
841 return nullptr;
842 }
843
CreateVarNodeWithSexp(const BaseRef & sexp,const BaseRef & graph)844 VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
845 if (utils::isa<VarPtr>(graph)) {
846 MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
847 return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
848 }
849 if (utils::isa<FuncGraphPtr>(graph)) {
850 MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
851 return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
852 }
853 MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
854 return nullptr;
855 }
856
HandleSexpVector(const BaseRef & sexp,const BaseRef & graph,PrimitiveVarMap * primitive_vars,bool multigraph)857 AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
858 bool multigraph) {
859 MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
860 std::vector<AnfNodePtr> input_nodes;
861 const auto &tuple = utils::cast<VectorRef>(sexp);
862 if (multigraph && utils::isa<VarPtr>(graph)) {
863 for (auto &x : tuple) {
864 AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
865 input_nodes.push_back(node);
866 }
867 VarPtr var_ptr = utils::cast<VarPtr>(graph);
868 return std::make_shared<CNode>(input_nodes, var_ptr);
869 }
870
871 for (auto &x : tuple) {
872 AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
873 input_nodes.push_back(node);
874 }
875 return CreateCNodeWithGraph(input_nodes, graph);
876 }
877
RectifyAbstractFromStructuralAttr(const ValuePtr & value,const AbstractBasePtrList & input_abstract,const std::vector<size_t> & list_start_vec,size_t input_index)878 std::pair<AbstractBasePtr, size_t> RectifyAbstractFromStructuralAttr(const ValuePtr &value,
879 const AbstractBasePtrList &input_abstract,
880 const std::vector<size_t> &list_start_vec,
881 size_t input_index) {
882 MS_EXCEPTION_IF_NULL(value);
883 auto begin_iter = input_abstract.begin() + input_index;
884 if (value->isa<ValueSequence>()) {
885 size_t offset = 0;
886 std::vector<AbstractBasePtr> abs_list;
887 auto seq_value = value->cast_ptr<ValueSequence>();
888 for (size_t i = 0; i < seq_value->size(); ++i) {
889 auto [abs, offset_inner] =
890 RectifyAbstractFromStructuralAttr((*seq_value)[i], input_abstract, list_start_vec, input_index + offset);
891 MS_EXCEPTION_IF_NULL(abs);
892 if (abs->isa<abstract::AbstractSequence>() &&
893 std::find(list_start_vec.begin(), list_start_vec.end(), input_index + offset) != list_start_vec.end()) {
894 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
895 const auto &elements = abs_seq->elements();
896 bool is_nested = std::any_of(elements.begin(), elements.end(),
897 [](const AbstractBasePtr &abs) { return abs->isa<abstract::AbstractSequence>(); });
898 if (!is_nested) {
899 const auto &first_abs_in_list = input_abstract[input_index + offset];
900 MS_EXCEPTION_IF_NULL(first_abs_in_list);
901 if (!first_abs_in_list->has_user_data<kernel::PyExecuteOutputUserData>()) {
902 MS_LOG(INTERNAL_EXCEPTION) << "List input abstract PyExecuteOutputUserData not found.";
903 }
904 const auto &list_user_data = first_abs_in_list->user_data<kernel::PyExecuteOutputUserData>();
905 abs->set_user_data<kernel::PyExecuteOutputUserData>(list_user_data);
906 }
907 }
908 (void)abs_list.emplace_back(abs);
909 offset += offset_inner;
910 }
911 (void)std::for_each(begin_iter, begin_iter + offset, [](AbstractBasePtr abs) -> void {
912 MS_LOG(DEBUG) << "The convert abs is :" << abs->ToString();
913 });
914 return std::make_pair(std::make_shared<abstract::AbstractTuple>(abs_list), offset);
915 }
916
917 const auto num_value = GetValue<int64_t>(value);
918
919 constexpr auto kNotDynamicFlag = -1;
920 if (num_value == kNotDynamicFlag) {
921 return std::make_pair(*begin_iter, 1);
922 } else {
923 MS_LOG(EXCEPTION) << "The attr of structural must all value -1 but got " << num_value;
924 }
925 }
926
RectifyEmptyTupleAbstract(const ValuePtr & structural)927 AbstractBasePtr RectifyEmptyTupleAbstract(const ValuePtr &structural) {
928 MS_EXCEPTION_IF_NULL(structural);
929 if (!structural->isa<ValueTuple>()) {
930 MS_LOG(EXCEPTION) << "input abstract is out of range.";
931 }
932
933 auto value_tuple = structural->cast_ptr<ValueTuple>();
934 std::vector<AbstractBasePtr> abs_list;
935 MS_EXCEPTION_IF_NULL(value_tuple);
936 for (size_t i = 0; i < value_tuple->size(); ++i) {
937 auto item = (*value_tuple)[i];
938 (void)abs_list.emplace_back(RectifyEmptyTupleAbstract(item));
939 }
940
941 return std::make_shared<abstract::AbstractTuple>(abs_list);
942 }
943
RectifyAbstractFromTupleInputStructural(const ValuePtr & tuple_structural,const AbstractBasePtrList & input_abstract,const ValuePtrList & list_start)944 AbstractBasePtrList RectifyAbstractFromTupleInputStructural(const ValuePtr &tuple_structural,
945 const AbstractBasePtrList &input_abstract,
946 const ValuePtrList &list_start) {
947 if (tuple_structural == nullptr) {
948 return input_abstract;
949 }
950 auto tuple_structural_value = tuple_structural->cast_ptr<ValueSequence>();
951 MS_EXCEPTION_IF_NULL(tuple_structural_value);
952 AbstractBasePtrList rectifyed_abs_list;
953 size_t input_index = 0;
954 for (size_t i = 0; i < tuple_structural_value->size(); ++i) {
955 auto item = (*tuple_structural_value)[i];
956 MS_EXCEPTION_IF_NULL(item);
957 if (input_abstract.size() <= input_index) {
958 // The Ori Node : Oper(a, b, ()) ==> Oper(a, b) with structural --> (-1, -1 , ())
959 // The abstract size will be smaller than the attr of tuple input structural.
960 (void)rectifyed_abs_list.emplace_back(RectifyEmptyTupleAbstract(item));
961 }
962 std::vector<size_t> list_start_vec;
963 (void)std::transform(list_start.begin(), list_start.end(), std::back_inserter(list_start_vec),
964 [](const ValuePtr val) { return GetValue<size_t>(val); });
965 auto [abs, offset] = RectifyAbstractFromStructuralAttr(item, input_abstract, list_start_vec, input_index);
966 input_index += offset;
967 (void)rectifyed_abs_list.emplace_back(abs);
968 MS_LOG(DEBUG) << "Rectify abs :" << item->ToString() << ", from structural " << abs->ToString();
969 }
970
971 return rectifyed_abs_list;
972 }
973
RectifyAbstractFromDynamicInput(const PrimitivePtr & prim,const AbstractBasePtrList & input_abstract)974 AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
975 const AbstractBasePtrList &input_abstract) {
976 MS_EXCEPTION_IF_NULL(prim);
977 auto dyn_input_list = prim->GetAttr(kAttrDynInputSizes);
978 if (dyn_input_list == nullptr) {
979 return input_abstract;
980 }
981 AbstractBasePtrList rectifyed_abs_list;
982 const int kNotDynamicFlag = -1;
983 auto dynamic_input_index = GetValue<std::vector<int64_t>>(dyn_input_list);
984 size_t input_index = 0;
985 for (auto item : dynamic_input_index) {
986 if (item == kNotDynamicFlag) {
987 if (input_index >= input_abstract.size()) {
988 if ((prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name())) {
989 MS_LOG(WARNING) << "For primitive \'PyExecute\', index " << input_index
990 << " is out of range in input abstract " << input_abstract.size();
991 continue;
992 }
993 MS_LOG(EXCEPTION) << "For primitive \'" << prim->name() << "\', index " << input_index
994 << " is out of range in input abstract " << input_abstract.size();
995 }
996 (void)rectifyed_abs_list.emplace_back(input_abstract[input_index++]);
997 } else {
998 if (item < 0) {
999 MS_LOG(EXCEPTION) << "The dynamic input size check error the index should be -1 or positive number but got "
1000 << item;
1001 }
1002 AbstractBasePtrList dynamic_inputs_abs;
1003 for (auto index = item; index > 0; --index) {
1004 if (input_index >= input_abstract.size()) {
1005 if ((prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name())) {
1006 MS_LOG(WARNING) << "For primitive \'PyExecute\', index " << input_index
1007 << " is out of range in input abstract " << input_abstract.size();
1008 continue;
1009 }
1010 MS_LOG(EXCEPTION) << "For primitive \'" << prim->name() << "\', index " << input_index
1011 << " is out of range in input abstract " << input_abstract.size();
1012 }
1013 (void)dynamic_inputs_abs.emplace_back(input_abstract[input_index++]);
1014 }
1015 (void)rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs));
1016 }
1017 }
1018 return rectifyed_abs_list;
1019 }
1020
RectifyAbstract(const PrimitivePtr & prim,const AbstractBasePtrList & input_abstract)1021 AbstractBasePtrList RectifyAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &input_abstract) {
1022 auto input_structural = prim->GetAttr(kAttrTupleInputStructural);
1023 if (input_structural != nullptr) {
1024 if (prim->HasAttr(kAttrListStartIndex)) {
1025 auto list_start_index = prim->GetAttr(kAttrListStartIndex);
1026 MS_EXCEPTION_IF_NULL(list_start_index);
1027 auto list_start_index_value = list_start_index->cast_ptr<ValueSequence>();
1028 MS_EXCEPTION_IF_NULL(list_start_index_value);
1029 return RectifyAbstractFromTupleInputStructural(input_structural, input_abstract, list_start_index_value->value());
1030 }
1031 return RectifyAbstractFromTupleInputStructural(input_structural, input_abstract, {});
1032 }
1033 return RectifyAbstractFromDynamicInput(prim, input_abstract);
1034 }
1035
InferShapeWithCheck(const PrimitivePtr & prim,const PrimitivePtr & prim_clone,const AbstractBasePtrList & infer_spec_list,const AbstractBasePtr & orig_abs,const CNodePtr & cnode)1036 inline AbstractBasePtr InferShapeWithCheck(const PrimitivePtr &prim, const PrimitivePtr &prim_clone,
1037 const AbstractBasePtrList &infer_spec_list, const AbstractBasePtr &orig_abs,
1038 const CNodePtr &cnode) {
1039 MS_EXCEPTION_IF_NULL(prim_clone);
1040 MS_EXCEPTION_IF_NULL(orig_abs);
1041 AbstractBasePtr out_abs;
1042 if (auto shape_optional = abstract::InferShapeByFuncImpl(prim_clone, infer_spec_list); shape_optional.has_value()) {
1043 out_abs = orig_abs->Clone();
1044 out_abs->set_shape(shape_optional.value());
1045 } else if (auto found = abstract::GetBackendPrimitiveInferImpl(prim_clone); found.has_value()) {
1046 auto infer = found.value();
1047 MS_EXCEPTION_IF_CHECK_FAIL(infer.IsImplInferShapeAndType(), "There is no infer-shape implement for backend!");
1048 MS_EXCEPTION_IF_NULL(cnode);
1049 if (common::AnfAlgo::IsDynamicSequence(cnode)) {
1050 out_abs = infer.InferShapeAndType(nullptr, prim_clone, infer_spec_list);
1051 } else {
1052 out_abs = orig_abs->Clone();
1053 auto shape = infer.InferShape(prim_clone, infer_spec_list);
1054 if (shape == nullptr) {
1055 MS_LOG(EXCEPTION) << "Infer shape with backend function failed";
1056 }
1057 out_abs->set_shape(shape);
1058 }
1059 } else {
1060 MS_EXCEPTION_IF_NULL(prim);
1061 MS_LOG(EXCEPTION) << "Get infer functions failed, the operator is not support dynamic shape yet, primitive name:"
1062 << prim->name() << " primitive type:" << prim->type_name();
1063 }
1064 return out_abs;
1065 }
1066 } // namespace
1067
SexpToNode(const BaseRef & sexp,const BaseRef & graph,PrimitiveVarMap * primitive_vars,bool multigraph)1068 AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
1069 MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
1070 MS_EXCEPTION_IF_NULL(primitive_vars);
1071 if (utils::isa<VectorRef>(sexp)) {
1072 return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
1073 }
1074 if (utils::isa<VarPtr>(sexp)) {
1075 auto var_ptr = utils::cast<VarPtr>(sexp);
1076 MS_EXCEPTION_IF_NULL(var_ptr);
1077 if (var_ptr->primitive()) {
1078 (*primitive_vars)[var_ptr->primitive()] = var_ptr;
1079 return NewValueNode(var_ptr->primitive());
1080 }
1081 return CreateVarNodeWithSexp(sexp, graph);
1082 }
1083 if (utils::isa<AnfNodePtr>(sexp)) {
1084 return utils::cast<AnfNodePtr>(sexp);
1085 }
1086 auto value_node = CreateValueNodeWithSexp(sexp, primitive_vars);
1087 if (value_node == nullptr) {
1088 MS_LOG(INTERNAL_EXCEPTION) << "Sexp cannot converted, sexp: " + sexp.ToString();
1089 }
1090 return value_node;
1091 }
1092
IsSameNode(const EquivPtr & equiv1,const EquivPtr & equiv2,const VarPtr & var_node)1093 bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) {
1094 MS_EXCEPTION_IF_NULL(equiv1);
1095 MS_EXCEPTION_IF_NULL(equiv2);
1096 MS_EXCEPTION_IF_NULL(var_node);
1097 auto equiv1_node = GetAnfNodeByVar(equiv1, var_node);
1098 MS_EXCEPTION_IF_NULL(equiv1_node);
1099 auto equiv2_node = GetAnfNodeByVar(equiv2, var_node);
1100 MS_EXCEPTION_IF_NULL(equiv2_node);
1101 return *equiv1_node == *equiv2_node;
1102 }
1103
GetAnfNodeByVar(const EquivPtr & equiv,const VarPtr & var_node)1104 AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) {
1105 MS_EXCEPTION_IF_NULL(equiv);
1106 MS_EXCEPTION_IF_NULL(var_node);
1107 auto iter = (*equiv).find(var_node);
1108 if (iter == (*equiv).cend()) {
1109 MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched.";
1110 return nullptr;
1111 }
1112 auto res = utils::cast<AnfNodePtr>(iter->second);
1113 if (res == nullptr) {
1114 MS_LOG(INTERNAL_EXCEPTION) << "Cast fail! Maybe var is not a anf node";
1115 }
1116 return res;
1117 }
1118
GetGetitemIndex(const AnfNodePtr & getitem)1119 int64_t GetGetitemIndex(const AnfNodePtr &getitem) {
1120 if (!getitem->isa<CNode>() || IsPrimitive(getitem, prim::kPrimTupleGetItem)) {
1121 MS_LOG(INTERNAL_EXCEPTION) << "Expect TupleGetItem, but got " << getitem->DebugString();
1122 }
1123 auto vnode = GetValueNode(getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
1124 return GetValue<int64_t>(vnode);
1125 }
1126
CompareTupleGetitem(const AnfNodePtr & n1,const AnfNodePtr & n2)1127 bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
1128 MS_EXCEPTION_IF_NULL(n1);
1129 MS_EXCEPTION_IF_NULL(n2);
1130 auto n1_cnode = n1->cast<CNodePtr>();
1131 auto n2_cnode = n2->cast<CNodePtr>();
1132 MS_EXCEPTION_IF_NULL(n1_cnode);
1133 MS_EXCEPTION_IF_NULL(n2_cnode);
1134 auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem);
1135 MS_EXCEPTION_IF_NULL(index_input1);
1136 auto value_node1 = index_input1->cast<ValueNodePtr>();
1137 MS_EXCEPTION_IF_NULL(value_node1);
1138 auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem);
1139 MS_EXCEPTION_IF_NULL(index_input2);
1140 auto value_node2 = index_input2->cast<ValueNodePtr>();
1141 MS_EXCEPTION_IF_NULL(value_node2);
1142 return GetValue<int64_t>(value_node1->value()) < GetValue<int64_t>(value_node2->value());
1143 }
1144
GetBoolAttr(const AnfNodePtr & node,const std::string & attr_name)1145 bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
1146 MS_EXCEPTION_IF_NULL(node);
1147 if (!node->isa<CNode>()) {
1148 MS_LOG(INFO) << "node is not a cnode";
1149 return false;
1150 }
1151 auto cnode = node->cast<CNodePtr>();
1152 MS_EXCEPTION_IF_NULL(cnode);
1153 return common::AnfAlgo::HasNodeAttr(attr_name, cnode) && common::AnfAlgo::GetNodeAttr<bool>(node, attr_name);
1154 }
1155
CheckSupportDataType(const AnfNodePtr & node,const std::set<TypeId> & supported_data_type_set)1156 bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set) {
1157 MS_EXCEPTION_IF_NULL(node);
1158 TypeId data_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
1159 if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) {
1160 return true;
1161 }
1162 MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
1163 return false;
1164 }
1165
MakeValueNode(const ValueNodePtr & value_node)1166 ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
1167 MS_EXCEPTION_IF_NULL(value_node);
1168 ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
1169 MS_EXCEPTION_IF_NULL(new_value_node);
1170 new_value_node->set_abstract(value_node->abstract());
1171 // create kernel_info fo new value node
1172 auto kernel_info = std::make_shared<device::KernelInfo>();
1173 new_value_node->set_kernel_info(kernel_info);
1174 // create kernel_build_info for new value node
1175 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
1176 MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
1177 // set the format of value_node to DEFAULT_FORMAT
1178 kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
1179 // set value node initial device data type = infer data type
1180 std::vector<TypeId> types;
1181 size_t output_num = AnfAlgo::GetOutputTensorNum(value_node);
1182 for (size_t index = 0; index < output_num; ++index) {
1183 types.push_back(kTypeUnknown);
1184 }
1185 kernel_build_info_builder->SetOutputsDeviceType(types);
1186 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
1187 return new_value_node;
1188 }
1189
TransferDependOrUpdateState(const CNodePtr & old_node,const FuncGraphPtr & graph,const CNodePtr & new_node)1190 void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
1191 MS_EXCEPTION_IF_NULL(old_node);
1192 MS_EXCEPTION_IF_NULL(graph);
1193 auto manager = graph->manager();
1194 MS_EXCEPTION_IF_NULL(manager);
1195 // Find BatchNorm's output which is a Depend or UpdateState.
1196 auto node_users = manager->node_users()[old_node];
1197 for (const auto &node_index : node_users) {
1198 AnfNodePtr output = node_index.first;
1199 MS_EXCEPTION_IF_NULL(output);
1200 if (common::AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
1201 common::AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
1202 auto depend = output->cast<CNodePtr>();
1203 MS_EXCEPTION_IF_NULL(depend);
1204 manager->SetEdge(depend, node_index.second, new_node);
1205 }
1206 }
1207 }
1208
GetPrimitiveChangeInfo(const PrimitivePtr & prim,std::string * me_name,bool * ir_change)1209 void GetPrimitiveChangeInfo(const PrimitivePtr &prim, std::string *me_name, bool *ir_change) {
1210 MS_EXCEPTION_IF_NULL(prim);
1211 MS_EXCEPTION_IF_NULL(me_name);
1212 MS_EXCEPTION_IF_NULL(ir_change);
1213 if (prim->HasAttr(kAttrMeOpName)) {
1214 *me_name = GetValue<std::string>(prim->GetAttr(kAttrMeOpName));
1215 }
1216 if (prim->HasAttr(kAttrIRChange)) {
1217 *ir_change = GetValue<bool>(prim->GetAttr(kAttrIRChange));
1218 }
1219 if (*ir_change || !me_name->empty()) {
1220 MS_LOG(DEBUG) << "Note: primitive(" << prim->ToString() << ", me_name:" << *me_name
1221 << ", ori_name: " << prim->name() << ", ir_change" << *ir_change << ") "
1222 << "has been changed in ascend vm pass, it should been rectify abstract before infer or provide a "
1223 "new infer func";
1224 }
1225 }
1226
CppInferShape(const PrimitivePtr & prim,const AbstractBasePtrList & args_spec_list,const CNodePtr & cnode)1227 void CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list, const CNodePtr &cnode) {
1228 MS_EXCEPTION_IF_NULL(prim);
1229 MS_EXCEPTION_IF_NULL(cnode);
1230 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kKernelInferInner,
1231 prim->name(), true);
1232 auto old_abs = cnode->abstract();
1233 MS_EXCEPTION_IF_NULL(old_abs);
1234
1235 if (IS_OUTPUT_ON(mindspore::kDebug)) {
1236 MS_LOG(DEBUG) << "Infer name = " << cnode->fullname_with_scope();
1237 for (size_t i = 0; i < args_spec_list.size(); i++) {
1238 MS_LOG(DEBUG) << "Infer name '" << cnode->fullname_with_scope() << "', The input[" << i
1239 << "] abs is : " << args_spec_list[i]->ToString();
1240 }
1241 }
1242
1243 PrimitivePtr prim_clone = prim;
1244 MS_EXCEPTION_IF_NULL(prim_clone);
1245 std::string me_name;
1246 std::string ori_name;
1247 bool ir_change = false;
1248 GetPrimitiveChangeInfo(prim, &me_name, &ir_change);
1249 if (!me_name.empty()) {
1250 prim_clone = prim->Clone();
1251 ori_name = prim->name();
1252 prim_clone->set_name(me_name);
1253 }
1254
1255 auto infer_spec_list = RectifyAbstract(prim_clone, args_spec_list);
1256 AbstractBasePtr out_abs = InferShapeWithCheck(prim, prim_clone, infer_spec_list, old_abs, cnode);
1257
1258 if (prim_clone != prim) {
1259 *prim = *prim_clone;
1260 prim->set_name(ori_name);
1261 }
1262 cnode->set_abstract(out_abs);
1263 MS_LOG(DEBUG) << "The abstract of " << cnode->fullname_with_scope() << " changes from " << old_abs << " to "
1264 << out_abs;
1265 }
1266
CppInferShapeAndType(const PrimitivePtr & prim,const AbstractBasePtrList & args_spec_list)1267 AbstractBasePtr CppInferShapeAndType(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
1268 MS_EXCEPTION_IF_NULL(prim);
1269 PrimitivePtr prim_clone = prim;
1270 MS_EXCEPTION_IF_NULL(prim_clone);
1271 std::string me_name;
1272 std::string ori_name;
1273 bool ir_change = false;
1274 GetPrimitiveChangeInfo(prim, &me_name, &ir_change);
1275 if (!me_name.empty()) {
1276 prim_clone = prim->Clone();
1277 ori_name = prim->name();
1278 prim_clone->set_name(me_name);
1279 }
1280
1281 AbstractBasePtr ret;
1282 if (auto abstract_optional = abstract::InferAbstractByFuncImpl(prim_clone, args_spec_list);
1283 abstract_optional.has_value()) {
1284 ret = abstract_optional.value();
1285 } else if (auto found = abstract::GetBackendPrimitiveInferImpl(prim_clone); found.has_value()) {
1286 auto infer = found.value();
1287 MS_EXCEPTION_IF_CHECK_FAIL(infer.IsImplInferShapeAndType(), "There is no infer-abstract implement!");
1288 auto infer_spec_list = RectifyAbstract(prim_clone, args_spec_list);
1289 ret = infer.InferShapeAndType(nullptr, prim_clone, infer_spec_list);
1290 } else {
1291 MS_LOG(EXCEPTION)
1292 << "Get infer shape function failed, the operator is not support dynamic shape yet, primitive name:"
1293 << prim->name() << " primitive type:" << prim->type_name();
1294 }
1295
1296 if (prim_clone != prim) {
1297 *prim = *prim_clone;
1298 prim->set_name(ori_name);
1299 }
1300 return ret;
1301 }
1302
GenerateKernelBuildInfo(const std::vector<AnfNodePtr> & node_list)1303 kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
1304 std::vector<std::string> inputs_device_format;
1305 std::vector<std::string> outputs_device_format;
1306 std::vector<TypeId> inputs_device_type;
1307 std::vector<TypeId> outputs_device_type;
1308 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
1309 for (size_t idx = 0; idx < node_list.size(); ++idx) {
1310 auto cnode = utils::cast<CNodePtr>(node_list[idx]);
1311 MS_EXCEPTION_IF_NULL(cnode);
1312 size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
1313 for (size_t input_index = 0; input_index < input_num; ++input_index) {
1314 (void)inputs_device_format.emplace_back(kOpFormat_DEFAULT);
1315 (void)inputs_device_type.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
1316 }
1317 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
1318 for (size_t output_index = 0; output_index < output_num; ++output_index) {
1319 (void)outputs_device_format.emplace_back(kOpFormat_DEFAULT);
1320 (void)outputs_device_type.emplace_back(common::AnfAlgo::GetOutputInferDataType(cnode, output_index));
1321 }
1322 }
1323 builder.SetInputsFormat(inputs_device_format);
1324 builder.SetOutputsFormat(outputs_device_format);
1325 builder.SetInputsDeviceType(inputs_device_type);
1326 builder.SetOutputsDeviceType(outputs_device_type);
1327 return builder.Build();
1328 }
1329
GetNodeOutputUsedNum(const session::KernelGraph & kernel_graph,const AnfNodePtr & node)1330 std::vector<int64_t> GetNodeOutputUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
1331 MS_EXCEPTION_IF_NULL(node);
1332 auto manager = kernel_graph.manager();
1333 MS_EXCEPTION_IF_NULL(manager);
1334 auto output_num = AnfAlgo::GetOutputTensorNum(node);
1335 std::vector<int64_t> output_used_num(output_num, 0);
1336 if (output_num == 1) {
1337 output_used_num[0] = SizeToLong(manager->node_users()[node].size());
1338 } else {
1339 for (auto out_getitem : manager->node_users()[node]) {
1340 MS_EXCEPTION_IF_NULL(out_getitem.first);
1341 if (!common::AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
1342 continue;
1343 }
1344 auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
1345 MS_EXCEPTION_IF_NULL(out_getitem_ptr);
1346 auto getitem_input2 = out_getitem_ptr->input(kInputNodeOutputIndexInTupleGetItem);
1347 auto output_idx = LongToSize(GetValue<int64_t>(GetValueNode(getitem_input2)));
1348 output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
1349 }
1350 }
1351 return output_used_num;
1352 }
1353
GetNodeOutputTotalUsedNum(const session::KernelGraph & kernel_graph,const AnfNodePtr & node)1354 int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
1355 auto output_used_num = GetNodeOutputUsedNum(kernel_graph, node);
1356 return std::accumulate(output_used_num.begin(), output_used_num.end(), int64_t(0));
1357 }
1358
GetCustomOpAttrIndex(const PrimitivePtr & primitive,mindspore::HashSet<size_t> * indexes)1359 void GetCustomOpAttrIndex(const PrimitivePtr &primitive, mindspore::HashSet<size_t> *indexes) {
1360 if (primitive == nullptr || primitive->name() != prim::kPrimCustom->name()) {
1361 return;
1362 }
1363 MS_EXCEPTION_IF_NULL(indexes);
1364 auto input_names = primitive->GetAttr(kAttrInputNames);
1365 auto attr_names = primitive->GetAttr(kAttrAttrNames);
1366 if (input_names == nullptr || attr_names == nullptr) {
1367 return;
1368 }
1369 auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
1370 auto attr_names_vec = GetValue<std::vector<std::string>>(attr_names);
1371 for (size_t i = 0; i < input_names_vec.size(); ++i) {
1372 if (std::find(attr_names_vec.begin(), attr_names_vec.end(), input_names_vec[i]) != attr_names_vec.end()) {
1373 (void)indexes->insert(i);
1374 }
1375 }
1376 }
1377
GetInputNodeIndex(const AnfNodePtr & input,const CNodePtr & user_node)1378 size_t GetInputNodeIndex(const AnfNodePtr &input, const CNodePtr &user_node) {
1379 MS_EXCEPTION_IF_NULL(input);
1380 MS_EXCEPTION_IF_NULL(user_node);
1381
1382 AnfNodePtrList input_list = user_node->inputs();
1383 auto pos = std::find(input_list.begin(), input_list.end(), input);
1384 if (pos == input_list.end()) {
1385 MS_LOG(EXCEPTION) << input->fullname_with_scope() << " is not the input of " << user_node->fullname_with_scope();
1386 }
1387
1388 // The first input is Primitive and needs to be skipped.
1389 return std::distance(input_list.begin() + kSizeOne, pos);
1390 }
1391
SplitTupleInputs(const FuncGraphPtr & graph,const AnfNodePtr & tuple_input,std::vector<AnfNodePtr> * plant_inputs)1392 int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
1393 std::vector<AnfNodePtr> *plant_inputs) {
1394 MS_EXCEPTION_IF_NULL(tuple_input);
1395 if (!common::AnfAlgo::IsTupleOutput(tuple_input)) {
1396 auto abs = tuple_input->abstract();
1397 MS_EXCEPTION_IF_NULL(abs);
1398 MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
1399 return -1;
1400 }
1401 MS_EXCEPTION_IF_NULL(plant_inputs);
1402 auto input_size = AnfAlgo::GetOutputElementNum(tuple_input);
1403 if (tuple_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
1404 auto make_tuple = tuple_input->cast<CNodePtr>();
1405 MS_EXCEPTION_IF_NULL(make_tuple);
1406 size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
1407 for (size_t j = 0; j < tuple_input_num; ++j) {
1408 // using for graph kernel
1409 auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
1410 MS_EXCEPTION_IF_NULL(dyn_input_node);
1411 // Handle tuple nested scenes.
1412 if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
1413 input_size += LongToSize(SplitTupleInputs(graph, dyn_input_node, plant_inputs));
1414 continue;
1415 }
1416 (void)plant_inputs->emplace_back(dyn_input_node);
1417 }
1418 return input_size;
1419 }
1420 for (size_t index = 0; index < input_size; ++index) {
1421 auto dynamic_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
1422 (void)plant_inputs->emplace_back(dynamic_input_node);
1423 }
1424 return input_size;
1425 }
1426
IsNotSequenceOfTensor(const abstract::AbstractBasePtr & abs)1427 static bool IsNotSequenceOfTensor(const abstract::AbstractBasePtr &abs) {
1428 if (abs->isa<abstract::AbstractTensor>()) {
1429 return false;
1430 }
1431
1432 if (abs->isa<abstract::AbstractSequence>()) {
1433 auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
1434 MS_EXCEPTION_IF_NULL(seq_abs);
1435 if (seq_abs->size() == 0) {
1436 return true;
1437 }
1438
1439 return IsNotSequenceOfTensor(seq_abs->elements()[0]);
1440 }
1441
1442 return true;
1443 }
1444
GenPrintAttrDynInputSizes(const CNodePtr & cnode)1445 std::vector<int64_t> GenPrintAttrDynInputSizes(const CNodePtr &cnode) {
1446 int64_t num_inputs = 0;
1447 std::vector<AnfNodePtr> node_inputs = cnode->inputs();
1448 for (size_t node_inputs_index = 1; node_inputs_index < node_inputs.size(); ++node_inputs_index) {
1449 auto &input = node_inputs[node_inputs_index];
1450 MS_EXCEPTION_IF_NULL(input);
1451 if (IsValueNode<UMonad>(input) || IsValueNode<IOMonad>(input) || HasAbstractMonad(input)) {
1452 continue;
1453 }
1454 num_inputs++;
1455 }
1456 // the first input of print is a placeholder
1457 return std::vector<int64_t>{-1, num_inputs - 1, -1};
1458 }
1459
InputArgTypeIsDynamicType(const mindspore::ops::OP_DTYPE input_arg_dtype)1460 bool InputArgTypeIsDynamicType(const mindspore::ops::OP_DTYPE input_arg_dtype) {
1461 if (input_arg_dtype >= mindspore::ops::DT_TUPLE_BOOL && input_arg_dtype <= mindspore::ops::DT_LIST_ANY) {
1462 return true;
1463 }
1464 return false;
1465 }
1466
UseEmptyNodeReplaceNone(const FuncGraphPtr & graph,const std::string & cnode_name,const size_t input_idx,std::vector<int64_t> * dyn_input_sizes,std::vector<AnfNodePtr> * plant_inputs)1467 void UseEmptyNodeReplaceNone(const FuncGraphPtr &graph, const std::string &cnode_name, const size_t input_idx,
1468 std::vector<int64_t> *dyn_input_sizes, std::vector<AnfNodePtr> *plant_inputs) {
1469 MS_EXCEPTION_IF_NULL(dyn_input_sizes);
1470 MS_EXCEPTION_IF_NULL(plant_inputs);
1471 if (OpInputDtypeMap.at(cnode_name).find(input_idx) != OpInputDtypeMap.at(cnode_name).end()) {
1472 // create empty tensor
1473 auto tensor_type = OpInputDtypeMap.at(cnode_name).at(input_idx);
1474 std::vector<int64_t> tensor_shape = {0};
1475 auto empty_tensor = std::make_shared<tensor::Tensor>(tensor_type, tensor_shape);
1476 // create node
1477 auto empty_node = opt::CreateValueNodeWithKernelInfo(graph, empty_tensor);
1478 ValueNodePtr empty_value_node = empty_node->cast<ValueNodePtr>();
1479 // empty node size is 1
1480 dyn_input_sizes->emplace_back(1);
1481 plant_inputs->emplace_back(empty_value_node);
1482 } else {
1483 MS_LOG(EXCEPTION) << "Invalid input index. The [" << input_idx << "] in op [" << cnode_name
1484 << "] is not in OpInputDtypeMap, cannot use new node replace None.";
1485 }
1486 }
1487
GetPlantInputsAndSize(const FuncGraphPtr & graph,const CNodePtr & cnode_ptr,std::vector<AnfNodePtr> * plant_inputs,std::vector<int64_t> * dyn_input_sizes)1488 void GetPlantInputsAndSize(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr, std::vector<AnfNodePtr> *plant_inputs,
1489 std::vector<int64_t> *dyn_input_sizes) {
1490 MS_EXCEPTION_IF_NULL(cnode_ptr);
1491 auto cnode_name = common::AnfAlgo::GetCNodeName(cnode_ptr);
1492 plant_inputs->push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
1493 size_t input_num = cnode_ptr->size() - 1;
1494 bool cnode_is_print = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPrint);
1495 for (size_t i = 0; i < input_num; ++i) {
1496 auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i);
1497 MS_EXCEPTION_IF_NULL(input_node);
1498 bool output_is_tuple = common::AnfAlgo::IsTupleOutput(input_node);
1499 if (output_is_tuple && cnode_is_print) {
1500 continue;
1501 } else if (output_is_tuple) {
1502 int64_t dyn_input_size;
1503 if (IsNotSequenceOfTensor(input_node->abstract())) {
1504 dyn_input_size = 0;
1505 } else {
1506 dyn_input_size = SplitTupleInputs(graph, input_node, plant_inputs);
1507 }
1508 if (dyn_input_size == 0) {
1509 dyn_input_sizes->push_back(-1);
1510 plant_inputs->push_back(input_node);
1511 } else {
1512 (void)dyn_input_sizes->emplace_back(dyn_input_size);
1513 }
1514 } else if (OpInputDtypeMap.find(cnode_name) != OpInputDtypeMap.end()) {
1515 // Only op in OpInputDtypeMap can be replace None input.
1516 auto opdef_ptr = mindspore::ops::GetOpDef(cnode_name);
1517 MS_EXCEPTION_IF_NULL(opdef_ptr);
1518 auto input_args = (opdef_ptr)->args_;
1519 if (i >= input_args.size()) {
1520 MS_LOG(EXCEPTION) << "The [" << i << "] in op [" << cnode_name << "] is out of op_def args range";
1521 }
1522 // When input[i] is None and input[i] type in op_yaml is dynamic type, do replace
1523 if (common::AnfAlgo::IsNoneInput(cnode_ptr, i) && InputArgTypeIsDynamicType(input_args[i].arg_dtype_)) {
1524 UseEmptyNodeReplaceNone(graph, cnode_name, i, dyn_input_sizes, plant_inputs);
1525 } else {
1526 dyn_input_sizes->push_back(-1);
1527 plant_inputs->push_back(input_node);
1528 }
1529 } else {
1530 dyn_input_sizes->push_back(-1);
1531 plant_inputs->push_back(input_node);
1532 }
1533 }
1534 }
1535
ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr & graph,const CNodePtr & cnode_ptr)1536 AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
1537 MS_EXCEPTION_IF_NULL(cnode_ptr);
1538 MS_EXCEPTION_IF_NULL(graph);
1539 if (common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) ||
1540 common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial) ||
1541 common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimBpropCut)) {
1542 return nullptr;
1543 }
1544
1545 if (common::AnfAlgo::HasDynamicTupleInput(cnode_ptr)) {
1546 MS_LOG(INFO) << "Node " << cnode_ptr->fullname_with_scope()
1547 << " has dynamic tuple input, can't convert. Node debug string:" << cnode_ptr->DebugString();
1548 return nullptr;
1549 }
1550 std::vector<AnfNodePtr> plant_inputs;
1551 std::vector<int64_t> dyn_input_sizes;
1552 GetPlantInputsAndSize(graph, cnode_ptr, &plant_inputs, &dyn_input_sizes);
1553
1554 // If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs.
1555 if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
1556 auto new_cnode = NewCNode(plant_inputs, graph, {cnode_ptr});
1557 MS_EXCEPTION_IF_NULL(new_cnode);
1558 new_cnode->set_abstract(cnode_ptr->abstract());
1559 new_cnode->set_scope(cnode_ptr->scope());
1560 new_cnode->set_primal_attrs(cnode_ptr->primal_attrs());
1561 new_cnode->set_attrs(cnode_ptr->attrs());
1562 bool cnode_is_print = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPrint);
1563 if (cnode_is_print) {
1564 dyn_input_sizes = GenPrintAttrDynInputSizes(new_cnode);
1565 }
1566 common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_cnode);
1567 auto kernel_graph = graph->cast<KernelGraphPtr>();
1568 if (kernel_graph != nullptr) {
1569 kernel_graph->FrontBackendlMapUpdate(cnode_ptr, new_cnode);
1570 }
1571 return new_cnode;
1572 }
1573 return nullptr;
1574 }
1575
InferOp(const CNodePtr & node,void * args)1576 void InferOp(const CNodePtr &node, void *args) { dynamic_shape::InferOp(node, args); }
1577
1578 LaunchHandler launch_py_handler{nullptr};
set_launch_handler(const LaunchHandler & handler)1579 void set_launch_handler(const LaunchHandler &handler) { launch_py_handler = handler; }
1580
LaunchPy(const PrimitivePtr & primitive,const std::vector<abstract::AbstractBase * > & args_abs_list)1581 abstract::AbstractBasePtr LaunchPy(const PrimitivePtr &primitive,
1582 const std::vector<abstract::AbstractBase *> &args_abs_list) {
1583 MS_EXCEPTION_IF_NULL(launch_py_handler);
1584 return launch_py_handler(primitive, args_abs_list);
1585 }
1586
InferAbstract(const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & input_list)1587 AbstractBasePtr InferAbstract(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &input_list) {
1588 MS_EXCEPTION_IF_NULL(primitive);
1589 const auto &op_name = primitive->name();
1590 std::vector<AbstractBasePtr> input_args;
1591 std::for_each(input_list.begin(), input_list.end(),
1592 [&input_args](const auto &input) { input_args.emplace_back(input->abstract()); });
1593 auto shape_optional = abstract::InferAbstractByFuncImpl(primitive, input_args);
1594 if (shape_optional.has_value()) {
1595 return shape_optional.value();
1596 }
1597
1598 auto infer_impl = abstract::GetBackendPrimitiveInferImpl(primitive);
1599 if (infer_impl.has_value()) {
1600 auto infer = infer_impl.value();
1601 if (infer.IsImplInferShapeAndType()) {
1602 return infer.InferShapeAndType(nullptr, primitive, input_args);
1603 }
1604 }
1605 MS_LOG(EXCEPTION) << "The InferAbstract function of [" << op_name << "] is not defined.";
1606 }
1607
CreateValueNodeWithKernelInfo(const FuncGraphPtr & graph,const ValuePtr & value)1608 AnfNodePtr CreateValueNodeWithKernelInfo(const FuncGraphPtr &graph, const ValuePtr &value) {
1609 MS_EXCEPTION_IF_NULL(value);
1610 auto value_node = NewValueNode(value);
1611 MS_EXCEPTION_IF_NULL(value_node);
1612 auto value_abs = value->ToAbstract();
1613 value_node->set_abstract(value_abs);
1614
1615 MS_EXCEPTION_IF_NULL(graph);
1616 auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(graph);
1617 if (kernel_graph != nullptr) {
1618 // In kernel graph case, a new value node should set kernel_info and kernel_build_info here for no-kernel-selecting.
1619 auto kernel_info = std::make_shared<device::KernelInfo>();
1620 value_node->set_kernel_info(kernel_info);
1621 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
1622 builder.SetOutputsFormat({kOpFormat_DEFAULT});
1623 if (value->isa<tensor::Tensor>()) {
1624 auto tensor = value->cast<tensor::TensorPtr>();
1625 MS_EXCEPTION_IF_NULL(tensor);
1626 builder.SetOutputsDeviceType({tensor->data_type()});
1627 } else {
1628 MS_EXCEPTION_IF_NULL(value->type());
1629 auto type_id = value->type()->type_id();
1630 if (value->isa<ValueSequence>()) {
1631 auto value_sequence = value->cast<ValueSequencePtr>()->value();
1632 if (value_sequence.empty()) {
1633 type_id = kNumberTypeInt64;
1634 } else {
1635 MS_EXCEPTION_IF_NULL(value_sequence[0]->type());
1636 type_id = value_sequence[0]->type()->type_id();
1637 }
1638 }
1639 builder.SetOutputsDeviceType({type_id});
1640 }
1641 auto object_type = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAbstractObjectType(value_abs));
1642 builder.SetOutputsKernelObjectType({object_type});
1643 AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), value_node.get());
1644
1645 kernel_graph->AddValueNodeToGraph(value_node);
1646 }
1647
1648 return value_node;
1649 }
1650 } // namespace opt
1651 } // namespace mindspore
1652