1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "kernel/framework_utils.h"
18 #include <algorithm>
19 #include <map>
20 #include <set>
21 #include <utility>
22 #include "include/backend/anf_runtime_algorithm.h"
23 #include "include/common/utils/anfalgo.h"
24 #include "include/common/utils/convert_utils.h"
25 #include "kernel/common_utils.h"
26 #include "kernel/format_utils.h"
27 #include "kernel/oplib/oplib.h"
28 #include "mindapi/base/type_id.h"
29 #include "mindspore/ccsrc/include/common/debug/common.h"
30 #include "ops/array_op_name.h"
31 #include "ops/conv_pool_op_name.h"
32 #include "ops/framework_ops.h"
33 #include "ops/math_op_name.h"
34 #include "ops/random_op_name.h"
35 #include "ops/image_op_name.h"
36 #include "ops/nn_op_name.h"
37 #include "ops/nn_ops.h"
38 #include "ops/sequence_ops.h"
39 #include "utils/file_utils.h"
40 #include "utils/ms_context.h"
41 #include "utils/trace_base.h"
42
43 namespace mindspore {
44 namespace kernel {
45 namespace {
46 constexpr char kAxis[] = "axis";
47 constexpr char kOperatorOriginFormat[] = "operator_origin_format";
48 constexpr char kKernelObjectTypeNotSupportedStr[] = "KernelObjectTypeNotSupported";
49
GetValidShapeFromAbstract(const abstract::AbstractBasePtr & abs)50 abstract::BaseShapePtr GetValidShapeFromAbstract(const abstract::AbstractBasePtr &abs) {
51 MS_EXCEPTION_IF_NULL(abs);
52 // Other abstract class, such as AbstractCSRTensor and AbstractCOOTensor, is converted to AbstractTensor early time.
53 abstract::BaseShapePtr res_shape;
54 if (abs->isa<abstract::AbstractTensor>() || abs->isa<abstract::AbstractMapTensor>()) {
55 res_shape = abs->BuildShape();
56 } else if (abs->isa<abstract::AbstractScalar>()) {
57 res_shape = std::make_shared<abstract::Shape>(ShapeVector{});
58 } else {
59 MS_INTERNAL_EXCEPTION(TypeError) << "The abstract must be a Scalar or Tensor, but got " << abs->ToString();
60 }
61 return res_shape;
62 }
63
GetChildAbstract(const abstract::AbstractBasePtr & cur_abstract,size_t idx)64 abstract::AbstractBasePtr GetChildAbstract(const abstract::AbstractBasePtr &cur_abstract, size_t idx) {
65 MS_EXCEPTION_IF_NULL(cur_abstract);
66 abstract::AbstractBasePtr child_abs = cur_abstract;
67 if (cur_abstract->isa<abstract::AbstractTuple>()) {
68 auto abs_tuple = cur_abstract->Clone()->cast<abstract::AbstractTuplePtr>();
69 MS_EXCEPTION_IF_NULL(abs_tuple);
70 auto abs_element = abs_tuple->elements();
71 MS_EXCEPTION_IF_CHECK_FAIL((idx < abs_element.size()), "Index is out of range, idx:" + std::to_string(idx) +
72 " size:" + std::to_string(abs_element.size()) +
73 " abs:" + abs_tuple->ToString());
74 child_abs = abs_element.at(idx);
75 } else {
76 MS_EXCEPTION_IF_CHECK_FAIL(
77 (idx == 0), "Cannot get " + std::to_string(idx) + " child abstract from " + cur_abstract->ToString());
78 }
79
80 return child_abs;
81 }
82
CreateKernelTensor(const abstract::AbstractBasePtr & cur_abstract,const TypeId & real_type,size_t idx,const std::string & format_str,bool prev_node_has_getitem=false)83 KernelTensorPtr CreateKernelTensor(const abstract::AbstractBasePtr &cur_abstract, const TypeId &real_type, size_t idx,
84 const std::string &format_str, bool prev_node_has_getitem = false) {
85 MS_EXCEPTION_IF_NULL(cur_abstract);
86 abstract::AbstractBasePtr tag_abstract = nullptr;
87 abstract::AbstractBasePtr new_abstract = nullptr;
88 if (prev_node_has_getitem) {
89 tag_abstract = cur_abstract;
90 } else {
91 tag_abstract = GetChildAbstract(cur_abstract, idx);
92 }
93 TypePtr tag_type_ptr = TypeIdToType(real_type);
94
95 if (tag_abstract->isa<abstract::AbstractTensor>()) {
96 auto abstract_shape_ptr = GetValidShapeFromAbstract(tag_abstract);
97 new_abstract = std::make_shared<abstract::AbstractTensor>(tag_type_ptr, abstract_shape_ptr);
98 } else {
99 new_abstract = tag_abstract->Clone();
100 }
101 KernelTensorPtr res_tensor =
102 std::make_shared<KernelTensor>(new_abstract->GetShape(), new_abstract->GetType(), new_abstract->GetValue());
103 res_tensor->set_format(GetFormatFromStrToEnum(format_str));
104 return res_tensor;
105 }
106
AdditionalAttrProcess(const ops::PrimitiveCPtr & primc,const CNodePtr & cnode)107 void AdditionalAttrProcess(const ops::PrimitiveCPtr &primc, const CNodePtr &cnode) {
108 MS_EXCEPTION_IF_NULL(primc);
109 MS_EXCEPTION_IF_NULL(cnode);
110 mindspore::HashMap<std::string, ValuePtr> additional_attrs;
111 additional_attrs[kOperatorOriginFormat] = MakeValue(AnfAlgo::GetOriginDataFormat(cnode));
112 (void)primc->SetAttrs(additional_attrs);
113 }
114
CheckRealTupleFromCNode(const std::vector<mindspore::kernel::KernelObjectType> & input_obj_types,const size_t input_idx)115 bool CheckRealTupleFromCNode(const std::vector<mindspore::kernel::KernelObjectType> &input_obj_types,
116 const size_t input_idx) {
117 // if input_obj_types is empty, regard it as a Tensor by default.
118 if (input_obj_types.size() > input_idx && input_obj_types[input_idx] == KernelObjectType::TUPLE) {
119 return true;
120 }
121 return false;
122 }
123
124 using InOutKernelTensors = std::pair<std::vector<KernelTensorPtr>, std::vector<KernelTensorPtr>>;
AbstractInOutFromCNode(const CNodePtr & cnode)125 inline InOutKernelTensors AbstractInOutFromCNode(const CNodePtr &cnode) {
126 MS_EXCEPTION_IF_NULL(cnode);
127 // Makeup input KernelTensors, meta_types can be tensor, scalar, tuple, list.
128 std::vector<KernelTensorPtr> input_tensors;
129 auto real_input_types = AnfAlgo::GetAllInputDeviceTypes(cnode);
130 size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
131 for (size_t input_idx = 0; input_idx < input_num; ++input_idx) {
132 const auto &[prev_node, output_idx] = common::AnfAlgo::GetPrevNodeOutput(cnode, input_idx);
133 bool prev_node_has_getitem = common::AnfAlgo::IsPrevNodeHasTupleGetItem(cnode, input_idx);
134 auto prev_abstract = prev_node->abstract();
135 auto real_input_type = real_input_types[input_idx];
136 if (IsPrimitiveCNode(prev_node, prim::kPrimPyExecute)) {
137 real_input_type = common::AnfAlgo::GetOutputInferDataType(prev_node, 0);
138 MS_LOG(DEBUG) << "need changed type node:" << cnode->DebugString()
139 << "Real input type :" << TypeIdToType(real_input_type)->ToString();
140 }
141 auto format_str = AnfAlgo::GetInputFormat(cnode, input_idx);
142 auto input_tensor = CreateKernelTensor(prev_abstract, real_input_type, output_idx, format_str,
143 ((!prev_node_has_getitem) || common::AnfAlgo::IsDynamicSequence(prev_node)));
144 input_tensors.push_back(input_tensor);
145 }
146
147 // Makeup output tensors.
148 std::vector<KernelTensorPtr> output_tensors;
149 auto real_output_types = AnfAlgo::GetAllOutputDeviceTypes(cnode);
150 auto cur_abstract = cnode->abstract();
151 MS_EXCEPTION_IF_NULL(cur_abstract);
152 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
153 auto build_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
154 auto output_obj_types = build_info->GetAllOutputKernelObjectTypes();
155 for (size_t output_idx = 0; output_idx < output_num; ++output_idx) {
156 bool is_real_tuple_output = CheckRealTupleFromCNode(output_obj_types, output_idx);
157 auto real_output_type = real_output_types[output_idx];
158 if (IsPrimitiveCNode(cnode, prim::kPrimPyExecute)) {
159 real_output_type = common::AnfAlgo::GetOutputInferDataType(cnode, 0);
160 MS_LOG(DEBUG) << "need changed type node:" << cnode->DebugString()
161 << "Real output type :" << TypeIdToType(real_output_type)->ToString()
162 << " is dynamic len:" << common::AnfAlgo::IsDynamicSequence(cnode);
163 }
164 auto format_str = AnfAlgo::GetOutputFormat(cnode, output_idx);
165 auto output_tensor = CreateKernelTensor(cur_abstract, real_output_type, output_idx, format_str,
166 is_real_tuple_output || common::AnfAlgo::IsDynamicSequence(cnode));
167 output_tensors.push_back(output_tensor);
168 }
169 return std::make_pair(input_tensors, output_tensors);
170 }
171
IsObjectTypeStrictlyMatched(const std::vector<TypeId> & object_types,const std::vector<DataType> & kernel_data_types)172 bool IsObjectTypeStrictlyMatched(const std::vector<TypeId> &object_types,
173 const std::vector<DataType> &kernel_data_types) {
174 if (object_types.size() != kernel_data_types.size()) {
175 return false;
176 }
177
178 for (size_t i = 0; i < object_types.size(); i++) {
179 // For optional input, the real input object type can be a None.
180 if ((object_types[i] != kernel_data_types[i].object_type) &&
181 !(object_types[i] == kMetaTypeNone && kernel_data_types[i].is_optional)) {
182 return false;
183 }
184 }
185
186 return true;
187 }
188
IsObjectTypeWeaklyMatched(const std::vector<TypeId> & object_types,const std::vector<DataType> & kernel_data_types,bool all_same,size_t element_num)189 bool IsObjectTypeWeaklyMatched(const std::vector<TypeId> &object_types, const std::vector<DataType> &kernel_data_types,
190 bool all_same, size_t element_num) {
191 // 1. The size equal can trigger the kernel object backoff(For example Reshape op).
192 if (object_types.size() == kernel_data_types.size()) {
193 return true;
194 }
195
196 // 2. AllSame is the tupleUnfold type(For example Split/Addn op).
197 if (all_same) {
198 return true;
199 }
200
201 // 3. Multiple outputs are expanded in the kernel attr(For example BatchNorm op).
202 if (kernel_data_types.size() == element_num) {
203 return true;
204 }
205
206 return false;
207 }
208 } // namespace
209
GetInOutDataTypesFromKernelAttr(const KernelAttr & kernel_attr)210 std::pair<std::vector<DataType>, std::vector<DataType>> GetInOutDataTypesFromKernelAttr(const KernelAttr &kernel_attr) {
211 size_t input_attr_size = kernel_attr.GetInputSize();
212 std::vector<DataType> input_data_types;
213 for (size_t i = 0; i < input_attr_size; ++i) {
214 input_data_types.push_back(kernel_attr.GetInputAttr(i));
215 }
216
217 size_t output_attr_size = kernel_attr.GetOutputSize();
218 std::vector<DataType> output_data_types;
219 for (size_t i = 0; i < output_attr_size; ++i) {
220 output_data_types.push_back(kernel_attr.GetOutputAttr(i));
221 }
222
223 return std::make_pair(input_data_types, output_data_types);
224 }
GetCompilerCachePath()225 std::string GetCompilerCachePath() { return Common::GetUserDefineCachePath(); }
226
CheckCache(const std::string & kernel_name)227 bool CheckCache(const std::string &kernel_name) {
228 // check cache.
229 KernelMeta *bin_map = KernelMeta::GetInstance();
230 if (bin_map == nullptr) {
231 MS_LOG(DEBUG) << "Kernel cache is invalid, kernel_name: " << kernel_name;
232 return false;
233 }
234 std::string kernel_json = bin_map->Search(kernel_name);
235 bool ret = (!kernel_json.empty());
236 if (ret) {
237 MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered.";
238 } else {
239 MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered.";
240 }
241 return ret;
242 }
243
SearchCache(const std::string & kernel_name,const std::string & processor)244 KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
245 // search cache.
246 KernelMeta *bin_map = KernelMeta::GetInstance();
247 if (bin_map == nullptr) {
248 MS_LOG(DEBUG) << "kernel cache is invalid, kernel_name: " << kernel_name;
249 return nullptr;
250 }
251
252 std::string kernel_json = bin_map->Search(kernel_name);
253 if (!kernel_json.empty()) {
254 KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
255 // just a tmp solution.
256 if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
257 MS_LOG(ERROR) << "Read cache json and bin file failed[" << kernel_json << "].";
258 return nullptr;
259 } else {
260 return kernel_pack;
261 }
262 } else {
263 MS_LOG(INFO) << "The cache kernel not found[" << kernel_name << "].";
264 return nullptr;
265 }
266 }
267
InsertCache(const std::string & kernel_name,const std::string & processor)268 KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
269 MS_LOG(INFO) << "Insert cache for kernel:" << kernel_name << ", processr:" << processor;
270 KernelMeta *bin_map = KernelMeta::GetInstance();
271 if (bin_map == nullptr) {
272 MS_LOG(DEBUG) << "Kernel cache is invalid, kernel name :" << kernel_name;
273 return nullptr;
274 }
275 std::string kernel_json = bin_map->kernel_meta_path();
276 (void)kernel_json.append(kernel_name).append(kJsonSuffix);
277 KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
278 if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
279 MS_LOG(ERROR) << "Read json and bin file failed[" << kernel_json << "].";
280 return nullptr;
281 }
282 if (bin_map->Insert(kernel_name, kernel_json)) {
283 MS_LOG(INFO) << "Kernel insert cache success[" << kernel_json << "], kernel name[" << kernel_name << "].";
284 }
285 return kernel_pack;
286 }
287
Initialize(const std::string & backend)288 void KernelMeta::Initialize(const std::string &backend) {
289 auto config_path = GetCompilerCachePath();
290 kernel_meta_path_ = config_path + backend + std::string(kKernelMetaSuffix);
291 (void)(FileUtils::CreateNotExistDirs(kernel_meta_path_, true));
292 initialized_ = true;
293 }
294
Search(const std::string & kernel_name) const295 std::string KernelMeta::Search(const std::string &kernel_name) const {
296 if (!initialized_) {
297 return "";
298 }
299
300 auto iter = kernel_meta_map_.find(kernel_name);
301 if (iter == kernel_meta_map_.end()) {
302 return "";
303 } else {
304 return iter->second;
305 }
306 }
307
Insert(const std::string & kernel_name,const std::string & kernel_json)308 bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
309 if (!initialized_) {
310 return false;
311 }
312 kernel_meta_map_[kernel_name] = kernel_json;
313 return true;
314 }
315
SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> & inputs,size_t real_input_num,size_t builder_idex,const std::vector<int64_t> & dyn_input_sizes,const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> & builder)316 bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
317 size_t builder_idex, const std::vector<int64_t> &dyn_input_sizes,
318 const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
319 MS_EXCEPTION_IF_NULL(builder);
320
321 std::vector<TypeId> inputs_device_type;
322 std::vector<std::string> inputs_format;
323 std::vector<KernelObjectType> inputs_object_type;
324 size_t dyn_input_idx = 0;
325 size_t kernel_info_index = 0;
326 MS_EXCEPTION_IF_NULL(inputs[0]);
327 size_t kernel_info_cnt = inputs[0]->dtypes().size();
328
329 for (const auto &input : inputs) {
330 MS_EXCEPTION_IF_NULL(input);
331 std::string param_type = input->param_type();
332 std::vector<std::string> dtypes = input->dtypes();
333 std::vector<std::string> formats = input->formats();
334 std::vector<std::string> object_types = input->object_types();
335 if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt ||
336 object_types.size() != kernel_info_cnt) {
337 MS_LOG(DEBUG) << "Set input kernel builder info failed, dtyps size, formats size and object_types size are not "
338 "same. dtypes size: "
339 << dtypes.size() << ", formats size : " << formats.size()
340 << ", object_types size: " << object_types.size();
341 return false;
342 }
343
344 if (param_type == "dynamic") {
345 if (dyn_input_sizes.empty()) {
346 MS_LOG(DEBUG) << "Set input kernel builder info failed, dyn_input_sizes's size is 0 when param_type is dynamic";
347 return false;
348 }
349
350 for (int64_t t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
351 kernel_info_index++;
352 auto type_id = DtypeToTypeId(dtypes[builder_idex]);
353 inputs_device_type.push_back(type_id);
354 inputs_format.push_back(formats[builder_idex]);
355 inputs_object_type.push_back(StringToKernelObjectType(object_types[builder_idex]));
356 }
357 } else if (param_type == "required") {
358 kernel_info_index++;
359 auto type_id = DtypeToTypeId(dtypes[builder_idex]);
360 inputs_device_type.push_back(type_id);
361 inputs_format.push_back(formats[builder_idex]);
362 inputs_object_type.push_back(StringToKernelObjectType(object_types[builder_idex]));
363 } else {
364 if (kernel_info_index < real_input_num) {
365 MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
366 kernel_info_index++;
367 auto type_id = DtypeToTypeId(dtypes[builder_idex]);
368 inputs_device_type.push_back(type_id);
369 inputs_format.push_back(formats[builder_idex]);
370 inputs_object_type.push_back(StringToKernelObjectType(object_types[builder_idex]));
371 }
372 }
373 dyn_input_idx++;
374 }
375
376 builder->SetInputsDeviceType(inputs_device_type);
377 builder->SetInputsFormat(inputs_format);
378 builder->SetInputsKernelObjectType(inputs_object_type);
379
380 return true;
381 }
382
SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> & outputs,size_t builder_idex,const size_t & real_output_num,const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> & builder)383 bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
384 const size_t &real_output_num,
385 const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
386 // not now but in the next we need to support dynamic output case
387 MS_EXCEPTION_IF_NULL(builder);
388
389 size_t output_idx = 0;
390 std::vector<TypeId> outputs_device_type;
391 std::vector<std::string> outputs_format;
392 std::vector<KernelObjectType> outputs_object_type;
393 MS_EXCEPTION_IF_NULL(outputs[0]);
394 size_t kernel_info_cnt = outputs[0]->dtypes().size();
395
396 for (const auto &output : outputs) {
397 MS_EXCEPTION_IF_NULL(output);
398 if (output_idx >= real_output_num) {
399 MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
400 continue;
401 }
402 size_t output_num = 0;
403 if (output->param_type() == "dynamic") {
404 if (outputs.size() > 1) {
405 MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
406 }
407 output_num = real_output_num;
408 } else if (output->param_type() == "required") {
409 output_num = 1;
410 } else {
411 if (output_idx < real_output_num) {
412 MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
413 output_num = 1;
414 }
415 }
416
417 for (size_t i = 0; i < output_num; i++) {
418 std::vector<std::string> dtypes = output->dtypes();
419 std::vector<std::string> formats = output->formats();
420 std::vector<std::string> object_types = output->object_types();
421 if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt ||
422 object_types.size() != kernel_info_cnt) {
423 MS_LOG(DEBUG)
424 << "Set output kernel builder info failed, dtyps size, formats size and object_types size are not "
425 "same. dtypes size: "
426 << dtypes.size() << ", formats size : " << formats.size() << ", object_types size: " << object_types.size();
427 return false;
428 }
429 auto type_id = DtypeToTypeId(dtypes[builder_idex]);
430 outputs_device_type.push_back(type_id);
431 outputs_format.push_back(formats[builder_idex]);
432 outputs_object_type.push_back(StringToKernelObjectType(object_types[builder_idex]));
433 output_idx++;
434 }
435 }
436
437 builder->SetOutputsFormat(outputs_format);
438 builder->SetOutputsDeviceType(outputs_device_type);
439 builder->SetOutputsKernelObjectType(outputs_object_type);
440 return true;
441 }
442
SetKernelBuildInfo(const std::vector<std::string> & input_formats,const std::vector<TypeId> & input_types,const std::vector<std::string> & output_formats,const std::vector<TypeId> & output_types,const CNodePtr & kernel_node)443 void SetKernelBuildInfo(const std::vector<std::string> &input_formats, const std::vector<TypeId> &input_types,
444 const std::vector<std::string> &output_formats, const std::vector<TypeId> &output_types,
445 const CNodePtr &kernel_node) {
446 MS_EXCEPTION_IF_NULL(kernel_node);
447 if (kernel_node->kernel_info() == nullptr) {
448 kernel_node->set_kernel_info(std::make_shared<device::KernelInfo>());
449 }
450 if (!kernel_node->kernel_info()->has_build_info()) {
451 AnfAlgo::SetSelectKernelBuildInfo(std::make_shared<kernel::KernelBuildInfo>(), kernel_node.get());
452 }
453 auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
454 build_info->SetInputsFormat(input_formats);
455 build_info->SetInputsDeviceType(input_types);
456 build_info->SetOutputsFormat(output_formats);
457 build_info->SetOutputsDeviceType(output_types);
458 }
459
SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> & builder,Processor processor,const std::shared_ptr<const OpInfo> & op_info_ptr)460 void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
461 const std::shared_ptr<const OpInfo> &op_info_ptr) {
462 MS_EXCEPTION_IF_NULL(builder);
463 MS_EXCEPTION_IF_NULL(op_info_ptr);
464 builder->SetProcessor(processor);
465 auto imply_type = op_info_ptr->imply_type();
466 switch (imply_type) {
467 case kImplyAKG:
468 builder->SetKernelType(AKG_KERNEL);
469 break;
470 case kImplyTBE:
471 builder->SetKernelType(TBE_KERNEL);
472 break;
473 case kImplyGPU:
474 builder->SetKernelType(GPU_KERNEL);
475 break;
476 case kImplyCPU:
477 builder->SetKernelType(CPU_KERNEL);
478 break;
479 case kImplyAICPU:
480 builder->SetKernelType(AICPU_KERNEL);
481 break;
482 case kImplyBISHENG:
483 builder->SetKernelType(BISHENG_KERNEL);
484 break;
485 default:
486 MS_LOG(EXCEPTION) << "Unknown Imply Type.";
487 break;
488 }
489 }
490
ParseMetadata(const CNodePtr & kernel_node,const std::shared_ptr<const OpInfo> & op_info_ptr,Processor processor,std::vector<std::shared_ptr<KernelBuildInfo>> * const kernel_info_list)491 bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
492 std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
493 MS_EXCEPTION_IF_NULL(kernel_node);
494 MS_EXCEPTION_IF_NULL(op_info_ptr);
495 MS_EXCEPTION_IF_NULL(kernel_info_list);
496 size_t real_input_num = AnfAlgo::GetInputElementNum(kernel_node);
497 size_t real_output_num = AnfAlgo::GetOutputElementNum(kernel_node);
498 std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
499 std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
500 std::vector<int64_t> dyn_input_sizes;
501 auto primitive = common::AnfAlgo::GetCNodePrimitive(kernel_node);
502 MS_EXCEPTION_IF_NULL(primitive);
503 auto op_name = common::AnfAlgo::GetCNodeName(kernel_node);
504 if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
505 dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr("dyn_input_sizes"));
506 }
507 if (dyn_input_sizes.empty() && inputs.size() < real_input_num) {
508 MS_LOG(WARNING) << "The size of inputs in OpIOInfo should be great than real input. Inputs size in OpIOInfo:"
509 << inputs.size() << ", real input num: " << real_input_num
510 << ", node: " << kernel_node->fullname_with_scope();
511 return false;
512 }
513 if (inputs.size() > 0) {
514 if (inputs[0] == nullptr) {
515 MS_LOG(INTERNAL_EXCEPTION) << "Inputs[0] is nullptr. Op name: " << op_name;
516 }
517 size_t kernel_info_cnt = inputs[0]->dtypes().size();
518 for (size_t j = 0; j < kernel_info_cnt; j++) {
519 auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
520 MS_EXCEPTION_IF_NULL(builder);
521 SetKernelBuildInfo(builder, processor, op_info_ptr);
522
523 if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
524 MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed. Op name: " << op_name;
525 return false;
526 }
527
528 if (outputs.size() > 0) {
529 if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
530 MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
531 return false;
532 }
533 }
534
535 kernel_info_list->push_back(builder->Build());
536 }
537 } else if (outputs.size() > 0) {
538 if (outputs[0] == nullptr) {
539 MS_LOG(INTERNAL_EXCEPTION) << "Outputs[0] is nullptr. Op name: " << op_name;
540 }
541 size_t kernel_info_cnt = outputs[0]->dtypes().size();
542 for (size_t j = 0; j < kernel_info_cnt; j++) {
543 auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
544 MS_EXCEPTION_IF_NULL(builder);
545 SetKernelBuildInfo(builder, processor, op_info_ptr);
546
547 if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
548 MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
549 return false;
550 }
551
552 kernel_info_list->push_back(builder->Build());
553 }
554 } else {
555 if (processor == AICPU) {
556 auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
557 MS_EXCEPTION_IF_NULL(builder);
558 SetKernelBuildInfo(builder, processor, op_info_ptr);
559 kernel_info_list->push_back(builder->Build());
560 }
561 }
562 return true;
563 }
564
SaveJsonInfo(const std::string & json_name,const std::string & info,const std::string & base_path)565 void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path) {
566 std::string path = base_path + json_name + kInfoSuffix;
567 auto realpath = Common::CreatePrefixPath(path, true);
568 if (!realpath.has_value()) {
569 MS_LOG(ERROR) << "Get real path failed, path=" << path;
570 return;
571 }
572 ChangeFileMode(realpath.value(), S_IWUSR);
573 std::ofstream filewrite(realpath.value());
574 if (!filewrite.is_open()) {
575 MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!";
576 return;
577 }
578 filewrite << info << std::endl;
579 filewrite.close();
580 ChangeFileMode(realpath.value(), S_IRUSR);
581 }
582
GetProcessor(const string & processor)583 Processor GetProcessor(const string &processor) {
584 if (processor == kProcessorAiCore) {
585 return Processor::AICORE;
586 }
587 if (processor == kProcessorAiCpu) {
588 return Processor::AICPU;
589 }
590 if (processor == kProcessorCuda) {
591 return Processor::CUDA;
592 }
593 MS_LOG(DEBUG) << "Unknown processor type.";
594 return Processor::UNKNOWN;
595 }
596
GetProcessor(const AnfNodePtr & anf_node)597 std::string GetProcessor(const AnfNodePtr &anf_node) {
598 MS_EXCEPTION_IF_NULL(anf_node);
599 std::string device;
600 switch (AnfAlgo::GetProcessor(anf_node)) {
601 case Processor::AICORE:
602 device = kProcessorAiCore;
603 break;
604
605 case Processor::AICPU:
606 device = kProcessorAiCpu;
607 break;
608
609 case Processor::CUDA:
610 device = kProcessorCuda;
611 break;
612
613 default:
614 MS_LOG(DEBUG) << "Unknown processor type.";
615 break;
616 }
617 return device;
618 }
619
GetOutputIndex(const std::vector<AnfNodePtr> & node_list,const std::vector<AnfNodePtr> & input_list,const std::vector<AnfNodePtr> & output_list)620 std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
621 const std::vector<AnfNodePtr> &input_list,
622 const std::vector<AnfNodePtr> &output_list) {
623 std::vector<std::pair<AnfNodePtr, size_t>> output_index;
624 for (size_t i = 0; i < output_list.size(); ++i) {
625 auto const &output = output_list[i];
626 MS_EXCEPTION_IF_NULL(output);
627 bool found = false;
628 auto pree_node = common::AnfAlgo::VisitKernel(output, 0);
629 auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
630 if (pos != std::end(node_list)) {
631 output_index.push_back(pree_node);
632 continue;
633 }
634 auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
635 if (ret != std::end(input_list)) {
636 output_index.push_back(std::make_pair(pree_node.first, 0));
637 found = true;
638 }
639 if (!found) {
640 MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
641 << output->func_graph()->ToString() << "] found no related kernel info.";
642 }
643 }
644 return output_index;
645 }
646
GetValidKernelNodes(const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * node_list)647 void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
648 MS_EXCEPTION_IF_NULL(node_list);
649 MS_EXCEPTION_IF_NULL(func_graph);
650 std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
651 for (auto const &node : node_lists) {
652 if (!AnfUtils::IsRealKernel(node) || !node->isa<CNode>()) {
653 continue;
654 }
655 auto cnode = node->cast<CNodePtr>();
656 MS_EXCEPTION_IF_NULL(cnode);
657 if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
658 node_list->push_back(node);
659 }
660 }
661 }
662
GetValidKernelNodes(const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * node_list,std::vector<AnfNodePtr> * input_list,std::vector<AnfNodePtr> * output_list)663 void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
664 std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
665 MS_EXCEPTION_IF_NULL(func_graph);
666 MS_EXCEPTION_IF_NULL(node_list);
667 MS_EXCEPTION_IF_NULL(input_list);
668
669 GetValidKernelNodes(func_graph, node_list);
670
671 auto parameters = func_graph->parameters();
672 (void)input_list->insert(input_list->cbegin(), parameters.begin(), parameters.end());
673
674 GetFuncGraphOutputNodes(func_graph, output_list);
675 }
676
GetFuncGraphOutputNodes(const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * output_list)677 void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list) {
678 MS_EXCEPTION_IF_NULL(func_graph);
679 MS_EXCEPTION_IF_NULL(output_list);
680 auto func_output = func_graph->output();
681 MS_EXCEPTION_IF_NULL(func_output);
682 if (func_output->isa<CNode>()) {
683 // multi output.
684 auto cnode = func_output->cast<CNodePtr>();
685 MS_EXCEPTION_IF_NULL(cnode);
686 auto input0 = cnode->input(kAnfPrimitiveIndex);
687 MS_EXCEPTION_IF_NULL(input0);
688 if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
689 for (size_t input_idx = 1; input_idx < cnode->size(); ++input_idx) {
690 auto input_node = cnode->input(input_idx);
691 MS_EXCEPTION_IF_NULL(input_node);
692 if (input_node->isa<CNode>() && common::AnfAlgo::GetInputTensorNum(input_node) == 0) {
693 continue;
694 }
695 output_list->push_back(common::AnfAlgo::VisitKernel(input_node, 0).first);
696 }
697 } else {
698 // single output.
699 output_list->push_back(common::AnfAlgo::VisitKernel(func_output, 0).first);
700 }
701 } else {
702 // single output.
703 output_list->push_back(common::AnfAlgo::VisitKernel(func_output, 0).first);
704 }
705 }
706
IsWeightBoundary(const AnfNodePtr & node)707 bool IsWeightBoundary(const AnfNodePtr &node) {
708 if (node->isa<ValueNode>()) {
709 return true;
710 }
711 if (node->isa<Parameter>() && common::AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
712 return true;
713 }
714 return false;
715 }
716
GetReduceAttrAxis(const CNodePtr & cnode)717 std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
718 if (common::AnfAlgo::GetInputTensorNum(cnode) != 1 || AnfAlgo::GetOutputElementNum(cnode) != 1) {
719 MS_LOG(INTERNAL_EXCEPTION) << "The reduce node [" << cnode->DebugString()
720 << "] is not single input or single output." << trace::DumpSourceLines(cnode);
721 }
722 auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
723 MS_EXCEPTION_IF_NULL(primitive);
724 auto axis_attr = primitive->GetAttr(kAxis);
725 if (axis_attr == nullptr) {
726 MS_LOG(ERROR) << "This node doesn't have axis attr. Node info [" << cnode->DebugString() << "]";
727 return {};
728 }
729 std::vector<int64_t> axis_list;
730 if (axis_attr->isa<Int64Imm>()) {
731 (void)axis_list.emplace_back(GetValue<int64_t>(axis_attr));
732 } else {
733 axis_list = GetValue<std::vector<int64_t>>(axis_attr);
734 }
735 return axis_list;
736 }
737
GetProcessorFromContext()738 Processor GetProcessorFromContext() {
739 kernel::Processor processor = kernel::Processor::UNKNOWN;
740 auto context_ptr = MsContext::GetInstance();
741 MS_EXCEPTION_IF_NULL(context_ptr);
742 auto device_info = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
743 if (device_info == kGPUDevice) {
744 processor = kernel::Processor::CUDA;
745 } else if (device_info == kAscendDevice) {
746 processor = kernel::Processor::AICORE;
747 } else if (device_info == kCPUDevice) {
748 processor = kernel::Processor::CPU;
749 }
750 return processor;
751 }
752
GetStrProcessorFromContext()753 std::string GetStrProcessorFromContext() {
754 auto processor = GetProcessorFromContext();
755 string str_processor = kernel::kProcessorUnknown;
756 if (processor == kernel::Processor::CUDA) {
757 str_processor = kernel::kProcessorCuda;
758 } else if (processor == kernel::Processor::AICORE) {
759 str_processor = kernel::kProcessorAiCore;
760 } else if (processor == kernel::Processor::CPU) {
761 str_processor = kernel::kProcessorCpu;
762 }
763 return str_processor;
764 }
765
GetShapeSize(const ShapeVector & shape,const TypePtr & type_ptr,int64_t * size_i)766 bool GetShapeSize(const ShapeVector &shape, const TypePtr &type_ptr, int64_t *size_i) {
767 MS_EXCEPTION_IF_NULL(type_ptr);
768 size_t type_byte = GetTypeByte(type_ptr);
769 if (type_byte == 0) {
770 return false;
771 }
772 for (size_t j = 0; j < shape.size(); j++) {
773 if (shape[j] <= 0) {
774 MS_LOG(DEBUG) << "shape[" << shape << "] has invalid value(less equal 0), set size to 0";
775 size_i[0] = 0;
776 return true;
777 }
778 size_i[0] = LongMulWithOverflowCheck(size_i[0], shape[j]);
779 }
780 size_i[0] = LongMulWithOverflowCheck(size_i[0], SizeToInt(type_byte));
781 return true;
782 }
783
IsDynamicParamKernel(const std::string & op_name)784 bool IsDynamicParamKernel(const std::string &op_name) {
785 const auto &op_info = kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kImplyCPU);
786 constexpr auto kParamDynamic = "dynamic";
787
788 if (op_info == nullptr) {
789 return false;
790 }
791
792 const auto &input_io_info = op_info->inputs_ptr();
793 if (input_io_info.size() != 1 || input_io_info[0]->param_type() != kParamDynamic) {
794 return false;
795 }
796
797 const auto &output_io_info = op_info->outputs_ptr();
798 if (output_io_info.size() != 1 || output_io_info[0]->param_type() != kParamDynamic) {
799 return false;
800 }
801
802 return true;
803 }
804
SelectKernelByObjectType(const CNodePtr & kernel_node,const std::vector<KernelAttr> & registered_kernel_attrs,std::vector<KernelAttr> * selected_kernel_attrs)805 bool SelectKernelByObjectType(const CNodePtr &kernel_node, const std::vector<KernelAttr> ®istered_kernel_attrs,
806 std::vector<KernelAttr> *selected_kernel_attrs) {
807 MS_EXCEPTION_IF_NULL(kernel_node);
808 MS_EXCEPTION_IF_NULL(selected_kernel_attrs);
809 const auto &inputs_object_types = AnfAlgo::GetAllInputObjectType(kernel_node);
810 const auto &output_object_types = AnfAlgo::GetAllOutputObjectType(kernel_node);
811
812 // 1. Try match all object type firstly.
813 for (auto &cur_kernel_attr : registered_kernel_attrs) {
814 const auto &[input_data_types, output_data_types] = GetInOutDataTypesFromKernelAttr(cur_kernel_attr);
815 if (IsObjectTypeStrictlyMatched(inputs_object_types, input_data_types) &&
816 IsObjectTypeStrictlyMatched(output_object_types, output_data_types)) {
817 (void)selected_kernel_attrs->emplace_back(cur_kernel_attr);
818 }
819 }
820 if (!selected_kernel_attrs->empty()) {
821 return true;
822 }
823
824 // 2. Precise matching failed, try fuzzy one again.
825 auto input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
826 auto output_num = AnfAlgo::GetOutputElementNum(kernel_node);
827 for (auto &cur_kernel_attr : registered_kernel_attrs) {
828 const auto &[input_data_types, output_data_types] = GetInOutDataTypesFromKernelAttr(cur_kernel_attr);
829 auto all_same = cur_kernel_attr.GetAllSame();
830 if (IsObjectTypeWeaklyMatched(inputs_object_types, input_data_types, all_same, input_num) &&
831 IsObjectTypeWeaklyMatched(output_object_types, output_data_types, all_same, output_num)) {
832 (void)selected_kernel_attrs->emplace_back(cur_kernel_attr);
833 }
834 }
835
836 return (!selected_kernel_attrs->empty());
837 }
838
KernelObjectTypeNotSupportWarning(const CNodePtr & kernel_node)839 std::pair<std::string, ExceptionType> KernelObjectTypeNotSupportWarning(const CNodePtr &kernel_node) {
840 MS_EXCEPTION_IF_NULL(kernel_node);
841 auto GetObjectTypeStr = [](const std::vector<TypeId> &object_types) {
842 std::vector<std::string> object_type_strs;
843 (void)std::transform(object_types.begin(), object_types.end(), std::back_inserter(object_type_strs), TypeIdLabel);
844 return std::accumulate(object_type_strs.begin(), object_type_strs.end(), std::string(),
845 [](const std::string &x, const std::string &y) { return x.empty() ? y : x + ", " + y; });
846 };
847 const std::string warn_str = std::string(kKernelObjectTypeNotSupportedStr) + ": unsupported kernel object type for " +
848 kernel_node->fullname_with_scope() + " with inputs (" +
849 GetObjectTypeStr(AnfAlgo::GetAllInputObjectType(kernel_node)) + "), outputs (" +
850 GetObjectTypeStr(AnfAlgo::GetAllOutputObjectType(kernel_node)) + ").";
851 return {warn_str, TypeError};
852 }
853
IsKernelObjectTypeNotSupportedError(const std::string & error_str)854 bool IsKernelObjectTypeNotSupportedError(const std::string &error_str) {
855 return error_str.find(kKernelObjectTypeNotSupportedStr) != std::string::npos;
856 }
857
StringToKernelObjectType(const std::string & object_type)858 KernelObjectType StringToKernelObjectType(const std::string &object_type) {
859 static const std::unordered_map<std::string, KernelObjectType> object_type_maps = {
860 {"unknown", KernelObjectType::UNKNOWN_TYPE},
861 {"tensor", KernelObjectType::TENSOR},
862 {"scalar", KernelObjectType::SCALAR},
863 {"tuple", KernelObjectType::TUPLE},
864 {"tuple_unfold", KernelObjectType::TUPLE_UNFOLD},
865 };
866 auto iter = object_type_maps.find(object_type);
867 if (iter == object_type_maps.end()) {
868 MS_LOG(EXCEPTION) << "Illegal input object type: " << object_type;
869 }
870 return iter->second;
871 }
872
UnfoldKernelBuildInfo(const CNodePtr & kernel_node)873 void UnfoldKernelBuildInfo(const CNodePtr &kernel_node) {
874 MS_EXCEPTION_IF_NULL(kernel_node);
875 auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
876 auto input_num = kernel_build_info->GetInputNum();
877 auto output_num = kernel_build_info->GetOutputNum();
878 if (input_num == 0 && output_num == 0) {
879 return;
880 }
881 const auto &input_kernel_object_types = kernel_build_info->GetAllInputKernelObjectTypes();
882 const auto &output_kernel_object_types = kernel_build_info->GetAllOutputKernelObjectTypes();
883 const auto &input_dtypes = kernel_build_info->GetAllInputDeviceTypes();
884 const auto &output_dtypes = kernel_build_info->GetAllOutputDeviceTypes();
885 const auto &input_formats = kernel_build_info->GetAllInputFormats();
886 const auto &output_formats = kernel_build_info->GetAllOutputFormats();
887
888 std::vector<TypeId> unfold_input_dtypes;
889 std::vector<TypeId> unfold_output_dtypes;
890 std::vector<std::string> unfold_input_formats;
891 std::vector<std::string> unfold_output_formats;
892 auto Append = [&](bool in_or_out, size_t index) {
893 if (in_or_out) {
894 MS_EXCEPTION_IF_CHECK_FAIL((input_num > index), "Input index is out of range.");
895 unfold_input_dtypes.push_back(input_dtypes[index]);
896 unfold_input_formats.push_back(input_formats[index]);
897 } else {
898 MS_EXCEPTION_IF_CHECK_FAIL((output_num > index), "Output index is out of range.");
899 unfold_output_dtypes.push_back(output_dtypes[index]);
900 unfold_output_formats.push_back(output_formats[index]);
901 }
902 };
903 auto RepeatAppend = [&](bool in_or_out, size_t index, size_t times) {
904 while (times > 0) {
905 Append(in_or_out, index);
906 times--;
907 }
908 };
909
910 for (size_t i = 0; i < input_kernel_object_types.size(); ++i) {
911 if (input_kernel_object_types[i] == kernel::KernelObjectType::TUPLE_UNFOLD) {
912 auto input_node = common::AnfAlgo::GetInputNode(kernel_node, i);
913 auto unfold_num = GetOutputNum(input_node);
914 MS_LOG(DEBUG) << kernel_node->fullname_with_scope() << " input idnex:" << i << " unfold num:" << unfold_num;
915 RepeatAppend(true, i, unfold_num);
916 } else {
917 Append(true, i);
918 }
919 }
920
921 for (size_t i = 0; i < output_kernel_object_types.size(); ++i) {
922 if (output_kernel_object_types[i] == kernel::KernelObjectType::TUPLE_UNFOLD) {
923 auto unfold_num = GetOutputNum(kernel_node);
924 MS_LOG(DEBUG) << kernel_node->fullname_with_scope() << " output idnex:" << i << " unfold num:" << unfold_num;
925 // Multiple outputs are expanded in the kernel attr(For example BatchNorm op).
926 if (output_num == unfold_num) {
927 for (size_t j = 0; j < unfold_num; ++j) {
928 Append(false, j);
929 }
930 } else {
931 RepeatAppend(false, i, unfold_num);
932 }
933 } else {
934 Append(false, i);
935 }
936 }
937
938 SetKernelBuildInfo(unfold_input_formats, unfold_input_dtypes, unfold_output_formats, unfold_output_dtypes,
939 kernel_node);
940 }
941
CalOutputTupleSize(const AnfNodePtr & node)942 int64_t CalOutputTupleSize(const AnfNodePtr &node) {
943 MS_EXCEPTION_IF_NULL(node);
944 bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimBpropCut);
945 bool skip = (is_bprop_cut && node->abstract()->isa<abstract::AbstractSparseTensor>());
946 if (skip || !common::AnfAlgo::IsTupleOutput(node)) {
947 return -1;
948 }
949 const auto &real_node = common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
950 auto build_info = AnfAlgo::GetSelectKernelBuildInfo(real_node);
951 if (build_info != nullptr) {
952 auto output_object = AnfAlgo::GetOutputKernelObjectType(real_node, 0);
953 if (output_object != kernel::KernelObjectType::TUPLE_UNFOLD) {
954 return -1;
955 }
956 }
957 auto output_size = static_cast<int64_t>(AnfAlgo::GetOutputElementNum(node));
958 if (node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
959 output_size = 0;
960 auto make_tuple = node->cast<CNodePtr>();
961 size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
962 for (size_t j = 0; j < tuple_input_num; ++j) {
963 // using for graph kernel
964 auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
965 // Handle tuple nested scenes.
966 if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
967 output_size += CalOutputTupleSize(dyn_input_node);
968 } else {
969 output_size++;
970 }
971 }
972 }
973 return output_size == 0 ? -1 : output_size;
974 }
975
SetDynamicInputSizeAttr(const CNodePtr & cnode)976 void SetDynamicInputSizeAttr(const CNodePtr &cnode) {
977 MS_EXCEPTION_IF_NULL(cnode);
978 if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) ||
979 common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial)) {
980 return;
981 }
982 std::vector<int64_t> dyn_input_sizes;
983 auto input_obj_types = AnfAlgo::GetInputKernelObjectTypes(cnode);
984 size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
985 for (size_t i = 0; i < input_num; ++i) {
986 if (i < input_obj_types.size() && input_obj_types[i] == kernel::KernelObjectType::TUPLE_UNFOLD) {
987 auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
988 dyn_input_sizes.push_back(CalOutputTupleSize(input_node));
989 } else {
990 dyn_input_sizes.push_back(-1);
991 }
992 }
993 if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
994 common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
995 }
996 }
997
AbstractArgsFromCNode(const CNodePtr & cnode)998 KernelArgs AbstractArgsFromCNode(const CNodePtr &cnode) {
999 MS_EXCEPTION_IF_NULL(cnode);
1000 auto [input_tensors, output_tensors] = AbstractInOutFromCNode(cnode);
1001 KernelArgs args = {input_tensors, output_tensors};
1002 return args;
1003 }
1004
CreateOperatorByCNode(const CNodePtr & cnode)1005 BaseOperatorPtr CreateOperatorByCNode(const CNodePtr &cnode) {
1006 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1007 if (prim == nullptr) {
1008 return nullptr;
1009 }
1010 auto kernel_name = prim->name();
1011 MS_LOG(DEBUG) << "Create operator " << kernel_name;
1012 auto ori_kernel_name = kernel_name;
1013 if (prim->HasAttr(kAttrMeOpName)) {
1014 ori_kernel_name = GetValue<std::string>(prim->GetAttr(kAttrMeOpName));
1015 }
1016 AdditionalAttrProcess(prim, cnode);
1017
1018 static auto operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
1019 auto it = operator_fns.find(ori_kernel_name);
1020 if (it == operator_fns.end()) {
1021 MS_LOG(DEBUG) << "Cannot create BaseOperator for " << ori_kernel_name;
1022 return nullptr;
1023 }
1024 auto base_operator = it->second(prim);
1025 return base_operator;
1026 }
1027
GetArgsFromCNode(const CNodePtr & cnode)1028 std::shared_ptr<KernelArgs> GetArgsFromCNode(const CNodePtr &cnode) {
1029 MS_EXCEPTION_IF_NULL(cnode);
1030 auto args = cnode->user_data<KernelArgs>();
1031 return args;
1032 }
1033
GetDependValueByConstTensor(const AnfNodePtr & input_node,const std::string & cnode_name,size_t i)1034 tensor::TensorPtr GetDependValueByConstTensor(const AnfNodePtr &input_node, const std::string &cnode_name, size_t i) {
1035 MS_EXCEPTION_IF_NULL(input_node);
1036 auto value_node = input_node->cast<ValueNodePtr>();
1037 MS_EXCEPTION_IF_NULL(value_node);
1038 auto value = value_node->value();
1039 MS_EXCEPTION_IF_NULL(value);
1040 if (value->isa<tensor::Tensor>()) {
1041 return value->cast<tensor::TensorPtr>();
1042 } else if (value->isa<Scalar>()) {
1043 return ScalarToTensor(value->cast<ScalarPtr>());
1044 }
1045 MS_EXCEPTION(ValueError) << "The CNode " << cnode_name << "'s input[" << i << "], must be tensor or scalar, but got "
1046 << value->ToString();
1047 }
1048
SetInputsByConstInputs(const CNodePtr & node,std::map<uint32_t,tensor::TensorPtr> * inputs_tensor_map)1049 void SetInputsByConstInputs(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *inputs_tensor_map) {
1050 std::set<int64_t> depend_list = abstract::GetValueDependArgIndices(node);
1051 auto input_size = common::AnfAlgo::GetInputTensorNum(node);
1052 auto cnode_name = node->fullname_with_scope();
1053 for (size_t i = 0; i < input_size; i++) {
1054 if (depend_list.find(i) != depend_list.end()) {
1055 auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i, false);
1056 auto real_input = input_node_with_index.first;
1057 if (!real_input->isa<ValueNode>()) {
1058 continue;
1059 }
1060 auto out_tensor = GetDependValueByConstTensor(real_input, cnode_name, i);
1061 MS_EXCEPTION_IF_NULL(inputs_tensor_map);
1062 auto ret2 = inputs_tensor_map->try_emplace(i, out_tensor);
1063 if (!ret2.second) {
1064 MS_LOG(INTERNAL_EXCEPTION) << "Insert map failed.";
1065 }
1066 }
1067 }
1068 }
1069
SetInputsByDependMap(const std::map<uint32_t,tensor::TensorPtr> & depend_tensor_map,std::vector<KernelTensorPtr> * inputs,bool is_stored_in_device)1070 void SetInputsByDependMap(const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
1071 std::vector<KernelTensorPtr> *inputs, bool is_stored_in_device) {
1072 MS_EXCEPTION_IF_NULL(inputs);
1073 for (const auto &[i, tensor] : depend_tensor_map) {
1074 if (i >= inputs->size()) {
1075 MS_LOG(EXCEPTION) << "Type to store the data to KernelTensor, expect less than" << inputs->size() << " but got "
1076 << i;
1077 }
1078 MS_EXCEPTION_IF_NULL(inputs->at(i));
1079 MS_EXCEPTION_IF_NULL(tensor);
1080 auto address = std::make_shared<kernel::Address>(tensor->data_c(), tensor->Size());
1081 if (is_stored_in_device) {
1082 // Store the data address in device one for cpu.
1083 inputs->at(i)->SetData(address);
1084 continue;
1085 }
1086 inputs->at(i)->SetHostData(address);
1087 }
1088 }
1089
SetArgsToCNode(const CNodePtr & cnode,const KernelArgs & args)1090 void SetArgsToCNode(const CNodePtr &cnode, const KernelArgs &args) {
1091 MS_EXCEPTION_IF_NULL(cnode);
1092 auto dst = cnode->user_data<KernelArgs>();
1093 if (dst == nullptr) {
1094 dst = std::make_shared<KernelArgs>();
1095 cnode->set_user_data<KernelArgs>(dst);
1096 }
1097 dst->inputs = args.inputs;
1098 dst->outputs = args.outputs;
1099 dst->depend_tensor_map = args.depend_tensor_map;
1100 }
1101
UpdateNodeShape(const CNodePtr & cnode)1102 void UpdateNodeShape(const CNodePtr &cnode) {
1103 MS_EXCEPTION_IF_NULL(cnode);
1104 auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
1105 MS_EXCEPTION_IF_NULL(kernel_mod);
1106 if (!kernel_mod->IsNeedUpdateOutputShapeAndSize()) {
1107 return;
1108 }
1109
1110 auto output_tensor = AnfAlgo::GetOrCreateAllOutputKernelTensors(cnode);
1111 auto input_tensor = AnfAlgo::GetOrCreateAllInputKernelTensors(cnode);
1112 kernel_mod->UpdateOutputShapeAndSize(input_tensor, output_tensor);
1113 if (output_tensor.empty()) {
1114 return;
1115 }
1116 std::vector<TypeId> type_ids;
1117 std::vector<ShapeVector> shapes;
1118 size_t output_num = output_tensor.size();
1119 for (size_t i = 0; i < output_num; ++i) {
1120 MS_EXCEPTION_IF_NULL(output_tensor[i]);
1121 auto out_shape = output_tensor[i]->GetShapeVector();
1122 if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t dim) { return dim < 0; })) {
1123 MS_LOG(ERROR) << "Retrieve invalid output shape " << out_shape;
1124 return;
1125 }
1126 (void)shapes.emplace_back(std::move(out_shape));
1127 (void)type_ids.emplace_back(output_tensor[i]->dtype_id());
1128 }
1129 common::AnfAlgo::SetOutputInferTypeAndShape(type_ids, shapes, cnode.get(), true);
1130 }
1131
1132 // In compile stage, run resize when kernel is not dynamic shape or has no value depend list.
CheckResizeCondition(const CNodePtr & node)1133 bool CheckResizeCondition(const CNodePtr &node) {
1134 MS_EXCEPTION_IF_NULL(node);
1135 MS_EXCEPTION_IF_NULL(node->input(0));
1136 if (!AnfAlgo::NodeValueIsFuncGraph(node->input(0))) {
1137 if (common::AnfAlgo::IsDynamicShape(node)) {
1138 MS_LOG(DEBUG) << "Skip resize for " << node->DebugString() << ", for reason is dynamic shape";
1139 return false;
1140 }
1141 if (common::AnfAlgo::IsDynamicValue(node)) {
1142 MS_LOG(DEBUG) << "Skip resize for " << node->DebugString() << ", for reason is dynamic value";
1143 return false;
1144 }
1145 }
1146
1147 return true;
1148 }
1149 } // namespace kernel
1150 } // namespace mindspore
1151