• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/graph_kernel/graph_kernel_flags.h"
18 
19 #include <map>
20 #include <string>
21 #include <cstring>
22 #include <vector>
23 #include <utility>
24 #include "nlohmann/json.hpp"
25 #include "utils/ms_context.h"
26 #include "include/common/utils/utils.h"
27 
28 namespace mindspore::graphkernel {
29 namespace {
30 constexpr auto kLogValidFlag =
31   "Valid flag format is \"--key=value\", flags are separated by spaces(e.g. \"--key1=value1 --key2=value2\"). bool "
32   "flag's value can be implicit, the \"--key\" means \"--key=true\".";
33 
34 // Split string to tokens
GetTokens(const std::string & str,const std::string & delim)35 std::vector<std::string> GetTokens(const std::string &str, const std::string &delim) {
36   std::vector<std::string> tokens;
37   std::vector<char> c_str(str.begin(), str.end());
38   c_str.push_back('\0');
39   char *saveptr = nullptr;
40 #ifdef _MSC_VER
41   char *pch = strtok_s(&c_str[0], delim.c_str(), &saveptr);
42 #else
43   char *pch = strtok_r(&c_str[0], delim.c_str(), &saveptr);
44 #endif
45   while (pch != nullptr) {
46     (void)tokens.emplace_back(pch);
47 #ifdef _MSC_VER
48     pch = strtok_s(nullptr, delim.c_str(), &saveptr);
49 #else
50     pch = strtok_r(nullptr, delim.c_str(), &saveptr);
51 #endif
52   }
53   return tokens;
54 }
55 
56 // Parse flag string to key-value pair.
57 // Flag format: "--key=value", bool flag's value can be implicit, the "--key" means "--key=true"
ParseFlag(const std::string & flag)58 std::pair<std::string, std::string> ParseFlag(const std::string &flag) {
59   auto i = flag.find("--");
60   // check the string starts with "--".
61   constexpr size_t leading_size = 2;
62   if (flag.size() <= leading_size || i != 0) {
63     return std::pair<std::string, std::string>();
64   }
65   i += leading_size;
66 
67   auto j = flag.find('=', i + 1);  // the key should not be empty, "--=" is invalid
68   if (j >= flag.size()) {
69     // no value, treated as bool flag.
70     return std::make_pair(flag.substr(i), "");
71   } else if (j + 1 < flag.size() && flag.find('=', j + 1) == std::string::npos) {
72     // normal "--key=value" format
73     return std::make_pair(flag.substr(i, j - i), flag.substr(j + 1));
74   }
75   // string with two "=" is invalid.
76   return std::pair<std::string, std::string>();
77 }
78 
ParseFlags(const std::string & flags)79 std::map<std::string, std::string> ParseFlags(const std::string &flags) {
80   std::map<std::string, std::string> flag_map;
81   auto tokens = GetTokens(flags, " ");
82   for (const auto &token : tokens) {
83     auto flag = ParseFlag(token);
84     if (flag.first != "") {
85       if (!flag_map.insert(flag).second) {
86         MS_LOG(WARNING) << "For 'context.set_context', the flag '" << flag.first
87                         << "' in the parameter 'graph_kernel_flags' is repeated.";
88       }
89     } else {
90       MS_LOG(WARNING) << "For 'context.set_context', the flag '" << token
91                       << "' in the parameter 'graph_kernel_flags' is invalid. " << kLogValidFlag;
92     }
93   }
94   return flag_map;
95 }
96 
97 class FlagRegister {
98  public:
FlagRegister(std::map<std::string,std::string> * flag_map)99   explicit FlagRegister(std::map<std::string, std::string> *flag_map) : flag_map_(*flag_map) {}
100   ~FlagRegister() = default;
101 
102   template <typename T>
AddFlag(const std::string & flag_name,T * flag_var,T default_value) const103   void AddFlag(const std::string &flag_name, T *flag_var, T default_value) const {
104     *flag_var = std::move(default_value);
105     AddFlag(flag_name, flag_var);
106   }
107 
108   template <typename T>
AddFlag(const std::string & flag_name,T * flag_var) const109   void AddFlag(const std::string &flag_name, T *flag_var) const {
110     const auto iter = flag_map_.find(flag_name);
111     if (iter != flag_map_.end()) {
112       T var;
113       bool ret = ParseValue(iter->second, &var);
114       if (ret) {
115         *flag_var = std::move(var);
116       } else {
117         if (iter->second.empty()) {
118           MS_LOG(WARNING) << "For 'context.set_context', the flag --" << iter->first
119                           << " in the parameter 'graph_kernel_flags' is invalid. " << kLogValidFlag;
120         } else {
121           MS_LOG(WARNING) << "For 'context.set_context', the flag --" << iter->first << "=" << iter->second
122                           << " in the parameter 'graph_kernel_flags' is invalid. " << kLogValidFlag;
123         }
124       }
125       (void)flag_map_.erase(iter);
126     }
127   }
128 
129  private:
ParseValue(const std::string & s,std::vector<std::string> * result) const130   bool ParseValue(const std::string &s, std::vector<std::string> *result) const {
131     *result = GetTokens(s, ",");
132     return !result->empty();
133   }
134 
ParseValue(const std::string & s,bool * result) const135   bool ParseValue(const std::string &s, bool *result) const {
136     *result = (s.empty() || s == "true" || s == "True" || s == "on" || s == "1");
137     return *result || s == "false" || s == "False" || s == "off" || s == "0";
138   }
139 
140   template <typename T>
ParseValue(const std::string & s,T * result) const141   bool ParseValue(const std::string &s, T *result) const {
142     if (s.empty()) {
143       return false;
144     }
145     std::istringstream iss(s);
146     iss >> (*result);
147     return iss.eof();
148   }
149 
150   template <typename T>
ParseValue(const std::string & s,std::vector<T> * result) const151   bool ParseValue(const std::string &s, std::vector<T> *result) const {
152     result->clear();
153     auto tokens = GetTokens(s, ",");
154     if (tokens.empty()) {
155       return false;
156     }
157     for (const auto &tok : tokens) {
158       T temp;
159       if (!ParseValue(tok, &temp)) {
160         result->clear();
161         return false;
162       }
163       result->emplace_back(temp);
164     }
165     return true;
166   }
167 
168   std::map<std::string, std::string> &flag_map_;
169 };
170 }  // namespace
171 
IsEnableKernelPacket() const172 bool GraphKernelFlags::IsEnableKernelPacket() const {
173   // Default disable kernelpacket now.
174   // todo: default enable when jit_level is O1.
175   return common::GetEnv("MS_DEV_ENABLE_KERNEL_PACKET") == "on";
176 }
177 
GetInstance()178 const GraphKernelFlags &GraphKernelFlags::GetInstance() {
179   static std::unique_ptr<GraphKernelFlags> flags(nullptr);
180   auto config = GetGraphKernelConfig();
181   if (flags == nullptr || config.first != flags->flags_cache_ || config.second != flags->enable_graph_kernel_) {
182     flags.reset(new GraphKernelFlags(config.first, config.second));
183     flags->Refresh();
184   }
185   return *flags;
186 }
187 
SaveJitConfig(const std::map<std::string,std::string> & jit_config)188 void GraphKernelFlags::SaveJitConfig(const std::map<std::string, std::string> &jit_config) {
189   auto &configs = GetJitConfig();
190   configs.clear();
191   auto level_iter = jit_config.find(kAttrJitLevel);
192   if (level_iter != jit_config.end()) {
193     configs[kAttrJitLevel] = level_iter->second;
194     MS_LOG(DEBUG) << "Save jit_level from jit config, level: " << level_iter->second;
195   }
196   auto flags_iter = jit_config.find("graph_kernel_flags");
197   if (flags_iter != jit_config.end()) {
198     configs["graph_kernel_flags"] = flags_iter->second;
199     MS_LOG(DEBUG) << "Save graph_kernel_flags from jit config, flags: " << flags_iter->second;
200   }
201 }
202 
GetGraphKernelConfig()203 std::pair<std::string, bool> GraphKernelFlags::GetGraphKernelConfig() {
204 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
205   std::string flags = common::GetEnv("MS_DEV_GRAPH_KERNEL_FLAGS");
206   if (flags != "") {
207     return std::make_pair(flags, false);
208   }
209   const auto &jit_config = GetJitConfig();
210   if (jit_config.find("graph_kernel_flags") != jit_config.end()) {
211     flags = jit_config.at("graph_kernel_flags");
212   }
213   return std::make_pair(flags, false);
214 #else
215   const auto &jit_config = GetJitConfig();
216   auto context = MsContext::GetInstance();
217   MS_EXCEPTION_IF_NULL(context);
218 
219   auto jit_level_iter = jit_config.find(kAttrJitLevel);
220   auto jit_level = (jit_level_iter != jit_config.end() ? jit_level_iter->second : "");
221   bool enable_gk = context->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL);
222   auto device_target = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
223   if (!enable_gk && device_target == kAscendDevice) {
224     enable_gk = (jit_level == kAttrJitLevelO1);
225   }
226   // use environ flags in priority
227   auto flags_env = std::getenv("MS_DEV_GRAPH_KERNEL_FLAGS");
228   if (flags_env != nullptr) {
229     return std::make_pair(std::string(flags_env), enable_gk);
230   }
231   // get flags string from context or jitconfig
232   auto flags = context->get_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS);
233   auto iter = jit_config.find("graph_kernel_flags");
234   if (iter != jit_config.end()) {
235     static bool print_warning_once = true;
236     if (!flags.empty() && print_warning_once) {
237       print_warning_once = false;
238       MS_LOG(WARNING) << "The 'graph_kernel_flags' in 'mindspore.context' and 'JitConfig' is set in the same time, "
239                          "only the JitConfig's setting is efficient.";
240     }
241     flags = iter->second;
242   }
243   return std::make_pair(flags, enable_gk);
244 #endif
245 }
246 
CheckSupport() const247 void GraphKernelFlags::CheckSupport() const {
248 #ifndef MSLITE_ENABLE_GRAPH_KERNEL
249   if (IsEnableGraphKernel()) {
250     auto context = MsContext::GetInstance();
251     MS_EXCEPTION_IF_NULL(context);
252 #ifndef USE_LLVM
253     auto is_cpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice);
254     if (is_cpu && const_cast<GraphKernelFlags *>(this)->kernel_generator == "AKG") {
255       MS_LOG(WARNING)
256         << "Graph Kernel Fusion is not supported without LLVM on cpu platform, and it will be turned off now. Please "
257            "refer to https://www.mindspore.cn/install and install the required version of LLVM.";
258       const_cast<GraphKernelFlags *>(this)->opt_level = OptLevel_0;
259       return;
260     }
261 #endif
262     auto is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
263     if (is_ascend) {
264 #ifndef ENABLE_DVM
265       MS_LOG(WARNING) << "Graph Kernel Fusion is not supported without the prebuild binary file tracked by git lfs, "
266                          "and it will be turned off now. Please perform the following steps:\n\n"
267                          "1. Install git lfs, refer https://github.com/git-lfs/git-lfs/wiki/installation\n"
268                          "2. After installing git lfs, do not forget executing the following command:\n"
269                          "   git lfs install\n"
270                          "3. Re-clone the source codes, the files tracked by git lfs will be downloaded automatically\n"
271                          "4. Re-compile the source codes\n";
272       const_cast<GraphKernelFlags *>(this)->opt_level = OptLevel_0;
273       return;
274 #else
275       if (const_cast<GraphKernelFlags *>(this)->kernel_generator == "DVM") {
276         auto const &soc_version = context->ascend_soc_version();
277         if (!soc_version.empty() && soc_version != "ascend910b" && soc_version != "ascend910c") {
278           MS_LOG(WARNING) << "DVM does not support " << soc_version << ".";
279           const_cast<GraphKernelFlags *>(this)->opt_level = OptLevel_0;
280           return;
281         }
282       }
283 #endif
284     }
285   }
286 #endif
287 }
288 
Refresh()289 void GraphKernelFlags::Refresh() {
290   auto flag_map = ParseFlags(flags_cache_);
291   RegisterFlags(&flag_map);
292   for (const auto &item : flag_map) {
293     MS_LOG(WARNING) << "Unknown flag: " << item.first;
294   }
295   if (!flag_map.empty()) {
296     MS_LOG(WARNING)
297       << "For 'context.set_context', the flags listed above in the parameter 'graph_kernel_flags' are invalid. For "
298          "valid flags, please refer to the source code file graph_kernel_flags.h at "
299          "https://gitee.com/mindspore/mindspore.";
300   }
301 #ifndef MSLITE_ENABLE_GRAPH_KERNEL
302   if (IsEnableGraphKernel()) {
303     CheckSupport();
304   }
305 #endif
306   // If enable graphkernel, Dump flags so that people can check the setting.
307   if (IsEnableGraphKernel()) {
308     MS_LOG(INFO) << "graph_kernel_flags = \"" << flags_cache_ << "\", all flags: " << DumpAllFlags();
309   }
310 }
311 
RegisterFlags(std::map<std::string,std::string> * flag_map)312 void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_map) {
313   bool has_kernel_generator = (flag_map->find("kernel_generator") != flag_map->end());
314   bool has_enable_dynamic_shape_fusion = (flag_map->find("enable_dynamic_shape_fusion") != flag_map->end());
315   FlagRegister reg(flag_map);
316   bool is_ascend{false};
317   bool is_910bc{false};
318   auto context_ptr = MsContext::GetInstance();
319   if (context_ptr != nullptr) {
320     auto const &soc_version = context_ptr->ascend_soc_version();
321     is_910bc = (soc_version == "ascend910b") || (soc_version == "ascend910c");
322     is_ascend = (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
323   }
324 
325   // Set opt_level first, some flags' default value depends on it.
326   // Default optimization level is level 2 when enable graphkernel
327   reg.AddFlag("opt_level", &opt_level, enable_graph_kernel_ ? OptLevel_2 : OptLevel_0);
328   if (opt_level > OptLevel_3) {
329     MS_LOG(WARNING) << "For 'context.set_context', the flag opt_level in the parameter 'graph_kernel_flags' must be in "
330                        "the range [0, 3], but got "
331                     << opt_level << ". It will be set to " << OptLevel_3
332                     << ". For more details, please refer to 'graph_kernel_flags' at https://www.mindspore.cn.";
333     opt_level = OptLevel_3;
334   }
335 
336   // Boolean flags
337   reg.AddFlag("dump_as_text", &dump_as_text);
338   reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion, (opt_level == OptLevel_3 && !is_910bc));
339   reg.AddFlag("enable_recompute_fusion", &enable_recompute_fusion, opt_level >= OptLevel_2);
340   reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3);
341   reg.AddFlag("enable_horizontal_fusion", &enable_horizontal_fusion);
342   reg.AddFlag("enable_auto_tensor_inplace", &enable_auto_tensor_inplace);
343   reg.AddFlag("enable_dynamic_batch", &enable_dynamic_batch);
344   reg.AddFlag("enable_low_precision", &enable_low_precision);
345   reg.AddFlag("enable_csr_fusion", &enable_csr_fusion);
346   reg.AddFlag("enable_debug_mode", &enable_debug_mode);
347   reg.AddFlag("enable_lite_conv_tuning", &enable_lite_conv_tuning);
348   reg.AddFlag("enable_vectorization", &enable_vectorization);
349   reg.AddFlag("enable_dynamic_shape_fusion", &enable_dynamic_shape_fusion);
350   reg.AddFlag("enable_parallel_op_combine", &enable_parallel_op_combine);
351 
352   // Integer flags
353   reg.AddFlag("reduce_fuse_depth", &reduce_fuse_depth);
354   reg.AddFlag("online_tuning", &online_tuning);
355   reg.AddFlag("cpu_refer_thread_num", &cpu_refer_thread_num);
356   reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_ascend ? OpLevel_0 : OpLevel_1);
357   reg.AddFlag("parallel_ops_level", &parallel_ops_level);
358   reg.AddFlag("recompute_increment_threshold", &recompute_increment_threshold);
359   reg.AddFlag("recompute_peak_threshold", &recompute_peak_threshold);
360   reg.AddFlag("composite_op_limit_size", &composite_op_limit_size);
361 
362   // String flags
363   reg.AddFlag("repository_path", &repository_path);
364   reg.AddFlag("target_os", &target_os);
365   reg.AddFlag("cpu_arch", &cpu_arch);
366   reg.AddFlag("cpu_feature", &cpu_feature);
367   reg.AddFlag("cpu_type", &cpu_type);
368   reg.AddFlag("kernel_generator", &kernel_generator);
369 
370   // String list flags
371   reg.AddFlag("enable_expand_ops", &enable_expand_ops);
372   reg.AddFlag("enable_expand_ops_only", &enable_expand_ops_only);
373   reg.AddFlag("disable_expand_ops", &disable_expand_ops);
374   reg.AddFlag("enable_cluster_ops", &enable_cluster_ops);
375   reg.AddFlag("enable_cluster_ops_only", &enable_cluster_ops_only);
376   reg.AddFlag("disable_cluster_ops", &disable_cluster_ops);
377   reg.AddFlag("enable_simplify_exprs_only", &enable_simplify_exprs_only);
378   reg.AddFlag("disable_simplify_exprs", &disable_simplify_exprs);
379   reg.AddFlag("enable_pass", &enable_pass);
380   reg.AddFlag("disable_pass", &disable_pass);
381   reg.AddFlag("enable_cce_lib", &enable_cce_lib);
382   reg.AddFlag("enable_cce_lib_ops", &enable_cce_lib_ops);
383   reg.AddFlag("enable_cce_lib_ops_only", &enable_cce_lib_ops_only);
384   reg.AddFlag("disable_cce_lib_ops", &disable_cce_lib_ops);
385   reg.AddFlag("enable_packet_ops_only", &enable_packet_ops_only);
386   reg.AddFlag("disable_packet_ops", &disable_packet_ops);
387 
388   if (enable_dynamic_shape_fusion && !is_ascend) {
389     kernel_generator = "AKG_V2";
390     return;
391   }
392 
393   if (is_ascend && !has_kernel_generator) {
394 #ifndef MSLITE_ENABLE_GRAPH_KERNEL
395     kernel_generator = "DVM";
396 #endif
397   }
398   if (kernel_generator == "DVM" && !has_enable_dynamic_shape_fusion) {
399     enable_dynamic_shape_fusion = true;
400   }
401   if (is_ascend && enable_auto_tensor_inplace) {
402     MS_LOG(WARNING)
403       << "For Graph Kernel Fusion, the flag '--enable_auto_tensor_inplace' set in 'graph_kernel_flags' is "
404          "not supported on Ascend and will be turned off now";
405     enable_auto_tensor_inplace = false;
406   }
407 }
408 
DumpAllFlags() const409 std::string GraphKernelFlags::DumpAllFlags() const {
410   nlohmann::json json;
411 
412   json["dump_as_text"] = dump_as_text;
413   json["enable_stitch_fusion"] = enable_stitch_fusion;
414   json["enable_recompute_fusion"] = enable_recompute_fusion;
415   json["enable_parallel_fusion"] = enable_parallel_fusion;
416   json["enable_horizontal_fusion"] = enable_horizontal_fusion;
417   json["enable_auto_tensor_inplace"] = enable_auto_tensor_inplace;
418   json["enable_dynamic_batch"] = enable_dynamic_batch;
419   json["enable_csr_fusion"] = enable_csr_fusion;
420   json["enable_low_precision"] = enable_low_precision;
421   json["enable_debug_mode"] = enable_debug_mode;
422   json["enable_lite_conv_tuning"] = enable_lite_conv_tuning;
423   json["enable_vectorization"] = enable_vectorization;
424   json["enable_dynamic_shape_fusion"] = enable_dynamic_shape_fusion;
425 
426   json["opt_level"] = opt_level;
427   json["fusion_ops_level"] = fusion_ops_level;
428   json["parallel_ops_level"] = parallel_ops_level;
429   json["reduce_fuse_depth"] = reduce_fuse_depth;
430   json["online_tuning"] = online_tuning;
431   json["cpu_refer_thread_num"] = cpu_refer_thread_num;
432   json["recompute_increment_threshold"] = recompute_increment_threshold;
433   json["recompute_peak_threshold"] = recompute_peak_threshold;
434   json["composite_op_limit_size"] = composite_op_limit_size;
435 
436   json["repository_path"] = repository_path;
437   json["target_os"] = target_os;
438   json["cpu_arch"] = cpu_arch;
439   json["cpu_feature"] = cpu_feature;
440   json["cpu_type"] = cpu_type;
441 
442   json["kernel_generator"] = kernel_generator;
443 
444   json["enable_expand_ops"] = enable_expand_ops;
445   json["enable_expand_ops_only"] = enable_expand_ops_only;
446   json["disable_expand_ops"] = disable_expand_ops;
447   json["enable_cluster_ops"] = enable_cluster_ops;
448   json["enable_cluster_ops_only"] = enable_cluster_ops_only;
449   json["disable_cluster_ops"] = disable_cluster_ops;
450   json["enable_simplify_exprs_only"] = enable_simplify_exprs_only;
451   json["disable_simplify_exprs"] = disable_simplify_exprs;
452   json["enable_pass"] = enable_pass;
453   json["disable_pass"] = disable_pass;
454   json["enable_cce_lib"] = enable_cce_lib;
455   json["enable_cce_lib_ops"] = enable_cce_lib_ops_only;
456   json["enable_cce_lib_ops_only"] = enable_cce_lib_ops_only;
457   json["disable_cce_lib_ops"] = disable_cce_lib_ops;
458   json["enable_packet_ops_only"] = enable_packet_ops_only;
459   json["disable_packet_ops"] = disable_packet_ops;
460 
461   return json.dump();
462 }
463 }  // namespace mindspore::graphkernel
464