• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/plugin/plugin_loader.h"
18 
19 #include <algorithm>
20 #include <numeric>
21 #include <set>
22 #include <vector>
23 
24 #include "minddata/dataset/plugin/shared_lib_util.h"
25 
26 namespace mindspore {
27 namespace dataset {
GetInstance()28 PluginLoader *PluginLoader::GetInstance() noexcept {
29   static PluginLoader pl;
30   return &pl;
31 }
32 
~PluginLoader()33 PluginLoader::~PluginLoader() {
34   std::vector<std::string> keys;
35   // get the keys from map, this is to avoid concurrent iteration and delete
36   std::transform(plugins_.begin(), plugins_.end(), std::back_inserter(keys), [](const auto &p) { return p.first; });
37   for (std::string &key : keys) {
38     Status rc = UnloadPlugin(key);
39     MSLOG_IF(MsLogLevel::kError, rc.IsError(), mindspore::NoExceptionType, nullptr) << rc.ToString();
40   }
41 }
42 
43 // LoadPlugin() is NOT thread-safe. It is supposed to be called when Ops are being built. E.g. PluginOp should call this
44 // within constructor instead of in its Compute() which is parallel.
LoadPlugin(const std::string & filename,plugin::PluginManagerBase ** singleton_plugin)45 Status PluginLoader::LoadPlugin(const std::string &filename, plugin::PluginManagerBase **singleton_plugin) {
46   RETURN_UNEXPECTED_IF_NULL(singleton_plugin);
47   auto itr = plugins_.find(filename);
48   // return ok if this module is already loaded
49   if (itr != plugins_.end()) {
50     *singleton_plugin = itr->second.first;
51     return Status::OK();
52   }
53   // Open the .so file
54   void *handle = SharedLibUtil::Load(filename);
55   CHECK_FAIL_RETURN_UNEXPECTED(handle != nullptr,
56                                "[Internal ERROR] Fail to load:" + filename + ".\n" + SharedLibUtil::ErrMsg());
57 
58   // Load GetInstance function ptr from the so file, so needs to be compiled with -fPIC
59   void *func_handle = SharedLibUtil::FindSym(handle, "GetInstance");
60   CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr,
61                                "[Internal ERROR] Fail to find GetInstance()\n" + SharedLibUtil::ErrMsg());
62 
63   // cast the returned function ptr of type void* to the type of GetInstance
64   plugin::PluginManagerBase *(*get_instance)(plugin::MindDataManagerBase *) =
65     reinterpret_cast<plugin::PluginManagerBase *(*)(plugin::MindDataManagerBase *)>(func_handle);
66   RETURN_UNEXPECTED_IF_NULL(get_instance);
67 
68   *singleton_plugin = get_instance(nullptr);  // call function ptr to get instance
69   RETURN_UNEXPECTED_IF_NULL(*singleton_plugin);
70 
71   std::string v1 = (*singleton_plugin)->GetPluginVersion();
72   std::string v2(plugin::kSharedIncludeVersion);
73   if (v1 != v2) {
74     std::string err_msg = "[Internal ERROR] expected:" + v2 + ", received:" + v1 + " please recompile.";
75     if (SharedLibUtil::Close(handle) != 0) {
76       err_msg += ("\ndlclose() error, err_msg:" + SharedLibUtil::ErrMsg() + ".");
77     }
78     RETURN_STATUS_UNEXPECTED(err_msg);
79   }
80 
81   const std::map<std::string, std::set<std::string>> module_names = (*singleton_plugin)->GetModuleNames();
82   for (auto &p : module_names) {
83     std::string msg = "Plugin " + p.first + " has module:";
84     MS_LOG(DEBUG) << std::accumulate(p.second.begin(), p.second.end(), msg,
85                                      [](const std::string &msg, const std::string &nm) { return msg + " " + nm; });
86   }
87 
88   // save the name and handle
89   std::pair<plugin::PluginManagerBase *, void *> plugin_new = std::make_pair(*singleton_plugin, handle);
90   plugins_.insert({filename, plugin_new});
91   return Status::OK();
92 }
93 
UnloadPlugin(const std::string & filename)94 Status PluginLoader::UnloadPlugin(const std::string &filename) {
95   auto itr = plugins_.find(filename);
96   RETURN_OK_IF_TRUE(itr == plugins_.end());  // return true if this plugin was never loaded or already removed
97 
98   void *func_handle = SharedLibUtil::FindSym(itr->second.second, "DestroyInstance");
99   CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr,
100                                "[Internal ERROR] Fail to find DestroyInstance()\n" + SharedLibUtil::ErrMsg());
101 
102   void (*destroy_instance)() = reinterpret_cast<void (*)()>(func_handle);
103   RETURN_UNEXPECTED_IF_NULL(destroy_instance);
104 
105   destroy_instance();
106   CHECK_FAIL_RETURN_UNEXPECTED(SharedLibUtil::Close(itr->second.second) == 0,
107                                "[Internal ERROR] dlclose() error: " + SharedLibUtil::ErrMsg());
108 
109   plugins_.erase(filename);
110   return Status::OK();
111 }
112 }  // namespace dataset
113 }  // namespace mindspore
114