• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_ACL_IR_ACL_HELPER_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_ACL_IR_ACL_HELPER_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include "ir/anf.h"
24 #include "ir/tensor.h"
25 #include "kernel/kernel.h"
26 
27 namespace mindspore {
28 namespace transform {
29 class GeAdapterInfo;
30 typedef enum ErrorAclType { kNormalOp, kUnknownOp, kInValidType, kSpecialOp, kInvalidBuildInfo } ErrorAclType;
31 
32 class AclHelper {
33  public:
34   // Check is data layout unchanged format.
35   static bool CheckDefaultSupportFormat(const string &format);
36   static bool IsPrintDebugString();
37 
38   // Kernel select by ge_ir.
39   static KernelType GetKernelInfoByInputs(const CNodePtr &cnode, const std::shared_ptr<GeAdapterInfo> &info);
40   static KernelType GetKernelInfoByOutputs(const AnfNodePtr &node, const std::shared_ptr<GeAdapterInfo> &info);
41   static KernelType GetKernelInfoFromGe(const AnfNodePtr &node, ErrorAclType *err_type);
42 
43   // Select kernel's device format.
44   static void GetValidKernelBuildInfo(const AnfNodePtr &node, std::vector<std::string> *input_formats,
45                                       std::vector<std::string> *output_formats,
46                                       std::vector<std::string> *input_reshape_types,
47                                       std::vector<std::string> *output_reshape_types);
48 
49   // Convert mindspore's origin information to acl's origin information.
50   static void PaddingOriShape(const std::string &name, size_t idx, const std::string &format, ShapeVector *shape);
51   static std::string ConvertOriginShapeAndFormat(const std::string &name, size_t idx, const std::string &dev_format,
52                                                  ShapeVector *shape);
53 
54   // Get attribute to input information.
55   static bool NeedCheckAttrToInput(const CNodePtr &node, const mindspore::HashMap<size_t, std::string> &attr_input_map,
56                                    size_t index);
57   // Get special information from kernel's attribute.
58   static std::string GetFormatFromAttr(const PrimitivePtr &primitive);
59   static int64_t GetFracZGroupFromAttr(const PrimitivePtr &primitive);
60   static bool GetDefaultFormatFlagFromAttr(const PrimitivePtr &primitive, bool is_input);
61 
62   // Get kernel's precision mode is FORCE_FP32.
63   static bool GetMoreDataTypeSupported(TypeId data_type, const std::string &op_type);
64 
65   // Check whether is nop op.
66   static bool IsNopNode(const CNodePtr &node);
67   static bool IsInputDtypeSupport(const std::string &kernel_name, TypeId base_type, size_t idx);
68 
69   // Set identity flag.
70   static bool NeedIdentityFlag(const std::vector<std::string> &formats);
71 };
72 }  // namespace transform
73 }  // namespace mindspore
74 
75 #endif  // MINDSPORE_CCSRC_TRANSFORM_ACL_IR_ACL_HELPER_H_
76