• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/context/context_extends.h"
18 #include <map>
19 #include <string>
20 #include <memory>
21 #include <thread>
22 #include "pybind11/pybind11.h"
23 #include "utils/ms_utils.h"
24 #include "utils/convert_utils_base.h"
25 #ifndef NO_DLIB
26 #include "acl/acl_tdt.h"
27 #include "runtime/dev.h"
28 #include "toolchain/plog.h"
29 #include "common/util/error_manager/error_manager.h"
30 #endif
31 #ifdef ENABLE_GE
32 #include "transform/graph_ir/df_graph_manager.h"
33 #endif
34 #include "profiler/device/profiling.h"
35 
36 namespace py = pybind11;
37 
38 namespace mindspore {
39 namespace context {
40 #ifdef ENABLE_GE
41 using mindspore::transform::DfGraphManager;
42 #endif
43 
44 constexpr auto kUnknowErrorString = "Unknown error occurred";
45 
46 #ifndef NO_DLIB
47 // Open tdt dataset
OpenTsd(const std::shared_ptr<MsContext> & ms_context_ptr)48 bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
49   if (ms_context_ptr == nullptr) {
50     MS_LOG(EXCEPTION) << "nullptr";
51   }
52 
53   if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
54     return true;
55   }
56 
57   if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
58     MS_LOG(DEBUG) << "ACLTDT Dataset client is already opened.";
59     ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
60     return true;
61   }
62 
63   auto role = common::GetEnv("MS_ROLE");
64   if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) {
65     return true;
66   }
67 
68   uint32_t rank_size = 1;
69   uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
70 
71   auto rank_size_env = common::GetEnv("RANK_SIZE");
72   if (rank_size_env.empty()) {
73     MS_LOG(INFO) << "Should config rank size.";
74     rank_size = 1;
75   } else {
76     int rank_env = std::stoi(rank_size_env);
77     if (rank_env <= 0) {
78       MS_LOG(EXCEPTION) << "Error rank size " << rank_env << ".";
79     }
80     rank_size = IntToUint(rank_env);
81   }
82 
83   int log_ret = DlogReportInitialize();
84   if (log_ret != 0) {
85     MS_LOG(WARNING) << "Init slog failed, ret = " << log_ret;
86   }
87 
88   MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
89   auto ret = rtSetDevice(static_cast<int32_t>(device_id));
90   if (ret != RT_ERROR_NONE) {
91     const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
92     if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
93       MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
94     }
95     MS_LOG(EXCEPTION) << "Device " << device_id << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
96   }
97   ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
98 #ifdef ENABLE_TDTQUE
99   auto thread_crt = [](const std::string &path, const acltdtChannelHandle *acl_handle) {
100     return std::thread(TensorPrint(path, acl_handle));
101   };
102   ms_context_ptr->CreateTensorPrintThread(thread_crt);
103 #endif
104   return true;
105 }
106 
CloseTsd(const std::shared_ptr<MsContext> & ms_context_ptr,bool force)107 bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
108   if (ms_context_ptr == nullptr) {
109     MS_LOG(EXCEPTION) << "ms_context_prt is nullptr";
110   }
111   if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
112     return true;
113   }
114   ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
115   if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
116     ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
117 
118 #ifdef ENABLE_TDTQUE
119     py::gil_scoped_release gil_release;
120     ms_context_ptr->DestroyTensorPrintThread();
121 #endif
122     uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
123     auto ret = rtDeviceReset(static_cast<int32_t>(device_id));
124     if (ret != RT_ERROR_NONE) {
125       const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
126       if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
127         MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
128       }
129       MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
130       return false;
131     }
132     ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
133     MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast<int>(ret) << "]";
134     (void)DlogReportFinalize();
135   } else {
136     MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = "
137                   << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
138   }
139   return true;
140 }
141 #else
OpenTsd(const std::shared_ptr<MsContext> & ms_context_ptr)142 bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { return true; }
CloseTsd(const std::shared_ptr<MsContext> & ms_context_ptr,bool)143 bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool) { return true; }
144 #endif
145 
SetDisableReuseMemoryFlag(std::map<std::string,std::string> * ge_options)146 void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) {
147   auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY");
148   if (!env_disable_reuse_memory.empty()) {
149     (*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory;
150   } else {
151     (*ge_options)["ge.exec.disableReuseMemory"] = "0";
152     MS_LOG(WARNING) << "DISABLE_REUSE_MEMORY is not set in ENV. Now set to default value 0";
153   }
154 }
155 
GetGeOptions(const std::shared_ptr<MsContext> & ms_context_ptr,std::map<std::string,std::string> * ge_options)156 void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std::string, std::string> *ge_options) {
157   if (ms_context_ptr == nullptr) {
158     MS_LOG(EXCEPTION) << "nullptr";
159   }
160 #ifdef ENABLE_GE
161   (*ge_options)["device_id"] = "0";
162   (*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP));
163   (*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH);
164   (*ge_options)["ge.exec.dumpMode"] = "output";
165   MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP))
166                << " and save dump path is " << ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH) << ".";
167 
168   auto profiler_manager = profiler::ProfilerManager::GetInstance();
169   if (profiler_manager == nullptr) {
170     MS_LOG(EXCEPTION) << "Profiler manager is nullptr";
171   }
172   (*ge_options)["ge.exec.profilingMode"] = std::to_string(profiler_manager->GetProfilingEnableFlag());
173   if (profiler_manager->GetProfilingEnableFlag()) {
174     (*ge_options)["ge.exec.profilingOptions"] = profiler_manager->GetProfilingOptions();
175   }
176 
177   (*ge_options)["rank_table_file"] = "";
178   auto env_ddk_version = common::GetEnv("DDK_VERSION");
179   if (!env_ddk_version.empty()) {
180     (*ge_options)["ge.DDK_version"] = env_ddk_version;
181   } else {
182     (*ge_options)["ge.DDK_version"] = "1.60.T17.B830";
183   }
184   (*ge_options)["graphType"] = "1";
185 
186   if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
187     (*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
188   }
189 
190   if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
191     (*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
192   }
193 
194 #if ENABLE_TRAIN == 1
195   (*ge_options)["ge.graphRunMode"] = "1";
196 #endif
197   SetDisableReuseMemoryFlag(ge_options);
198   SetHcclOptions(ms_context_ptr, ge_options);
199 
200   auto env_job_id = common::GetEnv("JOB_ID");
201   if (!env_job_id.empty()) {
202     (*ge_options)["ge.exec.jobId"] = env_job_id;
203   } else {
204     (*ge_options)["ge.exec.jobId"] = "0";
205     MS_LOG(WARNING) << "JOB_ID is not set in ENV. Now set to default value 0";
206   }
207 
208   auto env_fe_flag = common::GetEnv("FE_FLAG");
209   if (!env_fe_flag.empty()) {
210     (*ge_options)["ge.feFlag"] = env_fe_flag;
211     MS_LOG(INFO) << "Use FE, make sure fe lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
212   }
213 
214   auto env_aicpu_flag = common::GetEnv("AICPU_FLAG");
215   if (!env_aicpu_flag.empty()) {
216     (*ge_options)["ge.aicpuFlag"] = env_aicpu_flag;
217     MS_LOG(INFO) << "Use AICPU, make sure aicpu lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
218   }
219 
220   auto proto_lib_path = common::GetEnv("OPTION_PROTO_LIB_PATH");
221   if (!proto_lib_path.empty()) {
222     char real_path[PATH_MAX] = {0};
223     if (realpath(proto_lib_path.c_str(), real_path)) {
224       proto_lib_path = real_path;
225       (*ge_options)["ge.opsProtoLibPath"] = proto_lib_path;
226     }
227   } else {
228     MS_LOG(WARNING) << "Set proto lib path failed!";
229   }
230 
231   (*ge_options)["ge.exec.precision_mode"] = "force_fp16";
232 
233   // Disable the global variable acc, only enable it while adding training graph in pipeline
234   (*ge_options)["ge.exec.variable_acc"] = "0";
235 #endif
236 }
237 
SetHcclOptions(const std::shared_ptr<MsContext> & ms_context_ptr,std::map<std::string,std::string> * ge_options)238 void SetHcclOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std::string, std::string> *ge_options) {
239   if (ms_context_ptr == nullptr) {
240     MS_LOG(EXCEPTION) << "nullptr";
241   }
242   auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
243   auto env_rank_id = common::GetEnv("RANK_ID");
244   auto env_device_id = std::to_string(ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID));
245   if (!(env_table_file.empty() || env_rank_id.empty())) {
246     MS_LOG(INFO) << "Initialize Ge for distribute parameter";
247     MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
248     auto env_hccl_flag = common::GetEnv("HCCL_FLAG");
249     if (!env_hccl_flag.empty()) {
250       (*ge_options)["ge.exec.hcclFlag"] = env_hccl_flag;
251     }
252     (*ge_options)["ge.exec.isUseHcom"] = "1";
253     (*ge_options)["ge.exec.deviceId"] = env_device_id;
254     (*ge_options)["ge.exec.rankId"] = env_rank_id;
255     (*ge_options)["ge.exec.podName"] = env_rank_id;
256     (*ge_options)["ge.exec.rankTableFile"] = env_table_file;
257     (*ge_options)["ge.graphRunMode"] = "1";
258   } else {
259     // device id is still needed for non-distribute case
260     (*ge_options)["ge.exec.deviceId"] = env_device_id;
261     MS_LOG(INFO) << "No hccl mode. "
262                     "If use hccl, make sure [RANK_TABLE_FILE,RANK_ID,DEVICE_ID,DEPLOY_MODE] all be set in ENV.";
263   }
264 
265   auto env_deploy_mode = common::GetEnv("DEPLOY_MODE");
266   if (!env_deploy_mode.empty()) {
267     (*ge_options)["ge.exec.deployMode"] = env_deploy_mode;
268   } else {
269     (*ge_options)["ge.exec.deployMode"] = "0";
270     MS_LOG(WARNING) << "DEPLOY_MODE is not set in ENV. Now set to default value 0";
271   }
272 }
273 
InitGe(const std::shared_ptr<MsContext> & ms_context_ptr)274 bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
275   if (ms_context_ptr == nullptr) {
276     MS_LOG(EXCEPTION) << "nullptr";
277   }
278 #ifdef ENABLE_GE
279   if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
280     return true;
281   }
282 
283   if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF)) {
284     ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
285     return true;
286   }
287 
288   std::map<std::string, std::string> ge_options;
289   GetGeOptions(ms_context_ptr, &ge_options);
290   {
291     // Release GIL before calling into (potentially long-running) C++ code
292     py::gil_scoped_release release;
293     if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) {
294       MS_LOG(EXCEPTION) << "Initialize GE failed!";
295     }
296   }
297   ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
298   MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
299 #endif
300   return true;
301 }
302 
PynativeInitGe(const std::shared_ptr<MsContext> & ms_context_ptr)303 bool PynativeInitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
304   if (ms_context_ptr == nullptr) {
305     MS_LOG(EXCEPTION) << "nullptr";
306   }
307   if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT) ||
308       ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
309     return true;
310   }
311 
312   (void)OpenTsd(ms_context_ptr);
313   (void)InitGe(ms_context_ptr);
314   ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true);
315   return true;
316 }
317 
FinalizeGe(const std::shared_ptr<MsContext> & ms_context_ptr,bool force)318 bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
319   if (ms_context_ptr == nullptr) {
320     MS_LOG(EXCEPTION) << "nullptr";
321   }
322 #ifdef ENABLE_GE
323   if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
324     return true;
325   }
326   ms_context_ptr->decrease_param<uint32_t>(MS_CTX_GE_REF);
327   if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
328     ms_context_ptr->set_param<uint32_t>(MS_CTX_GE_REF, 0);
329     try {
330       DfGraphManager::GetInstance().DeleteGraphRunner();
331       DfGraphManager::GetInstance().DeleteGeSession();
332     } catch (const std::exception &e) {
333       MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what();
334     } catch (...) {
335       std::string exName(abi::__cxa_current_exception_type()->name());
336       MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName;
337     }
338     if (ge::GEFinalize() != ge::GRAPH_SUCCESS) {
339       MS_LOG(WARNING) << "Finalize GE failed!";
340     }
341     ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
342   } else {
343     MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
344                  << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
345   }
346 #endif
347   return true;
348 }
349 
IsTsdOpened(const std::shared_ptr<MsContext> & ms_context_ptr)350 bool IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
351   if (ms_context_ptr == nullptr) {
352     MS_LOG(EXCEPTION) << "nullptr";
353   }
354   return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0;
355 }
356 
IsGeInited(const std::shared_ptr<MsContext> & ms_context_ptr)357 bool IsGeInited(const std::shared_ptr<MsContext> &ms_context_ptr) {
358   if (ms_context_ptr == nullptr) {
359     MS_LOG(EXCEPTION) << "nullptr";
360   }
361   return ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) > 0;
362 }
363 
364 // Register for device type.
365 struct DeviceTypeSetRegister {
DeviceTypeSetRegistermindspore::context::DeviceTypeSetRegister366   DeviceTypeSetRegister() {
367     MsContext::device_type_seter([](std::shared_ptr<MsContext> &device_type_seter) {
368 #ifdef ENABLE_GE
369       device_type_seter.reset(new (std::nothrow) MsContext("ge", kAscendDevice));
370 #elif defined(ENABLE_D)
371       device_type_seter.reset(new (std::nothrow) MsContext("ms", kAscendDevice));
372 #elif defined(ENABLE_GPU)
373       device_type_seter.reset(new (std::nothrow) MsContext("ms", kGPUDevice));
374 #else
375       device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
376 #endif
377     });
378   }
379   DeviceTypeSetRegister(const DeviceTypeSetRegister &) = delete;
380   DeviceTypeSetRegister &operator=(const DeviceTypeSetRegister &) = delete;
381   ~DeviceTypeSetRegister() = default;
382 } device_type_set_regsiter;
383 }  // namespace context
384 }  // namespace mindspore
385