1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/hardware/device_context_manager.h"
18 #if defined(_WIN32) || defined(_WIN64)
19 #include <windows.h>
20 #endif
21 #ifdef __linux__
22 #include <sys/wait.h>
23 #endif // #ifdef __linux__
24 #include <dirent.h>
25 #include <algorithm>
26 #include <string>
27 #include <set>
28 #include <fstream>
29 #include "utils/ms_context.h"
30 #include "utils/dlopen_macro.h"
31 #include "utils/os.h"
32
33 namespace mindspore {
34 namespace {
GetStrLen(const char * const str)35 size_t constexpr GetStrLen(const char *const str) {
36 if (*str == '\0') {
37 return 0;
38 } else {
39 return GetStrLen(str + 1) + 1;
40 }
41 }
42
43 constexpr auto kCudaHomeEnv = "CUDA_HOME";
44 constexpr auto kNvccVersionKeyWords = "Cuda compilation tools, release ";
45 constexpr size_t kNvccVersionKeyWordsSize = GetStrLen(kNvccVersionKeyWords);
46 constexpr auto kSuccessKeyWord = "Success";
47 constexpr size_t kSuccessKeyWordSize = GetStrLen(kSuccessKeyWord);
48 constexpr size_t kBufferSize = 999;
49 constexpr auto kGpuPluginName = "libmindspore_gpu";
50
51 #ifdef __linux__
52 class FdScope {
53 public:
FdScope(int fd)54 explicit FdScope(int fd) : fd_(fd) {}
~FdScope()55 ~FdScope() { (void)close(fd_); }
56
57 private:
58 int fd_;
59 };
60
GetNvccRealPath(const std::string & cuda_path)61 std::string GetNvccRealPath(const std::string &cuda_path) {
62 auto nvcc_path = cuda_path + "/bin/nvcc";
63 char real_path_buffer[PATH_MAX];
64 if (realpath(nvcc_path.c_str(), real_path_buffer) == nullptr) {
65 MS_LOG(WARNING) << "Invalid environment variable CUDA_HOME [" << cuda_path << "], can not find nvcc file ["
66 << nvcc_path << "], please check the CUDA_HOME.";
67 return "";
68 }
69 return real_path_buffer;
70 }
71
GetCudaVersionFromNvcc(const std::string & nvcc_path)72 std::string GetCudaVersionFromNvcc(const std::string &nvcc_path) {
73 int pipe_fd[2];
74 if (pipe(pipe_fd) != 0) {
75 MS_LOG(ERROR) << "Create pipe failed, ret = " << errno << ", reason = " << strerror(errno);
76 return "";
77 }
78 FdScope fd0(pipe_fd[0]);
79 FdScope fd1(pipe_fd[1]);
80 pid_t pid = fork();
81 if (pid < 0) {
82 MS_LOG(ERROR) << "Fork child process failed, ret = " << errno << ", reason = " << strerror(errno);
83 return "";
84 } else if (pid == 0) { // child process
85 (void)dup2(pipe_fd[1], STDOUT_FILENO);
86 MS_LOG(DEBUG) << "Start exec " << nvcc_path << " --version";
87 if (execl(nvcc_path.c_str(), "nvcc", "--version", nullptr) == -1) {
88 MS_LOG(ERROR) << "Get cuda version from " << nvcc_path << " failed, ret = " << errno
89 << ", reason = " << strerror(errno);
90 exit(-1);
91 }
92 } else { // parent process
93 MS_LOG(DEBUG) << "Child process NVCC pid = " << pid;
94 int status;
95 std::string buffer(kBufferSize, 0);
96 if (waitpid(pid, &status, 0) == -1) {
97 MS_LOG(ERROR) << "Wait child process failed, ret = " << errno << ", reason = " << strerror(errno);
98 return "";
99 }
100 if (auto read_size = read(pipe_fd[0], buffer.data(), buffer.size()); read_size <= 0) {
101 MS_LOG(WARNING) << "Read from pipe failed, ret = " << errno << ", reason = " << strerror(errno);
102 return "";
103 } else {
104 buffer.resize(read_size);
105 }
106
107 MS_LOG(DEBUG) << "Child process return: " << buffer;
108 auto pos = buffer.find(kNvccVersionKeyWords);
109 if (pos == std::string::npos) {
110 MS_LOG(ERROR) << "Cannot found nvcc version key words [" << kNvccVersionKeyWords << "], nvcc return: " << buffer;
111 return "";
112 }
113 auto tmp_str = buffer.substr(pos + kNvccVersionKeyWordsSize);
114 pos = tmp_str.find_first_of(',');
115 if (pos == std::string::npos) {
116 MS_LOG(ERROR) << "Cannot found nvcc version key word \',\', nvcc return: " << tmp_str;
117 return "";
118 }
119 auto version_str = tmp_str.substr(0, pos);
120 MS_LOG(INFO) << "Get cuda version [" << version_str << "] from env CUDA_HOME.";
121 return version_str;
122 }
123 return ""; // useless code makes static checking tools happy.
124 }
125
126 // only support version str that format is "a.b"
GetIntVersionFromVersionStr(const std::string & version_str,size_t * major,size_t * minor)127 bool GetIntVersionFromVersionStr(const std::string &version_str, size_t *major, size_t *minor) {
128 MS_EXCEPTION_IF_NULL(major);
129 MS_EXCEPTION_IF_NULL(minor);
130 size_t major_num = 0;
131 size_t minor_num = 0;
132 auto dot_pos = version_str.find('.');
133 if (dot_pos == std::string::npos) {
134 return false;
135 }
136 std::string minor_str = version_str.substr(dot_pos + 1);
137 std::string major_str = version_str.substr(0, dot_pos);
138 try {
139 major_num = std::stoull(major_str);
140 minor_num = std::stoull(minor_str);
141 } catch (...) {
142 return false;
143 }
144 *major = major_num;
145 *minor = minor_num;
146 return true;
147 }
148
GetVersionFromFileName(const std::string & file_name,size_t * major,size_t * minor)149 bool GetVersionFromFileName(const std::string &file_name, size_t *major, size_t *minor) {
150 MS_EXCEPTION_IF_NULL(major);
151 MS_EXCEPTION_IF_NULL(minor);
152 auto dot_pos = file_name.find_last_of('.');
153 if (dot_pos == std::string::npos) {
154 return false;
155 }
156 std::string minor_str = file_name.substr(dot_pos + 1);
157 std::string remain_str = file_name.substr(0, dot_pos);
158 dot_pos = remain_str.find_last_of('.');
159 if (dot_pos == std::string::npos) {
160 return false;
161 }
162 std::string major_str = file_name.substr(dot_pos + 1);
163 if (!std::any_of(minor_str.begin(), minor_str.end(), [](char c) { return std::isdigit(c); })) {
164 return false;
165 }
166 if (!std::any_of(major_str.begin(), major_str.end(), [](char c) { return std::isdigit(c); })) {
167 return false;
168 }
169 return GetIntVersionFromVersionStr(major_str + "." + minor_str, major, minor);
170 }
171
VersionToFloat(size_t major,size_t minor)172 float VersionToFloat(size_t major, size_t minor) {
173 return SizeToFloat(major) + SizeToFloat(minor) / (SizeToFloat(std::to_string(minor).size()) + 1);
174 }
175 #endif // #ifdef __linux__
176 } // namespace
177 namespace plugin_loader {
LoadDynamicLib(const std::string & plugin_file,std::map<std::string,void * > * all_handles,std::stringstream * err_msg,const bool gpu_env)178 bool PluginLoader::LoadDynamicLib(const std::string &plugin_file, std::map<std::string, void *> *all_handles,
179 std::stringstream *err_msg, const bool gpu_env) {
180 MS_EXCEPTION_IF_NULL(all_handles);
181 MS_EXCEPTION_IF_NULL(err_msg);
182 void *handle = nullptr;
183 std::string err_msg_str;
184 auto so_name = GetDynamicLibName(plugin_file);
185 #if defined(_WIN32) || defined(_WIN64)
186 handle = LoadLibraryEx(plugin_file.c_str(), NULL, LOAD_WITH_ALTERED_SEARCH_PATH);
187 err_msg_str = std::to_string(GetLastError());
188 #else
189 handle = dlopen(plugin_file.c_str(), RTLD_LAZY | RTLD_LOCAL);
190 err_msg_str = GetDlErrorMsg();
191 #endif
192 if (handle == nullptr) {
193 MS_LOG(INFO) << "Load dynamic library: " << so_name << " failed. " << err_msg_str;
194 *err_msg << "Load dynamic library: " << so_name << " failed. " << err_msg_str << std::endl;
195 return false;
196 }
197 (*all_handles)[so_name] = handle;
198 return true;
199 }
200
CloseDynamicLib(const std::string & dl_name,void * handle)201 void PluginLoader::CloseDynamicLib(const std::string &dl_name, void *handle) {
202 #if defined(_WIN32) || defined(_WIN64)
203 if (!FreeLibrary(static_cast<HMODULE>(handle))) {
204 MS_LOG(EXCEPTION) << "Closing dynamic lib: " + dl_name + " handle failed. Error: " + std::to_string(GetLastError());
205 }
206
207 #else
208 if (dlclose(handle) != 0) {
209 MS_LOG(ERROR) << "Closing dynamic lib: " << dl_name << "failed, error message: " << GetDlErrorMsg();
210 }
211 #endif
212 }
213
GetDynamicLibName(const std::string & plugin_file)214 std::string PluginLoader::GetDynamicLibName(const std::string &plugin_file) {
215 auto p1 = plugin_file.find_last_of(PATH_SEPARATOR) + 1;
216 auto target_so = plugin_file.substr(p1);
217 return target_so;
218 }
219
GetPluginPath(std::string * file_path)220 bool PluginLoader::GetPluginPath(std::string *file_path) {
221 MS_EXCEPTION_IF_NULL(file_path);
222 std::string cur_so_path;
223 #if !defined(_WIN32) && !defined(_WIN64)
224 Dl_info dl_info;
225 if (dladdr(reinterpret_cast<void *>(PluginLoader::GetPluginPath), &dl_info) == 0) {
226 MS_LOG(INFO) << "Get dladdr error";
227 return false;
228 }
229 cur_so_path = dl_info.dli_fname;
230 #else
231 HMODULE hModule = nullptr;
232 if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT | GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
233 (LPCSTR)PluginLoader::GetPluginPath, &hModule) == 0) {
234 MS_LOG(INFO) << "Get GetModuleHandleEx failed.";
235 return false;
236 }
237 char szPath[MAX_PATH];
238 if (GetModuleFileName(hModule, szPath, sizeof(szPath)) == 0) {
239 MS_LOG(INFO) << "Get GetModuleHandleEx failed.";
240 return false;
241 }
242 cur_so_path = std::string(szPath);
243 #endif
244 auto pos = cur_so_path.find_last_of(PATH_SEPARATOR);
245 if (cur_so_path.empty() || pos == std::string::npos) {
246 MS_LOG(INFO) << "Current so path empty or the path [" << cur_so_path << "] is invalid.";
247 return false;
248 }
249 #ifndef _WIN32
250 auto plugin_so_path = cur_so_path.substr(0, pos) + "/plugin";
251 #else
252 auto plugin_so_path = cur_so_path.substr(0, pos);
253 #endif
254 if (plugin_so_path.size() >= PATH_MAX) {
255 MS_LOG(INFO) << "Current path [" << plugin_so_path << "] is invalid.";
256 return false;
257 }
258 char real_path_mem[PATH_MAX] = {0};
259 #if defined(_WIN32) || defined(_WIN64)
260 if (_fullpath(real_path_mem, common::SafeCStr(plugin_so_path), PATH_MAX) == nullptr) {
261 MS_LOG(INFO) << "Plugin path is invalid: [" << plugin_so_path << "], skip!";
262 return false;
263 }
264 #else
265 if (realpath(common::SafeCStr(plugin_so_path), real_path_mem) == nullptr) {
266 MS_LOG(INFO) << "Plugin path is invalid: [" << plugin_so_path << "], skip!";
267 return false;
268 }
269 #endif
270 *file_path = std::string(real_path_mem);
271 return true;
272 }
273 } // namespace plugin_loader
274
275 namespace device {
FetchRealDeviceContext(const AnfNodePtr & node,const DeviceContext * device_context)276 const DeviceContext *FetchRealDeviceContext(const AnfNodePtr &node, const DeviceContext *device_context) {
277 MS_EXCEPTION_IF_NULL(node);
278 MS_EXCEPTION_IF_NULL(device_context);
279
280 std::string target = "";
281 auto ud_target = node->user_data<std::string>(kAttrPrimitiveTarget);
282 if (ud_target != nullptr) {
283 target = *ud_target;
284 } else if (node->isa<CNode>()) {
285 auto cnode = node->cast<CNodePtr>();
286 MS_EXCEPTION_IF_NULL(cnode);
287 if (common::AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
288 target = common::AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrPrimitiveTarget);
289 }
290 }
291
292 if (target.empty() || (target == device_context->device_context_key().device_name_)) {
293 return device_context;
294 }
295
296 const auto &real_device_context = DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
297 {target, device_context->device_context_key().device_id_});
298 MS_EXCEPTION_IF_NULL(real_device_context);
299 real_device_context->Initialize();
300 return real_device_context;
301 }
302
GetInstance()303 DeviceContextManager &DeviceContextManager::GetInstance() {
304 static DeviceContextManager instance{};
305 #ifdef WITH_BACKEND
306 instance.LoadPlugin();
307 #endif
308 return instance;
309 }
310
Register(const std::string & device_name,DeviceContextCreator && device_context_creator)311 void DeviceContextManager::Register(const std::string &device_name, DeviceContextCreator &&device_context_creator) {
312 if (device_context_creators_.find(device_name) == device_context_creators_.end()) {
313 (void)device_context_creators_.emplace(device_name, device_context_creator);
314 }
315 }
316
LoadPlugin()317 void DeviceContextManager::LoadPlugin() {
318 if (load_init_) {
319 return;
320 }
321 load_init_ = true;
322 MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
323 MsContext::GetInstance()->ResisterLoadPluginErrorFunc(
324 []() -> std::string { return DeviceContextManager::GetInstance().GetErrorMsg(); });
325 if (plugin_path_.empty() && !plugin_loader::PluginLoader::GetPluginPath(&plugin_path_)) {
326 MS_LOG(INFO) << "Plugin path is invalid, skip!";
327 load_init_ = true;
328 dlopen_error_msg_ << "Plugin path is invalid, skip!" << std::endl;
329 return;
330 }
331 #ifdef _WIN32
332 auto plugin_file = plugin_path_ + "\\mindspore_gpu.dll";
333 if (access(plugin_file.c_str(), F_OK) != -1) {
334 (void)plugin_loader::PluginLoader::LoadDynamicLib(plugin_file, &plugin_maps_, &dlopen_error_msg_);
335 }
336 #else
337 DIR *dir = opendir(plugin_path_.c_str());
338 if (dir == nullptr) {
339 MS_LOG(ERROR) << "Open plugin dir failed, plugin path:" << plugin_path_;
340 load_init_ = true;
341 dlopen_error_msg_ << "Open plugin dir failed, plugin path:" << plugin_path_ << std::endl;
342 return;
343 }
344 struct dirent *entry;
345 std::map<std::string, std::set<std::string>> multi_version_plugin_map; // key: plugin name, value: so file name
346 while ((entry = readdir(dir)) != nullptr) {
347 auto plugin_file = plugin_path_ + PATH_SEPARATOR + entry->d_name;
348 if (plugin_file.find("libmindspore_") == std::string::npos) {
349 continue;
350 }
351 std::string file_name = entry->d_name;
352 auto dot = file_name.find_first_of(".");
353 if (dot == std::string::npos) {
354 continue;
355 }
356 (void)multi_version_plugin_map[file_name.substr(0, dot)].insert(plugin_file);
357 }
358
359 for (const auto &[plugin_name, file_names] : multi_version_plugin_map) {
360 // if we can confirm the platform is gpu, we should directly dlopen gpu_plugin file instead of trying.
361 if (plugin_name == kGpuPluginName) {
362 std::string cuda_home = common::GetEnv(kCudaHomeEnv);
363 if (cuda_home.empty()) {
364 MS_LOG(INFO) << "Please set env CUDA_HOME to path of cuda, if you want to enable gpu backend.";
365 continue;
366 } else if (SelectGpuPlugin(cuda_home, file_names)) {
367 break;
368 }
369 }
370 for (auto iter = file_names.rbegin(); iter != file_names.rend();) {
371 const auto &file_name = *(iter++);
372 auto ret = plugin_loader::PluginLoader::LoadDynamicLib(file_name, &plugin_maps_, &dlopen_error_msg_);
373 if (ret) {
374 if (iter != file_names.rend()) {
375 MS_LOG(INFO) << "Load " << plugin_name << " plugin file " << file_name
376 << " success, skip loading other version.";
377 }
378 break;
379 }
380 }
381 }
382 (void)closedir(dir);
383 #endif
384 }
385
UnloadPlugin()386 void DeviceContextManager::UnloadPlugin() {
387 if (plugin_maps_.empty()) {
388 return;
389 }
390 device_context_creators_.clear();
391 auto iter = plugin_maps_.begin();
392 while (iter != plugin_maps_.end()) {
393 plugin_loader::PluginLoader::CloseDynamicLib(iter->first, iter->second);
394 (void)iter++;
395 }
396 plugin_maps_.clear();
397 }
398
ClearDeviceContexts()399 void DeviceContextManager::ClearDeviceContexts() {
400 for (auto &iter : device_contexts_) {
401 MS_LOG(INFO) << "Release device " << iter.first;
402 MS_EXCEPTION_IF_NULL(iter.second);
403 iter.second->Destroy();
404 }
405 backend_to_device_context_.clear();
406 device_contexts_.clear();
407 }
408
ChildAfterFork()409 void DeviceContextManager::ChildAfterFork() {
410 MS_LOG(DEBUG) << "DeviceContextManager reinitialize after fork.";
411 MS_LOG(DEBUG) << "Clear device_contexts_.";
412 device_contexts_.clear();
413 MS_LOG(DEBUG) << "DeviceContextManager reinitialize after fork done.";
414 }
415
BindDeviceCtx() const416 void DeviceContextManager::BindDeviceCtx() const {
417 for (auto &iter : device_contexts_) {
418 MS_EXCEPTION_IF_NULL(iter.second);
419 MS_EXCEPTION_IF_NULL(iter.second->device_res_manager_);
420 if (!iter.second->device_res_manager_->BindDeviceToCurrentThread(true)) {
421 MS_LOG(EXCEPTION) << "Bind device failed";
422 }
423 }
424 }
425
SetRegisterDeviceStatelessFuncCb(const std::string & backend,const RegisterStatelessFuncCb & register_func_cb)426 void DeviceContextManager::SetRegisterDeviceStatelessFuncCb(const std::string &backend,
427 const RegisterStatelessFuncCb ®ister_func_cb) {
428 register_func_cbs_[backend] = register_func_cb;
429 }
430
RegisterDeviceStatelessFunc(py::module * m)431 void DeviceContextManager::RegisterDeviceStatelessFunc(py::module *m) {
432 for (const auto &f : register_func_cbs_) {
433 const auto ®ister_cb = f.second;
434 if (register_cb) {
435 register_cb(m);
436 }
437 }
438 }
439
GetOrCreateDeviceContext(const DeviceContextKey & device_context_key)440 DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContextKey &device_context_key) {
441 std::string device_context_key_str = device_context_key.ToString();
442 std::string name = device_context_key.device_name_;
443
444 auto device_context_iter = device_contexts_.find(device_context_key_str);
445 if (device_context_iter != device_contexts_.end()) {
446 return device_context_iter->second.get();
447 }
448
449 std::shared_ptr<DeviceContext> device_context;
450 auto creator_iter = device_context_creators_.find(name);
451 if (creator_iter != device_context_creators_.end()) {
452 device_context = (creator_iter->second)(device_context_key);
453 MS_EXCEPTION_IF_NULL(device_context);
454 device_contexts_[device_context_key_str] = device_context;
455 backend_to_device_context_[name] = device_context;
456 } else {
457 MS_LOG(EXCEPTION) << "Create device context failed, please make sure target device:" << name
458 << " is available, error message of loading plugins: " << std::endl
459 << GetErrorMsg();
460 }
461 return device_context.get();
462 }
463
GetDeviceContext(const std::string & device_target)464 DeviceContextPtr DeviceContextManager::GetDeviceContext(const std::string &device_target) {
465 if (backend_to_device_context_.count(device_target) == 0) {
466 MS_LOG(INFO) << "Device context of device " << device_target << " is not created yet.";
467 return nullptr;
468 }
469 return backend_to_device_context_[device_target];
470 }
471
UpdateDeviceContextKey(const DeviceContextKey & old_key,const DeviceContextKey & new_key)472 void DeviceContextManager::UpdateDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key) {
473 std::string old_key_str = old_key.ToString();
474 std::string new_key_str = new_key.ToString();
475
476 auto handle = device_contexts_.extract(old_key_str);
477 if (handle.empty()) {
478 MS_LOG(EXCEPTION) << "Can not find device context for: " << old_key_str;
479 }
480
481 handle.key() = new_key_str;
482 (void)device_contexts_.insert(std::move(handle));
483 }
484
WaitTaskFinishOnDevice() const485 void DeviceContextManager::WaitTaskFinishOnDevice() const {
486 for (const auto &item : device_contexts_) {
487 auto device_context = item.second;
488 try {
489 if (device_context != nullptr && !device_context->device_res_manager_->SyncAllStreams()) {
490 MS_LOG(ERROR) << "SyncStream failed";
491 return;
492 }
493 } catch (const std::exception &ex) {
494 MS_LOG(ERROR) << "SyncStream failed, exception:" << ex.what();
495 return;
496 }
497 }
498 }
499
SyncAllStreams() const500 void DeviceContextManager::SyncAllStreams() const {
501 for (const auto &item : device_contexts_) {
502 auto device_context = item.second;
503 if (device_context != nullptr && !device_context->device_res_manager_->SyncAllStreams()) {
504 MS_LOG(EXCEPTION) << "SyncStream failed, device info: " << device_context->device_context_key().ToString();
505 }
506 }
507 }
508
GetErrorMsg() const509 std::string DeviceContextManager::GetErrorMsg() const { return dlopen_error_msg_.str(); }
510
SelectGpuPlugin(const std::string & cuda_home,const std::set<std::string> & file_names)511 bool DeviceContextManager::SelectGpuPlugin(const std::string &cuda_home, const std::set<std::string> &file_names) {
512 #ifdef __linux__
513 bool ret;
514 if (file_names.size() == 1) {
515 ret = plugin_loader::PluginLoader::LoadDynamicLib(*file_names.begin(), &plugin_maps_, &dlopen_error_msg_, true);
516 } else {
517 auto nvcc_path = GetNvccRealPath(cuda_home);
518 if (nvcc_path.empty()) {
519 return false;
520 }
521 auto cuda_version = GetCudaVersionFromNvcc(nvcc_path);
522 if (cuda_version.empty()) {
523 return false;
524 }
525 size_t target_major = 0;
526 size_t target_minor = 0;
527 if (!GetIntVersionFromVersionStr(cuda_version, &target_major, &target_minor)) {
528 MS_LOG(EXCEPTION) << "Get version num from version string " << cuda_version << " failed.";
529 }
530
531 std::string selected_plugin = "";
532 std::vector<std::pair<size_t, size_t>> all_plugin_version;
533 std::vector<std::string> all_plugin_path;
534 std::for_each(file_names.begin(), file_names.end(),
535 [&selected_plugin, &all_plugin_version, &all_plugin_path, target_major,
536 target_minor](const std::string &file_name) {
537 size_t current_major = 0;
538 size_t current_minor = 0;
539 if (GetVersionFromFileName(file_name, ¤t_major, ¤t_minor)) {
540 all_plugin_version.emplace_back(current_major, current_minor);
541 all_plugin_path.emplace_back(file_name);
542 }
543 if (current_major == target_major && current_minor == target_minor) {
544 selected_plugin = file_name;
545 }
546 });
547
548 if (selected_plugin.empty()) {
549 for (size_t i = 0; i < all_plugin_version.size(); ++i) {
550 if (target_major != all_plugin_version[i].first) {
551 continue;
552 }
553 if (VersionToFloat(target_major, target_minor) >
554 VersionToFloat(all_plugin_version[i].first, all_plugin_version[i].second) &&
555 (i + 1 >= all_plugin_version.size() ||
556 VersionToFloat(target_major, target_minor) <
557 VersionToFloat(all_plugin_version[i + 1].first, all_plugin_version[i + 1].second))) {
558 selected_plugin = all_plugin_path[i];
559 }
560 }
561 }
562
563 if (selected_plugin.empty()) {
564 MS_LOG(WARNING) << "Env CUDA_HOME is " << cuda_home << ", but can not find suitable gpu plugin.";
565 return false;
566 }
567
568 ret = plugin_loader::PluginLoader::LoadDynamicLib(selected_plugin, &plugin_maps_, &dlopen_error_msg_, true);
569 }
570 if (!ret) {
571 MS_LOG(WARNING) << "Env CUDA_HOME is " << cuda_home
572 << ", but dlopen file_name failed, reason: " << dlopen_error_msg_.str();
573 return false;
574 }
575 return true;
576 #endif // #ifdef __linux__
577 return false;
578 }
579 } // namespace device
580 } // namespace mindspore
581