• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/backend/anf_runtime_algorithm.h"
17 
18 #include <memory>
19 #include <algorithm>
20 #include <map>
21 #include <set>
22 #include <functional>
23 #include "ops/ascend_op_name.h"
24 #include "ops/math_op_name.h"
25 #include "ops/lite_op_name.h"
26 #include "ops/structure_ops.h"
27 #include "ops/sequence_ops.h"
28 #include "ops/framework_ops.h"
29 #include "ir/anf.h"
30 #include "utils/log_adapter.h"
31 #include "ir/func_graph_cloner.h"
32 #include "utils/shape_utils.h"
33 #include "include/common/utils/utils.h"
34 #include "include/common/utils/parallel_context.h"
35 #include "include/common/utils/anfalgo.h"
36 #include "include/common/debug/anf_dump_utils.h"
37 #include "include/backend/kernel_info.h"
38 #include "include/backend/device_address.h"
39 #include "include/backend/optimizer/helper.h"
40 #include "kernel/kernel.h"
41 #include "kernel/kernel_build_info.h"
42 #include "runtime/device/ms_device_shape_transfer.h"
43 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
44 #include "abstract/ops/primitive_infer_map.h"
45 #include "utils/trace_base.h"
46 #include "utils/anf_utils.h"
47 #include "utils/ms_context.h"
48 #ifndef BUILD_LITE
49 #include "pybind_api/ir/base_ref_py.h"
50 #endif
51 
52 namespace mindspore::session {
53 using abstract::AbstractTensor;
54 using abstract::AbstractTuple;
55 using device::KernelInfo;
56 using kernel::KernelBuildInfoPtr;
57 using kernel::KernelMod;
58 using kernel::KernelModPtr;
59 constexpr char kDisableKernelBackoff[] = "MS_DISABLE_KERNEL_BACKOFF";
60 
61 namespace {
62 constexpr size_t kReturnDataIndex = 1;
63 constexpr size_t kSwitchTrueBranchIndex = 2;
64 constexpr auto kPatternUnknown = "";
65 
PrintKernelFormatAndType(const std::string & fmt,const TypeId & type,const std::vector<int64_t> & shape)66 std::string PrintKernelFormatAndType(const std::string &fmt, const TypeId &type, const std::vector<int64_t> &shape) {
67   std::ostringstream buffer;
68   buffer << "<" << TypeIdLabel(type);
69   if (!fmt.empty()) {
70     buffer << "x" << fmt << shape;
71   }
72   buffer << ">";
73   return buffer.str();
74 }
75 
76 [[maybe_unused]] struct AnfDumpHandlerRegister {
AnfDumpHandlerRegistermindspore::session::__anon4dbf874a0111::AnfDumpHandlerRegister77   AnfDumpHandlerRegister() {
78     AnfDumpHandler::SetPrintInputTypeShapeFormatHandler([](const std::shared_ptr<AnfNode> &node) -> std::string {
79       if (node == nullptr) {
80         return "";
81       }
82       std::ostringstream buffer;
83       size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
84       for (size_t i = 0; i < input_num; ++i) {
85         if (i != 0) {
86           buffer << ", ";
87         }
88         auto format = AnfAlgo::GetInputFormat(node, i);
89         auto type = AnfAlgo::GetInputDeviceDataType(node, i);
90         auto shape = AnfAlgo::GetInputDeviceShape(node, i);
91         buffer << PrintKernelFormatAndType(format, type, shape);
92       }
93       return buffer.str();
94     });
95     AnfDumpHandler::SetPrintOutputTypeShapeFormatHandler([](const std::shared_ptr<AnfNode> &node) -> std::string {
96       if (node == nullptr) {
97         return "";
98       }
99       std::ostringstream buffer;
100       size_t output_num = AnfAlgo::GetOutputTensorNum(node);
101       for (size_t i = 0; i < output_num; ++i) {
102         if (i != 0) {
103           buffer << ", ";
104         }
105         auto format = AnfAlgo::GetOutputFormat(node, (node->isa<Parameter>() ? 0 : i));
106         auto type = AnfAlgo::GetOutputDeviceDataType(node, (node->isa<Parameter>() ? 0 : i));
107         auto shape = AnfAlgo::GetOutputDeviceShape(node, (node->isa<Parameter>() ? 0 : i));
108         buffer << PrintKernelFormatAndType(format, type, shape);
109       }
110       return buffer.str();
111     });
112     AnfDumpHandler::SetPrintInputKernelObjectTypesHandler([](const std::shared_ptr<AnfNode> &node) -> std::string {
113       if (node == nullptr) {
114         return "";
115       }
116       auto input_obj_types = AnfAlgo::GetInputKernelObjectTypes(node);
117       return std::accumulate(
118         input_obj_types.begin(), input_obj_types.end(), std::string(), [](std::string &a, const KernelObjectType &b) {
119           return a.empty() ? kernel::KernelObjectTypeLabel(b) : a + ", " + kernel::KernelObjectTypeLabel(b);
120         });
121     });
122     AnfDumpHandler::SetPrintOutputKernelObjectTypesHandler([](const std::shared_ptr<AnfNode> &node) -> std::string {
123       if (node == nullptr) {
124         return "";
125       }
126       auto output_obj_types = AnfAlgo::GetOutputKernelObjectTypes(node);
127       return std::accumulate(
128         output_obj_types.begin(), output_obj_types.end(), std::string(), [](std::string &a, const KernelObjectType &b) {
129           return a.empty() ? kernel::KernelObjectTypeLabel(b) : a + ", " + kernel::KernelObjectTypeLabel(b);
130         });
131     });
132   }
133 } callback_register;
134 
GetForwardOutputTensor(const AnfNodePtr & node)135 tensor::BaseTensorPtr GetForwardOutputTensor(const AnfNodePtr &node) {
136   MS_EXCEPTION_IF_NULL(node);
137   if (node->isa<ValueNode>()) {
138     auto value_node = node->cast<ValueNodePtr>();
139     MS_EXCEPTION_IF_NULL(value_node);
140     auto value = value_node->value();
141     MS_EXCEPTION_IF_NULL(value);
142     if (value->isa<tensor::BaseTensor>()) {
143       auto tensor = value->cast<tensor::BaseTensorPtr>();
144       MS_EXCEPTION_IF_NULL(tensor);
145       // If output used as sens, output will create(clone) a fake tensor with device address is nullptr for memory
146       // usage. It has is_forward_output flag, which will be used for tensor input mask, and affect single op graph
147       // cache.
148       if (tensor->is_forward_output() && tensor->device_address() != nullptr) {
149         return tensor;
150       }
151     }
152   }
153   return nullptr;
154 }
155 
GetOutputTensorNumByKernelInfo(const AnfNodePtr & node)156 size_t GetOutputTensorNumByKernelInfo(const AnfNodePtr &node) {
157   MS_EXCEPTION_IF_NULL(node);
158   MS_EXCEPTION_IF_NULL(node->kernel_info());
159   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
160   MS_EXCEPTION_IF_NULL(kernel_info);
161   const auto &build_info = kernel_info->GetMutableSelectKernelBuildInfo();
162   MS_EXCEPTION_IF_NULL(build_info);
163   return build_info->GetAllOutputDeviceTypes().size();
164 }
165 
ContainScalarOut(const AbstractBasePtr & abs)166 bool ContainScalarOut(const AbstractBasePtr &abs) {
167   // Check the output abstract of node whether is scalar.
168   if ((abs != nullptr) && (abs->isa<abstract::AbstractScalar>())) {
169     return true;
170   }
171   // Check the output abstracts of node whether have scalar.
172   if ((abs != nullptr) && (abs->isa<abstract::AbstractSequence>())) {
173     auto abs_seq = abs->cast_ptr<abstract::AbstractSequence>();
174     MS_EXCEPTION_IF_NULL(abs_seq);
175     if (abs_seq->dynamic_len()) {
176       const auto &element_abs = abs_seq->dynamic_len_element_abs();
177       return (element_abs == nullptr) || (element_abs->isa<abstract::AbstractScalar>());
178     }
179     const auto &elements = abs_seq->elements();
180     bool has_scalar_out = std::any_of(elements.begin(), elements.end(),
181                                       [](const AbstractBasePtr &element) { return ContainScalarOut(element); });
182     return has_scalar_out;
183   }
184   return false;
185 }
186 }  // namespace
187 
MakeMonadValueNode(const KernelGraphPtr & kg)188 AnfNodePtr AnfRuntimeAlgorithm::MakeMonadValueNode(const KernelGraphPtr &kg) {
189   MS_EXCEPTION_IF_NULL(kg);
190   return kg->NewValueNode(kUMonad->ToAbstract(), kUMonad);
191 }
192 
193 // Convert: a = former(xxx)
194 //          b = latter(x, xxx)
195 // To:      a = former(xxx)
196 //          d1 = Depend(x, a)
197 //          b = latter(d1, xxx)
198 //          ...
199 //          out = Depend(out, latter)
KeepOrder(const KernelGraphPtr & kg,const AnfNodePtr & former,const AnfNodePtr & latter)200 void AnfRuntimeAlgorithm::KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter) {
201   MS_EXCEPTION_IF_NULL(kg);
202   MS_EXCEPTION_IF_NULL(former);
203   MS_EXCEPTION_IF_NULL(latter);
204   if (latter->isa<CNode>()) {
205     auto latter_cnode = latter->cast<CNodePtr>();
206     MS_EXCEPTION_IF_NULL(latter_cnode);
207     constexpr size_t inputsize = 2;
208     constexpr size_t kFirstDataInputIndex = 1;
209     if (latter_cnode->size() < inputsize) {
210       return;
211     }
212     auto latter_input = latter_cnode->input(kFirstDataInputIndex);
213     auto depend1 = kg->NewCNode({NewValueNode(prim::kPrimDepend), latter_input, former});
214     MS_EXCEPTION_IF_NULL(depend1);
215     MS_EXCEPTION_IF_NULL(latter_input);
216     depend1->set_abstract(latter_input->abstract());
217     latter_cnode->set_input(kFirstDataInputIndex, depend1);
218 
219     auto return_node = kg->get_return();
220     MS_EXCEPTION_IF_NULL(return_node);
221     auto depend2 = kg->NewCNode(
222       {NewValueNode(prim::kPrimDepend), return_node->cast<CNodePtr>()->input(kFirstDataInputIndex), latter});
223     MS_EXCEPTION_IF_NULL(depend2);
224     depend2->set_abstract(return_node->cast<CNodePtr>()->input(kFirstDataInputIndex)->abstract());
225     kg->set_output(depend2);
226     MS_LOG(DEBUG) << "former: " << former->DebugString() << ", latter: " << latter->DebugString()
227                   << ", depend1: " << depend1->DebugString() << ", depend2: " << depend2->DebugString();
228   }
229 }
230 
GetOutputTensorNum(const AnfNodePtr & node)231 size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
232   MS_EXCEPTION_IF_NULL(node);
233   size_t res;
234   TypePtr type = node->Type();
235   if (type == nullptr) {
236     res = 0;
237   } else if (type->isa<Tuple>() || type->isa<List>()) {
238     const auto &kernel_info = node->kernel_info();
239     if (kernel_info == nullptr || (!kernel_info->has_build_info())) {
240       return 1;
241     }
242     res = GetOutputTensorNumByKernelInfo(node);
243   } else if (type->isa<TypeNone>()) {
244     res = 0;
245   } else if (type->isa<CSRTensorType>()) {
246     // Currently, CSRTensor only supports 2-D matrix (shape has 2 values). 5 outputs = 3 Tensors + 2 shape values.
247     constexpr size_t kCSRTensorOutputNum = 5;
248     res = kCSRTensorOutputNum;
249   } else if (type->isa<COOTensorType>()) {
250     // Currently, COOTensor only supports 2-D matrix (shape has 2 values). 4 outputs = 2 Tensors + 2 shape values.
251     constexpr size_t kCOOTensorOutputNum = 4;
252     res = kCOOTensorOutputNum;
253   } else if (AnfUtils::NeedJumpMonadOutput(node) && type->isa<MonadType>()) {
254     // Some nodes could have monad outputs like RpcRecv. We need to jump these outputs.
255     res = 0;
256   } else {
257     res = 1;
258   }
259   return res;
260 }
261 
GetOutputNumWithoutKernelInfo(const AnfNodePtr & node)262 size_t AnfRuntimeAlgorithm::GetOutputNumWithoutKernelInfo(const AnfNodePtr &node) {
263   MS_EXCEPTION_IF_NULL(node);
264   const auto &kernel_info = node->kernel_info();
265   if (kernel_info != nullptr) {
266     MS_LOG(EXCEPTION) << "Kernel info is not null for node:" << node->DebugString();
267   }
268 
269   size_t res;
270   TypePtr type = node->Type();
271   if (type == nullptr) {
272     res = 0;
273   } else if (type->isa<Tuple>() || type->isa<List>()) {
274     res = 1;
275   } else if (type->isa<TypeNone>()) {
276     res = 0;
277   } else if (type->isa<CSRTensorType>()) {
278     // Currently, CSRTensor only supports 2-D matrix (shape has 2 values). 5 outputs = 3 Tensors + 2 shape values.
279     constexpr size_t kCSRTensorOutputNum = 5;
280     res = kCSRTensorOutputNum;
281   } else if (type->isa<COOTensorType>()) {
282     // Currently, COOTensor only supports 2-D matrix (shape has 2 values). 4 outputs = 2 Tensors + 2 shape values.
283     constexpr size_t kCOOTensorOutputNum = 4;
284     res = kCOOTensorOutputNum;
285   } else if (AnfUtils::NeedJumpMonadOutput(node) && type->isa<MonadType>()) {
286     // Some nodes could have monad outputs like RpcRecv. We need to jump these outputs.
287     res = 0;
288   } else {
289     res = 1;
290   }
291   return res;
292 }
293 
294 namespace {
IsTupleHasDynamicSequence(const abstract::AbstractBasePtr & abstract)295 bool IsTupleHasDynamicSequence(const abstract::AbstractBasePtr &abstract) {
296   MS_EXCEPTION_IF_NULL(abstract);
297   if (!abstract->isa<abstract::AbstractSequence>()) {
298     return false;
299   }
300   const auto &sequence_abs = abstract->cast<abstract::AbstractSequencePtr>();
301   MS_EXCEPTION_IF_NULL(sequence_abs);
302   if (sequence_abs->dynamic_len() || sequence_abs->dynamic_len_element_abs() != nullptr) {
303     return true;
304   }
305   if (std::any_of(sequence_abs->elements().begin(), sequence_abs->elements().end(),
306                   [](const abstract::AbstractBasePtr &abs) { return IsTupleHasDynamicSequence(abs); })) {
307     return true;
308   }
309   return false;
310 }
311 }  // namespace
312 
GetOutputElementNum(const AnfNodePtr & node)313 size_t AnfRuntimeAlgorithm::GetOutputElementNum(const AnfNodePtr &node) {
314   if (node->abstract() != nullptr && IsTupleHasDynamicSequence(node->abstract())) {
315     return common::AnfAlgo::GetOutputNumByAbstract(node->abstract());
316   }
317   return AnfUtils::GetOutputTensorNum(node);
318 }
319 
GetOutputTensorMemSizeImpl(const AnfNodePtr & node,size_t output_index,const ShapeVector & real_shape)320 size_t GetOutputTensorMemSizeImpl(const AnfNodePtr &node, size_t output_index, const ShapeVector &real_shape) {
321   MS_EXCEPTION_IF_NULL(node);
322   if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
323     MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
324                                 << AnfAlgo::GetOutputTensorNum(node) << "] of node!";
325   }
326   TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
327   if (output_type_id == kTypeUnknown) {
328     output_type_id = common::AnfAlgo::GetOutputInferDataType(node, output_index);
329   }
330   size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
331   auto shape = real_shape;
332   auto format = AnfAlgo::GetOutputFormat(node, output_index);
333   auto dtype = AnfAlgo::GetOutputDeviceDataType(node, output_index);
334   if (shape.empty() && format != kOpFormat_DEFAULT) {
335     shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index), node);
336     shape = trans::TransShapeToDevice(shape, format, node, output_index, dtype);
337   }
338   // scalar's output shape is a empty vector
339   size_t tensor_size = type_size * SizeOf(shape);
340   return tensor_size;
341 }
342 
GetOutputTensorMemSize(const AnfNodePtr & node,size_t output_index,const ShapeVector & real_shape)343 size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index,
344                                                    const ShapeVector &real_shape) {
345   if (IsDynamic(real_shape)) {
346     MS_LOG(EXCEPTION) << "The shape is " << real_shape << " dynamic shape , can not get OutputTensorMemSize";
347   }
348   return GetOutputTensorMemSizeImpl(node, output_index, real_shape);
349 }
350 
GetOutputTensorMemSize(const AnfNodePtr & node,size_t output_index)351 size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index) {
352   MS_EXCEPTION_IF_NULL(node);
353   auto shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
354   if (IsDynamic(shape)) {
355     auto max_shape = common::AnfAlgo::GetOutputMaxShape(node, output_index);
356     if (!max_shape.empty()) {
357       shape = max_shape;
358       MS_LOG(DEBUG) << "shape[" << shape << "] is dynamic, using max_shape[" << max_shape << "] instead.";
359     } else {
360       shape = {1};
361       MS_LOG(DEBUG) << "shape[" << shape << "] is dynamic, set default to {1}";
362     }
363   }
364   return GetOutputTensorMemSizeImpl(node, output_index, shape);
365 }
366 
GetAllOutputFormats(const AnfNodePtr & node)367 std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) {
368   MS_EXCEPTION_IF_NULL(node);
369   if (!AnfUtils::IsRealKernel(node)) {
370     MS_LOG(EXCEPTION) << "Not real kernel:"
371                       << "#node [" << node->DebugString() << "]" << trace::DumpSourceLines(node);
372   }
373   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
374   MS_EXCEPTION_IF_NULL(kernel_info);
375   auto build_info = kernel_info->select_kernel_build_info();
376   MS_EXCEPTION_IF_NULL(build_info);
377   auto format = build_info->GetAllOutputFormats();
378   return format;
379 }
380 
GetAllInputFormats(const AnfNodePtr & node)381 std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePtr &node) {
382   MS_EXCEPTION_IF_NULL(node);
383   if (!AnfUtils::IsRealKernel(node)) {
384     MS_LOG(EXCEPTION) << "Not real kernel:"
385                       << "#node [" << node->DebugString() << "]" << trace::DumpSourceLines(node);
386   }
387   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
388   MS_EXCEPTION_IF_NULL(kernel_info);
389   auto build_info = kernel_info->select_kernel_build_info();
390   MS_EXCEPTION_IF_NULL(build_info);
391   auto format = build_info->GetAllInputFormats();
392   return format;
393 }
394 
GetAllInputDeviceTypes(const AnfNodePtr & node)395 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllInputDeviceTypes(const AnfNodePtr &node) {
396   MS_EXCEPTION_IF_NULL(node);
397   if (!AnfUtils::IsRealKernel(node)) {
398     MS_LOG(EXCEPTION) << "Not real kernel:"
399                       << "#node [" << node->DebugString() << "]" << trace::DumpSourceLines(node);
400   }
401   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
402   MS_EXCEPTION_IF_NULL(kernel_info);
403   auto build_info = kernel_info->select_kernel_build_info();
404   MS_EXCEPTION_IF_NULL(build_info);
405   auto types = build_info->GetAllInputDeviceTypes();
406   return types;
407 }
408 
GetAllOutputDeviceTypes(const AnfNodePtr & node)409 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputDeviceTypes(const AnfNodePtr &node) {
410   MS_EXCEPTION_IF_NULL(node);
411   if (!AnfUtils::IsRealKernel(node)) {
412     MS_LOG(EXCEPTION) << "Not real kernel:"
413                       << "#node [" << node->DebugString() << "]" << trace::DumpSourceLines(node);
414   }
415   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
416   MS_EXCEPTION_IF_NULL(kernel_info);
417   auto build_info = kernel_info->select_kernel_build_info();
418   MS_EXCEPTION_IF_NULL(build_info);
419   auto types = build_info->GetAllOutputDeviceTypes();
420   return types;
421 }
422 
GetOriginDataFormat(const AnfNodePtr & node)423 std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) {
424   MS_EXCEPTION_IF_NULL(node);
425   if (!AnfUtils::IsRealKernel(node)) {
426     MS_LOG(EXCEPTION) << "Not real kernel:"
427                       << "#node [" << node->DebugString() << "]" << trace::DumpSourceLines(node);
428   }
429   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
430   if (kernel_info == nullptr) {
431     return kOpFormat_DEFAULT;
432   }
433   auto build_info = kernel_info->select_kernel_build_info();
434   if (build_info == nullptr) {
435     return kOpFormat_DEFAULT;
436   }
437   auto format = build_info->GetOriginDataFormat();
438   return format;
439 }
440 
GetOutputFormat(const AnfNodePtr & node,size_t output_idx)441 std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
442   MS_EXCEPTION_IF_NULL(node);
443   if (output_idx > AnfAlgo::GetOutputElementNum(node) && (!common::AnfAlgo::IsDynamicSequence(node))) {
444     MS_LOG(EXCEPTION) << "Output index:" << output_idx
445                       << " is out of the node output range :" << AnfAlgo::GetOutputElementNum(node) << " #node ["
446                       << node->DebugString() << "]" << trace::DumpSourceLines(node);
447   }
448   if (common::AnfAlgo::CheckAbsSparseTensor(node)) {
449     return kOpFormat_DEFAULT;
450   }
451   if (!AnfUtils::IsRealKernel(node)) {
452     return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
453   }
454   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
455   MS_EXCEPTION_IF_NULL(kernel_info);
456   auto build_info = kernel_info->select_kernel_build_info();
457   MS_EXCEPTION_IF_NULL(build_info);
458   std::string format;
459   // If the output is TUPLE, output format list's size is 1. So we use the first element as the output format.
460   // This scenario could happen before 'insert_type_transform_op' pass.
461   auto output_obj_types = build_info->GetAllOutputKernelObjectTypes();
462   if (!output_obj_types.empty() && output_obj_types[kIndex0] == KernelObjectType::TUPLE) {
463     MS_LOG(DEBUG) << "TUPLE only has one output. So use index 0 output format.";
464     format = build_info->GetOutputFormat(kIndex0);
465   } else {
466     format = build_info->GetOutputFormat(output_idx);
467   }
468   if (format == kernel::KernelBuildInfo::kInvalidFormat) {
469     MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
470                       << " has a invalid output format" << trace::DumpSourceLines(node);
471   }
472   return format;
473 }
474 
GetInputFormat(const AnfNodePtr & node,size_t input_idx)475 std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
476   MS_EXCEPTION_IF_NULL(node);
477   if (input_idx > common::AnfAlgo::GetInputTensorNum(node)) {
478     MS_LOG(EXCEPTION) << "Input index :" << input_idx
479                       << " is out of the number node Input range :" << common::AnfAlgo::GetInputTensorNum(node)
480                       << "#node [" << node->DebugString() << "]" << trace::DumpSourceLines(node);
481   }
482   if (!AnfUtils::IsRealKernel(node)) {
483     return GetPrevNodeOutputFormat(node, input_idx);
484   }
485   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
486   MS_EXCEPTION_IF_NULL(kernel_info);
487   auto build_info = kernel_info->select_kernel_build_info();
488   MS_EXCEPTION_IF_NULL(build_info);
489   auto format = build_info->GetInputFormat(input_idx);
490   if (format == kernel::KernelBuildInfo::kInvalidFormat) {
491     MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
492                       << " input index:" << input_idx << " has a invalid input format\n"
493                       << trace::DumpSourceLines(node);
494   }
495   return format;
496 }
497 
IsEquivalentFormat(const std::string & src_format,const std::string & dst_format)498 bool AnfRuntimeAlgorithm::IsEquivalentFormat(const std::string &src_format, const std::string &dst_format) {
499   if (src_format == dst_format) {
500     return true;
501   }
502 
503   // Equivalent default format.
504   if (((src_format == kOpFormat_DEFAULT) || (src_format == kOpFormat_NCHW) || (src_format == kOpFormat_ND)) &&
505       ((dst_format == kOpFormat_DEFAULT) || (dst_format == kOpFormat_NCHW) || (dst_format == kOpFormat_ND))) {
506     return true;
507   }
508 
509   return false;
510 }
511 
GetPrevNodeOutputFormat(const AnfNodePtr & anf_node,size_t input_idx)512 std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
513   KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
514   return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
515 }
516 
GetPrevNodeOutputReshapeType(const AnfNodePtr & node,size_t input_idx)517 std::string AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
518   KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_idx);
519   return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
520 }
521 
GetInputKernelObjectTypes(const AnfNodePtr & node)522 std::vector<KernelObjectType> AnfRuntimeAlgorithm::GetInputKernelObjectTypes(const AnfNodePtr &node) {
523   MS_EXCEPTION_IF_NULL(node);
524   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
525   MS_EXCEPTION_IF_NULL(kernel_info);
526   auto build_info = kernel_info->select_kernel_build_info();
527   if (build_info == nullptr) {
528     MS_LOG(EXCEPTION) << "Empty build info for node:" << node->fullname_with_scope()
529                       << ", debug name:" << node->DebugString();
530   }
531   return build_info->GetAllInputKernelObjectTypes();
532 }
533 
GetInputKernelObjectType(const AnfNodePtr & node,size_t input_idx)534 KernelObjectType AnfRuntimeAlgorithm::GetInputKernelObjectType(const AnfNodePtr &node, size_t input_idx) {
535   MS_EXCEPTION_IF_NULL(node);
536   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
537   MS_EXCEPTION_IF_NULL(kernel_info);
538   auto build_info = kernel_info->select_kernel_build_info();
539   if (build_info == nullptr) {
540     MS_LOG(EXCEPTION) << "Empty build info for node:" << node->fullname_with_scope()
541                       << ", debug name:" << node->DebugString();
542   }
543   const auto &input_kernel_obj_types = build_info->GetAllInputKernelObjectTypes();
544   if (input_idx >= input_kernel_obj_types.size()) {
545     MS_LOG(EXCEPTION) << "Input index " << input_idx << ", but the node input kernel object types size just "
546                       << input_kernel_obj_types.size() << ". node: " << node->fullname_with_scope()
547                       << ", debug name:" << node->DebugString() << "." << trace::DumpSourceLines(node);
548   }
549   return input_kernel_obj_types[input_idx];
550 }
551 
GetOutputKernelObjectTypes(const AnfNodePtr & node)552 std::vector<KernelObjectType> AnfRuntimeAlgorithm::GetOutputKernelObjectTypes(const AnfNodePtr &node) {
553   MS_EXCEPTION_IF_NULL(node);
554   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
555   MS_EXCEPTION_IF_NULL(kernel_info);
556   auto build_info = kernel_info->select_kernel_build_info();
557   if (build_info == nullptr) {
558     return {};
559   }
560   return build_info->GetAllOutputKernelObjectTypes();
561 }
562 
GetOutputKernelObjectType(const AnfNodePtr & node,size_t output_idx)563 KernelObjectType AnfRuntimeAlgorithm::GetOutputKernelObjectType(const AnfNodePtr &node, size_t output_idx) {
564   MS_EXCEPTION_IF_NULL(node);
565   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
566   MS_EXCEPTION_IF_NULL(kernel_info);
567   auto build_info = kernel_info->select_kernel_build_info();
568   if (build_info == nullptr) {
569     MS_LOG(EXCEPTION) << "Empty build info for node:" << node->fullname_with_scope()
570                       << ", debug name:" << node->DebugString();
571   }
572   const auto &output_kernel_obj_types = build_info->GetAllOutputKernelObjectTypes();
573   if (output_idx >= output_kernel_obj_types.size()) {
574     MS_LOG(EXCEPTION) << "Output index " << output_idx << ", but the node output kernel object types size just "
575                       << output_kernel_obj_types.size() << ". node: " << node->fullname_with_scope()
576                       << ", debug name:" << node->DebugString() << "." << trace::DumpSourceLines(node);
577   }
578   return output_kernel_obj_types[output_idx];
579 }
580 
GetOutputElementsKernelObjectTypes(const AnfNodePtr & node)581 std::vector<KernelObjectType> AnfRuntimeAlgorithm::GetOutputElementsKernelObjectTypes(const AnfNodePtr &node) {
582   MS_EXCEPTION_IF_NULL(node);
583   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
584   MS_EXCEPTION_IF_NULL(kernel_info);
585   auto build_info = kernel_info->select_kernel_build_info();
586   if (build_info == nullptr) {
587     MS_LOG(EXCEPTION) << "Empty build info for node:" << node->fullname_with_scope()
588                       << ", debug name:" << node->DebugString();
589   }
590   return build_info->GetAllOutputElementsKernelObjectTypes();
591 }
592 
GetValid(const AnfNodePtr & node)593 bool AnfRuntimeAlgorithm::GetValid(const AnfNodePtr &node) {
594   MS_EXCEPTION_IF_NULL(node);
595   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
596   MS_EXCEPTION_IF_NULL(kernel_info);
597   auto build_info = kernel_info->select_kernel_build_info();
598   if (build_info == nullptr) {
599     MS_LOG(EXCEPTION) << "Empty build info for node:" << node->fullname_with_scope()
600                       << ", debug name:" << node->DebugString();
601   }
602   return build_info->valid();
603 }
604 
IsRealSquenceOutput(const AnfNodePtr & node)605 bool AnfRuntimeAlgorithm::IsRealSquenceOutput(const AnfNodePtr &node) {
606   MS_EXCEPTION_IF_NULL(node);
607   std::vector<KernelObjectType> objects = GetOutputKernelObjectTypes(node);
608   bool is_real_tuple = false;
609   if (objects.empty()) {
610     return false;
611   } else {
612     is_real_tuple = (objects[0] == KernelObjectType::TUPLE);
613   }
614   return is_real_tuple;
615 }
616 
GetOutputDeviceShapeForTbeBuild(const AnfNodePtr & node,size_t output_idx,const std::string & format)617 std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node, size_t output_idx,
618                                                                           const std::string &format) {
619   auto output_shape = AnfAlgo::GetOutputDetailShape(node, output_idx);
620   std::vector<int64_t> infer_shape;
621   if (output_shape->isa<abstract::Shape>()) {
622     auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
623     MS_EXCEPTION_IF_NULL(shape_ptr);
624     infer_shape = shape_ptr->shape();
625   }
626   if (infer_shape.empty()) {
627     return infer_shape;
628   }
629 
630   // if format is default_format or NC1KHKWHWC0,device shape = original shape
631   if (trans::IsNeedPadding(format, infer_shape)) {
632     infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx), node);
633   }
634   auto dtype = GetOutputDeviceDataType(node, output_idx);
635   return trans::TransShapeToDevice(infer_shape, format, node, output_idx, dtype);
636 }
637 
IsShapesDynamic(const std::vector<ShapeVector> & shapes)638 bool AnfRuntimeAlgorithm::IsShapesDynamic(const std::vector<ShapeVector> &shapes) {
639   return std::any_of(shapes.cbegin(), shapes.cend(), [](const auto &shape) { return IsDynamic(shape); });
640 }
641 
GetOutputDeviceShape(const AnfNodePtr & node,size_t output_idx)642 ShapeVector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
643   auto format = GetOutputFormat(node, output_idx);
644   auto infer_shape = common::AnfAlgo::GetOutputInferShape(node, output_idx, IsRealSquenceOutput(node));
645   if (infer_shape.empty()) {
646     return infer_shape;
647   }
648 
649   // if format is default_format or NC1KHKWHWC0,device shape = original shape
650   if (trans::IsNeedPadding(format, infer_shape)) {
651     infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx), node);
652   }
653   auto dtype = GetOutputDeviceDataType(node, output_idx);
654   return trans::TransShapeToDevice(infer_shape, format, node, output_idx, dtype);
655 }
656 
GetOutputDeviceShape(const AnfNodePtr & node,size_t output_idx,ShapeVector real_shape)657 ShapeVector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx,
658                                                       ShapeVector real_shape) {
659   auto format = GetOutputFormat(node, output_idx);
660   if (real_shape.empty()) {
661     return real_shape;
662   }
663 
664   // if format is default_format or NC1KHKWHWC0,device shape = original shape
665   if (trans::IsNeedPadding(format, real_shape)) {
666     real_shape = trans::PaddingShape(real_shape, format, GetOutputReshapeType(node, output_idx), node);
667   }
668   auto dtype = GetOutputDeviceDataType(node, output_idx);
669   return trans::TransShapeToDevice(real_shape, format, node, output_idx, dtype);
670 }
671 
GetInputDeviceShapeForTbeBuild(const AnfNodePtr & node,size_t input_idx,const std::string & format)672 std::vector<int64_t> AnfRuntimeAlgorithm::GetInputDeviceShapeForTbeBuild(const AnfNodePtr &node, size_t input_idx,
673                                                                          const std::string &format) {
674   auto output_shape = AnfAlgo::GetPrevNodeOutputDetailShape(node, input_idx);
675   std::vector<int64_t> infer_shape;
676   if (output_shape->isa<abstract::Shape>()) {
677     auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
678     MS_EXCEPTION_IF_NULL(shape_ptr);
679     infer_shape = shape_ptr->shape();
680   }
681   if (infer_shape.empty()) {
682     return infer_shape;
683   }
684 
685   // if format is default_format or NC1KHKWHWC0,device shape = original shape
686   if (trans::IsNeedPadding(format, infer_shape)) {
687     infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx), node);
688   }
689   auto dtype = GetInputDeviceDataType(node, input_idx);
690   return trans::TransShapeToDevice(infer_shape, format, node, input_idx, dtype, false);
691 }
692 
GetInputDeviceShape(const AnfNodePtr & node,size_t input_idx)693 std::vector<int64_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
694   auto format = GetInputFormat(node, input_idx);
695   auto infer_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, input_idx);
696   if (infer_shape.empty()) {
697     return infer_shape;
698   }
699   // if format is default_format or NC1KHKWHWC0,device shape = original shape
700   if (trans::IsNeedPadding(format, infer_shape)) {
701     infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx), node);
702   }
703   auto dtype = GetInputDeviceDataType(node, input_idx);
704   return trans::TransShapeToDevice(infer_shape, format, node, input_idx, dtype, false);
705 }
706 
GetInputReshapeType(const AnfNodePtr & node,size_t input_idx)707 std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
708   MS_EXCEPTION_IF_NULL(node);
709   if (input_idx > common::AnfAlgo::GetInputTensorNum(node)) {
710     MS_LOG(EXCEPTION) << "The index:" << input_idx
711                       << " is out of range of the node's input size : " << common::AnfAlgo::GetInputTensorNum(node)
712                       << "#node[" << node->DebugString() << "]" << trace::DumpSourceLines(node);
713   }
714   if (!AnfUtils::IsRealKernel(node)) {
715     return GetPrevNodeOutputReshapeType(node, input_idx);
716   }
717   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
718   MS_EXCEPTION_IF_NULL(kernel_info);
719   auto build_info = kernel_info->select_kernel_build_info();
720   if (build_info == nullptr || build_info->IsInputDefaultPadding()) {
721     return "";
722   }
723   return build_info->GetInputReshapeType(input_idx);
724 }
725 
GetOutputReshapeType(const AnfNodePtr & node,size_t output_idx)726 std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
727   MS_EXCEPTION_IF_NULL(node);
728   if (!AnfUtils::IsRealKernel(node)) {
729     return GetPrevNodeOutputReshapeType(node, output_idx);
730   }
731   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
732   MS_EXCEPTION_IF_NULL(kernel_info);
733   auto build_info = kernel_info->select_kernel_build_info();
734   if (build_info == nullptr || build_info->IsOutputDefaultPadding()) {
735     return "";
736   }
737   return build_info->GetOutputReshapeType(output_idx);
738 }
739 
GetAllInputReshapeType(const AnfNodePtr & node)740 std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputReshapeType(const AnfNodePtr &node) {
741   MS_EXCEPTION_IF_NULL(node);
742   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
743   MS_EXCEPTION_IF_NULL(kernel_info);
744   auto build_info = kernel_info->select_kernel_build_info();
745   if (build_info == nullptr || build_info->IsInputDefaultPadding()) {
746     return {};
747   }
748   return build_info->GetAllInputReshapeType();
749 }
750 
GetAllOutputReshapeType(const AnfNodePtr & node)751 std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputReshapeType(const AnfNodePtr &node) {
752   MS_EXCEPTION_IF_NULL(node);
753   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
754   MS_EXCEPTION_IF_NULL(kernel_info);
755   auto build_info = kernel_info->select_kernel_build_info();
756   if (build_info == nullptr || build_info->IsOutputDefaultPadding()) {
757     return {};
758   }
759   return build_info->GetAllOutputReshapeType();
760 }
761 
GetOutputDeviceDataType(const AnfNodePtr & node,size_t output_idx)762 TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) {
763   MS_EXCEPTION_IF_NULL(node);
764   if (output_idx > AnfAlgo::GetOutputElementNum(node)) {
765     MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
766                       << AnfAlgo::GetOutputElementNum(node) << "#node [ " << node->DebugString() << "]"
767                       << trace::DumpSourceLines(node);
768   }
769   if (common::AnfAlgo::CheckAbsSparseTensor(node)) {
770     return common::AnfAlgo::GetSparseTypeIdAt(node, output_idx);
771   }
772   if (!AnfUtils::IsRealKernel(node)) {
773     return GetPrevNodeOutputDeviceDataType(node, output_idx);
774   }
775   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
776   MS_EXCEPTION_IF_NULL(kernel_info);
777   auto build_info = kernel_info->select_kernel_build_info();
778   MS_EXCEPTION_IF_NULL(build_info);
779 
780   // If node has only one output and it is Tuple, in build_info, it only has one same dtype, so set output_dix as zero.
781   if (build_info->GetOutputNum() == 1 && build_info->GetOutputKernelObjectType(0) == kernel::KernelObjectType::TUPLE) {
782     output_idx = 0;
783   }
784 
785   auto dtype = build_info->GetOutputDeviceType(output_idx);
786   if (dtype == TypeId::kNumberTypeEnd) {
787     MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "] has a invalid dtype" << trace::DumpSourceLines(node);
788   }
789   return dtype;
790 }
791 
GetInputDeviceDataType(const AnfNodePtr & node,size_t input_idx)792 TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
793   MS_EXCEPTION_IF_NULL(node);
794   if (input_idx > common::AnfAlgo::GetInputTensorNum(node)) {
795     MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
796                       << common::AnfAlgo::GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"
797                       << trace::DumpSourceLines(node);
798   }
799   if (!AnfUtils::IsRealKernel(node)) {
800     return GetPrevNodeOutputDeviceDataType(node, 0);
801   }
802   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
803   MS_EXCEPTION_IF_NULL(kernel_info);
804   auto build_info = kernel_info->select_kernel_build_info();
805   MS_EXCEPTION_IF_NULL(build_info);
806   auto dtype = build_info->GetInputDeviceType(input_idx);
807   if (dtype == TypeId::kNumberTypeEnd) {
808     MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
809                       << " has a invalid dtype." << trace::DumpSourceLines(node);
810   }
811   return dtype;
812 }
813 
GetPrevNodeOutputDeviceDataType(const AnfNodePtr & anf_node,size_t input_idx)814 TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
815   KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
816   return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
817 }
818 
819 // get output device addr of anf_node
GetOutputAddr(const AnfNodePtr & node,size_t output_idx,bool skip_nop_node)820 const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool skip_nop_node) {
821   MS_EXCEPTION_IF_NULL(node);
822   auto tensor = GetForwardOutputTensor(node);
823   if (tensor != nullptr) {
824     return dynamic_cast<const DeviceAddress *>(tensor->device_address().get());
825   }
826 
827   if (common::AnfAlgo::IsNopNode(node) && (skip_nop_node || common::AnfAlgo::IsNeedSkipNopOpAddr(node))) {
828     auto cnode = node->cast<CNodePtr>();
829     MS_EXCEPTION_IF_NULL(cnode);
830     return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
831   }
832   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
833   MS_EXCEPTION_IF_NULL(kernel_info);
834   auto addr = kernel_info->GetOutputAddr(output_idx);
835   if (addr == nullptr) {
836     MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
837                       << " output addr is not exist." << trace::DumpSourceLines(node);
838   }
839   return addr;
840 }
841 
GetMutableOutputAddr(const AnfNodePtr & node,size_t output_idx,bool skip_nop_node)842 DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx,
843                                                            bool skip_nop_node) {
844   MS_EXCEPTION_IF_NULL(node);
845   auto tensor = GetForwardOutputTensor(node);
846   if (tensor != nullptr) {
847     return std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address());
848   }
849 
850   if (common::AnfAlgo::IsNopNode(node) && (skip_nop_node || common::AnfAlgo::IsNeedSkipNopOpAddr(node))) {
851     auto cnode = node->cast<CNodePtr>();
852     MS_EXCEPTION_IF_NULL(cnode);
853     return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
854   }
855   // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
856   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
857   MS_EXCEPTION_IF_NULL(kernel_info);
858   auto addr = kernel_info->GetMutableOutputAddr(output_idx);
859   if (addr == nullptr) {
860     MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() << " node:" << node
861                       << " output addr is not exist." << trace::DumpSourceLines(node);
862   }
863   return addr;
864 }
865 
866 // get output device addr of anf_node
OutputAddrExist(const AnfNodePtr & node,size_t output_idx,bool skip_nop_node)867 bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool skip_nop_node) {
868   MS_EXCEPTION_IF_NULL(node);
869   if (common::AnfAlgo::IsNopNode(node) && (skip_nop_node || common::AnfAlgo::IsNeedSkipNopOpAddr(node))) {
870     auto cnode = node->cast<CNodePtr>();
871     MS_EXCEPTION_IF_NULL(cnode);
872     if (cnode->size() > 1) {
873       auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, 0);
874       return OutputAddrExist(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
875     }
876     return false;
877   }
878   // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
879   auto kernel_info_ptr = node->kernel_info();
880   if (kernel_info_ptr == nullptr) {
881     return false;
882   }
883   auto kernel_info = dynamic_cast<device::KernelInfo *>(kernel_info_ptr);
884   MS_EXCEPTION_IF_NULL(kernel_info);
885   return kernel_info->OutputAddrExist(output_idx);
886 }
887 
WorkspaceAddrExist(const AnfNodePtr & node,size_t output_idx)888 bool AnfRuntimeAlgorithm::WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx) {
889   MS_EXCEPTION_IF_NULL(node);
890   // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
891   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
892   MS_EXCEPTION_IF_NULL(kernel_info);
893   return kernel_info->WorkspaceAddrExist(output_idx);
894 }
895 
GetPrevNodeOutputAddr(const AnfNodePtr & anf_node,size_t input_idx,bool skip_nop_node)896 const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
897                                                                 bool skip_nop_node) {
898   KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
899   return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
900 }
901 
GetPrevNodeMutableOutputAddr(const AnfNodePtr & anf_node,size_t input_idx,bool skip_nop_node)902 DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
903                                                                    bool skip_nop_node) {
904   KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
905   return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
906 }
907 
GetAbstractInfo(const AnfNodePtr & node,size_t output_idx)908 std::tuple<abstract::BaseShapePtr, TypePtr, ValuePtr> AnfRuntimeAlgorithm::GetAbstractInfo(const AnfNodePtr &node,
909                                                                                            size_t output_idx) {
910   MS_EXCEPTION_IF_NULL(node);
911   abstract::BaseShapePtr shape;
912   TypePtr type;
913   ValuePtr value;
914 
915   // Create output kernel tensor if not exists.
916   if (node->isa<ValueNode>()) {
917     auto value_node = node->cast<ValueNodePtr>();
918     MS_EXCEPTION_IF_NULL(value_node);
919     value = value_node->value();
920     auto abs = node->abstract();
921     if (abs == nullptr) {
922       MS_EXCEPTION_IF_NULL(value);
923       abs = value->ToAbstract();
924       value_node->set_abstract(abs);
925     }
926     MS_EXCEPTION_IF_NULL(abs);
927     shape = abs->GetShape();
928     type = abs->GetType();
929   } else {
930     const auto &abs = AnfAlgo::GetNodeAbstractByIndex(node, output_idx);
931     MS_EXCEPTION_IF_NULL(abs);
932     shape = abs->GetShape();
933     type = abs->GetType();
934     value = nullptr;
935   }
936 
937   // Insert cast pass will change the device type for some reason like CPU do not support fp16 actually,
938   // so the output infer type and device type will be different, we change the output tensor to the real device type.
939   MS_EXCEPTION_IF_NULL(type);
940   if (type->isa<TensorType>()) {
941     auto real_device_type = AnfAlgo::GetOutputDeviceDataType(node, output_idx);
942     auto abs_tensor_type = type->Clone()->cast<TensorTypePtr>();
943     MS_EXCEPTION_IF_NULL(abs_tensor_type);
944     auto abs_element = abs_tensor_type->element();
945     if (abs_element != nullptr) {
946       auto abs_tensor_element_type = abs_element->type_id();
947       if (real_device_type != kTypeUnknown && real_device_type != abs_tensor_element_type) {
948         MS_LOG(INFO) << "For kernel " << node->DebugString() << ", the infer type of output[" << output_idx << "] is "
949                      << TypeIdToString(abs_tensor_element_type) << ", but the device type is "
950                      << TypeIdToString(real_device_type)
951                      << ". Maybe there has insert cast pass which changed the device type."
952                      << " So we change the tensor type from " << TypeIdToString(abs_tensor_element_type) << " to "
953                      << TypeIdToString(real_device_type);
954         abs_tensor_type->set_element(TypeIdToType(real_device_type));
955         // Use new tensor type with device data type.
956         type = abs_tensor_type;
957       }
958     }
959   }
960 
961   return std::make_tuple(shape, type, value);
962 }
963 
ExistOutputKernelTensor(const AnfNodePtr & node,size_t output_idx)964 bool AnfRuntimeAlgorithm::ExistOutputKernelTensor(const AnfNodePtr &node, size_t output_idx) {
965   MS_EXCEPTION_IF_NULL(node);
966   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
967   MS_EXCEPTION_IF_NULL(kernel_info);
968 
969   return kernel_info->OutputAddrExist(output_idx) || kernel_info->OutputKernelTensorExist(output_idx);
970 }
971 
GetOutputKernelTensor(const AnfNodePtr & node,size_t output_idx)972 const KernelTensorPtr &AnfRuntimeAlgorithm::GetOutputKernelTensor(const AnfNodePtr &node, size_t output_idx) {
973   MS_EXCEPTION_IF_NULL(node);
974   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
975   MS_EXCEPTION_IF_NULL(kernel_info);
976 
977   // Get output kernel tensor in device address if exists.
978   if (kernel_info->OutputAddrExist(output_idx)) {
979     return kernel_info->GetOutputAddr(output_idx)->kernel_tensor();
980   }
981 
982   // Get output kernel tensor if exists.
983   if (kernel_info->OutputKernelTensorExist(output_idx)) {
984     return kernel_info->GetOutputKernelTensor(output_idx);
985   }
986 
987   MS_LOG(EXCEPTION) << "Can not find kernel tensor for node : " << node->DebugString()
988                     << ", output index: " << output_idx;
989 }
990 
GetOrCreateOutputKernelTensor(const AnfNodePtr & node,size_t output_idx)991 const KernelTensorPtr &AnfRuntimeAlgorithm::GetOrCreateOutputKernelTensor(const AnfNodePtr &node, size_t output_idx) {
992   MS_EXCEPTION_IF_NULL(node);
993 
994   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
995   MS_EXCEPTION_IF_NULL(kernel_info);
996 
997   // Get output kernel tensor in device address if exists.
998   if (kernel_info->OutputAddrExist(output_idx)) {
999     const auto &kt = kernel_info->GetOutputAddr(output_idx)->kernel_tensor();
1000     return kt;
1001   }
1002 
1003   // Get output kernel tensor if exists.
1004   if (kernel_info->OutputKernelTensorExist(output_idx)) {
1005     return kernel_info->GetOutputKernelTensor(output_idx);
1006   }
1007 
1008   auto [shape, type, value] = GetAbstractInfo(node, output_idx);
1009   auto kernel_tensor = std::make_shared<KernelTensor>(shape, type, value);
1010   // Handle the format diff between host and device, need set format before Resize KernelMod.
1011   kernel_tensor->SetStringFormat(GetOutputFormat(node, output_idx));
1012   kernel_info->SetOutputKernelTensor(kernel_tensor, output_idx);
1013 
1014   return kernel_info->GetOutputKernelTensor(output_idx);
1015 }
1016 
GetPrevNodeOutputKernelTensor(const AnfNodePtr & node,size_t input_idx)1017 const KernelTensorPtr &AnfRuntimeAlgorithm::GetPrevNodeOutputKernelTensor(const AnfNodePtr &node, size_t input_idx) {
1018   KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_idx, false);
1019   return GetOutputKernelTensor(kernel_with_index.first, kernel_with_index.second);
1020 }
1021 
GetOrCreatePrevNodeOutputKernelTensor(const AnfNodePtr & node,size_t input_idx)1022 const KernelTensorPtr &AnfRuntimeAlgorithm::GetOrCreatePrevNodeOutputKernelTensor(const AnfNodePtr &node,
1023                                                                                   size_t input_idx) {
1024   KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_idx, false);
1025   return GetOrCreateOutputKernelTensor(kernel_with_index.first, kernel_with_index.second);
1026 }
1027 
GetOrCreateAllInputKernelTensors(const AnfNodePtr & node)1028 std::vector<KernelTensor *> AnfRuntimeAlgorithm::GetOrCreateAllInputKernelTensors(const AnfNodePtr &node) {
1029   MS_EXCEPTION_IF_NULL(node);
1030   size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
1031   std::vector<KernelTensor *> input_kernel_tensors(input_num);
1032   for (size_t input_idx = 0; input_idx < input_num; ++input_idx) {
1033     input_kernel_tensors[input_idx] = GetOrCreatePrevNodeOutputKernelTensor(node, input_idx).get();
1034   }
1035   return input_kernel_tensors;
1036 }
1037 
GetOrCreateAllOutputKernelTensors(const AnfNodePtr & node)1038 std::vector<KernelTensor *> AnfRuntimeAlgorithm::GetOrCreateAllOutputKernelTensors(const AnfNodePtr &node) {
1039   MS_EXCEPTION_IF_NULL(node);
1040   size_t output_num = AnfAlgo::GetOutputTensorNum(node);
1041   std::vector<KernelTensor *> output_kernel_tensors(output_num);
1042   for (size_t output_idx = 0; output_idx < output_num; ++output_idx) {
1043     output_kernel_tensors[output_idx] = GetOrCreateOutputKernelTensor(node, output_idx).get();
1044   }
1045   return output_kernel_tensors;
1046 }
1047 
CreateOutputKernelTensorWithDeviceInfo(const AnfWithOutIndex & node_with_index,void * const device_ptr,size_t size,const string & format,TypeId dtype_id,const ShapeVector & host_shape,const std::string & device_name,uint32_t device_id,const UserDataPtr & user_data)1048 KernelTensorPtr AnfRuntimeAlgorithm::CreateOutputKernelTensorWithDeviceInfo(
1049   const AnfWithOutIndex &node_with_index, void *const device_ptr, size_t size, const string &format, TypeId dtype_id,
1050   const ShapeVector &host_shape, const std::string &device_name, uint32_t device_id, const UserDataPtr &user_data) {
1051   abstract::BaseShapePtr shape;
1052   TypePtr type;
1053   ValuePtr value;
1054   if (ExistOutputKernelTensor(node_with_index.first, node_with_index.second)) {
1055     const auto &kernel_tensor = GetOutputKernelTensor(node_with_index.first, node_with_index.second);
1056     MS_EXCEPTION_IF_NULL(kernel_tensor);
1057     MS_EXCEPTION_IF_NULL(kernel_tensor->GetShape());
1058     MS_EXCEPTION_IF_NULL(kernel_tensor->GetType());
1059     shape = kernel_tensor->GetShape()->Clone();
1060     type = kernel_tensor->GetType()->Clone();
1061     value = kernel_tensor->GetValueTrack();
1062   } else {
1063     std::tie(shape, type, value) = GetAbstractInfo(node_with_index.first, node_with_index.second);
1064   }
1065 
1066   MS_EXCEPTION_IF_NULL(shape);
1067   MS_EXCEPTION_IF_NULL(type);
1068   MS_LOG(DEBUG) << "Create output kernel tensor for node: " << node_with_index.first->fullname_with_scope()
1069                 << ", output index: " << node_with_index.second << ", Shape: " << shape->ToString()
1070                 << ", Type: " << type->ToString() << ", Value: " << (value ? value->ToString() : "nullptr")
1071                 << ", host shape: " << host_shape;
1072 
1073   return std::make_shared<kernel::KernelTensor>(shape, type, value, device_ptr, size, format, dtype_id, host_shape,
1074                                                 device_name, device_id, user_data);
1075 }
1076 
GetNodeInputSizeList(const AnfNodePtr & node)1077 std::vector<size_t> AnfRuntimeAlgorithm::GetNodeInputSizeList(const AnfNodePtr &node) {
1078   std::vector<KernelTensor *> input_kernel_tensors = AnfAlgo::GetOrCreateAllInputKernelTensors(node);
1079   size_t input_num = input_kernel_tensors.size();
1080   std::vector<size_t> input_size_list(input_num, 0);
1081   for (size_t i = 0; i < input_num; i++) {
1082     MS_EXCEPTION_IF_NULL(input_kernel_tensors[i]);
1083     input_size_list[i] = input_kernel_tensors[i]->size();
1084   }
1085 
1086   return input_size_list;
1087 }
1088 
GetOutputAddressNum(const AnfNodePtr & node)1089 size_t AnfRuntimeAlgorithm::GetOutputAddressNum(const AnfNodePtr &node) {
1090   MS_EXCEPTION_IF_NULL(node);
1091   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1092   MS_EXCEPTION_IF_NULL(kernel_info);
1093   auto build_info = kernel_info->select_kernel_build_info();
1094   MS_EXCEPTION_IF_NULL(build_info);
1095   return build_info->GetOutputNumWithoutMonad();
1096 }
1097 
1098 // set output device addr of anf_node
SetOutputAddr(const DeviceAddressPtr & addr,size_t output_idx,AnfNode * node)1099 void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
1100   MS_EXCEPTION_IF_NULL(node);
1101   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1102   MS_EXCEPTION_IF_NULL(kernel_info);
1103   if (!kernel_info->SetOutputAddr(addr, output_idx)) {
1104     MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set output index:" << output_idx << " fail."
1105                       << trace::DumpSourceLines(node);
1106   }
1107 }
1108 
1109 // set workspace device addr of anf_node
SetWorkspaceAddr(const DeviceAddressPtr & addr,size_t output_idx,AnfNode * node)1110 void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
1111   MS_EXCEPTION_IF_NULL(node);
1112   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1113   MS_EXCEPTION_IF_NULL(kernel_info);
1114   if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
1115     MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set output index:" << output_idx << " fail."
1116                       << trace::DumpSourceLines(node);
1117   }
1118 }
1119 
1120 // get workspace device addr of anf_node
GetWorkspaceAddr(const AnfNodePtr & node,size_t output_idx)1121 DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
1122   MS_EXCEPTION_IF_NULL(node);
1123   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1124   MS_EXCEPTION_IF_NULL(kernel_info);
1125   auto addr = kernel_info->GetWorkspaceAddr(output_idx);
1126   if (addr == nullptr) {
1127     MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
1128                       << "] workspace addr is not exist." << trace::DumpSourceLines(node);
1129   }
1130   return addr;
1131 }
1132 
1133 // get workspace device mutable addr of anf_node
GetMutableWorkspaceAddr(const AnfNodePtr & node,size_t index)1134 DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index) {
1135   MS_EXCEPTION_IF_NULL(node);
1136   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1137   MS_EXCEPTION_IF_NULL(kernel_info);
1138   auto addr = kernel_info->GetMutableWorkspaceAddr(index);
1139   if (addr == nullptr) {
1140     MS_LOG(EXCEPTION) << "Index " << index << " of node " << node->DebugString() << "] workspace addr is not exist."
1141                       << trace::DumpSourceLines(node);
1142   }
1143   return addr;
1144 }
1145 
GetOpPattern(const AnfNodePtr & node)1146 kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
1147   MS_EXCEPTION_IF_NULL(node);
1148   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1149   MS_EXCEPTION_IF_NULL(kernel_info);
1150   // select_kernel_build_info() has checked whether return pointer is null
1151   auto build_info = kernel_info->select_kernel_build_info();
1152   MS_EXCEPTION_IF_NULL(build_info);
1153   return build_info->op_pattern();
1154 }
1155 
1156 // get KernelBuildType of node, such as ATT,RT,FWK and so on
GetKernelType(const AnfNodePtr & node)1157 KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
1158   MS_EXCEPTION_IF_NULL(node);
1159   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1160   MS_EXCEPTION_IF_NULL(kernel_info);
1161   // select_kernel_build_info() has checked whether return pointer is null
1162   auto build_info = kernel_info->select_kernel_build_info();
1163   if (build_info == nullptr) {
1164     MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << " has no kernel build info, using UNKNOWN_KERNEL_TYPE";
1165     return KernelType::UNKNOWN_KERNEL_TYPE;
1166   }
1167   return build_info->kernel_type();
1168 }
1169 
SetFusionType(const AnfNodePtr & node,const std::string & type)1170 void AnfRuntimeAlgorithm::SetFusionType(const AnfNodePtr &node, const std::string &type) {
1171   MS_EXCEPTION_IF_NULL(node);
1172   auto builder =
1173     std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
1174   MS_EXCEPTION_IF_NULL(builder);
1175   builder->SetFusionType(type);
1176   AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
1177 }
1178 
SetCoreType(const AnfNodePtr & node,const std::string & core_type)1179 void AnfRuntimeAlgorithm::SetCoreType(const AnfNodePtr &node, const std::string &core_type) {
1180   MS_EXCEPTION_IF_NULL(node);
1181   auto builder =
1182     std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
1183   MS_EXCEPTION_IF_NULL(builder);
1184   builder->SetCoreType(core_type);
1185   AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
1186 }
1187 
GetCoreType(const AnfNodePtr & node)1188 std::string AnfRuntimeAlgorithm::GetCoreType(const AnfNodePtr &node) {
1189   MS_EXCEPTION_IF_NULL(node);
1190   if (!AnfUtils::IsRealKernel(node)) {
1191     return "";
1192   }
1193   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1194   MS_EXCEPTION_IF_NULL(kernel_info);
1195   auto build_info = kernel_info->select_kernel_build_info();
1196   MS_EXCEPTION_IF_NULL(build_info);
1197   return build_info->core_type();
1198 }
1199 
GetOpType(const AnfNodePtr & node)1200 kernel::OpType AnfRuntimeAlgorithm::GetOpType(const AnfNodePtr &node) {
1201   MS_EXCEPTION_IF_NULL(node);
1202   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1203   MS_EXCEPTION_IF_NULL(kernel_info);
1204   auto build_info = kernel_info->select_kernel_build_info();
1205   MS_EXCEPTION_IF_NULL(build_info);
1206   return build_info->op_type();
1207 }
1208 
SetOutputDataDesc(const AnfNodePtr & node,const std::vector<nlohmann::json> & desc)1209 void AnfRuntimeAlgorithm::SetOutputDataDesc(const AnfNodePtr &node, const std::vector<nlohmann::json> &desc) {
1210   MS_EXCEPTION_IF_NULL(node);
1211   auto builder =
1212     std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
1213   MS_EXCEPTION_IF_NULL(builder);
1214   builder->SetOutputDataDesc(desc);
1215   AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
1216 }
1217 
GetOutputDataDesc(const AnfNodePtr & node)1218 std::vector<nlohmann::json> AnfRuntimeAlgorithm::GetOutputDataDesc(const AnfNodePtr &node) {
1219   MS_EXCEPTION_IF_NULL(node);
1220   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1221   if (kernel_info == nullptr) {
1222     return {};
1223   }
1224   auto build_info = kernel_info->select_kernel_build_info();
1225   if (build_info == nullptr) {
1226     return {};
1227   }
1228   return build_info->output_data_desc();
1229 }
1230 
GetProcessor(const AnfNodePtr & node)1231 kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
1232   MS_EXCEPTION_IF_NULL(node);
1233   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1234   MS_EXCEPTION_IF_NULL(kernel_info);
1235   auto build_info = kernel_info->select_kernel_build_info();
1236   MS_EXCEPTION_IF_NULL(build_info);
1237   return build_info->processor();
1238 }
1239 
GetFusionType(const AnfNodePtr & node)1240 std::string AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
1241   MS_EXCEPTION_IF_NULL(node);
1242   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1243   MS_EXCEPTION_IF_NULL(kernel_info);
1244   auto build_info = kernel_info->select_kernel_build_info();
1245   if (build_info == nullptr) {
1246     return kPatternUnknown;
1247   }
1248   return build_info->fusion_type();
1249 }
1250 
1251 // set select kernel_build_info
SetSelectKernelBuildInfo(const KernelBuildInfoPtr & select_kernel_build_info,AnfNode * node)1252 void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
1253   MS_EXCEPTION_IF_NULL(node);
1254   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1255   MS_EXCEPTION_IF_NULL(kernel_info);
1256   if (kernel_info->has_build_info() && (select_kernel_build_info != nullptr)) {
1257     auto ori_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
1258     auto input_object_types = ori_kernel_build_info->GetAllInputKernelObjectTypes();
1259     auto output_object_types = ori_kernel_build_info->GetAllOutputKernelObjectTypes();
1260     if (!input_object_types.empty() && select_kernel_build_info->GetAllInputKernelObjectTypes().empty()) {
1261       select_kernel_build_info->SetInputsKernelObjectType(input_object_types);
1262     }
1263     if (!output_object_types.empty() && select_kernel_build_info->GetAllOutputKernelObjectTypes().empty()) {
1264       MS_LOG(DEBUG) << "set kernel object type:" << output_object_types << " for node:" << node->fullname_with_scope();
1265       select_kernel_build_info->SetOutputsKernelObjectType(output_object_types);
1266     }
1267   }
1268   return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
1269 }
1270 
1271 // get select kernel_build_info
GetSelectKernelBuildInfo(const AnfNodePtr & node)1272 KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
1273   MS_EXCEPTION_IF_NULL(node);
1274   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1275   MS_EXCEPTION_IF_NULL(kernel_info);
1276   return kernel_info->GetMutableSelectKernelBuildInfo();
1277 }
1278 
1279 // get kernelMode
GetKernelMod(const AnfNodePtr & node)1280 KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
1281   MS_EXCEPTION_IF_NULL(node);
1282   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1283   MS_EXCEPTION_IF_NULL(kernel_info);
1284   return kernel_info->MutableKernelMod();
1285 }
1286 
1287 // set kernel mod
SetKernelMod(const KernelModPtr & kernel_mod,AnfNode * node)1288 void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
1289   MS_EXCEPTION_IF_NULL(node);
1290   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1291   MS_EXCEPTION_IF_NULL(kernel_info);
1292   kernel_info->set_kernel_mod(kernel_mod);
1293 }
1294 
SetStreamId(uint32_t stream_id,AnfNode * node)1295 void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
1296   MS_EXCEPTION_IF_NULL(node);
1297   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1298   MS_EXCEPTION_IF_NULL(kernel_info);
1299   kernel_info->set_stream_id(stream_id);
1300 }
1301 
GetStreamId(const AnfNodePtr & node)1302 uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
1303   MS_EXCEPTION_IF_NULL(node);
1304   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1305   MS_EXCEPTION_IF_NULL(kernel_info);
1306   return kernel_info->stream_id();
1307 }
1308 
SetStreamDistinctionLabel(uint32_t stream_label,AnfNode * node)1309 void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
1310   MS_EXCEPTION_IF_NULL(node);
1311   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1312   MS_EXCEPTION_IF_NULL(kernel_info);
1313   kernel_info->set_stream_distinction_label(stream_label);
1314 }
1315 
GetStreamDistinctionLabel(const AnfNode * node)1316 uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
1317   MS_EXCEPTION_IF_NULL(node);
1318   auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
1319   MS_EXCEPTION_IF_NULL(kernel_info);
1320   return kernel_info->stream_distinction_label();
1321 }
1322 
SetGraphId(uint32_t graph_id,AnfNode * node)1323 void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
1324   MS_EXCEPTION_IF_NULL(node);
1325   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
1326   MS_EXCEPTION_IF_NULL(kernel_info);
1327   kernel_info->set_graph_id(graph_id);
1328 }
1329 
GetGraphId(const AnfNode * node)1330 uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
1331   MS_EXCEPTION_IF_NULL(node);
1332   auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
1333   MS_EXCEPTION_IF_NULL(kernel_info);
1334   return kernel_info->graph_id();
1335 }
1336 
IsFeatureMapOutput(const AnfNodePtr & node)1337 bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
1338   MS_EXCEPTION_IF_NULL(node);
1339   if (node->isa<ValueNode>()) {
1340     auto value_node = node->cast<ValueNodePtr>();
1341     MS_EXCEPTION_IF_NULL(value_node);
1342     ValuePtr value = value_node->value();
1343     std::vector<tensor::BaseTensorPtr> tensors;
1344     TensorValueToTensor(value, &tensors);
1345     auto ret = false;
1346     if (!tensors.empty()) {
1347       auto all_tensor_have_address = true;
1348       for (const auto &tensor : tensors) {
1349         MS_EXCEPTION_IF_NULL(tensor);
1350         if (tensor->device_address() == nullptr) {
1351           all_tensor_have_address = false;
1352           break;
1353         }
1354       }
1355       ret = all_tensor_have_address;
1356     }
1357     return ret;
1358   }
1359   if (IsPrimitiveCNode(node, prim::kPrimLoad) || IsPrimitiveCNode(node, prim::kPrimDepend)) {
1360     return IsFeatureMapOutput(node->cast<CNodePtr>()->input(1));
1361   }
1362   auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
1363   // If node is a call node which not have kernel info
1364   if (kernel_info == nullptr) {
1365     return false;
1366   }
1367   return kernel_info->is_feature_map();
1368 }
1369 
IsFeatureMapInput(const AnfNodePtr & node,size_t input_index)1370 bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) {
1371   MS_EXCEPTION_IF_NULL(node);
1372   if (!node->isa<CNode>()) {
1373     MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map."
1374                       << trace::DumpSourceLines(node);
1375   }
1376   auto cnode = node->cast<CNodePtr>();
1377   MS_EXCEPTION_IF_NULL(cnode);
1378   auto input_node = cnode->input(input_index + 1);
1379   return IsFeatureMapOutput(input_node);
1380 }
1381 
GetInputGraphIdxByKernelIdx(const mindspore::AnfNodePtr & anf_node,size_t input_index_in_kernel)1382 size_t AnfRuntimeAlgorithm::GetInputGraphIdxByKernelIdx(const mindspore::AnfNodePtr &anf_node,
1383                                                         size_t input_index_in_kernel) {
1384   MS_EXCEPTION_IF_NULL(anf_node);
1385   return input_index_in_kernel;
1386 }
1387 
GetInputKernelIdxByGraphIdx(const mindspore::AnfNodePtr & anf_node,size_t input_index_in_graph)1388 size_t AnfRuntimeAlgorithm::GetInputKernelIdxByGraphIdx(const mindspore::AnfNodePtr &anf_node,
1389                                                         size_t input_index_in_graph) {
1390   MS_EXCEPTION_IF_NULL(anf_node);
1391   return input_index_in_graph;
1392 }
1393 
GetCallSwitchKernelGraph(const CNodePtr & cnode)1394 std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) {
1395   MS_EXCEPTION_IF_NULL(cnode);
1396   if (!(common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) ||
1397         common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
1398         common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) {
1399     MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch or switch_layer node."
1400                       << trace::DumpSourceLines(cnode);
1401   }
1402   auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr {
1403     auto partial = cnode->input(input_index);
1404     MS_EXCEPTION_IF_NULL(partial);
1405     if (IsValueNode<KernelGraph>(partial)) {
1406       return GetValueNode<KernelGraphPtr>(partial);
1407     }
1408     auto partial_cnode = partial->cast<CNodePtr>();
1409     MS_EXCEPTION_IF_NULL(partial_cnode);
1410     auto graph_node = partial_cnode->input(kPartialGraphIndex);
1411     MS_EXCEPTION_IF_NULL(graph_node);
1412     auto graph_value_node = graph_node->cast<ValueNodePtr>();
1413     MS_EXCEPTION_IF_NULL(graph_value_node);
1414     auto graph_value = graph_value_node->value();
1415     MS_EXCEPTION_IF_NULL(graph_value);
1416     auto child_graph = graph_value->cast<KernelGraphPtr>();
1417     return child_graph;
1418   };
1419   if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
1420     auto input1 = cnode->input(kPartialGraphIndex);
1421     MS_EXCEPTION_IF_NULL(input1);
1422     auto value_node = input1->cast<ValueNodePtr>();
1423     MS_EXCEPTION_IF_NULL(value_node);
1424     auto kernel_graph = value_node->value();
1425     MS_EXCEPTION_IF_NULL(kernel_graph);
1426     return {kernel_graph->cast<KernelGraphPtr>()};
1427   } else if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
1428     return {get_switch_kernel_graph(kSwitchTrueBranchIndex), get_switch_kernel_graph(kSwitchFalseBranchIndex)};
1429   } else if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
1430     std::vector<KernelGraphPtr> child_graphs;
1431     for (size_t idx = kSwitchLayerBranchesIndex; idx < cnode->size(); idx++) {
1432       auto child_graph = get_switch_kernel_graph(idx);
1433       child_graphs.emplace_back(child_graph);
1434     }
1435     return child_graphs;
1436   }
1437   return {};
1438 }
1439 
GetValueNodeKernelGraph(const AnfNodePtr & node)1440 KernelGraphPtr AnfRuntimeAlgorithm::GetValueNodeKernelGraph(const AnfNodePtr &node) {
1441   MS_EXCEPTION_IF_NULL(node);
1442   auto value_node = node->cast<ValueNodePtr>();
1443   if (value_node == nullptr) {
1444     return nullptr;
1445   }
1446   auto value = value_node->value();
1447   if (value == nullptr) {
1448     return nullptr;
1449   }
1450   auto kernel_graph = value->cast<KernelGraphPtr>();
1451   return kernel_graph;
1452 }
1453 
IsIndependentNode(const CNodePtr & node)1454 bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) {
1455   MS_EXCEPTION_IF_NULL(node);
1456   if (AnfAlgo::GetKernelType(node) != AICPU_KERNEL) {
1457     return false;
1458   }
1459 
1460   if (common::AnfAlgo::GetCNodeName(node) == kGetNextOpName) {
1461     MS_LOG(INFO) << "GetNext should not be independent node";
1462     return false;
1463   }
1464 
1465   // aicpu stack ops are not independent nodes.
1466   if (common::AnfAlgo::GetCNodeName(node) == kStackInitOpName ||
1467       common::AnfAlgo::GetCNodeName(node) == kStackDestroyOpName ||
1468       common::AnfAlgo::GetCNodeName(node) == kStackPopOpName ||
1469       common::AnfAlgo::GetCNodeName(node) == kStackPushOpName) {
1470     MS_LOG(INFO) << "AICPU stack ops should not be independent node";
1471     return false;
1472   }
1473 
1474   size_t input_nums = common::AnfAlgo::GetInputTensorNum(node);
1475   if (input_nums == 0) {
1476     return true;
1477   }
1478 
1479   auto inputs = node->inputs();
1480   for (size_t i = 1; i < inputs.size(); i++) {
1481     if (!inputs[i]->isa<ValueNode>()) {
1482       return false;
1483     }
1484   }
1485   return true;
1486 }
1487 
FetchKernelGraph(const AnfNode * node)1488 KernelGraphPtr AnfRuntimeAlgorithm::FetchKernelGraph(const AnfNode *node) {
1489   MS_EXCEPTION_IF_NULL(node);
1490   const auto &func_graph = node->func_graph();
1491   if (func_graph == nullptr) {
1492     return nullptr;
1493   } else {
1494     return func_graph->cast<KernelGraphPtr>();
1495   }
1496 }
1497 
FetchFrontNodeByBackendNode(const AnfNodePtr & backend_node,const KernelGraph & graph)1498 AnfNodePtr AnfRuntimeAlgorithm::FetchFrontNodeByBackendNode(const AnfNodePtr &backend_node, const KernelGraph &graph) {
1499   MS_EXCEPTION_IF_NULL(backend_node);
1500   auto front_node_with_index = graph.GetFrontNodeByInternalParameter(backend_node);
1501   if (front_node_with_index.first != nullptr) {
1502     return front_node_with_index.first;
1503   }
1504 
1505   auto front_node = graph.GetFrontAnfByBackendAnf(backend_node);
1506   // PyNative forward graph does not has front node, using backend node instead.
1507   if (front_node == nullptr) {
1508     front_node = backend_node;
1509   }
1510   return front_node;
1511 }
1512 
1513 namespace {
1514 // Host kernel with inputs on host
SkipDataSync(const CNodePtr & node,const std::map<uint32_t,tensor::TensorPtr> & depend_tensors)1515 bool SkipDataSync(const CNodePtr &node, const std::map<uint32_t, tensor::TensorPtr> &depend_tensors) {
1516   if (!common::AnfAlgo::IsHostKernel(node)) {
1517     return false;
1518   }
1519   auto input_size = common::AnfAlgo::GetInputTensorNum(node);
1520   for (size_t i = 0; i < input_size; ++i) {
1521     auto input_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i);
1522     auto real_input = input_with_index.first;
1523     auto iter_tensor = depend_tensors.find(i);
1524     if (iter_tensor != depend_tensors.end()) {
1525       auto output_addr = AnfAlgo::GetOutputAddr(real_input, 0);
1526       MS_EXCEPTION_IF_NULL(output_addr);
1527       if (output_addr->GetDeviceType() != device::DeviceType::kCPU) {
1528         return false;
1529       }
1530     }
1531   }
1532   return true;
1533 }
1534 }  // namespace
1535 
InferShape(const CNodePtr & node,std::map<uint32_t,tensor::TensorPtr> * depend_tensors)1536 void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *depend_tensors) {
1537   MS_EXCEPTION_IF_NULL(node);
1538   MS_LOG(INFO) << "InferShape start, node:" << node->DebugString();
1539   auto inputs = node->inputs();
1540   if (inputs.empty()) {
1541     MS_LOG(EXCEPTION) << "Inputs should not be empty! Cnode: " << node->DebugString() << "."
1542                       << trace::DumpSourceLines(node);
1543   }
1544   AbstractBasePtrList args_spec_list;
1545   auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
1546   auto input_size = common::AnfAlgo::GetInputTensorNum(node);
1547   for (size_t i = 0; i < input_size; ++i) {
1548     auto input_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i);
1549     auto real_input = input_with_index.first;
1550     MS_EXCEPTION_IF_NULL(real_input);
1551     auto cnode_input = node->input(i + 1);
1552     MS_EXCEPTION_IF_NULL(cnode_input);
1553     if (depend_tensors != nullptr) {
1554       auto iter_tensor = depend_tensors->find(i);
1555       if (iter_tensor != depend_tensors->cend()) {
1556         auto tensor_ptr = iter_tensor->second;
1557         MS_EXCEPTION_IF_NULL(tensor_ptr);
1558         if (!SkipDataSync(node, *depend_tensors)) {
1559           // sync data from device to host
1560           tensor_ptr->data_sync();
1561         }
1562         // cppcheck-suppress unreadVariable
1563         auto lock = AnfUtils::GetAbstractLock(real_input.get());
1564         auto real_abs = real_input->abstract();
1565         if (real_abs->isa<abstract::AbstractTensor>()) {
1566           real_abs->set_value(tensor_ptr);
1567         } else if (real_abs->isa<abstract::AbstractTuple>() && (!common::AnfAlgo::IsDynamicSequence(real_input))) {
1568           auto tuple_get_item_index = common::AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
1569           auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>();
1570           MS_EXCEPTION_IF_NULL(abstract_tuple);
1571           auto tuple_elements = abstract_tuple->elements()[tuple_get_item_index];
1572           tuple_elements->set_value(tensor_ptr);
1573         }
1574       }
1575     }
1576     common::AnfAlgo::AddArgList(&args_spec_list, real_input, input_with_index.second);
1577   }
1578   auto eval_result = opt::CppInferShapeAndType(primitive, args_spec_list);
1579   node->set_abstract(eval_result);
1580 }
1581 
InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> & root_graph)1582 void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> &root_graph) {
1583   auto return_node = root_graph->get_return();
1584   MS_EXCEPTION_IF_NULL(return_node);
1585   if (return_node->size() <= kReturnDataIndex) {
1586     return;
1587   }
1588   auto make_tuple = root_graph->NewCNode(
1589     {NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
1590   MS_EXCEPTION_IF_NULL(root_graph->output());
1591   MS_EXCEPTION_IF_NULL(make_tuple);
1592   abstract::AbstractBasePtrList abs_list{root_graph->output()->abstract()};
1593   make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
1594   root_graph->set_output(make_tuple);
1595 }
1596 
UpdateGraphValidRefPair(const KernelGraphPtr & graph)1597 void AnfRuntimeAlgorithm::UpdateGraphValidRefPair(const KernelGraphPtr &graph) {
1598   MS_EXCEPTION_IF_NULL(graph);
1599 
1600   if (graph->memory_managed_by_ge()) {
1601     return;
1602   }
1603 
1604   const auto &origin_ref_map = graph->GetRefMap();
1605   std::map<AnfWithOutIndex, AnfWithOutIndex> new_ref_map;
1606   for (const auto &node : graph->execution_order()) {
1607     MS_EXCEPTION_IF_NULL(node);
1608     auto output_num = AnfAlgo::GetOutputTensorNum(node);
1609     if (output_num == 0) {
1610       MS_LOG(DEBUG) << "This kernel has no output size.";
1611       continue;
1612     }
1613     for (size_t i = 0; i < output_num; ++i) {
1614       session::AnfWithOutIndex out_pair(node, i);
1615       auto iter = origin_ref_map.find(out_pair);
1616       if (iter != origin_ref_map.end()) {
1617         auto ret = new_ref_map.try_emplace(iter->first, iter->second);
1618         if (!ret.second) {
1619           MS_LOG(WARNING) << "Duplicate ref_map key, node:" << node->fullname_with_scope() << " index:" << i;
1620         }
1621       }
1622     }
1623   }
1624   graph->set_ref_out_in_map(new_ref_map);
1625 }
1626 
IsDynamicShapeSkipExecute(bool skip_mode,const ShapeVector & axes_shape)1627 bool AnfRuntimeAlgorithm::IsDynamicShapeSkipExecute(bool skip_mode, const ShapeVector &axes_shape) {
1628   // Skip run ReduceSum when axis is a Empty Tensor
1629   if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; }) && skip_mode) {
1630     return true;
1631   }
1632   return false;
1633 }
1634 
IsDynamicShapeSkipExecute(const CNodePtr & cnode)1635 bool AnfRuntimeAlgorithm::IsDynamicShapeSkipExecute(const CNodePtr &cnode) {
1636   // Skip run ReduceSum when axis is a Empty Tensor
1637   MS_EXCEPTION_IF_NULL(cnode);
1638   auto op_name = common::AnfAlgo::GetCNodeName(cnode);
1639   if ((op_name != kReduceSumOpName) && (op_name != kReduceSumDOpName)) {
1640     return false;
1641   }
1642 
1643   bool skip_mode = false;
1644   if (common::AnfAlgo::HasNodeAttr(kAttrSkipMode, cnode)) {
1645     skip_mode = common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrSkipMode);
1646   }
1647 
1648   if (!skip_mode) {
1649     return false;
1650   }
1651 
1652   const size_t axes_index = 1;
1653   if (cnode->size() <= axes_index + 1) {
1654     return false;
1655   }
1656   auto input_axes = cnode->input(axes_index + 1);
1657   // cppcheck-suppress unreadVariable
1658   auto lock = AnfUtils::GetAbstractLock(input_axes.get());
1659   auto abs = input_axes->abstract();
1660   MS_EXCEPTION_IF_NULL(abs);
1661   auto axes_abs = abs->Clone();
1662   MS_EXCEPTION_IF_NULL(axes_abs);
1663   auto axes_shape = AnfAlgo::GetInputDeviceShape(cnode, axes_index);
1664   if (axes_abs->isa<abstract::AbstractTensor>()) {
1665     if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; })) {
1666       return true;
1667     }
1668   }
1669   return false;
1670 }
1671 
IsNeedUpdateShapeAndTypeAfterLaunch(const AnfNodePtr & node)1672 bool AnfRuntimeAlgorithm::IsNeedUpdateShapeAndTypeAfterLaunch(const AnfNodePtr &node) {
1673   MS_EXCEPTION_IF_NULL(node);
1674   auto graph = FetchKernelGraph(node.get());
1675   // The graph run mode does not have kernelmod.
1676   if ((graph == nullptr) || graph->is_graph_run_mode()) {
1677     return true;
1678   }
1679 
1680   auto kernel_mod = GetKernelMod(node);
1681   if (kernel_mod == nullptr) {
1682     return true;
1683   }
1684   return kernel_mod->IsNeedUpdateOutputShapeAndSize();
1685 }
1686 
HasComputedDependInputNode(const CNodePtr & kernel)1687 bool AnfRuntimeAlgorithm::HasComputedDependInputNode(const CNodePtr &kernel) {
1688   MS_EXCEPTION_IF_NULL(kernel);
1689   auto real_input_num = common::AnfAlgo::GetInputTensorNum(kernel);
1690 
1691   for (size_t i = 0; i < real_input_num; i++) {
1692     const auto &input_node = common::AnfAlgo::GetInputNode(kernel, i);
1693     MS_EXCEPTION_IF_NULL(input_node);
1694     auto real_input_node = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
1695     MS_EXCEPTION_IF_NULL(real_input_node.first);
1696     if (!real_input_node.first->isa<CNode>()) {
1697       continue;
1698     }
1699 
1700     auto kernel_mod = AnfAlgo::GetKernelMod(real_input_node.first);
1701     if (kernel_mod && kernel_mod->IsNeedUpdateOutputShapeAndSize()) {
1702       return true;
1703     }
1704   }
1705   return false;
1706 }
1707 
UpdateOutputAddrSize(device::KernelInfo const * kernel_info,const CNodePtr & kernel)1708 void AnfRuntimeAlgorithm::UpdateOutputAddrSize(device::KernelInfo const *kernel_info, const CNodePtr &kernel) {
1709   MS_EXCEPTION_IF_NULL(kernel_info);
1710   MS_EXCEPTION_IF_NULL(kernel);
1711   auto &output_addresses = kernel_info->output_address_list();
1712   for (size_t i = 0; i < output_addresses.size(); ++i) {
1713     auto output_address = output_addresses[i].get();
1714     MS_EXCEPTION_IF_NULL(output_address);
1715     auto output_addr_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
1716     MS_LOG(DEBUG) << "output size:" << output_addr_size << " index:" << i
1717                   << " for kernel:" << kernel->fullname_with_scope()
1718                   << " abstract:" << (kernel->abstract() == nullptr ? "null" : kernel->abstract()->ToString());
1719     if (output_addr_size != output_address->GetSize()) {
1720       output_address->SetSize(output_addr_size);
1721     }
1722   }
1723 }
1724 
AddOutInRefToGraph(const KernelGraphPtr & graph)1725 void AnfRuntimeAlgorithm::AddOutInRefToGraph(const KernelGraphPtr &graph) {
1726   MS_EXCEPTION_IF_NULL(graph);
1727   for (const auto &cnode : graph->execution_order()) {
1728     MS_EXCEPTION_IF_NULL(cnode);
1729     auto kernel_info = dynamic_cast<device::KernelInfo *>(cnode->kernel_info());
1730     MS_EXCEPTION_IF_NULL(kernel_info);
1731     for (const auto &ref : kernel_info->out_in_ref_map()) {
1732       size_t output_index = ref.first;
1733       size_t input_index = ref.second;
1734       auto final_pair = std::make_pair(cnode, output_index);
1735       auto origin_pair = common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(cnode, input_index), 0);
1736       MS_LOG(INFO) << "The reference relation output " << final_pair.first->fullname_with_scope()
1737                    << ", output index: " << final_pair.second << " to input "
1738                    << origin_pair.first->fullname_with_scope() << ", output index: " << origin_pair.second;
1739       // Add to graph only if the input is not a monad.
1740       if (!HasAbstractUMonad(origin_pair.first) && !HasAbstractIOMonad(origin_pair.first)) {
1741         graph->AddRefCorrespondPairs(final_pair, origin_pair);
1742       }
1743     }
1744   }
1745 }
1746 
HasOriginFormat(const AnfNodePtr & anf_node)1747 bool AnfRuntimeAlgorithm::HasOriginFormat(const AnfNodePtr &anf_node) {
1748   MS_EXCEPTION_IF_NULL(anf_node);
1749   return anf_node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrOriginFormat, anf_node->cast<CNodePtr>());
1750 }
1751 
GetOriginFormat(const AnfNodePtr & anf_node)1752 std::string AnfRuntimeAlgorithm::GetOriginFormat(const AnfNodePtr &anf_node) {
1753   MS_EXCEPTION_IF_NULL(anf_node);
1754   if (anf_node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrOriginFormat, anf_node->cast<CNodePtr>())) {
1755     return common::AnfAlgo::GetNodeAttr<std::string>(anf_node, kAttrOriginFormat);
1756   }
1757   return {};
1758 }
1759 
NodeValueIsFuncGraph(const AnfNodePtr & node)1760 bool AnfRuntimeAlgorithm::NodeValueIsFuncGraph(const AnfNodePtr &node) {
1761   MS_EXCEPTION_IF_NULL(node);
1762   auto value_node = node->cast<ValueNodePtr>();
1763   MS_EXCEPTION_IF_NULL(value_node);
1764   auto value = value_node->value().get();
1765   MS_EXCEPTION_IF_NULL(value);
1766   return value->isa<FuncGraph>();
1767 }
1768 
IsNodeSupportKernelSelectBackoff(const AnfNodePtr & node,const KernelGraphPtr & graph)1769 bool AnfRuntimeAlgorithm::IsNodeSupportKernelSelectBackoff(const AnfNodePtr &node, const KernelGraphPtr &graph) {
1770   MS_EXCEPTION_IF_NULL(node);
1771   static std::string disable_kernel_backoff;
1772   static bool first_get_backoff_env = true;
1773   if (first_get_backoff_env) {
1774     disable_kernel_backoff = common::GetEnv(kDisableKernelBackoff);
1775     first_get_backoff_env = false;
1776   }
1777   if (disable_kernel_backoff == "1" && (!common::AnfAlgo::IsTypeTransformOp(common::AnfAlgo::GetCNodeName(node)))) {
1778     MS_LOG(INFO) << "MS_DISABLE_KERNEL_BACKOFF has been set to turn off the kernel backoff ability.";
1779     return false;
1780   }
1781 
1782   if (graph == nullptr) {
1783     return false;
1784   }
1785   if (graph->is_from_single_op() || graph->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
1786     MS_LOG(INFO) << "The pynative single op does not support the kernel backoff ability for graph:"
1787                  << graph->graph_id();
1788     return false;
1789   }
1790   return true;
1791 }
1792 
SetKernelSelectBackoffInfo(const CNodePtr & node,const std::pair<std::string,ExceptionType> & failure_info)1793 void AnfRuntimeAlgorithm::SetKernelSelectBackoffInfo(const CNodePtr &node,
1794                                                      const std::pair<std::string, ExceptionType> &failure_info) {
1795   MS_EXCEPTION_IF_NULL(node);
1796   common::AnfAlgo::SetNodeAttr(kAttrKernelBackoffWithFailureInfo, MakeValue(failure_info.first), node);
1797   common::AnfAlgo::SetNodeAttr(kAttrKernelBackoffWithFailureType, MakeValue(static_cast<int32_t>(failure_info.second)),
1798                                node);
1799 }
1800 
GetKernelSelectBackoffInfo(const AnfNodePtr & node)1801 std::pair<std::string, ExceptionType> AnfRuntimeAlgorithm::GetKernelSelectBackoffInfo(const AnfNodePtr &node) {
1802   MS_EXCEPTION_IF_NULL(node);
1803   if (!IsKernelSelectBackoffOp(node)) {
1804     return {"", NoExceptionType};
1805   }
1806 
1807   auto cnode = node->cast<CNodePtr>();
1808   MS_EXCEPTION_IF_NULL(cnode);
1809   auto failure_info = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrKernelBackoffWithFailureInfo);
1810   auto failure_type =
1811     static_cast<ExceptionType>(common::AnfAlgo::GetNodeAttr<int32_t>(node, kAttrKernelBackoffWithFailureType));
1812   return {failure_info, failure_type};
1813 }
1814 
IsKernelSelectBackoffOp(const AnfNodePtr & node)1815 bool AnfRuntimeAlgorithm::IsKernelSelectBackoffOp(const AnfNodePtr &node) {
1816   MS_EXCEPTION_IF_NULL(node);
1817   if (!node->isa<CNode>()) {
1818     return false;
1819   }
1820 
1821   auto cnode = node->cast<CNodePtr>();
1822   MS_EXCEPTION_IF_NULL(cnode);
1823   if (common::AnfAlgo::HasNodeAttr(kAttrKernelBackoffWithFailureInfo, cnode) &&
1824       common::AnfAlgo::HasNodeAttr(kAttrKernelBackoffWithFailureType, cnode)) {
1825     return true;
1826   }
1827   return false;
1828 }
1829 
FetchDeviceTarget(const AnfNodePtr & node,const KernelGraph * graph)1830 std::string AnfRuntimeAlgorithm::FetchDeviceTarget(const AnfNodePtr &node, const KernelGraph *graph) {
1831   MS_EXCEPTION_IF_NULL(node);
1832   MS_EXCEPTION_IF_NULL(graph);
1833   // The parameter also may be have the user data to express device target.
1834   auto ud_target = node->user_data<std::string>(kAttrPrimitiveTarget);
1835   if (ud_target != nullptr) {
1836     return *ud_target;
1837   }
1838 
1839   if (!node->isa<CNode>()) {
1840     return device::GetDeviceNameByType(graph->device_target());
1841   }
1842 
1843   // Only the CPU support kernel backoff.
1844   if (AnfAlgo::IsKernelSelectBackoffOp(node)) {
1845     return kCPUDevice;
1846   }
1847 
1848   auto cnode = node->cast<CNodePtr>();
1849   MS_EXCEPTION_IF_NULL(cnode);
1850   if (common::AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
1851     return common::AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrPrimitiveTarget);
1852   }
1853 
1854   return device::GetDeviceNameByType(graph->device_target());
1855 }
1856 
SetParameterDeviceTarget(const KernelGraphPtr graph)1857 void AnfRuntimeAlgorithm::SetParameterDeviceTarget(const KernelGraphPtr graph) {
1858   MS_EXCEPTION_IF_NULL(graph);
1859   auto manager = graph->manager();
1860   if (manager == nullptr) {
1861     manager = MakeManager({graph});
1862     graph->set_manager(manager);
1863   }
1864 
1865   const auto &graph_device_target = device::GetDeviceNameByType(graph->device_target());
1866   for (auto &input_node : graph->input_nodes()) {
1867     const auto &iter = manager->node_users().find(input_node);
1868     if (iter == manager->node_users().end()) {
1869       continue;
1870     }
1871 
1872     std::string device_target_affinity = graph_device_target;
1873     for (const auto &user_node : iter->second) {
1874       if (!AnfUtils::IsRealCNodeKernel(user_node.first)) {
1875         continue;
1876       }
1877       device_target_affinity = FetchDeviceTarget(user_node.first, graph.get());
1878       // If there is node with the same device target as the graph, then select the device target of graph affinity.
1879       if (device_target_affinity == graph_device_target) {
1880         break;
1881       }
1882     }
1883 
1884     // Set the device target for parameter when it is different with the graph.
1885     if (device_target_affinity != graph_device_target) {
1886       MS_LOG(INFO) << "Set the affinity device target for parameter:" << input_node->fullname_with_scope()
1887                    << " in graph:" << graph->graph_id() << " from graph device target:" << graph_device_target
1888                    << " to real device target:" << device_target_affinity;
1889       input_node->set_user_data(kAttrPrimitiveTarget, std::make_shared<std::string>(device_target_affinity));
1890     }
1891   }
1892 }
1893 
GetAbstractObjectType(const AbstractBasePtr & abstract)1894 TypeId AnfRuntimeAlgorithm::GetAbstractObjectType(const AbstractBasePtr &abstract) {
1895   if (abstract == nullptr) {
1896     return kTypeUnknown;
1897   }
1898   if (abstract->isa<AbstractTensor>()) {
1899     return kObjectTypeTensorType;
1900   }
1901   if (abstract->isa<AbstractTuple>()) {
1902     return kObjectTypeTuple;
1903   }
1904   if (abstract->isa<abstract::AbstractList>()) {
1905     return kObjectTypeList;
1906   }
1907   if (abstract->isa<abstract::AbstractScalar>()) {
1908     // scalar input may not converted to tensor
1909     return kObjectTypeNumber;
1910   }
1911   if (abstract->isa<abstract::AbstractNone>()) {
1912     return kMetaTypeNone;
1913   }
1914 
1915   return kTypeUnknown;
1916 }
1917 
GetOutputObjectType(const AnfNodePtr & node,size_t output_idx)1918 TypeId AnfRuntimeAlgorithm::GetOutputObjectType(const AnfNodePtr &node, size_t output_idx) {
1919   MS_EXCEPTION_IF_NULL(node);
1920   auto abstract = node->abstract();
1921   if (abstract->isa<AbstractTuple>()) {
1922     auto tuple_abs = abstract->cast<abstract::AbstractTuplePtr>();
1923     auto items = tuple_abs->elements();
1924     MS_EXCEPTION_IF_CHECK_FAIL(output_idx < items.size(), "invalid output_idx");
1925     return AnfAlgo::GetAbstractObjectType(items[output_idx]);
1926   }
1927   if (output_idx != 0) {
1928     MS_LOG(EXCEPTION) << node->DebugString() << "invalid output_idx" << trace::DumpSourceLines(node);
1929   }
1930   return AnfAlgo::GetAbstractObjectType(abstract);
1931 }
1932 
GetInputObjectType(const CNodePtr & node,size_t input_idx)1933 TypeId AnfRuntimeAlgorithm::GetInputObjectType(const CNodePtr &node, size_t input_idx) {
1934   MS_EXCEPTION_IF_NULL(node);
1935   auto input_node = common::AnfAlgo::GetInputNode(node, input_idx);
1936   const std::vector<PrimitivePtr> need_handled_prims = {prim::kPrimMakeTuple, prim::kPrimTupleGetItem};
1937   auto real_input_node = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false, need_handled_prims).first;
1938   return AnfAlgo::GetAbstractObjectType(real_input_node->abstract());
1939 }
1940 
GetAllInputObjectType(const AnfNodePtr & node)1941 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllInputObjectType(const AnfNodePtr &node) {
1942   MS_EXCEPTION_IF_NULL(node);
1943   if (!node->isa<CNode>()) {
1944     MS_LOG(EXCEPTION) << node->DebugString() << "anf_node is not CNode." << trace::DumpSourceLines(node);
1945   }
1946   auto cnode = node->cast<CNodePtr>();
1947   std::vector<TypeId> obj_types;
1948   auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
1949   for (size_t index = 0; index < input_num; ++index) {
1950     obj_types.push_back(AnfAlgo::GetInputObjectType(cnode, index));
1951   }
1952   return obj_types;
1953 }
1954 
GetAllOutputObjectType(const AnfNodePtr & node)1955 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputObjectType(const AnfNodePtr &node) {
1956   MS_EXCEPTION_IF_NULL(node);
1957   if (AnfAlgo::GetOutputElementNum(node) == 0 && node->abstract() != nullptr &&
1958       !node->abstract()->isa<abstract::AbstractSequence>()) {
1959     return {};
1960   }
1961   return {AnfAlgo::GetAbstractObjectType(node->abstract())};
1962 }
1963 
GetOutputDetailShape(const AnfNodePtr & node,size_t output_idx)1964 abstract::BaseShapePtr AnfRuntimeAlgorithm::GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx) {
1965   MS_EXCEPTION_IF_NULL(node);
1966   auto base_shape = node->Shape();
1967   MS_EXCEPTION_IF_NULL(base_shape);
1968   if (base_shape->isa<abstract::Shape>()) {
1969     if (output_idx == 0) {
1970       return base_shape;
1971     }
1972     MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << output_idx
1973                       << "]." << trace::DumpSourceLines(node);
1974   } else if (base_shape->isa<abstract::TupleShape>()) {
1975     auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
1976     MS_EXCEPTION_IF_NULL(tuple_shape);
1977     if (IsRealSquenceOutput(node)) {
1978       return tuple_shape;
1979     }
1980     if (output_idx >= tuple_shape->size()) {
1981       MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
1982                         << " node:" << node->DebugString() << "." << trace::DumpSourceLines(node);
1983     }
1984     auto b_shp = (*tuple_shape)[output_idx];
1985     if (b_shp->isa<abstract::Shape>() || b_shp->isa<abstract::NoShape>() || b_shp->isa<abstract::TupleShape>() ||
1986         b_shp->isa<abstract::DynamicSequenceShape>()) {
1987       return b_shp;
1988     } else {
1989       MS_LOG(EXCEPTION) << "The output type of node index:" << output_idx
1990                         << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
1991                         << "node :" << node->DebugString() << "." << trace::DumpSourceLines(node);
1992     }
1993   } else if (base_shape->isa<abstract::NoShape>()) {
1994     return base_shape;
1995   } else if (base_shape->isa<abstract::DynamicSequenceShape>()) {
1996     return common::AnfAlgo::GetDynamicSequenceShape(node, output_idx);
1997   }
1998   MS_LOG(EXCEPTION) << "The output type of node should be a NoShape , ArrayShape or a TupleShape, but it is "
1999                     << base_shape->ToString() << " node : " << node->DebugString() << trace::DumpSourceLines(node);
2000 }
2001 
GetPrevNodeOutputDetailShape(const AnfNodePtr & node,size_t input_idx)2002 abstract::BaseShapePtr AnfRuntimeAlgorithm::GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx) {
2003   KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_idx);
2004   return AnfAlgo::GetOutputDetailShape(kernel_with_index.first, kernel_with_index.second);
2005 }
2006 
GetAllOutputInferDataTypes(const AnfNodePtr & node)2007 std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputInferDataTypes(const AnfNodePtr &node) {
2008   MS_EXCEPTION_IF_NULL(node);
2009   std::vector<TypeId> outputs;
2010   auto out_nums = AnfAlgo::GetOutputElementNum(node);
2011   for (size_t i = 0; i < out_nums; i++) {
2012     auto type = common::AnfAlgo::GetOutputInferDataType(node, i);
2013     outputs.push_back(type);
2014   }
2015   return outputs;
2016 }
2017 
2018 // if input node is MakeTuple, find the PrevNodeNum recursively;
2019 // The monad node in the end is not included in the num;
GetInputElementNum(const AnfNodePtr & node)2020 size_t AnfRuntimeAlgorithm::GetInputElementNum(const AnfNodePtr &node) {
2021   MS_EXCEPTION_IF_NULL(node);
2022   auto cnode = node->cast<CNodePtr>();
2023   MS_EXCEPTION_IF_NULL(cnode);
2024   size_t element_num = 0;
2025   size_t input_num = cnode->size() - 1;
2026   bool cal_monad_flag = false;
2027   for (size_t i = input_num; i > 0; --i) {
2028     auto input_node = common::AnfAlgo::GetInputNode(cnode, i - 1);
2029     if (!cal_monad_flag && HasAbstractMonad(input_node)) {
2030       continue;
2031     } else if (common::AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
2032       element_num += GetInputElementNum(input_node);
2033       cal_monad_flag = true;
2034     } else if (common::AnfAlgo::IsTupleOutput(input_node)) {
2035       element_num += AnfAlgo::GetOutputElementNum(input_node);
2036       cal_monad_flag = true;
2037     } else {
2038       ++element_num;
2039       cal_monad_flag = true;
2040     }
2041   }
2042 
2043   return element_num;
2044 }
2045 
SetDynamicAttrToPrim(const PrimitivePtr & prim)2046 void AnfRuntimeAlgorithm::SetDynamicAttrToPrim(const PrimitivePtr &prim) {
2047   (void)prim->AddAttr(kAttrMutableKernel, MakeValue(true));
2048   (void)prim->AddAttr(kAttrInputIsDynamicShape, MakeValue(true));
2049   (void)prim->AddAttr(kAttrOutputIsDynamicShape, MakeValue(true));
2050 }
2051 
IsScalarConvertToTensor(const AnfNodePtr & input_node,const CNodePtr & node)2052 bool AnfRuntimeAlgorithm::IsScalarConvertToTensor(const AnfNodePtr &input_node, const CNodePtr &node) {
2053   MS_EXCEPTION_IF_NULL(input_node);
2054   MS_EXCEPTION_IF_NULL(node);
2055   if (!input_node->isa<ValueNode>()) {
2056     return false;
2057   }
2058 
2059   auto value_node = input_node->cast<ValueNodePtr>();
2060   MS_EXCEPTION_IF_NULL(value_node);
2061   auto value = value_node->value();
2062   MS_EXCEPTION_IF_NULL(value);
2063   if (!value->isa<Scalar>()) {
2064     return false;
2065   }
2066 
2067   const auto &abs = node->abstract();
2068   if (ContainScalarOut(abs)) {
2069     MS_LOG(INFO) << "The input scalar value node:" << input_node->fullname_with_scope()
2070                  << " of cnode:" << node->fullname_with_scope() << " doesn't need convert to tensor.";
2071     return false;
2072   }
2073   return true;
2074 }
2075 
IsSequenceOutputOfScalar(const AnfNodePtr & node)2076 bool AnfRuntimeAlgorithm::IsSequenceOutputOfScalar(const AnfNodePtr &node) {
2077   MS_EXCEPTION_IF_NULL(node);
2078   const auto &abs = node->abstract();
2079   if (abs == nullptr || !abs->isa<abstract::AbstractSequence>()) {
2080     return false;
2081   }
2082   // Check all elements in tuple/list are scalar.
2083   auto abs_seq = abs->cast_ptr<abstract::AbstractSequence>();
2084   MS_EXCEPTION_IF_NULL(abs_seq);
2085   if (abs_seq->dynamic_len()) {
2086     const auto &element_abs = abs_seq->dynamic_len_element_abs();
2087     return (element_abs == nullptr) || (element_abs->isa<abstract::AbstractScalar>());
2088   }
2089   const auto &elements = abs_seq->elements();
2090 
2091   return std::all_of(elements.begin(), elements.end(), [](const AbstractBasePtr &element) {
2092     return (element != nullptr) && (element->isa<abstract::AbstractScalar>()) &&
2093            (element->BuildValue() == nullptr || (!element->BuildValue()->isa<StringImm>()));
2094   });
2095 }
2096 
IsSummaryNode(const AnfNodePtr & node)2097 bool AnfRuntimeAlgorithm::IsSummaryNode(const AnfNodePtr &node) {
2098   return (IsPrimitiveCNode(node, prim::kPrimScalarSummary) || IsPrimitiveCNode(node, prim::kPrimTensorSummary) ||
2099           IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary));
2100 }
2101 
2102 namespace {
CheckValidTensorTuple(const std::vector<ValuePtr> & values)2103 bool CheckValidTensorTuple(const std::vector<ValuePtr> &values) {
2104   if (values.empty() || values[0] == nullptr || (!values[0]->isa<tensor::Tensor>())) {
2105     return false;
2106   }
2107   const auto &const_tensor = values[0]->cast<tensor::TensorPtr>();
2108   MS_EXCEPTION_IF_NULL(const_tensor);
2109   const auto &const_shape = const_tensor->shape();
2110   const auto &const_type_id = const_tensor->data_type();
2111   size_t const_size = const_tensor->Size();
2112   for (size_t i = 1; i < values.size(); ++i) {
2113     if (values[i] == nullptr || (!values[i]->isa<tensor::Tensor>())) {
2114       MS_LOG(ERROR) << "Invalid value:" << (values[i] == nullptr ? "nullptr" : values[i]->ToString()) << " index:" << i
2115                     << " in value tuple";
2116       return false;
2117     }
2118     const auto &tensor = values[i]->cast<tensor::TensorPtr>();
2119     MS_EXCEPTION_IF_NULL(tensor);
2120     const auto &shape = tensor->shape();
2121     const auto &type_id = tensor->data_type();
2122     size_t size = tensor->Size();
2123     if (shape != const_shape || type_id != const_type_id || size != const_size) {
2124       return false;
2125     }
2126   }
2127   return true;
2128 }
2129 
2130 // Return a new tensor with type like single_value.
SetScalarToTensor(const std::vector<ValuePtr> & values,const tensor::TensorPtr & tensor)2131 void SetScalarToTensor(const std::vector<ValuePtr> &values, const tensor::TensorPtr &tensor) {
2132   MS_EXCEPTION_IF_NULL(tensor);
2133   const auto &tensor_type_id = tensor->data_type();
2134   const auto dst_ptr = tensor->data_c();
2135   MS_EXCEPTION_IF_NULL(dst_ptr);
2136   MS_LOG(DEBUG) << "Set scalar tuple to tensor, dst size:" << tensor->data().nbytes();
2137   for (size_t i = 0; i < values.size(); ++i) {
2138     // Check mem size.
2139     if (SizeToLong(abstract::TypeIdSize(tensor_type_id) * (i + 1)) > tensor->data().nbytes()) {
2140       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Value size:" << values.size()
2141                                  << " type:" << tensor_type_id << " out of range:" << tensor->data().nbytes();
2142     }
2143     const auto &value = values[i];
2144     MS_EXCEPTION_IF_NULL(value);
2145     // Check value type.
2146     if (value->type()->type_id() != tensor_type_id) {
2147       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid value type:" << value->type()->type_id()
2148                                  << " for value:" << value->ToString() << " dst type:" << tensor_type_id;
2149     }
2150     if (tensor_type_id == TypeId::kNumberTypeInt8) {
2151       (reinterpret_cast<int8_t *>(dst_ptr))[i] = GetValue<int8_t>(value);
2152     } else if (tensor_type_id == TypeId::kNumberTypeInt16) {
2153       (reinterpret_cast<int16_t *>(dst_ptr))[i] = GetValue<int16_t>(value);
2154     } else if (tensor_type_id == TypeId::kNumberTypeInt32 || tensor_type_id == kNumberTypeInt) {
2155       (reinterpret_cast<int32_t *>(dst_ptr))[i] = GetValue<int32_t>(value);
2156     } else if (tensor_type_id == TypeId::kNumberTypeInt64) {
2157       (reinterpret_cast<int64_t *>(dst_ptr))[i] = GetValue<int64_t>(value);
2158     } else if (tensor_type_id == TypeId::kNumberTypeBool) {
2159       (reinterpret_cast<bool *>(dst_ptr))[i] = GetValue<bool>(value);
2160     } else if (tensor_type_id == TypeId::kNumberTypeFloat32 || tensor_type_id == TypeId::kNumberTypeFloat) {
2161       (reinterpret_cast<float *>(dst_ptr))[i] = GetValue<float>(value);
2162     } else if (tensor_type_id == TypeId::kNumberTypeFloat64) {
2163       (reinterpret_cast<double *>(dst_ptr))[i] = GetValue<double>(value);
2164     } else if (tensor_type_id == TypeId::kNumberTypeUInt8) {
2165       (reinterpret_cast<uint8_t *>(dst_ptr))[i] = GetValue<uint8_t>(value);
2166     } else if (tensor_type_id == TypeId::kNumberTypeUInt16) {
2167       (reinterpret_cast<uint16_t *>(dst_ptr))[i] = GetValue<uint16_t>(value);
2168     } else if (tensor_type_id == TypeId::kNumberTypeUInt || tensor_type_id == TypeId::kNumberTypeUInt32) {
2169       (reinterpret_cast<uint32_t *>(dst_ptr))[i] = GetValue<uint32_t>(value);
2170     } else if (tensor_type_id == TypeId::kNumberTypeUInt64) {
2171       (reinterpret_cast<uint64_t *>(dst_ptr))[i] = GetValue<uint64_t>(value);
2172     } else {
2173       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid tuple type:" << tensor_type_id
2174                                  << " for scalar to tensor.";
2175     }
2176   }
2177 }
2178 }  // namespace
2179 
CreateMapTensor(const DeviceAddressPtr & output_device_address)2180 tensor::TensorPtr AnfRuntimeAlgorithm::CreateMapTensor(const DeviceAddressPtr &output_device_address) {
2181   MS_EXCEPTION_IF_NULL(output_device_address);
2182   const auto &user_data = output_device_address->user_data();
2183   MS_EXCEPTION_IF_NULL(user_data);
2184   const auto &user_data_type = user_data->get<UserDataType>(kUserDataType);
2185   MS_EXCEPTION_IF_NULL(user_data_type);
2186   if (*user_data_type == UserDataType::kUserTypeHashTable) {
2187     auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector);
2188     auto key_type = user_data->get<TypeId>(kHashTableKeyType);
2189     auto value_type = user_data->get<TypeId>(kHashTableValueType);
2190     auto default_value = user_data->get<Value>(kHashTableDefaultValue);
2191     MS_EXCEPTION_IF_NULL(shape_vector);
2192     MS_EXCEPTION_IF_NULL(key_type);
2193     MS_EXCEPTION_IF_NULL(value_type);
2194     MS_EXCEPTION_IF_NULL(default_value);
2195     auto map_tensor = std::make_shared<tensor::MapTensor>(*key_type, *value_type, *shape_vector, default_value);
2196     map_tensor->set_device_address(output_device_address);
2197     return map_tensor;
2198   }
2199   MS_LOG(WARNING) << "Invalid user data type:" << *user_data_type;
2200   return nullptr;
2201 }
2202 
CreateMapTensor(const AnfNodePtr & output_node,size_t output_index)2203 tensor::TensorPtr AnfRuntimeAlgorithm::CreateMapTensor(const AnfNodePtr &output_node, size_t output_index) {
2204   const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
2205   return CreateMapTensor(device_tensor);
2206 }
2207 
2208 // In dynamic sequence, since the number of members is not determined in compile time, the entire sequence needs
2209 // to be placed in single tensor, and the shape of the tuple needs to be recorded in the tensor, so that the shape
2210 // of the tensor can be accurately restored during the dynamic shape derivation process in runtime.
SequenceToTensor(const ValuePtr & value)2211 tensor::TensorPtr AnfRuntimeAlgorithm::SequenceToTensor(const ValuePtr &value) {
2212   MS_EXCEPTION_IF_NULL(value);
2213   if (!value->isa<ValueSequence>()) {
2214     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid sequence value:" << value->ToString();
2215   }
2216 
2217   const auto &sequence_value = value->cast<ValueSequencePtr>();
2218   const auto &values = sequence_value->value();
2219   if (values.empty()) {
2220     auto tensor = std::make_shared<tensor::Tensor>();
2221     abstract::BaseShapePtr base_shape = nullptr;
2222     if (value->isa<ValueTuple>()) {
2223       base_shape = std::make_shared<abstract::TupleShape>(abstract::BaseShapePtrList());
2224     } else {
2225       base_shape = std::make_shared<abstract::ListShape>(abstract::BaseShapePtrList());
2226     }
2227     tensor->set_base_shape(base_shape);
2228     return tensor;
2229   }
2230   if (values[0] == nullptr || ((!values[0]->isa<Scalar>()) && (!values[0]->isa<tensor::BaseTensor>()))) {
2231     MS_LOG(WARNING) << "Empty sequence in sequence value:" << value->ToString();
2232     return std::make_shared<tensor::Tensor>();
2233   }
2234 
2235   ShapeVector shape_vector{SizeToLong(values.size())};
2236   if (values[0]->isa<tensor::BaseTensor>()) {
2237     MS_LOG(DEBUG) << "Check dynamic tuple tensor";
2238     if (!CheckValidTensorTuple(values)) {
2239       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid dynamic sequence tuple:"
2240                                  << value->ToString();
2241     }
2242     const auto &tensor = values[0]->cast<tensor::BaseTensorPtr>();
2243     MS_EXCEPTION_IF_NULL(tensor);
2244     size_t size = tensor->Size();
2245     const auto &type_id = tensor->data_type();
2246     auto single_shape_vector = tensor->shape();
2247     const auto &single_shape = std::make_shared<abstract::Shape>(single_shape_vector);
2248     (void)shape_vector.insert(shape_vector.end(), single_shape_vector.begin(), single_shape_vector.end());
2249     const auto &shape = std::make_shared<abstract::Shape>(shape_vector);
2250     auto new_tensor = std::make_shared<tensor::Tensor>(type_id, shape_vector);
2251     MS_EXCEPTION_IF_NULL(new_tensor);
2252     const auto dst_ptr = new_tensor->data_c();
2253     MS_EXCEPTION_IF_NULL(dst_ptr);
2254     MS_LOG(DEBUG) << "Copy start, dst size:" << new_tensor->data().nbytes();
2255     for (size_t i = 0; i < values.size(); ++i) {
2256       const auto &sub_value = values[i];
2257       MS_EXCEPTION_IF_NULL(sub_value);
2258       const auto &src_tensor = sub_value->cast<tensor::TensorPtr>();
2259       MS_EXCEPTION_IF_NULL(src_tensor);
2260       MS_EXCEPTION_IF_NULL(src_tensor->data_c());
2261       auto ret = memcpy_s((reinterpret_cast<char *>(dst_ptr)) + i * size,
2262                           static_cast<size_t>(new_tensor->data().nbytes()), src_tensor->data_c(), size);
2263       if (ret != EOK) {
2264         MS_LOG(INTERNAL_EXCEPTION)
2265           << "#dmsg#Runtime error info:#dmsg#Failed to copy data into tensor, memcpy_s errorno: " << ret;
2266       }
2267     }
2268     const auto &element_shapes = std::vector<abstract::BaseShapePtr>(values.size(), single_shape);
2269     new_tensor->set_base_shape(std::make_shared<abstract::TupleShape>(element_shapes));
2270     MS_LOG(DEBUG) << "merge tensor from:" << value->ToString() << " to:" << new_tensor->ToString() << " tensor addr"
2271                   << new_tensor;
2272     return new_tensor;
2273   }
2274 
2275   // Create the tensor.
2276   auto tensor = std::make_shared<tensor::Tensor>(values[0]->type()->type_id(), shape_vector);
2277   MS_EXCEPTION_IF_NULL(tensor);
2278   SetScalarToTensor(values, tensor);
2279   // Build the tuple shape and set into tensor.
2280   const auto &element_shape = std::make_shared<abstract::Shape>(ShapeVector({}));
2281   const auto &element_shapes = std::vector<abstract::BaseShapePtr>(values.size(), element_shape);
2282   tensor->set_base_shape(std::make_shared<abstract::TupleShape>(element_shapes));
2283   return tensor;
2284 }
2285 
FlattenDynamicInputArg(const BaseRef & arg,const AnfNodePtr & node,std::vector<tensor::TensorPtr> * flatten_tensors)2286 void AnfRuntimeAlgorithm::FlattenDynamicInputArg(const BaseRef &arg, const AnfNodePtr &node,
2287                                                  std::vector<tensor::TensorPtr> *flatten_tensors) {
2288   MS_EXCEPTION_IF_NULL(node);
2289   MS_EXCEPTION_IF_NULL(flatten_tensors);
2290   MS_LOG(DEBUG) << "Dynamic sequence node:" << node->fullname_with_scope() << " abs:" << node->abstract()->ToString();
2291   if (!utils::isa<ValuePtr>(arg)) {
2292     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid input for dynamic sequence node:"
2293                                << node->DebugString();
2294   }
2295   auto value = utils::cast<ValuePtr>(arg);
2296   MS_EXCEPTION_IF_NULL(value);
2297   if (!value->isa<ValueSequence>()) {
2298     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid value:" << value->ToString()
2299                                << " for dynamic sequence node:" << node->DebugString();
2300   }
2301   const auto &tensor = AnfAlgo::SequenceToTensor(value);
2302   flatten_tensors->emplace_back(tensor);
2303 }
2304 
FlattenInputArg(const BaseRef & arg,const AnfNodePtr & node,std::vector<tensor::TensorPtr> * flatten_tensors)2305 void AnfRuntimeAlgorithm::FlattenInputArg(const BaseRef &arg, const AnfNodePtr &node,
2306                                           std::vector<tensor::TensorPtr> *flatten_tensors) {
2307   MS_EXCEPTION_IF_NULL(flatten_tensors);
2308   if (node != nullptr && node->abstract() != nullptr && common::AnfAlgo::IsDynamicSequence(node)) {
2309     FlattenDynamicInputArg(arg, node, flatten_tensors);
2310     return;
2311   }
2312 
2313 #ifndef BUILD_LITE
2314   if (utils::isa<PyObjectRef>(arg)) {
2315     auto value = utils::cast<PyObjectRef>(arg).object_;
2316     flatten_tensors->push_back(py::cast<tensor::TensorPtr>(value));
2317     return;
2318   }
2319 #endif
2320 
2321   if (utils::isa<tensor::Tensor>(arg)) {
2322     (void)flatten_tensors->emplace_back(utils::cast<tensor::TensorPtr>(arg));
2323   } else if (utils::isa<tensor::BaseTensor>(arg)) {
2324     (void)flatten_tensors->emplace_back(std::make_shared<tensor::Tensor>(*utils::cast<tensor::BaseTensorPtr>(arg)));
2325   } else if (utils::isa<Scalar>(arg)) {
2326     (void)flatten_tensors->emplace_back(ScalarToTensor(utils::cast<ScalarPtr>(arg)));
2327   } else if (utils::isa<Monad>(arg)) {
2328     // If value is a monad, replace it with an unused tensor.
2329     flatten_tensors->push_back(std::make_shared<tensor::Tensor>(int64_t(0), kBool));
2330   } else if (utils::isa<ValueSequencePtr>(arg)) {
2331     auto value_sequence = utils::cast<ValueSequencePtr>(arg);
2332     MS_EXCEPTION_IF_NULL(value_sequence);
2333     auto sequence_value = value_sequence->value();
2334     for (auto &value : sequence_value) {
2335       FlattenInputArg(value, node, flatten_tensors);
2336     }
2337   } else if (utils::isa<ValueDictionaryPtr>(arg)) {
2338     auto value_dict = utils::cast<ValueDictionaryPtr>(arg);
2339     MS_EXCEPTION_IF_NULL(value_dict);
2340     auto dict_value = value_dict->value();
2341     for (auto &iter : dict_value) {
2342       FlattenInputArg(iter.second, node, flatten_tensors);
2343     }
2344   } else if (utils::isa<tensor::COOTensorPtr>(arg)) {
2345     auto coo_tensor = utils::cast<tensor::COOTensorPtr>(arg);
2346     MS_EXCEPTION_IF_NULL(coo_tensor);
2347     for (size_t i = 0; i < coo_tensor->GetTensorLength(); ++i) {
2348       (void)flatten_tensors->emplace_back(coo_tensor->GetTensorAt(i));
2349     }
2350   } else if (utils::isa<tensor::CSRTensorPtr>(arg)) {
2351     auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(arg);
2352     MS_EXCEPTION_IF_NULL(csr_tensor);
2353     for (size_t i = 0; i < csr_tensor->GetTensorLength(); ++i) {
2354       (void)flatten_tensors->emplace_back(csr_tensor->GetTensorAt(i));
2355     }
2356   } else if (utils::isa<VectorRefPtr>(arg)) {
2357     const auto &args_new = utils::cast<VectorRef>(arg);
2358     for (const auto &arg_new : args_new) {
2359       FlattenInputArg(arg_new, node, flatten_tensors);
2360     }
2361   } else {
2362     MS_LOG(INTERNAL_EXCEPTION)
2363       << "#dmsg#Runtime error info:#dmsg#The value input to flatten tensor not supported for type " << arg.ToString();
2364   }
2365 }
2366 
UpdateValueNodeShape(const AnfNodePtr & node)2367 void AnfRuntimeAlgorithm::UpdateValueNodeShape(const AnfNodePtr &node) {
2368   MS_EXCEPTION_IF_NULL(node);
2369   if (!node->isa<ValueNode>()) {
2370     return;
2371   }
2372   const auto &value_node = node->cast<ValueNodePtr>();
2373   MS_EXCEPTION_IF_NULL(value_node);
2374   const auto &value = value_node->value();
2375   MS_EXCEPTION_IF_NULL(value);
2376   if (!value->isa<ValueSequence>()) {
2377     return;
2378   }
2379   const auto &value_sequence = value->cast<ValueSequencePtr>();
2380   MS_EXCEPTION_IF_NULL(value_sequence);
2381   std::vector<abstract::AbstractBasePtr> abstract_list;
2382   for (const auto &sub_value : value_sequence->value()) {
2383     MS_EXCEPTION_IF_NULL(sub_value);
2384     if (sub_value->isa<Scalar>()) {
2385       auto abstract = std::make_shared<abstract::AbstractScalar>(sub_value->type());
2386       (void)abstract_list.emplace_back(abstract);
2387     } else if (sub_value->isa<tensor::Tensor>()) {
2388       const auto &tensor = sub_value->cast<tensor::TensorPtr>();
2389       MS_EXCEPTION_IF_NULL(tensor);
2390       auto abstract = std::make_shared<abstract::AbstractTensor>(tensor->Dtype(), tensor->shape());
2391       (void)abstract_list.emplace_back(abstract);
2392     } else {
2393       MS_LOG(EXCEPTION) << "Invalid value:" << sub_value->ToString()
2394                         << " in dynamic sequence value node:" << node->DebugString();
2395     }
2396   }
2397   auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
2398   MS_LOG(INFO) << "Set abstract for node:" << node->DebugString() << "from:" << node->abstract()->ToString()
2399                << " to:" << abstract_tuple->ToString();
2400   node->set_abstract(abstract_tuple);
2401 }
2402 
HasSelectKernelBuildInfo(const AnfNodePtr & node)2403 bool AnfRuntimeAlgorithm::HasSelectKernelBuildInfo(const AnfNodePtr &node) {
2404   MS_EXCEPTION_IF_NULL(node);
2405   auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
2406   if (kernel_info == nullptr) {
2407     return false;
2408   }
2409   auto build_info = kernel_info->select_kernel_build_info();
2410   if (build_info == nullptr) {
2411     return false;
2412   }
2413   return true;
2414 }
2415 
NeedEraseCache(const PrimitivePtr & prim)2416 bool AnfRuntimeAlgorithm::NeedEraseCache(const PrimitivePtr &prim) {
2417   MS_EXCEPTION_IF_NULL(prim);
2418   if (!prim->HasAttr(kRandomCache)) {
2419     return false;
2420   }
2421   auto random_cache_value = prim->GetAttr(kRandomCache);
2422   MS_EXCEPTION_IF_NULL(random_cache_value);
2423   return !GetValue<bool>(random_cache_value);
2424 }
2425 
GetNodeAbstractByIndex(const AnfNodePtr & node,size_t index)2426 abstract::AbstractBasePtr AnfRuntimeAlgorithm::GetNodeAbstractByIndex(const AnfNodePtr &node, size_t index) {
2427   MS_EXCEPTION_IF_NULL(node);
2428   const auto &abstract = node->abstract();
2429   if (abstract == nullptr) {
2430     return abstract;
2431   }
2432 
2433   // Return output abstract directly for : 1.not sequence type, 2.dynamic sequence type, 3.real tuple/list type.
2434   if (!abstract->isa<abstract::AbstractSequence>() || common::AnfAlgo::IsDynamicSequence(node) ||
2435       (node->isa<CNode>() && !mindspore::AnfAlgo::GetOutputKernelObjectTypes(node).empty() &&
2436        (mindspore::session::AnfRuntimeAlgorithm::GetOutputKernelObjectType(node, 0) ==
2437         kernel::KernelObjectType::TUPLE))) {
2438     MS_EXCEPTION_IF_CHECK_FAIL((index == 0), "Cannot get " + std::to_string(index) + " child abstract from " +
2439                                                abstract->ToString() + " in node:" + node->fullname_with_scope());
2440     return abstract;
2441   }
2442 
2443   // Return element abstract by index for tuple type.
2444   const auto &abstract_tuple = abstract->cast<abstract::AbstractSequencePtr>();
2445   MS_EXCEPTION_IF_NULL(abstract_tuple);
2446   const auto &elements = abstract_tuple->elements();
2447   if (elements.size() <= index) {
2448     const auto sub_abstract = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), index);
2449     return sub_abstract;
2450   }
2451   return elements[index];
2452 }
2453 
CreateTypeIdValueNodeToKernelGraph(const FuncGraphPtr & func_graph,TypeId data_type)2454 ValueNodePtr AnfRuntimeAlgorithm::CreateTypeIdValueNodeToKernelGraph(const FuncGraphPtr &func_graph, TypeId data_type) {
2455   auto type_id_value_node = NewValueNode(static_cast<int64_t>(data_type));
2456   auto type_id_value = std::make_shared<Int64Imm>(static_cast<int64_t>(data_type));
2457   type_id_value_node->set_abstract(type_id_value->ToAbstract());
2458   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
2459   MS_EXCEPTION_IF_NULL(kernel_graph);
2460   type_id_value_node = kernel_graph->NewValueNode(type_id_value_node);
2461   kernel_graph->AddValueNodeToGraph(type_id_value_node);
2462   return type_id_value_node;
2463 }
2464 
CreateTypeIdValueNodeToFuncGraph(const FuncGraphPtr & func_graph,TypeId data_type)2465 ValueNodePtr AnfRuntimeAlgorithm::CreateTypeIdValueNodeToFuncGraph(const FuncGraphPtr &func_graph, TypeId data_type) {
2466   auto type_id_value_node = NewValueNode(static_cast<int64_t>(data_type));
2467   auto type_id_value = std::make_shared<Int64Imm>(static_cast<int64_t>(data_type));
2468   type_id_value_node->set_abstract(type_id_value->ToAbstract());
2469   return type_id_value_node;
2470 }
2471 }  // namespace mindspore::session
2472