1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "transform/acl_ir/acl_helper.h"
18 #include <set>
19 #include <map>
20 #include <unordered_map>
21 #include <string>
22 #include "include/api/data_type.h"
23 #include "include/backend/anf_runtime_algorithm.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "include/common/utils/utils.h"
26 #include "include/transform/graph_ir/types.h"
27 #include "ops/nn_ops.h"
28 #include "ops/array_ops.h"
29 #include "ops/conv_pool_ops.h"
30 #include "ops/structure_ops.h"
31 #include "ops/ascend_op_name.h"
32 #include "ops/image_op_name.h"
33 #include "ops/math_op_name.h"
34 #include "runtime/device/ms_device_shape_transfer.h"
35 #include "plugin/device/ascend/hal/common/ascend_utils.h"
36 #include "transform/acl_ir/acl_adapter_info.h"
37 #include "transform/acl_ir/ge_adapter_info.h"
38 #include "ops/op_utils.h"
39
40 namespace mindspore {
41 namespace transform {
42 namespace {
43 #define GET_DEFAULT_FORMAT(shape) (shape.size() == kDim4 ? kOpFormat_NCHW : kOpFormat_DEFAULT)
44 static const std::set<std::string> kDefaultOutputNode = {
45 // Dynamic output shape kernel.
46 kUniqueOpName, kMaskedSelectOpName, kNonMaxSuppressionV3OpName,
47 // Dropout
48 kDropoutGenMaskOpName, kDropoutGenMaskV3OpName, kStatelessDropOutGenMaskOpName, kDropoutDoMaskOpName,
49 kDropoutDoMaskV3OpName, kDropoutOpName, kDropoutGradOpName, kDropout2DOpName, kDropout3DOpName,
50 // Special Op
51 kAffineGridOpName, kRangeOpName, kBernoulliOpName};
52
53 static const std::set<std::string> kHcomOps = {
54 kHcomOpTypeAllReduce, kHcomOpTypeReduce, kHcomOpTypeAllGather, kHcomOpTypeBroadcast, kHcomOpTypeSend,
55 kHcomOpTypeReceive, kHcomOpTypeReduceScatter, kHcomOpTypeAllToAllV, kHcomOpTypeBarrier, kHcomOpTypeScatter,
56 kHcomOpTypeGather, kHcomOpTypeBatchSendRecv, kHcomOpTypeAlltoAllV};
57
58 static const HashMap<GeDataType, TypeId> kGeTypeToMsType = {{GeDataType::DT_BOOL, kNumberTypeBool},
59 {GeDataType::DT_INT8, kNumberTypeInt8},
60 {GeDataType::DT_INT16, kNumberTypeInt16},
61 {GeDataType::DT_INT32, kNumberTypeInt32},
62 {GeDataType::DT_INT64, kNumberTypeInt64},
63 {GeDataType::DT_UINT8, kNumberTypeUInt8},
64 {GeDataType::DT_UINT16, kNumberTypeUInt16},
65 {GeDataType::DT_UINT32, kNumberTypeUInt32},
66 {GeDataType::DT_UINT64, kNumberTypeUInt64},
67 {GeDataType::DT_FLOAT16, kNumberTypeFloat16},
68 {GeDataType::DT_FLOAT, kNumberTypeFloat32},
69 {GeDataType::DT_DOUBLE, kNumberTypeFloat64},
70 {GeDataType::DT_STRING, kObjectTypeString},
71 {GeDataType::DT_COMPLEX64, kNumberTypeComplex64},
72 {GeDataType::DT_COMPLEX128, kNumberTypeComplex128},
73 {GeDataType::DT_BF16, kNumberTypeBFloat16}};
74
ConvertGeType(GeDataType type)75 TypeId ConvertGeType(GeDataType type) {
76 if (kGeTypeToMsType.count(type) != 0) {
77 return kGeTypeToMsType.at(type);
78 }
79 return kTypeUnknown;
80 }
81
GLogIsDebug()82 bool GLogIsDebug() {
83 const std::string &glog = common::GetEnv("GLOG_v");
84 auto is_debug = !glog.empty() && glog[0] == '0';
85
86 auto submodule = common::GetEnv("MS_SUBMODULE_LOG_v");
87 bool is_submodule_debug = false;
88 constexpr std::string_view kKernelSub = "KERNEL";
89 constexpr size_t kKernelPos = 7;
90 if (!submodule.empty() && submodule.find(kKernelSub) != std::string::npos) {
91 auto start_pos = submodule.find(kKernelSub) + kKernelPos;
92 is_submodule_debug = submodule[start_pos] == '0';
93 }
94 return is_debug || is_submodule_debug;
95 }
96
SetParameterFormat(const AnfNodePtr & node,const std::string & format,std::string * old_foramt)97 void SetParameterFormat(const AnfNodePtr &node, const std::string &format, std::string *old_foramt) {
98 MS_EXCEPTION_IF_NULL(node);
99 if (!node->isa<Parameter>()) {
100 if (IsPrimitiveCNode(node, prim::kPrimCast)) {
101 auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, 0);
102 if (kernel_with_index.first->isa<Parameter>()) {
103 SetParameterFormat(kernel_with_index.first, format, old_foramt);
104 } else {
105 return;
106 }
107 auto kernel_info = std::dynamic_pointer_cast<device::KernelInfo>(node->kernel_info_ptr());
108 MS_EXCEPTION_IF_NULL(kernel_info);
109 auto build_info = kernel_info->GetMutableSelectKernelBuildInfo();
110 MS_EXCEPTION_IF_NULL(build_info);
111 build_info->SetInputsFormat({format});
112 build_info->SetOutputsFormat({format});
113 kernel_info->set_select_kernel_build_info(build_info);
114 }
115 return;
116 }
117 const auto &output_with_indexs = common::AnfAlgo::GetAllOutputWithIndex(node);
118 std::vector<std::string> output_formats{output_with_indexs.size(), format};
119 auto kernel_info = std::dynamic_pointer_cast<device::KernelInfo>(node->kernel_info_ptr());
120 if (kernel_info == nullptr) {
121 kernel_info = std::make_shared<device::KernelInfo>();
122 node->set_kernel_info(kernel_info);
123 }
124 MS_EXCEPTION_IF_NULL(kernel_info);
125
126 auto build_info = kernel_info->GetMutableSelectKernelBuildInfo();
127 if (build_info == nullptr) {
128 auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
129 build_info = builder->Build();
130 }
131 MS_EXCEPTION_IF_NULL(build_info);
132 build_info->SetOutputsFormat(output_formats);
133 kernel_info->set_select_kernel_build_info(build_info);
134 *old_foramt = format;
135 }
136
NeedNDInput(const CNodePtr & cnode,const AnfNodePtr & input_node,const std::string & new_format,std::string * input_format,bool * input_special_flag)137 bool NeedNDInput(const CNodePtr &cnode, const AnfNodePtr &input_node, const std::string &new_format,
138 std::string *input_format, bool *input_special_flag) {
139 if (AclHelper::IsNopNode(cnode) && !AclHelper::CheckDefaultSupportFormat(*input_format)) {
140 *input_special_flag = true;
141 return true;
142 }
143
144 auto input_cnode = input_node->cast<CNodePtr>();
145 if (input_cnode != nullptr && common::AnfAlgo::HasNodeAttr(kAttrAclSpecialFormat, input_cnode)) {
146 return true;
147 }
148
149 if (!AclHelper::CheckDefaultSupportFormat(*input_format) || AclHelper::CheckDefaultSupportFormat(new_format)) {
150 return false;
151 }
152
153 SetParameterFormat(input_node, new_format, input_format);
154 return false;
155 }
156
NeedNDOutput(const CNodePtr & cnode,const size_t input_num,const size_t output_num,const std::vector<std::string> & input_formats)157 bool NeedNDOutput(const CNodePtr &cnode, const size_t input_num, const size_t output_num,
158 const std::vector<std::string> &input_formats) {
159 auto name = GetCNodeFuncName(cnode);
160 if (kDefaultOutputNode.count(name) != 0) {
161 return true;
162 }
163
164 if (input_num != output_num) {
165 if (output_num != 1 || input_formats.empty() ||
166 !std::all_of(input_formats.begin(), input_formats.end(),
167 [&input_formats](const std::string &format) { return format == input_formats[0]; })) {
168 return true;
169 }
170 }
171
172 for (size_t i = 0; i < output_num; ++i) {
173 const auto &shape = common::AnfAlgo::GetOutputInferShape(cnode, i);
174 if (shape.size() <= 1) {
175 return true;
176 }
177 }
178
179 return false;
180 }
181
GetInputBuildInfo(const AnfNodePtr & node,const size_t input_num,const AclAdapterInfo & acl_info,const GeAdapterInfoPtr & ge_info,std::vector<std::string> * input_formats,std::vector<std::string> * input_reshape_types)182 void GetInputBuildInfo(const AnfNodePtr &node, const size_t input_num, const AclAdapterInfo &acl_info,
183 const GeAdapterInfoPtr &ge_info, std::vector<std::string> *input_formats,
184 std::vector<std::string> *input_reshape_types) {
185 auto input_info = acl_info.inputs();
186 static bool default_format = device::ascend::GetFormatMode() == "1";
187 std::vector<size_t> special_inputs;
188 for (size_t i = 0; i < input_num; ++i) {
189 auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i);
190 bool input_special_flag = false;
191 std::string input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
192 auto prev_shape = common::AnfAlgo::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
193 auto cnode = node->cast<CNodePtr>();
194 auto new_format = input_format;
195 if (!default_format && acl_info.input_selector().count(i) != 0) {
196 auto func = acl_info.input_selector().at(i);
197 auto prev_dtype = common::AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
198 new_format = func(prev_dtype, {prev_shape});
199 }
200 input_format = NeedNDInput(cnode, kernel_with_index.first, new_format, &input_format, &input_special_flag)
201 ? GET_DEFAULT_FORMAT(prev_shape)
202 : input_format;
203
204 (void)input_formats->emplace_back(input_format);
205 if (input_special_flag) {
206 (void)special_inputs.emplace_back(i);
207 }
208
209 if (i >= input_info.size()) {
210 continue;
211 }
212 // Get reshape type.
213 auto ge_idx = ge_info->GetGeInputByMsInputIndex(i).index;
214 if (ge_idx >= input_info.size()) {
215 continue;
216 }
217 auto special_info = input_info.at(ge_idx);
218 if (!special_info.reshape_type.empty()) {
219 input_reshape_types->at(i) = special_info.reshape_type;
220 }
221 }
222 if (!special_inputs.empty()) {
223 common::AnfAlgo::SetNodeAttr(kAttrAclSpecialInputFormat, MakeValue(special_inputs), node);
224 }
225 }
226
GetOutputBuildInfo(const AnfNodePtr & node,const size_t output_num,const AclAdapterInfo & acl_info,const std::vector<std::string> & input_formats,std::vector<std::string> * output_formats)227 void GetOutputBuildInfo(const AnfNodePtr &node, const size_t output_num, const AclAdapterInfo &acl_info,
228 const std::vector<std::string> &input_formats, std::vector<std::string> *output_formats) {
229 // First use output func.
230 auto input_num = common::AnfAlgo::GetInputTensorNum(node);
231 static bool default_format = device::ascend::GetFormatMode() == "1";
232 if (!default_format && acl_info.output_selector() != nullptr) {
233 auto data_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
234 std::vector<ShapeVector> input_shapes;
235 for (size_t i = 0; i < input_num; ++i) {
236 (void)input_shapes.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferShape(node, i));
237 }
238 auto func = acl_info.output_selector();
239 for (size_t i = 0; i < output_num; ++i) {
240 const auto &format = func(data_type, input_shapes);
241 (void)output_formats->emplace_back(format);
242 }
243 return;
244 }
245
246 // Second use output format.
247 if (!acl_info.no_special_outputs()) {
248 for (size_t i = 0; i < output_num; ++i) {
249 (void)output_formats->emplace_back(acl_info.output_format(i, input_formats));
250 }
251 return;
252 }
253
254 for (size_t i = 0; i < output_num; ++i) {
255 auto shape = common::AnfAlgo::GetOutputInferShape(node, i);
256 (void)output_formats->emplace_back(GET_DEFAULT_FORMAT(shape));
257 }
258 }
259
SetOutputIdentityFlag(const AnfNodePtr & node,const std::vector<std::string> & output_formats)260 void SetOutputIdentityFlag(const AnfNodePtr &node, const std::vector<std::string> &output_formats) {
261 if (device::ascend::GetFormatMode() == "1" && AclHelper::NeedIdentityFlag(output_formats)) {
262 common::AnfAlgo::SetNodeAttr(kAttrAclSpecialFormat, MakeValue(true), node);
263 }
264 }
265
RefreshRefFormat(const std::unordered_map<size_t,size_t> & ref_map,const std::vector<std::string> & input_formats,std::vector<std::string> * output_formats)266 void RefreshRefFormat(const std::unordered_map<size_t, size_t> &ref_map, const std::vector<std::string> &input_formats,
267 std::vector<std::string> *output_formats) {
268 if (ref_map.empty()) {
269 return;
270 }
271
272 for (auto [out_idx, in_idx] : ref_map) {
273 if (out_idx >= output_formats->size()) {
274 MS_LOG(EXCEPTION) << "Error output index:" << out_idx << " for refresh!";
275 }
276 if (in_idx >= input_formats.size()) {
277 MS_LOG(EXCEPTION) << "Error input index:" << in_idx << " for refresh!";
278 }
279 output_formats->at(out_idx) = input_formats[in_idx];
280 }
281 }
282 } // namespace
283
IsPrintDebugString()284 bool AclHelper::IsPrintDebugString() {
285 static bool is_debug = GLogIsDebug();
286 return is_debug;
287 }
288
CheckDefaultSupportFormat(const string & format)289 bool AclHelper::CheckDefaultSupportFormat(const string &format) {
290 static std::set<std::string> default_support = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW,
291 kOpFormat_NHWC, kOpFormat_NDHWC, kOpFormat_NCDHW};
292 return default_support.find(format) != default_support.end();
293 }
294
GetMoreDataTypeSupported(TypeId data_type,const std::string & op_type)295 bool AclHelper::GetMoreDataTypeSupported(TypeId data_type, const std::string &op_type) {
296 if (!AclAdapterManager::GetInstance().CheckAclAdapter(op_type)) {
297 return false;
298 }
299 auto acl_info = AclAdapterManager::GetInstance().GetOpInfo(op_type);
300 if (acl_info.precision_mode() == FORCE_FP32) {
301 if (data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat) {
302 return false;
303 }
304 return true;
305 }
306 if (!acl_info.extra_supported_datatype().empty()) {
307 if (std::any_of(acl_info.extra_supported_datatype().begin(), acl_info.extra_supported_datatype().end(),
308 [data_type](GeDataType ge_type) { return ConvertGeType(ge_type) == data_type; })) {
309 return true;
310 }
311 }
312 return false;
313 }
314
GetKernelInfoByInputs(const CNodePtr & cnode,const std::shared_ptr<GeAdapterInfo> & info)315 KernelType AclHelper::GetKernelInfoByInputs(const CNodePtr &cnode, const std::shared_ptr<GeAdapterInfo> &info) {
316 MS_EXCEPTION_IF_NULL(cnode);
317 MS_EXCEPTION_IF_NULL(info);
318 auto input_supported_dtypes = info->input_supported_dtypes();
319 size_t num_real_inputs = common::AnfAlgo::GetInputTensorNum(cnode);
320 size_t ms_real_idx = 0; // index of actual input argument
321 auto value_depend_indices = ops::GetInputDependValueList(common::AnfAlgo::GetCNodePrimitive(cnode));
322
323 std::vector<int64_t> dyn_input_sizes = {};
324 if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode)) {
325 dyn_input_sizes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrDynInputSizes);
326 }
327
328 for (size_t ms_proto_idx = 0; ms_proto_idx < info->GetNumInputsOfMsOpProto(); ++ms_proto_idx) {
329 MS_LOG(DEBUG) << "ms_proto_idx=" << ms_proto_idx << ", ms_real_idx=" << ms_real_idx
330 << ", num_real_inputs=" << num_real_inputs;
331 // skip attribute converted input
332 if (NeedCheckAttrToInput(cnode, info->attr_input_map(), ms_proto_idx)) {
333 MS_LOG(DEBUG) << "Op prototype input idx:" << ms_proto_idx << " is attr to input, skip check";
334 continue;
335 }
336
337 if (ms_real_idx >= num_real_inputs) {
338 break;
339 }
340
341 auto opt_ge_input_info = info->GetOptGeInputByMsInputIndex(ms_proto_idx);
342 // skip input which will be converted to attribute, or some extra inputs defined by mindspore, such as AvgPoolGrad
343 if (!opt_ge_input_info.has_value()) {
344 MS_LOG(DEBUG) << "Unsupported op prototype input idx:" << ms_proto_idx
345 << " of node:" << cnode->fullname_with_scope();
346 ms_real_idx += 1;
347 continue;
348 }
349
350 auto &ge_input_info = opt_ge_input_info.value();
351 auto base_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, ms_real_idx);
352 bool is_value_depend = value_depend_indices.find(static_cast<int64_t>(ms_real_idx)) != value_depend_indices.end();
353 if (is_value_depend) {
354 // if the input is value_depend, verification is performed in the launch and type conversion if necessary
355 MS_LOG(DEBUG) << "When input is value_depend, skip it." << cnode->fullname_with_scope();
356 ms_real_idx += 1;
357 continue;
358 }
359
360 if (!std::any_of(
361 input_supported_dtypes[ms_proto_idx].begin(), input_supported_dtypes[ms_proto_idx].end(),
362 [base_type, ge_input_info](const ::ge::DataType ge_type) { return ConvertGeType(ge_type) == base_type; })) {
363 if (base_type == kMetaTypeNone && ge_input_info.type == Ms2GeParamInfo::OPTIONAL) {
364 MS_LOG(DEBUG) << "Input is a placeholder, continue!";
365 ms_real_idx += 1;
366 continue;
367 }
368 if (GetMoreDataTypeSupported(base_type, info->op_type())) {
369 MS_LOG(DEBUG) << "More data type is supported, continue!";
370 ms_real_idx += 1;
371 continue;
372 }
373 MS_LOG(DEBUG) << "Unsupported input dtype:" << TypeIdLabel(base_type)
374 << " in ACL, node:" << cnode->fullname_with_scope();
375 return UNKNOWN_KERNEL_TYPE;
376 }
377
378 if (ge_input_info.type == Ms2GeParamInfo::DYNAMIC) {
379 if (dyn_input_sizes.empty()) {
380 auto input_node = common::AnfAlgo::GetPrevNodeOutput(cnode, ms_real_idx);
381 auto abstract = input_node.first->abstract();
382 MS_EXCEPTION_IF_NULL(abstract);
383 if (abstract->isa<abstract::AbstractTuple>() || abstract->isa<abstract::AbstractList>()) {
384 ms_real_idx += 1;
385 continue;
386 }
387 }
388 if (ms_proto_idx >= dyn_input_sizes.size()) {
389 MS_LOG(EXCEPTION) << "Attribute " << kAttrDynInputSizes << " of " << cnode->fullname_with_scope() << " is "
390 << dyn_input_sizes << ", of which size is less than " << ms_proto_idx;
391 }
392 ms_real_idx += dyn_input_sizes[ms_proto_idx];
393 } else {
394 ms_real_idx += 1;
395 }
396 }
397
398 return ACL_KERNEL;
399 }
400
GetKernelInfoByOutputs(const AnfNodePtr & node,const std::shared_ptr<GeAdapterInfo> & info)401 KernelType AclHelper::GetKernelInfoByOutputs(const AnfNodePtr &node, const std::shared_ptr<GeAdapterInfo> &info) {
402 MS_EXCEPTION_IF_NULL(node);
403 MS_EXCEPTION_IF_NULL(info);
404 auto output_supported_dtypes = info->output_supported_dtypes();
405 auto output_flags = info->GetOutputMappingFlags();
406 size_t output_num = ((output_flags & GeTensorInfo::kDynamicParam) == 0) ? info->GetNumOutputsOfMsOpProto()
407 : AnfAlgo::GetOutputTensorNum(node);
408
409 auto is_support = [&node, &output_supported_dtypes](size_t i) {
410 auto base_type = common::AnfAlgo::GetOutputInferDataType(node, i);
411 if (!std::any_of(output_supported_dtypes[i].begin(), output_supported_dtypes[i].end(),
412 [base_type](const ::ge::DataType ge_type) { return ConvertGeType(ge_type) == base_type; })) {
413 MS_LOG(DEBUG) << "Unsupported output dtype:" << TypeIdLabel(base_type)
414 << " in ACL, node:" << node->fullname_with_scope();
415 return false;
416 }
417 return true;
418 };
419
420 // operator has dynamic output
421 if ((info->GetOutputMappingFlags() & GeTensorInfo::kDynamicParam) != 0) {
422 if (info->GetNumOutputsOfMsOpProto() == 1) {
423 return is_support(0) ? ACL_KERNEL : UNKNOWN_KERNEL_TYPE;
424 } else {
425 MS_LOG(EXCEPTION)
426 << "Now not support operator containing dynamic output mixed with other outputs, the failed not is "
427 << node->fullname_with_scope();
428 }
429 }
430
431 // operator does not have dynamic output
432 for (size_t i = 0; i < output_num; ++i) {
433 if (!is_support(i)) {
434 return UNKNOWN_KERNEL_TYPE;
435 }
436 }
437
438 return ACL_KERNEL;
439 }
440
GetKernelInfoFromGe(const AnfNodePtr & node,ErrorAclType * err_type)441 KernelType AclHelper::GetKernelInfoFromGe(const AnfNodePtr &node, ErrorAclType *err_type) {
442 MS_EXCEPTION_IF_NULL(node);
443 auto cnode = node->cast<CNodePtr>();
444 MS_EXCEPTION_IF_NULL(cnode);
445
446 std::string name = GetCNodeFuncName(cnode);
447 if (common::AnfAlgo::IsCommunicationOp(node)) {
448 *err_type = kNormalOp;
449 return HCCL_KERNEL;
450 }
451
452 auto info = GeAdapterManager::GetInstance().GetInfo(name, true);
453 if (info == nullptr) {
454 *err_type = kUnknownOp;
455 MS_LOG(DEBUG) << "Unsupported op type on acl, node name: " << node->fullname_with_scope();
456 return UNKNOWN_KERNEL_TYPE;
457 }
458
459 // check whether all inputs are matched
460 if (GetKernelInfoByInputs(cnode, info) == UNKNOWN_KERNEL_TYPE) {
461 *err_type = kInValidType;
462 return UNKNOWN_KERNEL_TYPE;
463 }
464
465 *err_type = kNormalOp;
466 return ACL_KERNEL;
467 }
468
IsInputDtypeSupport(const std::string & kernel_name,TypeId base_type,size_t idx)469 bool AclHelper::IsInputDtypeSupport(const std::string &kernel_name, TypeId base_type, size_t idx) {
470 auto info = GeAdapterManager::GetInstance().GetInfo(kernel_name, true);
471 MS_EXCEPTION_IF_NULL(info);
472 auto input_supported_dtypes = info->input_supported_dtypes();
473 if (idx >= info->GetNumInputsOfMsOpProto()) {
474 // this branch represent input_attr_map, didn't need check
475 return true;
476 }
477 if (!std::any_of(input_supported_dtypes[idx].begin(), input_supported_dtypes[idx].end(),
478 [base_type](const ::ge::DataType ge_type) { return ConvertGeType(ge_type) == base_type; })) {
479 return false;
480 }
481 return true;
482 }
483
GetValidKernelBuildInfo(const AnfNodePtr & node,std::vector<std::string> * input_formats,std::vector<std::string> * output_formats,std::vector<std::string> * input_reshape_types,std::vector<std::string> * output_reshape_types)484 void AclHelper::GetValidKernelBuildInfo(const AnfNodePtr &node, std::vector<std::string> *input_formats,
485 std::vector<std::string> *output_formats,
486 std::vector<std::string> *input_reshape_types,
487 std::vector<std::string> *output_reshape_types) {
488 MS_EXCEPTION_IF_NULL(node);
489 MS_EXCEPTION_IF_NULL(input_formats);
490 MS_EXCEPTION_IF_NULL(output_formats);
491 MS_EXCEPTION_IF_NULL(input_reshape_types);
492 MS_EXCEPTION_IF_NULL(output_reshape_types);
493 auto cnode = node->cast<CNodePtr>();
494 MS_EXCEPTION_IF_NULL(cnode);
495 std::string name = GetCNodeFuncName(cnode);
496 auto info = GeAdapterManager::GetInstance().GetInfo(name, true);
497 auto op_type = info->op_type();
498
499 input_formats->clear();
500 output_formats->clear();
501 input_reshape_types->clear();
502 output_reshape_types->clear();
503 size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
504 size_t output_num = AnfUtils::GetOutputTensorNum(node);
505 input_reshape_types->assign(input_num, "");
506 output_reshape_types->assign(output_num, "");
507
508 if (!AclAdapterManager::GetInstance().CheckAclAdapter(op_type)) {
509 std::vector<size_t> special_inputs;
510 for (size_t i = 0; i < input_num; ++i) {
511 auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i);
512 bool input_special_flag = false;
513 auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
514 auto prev_shape = common::AnfAlgo::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
515 input_format = NeedNDInput(cnode, kernel_with_index.first, input_format, &input_format, &input_special_flag)
516 ? GET_DEFAULT_FORMAT(prev_shape)
517 : input_format;
518 (void)input_formats->emplace_back(input_format);
519 if (input_special_flag) {
520 (void)special_inputs.emplace_back(i);
521 }
522 }
523 // Input and output number same's op forward.
524 if (NeedNDOutput(cnode, input_num, output_num, *input_formats)) {
525 for (size_t i = 0; i < output_num; ++i) {
526 auto shape = common::AnfAlgo::GetOutputInferShape(node, i);
527 (void)output_formats->emplace_back(GET_DEFAULT_FORMAT(shape));
528 }
529 } else {
530 if (output_num == 1) {
531 output_formats->emplace_back(input_formats->at(0));
532 } else {
533 output_formats->assign(input_formats->begin(), input_formats->end());
534 }
535 SetOutputIdentityFlag(node, *output_formats);
536 }
537
538 if (!special_inputs.empty()) {
539 common::AnfAlgo::SetNodeAttr(kAttrAclSpecialInputFormat, MakeValue(special_inputs), node);
540 }
541 RefreshRefFormat(info->GetRefMappingInfo(), *input_formats, output_formats);
542 return;
543 }
544
545 auto acl_info = AclAdapterManager::GetInstance().GetOpInfo(op_type);
546 GetInputBuildInfo(node, input_num, acl_info, info, input_formats, input_reshape_types);
547 GetOutputBuildInfo(node, output_num, acl_info, *input_formats, output_formats);
548 SetOutputIdentityFlag(node, *output_formats);
549 RefreshRefFormat(info->GetRefMappingInfo(), *input_formats, output_formats);
550 }
551
PaddingOriShape(const std::string & name,size_t idx,const std::string & format,ShapeVector * shape)552 void AclHelper::PaddingOriShape(const std::string &name, size_t idx, const std::string &format, ShapeVector *shape) {
553 MS_EXCEPTION_IF_NULL(shape);
554 auto info = GeAdapterManager::GetInstance().GetInfo(name, true);
555 auto op_type = info->op_type();
556 if (!AclAdapterManager::GetInstance().CheckAclAdapter(op_type)) {
557 return;
558 }
559 auto acl_info = AclAdapterManager::GetInstance().GetOpInfo(op_type);
560 auto info_list = acl_info.inputs();
561 if (info_list.empty() || idx >= info_list.size()) {
562 return;
563 }
564 auto ge_idx = info->GetGeInputByMsInputIndex(idx).index;
565 auto special_iter = info_list.find(ge_idx);
566 if (special_iter == info_list.end() || special_iter->second.ori_format.empty()) {
567 return;
568 }
569 if (!special_iter->second.ori_format.empty() && format == kOpFormat_NCHW && shape->size() < kDim4) {
570 *shape = trans::PaddingShape(*shape, kOpFormat_NCHW, special_iter->second.reshape_type);
571 }
572 }
573
ConvertOriginShapeAndFormat(const std::string & name,size_t idx,const std::string & dev_format,ShapeVector * shape)574 std::string AclHelper::ConvertOriginShapeAndFormat(const std::string &name, size_t idx, const std::string &dev_format,
575 ShapeVector *shape) {
576 MS_EXCEPTION_IF_NULL(shape);
577 auto info = GeAdapterManager::GetInstance().GetInfo(name, true);
578 auto op_type = info->op_type();
579 std::string ret_format = (shape->size() == kDim4) ? kOpFormat_NCHW : kOpFormat_DEFAULT;
580 // case0: normal
581 if (!AclAdapterManager::GetInstance().CheckAclAdapter(op_type)) {
582 return ret_format;
583 }
584 // case1: 3d operator
585 auto acl_info = AclAdapterManager::GetInstance().GetOpInfo(op_type);
586 if (acl_info.is_3d()) {
587 *shape = trans::PaddingShape(*shape, kOpFormat_NCDHW);
588 return kOpFormat_NCDHW;
589 }
590 if (acl_info.is_need_pad_no_shape() && shape->empty()) {
591 shape->push_back(1);
592 }
593 // case2: no special config
594 auto info_list = acl_info.inputs();
595 if (info_list.empty() || idx >= info_list.size()) {
596 return ret_format;
597 }
598 auto ge_idx = info->GetGeInputByMsInputIndex(idx).index;
599 auto special_iter = info_list.find(ge_idx);
600 if (special_iter == info_list.end() || special_iter->second.ori_format.empty()) {
601 return ret_format;
602 }
603 // case3: if config input ori format or dev_format is special
604 if (!special_iter->second.ori_format.empty() || !CheckDefaultSupportFormat(dev_format)) {
605 if (special_iter->second.ori_format[0] == kOpFormat_ND) {
606 return kOpFormat_ND;
607 }
608 if (ret_format == kOpFormat_DEFAULT && shape->size() < kDim4) {
609 *shape = trans::PaddingShape(*shape, kOpFormat_NCHW, special_iter->second.reshape_type);
610 ret_format = kOpFormat_NCHW;
611 }
612 }
613 return ret_format;
614 }
615
NeedCheckAttrToInput(const CNodePtr & node,const mindspore::HashMap<size_t,std::string> & attr_input_map,size_t index)616 bool AclHelper::NeedCheckAttrToInput(const CNodePtr &node,
617 const mindspore::HashMap<size_t, std::string> &attr_input_map, size_t index) {
618 MS_EXCEPTION_IF_NULL(node);
619 if (attr_input_map.count(index) == 0) {
620 return false;
621 }
622
623 const auto &attr_name = attr_input_map.at(index);
624 if (common::AnfAlgo::HasNodeAttr(attr_name, node)) {
625 return true;
626 }
627 return false;
628 }
629
GetFormatFromAttr(const PrimitivePtr & primitive)630 std::string AclHelper::GetFormatFromAttr(const PrimitivePtr &primitive) {
631 MS_EXCEPTION_IF_NULL(primitive);
632 auto &attrs = primitive->attrs();
633 std::string format;
634 if (attrs.count("format") != 0) {
635 auto attr_value = attrs.at("format");
636 if (attr_value->isa<StringImm>()) {
637 format = GetValue<std::string>(attr_value);
638 } else {
639 MS_LOG(DEBUG) << "The attr format is not a valid value.";
640 }
641 }
642 return format;
643 }
644
GetDefaultFormatFlagFromAttr(const PrimitivePtr & primitive,bool is_input)645 bool AclHelper::GetDefaultFormatFlagFromAttr(const PrimitivePtr &primitive, bool is_input) {
646 MS_EXCEPTION_IF_NULL(primitive);
647 bool is_default = true;
648 auto key = is_input ? kAttrInputDefaultFormat : kAttrOutputDefaultFormat;
649 auto attrs = primitive->attrs();
650 if (attrs.count(key) != 0) {
651 auto attr_value = attrs.at(key);
652 if (attr_value->isa<BoolImm>()) {
653 is_default = GetValue<bool>(attr_value);
654 } else {
655 MS_LOG(DEBUG) << "The attr: " << key << " is not a valid value.";
656 }
657 }
658 return is_default;
659 }
660
GetFracZGroupFromAttr(const PrimitivePtr & primitive)661 int64_t AclHelper::GetFracZGroupFromAttr(const PrimitivePtr &primitive) {
662 MS_EXCEPTION_IF_NULL(primitive);
663 auto attrs = primitive->attrs();
664 int64_t fracz_group = 1;
665 if (attrs.count(kAttrFracZGroup) != 0) {
666 auto attr_value = attrs.at(kAttrFracZGroup);
667 if (attr_value->isa<Int64Imm>()) {
668 fracz_group = GetValue<int64_t>(attr_value);
669 } else {
670 MS_LOG(DEBUG) << "The FracZGroup attr is not a valid value.";
671 }
672 }
673 return fracz_group;
674 }
675
IsNopNode(const CNodePtr & node)676 bool AclHelper::IsNopNode(const CNodePtr &node) {
677 MS_EXCEPTION_IF_NULL(node);
678 static mindspore::HashSet<std::string> nop_nodes = {prim::kPrimReshape->name(), prim::kPrimExpandDims->name(),
679 prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
680 prim::kPrimFlattenGrad->name()};
681 auto op_name = common::AnfAlgo::GetCNodeName(node);
682 return (nop_nodes.find(op_name) != nop_nodes.end());
683 }
684
NeedIdentityFlag(const std::vector<std::string> & formats)685 bool AclHelper::NeedIdentityFlag(const std::vector<std::string> &formats) {
686 return std::any_of(formats.begin(), formats.end(),
687 [](const auto &format) { return !AclHelper::CheckDefaultSupportFormat(format); });
688 }
689 } // namespace transform
690 } // namespace mindspore
691