• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/session/anf_runtime_algorithm.h"
17 #include <memory>
18 #include <algorithm>
19 #include <map>
20 #include <set>
21 #include <functional>
22 #include <numeric>
23 #include "ir/anf.h"
24 #include "ir/func_graph.h"
25 #include "base/core_ops.h"
26 #include "utils/utils.h"
27 #include "utils/shape_utils.h"
28 #include "runtime/device/kernel_info.h"
29 #include "runtime/device/device_address.h"
30 #include "backend/optimizer/common/helper.h"
31 #include "backend/kernel_compiler/kernel.h"
32 #include "backend/kernel_compiler/kernel_build_info.h"
33 #include "common/trans.h"
34 #include "abstract/param_validator.h"
35 #include "pipeline/jit/static_analysis/static_analysis.h"
36 #include "utils/trace_base.h"
37 #include "ir/anf_utils.h"
38 
39 namespace mindspore {
40 namespace session {
41 using abstract::AbstractTensor;
42 using abstract::AbstractTuple;
43 using device::KernelInfo;
44 using device::ascend::AscendDeviceAddress;
45 using kernel::KernelBuildInfoPtr;
46 using kernel::KernelMod;
47 using kernel::KernelModPtr;
48 namespace {
49 constexpr size_t kNopNodeInputSize = 2;
50 constexpr size_t kNopNodeRealInputIndex = 1;
51 constexpr size_t kReturnDataIndex = 1;
52 
53 const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad};
54 
IsOneOfPrimitive(const AnfNodePtr & node,const PrimitiveSet & prim_set)55 bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
56   PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
57   return (prim && prim_set.find(prim) != prim_set.end());
58 }
59 
IsRealKernelCNode(const CNodePtr & cnode)60 bool IsRealKernelCNode(const CNodePtr &cnode) {
61 #ifndef ENABLE_SECURITY
62   static const PrimitiveSet virtual_prims = {
63     prim::kPrimImageSummary, prim::kPrimScalarSummary, prim::kPrimTensorSummary, prim::kPrimHistogramSummary,
64     prim::kPrimMakeTuple,    prim::kPrimStateSetItem,  prim::kPrimTupleGetItem,  prim::kPrimReturn,
65     prim::kPrimPartial,      prim::kPrimDepend,        prim::kPrimUpdateState,   prim::kPrimLoad};
66 #else
67   static const PrimitiveSet virtual_prims = {prim::kPrimMakeTuple,   prim::kPrimStateSetItem, prim::kPrimTupleGetItem,
68                                              prim::kPrimReturn,      prim::kPrimPartial,      prim::kPrimDepend,
69                                              prim::kPrimUpdateState, prim::kPrimLoad};
70 #endif
71   MS_EXCEPTION_IF_NULL(cnode);
72   if (cnode->inputs().empty()) {
73     MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << cnode->DebugString();
74   }
75   const auto &input = cnode->inputs().at(0);
76   bool is_virtual_node = IsOneOfPrimitive(input, virtual_prims);
77   return !is_virtual_node;
78 }
79 
TransShapeToSizet(const abstract::ShapePtr & shape)80 std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
81   MS_EXCEPTION_IF_NULL(shape);
82   std::vector<size_t> shape_size_t;
83   if (AnfUtils::IsShapeDynamic(shape)) {
84     if (std::all_of(shape->max_shape().begin(), shape->max_shape().end(), [](int64_t s) { return s >= 0; })) {
85       std::transform(shape->max_shape().begin(), shape->max_shape().end(), std::back_inserter(shape_size_t),
86                      LongToSize);
87     } else {
88       MS_LOG(EXCEPTION) << "Invalid Max Shape";
89     }
90   } else {
91     std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), LongToSize);
92   }
93   return shape_size_t;
94 }
95 
96 enum class ShapeType { kMaxShape, kMinShape };
97 
GetRealOutputRecursively(const AnfNodePtr & node,size_t output_index,std::vector<session::KernelWithIndex> * inputs)98 void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
99                               std::vector<session::KernelWithIndex> *inputs) {
100   MS_EXCEPTION_IF_NULL(node);
101   if (node->isa<ValueNode>() || node->isa<Parameter>()) {
102     return inputs->push_back(std::make_pair(node, 0));
103   }
104 
105   // Skip control node
106   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) ||
107       AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
108     return GetRealOutputRecursively(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), 0, inputs);
109   }
110 
111   // Bypass TupleGetItem
112   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
113     auto tuple_get_item = node->cast<CNodePtr>();
114     MS_EXCEPTION_IF_NULL(tuple_get_item);
115     auto input = AnfAlgo::GetTupleGetItemRealInput(tuple_get_item);
116     auto index = AnfAlgo::GetTupleGetItemOutIndex(tuple_get_item);
117 
118     // Conceal MakeTuple + TupleGetItem pair.
119     if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
120       auto make_tuple = input->cast<CNodePtr>();
121       MS_EXCEPTION_IF_NULL(make_tuple);
122       auto real_input = AnfAlgo::GetInputNode(make_tuple, index);
123       return GetRealOutputRecursively(real_input, 0, inputs);
124     }
125 
126     // Skip TupleGetItem.
127     return GetRealOutputRecursively(input, index, inputs);
128   }
129 
130   // Flatten MakeTuple inputs.
131   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
132     auto make_tuple = node->cast<CNodePtr>();
133     MS_EXCEPTION_IF_NULL(make_tuple);
134     size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple);
135     for (size_t input_index = 0; input_index < input_num; ++input_index) {
136       auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index);
137       GetRealOutputRecursively(input_node, 0, inputs);
138     }
139     return;
140   }
141 
142   return inputs->push_back(std::make_pair(node, output_index));
143 }
144 
145 // ops map that dynamic input order is differ from the fixed shape ops
146 static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_dynamic_node_list = {
147   {prim::kPrimConv2DBackpropInput->name(), {{{0, 2}, {1, 1}, {2, 0}}, {{0, 2}, {1, 1}, {2, 0}}}},
148   {prim::kPrimConv2DBackpropFilter->name(), {{{0, 1}, {1, 2}, {2, 0}}, {{1, 0}, {2, 1}, {0, 2}}}}};
149 
150 // pair: ms input order to tbe input order, and tbe input order to ms input order
151 static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_node_list = {
152   {prim::kPrimConv2DBackpropInput->name(), {{{0, 1}, {1, 0}}, {{0, 1}, {1, 0}}}},
153   {kFusionOpConv2DBackpropInputReluGradV2Name, {{{0, 1}, {1, 0}, {2, 2}}, {{0, 1}, {1, 0}, {2, 2}}}},
154   {kFusionOpConv2DBackpropInputAddNReluGradV2Name,
155    {{{0, 1}, {1, 0}, {2, 2}, {3, 3}}, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}},
156   {prim::kPrimConv2DBackpropFilter->name(), {{{0, 1}, {1, 0}}, {{0, 1}, {1, 0}}}},
157   {prim::kPrimLogSoftmaxGrad->name(), {{{0, 1}, {1, 0}}, {{0, 1}, {1, 0}}}},
158   {prim::kPrimLayerNormGrad->name(),
159    {{{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}, {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}},
160   {prim::kPrimLayerNormBetaGammaBackprop->name(), {{{0, 1}, {1, 0}, {2, 2}, {3, 3}}, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}},
161   {prim::kPrimLayerNormXBackprop->name(),
162    {{{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}, {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}},
163   {prim::kPrimLayerNormXBackpropV2->name(),
164    {{{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}, {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}},
165   {prim::kPrimMinimumGrad->name(), {{{0, 2}, {1, 0}, {2, 1}}, {{2, 0}, {0, 1}, {1, 2}}}},
166   {prim::kPrimMaximumGrad->name(), {{{0, 2}, {1, 0}, {2, 1}}, {{2, 0}, {0, 1}, {1, 2}}}},
167   {prim::kPrimApplyCenteredRMSProp->name(),
168    {{{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 8}, {8, 4}},
169     {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {5, 4}, {6, 5}, {7, 6}, {8, 7}, {4, 8}}}}};
170 }  // namespace
171 
MakeMonadValueNode(const KernelGraphPtr & kg)172 AnfNodePtr AnfRuntimeAlgorithm::MakeMonadValueNode(const KernelGraphPtr &kg) {
173   return kg->NewValueNode(kUMonad->ToAbstract(), kUMonad);
174 }
175 
176 // Convert: a = former(xxx)
177 //          b = latter(x, xxx)
178 // To:      a = former(xxx)
179 //          d1 = Depend(x, a)
180 //          b = latter(d1, xxx)
181 //          ...
182 //          out = Depend(out, latter)
KeepOrder(const KernelGraphPtr & kg,const AnfNodePtr & former,const AnfNodePtr & latter)183 void AnfRuntimeAlgorithm::KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter) {
184   MS_EXCEPTION_IF_NULL(kg);
185   MS_EXCEPTION_IF_NULL(latter);
186   if (latter->isa<CNode>()) {
187     auto latter_cnode = latter->cast<CNodePtr>();
188     MS_EXCEPTION_IF_NULL(latter_cnode);
189     constexpr size_t inputsize = 2;
190     constexpr size_t kFirstDataInputIndex = 1;
191     if (latter_cnode->inputs().size() < inputsize) {
192       return;
193     }
194     auto latter_input = latter_cnode->input(kFirstDataInputIndex);
195     auto depend1 = kg->NewCNode({NewValueNode(prim::kPrimDepend), latter_input, former});
196     MS_EXCEPTION_IF_NULL(depend1);
197     depend1->set_abstract(latter_input->abstract());
198     latter_cnode->set_input(kFirstDataInputIndex, depend1);
199 
200     auto return_node = kg->get_return();
201     MS_EXCEPTION_IF_NULL(return_node);
202     auto depend2 = kg->NewCNode(
203       {NewValueNode(prim::kPrimDepend), return_node->cast<CNodePtr>()->input(kFirstDataInputIndex), latter});
204     MS_EXCEPTION_IF_NULL(depend2);
205     depend2->set_abstract(return_node->cast<CNodePtr>()->input(kFirstDataInputIndex)->abstract());
206     kg->set_output(depend2);
207     MS_LOG(DEBUG) << "former: " << former->DebugString() << ", latter: " << latter->DebugString()
208                   << ", depend1: " << depend1->DebugString() << ", depend2: " << depend2->DebugString();
209   }
210 }
211 
GetTupleGetItemRealInput(const CNodePtr & tuple_get_item)212 AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
213   MS_EXCEPTION_IF_NULL(tuple_get_item);
214   if (tuple_get_item->size() != kTupleGetItemInputSize) {
215     MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
216   }
217   return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
218 }
219 
GetTupleGetItemOutIndex(const CNodePtr & tuple_get_item)220 size_t AnfRuntimeAlgorithm::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
221   MS_EXCEPTION_IF_NULL(tuple_get_item);
222   if (tuple_get_item->size() != kTupleGetItemInputSize) {
223     MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
224   }
225   auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
226   MS_EXCEPTION_IF_NULL(output_index_value_node);
227   auto value_node = output_index_value_node->cast<ValueNodePtr>();
228   MS_EXCEPTION_IF_NULL(value_node);
229   return LongToSize(GetValue<int64_t>(value_node->value()));
230 }
231 
VisitKernel(const AnfNodePtr & anf_node,size_t index)232 KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
233   MS_EXCEPTION_IF_NULL(anf_node);
234   if (anf_node->isa<ValueNode>()) {
235     return std::make_pair(anf_node, 0);
236   } else if (anf_node->isa<Parameter>()) {
237     return std::make_pair(anf_node, 0);
238   } else if (anf_node->isa<CNode>()) {
239     auto cnode = anf_node->cast<CNodePtr>();
240     MS_EXCEPTION_IF_NULL(cnode);
241     auto input0 = cnode->input(0);
242     MS_EXCEPTION_IF_NULL(input0);
243     if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
244       if (AnfAlgo::GetInputTensorNum(cnode) == 0) {
245         return std::make_pair(nullptr, 0);
246       }
247       auto node = cnode->input(index + IntToSize(1));
248       MS_EXCEPTION_IF_NULL(node);
249       return VisitKernel(node, 0);
250     } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
251       if (cnode->inputs().size() != kTupleGetItemInputSize) {
252         MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
253       }
254       auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
255       MS_EXCEPTION_IF_NULL(input2);
256       auto value_node = input2->cast<ValueNodePtr>();
257       MS_EXCEPTION_IF_NULL(value_node);
258       auto item_idx = GetValue<int64_t>(value_node->value());
259       return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), LongToSize(item_idx));
260     } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
261       return VisitKernel(cnode->input(kUpdateStateRealInput), 0);
262     } else if (IsOneOfPrimitive(input0, follow_first_input_prims)) {
263       return VisitKernel(cnode->input(kRealInputIndexInDepend), 0);
264     } else {
265       return std::make_pair(anf_node, index);
266     }
267   } else {
268     MS_LOG(EXCEPTION) << "The input is invalid";
269   }
270 }
271 
VisitKernelWithReturnType(const AnfNodePtr & anf_node,size_t index,bool skip_nop_node,const std::vector<PrimitivePtr> & return_types)272 KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
273                                                                bool skip_nop_node,
274                                                                const std::vector<PrimitivePtr> &return_types) {
275   MS_EXCEPTION_IF_NULL(anf_node);
276   if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool {
277         return CheckPrimitiveType(anf_node, prim_type);
278       })) {
279     return KernelWithIndex(anf_node, index);
280   }
281   if (!anf_node->isa<CNode>()) {
282     return KernelWithIndex(anf_node, 0);
283   }
284   auto cnode = anf_node->cast<CNodePtr>();
285   MS_EXCEPTION_IF_NULL(cnode);
286   if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
287     auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode),
288                                                          GetTupleGetItemOutIndex(cnode), skip_nop_node, return_types);
289     if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) {
290       MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
291       auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
292       MS_EXCEPTION_IF_NULL(make_tuple);
293       const std::vector<AnfNodePtr> &make_tuple_inputs = make_tuple->inputs();
294       size_t make_tuple_input_index = item_with_index_tmp.second + 1;
295       if (make_tuple_input_index >= make_tuple_inputs.size()) {
296         MS_LOG(EXCEPTION) << "Index[" << make_tuple_input_index << "] out of range[" << make_tuple_inputs.size()
297                           << "].";
298       }
299       return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, skip_nop_node, return_types);
300     }
301     return item_with_index_tmp;
302   }
303   if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
304     return VisitKernelWithReturnType(cnode->input(kUpdateStateStateInput), index, skip_nop_node, return_types);
305   }
306   if (AnfAlgo::IsOneOfPrimitiveCNode(cnode, follow_first_input_prims)) {
307     return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, skip_nop_node, return_types);
308   }
309   if (opt::IsNopNode(cnode) && skip_nop_node) {
310     if (cnode->size() != kNopNodeInputSize) {
311       MS_LOG(EXCEPTION) << "Invalid nop node " << cnode->DebugString() << " trace: " << trace::DumpSourceLines(cnode);
312     }
313     return VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, skip_nop_node, return_types);
314   }
315   return KernelWithIndex(anf_node, index);
316 }
317 
GetAllOutput(const AnfNodePtr & node,const std::vector<PrimitivePtr> & return_types)318 std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
319                                                           const std::vector<PrimitivePtr> &return_types) {
320   std::vector<AnfNodePtr> ret;
321   auto return_prim_type = return_types;
322   // if visited make_tuple should return back
323   return_prim_type.push_back(prim::kPrimMakeTuple);
324   auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type);
325   if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
326     MS_EXCEPTION_IF_NULL(item_with_index.first);
327     auto make_tuple = item_with_index.first->cast<CNodePtr>();
328     MS_EXCEPTION_IF_NULL(make_tuple);
329     for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
330       auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types);
331       (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret));
332     }
333     return ret;
334   }
335   ret.push_back(item_with_index.first);
336   return ret;
337 }
338 
GetAllOutputWithIndex(const AnfNodePtr & node)339 std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node) {
340   std::vector<KernelWithIndex> ret;
341   std::vector<KernelWithIndex> ret_empty;
342 
343   // The makeTuple node need expand and recurse.
344   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
345     auto make_tuple = node->cast<CNodePtr>();
346     MS_EXCEPTION_IF_NULL(make_tuple);
347     for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
348       auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i));
349       (void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret));
350     }
351     return ret;
352   }
353 
354   // The depend node need get the real node.
355   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
356     auto depend_node = node->cast<CNodePtr>();
357     MS_EXCEPTION_IF_NULL(depend_node);
358     auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend));
359     (void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret));
360     return ret;
361   }
362 
363   // Value node need get all the elements.
364   if (node->isa<ValueNode>()) {
365     auto value = node->cast<ValueNodePtr>()->value();
366     MS_EXCEPTION_IF_NULL(value);
367     if (value->isa<None>()) {
368       return ret;
369     } else if (value->isa<ValueTuple>()) {
370       auto value_tuple = value->cast<ValueTuplePtr>();
371       auto value_tuple_size = CountValueNum(value_tuple);
372       for (size_t i = 0; i < value_tuple_size; ++i) {
373         (void)ret.emplace_back(node, i);
374       }
375     } else {
376       (void)ret.emplace_back(node, 0);
377     }
378     return ret;
379   }
380 
381   const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimMakeTuple};
382   size_t outputs_num = 1;
383   if (IsRealCNodeKernel(node)) {
384     outputs_num = AnfAlgo::GetOutputTensorNum(node);
385   }
386   // The output may be the tuple of node, so need visit all the outputs of node.
387   for (size_t i = 0; i < outputs_num; ++i) {
388     auto output_with_index = AnfAlgo::VisitKernelWithReturnType(node, i, false, return_types);
389     MS_EXCEPTION_IF_NULL(output_with_index.first);
390 
391     // The depend and makeTuple node need recurse.
392     if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimDepend) ||
393         AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) {
394       auto output_vector = GetAllOutputWithIndex(output_with_index.first);
395       (void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
396       continue;
397     }
398 
399     // Ignore the output of front call node.
400     if (output_with_index.first->isa<CNode>()) {
401       auto cnode = output_with_index.first->cast<CNodePtr>();
402       MS_EXCEPTION_IF_NULL(cnode);
403       auto inputs = cnode->inputs();
404       if (inputs[0]->isa<CNode>()) {
405         MS_LOG(INFO) << "The output is call node: " << output_with_index.first->DebugString();
406         return ret_empty;
407       }
408     }
409 
410     // The InitDataSetQueue node has no output.
411     if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimInitDataSetQueue)) {
412       return ret_empty;
413     }
414 
415     MS_LOG(INFO) << "Output node: " << output_with_index.first->fullname_with_scope()
416                  << " with output index: " << output_with_index.second;
417     ret.push_back(output_with_index);
418   }
419 
420   return ret;
421 }
422 
GetCNodePrimitiveNode(const CNodePtr & node)423 AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) {
424   MS_EXCEPTION_IF_NULL(node);
425   return node->input(kAnfPrimitiveIndex);
426 }
427 
GetCNodePrimitive(const AnfNodePtr & node)428 PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) {
429   MS_EXCEPTION_IF_NULL(node);
430   auto cnode = node->cast<CNodePtr>();
431   MS_EXCEPTION_IF_NULL(cnode);
432   auto attr_input = GetCNodePrimitiveNode(cnode);
433   MS_EXCEPTION_IF_NULL(attr_input);
434   auto value_node = attr_input->cast<ValueNodePtr>();
435   MS_EXCEPTION_IF_NULL(value_node);
436   auto value = value_node->value();
437   MS_EXCEPTION_IF_NULL(value);
438   auto primitive = value->cast<PrimitivePtr>();
439   return primitive;
440 }
441 
CheckPrimitiveType(const AnfNodePtr & node,const PrimitivePtr & primitive_type)442 bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
443   MS_EXCEPTION_IF_NULL(node);
444   if (!node->isa<CNode>()) {
445     return false;
446   }
447   auto cnode = node->cast<CNodePtr>();
448   MS_EXCEPTION_IF_NULL(cnode);
449   return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
450 }
451 
GetCNodeFuncGraphPtr(const AnfNodePtr & node)452 FuncGraphPtr AnfRuntimeAlgorithm::GetCNodeFuncGraphPtr(const AnfNodePtr &node) {
453   MS_EXCEPTION_IF_NULL(node);
454   auto cnode = node->cast<CNodePtr>();
455   MS_EXCEPTION_IF_NULL(cnode);
456   auto attr_input = cnode->input(kAnfPrimitiveIndex);
457   MS_EXCEPTION_IF_NULL(attr_input);
458   auto value_node = attr_input->cast<ValueNodePtr>();
459   MS_EXCEPTION_IF_NULL(value_node);
460   auto value = value_node->value();
461   MS_EXCEPTION_IF_NULL(value);
462   return value->cast<FuncGraphPtr>();
463 }
464 
GetCNodeName(const AnfNodePtr & node)465 std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) {
466   MS_EXCEPTION_IF_NULL(node);
467   if (node->isa<CNode>()) {
468     auto primitive = AnfAlgo::GetCNodePrimitive(node);
469     if (primitive != nullptr) {
470       return primitive->name();
471     }
472     auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
473     MS_EXCEPTION_IF_NULL(func_graph);
474     if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
475       std::string fg_name = "GraphKernel_";
476       fg_name += GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
477       return fg_name;
478     }
479     return func_graph->ToString();
480   }
481   MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString() << " trace: " << trace::DumpSourceLines(node);
482 }
483 
GetNodeDebugString(const AnfNodePtr & node)484 std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) {
485   MS_EXCEPTION_IF_NULL(node);
486   return node->DebugString();
487 }
488 
SetNodeAttr(const std::string & key,const ValuePtr & value,const AnfNodePtr & node)489 void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
490   MS_EXCEPTION_IF_NULL(node);
491   if (!node->isa<CNode>()) {
492     MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString()
493                       << " trace: " << trace::DumpSourceLines(node);
494   }
495   // single op cnode.
496   auto primitive = AnfAlgo::GetCNodePrimitive(node);
497   if (primitive != nullptr) {
498     primitive->set_attr(key, value);
499     return;
500   }
501   // graph kernel cnode.
502   auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
503   MS_EXCEPTION_IF_NULL(fg);
504   fg->set_attr(key, value);
505 }
506 
CopyNodeAttr(const std::string & key,const AnfNodePtr & from,const AnfNodePtr & to)507 void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) {
508   CopyNodeAttr(key, key, from, to);
509 }
510 
CopyNodeAttr(const std::string & old_key,const std::string & new_key,const AnfNodePtr & from,const AnfNodePtr & to)511 void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
512                                        const AnfNodePtr &to) {
513   MS_EXCEPTION_IF_NULL(from);
514   MS_EXCEPTION_IF_NULL(to);
515   if (!from->isa<CNode>() || !to->isa<CNode>()) {
516     MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is "
517                       << to->DebugString() << " trace: " << trace::DumpSourceLines(from);
518   }
519   auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
520   MS_EXCEPTION_IF_NULL(from_primitive);
521   auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
522   MS_EXCEPTION_IF_NULL(to_primitive);
523   to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key));
524 }
525 
CopyNodeAttrs(const AnfNodePtr & from,const AnfNodePtr & to)526 void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) {
527   MS_EXCEPTION_IF_NULL(from);
528   MS_EXCEPTION_IF_NULL(to);
529   if (!from->isa<CNode>() || !to->isa<CNode>()) {
530     MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is "
531                       << from->DebugString() << " trace: " << trace::DumpSourceLines(from);
532   }
533   auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
534   MS_EXCEPTION_IF_NULL(from_primitive);
535   auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
536   MS_EXCEPTION_IF_NULL(to_primitive);
537   (void)to_primitive->SetAttrs(from_primitive->attrs());
538 }
539 
EraseNodeAttr(const std::string & key,const AnfNodePtr node)540 void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) {
541   MS_EXCEPTION_IF_NULL(node);
542   if (!node->isa<CNode>()) {
543     MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString()
544                       << " trace: " << trace::DumpSourceLines(node);
545   }
546   // single op cnode.
547   auto primitive = AnfAlgo::GetCNodePrimitive(node);
548   if (primitive != nullptr) {
549     primitive->EraseAttr(key);
550     return;
551   }
552   // graph kernel cnode.
553   auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
554   MS_EXCEPTION_IF_NULL(fg);
555   fg->erase_flag(key);
556 }
557 
HasNodeAttr(const std::string & key,const CNodePtr & node)558 bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) {
559   MS_EXCEPTION_IF_NULL(node);
560   if (!node->isa<CNode>()) {
561     MS_LOG(WARNING) << "Only cnode has attr, but this anf is " << node->DebugString();
562     return false;
563   }
564   // single op cnode.
565   auto primitive = AnfAlgo::GetCNodePrimitive(node);
566   if (primitive != nullptr) {
567     return primitive->HasAttr(key);
568   }
569   // graph kernel cnode.
570   auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
571   MS_EXCEPTION_IF_NULL(fg);
572   return fg->has_attr(key);
573 }
574 
GetInputNum(const CNodePtr & cnode)575 size_t AnfRuntimeAlgorithm::GetInputNum(const CNodePtr &cnode) {
576   MS_EXCEPTION_IF_NULL(cnode);
577   size_t input_num = cnode->size();
578   if (input_num == 0) {
579     MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero";
580   }
581   return input_num - 1;
582 }
583 
GetInputTensorNum(const AnfNodePtr & node)584 size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
585   MS_EXCEPTION_IF_NULL(node);
586   auto cnode = node->cast<CNodePtr>();
587   if (cnode == nullptr) {
588     MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString()
589                       << " trace: " << trace::DumpSourceLines(node);
590   }
591   ssize_t input_tensor_num = cnode->input_tensor_num();
592   if (input_tensor_num >= 0) {
593     return static_cast<size_t>(input_tensor_num);
594   }
595   size_t input_num = cnode->inputs().size();
596   if (input_num == 0) {
597     MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero"
598                       << " trace: " << trace::DumpSourceLines(node);
599   }
600   // Exclude inputs[0].
601   --input_num;
602 
603   // Exclude monad inputs for real cnodes.
604   if (input_num > 0 && IsRealKernelCNode(cnode)) {
605     auto &inputs = cnode->inputs();
606     // Search monad inputs, backward.
607     for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
608       if (!HasAbstractMonad(*iter)) {
609         // Stop count if we encounter a non-monad input.
610         break;
611       }
612       --input_num;
613     }
614   }
615   cnode->set_input_tensor_num(static_cast<ssize_t>(input_num));
616   return input_num;
617 }
618 
GetOutputTensorNum(const AnfNodePtr & node)619 size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
620   MS_EXCEPTION_IF_NULL(node);
621   TypePtr type = node->Type();
622   if (type == nullptr) {
623     return 0;
624   }
625   if (type->isa<Tuple>()) {
626     auto tuple_type = type->cast<TuplePtr>();
627     MS_EXCEPTION_IF_NULL(tuple_type);
628     return tuple_type->size();
629   }
630   if (type->isa<TypeNone>()) {
631     return 0;
632   }
633   return 1;
634 }
635 
GetOutputTensorMemSize(const AnfNodePtr & node,size_t output_index)636 size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index) {
637   MS_EXCEPTION_IF_NULL(node);
638   if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
639     MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
640                                 << AnfAlgo::GetOutputTensorNum(node) << "] of node!";
641   }
642   TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
643   if (output_type_id == kTypeUnknown) {
644     output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
645   }
646   size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
647   std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
648   auto format = AnfAlgo::GetOutputFormat(node, output_index);
649   if (shape.empty() && format != kOpFormat_DEFAULT) {
650     shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index));
651     shape = trans::TransShapeToDevice(shape, format, node, output_index);
652   }
653   // scalar's output shape is a empty vector
654   size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
655   return tensor_size;
656 }
657 
GetAllOutputFormats(const AnfNodePtr & node)658 std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) {
659   MS_EXCEPTION_IF_NULL(node);
660   if (!AnfAlgo::IsRealKernel(node)) {
661     MS_LOG(EXCEPTION) << "Not real kernel:"
662                       << "#node [" << node->DebugString() << "]"
663                       << " trace: " << trace::DumpSourceLines(node);
664   }
665   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
666   MS_EXCEPTION_IF_NULL(kernel_info);
667   auto build_info = kernel_info->select_kernel_build_info();
668   MS_EXCEPTION_IF_NULL(build_info);
669   auto format = build_info->GetAllOutputFormats();
670   return format;
671 }
672 
GetAllInputFormats(const AnfNodePtr & node)673 std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePtr &node) {
674   MS_EXCEPTION_IF_NULL(node);
675   if (!AnfAlgo::IsRealKernel(node)) {
676     MS_LOG(EXCEPTION) << "Not real kernel:"
677                       << "#node [" << node->DebugString() << "]"
678                       << " trace: " << trace::DumpSourceLines(node);
679   }
680   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
681   MS_EXCEPTION_IF_NULL(kernel_info);
682   auto build_info = kernel_info->select_kernel_build_info();
683   MS_EXCEPTION_IF_NULL(build_info);
684   auto format = build_info->GetAllInputFormats();
685   return format;
686 }
687 
GetAllInputDeviceTypes(const AnfNodePtr & node)688 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllInputDeviceTypes(const AnfNodePtr &node) {
689   MS_EXCEPTION_IF_NULL(node);
690   if (!AnfAlgo::IsRealKernel(node)) {
691     MS_LOG(EXCEPTION) << "Not real kernel:"
692                       << "#node [" << node->DebugString() << "]"
693                       << " trace: " << trace::DumpSourceLines(node);
694   }
695   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
696   MS_EXCEPTION_IF_NULL(kernel_info);
697   auto build_info = kernel_info->select_kernel_build_info();
698   MS_EXCEPTION_IF_NULL(build_info);
699   auto types = build_info->GetAllInputDeviceTypes();
700   return types;
701 }
702 
GetAllOutputDeviceTypes(const AnfNodePtr & node)703 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputDeviceTypes(const AnfNodePtr &node) {
704   MS_EXCEPTION_IF_NULL(node);
705   if (!AnfAlgo::IsRealKernel(node)) {
706     MS_LOG(EXCEPTION) << "Not real kernel:"
707                       << "#node [" << node->DebugString() << "]"
708                       << " trace: " << trace::DumpSourceLines(node);
709   }
710   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
711   MS_EXCEPTION_IF_NULL(kernel_info);
712   auto build_info = kernel_info->select_kernel_build_info();
713   MS_EXCEPTION_IF_NULL(build_info);
714   auto types = build_info->GetAllOutputDeviceTypes();
715   return types;
716 }
717 
GetOriginDataFormat(const AnfNodePtr & node)718 std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) {
719   MS_EXCEPTION_IF_NULL(node);
720   if (!AnfAlgo::IsRealKernel(node)) {
721     MS_LOG(EXCEPTION) << "Not real kernel:"
722                       << "#node [" << node->DebugString() << "]"
723                       << " trace: " << trace::DumpSourceLines(node);
724   }
725   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
726   MS_EXCEPTION_IF_NULL(kernel_info);
727   auto build_info = kernel_info->select_kernel_build_info();
728   MS_EXCEPTION_IF_NULL(build_info);
729   auto format = build_info->GetOriginDataFormat();
730   return format;
731 }
732 
GetOutputFormat(const AnfNodePtr & node,size_t output_idx)733 std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
734   MS_EXCEPTION_IF_NULL(node);
735   if (output_idx > GetOutputTensorNum(node)) {
736     MS_LOG(EXCEPTION) << "Output index:" << output_idx
737                       << " is out of the node output range :" << GetOutputTensorNum(node) << " #node ["
738                       << node->DebugString() << "]"
739                       << " trace: " << trace::DumpSourceLines(node);
740   }
741   if (!AnfAlgo::IsRealKernel(node)) {
742     return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
743   }
744   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
745   MS_EXCEPTION_IF_NULL(kernel_info);
746   auto build_info = kernel_info->select_kernel_build_info();
747   MS_EXCEPTION_IF_NULL(build_info);
748   auto format = build_info->GetOutputFormat(output_idx);
749   if (format == kernel::KernelBuildInfo::kInvalidFormat) {
750     MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
751                       << " has a invalid output format"
752                       << " trace: " << trace::DumpSourceLines(node);
753   }
754   return format;
755 }
756 
GetInputFormat(const AnfNodePtr & node,size_t input_idx)757 std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
758   MS_EXCEPTION_IF_NULL(node);
759   if (input_idx > GetInputTensorNum(node)) {
760     MS_LOG(EXCEPTION) << "Input index :" << input_idx
761                       << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node ["
762                       << node->DebugString() << "]"
763                       << " trace: " << trace::DumpSourceLines(node);
764   }
765   if (!IsRealKernel(node)) {
766     return GetPrevNodeOutputFormat(node, input_idx);
767   }
768   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
769   MS_EXCEPTION_IF_NULL(kernel_info);
770   auto build_info = kernel_info->select_kernel_build_info();
771   MS_EXCEPTION_IF_NULL(build_info);
772   auto format = build_info->GetInputFormat(input_idx);
773   if (format == kernel::KernelBuildInfo::kInvalidFormat) {
774     MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
775                       << " has a invalid input format"
776                       << " trace: " << trace::DumpSourceLines(node);
777   }
778   return format;
779 }
780 
GetPrevNodeOutput(const AnfNodePtr & anf_node,size_t input_idx,bool visit_nop_node)781 KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx,
782                                                        bool visit_nop_node) {
783   MS_EXCEPTION_IF_NULL(anf_node);
784   if (!anf_node->isa<CNode>()) {
785     MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."
786                       << " trace: " << trace::DumpSourceLines(anf_node);
787   }
788   if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
789     return VisitKernelWithReturnType(anf_node, 0, visit_nop_node);
790   }
791   auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
792   MS_EXCEPTION_IF_NULL(input_node);
793   return VisitKernelWithReturnType(input_node, 0, visit_nop_node);
794 }
795 
GetPrevNodeOutputFormat(const AnfNodePtr & anf_node,size_t input_idx)796 std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
797   KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
798   return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
799 }
800 
GetPrevNodeOutputReshapeType(const AnfNodePtr & node,size_t input_idx)801 std::string AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
802   KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
803   return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
804 }
805 
GetOutputInferShape(const AnfNodePtr & node,const abstract::BaseShapePtr & base_shape,size_t output_idx)806 std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node,
807                                                              const abstract::BaseShapePtr &base_shape,
808                                                              size_t output_idx) {
809   MS_EXCEPTION_IF_NULL(node);
810   MS_EXCEPTION_IF_NULL(base_shape);
811   if (base_shape->isa<abstract::Shape>()) {
812     if (output_idx == 0) {
813       return TransShapeToSizet(base_shape->cast<abstract::ShapePtr>());
814     }
815     MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << output_idx
816                       << "."
817                       << " trace: " << trace::DumpSourceLines(node);
818   } else if (base_shape->isa<abstract::TupleShape>()) {
819     auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
820     MS_EXCEPTION_IF_NULL(tuple_shape);
821     if (output_idx >= tuple_shape->size()) {
822       MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
823                         << " node:" << node->DebugString() << "."
824                         << " trace: " << trace::DumpSourceLines(node);
825     }
826     auto b_shp = (*tuple_shape)[output_idx];
827     if (b_shp->isa<abstract::Shape>()) {
828       return TransShapeToSizet(b_shp->cast<abstract::ShapePtr>());
829     } else if (b_shp->isa<abstract::NoShape>()) {
830       return std::vector<size_t>();
831     } else {
832       MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
833                         << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
834                         << "node :" << node->DebugString() << "."
835                         << " trace: " << trace::DumpSourceLines(node);
836     }
837   } else if (base_shape->isa<abstract::NoShape>()) {
838     return std::vector<size_t>();
839   }
840   MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
841                     << base_shape->ToString() << " node : " << node->DebugString()
842                     << " trace: " << trace::DumpSourceLines(node);
843 }
844 
GetOutputInferShape(const AnfNodePtr & node,size_t output_idx)845 std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
846   MS_EXCEPTION_IF_NULL(node);
847   return GetOutputInferShape(node, node->Shape(), output_idx);
848 }
849 
GetPrevNodeOutputInferShape(const AnfNodePtr & node,size_t input_idx)850 std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
851   KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
852   return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
853 }
854 
GetOutputDeviceShapeForTbeBuild(const AnfNodePtr & node,const size_t output_idx,const std::string & format)855 std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node,
856                                                                           const size_t output_idx,
857                                                                           const std::string &format) {
858   auto output_shape = GetOutputDetailShape(node, output_idx);
859   std::vector<int64_t> infer_shape;
860   if (output_shape->isa<abstract::Shape>()) {
861     auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
862     MS_EXCEPTION_IF_NULL(shape_ptr);
863     infer_shape = shape_ptr->shape();
864   }
865   if (infer_shape.empty()) {
866     return infer_shape;
867   }
868 
869   // if format is default_format or NC1KHKWHWC0,device shape = original shape
870   if (trans::IsNeedPadding(format, infer_shape.size())) {
871     infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx));
872   }
873   return trans::TransShapeToDevice(infer_shape, format, node, output_idx);
874 }
875 
GetOutputDeviceShape(const AnfNodePtr & node,size_t output_idx)876 std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
877   auto format = GetOutputFormat(node, output_idx);
878   auto infer_shape = GetOutputInferShape(node, output_idx);
879   if (infer_shape.empty()) {
880     return infer_shape;
881   }
882   // if format is default_format or NC1KHKWHWC0,device shape = original shape
883   if (trans::IsNeedPadding(format, infer_shape.size())) {
884     infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx));
885   }
886   return trans::TransShapeToDevice(infer_shape, format, node, output_idx);
887 }
888 
GetInputDeviceShape(const AnfNodePtr & node,size_t input_idx)889 std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
890   auto format = GetInputFormat(node, input_idx);
891   auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx);
892   if (infer_shape.empty()) {
893     return infer_shape;
894   }
895   // if format is default_format or NC1KHKWHWC0,device shape = original shape
896   if (trans::IsNeedPadding(format, infer_shape.size())) {
897     infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx));
898   }
899   return trans::TransShapeToDevice(infer_shape, format, node, input_idx, false);
900 }
901 
GetInputReshapeType(const AnfNodePtr & node,size_t input_idx)902 std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
903   MS_EXCEPTION_IF_NULL(node);
904   if (input_idx > GetInputTensorNum(node)) {
905     MS_LOG(EXCEPTION) << "The index:" << input_idx
906                       << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node["
907                       << node->DebugString() << "]"
908                       << " trace: " << trace::DumpSourceLines(node);
909   }
910   if (!IsRealKernel(node)) {
911     return GetPrevNodeOutputReshapeType(node, input_idx);
912   }
913   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
914   MS_EXCEPTION_IF_NULL(kernel_info);
915   auto build_info = kernel_info->select_kernel_build_info();
916   MS_EXCEPTION_IF_NULL(build_info);
917   if (build_info->IsInputDefaultPadding()) {
918     return "";
919   }
920   return build_info->GetInputReshapeType(input_idx);
921 }
922 
GetOutputReshapeType(const AnfNodePtr & node,size_t output_idx)923 std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
924   MS_EXCEPTION_IF_NULL(node);
925   if (output_idx > GetOutputTensorNum(node)) {
926     MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
927                       << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"
928                       << " trace: " << trace::DumpSourceLines(node);
929   }
930   if (!IsRealKernel(node)) {
931     return GetPrevNodeOutputReshapeType(node, output_idx);
932   }
933   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
934   MS_EXCEPTION_IF_NULL(kernel_info);
935   auto build_info = kernel_info->select_kernel_build_info();
936   MS_EXCEPTION_IF_NULL(build_info);
937   if (build_info->IsOutputDefaultPadding()) {
938     return "";
939   }
940   return build_info->GetOutputReshapeType(output_idx);
941 }
942 
GetOutputInferDataType(const TypePtr & type,size_t output_idx)943 TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
944   auto type_ptr = type;
945   MS_EXCEPTION_IF_NULL(type_ptr);
946   if (type_ptr->isa<Tuple>()) {
947     auto tuple_ptr = type_ptr->cast<TuplePtr>();
948     MS_EXCEPTION_IF_NULL(tuple_ptr);
949     if (output_idx >= tuple_ptr->size()) {
950       MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
951     }
952     type_ptr = (*tuple_ptr)[output_idx];
953     MS_EXCEPTION_IF_NULL(type_ptr);
954   }
955 
956   if (type_ptr->isa<TensorType>()) {
957     auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
958     MS_EXCEPTION_IF_NULL(tensor_ptr);
959     TypePtr elem = tensor_ptr->element();
960     MS_EXCEPTION_IF_NULL(elem);
961     return elem->type_id();
962   }
963 
964   return type_ptr->type_id();
965 }
966 
GetOutputInferDataType(const AnfNodePtr & node,size_t output_idx)967 TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
968   MS_EXCEPTION_IF_NULL(node);
969   return GetOutputInferDataType(node->Type(), output_idx);
970 }
971 
GetPrevNodeOutputInferDataType(const AnfNodePtr & node,size_t input_idx)972 TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
973   KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
974   return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
975 }
976 
GetOutputDeviceDataType(const AnfNodePtr & node,size_t output_idx)977 TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) {
978   MS_EXCEPTION_IF_NULL(node);
979   if (output_idx > GetOutputTensorNum(node)) {
980     MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
981                       << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"
982                       << " trace: " << trace::DumpSourceLines(node);
983   }
984   if (!IsRealKernel(node)) {
985     return GetPrevNodeOutputDeviceDataType(node, output_idx);
986   }
987   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
988   MS_EXCEPTION_IF_NULL(kernel_info);
989   auto build_info = kernel_info->select_kernel_build_info();
990   MS_EXCEPTION_IF_NULL(build_info);
991   auto dtype = build_info->GetOutputDeviceType(output_idx);
992   if (dtype == TypeId::kNumberTypeEnd) {
993     MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
994                       << " has a invalid dtype"
995                       << " trace: " << trace::DumpSourceLines(node);
996   }
997   return dtype;
998 }
999 
GetInputDeviceDataType(const AnfNodePtr & node,size_t input_idx)1000 TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
1001   MS_EXCEPTION_IF_NULL(node);
1002   if (input_idx > GetInputTensorNum(node)) {
1003     MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
1004                       << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"
1005                       << " trace: " << trace::DumpSourceLines(node);
1006   }
1007   if (!IsRealKernel(node)) {
1008     return GetPrevNodeOutputDeviceDataType(node, 0);
1009   }
1010   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1011   MS_EXCEPTION_IF_NULL(kernel_info);
1012   auto build_info = kernel_info->select_kernel_build_info();
1013   MS_EXCEPTION_IF_NULL(build_info);
1014   auto dtype = build_info->GetInputDeviceType(input_idx);
1015   if (dtype == TypeId::kNumberTypeEnd) {
1016     MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
1017                       << " has a invalid dtype"
1018                       << " trace: " << trace::DumpSourceLines(node);
1019   }
1020   return dtype;
1021 }
1022 
GetPrevNodeOutputDeviceDataType(const AnfNodePtr & anf_node,size_t input_idx)1023 TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
1024   KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
1025   return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
1026 }
1027 
1028 // get output device addr of anf_node
GetOutputAddr(const AnfNodePtr & node,size_t output_idx,bool visit_nop_node)1029 const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx,
1030                                                         bool visit_nop_node) {
1031   MS_EXCEPTION_IF_NULL(node);
1032   if (opt::IsNopNode(node) && visit_nop_node) {
1033     auto cnode = node->cast<CNodePtr>();
1034     MS_EXCEPTION_IF_NULL(cnode);
1035     if (cnode->size() == kNopNodeInputSize) {
1036       return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
1037     } else {
1038       MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"
1039                         << " trace: " << trace::DumpSourceLines(node);
1040     }
1041   }
1042   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1043   MS_EXCEPTION_IF_NULL(kernel_info);
1044   auto addr = kernel_info->GetOutputAddr(output_idx);
1045   if (addr == nullptr) {
1046     MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
1047                       << " output addr is not exist"
1048                       << " trace: " << trace::DumpSourceLines(node);
1049   }
1050   return addr;
1051 }
1052 
GetMutableOutputAddr(const AnfNodePtr & node,size_t output_idx,bool visit_nop_node)1053 DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx,
1054                                                            bool visit_nop_node) {
1055   MS_EXCEPTION_IF_NULL(node);
1056   if (opt::IsNopNode(node) && visit_nop_node) {
1057     auto cnode = node->cast<CNodePtr>();
1058     MS_EXCEPTION_IF_NULL(cnode);
1059     if (cnode->inputs().size() == kNopNodeInputSize) {
1060       return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
1061     } else {
1062       MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."
1063                         << " trace: " << trace::DumpSourceLines(node);
1064     }
1065   }
1066   // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
1067   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1068   MS_EXCEPTION_IF_NULL(kernel_info);
1069   auto addr = kernel_info->GetMutableOutputAddr(output_idx);
1070   if (addr == nullptr) {
1071     MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString() << " output addr is not exist"
1072                       << " trace: " << trace::DumpSourceLines(node);
1073   }
1074   return addr;
1075 }
1076 
1077 // get output device addr of anf_node
OutputAddrExist(const AnfNodePtr & node,size_t output_idx,bool visit_nop_node)1078 bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node) {
1079   MS_EXCEPTION_IF_NULL(node);
1080   if (opt::IsNopNode(node) && visit_nop_node) {
1081     auto cnode = node->cast<CNodePtr>();
1082     MS_EXCEPTION_IF_NULL(cnode);
1083     if (cnode->inputs().size() > 1) {
1084       auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(cnode, 0);
1085       return OutputAddrExist(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
1086     }
1087     return false;
1088   }
1089   // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
1090   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1091   MS_EXCEPTION_IF_NULL(kernel_info);
1092   return kernel_info->OutputAddrExist(output_idx);
1093 }
1094 
WorkspaceAddrExist(const AnfNodePtr & node,size_t output_idx)1095 bool AnfRuntimeAlgorithm::WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx) {
1096   MS_EXCEPTION_IF_NULL(node);
1097   // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
1098   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1099   MS_EXCEPTION_IF_NULL(kernel_info);
1100   return kernel_info->WorkspaceAddrExist(output_idx);
1101 }
1102 
GetPrevNodeOutputAddr(const AnfNodePtr & anf_node,size_t input_idx,bool visit_nop_node)1103 const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
1104                                                                 bool visit_nop_node) {
1105   KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
1106   return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
1107 }
1108 
GetPrevNodeMutableOutputAddr(const AnfNodePtr & anf_node,size_t input_idx,bool visit_nop_node)1109 DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
1110                                                                    bool visit_nop_node) {
1111   KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
1112   return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
1113 }
1114 
1115 // set output device addr of anf_node
SetOutputAddr(const DeviceAddressPtr & addr,size_t output_idx,AnfNode * node)1116 void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
1117   MS_EXCEPTION_IF_NULL(node);
1118   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1119   MS_EXCEPTION_IF_NULL(kernel_info);
1120   if (!kernel_info->SetOutputAddr(addr, output_idx)) {
1121     MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail."
1122                       << " trace: " << trace::DumpSourceLines(node);
1123   }
1124 }
1125 
1126 // set workspace device addr of anf_node
SetWorkspaceAddr(const DeviceAddressPtr & addr,size_t output_idx,AnfNode * node)1127 void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
1128   MS_EXCEPTION_IF_NULL(node);
1129   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1130   MS_EXCEPTION_IF_NULL(kernel_info);
1131   if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
1132     MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail。"
1133                       << " trace: " << trace::DumpSourceLines(node);
1134   }
1135 }
1136 
1137 // get workspace device addr of anf_node
GetWorkspaceAddr(const AnfNodePtr & node,size_t output_idx)1138 DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
1139   MS_EXCEPTION_IF_NULL(node);
1140   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1141   MS_EXCEPTION_IF_NULL(kernel_info);
1142   auto addr = kernel_info->GetWorkspaceAddr(output_idx);
1143   if (addr == nullptr) {
1144     MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
1145                       << "] workspace addr is not exist"
1146                       << " trace: " << trace::DumpSourceLines(node);
1147   }
1148   return addr;
1149 }
1150 
1151 // get workspace device mutable addr of anf_node
GetMutableWorkspaceAddr(const AnfNodePtr & node,size_t index)1152 DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index) {
1153   MS_EXCEPTION_IF_NULL(node);
1154   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1155   MS_EXCEPTION_IF_NULL(kernel_info);
1156   auto addr = kernel_info->GetMutableWorkspaceAddr(index);
1157   if (addr == nullptr) {
1158     MS_LOG(EXCEPTION) << "Index " << index << " of node " << node->DebugString() << "] workspace addr is not exist"
1159                       << " trace: " << trace::DumpSourceLines(node);
1160   }
1161   return addr;
1162 }
1163 
GetOutputDetailShape(const AnfNodePtr & node,size_t output_idx)1164 abstract::BaseShapePtr AnfRuntimeAlgorithm::GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx) {
1165   MS_EXCEPTION_IF_NULL(node);
1166   auto base_shape = node->Shape();
1167   MS_EXCEPTION_IF_NULL(base_shape);
1168   if (base_shape->isa<abstract::Shape>()) {
1169     if (output_idx == 0) {
1170       return base_shape;
1171     }
1172     MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << output_idx
1173                       << "."
1174                       << " trace: " << trace::DumpSourceLines(node);
1175   } else if (base_shape->isa<abstract::TupleShape>()) {
1176     auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
1177     MS_EXCEPTION_IF_NULL(tuple_shape);
1178     if (output_idx >= tuple_shape->size()) {
1179       MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
1180                         << " node:" << node->DebugString() << "."
1181                         << " trace: " << trace::DumpSourceLines(node);
1182     }
1183     auto b_shp = (*tuple_shape)[output_idx];
1184     if (b_shp->isa<abstract::Shape>() || b_shp->isa<abstract::NoShape>()) {
1185       return b_shp;
1186     } else {
1187       MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
1188                         << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
1189                         << "node :" << node->DebugString() << "."
1190                         << " trace: " << trace::DumpSourceLines(node);
1191     }
1192   } else if (base_shape->isa<abstract::NoShape>()) {
1193     return base_shape;
1194   }
1195   MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
1196                     << base_shape->ToString() << " node : " << node->DebugString()
1197                     << " trace: " << trace::DumpSourceLines(node);
1198 }
1199 
GetPrevNodeOutputDetailShape(const AnfNodePtr & node,size_t input_idx)1200 abstract::BaseShapePtr AnfRuntimeAlgorithm::GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx) {
1201   KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
1202   return AnfRuntimeAlgorithm::GetOutputDetailShape(kernel_with_index.first, kernel_with_index.second);
1203 }
1204 
1205 // set infer shapes and types of anf node
SetOutputTypeAndDetailShape(const std::vector<TypeId> & types,const std::vector<abstract::BaseShapePtr> & shapes,AnfNode * node)1206 void AnfRuntimeAlgorithm::SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
1207                                                       const std::vector<abstract::BaseShapePtr> &shapes,
1208                                                       AnfNode *node) {
1209   MS_EXCEPTION_IF_NULL(node);
1210   auto node_ptr = node->cast<AnfNodePtr>();
1211   MS_EXCEPTION_IF_NULL(node_ptr);
1212   std::string node_name = "";
1213   if (node_ptr->isa<CNode>()) {
1214     node_name = GetCNodeName(node_ptr);
1215   }
1216   if (types.size() != shapes.size()) {
1217     MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size()
1218                       << " trace: " << trace::DumpSourceLines(node);
1219   }
1220   if (shapes.empty() && node_name != prim::kPrimMakeTuple->name()) {
1221     node->set_abstract(std::make_shared<abstract::AbstractNone>());
1222   } else if (shapes.size() == 1 && node_name != prim::kPrimMakeTuple->name()) {
1223     // single output handle
1224     auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shapes[0]);
1225     node->set_abstract(abstract);
1226   } else {
1227     // multiple output handle
1228     std::vector<AbstractBasePtr> abstract_list;
1229     for (size_t i = 0; i < types.size(); ++i) {
1230       auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[i]), shapes[i]);
1231       abstract_list.emplace_back(abstract);
1232     }
1233     auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
1234     node->set_abstract(abstract_tuple);
1235   }
1236 }
1237 
1238 // set infer shapes and types of anf node
SetOutputInferTypeAndShape(const std::vector<TypeId> & types,const std::vector<std::vector<size_t>> & shapes,AnfNode * node)1239 void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
1240                                                      const std::vector<std::vector<size_t>> &shapes, AnfNode *node) {
1241   MS_EXCEPTION_IF_NULL(node);
1242   auto node_ptr = node->cast<AnfNodePtr>();
1243   std::string node_name = "";
1244   if (node_ptr->isa<CNode>()) {
1245     node_name = GetCNodeName(node_ptr);
1246   }
1247   MS_EXCEPTION_IF_NULL(node_ptr);
1248   if (types.size() != shapes.size()) {
1249     MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size()
1250                       << " trace: " << trace::DumpSourceLines(node);
1251   }
1252   auto abstract_ptr = node_ptr->abstract();
1253   if (shapes.empty() && node_name != prim::kPrimMakeTuple->name()) {
1254     node->set_abstract(std::make_shared<abstract::AbstractNone>());
1255   } else if (shapes.size() == 1 && node_name != prim::kPrimMakeTuple->name()) {
1256     // single output handle
1257     ShapeVector shape_int;
1258     abstract::AbstractTensorPtr abstract = nullptr;
1259     if (abstract_ptr != nullptr) {
1260       auto max_shape0 = GetOutputMaxShape(node_ptr, 0);
1261       auto min_shape0 = GetOutputMinShape(node_ptr, 0);
1262       std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToLong);
1263       abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]),
1264                                                   std::make_shared<abstract::Shape>(shape_int, min_shape0, max_shape0));
1265     } else {
1266       abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shape_int);
1267     }
1268     node->set_abstract(abstract);
1269   } else {
1270     // multiple output handle
1271     std::vector<AbstractBasePtr> abstract_list;
1272     for (size_t i = 0; i < types.size(); ++i) {
1273       ShapeVector shape_int;
1274       abstract::AbstractTensorPtr abstract = nullptr;
1275       if (abstract_ptr != nullptr) {
1276         auto max_shape = GetOutputMaxShape(node_ptr, i);
1277         auto min_shape = GetOutputMinShape(node_ptr, i);
1278         std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToLong);
1279         abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[i]),
1280                                                     std::make_shared<abstract::Shape>(shape_int, min_shape, max_shape));
1281       } else {
1282         abstract =
1283           std::make_shared<AbstractTensor>(TypeIdToType(types[i]), std::make_shared<abstract::Shape>(shape_int));
1284       }
1285       abstract_list.emplace_back(abstract);
1286     }
1287     auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
1288     node->set_abstract(abstract_tuple);
1289   }
1290 }
1291 // copy an abstract of a node to another node
CopyAbstract(const AnfNodePtr & from_node,AnfNode * to_node)1292 void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) {
1293   MS_EXCEPTION_IF_NULL(from_node);
1294   MS_EXCEPTION_IF_NULL(to_node);
1295   to_node->set_abstract(from_node->abstract());
1296 }
1297 
GetOpPattern(const AnfNodePtr & node)1298 kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
1299   MS_EXCEPTION_IF_NULL(node);
1300   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1301   MS_EXCEPTION_IF_NULL(kernel_info);
1302   // select_kernel_build_info() has checked whether return pointer is null
1303   auto build_info = kernel_info->select_kernel_build_info();
1304   MS_EXCEPTION_IF_NULL(build_info);
1305   return build_info->op_pattern();
1306 }
1307 
1308 // get KernelBuildType of node, such as ATT,RT,FWK and so on
GetKernelType(const AnfNodePtr & node)1309 KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
1310   MS_EXCEPTION_IF_NULL(node);
1311   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1312   MS_EXCEPTION_IF_NULL(kernel_info);
1313   // select_kernel_build_info() has checked whether return pointer is null
1314   auto build_info = kernel_info->select_kernel_build_info();
1315   MS_EXCEPTION_IF_NULL(build_info);
1316   return build_info->kernel_type();
1317 }
1318 
SetFusionType(const AnfNodePtr & node,const kernel::FusionType & type)1319 void AnfRuntimeAlgorithm::SetFusionType(const AnfNodePtr &node, const kernel::FusionType &type) {
1320   MS_EXCEPTION_IF_NULL(node);
1321   auto builder =
1322     std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
1323   MS_EXCEPTION_IF_NULL(builder);
1324   builder->SetFusionType(type);
1325   AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
1326 }
1327 
SetOutputDataDesc(const AnfNodePtr & node,const std::vector<nlohmann::json> & desc)1328 void AnfRuntimeAlgorithm::SetOutputDataDesc(const AnfNodePtr &node, const std::vector<nlohmann::json> &desc) {
1329   MS_EXCEPTION_IF_NULL(node);
1330   auto builder =
1331     std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
1332   MS_EXCEPTION_IF_NULL(builder);
1333   builder->SetOutputDataDesc(desc);
1334   AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
1335 }
1336 
GetOutputDataDesc(const AnfNodePtr & node)1337 std::vector<nlohmann::json> AnfRuntimeAlgorithm::GetOutputDataDesc(const AnfNodePtr &node) {
1338   MS_EXCEPTION_IF_NULL(node);
1339   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1340   if (kernel_info == nullptr) {
1341     return {};
1342   }
1343   auto build_info = kernel_info->select_kernel_build_info();
1344   if (build_info == nullptr) {
1345     return {};
1346   }
1347   return build_info->output_data_desc();
1348 }
1349 
GetProcessor(const AnfNodePtr & node)1350 kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
1351   MS_EXCEPTION_IF_NULL(node);
1352   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1353   MS_EXCEPTION_IF_NULL(kernel_info);
1354   auto build_info = kernel_info->select_kernel_build_info();
1355   MS_EXCEPTION_IF_NULL(build_info);
1356   return build_info->processor();
1357 }
1358 
GetFusionType(const AnfNodePtr & node)1359 kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
1360   MS_EXCEPTION_IF_NULL(node);
1361   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1362   MS_EXCEPTION_IF_NULL(kernel_info);
1363   auto build_info = kernel_info->select_kernel_build_info();
1364   if (build_info == nullptr) {
1365     return kernel::FusionType::UNKNOWN_FUSION_TYPE;
1366   }
1367   return build_info->fusion_type();
1368 }
1369 
1370 // set select kernel_build_info
SetSelectKernelBuildInfo(const KernelBuildInfoPtr & select_kernel_build_info,AnfNode * node)1371 void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
1372   MS_EXCEPTION_IF_NULL(node);
1373   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1374   MS_EXCEPTION_IF_NULL(kernel_info);
1375   return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
1376 }
1377 
1378 // get select kernel_build_info
GetSelectKernelBuildInfo(const AnfNodePtr & node)1379 KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
1380   MS_EXCEPTION_IF_NULL(node);
1381   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1382   MS_EXCEPTION_IF_NULL(kernel_info);
1383   return kernel_info->GetMutableSelectKernelBuildInfo();
1384 }
1385 
1386 // get kernelMode
GetKernelMod(const AnfNodePtr & node)1387 KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
1388   MS_EXCEPTION_IF_NULL(node);
1389   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1390   MS_EXCEPTION_IF_NULL(kernel_info);
1391   return kernel_info->MutableKernelMod();
1392 }
1393 
1394 // set kernel mod
SetKernelMod(const KernelModPtr & kernel_mod,AnfNode * node)1395 void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
1396   MS_EXCEPTION_IF_NULL(node);
1397   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1398   MS_EXCEPTION_IF_NULL(kernel_info);
1399   kernel_info->set_kernel_mod(kernel_mod);
1400 }
1401 
IsRealKernel(const AnfNodePtr & node)1402 bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
1403   MS_EXCEPTION_IF_NULL(node);
1404   // parameter and value node is a real kernel too
1405   if (!node->isa<CNode>()) {
1406     return true;
1407   }
1408   auto cnode = node->cast<CNodePtr>();
1409   MS_EXCEPTION_IF_NULL(cnode);
1410   if (cnode->inputs().empty()) {
1411     MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString()
1412                       << " trace: " << trace::DumpSourceLines(node);
1413   }
1414   return IsRealKernelCNode(cnode);
1415 }
1416 
IsRealCNodeKernel(const AnfNodePtr & node)1417 bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) {
1418   MS_EXCEPTION_IF_NULL(node);
1419   // parameter and value node is not a real cnode kernel
1420   if (!node->isa<CNode>()) {
1421     return false;
1422   }
1423   // return considered as a real node
1424   if (CheckPrimitiveType(node, prim::kPrimReturn)) {
1425     return true;
1426   }
1427   return IsRealKernel(node);
1428 }
1429 
IsGraphKernel(const AnfNodePtr & node)1430 bool AnfRuntimeAlgorithm::IsGraphKernel(const AnfNodePtr &node) {
1431   MS_EXCEPTION_IF_NULL(node);
1432   // graph kernel should be a real cnode kernel.
1433   if (!IsRealCNodeKernel(node)) {
1434     return false;
1435   }
1436 
1437   auto cnode = node->cast<CNodePtr>();
1438   MS_EXCEPTION_IF_NULL(cnode);
1439   auto input = cnode->input(kAnfPrimitiveIndex);
1440   // graph kernel should has func_graph as first input.
1441   if (!IsValueNode<FuncGraph>(input)) {
1442     return false;
1443   }
1444 
1445   auto func_graph = GetValueNode<FuncGraphPtr>(input);
1446   MS_EXCEPTION_IF_NULL(func_graph);
1447   return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
1448 }
1449 
IsNodeInGraphKernel(const AnfNodePtr & node)1450 bool AnfRuntimeAlgorithm::IsNodeInGraphKernel(const AnfNodePtr &node) {
1451   MS_EXCEPTION_IF_NULL(node);
1452   return node->func_graph() != nullptr && node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
1453 }
1454 
GetOutputOfGraphkernel(const KernelWithIndex & kernel_with_index)1455 AnfNodePtr AnfRuntimeAlgorithm::GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index) {
1456   auto func_graph = GetCNodeFuncGraph(kernel_with_index.first);
1457   if (func_graph == nullptr) {
1458     return kernel_with_index.first;
1459   }
1460   auto output = func_graph->output();
1461   if (CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
1462     return output->cast<CNodePtr>()->input(kernel_with_index.second + 1);
1463   }
1464   return output;
1465 }
1466 
IsParameterWeight(const ParameterPtr & node)1467 bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
1468   MS_EXCEPTION_IF_NULL(node);
1469   return node->has_default();
1470 }
1471 
IsLabelIndexInNode(const AnfNodePtr & node,size_t label_index)1472 bool AnfRuntimeAlgorithm::IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index) {
1473   MS_EXCEPTION_IF_NULL(node);
1474   if (!node->isa<CNode>()) {
1475     return false;
1476   }
1477   auto cnode = node->cast<CNodePtr>();
1478   MS_EXCEPTION_IF_NULL(cnode);
1479   if (AnfAlgo::GetCNodeName(cnode) == kLabelGotoOpName &&
1480       (AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex) == label_index)) {
1481     return true;
1482   } else if (AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
1483     auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrLabelSwitchList);
1484     if (std::find(label_list.begin(), label_list.end(), label_index) != label_list.end()) {
1485       return true;
1486     }
1487   }
1488   return false;
1489 }
1490 
SetStreamId(uint32_t stream_id,AnfNode * node)1491 void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
1492   MS_EXCEPTION_IF_NULL(node);
1493   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1494   MS_EXCEPTION_IF_NULL(kernel_info);
1495   kernel_info->set_stream_id(stream_id);
1496 }
1497 
GetStreamId(const AnfNodePtr & node)1498 uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
1499   MS_EXCEPTION_IF_NULL(node);
1500   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1501   MS_EXCEPTION_IF_NULL(kernel_info);
1502   return kernel_info->stream_id();
1503 }
1504 
SetStreamDistinctionLabel(uint32_t stream_label,AnfNode * node)1505 void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
1506   MS_EXCEPTION_IF_NULL(node);
1507   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1508   MS_EXCEPTION_IF_NULL(kernel_info);
1509   kernel_info->set_stream_distinction_label(stream_label);
1510 }
1511 
GetStreamDistinctionLabel(const AnfNode * node)1512 uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
1513   MS_EXCEPTION_IF_NULL(node);
1514   auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
1515   MS_EXCEPTION_IF_NULL(kernel_info);
1516   return kernel_info->stream_distinction_label();
1517 }
1518 
SetGraphId(uint32_t graph_id,AnfNode * node)1519 void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
1520   MS_EXCEPTION_IF_NULL(node);
1521   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1522   MS_EXCEPTION_IF_NULL(kernel_info);
1523   kernel_info->set_graph_id(graph_id);
1524 }
1525 
GetGraphId(const AnfNode * node)1526 uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
1527   MS_EXCEPTION_IF_NULL(node);
1528   auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
1529   MS_EXCEPTION_IF_NULL(kernel_info);
1530   return kernel_info->graph_id();
1531 }
1532 
IsTupleOutput(const AnfNodePtr & anf)1533 bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
1534   MS_EXCEPTION_IF_NULL(anf);
1535   TypePtr type = anf->Type();
1536   if (type == nullptr) {
1537     return false;
1538   }
1539   MS_EXCEPTION_IF_NULL(type);
1540   return type->isa<Tuple>();
1541 }
1542 
GetInputNode(const CNodePtr & node,size_t index)1543 AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
1544   MS_EXCEPTION_IF_NULL(node);
1545   auto get_input_index = index + 1;
1546   if (get_input_index >= node->inputs().size()) {
1547     MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
1548                       << node->inputs().size() << " trace: " << trace::DumpSourceLines(node);
1549   }
1550   // input 0 is primitive node
1551   return node->input(get_input_index);
1552 }
1553 
IsFeatureMapOutput(const AnfNodePtr & node)1554 bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
1555   MS_EXCEPTION_IF_NULL(node);
1556   if (node->isa<ValueNode>()) {
1557     return false;
1558   }
1559   if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
1560     return IsFeatureMapOutput(node->cast<CNodePtr>()->input(1));
1561   }
1562   auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
1563   MS_EXCEPTION_IF_NULL(kernel_info);
1564   return kernel_info->is_feature_map();
1565 }
1566 
IsFeatureMapInput(const AnfNodePtr & node,size_t input_index)1567 bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) {
1568   MS_EXCEPTION_IF_NULL(node);
1569   if (!node->isa<CNode>()) {
1570     MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map"
1571                       << " trace: " << trace::DumpSourceLines(node);
1572   }
1573   auto cnode = node->cast<CNodePtr>();
1574   MS_EXCEPTION_IF_NULL(cnode);
1575   auto input_node = cnode->input(input_index + 1);
1576   return IsFeatureMapOutput(input_node);
1577 }
1578 
GetRealInputIndex(const mindspore::AnfNodePtr & anf_node,const size_t cur_index)1579 size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
1580   MS_EXCEPTION_IF_NULL(anf_node);
1581   size_t ret = cur_index;
1582   auto node_name = AnfAlgo::GetCNodeName(anf_node);
1583   if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
1584     if (AnfAlgo::IsNodeDynamicShape(anf_node) || AnfAlgo::IsDynamicShape(anf_node)) {
1585       auto find_dynamic = spec_dynamic_node_list.find(node_name);
1586       if (find_dynamic != spec_dynamic_node_list.end()) {
1587         auto dyn_index_converter = find_dynamic->second;
1588         ret = dyn_index_converter.first[cur_index];
1589         MS_LOG(DEBUG) << "Real input index change to " << ret << ", node name:" << node_name;
1590         return ret;
1591       }
1592     }
1593     auto find = spec_node_list.find(node_name);
1594     if (find != spec_node_list.end()) {
1595       auto index_converter = find->second;
1596       ret = index_converter.first[cur_index];
1597       MS_LOG(DEBUG) << "Real input index change to " << ret << ", node name:" << node_name;
1598     }
1599   }
1600   return ret;
1601 }
1602 
GetOriginalInputIndex(const mindspore::AnfNodePtr & anf_node,const size_t cur_index)1603 size_t AnfRuntimeAlgorithm::GetOriginalInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
1604   MS_EXCEPTION_IF_NULL(anf_node);
1605   size_t ret = cur_index;
1606   auto node_name = AnfAlgo::GetCNodeName(anf_node);
1607   if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
1608     if (AnfAlgo::IsNodeDynamicShape(anf_node) || AnfAlgo::IsDynamicShape(anf_node)) {
1609       auto find_dynamic = spec_dynamic_node_list.find(node_name);
1610       if (find_dynamic != spec_dynamic_node_list.end()) {
1611         auto dyn_index_converter = find_dynamic->second;
1612         ret = dyn_index_converter.second[cur_index];
1613         MS_LOG(DEBUG) << "Get original input index " << ret << ", node name:" << node_name;
1614         return ret;
1615       }
1616     }
1617     auto find = spec_node_list.find(node_name);
1618     if (find != spec_node_list.end()) {
1619       auto index_converter = find->second;
1620       ret = index_converter.second[cur_index];
1621       MS_LOG(DEBUG) << "Get original input index " << ret << ", node name:" << node_name;
1622     }
1623   }
1624   return ret;
1625 }
1626 
SetNodeInput(const CNodePtr & node,const AnfNodePtr & input_node,size_t index)1627 void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) {
1628   MS_EXCEPTION_IF_NULL(node);
1629   MS_EXCEPTION_IF_NULL(input_node);
1630   node->set_input(index + 1, input_node);
1631 }
1632 
IsInplaceNode(const mindspore::AnfNodePtr & kernel,const string & type)1633 bool AnfRuntimeAlgorithm::IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &type) {
1634   MS_EXCEPTION_IF_NULL(kernel);
1635   auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
1636   if (!primitive) {
1637     return false;
1638   }
1639 
1640   auto inplace_attr = primitive->GetAttr(type);
1641   if (inplace_attr == nullptr) {
1642     return false;
1643   }
1644 
1645   return true;
1646 }
1647 
IsCommunicationOp(const AnfNodePtr & node)1648 bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
1649   static const std::set<std::string> kCommunicationOpNames = {kAllReduceOpName,     kAllGatherOpName, kBroadcastOpName,
1650                                                               kReduceScatterOpName, kHcomSendOpName,  kReceiveOpName,
1651                                                               kAllToAllVOpName};
1652   MS_EXCEPTION_IF_NULL(node);
1653   if (!node->isa<CNode>()) {
1654     return false;
1655   }
1656   auto kernel_name = AnfAlgo::GetCNodeName(node);
1657   return (kCommunicationOpNames.find(kernel_name) != kCommunicationOpNames.end());
1658 }
1659 
IsFusedCommunicationOp(const AnfNodePtr & node)1660 bool AnfRuntimeAlgorithm::IsFusedCommunicationOp(const AnfNodePtr &node) {
1661   if (!IsCommunicationOp(node)) {
1662     return false;
1663   }
1664   auto primitive = AnfAlgo::GetCNodePrimitive(node);
1665   MS_EXCEPTION_IF_NULL(primitive);
1666   ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
1667   if (attr_fusion == nullptr) {
1668     return false;
1669   }
1670   auto fusion = GetValue<int64_t>(attr_fusion);
1671   if (fusion == 0) {
1672     return false;
1673   }
1674   return true;
1675 }
1676 
IsGetNext(const NotNull<AnfNodePtr> & node)1677 bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
1678   auto kernel_name = AnfAlgo::GetCNodeName(node);
1679   return kernel_name == kGetNextOpName;
1680 }
1681 
GetValueNodeFuncGraph(const AnfNodePtr & node)1682 FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) {
1683   MS_EXCEPTION_IF_NULL(node);
1684   auto value_node = node->cast<ValueNodePtr>();
1685   if (value_node == nullptr) {
1686     return nullptr;
1687   }
1688   auto value = value_node->value();
1689   if (value == nullptr) {
1690     return nullptr;
1691   }
1692   auto func_graph = value->cast<FuncGraphPtr>();
1693   return func_graph;
1694 }
1695 
GetCallSwitchKernelGraph(const CNodePtr & cnode)1696 std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) {
1697   MS_EXCEPTION_IF_NULL(cnode);
1698   if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
1699         AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) {
1700     MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch or switch_layer node."
1701                       << " trace: " << trace::DumpSourceLines(cnode);
1702   }
1703   auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr {
1704     auto partial = cnode->input(input_index);
1705     MS_EXCEPTION_IF_NULL(partial);
1706     if (IsValueNode<KernelGraph>(partial)) {
1707       return GetValueNode<KernelGraphPtr>(partial);
1708     }
1709     auto partial_cnode = partial->cast<CNodePtr>();
1710     MS_EXCEPTION_IF_NULL(partial_cnode);
1711     auto graph_node = partial_cnode->input(kCallKernelGraphIndex);
1712     MS_EXCEPTION_IF_NULL(graph_node);
1713     auto graph_value_node = graph_node->cast<ValueNodePtr>();
1714     MS_EXCEPTION_IF_NULL(graph_value_node);
1715     auto graph_value = graph_value_node->value();
1716     MS_EXCEPTION_IF_NULL(graph_value);
1717     auto child_graph = graph_value->cast<KernelGraphPtr>();
1718     return child_graph;
1719   };
1720   if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
1721     auto input1 = cnode->input(kCallKernelGraphIndex);
1722     MS_EXCEPTION_IF_NULL(input1);
1723     auto value_node = input1->cast<ValueNodePtr>();
1724     MS_EXCEPTION_IF_NULL(value_node);
1725     auto kernel_graph = value_node->value();
1726     MS_EXCEPTION_IF_NULL(kernel_graph);
1727     return {kernel_graph->cast<KernelGraphPtr>()};
1728   } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
1729     return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex),
1730             get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)};
1731   } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
1732     std::vector<KernelGraphPtr> child_graphs;
1733     for (size_t idx = kMakeTupleInSwitchLayerIndex; idx < cnode->inputs().size(); idx++) {
1734       auto child_graph = get_switch_kernel_graph(idx);
1735       child_graphs.emplace_back(child_graph);
1736     }
1737     return child_graphs;
1738   }
1739   return {};
1740 }
1741 
IsSwitchCall(const CNodePtr & call_node)1742 bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
1743   MS_EXCEPTION_IF_NULL(call_node);
1744   if (!CheckPrimitiveType(call_node, prim::kPrimCall)) {
1745     MS_LOG(EXCEPTION) << "Call node should be a 'call', but is a " << call_node->DebugString()
1746                       << " trace: " << trace::DumpSourceLines(call_node);
1747   }
1748   auto input1 = call_node->input(1);
1749   MS_EXCEPTION_IF_NULL(input1);
1750   if (input1->isa<ValueNode>()) {
1751     return false;
1752   } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
1753     return true;
1754   }
1755   MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString()
1756                     << " trace: " << trace::DumpSourceLines(call_node);
1757 }
1758 
IsScalarInput(const CNodePtr & cnode,size_t index)1759 bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) {
1760   auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
1761   if (shape.empty()) {
1762     return true;
1763   }
1764   return shape.size() == kShape1dDims && shape[0] == 1;
1765 }
1766 
IsScalarOutput(const CNodePtr & cnode,size_t index)1767 bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) {
1768   auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
1769   if (shape.empty()) {
1770     return true;
1771   }
1772   return shape.size() == kShape1dDims && shape[0] == 1;
1773 }
1774 
1775 namespace {
FindDelayExecPosition(const std::vector<CNodePtr> & nodes,size_t current_index,std::set<size_t> * invalid_position,std::map<size_t,std::vector<CNodePtr>> * insert_nodes)1776 void FindDelayExecPosition(const std::vector<CNodePtr> &nodes, size_t current_index, std::set<size_t> *invalid_position,
1777                            std::map<size_t, std::vector<CNodePtr>> *insert_nodes) {
1778   MS_EXCEPTION_IF_NULL(invalid_position);
1779   MS_EXCEPTION_IF_NULL(insert_nodes);
1780   if (current_index >= nodes.size()) {
1781     return;
1782   }
1783   auto &node = nodes[current_index];
1784   for (size_t j = current_index + 1; j < nodes.size(); ++j) {
1785     auto &child = nodes[j];
1786     auto input_size = child->inputs().size() - 1;
1787     for (size_t k = 0; k < input_size; ++k) {
1788       auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true);
1789       if (kernel_index.first != node) {
1790         continue;
1791       }
1792       if (AnfAlgo::GetCNodeName(child) == kApplyMomentumOpName) {
1793         return;
1794       }
1795       (void)invalid_position->insert(current_index);
1796       auto iter = insert_nodes->find(j);
1797       if (iter != insert_nodes->end()) {
1798         iter->second.emplace_back(node);
1799       } else {
1800         (*insert_nodes)[j] = {node};
1801       }
1802       return;
1803     }
1804   }
1805 }
1806 
DelayExecNode(const std::vector<CNodePtr> & nodes,const std::string & node_name,bool only_seed)1807 std::vector<CNodePtr> DelayExecNode(const std::vector<CNodePtr> &nodes, const std::string &node_name, bool only_seed) {
1808   std::map<size_t, std::vector<CNodePtr>> insert_nodes;
1809   std::set<size_t> invalid_position;
1810   for (size_t i = 0; i < nodes.size(); ++i) {
1811     auto &node = nodes[i];
1812     if (AnfAlgo::GetCNodeName(node) != node_name) {
1813       continue;
1814     }
1815     if (only_seed) {
1816       bool is_seed = true;
1817       auto input_size = node->inputs().size() - 1;
1818       for (size_t k = 0; k < input_size; ++k) {
1819         auto input = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(node, k), 0, true).first;
1820         if (input != nullptr && input->isa<CNode>()) {
1821           is_seed = false;
1822           break;
1823         }
1824       }
1825       if (!is_seed) {
1826         continue;
1827       }
1828     }
1829     FindDelayExecPosition(nodes, i, &invalid_position, &insert_nodes);
1830   }
1831   std::vector<CNodePtr> result;
1832   for (size_t i = 0; i < nodes.size(); ++i) {
1833     auto iter = insert_nodes.find(i);
1834     if (iter != insert_nodes.end()) {
1835       (void)result.insert(result.end(), iter->second.rbegin(), iter->second.rend());
1836     }
1837     if (invalid_position.find(i) != invalid_position.end()) {
1838       continue;
1839     }
1840     result.emplace_back(nodes[i]);
1841   }
1842   return result;
1843 }
1844 }  // namespace
1845 
ReorderExecList(NotNull<std::vector<CNodePtr> * > node_list)1846 void AnfRuntimeAlgorithm::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list) {
1847   std::vector<CNodePtr> result;
1848   std::copy(node_list->begin(), node_list->end(), std::back_inserter(result));
1849   result = DelayExecNode(result, "TransData", true);
1850   result = DelayExecNode(result, "Cast", true);
1851   result = DelayExecNode(result, "AdamApplyOneWithDecay", false);
1852   result = DelayExecNode(result, "AdamApplyOne", false);
1853   node_list->clear();
1854   std::copy(result.begin(), result.end(), std::back_inserter(*node_list));
1855 }
1856 
ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> * > node_list)1857 void AnfRuntimeAlgorithm::ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> *> node_list) {
1858   std::vector<CNodePtr> ordinary_node_list;
1859   std::vector<CNodePtr> posterior_node_list;
1860 
1861   for (const auto &node : *node_list) {
1862     MS_EXCEPTION_IF_NULL(node);
1863     if (kPosteriorOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kPosteriorOperatorSet.end()) {
1864       posterior_node_list.emplace_back(node);
1865     } else {
1866       ordinary_node_list.emplace_back(node);
1867     }
1868   }
1869   node_list->clear();
1870   std::copy(ordinary_node_list.begin(), ordinary_node_list.end(), std::back_inserter(*node_list));
1871   std::copy(posterior_node_list.begin(), posterior_node_list.end(), std::back_inserter(*node_list));
1872 }
1873 
GetCNodeOutputPrecision(const AnfNodePtr & node)1874 TypeId AnfRuntimeAlgorithm::GetCNodeOutputPrecision(const AnfNodePtr &node) {
1875   MS_EXCEPTION_IF_NULL(node);
1876   auto prim = AnfAlgo::GetCNodePrimitive(node);
1877   if (prim == nullptr) {
1878     return kTypeUnknown;
1879   }
1880 
1881   TypeId except_type = kTypeUnknown;
1882   if (prim->GetAttr(kAttrOutputPrecision) != nullptr) {
1883     auto output_type_str = GetValue<std::string>(prim->GetAttr(kAttrOutputPrecision));
1884     if (output_type_str == "float16") {
1885       except_type = kNumberTypeFloat16;
1886     } else if (output_type_str == "float32") {
1887       except_type = kNumberTypeFloat32;
1888     } else {
1889       MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got " << output_type_str
1890                         << " trace: " << trace::DumpSourceLines(node);
1891     }
1892   }
1893 
1894   return except_type;
1895 }
1896 
GetPrevNodeOutputPrecision(const AnfNodePtr & node,size_t input_idx)1897 TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx) {
1898   MS_EXCEPTION_IF_NULL(node);
1899   if (!node->isa<CNode>()) {
1900     MS_LOG(EXCEPTION) << node->DebugString() << ", input node is not CNode."
1901                       << " trace: " << trace::DumpSourceLines(node);
1902   }
1903   auto cnode = node->cast<CNodePtr>();
1904   MS_EXCEPTION_IF_NULL(cnode);
1905   if (input_idx + 1 >= cnode->inputs().size()) {
1906     MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode)
1907                       << " trace: " << trace::DumpSourceLines(node);
1908   }
1909   auto input_node = cnode->input(input_idx + 1);
1910   MS_EXCEPTION_IF_NULL(input_node);
1911   auto kernel_with_index = VisitKernel(input_node, 0);
1912   if (!kernel_with_index.first->isa<CNode>()) {
1913     return kTypeUnknown;
1914   }
1915   return GetCNodeOutputPrecision(kernel_with_index.first);
1916 }
1917 
IsCondControlKernel(const CNodePtr & node)1918 bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) {
1919   MS_EXCEPTION_IF_NULL(node);
1920   if (node->inputs().empty()) {
1921     MS_LOG(EXCEPTION) << "Illegal null input of cnode."
1922                       << " trace: " << trace::DumpSourceLines(node);
1923   }
1924   auto input = node->input(kAnfPrimitiveIndex);
1925   return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch);
1926 }
1927 
IsIndependentNode(const CNodePtr & node)1928 bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) {
1929   MS_EXCEPTION_IF_NULL(node);
1930   if (AnfAlgo::GetKernelType(node) != AICPU_KERNEL) {
1931     return false;
1932   }
1933 
1934   if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) {
1935     MS_LOG(INFO) << "GetNext should not be independent node";
1936     return false;
1937   }
1938 
1939   // aicpu stack ops are not independent nodes.
1940   if (AnfAlgo::GetCNodeName(node) == kStackInitOpName || AnfAlgo::GetCNodeName(node) == kStackDestroyOpName ||
1941       AnfAlgo::GetCNodeName(node) == kStackPopOpName || AnfAlgo::GetCNodeName(node) == kStackPushOpName) {
1942     MS_LOG(INFO) << "AICPU stack ops should not be independent node";
1943     return false;
1944   }
1945 
1946   size_t input_nums = AnfAlgo::GetInputTensorNum(node);
1947   if (input_nums == 0) {
1948     return true;
1949   }
1950 
1951   auto inputs = node->inputs();
1952   for (size_t i = 1; i < inputs.size(); i++) {
1953     if (!inputs[i]->isa<ValueNode>()) {
1954       return false;
1955     }
1956   }
1957   return true;
1958 }
1959 
GetBooleanAttr(const AnfNodePtr & node,const std::string & attr)1960 bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::string &attr) {
1961   MS_EXCEPTION_IF_NULL(node);
1962   if (!node->isa<CNode>()) {
1963     return false;
1964   }
1965   auto cnode = node->cast<CNodePtr>();
1966   MS_EXCEPTION_IF_NULL(cnode);
1967   auto has_attr = AnfAlgo::HasNodeAttr(attr, cnode);
1968   if (!has_attr) {
1969     return false;
1970   }
1971   return AnfAlgo::GetNodeAttr<bool>(node, attr);
1972 }
1973 
HasDynamicShapeFlag(const PrimitivePtr & prim)1974 bool AnfRuntimeAlgorithm::HasDynamicShapeFlag(const PrimitivePtr &prim) {
1975   auto get_bool_attr = [](const PrimitivePtr &primitive, const std::string &attr_name) -> bool {
1976     MS_EXCEPTION_IF_NULL(primitive);
1977     if (!primitive->HasAttr(attr_name)) {
1978       return false;
1979     }
1980     return GetValue<bool>(primitive->GetAttr(attr_name));
1981   };
1982   return get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape) ||
1983          get_bool_attr(prim, kAttrIsDynamicShape);
1984 }
1985 
IsDynamicShape(const AnfNodePtr & node)1986 bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) {
1987   return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) ||
1988          GetBooleanAttr(node, kAttrIsDynamicShape);
1989 }
1990 
GetRealDynamicShape(const std::vector<size_t> & shape,NotNull<std::vector<int64_t> * > dynamic_shape)1991 void AnfRuntimeAlgorithm::GetRealDynamicShape(const std::vector<size_t> &shape,
1992                                               NotNull<std::vector<int64_t> *> dynamic_shape) {
1993   for (auto size : shape) {
1994     if (size == SIZE_MAX) {
1995       dynamic_shape->push_back(-1);
1996     } else {
1997       dynamic_shape->push_back(SizeToLong(size));
1998     }
1999   }
2000 }
2001 
GetShapeFromSequeueShape(const abstract::SequeueShapePtr & sequeue_shape_ptr,size_t index,ShapeType type)2002 std::vector<int64_t> GetShapeFromSequeueShape(const abstract::SequeueShapePtr &sequeue_shape_ptr, size_t index,
2003                                               ShapeType type) {
2004   MS_EXCEPTION_IF_NULL(sequeue_shape_ptr);
2005   auto shape_list = sequeue_shape_ptr->shape();
2006   if (index >= shape_list.size()) {
2007     MS_LOG(EXCEPTION) << "Output Index:" << index << " >= " << shape_list.size();
2008   }
2009 
2010   auto shape = shape_list[index];
2011   MS_EXCEPTION_IF_NULL(shape);
2012   if (shape->isa<abstract::Shape>()) {
2013     auto shape_ptr = shape->cast<abstract::ShapePtr>();
2014     if (type == ShapeType::kMaxShape) {
2015       return shape_ptr->max_shape().empty() ? shape_ptr->shape() : shape_ptr->max_shape();
2016     } else {
2017       return shape_ptr->min_shape().empty() ? shape_ptr->shape() : shape_ptr->min_shape();
2018     }
2019   } else {
2020     MS_LOG(EXCEPTION) << "Invalid Shape Type In Shape List";
2021   }
2022 }
2023 
GetInputMaxShape(const AnfNodePtr & anf_node,size_t index)2024 std::vector<int64_t> AnfRuntimeAlgorithm::GetInputMaxShape(const AnfNodePtr &anf_node, size_t index) {
2025   auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
2026   return GetOutputMaxShape(input_node_with_index.first, input_node_with_index.second);
2027 }
2028 
GetInputMinShape(const AnfNodePtr & anf_node,size_t index)2029 std::vector<int64_t> AnfRuntimeAlgorithm::GetInputMinShape(const AnfNodePtr &anf_node, size_t index) {
2030   auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
2031   return GetOutputMinShape(input_node_with_index.first, input_node_with_index.second);
2032 }
2033 
GetOutputMaxShape(const AnfNodePtr & anf_node,size_t index)2034 std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index) {
2035   MS_EXCEPTION_IF_NULL(anf_node);
2036   auto shape = anf_node->Shape();
2037   MS_EXCEPTION_IF_NULL(shape);
2038   if (shape->isa<abstract::Shape>()) {
2039     auto shape_ptr = shape->cast<abstract::ShapePtr>();
2040     return shape_ptr->max_shape().empty() ? shape_ptr->shape() : shape_ptr->max_shape();
2041   } else if (shape->isa<abstract::SequeueShape>()) {
2042     auto sequeue_shape_ptr = shape->cast<abstract::SequeueShapePtr>();
2043     return GetShapeFromSequeueShape(sequeue_shape_ptr, index, ShapeType::kMaxShape);
2044   } else if (shape->isa<abstract::NoShape>()) {
2045     return {};
2046   } else {
2047     MS_LOG(EXCEPTION) << "Invalid Shape Type"
2048                       << " trace: " << trace::DumpSourceLines(anf_node);
2049   }
2050 }
2051 
GetOutputMinShape(const AnfNodePtr & anf_node,size_t index)2052 std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputMinShape(const AnfNodePtr &anf_node, size_t index) {
2053   MS_EXCEPTION_IF_NULL(anf_node);
2054   auto shape = anf_node->Shape();
2055   MS_EXCEPTION_IF_NULL(shape);
2056   if (shape->isa<abstract::Shape>()) {
2057     auto shape_ptr = shape->cast<abstract::ShapePtr>();
2058     return shape_ptr->min_shape().empty() ? shape_ptr->shape() : shape_ptr->min_shape();
2059   } else if (shape->isa<abstract::SequeueShape>()) {
2060     auto sequeue_shape_ptr = shape->cast<abstract::SequeueShapePtr>();
2061     return GetShapeFromSequeueShape(sequeue_shape_ptr, index, ShapeType::kMinShape);
2062   } else if (shape->isa<abstract::NoShape>()) {
2063     return {};
2064   } else {
2065     MS_LOG(EXCEPTION) << "Invalid Shape Type"
2066                       << " trace: " << trace::DumpSourceLines(anf_node);
2067   }
2068 }
2069 
IsNodeInputDynamicShape(const CNodePtr & anf_node_ptr)2070 bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr) {
2071   MS_EXCEPTION_IF_NULL(anf_node_ptr);
2072   auto input_num = AnfAlgo::GetInputTensorNum(anf_node_ptr);
2073   for (size_t i = 0; i < input_num; ++i) {
2074     auto input_with_index = AnfAlgo::GetPrevNodeOutput(anf_node_ptr, i);
2075     auto input = input_with_index.first;
2076     auto index = input_with_index.second;
2077     MS_EXCEPTION_IF_NULL(input);
2078     auto base_shape = input->Shape();
2079     if (base_shape == nullptr) {
2080       MS_LOG(INFO) << "Invalid shape ptr, node:" << input->fullname_with_scope();
2081       continue;
2082     }
2083     if (base_shape->isa<abstract::Shape>()) {
2084       if (AnfUtils::IsShapeDynamic(base_shape->cast<abstract::ShapePtr>())) {
2085         return true;
2086       }
2087     } else if (base_shape->isa<abstract::TupleShape>()) {
2088       auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
2089       MS_EXCEPTION_IF_NULL(tuple_shape);
2090 
2091       if (index >= tuple_shape->size()) {
2092         MS_LOG(INFO) << "Node:" << anf_node_ptr->fullname_with_scope() << "Invalid index:" << index
2093                      << " and tuple_shape size:" << tuple_shape->size();
2094         continue;
2095       }
2096       auto b_shp = (*tuple_shape)[index];
2097       if (!b_shp->isa<abstract::Shape>()) {
2098         continue;
2099       }
2100       if (AnfUtils::IsShapeDynamic(b_shp->cast<abstract::ShapePtr>())) {
2101         return true;
2102       }
2103     }
2104   }
2105   return false;
2106 }
2107 
IsNodeDynamicShape(const AnfNodePtr & node)2108 bool AnfRuntimeAlgorithm::IsNodeDynamicShape(const AnfNodePtr &node) {
2109   MS_EXCEPTION_IF_NULL(node);
2110   if (!node->isa<CNode>()) {
2111     MS_LOG(DEBUG) << "Node is not a cnode";
2112     return false;
2113   }
2114   auto cnode = node->cast<CNodePtr>();
2115   auto in_dynamic = IsNodeInputDynamicShape(cnode);
2116   auto out_dynamic = AnfUtils::IsNodeOutputDynamicShape(cnode);
2117   if (in_dynamic && !AnfAlgo::HasNodeAttr(kAttrInputIsDynamicShape, cnode)) {
2118     AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cnode);
2119     MS_LOG(INFO) << "Set Input Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
2120   }
2121   if (out_dynamic && !AnfAlgo::HasNodeAttr(kAttrOutputIsDynamicShape, cnode)) {
2122     AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cnode);
2123     MS_LOG(INFO) << "Set Output Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
2124   }
2125   return in_dynamic || out_dynamic;
2126 }
2127 
GetInputRealDeviceShapeIfExist(const AnfNodePtr & anf_node,size_t index)2128 std::vector<size_t> AnfRuntimeAlgorithm::GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index) {
2129   auto device_shape = GetInputDeviceShape(anf_node, index);
2130   // Initialize GPUKernel with max shape to fit 'InitDynamicOutputKernelRef()' for memory reuse.
2131   if (AnfUtils::IsShapeDynamic(device_shape)) {
2132     auto max_shape = GetInputMaxShape(anf_node, index);
2133     std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
2134     auto format = GetInputFormat(anf_node, index);
2135     (void)trans::TransShapeToDevice(device_shape, format, anf_node, index, false);
2136   }
2137   return device_shape;
2138 }
2139 
GetOutputRealDeviceShapeIfExist(const AnfNodePtr & anf_node,size_t index)2140 std::vector<size_t> AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index) {
2141   auto device_shape = GetOutputDeviceShape(anf_node, index);
2142   // Initialize GPUKernel with max shape to fit 'InitDynamicOutputKernelRef()' for memory reuse.
2143   if (AnfUtils::IsShapeDynamic(device_shape)) {
2144     auto max_shape = GetOutputMaxShape(anf_node, index);
2145     std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
2146     auto format = GetOutputFormat(anf_node, index);
2147     (void)trans::TransShapeToDevice(device_shape, format, anf_node, index);
2148   }
2149   return device_shape;
2150 }
2151 
GetAllVisitedCNode(const CNodePtr & anf_node,std::vector<AnfNodePtr> * used_kernels,std::set<AnfNodePtr> * visited)2152 void AnfRuntimeAlgorithm::GetAllVisitedCNode(const CNodePtr &anf_node, std::vector<AnfNodePtr> *used_kernels,
2153                                              std::set<AnfNodePtr> *visited) {
2154   MS_EXCEPTION_IF_NULL(anf_node);
2155   MS_EXCEPTION_IF_NULL(used_kernels);
2156   MS_EXCEPTION_IF_NULL(visited);
2157   if (visited->find(anf_node) != visited->end()) {
2158     MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has already been visited";
2159     return;
2160   }
2161   visited->insert(anf_node);
2162   auto input_size = anf_node->inputs().size() - 1;
2163   for (size_t i = 0; i < input_size; ++i) {
2164     auto input = AnfAlgo::GetInputNode(anf_node, i);
2165     if (!input->isa<CNode>()) {
2166       continue;
2167     }
2168     auto input_cnode = input->cast<CNodePtr>();
2169     if (!IsRealKernelCNode(input_cnode) || opt::IsNopNode(input_cnode)) {
2170       GetAllVisitedCNode(input_cnode, used_kernels, visited);
2171     } else {
2172       used_kernels->push_back(input);
2173     }
2174   }
2175 }
2176 
GetAllFatherRealNode(const AnfNodePtr & anf_node,std::vector<AnfNodePtr> * result,std::set<AnfNodePtr> * visited)2177 void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
2178                                                std::set<AnfNodePtr> *visited) {
2179   MS_EXCEPTION_IF_NULL(anf_node);
2180   MS_EXCEPTION_IF_NULL(result);
2181   MS_EXCEPTION_IF_NULL(visited);
2182   if (visited->find(anf_node) != visited->end()) {
2183     MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has already been visited";
2184     return;
2185   }
2186   visited->insert(anf_node);
2187   if (AnfAlgo::IsRealKernel(anf_node)) {
2188     result->emplace_back(anf_node);
2189     return;
2190   }
2191   if (!anf_node->isa<CNode>()) {
2192     return;
2193   }
2194   auto cnode = anf_node->cast<CNodePtr>();
2195   MS_EXCEPTION_IF_NULL(cnode);
2196   if (cnode->inputs().empty()) {
2197     MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
2198   }
2199   auto input0 = cnode->input(0);
2200   if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
2201     for (size_t i = 1; i < cnode->inputs().size(); ++i) {
2202       GetAllFatherRealNode(cnode->input(i), result, visited);
2203     }
2204   } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
2205     if (cnode->inputs().size() != kTupleGetItemInputSize) {
2206       MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
2207     }
2208     GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
2209   } else if (IsPrimitive(input0, prim::kPrimDepend)) {
2210     if (cnode->inputs().size() != kDependInputSize) {
2211       MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
2212     }
2213     GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
2214     GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
2215   }
2216 }
2217 
InferShape(const CNodePtr & node,std::map<uint32_t,tensor::TensorPtr> * depend_tensors)2218 void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *depend_tensors) {
2219   MS_EXCEPTION_IF_NULL(node);
2220   MS_LOG(INFO) << "InferShape start, node:" << node->DebugString();
2221   auto inputs = node->inputs();
2222   if (inputs.empty()) {
2223     MS_LOG(EXCEPTION) << "Invalid inputs";
2224   }
2225   AbstractBasePtrList args_spec_list;
2226   auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
2227   auto input_size = AnfAlgo::GetInputTensorNum(node);
2228   for (size_t i = 0; i < input_size; ++i) {
2229     auto input_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
2230     auto real_input = input_with_index.first;
2231     auto cnode_input = node->input(i + 1);
2232     MS_EXCEPTION_IF_NULL(cnode_input);
2233     MS_EXCEPTION_IF_NULL(real_input);
2234     if (depend_tensors != nullptr) {
2235       auto iter_tensor = depend_tensors->find(i);
2236       if (iter_tensor != depend_tensors->end()) {
2237         auto tensor_ptr = iter_tensor->second;
2238         MS_EXCEPTION_IF_NULL(tensor_ptr);
2239         // sync data from device to host
2240         tensor_ptr->data_sync();
2241         auto real_abs = real_input->abstract();
2242         if (real_abs->isa<abstract::AbstractTensor>()) {
2243           real_input->abstract()->set_value(tensor_ptr);
2244         } else if (real_abs->isa<abstract::AbstractTuple>()) {
2245           auto tuple_get_item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
2246           auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>();
2247           MS_EXCEPTION_IF_NULL(abstract_tuple);
2248           auto tuple_elements = abstract_tuple->elements()[tuple_get_item_index];
2249           tuple_elements->set_value(tensor_ptr);
2250         }
2251       }
2252     }
2253     if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) {
2254       auto base_shape = real_input->Shape();
2255       if (!base_shape->isa<abstract::TupleShape>()) {
2256         MS_LOG(EXCEPTION) << "Node:" << node->DebugString()
2257                           << " input is a tuple_get_item but real input node shape is not a TupleShape";
2258       }
2259       auto abs = real_input->abstract()->cast<abstract::AbstractTuplePtr>();
2260       MS_EXCEPTION_IF_NULL(abs);
2261       auto tuple_get_item_indexk = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
2262       auto abs_i = abs->elements()[tuple_get_item_indexk];
2263       (void)args_spec_list.emplace_back(abs_i);
2264     } else if (cnode_input->isa<CNode>() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) {
2265       (void)args_spec_list.emplace_back(cnode_input->abstract());
2266     } else {
2267       (void)args_spec_list.emplace_back(real_input->abstract());
2268     }
2269   }
2270   auto eval_result = opt::CppInferShape(primitive, args_spec_list);
2271   node->set_abstract(eval_result);
2272 }
2273 
InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> & root_graph)2274 void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> &root_graph) {
2275   auto return_node = root_graph->get_return();
2276   MS_EXCEPTION_IF_NULL(return_node);
2277   if (return_node->size() <= kReturnDataIndex) {
2278     return;
2279   }
2280   auto make_tuple = root_graph->NewCNode(
2281     {NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
2282   root_graph->set_output(make_tuple);
2283 }
2284 
GetUpdateStateUsers(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)2285 AnfNodeIndexSet AnfRuntimeAlgorithm::GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
2286   AnfNodeIndexSet update_states;
2287   for (auto &user : manager->node_users()[node]) {
2288     if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimUpdateState)) {
2289       update_states.insert(user);
2290     }
2291   }
2292   return update_states;
2293 }
2294 
GetRealInputs(const AnfNodePtr & node,std::vector<session::KernelWithIndex> * inputs)2295 void AnfRuntimeAlgorithm::GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex> *inputs) {
2296   size_t input_num = AnfAlgo::GetInputTensorNum(node);
2297   for (size_t input_index = 0; input_index < input_num; ++input_index) {
2298     auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), input_index);
2299     GetRealOutputRecursively(input_node, 0, inputs);
2300   }
2301 }
2302 
IsTensorBroadcast(const std::vector<size_t> & lhs,const std::vector<size_t> & rhs)2303 bool AnfRuntimeAlgorithm::IsTensorBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) {
2304   if (lhs.size() != rhs.size()) {
2305     return true;
2306   }
2307   for (size_t i = 0; i < lhs.size(); i++) {
2308     if (lhs[i] != rhs[i]) {
2309       return true;
2310     }
2311   }
2312   return false;
2313 }
2314 
IsOneOfPrimitiveCNode(const AnfNodePtr & node,const PrimitiveSet & prim_set)2315 bool AnfRuntimeAlgorithm::IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
2316   MS_EXCEPTION_IF_NULL(node);
2317   auto cnode = node->cast<CNodePtr>();
2318   if (cnode == nullptr || cnode->size() == 0) {
2319     return false;
2320   }
2321   return IsOneOfPrimitive(cnode->inputs().at(kAnfPrimitiveIndex), prim_set);
2322 }
2323 
IsControlOpExecInBackend(const AnfNodePtr & node)2324 bool AnfRuntimeAlgorithm::IsControlOpExecInBackend(const AnfNodePtr &node) {
2325   if (!node->isa<CNode>()) {
2326     return false;
2327   }
2328   // Operators in set control_ops_exec_in_backend will be compiled into kernel graph, rather than be cut into single op
2329   // and executed in VM.
2330   static std::set<std::string> control_ops_exec_in_backend = {kBpropCutOpName};
2331   return control_ops_exec_in_backend.find(AnfAlgo::GetCNodeName(node)) != control_ops_exec_in_backend.end();
2332 }
2333 
IsNodeInputContainMonad(const AnfNodePtr & node)2334 bool AnfRuntimeAlgorithm::IsNodeInputContainMonad(const AnfNodePtr &node) {
2335   auto input_size = GetInputTensorNum(node);
2336   for (size_t i = 0; i < input_size; ++i) {
2337     auto input_with_index = GetPrevNodeOutput(node, i);
2338     if (HasAbstractMonad(input_with_index.first)) {
2339       return true;
2340     }
2341   }
2342   return false;
2343 }
2344 
CacheAddrForGraph(const KernelGraphPtr & kernel_graph)2345 void AnfRuntimeAlgorithm::CacheAddrForGraph(const KernelGraphPtr &kernel_graph) {
2346   MS_EXCEPTION_IF_NULL(kernel_graph);
2347   auto ms_context = MsContext::GetInstance();
2348   MS_EXCEPTION_IF_NULL(ms_context);
2349   if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
2350       ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) == true) {
2351     return;
2352   }
2353   auto nodes = kernel_graph->execution_order();
2354   for (auto &kernel : nodes) {
2355     // Skip transpose kernel with "nop_op" attr which is not hidden or removed in PyNative infer scenario. Transpose
2356     // kernel, which is not supposed to be executed, is generated in TransDataSplit to support specific Transdata.
2357     // And hard code here should be removed after new Transdata programme is implemented in the foreseeable future.
2358     if (HasNodeAttr("nop_op", kernel)) {
2359       for (size_t idx = 0; idx < GetOutputTensorNum(kernel); idx += 1) {
2360         auto real_input = GetRealInputIndex(kernel, idx);
2361         auto device_address = GetPrevNodeMutableOutputAddr(kernel, real_input);
2362         SetOutputAddr(device_address, idx, kernel.get());
2363       }
2364       continue;
2365     }
2366     auto kernel_mod = GetKernelMod(kernel);
2367     MS_EXCEPTION_IF_NULL(kernel_mod);
2368     if (GetCNodeName(kernel) == kAtomicAddrCleanOpName) {
2369       CacheAddrForAtomicClean(kernel, kernel_mod);
2370       continue;
2371     }
2372     CacheAddrForKernel(kernel, kernel_mod);
2373   }
2374 }
2375 
CacheAddrForKernel(const AnfNodePtr & node,kernel::KernelMod * kernel_mod)2376 void AnfRuntimeAlgorithm::CacheAddrForKernel(const AnfNodePtr &node, kernel::KernelMod *kernel_mod) {
2377   MS_EXCEPTION_IF_NULL(node);
2378   MS_EXCEPTION_IF_NULL(kernel_mod);
2379   std::vector<AddressPtr> kernel_inputs;
2380   std::vector<AddressPtr> kernel_workspaces;
2381   std::vector<AddressPtr> kernel_outputs;
2382   auto cnode = node->cast<CNodePtr>();
2383   MS_EXCEPTION_IF_NULL(cnode);
2384   auto ms_context = MsContext::GetInstance();
2385   MS_EXCEPTION_IF_NULL(ms_context);
2386   auto visit_nop_node = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode);
2387   size_t input_num = GetInputTensorNum(node);
2388   for (size_t i = 0; i < input_num; ++i) {
2389     auto op_name = GetCNodeName(cnode);
2390     constexpr auto none_placeholder_index = 3;
2391     if (op_name == kDynamicRNNOpName && i == none_placeholder_index) {
2392       continue;
2393     }
2394     if (op_name == kDynamicGRUV2OpName) {
2395       auto none_index = GetNodeAttr<std::vector<int64_t>>(cnode, "placeholder_index");
2396       auto item = std::find(none_index.begin(), none_index.end(), i);
2397       if (item != none_index.end()) {
2398         continue;
2399       }
2400     }
2401     auto real_input = GetRealInputIndex(node, i);
2402     auto device_address = GetPrevNodeOutputAddr(node, real_input, visit_nop_node);
2403     MS_EXCEPTION_IF_NULL(device_address);
2404     kernel::AddressPtr input = std::make_shared<kernel::Address>();
2405     MS_EXCEPTION_IF_NULL(input);
2406     input->addr = const_cast<void *>(device_address->GetPtr());
2407     MS_EXCEPTION_IF_NULL(input->addr);
2408     input->size = device_address->GetSize();
2409     kernel_inputs.emplace_back(input);
2410   }
2411   for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
2412     auto device_address = GetOutputAddr(node, i, visit_nop_node);
2413     kernel::AddressPtr output = std::make_shared<kernel::Address>();
2414     MS_EXCEPTION_IF_NULL(output);
2415     output->addr = const_cast<void *>(device_address->GetPtr());
2416     MS_EXCEPTION_IF_NULL(output->addr);
2417     output->size = device_address->GetSize();
2418     kernel_outputs.emplace_back(output);
2419   }
2420   for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
2421     auto device_address = GetWorkspaceAddr(node, i);
2422     kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
2423     MS_EXCEPTION_IF_NULL(workspace);
2424     workspace->addr = const_cast<void *>(device_address->GetPtr());
2425     MS_EXCEPTION_IF_NULL(workspace->addr);
2426     workspace->size = device_address->GetSize();
2427     kernel_workspaces.emplace_back(workspace);
2428   }
2429   kernel_mod->set_inputs_addr(kernel_inputs);
2430   kernel_mod->set_workspaces_addr(kernel_workspaces);
2431   kernel_mod->set_outputs_addr(kernel_outputs);
2432 }
2433 
CacheAddrForAtomicClean(const AnfNodePtr & node,kernel::KernelMod * kernel_mod)2434 void AnfRuntimeAlgorithm::CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod) {
2435   MS_EXCEPTION_IF_NULL(node);
2436   MS_EXCEPTION_IF_NULL(kernel_mod);
2437   std::vector<AddressPtr> kernel_inputs;
2438   auto cnode = node->cast<CNodePtr>();
2439   MS_EXCEPTION_IF_NULL(cnode);
2440   if (cnode->inputs().size() != kIndex2) {
2441     MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
2442   }
2443   MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
2444   auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
2445   // set clean output address
2446   if (HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
2447 #if defined(__APPLE__)
2448     auto clean_output_indexes = GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicOutputIndexs);
2449 #else
2450     auto clean_output_indexes = GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
2451 #endif
2452     for (auto index : clean_output_indexes) {
2453       auto device_address = GetOutputAddr(pre_node, index);
2454       kernel::AddressPtr input = std::make_shared<kernel::Address>();
2455       MS_EXCEPTION_IF_NULL(input);
2456       input->addr = const_cast<void *>(device_address->GetPtr());
2457       MS_EXCEPTION_IF_NULL(input->addr);
2458       input->size = device_address->GetSize();
2459       kernel_inputs.emplace_back(input);
2460     }
2461     MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexes.size();
2462   }
2463   // set clean workspace address
2464   if (HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
2465 #if defined(__APPLE__)
2466     auto clean_workspaces_indexes = GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicWorkspaceIndexs);
2467 #else
2468     auto clean_workspaces_indexes = GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
2469 #endif
2470     for (const auto &index : clean_workspaces_indexes) {
2471       auto device_address = GetWorkspaceAddr(pre_node, index);
2472       kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
2473       MS_EXCEPTION_IF_NULL(workspace);
2474       workspace->addr = const_cast<void *>(device_address->GetPtr());
2475       MS_EXCEPTION_IF_NULL(workspace->addr);
2476       workspace->size = device_address->GetSize();
2477       kernel_inputs.emplace_back(workspace);
2478     }
2479   }
2480   kernel_mod->set_inputs_addr(kernel_inputs);
2481 }
2482 
output_format(size_t index) const2483 std::string OpRuntimeInfo::output_format(size_t index) const {
2484   if (index >= output_format_.size()) {
2485     MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_format:" << output_format_.size();
2486   }
2487   return output_format_[index];
2488 }
2489 
output_type(size_t index) const2490 TypeId OpRuntimeInfo::output_type(size_t index) const {
2491   if (index >= output_type_.size()) {
2492     MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_type:" << output_type_.size();
2493   }
2494   return output_type_[index];
2495 }
2496 
output_tensor_size(size_t index) const2497 size_t OpRuntimeInfo::output_tensor_size(size_t index) const {
2498   if (index >= output_tensor_size_.size()) {
2499     MS_LOG(EXCEPTION) << "Invalid index::" << index << " total output_tensor_size:" << output_tensor_size_.size();
2500   }
2501   return output_tensor_size_[index];
2502 }
2503 }  // namespace session
2504 }  // namespace mindspore
2505