• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/common/helper.h"
18 #include <string>
19 #include <utility>
20 #include <unordered_set>
21 #include <algorithm>
22 #include <map>
23 #include <set>
24 #include <deque>
25 #include "utils/utils.h"
26 #include "base/base_ref.h"
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "base/core_ops.h"
29 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
30 #include "frontend/operator/ops.h"
31 #include "utils/ms_utils.h"
32 #include "runtime/device/kernel_info.h"
33 #include "utils/ms_context.h"
34 #include "backend/optimizer/common/const_input_to_attr_registry.h"
35 #include "abstract/primitive_infer_map.h"
36 
37 namespace mindspore {
38 namespace opt {
39 constexpr size_t kType32Len = 4;
40 constexpr size_t kType64Len = 8;
41 
Convert2Int(const std::vector<size_t> & v)42 std::vector<int64_t> Convert2Int(const std::vector<size_t> &v) {
43   std::vector<int64_t> result;
44   (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt);
45   return result;
46 }
47 
Convert2Long(const std::vector<size_t> & v)48 std::vector<int64_t> Convert2Long(const std::vector<size_t> &v) {
49   std::vector<int64_t> result;
50   (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToLong);
51   return result;
52 }
53 
IsDepend(const FuncGraph & graph,const AnfNodePtr & node,const std::vector<AnfNodePtr> & nodes)54 bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
55   MS_EXCEPTION_IF_NULL(node);
56   FuncGraphManagerPtr manager = graph.manager();
57   MS_EXCEPTION_IF_NULL(manager);
58 
59   std::unordered_set<AnfNodePtr> seen_node;
60   std::deque<AnfNodePtr> todo{node};
61   while (!todo.empty()) {
62     AnfNodePtr nd = todo.front();
63     todo.pop_front();
64     if (seen_node.count(nd) > 0 || !manager->all_nodes().contains(nd)) {
65       continue;
66     }
67     (void)seen_node.insert(nd);
68 
69     if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) {
70       return true;
71     }
72     if (nd->isa<CNode>()) {
73       auto cnode = nd->cast<CNodePtr>();
74       MS_EXCEPTION_IF_NULL(cnode);
75       auto inputs = cnode->inputs();
76       (void)todo.insert(todo.end(), inputs.begin(), inputs.end());
77     }
78   }
79   return false;
80 }
81 
UnVisited(const BaseRef & n)82 bool UnVisited(const BaseRef &n) {
83   if (utils::isa<AnfNodePtr>(n)) {
84     AnfNodePtr in = utils::cast<AnfNodePtr>(n);
85     MS_EXCEPTION_IF_NULL(in);
86     if (IsValueNode<Primitive>(in)) {
87       auto value_node = in->cast<ValueNodePtr>();
88       MS_EXCEPTION_IF_NULL(value_node);
89       auto value = value_node->value();
90       MS_EXCEPTION_IF_NULL(value);
91       auto prim_py = value->cast<PrimitivePtr>();
92       MS_EXCEPTION_IF_NULL(prim_py);
93       return !prim_py->HasAttr(kAttrVisited);
94     } else if (IsValueNode<FuncGraph>(in)) {
95       auto func_graph = GetValueNode<FuncGraphPtr>(in);
96       MS_EXCEPTION_IF_NULL(func_graph);
97       return !func_graph->has_flag(kAttrVisited);
98     }
99     return false;
100   }
101   return false;
102 }
103 
CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr & node,size_t input_size)104 CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size) {
105   MS_EXCEPTION_IF_NULL(node);
106   if (!node->isa<CNode>()) {
107     MS_LOG(EXCEPTION) << "The node is expected to be a cnode";
108   }
109   auto cnode = node->cast<CNodePtr>();
110   CheckCNodeInputSize(cnode, input_size);
111   return cnode;
112 }
113 
CheckCNodeInputSize(const CNodePtr & cnode,size_t input_tensor_size)114 void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_size) {
115   MS_EXCEPTION_IF_NULL(cnode);
116   auto real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode);
117   if (real_input_tensor_num != input_tensor_size) {
118     MS_LOG(EXCEPTION) << "The input tensor size[" << real_input_tensor_num
119                       << "] of node " + cnode->DebugString() + " is not equal to " << input_tensor_size;
120   }
121 }
122 
HasSymmetricalKernelInfo(const AnfNodePtr & node_x,const AnfNodePtr & node_y)123 bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) {
124   MS_EXCEPTION_IF_NULL(node_x);
125   MS_EXCEPTION_IF_NULL(node_y);
126   return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) &&
127           AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0));
128 }
129 
EliminateDependTransop(const FuncGraphPtr & func_graph,const AnfNodePtr & node)130 const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
131   MS_EXCEPTION_IF_NULL(func_graph);
132 
133   auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
134   MS_EXCEPTION_IF_NULL(transop_cnode);
135   auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(1), kDependInputTensorNum);
136   auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputTensorNum);
137   auto transed_node = prev_transop_cnode->input(1);
138   MS_EXCEPTION_IF_NULL(transed_node);
139 
140   std::vector<AnfNodePtr> replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node,
141                                                 depend_cnode->input(kDependAttachNodeIndex)};
142   AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs);
143   MS_EXCEPTION_IF_NULL(replace_depend);
144   auto transed_abstract = transed_node->abstract();
145   replace_depend->set_abstract(transed_abstract);
146   return replace_depend;
147 }
148 
Visited(const BaseRef & n)149 bool Visited(const BaseRef &n) {
150   if (utils::isa<AnfNodePtr>(n)) {
151     AnfNodePtr in = utils::cast<AnfNodePtr>(n);
152     MS_EXCEPTION_IF_NULL(in);
153     if (IsValueNode<Primitive>(in)) {
154       auto value_node = in->cast<ValueNodePtr>();
155       MS_EXCEPTION_IF_NULL(value_node);
156       auto value = value_node->value();
157       MS_EXCEPTION_IF_NULL(value);
158       auto prim_py = value->cast<PrimitivePtr>();
159       MS_EXCEPTION_IF_NULL(prim_py);
160       return prim_py->HasAttr(kAttrVisited);
161     } else if (IsValueNode<FuncGraph>(in)) {
162       auto func_graph = GetValueNode<FuncGraphPtr>(in);
163       MS_EXCEPTION_IF_NULL(func_graph);
164       return func_graph->has_flag(kAttrVisited);
165     }
166     return false;
167   }
168   return false;
169 }
170 
CreateMultipleOutputsOfAnfNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,size_t output_num,std::vector<AnfNodePtr> * outputs)171 void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
172                                     std::vector<AnfNodePtr> *outputs) {
173   MS_EXCEPTION_IF_NULL(func_graph);
174   MS_EXCEPTION_IF_NULL(node);
175   MS_EXCEPTION_IF_NULL(outputs);
176   auto type_ptr = node->Type();
177   auto shape_ptr = node->Shape();
178   for (size_t i = 0; i < output_num; i++) {
179     int64_t temp = SizeToLong(i);
180     auto idx = NewValueNode(temp);
181     MS_EXCEPTION_IF_NULL(idx);
182     auto imm = std::make_shared<Int64Imm>(temp);
183     auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
184     idx->set_abstract(abstract_scalar);
185     auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
186     MS_EXCEPTION_IF_NULL(tuple_getitem);
187     AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(type_ptr, i)},
188                                         {AnfAlgo::GetOutputInferShape(node, shape_ptr, i)}, tuple_getitem.get());
189     (*outputs).push_back(tuple_getitem);
190   }
191 }
192 
193 template <typename T>
CreateTensorWithValueTuple(const ValueTuplePtr & value_tuple_ptr,const TypePtr & type_ptr,size_t data_length)194 tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
195                                              size_t data_length) {
196   MS_EXCEPTION_IF_NULL(value_tuple_ptr);
197   MS_EXCEPTION_IF_NULL(type_ptr);
198   std::vector<T> values;
199   for (const auto &v : value_tuple_ptr->value()) {
200     MS_EXCEPTION_IF_NULL(v);
201     if (v->isa<Scalar>()) {
202       ScalarPtr scalar = v->cast<ScalarPtr>();
203       values.push_back(GetValue<T>(scalar));
204     } else {
205       MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
206       return nullptr;
207     }
208   }
209   std::vector<int64_t> tensor_shape = {SizeToLong(values.size())};
210   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
211   MS_EXCEPTION_IF_NULL(tensor);
212   tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
213   tensor->set_device_info(device_info);
214   auto data_ptr = tensor->data_c();
215   MS_EXCEPTION_IF_NULL(data_ptr);
216   auto elem_num = values.size() * data_length;
217   auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num);
218   if (ret_code != 0) {
219     MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
220   }
221   return tensor;
222 }
223 
CreateTupleTensor(const ValueTuplePtr & value_tuple)224 tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
225   MS_EXCEPTION_IF_NULL(value_tuple);
226   tensor::TensorPtr tensor = nullptr;
227   if (value_tuple->value().empty()) {
228     MS_LOG(WARNING) << "The value tuple is empty.";
229     return nullptr;
230   }
231   ValuePtr v = *(value_tuple->value().begin());
232   MS_EXCEPTION_IF_NULL(v);
233   // Currently we only deal with the scalar tuple
234   if (!v->isa<Scalar>()) {
235     MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
236     return nullptr;
237   }
238   ScalarPtr scalar = v->cast<ScalarPtr>();
239   MS_EXCEPTION_IF_NULL(scalar);
240   if (scalar->isa<Int32Imm>()) {
241     tensor = CreateTensorWithValueTuple<int32_t>(value_tuple, kInt32, sizeof(int32_t));
242   } else if (scalar->isa<Int64Imm>()) {
243     tensor = CreateTensorWithValueTuple<int64_t>(value_tuple, kInt64, sizeof(int64_t));
244   } else if (scalar->isa<FloatImm>()) {
245     tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, sizeof(float));
246   } else {
247     auto type = scalar->type();
248     auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
249     MS_LOG(ERROR) << "Invalid scalar type: " << type_str;
250     return nullptr;
251   }
252   return tensor;
253 }
254 
IsNopNode(const AnfNodePtr & node)255 bool IsNopNode(const AnfNodePtr &node) {
256   auto context_ptr = MsContext::GetInstance();
257   MS_EXCEPTION_IF_NULL(context_ptr);
258   auto target = GetCNodeTarget(node);
259   if (target == kCPUDevice) {
260     return false;
261   }
262   if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
263       context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
264     return false;
265   }
266 
267   static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
268                                                       prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
269                                                       kFlattenGradOpName,         prim::kPrimReformat->name()};
270   if (node == nullptr || !node->isa<CNode>()) {
271     return false;
272   }
273   CNodePtr cnode = node->cast<CNodePtr>();
274   MS_EXCEPTION_IF_NULL(cnode);
275   if (cnode->inputs().empty()) {
276     return false;
277   }
278   auto input0 = cnode->input(0);
279   MS_EXCEPTION_IF_NULL(input0);
280   if (!input0->isa<ValueNode>()) {
281     return false;
282   }
283   bool is_nop_node = false;
284   if (AnfAlgo::HasNodeAttr("nop_op", cnode)) {
285     is_nop_node = AnfAlgo::GetNodeAttr<bool>(cnode, "nop_op");
286   }
287   if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) {
288     return false;
289   }
290   return true;
291 }
292 
IsAllNopNode(const session::KernelGraph * const graph)293 bool IsAllNopNode(const session::KernelGraph *const graph) {
294   MS_EXCEPTION_IF_NULL(graph);
295   auto execution_order = graph->execution_order();
296   for (auto &cnode : execution_order) {
297     MS_EXCEPTION_IF_NULL(cnode);
298     if (!IsNopNode(cnode)) {
299       return false;
300     }
301   }
302   return true;
303 }
304 
CheckNopNodeIsOutputNode(const std::vector<AnfNodePtr> & outputs,const AnfNodePtr & node,bool is_dynamic_graph)305 bool CheckNopNodeIsOutputNode(const std::vector<AnfNodePtr> &outputs, const AnfNodePtr &node, bool is_dynamic_graph) {
306   MS_EXCEPTION_IF_NULL(node);
307   // if node is not a nop node, keep it in execution order
308   if (!IsNopNode(node)) {
309     return true;
310   }
311   // if node is nop node and the graph is dynamic graph, check if the nop node is graph's output.
312   if (is_dynamic_graph) {
313     auto iter = find(outputs.begin(), outputs.end(), node);
314     if (iter != outputs.end()) {
315       return true;
316     }
317   }
318   return false;
319 }
320 
HideNopNode(session::KernelGraph * const graph)321 void HideNopNode(session::KernelGraph *const graph) {
322   MS_EXCEPTION_IF_NULL(graph);
323   if (IsAllNopNode(graph) == true) {
324     return;
325   }
326   auto execution_order = graph->execution_order();
327   auto outputs = graph->outputs();
328   bool is_dynamic_graph = graph->is_dynamic_shape();
329   MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
330   std::vector<CNodePtr> new_nodes;
331   for (auto &cnode : execution_order) {
332     MS_EXCEPTION_IF_NULL(cnode);
333     if (CheckNopNodeIsOutputNode(outputs, cnode, is_dynamic_graph)) {
334       new_nodes.push_back(cnode);
335     }
336   }
337   graph->set_execution_order(new_nodes);
338   MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size();
339 }
340 
RemoveNopNode(session::KernelGraph * const graph)341 void RemoveNopNode(session::KernelGraph *const graph) {
342   MS_EXCEPTION_IF_NULL(graph);
343   if (IsAllNopNode(graph) == true) {
344     return;
345   }
346   bool changed = true;
347   while (changed) {
348     changed = false;
349     std::vector<CNodePtr> new_nodes;
350     auto outputs = graph->outputs();
351     bool is_dynamic_graph = graph->is_dynamic_shape();
352     for (auto &cnode : graph->execution_order()) {
353       MS_EXCEPTION_IF_NULL(cnode);
354       // ignore nop node itself
355       if (!CheckNopNodeIsOutputNode(outputs, cnode, is_dynamic_graph)) {
356         continue;
357       }
358       // Replace the input which is nop node
359       std::vector<AnfNodePtr> new_inputs;
360       new_inputs.push_back(cnode->input(0));
361       bool need_update = false;
362       for (size_t i = 1; i < cnode->inputs().size(); ++i) {
363         auto input = cnode->input(i);
364         MS_EXCEPTION_IF_NULL(input);
365         auto cinput = input->cast<CNodePtr>();
366         if (cinput == nullptr || !IsNopNode(cinput)) {
367           new_inputs.push_back(input);
368           continue;
369         }
370         constexpr auto kInputSize = 2;
371         if (cinput->inputs().size() == kInputSize) {
372           new_inputs.push_back(cinput->input(1));
373           need_update = true;
374           changed = true;
375         } else {
376           new_inputs.push_back(input);
377         }
378       }
379       if (need_update) {
380         cnode->set_inputs(new_inputs);
381       }
382       // push into new execution list
383       new_nodes.push_back(cnode);
384     }
385     graph->set_execution_order(new_nodes);
386   }
387 }
388 
GetRealNodeNum(const FuncGraphPtr & graph,const AnfNodePtr & node)389 size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
390   auto out_list = GetRealNodeUsedList(graph, node);
391   MS_EXCEPTION_IF_NULL(out_list);
392   return out_list->size();
393 }
394 
GetRealNodeUsedList(const FuncGraphPtr & graph,const AnfNodePtr & node)395 std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
396                                                                              const AnfNodePtr &node) {
397   auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
398   MS_EXCEPTION_IF_NULL(graph);
399   auto manager = graph->manager();
400   MS_EXCEPTION_IF_NULL(manager);
401   auto iter = manager->node_users().find(node);
402   if (iter == manager->node_users().end()) {
403     return output_node_list;
404   }
405   auto output_info_list = iter->second;
406   for (const auto &output_info : output_info_list) {
407     auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
408     if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
409         (cnode_name == prim::kPrimUpdateState->name())) {
410       continue;
411     }
412     output_node_list->push_back(output_info);
413   }
414   return output_node_list;
415 }
416 
GetRealNodeUsedListByOutputIdx(const FuncGraphPtr & graph,const AnfNodePtr & node,size_t output_index)417 std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
418                                                                                         const AnfNodePtr &node,
419                                                                                         size_t output_index) {
420   auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
421   MS_EXCEPTION_IF_NULL(graph);
422   auto manager = graph->manager();
423   MS_EXCEPTION_IF_NULL(manager);
424   auto iter = manager->node_users().find(node);
425   if (iter == manager->node_users().end()) {
426     MS_LOG(EXCEPTION) << "node has no output in manager";
427   }
428   auto output_info_list = iter->second;
429   for (const auto &output_info : output_info_list) {
430     auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
431     if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
432         (cnode_name == prim::kPrimUpdateState->name())) {
433       continue;
434     }
435     size_t used_output_index;
436     if (cnode_name == prim::kPrimTupleGetItem->name()) {
437       used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
438     } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
439       used_output_index = output_index;
440     } else {
441       auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(output_info.first, IntToSize(output_info.second - 1));
442       if (kernel_with_index.first.get() != node.get()) {
443         MS_LOG(EXCEPTION) << "Get used node failed for op[" << AnfAlgo::GetCNodeName(node) << "]";
444       }
445       used_output_index = kernel_with_index.second;
446     }
447     if (used_output_index == output_index) {
448       output_node_list->push_back(output_info);
449     }
450   }
451   return output_node_list;
452 }
453 
IsUsedByOthers(const FuncGraphPtr & graph,const AnfNodePtr & node)454 bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
455   MS_EXCEPTION_IF_NULL(graph);
456   MS_EXCEPTION_IF_NULL(node);
457   auto output_node_list = GetRealNodeUsedList(graph, node);
458   MS_EXCEPTION_IF_NULL(output_node_list);
459   return output_node_list->size() > 1;
460 }
461 
IsNotRealUsedByOthers(const FuncGraphPtr & graph,const AnfNodePtr & node)462 bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
463   MS_EXCEPTION_IF_NULL(graph);
464   MS_EXCEPTION_IF_NULL(node);
465   auto output_node_list = GetRealNodeUsedList(graph, node);
466   MS_EXCEPTION_IF_NULL(output_node_list);
467   if (output_node_list->empty()) {
468     return true;
469   }
470   for (const auto &output : *output_node_list) {
471     auto out_node = output.first;
472     auto name = AnfAlgo::GetCNodeName(out_node);
473     if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
474         name == prim::kPrimTupleGetItem->name() || name == prim::kPrimLoad->name()) {
475       auto result = IsNotRealUsedByOthers(graph, out_node);
476       if (!result) {
477         return result;
478       }
479       continue;
480     }
481     return false;
482   }
483   return true;
484 }
485 
CreatTupleGetItemNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,size_t output_idx)486 CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
487   MS_EXCEPTION_IF_NULL(func_graph);
488   auto idx = NewValueNode(SizeToLong(output_idx));
489   MS_EXCEPTION_IF_NULL(idx);
490   auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx));
491   auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
492   idx->set_abstract(abstract_scalar);
493   CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
494   MS_EXCEPTION_IF_NULL(tuple_getitem);
495   tuple_getitem->set_scope(node->scope());
496   auto abs = node->abstract()->cast<abstract::AbstractTuplePtr>();
497   MS_EXCEPTION_IF_NULL(abs);
498   auto abs_i = abs->elements()[output_idx];
499   MS_EXCEPTION_IF_NULL(abs_i);
500   tuple_getitem->set_abstract(abs_i);
501   return tuple_getitem;
502 }
503 
CreateShapeValueNode(const FuncGraphPtr & func_graph,const std::vector<int64_t> & shape,bool to_tensor)504 ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape, bool to_tensor) {
505   MS_EXCEPTION_IF_NULL(func_graph);
506   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
507   MS_EXCEPTION_IF_NULL(kernel_graph);
508   ValuePtr shape_value = nullptr;
509   AbstractBasePtr abstract = nullptr;
510   if (to_tensor) {
511     // create Tensor
512     int64_t shape_dim = SizeToLong(shape.size());
513     std::vector<int64_t> shape_vec_shape = {shape_dim};
514     auto shape_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, shape_vec_shape);
515     MS_EXCEPTION_IF_NULL(shape_tensor);
516     auto data_ptr = shape_tensor->data_c();
517     MS_EXCEPTION_IF_NULL(data_ptr);
518     auto elem_num = shape.size() * kType64Len;
519     auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(shape_tensor->data().nbytes()), &shape[0], elem_num);
520     if (ret_code != 0) {
521       MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
522       return nullptr;
523     }
524     shape_value = shape_tensor;
525     abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);
526   } else {
527     // create ValueTuple
528     std::vector<ValuePtr> dim_values{};
529     abstract::AbstractBasePtrList abs{};
530     for (const auto &dim : shape) {
531       dim_values.push_back(MakeValue(dim));
532       abs.push_back(std::make_shared<abstract::AbstractScalar>(dim));
533     }
534     shape_value = std::make_shared<ValueTuple>(dim_values);
535     abstract = std::make_shared<abstract::AbstractTuple>(abs);
536   }
537   MS_EXCEPTION_IF_NULL(shape_value);
538   MS_EXCEPTION_IF_NULL(abstract);
539   auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value);
540   MS_EXCEPTION_IF_NULL(shape_value_node);
541   kernel_graph->AddValueNodeToGraph(shape_value_node);
542   return shape_value_node;
543 }
544 
ConstInputToAttr(const CNodePtr & cnode,const std::unordered_set<size_t> & input_attrs)545 void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) {
546   MS_EXCEPTION_IF_NULL(cnode);
547   std::vector<AnfNodePtr> new_inputs;
548   auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
549   MS_EXCEPTION_IF_NULL(primitive);
550   primitive = primitive->Clone();
551   auto input_names = primitive->GetAttr(kAttrInputNames);
552   if (input_names == nullptr) {
553     MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
554     return;
555   }
556   auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
557   auto inputs = cnode->inputs();
558   new_inputs.push_back(inputs[0]);
559   bool need_update = false;
560   for (size_t i = 0; i < inputs.size() - 1; ++i) {
561     auto input_node = inputs[i + 1];
562     if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimDepend)) {
563       input_node = AnfAlgo::VisitKernel(input_node, 0).first;
564     }
565     MS_EXCEPTION_IF_NULL(input_node);
566     if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>() && !HasAbstractMonad(input_node)) {
567       auto value_node = input_node->cast<ValueNodePtr>();
568       MS_EXCEPTION_IF_NULL(value_node);
569       MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
570       if (i >= input_names_vec.size()) {
571         MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
572       }
573       primitive->set_attr(input_names_vec[i], value_node->value());
574       need_update = true;
575     } else {
576       new_inputs.push_back(inputs[i + 1]);
577     }
578   }
579   if (need_update) {
580     // Update cnode's inputs
581     new_inputs[0] = NewValueNode(primitive);
582     cnode->set_inputs(new_inputs);
583   }
584 }
585 
AnfEqual(const BaseRef & a,const BaseRef & b)586 bool AnfEqual(const BaseRef &a, const BaseRef &b) {
587   if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
588     auto a_node = utils::cast<AnfNodePtr>(a);
589     auto b_node = utils::cast<AnfNodePtr>(b);
590     MS_EXCEPTION_IF_NULL(a_node);
591     MS_EXCEPTION_IF_NULL(b_node);
592     if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
593       auto a_value_node = a_node->cast<ValueNodePtr>();
594       MS_EXCEPTION_IF_NULL(a_value_node);
595       auto a_value = a_value_node->value();
596       MS_EXCEPTION_IF_NULL(a_value);
597       auto a_prim = a_value->cast<PrimitivePtr>();
598       MS_EXCEPTION_IF_NULL(a_prim);
599 
600       auto b_value_node = b_node->cast<ValueNodePtr>();
601       MS_EXCEPTION_IF_NULL(b_value_node);
602       auto b_value = b_value_node->value();
603       MS_EXCEPTION_IF_NULL(b_value);
604       auto b_prim = b_value->cast<PrimitivePtr>();
605       MS_EXCEPTION_IF_NULL(b_prim);
606 
607       return a_prim->name() == b_prim->name();
608     } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
609       auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
610       if (a_value_node_ptr == nullptr) {
611         MS_LOG(EXCEPTION) << "cast value node ptr fail";
612       }
613       auto a_value_ptr = a_value_node_ptr->value();
614       if (a_value_ptr == nullptr) {
615         MS_LOG(EXCEPTION) << "value ptr is nullptr";
616       }
617 
618       auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
619       if (b_value_node_ptr == nullptr) {
620         MS_LOG(EXCEPTION) << "cast value node ptr fail";
621       }
622       auto b_value_ptr = b_value_node_ptr->value();
623       if (b_value_ptr == nullptr) {
624         MS_LOG(EXCEPTION) << "value ptr is nullptr";
625       }
626 
627       return (*a_value_ptr) == (*b_value_ptr);
628     }
629     MS_LOG(DEBUG) << "check AnfNodePtr equal";
630   }
631   if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
632     MS_LOG(DEBUG) << "check GraphPtr equal";
633   }
634   return a == b;
635 }
636 
CNodeTypeEqual(const BaseRef & a,const BaseRef & b)637 bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
638   // To matchCNode and Kernel's type
639   if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
640     return true;
641   }
642   return a.type() == b.type();
643 }
644 
645 namespace {
CreateValueNodeWithSexp(const BaseRef & sexp)646 ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
647   if (utils::isa<int>(sexp)) {
648     return NewValueNode(utils::cast<int>(sexp));
649   }
650   if (utils::isa<int64_t>(sexp)) {
651     return NewValueNode(utils::cast<int64_t>(sexp));
652   }
653   if (utils::isa<float>(sexp)) {
654     return NewValueNode(utils::cast<float>(sexp));
655   }
656   if (utils::isa<bool>(sexp)) {
657     return NewValueNode(utils::cast<bool>(sexp));
658   }
659   if (utils::isa<ValuePtr>(sexp)) {
660     return NewValueNode(utils::cast<ValuePtr>(sexp));
661   }
662   return nullptr;
663 }
664 
CreateCNodeWithGraph(const std::vector<AnfNodePtr> & input_nodes,const BaseRef & graph)665 CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
666   if (utils::isa<FuncGraphPtr>(graph)) {
667     return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
668   }
669   if (utils::isa<VarPtr>(graph)) {
670     return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
671   }
672   return nullptr;
673 }
674 
CreateVarNodeWithSexp(const BaseRef & sexp,const BaseRef & graph)675 VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
676   if (utils::isa<VarPtr>(graph)) {
677     MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
678     return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
679   }
680   if (utils::isa<FuncGraphPtr>(graph)) {
681     MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
682     return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
683   }
684   MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
685   return nullptr;
686 }
687 
HandleSexpVector(const BaseRef & sexp,const BaseRef & graph,PrimitiveVarMap * primitive_vars,bool multigraph)688 AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
689                             bool multigraph) {
690   MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
691   std::vector<AnfNodePtr> input_nodes;
692   const auto &tuple = utils::cast<VectorRef>(sexp);
693   if (multigraph && utils::isa<VarPtr>(graph)) {
694     for (auto &x : tuple) {
695       AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
696       input_nodes.push_back(node);
697     }
698     VarPtr var_ptr = utils::cast<VarPtr>(graph);
699     return std::make_shared<CNode>(input_nodes, var_ptr);
700   }
701 
702   for (auto &x : tuple) {
703     AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
704     input_nodes.push_back(node);
705   }
706   return CreateCNodeWithGraph(input_nodes, graph);
707 }
708 
709 // rectify absttract if the input has been converted to the attr
RectifyAbstractFromRegAttr(const PrimitivePtr & primitive,const AbstractBasePtrList & input_abstract)710 AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
711                                                const AbstractBasePtrList &input_abstract) {
712   MS_EXCEPTION_IF_NULL(primitive);
713   opt::ConstInputToAttrInfoRegister reg;
714   if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), &reg)) {
715     return input_abstract;
716   }
717   if (AnfAlgo::HasDynamicShapeFlag(primitive)) {
718     return input_abstract;
719   }
720   auto ms_context = MsContext::GetInstance();
721   MS_EXCEPTION_IF_NULL(ms_context);
722   auto device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
723   if (device == kGPUDevice) {
724     if (DynamicShapeConstInputToAttrGPU.find(primitive->name()) != DynamicShapeConstInputToAttrGPU.end()) {
725       return input_abstract;
726     }
727   } else if (DynamicShapeConstInputToAttr.find(primitive->name()) != DynamicShapeConstInputToAttr.end()) {
728     return input_abstract;
729   }
730   auto convert_input_list = reg.GetConstInputAttrInfo();
731   auto input_names = primitive->GetAttr(kAttrInputNames);
732   if (input_names == nullptr) {
733     return input_abstract;
734   }
735   auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
736   AbstractBasePtrList rectify_abs_list;
737   size_t ori_index = 0;
738   rectify_abs_list.resize(input_names_vec.size());
739   for (size_t index = 0; index < rectify_abs_list.size(); ++index) {
740     // if convert input list find the index it means the input has been converted to the attr
741     if (convert_input_list.find(index) != convert_input_list.end()) {
742       AbstractBasePtr rectify_abs = nullptr;
743       auto input_name = input_names_vec[index];
744       auto attr = primitive->GetAttr(input_name);
745       if (attr != nullptr) {
746         rectify_abs = attr->ToAbstract();
747       } else {
748         MS_LOG(DEBUG) << "the node prim name :" << primitive->name() << "input index :" << index
749                       << " input name :" << input_name << "has not been converted to the attr";
750         rectify_abs = input_abstract[ori_index++];
751       }
752       rectify_abs_list[index] = rectify_abs;
753       continue;
754     }
755     if (ori_index > input_abstract.size()) {
756       MS_LOG(EXCEPTION) << "index is out of range input abstract size " << input_abstract.size()
757                         << " get index :" << ori_index;
758     }
759     rectify_abs_list[index] = input_abstract[ori_index++];
760   }
761   return rectify_abs_list;
762 }
763 
RectifyAbstractFromDynamicInput(const PrimitivePtr & primitive,const AbstractBasePtrList & input_abstract)764 AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
765                                                     const AbstractBasePtrList &input_abstract) {
766   auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
767   if (dynamic_inputs_list == nullptr) {
768     return input_abstract;
769   }
770   AbstractBasePtrList rectifyed_abs_list;
771   const int kNotDynamicFlag = -1;
772   auto dynamic_inputs_index = GetValue<std::vector<int64_t>>(dynamic_inputs_list);
773   size_t input_index = 0;
774   for (auto item : dynamic_inputs_index) {
775     if (item == kNotDynamicFlag) {
776       if (input_index >= input_abstract.size()) {
777         MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract " << input_abstract.size();
778       }
779       (void)rectifyed_abs_list.emplace_back(input_abstract[input_index++]);
780     } else {
781       if (item < 0) {
782         MS_LOG(EXCEPTION) << " the dynamic input size check error the index should be -1 or positive number but got "
783                           << item;
784       }
785       AbstractBasePtrList dynamic_inputs_abs;
786       for (auto index = item; index > 0; --index) {
787         if (input_index >= input_abstract.size()) {
788           MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract "
789                             << input_abstract.size();
790         }
791         (void)dynamic_inputs_abs.emplace_back(input_abstract[input_index++]);
792       }
793       (void)rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs));
794     }
795   }
796   return rectifyed_abs_list;
797 }
798 
RectifyAbstract(const PrimitivePtr & primitive,const AbstractBasePtrList & input_abstract)799 AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) {
800   auto rectify_abs_list = RectifyAbstractFromRegAttr(primitive, input_abstract);
801   return RectifyAbstractFromDynamicInput(primitive, rectify_abs_list);
802 }
803 }  // namespace
804 
SexpToNode(const BaseRef & sexp,const BaseRef & graph,PrimitiveVarMap * primitive_vars,bool multigraph)805 AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
806   MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
807   MS_EXCEPTION_IF_NULL(primitive_vars);
808   if (utils::isa<VectorRef>(sexp)) {
809     return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
810   }
811   if (utils::isa<VarPtr>(sexp)) {
812     auto var_ptr = utils::cast<VarPtr>(sexp);
813     MS_EXCEPTION_IF_NULL(var_ptr);
814     if (var_ptr->primitive()) {
815       (*primitive_vars)[var_ptr->primitive()] = var_ptr;
816       return NewValueNode(var_ptr->primitive());
817     }
818     return CreateVarNodeWithSexp(sexp, graph);
819   }
820   if (utils::isa<AnfNodePtr>(sexp)) {
821     return utils::cast<AnfNodePtr>(sexp);
822   }
823   auto value_node = CreateValueNodeWithSexp(sexp);
824   if (value_node == nullptr) {
825     MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
826   }
827   return value_node;
828 }
829 
IsSameNode(const EquivPtr & equiv1,const EquivPtr & equiv2,const VarPtr & var_node)830 bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) {
831   MS_EXCEPTION_IF_NULL(equiv1);
832   MS_EXCEPTION_IF_NULL(equiv2);
833   MS_EXCEPTION_IF_NULL(var_node);
834   auto equiv1_node = GetAnfNodeByVar(equiv1, var_node);
835   MS_EXCEPTION_IF_NULL(equiv1_node);
836   auto equiv2_node = GetAnfNodeByVar(equiv2, var_node);
837   MS_EXCEPTION_IF_NULL(equiv2_node);
838   return *equiv1_node == *equiv2_node;
839 }
840 
GetAnfNodeByVar(const EquivPtr & equiv,const VarPtr & var_node)841 AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) {
842   MS_EXCEPTION_IF_NULL(equiv);
843   MS_EXCEPTION_IF_NULL(var_node);
844   auto iter = (*equiv).find(var_node);
845   if (iter == (*equiv).end()) {
846     MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched.";
847     return nullptr;
848   }
849   auto res = utils::cast<AnfNodePtr>(iter->second);
850   if (res == nullptr) {
851     MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node";
852   }
853   return res;
854 }
855 
CompareTupleGetitem(const AnfNodePtr & n1,const AnfNodePtr & n2)856 bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
857   MS_EXCEPTION_IF_NULL(n1);
858   MS_EXCEPTION_IF_NULL(n2);
859   auto n1_cnode = n1->cast<CNodePtr>();
860   auto n2_cnode = n2->cast<CNodePtr>();
861   MS_EXCEPTION_IF_NULL(n1_cnode);
862   MS_EXCEPTION_IF_NULL(n2_cnode);
863   auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem);
864   MS_EXCEPTION_IF_NULL(index_input1);
865   auto value_node1 = index_input1->cast<ValueNodePtr>();
866   MS_EXCEPTION_IF_NULL(value_node1);
867   auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem);
868   MS_EXCEPTION_IF_NULL(index_input2);
869   auto value_node2 = index_input2->cast<ValueNodePtr>();
870   MS_EXCEPTION_IF_NULL(value_node2);
871   return GetValue<int64_t>(value_node1->value()) < GetValue<int64_t>(value_node2->value());
872 }
873 
GetBoolAttr(const AnfNodePtr & node,const std::string & attr_name)874 bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
875   MS_EXCEPTION_IF_NULL(node);
876   if (!node->isa<CNode>()) {
877     MS_LOG(INFO) << "node is not a cnode";
878     return false;
879   }
880   auto cnode = node->cast<CNodePtr>();
881   MS_EXCEPTION_IF_NULL(cnode);
882   return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr<bool>(node, attr_name);
883 }
884 
CheckSupportDataType(const AnfNodePtr & node,const std::set<TypeId> & supported_data_type_set)885 bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set) {
886   MS_EXCEPTION_IF_NULL(node);
887   TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0);
888   if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) {
889     return true;
890   }
891   MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
892   return false;
893 }
894 
MakeValueNode(const ValueNodePtr & value_node)895 ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
896   MS_EXCEPTION_IF_NULL(value_node);
897   ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
898   MS_EXCEPTION_IF_NULL(new_value_node);
899   new_value_node->set_abstract(value_node->abstract());
900   // create kernel_info fo new value node
901   auto kernel_info = std::make_shared<device::KernelInfo>();
902   new_value_node->set_kernel_info(kernel_info);
903   // create kernel_build_info for new value node
904   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
905   MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
906   // set the format of value_node to DEFAULT_FORMAT
907   kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
908   // set value node initial device data type = infer data type
909   std::vector<TypeId> types;
910   size_t output_num = AnfAlgo::GetOutputTensorNum(value_node);
911   for (size_t index = 0; index < output_num; ++index) {
912     types.push_back(kTypeUnknown);
913   }
914   kernel_build_info_builder->SetOutputsDeviceType(types);
915   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
916   return new_value_node;
917 }
918 
TransferDependOrUpdateState(const CNodePtr & old_node,const FuncGraphPtr & graph,const CNodePtr & new_node)919 void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
920   MS_EXCEPTION_IF_NULL(old_node);
921   MS_EXCEPTION_IF_NULL(graph);
922   auto manager = graph->manager();
923   MS_EXCEPTION_IF_NULL(manager);
924   // Find BatchNorm's output which is a Depend or UpdateState.
925   auto node_users = manager->node_users()[old_node];
926   for (const auto &node_index : node_users) {
927     AnfNodePtr output = node_index.first;
928     MS_EXCEPTION_IF_NULL(output);
929     if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
930         AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
931       auto depend = output->cast<CNodePtr>();
932       MS_EXCEPTION_IF_NULL(depend);
933       manager->SetEdge(depend, node_index.second, new_node);
934     }
935   }
936 }
937 
CppInferShape(const PrimitivePtr & prim,const AbstractBasePtrList & args_spec_list)938 AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
939   MS_EXCEPTION_IF_NULL(prim);
940   auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
941   auto ret = prim_eval_implement_map.find(prim);
942   if (ret != prim_eval_implement_map.end()) {
943     // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
944     MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_);
945     auto infer_spec_list = RectifyAbstract(prim, args_spec_list);
946     return ret->second.infer_shape_impl_(nullptr, prim, infer_spec_list);
947   } else {
948     // if the infer function has been not founded in the front infer map find it in the backend infer map instead
949     auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
950     auto ret_backend = prim_backend_eval_impl_map.find(prim);
951     if (ret_backend != prim_backend_eval_impl_map.end()) {
952       MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_);
953       auto infer_spec_list = args_spec_list;
954       if (!ret_backend->second.in_white_list_) {
955         infer_spec_list = RectifyAbstract(prim, args_spec_list);
956       }
957       return ret_backend->second.infer_shape_impl_(nullptr, prim, infer_spec_list);
958     }
959   }
960   MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
961                     << " primitive type:" << prim->type_name();
962 }
963 
GenerateKernelBuildInfo(const std::vector<AnfNodePtr> & node_list)964 kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
965   std::vector<std::string> inputs_device_format;
966   std::vector<std::string> outputs_device_format;
967   std::vector<TypeId> inputs_device_type;
968   std::vector<TypeId> outputs_device_type;
969   std::vector<std::vector<size_t>> outputs_shape;
970   kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
971   for (size_t idx = 0; idx < node_list.size(); ++idx) {
972     auto cnode = utils::cast<CNodePtr>(node_list[idx]);
973     MS_EXCEPTION_IF_NULL(cnode);
974     size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
975     for (size_t input_index = 0; input_index < input_num; ++input_index) {
976       (void)inputs_device_format.emplace_back(kOpFormat_DEFAULT);
977       (void)inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
978     }
979     size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
980     for (size_t output_index = 0; output_index < output_num; ++output_index) {
981       (void)outputs_device_format.emplace_back(kOpFormat_DEFAULT);
982       (void)outputs_device_type.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
983       (void)outputs_shape.emplace_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
984     }
985   }
986   builder.SetInputsFormat(inputs_device_format);
987   builder.SetOutputsFormat(outputs_device_format);
988   builder.SetInputsDeviceType(inputs_device_type);
989   builder.SetOutputsDeviceType(outputs_device_type);
990   return builder.Build();
991 }
992 
GetNodeOutputUsedNum(const session::KernelGraph & kernel_graph,const AnfNodePtr & node)993 std::vector<int64_t> GetNodeOutputUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
994   MS_EXCEPTION_IF_NULL(node);
995   auto manager = kernel_graph.manager();
996   MS_EXCEPTION_IF_NULL(manager);
997   auto output_num = AnfAlgo::GetOutputTensorNum(node);
998   std::vector<int64_t> output_used_num(output_num, 0);
999   if (output_num == 1) {
1000     output_used_num[0] = SizeToLong(manager->node_users()[node].size());
1001   } else {
1002     for (auto out_getitem : manager->node_users()[node]) {
1003       MS_EXCEPTION_IF_NULL(out_getitem.first);
1004       if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
1005         continue;
1006       }
1007       auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
1008       MS_EXCEPTION_IF_NULL(out_getitem_ptr);
1009       auto getitem_input2 = out_getitem_ptr->input(kInputNodeOutputIndexInTupleGetItem);
1010       auto output_idx = LongToSize(GetValue<int64_t>(GetValueNode(getitem_input2)));
1011       output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
1012     }
1013   }
1014   return output_used_num;
1015 }
1016 
GetNodeOutputTotalUsedNum(const session::KernelGraph & kernel_graph,const AnfNodePtr & node)1017 int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
1018   auto output_used_num = GetNodeOutputUsedNum(kernel_graph, node);
1019   return std::accumulate(output_used_num.begin(), output_used_num.end(), int64_t(0));
1020 }
1021 }  // namespace opt
1022 }  // namespace mindspore
1023