1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/context/context_extends.h"
18 #include <map>
19 #include <string>
20 #include <memory>
21 #include <thread>
22 #include "pybind11/pybind11.h"
23 #include "utils/ms_utils.h"
24 #include "utils/convert_utils_base.h"
25 #ifndef NO_DLIB
26 #include "acl/acl_tdt.h"
27 #include "runtime/dev.h"
28 #include "toolchain/plog.h"
29 #include "common/util/error_manager/error_manager.h"
30 #endif
31 #ifdef ENABLE_GE
32 #include "transform/graph_ir/df_graph_manager.h"
33 #endif
34 #include "profiler/device/profiling.h"
35
36 namespace py = pybind11;
37
38 namespace mindspore {
39 namespace context {
40 #ifdef ENABLE_GE
41 using mindspore::transform::DfGraphManager;
42 #endif
43
44 constexpr auto kUnknowErrorString = "Unknown error occurred";
45
46 #ifndef NO_DLIB
47 // Open tdt dataset
OpenTsd(const std::shared_ptr<MsContext> & ms_context_ptr)48 bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
49 if (ms_context_ptr == nullptr) {
50 MS_LOG(EXCEPTION) << "nullptr";
51 }
52
53 if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
54 return true;
55 }
56
57 if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
58 MS_LOG(DEBUG) << "ACLTDT Dataset client is already opened.";
59 ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
60 return true;
61 }
62
63 auto role = common::GetEnv("MS_ROLE");
64 if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) {
65 return true;
66 }
67
68 uint32_t rank_size = 1;
69 uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
70
71 auto rank_size_env = common::GetEnv("RANK_SIZE");
72 if (rank_size_env.empty()) {
73 MS_LOG(INFO) << "Should config rank size.";
74 rank_size = 1;
75 } else {
76 int rank_env = std::stoi(rank_size_env);
77 if (rank_env <= 0) {
78 MS_LOG(EXCEPTION) << "Error rank size " << rank_env << ".";
79 }
80 rank_size = IntToUint(rank_env);
81 }
82
83 int log_ret = DlogReportInitialize();
84 if (log_ret != 0) {
85 MS_LOG(WARNING) << "Init slog failed, ret = " << log_ret;
86 }
87
88 MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
89 auto ret = rtSetDevice(static_cast<int32_t>(device_id));
90 if (ret != RT_ERROR_NONE) {
91 const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
92 if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
93 MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
94 }
95 MS_LOG(EXCEPTION) << "Device " << device_id << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
96 }
97 ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
98 #ifdef ENABLE_TDTQUE
99 auto thread_crt = [](const std::string &path, const acltdtChannelHandle *acl_handle) {
100 return std::thread(TensorPrint(path, acl_handle));
101 };
102 ms_context_ptr->CreateTensorPrintThread(thread_crt);
103 #endif
104 return true;
105 }
106
CloseTsd(const std::shared_ptr<MsContext> & ms_context_ptr,bool force)107 bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
108 if (ms_context_ptr == nullptr) {
109 MS_LOG(EXCEPTION) << "ms_context_prt is nullptr";
110 }
111 if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
112 return true;
113 }
114 ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
115 if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
116 ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
117
118 #ifdef ENABLE_TDTQUE
119 py::gil_scoped_release gil_release;
120 ms_context_ptr->DestroyTensorPrintThread();
121 #endif
122 uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
123 auto ret = rtDeviceReset(static_cast<int32_t>(device_id));
124 if (ret != RT_ERROR_NONE) {
125 const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
126 if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
127 MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
128 }
129 MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
130 return false;
131 }
132 ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
133 MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast<int>(ret) << "]";
134 (void)DlogReportFinalize();
135 } else {
136 MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = "
137 << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
138 }
139 return true;
140 }
141 #else
OpenTsd(const std::shared_ptr<MsContext> & ms_context_ptr)142 bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { return true; }
CloseTsd(const std::shared_ptr<MsContext> & ms_context_ptr,bool)143 bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool) { return true; }
144 #endif
145
SetDisableReuseMemoryFlag(std::map<std::string,std::string> * ge_options)146 void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) {
147 auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY");
148 if (!env_disable_reuse_memory.empty()) {
149 (*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory;
150 } else {
151 (*ge_options)["ge.exec.disableReuseMemory"] = "0";
152 MS_LOG(WARNING) << "DISABLE_REUSE_MEMORY is not set in ENV. Now set to default value 0";
153 }
154 }
155
GetGeOptions(const std::shared_ptr<MsContext> & ms_context_ptr,std::map<std::string,std::string> * ge_options)156 void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std::string, std::string> *ge_options) {
157 if (ms_context_ptr == nullptr) {
158 MS_LOG(EXCEPTION) << "nullptr";
159 }
160 #ifdef ENABLE_GE
161 (*ge_options)["device_id"] = "0";
162 (*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP));
163 (*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH);
164 (*ge_options)["ge.exec.dumpMode"] = "output";
165 MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP))
166 << " and save dump path is " << ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH) << ".";
167
168 auto profiler_manager = profiler::ProfilerManager::GetInstance();
169 if (profiler_manager == nullptr) {
170 MS_LOG(EXCEPTION) << "Profiler manager is nullptr";
171 }
172 (*ge_options)["ge.exec.profilingMode"] = std::to_string(profiler_manager->GetProfilingEnableFlag());
173 if (profiler_manager->GetProfilingEnableFlag()) {
174 (*ge_options)["ge.exec.profilingOptions"] = profiler_manager->GetProfilingOptions();
175 }
176
177 (*ge_options)["rank_table_file"] = "";
178 auto env_ddk_version = common::GetEnv("DDK_VERSION");
179 if (!env_ddk_version.empty()) {
180 (*ge_options)["ge.DDK_version"] = env_ddk_version;
181 } else {
182 (*ge_options)["ge.DDK_version"] = "1.60.T17.B830";
183 }
184 (*ge_options)["graphType"] = "1";
185
186 if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
187 (*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
188 }
189
190 if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
191 (*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
192 }
193
194 #if ENABLE_TRAIN == 1
195 (*ge_options)["ge.graphRunMode"] = "1";
196 #endif
197 SetDisableReuseMemoryFlag(ge_options);
198 SetHcclOptions(ms_context_ptr, ge_options);
199
200 auto env_job_id = common::GetEnv("JOB_ID");
201 if (!env_job_id.empty()) {
202 (*ge_options)["ge.exec.jobId"] = env_job_id;
203 } else {
204 (*ge_options)["ge.exec.jobId"] = "0";
205 MS_LOG(WARNING) << "JOB_ID is not set in ENV. Now set to default value 0";
206 }
207
208 auto env_fe_flag = common::GetEnv("FE_FLAG");
209 if (!env_fe_flag.empty()) {
210 (*ge_options)["ge.feFlag"] = env_fe_flag;
211 MS_LOG(INFO) << "Use FE, make sure fe lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
212 }
213
214 auto env_aicpu_flag = common::GetEnv("AICPU_FLAG");
215 if (!env_aicpu_flag.empty()) {
216 (*ge_options)["ge.aicpuFlag"] = env_aicpu_flag;
217 MS_LOG(INFO) << "Use AICPU, make sure aicpu lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
218 }
219
220 auto proto_lib_path = common::GetEnv("OPTION_PROTO_LIB_PATH");
221 if (!proto_lib_path.empty()) {
222 char real_path[PATH_MAX] = {0};
223 if (realpath(proto_lib_path.c_str(), real_path)) {
224 proto_lib_path = real_path;
225 (*ge_options)["ge.opsProtoLibPath"] = proto_lib_path;
226 }
227 } else {
228 MS_LOG(WARNING) << "Set proto lib path failed!";
229 }
230
231 (*ge_options)["ge.exec.precision_mode"] = "force_fp16";
232
233 // Disable the global variable acc, only enable it while adding training graph in pipeline
234 (*ge_options)["ge.exec.variable_acc"] = "0";
235 #endif
236 }
237
SetHcclOptions(const std::shared_ptr<MsContext> & ms_context_ptr,std::map<std::string,std::string> * ge_options)238 void SetHcclOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std::string, std::string> *ge_options) {
239 if (ms_context_ptr == nullptr) {
240 MS_LOG(EXCEPTION) << "nullptr";
241 }
242 auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
243 auto env_rank_id = common::GetEnv("RANK_ID");
244 auto env_device_id = std::to_string(ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID));
245 if (!(env_table_file.empty() || env_rank_id.empty())) {
246 MS_LOG(INFO) << "Initialize Ge for distribute parameter";
247 MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
248 auto env_hccl_flag = common::GetEnv("HCCL_FLAG");
249 if (!env_hccl_flag.empty()) {
250 (*ge_options)["ge.exec.hcclFlag"] = env_hccl_flag;
251 }
252 (*ge_options)["ge.exec.isUseHcom"] = "1";
253 (*ge_options)["ge.exec.deviceId"] = env_device_id;
254 (*ge_options)["ge.exec.rankId"] = env_rank_id;
255 (*ge_options)["ge.exec.podName"] = env_rank_id;
256 (*ge_options)["ge.exec.rankTableFile"] = env_table_file;
257 (*ge_options)["ge.graphRunMode"] = "1";
258 } else {
259 // device id is still needed for non-distribute case
260 (*ge_options)["ge.exec.deviceId"] = env_device_id;
261 MS_LOG(INFO) << "No hccl mode. "
262 "If use hccl, make sure [RANK_TABLE_FILE,RANK_ID,DEVICE_ID,DEPLOY_MODE] all be set in ENV.";
263 }
264
265 auto env_deploy_mode = common::GetEnv("DEPLOY_MODE");
266 if (!env_deploy_mode.empty()) {
267 (*ge_options)["ge.exec.deployMode"] = env_deploy_mode;
268 } else {
269 (*ge_options)["ge.exec.deployMode"] = "0";
270 MS_LOG(WARNING) << "DEPLOY_MODE is not set in ENV. Now set to default value 0";
271 }
272 }
273
InitGe(const std::shared_ptr<MsContext> & ms_context_ptr)274 bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
275 if (ms_context_ptr == nullptr) {
276 MS_LOG(EXCEPTION) << "nullptr";
277 }
278 #ifdef ENABLE_GE
279 if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
280 return true;
281 }
282
283 if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF)) {
284 ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
285 return true;
286 }
287
288 std::map<std::string, std::string> ge_options;
289 GetGeOptions(ms_context_ptr, &ge_options);
290 {
291 // Release GIL before calling into (potentially long-running) C++ code
292 py::gil_scoped_release release;
293 if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) {
294 MS_LOG(EXCEPTION) << "Initialize GE failed!";
295 }
296 }
297 ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
298 MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
299 #endif
300 return true;
301 }
302
PynativeInitGe(const std::shared_ptr<MsContext> & ms_context_ptr)303 bool PynativeInitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
304 if (ms_context_ptr == nullptr) {
305 MS_LOG(EXCEPTION) << "nullptr";
306 }
307 if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT) ||
308 ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
309 return true;
310 }
311
312 (void)OpenTsd(ms_context_ptr);
313 (void)InitGe(ms_context_ptr);
314 ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true);
315 return true;
316 }
317
FinalizeGe(const std::shared_ptr<MsContext> & ms_context_ptr,bool force)318 bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
319 if (ms_context_ptr == nullptr) {
320 MS_LOG(EXCEPTION) << "nullptr";
321 }
322 #ifdef ENABLE_GE
323 if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
324 return true;
325 }
326 ms_context_ptr->decrease_param<uint32_t>(MS_CTX_GE_REF);
327 if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
328 ms_context_ptr->set_param<uint32_t>(MS_CTX_GE_REF, 0);
329 try {
330 DfGraphManager::GetInstance().DeleteGraphRunner();
331 DfGraphManager::GetInstance().DeleteGeSession();
332 } catch (const std::exception &e) {
333 MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what();
334 } catch (...) {
335 std::string exName(abi::__cxa_current_exception_type()->name());
336 MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName;
337 }
338 if (ge::GEFinalize() != ge::GRAPH_SUCCESS) {
339 MS_LOG(WARNING) << "Finalize GE failed!";
340 }
341 ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
342 } else {
343 MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
344 << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
345 }
346 #endif
347 return true;
348 }
349
IsTsdOpened(const std::shared_ptr<MsContext> & ms_context_ptr)350 bool IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
351 if (ms_context_ptr == nullptr) {
352 MS_LOG(EXCEPTION) << "nullptr";
353 }
354 return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0;
355 }
356
IsGeInited(const std::shared_ptr<MsContext> & ms_context_ptr)357 bool IsGeInited(const std::shared_ptr<MsContext> &ms_context_ptr) {
358 if (ms_context_ptr == nullptr) {
359 MS_LOG(EXCEPTION) << "nullptr";
360 }
361 return ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) > 0;
362 }
363
364 // Register for device type.
365 struct DeviceTypeSetRegister {
DeviceTypeSetRegistermindspore::context::DeviceTypeSetRegister366 DeviceTypeSetRegister() {
367 MsContext::device_type_seter([](std::shared_ptr<MsContext> &device_type_seter) {
368 #ifdef ENABLE_GE
369 device_type_seter.reset(new (std::nothrow) MsContext("ge", kAscendDevice));
370 #elif defined(ENABLE_D)
371 device_type_seter.reset(new (std::nothrow) MsContext("ms", kAscendDevice));
372 #elif defined(ENABLE_GPU)
373 device_type_seter.reset(new (std::nothrow) MsContext("ms", kGPUDevice));
374 #else
375 device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
376 #endif
377 });
378 }
379 DeviceTypeSetRegister(const DeviceTypeSetRegister &) = delete;
380 DeviceTypeSetRegister &operator=(const DeviceTypeSetRegister &) = delete;
381 ~DeviceTypeSetRegister() = default;
382 } device_type_set_regsiter;
383 } // namespace context
384 } // namespace mindspore
385