1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/common/helper.h"
18 #include <string>
19 #include <utility>
20 #include <unordered_set>
21 #include <algorithm>
22 #include <map>
23 #include <set>
24 #include <deque>
25 #include "utils/utils.h"
26 #include "base/base_ref.h"
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "base/core_ops.h"
29 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
30 #include "frontend/operator/ops.h"
31 #include "utils/ms_utils.h"
32 #include "runtime/device/kernel_info.h"
33 #include "utils/ms_context.h"
34 #include "backend/optimizer/common/const_input_to_attr_registry.h"
35 #include "abstract/primitive_infer_map.h"
36
37 namespace mindspore {
38 namespace opt {
39 constexpr size_t kType32Len = 4;
40 constexpr size_t kType64Len = 8;
41
Convert2Int(const std::vector<size_t> & v)42 std::vector<int64_t> Convert2Int(const std::vector<size_t> &v) {
43 std::vector<int64_t> result;
44 (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt);
45 return result;
46 }
47
Convert2Long(const std::vector<size_t> & v)48 std::vector<int64_t> Convert2Long(const std::vector<size_t> &v) {
49 std::vector<int64_t> result;
50 (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToLong);
51 return result;
52 }
53
IsDepend(const FuncGraph & graph,const AnfNodePtr & node,const std::vector<AnfNodePtr> & nodes)54 bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
55 MS_EXCEPTION_IF_NULL(node);
56 FuncGraphManagerPtr manager = graph.manager();
57 MS_EXCEPTION_IF_NULL(manager);
58
59 std::unordered_set<AnfNodePtr> seen_node;
60 std::deque<AnfNodePtr> todo{node};
61 while (!todo.empty()) {
62 AnfNodePtr nd = todo.front();
63 todo.pop_front();
64 if (seen_node.count(nd) > 0 || !manager->all_nodes().contains(nd)) {
65 continue;
66 }
67 (void)seen_node.insert(nd);
68
69 if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) {
70 return true;
71 }
72 if (nd->isa<CNode>()) {
73 auto cnode = nd->cast<CNodePtr>();
74 MS_EXCEPTION_IF_NULL(cnode);
75 auto inputs = cnode->inputs();
76 (void)todo.insert(todo.end(), inputs.begin(), inputs.end());
77 }
78 }
79 return false;
80 }
81
UnVisited(const BaseRef & n)82 bool UnVisited(const BaseRef &n) {
83 if (utils::isa<AnfNodePtr>(n)) {
84 AnfNodePtr in = utils::cast<AnfNodePtr>(n);
85 MS_EXCEPTION_IF_NULL(in);
86 if (IsValueNode<Primitive>(in)) {
87 auto value_node = in->cast<ValueNodePtr>();
88 MS_EXCEPTION_IF_NULL(value_node);
89 auto value = value_node->value();
90 MS_EXCEPTION_IF_NULL(value);
91 auto prim_py = value->cast<PrimitivePtr>();
92 MS_EXCEPTION_IF_NULL(prim_py);
93 return !prim_py->HasAttr(kAttrVisited);
94 } else if (IsValueNode<FuncGraph>(in)) {
95 auto func_graph = GetValueNode<FuncGraphPtr>(in);
96 MS_EXCEPTION_IF_NULL(func_graph);
97 return !func_graph->has_flag(kAttrVisited);
98 }
99 return false;
100 }
101 return false;
102 }
103
CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr & node,size_t input_size)104 CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size) {
105 MS_EXCEPTION_IF_NULL(node);
106 if (!node->isa<CNode>()) {
107 MS_LOG(EXCEPTION) << "The node is expected to be a cnode";
108 }
109 auto cnode = node->cast<CNodePtr>();
110 CheckCNodeInputSize(cnode, input_size);
111 return cnode;
112 }
113
CheckCNodeInputSize(const CNodePtr & cnode,size_t input_tensor_size)114 void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_size) {
115 MS_EXCEPTION_IF_NULL(cnode);
116 auto real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode);
117 if (real_input_tensor_num != input_tensor_size) {
118 MS_LOG(EXCEPTION) << "The input tensor size[" << real_input_tensor_num
119 << "] of node " + cnode->DebugString() + " is not equal to " << input_tensor_size;
120 }
121 }
122
HasSymmetricalKernelInfo(const AnfNodePtr & node_x,const AnfNodePtr & node_y)123 bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) {
124 MS_EXCEPTION_IF_NULL(node_x);
125 MS_EXCEPTION_IF_NULL(node_y);
126 return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) &&
127 AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0));
128 }
129
EliminateDependTransop(const FuncGraphPtr & func_graph,const AnfNodePtr & node)130 const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
131 MS_EXCEPTION_IF_NULL(func_graph);
132
133 auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
134 MS_EXCEPTION_IF_NULL(transop_cnode);
135 auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(1), kDependInputTensorNum);
136 auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputTensorNum);
137 auto transed_node = prev_transop_cnode->input(1);
138 MS_EXCEPTION_IF_NULL(transed_node);
139
140 std::vector<AnfNodePtr> replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node,
141 depend_cnode->input(kDependAttachNodeIndex)};
142 AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs);
143 MS_EXCEPTION_IF_NULL(replace_depend);
144 auto transed_abstract = transed_node->abstract();
145 replace_depend->set_abstract(transed_abstract);
146 return replace_depend;
147 }
148
Visited(const BaseRef & n)149 bool Visited(const BaseRef &n) {
150 if (utils::isa<AnfNodePtr>(n)) {
151 AnfNodePtr in = utils::cast<AnfNodePtr>(n);
152 MS_EXCEPTION_IF_NULL(in);
153 if (IsValueNode<Primitive>(in)) {
154 auto value_node = in->cast<ValueNodePtr>();
155 MS_EXCEPTION_IF_NULL(value_node);
156 auto value = value_node->value();
157 MS_EXCEPTION_IF_NULL(value);
158 auto prim_py = value->cast<PrimitivePtr>();
159 MS_EXCEPTION_IF_NULL(prim_py);
160 return prim_py->HasAttr(kAttrVisited);
161 } else if (IsValueNode<FuncGraph>(in)) {
162 auto func_graph = GetValueNode<FuncGraphPtr>(in);
163 MS_EXCEPTION_IF_NULL(func_graph);
164 return func_graph->has_flag(kAttrVisited);
165 }
166 return false;
167 }
168 return false;
169 }
170
CreateMultipleOutputsOfAnfNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,size_t output_num,std::vector<AnfNodePtr> * outputs)171 void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
172 std::vector<AnfNodePtr> *outputs) {
173 MS_EXCEPTION_IF_NULL(func_graph);
174 MS_EXCEPTION_IF_NULL(node);
175 MS_EXCEPTION_IF_NULL(outputs);
176 auto type_ptr = node->Type();
177 auto shape_ptr = node->Shape();
178 for (size_t i = 0; i < output_num; i++) {
179 int64_t temp = SizeToLong(i);
180 auto idx = NewValueNode(temp);
181 MS_EXCEPTION_IF_NULL(idx);
182 auto imm = std::make_shared<Int64Imm>(temp);
183 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
184 idx->set_abstract(abstract_scalar);
185 auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
186 MS_EXCEPTION_IF_NULL(tuple_getitem);
187 AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(type_ptr, i)},
188 {AnfAlgo::GetOutputInferShape(node, shape_ptr, i)}, tuple_getitem.get());
189 (*outputs).push_back(tuple_getitem);
190 }
191 }
192
193 template <typename T>
CreateTensorWithValueTuple(const ValueTuplePtr & value_tuple_ptr,const TypePtr & type_ptr,size_t data_length)194 tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
195 size_t data_length) {
196 MS_EXCEPTION_IF_NULL(value_tuple_ptr);
197 MS_EXCEPTION_IF_NULL(type_ptr);
198 std::vector<T> values;
199 for (const auto &v : value_tuple_ptr->value()) {
200 MS_EXCEPTION_IF_NULL(v);
201 if (v->isa<Scalar>()) {
202 ScalarPtr scalar = v->cast<ScalarPtr>();
203 values.push_back(GetValue<T>(scalar));
204 } else {
205 MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
206 return nullptr;
207 }
208 }
209 std::vector<int64_t> tensor_shape = {SizeToLong(values.size())};
210 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
211 MS_EXCEPTION_IF_NULL(tensor);
212 tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
213 tensor->set_device_info(device_info);
214 auto data_ptr = tensor->data_c();
215 MS_EXCEPTION_IF_NULL(data_ptr);
216 auto elem_num = values.size() * data_length;
217 auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num);
218 if (ret_code != 0) {
219 MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
220 }
221 return tensor;
222 }
223
CreateTupleTensor(const ValueTuplePtr & value_tuple)224 tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
225 MS_EXCEPTION_IF_NULL(value_tuple);
226 tensor::TensorPtr tensor = nullptr;
227 if (value_tuple->value().empty()) {
228 MS_LOG(WARNING) << "The value tuple is empty.";
229 return nullptr;
230 }
231 ValuePtr v = *(value_tuple->value().begin());
232 MS_EXCEPTION_IF_NULL(v);
233 // Currently we only deal with the scalar tuple
234 if (!v->isa<Scalar>()) {
235 MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
236 return nullptr;
237 }
238 ScalarPtr scalar = v->cast<ScalarPtr>();
239 MS_EXCEPTION_IF_NULL(scalar);
240 if (scalar->isa<Int32Imm>()) {
241 tensor = CreateTensorWithValueTuple<int32_t>(value_tuple, kInt32, sizeof(int32_t));
242 } else if (scalar->isa<Int64Imm>()) {
243 tensor = CreateTensorWithValueTuple<int64_t>(value_tuple, kInt64, sizeof(int64_t));
244 } else if (scalar->isa<FloatImm>()) {
245 tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, sizeof(float));
246 } else {
247 auto type = scalar->type();
248 auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
249 MS_LOG(ERROR) << "Invalid scalar type: " << type_str;
250 return nullptr;
251 }
252 return tensor;
253 }
254
IsNopNode(const AnfNodePtr & node)255 bool IsNopNode(const AnfNodePtr &node) {
256 auto context_ptr = MsContext::GetInstance();
257 MS_EXCEPTION_IF_NULL(context_ptr);
258 auto target = GetCNodeTarget(node);
259 if (target == kCPUDevice) {
260 return false;
261 }
262 if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
263 context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
264 return false;
265 }
266
267 static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
268 prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
269 kFlattenGradOpName, prim::kPrimReformat->name()};
270 if (node == nullptr || !node->isa<CNode>()) {
271 return false;
272 }
273 CNodePtr cnode = node->cast<CNodePtr>();
274 MS_EXCEPTION_IF_NULL(cnode);
275 if (cnode->inputs().empty()) {
276 return false;
277 }
278 auto input0 = cnode->input(0);
279 MS_EXCEPTION_IF_NULL(input0);
280 if (!input0->isa<ValueNode>()) {
281 return false;
282 }
283 bool is_nop_node = false;
284 if (AnfAlgo::HasNodeAttr("nop_op", cnode)) {
285 is_nop_node = AnfAlgo::GetNodeAttr<bool>(cnode, "nop_op");
286 }
287 if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) {
288 return false;
289 }
290 return true;
291 }
292
IsAllNopNode(const session::KernelGraph * const graph)293 bool IsAllNopNode(const session::KernelGraph *const graph) {
294 MS_EXCEPTION_IF_NULL(graph);
295 auto execution_order = graph->execution_order();
296 for (auto &cnode : execution_order) {
297 MS_EXCEPTION_IF_NULL(cnode);
298 if (!IsNopNode(cnode)) {
299 return false;
300 }
301 }
302 return true;
303 }
304
CheckNopNodeIsOutputNode(const std::vector<AnfNodePtr> & outputs,const AnfNodePtr & node,bool is_dynamic_graph)305 bool CheckNopNodeIsOutputNode(const std::vector<AnfNodePtr> &outputs, const AnfNodePtr &node, bool is_dynamic_graph) {
306 MS_EXCEPTION_IF_NULL(node);
307 // if node is not a nop node, keep it in execution order
308 if (!IsNopNode(node)) {
309 return true;
310 }
311 // if node is nop node and the graph is dynamic graph, check if the nop node is graph's output.
312 if (is_dynamic_graph) {
313 auto iter = find(outputs.begin(), outputs.end(), node);
314 if (iter != outputs.end()) {
315 return true;
316 }
317 }
318 return false;
319 }
320
HideNopNode(session::KernelGraph * const graph)321 void HideNopNode(session::KernelGraph *const graph) {
322 MS_EXCEPTION_IF_NULL(graph);
323 if (IsAllNopNode(graph) == true) {
324 return;
325 }
326 auto execution_order = graph->execution_order();
327 auto outputs = graph->outputs();
328 bool is_dynamic_graph = graph->is_dynamic_shape();
329 MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
330 std::vector<CNodePtr> new_nodes;
331 for (auto &cnode : execution_order) {
332 MS_EXCEPTION_IF_NULL(cnode);
333 if (CheckNopNodeIsOutputNode(outputs, cnode, is_dynamic_graph)) {
334 new_nodes.push_back(cnode);
335 }
336 }
337 graph->set_execution_order(new_nodes);
338 MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size();
339 }
340
RemoveNopNode(session::KernelGraph * const graph)341 void RemoveNopNode(session::KernelGraph *const graph) {
342 MS_EXCEPTION_IF_NULL(graph);
343 if (IsAllNopNode(graph) == true) {
344 return;
345 }
346 bool changed = true;
347 while (changed) {
348 changed = false;
349 std::vector<CNodePtr> new_nodes;
350 auto outputs = graph->outputs();
351 bool is_dynamic_graph = graph->is_dynamic_shape();
352 for (auto &cnode : graph->execution_order()) {
353 MS_EXCEPTION_IF_NULL(cnode);
354 // ignore nop node itself
355 if (!CheckNopNodeIsOutputNode(outputs, cnode, is_dynamic_graph)) {
356 continue;
357 }
358 // Replace the input which is nop node
359 std::vector<AnfNodePtr> new_inputs;
360 new_inputs.push_back(cnode->input(0));
361 bool need_update = false;
362 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
363 auto input = cnode->input(i);
364 MS_EXCEPTION_IF_NULL(input);
365 auto cinput = input->cast<CNodePtr>();
366 if (cinput == nullptr || !IsNopNode(cinput)) {
367 new_inputs.push_back(input);
368 continue;
369 }
370 constexpr auto kInputSize = 2;
371 if (cinput->inputs().size() == kInputSize) {
372 new_inputs.push_back(cinput->input(1));
373 need_update = true;
374 changed = true;
375 } else {
376 new_inputs.push_back(input);
377 }
378 }
379 if (need_update) {
380 cnode->set_inputs(new_inputs);
381 }
382 // push into new execution list
383 new_nodes.push_back(cnode);
384 }
385 graph->set_execution_order(new_nodes);
386 }
387 }
388
GetRealNodeNum(const FuncGraphPtr & graph,const AnfNodePtr & node)389 size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
390 auto out_list = GetRealNodeUsedList(graph, node);
391 MS_EXCEPTION_IF_NULL(out_list);
392 return out_list->size();
393 }
394
GetRealNodeUsedList(const FuncGraphPtr & graph,const AnfNodePtr & node)395 std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
396 const AnfNodePtr &node) {
397 auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
398 MS_EXCEPTION_IF_NULL(graph);
399 auto manager = graph->manager();
400 MS_EXCEPTION_IF_NULL(manager);
401 auto iter = manager->node_users().find(node);
402 if (iter == manager->node_users().end()) {
403 return output_node_list;
404 }
405 auto output_info_list = iter->second;
406 for (const auto &output_info : output_info_list) {
407 auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
408 if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
409 (cnode_name == prim::kPrimUpdateState->name())) {
410 continue;
411 }
412 output_node_list->push_back(output_info);
413 }
414 return output_node_list;
415 }
416
GetRealNodeUsedListByOutputIdx(const FuncGraphPtr & graph,const AnfNodePtr & node,size_t output_index)417 std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
418 const AnfNodePtr &node,
419 size_t output_index) {
420 auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
421 MS_EXCEPTION_IF_NULL(graph);
422 auto manager = graph->manager();
423 MS_EXCEPTION_IF_NULL(manager);
424 auto iter = manager->node_users().find(node);
425 if (iter == manager->node_users().end()) {
426 MS_LOG(EXCEPTION) << "node has no output in manager";
427 }
428 auto output_info_list = iter->second;
429 for (const auto &output_info : output_info_list) {
430 auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
431 if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
432 (cnode_name == prim::kPrimUpdateState->name())) {
433 continue;
434 }
435 size_t used_output_index;
436 if (cnode_name == prim::kPrimTupleGetItem->name()) {
437 used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
438 } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
439 used_output_index = output_index;
440 } else {
441 auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(output_info.first, IntToSize(output_info.second - 1));
442 if (kernel_with_index.first.get() != node.get()) {
443 MS_LOG(EXCEPTION) << "Get used node failed for op[" << AnfAlgo::GetCNodeName(node) << "]";
444 }
445 used_output_index = kernel_with_index.second;
446 }
447 if (used_output_index == output_index) {
448 output_node_list->push_back(output_info);
449 }
450 }
451 return output_node_list;
452 }
453
IsUsedByOthers(const FuncGraphPtr & graph,const AnfNodePtr & node)454 bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
455 MS_EXCEPTION_IF_NULL(graph);
456 MS_EXCEPTION_IF_NULL(node);
457 auto output_node_list = GetRealNodeUsedList(graph, node);
458 MS_EXCEPTION_IF_NULL(output_node_list);
459 return output_node_list->size() > 1;
460 }
461
IsNotRealUsedByOthers(const FuncGraphPtr & graph,const AnfNodePtr & node)462 bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
463 MS_EXCEPTION_IF_NULL(graph);
464 MS_EXCEPTION_IF_NULL(node);
465 auto output_node_list = GetRealNodeUsedList(graph, node);
466 MS_EXCEPTION_IF_NULL(output_node_list);
467 if (output_node_list->empty()) {
468 return true;
469 }
470 for (const auto &output : *output_node_list) {
471 auto out_node = output.first;
472 auto name = AnfAlgo::GetCNodeName(out_node);
473 if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
474 name == prim::kPrimTupleGetItem->name() || name == prim::kPrimLoad->name()) {
475 auto result = IsNotRealUsedByOthers(graph, out_node);
476 if (!result) {
477 return result;
478 }
479 continue;
480 }
481 return false;
482 }
483 return true;
484 }
485
CreatTupleGetItemNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,size_t output_idx)486 CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
487 MS_EXCEPTION_IF_NULL(func_graph);
488 auto idx = NewValueNode(SizeToLong(output_idx));
489 MS_EXCEPTION_IF_NULL(idx);
490 auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx));
491 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
492 idx->set_abstract(abstract_scalar);
493 CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
494 MS_EXCEPTION_IF_NULL(tuple_getitem);
495 tuple_getitem->set_scope(node->scope());
496 auto abs = node->abstract()->cast<abstract::AbstractTuplePtr>();
497 MS_EXCEPTION_IF_NULL(abs);
498 auto abs_i = abs->elements()[output_idx];
499 MS_EXCEPTION_IF_NULL(abs_i);
500 tuple_getitem->set_abstract(abs_i);
501 return tuple_getitem;
502 }
503
CreateShapeValueNode(const FuncGraphPtr & func_graph,const std::vector<int64_t> & shape,bool to_tensor)504 ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape, bool to_tensor) {
505 MS_EXCEPTION_IF_NULL(func_graph);
506 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
507 MS_EXCEPTION_IF_NULL(kernel_graph);
508 ValuePtr shape_value = nullptr;
509 AbstractBasePtr abstract = nullptr;
510 if (to_tensor) {
511 // create Tensor
512 int64_t shape_dim = SizeToLong(shape.size());
513 std::vector<int64_t> shape_vec_shape = {shape_dim};
514 auto shape_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, shape_vec_shape);
515 MS_EXCEPTION_IF_NULL(shape_tensor);
516 auto data_ptr = shape_tensor->data_c();
517 MS_EXCEPTION_IF_NULL(data_ptr);
518 auto elem_num = shape.size() * kType64Len;
519 auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(shape_tensor->data().nbytes()), &shape[0], elem_num);
520 if (ret_code != 0) {
521 MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
522 return nullptr;
523 }
524 shape_value = shape_tensor;
525 abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);
526 } else {
527 // create ValueTuple
528 std::vector<ValuePtr> dim_values{};
529 abstract::AbstractBasePtrList abs{};
530 for (const auto &dim : shape) {
531 dim_values.push_back(MakeValue(dim));
532 abs.push_back(std::make_shared<abstract::AbstractScalar>(dim));
533 }
534 shape_value = std::make_shared<ValueTuple>(dim_values);
535 abstract = std::make_shared<abstract::AbstractTuple>(abs);
536 }
537 MS_EXCEPTION_IF_NULL(shape_value);
538 MS_EXCEPTION_IF_NULL(abstract);
539 auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value);
540 MS_EXCEPTION_IF_NULL(shape_value_node);
541 kernel_graph->AddValueNodeToGraph(shape_value_node);
542 return shape_value_node;
543 }
544
ConstInputToAttr(const CNodePtr & cnode,const std::unordered_set<size_t> & input_attrs)545 void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) {
546 MS_EXCEPTION_IF_NULL(cnode);
547 std::vector<AnfNodePtr> new_inputs;
548 auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
549 MS_EXCEPTION_IF_NULL(primitive);
550 primitive = primitive->Clone();
551 auto input_names = primitive->GetAttr(kAttrInputNames);
552 if (input_names == nullptr) {
553 MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
554 return;
555 }
556 auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
557 auto inputs = cnode->inputs();
558 new_inputs.push_back(inputs[0]);
559 bool need_update = false;
560 for (size_t i = 0; i < inputs.size() - 1; ++i) {
561 auto input_node = inputs[i + 1];
562 if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimDepend)) {
563 input_node = AnfAlgo::VisitKernel(input_node, 0).first;
564 }
565 MS_EXCEPTION_IF_NULL(input_node);
566 if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>() && !HasAbstractMonad(input_node)) {
567 auto value_node = input_node->cast<ValueNodePtr>();
568 MS_EXCEPTION_IF_NULL(value_node);
569 MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
570 if (i >= input_names_vec.size()) {
571 MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
572 }
573 primitive->set_attr(input_names_vec[i], value_node->value());
574 need_update = true;
575 } else {
576 new_inputs.push_back(inputs[i + 1]);
577 }
578 }
579 if (need_update) {
580 // Update cnode's inputs
581 new_inputs[0] = NewValueNode(primitive);
582 cnode->set_inputs(new_inputs);
583 }
584 }
585
AnfEqual(const BaseRef & a,const BaseRef & b)586 bool AnfEqual(const BaseRef &a, const BaseRef &b) {
587 if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
588 auto a_node = utils::cast<AnfNodePtr>(a);
589 auto b_node = utils::cast<AnfNodePtr>(b);
590 MS_EXCEPTION_IF_NULL(a_node);
591 MS_EXCEPTION_IF_NULL(b_node);
592 if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
593 auto a_value_node = a_node->cast<ValueNodePtr>();
594 MS_EXCEPTION_IF_NULL(a_value_node);
595 auto a_value = a_value_node->value();
596 MS_EXCEPTION_IF_NULL(a_value);
597 auto a_prim = a_value->cast<PrimitivePtr>();
598 MS_EXCEPTION_IF_NULL(a_prim);
599
600 auto b_value_node = b_node->cast<ValueNodePtr>();
601 MS_EXCEPTION_IF_NULL(b_value_node);
602 auto b_value = b_value_node->value();
603 MS_EXCEPTION_IF_NULL(b_value);
604 auto b_prim = b_value->cast<PrimitivePtr>();
605 MS_EXCEPTION_IF_NULL(b_prim);
606
607 return a_prim->name() == b_prim->name();
608 } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
609 auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
610 if (a_value_node_ptr == nullptr) {
611 MS_LOG(EXCEPTION) << "cast value node ptr fail";
612 }
613 auto a_value_ptr = a_value_node_ptr->value();
614 if (a_value_ptr == nullptr) {
615 MS_LOG(EXCEPTION) << "value ptr is nullptr";
616 }
617
618 auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
619 if (b_value_node_ptr == nullptr) {
620 MS_LOG(EXCEPTION) << "cast value node ptr fail";
621 }
622 auto b_value_ptr = b_value_node_ptr->value();
623 if (b_value_ptr == nullptr) {
624 MS_LOG(EXCEPTION) << "value ptr is nullptr";
625 }
626
627 return (*a_value_ptr) == (*b_value_ptr);
628 }
629 MS_LOG(DEBUG) << "check AnfNodePtr equal";
630 }
631 if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
632 MS_LOG(DEBUG) << "check GraphPtr equal";
633 }
634 return a == b;
635 }
636
CNodeTypeEqual(const BaseRef & a,const BaseRef & b)637 bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
638 // To matchCNode and Kernel's type
639 if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
640 return true;
641 }
642 return a.type() == b.type();
643 }
644
645 namespace {
CreateValueNodeWithSexp(const BaseRef & sexp)646 ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
647 if (utils::isa<int>(sexp)) {
648 return NewValueNode(utils::cast<int>(sexp));
649 }
650 if (utils::isa<int64_t>(sexp)) {
651 return NewValueNode(utils::cast<int64_t>(sexp));
652 }
653 if (utils::isa<float>(sexp)) {
654 return NewValueNode(utils::cast<float>(sexp));
655 }
656 if (utils::isa<bool>(sexp)) {
657 return NewValueNode(utils::cast<bool>(sexp));
658 }
659 if (utils::isa<ValuePtr>(sexp)) {
660 return NewValueNode(utils::cast<ValuePtr>(sexp));
661 }
662 return nullptr;
663 }
664
CreateCNodeWithGraph(const std::vector<AnfNodePtr> & input_nodes,const BaseRef & graph)665 CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
666 if (utils::isa<FuncGraphPtr>(graph)) {
667 return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
668 }
669 if (utils::isa<VarPtr>(graph)) {
670 return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
671 }
672 return nullptr;
673 }
674
CreateVarNodeWithSexp(const BaseRef & sexp,const BaseRef & graph)675 VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
676 if (utils::isa<VarPtr>(graph)) {
677 MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
678 return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
679 }
680 if (utils::isa<FuncGraphPtr>(graph)) {
681 MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
682 return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
683 }
684 MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
685 return nullptr;
686 }
687
HandleSexpVector(const BaseRef & sexp,const BaseRef & graph,PrimitiveVarMap * primitive_vars,bool multigraph)688 AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
689 bool multigraph) {
690 MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
691 std::vector<AnfNodePtr> input_nodes;
692 const auto &tuple = utils::cast<VectorRef>(sexp);
693 if (multigraph && utils::isa<VarPtr>(graph)) {
694 for (auto &x : tuple) {
695 AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
696 input_nodes.push_back(node);
697 }
698 VarPtr var_ptr = utils::cast<VarPtr>(graph);
699 return std::make_shared<CNode>(input_nodes, var_ptr);
700 }
701
702 for (auto &x : tuple) {
703 AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
704 input_nodes.push_back(node);
705 }
706 return CreateCNodeWithGraph(input_nodes, graph);
707 }
708
709 // rectify absttract if the input has been converted to the attr
RectifyAbstractFromRegAttr(const PrimitivePtr & primitive,const AbstractBasePtrList & input_abstract)710 AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
711 const AbstractBasePtrList &input_abstract) {
712 MS_EXCEPTION_IF_NULL(primitive);
713 opt::ConstInputToAttrInfoRegister reg;
714 if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), ®)) {
715 return input_abstract;
716 }
717 if (AnfAlgo::HasDynamicShapeFlag(primitive)) {
718 return input_abstract;
719 }
720 auto ms_context = MsContext::GetInstance();
721 MS_EXCEPTION_IF_NULL(ms_context);
722 auto device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
723 if (device == kGPUDevice) {
724 if (DynamicShapeConstInputToAttrGPU.find(primitive->name()) != DynamicShapeConstInputToAttrGPU.end()) {
725 return input_abstract;
726 }
727 } else if (DynamicShapeConstInputToAttr.find(primitive->name()) != DynamicShapeConstInputToAttr.end()) {
728 return input_abstract;
729 }
730 auto convert_input_list = reg.GetConstInputAttrInfo();
731 auto input_names = primitive->GetAttr(kAttrInputNames);
732 if (input_names == nullptr) {
733 return input_abstract;
734 }
735 auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
736 AbstractBasePtrList rectify_abs_list;
737 size_t ori_index = 0;
738 rectify_abs_list.resize(input_names_vec.size());
739 for (size_t index = 0; index < rectify_abs_list.size(); ++index) {
740 // if convert input list find the index it means the input has been converted to the attr
741 if (convert_input_list.find(index) != convert_input_list.end()) {
742 AbstractBasePtr rectify_abs = nullptr;
743 auto input_name = input_names_vec[index];
744 auto attr = primitive->GetAttr(input_name);
745 if (attr != nullptr) {
746 rectify_abs = attr->ToAbstract();
747 } else {
748 MS_LOG(DEBUG) << "the node prim name :" << primitive->name() << "input index :" << index
749 << " input name :" << input_name << "has not been converted to the attr";
750 rectify_abs = input_abstract[ori_index++];
751 }
752 rectify_abs_list[index] = rectify_abs;
753 continue;
754 }
755 if (ori_index > input_abstract.size()) {
756 MS_LOG(EXCEPTION) << "index is out of range input abstract size " << input_abstract.size()
757 << " get index :" << ori_index;
758 }
759 rectify_abs_list[index] = input_abstract[ori_index++];
760 }
761 return rectify_abs_list;
762 }
763
RectifyAbstractFromDynamicInput(const PrimitivePtr & primitive,const AbstractBasePtrList & input_abstract)764 AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
765 const AbstractBasePtrList &input_abstract) {
766 auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
767 if (dynamic_inputs_list == nullptr) {
768 return input_abstract;
769 }
770 AbstractBasePtrList rectifyed_abs_list;
771 const int kNotDynamicFlag = -1;
772 auto dynamic_inputs_index = GetValue<std::vector<int64_t>>(dynamic_inputs_list);
773 size_t input_index = 0;
774 for (auto item : dynamic_inputs_index) {
775 if (item == kNotDynamicFlag) {
776 if (input_index >= input_abstract.size()) {
777 MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract " << input_abstract.size();
778 }
779 (void)rectifyed_abs_list.emplace_back(input_abstract[input_index++]);
780 } else {
781 if (item < 0) {
782 MS_LOG(EXCEPTION) << " the dynamic input size check error the index should be -1 or positive number but got "
783 << item;
784 }
785 AbstractBasePtrList dynamic_inputs_abs;
786 for (auto index = item; index > 0; --index) {
787 if (input_index >= input_abstract.size()) {
788 MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract "
789 << input_abstract.size();
790 }
791 (void)dynamic_inputs_abs.emplace_back(input_abstract[input_index++]);
792 }
793 (void)rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs));
794 }
795 }
796 return rectifyed_abs_list;
797 }
798
RectifyAbstract(const PrimitivePtr & primitive,const AbstractBasePtrList & input_abstract)799 AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) {
800 auto rectify_abs_list = RectifyAbstractFromRegAttr(primitive, input_abstract);
801 return RectifyAbstractFromDynamicInput(primitive, rectify_abs_list);
802 }
803 } // namespace
804
SexpToNode(const BaseRef & sexp,const BaseRef & graph,PrimitiveVarMap * primitive_vars,bool multigraph)805 AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
806 MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
807 MS_EXCEPTION_IF_NULL(primitive_vars);
808 if (utils::isa<VectorRef>(sexp)) {
809 return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
810 }
811 if (utils::isa<VarPtr>(sexp)) {
812 auto var_ptr = utils::cast<VarPtr>(sexp);
813 MS_EXCEPTION_IF_NULL(var_ptr);
814 if (var_ptr->primitive()) {
815 (*primitive_vars)[var_ptr->primitive()] = var_ptr;
816 return NewValueNode(var_ptr->primitive());
817 }
818 return CreateVarNodeWithSexp(sexp, graph);
819 }
820 if (utils::isa<AnfNodePtr>(sexp)) {
821 return utils::cast<AnfNodePtr>(sexp);
822 }
823 auto value_node = CreateValueNodeWithSexp(sexp);
824 if (value_node == nullptr) {
825 MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
826 }
827 return value_node;
828 }
829
IsSameNode(const EquivPtr & equiv1,const EquivPtr & equiv2,const VarPtr & var_node)830 bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) {
831 MS_EXCEPTION_IF_NULL(equiv1);
832 MS_EXCEPTION_IF_NULL(equiv2);
833 MS_EXCEPTION_IF_NULL(var_node);
834 auto equiv1_node = GetAnfNodeByVar(equiv1, var_node);
835 MS_EXCEPTION_IF_NULL(equiv1_node);
836 auto equiv2_node = GetAnfNodeByVar(equiv2, var_node);
837 MS_EXCEPTION_IF_NULL(equiv2_node);
838 return *equiv1_node == *equiv2_node;
839 }
840
GetAnfNodeByVar(const EquivPtr & equiv,const VarPtr & var_node)841 AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) {
842 MS_EXCEPTION_IF_NULL(equiv);
843 MS_EXCEPTION_IF_NULL(var_node);
844 auto iter = (*equiv).find(var_node);
845 if (iter == (*equiv).end()) {
846 MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched.";
847 return nullptr;
848 }
849 auto res = utils::cast<AnfNodePtr>(iter->second);
850 if (res == nullptr) {
851 MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node";
852 }
853 return res;
854 }
855
CompareTupleGetitem(const AnfNodePtr & n1,const AnfNodePtr & n2)856 bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
857 MS_EXCEPTION_IF_NULL(n1);
858 MS_EXCEPTION_IF_NULL(n2);
859 auto n1_cnode = n1->cast<CNodePtr>();
860 auto n2_cnode = n2->cast<CNodePtr>();
861 MS_EXCEPTION_IF_NULL(n1_cnode);
862 MS_EXCEPTION_IF_NULL(n2_cnode);
863 auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem);
864 MS_EXCEPTION_IF_NULL(index_input1);
865 auto value_node1 = index_input1->cast<ValueNodePtr>();
866 MS_EXCEPTION_IF_NULL(value_node1);
867 auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem);
868 MS_EXCEPTION_IF_NULL(index_input2);
869 auto value_node2 = index_input2->cast<ValueNodePtr>();
870 MS_EXCEPTION_IF_NULL(value_node2);
871 return GetValue<int64_t>(value_node1->value()) < GetValue<int64_t>(value_node2->value());
872 }
873
GetBoolAttr(const AnfNodePtr & node,const std::string & attr_name)874 bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
875 MS_EXCEPTION_IF_NULL(node);
876 if (!node->isa<CNode>()) {
877 MS_LOG(INFO) << "node is not a cnode";
878 return false;
879 }
880 auto cnode = node->cast<CNodePtr>();
881 MS_EXCEPTION_IF_NULL(cnode);
882 return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr<bool>(node, attr_name);
883 }
884
CheckSupportDataType(const AnfNodePtr & node,const std::set<TypeId> & supported_data_type_set)885 bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set) {
886 MS_EXCEPTION_IF_NULL(node);
887 TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0);
888 if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) {
889 return true;
890 }
891 MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
892 return false;
893 }
894
MakeValueNode(const ValueNodePtr & value_node)895 ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
896 MS_EXCEPTION_IF_NULL(value_node);
897 ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
898 MS_EXCEPTION_IF_NULL(new_value_node);
899 new_value_node->set_abstract(value_node->abstract());
900 // create kernel_info fo new value node
901 auto kernel_info = std::make_shared<device::KernelInfo>();
902 new_value_node->set_kernel_info(kernel_info);
903 // create kernel_build_info for new value node
904 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
905 MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
906 // set the format of value_node to DEFAULT_FORMAT
907 kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
908 // set value node initial device data type = infer data type
909 std::vector<TypeId> types;
910 size_t output_num = AnfAlgo::GetOutputTensorNum(value_node);
911 for (size_t index = 0; index < output_num; ++index) {
912 types.push_back(kTypeUnknown);
913 }
914 kernel_build_info_builder->SetOutputsDeviceType(types);
915 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
916 return new_value_node;
917 }
918
TransferDependOrUpdateState(const CNodePtr & old_node,const FuncGraphPtr & graph,const CNodePtr & new_node)919 void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
920 MS_EXCEPTION_IF_NULL(old_node);
921 MS_EXCEPTION_IF_NULL(graph);
922 auto manager = graph->manager();
923 MS_EXCEPTION_IF_NULL(manager);
924 // Find BatchNorm's output which is a Depend or UpdateState.
925 auto node_users = manager->node_users()[old_node];
926 for (const auto &node_index : node_users) {
927 AnfNodePtr output = node_index.first;
928 MS_EXCEPTION_IF_NULL(output);
929 if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
930 AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
931 auto depend = output->cast<CNodePtr>();
932 MS_EXCEPTION_IF_NULL(depend);
933 manager->SetEdge(depend, node_index.second, new_node);
934 }
935 }
936 }
937
CppInferShape(const PrimitivePtr & prim,const AbstractBasePtrList & args_spec_list)938 AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
939 MS_EXCEPTION_IF_NULL(prim);
940 auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
941 auto ret = prim_eval_implement_map.find(prim);
942 if (ret != prim_eval_implement_map.end()) {
943 // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
944 MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_);
945 auto infer_spec_list = RectifyAbstract(prim, args_spec_list);
946 return ret->second.infer_shape_impl_(nullptr, prim, infer_spec_list);
947 } else {
948 // if the infer function has been not founded in the front infer map find it in the backend infer map instead
949 auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
950 auto ret_backend = prim_backend_eval_impl_map.find(prim);
951 if (ret_backend != prim_backend_eval_impl_map.end()) {
952 MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_);
953 auto infer_spec_list = args_spec_list;
954 if (!ret_backend->second.in_white_list_) {
955 infer_spec_list = RectifyAbstract(prim, args_spec_list);
956 }
957 return ret_backend->second.infer_shape_impl_(nullptr, prim, infer_spec_list);
958 }
959 }
960 MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
961 << " primitive type:" << prim->type_name();
962 }
963
GenerateKernelBuildInfo(const std::vector<AnfNodePtr> & node_list)964 kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
965 std::vector<std::string> inputs_device_format;
966 std::vector<std::string> outputs_device_format;
967 std::vector<TypeId> inputs_device_type;
968 std::vector<TypeId> outputs_device_type;
969 std::vector<std::vector<size_t>> outputs_shape;
970 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
971 for (size_t idx = 0; idx < node_list.size(); ++idx) {
972 auto cnode = utils::cast<CNodePtr>(node_list[idx]);
973 MS_EXCEPTION_IF_NULL(cnode);
974 size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
975 for (size_t input_index = 0; input_index < input_num; ++input_index) {
976 (void)inputs_device_format.emplace_back(kOpFormat_DEFAULT);
977 (void)inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
978 }
979 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
980 for (size_t output_index = 0; output_index < output_num; ++output_index) {
981 (void)outputs_device_format.emplace_back(kOpFormat_DEFAULT);
982 (void)outputs_device_type.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
983 (void)outputs_shape.emplace_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
984 }
985 }
986 builder.SetInputsFormat(inputs_device_format);
987 builder.SetOutputsFormat(outputs_device_format);
988 builder.SetInputsDeviceType(inputs_device_type);
989 builder.SetOutputsDeviceType(outputs_device_type);
990 return builder.Build();
991 }
992
GetNodeOutputUsedNum(const session::KernelGraph & kernel_graph,const AnfNodePtr & node)993 std::vector<int64_t> GetNodeOutputUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
994 MS_EXCEPTION_IF_NULL(node);
995 auto manager = kernel_graph.manager();
996 MS_EXCEPTION_IF_NULL(manager);
997 auto output_num = AnfAlgo::GetOutputTensorNum(node);
998 std::vector<int64_t> output_used_num(output_num, 0);
999 if (output_num == 1) {
1000 output_used_num[0] = SizeToLong(manager->node_users()[node].size());
1001 } else {
1002 for (auto out_getitem : manager->node_users()[node]) {
1003 MS_EXCEPTION_IF_NULL(out_getitem.first);
1004 if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
1005 continue;
1006 }
1007 auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
1008 MS_EXCEPTION_IF_NULL(out_getitem_ptr);
1009 auto getitem_input2 = out_getitem_ptr->input(kInputNodeOutputIndexInTupleGetItem);
1010 auto output_idx = LongToSize(GetValue<int64_t>(GetValueNode(getitem_input2)));
1011 output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
1012 }
1013 }
1014 return output_used_num;
1015 }
1016
GetNodeOutputTotalUsedNum(const session::KernelGraph & kernel_graph,const AnfNodePtr & node)1017 int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
1018 auto output_used_num = GetNodeOutputUsedNum(kernel_graph, node);
1019 return std::accumulate(output_used_num.begin(), output_used_num.end(), int64_t(0));
1020 }
1021 } // namespace opt
1022 } // namespace mindspore
1023