1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "kernel/common_utils.h"
18 #include <algorithm>
19 #include <bitset>
20 #include <cmath>
21 #include <fstream>
22 #include <iostream>
23 #include <map>
24 #include <set>
25 #include <thread>
26 #include <tuple>
27 #include <unordered_map>
28 #include <utility>
29 #include <vector>
30 #include "include/backend/anf_runtime_algorithm.h"
31 #include "include/common/utils/anfalgo.h"
32 #include "ir/graph_utils.h"
33 #include "kernel/oplib/oplib.h"
34 #include "kernel/format_utils.h"
35 #include "mindapi/base/type_id.h"
36 #include "mindspore/ccsrc/include/common/debug/common.h"
37 #include "nlohmann/json.hpp"
38 #include "ops/array_op_name.h"
39 #include "ops/conv_pool_op_name.h"
40 #include "ops/framework_ops.h"
41 #include "ops/math_op_name.h"
42 #include "ops/nn_ops.h"
43 #include "ops/sequence_ops.h"
44 #include "utils/anf_utils.h"
45
46 namespace mindspore {
47 namespace kernel {
48 namespace {
49 constexpr char kTypeInt32[] = "Int32";
50 constexpr auto kQuad = 4;
51 constexpr size_t kInputFirstIndex = 0;
52 } // namespace
53
GetOutputNum(const AnfNodePtr & node)54 size_t GetOutputNum(const AnfNodePtr &node) {
55 MS_EXCEPTION_IF_NULL(node);
56 const auto &type = node->Type();
57 if (type == nullptr) {
58 MS_LOG(EXCEPTION) << "Failed to get type in node:" << node->fullname_with_scope();
59 } else if (type->isa<Tuple>()) {
60 auto tuple_type = type->cast<TuplePtr>();
61 MS_EXCEPTION_IF_NULL(tuple_type);
62 if (tuple_type->dynamic_len()) {
63 return 1;
64 }
65 const auto &sub_types = tuple_type->elements();
66 return static_cast<size_t>(std::count_if(sub_types.begin(), sub_types.end(), [](const TypePtr &sub_type) {
67 return sub_type != nullptr && (!sub_type->isa<MonadType>());
68 }));
69 } else if (type->isa<List>()) {
70 auto list_type = type->cast<ListPtr>();
71 MS_EXCEPTION_IF_NULL(list_type);
72 if (list_type->dynamic_len()) {
73 return 1;
74 }
75 const auto &sub_types = list_type->elements();
76 return static_cast<size_t>(std::count_if(sub_types.begin(), sub_types.end(), [](const TypePtr &sub_type) {
77 return sub_type != nullptr && (!sub_type->isa<MonadType>());
78 }));
79 } else if (type->isa<CSRTensorType>()) {
80 return 5;
81 } else if (type->isa<COOTensorType>()) {
82 return 4;
83 }
84 return 1;
85 }
86
CalDiagOffset(int diag_index,int max_diag_len,int inner_rows,int inner_cols,const std::pair<MatrixDiag::Alignment,MatrixDiag::Alignment> & alignment)87 int CalDiagOffset(int diag_index, int max_diag_len, int inner_rows, int inner_cols,
88 const std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> &alignment) {
89 bool right_align_super_diagonal = (alignment.first == MatrixDiag::RIGHT);
90 bool right_align_sub_diagonal = (alignment.second == MatrixDiag::RIGHT);
91 const bool right_align =
92 (diag_index >= 0 && right_align_super_diagonal) || (diag_index <= 0 && right_align_sub_diagonal);
93 const int diag_len = std::min(inner_rows + std::min(0, diag_index), inner_cols - std::max(0, diag_index));
94 const int offset = (right_align) ? (max_diag_len - diag_len) : 0;
95 return offset;
96 }
97
DtypeToTypeId(const std::string & dtypes)98 TypeId DtypeToTypeId(const std::string &dtypes) {
99 if (dtypes == "float") {
100 return TypeId::kNumberTypeFloat32;
101 }
102 if (dtypes.empty()) {
103 return TypeId::kMetaTypeNone;
104 }
105 return StringToTypeId(dtypes);
106 }
107
Dtype2ShortType(const std::string & dtype)108 std::string Dtype2ShortType(const std::string &dtype) {
109 static const std::unordered_map<std::string, std::string> dtype_shortdtype_map = {
110 {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"},
111 {"int32", "i32"}, {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"},
112 {"uint64", "u64"}, {"bool", "bool"}, {"bfloat16", "bf16"}};
113
114 auto iter = dtype_shortdtype_map.find(dtype);
115 if (iter != dtype_shortdtype_map.end()) {
116 return iter->second;
117 } else {
118 MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
119 }
120 }
121
GetDtypeNbyte(const std::string & dtype)122 size_t GetDtypeNbyte(const std::string &dtype) {
123 static const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
124 {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
125 {"int8", sizeof(int) / kQuad}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
126 {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / kQuad}, {"uint16", sizeof(int) / 2},
127 {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
128 {"complex64", sizeof(float) * 2}, {"bfloat16", sizeof(float) / 2}};
129
130 auto iter = dtype_nbyte_map.find(dtype);
131 if (iter != dtype_nbyte_map.end()) {
132 return iter->second;
133 } else {
134 MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
135 }
136 }
137
IsSameShape(const ShapeVector & shape_a,const ShapeVector & shape_b)138 bool IsSameShape(const ShapeVector &shape_a, const ShapeVector &shape_b) { return shape_a == shape_b; }
139
CheckShapesSame(const ShapeArray & shape_array)140 bool CheckShapesSame(const ShapeArray &shape_array) {
141 auto first_shape = shape_array[0];
142 return std::all_of(shape_array.begin() + 1, shape_array.end(),
143 [&first_shape](const ShapeVector &shape) { return IsSameShape(shape, first_shape); });
144 }
145
GetProcessorStr(const AnfNodePtr & anf_node)146 std::string GetProcessorStr(const AnfNodePtr &anf_node) {
147 MS_EXCEPTION_IF_NULL(anf_node);
148 std::string processor = kProcessorUnknown;
149 auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
150 MS_EXCEPTION_IF_NULL(kernel_info);
151 auto build_info = kernel_info->select_kernel_build_info();
152 // we may call this before kernel select.
153 if (build_info == nullptr) {
154 return processor;
155 }
156 switch (build_info->processor()) {
157 case Processor::AICORE:
158 processor = kProcessorAiCore;
159 break;
160
161 case Processor::AICPU:
162 processor = kProcessorAiCpu;
163 break;
164
165 case Processor::CUDA:
166 processor = kProcessorCuda;
167 break;
168
169 case Processor::CPU:
170 processor = kProcessorCpu;
171 break;
172
173 default:
174 MS_LOG(DEBUG) << "Unknown processor type.";
175 break;
176 }
177
178 return processor;
179 }
180
GetOutputObjectTypeListFromKernelAttr(const kernel::KernelAttr & kernel_attr)181 std::vector<TypeId> GetOutputObjectTypeListFromKernelAttr(const kernel::KernelAttr &kernel_attr) {
182 size_t output_attr_size = kernel_attr.GetOutputSize();
183 std::vector<TypeId> res;
184 for (size_t i = 0; i < output_attr_size; ++i) {
185 res.push_back(kernel_attr.GetOutputAttr(i).object_type);
186 }
187 return res;
188 }
189
GetInputObjectTypeListFromKernelAttr(const kernel::KernelAttr & kernel_attr)190 std::vector<TypeId> GetInputObjectTypeListFromKernelAttr(const kernel::KernelAttr &kernel_attr) {
191 size_t input_attr_size = kernel_attr.GetInputSize();
192 std::vector<TypeId> res;
193 for (size_t i = 0; i < input_attr_size; ++i) {
194 res.push_back(kernel_attr.GetInputAttr(i).object_type);
195 }
196 return res;
197 }
198
TypeIdToKernelObjectType(const TypeId & type_id)199 KernelObjectType TypeIdToKernelObjectType(const TypeId &type_id) {
200 std::unordered_map<TypeId, KernelObjectType> trans_map{{kObjectTypeTuple, KernelObjectType::TUPLE},
201 {kObjectTypeNumber, KernelObjectType::SCALAR},
202 {kObjectTypeTensorType, KernelObjectType::TENSOR}};
203 if (trans_map.find(type_id) == trans_map.end()) {
204 MS_LOG(DEBUG) << "Unsupported type id " << TypeIdToString(type_id)
205 << ", that cannot converted to corresponding kernel object type.";
206 return KernelObjectType::UNKNOWN_TYPE;
207 }
208 return trans_map[type_id];
209 }
210
TypeIdToKernelObjectType(const std::vector<TypeId> & type_ids)211 std::vector<KernelObjectType> TypeIdToKernelObjectType(const std::vector<TypeId> &type_ids) {
212 std::vector<KernelObjectType> ret;
213 (void)std::transform(type_ids.begin(), type_ids.end(), std::back_inserter(ret),
214 [](const TypeId &type_id) { return kernel::TypeIdToKernelObjectType(type_id); });
215 return ret;
216 }
217
TypeIdToKernelObjectTypeForTupleUnfold(const TypeId & type_id)218 KernelObjectType TypeIdToKernelObjectTypeForTupleUnfold(const TypeId &type_id) {
219 std::unordered_map<TypeId, KernelObjectType> trans_map{{kObjectTypeTuple, KernelObjectType::TUPLE_UNFOLD},
220 {kObjectTypeNumber, KernelObjectType::SCALAR},
221 {kObjectTypeTensorType, KernelObjectType::TENSOR}};
222 if (trans_map.find(type_id) == trans_map.end()) {
223 MS_LOG(DEBUG) << "Unsupported type id " << TypeIdToString(type_id)
224 << ", that cannot converted to corresponding kernel object type.";
225 return KernelObjectType::UNKNOWN_TYPE;
226 }
227 return trans_map[type_id];
228 }
229
TypeIdToKernelObjectTypeForTupleUnfold(const std::vector<TypeId> & type_ids)230 std::vector<KernelObjectType> TypeIdToKernelObjectTypeForTupleUnfold(const std::vector<TypeId> &type_ids) {
231 std::vector<KernelObjectType> ret;
232 (void)std::transform(type_ids.begin(), type_ids.end(), std::back_inserter(ret),
233 [](const TypeId &type_id) { return kernel::TypeIdToKernelObjectTypeForTupleUnfold(type_id); });
234 return ret;
235 }
236
KernelObjectTypeToTypeId(const KernelObjectType & object_type)237 TypeId KernelObjectTypeToTypeId(const KernelObjectType &object_type) {
238 std::unordered_map<KernelObjectType, TypeId> trans_map{{KernelObjectType::TUPLE, kObjectTypeTuple},
239 {KernelObjectType::TUPLE_UNFOLD, kObjectTypeTuple},
240 {KernelObjectType::SCALAR, kObjectTypeNumber},
241 {KernelObjectType::TENSOR, kObjectTypeTensorType}};
242 if (trans_map.find(object_type) == trans_map.end()) {
243 MS_LOG(DEBUG) << "Unsupported kernel object type " << object_type
244 << ", that cannot converted to corresponding type id.";
245 return kTypeUnknown;
246 }
247 return trans_map[object_type];
248 }
249
250 // The allsame/skip_check and the unequal size scenario don't support object type backoff and use the object_types,
251 // other scenes support the object type backoff and use the selected_object_types.
CalKernelObjectTypes(const std::vector<TypeId> & object_types,const std::vector<TypeId> & selected_object_types,bool all_same,bool skip_check)252 std::vector<KernelObjectType> CalKernelObjectTypes(const std::vector<TypeId> &object_types,
253 const std::vector<TypeId> &selected_object_types, bool all_same,
254 bool skip_check) {
255 std::vector<KernelObjectType> ret;
256 // Use the selected_object_types in the equal size scenario.
257 if (object_types.size() == selected_object_types.size()) {
258 for (size_t i = 0; i < selected_object_types.size(); ++i) {
259 // Allsame/skip_check doesn't support the backoff.
260 bool not_backoff = ((all_same || skip_check) && (selected_object_types[i] != object_types[i]));
261 if (not_backoff) {
262 (void)ret.emplace_back(TypeIdToKernelObjectTypeForTupleUnfold(object_types[i]));
263 } else {
264 (void)ret.emplace_back(TypeIdToKernelObjectType(selected_object_types[i]));
265 }
266 }
267 return ret;
268 }
269
270 // Use the object_types in the unequal size scenario, and convert tuple to tupleUnflod.
271 for (size_t i = 0; i < object_types.size(); ++i) {
272 (void)ret.emplace_back(TypeIdToKernelObjectTypeForTupleUnfold(object_types[i]));
273 }
274 return ret;
275 }
276
CalInputKernelObjectTypes(const AnfNodePtr & kernel_node,const kernel::KernelAttr & selected_kernel_attr)277 std::vector<KernelObjectType> CalInputKernelObjectTypes(const AnfNodePtr &kernel_node,
278 const kernel::KernelAttr &selected_kernel_attr) {
279 MS_EXCEPTION_IF_NULL(kernel_node);
280 auto selected_input_object_types = GetInputObjectTypeListFromKernelAttr(selected_kernel_attr);
281 auto input_object_types = AnfAlgo::GetAllInputObjectType(kernel_node);
282 return CalKernelObjectTypes(input_object_types, selected_input_object_types, selected_kernel_attr.GetAllSame(),
283 selected_kernel_attr.GetSkipCheck());
284 }
285
CalOutputKernelObjectTypes(const AnfNodePtr & kernel_node,const kernel::KernelAttr & selected_kernel_attr)286 std::vector<KernelObjectType> CalOutputKernelObjectTypes(const AnfNodePtr &kernel_node,
287 const kernel::KernelAttr &selected_kernel_attr) {
288 MS_EXCEPTION_IF_NULL(kernel_node);
289 auto selected_output_object_types = GetOutputObjectTypeListFromKernelAttr(selected_kernel_attr);
290 auto output_object_types = AnfAlgo::GetAllOutputObjectType(kernel_node);
291 return CalKernelObjectTypes(output_object_types, selected_output_object_types, selected_kernel_attr.GetAllSame(),
292 selected_kernel_attr.GetSkipCheck());
293 }
294
CalOutputElementObjectTypes(const AnfNodePtr & kernel_node,const kernel::KernelAttr & selected_kernel_attr)295 std::vector<KernelObjectType> CalOutputElementObjectTypes(const AnfNodePtr &kernel_node,
296 const kernel::KernelAttr &selected_kernel_attr) {
297 MS_EXCEPTION_IF_NULL(kernel_node);
298 auto selected_output_object_types = GetOutputObjectTypeListFromKernelAttr(selected_kernel_attr);
299 MS_LOG(DEBUG) << "Output object type:" << selected_output_object_types << " for node:" << kernel_node->DebugString()
300 << " select attr:" << kernel::FetchPrintInfoByKernelAttr(selected_kernel_attr);
301 auto element_num = GetOutputNum(kernel_node);
302 if (selected_kernel_attr.GetAllSame() && selected_output_object_types.size() == 1) {
303 return std::vector<KernelObjectType>(element_num, TypeIdToKernelObjectType(selected_output_object_types[0]));
304 }
305 MS_EXCEPTION_IF_CHECK_FAIL(element_num == selected_output_object_types.size(),
306 "Check multi-output kernel attr size failed.");
307 return TypeIdToKernelObjectType(selected_output_object_types);
308 }
309
FetchPrintInfoByKernelAttr(KernelAttr selected_kernel_attr)310 std::string FetchPrintInfoByKernelAttr(KernelAttr selected_kernel_attr) {
311 std::string attr_info = "input[";
312 (void)std::for_each(std::begin(selected_kernel_attr.input_type()), std::end(selected_kernel_attr.input_type()),
313 [&attr_info](auto &input_type) {
314 attr_info += TypeIdToString(input_type.object_type) + " " + TypeIdToString(input_type.dtype) +
315 " " + input_type.format + ",";
316 });
317 attr_info += "] output[";
318 (void)std::for_each(std::begin(selected_kernel_attr.output_type()), std::end(selected_kernel_attr.output_type()),
319 [&attr_info](auto &output_type) {
320 attr_info += TypeIdToString(output_type.object_type) + " " + TypeIdToString(output_type.dtype) +
321 " " + output_type.format + ",";
322 });
323 attr_info += "]";
324 return attr_info;
325 }
326
SetKernelObjectTypeBuildInfo(const AnfNodePtr & kernel_node,const std::vector<KernelObjectType> & input_kernel_object_types,const std::vector<KernelObjectType> & output_kernel_object_types)327 void SetKernelObjectTypeBuildInfo(const AnfNodePtr &kernel_node,
328 const std::vector<KernelObjectType> &input_kernel_object_types,
329 const std::vector<KernelObjectType> &output_kernel_object_types) {
330 MS_EXCEPTION_IF_NULL(kernel_node);
331 if (kernel_node->kernel_info() == nullptr) {
332 kernel_node->set_kernel_info(std::make_shared<device::KernelInfo>());
333 }
334 if (!kernel_node->kernel_info()->has_build_info()) {
335 AnfAlgo::SetSelectKernelBuildInfo(std::make_shared<kernel::KernelBuildInfo>(), kernel_node.get());
336 }
337
338 MS_LOG(DEBUG) << kernel_node->fullname_with_scope() << " input kernel object type is: " << input_kernel_object_types
339 << ", output kernel object type is: " << output_kernel_object_types;
340 auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
341 kernel_build_info->SetOutputsKernelObjectType(output_kernel_object_types);
342 kernel_build_info->SetInputsKernelObjectType(input_kernel_object_types);
343 }
344
SetKernelObjectTypeBuildInfo(const AnfNodePtr & kernel_node,const std::vector<KernelObjectType> & input_kernel_object_types,const std::vector<KernelObjectType> & output_kernel_object_types,const std::vector<KernelObjectType> & output_elements_kernel_object_types)345 void SetKernelObjectTypeBuildInfo(const AnfNodePtr &kernel_node,
346 const std::vector<KernelObjectType> &input_kernel_object_types,
347 const std::vector<KernelObjectType> &output_kernel_object_types,
348 const std::vector<KernelObjectType> &output_elements_kernel_object_types) {
349 MS_EXCEPTION_IF_NULL(kernel_node);
350 if (kernel_node->kernel_info() == nullptr) {
351 kernel_node->set_kernel_info(std::make_shared<device::KernelInfo>());
352 }
353 if (!kernel_node->kernel_info()->has_build_info()) {
354 AnfAlgo::SetSelectKernelBuildInfo(std::make_shared<kernel::KernelBuildInfo>(), kernel_node.get());
355 }
356
357 MS_LOG(DEBUG) << kernel_node->fullname_with_scope() << " input kernel object type is: " << input_kernel_object_types
358 << ", output kernel object type is: " << output_kernel_object_types
359 << ", output elements kernel object type is: " << output_elements_kernel_object_types;
360 auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
361 kernel_build_info->SetOutputsKernelObjectType(output_kernel_object_types);
362 kernel_build_info->SetInputsKernelObjectType(input_kernel_object_types);
363 kernel_build_info->SetOutputElementsKernelObjectType(output_elements_kernel_object_types);
364 }
365
HasOutputElementsKernelObjectType(const std::vector<KernelObjectType> & output_kernel_object_types)366 bool HasOutputElementsKernelObjectType(const std::vector<KernelObjectType> &output_kernel_object_types) {
367 return output_kernel_object_types.size() == 1 &&
368 output_kernel_object_types[0] == kernel::KernelObjectType::TUPLE_UNFOLD;
369 }
370
SetKernelObjectTypeWithSelectedAttr(const CNodePtr & kernel_node,const kernel::KernelAttr & selected_kernel_attr)371 void SetKernelObjectTypeWithSelectedAttr(const CNodePtr &kernel_node, const kernel::KernelAttr &selected_kernel_attr) {
372 MS_EXCEPTION_IF_NULL(kernel_node);
373 std::vector<KernelObjectType> input_kernel_object_types;
374 if (common::AnfAlgo::HasNodeAttr(kInputRealTuple, kernel_node)) {
375 input_kernel_object_types = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAllInputObjectType(kernel_node));
376 } else {
377 input_kernel_object_types = CalInputKernelObjectTypes(kernel_node, selected_kernel_attr);
378 }
379
380 std::vector<KernelObjectType> output_kernel_object_types;
381 if (common::AnfAlgo::HasNodeAttr(kOutputRealTuple, kernel_node)) {
382 output_kernel_object_types = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAllOutputObjectType(kernel_node));
383 } else {
384 output_kernel_object_types = CalOutputKernelObjectTypes(kernel_node, selected_kernel_attr);
385 }
386
387 std::vector<KernelObjectType> output_element_object_types;
388 if (HasOutputElementsKernelObjectType(output_kernel_object_types)) {
389 output_element_object_types = CalOutputElementObjectTypes(kernel_node, selected_kernel_attr);
390 }
391 MS_LOG(DEBUG) << "Set kernel object type:" << output_kernel_object_types
392 << " for node:" << kernel_node->fullname_with_scope();
393 SetKernelObjectTypeBuildInfo(kernel_node, input_kernel_object_types, output_kernel_object_types,
394 output_element_object_types);
395 }
396
AddInputAttr(const TypeId & object_type,const TypeId & ms_type,const std::string & format)397 KernelAttr &KernelAttr::AddInputAttr(const TypeId &object_type, const TypeId &ms_type, const std::string &format) {
398 (void)input_type_.emplace_back(DataType(ms_type, format, object_type));
399 return *this;
400 }
401
AddOptionalInputAttr(const TypeId & object_type,const TypeId & ms_type,const std::string & format)402 KernelAttr &KernelAttr::AddOptionalInputAttr(const TypeId &object_type, const TypeId &ms_type,
403 const std::string &format) {
404 (void)input_type_.emplace_back(DataType(ms_type, format, object_type, true));
405 return *this;
406 }
407
AddOutputAttr(const TypeId & object_type,const TypeId & ms_type,const std::string & format)408 KernelAttr &KernelAttr::AddOutputAttr(const TypeId &object_type, const TypeId &ms_type, const std::string &format) {
409 (void)output_type_.emplace_back(DataType(ms_type, format, object_type));
410 return *this;
411 }
412
AddInputAttr(const TypeId & ms_type,const std::string & format)413 KernelAttr &KernelAttr::AddInputAttr(const TypeId &ms_type, const std::string &format) {
414 (void)input_type_.emplace_back(DataType(ms_type, format));
415 return *this;
416 }
417
AddOptionalInputAttr(const TypeId & ms_type,const std::string & format)418 KernelAttr &KernelAttr::AddOptionalInputAttr(const TypeId &ms_type, const std::string &format) {
419 (void)input_type_.emplace_back(DataType(ms_type, format, kObjectTypeTensorType, true));
420 return *this;
421 }
422
AddOutputAttr(const TypeId & ms_type,const std::string & format)423 KernelAttr &KernelAttr::AddOutputAttr(const TypeId &ms_type, const std::string &format) {
424 (void)output_type_.emplace_back(DataType(ms_type, format));
425 return *this;
426 }
427
AddAllSameAttr(bool all_same,size_t all_same_input_num,bool group_allsame)428 KernelAttr &KernelAttr::AddAllSameAttr(bool all_same, size_t all_same_input_num, bool group_allsame) {
429 all_same_ = all_same;
430 is_group_allsame_ = group_allsame;
431 if (all_same_input_num < 1) {
432 MS_LOG(EXCEPTION) << "Allsame attr must >= 1, but get " << all_same_input_num;
433 }
434 all_same_input_num_ = all_same_input_num;
435 return *this;
436 }
437
AddSkipCheckAttr(bool skip_check)438 KernelAttr &KernelAttr::AddSkipCheckAttr(bool skip_check) {
439 skip_check_ = skip_check;
440 return *this;
441 }
442
AddRealTuple(const bool & is_real_tuple)443 KernelAttr &KernelAttr::AddRealTuple(const bool &is_real_tuple) {
444 is_real_tuple_ = is_real_tuple;
445 return *this;
446 }
447
AddOutInRef(size_t output_index,size_t input_index)448 KernelAttr &KernelAttr::AddOutInRef(size_t output_index, size_t input_index) {
449 out_in_ref_map_[output_index] = input_index;
450 return *this;
451 }
452
AddAllOutInRef(bool all_out_in_ref)453 KernelAttr &KernelAttr::AddAllOutInRef(bool all_out_in_ref) {
454 all_out_in_ref_ = all_out_in_ref;
455 return *this;
456 }
457
SetInputAttr(const size_t index,const TypeId & ms_type,const std::string & format)458 void KernelAttr::SetInputAttr(const size_t index, const TypeId &ms_type, const std::string &format) {
459 if (index >= input_type_.size()) {
460 MS_LOG(EXCEPTION) << "Invalid index for input: " << index << ", out of range.";
461 }
462 input_type_[index] = DataType(ms_type, format);
463 }
464
SetOutputAttr(const size_t index,const TypeId & ms_type,const std::string & format)465 void KernelAttr::SetOutputAttr(const size_t index, const TypeId &ms_type, const std::string &format) {
466 if (index >= output_type_.size()) {
467 MS_LOG(EXCEPTION) << "Invalid index for output: " << index << ", out of range.";
468 }
469 output_type_[index] = DataType(ms_type, format);
470 }
471
SetInputAttrList(const std::vector<DataType> & addr_list)472 void KernelAttr::SetInputAttrList(const std::vector<DataType> &addr_list) {
473 input_type_.assign(addr_list.begin(), addr_list.end());
474 }
475
SetOutputAttrList(const std::vector<DataType> & addr_list)476 void KernelAttr::SetOutputAttrList(const std::vector<DataType> &addr_list) {
477 output_type_.assign(addr_list.begin(), addr_list.end());
478 }
479
operator <<(std::ostream & os,KernelAttr kernel_attr)480 std::ostream &operator<<(std::ostream &os, KernelAttr kernel_attr) {
481 std::stringstream ss;
482 ss << "[Kernel Attr] all same: " << kernel_attr.GetAllSame();
483 if (kernel_attr.GetSkipCheck()) {
484 ss << ", skip check: true";
485 }
486 size_t input_num = kernel_attr.GetInputSize();
487 if (input_num > 0) {
488 ss << ", input(";
489 for (size_t i = 0; i < input_num; ++i) {
490 ss << TypeIdLabel(kernel_attr.GetInputAttr(i).dtype);
491 if (kernel_attr.GetInputAttr(i).is_optional) {
492 ss << "|None";
493 }
494 if (i != input_num - 1) {
495 ss << ",";
496 }
497 }
498 ss << ") ";
499 }
500 size_t output_num = kernel_attr.GetOutputSize();
501 if (output_num > 0) {
502 ss << ", output(";
503 for (size_t i = 0; i < output_num; ++i) {
504 ss << TypeIdLabel(kernel_attr.GetOutputAttr(i).dtype);
505 if (i != output_num - 1) {
506 ss << ",";
507 }
508 }
509 ss << ").";
510 }
511
512 return os << ss.str();
513 }
514
MatchMultiDynamicKernelAttr(const KernelAttr & kernel_attr,const std::vector<int64_t> & dyn_input_sizes,const std::vector<KernelAttr> & kernel_attr_list)515 std::pair<bool, size_t> MatchMultiDynamicKernelAttr(const KernelAttr &kernel_attr,
516 const std::vector<int64_t> &dyn_input_sizes,
517 const std::vector<KernelAttr> &kernel_attr_list) {
518 auto output_num = kernel_attr.GetOutputSize();
519 for (size_t index = 0; index < kernel_attr_list.size(); ++index) {
520 // support multi dynamic inputs.
521 const auto &cur_kernel_attr = kernel_attr_list[index];
522 auto cur_input_num = cur_kernel_attr.GetInputSize();
523 if (dyn_input_sizes.size() != cur_input_num) {
524 MS_LOG(EXCEPTION) << "Kernel attr's input num: " << cur_input_num
525 << ", is not equal to dynamic input size: " << dyn_input_sizes.size();
526 }
527 bool mis_match = false;
528 size_t input_index = kInputFirstIndex;
529 for (size_t i = 0; i < cur_input_num; ++i) {
530 int64_t dyn_input_size = dyn_input_sizes[i];
531 if (dyn_input_size < 0) {
532 dyn_input_size = 1;
533 }
534 auto dtype = cur_kernel_attr.GetInputAttr(i).dtype;
535 for (size_t j = 0; j < LongToSize(dyn_input_size); ++j) {
536 if (kernel_attr.GetInputAttr(input_index).dtype != dtype) {
537 mis_match = true;
538 break;
539 }
540 ++input_index;
541 }
542 if (mis_match) {
543 break;
544 }
545 }
546 if (mis_match) {
547 continue;
548 }
549
550 // only support one dynamic output. TODO: support multi dynamic output.
551 for (size_t i = 0; i < output_num; ++i) {
552 auto dtype = cur_kernel_attr.GetOutputAttr(i).dtype;
553 if (kernel_attr.GetInputAttr(i).dtype != dtype) {
554 mis_match = true;
555 break;
556 }
557 }
558 if (!mis_match) {
559 return std::make_pair(true, index);
560 }
561 }
562 return std::make_pair(false, 0);
563 }
564
CheckAttrForAllSameInput(const size_t input_num,const std::vector<mindspore::TypeId> & input_types,const KernelAttr & cur_kernel_attr)565 bool CheckAttrForAllSameInput(const size_t input_num, const std::vector<mindspore::TypeId> &input_types,
566 const KernelAttr &cur_kernel_attr) {
567 auto cur_input_num = cur_kernel_attr.GetInputSize();
568 bool is_group_allsame = cur_kernel_attr.GetGroupAllSame();
569 size_t cur_all_same_input_num = cur_kernel_attr.GetAllSameInputNum(); // default 0; else >=1 when allsame=true
570 size_t cur_standalone_input_num = cur_input_num - cur_all_same_input_num;
571 size_t each_attr_input_num =
572 (input_num - cur_standalone_input_num) / (cur_all_same_input_num == 0 ? 1 : cur_all_same_input_num);
573 // deal with allsame inputs
574 if (is_group_allsame) {
575 for (size_t i = 0; i < each_attr_input_num; ++i) {
576 for (size_t j = 0; j < cur_all_same_input_num; ++j) {
577 auto dtype = cur_kernel_attr.GetInputAttr(j).dtype;
578 auto start = j + i * cur_all_same_input_num;
579 if (input_types[start] != dtype && input_types[start] != kTypeUnknown) {
580 return true;
581 }
582 }
583 }
584 } else {
585 for (size_t i = 0; i < cur_all_same_input_num; ++i) {
586 for (size_t j = 0; j < each_attr_input_num; ++j) {
587 auto dtype = cur_kernel_attr.GetInputAttr(i).dtype;
588 auto start = j + i * each_attr_input_num;
589 if (input_types[start] != dtype && input_types[start] != kTypeUnknown) {
590 return true;
591 }
592 }
593 }
594 }
595
596 // deal with the rest except allsame inputs
597 for (size_t i = cur_all_same_input_num; i < cur_standalone_input_num; ++i) {
598 auto dtype = cur_kernel_attr.GetInputAttr(i).dtype;
599 auto start = each_attr_input_num * cur_all_same_input_num + i;
600 if (!(cur_kernel_attr.GetInputAttr(i).is_optional && input_types[start] == kMetaTypeNone) &&
601 (input_types[start] != dtype && input_types[start] != kTypeUnknown)) {
602 return true;
603 }
604 }
605 return false;
606 }
607
MatchKernelAttr(const KernelAttr & kernel_attr,const std::vector<KernelAttr> & kernel_attr_list)608 std::pair<bool, size_t> MatchKernelAttr(const KernelAttr &kernel_attr,
609 const std::vector<KernelAttr> &kernel_attr_list) {
610 // kernel_attr should not be all same. If so, then return false.
611 if (kernel_attr.GetAllSame()) {
612 return std::make_pair(false, 0);
613 }
614 auto input_num = kernel_attr.GetInputSize();
615 auto output_num = kernel_attr.GetOutputSize();
616
617 for (size_t index = 0; index < kernel_attr_list.size(); ++index) {
618 const auto &cur_kernel_attr = kernel_attr_list[index];
619 auto cur_input_num = cur_kernel_attr.GetInputSize();
620 auto cur_output_num = cur_kernel_attr.GetOutputSize();
621 if (!cur_kernel_attr.GetAllSame() && (input_num != cur_input_num || output_num != cur_output_num)) {
622 continue;
623 }
624 std::vector<mindspore::TypeId> input_types;
625 (void)std::transform(kernel_attr.input_type().begin(), kernel_attr.input_type().end(),
626 std::back_inserter(input_types), [](const DataType &Dtype) { return Dtype.dtype; });
627
628 bool mis_match = CheckAttrForAllSameInput(input_num, input_types, cur_kernel_attr);
629 if (mis_match) {
630 continue;
631 }
632
633 for (size_t i = 0; i < output_num; ++i) {
634 auto dtype = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).dtype;
635 if (kernel_attr.GetOutputAttr(i).dtype != dtype && kernel_attr.GetOutputAttr(i).dtype != kTypeUnknown) {
636 mis_match = true;
637 break;
638 }
639 }
640 if (!mis_match) {
641 return std::make_pair(true, index);
642 }
643 }
644
645 return std::make_pair(false, 0);
646 }
647
MatchKernelAttrStrict(const KernelAttr & kernel_attr,const std::vector<KernelAttr> & kernel_attr_list)648 std::pair<bool, size_t> MatchKernelAttrStrict(const KernelAttr &kernel_attr,
649 const std::vector<KernelAttr> &kernel_attr_list) {
650 auto input_num = kernel_attr.GetInputSize();
651 auto output_num = kernel_attr.GetOutputSize();
652 auto AttrMatched = [](const DataType &attr, const DataType &compared_attr) {
653 return (attr.dtype != compared_attr.dtype && attr.dtype != kTypeUnknown) ||
654 (!AnfAlgo::IsEquivalentFormat(attr.format, compared_attr.format)) ||
655 (attr.object_type != compared_attr.object_type && attr.object_type != kTypeUnknown);
656 };
657 for (size_t index = 0; index < kernel_attr_list.size(); ++index) {
658 const auto &cur_kernel_attr = kernel_attr_list[index];
659 // Attr skip indicates that any attr is supported.
660 if (cur_kernel_attr.GetSkipCheck()) {
661 return std::make_pair(true, index);
662 }
663 auto cur_input_num = cur_kernel_attr.GetInputSize();
664 auto cur_output_num = cur_kernel_attr.GetOutputSize();
665 // The num must be equal when not all same.
666 if (!cur_kernel_attr.GetAllSame() && (input_num != cur_input_num || output_num != cur_output_num)) {
667 continue;
668 }
669
670 bool mis_match = false;
671 // Check the input attrs.
672 for (size_t i = 0; i < cur_input_num; ++i) {
673 MS_EXCEPTION_IF_CHECK_FAIL((kernel_attr.GetInputSize() > i), "The input num is out of range.");
674 auto &input_attr = kernel_attr.GetInputAttr(i);
675 auto &cur_input_attr = cur_kernel_attr.GetInputAttr(i);
676 if (AttrMatched(input_attr, cur_input_attr)) {
677 mis_match = true;
678 break;
679 }
680 }
681
682 if (mis_match) {
683 continue;
684 }
685
686 // Check the output attrs.
687 for (size_t i = 0; i < cur_output_num; ++i) {
688 MS_EXCEPTION_IF_CHECK_FAIL((kernel_attr.GetOutputSize() > i), "The output num is out of range.");
689 auto &output_attr = kernel_attr.GetOutputAttr(i);
690 auto &cur_output_attr = cur_kernel_attr.GetOutputAttr(i);
691 if (AttrMatched(output_attr, cur_output_attr)) {
692 mis_match = true;
693 break;
694 }
695 }
696
697 if (!mis_match) {
698 return std::make_pair(true, index);
699 }
700 }
701
702 return std::make_pair(false, 0);
703 }
704
IsFoldKernelBuildInfo(const KernelBuildInfoPtr & kernel_build_info)705 bool IsFoldKernelBuildInfo(const KernelBuildInfoPtr &kernel_build_info) {
706 MS_EXCEPTION_IF_NULL(kernel_build_info);
707 auto inputs_object_type = kernel_build_info->GetAllInputKernelObjectTypes();
708 if (std::find(inputs_object_type.begin(), inputs_object_type.end(), KernelObjectType::TUPLE) !=
709 inputs_object_type.end()) {
710 return true;
711 }
712
713 auto outputs_object_type = kernel_build_info->GetAllOutputKernelObjectTypes();
714 if (std::find(outputs_object_type.begin(), outputs_object_type.end(), KernelObjectType::TUPLE) !=
715 outputs_object_type.end()) {
716 return true;
717 }
718
719 return false;
720 }
721
GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr & build_info)722 KernelAttr GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr &build_info) {
723 MS_EXCEPTION_IF_NULL(build_info);
724 KernelAttr kernel_attr;
725 for (size_t i = 0; i < build_info->GetInputNum(); ++i) {
726 (void)kernel_attr.AddInputAttr(KernelObjectTypeToTypeId(build_info->GetInputKernelObjectType(i)),
727 build_info->GetInputDeviceType(i), build_info->GetInputFormat(i));
728 }
729 for (size_t j = 0; j < build_info->GetOutputNum(); ++j) {
730 (void)kernel_attr.AddOutputAttr(KernelObjectTypeToTypeId(build_info->GetOutputKernelObjectType(j)),
731 build_info->GetOutputDeviceType(j), build_info->GetOutputFormat(j));
732 }
733 return kernel_attr;
734 }
735
GetKernelAttrFromNode(const AnfNodePtr & kernel_node)736 KernelAttr GetKernelAttrFromNode(const AnfNodePtr &kernel_node) {
737 MS_EXCEPTION_IF_NULL(kernel_node);
738 auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
739 return GetKernelAttrFromBuildInfo(build_info);
740 }
741
GetKernelAttrFromTensors(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)742 KernelAttr GetKernelAttrFromTensors(const std::vector<KernelTensor *> &inputs,
743 const std::vector<KernelTensor *> &outputs) {
744 KernelAttr kernel_attr;
745 for (auto tensor : inputs) {
746 (void)kernel_attr.AddInputAttr(tensor->dtype_id(), GetFormatFromEnumToStr(tensor->format()));
747 }
748 for (auto tensor : outputs) {
749 (void)kernel_attr.AddOutputAttr(tensor->dtype_id(), GetFormatFromEnumToStr(tensor->format()));
750 }
751 return kernel_attr;
752 }
753
SetCpuRefMapToKernelInfo(const CNodePtr & apply_kernel,const std::vector<KernelAttr> & apply_kernel_attrs)754 void SetCpuRefMapToKernelInfo(const CNodePtr &apply_kernel, const std::vector<KernelAttr> &apply_kernel_attrs) {
755 MS_EXCEPTION_IF_NULL(apply_kernel);
756 auto kernel_attrs = apply_kernel_attrs;
757 if (kernel_attrs.empty()) {
758 return;
759 }
760
761 auto build_info = AnfAlgo::GetSelectKernelBuildInfo(apply_kernel);
762 MS_EXCEPTION_IF_NULL(build_info);
763 auto kernel_attr = GetKernelAttrFromBuildInfo(build_info);
764 std::vector<int64_t> dyn_input_sizes = {};
765 if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, apply_kernel)) {
766 dyn_input_sizes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(apply_kernel, kAttrDynInputSizes);
767 }
768 std::pair<bool, int64_t> match_result;
769
770 if (kernel_attrs[0].GetSkipCheck()) {
771 // If kernel skips attr check, we need to synchronize the ref map in case it's discarded.
772 SyncOutInRef(kernel_attrs[0], &kernel_attr);
773 kernel_attrs[0] = kernel_attr;
774 match_result = {true, 0};
775 } else if (dyn_input_sizes.empty() || kernel_attrs[0].GetAllSame()) {
776 match_result = MatchKernelAttr(kernel_attr, kernel_attrs);
777 } else {
778 match_result = MatchMultiDynamicKernelAttr(kernel_attr, dyn_input_sizes, kernel_attrs);
779 }
780
781 auto [is_match, index] = match_result;
782 if (!is_match) {
783 constexpr auto recursive_level = 2;
784 MS_LOG(EXCEPTION) << apply_kernel->fullname_with_scope()
785 << " does not support this kernel data type: " << build_info->ToString()
786 << ", node debug name: " << apply_kernel->DebugString(recursive_level);
787 }
788
789 auto kernel_info = dynamic_cast<device::KernelInfo *>(apply_kernel->kernel_info());
790 MS_EXCEPTION_IF_NULL(kernel_info);
791 const auto &matched_kernel_attr = kernel_attrs[index];
792 if (!matched_kernel_attr.GetOutInRefMap().empty() || matched_kernel_attr.GetAllOutInRef()) {
793 kernel_info->set_ref_map(matched_kernel_attr.GetAllOutInRef(), matched_kernel_attr.GetOutInRefMap());
794 }
795 }
796
SyncOutInRef(const KernelAttr & from_kernel_attr,KernelAttr * to_kernel_attr)797 void SyncOutInRef(const KernelAttr &from_kernel_attr, KernelAttr *to_kernel_attr) {
798 const auto &out_in_ref = from_kernel_attr.GetOutInRefMap();
799 bool all_out_in_ref = from_kernel_attr.GetAllOutInRef();
800 for (const auto &ref : out_in_ref) {
801 (void)to_kernel_attr->AddOutInRef(ref.first, ref.second);
802 }
803 (void)to_kernel_attr->AddAllOutInRef(all_out_in_ref);
804 }
805
806 namespace math {
SinCosf(float x,float * sinv,float * cosv)807 void SinCosf(float x, float *sinv, float *cosv) {
808 *sinv = sinf(x);
809 *cosv = cosf(x);
810 }
811 } // namespace math
812 } // namespace kernel
813 } // namespace mindspore
814