• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "kernel/framework_utils.h"
18 #include <algorithm>
19 #include <map>
20 #include <set>
21 #include <utility>
22 #include "include/backend/anf_runtime_algorithm.h"
23 #include "include/common/utils/anfalgo.h"
24 #include "include/common/utils/convert_utils.h"
25 #include "kernel/common_utils.h"
26 #include "kernel/format_utils.h"
27 #include "kernel/oplib/oplib.h"
28 #include "mindapi/base/type_id.h"
29 #include "mindspore/ccsrc/include/common/debug/common.h"
30 #include "ops/array_op_name.h"
31 #include "ops/conv_pool_op_name.h"
32 #include "ops/framework_ops.h"
33 #include "ops/math_op_name.h"
34 #include "ops/random_op_name.h"
35 #include "ops/image_op_name.h"
36 #include "ops/nn_op_name.h"
37 #include "ops/nn_ops.h"
38 #include "ops/sequence_ops.h"
39 #include "utils/file_utils.h"
40 #include "utils/ms_context.h"
41 #include "utils/trace_base.h"
42 
43 namespace mindspore {
44 namespace kernel {
45 namespace {
46 constexpr char kAxis[] = "axis";
47 constexpr char kOperatorOriginFormat[] = "operator_origin_format";
48 constexpr char kKernelObjectTypeNotSupportedStr[] = "KernelObjectTypeNotSupported";
49 
GetValidShapeFromAbstract(const abstract::AbstractBasePtr & abs)50 abstract::BaseShapePtr GetValidShapeFromAbstract(const abstract::AbstractBasePtr &abs) {
51   MS_EXCEPTION_IF_NULL(abs);
52   // Other abstract class, such as AbstractCSRTensor and AbstractCOOTensor, is converted to AbstractTensor early time.
53   abstract::BaseShapePtr res_shape;
54   if (abs->isa<abstract::AbstractTensor>() || abs->isa<abstract::AbstractMapTensor>()) {
55     res_shape = abs->BuildShape();
56   } else if (abs->isa<abstract::AbstractScalar>()) {
57     res_shape = std::make_shared<abstract::Shape>(ShapeVector{});
58   } else {
59     MS_INTERNAL_EXCEPTION(TypeError) << "The abstract must be a Scalar or Tensor, but got " << abs->ToString();
60   }
61   return res_shape;
62 }
63 
GetChildAbstract(const abstract::AbstractBasePtr & cur_abstract,size_t idx)64 abstract::AbstractBasePtr GetChildAbstract(const abstract::AbstractBasePtr &cur_abstract, size_t idx) {
65   MS_EXCEPTION_IF_NULL(cur_abstract);
66   abstract::AbstractBasePtr child_abs = cur_abstract;
67   if (cur_abstract->isa<abstract::AbstractTuple>()) {
68     auto abs_tuple = cur_abstract->Clone()->cast<abstract::AbstractTuplePtr>();
69     MS_EXCEPTION_IF_NULL(abs_tuple);
70     auto abs_element = abs_tuple->elements();
71     MS_EXCEPTION_IF_CHECK_FAIL((idx < abs_element.size()), "Index is out of range, idx:" + std::to_string(idx) +
72                                                              " size:" + std::to_string(abs_element.size()) +
73                                                              " abs:" + abs_tuple->ToString());
74     child_abs = abs_element.at(idx);
75   } else {
76     MS_EXCEPTION_IF_CHECK_FAIL(
77       (idx == 0), "Cannot get " + std::to_string(idx) + " child abstract from " + cur_abstract->ToString());
78   }
79 
80   return child_abs;
81 }
82 
CreateKernelTensor(const abstract::AbstractBasePtr & cur_abstract,const TypeId & real_type,size_t idx,const std::string & format_str,bool prev_node_has_getitem=false)83 KernelTensorPtr CreateKernelTensor(const abstract::AbstractBasePtr &cur_abstract, const TypeId &real_type, size_t idx,
84                                    const std::string &format_str, bool prev_node_has_getitem = false) {
85   MS_EXCEPTION_IF_NULL(cur_abstract);
86   abstract::AbstractBasePtr tag_abstract = nullptr;
87   abstract::AbstractBasePtr new_abstract = nullptr;
88   if (prev_node_has_getitem) {
89     tag_abstract = cur_abstract;
90   } else {
91     tag_abstract = GetChildAbstract(cur_abstract, idx);
92   }
93   TypePtr tag_type_ptr = TypeIdToType(real_type);
94 
95   if (tag_abstract->isa<abstract::AbstractTensor>()) {
96     auto abstract_shape_ptr = GetValidShapeFromAbstract(tag_abstract);
97     new_abstract = std::make_shared<abstract::AbstractTensor>(tag_type_ptr, abstract_shape_ptr);
98   } else {
99     new_abstract = tag_abstract->Clone();
100   }
101   KernelTensorPtr res_tensor =
102     std::make_shared<KernelTensor>(new_abstract->GetShape(), new_abstract->GetType(), new_abstract->GetValue());
103   res_tensor->set_format(GetFormatFromStrToEnum(format_str));
104   return res_tensor;
105 }
106 
AdditionalAttrProcess(const ops::PrimitiveCPtr & primc,const CNodePtr & cnode)107 void AdditionalAttrProcess(const ops::PrimitiveCPtr &primc, const CNodePtr &cnode) {
108   MS_EXCEPTION_IF_NULL(primc);
109   MS_EXCEPTION_IF_NULL(cnode);
110   mindspore::HashMap<std::string, ValuePtr> additional_attrs;
111   additional_attrs[kOperatorOriginFormat] = MakeValue(AnfAlgo::GetOriginDataFormat(cnode));
112   (void)primc->SetAttrs(additional_attrs);
113 }
114 
CheckRealTupleFromCNode(const std::vector<mindspore::kernel::KernelObjectType> & input_obj_types,const size_t input_idx)115 bool CheckRealTupleFromCNode(const std::vector<mindspore::kernel::KernelObjectType> &input_obj_types,
116                              const size_t input_idx) {
117   // if input_obj_types is empty, regard it as a Tensor by default.
118   if (input_obj_types.size() > input_idx && input_obj_types[input_idx] == KernelObjectType::TUPLE) {
119     return true;
120   }
121   return false;
122 }
123 
124 using InOutKernelTensors = std::pair<std::vector<KernelTensorPtr>, std::vector<KernelTensorPtr>>;
AbstractInOutFromCNode(const CNodePtr & cnode)125 inline InOutKernelTensors AbstractInOutFromCNode(const CNodePtr &cnode) {
126   MS_EXCEPTION_IF_NULL(cnode);
127   // Makeup input KernelTensors, meta_types can be tensor, scalar, tuple, list.
128   std::vector<KernelTensorPtr> input_tensors;
129   auto real_input_types = AnfAlgo::GetAllInputDeviceTypes(cnode);
130   size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
131   for (size_t input_idx = 0; input_idx < input_num; ++input_idx) {
132     const auto &[prev_node, output_idx] = common::AnfAlgo::GetPrevNodeOutput(cnode, input_idx);
133     bool prev_node_has_getitem = common::AnfAlgo::IsPrevNodeHasTupleGetItem(cnode, input_idx);
134     auto prev_abstract = prev_node->abstract();
135     auto real_input_type = real_input_types[input_idx];
136     if (IsPrimitiveCNode(prev_node, prim::kPrimPyExecute)) {
137       real_input_type = common::AnfAlgo::GetOutputInferDataType(prev_node, 0);
138       MS_LOG(DEBUG) << "need changed type node:" << cnode->DebugString()
139                     << "Real input type :" << TypeIdToType(real_input_type)->ToString();
140     }
141     auto format_str = AnfAlgo::GetInputFormat(cnode, input_idx);
142     auto input_tensor = CreateKernelTensor(prev_abstract, real_input_type, output_idx, format_str,
143                                            ((!prev_node_has_getitem) || common::AnfAlgo::IsDynamicSequence(prev_node)));
144     input_tensors.push_back(input_tensor);
145   }
146 
147   // Makeup output tensors.
148   std::vector<KernelTensorPtr> output_tensors;
149   auto real_output_types = AnfAlgo::GetAllOutputDeviceTypes(cnode);
150   auto cur_abstract = cnode->abstract();
151   MS_EXCEPTION_IF_NULL(cur_abstract);
152   size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
153   auto build_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
154   auto output_obj_types = build_info->GetAllOutputKernelObjectTypes();
155   for (size_t output_idx = 0; output_idx < output_num; ++output_idx) {
156     bool is_real_tuple_output = CheckRealTupleFromCNode(output_obj_types, output_idx);
157     auto real_output_type = real_output_types[output_idx];
158     if (IsPrimitiveCNode(cnode, prim::kPrimPyExecute)) {
159       real_output_type = common::AnfAlgo::GetOutputInferDataType(cnode, 0);
160       MS_LOG(DEBUG) << "need changed type node:" << cnode->DebugString()
161                     << "Real output type :" << TypeIdToType(real_output_type)->ToString()
162                     << " is dynamic len:" << common::AnfAlgo::IsDynamicSequence(cnode);
163     }
164     auto format_str = AnfAlgo::GetOutputFormat(cnode, output_idx);
165     auto output_tensor = CreateKernelTensor(cur_abstract, real_output_type, output_idx, format_str,
166                                             is_real_tuple_output || common::AnfAlgo::IsDynamicSequence(cnode));
167     output_tensors.push_back(output_tensor);
168   }
169   return std::make_pair(input_tensors, output_tensors);
170 }
171 
IsObjectTypeStrictlyMatched(const std::vector<TypeId> & object_types,const std::vector<DataType> & kernel_data_types)172 bool IsObjectTypeStrictlyMatched(const std::vector<TypeId> &object_types,
173                                  const std::vector<DataType> &kernel_data_types) {
174   if (object_types.size() != kernel_data_types.size()) {
175     return false;
176   }
177 
178   for (size_t i = 0; i < object_types.size(); i++) {
179     // For optional input, the real input object type can be a None.
180     if ((object_types[i] != kernel_data_types[i].object_type) &&
181         !(object_types[i] == kMetaTypeNone && kernel_data_types[i].is_optional)) {
182       return false;
183     }
184   }
185 
186   return true;
187 }
188 
IsObjectTypeWeaklyMatched(const std::vector<TypeId> & object_types,const std::vector<DataType> & kernel_data_types,bool all_same,size_t element_num)189 bool IsObjectTypeWeaklyMatched(const std::vector<TypeId> &object_types, const std::vector<DataType> &kernel_data_types,
190                                bool all_same, size_t element_num) {
191   // 1. The size equal can trigger the kernel object backoff(For example Reshape op).
192   if (object_types.size() == kernel_data_types.size()) {
193     return true;
194   }
195 
196   // 2. AllSame is the tupleUnfold type(For example Split/Addn op).
197   if (all_same) {
198     return true;
199   }
200 
201   // 3. Multiple outputs are expanded in the kernel attr(For example BatchNorm op).
202   if (kernel_data_types.size() == element_num) {
203     return true;
204   }
205 
206   return false;
207 }
208 }  // namespace
209 
GetInOutDataTypesFromKernelAttr(const KernelAttr & kernel_attr)210 std::pair<std::vector<DataType>, std::vector<DataType>> GetInOutDataTypesFromKernelAttr(const KernelAttr &kernel_attr) {
211   size_t input_attr_size = kernel_attr.GetInputSize();
212   std::vector<DataType> input_data_types;
213   for (size_t i = 0; i < input_attr_size; ++i) {
214     input_data_types.push_back(kernel_attr.GetInputAttr(i));
215   }
216 
217   size_t output_attr_size = kernel_attr.GetOutputSize();
218   std::vector<DataType> output_data_types;
219   for (size_t i = 0; i < output_attr_size; ++i) {
220     output_data_types.push_back(kernel_attr.GetOutputAttr(i));
221   }
222 
223   return std::make_pair(input_data_types, output_data_types);
224 }
GetCompilerCachePath()225 std::string GetCompilerCachePath() { return Common::GetUserDefineCachePath(); }
226 
CheckCache(const std::string & kernel_name)227 bool CheckCache(const std::string &kernel_name) {
228   // check cache.
229   KernelMeta *bin_map = KernelMeta::GetInstance();
230   if (bin_map == nullptr) {
231     MS_LOG(DEBUG) << "Kernel cache is invalid, kernel_name: " << kernel_name;
232     return false;
233   }
234   std::string kernel_json = bin_map->Search(kernel_name);
235   bool ret = (!kernel_json.empty());
236   if (ret) {
237     MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered.";
238   } else {
239     MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered.";
240   }
241   return ret;
242 }
243 
SearchCache(const std::string & kernel_name,const std::string & processor)244 KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
245   // search cache.
246   KernelMeta *bin_map = KernelMeta::GetInstance();
247   if (bin_map == nullptr) {
248     MS_LOG(DEBUG) << "kernel cache is invalid, kernel_name: " << kernel_name;
249     return nullptr;
250   }
251 
252   std::string kernel_json = bin_map->Search(kernel_name);
253   if (!kernel_json.empty()) {
254     KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
255     // just a tmp solution.
256     if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
257       MS_LOG(ERROR) << "Read cache json and bin file failed[" << kernel_json << "].";
258       return nullptr;
259     } else {
260       return kernel_pack;
261     }
262   } else {
263     MS_LOG(INFO) << "The cache kernel not found[" << kernel_name << "].";
264     return nullptr;
265   }
266 }
267 
InsertCache(const std::string & kernel_name,const std::string & processor)268 KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
269   MS_LOG(INFO) << "Insert cache for kernel:" << kernel_name << ", processr:" << processor;
270   KernelMeta *bin_map = KernelMeta::GetInstance();
271   if (bin_map == nullptr) {
272     MS_LOG(DEBUG) << "Kernel cache is invalid, kernel name :" << kernel_name;
273     return nullptr;
274   }
275   std::string kernel_json = bin_map->kernel_meta_path();
276   (void)kernel_json.append(kernel_name).append(kJsonSuffix);
277   KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
278   if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
279     MS_LOG(ERROR) << "Read json and bin file failed[" << kernel_json << "].";
280     return nullptr;
281   }
282   if (bin_map->Insert(kernel_name, kernel_json)) {
283     MS_LOG(INFO) << "Kernel insert cache success[" << kernel_json << "], kernel name[" << kernel_name << "].";
284   }
285   return kernel_pack;
286 }
287 
Initialize(const std::string & backend)288 void KernelMeta::Initialize(const std::string &backend) {
289   auto config_path = GetCompilerCachePath();
290   kernel_meta_path_ = config_path + backend + std::string(kKernelMetaSuffix);
291   (void)(FileUtils::CreateNotExistDirs(kernel_meta_path_, true));
292   initialized_ = true;
293 }
294 
Search(const std::string & kernel_name) const295 std::string KernelMeta::Search(const std::string &kernel_name) const {
296   if (!initialized_) {
297     return "";
298   }
299 
300   auto iter = kernel_meta_map_.find(kernel_name);
301   if (iter == kernel_meta_map_.end()) {
302     return "";
303   } else {
304     return iter->second;
305   }
306 }
307 
Insert(const std::string & kernel_name,const std::string & kernel_json)308 bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
309   if (!initialized_) {
310     return false;
311   }
312   kernel_meta_map_[kernel_name] = kernel_json;
313   return true;
314 }
315 
SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> & inputs,size_t real_input_num,size_t builder_idex,const std::vector<int64_t> & dyn_input_sizes,const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> & builder)316 bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
317                                size_t builder_idex, const std::vector<int64_t> &dyn_input_sizes,
318                                const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
319   MS_EXCEPTION_IF_NULL(builder);
320 
321   std::vector<TypeId> inputs_device_type;
322   std::vector<std::string> inputs_format;
323   std::vector<KernelObjectType> inputs_object_type;
324   size_t dyn_input_idx = 0;
325   size_t kernel_info_index = 0;
326   MS_EXCEPTION_IF_NULL(inputs[0]);
327   size_t kernel_info_cnt = inputs[0]->dtypes().size();
328 
329   for (const auto &input : inputs) {
330     MS_EXCEPTION_IF_NULL(input);
331     std::string param_type = input->param_type();
332     std::vector<std::string> dtypes = input->dtypes();
333     std::vector<std::string> formats = input->formats();
334     std::vector<std::string> object_types = input->object_types();
335     if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt ||
336         object_types.size() != kernel_info_cnt) {
337       MS_LOG(DEBUG) << "Set input kernel builder info failed, dtyps size, formats size and object_types size are not "
338                        "same. dtypes size: "
339                     << dtypes.size() << ", formats size : " << formats.size()
340                     << ", object_types size: " << object_types.size();
341       return false;
342     }
343 
344     if (param_type == "dynamic") {
345       if (dyn_input_sizes.empty()) {
346         MS_LOG(DEBUG) << "Set input kernel builder info failed, dyn_input_sizes's size is 0 when param_type is dynamic";
347         return false;
348       }
349 
350       for (int64_t t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
351         kernel_info_index++;
352         auto type_id = DtypeToTypeId(dtypes[builder_idex]);
353         inputs_device_type.push_back(type_id);
354         inputs_format.push_back(formats[builder_idex]);
355         inputs_object_type.push_back(StringToKernelObjectType(object_types[builder_idex]));
356       }
357     } else if (param_type == "required") {
358       kernel_info_index++;
359       auto type_id = DtypeToTypeId(dtypes[builder_idex]);
360       inputs_device_type.push_back(type_id);
361       inputs_format.push_back(formats[builder_idex]);
362       inputs_object_type.push_back(StringToKernelObjectType(object_types[builder_idex]));
363     } else {
364       if (kernel_info_index < real_input_num) {
365         MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
366         kernel_info_index++;
367         auto type_id = DtypeToTypeId(dtypes[builder_idex]);
368         inputs_device_type.push_back(type_id);
369         inputs_format.push_back(formats[builder_idex]);
370         inputs_object_type.push_back(StringToKernelObjectType(object_types[builder_idex]));
371       }
372     }
373     dyn_input_idx++;
374   }
375 
376   builder->SetInputsDeviceType(inputs_device_type);
377   builder->SetInputsFormat(inputs_format);
378   builder->SetInputsKernelObjectType(inputs_object_type);
379 
380   return true;
381 }
382 
SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> & outputs,size_t builder_idex,const size_t & real_output_num,const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> & builder)383 bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
384                                 const size_t &real_output_num,
385                                 const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
386   // not now but in the next we need to support dynamic output case
387   MS_EXCEPTION_IF_NULL(builder);
388 
389   size_t output_idx = 0;
390   std::vector<TypeId> outputs_device_type;
391   std::vector<std::string> outputs_format;
392   std::vector<KernelObjectType> outputs_object_type;
393   MS_EXCEPTION_IF_NULL(outputs[0]);
394   size_t kernel_info_cnt = outputs[0]->dtypes().size();
395 
396   for (const auto &output : outputs) {
397     MS_EXCEPTION_IF_NULL(output);
398     if (output_idx >= real_output_num) {
399       MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
400       continue;
401     }
402     size_t output_num = 0;
403     if (output->param_type() == "dynamic") {
404       if (outputs.size() > 1) {
405         MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
406       }
407       output_num = real_output_num;
408     } else if (output->param_type() == "required") {
409       output_num = 1;
410     } else {
411       if (output_idx < real_output_num) {
412         MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
413         output_num = 1;
414       }
415     }
416 
417     for (size_t i = 0; i < output_num; i++) {
418       std::vector<std::string> dtypes = output->dtypes();
419       std::vector<std::string> formats = output->formats();
420       std::vector<std::string> object_types = output->object_types();
421       if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt ||
422           object_types.size() != kernel_info_cnt) {
423         MS_LOG(DEBUG)
424           << "Set output kernel builder info failed, dtyps size, formats size and object_types size are not "
425              "same. dtypes size: "
426           << dtypes.size() << ", formats size : " << formats.size() << ", object_types size: " << object_types.size();
427         return false;
428       }
429       auto type_id = DtypeToTypeId(dtypes[builder_idex]);
430       outputs_device_type.push_back(type_id);
431       outputs_format.push_back(formats[builder_idex]);
432       outputs_object_type.push_back(StringToKernelObjectType(object_types[builder_idex]));
433       output_idx++;
434     }
435   }
436 
437   builder->SetOutputsFormat(outputs_format);
438   builder->SetOutputsDeviceType(outputs_device_type);
439   builder->SetOutputsKernelObjectType(outputs_object_type);
440   return true;
441 }
442 
SetKernelBuildInfo(const std::vector<std::string> & input_formats,const std::vector<TypeId> & input_types,const std::vector<std::string> & output_formats,const std::vector<TypeId> & output_types,const CNodePtr & kernel_node)443 void SetKernelBuildInfo(const std::vector<std::string> &input_formats, const std::vector<TypeId> &input_types,
444                         const std::vector<std::string> &output_formats, const std::vector<TypeId> &output_types,
445                         const CNodePtr &kernel_node) {
446   MS_EXCEPTION_IF_NULL(kernel_node);
447   if (kernel_node->kernel_info() == nullptr) {
448     kernel_node->set_kernel_info(std::make_shared<device::KernelInfo>());
449   }
450   if (!kernel_node->kernel_info()->has_build_info()) {
451     AnfAlgo::SetSelectKernelBuildInfo(std::make_shared<kernel::KernelBuildInfo>(), kernel_node.get());
452   }
453   auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
454   build_info->SetInputsFormat(input_formats);
455   build_info->SetInputsDeviceType(input_types);
456   build_info->SetOutputsFormat(output_formats);
457   build_info->SetOutputsDeviceType(output_types);
458 }
459 
SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> & builder,Processor processor,const std::shared_ptr<const OpInfo> & op_info_ptr)460 void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
461                         const std::shared_ptr<const OpInfo> &op_info_ptr) {
462   MS_EXCEPTION_IF_NULL(builder);
463   MS_EXCEPTION_IF_NULL(op_info_ptr);
464   builder->SetProcessor(processor);
465   auto imply_type = op_info_ptr->imply_type();
466   switch (imply_type) {
467     case kImplyAKG:
468       builder->SetKernelType(AKG_KERNEL);
469       break;
470     case kImplyTBE:
471       builder->SetKernelType(TBE_KERNEL);
472       break;
473     case kImplyGPU:
474       builder->SetKernelType(GPU_KERNEL);
475       break;
476     case kImplyCPU:
477       builder->SetKernelType(CPU_KERNEL);
478       break;
479     case kImplyAICPU:
480       builder->SetKernelType(AICPU_KERNEL);
481       break;
482     case kImplyBISHENG:
483       builder->SetKernelType(BISHENG_KERNEL);
484       break;
485     default:
486       MS_LOG(EXCEPTION) << "Unknown Imply Type.";
487       break;
488   }
489 }
490 
ParseMetadata(const CNodePtr & kernel_node,const std::shared_ptr<const OpInfo> & op_info_ptr,Processor processor,std::vector<std::shared_ptr<KernelBuildInfo>> * const kernel_info_list)491 bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
492                    std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
493   MS_EXCEPTION_IF_NULL(kernel_node);
494   MS_EXCEPTION_IF_NULL(op_info_ptr);
495   MS_EXCEPTION_IF_NULL(kernel_info_list);
496   size_t real_input_num = AnfAlgo::GetInputElementNum(kernel_node);
497   size_t real_output_num = AnfAlgo::GetOutputElementNum(kernel_node);
498   std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
499   std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
500   std::vector<int64_t> dyn_input_sizes;
501   auto primitive = common::AnfAlgo::GetCNodePrimitive(kernel_node);
502   MS_EXCEPTION_IF_NULL(primitive);
503   auto op_name = common::AnfAlgo::GetCNodeName(kernel_node);
504   if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
505     dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr("dyn_input_sizes"));
506   }
507   if (dyn_input_sizes.empty() && inputs.size() < real_input_num) {
508     MS_LOG(WARNING) << "The size of inputs in OpIOInfo should be great than real input. Inputs size in OpIOInfo:"
509                     << inputs.size() << ", real input num: " << real_input_num
510                     << ", node: " << kernel_node->fullname_with_scope();
511     return false;
512   }
513   if (inputs.size() > 0) {
514     if (inputs[0] == nullptr) {
515       MS_LOG(INTERNAL_EXCEPTION) << "Inputs[0] is nullptr. Op name: " << op_name;
516     }
517     size_t kernel_info_cnt = inputs[0]->dtypes().size();
518     for (size_t j = 0; j < kernel_info_cnt; j++) {
519       auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
520       MS_EXCEPTION_IF_NULL(builder);
521       SetKernelBuildInfo(builder, processor, op_info_ptr);
522 
523       if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
524         MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed. Op name: " << op_name;
525         return false;
526       }
527 
528       if (outputs.size() > 0) {
529         if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
530           MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
531           return false;
532         }
533       }
534 
535       kernel_info_list->push_back(builder->Build());
536     }
537   } else if (outputs.size() > 0) {
538     if (outputs[0] == nullptr) {
539       MS_LOG(INTERNAL_EXCEPTION) << "Outputs[0] is nullptr. Op name: " << op_name;
540     }
541     size_t kernel_info_cnt = outputs[0]->dtypes().size();
542     for (size_t j = 0; j < kernel_info_cnt; j++) {
543       auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
544       MS_EXCEPTION_IF_NULL(builder);
545       SetKernelBuildInfo(builder, processor, op_info_ptr);
546 
547       if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
548         MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
549         return false;
550       }
551 
552       kernel_info_list->push_back(builder->Build());
553     }
554   } else {
555     if (processor == AICPU) {
556       auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
557       MS_EXCEPTION_IF_NULL(builder);
558       SetKernelBuildInfo(builder, processor, op_info_ptr);
559       kernel_info_list->push_back(builder->Build());
560     }
561   }
562   return true;
563 }
564 
SaveJsonInfo(const std::string & json_name,const std::string & info,const std::string & base_path)565 void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path) {
566   std::string path = base_path + json_name + kInfoSuffix;
567   auto realpath = Common::CreatePrefixPath(path, true);
568   if (!realpath.has_value()) {
569     MS_LOG(ERROR) << "Get real path failed, path=" << path;
570     return;
571   }
572   ChangeFileMode(realpath.value(), S_IWUSR);
573   std::ofstream filewrite(realpath.value());
574   if (!filewrite.is_open()) {
575     MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!";
576     return;
577   }
578   filewrite << info << std::endl;
579   filewrite.close();
580   ChangeFileMode(realpath.value(), S_IRUSR);
581 }
582 
GetProcessor(const string & processor)583 Processor GetProcessor(const string &processor) {
584   if (processor == kProcessorAiCore) {
585     return Processor::AICORE;
586   }
587   if (processor == kProcessorAiCpu) {
588     return Processor::AICPU;
589   }
590   if (processor == kProcessorCuda) {
591     return Processor::CUDA;
592   }
593   MS_LOG(DEBUG) << "Unknown processor type.";
594   return Processor::UNKNOWN;
595 }
596 
GetProcessor(const AnfNodePtr & anf_node)597 std::string GetProcessor(const AnfNodePtr &anf_node) {
598   MS_EXCEPTION_IF_NULL(anf_node);
599   std::string device;
600   switch (AnfAlgo::GetProcessor(anf_node)) {
601     case Processor::AICORE:
602       device = kProcessorAiCore;
603       break;
604 
605     case Processor::AICPU:
606       device = kProcessorAiCpu;
607       break;
608 
609     case Processor::CUDA:
610       device = kProcessorCuda;
611       break;
612 
613     default:
614       MS_LOG(DEBUG) << "Unknown processor type.";
615       break;
616   }
617   return device;
618 }
619 
GetOutputIndex(const std::vector<AnfNodePtr> & node_list,const std::vector<AnfNodePtr> & input_list,const std::vector<AnfNodePtr> & output_list)620 std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
621                                                           const std::vector<AnfNodePtr> &input_list,
622                                                           const std::vector<AnfNodePtr> &output_list) {
623   std::vector<std::pair<AnfNodePtr, size_t>> output_index;
624   for (size_t i = 0; i < output_list.size(); ++i) {
625     auto const &output = output_list[i];
626     MS_EXCEPTION_IF_NULL(output);
627     bool found = false;
628     auto pree_node = common::AnfAlgo::VisitKernel(output, 0);
629     auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
630     if (pos != std::end(node_list)) {
631       output_index.push_back(pree_node);
632       continue;
633     }
634     auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
635     if (ret != std::end(input_list)) {
636       output_index.push_back(std::make_pair(pree_node.first, 0));
637       found = true;
638     }
639     if (!found) {
640       MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
641                                   << output->func_graph()->ToString() << "] found no related kernel info.";
642     }
643   }
644   return output_index;
645 }
646 
GetValidKernelNodes(const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * node_list)647 void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
648   MS_EXCEPTION_IF_NULL(node_list);
649   MS_EXCEPTION_IF_NULL(func_graph);
650   std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
651   for (auto const &node : node_lists) {
652     if (!AnfUtils::IsRealKernel(node) || !node->isa<CNode>()) {
653       continue;
654     }
655     auto cnode = node->cast<CNodePtr>();
656     MS_EXCEPTION_IF_NULL(cnode);
657     if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
658       node_list->push_back(node);
659     }
660   }
661 }
662 
GetValidKernelNodes(const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * node_list,std::vector<AnfNodePtr> * input_list,std::vector<AnfNodePtr> * output_list)663 void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
664                          std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
665   MS_EXCEPTION_IF_NULL(func_graph);
666   MS_EXCEPTION_IF_NULL(node_list);
667   MS_EXCEPTION_IF_NULL(input_list);
668 
669   GetValidKernelNodes(func_graph, node_list);
670 
671   auto parameters = func_graph->parameters();
672   (void)input_list->insert(input_list->cbegin(), parameters.begin(), parameters.end());
673 
674   GetFuncGraphOutputNodes(func_graph, output_list);
675 }
676 
GetFuncGraphOutputNodes(const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * output_list)677 void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list) {
678   MS_EXCEPTION_IF_NULL(func_graph);
679   MS_EXCEPTION_IF_NULL(output_list);
680   auto func_output = func_graph->output();
681   MS_EXCEPTION_IF_NULL(func_output);
682   if (func_output->isa<CNode>()) {
683     // multi output.
684     auto cnode = func_output->cast<CNodePtr>();
685     MS_EXCEPTION_IF_NULL(cnode);
686     auto input0 = cnode->input(kAnfPrimitiveIndex);
687     MS_EXCEPTION_IF_NULL(input0);
688     if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
689       for (size_t input_idx = 1; input_idx < cnode->size(); ++input_idx) {
690         auto input_node = cnode->input(input_idx);
691         MS_EXCEPTION_IF_NULL(input_node);
692         if (input_node->isa<CNode>() && common::AnfAlgo::GetInputTensorNum(input_node) == 0) {
693           continue;
694         }
695         output_list->push_back(common::AnfAlgo::VisitKernel(input_node, 0).first);
696       }
697     } else {
698       // single output.
699       output_list->push_back(common::AnfAlgo::VisitKernel(func_output, 0).first);
700     }
701   } else {
702     // single output.
703     output_list->push_back(common::AnfAlgo::VisitKernel(func_output, 0).first);
704   }
705 }
706 
IsWeightBoundary(const AnfNodePtr & node)707 bool IsWeightBoundary(const AnfNodePtr &node) {
708   if (node->isa<ValueNode>()) {
709     return true;
710   }
711   if (node->isa<Parameter>() && common::AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
712     return true;
713   }
714   return false;
715 }
716 
GetReduceAttrAxis(const CNodePtr & cnode)717 std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
718   if (common::AnfAlgo::GetInputTensorNum(cnode) != 1 || AnfAlgo::GetOutputElementNum(cnode) != 1) {
719     MS_LOG(INTERNAL_EXCEPTION) << "The reduce node [" << cnode->DebugString()
720                                << "] is not single input or single output." << trace::DumpSourceLines(cnode);
721   }
722   auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
723   MS_EXCEPTION_IF_NULL(primitive);
724   auto axis_attr = primitive->GetAttr(kAxis);
725   if (axis_attr == nullptr) {
726     MS_LOG(ERROR) << "This node doesn't have axis attr. Node info [" << cnode->DebugString() << "]";
727     return {};
728   }
729   std::vector<int64_t> axis_list;
730   if (axis_attr->isa<Int64Imm>()) {
731     (void)axis_list.emplace_back(GetValue<int64_t>(axis_attr));
732   } else {
733     axis_list = GetValue<std::vector<int64_t>>(axis_attr);
734   }
735   return axis_list;
736 }
737 
GetProcessorFromContext()738 Processor GetProcessorFromContext() {
739   kernel::Processor processor = kernel::Processor::UNKNOWN;
740   auto context_ptr = MsContext::GetInstance();
741   MS_EXCEPTION_IF_NULL(context_ptr);
742   auto device_info = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
743   if (device_info == kGPUDevice) {
744     processor = kernel::Processor::CUDA;
745   } else if (device_info == kAscendDevice) {
746     processor = kernel::Processor::AICORE;
747   } else if (device_info == kCPUDevice) {
748     processor = kernel::Processor::CPU;
749   }
750   return processor;
751 }
752 
GetStrProcessorFromContext()753 std::string GetStrProcessorFromContext() {
754   auto processor = GetProcessorFromContext();
755   string str_processor = kernel::kProcessorUnknown;
756   if (processor == kernel::Processor::CUDA) {
757     str_processor = kernel::kProcessorCuda;
758   } else if (processor == kernel::Processor::AICORE) {
759     str_processor = kernel::kProcessorAiCore;
760   } else if (processor == kernel::Processor::CPU) {
761     str_processor = kernel::kProcessorCpu;
762   }
763   return str_processor;
764 }
765 
GetShapeSize(const ShapeVector & shape,const TypePtr & type_ptr,int64_t * size_i)766 bool GetShapeSize(const ShapeVector &shape, const TypePtr &type_ptr, int64_t *size_i) {
767   MS_EXCEPTION_IF_NULL(type_ptr);
768   size_t type_byte = GetTypeByte(type_ptr);
769   if (type_byte == 0) {
770     return false;
771   }
772   for (size_t j = 0; j < shape.size(); j++) {
773     if (shape[j] <= 0) {
774       MS_LOG(DEBUG) << "shape[" << shape << "] has invalid value(less equal 0), set size to 0";
775       size_i[0] = 0;
776       return true;
777     }
778     size_i[0] = LongMulWithOverflowCheck(size_i[0], shape[j]);
779   }
780   size_i[0] = LongMulWithOverflowCheck(size_i[0], SizeToInt(type_byte));
781   return true;
782 }
783 
IsDynamicParamKernel(const std::string & op_name)784 bool IsDynamicParamKernel(const std::string &op_name) {
785   const auto &op_info = kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kImplyCPU);
786   constexpr auto kParamDynamic = "dynamic";
787 
788   if (op_info == nullptr) {
789     return false;
790   }
791 
792   const auto &input_io_info = op_info->inputs_ptr();
793   if (input_io_info.size() != 1 || input_io_info[0]->param_type() != kParamDynamic) {
794     return false;
795   }
796 
797   const auto &output_io_info = op_info->outputs_ptr();
798   if (output_io_info.size() != 1 || output_io_info[0]->param_type() != kParamDynamic) {
799     return false;
800   }
801 
802   return true;
803 }
804 
SelectKernelByObjectType(const CNodePtr & kernel_node,const std::vector<KernelAttr> & registered_kernel_attrs,std::vector<KernelAttr> * selected_kernel_attrs)805 bool SelectKernelByObjectType(const CNodePtr &kernel_node, const std::vector<KernelAttr> &registered_kernel_attrs,
806                               std::vector<KernelAttr> *selected_kernel_attrs) {
807   MS_EXCEPTION_IF_NULL(kernel_node);
808   MS_EXCEPTION_IF_NULL(selected_kernel_attrs);
809   const auto &inputs_object_types = AnfAlgo::GetAllInputObjectType(kernel_node);
810   const auto &output_object_types = AnfAlgo::GetAllOutputObjectType(kernel_node);
811 
812   // 1. Try match all object type firstly.
813   for (auto &cur_kernel_attr : registered_kernel_attrs) {
814     const auto &[input_data_types, output_data_types] = GetInOutDataTypesFromKernelAttr(cur_kernel_attr);
815     if (IsObjectTypeStrictlyMatched(inputs_object_types, input_data_types) &&
816         IsObjectTypeStrictlyMatched(output_object_types, output_data_types)) {
817       (void)selected_kernel_attrs->emplace_back(cur_kernel_attr);
818     }
819   }
820   if (!selected_kernel_attrs->empty()) {
821     return true;
822   }
823 
824   // 2. Precise matching failed, try fuzzy one again.
825   auto input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
826   auto output_num = AnfAlgo::GetOutputElementNum(kernel_node);
827   for (auto &cur_kernel_attr : registered_kernel_attrs) {
828     const auto &[input_data_types, output_data_types] = GetInOutDataTypesFromKernelAttr(cur_kernel_attr);
829     auto all_same = cur_kernel_attr.GetAllSame();
830     if (IsObjectTypeWeaklyMatched(inputs_object_types, input_data_types, all_same, input_num) &&
831         IsObjectTypeWeaklyMatched(output_object_types, output_data_types, all_same, output_num)) {
832       (void)selected_kernel_attrs->emplace_back(cur_kernel_attr);
833     }
834   }
835 
836   return (!selected_kernel_attrs->empty());
837 }
838 
KernelObjectTypeNotSupportWarning(const CNodePtr & kernel_node)839 std::pair<std::string, ExceptionType> KernelObjectTypeNotSupportWarning(const CNodePtr &kernel_node) {
840   MS_EXCEPTION_IF_NULL(kernel_node);
841   auto GetObjectTypeStr = [](const std::vector<TypeId> &object_types) {
842     std::vector<std::string> object_type_strs;
843     (void)std::transform(object_types.begin(), object_types.end(), std::back_inserter(object_type_strs), TypeIdLabel);
844     return std::accumulate(object_type_strs.begin(), object_type_strs.end(), std::string(),
845                            [](const std::string &x, const std::string &y) { return x.empty() ? y : x + ", " + y; });
846   };
847   const std::string warn_str = std::string(kKernelObjectTypeNotSupportedStr) + ": unsupported kernel object type for " +
848                                kernel_node->fullname_with_scope() + " with inputs (" +
849                                GetObjectTypeStr(AnfAlgo::GetAllInputObjectType(kernel_node)) + "), outputs (" +
850                                GetObjectTypeStr(AnfAlgo::GetAllOutputObjectType(kernel_node)) + ").";
851   return {warn_str, TypeError};
852 }
853 
IsKernelObjectTypeNotSupportedError(const std::string & error_str)854 bool IsKernelObjectTypeNotSupportedError(const std::string &error_str) {
855   return error_str.find(kKernelObjectTypeNotSupportedStr) != std::string::npos;
856 }
857 
StringToKernelObjectType(const std::string & object_type)858 KernelObjectType StringToKernelObjectType(const std::string &object_type) {
859   static const std::unordered_map<std::string, KernelObjectType> object_type_maps = {
860     {"unknown", KernelObjectType::UNKNOWN_TYPE},
861     {"tensor", KernelObjectType::TENSOR},
862     {"scalar", KernelObjectType::SCALAR},
863     {"tuple", KernelObjectType::TUPLE},
864     {"tuple_unfold", KernelObjectType::TUPLE_UNFOLD},
865   };
866   auto iter = object_type_maps.find(object_type);
867   if (iter == object_type_maps.end()) {
868     MS_LOG(EXCEPTION) << "Illegal input object type: " << object_type;
869   }
870   return iter->second;
871 }
872 
UnfoldKernelBuildInfo(const CNodePtr & kernel_node)873 void UnfoldKernelBuildInfo(const CNodePtr &kernel_node) {
874   MS_EXCEPTION_IF_NULL(kernel_node);
875   auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
876   auto input_num = kernel_build_info->GetInputNum();
877   auto output_num = kernel_build_info->GetOutputNum();
878   if (input_num == 0 && output_num == 0) {
879     return;
880   }
881   const auto &input_kernel_object_types = kernel_build_info->GetAllInputKernelObjectTypes();
882   const auto &output_kernel_object_types = kernel_build_info->GetAllOutputKernelObjectTypes();
883   const auto &input_dtypes = kernel_build_info->GetAllInputDeviceTypes();
884   const auto &output_dtypes = kernel_build_info->GetAllOutputDeviceTypes();
885   const auto &input_formats = kernel_build_info->GetAllInputFormats();
886   const auto &output_formats = kernel_build_info->GetAllOutputFormats();
887 
888   std::vector<TypeId> unfold_input_dtypes;
889   std::vector<TypeId> unfold_output_dtypes;
890   std::vector<std::string> unfold_input_formats;
891   std::vector<std::string> unfold_output_formats;
892   auto Append = [&](bool in_or_out, size_t index) {
893     if (in_or_out) {
894       MS_EXCEPTION_IF_CHECK_FAIL((input_num > index), "Input index is out of range.");
895       unfold_input_dtypes.push_back(input_dtypes[index]);
896       unfold_input_formats.push_back(input_formats[index]);
897     } else {
898       MS_EXCEPTION_IF_CHECK_FAIL((output_num > index), "Output index is out of range.");
899       unfold_output_dtypes.push_back(output_dtypes[index]);
900       unfold_output_formats.push_back(output_formats[index]);
901     }
902   };
903   auto RepeatAppend = [&](bool in_or_out, size_t index, size_t times) {
904     while (times > 0) {
905       Append(in_or_out, index);
906       times--;
907     }
908   };
909 
910   for (size_t i = 0; i < input_kernel_object_types.size(); ++i) {
911     if (input_kernel_object_types[i] == kernel::KernelObjectType::TUPLE_UNFOLD) {
912       auto input_node = common::AnfAlgo::GetInputNode(kernel_node, i);
913       auto unfold_num = GetOutputNum(input_node);
914       MS_LOG(DEBUG) << kernel_node->fullname_with_scope() << " input idnex:" << i << " unfold num:" << unfold_num;
915       RepeatAppend(true, i, unfold_num);
916     } else {
917       Append(true, i);
918     }
919   }
920 
921   for (size_t i = 0; i < output_kernel_object_types.size(); ++i) {
922     if (output_kernel_object_types[i] == kernel::KernelObjectType::TUPLE_UNFOLD) {
923       auto unfold_num = GetOutputNum(kernel_node);
924       MS_LOG(DEBUG) << kernel_node->fullname_with_scope() << " output idnex:" << i << " unfold num:" << unfold_num;
925       // Multiple outputs are expanded in the kernel attr(For example BatchNorm op).
926       if (output_num == unfold_num) {
927         for (size_t j = 0; j < unfold_num; ++j) {
928           Append(false, j);
929         }
930       } else {
931         RepeatAppend(false, i, unfold_num);
932       }
933     } else {
934       Append(false, i);
935     }
936   }
937 
938   SetKernelBuildInfo(unfold_input_formats, unfold_input_dtypes, unfold_output_formats, unfold_output_dtypes,
939                      kernel_node);
940 }
941 
CalOutputTupleSize(const AnfNodePtr & node)942 int64_t CalOutputTupleSize(const AnfNodePtr &node) {
943   MS_EXCEPTION_IF_NULL(node);
944   bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimBpropCut);
945   bool skip = (is_bprop_cut && node->abstract()->isa<abstract::AbstractSparseTensor>());
946   if (skip || !common::AnfAlgo::IsTupleOutput(node)) {
947     return -1;
948   }
949   const auto &real_node = common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
950   auto build_info = AnfAlgo::GetSelectKernelBuildInfo(real_node);
951   if (build_info != nullptr) {
952     auto output_object = AnfAlgo::GetOutputKernelObjectType(real_node, 0);
953     if (output_object != kernel::KernelObjectType::TUPLE_UNFOLD) {
954       return -1;
955     }
956   }
957   auto output_size = static_cast<int64_t>(AnfAlgo::GetOutputElementNum(node));
958   if (node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
959     output_size = 0;
960     auto make_tuple = node->cast<CNodePtr>();
961     size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
962     for (size_t j = 0; j < tuple_input_num; ++j) {
963       // using for graph kernel
964       auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
965       // Handle tuple nested scenes.
966       if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
967         output_size += CalOutputTupleSize(dyn_input_node);
968       } else {
969         output_size++;
970       }
971     }
972   }
973   return output_size == 0 ? -1 : output_size;
974 }
975 
SetDynamicInputSizeAttr(const CNodePtr & cnode)976 void SetDynamicInputSizeAttr(const CNodePtr &cnode) {
977   MS_EXCEPTION_IF_NULL(cnode);
978   if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) ||
979       common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial)) {
980     return;
981   }
982   std::vector<int64_t> dyn_input_sizes;
983   auto input_obj_types = AnfAlgo::GetInputKernelObjectTypes(cnode);
984   size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
985   for (size_t i = 0; i < input_num; ++i) {
986     if (i < input_obj_types.size() && input_obj_types[i] == kernel::KernelObjectType::TUPLE_UNFOLD) {
987       auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
988       dyn_input_sizes.push_back(CalOutputTupleSize(input_node));
989     } else {
990       dyn_input_sizes.push_back(-1);
991     }
992   }
993   if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
994     common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
995   }
996 }
997 
AbstractArgsFromCNode(const CNodePtr & cnode)998 KernelArgs AbstractArgsFromCNode(const CNodePtr &cnode) {
999   MS_EXCEPTION_IF_NULL(cnode);
1000   auto [input_tensors, output_tensors] = AbstractInOutFromCNode(cnode);
1001   KernelArgs args = {input_tensors, output_tensors};
1002   return args;
1003 }
1004 
CreateOperatorByCNode(const CNodePtr & cnode)1005 BaseOperatorPtr CreateOperatorByCNode(const CNodePtr &cnode) {
1006   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1007   if (prim == nullptr) {
1008     return nullptr;
1009   }
1010   auto kernel_name = prim->name();
1011   MS_LOG(DEBUG) << "Create operator " << kernel_name;
1012   auto ori_kernel_name = kernel_name;
1013   if (prim->HasAttr(kAttrMeOpName)) {
1014     ori_kernel_name = GetValue<std::string>(prim->GetAttr(kAttrMeOpName));
1015   }
1016   AdditionalAttrProcess(prim, cnode);
1017 
1018   static auto operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
1019   auto it = operator_fns.find(ori_kernel_name);
1020   if (it == operator_fns.end()) {
1021     MS_LOG(DEBUG) << "Cannot create BaseOperator for " << ori_kernel_name;
1022     return nullptr;
1023   }
1024   auto base_operator = it->second(prim);
1025   return base_operator;
1026 }
1027 
GetArgsFromCNode(const CNodePtr & cnode)1028 std::shared_ptr<KernelArgs> GetArgsFromCNode(const CNodePtr &cnode) {
1029   MS_EXCEPTION_IF_NULL(cnode);
1030   auto args = cnode->user_data<KernelArgs>();
1031   return args;
1032 }
1033 
GetDependValueByConstTensor(const AnfNodePtr & input_node,const std::string & cnode_name,size_t i)1034 tensor::TensorPtr GetDependValueByConstTensor(const AnfNodePtr &input_node, const std::string &cnode_name, size_t i) {
1035   MS_EXCEPTION_IF_NULL(input_node);
1036   auto value_node = input_node->cast<ValueNodePtr>();
1037   MS_EXCEPTION_IF_NULL(value_node);
1038   auto value = value_node->value();
1039   MS_EXCEPTION_IF_NULL(value);
1040   if (value->isa<tensor::Tensor>()) {
1041     return value->cast<tensor::TensorPtr>();
1042   } else if (value->isa<Scalar>()) {
1043     return ScalarToTensor(value->cast<ScalarPtr>());
1044   }
1045   MS_EXCEPTION(ValueError) << "The CNode " << cnode_name << "'s input[" << i << "], must be tensor or scalar, but got "
1046                            << value->ToString();
1047 }
1048 
SetInputsByConstInputs(const CNodePtr & node,std::map<uint32_t,tensor::TensorPtr> * inputs_tensor_map)1049 void SetInputsByConstInputs(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *inputs_tensor_map) {
1050   std::set<int64_t> depend_list = abstract::GetValueDependArgIndices(node);
1051   auto input_size = common::AnfAlgo::GetInputTensorNum(node);
1052   auto cnode_name = node->fullname_with_scope();
1053   for (size_t i = 0; i < input_size; i++) {
1054     if (depend_list.find(i) != depend_list.end()) {
1055       auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i, false);
1056       auto real_input = input_node_with_index.first;
1057       if (!real_input->isa<ValueNode>()) {
1058         continue;
1059       }
1060       auto out_tensor = GetDependValueByConstTensor(real_input, cnode_name, i);
1061       MS_EXCEPTION_IF_NULL(inputs_tensor_map);
1062       auto ret2 = inputs_tensor_map->try_emplace(i, out_tensor);
1063       if (!ret2.second) {
1064         MS_LOG(INTERNAL_EXCEPTION) << "Insert map failed.";
1065       }
1066     }
1067   }
1068 }
1069 
SetInputsByDependMap(const std::map<uint32_t,tensor::TensorPtr> & depend_tensor_map,std::vector<KernelTensorPtr> * inputs,bool is_stored_in_device)1070 void SetInputsByDependMap(const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
1071                           std::vector<KernelTensorPtr> *inputs, bool is_stored_in_device) {
1072   MS_EXCEPTION_IF_NULL(inputs);
1073   for (const auto &[i, tensor] : depend_tensor_map) {
1074     if (i >= inputs->size()) {
1075       MS_LOG(EXCEPTION) << "Type to store the data to KernelTensor, expect less than" << inputs->size() << " but got "
1076                         << i;
1077     }
1078     MS_EXCEPTION_IF_NULL(inputs->at(i));
1079     MS_EXCEPTION_IF_NULL(tensor);
1080     auto address = std::make_shared<kernel::Address>(tensor->data_c(), tensor->Size());
1081     if (is_stored_in_device) {
1082       // Store the data address in device one for cpu.
1083       inputs->at(i)->SetData(address);
1084       continue;
1085     }
1086     inputs->at(i)->SetHostData(address);
1087   }
1088 }
1089 
SetArgsToCNode(const CNodePtr & cnode,const KernelArgs & args)1090 void SetArgsToCNode(const CNodePtr &cnode, const KernelArgs &args) {
1091   MS_EXCEPTION_IF_NULL(cnode);
1092   auto dst = cnode->user_data<KernelArgs>();
1093   if (dst == nullptr) {
1094     dst = std::make_shared<KernelArgs>();
1095     cnode->set_user_data<KernelArgs>(dst);
1096   }
1097   dst->inputs = args.inputs;
1098   dst->outputs = args.outputs;
1099   dst->depend_tensor_map = args.depend_tensor_map;
1100 }
1101 
UpdateNodeShape(const CNodePtr & cnode)1102 void UpdateNodeShape(const CNodePtr &cnode) {
1103   MS_EXCEPTION_IF_NULL(cnode);
1104   auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
1105   MS_EXCEPTION_IF_NULL(kernel_mod);
1106   if (!kernel_mod->IsNeedUpdateOutputShapeAndSize()) {
1107     return;
1108   }
1109 
1110   auto output_tensor = AnfAlgo::GetOrCreateAllOutputKernelTensors(cnode);
1111   auto input_tensor = AnfAlgo::GetOrCreateAllInputKernelTensors(cnode);
1112   kernel_mod->UpdateOutputShapeAndSize(input_tensor, output_tensor);
1113   if (output_tensor.empty()) {
1114     return;
1115   }
1116   std::vector<TypeId> type_ids;
1117   std::vector<ShapeVector> shapes;
1118   size_t output_num = output_tensor.size();
1119   for (size_t i = 0; i < output_num; ++i) {
1120     MS_EXCEPTION_IF_NULL(output_tensor[i]);
1121     auto out_shape = output_tensor[i]->GetShapeVector();
1122     if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t dim) { return dim < 0; })) {
1123       MS_LOG(ERROR) << "Retrieve invalid output shape " << out_shape;
1124       return;
1125     }
1126     (void)shapes.emplace_back(std::move(out_shape));
1127     (void)type_ids.emplace_back(output_tensor[i]->dtype_id());
1128   }
1129   common::AnfAlgo::SetOutputInferTypeAndShape(type_ids, shapes, cnode.get(), true);
1130 }
1131 
1132 // In compile stage, run resize when kernel is not dynamic shape or has no value depend list.
CheckResizeCondition(const CNodePtr & node)1133 bool CheckResizeCondition(const CNodePtr &node) {
1134   MS_EXCEPTION_IF_NULL(node);
1135   MS_EXCEPTION_IF_NULL(node->input(0));
1136   if (!AnfAlgo::NodeValueIsFuncGraph(node->input(0))) {
1137     if (common::AnfAlgo::IsDynamicShape(node)) {
1138       MS_LOG(DEBUG) << "Skip resize for " << node->DebugString() << ", for reason is dynamic shape";
1139       return false;
1140     }
1141     if (common::AnfAlgo::IsDynamicValue(node)) {
1142       MS_LOG(DEBUG) << "Skip resize for " << node->DebugString() << ", for reason is dynamic value";
1143       return false;
1144     }
1145   }
1146 
1147   return true;
1148 }
1149 }  // namespace kernel
1150 }  // namespace mindspore
1151