• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/ms_context.h"
18 #include <string>
19 #include <thread>
20 #include <atomic>
21 #include <fstream>
22 #include <algorithm>
23 #include <utility>
24 #include "utils/ms_utils.h"
25 #include "include/common/utils/utils.h"
26 #include "utils/convert_utils_base.h"
27 #include "utils/phase.h"
28 
29 #if defined(_WIN32) || defined(_WIN64)
30 #include <windows.h>
31 #else
32 #include <dlfcn.h>
33 #endif
34 
35 namespace mindspore {
36 namespace {
37 std::map<std::string, MsBackendPolicy> kPolicyMap = {{"ge", kMsBackendGePrior},     {"bisheng", kMsBackendBishengPrior},
38                                                      {"vm", kMsBackendVmOnly},      {"ms", kMsBackendMsPrior},
39                                                      {"ge_only", kMsBackendGeOnly}, {"vm_prior", kMsBackendVmPrior}};
40 
41 constexpr auto kDeviceTargetSize2 = 2;
42 }  // namespace
43 std::atomic<bool> thread_1_must_end(false);
44 
45 MsContext::DeviceSeter MsContext::seter_ = nullptr;
46 MsContext::LoadPluginError MsContext::load_plugin_error_ = nullptr;
47 std::shared_ptr<MsContext> MsContext::inst_context_ = nullptr;
48 
49 std::map<MsCtxParam, std::string> kUnresetParamCheckList = {
50   {MsCtxParam::MS_CTX_DEVICE_ID, "device_id"},
51   {MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "variable_memory_max_size"},
52   {MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY, "max_device_memory"},
53   {MsCtxParam::MS_CTX_MEMPOOL_BLOCK_SIZE, "mempool_block_size"}};
54 
MsContext(const std::string & policy,const std::string & target)55 MsContext::MsContext(const std::string &policy, const std::string &target) {
56 #ifndef ENABLE_SECURITY
57   set_param<int>(MS_CTX_SAVE_GRAPHS_FLAG, 0);
58   set_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH, ".");
59   set_param<std::string>(MS_CTX_COMPILE_CACHE_PATH, "");
60 #else
61   // Need set a default value for arrays even if running in the security mode.
62   int_params_[MS_CTX_SAVE_GRAPHS_FLAG - MS_CTX_TYPE_BOOL_BEGIN] = 0;
63   string_params_[MS_CTX_SAVE_GRAPHS_PATH - MS_CTX_TYPE_STRING_BEGIN] = ".";
64 #endif
65   InitBoolTypeDefaultValue();
66   InitStringTypeDefaultValue();
67   InitDigitalTypeDefaultValue();
68   MsContext::SetDeviceId();
69   string_params_[MS_CTX_DEVICE_TARGET - MS_CTX_TYPE_STRING_BEGIN] = target;
70   set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, target == kAscendDevice || target == kDavinciDevice);
71 
72   backend_policy_ = kPolicyMap[policy];
73   ascend_soc_version_ = "";
74 
75   params_read_status_ = std::vector<bool>(
76     static_cast<size_t>(MsCtxParam::NUM_BOOL_PARAMS + MsCtxParam::NUM_UINT32_PARAMS + MsCtxParam::NUM_INT_PARAMS +
77                         MsCtxParam::NUM_FLOAT_PARAMS + MsCtxParam::NUM_STRING_PARAMS),
78     false);
79   params_write_status_ = std::vector<bool>(
80     static_cast<size_t>(MsCtxParam::NUM_BOOL_PARAMS + MsCtxParam::NUM_UINT32_PARAMS + MsCtxParam::NUM_INT_PARAMS +
81                         MsCtxParam::NUM_FLOAT_PARAMS + MsCtxParam::NUM_STRING_PARAMS),
82     false);
83 
84   SetAscendConfig();
85 }
86 
GetInstance()87 std::shared_ptr<MsContext> MsContext::GetInstance() {
88   static std::once_flag inst_context_init_flag_ = {};
89   std::call_once(inst_context_init_flag_, [&]() {
90     if (inst_context_ == nullptr) {
91       MS_LOG(DEBUG) << "Create new mindspore context";
92       inst_context_ = std::make_shared<MsContext>("vm", kCPUDevice);
93     }
94   });
95   MS_EXCEPTION_IF_NULL(inst_context_);
96   return inst_context_;
97 }
98 
SetDeviceId()99 void MsContext::SetDeviceId() {
100   auto env_device = common::GetEnv("DEVICE_ID");
101   if (!env_device.empty()) {
102     try {
103       uint32_t device_id = UlongToUint(std::stoul(env_device));
104       set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
105     } catch (std::invalid_argument &e) {
106       MS_LOG(WARNING) << "Invalid DEVICE_ID env:" << env_device << ". Please set DEVICE_ID to 0-7";
107       set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
108     }
109   } else {
110     set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
111   }
112 }
113 
Refresh()114 void MsContext::Refresh() {
115   RefreshExecutionMode();
116   RefreshMemoryOffload();
117 }
118 
RefreshExecutionMode()119 void MsContext::RefreshExecutionMode() {
120   const std::string &target = get_param<std::string>(MS_CTX_DEVICE_TARGET);
121   if (target == kAscendDevice) {
122     if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
123       set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
124     } else if (IsKByKExecutorMode()) {
125       set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
126     }
127   }
128 }
129 
RefreshMemoryOffload()130 void MsContext::RefreshMemoryOffload() {
131   const bool enable_mem_offload = get_param<bool>(MS_CTX_ENABLE_MEM_OFFLOAD);
132   if (!enable_mem_offload) {
133     return;
134   }
135   const std::string &target = get_param<std::string>(MS_CTX_DEVICE_TARGET);
136   if (target == kCPUDevice) {
137     MS_LOG(WARNING) << "Memory offload is not available on CPU device.";
138     set_param(MS_CTX_ENABLE_MEM_OFFLOAD, false);
139     return;
140   }
141   if (target == kAscendDevice && get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && !IsKByKExecutorMode()) {
142     MS_LOG(WARNING) << "Run graph mode with kernel by kernel because memory offload is ON.";
143     set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
144     return;
145   }
146   if (get_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL) == kOptimizeO1) {
147     MS_LOG(WARNING) << "Memory offload is not available when memory_optimize_level is set to O1.";
148     set_param(MS_CTX_ENABLE_MEM_OFFLOAD, false);
149     return;
150   }
151   MS_LOG(INFO) << "Set memory pool block size to max device memory size for memory offload.";
152   set_param_inner(MS_CTX_MEMPOOL_BLOCK_SIZE, get_param<float>(MS_CTX_MAX_DEVICE_MEMORY));
153 }
154 
set_backend_policy(const std::string & policy)155 bool MsContext::set_backend_policy(const std::string &policy) {
156   auto iter = kPolicyMap.find(policy);
157   if (iter == kPolicyMap.end()) {
158     MS_LOG(ERROR) << "invalid backend policy name: " << policy;
159     return false;
160   }
161   backend_policy_ = iter->second;
162   MS_LOG(INFO) << "ms set context backend policy:" << policy;
163   return true;
164 }
165 
backend_policy() const166 std::string MsContext::backend_policy() const {
167   auto res = std::find_if(
168     kPolicyMap.begin(), kPolicyMap.end(),
169     [&, this](const std::pair<std::string, MsBackendPolicy> &item) { return item.second == backend_policy_; });
170   if (res != kPolicyMap.end()) {
171     return res->first;
172   }
173   return "unknown";
174 }
175 
set_ascend_soc_name(const std::string & soc_name)176 void MsContext::set_ascend_soc_name(const std::string &soc_name) { ascend_soc_name_ = soc_name; }
177 
ascend_soc_name() const178 std::string MsContext::ascend_soc_name() const { return ascend_soc_name_; }
179 
set_ascend_soc_version(const std::string & soc_version)180 bool MsContext::set_ascend_soc_version(const std::string &soc_version) {
181   ascend_soc_version_ = soc_version;
182   return true;
183 }
184 
ascend_soc_version() const185 std::string MsContext::ascend_soc_version() const { return ascend_soc_version_; }
186 
enable_dump_ir() const187 bool MsContext::enable_dump_ir() const {
188 #ifdef ENABLE_DUMP_IR
189   return true;
190 #else
191   return false;
192 #endif
193 }
194 
InitFuncMap()195 std::map<std::string, MsContext::InitDeviceTargetAndPolicy> &MsContext::InitFuncMap() {
196   static std::map<std::string, InitDeviceTargetAndPolicy> init_func_map = {};
197   return init_func_map;
198 }
199 
PluginPathMap()200 std::map<std::string, std::string> &MsContext::PluginPathMap() {
201   static std::map<std::string, std::string> plugin_path_map = {};
202   return plugin_path_map;
203 }
204 
RegisterInitFunc(const std::string & name,MsContext::InitDeviceTargetAndPolicy func)205 void MsContext::RegisterInitFunc(const std::string &name, MsContext::InitDeviceTargetAndPolicy func) {
206   (void)InitFuncMap().emplace(name, func);
207   if (GetInstance() != nullptr) {
208     GetInstance()->SetDefaultDeviceTarget();
209   }
210   std::string plugin_path;
211 #if !defined(_WIN32) && !defined(_WIN64)
212   Dl_info dl_info;
213   if (dladdr(reinterpret_cast<void *>(func), &dl_info) == 0) {
214     MS_LOG(EXCEPTION) << "Get dladdr error for " << name;
215   }
216   plugin_path = dl_info.dli_fname;
217 #else
218   HMODULE h_module = nullptr;
219   if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT | GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
220                         (LPCSTR)func, &h_module) == 0) {
221     MS_LOG(EXCEPTION) << "Get GetModuleHandleEx failed for " << name;
222   }
223   char sz_path[MAX_PATH];
224   if (GetModuleFileName(h_module, sz_path, sizeof(sz_path)) == 0) {
225     MS_LOG(EXCEPTION) << "Get GetModuleFileName failed for " << name;
226   }
227   plugin_path = std::string(sz_path);
228 #endif
229   (void)PluginPathMap().emplace(name, plugin_path);
230 }
231 
ResisterLoadPluginErrorFunc(MsContext::LoadPluginError func)232 void MsContext::ResisterLoadPluginErrorFunc(MsContext::LoadPluginError func) { load_plugin_error_ = func; }
233 
IsAscendPluginLoaded() const234 bool MsContext::IsAscendPluginLoaded() const {
235 #ifdef WITH_BACKEND
236   return InitFuncMap().find("Ascend") != InitFuncMap().end();
237 #else
238   // for ut test
239   return true;
240 #endif
241 }
242 
SetDefaultDeviceTarget()243 void MsContext::SetDefaultDeviceTarget() {
244   auto cpu_iter = InitFuncMap().find(kCPUDevice);
245   if (cpu_iter == InitFuncMap().end()) {
246     return;
247   }
248   if (InitFuncMap().size() == 1) {
249     // when only cpu in map
250     cpu_iter->second(inst_context_.get());
251   } else if (InitFuncMap().size() == kDeviceTargetSize2) {
252     // when cpu and another in map
253     for (auto [name, func] : InitFuncMap()) {
254       if (name != kCPUDevice) {
255         inst_context_ = std::make_shared<MsContext>("ms", name);
256         func(inst_context_.get());
257       }
258     }
259   } else {
260     cpu_iter->second(inst_context_.get());
261   }
262   default_device_target_ = true;
263 }
264 
SetDeviceTargetFromInner(const std::string & device_target)265 void MsContext::SetDeviceTargetFromInner(const std::string &device_target) {
266   if (seter_ != nullptr) {
267     if (!InitFuncMap().empty()) {
268       if (auto iter = InitFuncMap().find(device_target); iter == InitFuncMap().end()) {
269         CheckEnv(device_target);
270         std::string device_list = "[";
271         for (auto citer = InitFuncMap().cbegin(); citer != InitFuncMap().cend(); ++citer) {
272           if (device_list == "[") {
273             device_list += "\'" + citer->first + "\'";
274           } else {
275             device_list += ", \'" + citer->first + "\'";
276           }
277         }
278         device_list += "]";
279         if (load_plugin_error_ != nullptr) {
280           auto load_plugin_error_str = load_plugin_error_();
281           if (!load_plugin_error_str.empty()) {
282             MS_EXCEPTION(RuntimeError) << "Unsupported device target " << device_target
283                                        << ". This process only supports one of the " << device_list
284                                        << ". Please check whether the " << device_target
285                                        << " environment is installed and configured correctly, and check whether "
286                                           "current mindspore wheel package was built with \"-e "
287                                        << device_target
288                                        << "\". For details, please refer to \"Device load error message\"." << std::endl
289                                        << "#umsg#Device load error message:#umsg#" << load_plugin_error_str;
290           }
291         }
292         MS_EXCEPTION(RuntimeError) << "Unsupported device target " << device_target
293                                    << ". This process only supports one of the " << device_list
294                                    << ". Please check whether the " << device_target
295                                    << " environment is installed and configured correctly, and check whether "
296                                       "current mindspore wheel package was built with \"-e "
297                                    << device_target << "\".";
298       } else {
299         iter->second(this);
300         SetEnv(device_target);
301       }
302     }
303     MS_LOG(INFO) << "ms set context device target:" << device_target;
304     seter_(device_target);
305   }
306   if (device_target == "Ascend" && !CheckWriteStatus(MS_CTX_MEMORY_OPTIMIZE_LEVEL)) {
307     MS_LOG(INFO) << "Set memory_optimize_level to O1 as default on ascend";
308     int_params_[MS_CTX_MEMORY_OPTIMIZE_LEVEL - MS_CTX_TYPE_INT_BEGIN] = kOptimizeO1;
309   } else if (!CheckWriteStatus(MS_CTX_MEMORY_OPTIMIZE_LEVEL)) {
310     MS_LOG(INFO) << "Set memory_optimize_level to O0 as default on other device";
311     int_params_[MS_CTX_MEMORY_OPTIMIZE_LEVEL - MS_CTX_TYPE_INT_BEGIN] = kOptimizeO0;
312   }
313   string_params_[MS_CTX_DEVICE_TARGET - MS_CTX_TYPE_STRING_BEGIN] = device_target;
314 }
315 
SetDeviceTargetFromUser(const std::string & device_target)316 void MsContext::SetDeviceTargetFromUser(const std::string &device_target) {
317   SetDeviceTargetFromInner(device_target);
318   default_device_target_ = false;
319 }
320 
IsDefaultDeviceTarget() const321 bool MsContext::IsDefaultDeviceTarget() const { return default_device_target_; }
322 
RegisterSetEnv(const EnvFunc & func)323 void MsContext::RegisterSetEnv(const EnvFunc &func) { set_env_ = func; }
RegisterCheckEnv(const EnvFunc & func)324 void MsContext::RegisterCheckEnv(const EnvFunc &func) { check_env_ = func; }
325 
SetEnv(const std::string & device)326 void MsContext::SetEnv(const std::string &device) {
327   if (set_env_ == nullptr) {
328     return;
329   }
330 
331   if (auto iter = PluginPathMap().find(device); iter != PluginPathMap().end()) {
332     const auto &library_path = iter->second;
333     set_env_(device, library_path);
334   }
335 }
336 
CheckEnv(const std::string & device)337 void MsContext::CheckEnv(const std::string &device) {
338   if (check_env_ == nullptr) {
339     return;
340   }
341 
342   check_env_(device, "");
343 }
344 
GetSaveGraphsPath() const345 std::string MsContext::GetSaveGraphsPath() const {
346   std::string path = common::GetEnv("MS_DEV_SAVE_GRAPHS_PATH");
347   if (!path.empty()) {
348     return path;
349   } else {
350     return MsContext::GetInstance()->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
351   }
352 }
353 
GetSaveGraphsLevel() const354 int MsContext::GetSaveGraphsLevel() const {
355   static std::string save_env = common::GetEnv("MS_DEV_SAVE_GRAPHS");
356   if (save_env.size() == 1) {
357     int save_graphs_by_env = -1;
358     try {
359       save_graphs_by_env = std::stoi(save_env);
360     } catch (const std::invalid_argument &ia) {
361       MS_LOG(EXCEPTION) << "Invalid argument: " << ia.what() << " when parse " << save_env;
362     }
363     if (save_graphs_by_env < 0 || save_graphs_by_env > kFully) {
364       MS_LOG(EXCEPTION) << "Dump level can only be from 0 to 3";
365     }
366     return save_graphs_by_env;
367   } else if (save_env.size() > 1) {
368     MS_LOG(EXCEPTION) << "MS_DEV_SAVE_GRAPHS should be a single number with one digit.";
369   }
370   return MsContext::GetInstance()->get_param<int>(MS_CTX_SAVE_GRAPHS_FLAG);
371 }
372 
CanDump(const DumpLevel & level) const373 bool MsContext::CanDump(const DumpLevel &level) const { return GetSaveGraphsLevel() >= level; }
374 
MarkReadStatus(MsCtxParam param) const375 void MsContext::MarkReadStatus(MsCtxParam param) const {
376 #if !(defined(ENABLE_TEST) || defined(ENABLE_TESTCASES) || defined(BUILD_LITE))
377   // unit tests will set device_id many times in one process
378   if (static_cast<size_t>(param) < params_read_status_.size()) {
379     params_read_status_[static_cast<size_t>(param)] = true;
380   }
381 #endif
382 }
383 
MarkWriteStatus(MsCtxParam param) const384 void MsContext::MarkWriteStatus(MsCtxParam param) const {
385   if (static_cast<size_t>(param) < params_write_status_.size()) {
386     params_write_status_[static_cast<size_t>(param)] = true;
387   }
388 }
389 
390 template <typename T>
CheckReadStatus(MsCtxParam param,const T & value) const391 void MsContext::CheckReadStatus(MsCtxParam param, const T &value) const {
392 #if !(defined(ENABLE_TEST) || defined(ENABLE_TESTCASES) || defined(BUILD_LITE))
393   // unit tests will set device_id many times in one process
394   if (static_cast<size_t>(param) >= params_read_status_.size()) {
395     return;
396   }
397   auto iter = kUnresetParamCheckList.find(param);
398   if (iter == kUnresetParamCheckList.end()) {
399     return;
400   }
401   auto origin_status = params_read_status_;
402   T origin_value = get_param<T>(param);
403   params_read_status_ = origin_status;
404   if (params_read_status_[static_cast<size_t>(param)] && value != origin_value) {
405     MS_EXCEPTION(TypeError) << "For 'set_context', the parameter " << iter->second
406                             << " can not be set repeatedly, origin value [" << origin_value << "] has been in effect."
407                             << " Maybe 'mindspore.communication.init()' has been called before 'set_context()'.";
408   }
409 #endif
410 }
411 
CheckWriteStatus(MsCtxParam param) const412 bool MsContext::CheckWriteStatus(MsCtxParam param) const {
413   if (static_cast<size_t>(param) >= params_write_status_.size()) {
414     return false;
415   }
416   return params_write_status_[static_cast<size_t>(param)];
417 }
418 
419 // Reset ms context. Only called in child process after fork occurs.
ChildAfterFork()420 void MsContext::ChildAfterFork() {
421   MS_LOG(DEBUG) << "Reset context after fork.";
422   // configs can be modified again.
423   params_read_status_ = std::vector<bool>(
424     static_cast<size_t>(MsCtxParam::NUM_BOOL_PARAMS + MsCtxParam::NUM_UINT32_PARAMS + MsCtxParam::NUM_INT_PARAMS +
425                         MsCtxParam::NUM_FLOAT_PARAMS + MsCtxParam::NUM_STRING_PARAMS),
426     false);
427   params_write_status_ = std::vector<bool>(
428     static_cast<size_t>(MsCtxParam::NUM_BOOL_PARAMS + MsCtxParam::NUM_UINT32_PARAMS + MsCtxParam::NUM_INT_PARAMS +
429                         MsCtxParam::NUM_FLOAT_PARAMS + MsCtxParam::NUM_STRING_PARAMS),
430     false);
431   std::string device_target_ = get_param<std::string>(MS_CTX_DEVICE_TARGET);
432   if (device_target_ != kCPUDevice) {
433     // set device_target to 'CPU' as default.
434     MS_LOG(INFO) << "Process " << getpid() << " config changed: 'device_target' is reset to 'CPU'.";
435     SetDeviceTargetFromUser("CPU");
436   }
437 }
438 
EnableAoeOnline() const439 bool MsContext::EnableAoeOnline() const {
440   std::string aoe_tune_mode = MsContext::GetInstance()->get_param<std::string>(MS_CTX_AOE_TUNE_MODE);
441   return aoe_tune_mode == "online";
442 }
443 
EnableAoeOffline() const444 bool MsContext::EnableAoeOffline() const {
445   std::string aoe_tune_mode = MsContext::GetInstance()->get_param<std::string>(MS_CTX_AOE_TUNE_MODE);
446   return aoe_tune_mode == "offline";
447 }
448 
449 namespace {
PrintJitLevelAndExecMode(bool is_jit_level_changed,const std::string & jit_level,const std::string & exec_mode)450 void PrintJitLevelAndExecMode(bool is_jit_level_changed, const std::string &jit_level, const std::string &exec_mode) {
451   if (!is_jit_level_changed) {
452     return;
453   }
454 
455   MS_LOG(INFO) << "The jit_level is: " << jit_level << ", and " << exec_mode;
456   static std::string is_enable_runtime_cfg = common::GetEnv("MS_DEV_RUNTIME_CONF");
457   if (!is_enable_runtime_cfg.empty()) {
458     std::cout << "[MS_RUNTIME_PROF]The jit_level is: " << jit_level << ", and " << exec_mode << std::endl;
459   }
460 }
461 }  // namespace
462 
SetJitLevel(const std::string & jit_level) const463 void MsContext::SetJitLevel(const std::string &jit_level) const {
464   if (jit_level.empty()) {
465     return;
466   }
467   std::map<std::string, std::string> jit_config = PhaseManager::GetInstance().jit_config();
468   jit_config["jit_level"] = jit_level;
469   PhaseManager::GetInstance().set_jit_config(jit_config);
470 }
471 
GetJitLevel() const472 std::string MsContext::GetJitLevel() const {
473   static bool first_call = true;
474   const auto &jit_config = PhaseManager::GetInstance().jit_config();
475   std::string jit_level = "";
476   auto iter = jit_config.find("jit_level");
477   if (iter != jit_config.end()) {
478     jit_level = iter->second;
479   }
480 
481   auto global_jit_level = get_param<std::string>(MS_CTX_JIT_LEVEL);
482   auto device_target = get_param<std::string>(MS_CTX_DEVICE_TARGET);
483   auto mode = get_param<int>(MS_CTX_EXECUTION_MODE);
484   if (jit_level.empty()) {
485     if (!global_jit_level.empty()) {
486       jit_level = global_jit_level;
487     } else if (device_target == kAscendDevice && mode == kGraphMode) {
488       jit_level = ascend_soc_version() == kAscendVersion910 ? kAttrJitLevelO2 : kAttrJitLevelO0;
489     } else {
490       jit_level = kAttrJitLevelO0;
491     }
492   }
493 
494   if (mode == kPynativeMode && jit_level == kAttrJitLevelO2) {
495     if (first_call) {
496       MS_LOG(WARNING) << "Pynative mode can not set jit_level to O2, use O0 instead.";
497     }
498     jit_level = kAttrJitLevelO0;
499   }
500 
501   // If use rank table startup method, set jit level to O2.
502   if (!common::UseDynamicCluster() && !common::GetEnv("RANK_TABLE_FILE").empty() && jit_level != kAttrJitLevelO2) {
503     if (first_call) {
504       MS_LOG(WARNING) << "Set jit level to O2 for rank table startup method.";
505     }
506     jit_level = kAttrJitLevelO2;
507   }
508   first_call = false;
509 
510   return jit_level;
511 }
512 
IsKByKExecutorMode() const513 bool MsContext::IsKByKExecutorMode() const {
514   // Get jit level.
515   std::string jit_level = GetJitLevel();
516   static std::string jit_level_log = "";
517   bool is_jit_level_changed = false;
518   auto mode = get_param<int>(MS_CTX_EXECUTION_MODE);
519   if (jit_level_log != jit_level) {
520     is_jit_level_changed = true;
521     jit_level_log = jit_level;
522   }
523 
524   if (get_param<bool>(MS_CTX_ENABLE_MEM_OFFLOAD)) {
525     PrintJitLevelAndExecMode(is_jit_level_changed, jit_level, "enable kernelbykernel executor by mem offload.");
526     return true;
527   }
528 
529   if (mode == kPynativeMode) {
530     if (jit_level == kAttrJitLevelO2) {
531       PrintJitLevelAndExecMode(is_jit_level_changed, jit_level, "enable graph_sink executor in the PYNATIVE mode.");
532       return false;
533     }
534     PrintJitLevelAndExecMode(is_jit_level_changed, jit_level, "enable kernelbykernel executor in the PYNATIVE mode.");
535     return true;
536   }
537 
538   if (mode == kGraphMode) {
539     if (jit_level == kAttrJitLevelO0 || jit_level == kAttrJitLevelO1) {
540       PrintJitLevelAndExecMode(is_jit_level_changed, jit_level, "enable kernelbykernel executor in the GRAPH mode.");
541       return true;
542     }
543     PrintJitLevelAndExecMode(is_jit_level_changed, jit_level, "enable graph_sink executor in the GRAPH mode.");
544     return false;
545   }
546 
547   MS_LOG(ERROR) << "No valid executor mode.";
548   return false;
549 }
550 
SetAscendConfig()551 void MsContext::SetAscendConfig() {
552   set_param<std::string>(MS_CTX_PRECISION_MODE, "");
553   set_param<std::string>(MS_CTX_ENABLE_JIT_COMPILE, "");
554   set_param<std::string>(MS_CTX_ATOMIC_CLEAN_POLICY, "");
555   set_param<std::string>(MS_CTX_MATMUL_ALLOW_HF32, "");
556   set_param<std::string>(MS_CTX_CONV_ALLOW_HF32, "");
557   set_param<std::string>(MS_CTX_OP_PRECISION_MODE, "");
558   set_param<std::string>(MS_CTX_HOST_SCHEDULING_MAX_THRESHOLD, "");
559   set_param<std::string>(MS_CTX_GE_OPTIONS, "");
560 }
561 
InitBoolTypeDefaultValue()562 void MsContext::InitBoolTypeDefaultValue() {
563   set_param<bool>(MS_CTX_ENABLE_DUMP, false);
564   set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
565   set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
566   set_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION, true);
567   set_param<bool>(MS_CTX_ENABLE_TASK_SINK, true);
568   set_param<bool>(MS_CTX_IR_FUSION_FLAG, true);
569   set_param<bool>(MS_CTX_ENABLE_HCCL, false);
570   set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true);
571   set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false);
572   set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
573   set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false);
574   set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true);
575   set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
576   set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
577   set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
578   set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
579   set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
580   set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
581   set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
582   set_param<bool>(MS_CTX_ENABLE_MEM_OFFLOAD, false);
583   set_param<bool>(MS_CTX_ENABLE_PROF_MEM, false);
584   set_param<bool>(MS_CTX_ENABLE_RECOVERY, false);
585   set_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS, false);
586   set_param<bool>(MS_CTX_DISABLE_FORMAT_TRANSFORM, false);
587   set_param<bool>(MS_CTX_RECOMPUTE_COMM_OVERLAP, false);
588   set_param<bool>(MS_CTX_GRAD_COMM_OVERLAP, false);
589   set_param<bool>(MS_CTX_ENABLE_OPT_SHARD_COMM_OPT, false);
590   set_param<bool>(MS_CTX_ENABLE_TASK_OPT, false);
591   set_param<bool>(MS_CTX_ENABLE_GRAD_COMM_OPT, false);
592   set_param<bool>(MS_CTX_INTERLEAVED_MATMUL_COMM, false);
593   set_param<bool>(MS_CTX_INTERLEAVED_LAYERNORM_COMM, false);
594   set_param<bool>(MS_CTX_BIAS_ADD_COMM_SWAP, false);
595   set_param<bool>(MS_CTX_ENABLE_BEGIN_END_INLINE_OPT, false);
596   set_param<bool>(MS_CTX_ENABLE_CONCAT_ELIMINATE_OPT, false);
597   set_param<bool>(MS_CTX_ENABLE_FUSED_CAST_ADD_OPT, false);
598   set_param<bool>(MS_CTX_ENABLE_PROFILING, false);
599   set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
600   set_param<bool>(MS_CTX_CONV_ALLOW_TF32, true);
601   set_param<bool>(MS_CTX_MATMUL_ALLOW_TF32, false);
602   set_param<bool>(MS_CTX_NEED_CKPT, false);
603   set_param<bool>(MS_CTX_RECOMPUTE_ALLGATHER_OVERLAP_FAGRAD, false);
604   set_param<bool>(MS_CTX_ENABLE_FLASH_ATTENTION_LOAD_BALANCE, false);
605 }
606 
InitStringTypeDefaultValue()607 void MsContext::InitStringTypeDefaultValue() {
608   set_param<std::string>(MS_CTX_PYTHON_EXE_PATH, "python");
609   set_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR, "");
610   set_param<std::string>(MS_CTX_SAVE_DUMP_PATH, ".");
611   set_param<std::string>(MS_CTX_DETERMINISTIC, "OFF");
612   set_param<std::string>(MS_CTX_ENV_CONFIG_PATH, "");
613   set_param<std::string>(MS_CTX_TUNE_MODE, "NO_TUNE");
614   set_param<std::string>(MS_CTX_AOE_TUNE_MODE, "");
615   set_param<std::string>(MS_CTX_AOE_JOB_TYPE, "2");
616   set_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS, "");
617   set_param<std::string>(MS_CTX_HOST_SCHEDULING_MAX_THRESHOLD, "");
618   set_param<std::string>(MS_CTX_ENABLE_EXCEPTION_DUMP, "2");
619   set_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE, "0");
620   set_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "0");
621   set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
622   set_param<std::string>(MS_CTX_PRINT_FILE_PATH, "");
623   set_param<std::string>(MS_CTX_CONV_FPROP_ALGO, "normal");
624   set_param<std::string>(MS_CTX_CONV_DGRAD_ALGO, "normal");
625   set_param<std::string>(MS_CTX_CONV_WGRAD_ALGO, "normal");
626   set_param<std::string>(MS_CTX_JIT_LEVEL, "");
627   set_param<std::string>(MS_CTX_INFER_BOOST, "off");
628   set_param<std::string>(MS_CTX_PROF_MEM_OUTPUT_PATH, "");
629 }
630 
InitDigitalTypeDefaultValue()631 void MsContext::InitDigitalTypeDefaultValue() {
632   set_param<int>(MS_CTX_EXECUTION_MODE, kPynativeMode);
633   set_param<int>(MS_CTX_JIT_SYNTAX_LEVEL, kLax);
634   set_param<int>(MS_CTX_CUR_STEP_NUM, 0);
635   set_param<int>(MS_CTX_SAVE_CKPT_STEPS, 0);
636   set_param<int>(MS_CTX_LAST_TRIGGERED_STEP, 0);
637   set_param<int>(MS_CTX_COMPUTE_COMMUNICATE_FUSION_LEVEL, 0);
638   set_param<int>(MS_CTX_ENABLE_COMPILE_CACHE, -1);
639   set_param<int>(MS_CTX_DEBUG_LEVEL, kLevelRelease);
640   set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);
641   set_param<float>(MS_CTX_MEMPOOL_BLOCK_SIZE, kDefaultMempoolBlockSize);
642   //
643   uint32_t kDefaultInterOpParallelThreads = 0;
644   uint32_t kDefaultRuntimeNumThreads = 30;
645   uint32_t cpu_core_num = std::thread::hardware_concurrency();
646   uint32_t runtime_num_threads_default = std::min(cpu_core_num, kDefaultRuntimeNumThreads);
647   uint32_t inter_op_parallel_num_default = std::min(cpu_core_num, kDefaultInterOpParallelThreads);
648   set_param<uint32_t>(MS_CTX_RUNTIME_NUM_THREADS, runtime_num_threads_default);
649   set_param<uint32_t>(MS_CTX_INTER_OP_PARALLEL_NUM, inter_op_parallel_num_default);
650   //
651   set_param<uint32_t>(MS_CTX_TSD_REF, 0);
652   set_param<uint32_t>(MS_CTX_GE_REF, 0);
653   set_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT);
654   set_param<uint32_t>(MS_CTX_OP_TIMEOUT, kOpTimeout);
655 }
656 
SplitString(const std::string & str,char delim,std::set<std::string> * output_list)657 inline void SplitString(const std::string &str, char delim, std::set<std::string> *output_list) {
658   std::stringstream ss(str);
659   std::string item;
660   while (std::getline(ss, item, delim)) {
661     if (!item.empty()) {
662       output_list->emplace(item);
663     }
664   }
665 }
666 
SetToString(const std::set<std::string> & kernel_list)667 inline std::string SetToString(const std::set<std::string> &kernel_list) {
668   std::string out = "";
669   for (auto &name : kernel_list) {
670     out.append(name).append(", ");
671   }
672   return out;
673 }
674 
SetMsInternalEnableCustomKernelList()675 void MsContext::SetMsInternalEnableCustomKernelList() {
676   const std::string kDefaultEnabledOpList =
677     "MatMul,RmsNorm,Add,Sub,FlashAttentionScore,PagedAttention,AddRmsNorm,AddLayerNorm,MatMulAllReduce,"
678     "InferenceMatmulSplit";
679   auto internal_op_boost_env = common::GetEnv("MS_ENABLE_INTERNAL_BOOST");
680   bool is_enable_internal_op = true;
681   if (internal_op_boost_env == "off") {
682     is_enable_internal_op = false;
683   }
684 
685   std::set<std::string> enable_fusion_list;
686   if (is_enable_internal_op) {
687     SplitString(kDefaultEnabledOpList, ',', &enable_fusion_list);
688   }
689 
690   std::string env = common::GetEnv("MS_INTERNAL_ENABLE_CUSTOM_KERNEL_LIST");
691   if (!env.empty()) {
692     SplitString(env, ',', &enable_fusion_list);
693   }
694 
695   std::set<std::string> disable_fusion_list;
696   env = common::GetEnv("MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST");
697   if (!env.empty()) {
698     SplitString(env, ',', &disable_fusion_list);
699   }
700 
701   ms_internal_enable_custom_kernel_list_.clear();
702   for (const auto &fusion_name : enable_fusion_list) {
703     if (disable_fusion_list.find(fusion_name) == disable_fusion_list.end()) {
704       ms_internal_enable_custom_kernel_list_.emplace(fusion_name);
705     }
706   }
707 
708   MS_LOG(INFO) << "Enable internal kernel list: " << SetToString(ms_internal_enable_custom_kernel_list_);
709 }
710 
IsEnableInferBoost()711 bool MsContext::IsEnableInferBoost() {
712   if (enable_infer_boost_.has_value()) {
713     return enable_infer_boost_.value();
714   }
715 
716   const auto &jit_config = PhaseManager::GetInstance().jit_config();
717   auto iter = jit_config.find("infer_boost");
718   if (iter != jit_config.end() && iter->second == "on") {
719     enable_infer_boost_ = true;
720     MS_LOG(INFO) << "MSContext enable ms infer boost from JitConfig";
721     SetMsInternalEnableCustomKernelList();
722     return enable_infer_boost_.value();
723   }
724 
725   auto global_infer_boost = get_param<std::string>(MS_CTX_INFER_BOOST);
726   if (global_infer_boost == "on") {
727     enable_infer_boost_ = true;
728     MS_LOG(INFO) << "MSContext enable ms infer boost from Global Context JitConfig";
729     SetMsInternalEnableCustomKernelList();
730     return enable_infer_boost_.value();
731   }
732 
733   if (common::GetEnv("MS_ENABLE_INTERNAL_KERNELS") == "on") {
734     enable_infer_boost_ = true;
735     MS_LOG(INFO) << "MSContext enable ms infer boost from Env";
736     SetMsInternalEnableCustomKernelList();
737   } else {
738     enable_infer_boost_ = false;
739   }
740 
741   return enable_infer_boost_.value();
742 }
743 
ms_internal_enable_custom_kernel_list() const744 const std::set<std::string> &MsContext::ms_internal_enable_custom_kernel_list() const {
745   return ms_internal_enable_custom_kernel_list_;
746 }
747 
748 template MS_CORE_API void MsContext::CheckReadStatus<bool>(MsCtxParam, const bool &) const;
749 template MS_CORE_API void MsContext::CheckReadStatus<uint32_t>(MsCtxParam, const uint32_t &) const;
750 template MS_CORE_API void MsContext::CheckReadStatus<int>(MsCtxParam, const int &) const;
751 template MS_CORE_API void MsContext::CheckReadStatus<float>(MsCtxParam, const float &) const;
752 template MS_CORE_API void MsContext::CheckReadStatus<std::string>(MsCtxParam, const std::string &) const;
753 }  // namespace mindspore
754