1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/session/anf_runtime_algorithm.h"
17 #include <memory>
18 #include <algorithm>
19 #include <map>
20 #include <set>
21 #include <functional>
22 #include <numeric>
23 #include "ir/anf.h"
24 #include "ir/func_graph.h"
25 #include "base/core_ops.h"
26 #include "utils/utils.h"
27 #include "utils/shape_utils.h"
28 #include "runtime/device/kernel_info.h"
29 #include "runtime/device/device_address.h"
30 #include "backend/optimizer/common/helper.h"
31 #include "backend/kernel_compiler/kernel.h"
32 #include "backend/kernel_compiler/kernel_build_info.h"
33 #include "common/trans.h"
34 #include "abstract/param_validator.h"
35 #include "pipeline/jit/static_analysis/static_analysis.h"
36 #include "utils/trace_base.h"
37 #include "ir/anf_utils.h"
38
39 namespace mindspore {
40 namespace session {
41 using abstract::AbstractTensor;
42 using abstract::AbstractTuple;
43 using device::KernelInfo;
44 using device::ascend::AscendDeviceAddress;
45 using kernel::KernelBuildInfoPtr;
46 using kernel::KernelMod;
47 using kernel::KernelModPtr;
48 namespace {
49 constexpr size_t kNopNodeInputSize = 2;
50 constexpr size_t kNopNodeRealInputIndex = 1;
51 constexpr size_t kReturnDataIndex = 1;
52
53 const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad};
54
IsOneOfPrimitive(const AnfNodePtr & node,const PrimitiveSet & prim_set)55 bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
56 PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
57 return (prim && prim_set.find(prim) != prim_set.end());
58 }
59
IsRealKernelCNode(const CNodePtr & cnode)60 bool IsRealKernelCNode(const CNodePtr &cnode) {
61 #ifndef ENABLE_SECURITY
62 static const PrimitiveSet virtual_prims = {
63 prim::kPrimImageSummary, prim::kPrimScalarSummary, prim::kPrimTensorSummary, prim::kPrimHistogramSummary,
64 prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem, prim::kPrimReturn,
65 prim::kPrimPartial, prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
66 #else
67 static const PrimitiveSet virtual_prims = {prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem,
68 prim::kPrimReturn, prim::kPrimPartial, prim::kPrimDepend,
69 prim::kPrimUpdateState, prim::kPrimLoad};
70 #endif
71 MS_EXCEPTION_IF_NULL(cnode);
72 if (cnode->inputs().empty()) {
73 MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << cnode->DebugString();
74 }
75 const auto &input = cnode->inputs().at(0);
76 bool is_virtual_node = IsOneOfPrimitive(input, virtual_prims);
77 return !is_virtual_node;
78 }
79
TransShapeToSizet(const abstract::ShapePtr & shape)80 std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
81 MS_EXCEPTION_IF_NULL(shape);
82 std::vector<size_t> shape_size_t;
83 if (AnfUtils::IsShapeDynamic(shape)) {
84 if (std::all_of(shape->max_shape().begin(), shape->max_shape().end(), [](int64_t s) { return s >= 0; })) {
85 std::transform(shape->max_shape().begin(), shape->max_shape().end(), std::back_inserter(shape_size_t),
86 LongToSize);
87 } else {
88 MS_LOG(EXCEPTION) << "Invalid Max Shape";
89 }
90 } else {
91 std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), LongToSize);
92 }
93 return shape_size_t;
94 }
95
96 enum class ShapeType { kMaxShape, kMinShape };
97
GetRealOutputRecursively(const AnfNodePtr & node,size_t output_index,std::vector<session::KernelWithIndex> * inputs)98 void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
99 std::vector<session::KernelWithIndex> *inputs) {
100 MS_EXCEPTION_IF_NULL(node);
101 if (node->isa<ValueNode>() || node->isa<Parameter>()) {
102 return inputs->push_back(std::make_pair(node, 0));
103 }
104
105 // Skip control node
106 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) ||
107 AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
108 return GetRealOutputRecursively(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), 0, inputs);
109 }
110
111 // Bypass TupleGetItem
112 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
113 auto tuple_get_item = node->cast<CNodePtr>();
114 MS_EXCEPTION_IF_NULL(tuple_get_item);
115 auto input = AnfAlgo::GetTupleGetItemRealInput(tuple_get_item);
116 auto index = AnfAlgo::GetTupleGetItemOutIndex(tuple_get_item);
117
118 // Conceal MakeTuple + TupleGetItem pair.
119 if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
120 auto make_tuple = input->cast<CNodePtr>();
121 MS_EXCEPTION_IF_NULL(make_tuple);
122 auto real_input = AnfAlgo::GetInputNode(make_tuple, index);
123 return GetRealOutputRecursively(real_input, 0, inputs);
124 }
125
126 // Skip TupleGetItem.
127 return GetRealOutputRecursively(input, index, inputs);
128 }
129
130 // Flatten MakeTuple inputs.
131 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
132 auto make_tuple = node->cast<CNodePtr>();
133 MS_EXCEPTION_IF_NULL(make_tuple);
134 size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple);
135 for (size_t input_index = 0; input_index < input_num; ++input_index) {
136 auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index);
137 GetRealOutputRecursively(input_node, 0, inputs);
138 }
139 return;
140 }
141
142 return inputs->push_back(std::make_pair(node, output_index));
143 }
144
145 // ops map that dynamic input order is differ from the fixed shape ops
146 static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_dynamic_node_list = {
147 {prim::kPrimConv2DBackpropInput->name(), {{{0, 2}, {1, 1}, {2, 0}}, {{0, 2}, {1, 1}, {2, 0}}}},
148 {prim::kPrimConv2DBackpropFilter->name(), {{{0, 1}, {1, 2}, {2, 0}}, {{1, 0}, {2, 1}, {0, 2}}}}};
149
150 // pair: ms input order to tbe input order, and tbe input order to ms input order
151 static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_node_list = {
152 {prim::kPrimConv2DBackpropInput->name(), {{{0, 1}, {1, 0}}, {{0, 1}, {1, 0}}}},
153 {kFusionOpConv2DBackpropInputReluGradV2Name, {{{0, 1}, {1, 0}, {2, 2}}, {{0, 1}, {1, 0}, {2, 2}}}},
154 {kFusionOpConv2DBackpropInputAddNReluGradV2Name,
155 {{{0, 1}, {1, 0}, {2, 2}, {3, 3}}, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}},
156 {prim::kPrimConv2DBackpropFilter->name(), {{{0, 1}, {1, 0}}, {{0, 1}, {1, 0}}}},
157 {prim::kPrimLogSoftmaxGrad->name(), {{{0, 1}, {1, 0}}, {{0, 1}, {1, 0}}}},
158 {prim::kPrimLayerNormGrad->name(),
159 {{{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}, {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}},
160 {prim::kPrimLayerNormBetaGammaBackprop->name(), {{{0, 1}, {1, 0}, {2, 2}, {3, 3}}, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}},
161 {prim::kPrimLayerNormXBackprop->name(),
162 {{{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}, {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}},
163 {prim::kPrimLayerNormXBackpropV2->name(),
164 {{{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}, {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}},
165 {prim::kPrimMinimumGrad->name(), {{{0, 2}, {1, 0}, {2, 1}}, {{2, 0}, {0, 1}, {1, 2}}}},
166 {prim::kPrimMaximumGrad->name(), {{{0, 2}, {1, 0}, {2, 1}}, {{2, 0}, {0, 1}, {1, 2}}}},
167 {prim::kPrimApplyCenteredRMSProp->name(),
168 {{{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 8}, {8, 4}},
169 {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {5, 4}, {6, 5}, {7, 6}, {8, 7}, {4, 8}}}}};
170 } // namespace
171
MakeMonadValueNode(const KernelGraphPtr & kg)172 AnfNodePtr AnfRuntimeAlgorithm::MakeMonadValueNode(const KernelGraphPtr &kg) {
173 return kg->NewValueNode(kUMonad->ToAbstract(), kUMonad);
174 }
175
176 // Convert: a = former(xxx)
177 // b = latter(x, xxx)
178 // To: a = former(xxx)
179 // d1 = Depend(x, a)
180 // b = latter(d1, xxx)
181 // ...
182 // out = Depend(out, latter)
KeepOrder(const KernelGraphPtr & kg,const AnfNodePtr & former,const AnfNodePtr & latter)183 void AnfRuntimeAlgorithm::KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter) {
184 MS_EXCEPTION_IF_NULL(kg);
185 MS_EXCEPTION_IF_NULL(latter);
186 if (latter->isa<CNode>()) {
187 auto latter_cnode = latter->cast<CNodePtr>();
188 MS_EXCEPTION_IF_NULL(latter_cnode);
189 constexpr size_t inputsize = 2;
190 constexpr size_t kFirstDataInputIndex = 1;
191 if (latter_cnode->inputs().size() < inputsize) {
192 return;
193 }
194 auto latter_input = latter_cnode->input(kFirstDataInputIndex);
195 auto depend1 = kg->NewCNode({NewValueNode(prim::kPrimDepend), latter_input, former});
196 MS_EXCEPTION_IF_NULL(depend1);
197 depend1->set_abstract(latter_input->abstract());
198 latter_cnode->set_input(kFirstDataInputIndex, depend1);
199
200 auto return_node = kg->get_return();
201 MS_EXCEPTION_IF_NULL(return_node);
202 auto depend2 = kg->NewCNode(
203 {NewValueNode(prim::kPrimDepend), return_node->cast<CNodePtr>()->input(kFirstDataInputIndex), latter});
204 MS_EXCEPTION_IF_NULL(depend2);
205 depend2->set_abstract(return_node->cast<CNodePtr>()->input(kFirstDataInputIndex)->abstract());
206 kg->set_output(depend2);
207 MS_LOG(DEBUG) << "former: " << former->DebugString() << ", latter: " << latter->DebugString()
208 << ", depend1: " << depend1->DebugString() << ", depend2: " << depend2->DebugString();
209 }
210 }
211
GetTupleGetItemRealInput(const CNodePtr & tuple_get_item)212 AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
213 MS_EXCEPTION_IF_NULL(tuple_get_item);
214 if (tuple_get_item->size() != kTupleGetItemInputSize) {
215 MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
216 }
217 return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
218 }
219
GetTupleGetItemOutIndex(const CNodePtr & tuple_get_item)220 size_t AnfRuntimeAlgorithm::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
221 MS_EXCEPTION_IF_NULL(tuple_get_item);
222 if (tuple_get_item->size() != kTupleGetItemInputSize) {
223 MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
224 }
225 auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
226 MS_EXCEPTION_IF_NULL(output_index_value_node);
227 auto value_node = output_index_value_node->cast<ValueNodePtr>();
228 MS_EXCEPTION_IF_NULL(value_node);
229 return LongToSize(GetValue<int64_t>(value_node->value()));
230 }
231
VisitKernel(const AnfNodePtr & anf_node,size_t index)232 KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
233 MS_EXCEPTION_IF_NULL(anf_node);
234 if (anf_node->isa<ValueNode>()) {
235 return std::make_pair(anf_node, 0);
236 } else if (anf_node->isa<Parameter>()) {
237 return std::make_pair(anf_node, 0);
238 } else if (anf_node->isa<CNode>()) {
239 auto cnode = anf_node->cast<CNodePtr>();
240 MS_EXCEPTION_IF_NULL(cnode);
241 auto input0 = cnode->input(0);
242 MS_EXCEPTION_IF_NULL(input0);
243 if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
244 if (AnfAlgo::GetInputTensorNum(cnode) == 0) {
245 return std::make_pair(nullptr, 0);
246 }
247 auto node = cnode->input(index + IntToSize(1));
248 MS_EXCEPTION_IF_NULL(node);
249 return VisitKernel(node, 0);
250 } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
251 if (cnode->inputs().size() != kTupleGetItemInputSize) {
252 MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
253 }
254 auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
255 MS_EXCEPTION_IF_NULL(input2);
256 auto value_node = input2->cast<ValueNodePtr>();
257 MS_EXCEPTION_IF_NULL(value_node);
258 auto item_idx = GetValue<int64_t>(value_node->value());
259 return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), LongToSize(item_idx));
260 } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
261 return VisitKernel(cnode->input(kUpdateStateRealInput), 0);
262 } else if (IsOneOfPrimitive(input0, follow_first_input_prims)) {
263 return VisitKernel(cnode->input(kRealInputIndexInDepend), 0);
264 } else {
265 return std::make_pair(anf_node, index);
266 }
267 } else {
268 MS_LOG(EXCEPTION) << "The input is invalid";
269 }
270 }
271
VisitKernelWithReturnType(const AnfNodePtr & anf_node,size_t index,bool skip_nop_node,const std::vector<PrimitivePtr> & return_types)272 KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
273 bool skip_nop_node,
274 const std::vector<PrimitivePtr> &return_types) {
275 MS_EXCEPTION_IF_NULL(anf_node);
276 if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool {
277 return CheckPrimitiveType(anf_node, prim_type);
278 })) {
279 return KernelWithIndex(anf_node, index);
280 }
281 if (!anf_node->isa<CNode>()) {
282 return KernelWithIndex(anf_node, 0);
283 }
284 auto cnode = anf_node->cast<CNodePtr>();
285 MS_EXCEPTION_IF_NULL(cnode);
286 if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
287 auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode),
288 GetTupleGetItemOutIndex(cnode), skip_nop_node, return_types);
289 if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) {
290 MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
291 auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
292 MS_EXCEPTION_IF_NULL(make_tuple);
293 const std::vector<AnfNodePtr> &make_tuple_inputs = make_tuple->inputs();
294 size_t make_tuple_input_index = item_with_index_tmp.second + 1;
295 if (make_tuple_input_index >= make_tuple_inputs.size()) {
296 MS_LOG(EXCEPTION) << "Index[" << make_tuple_input_index << "] out of range[" << make_tuple_inputs.size()
297 << "].";
298 }
299 return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, skip_nop_node, return_types);
300 }
301 return item_with_index_tmp;
302 }
303 if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
304 return VisitKernelWithReturnType(cnode->input(kUpdateStateStateInput), index, skip_nop_node, return_types);
305 }
306 if (AnfAlgo::IsOneOfPrimitiveCNode(cnode, follow_first_input_prims)) {
307 return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, skip_nop_node, return_types);
308 }
309 if (opt::IsNopNode(cnode) && skip_nop_node) {
310 if (cnode->size() != kNopNodeInputSize) {
311 MS_LOG(EXCEPTION) << "Invalid nop node " << cnode->DebugString() << " trace: " << trace::DumpSourceLines(cnode);
312 }
313 return VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, skip_nop_node, return_types);
314 }
315 return KernelWithIndex(anf_node, index);
316 }
317
GetAllOutput(const AnfNodePtr & node,const std::vector<PrimitivePtr> & return_types)318 std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
319 const std::vector<PrimitivePtr> &return_types) {
320 std::vector<AnfNodePtr> ret;
321 auto return_prim_type = return_types;
322 // if visited make_tuple should return back
323 return_prim_type.push_back(prim::kPrimMakeTuple);
324 auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type);
325 if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
326 MS_EXCEPTION_IF_NULL(item_with_index.first);
327 auto make_tuple = item_with_index.first->cast<CNodePtr>();
328 MS_EXCEPTION_IF_NULL(make_tuple);
329 for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
330 auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types);
331 (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret));
332 }
333 return ret;
334 }
335 ret.push_back(item_with_index.first);
336 return ret;
337 }
338
GetAllOutputWithIndex(const AnfNodePtr & node)339 std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node) {
340 std::vector<KernelWithIndex> ret;
341 std::vector<KernelWithIndex> ret_empty;
342
343 // The makeTuple node need expand and recurse.
344 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
345 auto make_tuple = node->cast<CNodePtr>();
346 MS_EXCEPTION_IF_NULL(make_tuple);
347 for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
348 auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i));
349 (void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret));
350 }
351 return ret;
352 }
353
354 // The depend node need get the real node.
355 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
356 auto depend_node = node->cast<CNodePtr>();
357 MS_EXCEPTION_IF_NULL(depend_node);
358 auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend));
359 (void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret));
360 return ret;
361 }
362
363 // Value node need get all the elements.
364 if (node->isa<ValueNode>()) {
365 auto value = node->cast<ValueNodePtr>()->value();
366 MS_EXCEPTION_IF_NULL(value);
367 if (value->isa<None>()) {
368 return ret;
369 } else if (value->isa<ValueTuple>()) {
370 auto value_tuple = value->cast<ValueTuplePtr>();
371 auto value_tuple_size = CountValueNum(value_tuple);
372 for (size_t i = 0; i < value_tuple_size; ++i) {
373 (void)ret.emplace_back(node, i);
374 }
375 } else {
376 (void)ret.emplace_back(node, 0);
377 }
378 return ret;
379 }
380
381 const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimMakeTuple};
382 size_t outputs_num = 1;
383 if (IsRealCNodeKernel(node)) {
384 outputs_num = AnfAlgo::GetOutputTensorNum(node);
385 }
386 // The output may be the tuple of node, so need visit all the outputs of node.
387 for (size_t i = 0; i < outputs_num; ++i) {
388 auto output_with_index = AnfAlgo::VisitKernelWithReturnType(node, i, false, return_types);
389 MS_EXCEPTION_IF_NULL(output_with_index.first);
390
391 // The depend and makeTuple node need recurse.
392 if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimDepend) ||
393 AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) {
394 auto output_vector = GetAllOutputWithIndex(output_with_index.first);
395 (void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
396 continue;
397 }
398
399 // Ignore the output of front call node.
400 if (output_with_index.first->isa<CNode>()) {
401 auto cnode = output_with_index.first->cast<CNodePtr>();
402 MS_EXCEPTION_IF_NULL(cnode);
403 auto inputs = cnode->inputs();
404 if (inputs[0]->isa<CNode>()) {
405 MS_LOG(INFO) << "The output is call node: " << output_with_index.first->DebugString();
406 return ret_empty;
407 }
408 }
409
410 // The InitDataSetQueue node has no output.
411 if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimInitDataSetQueue)) {
412 return ret_empty;
413 }
414
415 MS_LOG(INFO) << "Output node: " << output_with_index.first->fullname_with_scope()
416 << " with output index: " << output_with_index.second;
417 ret.push_back(output_with_index);
418 }
419
420 return ret;
421 }
422
GetCNodePrimitiveNode(const CNodePtr & node)423 AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) {
424 MS_EXCEPTION_IF_NULL(node);
425 return node->input(kAnfPrimitiveIndex);
426 }
427
GetCNodePrimitive(const AnfNodePtr & node)428 PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) {
429 MS_EXCEPTION_IF_NULL(node);
430 auto cnode = node->cast<CNodePtr>();
431 MS_EXCEPTION_IF_NULL(cnode);
432 auto attr_input = GetCNodePrimitiveNode(cnode);
433 MS_EXCEPTION_IF_NULL(attr_input);
434 auto value_node = attr_input->cast<ValueNodePtr>();
435 MS_EXCEPTION_IF_NULL(value_node);
436 auto value = value_node->value();
437 MS_EXCEPTION_IF_NULL(value);
438 auto primitive = value->cast<PrimitivePtr>();
439 return primitive;
440 }
441
CheckPrimitiveType(const AnfNodePtr & node,const PrimitivePtr & primitive_type)442 bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
443 MS_EXCEPTION_IF_NULL(node);
444 if (!node->isa<CNode>()) {
445 return false;
446 }
447 auto cnode = node->cast<CNodePtr>();
448 MS_EXCEPTION_IF_NULL(cnode);
449 return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
450 }
451
GetCNodeFuncGraphPtr(const AnfNodePtr & node)452 FuncGraphPtr AnfRuntimeAlgorithm::GetCNodeFuncGraphPtr(const AnfNodePtr &node) {
453 MS_EXCEPTION_IF_NULL(node);
454 auto cnode = node->cast<CNodePtr>();
455 MS_EXCEPTION_IF_NULL(cnode);
456 auto attr_input = cnode->input(kAnfPrimitiveIndex);
457 MS_EXCEPTION_IF_NULL(attr_input);
458 auto value_node = attr_input->cast<ValueNodePtr>();
459 MS_EXCEPTION_IF_NULL(value_node);
460 auto value = value_node->value();
461 MS_EXCEPTION_IF_NULL(value);
462 return value->cast<FuncGraphPtr>();
463 }
464
GetCNodeName(const AnfNodePtr & node)465 std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) {
466 MS_EXCEPTION_IF_NULL(node);
467 if (node->isa<CNode>()) {
468 auto primitive = AnfAlgo::GetCNodePrimitive(node);
469 if (primitive != nullptr) {
470 return primitive->name();
471 }
472 auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
473 MS_EXCEPTION_IF_NULL(func_graph);
474 if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
475 std::string fg_name = "GraphKernel_";
476 fg_name += GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
477 return fg_name;
478 }
479 return func_graph->ToString();
480 }
481 MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString() << " trace: " << trace::DumpSourceLines(node);
482 }
483
GetNodeDebugString(const AnfNodePtr & node)484 std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) {
485 MS_EXCEPTION_IF_NULL(node);
486 return node->DebugString();
487 }
488
SetNodeAttr(const std::string & key,const ValuePtr & value,const AnfNodePtr & node)489 void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
490 MS_EXCEPTION_IF_NULL(node);
491 if (!node->isa<CNode>()) {
492 MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString()
493 << " trace: " << trace::DumpSourceLines(node);
494 }
495 // single op cnode.
496 auto primitive = AnfAlgo::GetCNodePrimitive(node);
497 if (primitive != nullptr) {
498 primitive->set_attr(key, value);
499 return;
500 }
501 // graph kernel cnode.
502 auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
503 MS_EXCEPTION_IF_NULL(fg);
504 fg->set_attr(key, value);
505 }
506
CopyNodeAttr(const std::string & key,const AnfNodePtr & from,const AnfNodePtr & to)507 void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) {
508 CopyNodeAttr(key, key, from, to);
509 }
510
CopyNodeAttr(const std::string & old_key,const std::string & new_key,const AnfNodePtr & from,const AnfNodePtr & to)511 void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
512 const AnfNodePtr &to) {
513 MS_EXCEPTION_IF_NULL(from);
514 MS_EXCEPTION_IF_NULL(to);
515 if (!from->isa<CNode>() || !to->isa<CNode>()) {
516 MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is "
517 << to->DebugString() << " trace: " << trace::DumpSourceLines(from);
518 }
519 auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
520 MS_EXCEPTION_IF_NULL(from_primitive);
521 auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
522 MS_EXCEPTION_IF_NULL(to_primitive);
523 to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key));
524 }
525
CopyNodeAttrs(const AnfNodePtr & from,const AnfNodePtr & to)526 void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) {
527 MS_EXCEPTION_IF_NULL(from);
528 MS_EXCEPTION_IF_NULL(to);
529 if (!from->isa<CNode>() || !to->isa<CNode>()) {
530 MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is "
531 << from->DebugString() << " trace: " << trace::DumpSourceLines(from);
532 }
533 auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
534 MS_EXCEPTION_IF_NULL(from_primitive);
535 auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
536 MS_EXCEPTION_IF_NULL(to_primitive);
537 (void)to_primitive->SetAttrs(from_primitive->attrs());
538 }
539
EraseNodeAttr(const std::string & key,const AnfNodePtr node)540 void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) {
541 MS_EXCEPTION_IF_NULL(node);
542 if (!node->isa<CNode>()) {
543 MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString()
544 << " trace: " << trace::DumpSourceLines(node);
545 }
546 // single op cnode.
547 auto primitive = AnfAlgo::GetCNodePrimitive(node);
548 if (primitive != nullptr) {
549 primitive->EraseAttr(key);
550 return;
551 }
552 // graph kernel cnode.
553 auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
554 MS_EXCEPTION_IF_NULL(fg);
555 fg->erase_flag(key);
556 }
557
HasNodeAttr(const std::string & key,const CNodePtr & node)558 bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) {
559 MS_EXCEPTION_IF_NULL(node);
560 if (!node->isa<CNode>()) {
561 MS_LOG(WARNING) << "Only cnode has attr, but this anf is " << node->DebugString();
562 return false;
563 }
564 // single op cnode.
565 auto primitive = AnfAlgo::GetCNodePrimitive(node);
566 if (primitive != nullptr) {
567 return primitive->HasAttr(key);
568 }
569 // graph kernel cnode.
570 auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
571 MS_EXCEPTION_IF_NULL(fg);
572 return fg->has_attr(key);
573 }
574
GetInputNum(const CNodePtr & cnode)575 size_t AnfRuntimeAlgorithm::GetInputNum(const CNodePtr &cnode) {
576 MS_EXCEPTION_IF_NULL(cnode);
577 size_t input_num = cnode->size();
578 if (input_num == 0) {
579 MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero";
580 }
581 return input_num - 1;
582 }
583
GetInputTensorNum(const AnfNodePtr & node)584 size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
585 MS_EXCEPTION_IF_NULL(node);
586 auto cnode = node->cast<CNodePtr>();
587 if (cnode == nullptr) {
588 MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString()
589 << " trace: " << trace::DumpSourceLines(node);
590 }
591 ssize_t input_tensor_num = cnode->input_tensor_num();
592 if (input_tensor_num >= 0) {
593 return static_cast<size_t>(input_tensor_num);
594 }
595 size_t input_num = cnode->inputs().size();
596 if (input_num == 0) {
597 MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero"
598 << " trace: " << trace::DumpSourceLines(node);
599 }
600 // Exclude inputs[0].
601 --input_num;
602
603 // Exclude monad inputs for real cnodes.
604 if (input_num > 0 && IsRealKernelCNode(cnode)) {
605 auto &inputs = cnode->inputs();
606 // Search monad inputs, backward.
607 for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
608 if (!HasAbstractMonad(*iter)) {
609 // Stop count if we encounter a non-monad input.
610 break;
611 }
612 --input_num;
613 }
614 }
615 cnode->set_input_tensor_num(static_cast<ssize_t>(input_num));
616 return input_num;
617 }
618
GetOutputTensorNum(const AnfNodePtr & node)619 size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
620 MS_EXCEPTION_IF_NULL(node);
621 TypePtr type = node->Type();
622 if (type == nullptr) {
623 return 0;
624 }
625 if (type->isa<Tuple>()) {
626 auto tuple_type = type->cast<TuplePtr>();
627 MS_EXCEPTION_IF_NULL(tuple_type);
628 return tuple_type->size();
629 }
630 if (type->isa<TypeNone>()) {
631 return 0;
632 }
633 return 1;
634 }
635
GetOutputTensorMemSize(const AnfNodePtr & node,size_t output_index)636 size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index) {
637 MS_EXCEPTION_IF_NULL(node);
638 if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
639 MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
640 << AnfAlgo::GetOutputTensorNum(node) << "] of node!";
641 }
642 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
643 if (output_type_id == kTypeUnknown) {
644 output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
645 }
646 size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
647 std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
648 auto format = AnfAlgo::GetOutputFormat(node, output_index);
649 if (shape.empty() && format != kOpFormat_DEFAULT) {
650 shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index));
651 shape = trans::TransShapeToDevice(shape, format, node, output_index);
652 }
653 // scalar's output shape is a empty vector
654 size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
655 return tensor_size;
656 }
657
GetAllOutputFormats(const AnfNodePtr & node)658 std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) {
659 MS_EXCEPTION_IF_NULL(node);
660 if (!AnfAlgo::IsRealKernel(node)) {
661 MS_LOG(EXCEPTION) << "Not real kernel:"
662 << "#node [" << node->DebugString() << "]"
663 << " trace: " << trace::DumpSourceLines(node);
664 }
665 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
666 MS_EXCEPTION_IF_NULL(kernel_info);
667 auto build_info = kernel_info->select_kernel_build_info();
668 MS_EXCEPTION_IF_NULL(build_info);
669 auto format = build_info->GetAllOutputFormats();
670 return format;
671 }
672
GetAllInputFormats(const AnfNodePtr & node)673 std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePtr &node) {
674 MS_EXCEPTION_IF_NULL(node);
675 if (!AnfAlgo::IsRealKernel(node)) {
676 MS_LOG(EXCEPTION) << "Not real kernel:"
677 << "#node [" << node->DebugString() << "]"
678 << " trace: " << trace::DumpSourceLines(node);
679 }
680 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
681 MS_EXCEPTION_IF_NULL(kernel_info);
682 auto build_info = kernel_info->select_kernel_build_info();
683 MS_EXCEPTION_IF_NULL(build_info);
684 auto format = build_info->GetAllInputFormats();
685 return format;
686 }
687
GetAllInputDeviceTypes(const AnfNodePtr & node)688 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllInputDeviceTypes(const AnfNodePtr &node) {
689 MS_EXCEPTION_IF_NULL(node);
690 if (!AnfAlgo::IsRealKernel(node)) {
691 MS_LOG(EXCEPTION) << "Not real kernel:"
692 << "#node [" << node->DebugString() << "]"
693 << " trace: " << trace::DumpSourceLines(node);
694 }
695 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
696 MS_EXCEPTION_IF_NULL(kernel_info);
697 auto build_info = kernel_info->select_kernel_build_info();
698 MS_EXCEPTION_IF_NULL(build_info);
699 auto types = build_info->GetAllInputDeviceTypes();
700 return types;
701 }
702
GetAllOutputDeviceTypes(const AnfNodePtr & node)703 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputDeviceTypes(const AnfNodePtr &node) {
704 MS_EXCEPTION_IF_NULL(node);
705 if (!AnfAlgo::IsRealKernel(node)) {
706 MS_LOG(EXCEPTION) << "Not real kernel:"
707 << "#node [" << node->DebugString() << "]"
708 << " trace: " << trace::DumpSourceLines(node);
709 }
710 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
711 MS_EXCEPTION_IF_NULL(kernel_info);
712 auto build_info = kernel_info->select_kernel_build_info();
713 MS_EXCEPTION_IF_NULL(build_info);
714 auto types = build_info->GetAllOutputDeviceTypes();
715 return types;
716 }
717
GetOriginDataFormat(const AnfNodePtr & node)718 std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) {
719 MS_EXCEPTION_IF_NULL(node);
720 if (!AnfAlgo::IsRealKernel(node)) {
721 MS_LOG(EXCEPTION) << "Not real kernel:"
722 << "#node [" << node->DebugString() << "]"
723 << " trace: " << trace::DumpSourceLines(node);
724 }
725 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
726 MS_EXCEPTION_IF_NULL(kernel_info);
727 auto build_info = kernel_info->select_kernel_build_info();
728 MS_EXCEPTION_IF_NULL(build_info);
729 auto format = build_info->GetOriginDataFormat();
730 return format;
731 }
732
GetOutputFormat(const AnfNodePtr & node,size_t output_idx)733 std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
734 MS_EXCEPTION_IF_NULL(node);
735 if (output_idx > GetOutputTensorNum(node)) {
736 MS_LOG(EXCEPTION) << "Output index:" << output_idx
737 << " is out of the node output range :" << GetOutputTensorNum(node) << " #node ["
738 << node->DebugString() << "]"
739 << " trace: " << trace::DumpSourceLines(node);
740 }
741 if (!AnfAlgo::IsRealKernel(node)) {
742 return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
743 }
744 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
745 MS_EXCEPTION_IF_NULL(kernel_info);
746 auto build_info = kernel_info->select_kernel_build_info();
747 MS_EXCEPTION_IF_NULL(build_info);
748 auto format = build_info->GetOutputFormat(output_idx);
749 if (format == kernel::KernelBuildInfo::kInvalidFormat) {
750 MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
751 << " has a invalid output format"
752 << " trace: " << trace::DumpSourceLines(node);
753 }
754 return format;
755 }
756
GetInputFormat(const AnfNodePtr & node,size_t input_idx)757 std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
758 MS_EXCEPTION_IF_NULL(node);
759 if (input_idx > GetInputTensorNum(node)) {
760 MS_LOG(EXCEPTION) << "Input index :" << input_idx
761 << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node ["
762 << node->DebugString() << "]"
763 << " trace: " << trace::DumpSourceLines(node);
764 }
765 if (!IsRealKernel(node)) {
766 return GetPrevNodeOutputFormat(node, input_idx);
767 }
768 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
769 MS_EXCEPTION_IF_NULL(kernel_info);
770 auto build_info = kernel_info->select_kernel_build_info();
771 MS_EXCEPTION_IF_NULL(build_info);
772 auto format = build_info->GetInputFormat(input_idx);
773 if (format == kernel::KernelBuildInfo::kInvalidFormat) {
774 MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
775 << " has a invalid input format"
776 << " trace: " << trace::DumpSourceLines(node);
777 }
778 return format;
779 }
780
GetPrevNodeOutput(const AnfNodePtr & anf_node,size_t input_idx,bool visit_nop_node)781 KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx,
782 bool visit_nop_node) {
783 MS_EXCEPTION_IF_NULL(anf_node);
784 if (!anf_node->isa<CNode>()) {
785 MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."
786 << " trace: " << trace::DumpSourceLines(anf_node);
787 }
788 if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
789 return VisitKernelWithReturnType(anf_node, 0, visit_nop_node);
790 }
791 auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
792 MS_EXCEPTION_IF_NULL(input_node);
793 return VisitKernelWithReturnType(input_node, 0, visit_nop_node);
794 }
795
GetPrevNodeOutputFormat(const AnfNodePtr & anf_node,size_t input_idx)796 std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
797 KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
798 return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
799 }
800
GetPrevNodeOutputReshapeType(const AnfNodePtr & node,size_t input_idx)801 std::string AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
802 KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
803 return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
804 }
805
GetOutputInferShape(const AnfNodePtr & node,const abstract::BaseShapePtr & base_shape,size_t output_idx)806 std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node,
807 const abstract::BaseShapePtr &base_shape,
808 size_t output_idx) {
809 MS_EXCEPTION_IF_NULL(node);
810 MS_EXCEPTION_IF_NULL(base_shape);
811 if (base_shape->isa<abstract::Shape>()) {
812 if (output_idx == 0) {
813 return TransShapeToSizet(base_shape->cast<abstract::ShapePtr>());
814 }
815 MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << output_idx
816 << "."
817 << " trace: " << trace::DumpSourceLines(node);
818 } else if (base_shape->isa<abstract::TupleShape>()) {
819 auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
820 MS_EXCEPTION_IF_NULL(tuple_shape);
821 if (output_idx >= tuple_shape->size()) {
822 MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
823 << " node:" << node->DebugString() << "."
824 << " trace: " << trace::DumpSourceLines(node);
825 }
826 auto b_shp = (*tuple_shape)[output_idx];
827 if (b_shp->isa<abstract::Shape>()) {
828 return TransShapeToSizet(b_shp->cast<abstract::ShapePtr>());
829 } else if (b_shp->isa<abstract::NoShape>()) {
830 return std::vector<size_t>();
831 } else {
832 MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
833 << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
834 << "node :" << node->DebugString() << "."
835 << " trace: " << trace::DumpSourceLines(node);
836 }
837 } else if (base_shape->isa<abstract::NoShape>()) {
838 return std::vector<size_t>();
839 }
840 MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
841 << base_shape->ToString() << " node : " << node->DebugString()
842 << " trace: " << trace::DumpSourceLines(node);
843 }
844
GetOutputInferShape(const AnfNodePtr & node,size_t output_idx)845 std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
846 MS_EXCEPTION_IF_NULL(node);
847 return GetOutputInferShape(node, node->Shape(), output_idx);
848 }
849
GetPrevNodeOutputInferShape(const AnfNodePtr & node,size_t input_idx)850 std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
851 KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
852 return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
853 }
854
GetOutputDeviceShapeForTbeBuild(const AnfNodePtr & node,const size_t output_idx,const std::string & format)855 std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node,
856 const size_t output_idx,
857 const std::string &format) {
858 auto output_shape = GetOutputDetailShape(node, output_idx);
859 std::vector<int64_t> infer_shape;
860 if (output_shape->isa<abstract::Shape>()) {
861 auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
862 MS_EXCEPTION_IF_NULL(shape_ptr);
863 infer_shape = shape_ptr->shape();
864 }
865 if (infer_shape.empty()) {
866 return infer_shape;
867 }
868
869 // if format is default_format or NC1KHKWHWC0,device shape = original shape
870 if (trans::IsNeedPadding(format, infer_shape.size())) {
871 infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx));
872 }
873 return trans::TransShapeToDevice(infer_shape, format, node, output_idx);
874 }
875
GetOutputDeviceShape(const AnfNodePtr & node,size_t output_idx)876 std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
877 auto format = GetOutputFormat(node, output_idx);
878 auto infer_shape = GetOutputInferShape(node, output_idx);
879 if (infer_shape.empty()) {
880 return infer_shape;
881 }
882 // if format is default_format or NC1KHKWHWC0,device shape = original shape
883 if (trans::IsNeedPadding(format, infer_shape.size())) {
884 infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx));
885 }
886 return trans::TransShapeToDevice(infer_shape, format, node, output_idx);
887 }
888
GetInputDeviceShape(const AnfNodePtr & node,size_t input_idx)889 std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
890 auto format = GetInputFormat(node, input_idx);
891 auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx);
892 if (infer_shape.empty()) {
893 return infer_shape;
894 }
895 // if format is default_format or NC1KHKWHWC0,device shape = original shape
896 if (trans::IsNeedPadding(format, infer_shape.size())) {
897 infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx));
898 }
899 return trans::TransShapeToDevice(infer_shape, format, node, input_idx, false);
900 }
901
GetInputReshapeType(const AnfNodePtr & node,size_t input_idx)902 std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
903 MS_EXCEPTION_IF_NULL(node);
904 if (input_idx > GetInputTensorNum(node)) {
905 MS_LOG(EXCEPTION) << "The index:" << input_idx
906 << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node["
907 << node->DebugString() << "]"
908 << " trace: " << trace::DumpSourceLines(node);
909 }
910 if (!IsRealKernel(node)) {
911 return GetPrevNodeOutputReshapeType(node, input_idx);
912 }
913 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
914 MS_EXCEPTION_IF_NULL(kernel_info);
915 auto build_info = kernel_info->select_kernel_build_info();
916 MS_EXCEPTION_IF_NULL(build_info);
917 if (build_info->IsInputDefaultPadding()) {
918 return "";
919 }
920 return build_info->GetInputReshapeType(input_idx);
921 }
922
GetOutputReshapeType(const AnfNodePtr & node,size_t output_idx)923 std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
924 MS_EXCEPTION_IF_NULL(node);
925 if (output_idx > GetOutputTensorNum(node)) {
926 MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
927 << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"
928 << " trace: " << trace::DumpSourceLines(node);
929 }
930 if (!IsRealKernel(node)) {
931 return GetPrevNodeOutputReshapeType(node, output_idx);
932 }
933 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
934 MS_EXCEPTION_IF_NULL(kernel_info);
935 auto build_info = kernel_info->select_kernel_build_info();
936 MS_EXCEPTION_IF_NULL(build_info);
937 if (build_info->IsOutputDefaultPadding()) {
938 return "";
939 }
940 return build_info->GetOutputReshapeType(output_idx);
941 }
942
GetOutputInferDataType(const TypePtr & type,size_t output_idx)943 TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
944 auto type_ptr = type;
945 MS_EXCEPTION_IF_NULL(type_ptr);
946 if (type_ptr->isa<Tuple>()) {
947 auto tuple_ptr = type_ptr->cast<TuplePtr>();
948 MS_EXCEPTION_IF_NULL(tuple_ptr);
949 if (output_idx >= tuple_ptr->size()) {
950 MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
951 }
952 type_ptr = (*tuple_ptr)[output_idx];
953 MS_EXCEPTION_IF_NULL(type_ptr);
954 }
955
956 if (type_ptr->isa<TensorType>()) {
957 auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
958 MS_EXCEPTION_IF_NULL(tensor_ptr);
959 TypePtr elem = tensor_ptr->element();
960 MS_EXCEPTION_IF_NULL(elem);
961 return elem->type_id();
962 }
963
964 return type_ptr->type_id();
965 }
966
GetOutputInferDataType(const AnfNodePtr & node,size_t output_idx)967 TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
968 MS_EXCEPTION_IF_NULL(node);
969 return GetOutputInferDataType(node->Type(), output_idx);
970 }
971
GetPrevNodeOutputInferDataType(const AnfNodePtr & node,size_t input_idx)972 TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
973 KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
974 return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
975 }
976
GetOutputDeviceDataType(const AnfNodePtr & node,size_t output_idx)977 TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) {
978 MS_EXCEPTION_IF_NULL(node);
979 if (output_idx > GetOutputTensorNum(node)) {
980 MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
981 << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"
982 << " trace: " << trace::DumpSourceLines(node);
983 }
984 if (!IsRealKernel(node)) {
985 return GetPrevNodeOutputDeviceDataType(node, output_idx);
986 }
987 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
988 MS_EXCEPTION_IF_NULL(kernel_info);
989 auto build_info = kernel_info->select_kernel_build_info();
990 MS_EXCEPTION_IF_NULL(build_info);
991 auto dtype = build_info->GetOutputDeviceType(output_idx);
992 if (dtype == TypeId::kNumberTypeEnd) {
993 MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
994 << " has a invalid dtype"
995 << " trace: " << trace::DumpSourceLines(node);
996 }
997 return dtype;
998 }
999
GetInputDeviceDataType(const AnfNodePtr & node,size_t input_idx)1000 TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
1001 MS_EXCEPTION_IF_NULL(node);
1002 if (input_idx > GetInputTensorNum(node)) {
1003 MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
1004 << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"
1005 << " trace: " << trace::DumpSourceLines(node);
1006 }
1007 if (!IsRealKernel(node)) {
1008 return GetPrevNodeOutputDeviceDataType(node, 0);
1009 }
1010 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1011 MS_EXCEPTION_IF_NULL(kernel_info);
1012 auto build_info = kernel_info->select_kernel_build_info();
1013 MS_EXCEPTION_IF_NULL(build_info);
1014 auto dtype = build_info->GetInputDeviceType(input_idx);
1015 if (dtype == TypeId::kNumberTypeEnd) {
1016 MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
1017 << " has a invalid dtype"
1018 << " trace: " << trace::DumpSourceLines(node);
1019 }
1020 return dtype;
1021 }
1022
GetPrevNodeOutputDeviceDataType(const AnfNodePtr & anf_node,size_t input_idx)1023 TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
1024 KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
1025 return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
1026 }
1027
1028 // get output device addr of anf_node
GetOutputAddr(const AnfNodePtr & node,size_t output_idx,bool visit_nop_node)1029 const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx,
1030 bool visit_nop_node) {
1031 MS_EXCEPTION_IF_NULL(node);
1032 if (opt::IsNopNode(node) && visit_nop_node) {
1033 auto cnode = node->cast<CNodePtr>();
1034 MS_EXCEPTION_IF_NULL(cnode);
1035 if (cnode->size() == kNopNodeInputSize) {
1036 return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
1037 } else {
1038 MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"
1039 << " trace: " << trace::DumpSourceLines(node);
1040 }
1041 }
1042 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1043 MS_EXCEPTION_IF_NULL(kernel_info);
1044 auto addr = kernel_info->GetOutputAddr(output_idx);
1045 if (addr == nullptr) {
1046 MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
1047 << " output addr is not exist"
1048 << " trace: " << trace::DumpSourceLines(node);
1049 }
1050 return addr;
1051 }
1052
GetMutableOutputAddr(const AnfNodePtr & node,size_t output_idx,bool visit_nop_node)1053 DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx,
1054 bool visit_nop_node) {
1055 MS_EXCEPTION_IF_NULL(node);
1056 if (opt::IsNopNode(node) && visit_nop_node) {
1057 auto cnode = node->cast<CNodePtr>();
1058 MS_EXCEPTION_IF_NULL(cnode);
1059 if (cnode->inputs().size() == kNopNodeInputSize) {
1060 return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
1061 } else {
1062 MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."
1063 << " trace: " << trace::DumpSourceLines(node);
1064 }
1065 }
1066 // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
1067 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1068 MS_EXCEPTION_IF_NULL(kernel_info);
1069 auto addr = kernel_info->GetMutableOutputAddr(output_idx);
1070 if (addr == nullptr) {
1071 MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString() << " output addr is not exist"
1072 << " trace: " << trace::DumpSourceLines(node);
1073 }
1074 return addr;
1075 }
1076
1077 // get output device addr of anf_node
OutputAddrExist(const AnfNodePtr & node,size_t output_idx,bool visit_nop_node)1078 bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node) {
1079 MS_EXCEPTION_IF_NULL(node);
1080 if (opt::IsNopNode(node) && visit_nop_node) {
1081 auto cnode = node->cast<CNodePtr>();
1082 MS_EXCEPTION_IF_NULL(cnode);
1083 if (cnode->inputs().size() > 1) {
1084 auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(cnode, 0);
1085 return OutputAddrExist(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
1086 }
1087 return false;
1088 }
1089 // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
1090 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1091 MS_EXCEPTION_IF_NULL(kernel_info);
1092 return kernel_info->OutputAddrExist(output_idx);
1093 }
1094
WorkspaceAddrExist(const AnfNodePtr & node,size_t output_idx)1095 bool AnfRuntimeAlgorithm::WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx) {
1096 MS_EXCEPTION_IF_NULL(node);
1097 // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
1098 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1099 MS_EXCEPTION_IF_NULL(kernel_info);
1100 return kernel_info->WorkspaceAddrExist(output_idx);
1101 }
1102
GetPrevNodeOutputAddr(const AnfNodePtr & anf_node,size_t input_idx,bool visit_nop_node)1103 const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
1104 bool visit_nop_node) {
1105 KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
1106 return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
1107 }
1108
GetPrevNodeMutableOutputAddr(const AnfNodePtr & anf_node,size_t input_idx,bool visit_nop_node)1109 DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
1110 bool visit_nop_node) {
1111 KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
1112 return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
1113 }
1114
1115 // set output device addr of anf_node
SetOutputAddr(const DeviceAddressPtr & addr,size_t output_idx,AnfNode * node)1116 void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
1117 MS_EXCEPTION_IF_NULL(node);
1118 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1119 MS_EXCEPTION_IF_NULL(kernel_info);
1120 if (!kernel_info->SetOutputAddr(addr, output_idx)) {
1121 MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail."
1122 << " trace: " << trace::DumpSourceLines(node);
1123 }
1124 }
1125
1126 // set workspace device addr of anf_node
SetWorkspaceAddr(const DeviceAddressPtr & addr,size_t output_idx,AnfNode * node)1127 void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
1128 MS_EXCEPTION_IF_NULL(node);
1129 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1130 MS_EXCEPTION_IF_NULL(kernel_info);
1131 if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
1132 MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail。"
1133 << " trace: " << trace::DumpSourceLines(node);
1134 }
1135 }
1136
1137 // get workspace device addr of anf_node
GetWorkspaceAddr(const AnfNodePtr & node,size_t output_idx)1138 DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
1139 MS_EXCEPTION_IF_NULL(node);
1140 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1141 MS_EXCEPTION_IF_NULL(kernel_info);
1142 auto addr = kernel_info->GetWorkspaceAddr(output_idx);
1143 if (addr == nullptr) {
1144 MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
1145 << "] workspace addr is not exist"
1146 << " trace: " << trace::DumpSourceLines(node);
1147 }
1148 return addr;
1149 }
1150
1151 // get workspace device mutable addr of anf_node
GetMutableWorkspaceAddr(const AnfNodePtr & node,size_t index)1152 DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index) {
1153 MS_EXCEPTION_IF_NULL(node);
1154 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1155 MS_EXCEPTION_IF_NULL(kernel_info);
1156 auto addr = kernel_info->GetMutableWorkspaceAddr(index);
1157 if (addr == nullptr) {
1158 MS_LOG(EXCEPTION) << "Index " << index << " of node " << node->DebugString() << "] workspace addr is not exist"
1159 << " trace: " << trace::DumpSourceLines(node);
1160 }
1161 return addr;
1162 }
1163
GetOutputDetailShape(const AnfNodePtr & node,size_t output_idx)1164 abstract::BaseShapePtr AnfRuntimeAlgorithm::GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx) {
1165 MS_EXCEPTION_IF_NULL(node);
1166 auto base_shape = node->Shape();
1167 MS_EXCEPTION_IF_NULL(base_shape);
1168 if (base_shape->isa<abstract::Shape>()) {
1169 if (output_idx == 0) {
1170 return base_shape;
1171 }
1172 MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << output_idx
1173 << "."
1174 << " trace: " << trace::DumpSourceLines(node);
1175 } else if (base_shape->isa<abstract::TupleShape>()) {
1176 auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
1177 MS_EXCEPTION_IF_NULL(tuple_shape);
1178 if (output_idx >= tuple_shape->size()) {
1179 MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
1180 << " node:" << node->DebugString() << "."
1181 << " trace: " << trace::DumpSourceLines(node);
1182 }
1183 auto b_shp = (*tuple_shape)[output_idx];
1184 if (b_shp->isa<abstract::Shape>() || b_shp->isa<abstract::NoShape>()) {
1185 return b_shp;
1186 } else {
1187 MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
1188 << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
1189 << "node :" << node->DebugString() << "."
1190 << " trace: " << trace::DumpSourceLines(node);
1191 }
1192 } else if (base_shape->isa<abstract::NoShape>()) {
1193 return base_shape;
1194 }
1195 MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
1196 << base_shape->ToString() << " node : " << node->DebugString()
1197 << " trace: " << trace::DumpSourceLines(node);
1198 }
1199
GetPrevNodeOutputDetailShape(const AnfNodePtr & node,size_t input_idx)1200 abstract::BaseShapePtr AnfRuntimeAlgorithm::GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx) {
1201 KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
1202 return AnfRuntimeAlgorithm::GetOutputDetailShape(kernel_with_index.first, kernel_with_index.second);
1203 }
1204
1205 // set infer shapes and types of anf node
SetOutputTypeAndDetailShape(const std::vector<TypeId> & types,const std::vector<abstract::BaseShapePtr> & shapes,AnfNode * node)1206 void AnfRuntimeAlgorithm::SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
1207 const std::vector<abstract::BaseShapePtr> &shapes,
1208 AnfNode *node) {
1209 MS_EXCEPTION_IF_NULL(node);
1210 auto node_ptr = node->cast<AnfNodePtr>();
1211 MS_EXCEPTION_IF_NULL(node_ptr);
1212 std::string node_name = "";
1213 if (node_ptr->isa<CNode>()) {
1214 node_name = GetCNodeName(node_ptr);
1215 }
1216 if (types.size() != shapes.size()) {
1217 MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size()
1218 << " trace: " << trace::DumpSourceLines(node);
1219 }
1220 if (shapes.empty() && node_name != prim::kPrimMakeTuple->name()) {
1221 node->set_abstract(std::make_shared<abstract::AbstractNone>());
1222 } else if (shapes.size() == 1 && node_name != prim::kPrimMakeTuple->name()) {
1223 // single output handle
1224 auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shapes[0]);
1225 node->set_abstract(abstract);
1226 } else {
1227 // multiple output handle
1228 std::vector<AbstractBasePtr> abstract_list;
1229 for (size_t i = 0; i < types.size(); ++i) {
1230 auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[i]), shapes[i]);
1231 abstract_list.emplace_back(abstract);
1232 }
1233 auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
1234 node->set_abstract(abstract_tuple);
1235 }
1236 }
1237
1238 // set infer shapes and types of anf node
SetOutputInferTypeAndShape(const std::vector<TypeId> & types,const std::vector<std::vector<size_t>> & shapes,AnfNode * node)1239 void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
1240 const std::vector<std::vector<size_t>> &shapes, AnfNode *node) {
1241 MS_EXCEPTION_IF_NULL(node);
1242 auto node_ptr = node->cast<AnfNodePtr>();
1243 std::string node_name = "";
1244 if (node_ptr->isa<CNode>()) {
1245 node_name = GetCNodeName(node_ptr);
1246 }
1247 MS_EXCEPTION_IF_NULL(node_ptr);
1248 if (types.size() != shapes.size()) {
1249 MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size()
1250 << " trace: " << trace::DumpSourceLines(node);
1251 }
1252 auto abstract_ptr = node_ptr->abstract();
1253 if (shapes.empty() && node_name != prim::kPrimMakeTuple->name()) {
1254 node->set_abstract(std::make_shared<abstract::AbstractNone>());
1255 } else if (shapes.size() == 1 && node_name != prim::kPrimMakeTuple->name()) {
1256 // single output handle
1257 ShapeVector shape_int;
1258 abstract::AbstractTensorPtr abstract = nullptr;
1259 if (abstract_ptr != nullptr) {
1260 auto max_shape0 = GetOutputMaxShape(node_ptr, 0);
1261 auto min_shape0 = GetOutputMinShape(node_ptr, 0);
1262 std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToLong);
1263 abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]),
1264 std::make_shared<abstract::Shape>(shape_int, min_shape0, max_shape0));
1265 } else {
1266 abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shape_int);
1267 }
1268 node->set_abstract(abstract);
1269 } else {
1270 // multiple output handle
1271 std::vector<AbstractBasePtr> abstract_list;
1272 for (size_t i = 0; i < types.size(); ++i) {
1273 ShapeVector shape_int;
1274 abstract::AbstractTensorPtr abstract = nullptr;
1275 if (abstract_ptr != nullptr) {
1276 auto max_shape = GetOutputMaxShape(node_ptr, i);
1277 auto min_shape = GetOutputMinShape(node_ptr, i);
1278 std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToLong);
1279 abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[i]),
1280 std::make_shared<abstract::Shape>(shape_int, min_shape, max_shape));
1281 } else {
1282 abstract =
1283 std::make_shared<AbstractTensor>(TypeIdToType(types[i]), std::make_shared<abstract::Shape>(shape_int));
1284 }
1285 abstract_list.emplace_back(abstract);
1286 }
1287 auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
1288 node->set_abstract(abstract_tuple);
1289 }
1290 }
1291 // copy an abstract of a node to another node
CopyAbstract(const AnfNodePtr & from_node,AnfNode * to_node)1292 void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) {
1293 MS_EXCEPTION_IF_NULL(from_node);
1294 MS_EXCEPTION_IF_NULL(to_node);
1295 to_node->set_abstract(from_node->abstract());
1296 }
1297
GetOpPattern(const AnfNodePtr & node)1298 kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
1299 MS_EXCEPTION_IF_NULL(node);
1300 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1301 MS_EXCEPTION_IF_NULL(kernel_info);
1302 // select_kernel_build_info() has checked whether return pointer is null
1303 auto build_info = kernel_info->select_kernel_build_info();
1304 MS_EXCEPTION_IF_NULL(build_info);
1305 return build_info->op_pattern();
1306 }
1307
1308 // get KernelBuildType of node, such as ATT,RT,FWK and so on
GetKernelType(const AnfNodePtr & node)1309 KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
1310 MS_EXCEPTION_IF_NULL(node);
1311 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1312 MS_EXCEPTION_IF_NULL(kernel_info);
1313 // select_kernel_build_info() has checked whether return pointer is null
1314 auto build_info = kernel_info->select_kernel_build_info();
1315 MS_EXCEPTION_IF_NULL(build_info);
1316 return build_info->kernel_type();
1317 }
1318
SetFusionType(const AnfNodePtr & node,const kernel::FusionType & type)1319 void AnfRuntimeAlgorithm::SetFusionType(const AnfNodePtr &node, const kernel::FusionType &type) {
1320 MS_EXCEPTION_IF_NULL(node);
1321 auto builder =
1322 std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
1323 MS_EXCEPTION_IF_NULL(builder);
1324 builder->SetFusionType(type);
1325 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
1326 }
1327
SetOutputDataDesc(const AnfNodePtr & node,const std::vector<nlohmann::json> & desc)1328 void AnfRuntimeAlgorithm::SetOutputDataDesc(const AnfNodePtr &node, const std::vector<nlohmann::json> &desc) {
1329 MS_EXCEPTION_IF_NULL(node);
1330 auto builder =
1331 std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
1332 MS_EXCEPTION_IF_NULL(builder);
1333 builder->SetOutputDataDesc(desc);
1334 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
1335 }
1336
GetOutputDataDesc(const AnfNodePtr & node)1337 std::vector<nlohmann::json> AnfRuntimeAlgorithm::GetOutputDataDesc(const AnfNodePtr &node) {
1338 MS_EXCEPTION_IF_NULL(node);
1339 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1340 if (kernel_info == nullptr) {
1341 return {};
1342 }
1343 auto build_info = kernel_info->select_kernel_build_info();
1344 if (build_info == nullptr) {
1345 return {};
1346 }
1347 return build_info->output_data_desc();
1348 }
1349
GetProcessor(const AnfNodePtr & node)1350 kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
1351 MS_EXCEPTION_IF_NULL(node);
1352 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1353 MS_EXCEPTION_IF_NULL(kernel_info);
1354 auto build_info = kernel_info->select_kernel_build_info();
1355 MS_EXCEPTION_IF_NULL(build_info);
1356 return build_info->processor();
1357 }
1358
GetFusionType(const AnfNodePtr & node)1359 kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
1360 MS_EXCEPTION_IF_NULL(node);
1361 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1362 MS_EXCEPTION_IF_NULL(kernel_info);
1363 auto build_info = kernel_info->select_kernel_build_info();
1364 if (build_info == nullptr) {
1365 return kernel::FusionType::UNKNOWN_FUSION_TYPE;
1366 }
1367 return build_info->fusion_type();
1368 }
1369
1370 // set select kernel_build_info
SetSelectKernelBuildInfo(const KernelBuildInfoPtr & select_kernel_build_info,AnfNode * node)1371 void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
1372 MS_EXCEPTION_IF_NULL(node);
1373 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1374 MS_EXCEPTION_IF_NULL(kernel_info);
1375 return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
1376 }
1377
1378 // get select kernel_build_info
GetSelectKernelBuildInfo(const AnfNodePtr & node)1379 KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
1380 MS_EXCEPTION_IF_NULL(node);
1381 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1382 MS_EXCEPTION_IF_NULL(kernel_info);
1383 return kernel_info->GetMutableSelectKernelBuildInfo();
1384 }
1385
1386 // get kernelMode
GetKernelMod(const AnfNodePtr & node)1387 KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
1388 MS_EXCEPTION_IF_NULL(node);
1389 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1390 MS_EXCEPTION_IF_NULL(kernel_info);
1391 return kernel_info->MutableKernelMod();
1392 }
1393
1394 // set kernel mod
SetKernelMod(const KernelModPtr & kernel_mod,AnfNode * node)1395 void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
1396 MS_EXCEPTION_IF_NULL(node);
1397 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1398 MS_EXCEPTION_IF_NULL(kernel_info);
1399 kernel_info->set_kernel_mod(kernel_mod);
1400 }
1401
IsRealKernel(const AnfNodePtr & node)1402 bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
1403 MS_EXCEPTION_IF_NULL(node);
1404 // parameter and value node is a real kernel too
1405 if (!node->isa<CNode>()) {
1406 return true;
1407 }
1408 auto cnode = node->cast<CNodePtr>();
1409 MS_EXCEPTION_IF_NULL(cnode);
1410 if (cnode->inputs().empty()) {
1411 MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString()
1412 << " trace: " << trace::DumpSourceLines(node);
1413 }
1414 return IsRealKernelCNode(cnode);
1415 }
1416
IsRealCNodeKernel(const AnfNodePtr & node)1417 bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) {
1418 MS_EXCEPTION_IF_NULL(node);
1419 // parameter and value node is not a real cnode kernel
1420 if (!node->isa<CNode>()) {
1421 return false;
1422 }
1423 // return considered as a real node
1424 if (CheckPrimitiveType(node, prim::kPrimReturn)) {
1425 return true;
1426 }
1427 return IsRealKernel(node);
1428 }
1429
IsGraphKernel(const AnfNodePtr & node)1430 bool AnfRuntimeAlgorithm::IsGraphKernel(const AnfNodePtr &node) {
1431 MS_EXCEPTION_IF_NULL(node);
1432 // graph kernel should be a real cnode kernel.
1433 if (!IsRealCNodeKernel(node)) {
1434 return false;
1435 }
1436
1437 auto cnode = node->cast<CNodePtr>();
1438 MS_EXCEPTION_IF_NULL(cnode);
1439 auto input = cnode->input(kAnfPrimitiveIndex);
1440 // graph kernel should has func_graph as first input.
1441 if (!IsValueNode<FuncGraph>(input)) {
1442 return false;
1443 }
1444
1445 auto func_graph = GetValueNode<FuncGraphPtr>(input);
1446 MS_EXCEPTION_IF_NULL(func_graph);
1447 return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
1448 }
1449
IsNodeInGraphKernel(const AnfNodePtr & node)1450 bool AnfRuntimeAlgorithm::IsNodeInGraphKernel(const AnfNodePtr &node) {
1451 MS_EXCEPTION_IF_NULL(node);
1452 return node->func_graph() != nullptr && node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
1453 }
1454
GetOutputOfGraphkernel(const KernelWithIndex & kernel_with_index)1455 AnfNodePtr AnfRuntimeAlgorithm::GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index) {
1456 auto func_graph = GetCNodeFuncGraph(kernel_with_index.first);
1457 if (func_graph == nullptr) {
1458 return kernel_with_index.first;
1459 }
1460 auto output = func_graph->output();
1461 if (CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
1462 return output->cast<CNodePtr>()->input(kernel_with_index.second + 1);
1463 }
1464 return output;
1465 }
1466
IsParameterWeight(const ParameterPtr & node)1467 bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
1468 MS_EXCEPTION_IF_NULL(node);
1469 return node->has_default();
1470 }
1471
IsLabelIndexInNode(const AnfNodePtr & node,size_t label_index)1472 bool AnfRuntimeAlgorithm::IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index) {
1473 MS_EXCEPTION_IF_NULL(node);
1474 if (!node->isa<CNode>()) {
1475 return false;
1476 }
1477 auto cnode = node->cast<CNodePtr>();
1478 MS_EXCEPTION_IF_NULL(cnode);
1479 if (AnfAlgo::GetCNodeName(cnode) == kLabelGotoOpName &&
1480 (AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex) == label_index)) {
1481 return true;
1482 } else if (AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
1483 auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrLabelSwitchList);
1484 if (std::find(label_list.begin(), label_list.end(), label_index) != label_list.end()) {
1485 return true;
1486 }
1487 }
1488 return false;
1489 }
1490
SetStreamId(uint32_t stream_id,AnfNode * node)1491 void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
1492 MS_EXCEPTION_IF_NULL(node);
1493 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1494 MS_EXCEPTION_IF_NULL(kernel_info);
1495 kernel_info->set_stream_id(stream_id);
1496 }
1497
GetStreamId(const AnfNodePtr & node)1498 uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
1499 MS_EXCEPTION_IF_NULL(node);
1500 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1501 MS_EXCEPTION_IF_NULL(kernel_info);
1502 return kernel_info->stream_id();
1503 }
1504
SetStreamDistinctionLabel(uint32_t stream_label,AnfNode * node)1505 void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
1506 MS_EXCEPTION_IF_NULL(node);
1507 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1508 MS_EXCEPTION_IF_NULL(kernel_info);
1509 kernel_info->set_stream_distinction_label(stream_label);
1510 }
1511
GetStreamDistinctionLabel(const AnfNode * node)1512 uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
1513 MS_EXCEPTION_IF_NULL(node);
1514 auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
1515 MS_EXCEPTION_IF_NULL(kernel_info);
1516 return kernel_info->stream_distinction_label();
1517 }
1518
SetGraphId(uint32_t graph_id,AnfNode * node)1519 void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
1520 MS_EXCEPTION_IF_NULL(node);
1521 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1522 MS_EXCEPTION_IF_NULL(kernel_info);
1523 kernel_info->set_graph_id(graph_id);
1524 }
1525
GetGraphId(const AnfNode * node)1526 uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
1527 MS_EXCEPTION_IF_NULL(node);
1528 auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
1529 MS_EXCEPTION_IF_NULL(kernel_info);
1530 return kernel_info->graph_id();
1531 }
1532
IsTupleOutput(const AnfNodePtr & anf)1533 bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
1534 MS_EXCEPTION_IF_NULL(anf);
1535 TypePtr type = anf->Type();
1536 if (type == nullptr) {
1537 return false;
1538 }
1539 MS_EXCEPTION_IF_NULL(type);
1540 return type->isa<Tuple>();
1541 }
1542
GetInputNode(const CNodePtr & node,size_t index)1543 AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
1544 MS_EXCEPTION_IF_NULL(node);
1545 auto get_input_index = index + 1;
1546 if (get_input_index >= node->inputs().size()) {
1547 MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
1548 << node->inputs().size() << " trace: " << trace::DumpSourceLines(node);
1549 }
1550 // input 0 is primitive node
1551 return node->input(get_input_index);
1552 }
1553
IsFeatureMapOutput(const AnfNodePtr & node)1554 bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
1555 MS_EXCEPTION_IF_NULL(node);
1556 if (node->isa<ValueNode>()) {
1557 return false;
1558 }
1559 if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
1560 return IsFeatureMapOutput(node->cast<CNodePtr>()->input(1));
1561 }
1562 auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
1563 MS_EXCEPTION_IF_NULL(kernel_info);
1564 return kernel_info->is_feature_map();
1565 }
1566
IsFeatureMapInput(const AnfNodePtr & node,size_t input_index)1567 bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) {
1568 MS_EXCEPTION_IF_NULL(node);
1569 if (!node->isa<CNode>()) {
1570 MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map"
1571 << " trace: " << trace::DumpSourceLines(node);
1572 }
1573 auto cnode = node->cast<CNodePtr>();
1574 MS_EXCEPTION_IF_NULL(cnode);
1575 auto input_node = cnode->input(input_index + 1);
1576 return IsFeatureMapOutput(input_node);
1577 }
1578
GetRealInputIndex(const mindspore::AnfNodePtr & anf_node,const size_t cur_index)1579 size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
1580 MS_EXCEPTION_IF_NULL(anf_node);
1581 size_t ret = cur_index;
1582 auto node_name = AnfAlgo::GetCNodeName(anf_node);
1583 if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
1584 if (AnfAlgo::IsNodeDynamicShape(anf_node) || AnfAlgo::IsDynamicShape(anf_node)) {
1585 auto find_dynamic = spec_dynamic_node_list.find(node_name);
1586 if (find_dynamic != spec_dynamic_node_list.end()) {
1587 auto dyn_index_converter = find_dynamic->second;
1588 ret = dyn_index_converter.first[cur_index];
1589 MS_LOG(DEBUG) << "Real input index change to " << ret << ", node name:" << node_name;
1590 return ret;
1591 }
1592 }
1593 auto find = spec_node_list.find(node_name);
1594 if (find != spec_node_list.end()) {
1595 auto index_converter = find->second;
1596 ret = index_converter.first[cur_index];
1597 MS_LOG(DEBUG) << "Real input index change to " << ret << ", node name:" << node_name;
1598 }
1599 }
1600 return ret;
1601 }
1602
GetOriginalInputIndex(const mindspore::AnfNodePtr & anf_node,const size_t cur_index)1603 size_t AnfRuntimeAlgorithm::GetOriginalInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
1604 MS_EXCEPTION_IF_NULL(anf_node);
1605 size_t ret = cur_index;
1606 auto node_name = AnfAlgo::GetCNodeName(anf_node);
1607 if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
1608 if (AnfAlgo::IsNodeDynamicShape(anf_node) || AnfAlgo::IsDynamicShape(anf_node)) {
1609 auto find_dynamic = spec_dynamic_node_list.find(node_name);
1610 if (find_dynamic != spec_dynamic_node_list.end()) {
1611 auto dyn_index_converter = find_dynamic->second;
1612 ret = dyn_index_converter.second[cur_index];
1613 MS_LOG(DEBUG) << "Get original input index " << ret << ", node name:" << node_name;
1614 return ret;
1615 }
1616 }
1617 auto find = spec_node_list.find(node_name);
1618 if (find != spec_node_list.end()) {
1619 auto index_converter = find->second;
1620 ret = index_converter.second[cur_index];
1621 MS_LOG(DEBUG) << "Get original input index " << ret << ", node name:" << node_name;
1622 }
1623 }
1624 return ret;
1625 }
1626
SetNodeInput(const CNodePtr & node,const AnfNodePtr & input_node,size_t index)1627 void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) {
1628 MS_EXCEPTION_IF_NULL(node);
1629 MS_EXCEPTION_IF_NULL(input_node);
1630 node->set_input(index + 1, input_node);
1631 }
1632
IsInplaceNode(const mindspore::AnfNodePtr & kernel,const string & type)1633 bool AnfRuntimeAlgorithm::IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &type) {
1634 MS_EXCEPTION_IF_NULL(kernel);
1635 auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
1636 if (!primitive) {
1637 return false;
1638 }
1639
1640 auto inplace_attr = primitive->GetAttr(type);
1641 if (inplace_attr == nullptr) {
1642 return false;
1643 }
1644
1645 return true;
1646 }
1647
IsCommunicationOp(const AnfNodePtr & node)1648 bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
1649 static const std::set<std::string> kCommunicationOpNames = {kAllReduceOpName, kAllGatherOpName, kBroadcastOpName,
1650 kReduceScatterOpName, kHcomSendOpName, kReceiveOpName,
1651 kAllToAllVOpName};
1652 MS_EXCEPTION_IF_NULL(node);
1653 if (!node->isa<CNode>()) {
1654 return false;
1655 }
1656 auto kernel_name = AnfAlgo::GetCNodeName(node);
1657 return (kCommunicationOpNames.find(kernel_name) != kCommunicationOpNames.end());
1658 }
1659
IsFusedCommunicationOp(const AnfNodePtr & node)1660 bool AnfRuntimeAlgorithm::IsFusedCommunicationOp(const AnfNodePtr &node) {
1661 if (!IsCommunicationOp(node)) {
1662 return false;
1663 }
1664 auto primitive = AnfAlgo::GetCNodePrimitive(node);
1665 MS_EXCEPTION_IF_NULL(primitive);
1666 ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
1667 if (attr_fusion == nullptr) {
1668 return false;
1669 }
1670 auto fusion = GetValue<int64_t>(attr_fusion);
1671 if (fusion == 0) {
1672 return false;
1673 }
1674 return true;
1675 }
1676
IsGetNext(const NotNull<AnfNodePtr> & node)1677 bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
1678 auto kernel_name = AnfAlgo::GetCNodeName(node);
1679 return kernel_name == kGetNextOpName;
1680 }
1681
GetValueNodeFuncGraph(const AnfNodePtr & node)1682 FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) {
1683 MS_EXCEPTION_IF_NULL(node);
1684 auto value_node = node->cast<ValueNodePtr>();
1685 if (value_node == nullptr) {
1686 return nullptr;
1687 }
1688 auto value = value_node->value();
1689 if (value == nullptr) {
1690 return nullptr;
1691 }
1692 auto func_graph = value->cast<FuncGraphPtr>();
1693 return func_graph;
1694 }
1695
GetCallSwitchKernelGraph(const CNodePtr & cnode)1696 std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) {
1697 MS_EXCEPTION_IF_NULL(cnode);
1698 if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
1699 AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) {
1700 MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch or switch_layer node."
1701 << " trace: " << trace::DumpSourceLines(cnode);
1702 }
1703 auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr {
1704 auto partial = cnode->input(input_index);
1705 MS_EXCEPTION_IF_NULL(partial);
1706 if (IsValueNode<KernelGraph>(partial)) {
1707 return GetValueNode<KernelGraphPtr>(partial);
1708 }
1709 auto partial_cnode = partial->cast<CNodePtr>();
1710 MS_EXCEPTION_IF_NULL(partial_cnode);
1711 auto graph_node = partial_cnode->input(kCallKernelGraphIndex);
1712 MS_EXCEPTION_IF_NULL(graph_node);
1713 auto graph_value_node = graph_node->cast<ValueNodePtr>();
1714 MS_EXCEPTION_IF_NULL(graph_value_node);
1715 auto graph_value = graph_value_node->value();
1716 MS_EXCEPTION_IF_NULL(graph_value);
1717 auto child_graph = graph_value->cast<KernelGraphPtr>();
1718 return child_graph;
1719 };
1720 if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
1721 auto input1 = cnode->input(kCallKernelGraphIndex);
1722 MS_EXCEPTION_IF_NULL(input1);
1723 auto value_node = input1->cast<ValueNodePtr>();
1724 MS_EXCEPTION_IF_NULL(value_node);
1725 auto kernel_graph = value_node->value();
1726 MS_EXCEPTION_IF_NULL(kernel_graph);
1727 return {kernel_graph->cast<KernelGraphPtr>()};
1728 } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
1729 return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex),
1730 get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)};
1731 } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
1732 std::vector<KernelGraphPtr> child_graphs;
1733 for (size_t idx = kMakeTupleInSwitchLayerIndex; idx < cnode->inputs().size(); idx++) {
1734 auto child_graph = get_switch_kernel_graph(idx);
1735 child_graphs.emplace_back(child_graph);
1736 }
1737 return child_graphs;
1738 }
1739 return {};
1740 }
1741
IsSwitchCall(const CNodePtr & call_node)1742 bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
1743 MS_EXCEPTION_IF_NULL(call_node);
1744 if (!CheckPrimitiveType(call_node, prim::kPrimCall)) {
1745 MS_LOG(EXCEPTION) << "Call node should be a 'call', but is a " << call_node->DebugString()
1746 << " trace: " << trace::DumpSourceLines(call_node);
1747 }
1748 auto input1 = call_node->input(1);
1749 MS_EXCEPTION_IF_NULL(input1);
1750 if (input1->isa<ValueNode>()) {
1751 return false;
1752 } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
1753 return true;
1754 }
1755 MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString()
1756 << " trace: " << trace::DumpSourceLines(call_node);
1757 }
1758
IsScalarInput(const CNodePtr & cnode,size_t index)1759 bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) {
1760 auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
1761 if (shape.empty()) {
1762 return true;
1763 }
1764 return shape.size() == kShape1dDims && shape[0] == 1;
1765 }
1766
IsScalarOutput(const CNodePtr & cnode,size_t index)1767 bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) {
1768 auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
1769 if (shape.empty()) {
1770 return true;
1771 }
1772 return shape.size() == kShape1dDims && shape[0] == 1;
1773 }
1774
1775 namespace {
FindDelayExecPosition(const std::vector<CNodePtr> & nodes,size_t current_index,std::set<size_t> * invalid_position,std::map<size_t,std::vector<CNodePtr>> * insert_nodes)1776 void FindDelayExecPosition(const std::vector<CNodePtr> &nodes, size_t current_index, std::set<size_t> *invalid_position,
1777 std::map<size_t, std::vector<CNodePtr>> *insert_nodes) {
1778 MS_EXCEPTION_IF_NULL(invalid_position);
1779 MS_EXCEPTION_IF_NULL(insert_nodes);
1780 if (current_index >= nodes.size()) {
1781 return;
1782 }
1783 auto &node = nodes[current_index];
1784 for (size_t j = current_index + 1; j < nodes.size(); ++j) {
1785 auto &child = nodes[j];
1786 auto input_size = child->inputs().size() - 1;
1787 for (size_t k = 0; k < input_size; ++k) {
1788 auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true);
1789 if (kernel_index.first != node) {
1790 continue;
1791 }
1792 if (AnfAlgo::GetCNodeName(child) == kApplyMomentumOpName) {
1793 return;
1794 }
1795 (void)invalid_position->insert(current_index);
1796 auto iter = insert_nodes->find(j);
1797 if (iter != insert_nodes->end()) {
1798 iter->second.emplace_back(node);
1799 } else {
1800 (*insert_nodes)[j] = {node};
1801 }
1802 return;
1803 }
1804 }
1805 }
1806
DelayExecNode(const std::vector<CNodePtr> & nodes,const std::string & node_name,bool only_seed)1807 std::vector<CNodePtr> DelayExecNode(const std::vector<CNodePtr> &nodes, const std::string &node_name, bool only_seed) {
1808 std::map<size_t, std::vector<CNodePtr>> insert_nodes;
1809 std::set<size_t> invalid_position;
1810 for (size_t i = 0; i < nodes.size(); ++i) {
1811 auto &node = nodes[i];
1812 if (AnfAlgo::GetCNodeName(node) != node_name) {
1813 continue;
1814 }
1815 if (only_seed) {
1816 bool is_seed = true;
1817 auto input_size = node->inputs().size() - 1;
1818 for (size_t k = 0; k < input_size; ++k) {
1819 auto input = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(node, k), 0, true).first;
1820 if (input != nullptr && input->isa<CNode>()) {
1821 is_seed = false;
1822 break;
1823 }
1824 }
1825 if (!is_seed) {
1826 continue;
1827 }
1828 }
1829 FindDelayExecPosition(nodes, i, &invalid_position, &insert_nodes);
1830 }
1831 std::vector<CNodePtr> result;
1832 for (size_t i = 0; i < nodes.size(); ++i) {
1833 auto iter = insert_nodes.find(i);
1834 if (iter != insert_nodes.end()) {
1835 (void)result.insert(result.end(), iter->second.rbegin(), iter->second.rend());
1836 }
1837 if (invalid_position.find(i) != invalid_position.end()) {
1838 continue;
1839 }
1840 result.emplace_back(nodes[i]);
1841 }
1842 return result;
1843 }
1844 } // namespace
1845
ReorderExecList(NotNull<std::vector<CNodePtr> * > node_list)1846 void AnfRuntimeAlgorithm::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list) {
1847 std::vector<CNodePtr> result;
1848 std::copy(node_list->begin(), node_list->end(), std::back_inserter(result));
1849 result = DelayExecNode(result, "TransData", true);
1850 result = DelayExecNode(result, "Cast", true);
1851 result = DelayExecNode(result, "AdamApplyOneWithDecay", false);
1852 result = DelayExecNode(result, "AdamApplyOne", false);
1853 node_list->clear();
1854 std::copy(result.begin(), result.end(), std::back_inserter(*node_list));
1855 }
1856
ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> * > node_list)1857 void AnfRuntimeAlgorithm::ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> *> node_list) {
1858 std::vector<CNodePtr> ordinary_node_list;
1859 std::vector<CNodePtr> posterior_node_list;
1860
1861 for (const auto &node : *node_list) {
1862 MS_EXCEPTION_IF_NULL(node);
1863 if (kPosteriorOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kPosteriorOperatorSet.end()) {
1864 posterior_node_list.emplace_back(node);
1865 } else {
1866 ordinary_node_list.emplace_back(node);
1867 }
1868 }
1869 node_list->clear();
1870 std::copy(ordinary_node_list.begin(), ordinary_node_list.end(), std::back_inserter(*node_list));
1871 std::copy(posterior_node_list.begin(), posterior_node_list.end(), std::back_inserter(*node_list));
1872 }
1873
GetCNodeOutputPrecision(const AnfNodePtr & node)1874 TypeId AnfRuntimeAlgorithm::GetCNodeOutputPrecision(const AnfNodePtr &node) {
1875 MS_EXCEPTION_IF_NULL(node);
1876 auto prim = AnfAlgo::GetCNodePrimitive(node);
1877 if (prim == nullptr) {
1878 return kTypeUnknown;
1879 }
1880
1881 TypeId except_type = kTypeUnknown;
1882 if (prim->GetAttr(kAttrOutputPrecision) != nullptr) {
1883 auto output_type_str = GetValue<std::string>(prim->GetAttr(kAttrOutputPrecision));
1884 if (output_type_str == "float16") {
1885 except_type = kNumberTypeFloat16;
1886 } else if (output_type_str == "float32") {
1887 except_type = kNumberTypeFloat32;
1888 } else {
1889 MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got " << output_type_str
1890 << " trace: " << trace::DumpSourceLines(node);
1891 }
1892 }
1893
1894 return except_type;
1895 }
1896
GetPrevNodeOutputPrecision(const AnfNodePtr & node,size_t input_idx)1897 TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx) {
1898 MS_EXCEPTION_IF_NULL(node);
1899 if (!node->isa<CNode>()) {
1900 MS_LOG(EXCEPTION) << node->DebugString() << ", input node is not CNode."
1901 << " trace: " << trace::DumpSourceLines(node);
1902 }
1903 auto cnode = node->cast<CNodePtr>();
1904 MS_EXCEPTION_IF_NULL(cnode);
1905 if (input_idx + 1 >= cnode->inputs().size()) {
1906 MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode)
1907 << " trace: " << trace::DumpSourceLines(node);
1908 }
1909 auto input_node = cnode->input(input_idx + 1);
1910 MS_EXCEPTION_IF_NULL(input_node);
1911 auto kernel_with_index = VisitKernel(input_node, 0);
1912 if (!kernel_with_index.first->isa<CNode>()) {
1913 return kTypeUnknown;
1914 }
1915 return GetCNodeOutputPrecision(kernel_with_index.first);
1916 }
1917
IsCondControlKernel(const CNodePtr & node)1918 bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) {
1919 MS_EXCEPTION_IF_NULL(node);
1920 if (node->inputs().empty()) {
1921 MS_LOG(EXCEPTION) << "Illegal null input of cnode."
1922 << " trace: " << trace::DumpSourceLines(node);
1923 }
1924 auto input = node->input(kAnfPrimitiveIndex);
1925 return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch);
1926 }
1927
IsIndependentNode(const CNodePtr & node)1928 bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) {
1929 MS_EXCEPTION_IF_NULL(node);
1930 if (AnfAlgo::GetKernelType(node) != AICPU_KERNEL) {
1931 return false;
1932 }
1933
1934 if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) {
1935 MS_LOG(INFO) << "GetNext should not be independent node";
1936 return false;
1937 }
1938
1939 // aicpu stack ops are not independent nodes.
1940 if (AnfAlgo::GetCNodeName(node) == kStackInitOpName || AnfAlgo::GetCNodeName(node) == kStackDestroyOpName ||
1941 AnfAlgo::GetCNodeName(node) == kStackPopOpName || AnfAlgo::GetCNodeName(node) == kStackPushOpName) {
1942 MS_LOG(INFO) << "AICPU stack ops should not be independent node";
1943 return false;
1944 }
1945
1946 size_t input_nums = AnfAlgo::GetInputTensorNum(node);
1947 if (input_nums == 0) {
1948 return true;
1949 }
1950
1951 auto inputs = node->inputs();
1952 for (size_t i = 1; i < inputs.size(); i++) {
1953 if (!inputs[i]->isa<ValueNode>()) {
1954 return false;
1955 }
1956 }
1957 return true;
1958 }
1959
GetBooleanAttr(const AnfNodePtr & node,const std::string & attr)1960 bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::string &attr) {
1961 MS_EXCEPTION_IF_NULL(node);
1962 if (!node->isa<CNode>()) {
1963 return false;
1964 }
1965 auto cnode = node->cast<CNodePtr>();
1966 MS_EXCEPTION_IF_NULL(cnode);
1967 auto has_attr = AnfAlgo::HasNodeAttr(attr, cnode);
1968 if (!has_attr) {
1969 return false;
1970 }
1971 return AnfAlgo::GetNodeAttr<bool>(node, attr);
1972 }
1973
HasDynamicShapeFlag(const PrimitivePtr & prim)1974 bool AnfRuntimeAlgorithm::HasDynamicShapeFlag(const PrimitivePtr &prim) {
1975 auto get_bool_attr = [](const PrimitivePtr &primitive, const std::string &attr_name) -> bool {
1976 MS_EXCEPTION_IF_NULL(primitive);
1977 if (!primitive->HasAttr(attr_name)) {
1978 return false;
1979 }
1980 return GetValue<bool>(primitive->GetAttr(attr_name));
1981 };
1982 return get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape) ||
1983 get_bool_attr(prim, kAttrIsDynamicShape);
1984 }
1985
IsDynamicShape(const AnfNodePtr & node)1986 bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) {
1987 return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) ||
1988 GetBooleanAttr(node, kAttrIsDynamicShape);
1989 }
1990
GetRealDynamicShape(const std::vector<size_t> & shape,NotNull<std::vector<int64_t> * > dynamic_shape)1991 void AnfRuntimeAlgorithm::GetRealDynamicShape(const std::vector<size_t> &shape,
1992 NotNull<std::vector<int64_t> *> dynamic_shape) {
1993 for (auto size : shape) {
1994 if (size == SIZE_MAX) {
1995 dynamic_shape->push_back(-1);
1996 } else {
1997 dynamic_shape->push_back(SizeToLong(size));
1998 }
1999 }
2000 }
2001
GetShapeFromSequeueShape(const abstract::SequeueShapePtr & sequeue_shape_ptr,size_t index,ShapeType type)2002 std::vector<int64_t> GetShapeFromSequeueShape(const abstract::SequeueShapePtr &sequeue_shape_ptr, size_t index,
2003 ShapeType type) {
2004 MS_EXCEPTION_IF_NULL(sequeue_shape_ptr);
2005 auto shape_list = sequeue_shape_ptr->shape();
2006 if (index >= shape_list.size()) {
2007 MS_LOG(EXCEPTION) << "Output Index:" << index << " >= " << shape_list.size();
2008 }
2009
2010 auto shape = shape_list[index];
2011 MS_EXCEPTION_IF_NULL(shape);
2012 if (shape->isa<abstract::Shape>()) {
2013 auto shape_ptr = shape->cast<abstract::ShapePtr>();
2014 if (type == ShapeType::kMaxShape) {
2015 return shape_ptr->max_shape().empty() ? shape_ptr->shape() : shape_ptr->max_shape();
2016 } else {
2017 return shape_ptr->min_shape().empty() ? shape_ptr->shape() : shape_ptr->min_shape();
2018 }
2019 } else {
2020 MS_LOG(EXCEPTION) << "Invalid Shape Type In Shape List";
2021 }
2022 }
2023
GetInputMaxShape(const AnfNodePtr & anf_node,size_t index)2024 std::vector<int64_t> AnfRuntimeAlgorithm::GetInputMaxShape(const AnfNodePtr &anf_node, size_t index) {
2025 auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
2026 return GetOutputMaxShape(input_node_with_index.first, input_node_with_index.second);
2027 }
2028
GetInputMinShape(const AnfNodePtr & anf_node,size_t index)2029 std::vector<int64_t> AnfRuntimeAlgorithm::GetInputMinShape(const AnfNodePtr &anf_node, size_t index) {
2030 auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
2031 return GetOutputMinShape(input_node_with_index.first, input_node_with_index.second);
2032 }
2033
GetOutputMaxShape(const AnfNodePtr & anf_node,size_t index)2034 std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index) {
2035 MS_EXCEPTION_IF_NULL(anf_node);
2036 auto shape = anf_node->Shape();
2037 MS_EXCEPTION_IF_NULL(shape);
2038 if (shape->isa<abstract::Shape>()) {
2039 auto shape_ptr = shape->cast<abstract::ShapePtr>();
2040 return shape_ptr->max_shape().empty() ? shape_ptr->shape() : shape_ptr->max_shape();
2041 } else if (shape->isa<abstract::SequeueShape>()) {
2042 auto sequeue_shape_ptr = shape->cast<abstract::SequeueShapePtr>();
2043 return GetShapeFromSequeueShape(sequeue_shape_ptr, index, ShapeType::kMaxShape);
2044 } else if (shape->isa<abstract::NoShape>()) {
2045 return {};
2046 } else {
2047 MS_LOG(EXCEPTION) << "Invalid Shape Type"
2048 << " trace: " << trace::DumpSourceLines(anf_node);
2049 }
2050 }
2051
GetOutputMinShape(const AnfNodePtr & anf_node,size_t index)2052 std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputMinShape(const AnfNodePtr &anf_node, size_t index) {
2053 MS_EXCEPTION_IF_NULL(anf_node);
2054 auto shape = anf_node->Shape();
2055 MS_EXCEPTION_IF_NULL(shape);
2056 if (shape->isa<abstract::Shape>()) {
2057 auto shape_ptr = shape->cast<abstract::ShapePtr>();
2058 return shape_ptr->min_shape().empty() ? shape_ptr->shape() : shape_ptr->min_shape();
2059 } else if (shape->isa<abstract::SequeueShape>()) {
2060 auto sequeue_shape_ptr = shape->cast<abstract::SequeueShapePtr>();
2061 return GetShapeFromSequeueShape(sequeue_shape_ptr, index, ShapeType::kMinShape);
2062 } else if (shape->isa<abstract::NoShape>()) {
2063 return {};
2064 } else {
2065 MS_LOG(EXCEPTION) << "Invalid Shape Type"
2066 << " trace: " << trace::DumpSourceLines(anf_node);
2067 }
2068 }
2069
IsNodeInputDynamicShape(const CNodePtr & anf_node_ptr)2070 bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr) {
2071 MS_EXCEPTION_IF_NULL(anf_node_ptr);
2072 auto input_num = AnfAlgo::GetInputTensorNum(anf_node_ptr);
2073 for (size_t i = 0; i < input_num; ++i) {
2074 auto input_with_index = AnfAlgo::GetPrevNodeOutput(anf_node_ptr, i);
2075 auto input = input_with_index.first;
2076 auto index = input_with_index.second;
2077 MS_EXCEPTION_IF_NULL(input);
2078 auto base_shape = input->Shape();
2079 if (base_shape == nullptr) {
2080 MS_LOG(INFO) << "Invalid shape ptr, node:" << input->fullname_with_scope();
2081 continue;
2082 }
2083 if (base_shape->isa<abstract::Shape>()) {
2084 if (AnfUtils::IsShapeDynamic(base_shape->cast<abstract::ShapePtr>())) {
2085 return true;
2086 }
2087 } else if (base_shape->isa<abstract::TupleShape>()) {
2088 auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
2089 MS_EXCEPTION_IF_NULL(tuple_shape);
2090
2091 if (index >= tuple_shape->size()) {
2092 MS_LOG(INFO) << "Node:" << anf_node_ptr->fullname_with_scope() << "Invalid index:" << index
2093 << " and tuple_shape size:" << tuple_shape->size();
2094 continue;
2095 }
2096 auto b_shp = (*tuple_shape)[index];
2097 if (!b_shp->isa<abstract::Shape>()) {
2098 continue;
2099 }
2100 if (AnfUtils::IsShapeDynamic(b_shp->cast<abstract::ShapePtr>())) {
2101 return true;
2102 }
2103 }
2104 }
2105 return false;
2106 }
2107
IsNodeDynamicShape(const AnfNodePtr & node)2108 bool AnfRuntimeAlgorithm::IsNodeDynamicShape(const AnfNodePtr &node) {
2109 MS_EXCEPTION_IF_NULL(node);
2110 if (!node->isa<CNode>()) {
2111 MS_LOG(DEBUG) << "Node is not a cnode";
2112 return false;
2113 }
2114 auto cnode = node->cast<CNodePtr>();
2115 auto in_dynamic = IsNodeInputDynamicShape(cnode);
2116 auto out_dynamic = AnfUtils::IsNodeOutputDynamicShape(cnode);
2117 if (in_dynamic && !AnfAlgo::HasNodeAttr(kAttrInputIsDynamicShape, cnode)) {
2118 AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cnode);
2119 MS_LOG(INFO) << "Set Input Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
2120 }
2121 if (out_dynamic && !AnfAlgo::HasNodeAttr(kAttrOutputIsDynamicShape, cnode)) {
2122 AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cnode);
2123 MS_LOG(INFO) << "Set Output Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
2124 }
2125 return in_dynamic || out_dynamic;
2126 }
2127
GetInputRealDeviceShapeIfExist(const AnfNodePtr & anf_node,size_t index)2128 std::vector<size_t> AnfRuntimeAlgorithm::GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index) {
2129 auto device_shape = GetInputDeviceShape(anf_node, index);
2130 // Initialize GPUKernel with max shape to fit 'InitDynamicOutputKernelRef()' for memory reuse.
2131 if (AnfUtils::IsShapeDynamic(device_shape)) {
2132 auto max_shape = GetInputMaxShape(anf_node, index);
2133 std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
2134 auto format = GetInputFormat(anf_node, index);
2135 (void)trans::TransShapeToDevice(device_shape, format, anf_node, index, false);
2136 }
2137 return device_shape;
2138 }
2139
GetOutputRealDeviceShapeIfExist(const AnfNodePtr & anf_node,size_t index)2140 std::vector<size_t> AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index) {
2141 auto device_shape = GetOutputDeviceShape(anf_node, index);
2142 // Initialize GPUKernel with max shape to fit 'InitDynamicOutputKernelRef()' for memory reuse.
2143 if (AnfUtils::IsShapeDynamic(device_shape)) {
2144 auto max_shape = GetOutputMaxShape(anf_node, index);
2145 std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
2146 auto format = GetOutputFormat(anf_node, index);
2147 (void)trans::TransShapeToDevice(device_shape, format, anf_node, index);
2148 }
2149 return device_shape;
2150 }
2151
GetAllVisitedCNode(const CNodePtr & anf_node,std::vector<AnfNodePtr> * used_kernels,std::set<AnfNodePtr> * visited)2152 void AnfRuntimeAlgorithm::GetAllVisitedCNode(const CNodePtr &anf_node, std::vector<AnfNodePtr> *used_kernels,
2153 std::set<AnfNodePtr> *visited) {
2154 MS_EXCEPTION_IF_NULL(anf_node);
2155 MS_EXCEPTION_IF_NULL(used_kernels);
2156 MS_EXCEPTION_IF_NULL(visited);
2157 if (visited->find(anf_node) != visited->end()) {
2158 MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has already been visited";
2159 return;
2160 }
2161 visited->insert(anf_node);
2162 auto input_size = anf_node->inputs().size() - 1;
2163 for (size_t i = 0; i < input_size; ++i) {
2164 auto input = AnfAlgo::GetInputNode(anf_node, i);
2165 if (!input->isa<CNode>()) {
2166 continue;
2167 }
2168 auto input_cnode = input->cast<CNodePtr>();
2169 if (!IsRealKernelCNode(input_cnode) || opt::IsNopNode(input_cnode)) {
2170 GetAllVisitedCNode(input_cnode, used_kernels, visited);
2171 } else {
2172 used_kernels->push_back(input);
2173 }
2174 }
2175 }
2176
GetAllFatherRealNode(const AnfNodePtr & anf_node,std::vector<AnfNodePtr> * result,std::set<AnfNodePtr> * visited)2177 void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
2178 std::set<AnfNodePtr> *visited) {
2179 MS_EXCEPTION_IF_NULL(anf_node);
2180 MS_EXCEPTION_IF_NULL(result);
2181 MS_EXCEPTION_IF_NULL(visited);
2182 if (visited->find(anf_node) != visited->end()) {
2183 MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has already been visited";
2184 return;
2185 }
2186 visited->insert(anf_node);
2187 if (AnfAlgo::IsRealKernel(anf_node)) {
2188 result->emplace_back(anf_node);
2189 return;
2190 }
2191 if (!anf_node->isa<CNode>()) {
2192 return;
2193 }
2194 auto cnode = anf_node->cast<CNodePtr>();
2195 MS_EXCEPTION_IF_NULL(cnode);
2196 if (cnode->inputs().empty()) {
2197 MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
2198 }
2199 auto input0 = cnode->input(0);
2200 if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
2201 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
2202 GetAllFatherRealNode(cnode->input(i), result, visited);
2203 }
2204 } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
2205 if (cnode->inputs().size() != kTupleGetItemInputSize) {
2206 MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
2207 }
2208 GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
2209 } else if (IsPrimitive(input0, prim::kPrimDepend)) {
2210 if (cnode->inputs().size() != kDependInputSize) {
2211 MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
2212 }
2213 GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
2214 GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
2215 }
2216 }
2217
InferShape(const CNodePtr & node,std::map<uint32_t,tensor::TensorPtr> * depend_tensors)2218 void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *depend_tensors) {
2219 MS_EXCEPTION_IF_NULL(node);
2220 MS_LOG(INFO) << "InferShape start, node:" << node->DebugString();
2221 auto inputs = node->inputs();
2222 if (inputs.empty()) {
2223 MS_LOG(EXCEPTION) << "Invalid inputs";
2224 }
2225 AbstractBasePtrList args_spec_list;
2226 auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
2227 auto input_size = AnfAlgo::GetInputTensorNum(node);
2228 for (size_t i = 0; i < input_size; ++i) {
2229 auto input_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
2230 auto real_input = input_with_index.first;
2231 auto cnode_input = node->input(i + 1);
2232 MS_EXCEPTION_IF_NULL(cnode_input);
2233 MS_EXCEPTION_IF_NULL(real_input);
2234 if (depend_tensors != nullptr) {
2235 auto iter_tensor = depend_tensors->find(i);
2236 if (iter_tensor != depend_tensors->end()) {
2237 auto tensor_ptr = iter_tensor->second;
2238 MS_EXCEPTION_IF_NULL(tensor_ptr);
2239 // sync data from device to host
2240 tensor_ptr->data_sync();
2241 auto real_abs = real_input->abstract();
2242 if (real_abs->isa<abstract::AbstractTensor>()) {
2243 real_input->abstract()->set_value(tensor_ptr);
2244 } else if (real_abs->isa<abstract::AbstractTuple>()) {
2245 auto tuple_get_item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
2246 auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>();
2247 MS_EXCEPTION_IF_NULL(abstract_tuple);
2248 auto tuple_elements = abstract_tuple->elements()[tuple_get_item_index];
2249 tuple_elements->set_value(tensor_ptr);
2250 }
2251 }
2252 }
2253 if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) {
2254 auto base_shape = real_input->Shape();
2255 if (!base_shape->isa<abstract::TupleShape>()) {
2256 MS_LOG(EXCEPTION) << "Node:" << node->DebugString()
2257 << " input is a tuple_get_item but real input node shape is not a TupleShape";
2258 }
2259 auto abs = real_input->abstract()->cast<abstract::AbstractTuplePtr>();
2260 MS_EXCEPTION_IF_NULL(abs);
2261 auto tuple_get_item_indexk = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
2262 auto abs_i = abs->elements()[tuple_get_item_indexk];
2263 (void)args_spec_list.emplace_back(abs_i);
2264 } else if (cnode_input->isa<CNode>() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) {
2265 (void)args_spec_list.emplace_back(cnode_input->abstract());
2266 } else {
2267 (void)args_spec_list.emplace_back(real_input->abstract());
2268 }
2269 }
2270 auto eval_result = opt::CppInferShape(primitive, args_spec_list);
2271 node->set_abstract(eval_result);
2272 }
2273
InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> & root_graph)2274 void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> &root_graph) {
2275 auto return_node = root_graph->get_return();
2276 MS_EXCEPTION_IF_NULL(return_node);
2277 if (return_node->size() <= kReturnDataIndex) {
2278 return;
2279 }
2280 auto make_tuple = root_graph->NewCNode(
2281 {NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
2282 root_graph->set_output(make_tuple);
2283 }
2284
GetUpdateStateUsers(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)2285 AnfNodeIndexSet AnfRuntimeAlgorithm::GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
2286 AnfNodeIndexSet update_states;
2287 for (auto &user : manager->node_users()[node]) {
2288 if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimUpdateState)) {
2289 update_states.insert(user);
2290 }
2291 }
2292 return update_states;
2293 }
2294
GetRealInputs(const AnfNodePtr & node,std::vector<session::KernelWithIndex> * inputs)2295 void AnfRuntimeAlgorithm::GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex> *inputs) {
2296 size_t input_num = AnfAlgo::GetInputTensorNum(node);
2297 for (size_t input_index = 0; input_index < input_num; ++input_index) {
2298 auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), input_index);
2299 GetRealOutputRecursively(input_node, 0, inputs);
2300 }
2301 }
2302
IsTensorBroadcast(const std::vector<size_t> & lhs,const std::vector<size_t> & rhs)2303 bool AnfRuntimeAlgorithm::IsTensorBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) {
2304 if (lhs.size() != rhs.size()) {
2305 return true;
2306 }
2307 for (size_t i = 0; i < lhs.size(); i++) {
2308 if (lhs[i] != rhs[i]) {
2309 return true;
2310 }
2311 }
2312 return false;
2313 }
2314
IsOneOfPrimitiveCNode(const AnfNodePtr & node,const PrimitiveSet & prim_set)2315 bool AnfRuntimeAlgorithm::IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
2316 MS_EXCEPTION_IF_NULL(node);
2317 auto cnode = node->cast<CNodePtr>();
2318 if (cnode == nullptr || cnode->size() == 0) {
2319 return false;
2320 }
2321 return IsOneOfPrimitive(cnode->inputs().at(kAnfPrimitiveIndex), prim_set);
2322 }
2323
IsControlOpExecInBackend(const AnfNodePtr & node)2324 bool AnfRuntimeAlgorithm::IsControlOpExecInBackend(const AnfNodePtr &node) {
2325 if (!node->isa<CNode>()) {
2326 return false;
2327 }
2328 // Operators in set control_ops_exec_in_backend will be compiled into kernel graph, rather than be cut into single op
2329 // and executed in VM.
2330 static std::set<std::string> control_ops_exec_in_backend = {kBpropCutOpName};
2331 return control_ops_exec_in_backend.find(AnfAlgo::GetCNodeName(node)) != control_ops_exec_in_backend.end();
2332 }
2333
IsNodeInputContainMonad(const AnfNodePtr & node)2334 bool AnfRuntimeAlgorithm::IsNodeInputContainMonad(const AnfNodePtr &node) {
2335 auto input_size = GetInputTensorNum(node);
2336 for (size_t i = 0; i < input_size; ++i) {
2337 auto input_with_index = GetPrevNodeOutput(node, i);
2338 if (HasAbstractMonad(input_with_index.first)) {
2339 return true;
2340 }
2341 }
2342 return false;
2343 }
2344
CacheAddrForGraph(const KernelGraphPtr & kernel_graph)2345 void AnfRuntimeAlgorithm::CacheAddrForGraph(const KernelGraphPtr &kernel_graph) {
2346 MS_EXCEPTION_IF_NULL(kernel_graph);
2347 auto ms_context = MsContext::GetInstance();
2348 MS_EXCEPTION_IF_NULL(ms_context);
2349 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
2350 ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) == true) {
2351 return;
2352 }
2353 auto nodes = kernel_graph->execution_order();
2354 for (auto &kernel : nodes) {
2355 // Skip transpose kernel with "nop_op" attr which is not hidden or removed in PyNative infer scenario. Transpose
2356 // kernel, which is not supposed to be executed, is generated in TransDataSplit to support specific Transdata.
2357 // And hard code here should be removed after new Transdata programme is implemented in the foreseeable future.
2358 if (HasNodeAttr("nop_op", kernel)) {
2359 for (size_t idx = 0; idx < GetOutputTensorNum(kernel); idx += 1) {
2360 auto real_input = GetRealInputIndex(kernel, idx);
2361 auto device_address = GetPrevNodeMutableOutputAddr(kernel, real_input);
2362 SetOutputAddr(device_address, idx, kernel.get());
2363 }
2364 continue;
2365 }
2366 auto kernel_mod = GetKernelMod(kernel);
2367 MS_EXCEPTION_IF_NULL(kernel_mod);
2368 if (GetCNodeName(kernel) == kAtomicAddrCleanOpName) {
2369 CacheAddrForAtomicClean(kernel, kernel_mod);
2370 continue;
2371 }
2372 CacheAddrForKernel(kernel, kernel_mod);
2373 }
2374 }
2375
CacheAddrForKernel(const AnfNodePtr & node,kernel::KernelMod * kernel_mod)2376 void AnfRuntimeAlgorithm::CacheAddrForKernel(const AnfNodePtr &node, kernel::KernelMod *kernel_mod) {
2377 MS_EXCEPTION_IF_NULL(node);
2378 MS_EXCEPTION_IF_NULL(kernel_mod);
2379 std::vector<AddressPtr> kernel_inputs;
2380 std::vector<AddressPtr> kernel_workspaces;
2381 std::vector<AddressPtr> kernel_outputs;
2382 auto cnode = node->cast<CNodePtr>();
2383 MS_EXCEPTION_IF_NULL(cnode);
2384 auto ms_context = MsContext::GetInstance();
2385 MS_EXCEPTION_IF_NULL(ms_context);
2386 auto visit_nop_node = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode);
2387 size_t input_num = GetInputTensorNum(node);
2388 for (size_t i = 0; i < input_num; ++i) {
2389 auto op_name = GetCNodeName(cnode);
2390 constexpr auto none_placeholder_index = 3;
2391 if (op_name == kDynamicRNNOpName && i == none_placeholder_index) {
2392 continue;
2393 }
2394 if (op_name == kDynamicGRUV2OpName) {
2395 auto none_index = GetNodeAttr<std::vector<int64_t>>(cnode, "placeholder_index");
2396 auto item = std::find(none_index.begin(), none_index.end(), i);
2397 if (item != none_index.end()) {
2398 continue;
2399 }
2400 }
2401 auto real_input = GetRealInputIndex(node, i);
2402 auto device_address = GetPrevNodeOutputAddr(node, real_input, visit_nop_node);
2403 MS_EXCEPTION_IF_NULL(device_address);
2404 kernel::AddressPtr input = std::make_shared<kernel::Address>();
2405 MS_EXCEPTION_IF_NULL(input);
2406 input->addr = const_cast<void *>(device_address->GetPtr());
2407 MS_EXCEPTION_IF_NULL(input->addr);
2408 input->size = device_address->GetSize();
2409 kernel_inputs.emplace_back(input);
2410 }
2411 for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
2412 auto device_address = GetOutputAddr(node, i, visit_nop_node);
2413 kernel::AddressPtr output = std::make_shared<kernel::Address>();
2414 MS_EXCEPTION_IF_NULL(output);
2415 output->addr = const_cast<void *>(device_address->GetPtr());
2416 MS_EXCEPTION_IF_NULL(output->addr);
2417 output->size = device_address->GetSize();
2418 kernel_outputs.emplace_back(output);
2419 }
2420 for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
2421 auto device_address = GetWorkspaceAddr(node, i);
2422 kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
2423 MS_EXCEPTION_IF_NULL(workspace);
2424 workspace->addr = const_cast<void *>(device_address->GetPtr());
2425 MS_EXCEPTION_IF_NULL(workspace->addr);
2426 workspace->size = device_address->GetSize();
2427 kernel_workspaces.emplace_back(workspace);
2428 }
2429 kernel_mod->set_inputs_addr(kernel_inputs);
2430 kernel_mod->set_workspaces_addr(kernel_workspaces);
2431 kernel_mod->set_outputs_addr(kernel_outputs);
2432 }
2433
CacheAddrForAtomicClean(const AnfNodePtr & node,kernel::KernelMod * kernel_mod)2434 void AnfRuntimeAlgorithm::CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod) {
2435 MS_EXCEPTION_IF_NULL(node);
2436 MS_EXCEPTION_IF_NULL(kernel_mod);
2437 std::vector<AddressPtr> kernel_inputs;
2438 auto cnode = node->cast<CNodePtr>();
2439 MS_EXCEPTION_IF_NULL(cnode);
2440 if (cnode->inputs().size() != kIndex2) {
2441 MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
2442 }
2443 MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
2444 auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
2445 // set clean output address
2446 if (HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
2447 #if defined(__APPLE__)
2448 auto clean_output_indexes = GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicOutputIndexs);
2449 #else
2450 auto clean_output_indexes = GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
2451 #endif
2452 for (auto index : clean_output_indexes) {
2453 auto device_address = GetOutputAddr(pre_node, index);
2454 kernel::AddressPtr input = std::make_shared<kernel::Address>();
2455 MS_EXCEPTION_IF_NULL(input);
2456 input->addr = const_cast<void *>(device_address->GetPtr());
2457 MS_EXCEPTION_IF_NULL(input->addr);
2458 input->size = device_address->GetSize();
2459 kernel_inputs.emplace_back(input);
2460 }
2461 MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexes.size();
2462 }
2463 // set clean workspace address
2464 if (HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
2465 #if defined(__APPLE__)
2466 auto clean_workspaces_indexes = GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicWorkspaceIndexs);
2467 #else
2468 auto clean_workspaces_indexes = GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
2469 #endif
2470 for (const auto &index : clean_workspaces_indexes) {
2471 auto device_address = GetWorkspaceAddr(pre_node, index);
2472 kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
2473 MS_EXCEPTION_IF_NULL(workspace);
2474 workspace->addr = const_cast<void *>(device_address->GetPtr());
2475 MS_EXCEPTION_IF_NULL(workspace->addr);
2476 workspace->size = device_address->GetSize();
2477 kernel_inputs.emplace_back(workspace);
2478 }
2479 }
2480 kernel_mod->set_inputs_addr(kernel_inputs);
2481 }
2482
output_format(size_t index) const2483 std::string OpRuntimeInfo::output_format(size_t index) const {
2484 if (index >= output_format_.size()) {
2485 MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_format:" << output_format_.size();
2486 }
2487 return output_format_[index];
2488 }
2489
output_type(size_t index) const2490 TypeId OpRuntimeInfo::output_type(size_t index) const {
2491 if (index >= output_type_.size()) {
2492 MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_type:" << output_type_.size();
2493 }
2494 return output_type_[index];
2495 }
2496
output_tensor_size(size_t index) const2497 size_t OpRuntimeInfo::output_tensor_size(size_t index) const {
2498 if (index >= output_tensor_size_.size()) {
2499 MS_LOG(EXCEPTION) << "Invalid index::" << index << " total output_tensor_size:" << output_tensor_size_.size();
2500 }
2501 return output_tensor_size_[index];
2502 }
2503 } // namespace session
2504 } // namespace mindspore
2505