1 /**
2 * Copyright 2021-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/graph_kernel/graph_kernel_flags.h"
18
19 #include <map>
20 #include <string>
21 #include <cstring>
22 #include <vector>
23 #include <utility>
24 #include "nlohmann/json.hpp"
25 #include "utils/ms_context.h"
26 #include "include/common/utils/utils.h"
27
28 namespace mindspore::graphkernel {
29 namespace {
30 constexpr auto kLogValidFlag =
31 "Valid flag format is \"--key=value\", flags are separated by spaces(e.g. \"--key1=value1 --key2=value2\"). bool "
32 "flag's value can be implicit, the \"--key\" means \"--key=true\".";
33
34 // Split string to tokens
GetTokens(const std::string & str,const std::string & delim)35 std::vector<std::string> GetTokens(const std::string &str, const std::string &delim) {
36 std::vector<std::string> tokens;
37 std::vector<char> c_str(str.begin(), str.end());
38 c_str.push_back('\0');
39 char *saveptr = nullptr;
40 #ifdef _MSC_VER
41 char *pch = strtok_s(&c_str[0], delim.c_str(), &saveptr);
42 #else
43 char *pch = strtok_r(&c_str[0], delim.c_str(), &saveptr);
44 #endif
45 while (pch != nullptr) {
46 (void)tokens.emplace_back(pch);
47 #ifdef _MSC_VER
48 pch = strtok_s(nullptr, delim.c_str(), &saveptr);
49 #else
50 pch = strtok_r(nullptr, delim.c_str(), &saveptr);
51 #endif
52 }
53 return tokens;
54 }
55
56 // Parse flag string to key-value pair.
57 // Flag format: "--key=value", bool flag's value can be implicit, the "--key" means "--key=true"
ParseFlag(const std::string & flag)58 std::pair<std::string, std::string> ParseFlag(const std::string &flag) {
59 auto i = flag.find("--");
60 // check the string starts with "--".
61 constexpr size_t leading_size = 2;
62 if (flag.size() <= leading_size || i != 0) {
63 return std::pair<std::string, std::string>();
64 }
65 i += leading_size;
66
67 auto j = flag.find('=', i + 1); // the key should not be empty, "--=" is invalid
68 if (j >= flag.size()) {
69 // no value, treated as bool flag.
70 return std::make_pair(flag.substr(i), "");
71 } else if (j + 1 < flag.size() && flag.find('=', j + 1) == std::string::npos) {
72 // normal "--key=value" format
73 return std::make_pair(flag.substr(i, j - i), flag.substr(j + 1));
74 }
75 // string with two "=" is invalid.
76 return std::pair<std::string, std::string>();
77 }
78
ParseFlags(const std::string & flags)79 std::map<std::string, std::string> ParseFlags(const std::string &flags) {
80 std::map<std::string, std::string> flag_map;
81 auto tokens = GetTokens(flags, " ");
82 for (const auto &token : tokens) {
83 auto flag = ParseFlag(token);
84 if (flag.first != "") {
85 if (!flag_map.insert(flag).second) {
86 MS_LOG(WARNING) << "For 'context.set_context', the flag '" << flag.first
87 << "' in the parameter 'graph_kernel_flags' is repeated.";
88 }
89 } else {
90 MS_LOG(WARNING) << "For 'context.set_context', the flag '" << token
91 << "' in the parameter 'graph_kernel_flags' is invalid. " << kLogValidFlag;
92 }
93 }
94 return flag_map;
95 }
96
97 class FlagRegister {
98 public:
FlagRegister(std::map<std::string,std::string> * flag_map)99 explicit FlagRegister(std::map<std::string, std::string> *flag_map) : flag_map_(*flag_map) {}
100 ~FlagRegister() = default;
101
102 template <typename T>
AddFlag(const std::string & flag_name,T * flag_var,T default_value) const103 void AddFlag(const std::string &flag_name, T *flag_var, T default_value) const {
104 *flag_var = std::move(default_value);
105 AddFlag(flag_name, flag_var);
106 }
107
108 template <typename T>
AddFlag(const std::string & flag_name,T * flag_var) const109 void AddFlag(const std::string &flag_name, T *flag_var) const {
110 const auto iter = flag_map_.find(flag_name);
111 if (iter != flag_map_.end()) {
112 T var;
113 bool ret = ParseValue(iter->second, &var);
114 if (ret) {
115 *flag_var = std::move(var);
116 } else {
117 if (iter->second.empty()) {
118 MS_LOG(WARNING) << "For 'context.set_context', the flag --" << iter->first
119 << " in the parameter 'graph_kernel_flags' is invalid. " << kLogValidFlag;
120 } else {
121 MS_LOG(WARNING) << "For 'context.set_context', the flag --" << iter->first << "=" << iter->second
122 << " in the parameter 'graph_kernel_flags' is invalid. " << kLogValidFlag;
123 }
124 }
125 (void)flag_map_.erase(iter);
126 }
127 }
128
129 private:
ParseValue(const std::string & s,std::vector<std::string> * result) const130 bool ParseValue(const std::string &s, std::vector<std::string> *result) const {
131 *result = GetTokens(s, ",");
132 return !result->empty();
133 }
134
ParseValue(const std::string & s,bool * result) const135 bool ParseValue(const std::string &s, bool *result) const {
136 *result = (s.empty() || s == "true" || s == "True" || s == "on" || s == "1");
137 return *result || s == "false" || s == "False" || s == "off" || s == "0";
138 }
139
140 template <typename T>
ParseValue(const std::string & s,T * result) const141 bool ParseValue(const std::string &s, T *result) const {
142 if (s.empty()) {
143 return false;
144 }
145 std::istringstream iss(s);
146 iss >> (*result);
147 return iss.eof();
148 }
149
150 template <typename T>
ParseValue(const std::string & s,std::vector<T> * result) const151 bool ParseValue(const std::string &s, std::vector<T> *result) const {
152 result->clear();
153 auto tokens = GetTokens(s, ",");
154 if (tokens.empty()) {
155 return false;
156 }
157 for (const auto &tok : tokens) {
158 T temp;
159 if (!ParseValue(tok, &temp)) {
160 result->clear();
161 return false;
162 }
163 result->emplace_back(temp);
164 }
165 return true;
166 }
167
168 std::map<std::string, std::string> &flag_map_;
169 };
170 } // namespace
171
IsEnableKernelPacket() const172 bool GraphKernelFlags::IsEnableKernelPacket() const {
173 // Default disable kernelpacket now.
174 // todo: default enable when jit_level is O1.
175 return common::GetEnv("MS_DEV_ENABLE_KERNEL_PACKET") == "on";
176 }
177
GetInstance()178 const GraphKernelFlags &GraphKernelFlags::GetInstance() {
179 static std::unique_ptr<GraphKernelFlags> flags(nullptr);
180 auto config = GetGraphKernelConfig();
181 if (flags == nullptr || config.first != flags->flags_cache_ || config.second != flags->enable_graph_kernel_) {
182 flags.reset(new GraphKernelFlags(config.first, config.second));
183 flags->Refresh();
184 }
185 return *flags;
186 }
187
SaveJitConfig(const std::map<std::string,std::string> & jit_config)188 void GraphKernelFlags::SaveJitConfig(const std::map<std::string, std::string> &jit_config) {
189 auto &configs = GetJitConfig();
190 configs.clear();
191 auto level_iter = jit_config.find(kAttrJitLevel);
192 if (level_iter != jit_config.end()) {
193 configs[kAttrJitLevel] = level_iter->second;
194 MS_LOG(DEBUG) << "Save jit_level from jit config, level: " << level_iter->second;
195 }
196 auto flags_iter = jit_config.find("graph_kernel_flags");
197 if (flags_iter != jit_config.end()) {
198 configs["graph_kernel_flags"] = flags_iter->second;
199 MS_LOG(DEBUG) << "Save graph_kernel_flags from jit config, flags: " << flags_iter->second;
200 }
201 }
202
GetGraphKernelConfig()203 std::pair<std::string, bool> GraphKernelFlags::GetGraphKernelConfig() {
204 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
205 std::string flags = common::GetEnv("MS_DEV_GRAPH_KERNEL_FLAGS");
206 if (flags != "") {
207 return std::make_pair(flags, false);
208 }
209 const auto &jit_config = GetJitConfig();
210 if (jit_config.find("graph_kernel_flags") != jit_config.end()) {
211 flags = jit_config.at("graph_kernel_flags");
212 }
213 return std::make_pair(flags, false);
214 #else
215 const auto &jit_config = GetJitConfig();
216 auto context = MsContext::GetInstance();
217 MS_EXCEPTION_IF_NULL(context);
218
219 auto jit_level_iter = jit_config.find(kAttrJitLevel);
220 auto jit_level = (jit_level_iter != jit_config.end() ? jit_level_iter->second : "");
221 bool enable_gk = context->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL);
222 auto device_target = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
223 if (!enable_gk && device_target == kAscendDevice) {
224 enable_gk = (jit_level == kAttrJitLevelO1);
225 }
226 // use environ flags in priority
227 auto flags_env = std::getenv("MS_DEV_GRAPH_KERNEL_FLAGS");
228 if (flags_env != nullptr) {
229 return std::make_pair(std::string(flags_env), enable_gk);
230 }
231 // get flags string from context or jitconfig
232 auto flags = context->get_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS);
233 auto iter = jit_config.find("graph_kernel_flags");
234 if (iter != jit_config.end()) {
235 static bool print_warning_once = true;
236 if (!flags.empty() && print_warning_once) {
237 print_warning_once = false;
238 MS_LOG(WARNING) << "The 'graph_kernel_flags' in 'mindspore.context' and 'JitConfig' is set in the same time, "
239 "only the JitConfig's setting is efficient.";
240 }
241 flags = iter->second;
242 }
243 return std::make_pair(flags, enable_gk);
244 #endif
245 }
246
CheckSupport() const247 void GraphKernelFlags::CheckSupport() const {
248 #ifndef MSLITE_ENABLE_GRAPH_KERNEL
249 if (IsEnableGraphKernel()) {
250 auto context = MsContext::GetInstance();
251 MS_EXCEPTION_IF_NULL(context);
252 #ifndef USE_LLVM
253 auto is_cpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice);
254 if (is_cpu && const_cast<GraphKernelFlags *>(this)->kernel_generator == "AKG") {
255 MS_LOG(WARNING)
256 << "Graph Kernel Fusion is not supported without LLVM on cpu platform, and it will be turned off now. Please "
257 "refer to https://www.mindspore.cn/install and install the required version of LLVM.";
258 const_cast<GraphKernelFlags *>(this)->opt_level = OptLevel_0;
259 return;
260 }
261 #endif
262 auto is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
263 if (is_ascend) {
264 #ifndef ENABLE_DVM
265 MS_LOG(WARNING) << "Graph Kernel Fusion is not supported without the prebuild binary file tracked by git lfs, "
266 "and it will be turned off now. Please perform the following steps:\n\n"
267 "1. Install git lfs, refer https://github.com/git-lfs/git-lfs/wiki/installation\n"
268 "2. After installing git lfs, do not forget executing the following command:\n"
269 " git lfs install\n"
270 "3. Re-clone the source codes, the files tracked by git lfs will be downloaded automatically\n"
271 "4. Re-compile the source codes\n";
272 const_cast<GraphKernelFlags *>(this)->opt_level = OptLevel_0;
273 return;
274 #else
275 if (const_cast<GraphKernelFlags *>(this)->kernel_generator == "DVM") {
276 auto const &soc_version = context->ascend_soc_version();
277 if (!soc_version.empty() && soc_version != "ascend910b" && soc_version != "ascend910c") {
278 MS_LOG(WARNING) << "DVM does not support " << soc_version << ".";
279 const_cast<GraphKernelFlags *>(this)->opt_level = OptLevel_0;
280 return;
281 }
282 }
283 #endif
284 }
285 }
286 #endif
287 }
288
Refresh()289 void GraphKernelFlags::Refresh() {
290 auto flag_map = ParseFlags(flags_cache_);
291 RegisterFlags(&flag_map);
292 for (const auto &item : flag_map) {
293 MS_LOG(WARNING) << "Unknown flag: " << item.first;
294 }
295 if (!flag_map.empty()) {
296 MS_LOG(WARNING)
297 << "For 'context.set_context', the flags listed above in the parameter 'graph_kernel_flags' are invalid. For "
298 "valid flags, please refer to the source code file graph_kernel_flags.h at "
299 "https://gitee.com/mindspore/mindspore.";
300 }
301 #ifndef MSLITE_ENABLE_GRAPH_KERNEL
302 if (IsEnableGraphKernel()) {
303 CheckSupport();
304 }
305 #endif
306 // If enable graphkernel, Dump flags so that people can check the setting.
307 if (IsEnableGraphKernel()) {
308 MS_LOG(INFO) << "graph_kernel_flags = \"" << flags_cache_ << "\", all flags: " << DumpAllFlags();
309 }
310 }
311
RegisterFlags(std::map<std::string,std::string> * flag_map)312 void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_map) {
313 bool has_kernel_generator = (flag_map->find("kernel_generator") != flag_map->end());
314 bool has_enable_dynamic_shape_fusion = (flag_map->find("enable_dynamic_shape_fusion") != flag_map->end());
315 FlagRegister reg(flag_map);
316 bool is_ascend{false};
317 bool is_910bc{false};
318 auto context_ptr = MsContext::GetInstance();
319 if (context_ptr != nullptr) {
320 auto const &soc_version = context_ptr->ascend_soc_version();
321 is_910bc = (soc_version == "ascend910b") || (soc_version == "ascend910c");
322 is_ascend = (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
323 }
324
325 // Set opt_level first, some flags' default value depends on it.
326 // Default optimization level is level 2 when enable graphkernel
327 reg.AddFlag("opt_level", &opt_level, enable_graph_kernel_ ? OptLevel_2 : OptLevel_0);
328 if (opt_level > OptLevel_3) {
329 MS_LOG(WARNING) << "For 'context.set_context', the flag opt_level in the parameter 'graph_kernel_flags' must be in "
330 "the range [0, 3], but got "
331 << opt_level << ". It will be set to " << OptLevel_3
332 << ". For more details, please refer to 'graph_kernel_flags' at https://www.mindspore.cn.";
333 opt_level = OptLevel_3;
334 }
335
336 // Boolean flags
337 reg.AddFlag("dump_as_text", &dump_as_text);
338 reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion, (opt_level == OptLevel_3 && !is_910bc));
339 reg.AddFlag("enable_recompute_fusion", &enable_recompute_fusion, opt_level >= OptLevel_2);
340 reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3);
341 reg.AddFlag("enable_horizontal_fusion", &enable_horizontal_fusion);
342 reg.AddFlag("enable_auto_tensor_inplace", &enable_auto_tensor_inplace);
343 reg.AddFlag("enable_dynamic_batch", &enable_dynamic_batch);
344 reg.AddFlag("enable_low_precision", &enable_low_precision);
345 reg.AddFlag("enable_csr_fusion", &enable_csr_fusion);
346 reg.AddFlag("enable_debug_mode", &enable_debug_mode);
347 reg.AddFlag("enable_lite_conv_tuning", &enable_lite_conv_tuning);
348 reg.AddFlag("enable_vectorization", &enable_vectorization);
349 reg.AddFlag("enable_dynamic_shape_fusion", &enable_dynamic_shape_fusion);
350 reg.AddFlag("enable_parallel_op_combine", &enable_parallel_op_combine);
351
352 // Integer flags
353 reg.AddFlag("reduce_fuse_depth", &reduce_fuse_depth);
354 reg.AddFlag("online_tuning", &online_tuning);
355 reg.AddFlag("cpu_refer_thread_num", &cpu_refer_thread_num);
356 reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_ascend ? OpLevel_0 : OpLevel_1);
357 reg.AddFlag("parallel_ops_level", ¶llel_ops_level);
358 reg.AddFlag("recompute_increment_threshold", &recompute_increment_threshold);
359 reg.AddFlag("recompute_peak_threshold", &recompute_peak_threshold);
360 reg.AddFlag("composite_op_limit_size", &composite_op_limit_size);
361
362 // String flags
363 reg.AddFlag("repository_path", &repository_path);
364 reg.AddFlag("target_os", &target_os);
365 reg.AddFlag("cpu_arch", &cpu_arch);
366 reg.AddFlag("cpu_feature", &cpu_feature);
367 reg.AddFlag("cpu_type", &cpu_type);
368 reg.AddFlag("kernel_generator", &kernel_generator);
369
370 // String list flags
371 reg.AddFlag("enable_expand_ops", &enable_expand_ops);
372 reg.AddFlag("enable_expand_ops_only", &enable_expand_ops_only);
373 reg.AddFlag("disable_expand_ops", &disable_expand_ops);
374 reg.AddFlag("enable_cluster_ops", &enable_cluster_ops);
375 reg.AddFlag("enable_cluster_ops_only", &enable_cluster_ops_only);
376 reg.AddFlag("disable_cluster_ops", &disable_cluster_ops);
377 reg.AddFlag("enable_simplify_exprs_only", &enable_simplify_exprs_only);
378 reg.AddFlag("disable_simplify_exprs", &disable_simplify_exprs);
379 reg.AddFlag("enable_pass", &enable_pass);
380 reg.AddFlag("disable_pass", &disable_pass);
381 reg.AddFlag("enable_cce_lib", &enable_cce_lib);
382 reg.AddFlag("enable_cce_lib_ops", &enable_cce_lib_ops);
383 reg.AddFlag("enable_cce_lib_ops_only", &enable_cce_lib_ops_only);
384 reg.AddFlag("disable_cce_lib_ops", &disable_cce_lib_ops);
385 reg.AddFlag("enable_packet_ops_only", &enable_packet_ops_only);
386 reg.AddFlag("disable_packet_ops", &disable_packet_ops);
387
388 if (enable_dynamic_shape_fusion && !is_ascend) {
389 kernel_generator = "AKG_V2";
390 return;
391 }
392
393 if (is_ascend && !has_kernel_generator) {
394 #ifndef MSLITE_ENABLE_GRAPH_KERNEL
395 kernel_generator = "DVM";
396 #endif
397 }
398 if (kernel_generator == "DVM" && !has_enable_dynamic_shape_fusion) {
399 enable_dynamic_shape_fusion = true;
400 }
401 if (is_ascend && enable_auto_tensor_inplace) {
402 MS_LOG(WARNING)
403 << "For Graph Kernel Fusion, the flag '--enable_auto_tensor_inplace' set in 'graph_kernel_flags' is "
404 "not supported on Ascend and will be turned off now";
405 enable_auto_tensor_inplace = false;
406 }
407 }
408
DumpAllFlags() const409 std::string GraphKernelFlags::DumpAllFlags() const {
410 nlohmann::json json;
411
412 json["dump_as_text"] = dump_as_text;
413 json["enable_stitch_fusion"] = enable_stitch_fusion;
414 json["enable_recompute_fusion"] = enable_recompute_fusion;
415 json["enable_parallel_fusion"] = enable_parallel_fusion;
416 json["enable_horizontal_fusion"] = enable_horizontal_fusion;
417 json["enable_auto_tensor_inplace"] = enable_auto_tensor_inplace;
418 json["enable_dynamic_batch"] = enable_dynamic_batch;
419 json["enable_csr_fusion"] = enable_csr_fusion;
420 json["enable_low_precision"] = enable_low_precision;
421 json["enable_debug_mode"] = enable_debug_mode;
422 json["enable_lite_conv_tuning"] = enable_lite_conv_tuning;
423 json["enable_vectorization"] = enable_vectorization;
424 json["enable_dynamic_shape_fusion"] = enable_dynamic_shape_fusion;
425
426 json["opt_level"] = opt_level;
427 json["fusion_ops_level"] = fusion_ops_level;
428 json["parallel_ops_level"] = parallel_ops_level;
429 json["reduce_fuse_depth"] = reduce_fuse_depth;
430 json["online_tuning"] = online_tuning;
431 json["cpu_refer_thread_num"] = cpu_refer_thread_num;
432 json["recompute_increment_threshold"] = recompute_increment_threshold;
433 json["recompute_peak_threshold"] = recompute_peak_threshold;
434 json["composite_op_limit_size"] = composite_op_limit_size;
435
436 json["repository_path"] = repository_path;
437 json["target_os"] = target_os;
438 json["cpu_arch"] = cpu_arch;
439 json["cpu_feature"] = cpu_feature;
440 json["cpu_type"] = cpu_type;
441
442 json["kernel_generator"] = kernel_generator;
443
444 json["enable_expand_ops"] = enable_expand_ops;
445 json["enable_expand_ops_only"] = enable_expand_ops_only;
446 json["disable_expand_ops"] = disable_expand_ops;
447 json["enable_cluster_ops"] = enable_cluster_ops;
448 json["enable_cluster_ops_only"] = enable_cluster_ops_only;
449 json["disable_cluster_ops"] = disable_cluster_ops;
450 json["enable_simplify_exprs_only"] = enable_simplify_exprs_only;
451 json["disable_simplify_exprs"] = disable_simplify_exprs;
452 json["enable_pass"] = enable_pass;
453 json["disable_pass"] = disable_pass;
454 json["enable_cce_lib"] = enable_cce_lib;
455 json["enable_cce_lib_ops"] = enable_cce_lib_ops_only;
456 json["enable_cce_lib_ops_only"] = enable_cce_lib_ops_only;
457 json["disable_cce_lib_ops"] = disable_cce_lib_ops;
458 json["enable_packet_ops_only"] = enable_packet_ops_only;
459 json["disable_packet_ops"] = disable_packet_ops;
460
461 return json.dump();
462 }
463 } // namespace mindspore::graphkernel
464