• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "transform/acl_ir/acl_helper.h"
18 #include <set>
19 #include <map>
20 #include <unordered_map>
21 #include <string>
22 #include "include/api/data_type.h"
23 #include "include/backend/anf_runtime_algorithm.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "include/common/utils/utils.h"
26 #include "include/transform/graph_ir/types.h"
27 #include "ops/nn_ops.h"
28 #include "ops/array_ops.h"
29 #include "ops/conv_pool_ops.h"
30 #include "ops/structure_ops.h"
31 #include "ops/ascend_op_name.h"
32 #include "ops/image_op_name.h"
33 #include "ops/math_op_name.h"
34 #include "runtime/device/ms_device_shape_transfer.h"
35 #include "plugin/device/ascend/hal/common/ascend_utils.h"
36 #include "transform/acl_ir/acl_adapter_info.h"
37 #include "transform/acl_ir/ge_adapter_info.h"
38 #include "ops/op_utils.h"
39 
40 namespace mindspore {
41 namespace transform {
42 namespace {
43 #define GET_DEFAULT_FORMAT(shape) (shape.size() == kDim4 ? kOpFormat_NCHW : kOpFormat_DEFAULT)
44 static const std::set<std::string> kDefaultOutputNode = {
45   // Dynamic output shape kernel.
46   kUniqueOpName, kMaskedSelectOpName, kNonMaxSuppressionV3OpName,
47   // Dropout
48   kDropoutGenMaskOpName, kDropoutGenMaskV3OpName, kStatelessDropOutGenMaskOpName, kDropoutDoMaskOpName,
49   kDropoutDoMaskV3OpName, kDropoutOpName, kDropoutGradOpName, kDropout2DOpName, kDropout3DOpName,
50   // Special Op
51   kAffineGridOpName, kRangeOpName, kBernoulliOpName};
52 
53 static const std::set<std::string> kHcomOps = {
54   kHcomOpTypeAllReduce, kHcomOpTypeReduce,        kHcomOpTypeAllGather, kHcomOpTypeBroadcast, kHcomOpTypeSend,
55   kHcomOpTypeReceive,   kHcomOpTypeReduceScatter, kHcomOpTypeAllToAllV, kHcomOpTypeBarrier,   kHcomOpTypeScatter,
56   kHcomOpTypeGather,    kHcomOpTypeBatchSendRecv, kHcomOpTypeAlltoAllV};
57 
58 static const HashMap<GeDataType, TypeId> kGeTypeToMsType = {{GeDataType::DT_BOOL, kNumberTypeBool},
59                                                             {GeDataType::DT_INT8, kNumberTypeInt8},
60                                                             {GeDataType::DT_INT16, kNumberTypeInt16},
61                                                             {GeDataType::DT_INT32, kNumberTypeInt32},
62                                                             {GeDataType::DT_INT64, kNumberTypeInt64},
63                                                             {GeDataType::DT_UINT8, kNumberTypeUInt8},
64                                                             {GeDataType::DT_UINT16, kNumberTypeUInt16},
65                                                             {GeDataType::DT_UINT32, kNumberTypeUInt32},
66                                                             {GeDataType::DT_UINT64, kNumberTypeUInt64},
67                                                             {GeDataType::DT_FLOAT16, kNumberTypeFloat16},
68                                                             {GeDataType::DT_FLOAT, kNumberTypeFloat32},
69                                                             {GeDataType::DT_DOUBLE, kNumberTypeFloat64},
70                                                             {GeDataType::DT_STRING, kObjectTypeString},
71                                                             {GeDataType::DT_COMPLEX64, kNumberTypeComplex64},
72                                                             {GeDataType::DT_COMPLEX128, kNumberTypeComplex128},
73                                                             {GeDataType::DT_BF16, kNumberTypeBFloat16}};
74 
ConvertGeType(GeDataType type)75 TypeId ConvertGeType(GeDataType type) {
76   if (kGeTypeToMsType.count(type) != 0) {
77     return kGeTypeToMsType.at(type);
78   }
79   return kTypeUnknown;
80 }
81 
GLogIsDebug()82 bool GLogIsDebug() {
83   const std::string &glog = common::GetEnv("GLOG_v");
84   auto is_debug = !glog.empty() && glog[0] == '0';
85 
86   auto submodule = common::GetEnv("MS_SUBMODULE_LOG_v");
87   bool is_submodule_debug = false;
88   constexpr std::string_view kKernelSub = "KERNEL";
89   constexpr size_t kKernelPos = 7;
90   if (!submodule.empty() && submodule.find(kKernelSub) != std::string::npos) {
91     auto start_pos = submodule.find(kKernelSub) + kKernelPos;
92     is_submodule_debug = submodule[start_pos] == '0';
93   }
94   return is_debug || is_submodule_debug;
95 }
96 
SetParameterFormat(const AnfNodePtr & node,const std::string & format,std::string * old_foramt)97 void SetParameterFormat(const AnfNodePtr &node, const std::string &format, std::string *old_foramt) {
98   MS_EXCEPTION_IF_NULL(node);
99   if (!node->isa<Parameter>()) {
100     if (IsPrimitiveCNode(node, prim::kPrimCast)) {
101       auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, 0);
102       if (kernel_with_index.first->isa<Parameter>()) {
103         SetParameterFormat(kernel_with_index.first, format, old_foramt);
104       } else {
105         return;
106       }
107       auto kernel_info = std::dynamic_pointer_cast<device::KernelInfo>(node->kernel_info_ptr());
108       MS_EXCEPTION_IF_NULL(kernel_info);
109       auto build_info = kernel_info->GetMutableSelectKernelBuildInfo();
110       MS_EXCEPTION_IF_NULL(build_info);
111       build_info->SetInputsFormat({format});
112       build_info->SetOutputsFormat({format});
113       kernel_info->set_select_kernel_build_info(build_info);
114     }
115     return;
116   }
117   const auto &output_with_indexs = common::AnfAlgo::GetAllOutputWithIndex(node);
118   std::vector<std::string> output_formats{output_with_indexs.size(), format};
119   auto kernel_info = std::dynamic_pointer_cast<device::KernelInfo>(node->kernel_info_ptr());
120   if (kernel_info == nullptr) {
121     kernel_info = std::make_shared<device::KernelInfo>();
122     node->set_kernel_info(kernel_info);
123   }
124   MS_EXCEPTION_IF_NULL(kernel_info);
125 
126   auto build_info = kernel_info->GetMutableSelectKernelBuildInfo();
127   if (build_info == nullptr) {
128     auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
129     build_info = builder->Build();
130   }
131   MS_EXCEPTION_IF_NULL(build_info);
132   build_info->SetOutputsFormat(output_formats);
133   kernel_info->set_select_kernel_build_info(build_info);
134   *old_foramt = format;
135 }
136 
NeedNDInput(const CNodePtr & cnode,const AnfNodePtr & input_node,const std::string & new_format,std::string * input_format,bool * input_special_flag)137 bool NeedNDInput(const CNodePtr &cnode, const AnfNodePtr &input_node, const std::string &new_format,
138                  std::string *input_format, bool *input_special_flag) {
139   if (AclHelper::IsNopNode(cnode) && !AclHelper::CheckDefaultSupportFormat(*input_format)) {
140     *input_special_flag = true;
141     return true;
142   }
143 
144   auto input_cnode = input_node->cast<CNodePtr>();
145   if (input_cnode != nullptr && common::AnfAlgo::HasNodeAttr(kAttrAclSpecialFormat, input_cnode)) {
146     return true;
147   }
148 
149   if (!AclHelper::CheckDefaultSupportFormat(*input_format) || AclHelper::CheckDefaultSupportFormat(new_format)) {
150     return false;
151   }
152 
153   SetParameterFormat(input_node, new_format, input_format);
154   return false;
155 }
156 
NeedNDOutput(const CNodePtr & cnode,const size_t input_num,const size_t output_num,const std::vector<std::string> & input_formats)157 bool NeedNDOutput(const CNodePtr &cnode, const size_t input_num, const size_t output_num,
158                   const std::vector<std::string> &input_formats) {
159   auto name = GetCNodeFuncName(cnode);
160   if (kDefaultOutputNode.count(name) != 0) {
161     return true;
162   }
163 
164   if (input_num != output_num) {
165     if (output_num != 1 || input_formats.empty() ||
166         !std::all_of(input_formats.begin(), input_formats.end(),
167                      [&input_formats](const std::string &format) { return format == input_formats[0]; })) {
168       return true;
169     }
170   }
171 
172   for (size_t i = 0; i < output_num; ++i) {
173     const auto &shape = common::AnfAlgo::GetOutputInferShape(cnode, i);
174     if (shape.size() <= 1) {
175       return true;
176     }
177   }
178 
179   return false;
180 }
181 
GetInputBuildInfo(const AnfNodePtr & node,const size_t input_num,const AclAdapterInfo & acl_info,const GeAdapterInfoPtr & ge_info,std::vector<std::string> * input_formats,std::vector<std::string> * input_reshape_types)182 void GetInputBuildInfo(const AnfNodePtr &node, const size_t input_num, const AclAdapterInfo &acl_info,
183                        const GeAdapterInfoPtr &ge_info, std::vector<std::string> *input_formats,
184                        std::vector<std::string> *input_reshape_types) {
185   auto input_info = acl_info.inputs();
186   static bool default_format = device::ascend::GetFormatMode() == "1";
187   std::vector<size_t> special_inputs;
188   for (size_t i = 0; i < input_num; ++i) {
189     auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i);
190     bool input_special_flag = false;
191     std::string input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
192     auto prev_shape = common::AnfAlgo::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
193     auto cnode = node->cast<CNodePtr>();
194     auto new_format = input_format;
195     if (!default_format && acl_info.input_selector().count(i) != 0) {
196       auto func = acl_info.input_selector().at(i);
197       auto prev_dtype = common::AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
198       new_format = func(prev_dtype, {prev_shape});
199     }
200     input_format = NeedNDInput(cnode, kernel_with_index.first, new_format, &input_format, &input_special_flag)
201                      ? GET_DEFAULT_FORMAT(prev_shape)
202                      : input_format;
203 
204     (void)input_formats->emplace_back(input_format);
205     if (input_special_flag) {
206       (void)special_inputs.emplace_back(i);
207     }
208 
209     if (i >= input_info.size()) {
210       continue;
211     }
212     // Get reshape type.
213     auto ge_idx = ge_info->GetGeInputByMsInputIndex(i).index;
214     if (ge_idx >= input_info.size()) {
215       continue;
216     }
217     auto special_info = input_info.at(ge_idx);
218     if (!special_info.reshape_type.empty()) {
219       input_reshape_types->at(i) = special_info.reshape_type;
220     }
221   }
222   if (!special_inputs.empty()) {
223     common::AnfAlgo::SetNodeAttr(kAttrAclSpecialInputFormat, MakeValue(special_inputs), node);
224   }
225 }
226 
GetOutputBuildInfo(const AnfNodePtr & node,const size_t output_num,const AclAdapterInfo & acl_info,const std::vector<std::string> & input_formats,std::vector<std::string> * output_formats)227 void GetOutputBuildInfo(const AnfNodePtr &node, const size_t output_num, const AclAdapterInfo &acl_info,
228                         const std::vector<std::string> &input_formats, std::vector<std::string> *output_formats) {
229   // First use output func.
230   auto input_num = common::AnfAlgo::GetInputTensorNum(node);
231   static bool default_format = device::ascend::GetFormatMode() == "1";
232   if (!default_format && acl_info.output_selector() != nullptr) {
233     auto data_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
234     std::vector<ShapeVector> input_shapes;
235     for (size_t i = 0; i < input_num; ++i) {
236       (void)input_shapes.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferShape(node, i));
237     }
238     auto func = acl_info.output_selector();
239     for (size_t i = 0; i < output_num; ++i) {
240       const auto &format = func(data_type, input_shapes);
241       (void)output_formats->emplace_back(format);
242     }
243     return;
244   }
245 
246   // Second use output format.
247   if (!acl_info.no_special_outputs()) {
248     for (size_t i = 0; i < output_num; ++i) {
249       (void)output_formats->emplace_back(acl_info.output_format(i, input_formats));
250     }
251     return;
252   }
253 
254   for (size_t i = 0; i < output_num; ++i) {
255     auto shape = common::AnfAlgo::GetOutputInferShape(node, i);
256     (void)output_formats->emplace_back(GET_DEFAULT_FORMAT(shape));
257   }
258 }
259 
SetOutputIdentityFlag(const AnfNodePtr & node,const std::vector<std::string> & output_formats)260 void SetOutputIdentityFlag(const AnfNodePtr &node, const std::vector<std::string> &output_formats) {
261   if (device::ascend::GetFormatMode() == "1" && AclHelper::NeedIdentityFlag(output_formats)) {
262     common::AnfAlgo::SetNodeAttr(kAttrAclSpecialFormat, MakeValue(true), node);
263   }
264 }
265 
RefreshRefFormat(const std::unordered_map<size_t,size_t> & ref_map,const std::vector<std::string> & input_formats,std::vector<std::string> * output_formats)266 void RefreshRefFormat(const std::unordered_map<size_t, size_t> &ref_map, const std::vector<std::string> &input_formats,
267                       std::vector<std::string> *output_formats) {
268   if (ref_map.empty()) {
269     return;
270   }
271 
272   for (auto [out_idx, in_idx] : ref_map) {
273     if (out_idx >= output_formats->size()) {
274       MS_LOG(EXCEPTION) << "Error output index:" << out_idx << " for refresh!";
275     }
276     if (in_idx >= input_formats.size()) {
277       MS_LOG(EXCEPTION) << "Error input index:" << in_idx << " for refresh!";
278     }
279     output_formats->at(out_idx) = input_formats[in_idx];
280   }
281 }
282 }  // namespace
283 
IsPrintDebugString()284 bool AclHelper::IsPrintDebugString() {
285   static bool is_debug = GLogIsDebug();
286   return is_debug;
287 }
288 
CheckDefaultSupportFormat(const string & format)289 bool AclHelper::CheckDefaultSupportFormat(const string &format) {
290   static std::set<std::string> default_support = {kOpFormat_DEFAULT, kOpFormat_ND,    kOpFormat_NCHW,
291                                                   kOpFormat_NHWC,    kOpFormat_NDHWC, kOpFormat_NCDHW};
292   return default_support.find(format) != default_support.end();
293 }
294 
GetMoreDataTypeSupported(TypeId data_type,const std::string & op_type)295 bool AclHelper::GetMoreDataTypeSupported(TypeId data_type, const std::string &op_type) {
296   if (!AclAdapterManager::GetInstance().CheckAclAdapter(op_type)) {
297     return false;
298   }
299   auto acl_info = AclAdapterManager::GetInstance().GetOpInfo(op_type);
300   if (acl_info.precision_mode() == FORCE_FP32) {
301     if (data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat) {
302       return false;
303     }
304     return true;
305   }
306   if (!acl_info.extra_supported_datatype().empty()) {
307     if (std::any_of(acl_info.extra_supported_datatype().begin(), acl_info.extra_supported_datatype().end(),
308                     [data_type](GeDataType ge_type) { return ConvertGeType(ge_type) == data_type; })) {
309       return true;
310     }
311   }
312   return false;
313 }
314 
GetKernelInfoByInputs(const CNodePtr & cnode,const std::shared_ptr<GeAdapterInfo> & info)315 KernelType AclHelper::GetKernelInfoByInputs(const CNodePtr &cnode, const std::shared_ptr<GeAdapterInfo> &info) {
316   MS_EXCEPTION_IF_NULL(cnode);
317   MS_EXCEPTION_IF_NULL(info);
318   auto input_supported_dtypes = info->input_supported_dtypes();
319   size_t num_real_inputs = common::AnfAlgo::GetInputTensorNum(cnode);
320   size_t ms_real_idx = 0;  // index of actual input argument
321   auto value_depend_indices = ops::GetInputDependValueList(common::AnfAlgo::GetCNodePrimitive(cnode));
322 
323   std::vector<int64_t> dyn_input_sizes = {};
324   if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode)) {
325     dyn_input_sizes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrDynInputSizes);
326   }
327 
328   for (size_t ms_proto_idx = 0; ms_proto_idx < info->GetNumInputsOfMsOpProto(); ++ms_proto_idx) {
329     MS_LOG(DEBUG) << "ms_proto_idx=" << ms_proto_idx << ", ms_real_idx=" << ms_real_idx
330                   << ", num_real_inputs=" << num_real_inputs;
331     // skip attribute converted input
332     if (NeedCheckAttrToInput(cnode, info->attr_input_map(), ms_proto_idx)) {
333       MS_LOG(DEBUG) << "Op prototype input idx:" << ms_proto_idx << " is attr to input, skip check";
334       continue;
335     }
336 
337     if (ms_real_idx >= num_real_inputs) {
338       break;
339     }
340 
341     auto opt_ge_input_info = info->GetOptGeInputByMsInputIndex(ms_proto_idx);
342     // skip input which will be converted to attribute, or some extra inputs defined by mindspore, such as AvgPoolGrad
343     if (!opt_ge_input_info.has_value()) {
344       MS_LOG(DEBUG) << "Unsupported op prototype input idx:" << ms_proto_idx
345                     << " of node:" << cnode->fullname_with_scope();
346       ms_real_idx += 1;
347       continue;
348     }
349 
350     auto &ge_input_info = opt_ge_input_info.value();
351     auto base_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, ms_real_idx);
352     bool is_value_depend = value_depend_indices.find(static_cast<int64_t>(ms_real_idx)) != value_depend_indices.end();
353     if (is_value_depend) {
354       // if the input is value_depend,  verification is performed in the launch and type conversion if necessary
355       MS_LOG(DEBUG) << "When input is value_depend, skip it." << cnode->fullname_with_scope();
356       ms_real_idx += 1;
357       continue;
358     }
359 
360     if (!std::any_of(
361           input_supported_dtypes[ms_proto_idx].begin(), input_supported_dtypes[ms_proto_idx].end(),
362           [base_type, ge_input_info](const ::ge::DataType ge_type) { return ConvertGeType(ge_type) == base_type; })) {
363       if (base_type == kMetaTypeNone && ge_input_info.type == Ms2GeParamInfo::OPTIONAL) {
364         MS_LOG(DEBUG) << "Input is a placeholder, continue!";
365         ms_real_idx += 1;
366         continue;
367       }
368       if (GetMoreDataTypeSupported(base_type, info->op_type())) {
369         MS_LOG(DEBUG) << "More data type is supported, continue!";
370         ms_real_idx += 1;
371         continue;
372       }
373       MS_LOG(DEBUG) << "Unsupported input dtype:" << TypeIdLabel(base_type)
374                     << " in ACL, node:" << cnode->fullname_with_scope();
375       return UNKNOWN_KERNEL_TYPE;
376     }
377 
378     if (ge_input_info.type == Ms2GeParamInfo::DYNAMIC) {
379       if (dyn_input_sizes.empty()) {
380         auto input_node = common::AnfAlgo::GetPrevNodeOutput(cnode, ms_real_idx);
381         auto abstract = input_node.first->abstract();
382         MS_EXCEPTION_IF_NULL(abstract);
383         if (abstract->isa<abstract::AbstractTuple>() || abstract->isa<abstract::AbstractList>()) {
384           ms_real_idx += 1;
385           continue;
386         }
387       }
388       if (ms_proto_idx >= dyn_input_sizes.size()) {
389         MS_LOG(EXCEPTION) << "Attribute " << kAttrDynInputSizes << " of " << cnode->fullname_with_scope() << " is "
390                           << dyn_input_sizes << ", of which size is less than " << ms_proto_idx;
391       }
392       ms_real_idx += dyn_input_sizes[ms_proto_idx];
393     } else {
394       ms_real_idx += 1;
395     }
396   }
397 
398   return ACL_KERNEL;
399 }
400 
GetKernelInfoByOutputs(const AnfNodePtr & node,const std::shared_ptr<GeAdapterInfo> & info)401 KernelType AclHelper::GetKernelInfoByOutputs(const AnfNodePtr &node, const std::shared_ptr<GeAdapterInfo> &info) {
402   MS_EXCEPTION_IF_NULL(node);
403   MS_EXCEPTION_IF_NULL(info);
404   auto output_supported_dtypes = info->output_supported_dtypes();
405   auto output_flags = info->GetOutputMappingFlags();
406   size_t output_num = ((output_flags & GeTensorInfo::kDynamicParam) == 0) ? info->GetNumOutputsOfMsOpProto()
407                                                                           : AnfAlgo::GetOutputTensorNum(node);
408 
409   auto is_support = [&node, &output_supported_dtypes](size_t i) {
410     auto base_type = common::AnfAlgo::GetOutputInferDataType(node, i);
411     if (!std::any_of(output_supported_dtypes[i].begin(), output_supported_dtypes[i].end(),
412                      [base_type](const ::ge::DataType ge_type) { return ConvertGeType(ge_type) == base_type; })) {
413       MS_LOG(DEBUG) << "Unsupported output dtype:" << TypeIdLabel(base_type)
414                     << " in ACL, node:" << node->fullname_with_scope();
415       return false;
416     }
417     return true;
418   };
419 
420   // operator has dynamic output
421   if ((info->GetOutputMappingFlags() & GeTensorInfo::kDynamicParam) != 0) {
422     if (info->GetNumOutputsOfMsOpProto() == 1) {
423       return is_support(0) ? ACL_KERNEL : UNKNOWN_KERNEL_TYPE;
424     } else {
425       MS_LOG(EXCEPTION)
426         << "Now not support operator containing dynamic output mixed with other outputs, the failed not is "
427         << node->fullname_with_scope();
428     }
429   }
430 
431   // operator does not have dynamic output
432   for (size_t i = 0; i < output_num; ++i) {
433     if (!is_support(i)) {
434       return UNKNOWN_KERNEL_TYPE;
435     }
436   }
437 
438   return ACL_KERNEL;
439 }
440 
GetKernelInfoFromGe(const AnfNodePtr & node,ErrorAclType * err_type)441 KernelType AclHelper::GetKernelInfoFromGe(const AnfNodePtr &node, ErrorAclType *err_type) {
442   MS_EXCEPTION_IF_NULL(node);
443   auto cnode = node->cast<CNodePtr>();
444   MS_EXCEPTION_IF_NULL(cnode);
445 
446   std::string name = GetCNodeFuncName(cnode);
447   if (common::AnfAlgo::IsCommunicationOp(node)) {
448     *err_type = kNormalOp;
449     return HCCL_KERNEL;
450   }
451 
452   auto info = GeAdapterManager::GetInstance().GetInfo(name, true);
453   if (info == nullptr) {
454     *err_type = kUnknownOp;
455     MS_LOG(DEBUG) << "Unsupported op type on acl, node name: " << node->fullname_with_scope();
456     return UNKNOWN_KERNEL_TYPE;
457   }
458 
459   // check whether all inputs are matched
460   if (GetKernelInfoByInputs(cnode, info) == UNKNOWN_KERNEL_TYPE) {
461     *err_type = kInValidType;
462     return UNKNOWN_KERNEL_TYPE;
463   }
464 
465   *err_type = kNormalOp;
466   return ACL_KERNEL;
467 }
468 
IsInputDtypeSupport(const std::string & kernel_name,TypeId base_type,size_t idx)469 bool AclHelper::IsInputDtypeSupport(const std::string &kernel_name, TypeId base_type, size_t idx) {
470   auto info = GeAdapterManager::GetInstance().GetInfo(kernel_name, true);
471   MS_EXCEPTION_IF_NULL(info);
472   auto input_supported_dtypes = info->input_supported_dtypes();
473   if (idx >= info->GetNumInputsOfMsOpProto()) {
474     // this branch represent input_attr_map, didn't need check
475     return true;
476   }
477   if (!std::any_of(input_supported_dtypes[idx].begin(), input_supported_dtypes[idx].end(),
478                    [base_type](const ::ge::DataType ge_type) { return ConvertGeType(ge_type) == base_type; })) {
479     return false;
480   }
481   return true;
482 }
483 
GetValidKernelBuildInfo(const AnfNodePtr & node,std::vector<std::string> * input_formats,std::vector<std::string> * output_formats,std::vector<std::string> * input_reshape_types,std::vector<std::string> * output_reshape_types)484 void AclHelper::GetValidKernelBuildInfo(const AnfNodePtr &node, std::vector<std::string> *input_formats,
485                                         std::vector<std::string> *output_formats,
486                                         std::vector<std::string> *input_reshape_types,
487                                         std::vector<std::string> *output_reshape_types) {
488   MS_EXCEPTION_IF_NULL(node);
489   MS_EXCEPTION_IF_NULL(input_formats);
490   MS_EXCEPTION_IF_NULL(output_formats);
491   MS_EXCEPTION_IF_NULL(input_reshape_types);
492   MS_EXCEPTION_IF_NULL(output_reshape_types);
493   auto cnode = node->cast<CNodePtr>();
494   MS_EXCEPTION_IF_NULL(cnode);
495   std::string name = GetCNodeFuncName(cnode);
496   auto info = GeAdapterManager::GetInstance().GetInfo(name, true);
497   auto op_type = info->op_type();
498 
499   input_formats->clear();
500   output_formats->clear();
501   input_reshape_types->clear();
502   output_reshape_types->clear();
503   size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
504   size_t output_num = AnfUtils::GetOutputTensorNum(node);
505   input_reshape_types->assign(input_num, "");
506   output_reshape_types->assign(output_num, "");
507 
508   if (!AclAdapterManager::GetInstance().CheckAclAdapter(op_type)) {
509     std::vector<size_t> special_inputs;
510     for (size_t i = 0; i < input_num; ++i) {
511       auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i);
512       bool input_special_flag = false;
513       auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
514       auto prev_shape = common::AnfAlgo::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
515       input_format = NeedNDInput(cnode, kernel_with_index.first, input_format, &input_format, &input_special_flag)
516                        ? GET_DEFAULT_FORMAT(prev_shape)
517                        : input_format;
518       (void)input_formats->emplace_back(input_format);
519       if (input_special_flag) {
520         (void)special_inputs.emplace_back(i);
521       }
522     }
523     // Input and output number same's op forward.
524     if (NeedNDOutput(cnode, input_num, output_num, *input_formats)) {
525       for (size_t i = 0; i < output_num; ++i) {
526         auto shape = common::AnfAlgo::GetOutputInferShape(node, i);
527         (void)output_formats->emplace_back(GET_DEFAULT_FORMAT(shape));
528       }
529     } else {
530       if (output_num == 1) {
531         output_formats->emplace_back(input_formats->at(0));
532       } else {
533         output_formats->assign(input_formats->begin(), input_formats->end());
534       }
535       SetOutputIdentityFlag(node, *output_formats);
536     }
537 
538     if (!special_inputs.empty()) {
539       common::AnfAlgo::SetNodeAttr(kAttrAclSpecialInputFormat, MakeValue(special_inputs), node);
540     }
541     RefreshRefFormat(info->GetRefMappingInfo(), *input_formats, output_formats);
542     return;
543   }
544 
545   auto acl_info = AclAdapterManager::GetInstance().GetOpInfo(op_type);
546   GetInputBuildInfo(node, input_num, acl_info, info, input_formats, input_reshape_types);
547   GetOutputBuildInfo(node, output_num, acl_info, *input_formats, output_formats);
548   SetOutputIdentityFlag(node, *output_formats);
549   RefreshRefFormat(info->GetRefMappingInfo(), *input_formats, output_formats);
550 }
551 
PaddingOriShape(const std::string & name,size_t idx,const std::string & format,ShapeVector * shape)552 void AclHelper::PaddingOriShape(const std::string &name, size_t idx, const std::string &format, ShapeVector *shape) {
553   MS_EXCEPTION_IF_NULL(shape);
554   auto info = GeAdapterManager::GetInstance().GetInfo(name, true);
555   auto op_type = info->op_type();
556   if (!AclAdapterManager::GetInstance().CheckAclAdapter(op_type)) {
557     return;
558   }
559   auto acl_info = AclAdapterManager::GetInstance().GetOpInfo(op_type);
560   auto info_list = acl_info.inputs();
561   if (info_list.empty() || idx >= info_list.size()) {
562     return;
563   }
564   auto ge_idx = info->GetGeInputByMsInputIndex(idx).index;
565   auto special_iter = info_list.find(ge_idx);
566   if (special_iter == info_list.end() || special_iter->second.ori_format.empty()) {
567     return;
568   }
569   if (!special_iter->second.ori_format.empty() && format == kOpFormat_NCHW && shape->size() < kDim4) {
570     *shape = trans::PaddingShape(*shape, kOpFormat_NCHW, special_iter->second.reshape_type);
571   }
572 }
573 
ConvertOriginShapeAndFormat(const std::string & name,size_t idx,const std::string & dev_format,ShapeVector * shape)574 std::string AclHelper::ConvertOriginShapeAndFormat(const std::string &name, size_t idx, const std::string &dev_format,
575                                                    ShapeVector *shape) {
576   MS_EXCEPTION_IF_NULL(shape);
577   auto info = GeAdapterManager::GetInstance().GetInfo(name, true);
578   auto op_type = info->op_type();
579   std::string ret_format = (shape->size() == kDim4) ? kOpFormat_NCHW : kOpFormat_DEFAULT;
580   // case0: normal
581   if (!AclAdapterManager::GetInstance().CheckAclAdapter(op_type)) {
582     return ret_format;
583   }
584   // case1: 3d operator
585   auto acl_info = AclAdapterManager::GetInstance().GetOpInfo(op_type);
586   if (acl_info.is_3d()) {
587     *shape = trans::PaddingShape(*shape, kOpFormat_NCDHW);
588     return kOpFormat_NCDHW;
589   }
590   if (acl_info.is_need_pad_no_shape() && shape->empty()) {
591     shape->push_back(1);
592   }
593   // case2: no special config
594   auto info_list = acl_info.inputs();
595   if (info_list.empty() || idx >= info_list.size()) {
596     return ret_format;
597   }
598   auto ge_idx = info->GetGeInputByMsInputIndex(idx).index;
599   auto special_iter = info_list.find(ge_idx);
600   if (special_iter == info_list.end() || special_iter->second.ori_format.empty()) {
601     return ret_format;
602   }
603   // case3: if config input ori format or dev_format is special
604   if (!special_iter->second.ori_format.empty() || !CheckDefaultSupportFormat(dev_format)) {
605     if (special_iter->second.ori_format[0] == kOpFormat_ND) {
606       return kOpFormat_ND;
607     }
608     if (ret_format == kOpFormat_DEFAULT && shape->size() < kDim4) {
609       *shape = trans::PaddingShape(*shape, kOpFormat_NCHW, special_iter->second.reshape_type);
610       ret_format = kOpFormat_NCHW;
611     }
612   }
613   return ret_format;
614 }
615 
NeedCheckAttrToInput(const CNodePtr & node,const mindspore::HashMap<size_t,std::string> & attr_input_map,size_t index)616 bool AclHelper::NeedCheckAttrToInput(const CNodePtr &node,
617                                      const mindspore::HashMap<size_t, std::string> &attr_input_map, size_t index) {
618   MS_EXCEPTION_IF_NULL(node);
619   if (attr_input_map.count(index) == 0) {
620     return false;
621   }
622 
623   const auto &attr_name = attr_input_map.at(index);
624   if (common::AnfAlgo::HasNodeAttr(attr_name, node)) {
625     return true;
626   }
627   return false;
628 }
629 
GetFormatFromAttr(const PrimitivePtr & primitive)630 std::string AclHelper::GetFormatFromAttr(const PrimitivePtr &primitive) {
631   MS_EXCEPTION_IF_NULL(primitive);
632   auto &attrs = primitive->attrs();
633   std::string format;
634   if (attrs.count("format") != 0) {
635     auto attr_value = attrs.at("format");
636     if (attr_value->isa<StringImm>()) {
637       format = GetValue<std::string>(attr_value);
638     } else {
639       MS_LOG(DEBUG) << "The attr format is not a valid value.";
640     }
641   }
642   return format;
643 }
644 
GetDefaultFormatFlagFromAttr(const PrimitivePtr & primitive,bool is_input)645 bool AclHelper::GetDefaultFormatFlagFromAttr(const PrimitivePtr &primitive, bool is_input) {
646   MS_EXCEPTION_IF_NULL(primitive);
647   bool is_default = true;
648   auto key = is_input ? kAttrInputDefaultFormat : kAttrOutputDefaultFormat;
649   auto attrs = primitive->attrs();
650   if (attrs.count(key) != 0) {
651     auto attr_value = attrs.at(key);
652     if (attr_value->isa<BoolImm>()) {
653       is_default = GetValue<bool>(attr_value);
654     } else {
655       MS_LOG(DEBUG) << "The attr: " << key << " is not a valid value.";
656     }
657   }
658   return is_default;
659 }
660 
GetFracZGroupFromAttr(const PrimitivePtr & primitive)661 int64_t AclHelper::GetFracZGroupFromAttr(const PrimitivePtr &primitive) {
662   MS_EXCEPTION_IF_NULL(primitive);
663   auto attrs = primitive->attrs();
664   int64_t fracz_group = 1;
665   if (attrs.count(kAttrFracZGroup) != 0) {
666     auto attr_value = attrs.at(kAttrFracZGroup);
667     if (attr_value->isa<Int64Imm>()) {
668       fracz_group = GetValue<int64_t>(attr_value);
669     } else {
670       MS_LOG(DEBUG) << "The FracZGroup attr is not a valid value.";
671     }
672   }
673   return fracz_group;
674 }
675 
IsNopNode(const CNodePtr & node)676 bool AclHelper::IsNopNode(const CNodePtr &node) {
677   MS_EXCEPTION_IF_NULL(node);
678   static mindspore::HashSet<std::string> nop_nodes = {prim::kPrimReshape->name(), prim::kPrimExpandDims->name(),
679                                                       prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
680                                                       prim::kPrimFlattenGrad->name()};
681   auto op_name = common::AnfAlgo::GetCNodeName(node);
682   return (nop_nodes.find(op_name) != nop_nodes.end());
683 }
684 
NeedIdentityFlag(const std::vector<std::string> & formats)685 bool AclHelper::NeedIdentityFlag(const std::vector<std::string> &formats) {
686   return std::any_of(formats.begin(), formats.end(),
687                      [](const auto &format) { return !AclHelper::CheckDefaultSupportFormat(format); });
688 }
689 }  // namespace transform
690 }  // namespace mindspore
691