• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "transform/acl_ir/op_api_util.h"
17 #include <dlfcn.h>
18 #include <unordered_map>
19 #include <unordered_set>
20 #include "acl/error_codes/rt_error_codes.h"
21 #include "transform/acl_ir/acl_helper.h"
22 #include "include/backend/anf_runtime_algorithm.h"
23 #include "include/common/utils/anfalgo.h"
24 #include "include/common/utils/utils.h"
25 #include "ops/math_op_name.h"
26 #include "utils/ms_context.h"
27 #include "transform/symbol/acl_base_symbol.h"
28 #include "transform/symbol/acl_compiler_symbol.h"
29 #include "transform/symbol/symbol_utils.h"
30 
31 namespace mindspore::transform {
32 namespace {
33 typedef aclError (*AclrtCtxSetSysParamOpt)(aclSysParamOpt, int64_t);
34 typedef HcclResult (*HcclSetConfigFunc)(HcclConfig, HcclConfigValue);
35 
36 static const char k910BKey[] = "Ascend910B";
37 static const char k310BKey[] = "Ascend310B";
38 static const char k910CKey[] = "Ascend910C";
39 
40 static const std::unordered_map<std::string, aclCubeMathType> kCubeMathType = {
41   {"force_fp16", FORCE_FP16},
42   {"allow_fp32_to_fp16", ALLOW_FP32_DOWN_PRECISION},
43   {"allow_mix_precision", ALLOW_FP32_DOWN_PRECISION},
44   {"must_keep_origin_dtype", KEEP_DTYPE},
45   {"allow_fp32_to_bf16", ALLOW_FP32_DOWN_PRECISION},
46   {"allow_mix_precision_fp16", ALLOW_FP32_DOWN_PRECISION},
47   {"allow_mix_precision_bf16", ALLOW_FP32_DOWN_PRECISION}};
48 
49 static const std::unordered_map<uint8_t, aclCubeMathType> kSelectMoreMathType = {
50   {0b01, KEEP_DTYPE}, {0b00, FORCE_FP16}, {0b11, FORCE_HF32}, {0b10, ALLOW_FP32_DOWN_PRECISION}};
51 
52 std::mutex set_opt_mutex;
53 
SetCompileopt(aclCompileOpt opt,const char * value)54 aclError SetCompileopt(aclCompileOpt opt, const char *value) { return CALL_ASCEND_API(aclSetCompileopt, opt, value); }
55 
GetAclFunc(const std::string & lib_path,const std::string & func_name)56 void *GetAclFunc(const std::string &lib_path, const std::string &func_name) {
57   static auto ascend_path = mindspore::transform::GetAscendPath();
58   auto load_path = ascend_path + "/lib64/" + lib_path;
59 
60   auto handler = dlopen(load_path.c_str(), RTLD_LAZY);
61   if (handler == nullptr) {
62     MS_LOG(INFO) << "Dlopen " << load_path << " failed!" << dlerror();
63     return nullptr;
64   }
65 
66   auto func = dlsym(handler, func_name.c_str());
67   if (func == nullptr) {
68     MS_LOG(INFO) << "Dlsym " << func_name << " from " << load_path << " failed!" << dlerror();
69   }
70   return func;
71 }
72 }  // namespace
73 
GetCubeMathType(bool use_hf32)74 aclCubeMathType OpApiUtil::GetCubeMathType(bool use_hf32) {
75   static std::string precision_mode = "not_inited";
76   if (precision_mode == "not_inited") {
77     auto ms_context = MsContext::GetInstance();
78     MS_EXCEPTION_IF_NULL(ms_context);
79     precision_mode = ms_context->get_param<std::string>(MS_CTX_PRECISION_MODE);
80   }
81 
82   if (!precision_mode.empty() && kCubeMathType.count(precision_mode) != 0) {
83     return kCubeMathType.at(precision_mode);
84   }
85   uint8_t select_mode = (static_cast<uint8_t>(use_hf32) << 1) + AclUtil::KeepOriginDType();
86   if (kSelectMoreMathType.count(select_mode) != 0) {
87     return kSelectMoreMathType.at(select_mode);
88   }
89   return AclUtil::KeepOriginDType() ? KEEP_DTYPE : ALLOW_FP32_DOWN_PRECISION;
90 }
91 
GetValidKernelBuildInfo(const AnfNodePtr & node,std::vector<std::string> * input_formats,std::vector<std::string> * output_formats,std::vector<std::string> * input_reshape_types,std::vector<std::string> * output_reshape_types)92 void OpApiUtil::GetValidKernelBuildInfo(const AnfNodePtr &node, std::vector<std::string> *input_formats,
93                                         std::vector<std::string> *output_formats,
94                                         std::vector<std::string> *input_reshape_types,
95                                         std::vector<std::string> *output_reshape_types) {
96   MS_EXCEPTION_IF_NULL(node);
97   MS_EXCEPTION_IF_NULL(input_formats);
98   MS_EXCEPTION_IF_NULL(output_formats);
99   MS_EXCEPTION_IF_NULL(input_reshape_types);
100   MS_EXCEPTION_IF_NULL(output_reshape_types);
101 
102   input_formats->clear();
103   output_formats->clear();
104   input_reshape_types->clear();
105   output_reshape_types->clear();
106   size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
107   size_t output_num = AnfUtils::GetOutputTensorNum(node);
108   input_formats->assign(input_num, kOpFormat_DEFAULT);
109   output_formats->assign(output_num, kOpFormat_DEFAULT);
110   input_reshape_types->assign(input_num, "");
111   output_reshape_types->assign(output_num, "");
112 
113   std::vector<size_t> special_inputs;
114   std::unordered_set<std::string> matmul_ops = {kMatMulOpName, kMatMulV2OpName, kBatchMatMulOpName};
115   for (size_t i = 0; i < input_num; ++i) {
116     auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i);
117     std::string input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
118     if (!AclHelper::CheckDefaultSupportFormat(input_format) &&
119         matmul_ops.find(AnfUtils::GetCNodeName(kernel_with_index.first)) == matmul_ops.end()) {
120       (void)special_inputs.emplace_back(i);
121     }
122   }
123   if (!special_inputs.empty()) {
124     common::AnfAlgo::SetNodeAttr(kAttrAclSpecialInputFormat, MakeValue(special_inputs), node);
125   }
126 }
127 
KeepOriginDType()128 uint8_t AclUtil::KeepOriginDType() {
129   static std::string version = "";
130   static uint8_t need_keep_dtype = 0;
131   if (version.empty()) {
132     const char *soc_name_c = CALL_ASCEND_API(aclrtGetSocName);
133     if (soc_name_c != nullptr) {
134       version = soc_name_c;
135     }
136     if (version.find(k910BKey) != std::string::npos || version.find(k310BKey) != std::string::npos ||
137         version.find(k910CKey) != std::string::npos) {
138       need_keep_dtype = 1;
139     }
140   }
141   return need_keep_dtype;
142 }
143 
SetDeterministic()144 void AclUtil::SetDeterministic() {
145   std::lock_guard<std::mutex> lock(set_opt_mutex);
146   auto ms_context = MsContext::GetInstance();
147   MS_EXCEPTION_IF_NULL(ms_context);
148   bool is_deterministic = ms_context->get_param<std::string>(MS_CTX_DETERMINISTIC) == "ON" ? true : false;
149   // Set acl
150   auto ret = SetCompileopt(aclCompileOpt::ACL_OP_DETERMINISTIC, is_deterministic ? "1" : "0");
151   if (ret != ACL_SUCCESS) {
152     MS_LOG(EXCEPTION) << "Acl set deterministic mode failed! mode is " << is_deterministic << " and error flag is "
153                       << ret;
154   }
155   // Set acl sys
156   const std::string rt_sys_opt_lib = "libacl_op_compiler.so";
157   const std::string rt_sys_opt_name = "aclrtCtxSetSysParamOpt";
158   auto rt_sys_opt = GetAclFunc(rt_sys_opt_lib, rt_sys_opt_name);
159   if (rt_sys_opt == nullptr) {
160     MS_LOG(EXCEPTION) << "Get 'aclrtCtxSetSysParamOpt' from " << rt_sys_opt_lib << " failed!";
161   }
162   auto rt_sys_opt_func = reinterpret_cast<AclrtCtxSetSysParamOpt>(rt_sys_opt);
163   ret = rt_sys_opt_func(aclSysParamOpt::ACL_OPT_DETERMINISTIC, is_deterministic ? 1 : 0);
164   if (ret != ACL_SUCCESS) {
165     MS_LOG(EXCEPTION) << "Acl sys set deterministic mode failed! mode is " << is_deterministic << " and error flag is "
166                       << ret;
167   }
168   // Set hccl
169   const std::string hccl_lib = "libhccl.so";
170   const std::string hccl_set_config_name = "HcclSetConfig";
171   auto hccl_set_config = GetAclFunc(hccl_lib, hccl_set_config_name);
172   if (hccl_set_config == nullptr) {
173     MS_LOG(EXCEPTION) << "Get 'HcclSetConfig' from " << hccl_lib << " failed!";
174   }
175   auto hccl_set_config_func = reinterpret_cast<HcclSetConfigFunc>(hccl_set_config);
176   HcclConfigValue config = {is_deterministic ? 1 : 0};
177   auto hccl_ret = hccl_set_config_func(HcclConfig::HCCL_DETERMINISTIC, config);
178   if (hccl_ret != HCCL_SUCCESS) {
179     MS_LOG(EXCEPTION) << "Hccl set deterministic mode failed! mode is " << is_deterministic << " and error flag is "
180                       << ret;
181   }
182 }
183 
SetCompileMode(const int64_t is_dynamic)184 aclError AclUtil::SetCompileMode(const int64_t is_dynamic) {
185   std::lock_guard<std::mutex> lock(set_opt_mutex);
186   static int64_t last_mode = -1;
187   if (is_dynamic != last_mode) {
188     std::string mode = is_dynamic ? "disable" : "enable";
189     auto set_compile_flag = SetCompileopt(aclCompileOpt::ACL_OP_JIT_COMPILE, mode.c_str());
190     last_mode = is_dynamic;
191     return set_compile_flag;
192   }
193 
194   return ACL_SUCCESS;
195 }
196 
SetPrecisionMode(const std::string & mode)197 aclError AclUtil::SetPrecisionMode(const std::string &mode) {
198   std::lock_guard<std::mutex> lock(set_opt_mutex);
199 
200   static int8_t is_global_precision = -1;
201   if (is_global_precision == -1) {
202     auto ms_context = MsContext::GetInstance();
203     MS_EXCEPTION_IF_NULL(ms_context);
204     auto precision_mode = ms_context->get_param<std::string>(MS_CTX_PRECISION_MODE);
205     if (!precision_mode.empty()) {
206       is_global_precision = 1;
207     } else {
208       is_global_precision = 0;
209     }
210   }
211   if (is_global_precision == 1) {
212     return ACL_SUCCESS;
213   }
214 
215   static std::string last_mode = (AclUtil::KeepOriginDType() == 1) ? "must_keep_origin_dtype" : "allow_fp32_to_fp16";
216   if (last_mode != mode) {
217     auto ret = SetCompileopt(aclCompileOpt::ACL_PRECISION_MODE, mode.c_str());
218     last_mode = mode;
219     return ret;
220   }
221   return ACL_SUCCESS;
222 }
223 
SetOpPrecisionMode()224 void AclUtil::SetOpPrecisionMode() {
225   std::lock_guard<std::mutex> lock(set_opt_mutex);
226   auto ms_context = MsContext::GetInstance();
227   MS_EXCEPTION_IF_NULL(ms_context);
228   auto op_precision_mode = ms_context->get_param<std::string>(MS_CTX_OP_PRECISION_MODE);
229   if (op_precision_mode.empty()) {
230     return;
231   }
232   MS_LOG(DEBUG) << "Set ACL_OP_PRECISION_MODE: " << op_precision_mode;
233   auto ret = SetCompileopt(aclCompileOpt::ACL_OP_PRECISION_MODE, op_precision_mode.c_str());
234   if (ret != ACL_SUCCESS) {
235     MS_LOG(EXCEPTION) << "Acl set op precision mode failed! error flag is " << ret;
236   }
237 }
238 }  // namespace mindspore::transform
239