1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "transform/acl_ir/op_api_util.h"
17 #include <dlfcn.h>
18 #include <unordered_map>
19 #include <unordered_set>
20 #include "acl/error_codes/rt_error_codes.h"
21 #include "transform/acl_ir/acl_helper.h"
22 #include "include/backend/anf_runtime_algorithm.h"
23 #include "include/common/utils/anfalgo.h"
24 #include "include/common/utils/utils.h"
25 #include "ops/math_op_name.h"
26 #include "utils/ms_context.h"
27 #include "transform/symbol/acl_base_symbol.h"
28 #include "transform/symbol/acl_compiler_symbol.h"
29 #include "transform/symbol/symbol_utils.h"
30
31 namespace mindspore::transform {
32 namespace {
33 typedef aclError (*AclrtCtxSetSysParamOpt)(aclSysParamOpt, int64_t);
34 typedef HcclResult (*HcclSetConfigFunc)(HcclConfig, HcclConfigValue);
35
36 static const char k910BKey[] = "Ascend910B";
37 static const char k310BKey[] = "Ascend310B";
38 static const char k910CKey[] = "Ascend910C";
39
40 static const std::unordered_map<std::string, aclCubeMathType> kCubeMathType = {
41 {"force_fp16", FORCE_FP16},
42 {"allow_fp32_to_fp16", ALLOW_FP32_DOWN_PRECISION},
43 {"allow_mix_precision", ALLOW_FP32_DOWN_PRECISION},
44 {"must_keep_origin_dtype", KEEP_DTYPE},
45 {"allow_fp32_to_bf16", ALLOW_FP32_DOWN_PRECISION},
46 {"allow_mix_precision_fp16", ALLOW_FP32_DOWN_PRECISION},
47 {"allow_mix_precision_bf16", ALLOW_FP32_DOWN_PRECISION}};
48
49 static const std::unordered_map<uint8_t, aclCubeMathType> kSelectMoreMathType = {
50 {0b01, KEEP_DTYPE}, {0b00, FORCE_FP16}, {0b11, FORCE_HF32}, {0b10, ALLOW_FP32_DOWN_PRECISION}};
51
52 std::mutex set_opt_mutex;
53
SetCompileopt(aclCompileOpt opt,const char * value)54 aclError SetCompileopt(aclCompileOpt opt, const char *value) { return CALL_ASCEND_API(aclSetCompileopt, opt, value); }
55
GetAclFunc(const std::string & lib_path,const std::string & func_name)56 void *GetAclFunc(const std::string &lib_path, const std::string &func_name) {
57 static auto ascend_path = mindspore::transform::GetAscendPath();
58 auto load_path = ascend_path + "/lib64/" + lib_path;
59
60 auto handler = dlopen(load_path.c_str(), RTLD_LAZY);
61 if (handler == nullptr) {
62 MS_LOG(INFO) << "Dlopen " << load_path << " failed!" << dlerror();
63 return nullptr;
64 }
65
66 auto func = dlsym(handler, func_name.c_str());
67 if (func == nullptr) {
68 MS_LOG(INFO) << "Dlsym " << func_name << " from " << load_path << " failed!" << dlerror();
69 }
70 return func;
71 }
72 } // namespace
73
GetCubeMathType(bool use_hf32)74 aclCubeMathType OpApiUtil::GetCubeMathType(bool use_hf32) {
75 static std::string precision_mode = "not_inited";
76 if (precision_mode == "not_inited") {
77 auto ms_context = MsContext::GetInstance();
78 MS_EXCEPTION_IF_NULL(ms_context);
79 precision_mode = ms_context->get_param<std::string>(MS_CTX_PRECISION_MODE);
80 }
81
82 if (!precision_mode.empty() && kCubeMathType.count(precision_mode) != 0) {
83 return kCubeMathType.at(precision_mode);
84 }
85 uint8_t select_mode = (static_cast<uint8_t>(use_hf32) << 1) + AclUtil::KeepOriginDType();
86 if (kSelectMoreMathType.count(select_mode) != 0) {
87 return kSelectMoreMathType.at(select_mode);
88 }
89 return AclUtil::KeepOriginDType() ? KEEP_DTYPE : ALLOW_FP32_DOWN_PRECISION;
90 }
91
GetValidKernelBuildInfo(const AnfNodePtr & node,std::vector<std::string> * input_formats,std::vector<std::string> * output_formats,std::vector<std::string> * input_reshape_types,std::vector<std::string> * output_reshape_types)92 void OpApiUtil::GetValidKernelBuildInfo(const AnfNodePtr &node, std::vector<std::string> *input_formats,
93 std::vector<std::string> *output_formats,
94 std::vector<std::string> *input_reshape_types,
95 std::vector<std::string> *output_reshape_types) {
96 MS_EXCEPTION_IF_NULL(node);
97 MS_EXCEPTION_IF_NULL(input_formats);
98 MS_EXCEPTION_IF_NULL(output_formats);
99 MS_EXCEPTION_IF_NULL(input_reshape_types);
100 MS_EXCEPTION_IF_NULL(output_reshape_types);
101
102 input_formats->clear();
103 output_formats->clear();
104 input_reshape_types->clear();
105 output_reshape_types->clear();
106 size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
107 size_t output_num = AnfUtils::GetOutputTensorNum(node);
108 input_formats->assign(input_num, kOpFormat_DEFAULT);
109 output_formats->assign(output_num, kOpFormat_DEFAULT);
110 input_reshape_types->assign(input_num, "");
111 output_reshape_types->assign(output_num, "");
112
113 std::vector<size_t> special_inputs;
114 std::unordered_set<std::string> matmul_ops = {kMatMulOpName, kMatMulV2OpName, kBatchMatMulOpName};
115 for (size_t i = 0; i < input_num; ++i) {
116 auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i);
117 std::string input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
118 if (!AclHelper::CheckDefaultSupportFormat(input_format) &&
119 matmul_ops.find(AnfUtils::GetCNodeName(kernel_with_index.first)) == matmul_ops.end()) {
120 (void)special_inputs.emplace_back(i);
121 }
122 }
123 if (!special_inputs.empty()) {
124 common::AnfAlgo::SetNodeAttr(kAttrAclSpecialInputFormat, MakeValue(special_inputs), node);
125 }
126 }
127
KeepOriginDType()128 uint8_t AclUtil::KeepOriginDType() {
129 static std::string version = "";
130 static uint8_t need_keep_dtype = 0;
131 if (version.empty()) {
132 const char *soc_name_c = CALL_ASCEND_API(aclrtGetSocName);
133 if (soc_name_c != nullptr) {
134 version = soc_name_c;
135 }
136 if (version.find(k910BKey) != std::string::npos || version.find(k310BKey) != std::string::npos ||
137 version.find(k910CKey) != std::string::npos) {
138 need_keep_dtype = 1;
139 }
140 }
141 return need_keep_dtype;
142 }
143
SetDeterministic()144 void AclUtil::SetDeterministic() {
145 std::lock_guard<std::mutex> lock(set_opt_mutex);
146 auto ms_context = MsContext::GetInstance();
147 MS_EXCEPTION_IF_NULL(ms_context);
148 bool is_deterministic = ms_context->get_param<std::string>(MS_CTX_DETERMINISTIC) == "ON" ? true : false;
149 // Set acl
150 auto ret = SetCompileopt(aclCompileOpt::ACL_OP_DETERMINISTIC, is_deterministic ? "1" : "0");
151 if (ret != ACL_SUCCESS) {
152 MS_LOG(EXCEPTION) << "Acl set deterministic mode failed! mode is " << is_deterministic << " and error flag is "
153 << ret;
154 }
155 // Set acl sys
156 const std::string rt_sys_opt_lib = "libacl_op_compiler.so";
157 const std::string rt_sys_opt_name = "aclrtCtxSetSysParamOpt";
158 auto rt_sys_opt = GetAclFunc(rt_sys_opt_lib, rt_sys_opt_name);
159 if (rt_sys_opt == nullptr) {
160 MS_LOG(EXCEPTION) << "Get 'aclrtCtxSetSysParamOpt' from " << rt_sys_opt_lib << " failed!";
161 }
162 auto rt_sys_opt_func = reinterpret_cast<AclrtCtxSetSysParamOpt>(rt_sys_opt);
163 ret = rt_sys_opt_func(aclSysParamOpt::ACL_OPT_DETERMINISTIC, is_deterministic ? 1 : 0);
164 if (ret != ACL_SUCCESS) {
165 MS_LOG(EXCEPTION) << "Acl sys set deterministic mode failed! mode is " << is_deterministic << " and error flag is "
166 << ret;
167 }
168 // Set hccl
169 const std::string hccl_lib = "libhccl.so";
170 const std::string hccl_set_config_name = "HcclSetConfig";
171 auto hccl_set_config = GetAclFunc(hccl_lib, hccl_set_config_name);
172 if (hccl_set_config == nullptr) {
173 MS_LOG(EXCEPTION) << "Get 'HcclSetConfig' from " << hccl_lib << " failed!";
174 }
175 auto hccl_set_config_func = reinterpret_cast<HcclSetConfigFunc>(hccl_set_config);
176 HcclConfigValue config = {is_deterministic ? 1 : 0};
177 auto hccl_ret = hccl_set_config_func(HcclConfig::HCCL_DETERMINISTIC, config);
178 if (hccl_ret != HCCL_SUCCESS) {
179 MS_LOG(EXCEPTION) << "Hccl set deterministic mode failed! mode is " << is_deterministic << " and error flag is "
180 << ret;
181 }
182 }
183
SetCompileMode(const int64_t is_dynamic)184 aclError AclUtil::SetCompileMode(const int64_t is_dynamic) {
185 std::lock_guard<std::mutex> lock(set_opt_mutex);
186 static int64_t last_mode = -1;
187 if (is_dynamic != last_mode) {
188 std::string mode = is_dynamic ? "disable" : "enable";
189 auto set_compile_flag = SetCompileopt(aclCompileOpt::ACL_OP_JIT_COMPILE, mode.c_str());
190 last_mode = is_dynamic;
191 return set_compile_flag;
192 }
193
194 return ACL_SUCCESS;
195 }
196
SetPrecisionMode(const std::string & mode)197 aclError AclUtil::SetPrecisionMode(const std::string &mode) {
198 std::lock_guard<std::mutex> lock(set_opt_mutex);
199
200 static int8_t is_global_precision = -1;
201 if (is_global_precision == -1) {
202 auto ms_context = MsContext::GetInstance();
203 MS_EXCEPTION_IF_NULL(ms_context);
204 auto precision_mode = ms_context->get_param<std::string>(MS_CTX_PRECISION_MODE);
205 if (!precision_mode.empty()) {
206 is_global_precision = 1;
207 } else {
208 is_global_precision = 0;
209 }
210 }
211 if (is_global_precision == 1) {
212 return ACL_SUCCESS;
213 }
214
215 static std::string last_mode = (AclUtil::KeepOriginDType() == 1) ? "must_keep_origin_dtype" : "allow_fp32_to_fp16";
216 if (last_mode != mode) {
217 auto ret = SetCompileopt(aclCompileOpt::ACL_PRECISION_MODE, mode.c_str());
218 last_mode = mode;
219 return ret;
220 }
221 return ACL_SUCCESS;
222 }
223
SetOpPrecisionMode()224 void AclUtil::SetOpPrecisionMode() {
225 std::lock_guard<std::mutex> lock(set_opt_mutex);
226 auto ms_context = MsContext::GetInstance();
227 MS_EXCEPTION_IF_NULL(ms_context);
228 auto op_precision_mode = ms_context->get_param<std::string>(MS_CTX_OP_PRECISION_MODE);
229 if (op_precision_mode.empty()) {
230 return;
231 }
232 MS_LOG(DEBUG) << "Set ACL_OP_PRECISION_MODE: " << op_precision_mode;
233 auto ret = SetCompileopt(aclCompileOpt::ACL_OP_PRECISION_MODE, op_precision_mode.c_str());
234 if (ret != ACL_SUCCESS) {
235 MS_LOG(EXCEPTION) << "Acl set op precision mode failed! error flag is " << ret;
236 }
237 }
238 } // namespace mindspore::transform
239