1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/backend/anf_runtime_algorithm.h"
17
18 #include <memory>
19 #include <algorithm>
20 #include <map>
21 #include <set>
22 #include <functional>
23 #include "ops/ascend_op_name.h"
24 #include "ops/math_op_name.h"
25 #include "ops/lite_op_name.h"
26 #include "ops/structure_ops.h"
27 #include "ops/sequence_ops.h"
28 #include "ops/framework_ops.h"
29 #include "ir/anf.h"
30 #include "utils/log_adapter.h"
31 #include "ir/func_graph_cloner.h"
32 #include "utils/shape_utils.h"
33 #include "include/common/utils/utils.h"
34 #include "include/common/utils/parallel_context.h"
35 #include "include/common/utils/anfalgo.h"
36 #include "include/common/debug/anf_dump_utils.h"
37 #include "include/backend/kernel_info.h"
38 #include "include/backend/device_address.h"
39 #include "include/backend/optimizer/helper.h"
40 #include "kernel/kernel.h"
41 #include "kernel/kernel_build_info.h"
42 #include "runtime/device/ms_device_shape_transfer.h"
43 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
44 #include "abstract/ops/primitive_infer_map.h"
45 #include "utils/trace_base.h"
46 #include "utils/anf_utils.h"
47 #include "utils/ms_context.h"
48 #ifndef BUILD_LITE
49 #include "pybind_api/ir/base_ref_py.h"
50 #endif
51
52 namespace mindspore::session {
53 using abstract::AbstractTensor;
54 using abstract::AbstractTuple;
55 using device::KernelInfo;
56 using kernel::KernelBuildInfoPtr;
57 using kernel::KernelMod;
58 using kernel::KernelModPtr;
59 constexpr char kDisableKernelBackoff[] = "MS_DISABLE_KERNEL_BACKOFF";
60
61 namespace {
62 constexpr size_t kReturnDataIndex = 1;
63 constexpr size_t kSwitchTrueBranchIndex = 2;
64 constexpr auto kPatternUnknown = "";
65
PrintKernelFormatAndType(const std::string & fmt,const TypeId & type,const std::vector<int64_t> & shape)66 std::string PrintKernelFormatAndType(const std::string &fmt, const TypeId &type, const std::vector<int64_t> &shape) {
67 std::ostringstream buffer;
68 buffer << "<" << TypeIdLabel(type);
69 if (!fmt.empty()) {
70 buffer << "x" << fmt << shape;
71 }
72 buffer << ">";
73 return buffer.str();
74 }
75
76 [[maybe_unused]] struct AnfDumpHandlerRegister {
AnfDumpHandlerRegistermindspore::session::__anon4d458be80111::AnfDumpHandlerRegister77 AnfDumpHandlerRegister() {
78 AnfDumpHandler::SetPrintInputTypeShapeFormatHandler([](const std::shared_ptr<AnfNode> &node) -> std::string {
79 if (node == nullptr) {
80 return "";
81 }
82 std::ostringstream buffer;
83 size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
84 for (size_t i = 0; i < input_num; ++i) {
85 if (i != 0) {
86 buffer << ", ";
87 }
88 auto format = AnfAlgo::GetInputFormat(node, i);
89 auto type = AnfAlgo::GetInputDeviceDataType(node, i);
90 auto shape = AnfAlgo::GetInputDeviceShape(node, i);
91 buffer << PrintKernelFormatAndType(format, type, shape);
92 }
93 return buffer.str();
94 });
95 AnfDumpHandler::SetPrintOutputTypeShapeFormatHandler([](const std::shared_ptr<AnfNode> &node) -> std::string {
96 if (node == nullptr) {
97 return "";
98 }
99 std::ostringstream buffer;
100 size_t output_num = AnfAlgo::GetOutputTensorNum(node);
101 for (size_t i = 0; i < output_num; ++i) {
102 if (i != 0) {
103 buffer << ", ";
104 }
105 auto format = AnfAlgo::GetOutputFormat(node, (node->isa<Parameter>() ? 0 : i));
106 auto type = AnfAlgo::GetOutputDeviceDataType(node, (node->isa<Parameter>() ? 0 : i));
107 auto shape = AnfAlgo::GetOutputDeviceShape(node, (node->isa<Parameter>() ? 0 : i));
108 buffer << PrintKernelFormatAndType(format, type, shape);
109 }
110 return buffer.str();
111 });
112 AnfDumpHandler::SetPrintInputKernelObjectTypesHandler([](const std::shared_ptr<AnfNode> &node) -> std::string {
113 if (node == nullptr) {
114 return "";
115 }
116 auto input_obj_types = AnfAlgo::GetInputKernelObjectTypes(node);
117 return std::accumulate(
118 input_obj_types.begin(), input_obj_types.end(), std::string(), [](std::string &a, const KernelObjectType &b) {
119 return a.empty() ? kernel::KernelObjectTypeLabel(b) : a + ", " + kernel::KernelObjectTypeLabel(b);
120 });
121 });
122 AnfDumpHandler::SetPrintOutputKernelObjectTypesHandler([](const std::shared_ptr<AnfNode> &node) -> std::string {
123 if (node == nullptr) {
124 return "";
125 }
126 auto output_obj_types = AnfAlgo::GetOutputKernelObjectTypes(node);
127 return std::accumulate(
128 output_obj_types.begin(), output_obj_types.end(), std::string(), [](std::string &a, const KernelObjectType &b) {
129 return a.empty() ? kernel::KernelObjectTypeLabel(b) : a + ", " + kernel::KernelObjectTypeLabel(b);
130 });
131 });
132 }
133 } callback_register;
134
GetForwardOutputTensor(const AnfNodePtr & node)135 tensor::BaseTensorPtr GetForwardOutputTensor(const AnfNodePtr &node) {
136 MS_EXCEPTION_IF_NULL(node);
137 if (node->isa<ValueNode>()) {
138 auto value_node = node->cast<ValueNodePtr>();
139 MS_EXCEPTION_IF_NULL(value_node);
140 auto value = value_node->value();
141 MS_EXCEPTION_IF_NULL(value);
142 if (value->isa<tensor::BaseTensor>()) {
143 auto tensor = value->cast<tensor::BaseTensorPtr>();
144 MS_EXCEPTION_IF_NULL(tensor);
145 // If output used as sens, output will create(clone) a fake tensor with device address is nullptr for memory
146 // usage. It has is_forward_output flag, which will be used for tensor input mask, and affect single op graph
147 // cache.
148 if (tensor->is_forward_output() && tensor->device_address() != nullptr) {
149 return tensor;
150 }
151 }
152 }
153 return nullptr;
154 }
155
GetOutputTensorNumByKernelInfo(const AnfNodePtr & node)156 size_t GetOutputTensorNumByKernelInfo(const AnfNodePtr &node) {
157 MS_EXCEPTION_IF_NULL(node);
158 MS_EXCEPTION_IF_NULL(node->kernel_info());
159 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
160 MS_EXCEPTION_IF_NULL(kernel_info);
161 const auto &build_info = kernel_info->GetMutableSelectKernelBuildInfo();
162 MS_EXCEPTION_IF_NULL(build_info);
163 return build_info->GetAllOutputDeviceTypes().size();
164 }
165
ContainScalarOut(const AbstractBasePtr & abs)166 bool ContainScalarOut(const AbstractBasePtr &abs) {
167 // Check the output abstract of node whether is scalar.
168 if ((abs != nullptr) && (abs->isa<abstract::AbstractScalar>())) {
169 return true;
170 }
171 // Check the output abstracts of node whether have scalar.
172 if ((abs != nullptr) && (abs->isa<abstract::AbstractSequence>())) {
173 auto abs_seq = abs->cast_ptr<abstract::AbstractSequence>();
174 MS_EXCEPTION_IF_NULL(abs_seq);
175 if (abs_seq->dynamic_len()) {
176 const auto &element_abs = abs_seq->dynamic_len_element_abs();
177 return (element_abs == nullptr) || (element_abs->isa<abstract::AbstractScalar>());
178 }
179 const auto &elements = abs_seq->elements();
180 bool has_scalar_out = std::any_of(elements.begin(), elements.end(),
181 [](const AbstractBasePtr &element) { return ContainScalarOut(element); });
182 return has_scalar_out;
183 }
184 return false;
185 }
186 } // namespace
187
MakeMonadValueNode(const KernelGraphPtr & kg)188 AnfNodePtr AnfRuntimeAlgorithm::MakeMonadValueNode(const KernelGraphPtr &kg) {
189 MS_EXCEPTION_IF_NULL(kg);
190 return kg->NewValueNode(kUMonad->ToAbstract(), kUMonad);
191 }
192
193 // Convert: a = former(xxx)
194 // b = latter(x, xxx)
195 // To: a = former(xxx)
196 // d1 = Depend(x, a)
197 // b = latter(d1, xxx)
198 // ...
199 // out = Depend(out, latter)
KeepOrder(const KernelGraphPtr & kg,const AnfNodePtr & former,const AnfNodePtr & latter)200 void AnfRuntimeAlgorithm::KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter) {
201 MS_EXCEPTION_IF_NULL(kg);
202 MS_EXCEPTION_IF_NULL(former);
203 MS_EXCEPTION_IF_NULL(latter);
204 if (latter->isa<CNode>()) {
205 auto latter_cnode = latter->cast<CNodePtr>();
206 MS_EXCEPTION_IF_NULL(latter_cnode);
207 constexpr size_t inputsize = 2;
208 constexpr size_t kFirstDataInputIndex = 1;
209 if (latter_cnode->size() < inputsize) {
210 return;
211 }
212 auto latter_input = latter_cnode->input(kFirstDataInputIndex);
213 auto depend1 = kg->NewCNode({NewValueNode(prim::kPrimDepend), latter_input, former});
214 MS_EXCEPTION_IF_NULL(depend1);
215 MS_EXCEPTION_IF_NULL(latter_input);
216 depend1->set_abstract(latter_input->abstract());
217 latter_cnode->set_input(kFirstDataInputIndex, depend1);
218
219 auto return_node = kg->get_return();
220 MS_EXCEPTION_IF_NULL(return_node);
221 auto depend2 = kg->NewCNode(
222 {NewValueNode(prim::kPrimDepend), return_node->cast<CNodePtr>()->input(kFirstDataInputIndex), latter});
223 MS_EXCEPTION_IF_NULL(depend2);
224 depend2->set_abstract(return_node->cast<CNodePtr>()->input(kFirstDataInputIndex)->abstract());
225 kg->set_output(depend2);
226 MS_LOG(DEBUG) << "former: " << former->DebugString() << ", latter: " << latter->DebugString()
227 << ", depend1: " << depend1->DebugString() << ", depend2: " << depend2->DebugString();
228 }
229 }
230
GetOutputTensorNum(const AnfNodePtr & node)231 size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
232 MS_EXCEPTION_IF_NULL(node);
233 size_t res;
234 TypePtr type = node->Type();
235 if (type == nullptr) {
236 res = 0;
237 } else if (type->isa<Tuple>() || type->isa<List>()) {
238 const auto &kernel_info = node->kernel_info();
239 if (kernel_info == nullptr || (!kernel_info->has_build_info())) {
240 return 1;
241 }
242 res = GetOutputTensorNumByKernelInfo(node);
243 } else if (type->isa<TypeNone>()) {
244 res = 0;
245 } else if (type->isa<CSRTensorType>()) {
246 // Currently, CSRTensor only supports 2-D matrix (shape has 2 values). 5 outputs = 3 Tensors + 2 shape values.
247 constexpr size_t kCSRTensorOutputNum = 5;
248 res = kCSRTensorOutputNum;
249 } else if (type->isa<COOTensorType>()) {
250 // Currently, COOTensor only supports 2-D matrix (shape has 2 values). 4 outputs = 2 Tensors + 2 shape values.
251 constexpr size_t kCOOTensorOutputNum = 4;
252 res = kCOOTensorOutputNum;
253 } else if (AnfUtils::NeedJumpMonadOutput(node) && type->isa<MonadType>()) {
254 // Some nodes could have monad outputs like RpcRecv. We need to jump these outputs.
255 res = 0;
256 } else {
257 res = 1;
258 }
259 return res;
260 }
261
GetOutputNumWithoutKernelInfo(const AnfNodePtr & node)262 size_t AnfRuntimeAlgorithm::GetOutputNumWithoutKernelInfo(const AnfNodePtr &node) {
263 MS_EXCEPTION_IF_NULL(node);
264 const auto &kernel_info = node->kernel_info();
265 if (kernel_info != nullptr) {
266 MS_LOG(EXCEPTION) << "Kernel info is not null for node:" << node->DebugString();
267 }
268
269 size_t res;
270 TypePtr type = node->Type();
271 if (type == nullptr) {
272 res = 0;
273 } else if (type->isa<Tuple>() || type->isa<List>()) {
274 res = 1;
275 } else if (type->isa<TypeNone>()) {
276 res = 0;
277 } else if (type->isa<CSRTensorType>()) {
278 // Currently, CSRTensor only supports 2-D matrix (shape has 2 values). 5 outputs = 3 Tensors + 2 shape values.
279 constexpr size_t kCSRTensorOutputNum = 5;
280 res = kCSRTensorOutputNum;
281 } else if (type->isa<COOTensorType>()) {
282 // Currently, COOTensor only supports 2-D matrix (shape has 2 values). 4 outputs = 2 Tensors + 2 shape values.
283 constexpr size_t kCOOTensorOutputNum = 4;
284 res = kCOOTensorOutputNum;
285 } else if (AnfUtils::NeedJumpMonadOutput(node) && type->isa<MonadType>()) {
286 // Some nodes could have monad outputs like RpcRecv. We need to jump these outputs.
287 res = 0;
288 } else {
289 res = 1;
290 }
291 return res;
292 }
293
294 namespace {
IsTupleHasDynamicSequence(const abstract::AbstractBasePtr & abstract)295 bool IsTupleHasDynamicSequence(const abstract::AbstractBasePtr &abstract) {
296 MS_EXCEPTION_IF_NULL(abstract);
297 if (!abstract->isa<abstract::AbstractSequence>()) {
298 return false;
299 }
300 const auto &sequence_abs = abstract->cast<abstract::AbstractSequencePtr>();
301 MS_EXCEPTION_IF_NULL(sequence_abs);
302 if (sequence_abs->dynamic_len() || sequence_abs->dynamic_len_element_abs() != nullptr) {
303 return true;
304 }
305 if (std::any_of(sequence_abs->elements().begin(), sequence_abs->elements().end(),
306 [](const abstract::AbstractBasePtr &abs) { return IsTupleHasDynamicSequence(abs); })) {
307 return true;
308 }
309 return false;
310 }
311 } // namespace
312
GetOutputElementNum(const AnfNodePtr & node)313 size_t AnfRuntimeAlgorithm::GetOutputElementNum(const AnfNodePtr &node) {
314 if (node->abstract() != nullptr && IsTupleHasDynamicSequence(node->abstract())) {
315 return common::AnfAlgo::GetOutputNumByAbstract(node->abstract());
316 }
317 return AnfUtils::GetOutputTensorNum(node);
318 }
319
GetOutputTensorMemSizeImpl(const AnfNodePtr & node,size_t output_index,const ShapeVector & real_shape)320 size_t GetOutputTensorMemSizeImpl(const AnfNodePtr &node, size_t output_index, const ShapeVector &real_shape) {
321 MS_EXCEPTION_IF_NULL(node);
322 if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
323 MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
324 << AnfAlgo::GetOutputTensorNum(node) << "] of node!";
325 }
326 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
327 if (output_type_id == kTypeUnknown) {
328 output_type_id = common::AnfAlgo::GetOutputInferDataType(node, output_index);
329 }
330 size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
331 auto shape = real_shape;
332 auto format = AnfAlgo::GetOutputFormat(node, output_index);
333 auto dtype = AnfAlgo::GetOutputDeviceDataType(node, output_index);
334 if (shape.empty() && format != kOpFormat_DEFAULT) {
335 shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index), node);
336 shape = trans::TransShapeToDevice(shape, format, node, output_index, dtype);
337 }
338 // scalar's output shape is a empty vector
339 size_t tensor_size = type_size * SizeOf(shape);
340 return tensor_size;
341 }
342
GetOutputTensorMemSize(const AnfNodePtr & node,size_t output_index,const ShapeVector & real_shape)343 size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index,
344 const ShapeVector &real_shape) {
345 if (IsDynamic(real_shape)) {
346 MS_LOG(EXCEPTION) << "The shape is " << real_shape << " dynamic shape , can not get OutputTensorMemSize";
347 }
348 return GetOutputTensorMemSizeImpl(node, output_index, real_shape);
349 }
350
GetOutputTensorMemSize(const AnfNodePtr & node,size_t output_index)351 size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index) {
352 MS_EXCEPTION_IF_NULL(node);
353 auto shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
354 if (IsDynamic(shape)) {
355 auto max_shape = common::AnfAlgo::GetOutputMaxShape(node, output_index);
356 if (!max_shape.empty()) {
357 shape = max_shape;
358 MS_LOG(DEBUG) << "shape[" << shape << "] is dynamic, using max_shape[" << max_shape << "] instead.";
359 } else {
360 shape = {1};
361 MS_LOG(DEBUG) << "shape[" << shape << "] is dynamic, set default to {1}";
362 }
363 }
364 return GetOutputTensorMemSizeImpl(node, output_index, shape);
365 }
366
GetAllOutputFormats(const AnfNodePtr & node)367 std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) {
368 MS_EXCEPTION_IF_NULL(node);
369 if (!AnfUtils::IsRealKernel(node)) {
370 MS_LOG(EXCEPTION) << "Not real kernel:"
371 << "#node [" << node->DebugString() << "]" << trace::DumpSourceLines(node);
372 }
373 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
374 MS_EXCEPTION_IF_NULL(kernel_info);
375 auto build_info = kernel_info->select_kernel_build_info();
376 MS_EXCEPTION_IF_NULL(build_info);
377 auto format = build_info->GetAllOutputFormats();
378 return format;
379 }
380
GetAllInputFormats(const AnfNodePtr & node)381 std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePtr &node) {
382 MS_EXCEPTION_IF_NULL(node);
383 if (!AnfUtils::IsRealKernel(node)) {
384 MS_LOG(EXCEPTION) << "Not real kernel:"
385 << "#node [" << node->DebugString() << "]" << trace::DumpSourceLines(node);
386 }
387 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
388 MS_EXCEPTION_IF_NULL(kernel_info);
389 auto build_info = kernel_info->select_kernel_build_info();
390 MS_EXCEPTION_IF_NULL(build_info);
391 auto format = build_info->GetAllInputFormats();
392 return format;
393 }
394
GetAllInputDeviceTypes(const AnfNodePtr & node)395 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllInputDeviceTypes(const AnfNodePtr &node) {
396 MS_EXCEPTION_IF_NULL(node);
397 if (!AnfUtils::IsRealKernel(node)) {
398 MS_LOG(EXCEPTION) << "Not real kernel:"
399 << "#node [" << node->DebugString() << "]" << trace::DumpSourceLines(node);
400 }
401 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
402 MS_EXCEPTION_IF_NULL(kernel_info);
403 auto build_info = kernel_info->select_kernel_build_info();
404 MS_EXCEPTION_IF_NULL(build_info);
405 auto types = build_info->GetAllInputDeviceTypes();
406 return types;
407 }
408
GetAllOutputDeviceTypes(const AnfNodePtr & node)409 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputDeviceTypes(const AnfNodePtr &node) {
410 MS_EXCEPTION_IF_NULL(node);
411 if (!AnfUtils::IsRealKernel(node)) {
412 MS_LOG(EXCEPTION) << "Not real kernel:"
413 << "#node [" << node->DebugString() << "]" << trace::DumpSourceLines(node);
414 }
415 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
416 MS_EXCEPTION_IF_NULL(kernel_info);
417 auto build_info = kernel_info->select_kernel_build_info();
418 MS_EXCEPTION_IF_NULL(build_info);
419 auto types = build_info->GetAllOutputDeviceTypes();
420 return types;
421 }
422
GetOriginDataFormat(const AnfNodePtr & node)423 std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) {
424 MS_EXCEPTION_IF_NULL(node);
425 if (!AnfUtils::IsRealKernel(node)) {
426 MS_LOG(EXCEPTION) << "Not real kernel:"
427 << "#node [" << node->DebugString() << "]" << trace::DumpSourceLines(node);
428 }
429 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
430 if (kernel_info == nullptr) {
431 return kOpFormat_DEFAULT;
432 }
433 auto build_info = kernel_info->select_kernel_build_info();
434 if (build_info == nullptr) {
435 return kOpFormat_DEFAULT;
436 }
437 auto format = build_info->GetOriginDataFormat();
438 return format;
439 }
440
GetOutputFormat(const AnfNodePtr & node,size_t output_idx)441 std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
442 MS_EXCEPTION_IF_NULL(node);
443 if (output_idx > AnfAlgo::GetOutputElementNum(node) && (!common::AnfAlgo::IsDynamicSequence(node))) {
444 MS_LOG(EXCEPTION) << "Output index:" << output_idx
445 << " is out of the node output range :" << AnfAlgo::GetOutputElementNum(node) << " #node ["
446 << node->DebugString() << "]" << trace::DumpSourceLines(node);
447 }
448 if (common::AnfAlgo::CheckAbsSparseTensor(node)) {
449 return kOpFormat_DEFAULT;
450 }
451 if (!AnfUtils::IsRealKernel(node)) {
452 return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
453 }
454 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
455 MS_EXCEPTION_IF_NULL(kernel_info);
456 auto build_info = kernel_info->select_kernel_build_info();
457 MS_EXCEPTION_IF_NULL(build_info);
458 std::string format;
459 // If the output is TUPLE, output format list's size is 1. So we use the first element as the output format.
460 // This scenario could happen before 'insert_type_transform_op' pass.
461 auto output_obj_types = build_info->GetAllOutputKernelObjectTypes();
462 if (!output_obj_types.empty() && output_obj_types[kIndex0] == KernelObjectType::TUPLE) {
463 MS_LOG(DEBUG) << "TUPLE only has one output. So use index 0 output format.";
464 format = build_info->GetOutputFormat(kIndex0);
465 } else {
466 format = build_info->GetOutputFormat(output_idx);
467 }
468 if (format == kernel::KernelBuildInfo::kInvalidFormat) {
469 MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
470 << " has a invalid output format" << trace::DumpSourceLines(node);
471 }
472 return format;
473 }
474
GetInputFormat(const AnfNodePtr & node,size_t input_idx)475 std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
476 MS_EXCEPTION_IF_NULL(node);
477 if (input_idx > common::AnfAlgo::GetInputTensorNum(node)) {
478 MS_LOG(EXCEPTION) << "Input index :" << input_idx
479 << " is out of the number node Input range :" << common::AnfAlgo::GetInputTensorNum(node)
480 << "#node [" << node->DebugString() << "]" << trace::DumpSourceLines(node);
481 }
482 if (!AnfUtils::IsRealKernel(node)) {
483 return GetPrevNodeOutputFormat(node, input_idx);
484 }
485 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
486 MS_EXCEPTION_IF_NULL(kernel_info);
487 auto build_info = kernel_info->select_kernel_build_info();
488 MS_EXCEPTION_IF_NULL(build_info);
489 auto format = build_info->GetInputFormat(input_idx);
490 if (format == kernel::KernelBuildInfo::kInvalidFormat) {
491 MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
492 << " input index:" << input_idx << " has a invalid input format\n"
493 << trace::DumpSourceLines(node);
494 }
495 return format;
496 }
497
IsEquivalentFormat(const std::string & src_format,const std::string & dst_format)498 bool AnfRuntimeAlgorithm::IsEquivalentFormat(const std::string &src_format, const std::string &dst_format) {
499 if (src_format == dst_format) {
500 return true;
501 }
502
503 // Equivalent default format.
504 if (((src_format == kOpFormat_DEFAULT) || (src_format == kOpFormat_NCHW) || (src_format == kOpFormat_ND)) &&
505 ((dst_format == kOpFormat_DEFAULT) || (dst_format == kOpFormat_NCHW) || (dst_format == kOpFormat_ND))) {
506 return true;
507 }
508
509 return false;
510 }
511
GetPrevNodeOutputFormat(const AnfNodePtr & anf_node,size_t input_idx)512 std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
513 KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
514 return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
515 }
516
GetPrevNodeOutputReshapeType(const AnfNodePtr & node,size_t input_idx)517 std::string AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
518 KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_idx);
519 return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
520 }
521
GetInputKernelObjectTypes(const AnfNodePtr & node)522 std::vector<KernelObjectType> AnfRuntimeAlgorithm::GetInputKernelObjectTypes(const AnfNodePtr &node) {
523 MS_EXCEPTION_IF_NULL(node);
524 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
525 MS_EXCEPTION_IF_NULL(kernel_info);
526 auto build_info = kernel_info->select_kernel_build_info();
527 if (build_info == nullptr) {
528 MS_LOG(EXCEPTION) << "Empty build info for node:" << node->fullname_with_scope()
529 << ", debug name:" << node->DebugString();
530 }
531 return build_info->GetAllInputKernelObjectTypes();
532 }
533
GetInputKernelObjectType(const AnfNodePtr & node,size_t input_idx)534 KernelObjectType AnfRuntimeAlgorithm::GetInputKernelObjectType(const AnfNodePtr &node, size_t input_idx) {
535 MS_EXCEPTION_IF_NULL(node);
536 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
537 MS_EXCEPTION_IF_NULL(kernel_info);
538 auto build_info = kernel_info->select_kernel_build_info();
539 if (build_info == nullptr) {
540 MS_LOG(EXCEPTION) << "Empty build info for node:" << node->fullname_with_scope()
541 << ", debug name:" << node->DebugString();
542 }
543 const auto &input_kernel_obj_types = build_info->GetAllInputKernelObjectTypes();
544 if (input_idx >= input_kernel_obj_types.size()) {
545 MS_LOG(EXCEPTION) << "Input index " << input_idx << ", but the node input kernel object types size just "
546 << input_kernel_obj_types.size() << ". node: " << node->fullname_with_scope()
547 << ", debug name:" << node->DebugString() << "." << trace::DumpSourceLines(node);
548 }
549 return input_kernel_obj_types[input_idx];
550 }
551
GetOutputKernelObjectTypes(const AnfNodePtr & node)552 std::vector<KernelObjectType> AnfRuntimeAlgorithm::GetOutputKernelObjectTypes(const AnfNodePtr &node) {
553 MS_EXCEPTION_IF_NULL(node);
554 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
555 MS_EXCEPTION_IF_NULL(kernel_info);
556 auto build_info = kernel_info->select_kernel_build_info();
557 if (build_info == nullptr) {
558 return {};
559 }
560 return build_info->GetAllOutputKernelObjectTypes();
561 }
562
GetOutputKernelObjectType(const AnfNodePtr & node,size_t output_idx)563 KernelObjectType AnfRuntimeAlgorithm::GetOutputKernelObjectType(const AnfNodePtr &node, size_t output_idx) {
564 MS_EXCEPTION_IF_NULL(node);
565 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
566 MS_EXCEPTION_IF_NULL(kernel_info);
567 auto build_info = kernel_info->select_kernel_build_info();
568 if (build_info == nullptr) {
569 MS_LOG(EXCEPTION) << "Empty build info for node:" << node->fullname_with_scope()
570 << ", debug name:" << node->DebugString();
571 }
572 const auto &output_kernel_obj_types = build_info->GetAllOutputKernelObjectTypes();
573 if (output_idx >= output_kernel_obj_types.size()) {
574 MS_LOG(EXCEPTION) << "Output index " << output_idx << ", but the node output kernel object types size just "
575 << output_kernel_obj_types.size() << ". node: " << node->fullname_with_scope()
576 << ", debug name:" << node->DebugString() << "." << trace::DumpSourceLines(node);
577 }
578 return output_kernel_obj_types[output_idx];
579 }
580
GetOutputElementsKernelObjectTypes(const AnfNodePtr & node)581 std::vector<KernelObjectType> AnfRuntimeAlgorithm::GetOutputElementsKernelObjectTypes(const AnfNodePtr &node) {
582 MS_EXCEPTION_IF_NULL(node);
583 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
584 MS_EXCEPTION_IF_NULL(kernel_info);
585 auto build_info = kernel_info->select_kernel_build_info();
586 if (build_info == nullptr) {
587 MS_LOG(EXCEPTION) << "Empty build info for node:" << node->fullname_with_scope()
588 << ", debug name:" << node->DebugString();
589 }
590 return build_info->GetAllOutputElementsKernelObjectTypes();
591 }
592
GetValid(const AnfNodePtr & node)593 bool AnfRuntimeAlgorithm::GetValid(const AnfNodePtr &node) {
594 MS_EXCEPTION_IF_NULL(node);
595 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
596 MS_EXCEPTION_IF_NULL(kernel_info);
597 auto build_info = kernel_info->select_kernel_build_info();
598 if (build_info == nullptr) {
599 MS_LOG(EXCEPTION) << "Empty build info for node:" << node->fullname_with_scope()
600 << ", debug name:" << node->DebugString();
601 }
602 return build_info->valid();
603 }
604
IsRealSquenceOutput(const AnfNodePtr & node)605 bool AnfRuntimeAlgorithm::IsRealSquenceOutput(const AnfNodePtr &node) {
606 MS_EXCEPTION_IF_NULL(node);
607 std::vector<KernelObjectType> objects = GetOutputKernelObjectTypes(node);
608 bool is_real_tuple = false;
609 if (objects.empty()) {
610 return false;
611 } else {
612 is_real_tuple = (objects[0] == KernelObjectType::TUPLE);
613 }
614 return is_real_tuple;
615 }
616
GetOutputDeviceShapeForTbeBuild(const AnfNodePtr & node,size_t output_idx,const std::string & format)617 std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node, size_t output_idx,
618 const std::string &format) {
619 auto output_shape = AnfAlgo::GetOutputDetailShape(node, output_idx);
620 std::vector<int64_t> infer_shape;
621 if (output_shape->isa<abstract::Shape>()) {
622 auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
623 MS_EXCEPTION_IF_NULL(shape_ptr);
624 infer_shape = shape_ptr->shape();
625 }
626 if (infer_shape.empty()) {
627 return infer_shape;
628 }
629
630 // if format is default_format or NC1KHKWHWC0,device shape = original shape
631 if (trans::IsNeedPadding(format, infer_shape)) {
632 infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx), node);
633 }
634 auto dtype = GetOutputDeviceDataType(node, output_idx);
635 return trans::TransShapeToDevice(infer_shape, format, node, output_idx, dtype);
636 }
637
IsShapesDynamic(const std::vector<ShapeVector> & shapes)638 bool AnfRuntimeAlgorithm::IsShapesDynamic(const std::vector<ShapeVector> &shapes) {
639 return std::any_of(shapes.cbegin(), shapes.cend(), [](const auto &shape) { return IsDynamic(shape); });
640 }
641
GetOutputDeviceShape(const AnfNodePtr & node,size_t output_idx)642 ShapeVector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
643 auto format = GetOutputFormat(node, output_idx);
644 auto infer_shape = common::AnfAlgo::GetOutputInferShape(node, output_idx, IsRealSquenceOutput(node));
645 if (infer_shape.empty()) {
646 return infer_shape;
647 }
648
649 // if format is default_format or NC1KHKWHWC0,device shape = original shape
650 if (trans::IsNeedPadding(format, infer_shape)) {
651 infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx), node);
652 }
653 auto dtype = GetOutputDeviceDataType(node, output_idx);
654 return trans::TransShapeToDevice(infer_shape, format, node, output_idx, dtype);
655 }
656
GetOutputDeviceShape(const AnfNodePtr & node,size_t output_idx,ShapeVector real_shape)657 ShapeVector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx,
658 ShapeVector real_shape) {
659 auto format = GetOutputFormat(node, output_idx);
660 if (real_shape.empty()) {
661 return real_shape;
662 }
663
664 // if format is default_format or NC1KHKWHWC0,device shape = original shape
665 if (trans::IsNeedPadding(format, real_shape)) {
666 real_shape = trans::PaddingShape(real_shape, format, GetOutputReshapeType(node, output_idx), node);
667 }
668 auto dtype = GetOutputDeviceDataType(node, output_idx);
669 return trans::TransShapeToDevice(real_shape, format, node, output_idx, dtype);
670 }
671
GetInputDeviceShapeForTbeBuild(const AnfNodePtr & node,size_t input_idx,const std::string & format)672 std::vector<int64_t> AnfRuntimeAlgorithm::GetInputDeviceShapeForTbeBuild(const AnfNodePtr &node, size_t input_idx,
673 const std::string &format) {
674 auto output_shape = AnfAlgo::GetPrevNodeOutputDetailShape(node, input_idx);
675 std::vector<int64_t> infer_shape;
676 if (output_shape->isa<abstract::Shape>()) {
677 auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
678 MS_EXCEPTION_IF_NULL(shape_ptr);
679 infer_shape = shape_ptr->shape();
680 }
681 if (infer_shape.empty()) {
682 return infer_shape;
683 }
684
685 // if format is default_format or NC1KHKWHWC0,device shape = original shape
686 if (trans::IsNeedPadding(format, infer_shape)) {
687 infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx), node);
688 }
689 auto dtype = GetInputDeviceDataType(node, input_idx);
690 return trans::TransShapeToDevice(infer_shape, format, node, input_idx, dtype, false);
691 }
692
GetInputDeviceShape(const AnfNodePtr & node,size_t input_idx)693 std::vector<int64_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
694 auto format = GetInputFormat(node, input_idx);
695 auto infer_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, input_idx);
696 if (infer_shape.empty()) {
697 return infer_shape;
698 }
699 // if format is default_format or NC1KHKWHWC0,device shape = original shape
700 if (trans::IsNeedPadding(format, infer_shape)) {
701 infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx), node);
702 }
703 auto dtype = GetInputDeviceDataType(node, input_idx);
704 return trans::TransShapeToDevice(infer_shape, format, node, input_idx, dtype, false);
705 }
706
GetInputReshapeType(const AnfNodePtr & node,size_t input_idx)707 std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
708 MS_EXCEPTION_IF_NULL(node);
709 if (input_idx > common::AnfAlgo::GetInputTensorNum(node)) {
710 MS_LOG(EXCEPTION) << "The index:" << input_idx
711 << " is out of range of the node's input size : " << common::AnfAlgo::GetInputTensorNum(node)
712 << "#node[" << node->DebugString() << "]" << trace::DumpSourceLines(node);
713 }
714 if (!AnfUtils::IsRealKernel(node)) {
715 return GetPrevNodeOutputReshapeType(node, input_idx);
716 }
717 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
718 MS_EXCEPTION_IF_NULL(kernel_info);
719 auto build_info = kernel_info->select_kernel_build_info();
720 if (build_info == nullptr || build_info->IsInputDefaultPadding()) {
721 return "";
722 }
723 return build_info->GetInputReshapeType(input_idx);
724 }
725
GetOutputReshapeType(const AnfNodePtr & node,size_t output_idx)726 std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
727 MS_EXCEPTION_IF_NULL(node);
728 if (!AnfUtils::IsRealKernel(node)) {
729 return GetPrevNodeOutputReshapeType(node, output_idx);
730 }
731 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
732 MS_EXCEPTION_IF_NULL(kernel_info);
733 auto build_info = kernel_info->select_kernel_build_info();
734 if (build_info == nullptr || build_info->IsOutputDefaultPadding()) {
735 return "";
736 }
737 return build_info->GetOutputReshapeType(output_idx);
738 }
739
GetAllInputReshapeType(const AnfNodePtr & node)740 std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputReshapeType(const AnfNodePtr &node) {
741 MS_EXCEPTION_IF_NULL(node);
742 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
743 MS_EXCEPTION_IF_NULL(kernel_info);
744 auto build_info = kernel_info->select_kernel_build_info();
745 if (build_info == nullptr || build_info->IsInputDefaultPadding()) {
746 return {};
747 }
748 return build_info->GetAllInputReshapeType();
749 }
750
GetAllOutputReshapeType(const AnfNodePtr & node)751 std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputReshapeType(const AnfNodePtr &node) {
752 MS_EXCEPTION_IF_NULL(node);
753 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
754 MS_EXCEPTION_IF_NULL(kernel_info);
755 auto build_info = kernel_info->select_kernel_build_info();
756 if (build_info == nullptr || build_info->IsOutputDefaultPadding()) {
757 return {};
758 }
759 return build_info->GetAllOutputReshapeType();
760 }
761
GetOutputDeviceDataType(const AnfNodePtr & node,size_t output_idx)762 TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) {
763 MS_EXCEPTION_IF_NULL(node);
764 if (output_idx > AnfAlgo::GetOutputElementNum(node)) {
765 MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
766 << AnfAlgo::GetOutputElementNum(node) << "#node [ " << node->DebugString() << "]"
767 << trace::DumpSourceLines(node);
768 }
769 if (common::AnfAlgo::CheckAbsSparseTensor(node)) {
770 return common::AnfAlgo::GetSparseTypeIdAt(node, output_idx);
771 }
772 if (!AnfUtils::IsRealKernel(node)) {
773 return GetPrevNodeOutputDeviceDataType(node, output_idx);
774 }
775 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
776 MS_EXCEPTION_IF_NULL(kernel_info);
777 auto build_info = kernel_info->select_kernel_build_info();
778 MS_EXCEPTION_IF_NULL(build_info);
779
780 // If node has only one output and it is Tuple, in build_info, it only has one same dtype, so set output_dix as zero.
781 if (build_info->GetOutputNum() == 1 && build_info->GetOutputKernelObjectType(0) == kernel::KernelObjectType::TUPLE) {
782 output_idx = 0;
783 }
784
785 auto dtype = build_info->GetOutputDeviceType(output_idx);
786 if (dtype == TypeId::kNumberTypeEnd) {
787 MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "] has a invalid dtype" << trace::DumpSourceLines(node);
788 }
789 return dtype;
790 }
791
GetInputDeviceDataType(const AnfNodePtr & node,size_t input_idx)792 TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
793 MS_EXCEPTION_IF_NULL(node);
794 if (input_idx > common::AnfAlgo::GetInputTensorNum(node)) {
795 MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
796 << common::AnfAlgo::GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"
797 << trace::DumpSourceLines(node);
798 }
799 if (!AnfUtils::IsRealKernel(node)) {
800 return GetPrevNodeOutputDeviceDataType(node, 0);
801 }
802 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
803 MS_EXCEPTION_IF_NULL(kernel_info);
804 auto build_info = kernel_info->select_kernel_build_info();
805 MS_EXCEPTION_IF_NULL(build_info);
806 auto dtype = build_info->GetInputDeviceType(input_idx);
807 if (dtype == TypeId::kNumberTypeEnd) {
808 MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
809 << " has a invalid dtype." << trace::DumpSourceLines(node);
810 }
811 return dtype;
812 }
813
GetPrevNodeOutputDeviceDataType(const AnfNodePtr & anf_node,size_t input_idx)814 TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
815 KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
816 return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
817 }
818
819 // get output device addr of anf_node
GetOutputAddr(const AnfNodePtr & node,size_t output_idx,bool skip_nop_node)820 const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool skip_nop_node) {
821 MS_EXCEPTION_IF_NULL(node);
822 auto tensor = GetForwardOutputTensor(node);
823 if (tensor != nullptr) {
824 return dynamic_cast<const DeviceAddress *>(tensor->device_address().get());
825 }
826
827 if (common::AnfAlgo::IsNopNode(node) && (skip_nop_node || common::AnfAlgo::IsNeedSkipNopOpAddr(node))) {
828 auto cnode = node->cast<CNodePtr>();
829 MS_EXCEPTION_IF_NULL(cnode);
830 return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
831 }
832 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
833 MS_EXCEPTION_IF_NULL(kernel_info);
834 auto addr = kernel_info->GetOutputAddr(output_idx);
835 if (addr == nullptr) {
836 MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
837 << " output addr is not exist." << trace::DumpSourceLines(node);
838 }
839 return addr;
840 }
841
GetMutableOutputAddr(const AnfNodePtr & node,size_t output_idx,bool skip_nop_node)842 DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx,
843 bool skip_nop_node) {
844 MS_EXCEPTION_IF_NULL(node);
845 auto tensor = GetForwardOutputTensor(node);
846 if (tensor != nullptr) {
847 return std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address());
848 }
849
850 if (common::AnfAlgo::IsNopNode(node) && (skip_nop_node || common::AnfAlgo::IsNeedSkipNopOpAddr(node))) {
851 auto cnode = node->cast<CNodePtr>();
852 MS_EXCEPTION_IF_NULL(cnode);
853 return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
854 }
855 // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
856 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
857 MS_EXCEPTION_IF_NULL(kernel_info);
858 auto addr = kernel_info->GetMutableOutputAddr(output_idx);
859 if (addr == nullptr) {
860 MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() << " node:" << node
861 << " output addr is not exist." << trace::DumpSourceLines(node);
862 }
863 return addr;
864 }
865
866 // get output device addr of anf_node
OutputAddrExist(const AnfNodePtr & node,size_t output_idx,bool skip_nop_node)867 bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool skip_nop_node) {
868 MS_EXCEPTION_IF_NULL(node);
869 if (common::AnfAlgo::IsNopNode(node) && (skip_nop_node || common::AnfAlgo::IsNeedSkipNopOpAddr(node))) {
870 auto cnode = node->cast<CNodePtr>();
871 MS_EXCEPTION_IF_NULL(cnode);
872 if (cnode->size() > 1) {
873 auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, 0);
874 return OutputAddrExist(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
875 }
876 return false;
877 }
878 // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
879 auto kernel_info_ptr = node->kernel_info();
880 if (kernel_info_ptr == nullptr) {
881 return false;
882 }
883 auto kernel_info = dynamic_cast<device::KernelInfo *>(kernel_info_ptr);
884 MS_EXCEPTION_IF_NULL(kernel_info);
885 return kernel_info->OutputAddrExist(output_idx);
886 }
887
WorkspaceAddrExist(const AnfNodePtr & node,size_t output_idx)888 bool AnfRuntimeAlgorithm::WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx) {
889 MS_EXCEPTION_IF_NULL(node);
890 // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
891 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
892 MS_EXCEPTION_IF_NULL(kernel_info);
893 return kernel_info->WorkspaceAddrExist(output_idx);
894 }
895
GetPrevNodeOutputAddr(const AnfNodePtr & anf_node,size_t input_idx,bool skip_nop_node)896 const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
897 bool skip_nop_node) {
898 KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
899 return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
900 }
901
GetPrevNodeMutableOutputAddr(const AnfNodePtr & anf_node,size_t input_idx,bool skip_nop_node)902 DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
903 bool skip_nop_node) {
904 KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
905 return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
906 }
907
GetAbstractInfo(const AnfNodePtr & node,size_t output_idx)908 std::tuple<abstract::BaseShapePtr, TypePtr, ValuePtr> AnfRuntimeAlgorithm::GetAbstractInfo(const AnfNodePtr &node,
909 size_t output_idx) {
910 MS_EXCEPTION_IF_NULL(node);
911 abstract::BaseShapePtr shape;
912 TypePtr type;
913 ValuePtr value;
914
915 // Create output kernel tensor if not exists.
916 if (node->isa<ValueNode>()) {
917 auto value_node = node->cast<ValueNodePtr>();
918 MS_EXCEPTION_IF_NULL(value_node);
919 value = value_node->value();
920 auto abs = node->abstract();
921 if (abs == nullptr) {
922 MS_EXCEPTION_IF_NULL(value);
923 abs = value->ToAbstract();
924 value_node->set_abstract(abs);
925 }
926 MS_EXCEPTION_IF_NULL(abs);
927 shape = abs->GetShape();
928 type = abs->GetType();
929 } else {
930 const auto &abs = AnfAlgo::GetNodeAbstractByIndex(node, output_idx);
931 MS_EXCEPTION_IF_NULL(abs);
932 shape = abs->GetShape();
933 type = abs->GetType();
934 value = nullptr;
935 }
936
937 // Insert cast pass will change the device type for some reason like CPU do not support fp16 actually,
938 // so the output infer type and device type will be different, we change the output tensor to the real device type.
939 MS_EXCEPTION_IF_NULL(type);
940 if (type->isa<TensorType>()) {
941 auto real_device_type = AnfAlgo::GetOutputDeviceDataType(node, output_idx);
942 auto abs_tensor_type = type->Clone()->cast<TensorTypePtr>();
943 MS_EXCEPTION_IF_NULL(abs_tensor_type);
944 auto abs_element = abs_tensor_type->element();
945 if (abs_element != nullptr) {
946 auto abs_tensor_element_type = abs_element->type_id();
947 if (real_device_type != kTypeUnknown && real_device_type != abs_tensor_element_type) {
948 MS_LOG(INFO) << "For kernel " << node->DebugString() << ", the infer type of output[" << output_idx << "] is "
949 << TypeIdToString(abs_tensor_element_type) << ", but the device type is "
950 << TypeIdToString(real_device_type)
951 << ". Maybe there has insert cast pass which changed the device type."
952 << " So we change the tensor type from " << TypeIdToString(abs_tensor_element_type) << " to "
953 << TypeIdToString(real_device_type);
954 abs_tensor_type->set_element(TypeIdToType(real_device_type));
955 // Use new tensor type with device data type.
956 type = abs_tensor_type;
957 }
958 }
959 }
960
961 return std::make_tuple(shape, type, value);
962 }
963
ExistOutputKernelTensor(const AnfNodePtr & node,size_t output_idx)964 bool AnfRuntimeAlgorithm::ExistOutputKernelTensor(const AnfNodePtr &node, size_t output_idx) {
965 MS_EXCEPTION_IF_NULL(node);
966 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
967 MS_EXCEPTION_IF_NULL(kernel_info);
968
969 return kernel_info->OutputAddrExist(output_idx) || kernel_info->OutputKernelTensorExist(output_idx);
970 }
971
GetOutputKernelTensor(const AnfNodePtr & node,size_t output_idx)972 const KernelTensorPtr &AnfRuntimeAlgorithm::GetOutputKernelTensor(const AnfNodePtr &node, size_t output_idx) {
973 MS_EXCEPTION_IF_NULL(node);
974 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
975 MS_EXCEPTION_IF_NULL(kernel_info);
976
977 // Get output kernel tensor in device address if exists.
978 if (kernel_info->OutputAddrExist(output_idx)) {
979 return kernel_info->GetOutputAddr(output_idx)->kernel_tensor();
980 }
981
982 // Get output kernel tensor if exists.
983 if (kernel_info->OutputKernelTensorExist(output_idx)) {
984 return kernel_info->GetOutputKernelTensor(output_idx);
985 }
986
987 MS_LOG(EXCEPTION) << "Can not find kernel tensor for node : " << node->DebugString()
988 << ", output index: " << output_idx;
989 }
990
GetOrCreateOutputKernelTensor(const AnfNodePtr & node,size_t output_idx)991 const KernelTensorPtr &AnfRuntimeAlgorithm::GetOrCreateOutputKernelTensor(const AnfNodePtr &node, size_t output_idx) {
992 MS_EXCEPTION_IF_NULL(node);
993
994 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
995 MS_EXCEPTION_IF_NULL(kernel_info);
996
997 // Get output kernel tensor in device address if exists.
998 if (kernel_info->OutputAddrExist(output_idx)) {
999 const auto &kt = kernel_info->GetOutputAddr(output_idx)->kernel_tensor();
1000 return kt;
1001 }
1002
1003 // Get output kernel tensor if exists.
1004 if (kernel_info->OutputKernelTensorExist(output_idx)) {
1005 return kernel_info->GetOutputKernelTensor(output_idx);
1006 }
1007
1008 auto [shape, type, value] = GetAbstractInfo(node, output_idx);
1009 auto kernel_tensor = std::make_shared<KernelTensor>(shape, type, value);
1010 // Handle the format diff between host and device, need set format before Resize KernelMod.
1011 kernel_tensor->SetStringFormat(GetOutputFormat(node, output_idx));
1012 kernel_info->SetOutputKernelTensor(kernel_tensor, output_idx);
1013
1014 return kernel_info->GetOutputKernelTensor(output_idx);
1015 }
1016
GetPrevNodeOutputKernelTensor(const AnfNodePtr & node,size_t input_idx)1017 const KernelTensorPtr &AnfRuntimeAlgorithm::GetPrevNodeOutputKernelTensor(const AnfNodePtr &node, size_t input_idx) {
1018 KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_idx, false);
1019 return GetOutputKernelTensor(kernel_with_index.first, kernel_with_index.second);
1020 }
1021
GetOrCreatePrevNodeOutputKernelTensor(const AnfNodePtr & node,size_t input_idx)1022 const KernelTensorPtr &AnfRuntimeAlgorithm::GetOrCreatePrevNodeOutputKernelTensor(const AnfNodePtr &node,
1023 size_t input_idx) {
1024 KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_idx, false);
1025 return GetOrCreateOutputKernelTensor(kernel_with_index.first, kernel_with_index.second);
1026 }
1027
GetOrCreateAllInputKernelTensors(const AnfNodePtr & node)1028 std::vector<KernelTensor *> AnfRuntimeAlgorithm::GetOrCreateAllInputKernelTensors(const AnfNodePtr &node) {
1029 MS_EXCEPTION_IF_NULL(node);
1030 size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
1031 std::vector<KernelTensor *> input_kernel_tensors(input_num);
1032 for (size_t input_idx = 0; input_idx < input_num; ++input_idx) {
1033 input_kernel_tensors[input_idx] = GetOrCreatePrevNodeOutputKernelTensor(node, input_idx).get();
1034 }
1035 return input_kernel_tensors;
1036 }
1037
GetOrCreateAllOutputKernelTensors(const AnfNodePtr & node)1038 std::vector<KernelTensor *> AnfRuntimeAlgorithm::GetOrCreateAllOutputKernelTensors(const AnfNodePtr &node) {
1039 MS_EXCEPTION_IF_NULL(node);
1040 size_t output_num = AnfAlgo::GetOutputTensorNum(node);
1041 std::vector<KernelTensor *> output_kernel_tensors(output_num);
1042 for (size_t output_idx = 0; output_idx < output_num; ++output_idx) {
1043 output_kernel_tensors[output_idx] = GetOrCreateOutputKernelTensor(node, output_idx).get();
1044 }
1045 return output_kernel_tensors;
1046 }
1047
CreateOutputKernelTensorWithDeviceInfo(const AnfWithOutIndex & node_with_index,void * const device_ptr,size_t size,const string & format,TypeId dtype_id,const ShapeVector & host_shape,const std::string & device_name,uint32_t device_id,const UserDataPtr & user_data)1048 KernelTensorPtr AnfRuntimeAlgorithm::CreateOutputKernelTensorWithDeviceInfo(
1049 const AnfWithOutIndex &node_with_index, void *const device_ptr, size_t size, const string &format, TypeId dtype_id,
1050 const ShapeVector &host_shape, const std::string &device_name, uint32_t device_id, const UserDataPtr &user_data) {
1051 abstract::BaseShapePtr shape;
1052 TypePtr type;
1053 ValuePtr value;
1054 if (ExistOutputKernelTensor(node_with_index.first, node_with_index.second)) {
1055 const auto &kernel_tensor = GetOutputKernelTensor(node_with_index.first, node_with_index.second);
1056 MS_EXCEPTION_IF_NULL(kernel_tensor);
1057 MS_EXCEPTION_IF_NULL(kernel_tensor->GetShape());
1058 MS_EXCEPTION_IF_NULL(kernel_tensor->GetType());
1059 shape = kernel_tensor->GetShape()->Clone();
1060 type = kernel_tensor->GetType()->Clone();
1061 value = kernel_tensor->GetValueTrack();
1062 } else {
1063 std::tie(shape, type, value) = GetAbstractInfo(node_with_index.first, node_with_index.second);
1064 }
1065
1066 MS_EXCEPTION_IF_NULL(shape);
1067 MS_EXCEPTION_IF_NULL(type);
1068 MS_LOG(DEBUG) << "Create output kernel tensor for node: " << node_with_index.first->fullname_with_scope()
1069 << ", output index: " << node_with_index.second << ", Shape: " << shape->ToString()
1070 << ", Type: " << type->ToString() << ", Value: " << (value ? value->ToString() : "nullptr")
1071 << ", host shape: " << host_shape;
1072
1073 return std::make_shared<kernel::KernelTensor>(shape, type, value, device_ptr, size, format, dtype_id, host_shape,
1074 device_name, device_id, user_data);
1075 }
1076
GetNodeInputSizeList(const AnfNodePtr & node)1077 std::vector<size_t> AnfRuntimeAlgorithm::GetNodeInputSizeList(const AnfNodePtr &node) {
1078 std::vector<KernelTensor *> input_kernel_tensors = AnfAlgo::GetOrCreateAllInputKernelTensors(node);
1079 size_t input_num = input_kernel_tensors.size();
1080 std::vector<size_t> input_size_list(input_num, 0);
1081 for (size_t i = 0; i < input_num; i++) {
1082 MS_EXCEPTION_IF_NULL(input_kernel_tensors[i]);
1083 input_size_list[i] = input_kernel_tensors[i]->size();
1084 }
1085
1086 return input_size_list;
1087 }
1088
GetOutputAddressNum(const AnfNodePtr & node)1089 size_t AnfRuntimeAlgorithm::GetOutputAddressNum(const AnfNodePtr &node) {
1090 MS_EXCEPTION_IF_NULL(node);
1091 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1092 MS_EXCEPTION_IF_NULL(kernel_info);
1093 auto build_info = kernel_info->select_kernel_build_info();
1094 MS_EXCEPTION_IF_NULL(build_info);
1095 return build_info->GetOutputNumWithoutMonad();
1096 }
1097
1098 // set output device addr of anf_node
SetOutputAddr(const DeviceAddressPtr & addr,size_t output_idx,AnfNode * node)1099 void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
1100 MS_EXCEPTION_IF_NULL(node);
1101 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1102 MS_EXCEPTION_IF_NULL(kernel_info);
1103 if (!kernel_info->SetOutputAddr(addr, output_idx)) {
1104 MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set output index:" << output_idx << " fail."
1105 << trace::DumpSourceLines(node);
1106 }
1107 }
1108
1109 // set workspace device addr of anf_node
SetWorkspaceAddr(const DeviceAddressPtr & addr,size_t output_idx,AnfNode * node)1110 void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
1111 MS_EXCEPTION_IF_NULL(node);
1112 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1113 MS_EXCEPTION_IF_NULL(kernel_info);
1114 if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
1115 MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set output index:" << output_idx << " fail."
1116 << trace::DumpSourceLines(node);
1117 }
1118 }
1119
1120 // get workspace device addr of anf_node
GetWorkspaceAddr(const AnfNodePtr & node,size_t output_idx)1121 DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
1122 MS_EXCEPTION_IF_NULL(node);
1123 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1124 MS_EXCEPTION_IF_NULL(kernel_info);
1125 auto addr = kernel_info->GetWorkspaceAddr(output_idx);
1126 if (addr == nullptr) {
1127 MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
1128 << "] workspace addr is not exist." << trace::DumpSourceLines(node);
1129 }
1130 return addr;
1131 }
1132
1133 // get workspace device mutable addr of anf_node
GetMutableWorkspaceAddr(const AnfNodePtr & node,size_t index)1134 DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index) {
1135 MS_EXCEPTION_IF_NULL(node);
1136 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1137 MS_EXCEPTION_IF_NULL(kernel_info);
1138 auto addr = kernel_info->GetMutableWorkspaceAddr(index);
1139 if (addr == nullptr) {
1140 MS_LOG(EXCEPTION) << "Index " << index << " of node " << node->DebugString() << "] workspace addr is not exist."
1141 << trace::DumpSourceLines(node);
1142 }
1143 return addr;
1144 }
1145
GetOpPattern(const AnfNodePtr & node)1146 kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
1147 MS_EXCEPTION_IF_NULL(node);
1148 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1149 MS_EXCEPTION_IF_NULL(kernel_info);
1150 // select_kernel_build_info() has checked whether return pointer is null
1151 auto build_info = kernel_info->select_kernel_build_info();
1152 MS_EXCEPTION_IF_NULL(build_info);
1153 return build_info->op_pattern();
1154 }
1155
1156 // get KernelBuildType of node, such as ATT,RT,FWK and so on
GetKernelType(const AnfNodePtr & node)1157 KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
1158 MS_EXCEPTION_IF_NULL(node);
1159 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1160 MS_EXCEPTION_IF_NULL(kernel_info);
1161 // select_kernel_build_info() has checked whether return pointer is null
1162 auto build_info = kernel_info->select_kernel_build_info();
1163 if (build_info == nullptr) {
1164 MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << " has no kernel build info, using UNKNOWN_KERNEL_TYPE";
1165 return KernelType::UNKNOWN_KERNEL_TYPE;
1166 }
1167 return build_info->kernel_type();
1168 }
1169
SetFusionType(const AnfNodePtr & node,const std::string & type)1170 void AnfRuntimeAlgorithm::SetFusionType(const AnfNodePtr &node, const std::string &type) {
1171 MS_EXCEPTION_IF_NULL(node);
1172 auto builder =
1173 std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
1174 MS_EXCEPTION_IF_NULL(builder);
1175 builder->SetFusionType(type);
1176 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
1177 }
1178
SetCoreType(const AnfNodePtr & node,const std::string & core_type)1179 void AnfRuntimeAlgorithm::SetCoreType(const AnfNodePtr &node, const std::string &core_type) {
1180 MS_EXCEPTION_IF_NULL(node);
1181 auto builder =
1182 std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
1183 MS_EXCEPTION_IF_NULL(builder);
1184 builder->SetCoreType(core_type);
1185 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
1186 }
1187
GetCoreType(const AnfNodePtr & node)1188 std::string AnfRuntimeAlgorithm::GetCoreType(const AnfNodePtr &node) {
1189 MS_EXCEPTION_IF_NULL(node);
1190 if (!AnfUtils::IsRealKernel(node)) {
1191 return "";
1192 }
1193 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1194 MS_EXCEPTION_IF_NULL(kernel_info);
1195 auto build_info = kernel_info->select_kernel_build_info();
1196 MS_EXCEPTION_IF_NULL(build_info);
1197 return build_info->core_type();
1198 }
1199
GetOpType(const AnfNodePtr & node)1200 kernel::OpType AnfRuntimeAlgorithm::GetOpType(const AnfNodePtr &node) {
1201 MS_EXCEPTION_IF_NULL(node);
1202 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1203 MS_EXCEPTION_IF_NULL(kernel_info);
1204 auto build_info = kernel_info->select_kernel_build_info();
1205 MS_EXCEPTION_IF_NULL(build_info);
1206 return build_info->op_type();
1207 }
1208
SetOutputDataDesc(const AnfNodePtr & node,const std::vector<nlohmann::json> & desc)1209 void AnfRuntimeAlgorithm::SetOutputDataDesc(const AnfNodePtr &node, const std::vector<nlohmann::json> &desc) {
1210 MS_EXCEPTION_IF_NULL(node);
1211 auto builder =
1212 std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
1213 MS_EXCEPTION_IF_NULL(builder);
1214 builder->SetOutputDataDesc(desc);
1215 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
1216 }
1217
GetOutputDataDesc(const AnfNodePtr & node)1218 std::vector<nlohmann::json> AnfRuntimeAlgorithm::GetOutputDataDesc(const AnfNodePtr &node) {
1219 MS_EXCEPTION_IF_NULL(node);
1220 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1221 if (kernel_info == nullptr) {
1222 return {};
1223 }
1224 auto build_info = kernel_info->select_kernel_build_info();
1225 if (build_info == nullptr) {
1226 return {};
1227 }
1228 return build_info->output_data_desc();
1229 }
1230
GetProcessor(const AnfNodePtr & node)1231 kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
1232 MS_EXCEPTION_IF_NULL(node);
1233 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1234 MS_EXCEPTION_IF_NULL(kernel_info);
1235 auto build_info = kernel_info->select_kernel_build_info();
1236 MS_EXCEPTION_IF_NULL(build_info);
1237 return build_info->processor();
1238 }
1239
GetFusionType(const AnfNodePtr & node)1240 std::string AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
1241 MS_EXCEPTION_IF_NULL(node);
1242 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1243 MS_EXCEPTION_IF_NULL(kernel_info);
1244 auto build_info = kernel_info->select_kernel_build_info();
1245 if (build_info == nullptr) {
1246 return kPatternUnknown;
1247 }
1248 return build_info->fusion_type();
1249 }
1250
1251 // set select kernel_build_info
SetSelectKernelBuildInfo(const KernelBuildInfoPtr & select_kernel_build_info,AnfNode * node)1252 void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
1253 MS_EXCEPTION_IF_NULL(node);
1254 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1255 MS_EXCEPTION_IF_NULL(kernel_info);
1256 if (kernel_info->has_build_info() && (select_kernel_build_info != nullptr)) {
1257 auto ori_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
1258 auto input_object_types = ori_kernel_build_info->GetAllInputKernelObjectTypes();
1259 auto output_object_types = ori_kernel_build_info->GetAllOutputKernelObjectTypes();
1260 if (!input_object_types.empty() && select_kernel_build_info->GetAllInputKernelObjectTypes().empty()) {
1261 select_kernel_build_info->SetInputsKernelObjectType(input_object_types);
1262 }
1263 if (!output_object_types.empty() && select_kernel_build_info->GetAllOutputKernelObjectTypes().empty()) {
1264 MS_LOG(DEBUG) << "set kernel object type:" << output_object_types << " for node:" << node->fullname_with_scope();
1265 select_kernel_build_info->SetOutputsKernelObjectType(output_object_types);
1266 }
1267 }
1268 return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
1269 }
1270
1271 // get select kernel_build_info
GetSelectKernelBuildInfo(const AnfNodePtr & node)1272 KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
1273 MS_EXCEPTION_IF_NULL(node);
1274 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1275 MS_EXCEPTION_IF_NULL(kernel_info);
1276 return kernel_info->GetMutableSelectKernelBuildInfo();
1277 }
1278
1279 // get kernelMode
GetKernelMod(const AnfNodePtr & node)1280 KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
1281 MS_EXCEPTION_IF_NULL(node);
1282 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1283 MS_EXCEPTION_IF_NULL(kernel_info);
1284 return kernel_info->MutableKernelMod();
1285 }
1286
1287 // set kernel mod
SetKernelMod(const KernelModPtr & kernel_mod,AnfNode * node)1288 void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
1289 MS_EXCEPTION_IF_NULL(node);
1290 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1291 MS_EXCEPTION_IF_NULL(kernel_info);
1292 kernel_info->set_kernel_mod(kernel_mod);
1293 }
1294
SetStreamId(uint32_t stream_id,AnfNode * node)1295 void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
1296 MS_EXCEPTION_IF_NULL(node);
1297 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1298 MS_EXCEPTION_IF_NULL(kernel_info);
1299 kernel_info->set_stream_id(stream_id);
1300 }
1301
GetStreamId(const AnfNodePtr & node)1302 uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
1303 MS_EXCEPTION_IF_NULL(node);
1304 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1305 MS_EXCEPTION_IF_NULL(kernel_info);
1306 return kernel_info->stream_id();
1307 }
1308
SetStreamDistinctionLabel(uint32_t stream_label,AnfNode * node)1309 void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
1310 MS_EXCEPTION_IF_NULL(node);
1311 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1312 MS_EXCEPTION_IF_NULL(kernel_info);
1313 kernel_info->set_stream_distinction_label(stream_label);
1314 }
1315
GetStreamDistinctionLabel(const AnfNode * node)1316 uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
1317 MS_EXCEPTION_IF_NULL(node);
1318 auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
1319 MS_EXCEPTION_IF_NULL(kernel_info);
1320 return kernel_info->stream_distinction_label();
1321 }
1322
SetGraphId(uint32_t graph_id,AnfNode * node)1323 void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
1324 MS_EXCEPTION_IF_NULL(node);
1325 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1326 MS_EXCEPTION_IF_NULL(kernel_info);
1327 kernel_info->set_graph_id(graph_id);
1328 }
1329
GetGraphId(const AnfNode * node)1330 uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
1331 MS_EXCEPTION_IF_NULL(node);
1332 auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
1333 MS_EXCEPTION_IF_NULL(kernel_info);
1334 return kernel_info->graph_id();
1335 }
1336
IsFeatureMapOutput(const AnfNodePtr & node)1337 bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
1338 MS_EXCEPTION_IF_NULL(node);
1339 if (node->isa<ValueNode>()) {
1340 auto value_node = node->cast<ValueNodePtr>();
1341 MS_EXCEPTION_IF_NULL(value_node);
1342 ValuePtr value = value_node->value();
1343 std::vector<tensor::BaseTensorPtr> tensors;
1344 TensorValueToTensor(value, &tensors);
1345 auto ret = false;
1346 if (!tensors.empty()) {
1347 auto all_tensor_have_address = true;
1348 for (const auto &tensor : tensors) {
1349 MS_EXCEPTION_IF_NULL(tensor);
1350 if (tensor->device_address() == nullptr) {
1351 all_tensor_have_address = false;
1352 break;
1353 }
1354 }
1355 ret = all_tensor_have_address;
1356 }
1357 return ret;
1358 }
1359 if (IsPrimitiveCNode(node, prim::kPrimLoad) || IsPrimitiveCNode(node, prim::kPrimDepend)) {
1360 return IsFeatureMapOutput(node->cast<CNodePtr>()->input(1));
1361 }
1362 auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
1363 // If node is a call node which not have kernel info
1364 if (kernel_info == nullptr) {
1365 return false;
1366 }
1367 return kernel_info->is_feature_map();
1368 }
1369
IsFeatureMapInput(const AnfNodePtr & node,size_t input_index)1370 bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) {
1371 MS_EXCEPTION_IF_NULL(node);
1372 if (!node->isa<CNode>()) {
1373 MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map."
1374 << trace::DumpSourceLines(node);
1375 }
1376 auto cnode = node->cast<CNodePtr>();
1377 MS_EXCEPTION_IF_NULL(cnode);
1378 auto input_node = cnode->input(input_index + 1);
1379 return IsFeatureMapOutput(input_node);
1380 }
1381
GetInputGraphIdxByKernelIdx(const mindspore::AnfNodePtr & anf_node,size_t input_index_in_kernel)1382 size_t AnfRuntimeAlgorithm::GetInputGraphIdxByKernelIdx(const mindspore::AnfNodePtr &anf_node,
1383 size_t input_index_in_kernel) {
1384 MS_EXCEPTION_IF_NULL(anf_node);
1385 return input_index_in_kernel;
1386 }
1387
GetInputKernelIdxByGraphIdx(const mindspore::AnfNodePtr & anf_node,size_t input_index_in_graph)1388 size_t AnfRuntimeAlgorithm::GetInputKernelIdxByGraphIdx(const mindspore::AnfNodePtr &anf_node,
1389 size_t input_index_in_graph) {
1390 MS_EXCEPTION_IF_NULL(anf_node);
1391 return input_index_in_graph;
1392 }
1393
GetCallSwitchKernelGraph(const CNodePtr & cnode)1394 std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) {
1395 MS_EXCEPTION_IF_NULL(cnode);
1396 if (!(common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) ||
1397 common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
1398 common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) {
1399 MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch or switch_layer node."
1400 << trace::DumpSourceLines(cnode);
1401 }
1402 auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr {
1403 auto partial = cnode->input(input_index);
1404 MS_EXCEPTION_IF_NULL(partial);
1405 if (IsValueNode<KernelGraph>(partial)) {
1406 return GetValueNode<KernelGraphPtr>(partial);
1407 }
1408 auto partial_cnode = partial->cast<CNodePtr>();
1409 MS_EXCEPTION_IF_NULL(partial_cnode);
1410 auto graph_node = partial_cnode->input(kPartialGraphIndex);
1411 MS_EXCEPTION_IF_NULL(graph_node);
1412 auto graph_value_node = graph_node->cast<ValueNodePtr>();
1413 MS_EXCEPTION_IF_NULL(graph_value_node);
1414 auto graph_value = graph_value_node->value();
1415 MS_EXCEPTION_IF_NULL(graph_value);
1416 auto child_graph = graph_value->cast<KernelGraphPtr>();
1417 return child_graph;
1418 };
1419 if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
1420 auto input1 = cnode->input(kPartialGraphIndex);
1421 MS_EXCEPTION_IF_NULL(input1);
1422 auto value_node = input1->cast<ValueNodePtr>();
1423 MS_EXCEPTION_IF_NULL(value_node);
1424 auto kernel_graph = value_node->value();
1425 MS_EXCEPTION_IF_NULL(kernel_graph);
1426 return {kernel_graph->cast<KernelGraphPtr>()};
1427 } else if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
1428 return {get_switch_kernel_graph(kSwitchTrueBranchIndex), get_switch_kernel_graph(kSwitchFalseBranchIndex)};
1429 } else if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
1430 std::vector<KernelGraphPtr> child_graphs;
1431 for (size_t idx = kSwitchLayerBranchesIndex; idx < cnode->size(); idx++) {
1432 auto child_graph = get_switch_kernel_graph(idx);
1433 child_graphs.emplace_back(child_graph);
1434 }
1435 return child_graphs;
1436 }
1437 return {};
1438 }
1439
GetValueNodeKernelGraph(const AnfNodePtr & node)1440 KernelGraphPtr AnfRuntimeAlgorithm::GetValueNodeKernelGraph(const AnfNodePtr &node) {
1441 MS_EXCEPTION_IF_NULL(node);
1442 auto value_node = node->cast<ValueNodePtr>();
1443 if (value_node == nullptr) {
1444 return nullptr;
1445 }
1446 auto value = value_node->value();
1447 if (value == nullptr) {
1448 return nullptr;
1449 }
1450 auto kernel_graph = value->cast<KernelGraphPtr>();
1451 return kernel_graph;
1452 }
1453
IsIndependentNode(const CNodePtr & node)1454 bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) {
1455 MS_EXCEPTION_IF_NULL(node);
1456 if (AnfAlgo::GetKernelType(node) != AICPU_KERNEL) {
1457 return false;
1458 }
1459
1460 if (common::AnfAlgo::GetCNodeName(node) == kGetNextOpName) {
1461 MS_LOG(INFO) << "GetNext should not be independent node";
1462 return false;
1463 }
1464
1465 // aicpu stack ops are not independent nodes.
1466 if (common::AnfAlgo::GetCNodeName(node) == kStackInitOpName ||
1467 common::AnfAlgo::GetCNodeName(node) == kStackDestroyOpName ||
1468 common::AnfAlgo::GetCNodeName(node) == kStackPopOpName ||
1469 common::AnfAlgo::GetCNodeName(node) == kStackPushOpName) {
1470 MS_LOG(INFO) << "AICPU stack ops should not be independent node";
1471 return false;
1472 }
1473
1474 size_t input_nums = common::AnfAlgo::GetInputTensorNum(node);
1475 if (input_nums == 0) {
1476 return true;
1477 }
1478
1479 auto inputs = node->inputs();
1480 for (size_t i = 1; i < inputs.size(); i++) {
1481 if (!inputs[i]->isa<ValueNode>()) {
1482 return false;
1483 }
1484 }
1485 return true;
1486 }
1487
FetchKernelGraph(const AnfNode * node)1488 KernelGraphPtr AnfRuntimeAlgorithm::FetchKernelGraph(const AnfNode *node) {
1489 MS_EXCEPTION_IF_NULL(node);
1490 const auto &func_graph = node->func_graph();
1491 if (func_graph == nullptr) {
1492 return nullptr;
1493 } else {
1494 return func_graph->cast<KernelGraphPtr>();
1495 }
1496 }
1497
FetchFrontNodeByBackendNode(const AnfNodePtr & backend_node,const KernelGraph & graph)1498 AnfNodePtr AnfRuntimeAlgorithm::FetchFrontNodeByBackendNode(const AnfNodePtr &backend_node, const KernelGraph &graph) {
1499 MS_EXCEPTION_IF_NULL(backend_node);
1500 auto front_node_with_index = graph.GetFrontNodeByInternalParameter(backend_node);
1501 if (front_node_with_index.first != nullptr) {
1502 return front_node_with_index.first;
1503 }
1504
1505 auto front_node = graph.GetFrontAnfByBackendAnf(backend_node);
1506 // PyNative forward graph does not has front node, using backend node instead.
1507 if (front_node == nullptr) {
1508 front_node = backend_node;
1509 }
1510 return front_node;
1511 }
1512
1513 namespace {
1514 // Host kernel with inputs on host
SkipDataSync(const CNodePtr & node,const std::map<uint32_t,tensor::TensorPtr> & depend_tensors)1515 bool SkipDataSync(const CNodePtr &node, const std::map<uint32_t, tensor::TensorPtr> &depend_tensors) {
1516 if (!common::AnfAlgo::IsHostKernel(node)) {
1517 return false;
1518 }
1519 auto input_size = common::AnfAlgo::GetInputTensorNum(node);
1520 for (size_t i = 0; i < input_size; ++i) {
1521 auto input_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i);
1522 auto real_input = input_with_index.first;
1523 auto iter_tensor = depend_tensors.find(i);
1524 if (iter_tensor != depend_tensors.end()) {
1525 auto output_addr = AnfAlgo::GetOutputAddr(real_input, 0);
1526 MS_EXCEPTION_IF_NULL(output_addr);
1527 if (output_addr->GetDeviceType() != device::DeviceType::kCPU) {
1528 return false;
1529 }
1530 }
1531 }
1532 return true;
1533 }
1534 } // namespace
1535
InferShape(const CNodePtr & node,std::map<uint32_t,tensor::TensorPtr> * depend_tensors)1536 void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *depend_tensors) {
1537 MS_EXCEPTION_IF_NULL(node);
1538 MS_LOG(INFO) << "InferShape start, node:" << node->DebugString();
1539 auto inputs = node->inputs();
1540 if (inputs.empty()) {
1541 MS_LOG(EXCEPTION) << "Inputs should not be empty! Cnode: " << node->DebugString() << "."
1542 << trace::DumpSourceLines(node);
1543 }
1544 AbstractBasePtrList args_spec_list;
1545 auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
1546 auto input_size = common::AnfAlgo::GetInputTensorNum(node);
1547 for (size_t i = 0; i < input_size; ++i) {
1548 auto input_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i);
1549 auto real_input = input_with_index.first;
1550 MS_EXCEPTION_IF_NULL(real_input);
1551 auto cnode_input = node->input(i + 1);
1552 MS_EXCEPTION_IF_NULL(cnode_input);
1553 if (depend_tensors != nullptr) {
1554 auto iter_tensor = depend_tensors->find(i);
1555 if (iter_tensor != depend_tensors->cend()) {
1556 auto tensor_ptr = iter_tensor->second;
1557 MS_EXCEPTION_IF_NULL(tensor_ptr);
1558 if (!SkipDataSync(node, *depend_tensors)) {
1559 // sync data from device to host
1560 tensor_ptr->data_sync();
1561 }
1562 // cppcheck-suppress unreadVariable
1563 auto lock = AnfUtils::GetAbstractLock(real_input.get());
1564 auto real_abs = real_input->abstract();
1565 if (real_abs->isa<abstract::AbstractTensor>()) {
1566 real_abs->set_value(tensor_ptr);
1567 } else if (real_abs->isa<abstract::AbstractTuple>() && (!common::AnfAlgo::IsDynamicSequence(real_input))) {
1568 auto tuple_get_item_index = common::AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
1569 auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>();
1570 MS_EXCEPTION_IF_NULL(abstract_tuple);
1571 auto tuple_elements = abstract_tuple->elements()[tuple_get_item_index];
1572 tuple_elements->set_value(tensor_ptr);
1573 }
1574 }
1575 }
1576 common::AnfAlgo::AddArgList(&args_spec_list, real_input, input_with_index.second);
1577 }
1578 auto eval_result = opt::CppInferShapeAndType(primitive, args_spec_list);
1579 node->set_abstract(eval_result);
1580 }
1581
InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> & root_graph)1582 void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> &root_graph) {
1583 auto return_node = root_graph->get_return();
1584 MS_EXCEPTION_IF_NULL(return_node);
1585 if (return_node->size() <= kReturnDataIndex) {
1586 return;
1587 }
1588 auto make_tuple = root_graph->NewCNode(
1589 {NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
1590 MS_EXCEPTION_IF_NULL(root_graph->output());
1591 MS_EXCEPTION_IF_NULL(make_tuple);
1592 abstract::AbstractBasePtrList abs_list{root_graph->output()->abstract()};
1593 make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
1594 root_graph->set_output(make_tuple);
1595 }
1596
UpdateGraphValidRefPair(const KernelGraphPtr & graph)1597 void AnfRuntimeAlgorithm::UpdateGraphValidRefPair(const KernelGraphPtr &graph) {
1598 MS_EXCEPTION_IF_NULL(graph);
1599
1600 if (graph->memory_managed_by_ge()) {
1601 return;
1602 }
1603
1604 const auto &origin_ref_map = graph->GetRefMap();
1605 std::map<AnfWithOutIndex, AnfWithOutIndex> new_ref_map;
1606 for (const auto &node : graph->execution_order()) {
1607 MS_EXCEPTION_IF_NULL(node);
1608 auto output_num = AnfAlgo::GetOutputTensorNum(node);
1609 if (output_num == 0) {
1610 MS_LOG(DEBUG) << "This kernel has no output size.";
1611 continue;
1612 }
1613 for (size_t i = 0; i < output_num; ++i) {
1614 session::AnfWithOutIndex out_pair(node, i);
1615 auto iter = origin_ref_map.find(out_pair);
1616 if (iter != origin_ref_map.end()) {
1617 auto ret = new_ref_map.try_emplace(iter->first, iter->second);
1618 if (!ret.second) {
1619 MS_LOG(WARNING) << "Duplicate ref_map key, node:" << node->fullname_with_scope() << " index:" << i;
1620 }
1621 }
1622 }
1623 }
1624 graph->set_ref_out_in_map(new_ref_map);
1625 }
1626
IsDynamicShapeSkipExecute(bool skip_mode,const ShapeVector & axes_shape)1627 bool AnfRuntimeAlgorithm::IsDynamicShapeSkipExecute(bool skip_mode, const ShapeVector &axes_shape) {
1628 // Skip run ReduceSum when axis is a Empty Tensor
1629 if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; }) && skip_mode) {
1630 return true;
1631 }
1632 return false;
1633 }
1634
IsDynamicShapeSkipExecute(const CNodePtr & cnode)1635 bool AnfRuntimeAlgorithm::IsDynamicShapeSkipExecute(const CNodePtr &cnode) {
1636 // Skip run ReduceSum when axis is a Empty Tensor
1637 MS_EXCEPTION_IF_NULL(cnode);
1638 auto op_name = common::AnfAlgo::GetCNodeName(cnode);
1639 if ((op_name != kReduceSumOpName) && (op_name != kReduceSumDOpName)) {
1640 return false;
1641 }
1642
1643 bool skip_mode = false;
1644 if (common::AnfAlgo::HasNodeAttr(kAttrSkipMode, cnode)) {
1645 skip_mode = common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrSkipMode);
1646 }
1647
1648 if (!skip_mode) {
1649 return false;
1650 }
1651
1652 const size_t axes_index = 1;
1653 if (cnode->size() <= axes_index + 1) {
1654 return false;
1655 }
1656 auto input_axes = cnode->input(axes_index + 1);
1657 // cppcheck-suppress unreadVariable
1658 auto lock = AnfUtils::GetAbstractLock(input_axes.get());
1659 auto abs = input_axes->abstract();
1660 MS_EXCEPTION_IF_NULL(abs);
1661 auto axes_abs = abs->Clone();
1662 MS_EXCEPTION_IF_NULL(axes_abs);
1663 auto axes_shape = AnfAlgo::GetInputDeviceShape(cnode, axes_index);
1664 if (axes_abs->isa<abstract::AbstractTensor>()) {
1665 if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; })) {
1666 return true;
1667 }
1668 }
1669 return false;
1670 }
1671
IsNeedUpdateShapeAndTypeAfterLaunch(const AnfNodePtr & node)1672 bool AnfRuntimeAlgorithm::IsNeedUpdateShapeAndTypeAfterLaunch(const AnfNodePtr &node) {
1673 MS_EXCEPTION_IF_NULL(node);
1674 auto graph = FetchKernelGraph(node.get());
1675 // The graph run mode does not have kernelmod.
1676 if ((graph == nullptr) || graph->is_graph_run_mode()) {
1677 return true;
1678 }
1679
1680 auto kernel_mod = GetKernelMod(node);
1681 if (kernel_mod == nullptr) {
1682 return true;
1683 }
1684 return kernel_mod->IsNeedUpdateOutputShapeAndSize();
1685 }
1686
HasComputedDependInputNode(const CNodePtr & kernel)1687 bool AnfRuntimeAlgorithm::HasComputedDependInputNode(const CNodePtr &kernel) {
1688 MS_EXCEPTION_IF_NULL(kernel);
1689 auto real_input_num = common::AnfAlgo::GetInputTensorNum(kernel);
1690
1691 for (size_t i = 0; i < real_input_num; i++) {
1692 const auto &input_node = common::AnfAlgo::GetInputNode(kernel, i);
1693 MS_EXCEPTION_IF_NULL(input_node);
1694 auto real_input_node = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
1695 MS_EXCEPTION_IF_NULL(real_input_node.first);
1696 if (!real_input_node.first->isa<CNode>()) {
1697 continue;
1698 }
1699
1700 auto kernel_mod = AnfAlgo::GetKernelMod(real_input_node.first);
1701 if (kernel_mod && kernel_mod->IsNeedUpdateOutputShapeAndSize()) {
1702 return true;
1703 }
1704 }
1705 return false;
1706 }
1707
UpdateOutputAddrSize(device::KernelInfo const * kernel_info,const CNodePtr & kernel)1708 void AnfRuntimeAlgorithm::UpdateOutputAddrSize(device::KernelInfo const *kernel_info, const CNodePtr &kernel) {
1709 MS_EXCEPTION_IF_NULL(kernel_info);
1710 MS_EXCEPTION_IF_NULL(kernel);
1711 auto &output_addresses = kernel_info->output_address_list();
1712 for (size_t i = 0; i < output_addresses.size(); ++i) {
1713 auto output_address = output_addresses[i].get();
1714 MS_EXCEPTION_IF_NULL(output_address);
1715 auto output_addr_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
1716 MS_LOG(DEBUG) << "output size:" << output_addr_size << " index:" << i
1717 << " for kernel:" << kernel->fullname_with_scope()
1718 << " abstract:" << (kernel->abstract() == nullptr ? "null" : kernel->abstract()->ToString());
1719 if (output_addr_size != output_address->GetSize()) {
1720 output_address->SetSize(output_addr_size);
1721 }
1722 }
1723 }
1724
AddOutInRefToGraph(const KernelGraphPtr & graph)1725 void AnfRuntimeAlgorithm::AddOutInRefToGraph(const KernelGraphPtr &graph) {
1726 MS_EXCEPTION_IF_NULL(graph);
1727 for (const auto &cnode : graph->execution_order()) {
1728 MS_EXCEPTION_IF_NULL(cnode);
1729 auto kernel_info = dynamic_cast<device::KernelInfo *>(cnode->kernel_info());
1730 MS_EXCEPTION_IF_NULL(kernel_info);
1731 for (const auto &ref : kernel_info->out_in_ref_map()) {
1732 size_t output_index = ref.first;
1733 size_t input_index = ref.second;
1734 auto final_pair = std::make_pair(cnode, output_index);
1735 auto origin_pair = common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(cnode, input_index), 0);
1736 MS_LOG(INFO) << "The reference relation output " << final_pair.first->fullname_with_scope()
1737 << ", output index: " << final_pair.second << " to input "
1738 << origin_pair.first->fullname_with_scope() << ", output index: " << origin_pair.second;
1739 // Add to graph only if the input is not a monad.
1740 if (!HasAbstractUMonad(origin_pair.first) && !HasAbstractIOMonad(origin_pair.first)) {
1741 graph->AddRefCorrespondPairs(final_pair, origin_pair);
1742 }
1743 }
1744 }
1745 }
1746
HasOriginFormat(const AnfNodePtr & anf_node)1747 bool AnfRuntimeAlgorithm::HasOriginFormat(const AnfNodePtr &anf_node) {
1748 MS_EXCEPTION_IF_NULL(anf_node);
1749 return anf_node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrOriginFormat, anf_node->cast<CNodePtr>());
1750 }
1751
GetOriginFormat(const AnfNodePtr & anf_node)1752 std::string AnfRuntimeAlgorithm::GetOriginFormat(const AnfNodePtr &anf_node) {
1753 MS_EXCEPTION_IF_NULL(anf_node);
1754 if (anf_node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrOriginFormat, anf_node->cast<CNodePtr>())) {
1755 return common::AnfAlgo::GetNodeAttr<std::string>(anf_node, kAttrOriginFormat);
1756 }
1757 return {};
1758 }
1759
NodeValueIsFuncGraph(const AnfNodePtr & node)1760 bool AnfRuntimeAlgorithm::NodeValueIsFuncGraph(const AnfNodePtr &node) {
1761 MS_EXCEPTION_IF_NULL(node);
1762 auto value_node = node->cast<ValueNodePtr>();
1763 MS_EXCEPTION_IF_NULL(value_node);
1764 auto value = value_node->value().get();
1765 MS_EXCEPTION_IF_NULL(value);
1766 return value->isa<FuncGraph>();
1767 }
1768
IsNodeSupportKernelSelectBackoff(const AnfNodePtr & node,const KernelGraphPtr & graph)1769 bool AnfRuntimeAlgorithm::IsNodeSupportKernelSelectBackoff(const AnfNodePtr &node, const KernelGraphPtr &graph) {
1770 MS_EXCEPTION_IF_NULL(node);
1771 static std::string disable_kernel_backoff;
1772 static bool first_get_backoff_env = true;
1773 if (first_get_backoff_env) {
1774 disable_kernel_backoff = common::GetEnv(kDisableKernelBackoff);
1775 first_get_backoff_env = false;
1776 }
1777 if (disable_kernel_backoff == "1" && (!common::AnfAlgo::IsTypeTransformOp(common::AnfAlgo::GetCNodeName(node)))) {
1778 MS_LOG(INFO) << "MS_DISABLE_KERNEL_BACKOFF has been set to turn off the kernel backoff ability.";
1779 return false;
1780 }
1781
1782 if (graph == nullptr) {
1783 return false;
1784 }
1785 if (graph->is_from_single_op() || graph->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
1786 MS_LOG(INFO) << "The pynative single op does not support the kernel backoff ability for graph:"
1787 << graph->graph_id();
1788 return false;
1789 }
1790 return true;
1791 }
1792
SetKernelSelectBackoffInfo(const CNodePtr & node,const std::pair<std::string,ExceptionType> & failure_info)1793 void AnfRuntimeAlgorithm::SetKernelSelectBackoffInfo(const CNodePtr &node,
1794 const std::pair<std::string, ExceptionType> &failure_info) {
1795 MS_EXCEPTION_IF_NULL(node);
1796 common::AnfAlgo::SetNodeAttr(kAttrKernelBackoffWithFailureInfo, MakeValue(failure_info.first), node);
1797 common::AnfAlgo::SetNodeAttr(kAttrKernelBackoffWithFailureType, MakeValue(static_cast<int32_t>(failure_info.second)),
1798 node);
1799 }
1800
GetKernelSelectBackoffInfo(const AnfNodePtr & node)1801 std::pair<std::string, ExceptionType> AnfRuntimeAlgorithm::GetKernelSelectBackoffInfo(const AnfNodePtr &node) {
1802 MS_EXCEPTION_IF_NULL(node);
1803 if (!IsKernelSelectBackoffOp(node)) {
1804 return {"", NoExceptionType};
1805 }
1806
1807 auto cnode = node->cast<CNodePtr>();
1808 MS_EXCEPTION_IF_NULL(cnode);
1809 auto failure_info = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrKernelBackoffWithFailureInfo);
1810 auto failure_type =
1811 static_cast<ExceptionType>(common::AnfAlgo::GetNodeAttr<int32_t>(node, kAttrKernelBackoffWithFailureType));
1812 return {failure_info, failure_type};
1813 }
1814
IsKernelSelectBackoffOp(const AnfNodePtr & node)1815 bool AnfRuntimeAlgorithm::IsKernelSelectBackoffOp(const AnfNodePtr &node) {
1816 MS_EXCEPTION_IF_NULL(node);
1817 if (!node->isa<CNode>()) {
1818 return false;
1819 }
1820
1821 auto cnode = node->cast<CNodePtr>();
1822 MS_EXCEPTION_IF_NULL(cnode);
1823 if (common::AnfAlgo::HasNodeAttr(kAttrKernelBackoffWithFailureInfo, cnode) &&
1824 common::AnfAlgo::HasNodeAttr(kAttrKernelBackoffWithFailureType, cnode)) {
1825 return true;
1826 }
1827 return false;
1828 }
1829
FetchDeviceTarget(const AnfNodePtr & node,const KernelGraph * graph)1830 std::string AnfRuntimeAlgorithm::FetchDeviceTarget(const AnfNodePtr &node, const KernelGraph *graph) {
1831 MS_EXCEPTION_IF_NULL(node);
1832 MS_EXCEPTION_IF_NULL(graph);
1833 // The parameter also may be have the user data to express device target.
1834 auto ud_target = node->user_data<std::string>(kAttrPrimitiveTarget);
1835 if (ud_target != nullptr) {
1836 return *ud_target;
1837 }
1838
1839 if (!node->isa<CNode>()) {
1840 return device::GetDeviceNameByType(graph->device_target());
1841 }
1842
1843 // Only the CPU support kernel backoff.
1844 if (AnfAlgo::IsKernelSelectBackoffOp(node)) {
1845 return kCPUDevice;
1846 }
1847
1848 auto cnode = node->cast<CNodePtr>();
1849 MS_EXCEPTION_IF_NULL(cnode);
1850 if (common::AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
1851 return common::AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrPrimitiveTarget);
1852 }
1853
1854 return device::GetDeviceNameByType(graph->device_target());
1855 }
1856
SetParameterDeviceTarget(const KernelGraphPtr graph)1857 void AnfRuntimeAlgorithm::SetParameterDeviceTarget(const KernelGraphPtr graph) {
1858 MS_EXCEPTION_IF_NULL(graph);
1859 auto manager = graph->manager();
1860 if (manager == nullptr) {
1861 manager = MakeManager({graph});
1862 graph->set_manager(manager);
1863 }
1864
1865 const auto &graph_device_target = device::GetDeviceNameByType(graph->device_target());
1866 for (auto &input_node : graph->input_nodes()) {
1867 const auto &iter = manager->node_users().find(input_node);
1868 if (iter == manager->node_users().end()) {
1869 continue;
1870 }
1871
1872 std::string device_target_affinity = graph_device_target;
1873 for (const auto &user_node : iter->second) {
1874 if (!AnfUtils::IsRealCNodeKernel(user_node.first)) {
1875 continue;
1876 }
1877 device_target_affinity = FetchDeviceTarget(user_node.first, graph.get());
1878 // If there is node with the same device target as the graph, then select the device target of graph affinity.
1879 if (device_target_affinity == graph_device_target) {
1880 break;
1881 }
1882 }
1883
1884 // Set the device target for parameter when it is different with the graph.
1885 if (device_target_affinity != graph_device_target) {
1886 MS_LOG(INFO) << "Set the affinity device target for parameter:" << input_node->fullname_with_scope()
1887 << " in graph:" << graph->graph_id() << " from graph device target:" << graph_device_target
1888 << " to real device target:" << device_target_affinity;
1889 input_node->set_user_data(kAttrPrimitiveTarget, std::make_shared<std::string>(device_target_affinity));
1890 }
1891 }
1892 }
1893
GetAbstractObjectType(const AbstractBasePtr & abstract)1894 TypeId AnfRuntimeAlgorithm::GetAbstractObjectType(const AbstractBasePtr &abstract) {
1895 if (abstract == nullptr) {
1896 return kTypeUnknown;
1897 }
1898 if (abstract->isa<AbstractTensor>()) {
1899 return kObjectTypeTensorType;
1900 }
1901 if (abstract->isa<AbstractTuple>()) {
1902 return kObjectTypeTuple;
1903 }
1904 if (abstract->isa<abstract::AbstractList>()) {
1905 return kObjectTypeList;
1906 }
1907 if (abstract->isa<abstract::AbstractScalar>()) {
1908 // scalar input may not converted to tensor
1909 return kObjectTypeNumber;
1910 }
1911 if (abstract->isa<abstract::AbstractNone>()) {
1912 return kMetaTypeNone;
1913 }
1914
1915 return kTypeUnknown;
1916 }
1917
GetOutputObjectType(const AnfNodePtr & node,size_t output_idx)1918 TypeId AnfRuntimeAlgorithm::GetOutputObjectType(const AnfNodePtr &node, size_t output_idx) {
1919 MS_EXCEPTION_IF_NULL(node);
1920 auto abstract = node->abstract();
1921 if (abstract->isa<AbstractTuple>()) {
1922 auto tuple_abs = abstract->cast<abstract::AbstractTuplePtr>();
1923 auto items = tuple_abs->elements();
1924 MS_EXCEPTION_IF_CHECK_FAIL(output_idx < items.size(), "invalid output_idx");
1925 return AnfAlgo::GetAbstractObjectType(items[output_idx]);
1926 }
1927 if (output_idx != 0) {
1928 MS_LOG(EXCEPTION) << node->DebugString() << "invalid output_idx" << trace::DumpSourceLines(node);
1929 }
1930 return AnfAlgo::GetAbstractObjectType(abstract);
1931 }
1932
GetInputObjectType(const CNodePtr & node,size_t input_idx)1933 TypeId AnfRuntimeAlgorithm::GetInputObjectType(const CNodePtr &node, size_t input_idx) {
1934 MS_EXCEPTION_IF_NULL(node);
1935 auto input_node = common::AnfAlgo::GetInputNode(node, input_idx);
1936 const std::vector<PrimitivePtr> need_handled_prims = {prim::kPrimMakeTuple, prim::kPrimTupleGetItem};
1937 auto real_input_node = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false, need_handled_prims).first;
1938 return AnfAlgo::GetAbstractObjectType(real_input_node->abstract());
1939 }
1940
GetAllInputObjectType(const AnfNodePtr & node)1941 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllInputObjectType(const AnfNodePtr &node) {
1942 MS_EXCEPTION_IF_NULL(node);
1943 if (!node->isa<CNode>()) {
1944 MS_LOG(EXCEPTION) << node->DebugString() << "anf_node is not CNode." << trace::DumpSourceLines(node);
1945 }
1946 auto cnode = node->cast<CNodePtr>();
1947 std::vector<TypeId> obj_types;
1948 auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
1949 for (size_t index = 0; index < input_num; ++index) {
1950 obj_types.push_back(AnfAlgo::GetInputObjectType(cnode, index));
1951 }
1952 return obj_types;
1953 }
1954
GetAllOutputObjectType(const AnfNodePtr & node)1955 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputObjectType(const AnfNodePtr &node) {
1956 MS_EXCEPTION_IF_NULL(node);
1957 if (AnfAlgo::GetOutputElementNum(node) == 0 && node->abstract() != nullptr &&
1958 !node->abstract()->isa<abstract::AbstractSequence>()) {
1959 return {};
1960 }
1961 return {AnfAlgo::GetAbstractObjectType(node->abstract())};
1962 }
1963
GetOutputDetailShape(const AnfNodePtr & node,size_t output_idx)1964 abstract::BaseShapePtr AnfRuntimeAlgorithm::GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx) {
1965 MS_EXCEPTION_IF_NULL(node);
1966 auto base_shape = node->Shape();
1967 MS_EXCEPTION_IF_NULL(base_shape);
1968 if (base_shape->isa<abstract::Shape>()) {
1969 if (output_idx == 0) {
1970 return base_shape;
1971 }
1972 MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << output_idx
1973 << "]." << trace::DumpSourceLines(node);
1974 } else if (base_shape->isa<abstract::TupleShape>()) {
1975 auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
1976 MS_EXCEPTION_IF_NULL(tuple_shape);
1977 if (IsRealSquenceOutput(node)) {
1978 return tuple_shape;
1979 }
1980 if (output_idx >= tuple_shape->size()) {
1981 MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
1982 << " node:" << node->DebugString() << "." << trace::DumpSourceLines(node);
1983 }
1984 auto b_shp = (*tuple_shape)[output_idx];
1985 if (b_shp->isa<abstract::Shape>() || b_shp->isa<abstract::NoShape>() || b_shp->isa<abstract::TupleShape>() ||
1986 b_shp->isa<abstract::DynamicSequenceShape>()) {
1987 return b_shp;
1988 } else {
1989 MS_LOG(EXCEPTION) << "The output type of node index:" << output_idx
1990 << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
1991 << "node :" << node->DebugString() << "." << trace::DumpSourceLines(node);
1992 }
1993 } else if (base_shape->isa<abstract::NoShape>()) {
1994 return base_shape;
1995 } else if (base_shape->isa<abstract::DynamicSequenceShape>()) {
1996 return common::AnfAlgo::GetDynamicSequenceShape(node, output_idx);
1997 }
1998 MS_LOG(EXCEPTION) << "The output type of node should be a NoShape , ArrayShape or a TupleShape, but it is "
1999 << base_shape->ToString() << " node : " << node->DebugString() << trace::DumpSourceLines(node);
2000 }
2001
GetPrevNodeOutputDetailShape(const AnfNodePtr & node,size_t input_idx)2002 abstract::BaseShapePtr AnfRuntimeAlgorithm::GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx) {
2003 KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_idx);
2004 return AnfAlgo::GetOutputDetailShape(kernel_with_index.first, kernel_with_index.second);
2005 }
2006
GetAllOutputInferDataTypes(const AnfNodePtr & node)2007 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputInferDataTypes(const AnfNodePtr &node) {
2008 MS_EXCEPTION_IF_NULL(node);
2009 std::vector<TypeId> outputs;
2010 auto out_nums = AnfAlgo::GetOutputElementNum(node);
2011 for (size_t i = 0; i < out_nums; i++) {
2012 auto type = common::AnfAlgo::GetOutputInferDataType(node, i);
2013 outputs.push_back(type);
2014 }
2015 return outputs;
2016 }
2017
2018 // if input node is MakeTuple, find the PrevNodeNum recursively;
2019 // The monad node in the end is not included in the num;
GetInputElementNum(const AnfNodePtr & node)2020 size_t AnfRuntimeAlgorithm::GetInputElementNum(const AnfNodePtr &node) {
2021 MS_EXCEPTION_IF_NULL(node);
2022 auto cnode = node->cast<CNodePtr>();
2023 MS_EXCEPTION_IF_NULL(cnode);
2024 size_t element_num = 0;
2025 size_t input_num = cnode->size() - 1;
2026 bool cal_monad_flag = false;
2027 for (size_t i = input_num; i > 0; --i) {
2028 auto input_node = common::AnfAlgo::GetInputNode(cnode, i - 1);
2029 if (!cal_monad_flag && HasAbstractMonad(input_node)) {
2030 continue;
2031 } else if (common::AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
2032 element_num += GetInputElementNum(input_node);
2033 cal_monad_flag = true;
2034 } else if (common::AnfAlgo::IsTupleOutput(input_node)) {
2035 element_num += AnfAlgo::GetOutputElementNum(input_node);
2036 cal_monad_flag = true;
2037 } else {
2038 ++element_num;
2039 cal_monad_flag = true;
2040 }
2041 }
2042
2043 return element_num;
2044 }
2045
SetDynamicAttrToPrim(const PrimitivePtr & prim)2046 void AnfRuntimeAlgorithm::SetDynamicAttrToPrim(const PrimitivePtr &prim) {
2047 (void)prim->AddAttr(kAttrMutableKernel, MakeValue(true));
2048 (void)prim->AddAttr(kAttrInputIsDynamicShape, MakeValue(true));
2049 (void)prim->AddAttr(kAttrOutputIsDynamicShape, MakeValue(true));
2050 }
2051
IsScalarConvertToTensor(const AnfNodePtr & input_node,const CNodePtr & node)2052 bool AnfRuntimeAlgorithm::IsScalarConvertToTensor(const AnfNodePtr &input_node, const CNodePtr &node) {
2053 MS_EXCEPTION_IF_NULL(input_node);
2054 MS_EXCEPTION_IF_NULL(node);
2055 if (!input_node->isa<ValueNode>()) {
2056 return false;
2057 }
2058
2059 auto value_node = input_node->cast<ValueNodePtr>();
2060 MS_EXCEPTION_IF_NULL(value_node);
2061 auto value = value_node->value();
2062 MS_EXCEPTION_IF_NULL(value);
2063 if (!value->isa<Scalar>()) {
2064 return false;
2065 }
2066
2067 const auto &abs = node->abstract();
2068 if (ContainScalarOut(abs)) {
2069 MS_LOG(INFO) << "The input scalar value node:" << input_node->fullname_with_scope()
2070 << " of cnode:" << node->fullname_with_scope() << " doesn't need convert to tensor.";
2071 return false;
2072 }
2073 return true;
2074 }
2075
IsSequenceOutputOfScalar(const AnfNodePtr & node)2076 bool AnfRuntimeAlgorithm::IsSequenceOutputOfScalar(const AnfNodePtr &node) {
2077 MS_EXCEPTION_IF_NULL(node);
2078 const auto &abs = node->abstract();
2079 if (abs == nullptr || !abs->isa<abstract::AbstractSequence>()) {
2080 return false;
2081 }
2082 // Check all elements in tuple/list are scalar.
2083 auto abs_seq = abs->cast_ptr<abstract::AbstractSequence>();
2084 MS_EXCEPTION_IF_NULL(abs_seq);
2085 if (abs_seq->dynamic_len()) {
2086 const auto &element_abs = abs_seq->dynamic_len_element_abs();
2087 return (element_abs == nullptr) || (element_abs->isa<abstract::AbstractScalar>());
2088 }
2089 const auto &elements = abs_seq->elements();
2090
2091 return std::all_of(elements.begin(), elements.end(), [](const AbstractBasePtr &element) {
2092 return (element != nullptr) && (element->isa<abstract::AbstractScalar>()) &&
2093 (element->BuildValue() == nullptr || (!element->BuildValue()->isa<StringImm>()));
2094 });
2095 }
2096
IsSummaryNode(const AnfNodePtr & node)2097 bool AnfRuntimeAlgorithm::IsSummaryNode(const AnfNodePtr &node) {
2098 return (IsPrimitiveCNode(node, prim::kPrimScalarSummary) || IsPrimitiveCNode(node, prim::kPrimTensorSummary) ||
2099 IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary));
2100 }
2101
2102 namespace {
CheckValidTensorTuple(const std::vector<ValuePtr> & values)2103 bool CheckValidTensorTuple(const std::vector<ValuePtr> &values) {
2104 if (values.empty() || values[0] == nullptr || (!values[0]->isa<tensor::Tensor>())) {
2105 return false;
2106 }
2107 const auto &const_tensor = values[0]->cast<tensor::TensorPtr>();
2108 MS_EXCEPTION_IF_NULL(const_tensor);
2109 const auto &const_shape = const_tensor->shape();
2110 const auto &const_type_id = const_tensor->data_type();
2111 size_t const_size = const_tensor->Size();
2112 for (size_t i = 1; i < values.size(); ++i) {
2113 if (values[i] == nullptr || (!values[i]->isa<tensor::Tensor>())) {
2114 MS_LOG(ERROR) << "Invalid value:" << (values[i] == nullptr ? "nullptr" : values[i]->ToString()) << " index:" << i
2115 << " in value tuple";
2116 return false;
2117 }
2118 const auto &tensor = values[i]->cast<tensor::TensorPtr>();
2119 MS_EXCEPTION_IF_NULL(tensor);
2120 const auto &shape = tensor->shape();
2121 const auto &type_id = tensor->data_type();
2122 size_t size = tensor->Size();
2123 if (shape != const_shape || type_id != const_type_id || size != const_size) {
2124 return false;
2125 }
2126 }
2127 return true;
2128 }
2129
2130 // Return a new tensor with type like single_value.
SetScalarToTensor(const std::vector<ValuePtr> & values,const tensor::TensorPtr & tensor)2131 void SetScalarToTensor(const std::vector<ValuePtr> &values, const tensor::TensorPtr &tensor) {
2132 MS_EXCEPTION_IF_NULL(tensor);
2133 const auto &tensor_type_id = tensor->data_type();
2134 const auto dst_ptr = tensor->data_c();
2135 MS_EXCEPTION_IF_NULL(dst_ptr);
2136 MS_LOG(DEBUG) << "Set scalar tuple to tensor, dst size:" << tensor->data().nbytes();
2137 for (size_t i = 0; i < values.size(); ++i) {
2138 // Check mem size.
2139 if (SizeToLong(abstract::TypeIdSize(tensor_type_id) * (i + 1)) > tensor->data().nbytes()) {
2140 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Value size:" << values.size()
2141 << " type:" << tensor_type_id << " out of range:" << tensor->data().nbytes();
2142 }
2143 const auto &value = values[i];
2144 MS_EXCEPTION_IF_NULL(value);
2145 // Check value type.
2146 if (value->type()->type_id() != tensor_type_id) {
2147 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid value type:" << value->type()->type_id()
2148 << " for value:" << value->ToString() << " dst type:" << tensor_type_id;
2149 }
2150 if (tensor_type_id == TypeId::kNumberTypeInt8) {
2151 (reinterpret_cast<int8_t *>(dst_ptr))[i] = GetValue<int8_t>(value);
2152 } else if (tensor_type_id == TypeId::kNumberTypeInt16) {
2153 (reinterpret_cast<int16_t *>(dst_ptr))[i] = GetValue<int16_t>(value);
2154 } else if (tensor_type_id == TypeId::kNumberTypeInt32 || tensor_type_id == kNumberTypeInt) {
2155 (reinterpret_cast<int32_t *>(dst_ptr))[i] = GetValue<int32_t>(value);
2156 } else if (tensor_type_id == TypeId::kNumberTypeInt64) {
2157 (reinterpret_cast<int64_t *>(dst_ptr))[i] = GetValue<int64_t>(value);
2158 } else if (tensor_type_id == TypeId::kNumberTypeBool) {
2159 (reinterpret_cast<bool *>(dst_ptr))[i] = GetValue<bool>(value);
2160 } else if (tensor_type_id == TypeId::kNumberTypeFloat32 || tensor_type_id == TypeId::kNumberTypeFloat) {
2161 (reinterpret_cast<float *>(dst_ptr))[i] = GetValue<float>(value);
2162 } else if (tensor_type_id == TypeId::kNumberTypeFloat64) {
2163 (reinterpret_cast<double *>(dst_ptr))[i] = GetValue<double>(value);
2164 } else if (tensor_type_id == TypeId::kNumberTypeUInt8) {
2165 (reinterpret_cast<uint8_t *>(dst_ptr))[i] = GetValue<uint8_t>(value);
2166 } else if (tensor_type_id == TypeId::kNumberTypeUInt16) {
2167 (reinterpret_cast<uint16_t *>(dst_ptr))[i] = GetValue<uint16_t>(value);
2168 } else if (tensor_type_id == TypeId::kNumberTypeUInt || tensor_type_id == TypeId::kNumberTypeUInt32) {
2169 (reinterpret_cast<uint32_t *>(dst_ptr))[i] = GetValue<uint32_t>(value);
2170 } else if (tensor_type_id == TypeId::kNumberTypeUInt64) {
2171 (reinterpret_cast<uint64_t *>(dst_ptr))[i] = GetValue<uint64_t>(value);
2172 } else {
2173 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid tuple type:" << tensor_type_id
2174 << " for scalar to tensor.";
2175 }
2176 }
2177 }
2178 } // namespace
2179
CreateMapTensor(const DeviceAddressPtr & output_device_address)2180 tensor::TensorPtr AnfRuntimeAlgorithm::CreateMapTensor(const DeviceAddressPtr &output_device_address) {
2181 MS_EXCEPTION_IF_NULL(output_device_address);
2182 const auto &user_data = output_device_address->user_data();
2183 MS_EXCEPTION_IF_NULL(user_data);
2184 const auto &user_data_type = user_data->get<UserDataType>(kUserDataType);
2185 MS_EXCEPTION_IF_NULL(user_data_type);
2186 if (*user_data_type == UserDataType::kUserTypeHashTable) {
2187 auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector);
2188 auto key_type = user_data->get<TypeId>(kHashTableKeyType);
2189 auto value_type = user_data->get<TypeId>(kHashTableValueType);
2190 auto default_value = user_data->get<Value>(kHashTableDefaultValue);
2191 MS_EXCEPTION_IF_NULL(shape_vector);
2192 MS_EXCEPTION_IF_NULL(key_type);
2193 MS_EXCEPTION_IF_NULL(value_type);
2194 MS_EXCEPTION_IF_NULL(default_value);
2195 auto map_tensor = std::make_shared<tensor::MapTensor>(*key_type, *value_type, *shape_vector, default_value);
2196 map_tensor->set_device_address(output_device_address);
2197 return map_tensor;
2198 }
2199 MS_LOG(WARNING) << "Invalid user data type:" << *user_data_type;
2200 return nullptr;
2201 }
2202
CreateMapTensor(const AnfNodePtr & output_node,size_t output_index)2203 tensor::TensorPtr AnfRuntimeAlgorithm::CreateMapTensor(const AnfNodePtr &output_node, size_t output_index) {
2204 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
2205 return CreateMapTensor(device_tensor);
2206 }
2207
2208 // In dynamic sequence, since the number of members is not determined in compile time, the entire sequence needs
2209 // to be placed in single tensor, and the shape of the tuple needs to be recorded in the tensor, so that the shape
2210 // of the tensor can be accurately restored during the dynamic shape derivation process in runtime.
SequenceToTensor(const ValuePtr & value)2211 tensor::TensorPtr AnfRuntimeAlgorithm::SequenceToTensor(const ValuePtr &value) {
2212 MS_EXCEPTION_IF_NULL(value);
2213 if (!value->isa<ValueSequence>()) {
2214 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid sequence value:" << value->ToString();
2215 }
2216
2217 const auto &sequence_value = value->cast<ValueSequencePtr>();
2218 const auto &values = sequence_value->value();
2219 if (values.empty()) {
2220 auto tensor = std::make_shared<tensor::Tensor>();
2221 abstract::BaseShapePtr base_shape = nullptr;
2222 if (value->isa<ValueTuple>()) {
2223 base_shape = std::make_shared<abstract::TupleShape>(abstract::BaseShapePtrList());
2224 } else {
2225 base_shape = std::make_shared<abstract::ListShape>(abstract::BaseShapePtrList());
2226 }
2227 tensor->set_base_shape(base_shape);
2228 return tensor;
2229 }
2230 if (values[0] == nullptr || ((!values[0]->isa<Scalar>()) && (!values[0]->isa<tensor::BaseTensor>()))) {
2231 MS_LOG(WARNING) << "Empty sequence in sequence value:" << value->ToString();
2232 return std::make_shared<tensor::Tensor>();
2233 }
2234
2235 ShapeVector shape_vector{SizeToLong(values.size())};
2236 if (values[0]->isa<tensor::BaseTensor>()) {
2237 MS_LOG(DEBUG) << "Check dynamic tuple tensor";
2238 if (!CheckValidTensorTuple(values)) {
2239 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid dynamic sequence tuple:"
2240 << value->ToString();
2241 }
2242 const auto &tensor = values[0]->cast<tensor::BaseTensorPtr>();
2243 MS_EXCEPTION_IF_NULL(tensor);
2244 size_t size = tensor->Size();
2245 const auto &type_id = tensor->data_type();
2246 auto single_shape_vector = tensor->shape();
2247 const auto &single_shape = std::make_shared<abstract::Shape>(single_shape_vector);
2248 (void)shape_vector.insert(shape_vector.end(), single_shape_vector.begin(), single_shape_vector.end());
2249 const auto &shape = std::make_shared<abstract::Shape>(shape_vector);
2250 auto new_tensor = std::make_shared<tensor::Tensor>(type_id, shape_vector);
2251 MS_EXCEPTION_IF_NULL(new_tensor);
2252 const auto dst_ptr = new_tensor->data_c();
2253 MS_EXCEPTION_IF_NULL(dst_ptr);
2254 MS_LOG(DEBUG) << "Copy start, dst size:" << new_tensor->data().nbytes();
2255 for (size_t i = 0; i < values.size(); ++i) {
2256 const auto &sub_value = values[i];
2257 MS_EXCEPTION_IF_NULL(sub_value);
2258 const auto &src_tensor = sub_value->cast<tensor::TensorPtr>();
2259 MS_EXCEPTION_IF_NULL(src_tensor);
2260 MS_EXCEPTION_IF_NULL(src_tensor->data_c());
2261 auto ret = memcpy_s((reinterpret_cast<char *>(dst_ptr)) + i * size,
2262 static_cast<size_t>(new_tensor->data().nbytes()), src_tensor->data_c(), size);
2263 if (ret != EOK) {
2264 MS_LOG(INTERNAL_EXCEPTION)
2265 << "#dmsg#Runtime error info:#dmsg#Failed to copy data into tensor, memcpy_s errorno: " << ret;
2266 }
2267 }
2268 const auto &element_shapes = std::vector<abstract::BaseShapePtr>(values.size(), single_shape);
2269 new_tensor->set_base_shape(std::make_shared<abstract::TupleShape>(element_shapes));
2270 MS_LOG(DEBUG) << "merge tensor from:" << value->ToString() << " to:" << new_tensor->ToString() << " tensor addr"
2271 << new_tensor;
2272 return new_tensor;
2273 }
2274
2275 // Create the tensor.
2276 auto tensor = std::make_shared<tensor::Tensor>(values[0]->type()->type_id(), shape_vector);
2277 MS_EXCEPTION_IF_NULL(tensor);
2278 SetScalarToTensor(values, tensor);
2279 // Build the tuple shape and set into tensor.
2280 const auto &element_shape = std::make_shared<abstract::Shape>(ShapeVector({}));
2281 const auto &element_shapes = std::vector<abstract::BaseShapePtr>(values.size(), element_shape);
2282 tensor->set_base_shape(std::make_shared<abstract::TupleShape>(element_shapes));
2283 return tensor;
2284 }
2285
FlattenDynamicInputArg(const BaseRef & arg,const AnfNodePtr & node,std::vector<tensor::TensorPtr> * flatten_tensors)2286 void AnfRuntimeAlgorithm::FlattenDynamicInputArg(const BaseRef &arg, const AnfNodePtr &node,
2287 std::vector<tensor::TensorPtr> *flatten_tensors) {
2288 MS_EXCEPTION_IF_NULL(node);
2289 MS_EXCEPTION_IF_NULL(flatten_tensors);
2290 MS_LOG(DEBUG) << "Dynamic sequence node:" << node->fullname_with_scope() << " abs:" << node->abstract()->ToString();
2291 if (!utils::isa<ValuePtr>(arg)) {
2292 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid input for dynamic sequence node:"
2293 << node->DebugString();
2294 }
2295 auto value = utils::cast<ValuePtr>(arg);
2296 MS_EXCEPTION_IF_NULL(value);
2297 if (!value->isa<ValueSequence>()) {
2298 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid value:" << value->ToString()
2299 << " for dynamic sequence node:" << node->DebugString();
2300 }
2301 const auto &tensor = AnfAlgo::SequenceToTensor(value);
2302 flatten_tensors->emplace_back(tensor);
2303 }
2304
FlattenInputArg(const BaseRef & arg,const AnfNodePtr & node,std::vector<tensor::TensorPtr> * flatten_tensors)2305 void AnfRuntimeAlgorithm::FlattenInputArg(const BaseRef &arg, const AnfNodePtr &node,
2306 std::vector<tensor::TensorPtr> *flatten_tensors) {
2307 MS_EXCEPTION_IF_NULL(flatten_tensors);
2308 if (node != nullptr && node->abstract() != nullptr && common::AnfAlgo::IsDynamicSequence(node)) {
2309 FlattenDynamicInputArg(arg, node, flatten_tensors);
2310 return;
2311 }
2312
2313 #ifndef BUILD_LITE
2314 if (utils::isa<PyObjectRef>(arg)) {
2315 auto value = utils::cast<PyObjectRef>(arg).object_;
2316 flatten_tensors->push_back(py::cast<tensor::TensorPtr>(value));
2317 return;
2318 }
2319 #endif
2320
2321 if (utils::isa<tensor::Tensor>(arg)) {
2322 (void)flatten_tensors->emplace_back(utils::cast<tensor::TensorPtr>(arg));
2323 } else if (utils::isa<tensor::BaseTensor>(arg)) {
2324 (void)flatten_tensors->emplace_back(std::make_shared<tensor::Tensor>(*utils::cast<tensor::BaseTensorPtr>(arg)));
2325 } else if (utils::isa<Scalar>(arg)) {
2326 (void)flatten_tensors->emplace_back(ScalarToTensor(utils::cast<ScalarPtr>(arg)));
2327 } else if (utils::isa<Monad>(arg)) {
2328 // If value is a monad, replace it with an unused tensor.
2329 flatten_tensors->push_back(std::make_shared<tensor::Tensor>(int64_t(0), kBool));
2330 } else if (utils::isa<ValueSequencePtr>(arg)) {
2331 auto value_sequence = utils::cast<ValueSequencePtr>(arg);
2332 MS_EXCEPTION_IF_NULL(value_sequence);
2333 auto sequence_value = value_sequence->value();
2334 for (auto &value : sequence_value) {
2335 FlattenInputArg(value, node, flatten_tensors);
2336 }
2337 } else if (utils::isa<ValueDictionaryPtr>(arg)) {
2338 auto value_dict = utils::cast<ValueDictionaryPtr>(arg);
2339 MS_EXCEPTION_IF_NULL(value_dict);
2340 auto dict_value = value_dict->value();
2341 for (auto &iter : dict_value) {
2342 FlattenInputArg(iter.second, node, flatten_tensors);
2343 }
2344 } else if (utils::isa<tensor::COOTensorPtr>(arg)) {
2345 auto coo_tensor = utils::cast<tensor::COOTensorPtr>(arg);
2346 MS_EXCEPTION_IF_NULL(coo_tensor);
2347 for (size_t i = 0; i < coo_tensor->GetTensorLength(); ++i) {
2348 (void)flatten_tensors->emplace_back(coo_tensor->GetTensorAt(i));
2349 }
2350 } else if (utils::isa<tensor::CSRTensorPtr>(arg)) {
2351 auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(arg);
2352 MS_EXCEPTION_IF_NULL(csr_tensor);
2353 for (size_t i = 0; i < csr_tensor->GetTensorLength(); ++i) {
2354 (void)flatten_tensors->emplace_back(csr_tensor->GetTensorAt(i));
2355 }
2356 } else if (utils::isa<VectorRefPtr>(arg)) {
2357 const auto &args_new = utils::cast<VectorRef>(arg);
2358 for (const auto &arg_new : args_new) {
2359 FlattenInputArg(arg_new, node, flatten_tensors);
2360 }
2361 } else {
2362 MS_LOG(INTERNAL_EXCEPTION)
2363 << "#dmsg#Runtime error info:#dmsg#The value input to flatten tensor not supported for type " << arg.ToString();
2364 }
2365 }
2366
UpdateValueNodeShape(const AnfNodePtr & node)2367 void AnfRuntimeAlgorithm::UpdateValueNodeShape(const AnfNodePtr &node) {
2368 MS_EXCEPTION_IF_NULL(node);
2369 if (!node->isa<ValueNode>()) {
2370 return;
2371 }
2372 const auto &value_node = node->cast<ValueNodePtr>();
2373 MS_EXCEPTION_IF_NULL(value_node);
2374 const auto &value = value_node->value();
2375 MS_EXCEPTION_IF_NULL(value);
2376 if (!value->isa<ValueSequence>()) {
2377 return;
2378 }
2379 const auto &value_sequence = value->cast<ValueSequencePtr>();
2380 MS_EXCEPTION_IF_NULL(value_sequence);
2381 std::vector<abstract::AbstractBasePtr> abstract_list;
2382 for (const auto &sub_value : value_sequence->value()) {
2383 MS_EXCEPTION_IF_NULL(sub_value);
2384 if (sub_value->isa<Scalar>()) {
2385 auto abstract = std::make_shared<abstract::AbstractScalar>(sub_value->type());
2386 (void)abstract_list.emplace_back(abstract);
2387 } else if (sub_value->isa<tensor::Tensor>()) {
2388 const auto &tensor = sub_value->cast<tensor::TensorPtr>();
2389 MS_EXCEPTION_IF_NULL(tensor);
2390 auto abstract = std::make_shared<abstract::AbstractTensor>(tensor->Dtype(), tensor->shape());
2391 (void)abstract_list.emplace_back(abstract);
2392 } else {
2393 MS_LOG(EXCEPTION) << "Invalid value:" << sub_value->ToString()
2394 << " in dynamic sequence value node:" << node->DebugString();
2395 }
2396 }
2397 auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
2398 MS_LOG(INFO) << "Set abstract for node:" << node->DebugString() << "from:" << node->abstract()->ToString()
2399 << " to:" << abstract_tuple->ToString();
2400 node->set_abstract(abstract_tuple);
2401 }
2402
HasSelectKernelBuildInfo(const AnfNodePtr & node)2403 bool AnfRuntimeAlgorithm::HasSelectKernelBuildInfo(const AnfNodePtr &node) {
2404 MS_EXCEPTION_IF_NULL(node);
2405 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
2406 if (kernel_info == nullptr) {
2407 return false;
2408 }
2409 auto build_info = kernel_info->select_kernel_build_info();
2410 if (build_info == nullptr) {
2411 return false;
2412 }
2413 return true;
2414 }
2415
NeedEraseCache(const PrimitivePtr & prim)2416 bool AnfRuntimeAlgorithm::NeedEraseCache(const PrimitivePtr &prim) {
2417 MS_EXCEPTION_IF_NULL(prim);
2418 if (!prim->HasAttr(kRandomCache)) {
2419 return false;
2420 }
2421 auto random_cache_value = prim->GetAttr(kRandomCache);
2422 MS_EXCEPTION_IF_NULL(random_cache_value);
2423 return !GetValue<bool>(random_cache_value);
2424 }
2425
GetNodeAbstractByIndex(const AnfNodePtr & node,size_t index)2426 abstract::AbstractBasePtr AnfRuntimeAlgorithm::GetNodeAbstractByIndex(const AnfNodePtr &node, size_t index) {
2427 MS_EXCEPTION_IF_NULL(node);
2428 const auto &abstract = node->abstract();
2429 if (abstract == nullptr) {
2430 return abstract;
2431 }
2432
2433 // Return output abstract directly for : 1.not sequence type, 2.dynamic sequence type, 3.real tuple/list type.
2434 if (!abstract->isa<abstract::AbstractSequence>() || common::AnfAlgo::IsDynamicSequence(node) ||
2435 (node->isa<CNode>() && !mindspore::AnfAlgo::GetOutputKernelObjectTypes(node).empty() &&
2436 (mindspore::session::AnfRuntimeAlgorithm::GetOutputKernelObjectType(node, 0) ==
2437 kernel::KernelObjectType::TUPLE))) {
2438 MS_EXCEPTION_IF_CHECK_FAIL((index == 0), "Cannot get " + std::to_string(index) + " child abstract from " +
2439 abstract->ToString() + " in node:" + node->fullname_with_scope());
2440 return abstract;
2441 }
2442
2443 // Return element abstract by index for tuple type.
2444 const auto &abstract_tuple = abstract->cast<abstract::AbstractSequencePtr>();
2445 MS_EXCEPTION_IF_NULL(abstract_tuple);
2446 const auto &elements = abstract_tuple->elements();
2447 if (elements.size() <= index) {
2448 const auto sub_abstract = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), index);
2449 return sub_abstract;
2450 }
2451 return elements[index];
2452 }
2453
CreateTypeIdValueNodeToKernelGraph(const FuncGraphPtr & func_graph,TypeId data_type)2454 ValueNodePtr AnfRuntimeAlgorithm::CreateTypeIdValueNodeToKernelGraph(const FuncGraphPtr &func_graph, TypeId data_type) {
2455 auto type_id_value_node = NewValueNode(static_cast<int64_t>(data_type));
2456 auto type_id_value = std::make_shared<Int64Imm>(static_cast<int64_t>(data_type));
2457 type_id_value_node->set_abstract(type_id_value->ToAbstract());
2458 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
2459 MS_EXCEPTION_IF_NULL(kernel_graph);
2460 type_id_value_node = kernel_graph->NewValueNode(type_id_value_node);
2461 kernel_graph->AddValueNodeToGraph(type_id_value_node);
2462 return type_id_value_node;
2463 }
2464
CreateTypeIdValueNodeToFuncGraph(const FuncGraphPtr & func_graph,TypeId data_type)2465 ValueNodePtr AnfRuntimeAlgorithm::CreateTypeIdValueNodeToFuncGraph(const FuncGraphPtr &func_graph, TypeId data_type) {
2466 auto type_id_value_node = NewValueNode(static_cast<int64_t>(data_type));
2467 auto type_id_value = std::make_shared<Int64Imm>(static_cast<int64_t>(data_type));
2468 type_id_value_node->set_abstract(type_id_value->ToAbstract());
2469 return type_id_value_node;
2470 }
2471 } // namespace mindspore::session
2472