• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "kernel/common_utils.h"
18 #include <algorithm>
19 #include <bitset>
20 #include <cmath>
21 #include <fstream>
22 #include <iostream>
23 #include <map>
24 #include <set>
25 #include <thread>
26 #include <tuple>
27 #include <unordered_map>
28 #include <utility>
29 #include <vector>
30 #include "include/backend/anf_runtime_algorithm.h"
31 #include "include/common/utils/anfalgo.h"
32 #include "ir/graph_utils.h"
33 #include "kernel/oplib/oplib.h"
34 #include "kernel/format_utils.h"
35 #include "mindapi/base/type_id.h"
36 #include "mindspore/ccsrc/include/common/debug/common.h"
37 #include "nlohmann/json.hpp"
38 #include "ops/array_op_name.h"
39 #include "ops/conv_pool_op_name.h"
40 #include "ops/framework_ops.h"
41 #include "ops/math_op_name.h"
42 #include "ops/nn_ops.h"
43 #include "ops/sequence_ops.h"
44 #include "utils/anf_utils.h"
45 
46 namespace mindspore {
47 namespace kernel {
48 namespace {
49 constexpr char kTypeInt32[] = "Int32";
50 constexpr auto kQuad = 4;
51 constexpr size_t kInputFirstIndex = 0;
52 }  // namespace
53 
GetOutputNum(const AnfNodePtr & node)54 size_t GetOutputNum(const AnfNodePtr &node) {
55   MS_EXCEPTION_IF_NULL(node);
56   const auto &type = node->Type();
57   if (type == nullptr) {
58     MS_LOG(EXCEPTION) << "Failed to get type in node:" << node->fullname_with_scope();
59   } else if (type->isa<Tuple>()) {
60     auto tuple_type = type->cast<TuplePtr>();
61     MS_EXCEPTION_IF_NULL(tuple_type);
62     if (tuple_type->dynamic_len()) {
63       return 1;
64     }
65     const auto &sub_types = tuple_type->elements();
66     return static_cast<size_t>(std::count_if(sub_types.begin(), sub_types.end(), [](const TypePtr &sub_type) {
67       return sub_type != nullptr && (!sub_type->isa<MonadType>());
68     }));
69   } else if (type->isa<List>()) {
70     auto list_type = type->cast<ListPtr>();
71     MS_EXCEPTION_IF_NULL(list_type);
72     if (list_type->dynamic_len()) {
73       return 1;
74     }
75     const auto &sub_types = list_type->elements();
76     return static_cast<size_t>(std::count_if(sub_types.begin(), sub_types.end(), [](const TypePtr &sub_type) {
77       return sub_type != nullptr && (!sub_type->isa<MonadType>());
78     }));
79   } else if (type->isa<CSRTensorType>()) {
80     return 5;
81   } else if (type->isa<COOTensorType>()) {
82     return 4;
83   }
84   return 1;
85 }
86 
CalDiagOffset(int diag_index,int max_diag_len,int inner_rows,int inner_cols,const std::pair<MatrixDiag::Alignment,MatrixDiag::Alignment> & alignment)87 int CalDiagOffset(int diag_index, int max_diag_len, int inner_rows, int inner_cols,
88                   const std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> &alignment) {
89   bool right_align_super_diagonal = (alignment.first == MatrixDiag::RIGHT);
90   bool right_align_sub_diagonal = (alignment.second == MatrixDiag::RIGHT);
91   const bool right_align =
92     (diag_index >= 0 && right_align_super_diagonal) || (diag_index <= 0 && right_align_sub_diagonal);
93   const int diag_len = std::min(inner_rows + std::min(0, diag_index), inner_cols - std::max(0, diag_index));
94   const int offset = (right_align) ? (max_diag_len - diag_len) : 0;
95   return offset;
96 }
97 
DtypeToTypeId(const std::string & dtypes)98 TypeId DtypeToTypeId(const std::string &dtypes) {
99   if (dtypes == "float") {
100     return TypeId::kNumberTypeFloat32;
101   }
102   if (dtypes.empty()) {
103     return TypeId::kMetaTypeNone;
104   }
105   return StringToTypeId(dtypes);
106 }
107 
Dtype2ShortType(const std::string & dtype)108 std::string Dtype2ShortType(const std::string &dtype) {
109   static const std::unordered_map<std::string, std::string> dtype_shortdtype_map = {
110     {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"},  {"int8", "i8"},    {"int16", "i16"},
111     {"int32", "i32"},   {"int64", "i64"},   {"uint8", "u8"},     {"uint16", "u16"}, {"uint32", "u32"},
112     {"uint64", "u64"},  {"bool", "bool"},   {"bfloat16", "bf16"}};
113 
114   auto iter = dtype_shortdtype_map.find(dtype);
115   if (iter != dtype_shortdtype_map.end()) {
116     return iter->second;
117   } else {
118     MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
119   }
120 }
121 
GetDtypeNbyte(const std::string & dtype)122 size_t GetDtypeNbyte(const std::string &dtype) {
123   static const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
124     {"float16", sizeof(float) / 2},   {"float32", sizeof(float)},     {"float64", sizeof(float) * 2},
125     {"int8", sizeof(int) / kQuad},    {"int16", sizeof(int) / 2},     {"int32", sizeof(int)},
126     {"int64", sizeof(int) * 2},       {"uint8", sizeof(int) / kQuad}, {"uint16", sizeof(int) / 2},
127     {"uint32", sizeof(int)},          {"uint64", sizeof(int) * 2},    {"bool", sizeof(char)},
128     {"complex64", sizeof(float) * 2}, {"bfloat16", sizeof(float) / 2}};
129 
130   auto iter = dtype_nbyte_map.find(dtype);
131   if (iter != dtype_nbyte_map.end()) {
132     return iter->second;
133   } else {
134     MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
135   }
136 }
137 
IsSameShape(const ShapeVector & shape_a,const ShapeVector & shape_b)138 bool IsSameShape(const ShapeVector &shape_a, const ShapeVector &shape_b) { return shape_a == shape_b; }
139 
CheckShapesSame(const ShapeArray & shape_array)140 bool CheckShapesSame(const ShapeArray &shape_array) {
141   auto first_shape = shape_array[0];
142   return std::all_of(shape_array.begin() + 1, shape_array.end(),
143                      [&first_shape](const ShapeVector &shape) { return IsSameShape(shape, first_shape); });
144 }
145 
GetProcessorStr(const AnfNodePtr & anf_node)146 std::string GetProcessorStr(const AnfNodePtr &anf_node) {
147   MS_EXCEPTION_IF_NULL(anf_node);
148   std::string processor = kProcessorUnknown;
149   auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
150   MS_EXCEPTION_IF_NULL(kernel_info);
151   auto build_info = kernel_info->select_kernel_build_info();
152   // we may call this before kernel select.
153   if (build_info == nullptr) {
154     return processor;
155   }
156   switch (build_info->processor()) {
157     case Processor::AICORE:
158       processor = kProcessorAiCore;
159       break;
160 
161     case Processor::AICPU:
162       processor = kProcessorAiCpu;
163       break;
164 
165     case Processor::CUDA:
166       processor = kProcessorCuda;
167       break;
168 
169     case Processor::CPU:
170       processor = kProcessorCpu;
171       break;
172 
173     default:
174       MS_LOG(DEBUG) << "Unknown processor type.";
175       break;
176   }
177 
178   return processor;
179 }
180 
GetOutputObjectTypeListFromKernelAttr(const kernel::KernelAttr & kernel_attr)181 std::vector<TypeId> GetOutputObjectTypeListFromKernelAttr(const kernel::KernelAttr &kernel_attr) {
182   size_t output_attr_size = kernel_attr.GetOutputSize();
183   std::vector<TypeId> res;
184   for (size_t i = 0; i < output_attr_size; ++i) {
185     res.push_back(kernel_attr.GetOutputAttr(i).object_type);
186   }
187   return res;
188 }
189 
GetInputObjectTypeListFromKernelAttr(const kernel::KernelAttr & kernel_attr)190 std::vector<TypeId> GetInputObjectTypeListFromKernelAttr(const kernel::KernelAttr &kernel_attr) {
191   size_t input_attr_size = kernel_attr.GetInputSize();
192   std::vector<TypeId> res;
193   for (size_t i = 0; i < input_attr_size; ++i) {
194     res.push_back(kernel_attr.GetInputAttr(i).object_type);
195   }
196   return res;
197 }
198 
TypeIdToKernelObjectType(const TypeId & type_id)199 KernelObjectType TypeIdToKernelObjectType(const TypeId &type_id) {
200   std::unordered_map<TypeId, KernelObjectType> trans_map{{kObjectTypeTuple, KernelObjectType::TUPLE},
201                                                          {kObjectTypeNumber, KernelObjectType::SCALAR},
202                                                          {kObjectTypeTensorType, KernelObjectType::TENSOR}};
203   if (trans_map.find(type_id) == trans_map.end()) {
204     MS_LOG(DEBUG) << "Unsupported type id " << TypeIdToString(type_id)
205                   << ", that cannot converted to corresponding kernel object type.";
206     return KernelObjectType::UNKNOWN_TYPE;
207   }
208   return trans_map[type_id];
209 }
210 
TypeIdToKernelObjectType(const std::vector<TypeId> & type_ids)211 std::vector<KernelObjectType> TypeIdToKernelObjectType(const std::vector<TypeId> &type_ids) {
212   std::vector<KernelObjectType> ret;
213   (void)std::transform(type_ids.begin(), type_ids.end(), std::back_inserter(ret),
214                        [](const TypeId &type_id) { return kernel::TypeIdToKernelObjectType(type_id); });
215   return ret;
216 }
217 
TypeIdToKernelObjectTypeForTupleUnfold(const TypeId & type_id)218 KernelObjectType TypeIdToKernelObjectTypeForTupleUnfold(const TypeId &type_id) {
219   std::unordered_map<TypeId, KernelObjectType> trans_map{{kObjectTypeTuple, KernelObjectType::TUPLE_UNFOLD},
220                                                          {kObjectTypeNumber, KernelObjectType::SCALAR},
221                                                          {kObjectTypeTensorType, KernelObjectType::TENSOR}};
222   if (trans_map.find(type_id) == trans_map.end()) {
223     MS_LOG(DEBUG) << "Unsupported type id " << TypeIdToString(type_id)
224                   << ", that cannot converted to corresponding kernel object type.";
225     return KernelObjectType::UNKNOWN_TYPE;
226   }
227   return trans_map[type_id];
228 }
229 
TypeIdToKernelObjectTypeForTupleUnfold(const std::vector<TypeId> & type_ids)230 std::vector<KernelObjectType> TypeIdToKernelObjectTypeForTupleUnfold(const std::vector<TypeId> &type_ids) {
231   std::vector<KernelObjectType> ret;
232   (void)std::transform(type_ids.begin(), type_ids.end(), std::back_inserter(ret),
233                        [](const TypeId &type_id) { return kernel::TypeIdToKernelObjectTypeForTupleUnfold(type_id); });
234   return ret;
235 }
236 
KernelObjectTypeToTypeId(const KernelObjectType & object_type)237 TypeId KernelObjectTypeToTypeId(const KernelObjectType &object_type) {
238   std::unordered_map<KernelObjectType, TypeId> trans_map{{KernelObjectType::TUPLE, kObjectTypeTuple},
239                                                          {KernelObjectType::TUPLE_UNFOLD, kObjectTypeTuple},
240                                                          {KernelObjectType::SCALAR, kObjectTypeNumber},
241                                                          {KernelObjectType::TENSOR, kObjectTypeTensorType}};
242   if (trans_map.find(object_type) == trans_map.end()) {
243     MS_LOG(DEBUG) << "Unsupported kernel object type " << object_type
244                   << ", that cannot converted to corresponding type id.";
245     return kTypeUnknown;
246   }
247   return trans_map[object_type];
248 }
249 
250 // The allsame/skip_check and the unequal size scenario don't support object type backoff and use the object_types,
251 // other scenes support the object type backoff and use the selected_object_types.
CalKernelObjectTypes(const std::vector<TypeId> & object_types,const std::vector<TypeId> & selected_object_types,bool all_same,bool skip_check)252 std::vector<KernelObjectType> CalKernelObjectTypes(const std::vector<TypeId> &object_types,
253                                                    const std::vector<TypeId> &selected_object_types, bool all_same,
254                                                    bool skip_check) {
255   std::vector<KernelObjectType> ret;
256   //  Use the selected_object_types in the equal size scenario.
257   if (object_types.size() == selected_object_types.size()) {
258     for (size_t i = 0; i < selected_object_types.size(); ++i) {
259       // Allsame/skip_check doesn't support the backoff.
260       bool not_backoff = ((all_same || skip_check) && (selected_object_types[i] != object_types[i]));
261       if (not_backoff) {
262         (void)ret.emplace_back(TypeIdToKernelObjectTypeForTupleUnfold(object_types[i]));
263       } else {
264         (void)ret.emplace_back(TypeIdToKernelObjectType(selected_object_types[i]));
265       }
266     }
267     return ret;
268   }
269 
270   // Use the object_types in the unequal size scenario, and convert tuple to tupleUnflod.
271   for (size_t i = 0; i < object_types.size(); ++i) {
272     (void)ret.emplace_back(TypeIdToKernelObjectTypeForTupleUnfold(object_types[i]));
273   }
274   return ret;
275 }
276 
CalInputKernelObjectTypes(const AnfNodePtr & kernel_node,const kernel::KernelAttr & selected_kernel_attr)277 std::vector<KernelObjectType> CalInputKernelObjectTypes(const AnfNodePtr &kernel_node,
278                                                         const kernel::KernelAttr &selected_kernel_attr) {
279   MS_EXCEPTION_IF_NULL(kernel_node);
280   auto selected_input_object_types = GetInputObjectTypeListFromKernelAttr(selected_kernel_attr);
281   auto input_object_types = AnfAlgo::GetAllInputObjectType(kernel_node);
282   return CalKernelObjectTypes(input_object_types, selected_input_object_types, selected_kernel_attr.GetAllSame(),
283                               selected_kernel_attr.GetSkipCheck());
284 }
285 
CalOutputKernelObjectTypes(const AnfNodePtr & kernel_node,const kernel::KernelAttr & selected_kernel_attr)286 std::vector<KernelObjectType> CalOutputKernelObjectTypes(const AnfNodePtr &kernel_node,
287                                                          const kernel::KernelAttr &selected_kernel_attr) {
288   MS_EXCEPTION_IF_NULL(kernel_node);
289   auto selected_output_object_types = GetOutputObjectTypeListFromKernelAttr(selected_kernel_attr);
290   auto output_object_types = AnfAlgo::GetAllOutputObjectType(kernel_node);
291   return CalKernelObjectTypes(output_object_types, selected_output_object_types, selected_kernel_attr.GetAllSame(),
292                               selected_kernel_attr.GetSkipCheck());
293 }
294 
CalOutputElementObjectTypes(const AnfNodePtr & kernel_node,const kernel::KernelAttr & selected_kernel_attr)295 std::vector<KernelObjectType> CalOutputElementObjectTypes(const AnfNodePtr &kernel_node,
296                                                           const kernel::KernelAttr &selected_kernel_attr) {
297   MS_EXCEPTION_IF_NULL(kernel_node);
298   auto selected_output_object_types = GetOutputObjectTypeListFromKernelAttr(selected_kernel_attr);
299   MS_LOG(DEBUG) << "Output object type:" << selected_output_object_types << " for node:" << kernel_node->DebugString()
300                 << " select attr:" << kernel::FetchPrintInfoByKernelAttr(selected_kernel_attr);
301   auto element_num = GetOutputNum(kernel_node);
302   if (selected_kernel_attr.GetAllSame() && selected_output_object_types.size() == 1) {
303     return std::vector<KernelObjectType>(element_num, TypeIdToKernelObjectType(selected_output_object_types[0]));
304   }
305   MS_EXCEPTION_IF_CHECK_FAIL(element_num == selected_output_object_types.size(),
306                              "Check multi-output kernel attr size failed.");
307   return TypeIdToKernelObjectType(selected_output_object_types);
308 }
309 
FetchPrintInfoByKernelAttr(KernelAttr selected_kernel_attr)310 std::string FetchPrintInfoByKernelAttr(KernelAttr selected_kernel_attr) {
311   std::string attr_info = "input[";
312   (void)std::for_each(std::begin(selected_kernel_attr.input_type()), std::end(selected_kernel_attr.input_type()),
313                       [&attr_info](auto &input_type) {
314                         attr_info += TypeIdToString(input_type.object_type) + " " + TypeIdToString(input_type.dtype) +
315                                      " " + input_type.format + ",";
316                       });
317   attr_info += "] output[";
318   (void)std::for_each(std::begin(selected_kernel_attr.output_type()), std::end(selected_kernel_attr.output_type()),
319                       [&attr_info](auto &output_type) {
320                         attr_info += TypeIdToString(output_type.object_type) + " " + TypeIdToString(output_type.dtype) +
321                                      " " + output_type.format + ",";
322                       });
323   attr_info += "]";
324   return attr_info;
325 }
326 
SetKernelObjectTypeBuildInfo(const AnfNodePtr & kernel_node,const std::vector<KernelObjectType> & input_kernel_object_types,const std::vector<KernelObjectType> & output_kernel_object_types)327 void SetKernelObjectTypeBuildInfo(const AnfNodePtr &kernel_node,
328                                   const std::vector<KernelObjectType> &input_kernel_object_types,
329                                   const std::vector<KernelObjectType> &output_kernel_object_types) {
330   MS_EXCEPTION_IF_NULL(kernel_node);
331   if (kernel_node->kernel_info() == nullptr) {
332     kernel_node->set_kernel_info(std::make_shared<device::KernelInfo>());
333   }
334   if (!kernel_node->kernel_info()->has_build_info()) {
335     AnfAlgo::SetSelectKernelBuildInfo(std::make_shared<kernel::KernelBuildInfo>(), kernel_node.get());
336   }
337 
338   MS_LOG(DEBUG) << kernel_node->fullname_with_scope() << " input kernel object type is: " << input_kernel_object_types
339                 << ", output kernel object type is: " << output_kernel_object_types;
340   auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
341   kernel_build_info->SetOutputsKernelObjectType(output_kernel_object_types);
342   kernel_build_info->SetInputsKernelObjectType(input_kernel_object_types);
343 }
344 
SetKernelObjectTypeBuildInfo(const AnfNodePtr & kernel_node,const std::vector<KernelObjectType> & input_kernel_object_types,const std::vector<KernelObjectType> & output_kernel_object_types,const std::vector<KernelObjectType> & output_elements_kernel_object_types)345 void SetKernelObjectTypeBuildInfo(const AnfNodePtr &kernel_node,
346                                   const std::vector<KernelObjectType> &input_kernel_object_types,
347                                   const std::vector<KernelObjectType> &output_kernel_object_types,
348                                   const std::vector<KernelObjectType> &output_elements_kernel_object_types) {
349   MS_EXCEPTION_IF_NULL(kernel_node);
350   if (kernel_node->kernel_info() == nullptr) {
351     kernel_node->set_kernel_info(std::make_shared<device::KernelInfo>());
352   }
353   if (!kernel_node->kernel_info()->has_build_info()) {
354     AnfAlgo::SetSelectKernelBuildInfo(std::make_shared<kernel::KernelBuildInfo>(), kernel_node.get());
355   }
356 
357   MS_LOG(DEBUG) << kernel_node->fullname_with_scope() << " input kernel object type is: " << input_kernel_object_types
358                 << ", output kernel object type is: " << output_kernel_object_types
359                 << ", output elements kernel object type is: " << output_elements_kernel_object_types;
360   auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
361   kernel_build_info->SetOutputsKernelObjectType(output_kernel_object_types);
362   kernel_build_info->SetInputsKernelObjectType(input_kernel_object_types);
363   kernel_build_info->SetOutputElementsKernelObjectType(output_elements_kernel_object_types);
364 }
365 
HasOutputElementsKernelObjectType(const std::vector<KernelObjectType> & output_kernel_object_types)366 bool HasOutputElementsKernelObjectType(const std::vector<KernelObjectType> &output_kernel_object_types) {
367   return output_kernel_object_types.size() == 1 &&
368          output_kernel_object_types[0] == kernel::KernelObjectType::TUPLE_UNFOLD;
369 }
370 
SetKernelObjectTypeWithSelectedAttr(const CNodePtr & kernel_node,const kernel::KernelAttr & selected_kernel_attr)371 void SetKernelObjectTypeWithSelectedAttr(const CNodePtr &kernel_node, const kernel::KernelAttr &selected_kernel_attr) {
372   MS_EXCEPTION_IF_NULL(kernel_node);
373   std::vector<KernelObjectType> input_kernel_object_types;
374   if (common::AnfAlgo::HasNodeAttr(kInputRealTuple, kernel_node)) {
375     input_kernel_object_types = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAllInputObjectType(kernel_node));
376   } else {
377     input_kernel_object_types = CalInputKernelObjectTypes(kernel_node, selected_kernel_attr);
378   }
379 
380   std::vector<KernelObjectType> output_kernel_object_types;
381   if (common::AnfAlgo::HasNodeAttr(kOutputRealTuple, kernel_node)) {
382     output_kernel_object_types = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAllOutputObjectType(kernel_node));
383   } else {
384     output_kernel_object_types = CalOutputKernelObjectTypes(kernel_node, selected_kernel_attr);
385   }
386 
387   std::vector<KernelObjectType> output_element_object_types;
388   if (HasOutputElementsKernelObjectType(output_kernel_object_types)) {
389     output_element_object_types = CalOutputElementObjectTypes(kernel_node, selected_kernel_attr);
390   }
391   MS_LOG(DEBUG) << "Set kernel object type:" << output_kernel_object_types
392                 << " for node:" << kernel_node->fullname_with_scope();
393   SetKernelObjectTypeBuildInfo(kernel_node, input_kernel_object_types, output_kernel_object_types,
394                                output_element_object_types);
395 }
396 
AddInputAttr(const TypeId & object_type,const TypeId & ms_type,const std::string & format)397 KernelAttr &KernelAttr::AddInputAttr(const TypeId &object_type, const TypeId &ms_type, const std::string &format) {
398   (void)input_type_.emplace_back(DataType(ms_type, format, object_type));
399   return *this;
400 }
401 
AddOptionalInputAttr(const TypeId & object_type,const TypeId & ms_type,const std::string & format)402 KernelAttr &KernelAttr::AddOptionalInputAttr(const TypeId &object_type, const TypeId &ms_type,
403                                              const std::string &format) {
404   (void)input_type_.emplace_back(DataType(ms_type, format, object_type, true));
405   return *this;
406 }
407 
AddOutputAttr(const TypeId & object_type,const TypeId & ms_type,const std::string & format)408 KernelAttr &KernelAttr::AddOutputAttr(const TypeId &object_type, const TypeId &ms_type, const std::string &format) {
409   (void)output_type_.emplace_back(DataType(ms_type, format, object_type));
410   return *this;
411 }
412 
AddInputAttr(const TypeId & ms_type,const std::string & format)413 KernelAttr &KernelAttr::AddInputAttr(const TypeId &ms_type, const std::string &format) {
414   (void)input_type_.emplace_back(DataType(ms_type, format));
415   return *this;
416 }
417 
AddOptionalInputAttr(const TypeId & ms_type,const std::string & format)418 KernelAttr &KernelAttr::AddOptionalInputAttr(const TypeId &ms_type, const std::string &format) {
419   (void)input_type_.emplace_back(DataType(ms_type, format, kObjectTypeTensorType, true));
420   return *this;
421 }
422 
AddOutputAttr(const TypeId & ms_type,const std::string & format)423 KernelAttr &KernelAttr::AddOutputAttr(const TypeId &ms_type, const std::string &format) {
424   (void)output_type_.emplace_back(DataType(ms_type, format));
425   return *this;
426 }
427 
AddAllSameAttr(bool all_same,size_t all_same_input_num,bool group_allsame)428 KernelAttr &KernelAttr::AddAllSameAttr(bool all_same, size_t all_same_input_num, bool group_allsame) {
429   all_same_ = all_same;
430   is_group_allsame_ = group_allsame;
431   if (all_same_input_num < 1) {
432     MS_LOG(EXCEPTION) << "Allsame attr must >= 1, but get " << all_same_input_num;
433   }
434   all_same_input_num_ = all_same_input_num;
435   return *this;
436 }
437 
AddSkipCheckAttr(bool skip_check)438 KernelAttr &KernelAttr::AddSkipCheckAttr(bool skip_check) {
439   skip_check_ = skip_check;
440   return *this;
441 }
442 
AddRealTuple(const bool & is_real_tuple)443 KernelAttr &KernelAttr::AddRealTuple(const bool &is_real_tuple) {
444   is_real_tuple_ = is_real_tuple;
445   return *this;
446 }
447 
AddOutInRef(size_t output_index,size_t input_index)448 KernelAttr &KernelAttr::AddOutInRef(size_t output_index, size_t input_index) {
449   out_in_ref_map_[output_index] = input_index;
450   return *this;
451 }
452 
AddAllOutInRef(bool all_out_in_ref)453 KernelAttr &KernelAttr::AddAllOutInRef(bool all_out_in_ref) {
454   all_out_in_ref_ = all_out_in_ref;
455   return *this;
456 }
457 
SetInputAttr(const size_t index,const TypeId & ms_type,const std::string & format)458 void KernelAttr::SetInputAttr(const size_t index, const TypeId &ms_type, const std::string &format) {
459   if (index >= input_type_.size()) {
460     MS_LOG(EXCEPTION) << "Invalid index for input: " << index << ", out of range.";
461   }
462   input_type_[index] = DataType(ms_type, format);
463 }
464 
SetOutputAttr(const size_t index,const TypeId & ms_type,const std::string & format)465 void KernelAttr::SetOutputAttr(const size_t index, const TypeId &ms_type, const std::string &format) {
466   if (index >= output_type_.size()) {
467     MS_LOG(EXCEPTION) << "Invalid index for output: " << index << ", out of range.";
468   }
469   output_type_[index] = DataType(ms_type, format);
470 }
471 
SetInputAttrList(const std::vector<DataType> & addr_list)472 void KernelAttr::SetInputAttrList(const std::vector<DataType> &addr_list) {
473   input_type_.assign(addr_list.begin(), addr_list.end());
474 }
475 
SetOutputAttrList(const std::vector<DataType> & addr_list)476 void KernelAttr::SetOutputAttrList(const std::vector<DataType> &addr_list) {
477   output_type_.assign(addr_list.begin(), addr_list.end());
478 }
479 
operator <<(std::ostream & os,KernelAttr kernel_attr)480 std::ostream &operator<<(std::ostream &os, KernelAttr kernel_attr) {
481   std::stringstream ss;
482   ss << "[Kernel Attr] all same: " << kernel_attr.GetAllSame();
483   if (kernel_attr.GetSkipCheck()) {
484     ss << ", skip check: true";
485   }
486   size_t input_num = kernel_attr.GetInputSize();
487   if (input_num > 0) {
488     ss << ", input(";
489     for (size_t i = 0; i < input_num; ++i) {
490       ss << TypeIdLabel(kernel_attr.GetInputAttr(i).dtype);
491       if (kernel_attr.GetInputAttr(i).is_optional) {
492         ss << "|None";
493       }
494       if (i != input_num - 1) {
495         ss << ",";
496       }
497     }
498     ss << ") ";
499   }
500   size_t output_num = kernel_attr.GetOutputSize();
501   if (output_num > 0) {
502     ss << ", output(";
503     for (size_t i = 0; i < output_num; ++i) {
504       ss << TypeIdLabel(kernel_attr.GetOutputAttr(i).dtype);
505       if (i != output_num - 1) {
506         ss << ",";
507       }
508     }
509     ss << ").";
510   }
511 
512   return os << ss.str();
513 }
514 
MatchMultiDynamicKernelAttr(const KernelAttr & kernel_attr,const std::vector<int64_t> & dyn_input_sizes,const std::vector<KernelAttr> & kernel_attr_list)515 std::pair<bool, size_t> MatchMultiDynamicKernelAttr(const KernelAttr &kernel_attr,
516                                                     const std::vector<int64_t> &dyn_input_sizes,
517                                                     const std::vector<KernelAttr> &kernel_attr_list) {
518   auto output_num = kernel_attr.GetOutputSize();
519   for (size_t index = 0; index < kernel_attr_list.size(); ++index) {
520     // support multi dynamic inputs.
521     const auto &cur_kernel_attr = kernel_attr_list[index];
522     auto cur_input_num = cur_kernel_attr.GetInputSize();
523     if (dyn_input_sizes.size() != cur_input_num) {
524       MS_LOG(EXCEPTION) << "Kernel attr's input num: " << cur_input_num
525                         << ", is not equal to dynamic input size: " << dyn_input_sizes.size();
526     }
527     bool mis_match = false;
528     size_t input_index = kInputFirstIndex;
529     for (size_t i = 0; i < cur_input_num; ++i) {
530       int64_t dyn_input_size = dyn_input_sizes[i];
531       if (dyn_input_size < 0) {
532         dyn_input_size = 1;
533       }
534       auto dtype = cur_kernel_attr.GetInputAttr(i).dtype;
535       for (size_t j = 0; j < LongToSize(dyn_input_size); ++j) {
536         if (kernel_attr.GetInputAttr(input_index).dtype != dtype) {
537           mis_match = true;
538           break;
539         }
540         ++input_index;
541       }
542       if (mis_match) {
543         break;
544       }
545     }
546     if (mis_match) {
547       continue;
548     }
549 
550     // only support one dynamic output. TODO: support multi dynamic output.
551     for (size_t i = 0; i < output_num; ++i) {
552       auto dtype = cur_kernel_attr.GetOutputAttr(i).dtype;
553       if (kernel_attr.GetInputAttr(i).dtype != dtype) {
554         mis_match = true;
555         break;
556       }
557     }
558     if (!mis_match) {
559       return std::make_pair(true, index);
560     }
561   }
562   return std::make_pair(false, 0);
563 }
564 
CheckAttrForAllSameInput(const size_t input_num,const std::vector<mindspore::TypeId> & input_types,const KernelAttr & cur_kernel_attr)565 bool CheckAttrForAllSameInput(const size_t input_num, const std::vector<mindspore::TypeId> &input_types,
566                               const KernelAttr &cur_kernel_attr) {
567   auto cur_input_num = cur_kernel_attr.GetInputSize();
568   bool is_group_allsame = cur_kernel_attr.GetGroupAllSame();
569   size_t cur_all_same_input_num = cur_kernel_attr.GetAllSameInputNum();  // default 0; else >=1 when allsame=true
570   size_t cur_standalone_input_num = cur_input_num - cur_all_same_input_num;
571   size_t each_attr_input_num =
572     (input_num - cur_standalone_input_num) / (cur_all_same_input_num == 0 ? 1 : cur_all_same_input_num);
573   // deal with allsame inputs
574   if (is_group_allsame) {
575     for (size_t i = 0; i < each_attr_input_num; ++i) {
576       for (size_t j = 0; j < cur_all_same_input_num; ++j) {
577         auto dtype = cur_kernel_attr.GetInputAttr(j).dtype;
578         auto start = j + i * cur_all_same_input_num;
579         if (input_types[start] != dtype && input_types[start] != kTypeUnknown) {
580           return true;
581         }
582       }
583     }
584   } else {
585     for (size_t i = 0; i < cur_all_same_input_num; ++i) {
586       for (size_t j = 0; j < each_attr_input_num; ++j) {
587         auto dtype = cur_kernel_attr.GetInputAttr(i).dtype;
588         auto start = j + i * each_attr_input_num;
589         if (input_types[start] != dtype && input_types[start] != kTypeUnknown) {
590           return true;
591         }
592       }
593     }
594   }
595 
596   // deal with the rest except allsame inputs
597   for (size_t i = cur_all_same_input_num; i < cur_standalone_input_num; ++i) {
598     auto dtype = cur_kernel_attr.GetInputAttr(i).dtype;
599     auto start = each_attr_input_num * cur_all_same_input_num + i;
600     if (!(cur_kernel_attr.GetInputAttr(i).is_optional && input_types[start] == kMetaTypeNone) &&
601         (input_types[start] != dtype && input_types[start] != kTypeUnknown)) {
602       return true;
603     }
604   }
605   return false;
606 }
607 
MatchKernelAttr(const KernelAttr & kernel_attr,const std::vector<KernelAttr> & kernel_attr_list)608 std::pair<bool, size_t> MatchKernelAttr(const KernelAttr &kernel_attr,
609                                         const std::vector<KernelAttr> &kernel_attr_list) {
610   // kernel_attr should not be all same. If so, then return false.
611   if (kernel_attr.GetAllSame()) {
612     return std::make_pair(false, 0);
613   }
614   auto input_num = kernel_attr.GetInputSize();
615   auto output_num = kernel_attr.GetOutputSize();
616 
617   for (size_t index = 0; index < kernel_attr_list.size(); ++index) {
618     const auto &cur_kernel_attr = kernel_attr_list[index];
619     auto cur_input_num = cur_kernel_attr.GetInputSize();
620     auto cur_output_num = cur_kernel_attr.GetOutputSize();
621     if (!cur_kernel_attr.GetAllSame() && (input_num != cur_input_num || output_num != cur_output_num)) {
622       continue;
623     }
624     std::vector<mindspore::TypeId> input_types;
625     (void)std::transform(kernel_attr.input_type().begin(), kernel_attr.input_type().end(),
626                          std::back_inserter(input_types), [](const DataType &Dtype) { return Dtype.dtype; });
627 
628     bool mis_match = CheckAttrForAllSameInput(input_num, input_types, cur_kernel_attr);
629     if (mis_match) {
630       continue;
631     }
632 
633     for (size_t i = 0; i < output_num; ++i) {
634       auto dtype = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).dtype;
635       if (kernel_attr.GetOutputAttr(i).dtype != dtype && kernel_attr.GetOutputAttr(i).dtype != kTypeUnknown) {
636         mis_match = true;
637         break;
638       }
639     }
640     if (!mis_match) {
641       return std::make_pair(true, index);
642     }
643   }
644 
645   return std::make_pair(false, 0);
646 }
647 
MatchKernelAttrStrict(const KernelAttr & kernel_attr,const std::vector<KernelAttr> & kernel_attr_list)648 std::pair<bool, size_t> MatchKernelAttrStrict(const KernelAttr &kernel_attr,
649                                               const std::vector<KernelAttr> &kernel_attr_list) {
650   auto input_num = kernel_attr.GetInputSize();
651   auto output_num = kernel_attr.GetOutputSize();
652   auto AttrMatched = [](const DataType &attr, const DataType &compared_attr) {
653     return (attr.dtype != compared_attr.dtype && attr.dtype != kTypeUnknown) ||
654            (!AnfAlgo::IsEquivalentFormat(attr.format, compared_attr.format)) ||
655            (attr.object_type != compared_attr.object_type && attr.object_type != kTypeUnknown);
656   };
657   for (size_t index = 0; index < kernel_attr_list.size(); ++index) {
658     const auto &cur_kernel_attr = kernel_attr_list[index];
659     // Attr skip indicates that any attr is supported.
660     if (cur_kernel_attr.GetSkipCheck()) {
661       return std::make_pair(true, index);
662     }
663     auto cur_input_num = cur_kernel_attr.GetInputSize();
664     auto cur_output_num = cur_kernel_attr.GetOutputSize();
665     // The num must be equal when not all same.
666     if (!cur_kernel_attr.GetAllSame() && (input_num != cur_input_num || output_num != cur_output_num)) {
667       continue;
668     }
669 
670     bool mis_match = false;
671     // Check the input attrs.
672     for (size_t i = 0; i < cur_input_num; ++i) {
673       MS_EXCEPTION_IF_CHECK_FAIL((kernel_attr.GetInputSize() > i), "The input num is out of range.");
674       auto &input_attr = kernel_attr.GetInputAttr(i);
675       auto &cur_input_attr = cur_kernel_attr.GetInputAttr(i);
676       if (AttrMatched(input_attr, cur_input_attr)) {
677         mis_match = true;
678         break;
679       }
680     }
681 
682     if (mis_match) {
683       continue;
684     }
685 
686     // Check the output attrs.
687     for (size_t i = 0; i < cur_output_num; ++i) {
688       MS_EXCEPTION_IF_CHECK_FAIL((kernel_attr.GetOutputSize() > i), "The output num is out of range.");
689       auto &output_attr = kernel_attr.GetOutputAttr(i);
690       auto &cur_output_attr = cur_kernel_attr.GetOutputAttr(i);
691       if (AttrMatched(output_attr, cur_output_attr)) {
692         mis_match = true;
693         break;
694       }
695     }
696 
697     if (!mis_match) {
698       return std::make_pair(true, index);
699     }
700   }
701 
702   return std::make_pair(false, 0);
703 }
704 
IsFoldKernelBuildInfo(const KernelBuildInfoPtr & kernel_build_info)705 bool IsFoldKernelBuildInfo(const KernelBuildInfoPtr &kernel_build_info) {
706   MS_EXCEPTION_IF_NULL(kernel_build_info);
707   auto inputs_object_type = kernel_build_info->GetAllInputKernelObjectTypes();
708   if (std::find(inputs_object_type.begin(), inputs_object_type.end(), KernelObjectType::TUPLE) !=
709       inputs_object_type.end()) {
710     return true;
711   }
712 
713   auto outputs_object_type = kernel_build_info->GetAllOutputKernelObjectTypes();
714   if (std::find(outputs_object_type.begin(), outputs_object_type.end(), KernelObjectType::TUPLE) !=
715       outputs_object_type.end()) {
716     return true;
717   }
718 
719   return false;
720 }
721 
GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr & build_info)722 KernelAttr GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr &build_info) {
723   MS_EXCEPTION_IF_NULL(build_info);
724   KernelAttr kernel_attr;
725   for (size_t i = 0; i < build_info->GetInputNum(); ++i) {
726     (void)kernel_attr.AddInputAttr(KernelObjectTypeToTypeId(build_info->GetInputKernelObjectType(i)),
727                                    build_info->GetInputDeviceType(i), build_info->GetInputFormat(i));
728   }
729   for (size_t j = 0; j < build_info->GetOutputNum(); ++j) {
730     (void)kernel_attr.AddOutputAttr(KernelObjectTypeToTypeId(build_info->GetOutputKernelObjectType(j)),
731                                     build_info->GetOutputDeviceType(j), build_info->GetOutputFormat(j));
732   }
733   return kernel_attr;
734 }
735 
GetKernelAttrFromNode(const AnfNodePtr & kernel_node)736 KernelAttr GetKernelAttrFromNode(const AnfNodePtr &kernel_node) {
737   MS_EXCEPTION_IF_NULL(kernel_node);
738   auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
739   return GetKernelAttrFromBuildInfo(build_info);
740 }
741 
GetKernelAttrFromTensors(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)742 KernelAttr GetKernelAttrFromTensors(const std::vector<KernelTensor *> &inputs,
743                                     const std::vector<KernelTensor *> &outputs) {
744   KernelAttr kernel_attr;
745   for (auto tensor : inputs) {
746     (void)kernel_attr.AddInputAttr(tensor->dtype_id(), GetFormatFromEnumToStr(tensor->format()));
747   }
748   for (auto tensor : outputs) {
749     (void)kernel_attr.AddOutputAttr(tensor->dtype_id(), GetFormatFromEnumToStr(tensor->format()));
750   }
751   return kernel_attr;
752 }
753 
SetCpuRefMapToKernelInfo(const CNodePtr & apply_kernel,const std::vector<KernelAttr> & apply_kernel_attrs)754 void SetCpuRefMapToKernelInfo(const CNodePtr &apply_kernel, const std::vector<KernelAttr> &apply_kernel_attrs) {
755   MS_EXCEPTION_IF_NULL(apply_kernel);
756   auto kernel_attrs = apply_kernel_attrs;
757   if (kernel_attrs.empty()) {
758     return;
759   }
760 
761   auto build_info = AnfAlgo::GetSelectKernelBuildInfo(apply_kernel);
762   MS_EXCEPTION_IF_NULL(build_info);
763   auto kernel_attr = GetKernelAttrFromBuildInfo(build_info);
764   std::vector<int64_t> dyn_input_sizes = {};
765   if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, apply_kernel)) {
766     dyn_input_sizes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(apply_kernel, kAttrDynInputSizes);
767   }
768   std::pair<bool, int64_t> match_result;
769 
770   if (kernel_attrs[0].GetSkipCheck()) {
771     // If kernel skips attr check, we need to synchronize the ref map in case it's discarded.
772     SyncOutInRef(kernel_attrs[0], &kernel_attr);
773     kernel_attrs[0] = kernel_attr;
774     match_result = {true, 0};
775   } else if (dyn_input_sizes.empty() || kernel_attrs[0].GetAllSame()) {
776     match_result = MatchKernelAttr(kernel_attr, kernel_attrs);
777   } else {
778     match_result = MatchMultiDynamicKernelAttr(kernel_attr, dyn_input_sizes, kernel_attrs);
779   }
780 
781   auto [is_match, index] = match_result;
782   if (!is_match) {
783     constexpr auto recursive_level = 2;
784     MS_LOG(EXCEPTION) << apply_kernel->fullname_with_scope()
785                       << " does not support this kernel data type: " << build_info->ToString()
786                       << ", node debug name: " << apply_kernel->DebugString(recursive_level);
787   }
788 
789   auto kernel_info = dynamic_cast<device::KernelInfo *>(apply_kernel->kernel_info());
790   MS_EXCEPTION_IF_NULL(kernel_info);
791   const auto &matched_kernel_attr = kernel_attrs[index];
792   if (!matched_kernel_attr.GetOutInRefMap().empty() || matched_kernel_attr.GetAllOutInRef()) {
793     kernel_info->set_ref_map(matched_kernel_attr.GetAllOutInRef(), matched_kernel_attr.GetOutInRefMap());
794   }
795 }
796 
SyncOutInRef(const KernelAttr & from_kernel_attr,KernelAttr * to_kernel_attr)797 void SyncOutInRef(const KernelAttr &from_kernel_attr, KernelAttr *to_kernel_attr) {
798   const auto &out_in_ref = from_kernel_attr.GetOutInRefMap();
799   bool all_out_in_ref = from_kernel_attr.GetAllOutInRef();
800   for (const auto &ref : out_in_ref) {
801     (void)to_kernel_attr->AddOutInRef(ref.first, ref.second);
802   }
803   (void)to_kernel_attr->AddAllOutInRef(all_out_in_ref);
804 }
805 
806 namespace math {
SinCosf(float x,float * sinv,float * cosv)807 void SinCosf(float x, float *sinv, float *cosv) {
808   *sinv = sinf(x);
809   *cosv = cosf(x);
810 }
811 }  // namespace math
812 }  // namespace kernel
813 }  // namespace mindspore
814