• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_MANAGER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_MANAGER_H_
19 
20 #include <set>
21 #include <any>
22 #include <map>
23 #include <string>
24 #include <memory>
25 #include <utility>
26 #include <functional>
27 #include <mutex>
28 #include <vector>
29 #include "runtime/hardware/device_context.h"
30 #include "include/backend/visible.h"
31 #include "include/common/pybind_api/api_register.h"
32 
33 namespace mindspore {
34 namespace plugin_loader {
35 class PluginLoader {
36  public:
37   static bool LoadDynamicLib(const std::string &plugin_file, std::map<std::string, void *> *all_handles,
38                              std::stringstream *err_msg, const bool gpu_env = false);
39   static void CloseDynamicLib(const std::string &dl_name, void *handle);
40   static bool GetPluginPath(std::string *file_path);
41 
42  private:
43   static std::string GetDynamicLibName(const std::string &plugin_file);
44 };
45 }  // namespace plugin_loader
46 
47 namespace device {
48 using DeviceContextCreator = std::function<std::shared_ptr<DeviceContext>(const DeviceContextKey &)>;
49 
50 // This callback registers stateless functions to _c_expression. It is set by different device contexts.
51 using RegisterStatelessFuncCb = std::function<void(py::module *m)>;
52 
53 const DeviceContext *FetchRealDeviceContext(const AnfNodePtr &node, const DeviceContext *device_context);
54 
55 class BACKEND_EXPORT DeviceContextManager {
56  public:
57   ~DeviceContextManager() = default;
58   static DeviceContextManager &GetInstance();
59   void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator);
60   DeviceContext *GetOrCreateDeviceContext(const DeviceContextKey &device_context_key);
61   // Return the device context of the specified device target.
62   // The difference between this method and 'GetOrCreateDeviceContext' is this method only query device context by
63   // device target(without device id) since MindSpore only supports 'single process, single device'.
64   DeviceContextPtr GetDeviceContext(const std::string &device_target);
65   void UpdateDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key);
66   void ClearDeviceContexts();
67   void ChildAfterFork();
68   void WaitTaskFinishOnDevice() const;
69   void SyncAllStreams() const;
70   void UnloadPlugin();
71   std::string GetErrorMsg() const;
72   void BindDeviceCtx() const;
73 
74   // For different device backends, some methods are stateless. They have to be registered to `DeviceContextManager`.
75   void SetRegisterDeviceStatelessFuncCb(const std::string &backend, const RegisterStatelessFuncCb &register_func_cb);
76   void RegisterDeviceStatelessFunc(py::module *m);
77 
78  private:
79   DeviceContextManager() = default;
80   DISABLE_COPY_AND_ASSIGN(DeviceContextManager);
81   void LoadPlugin();
82   bool SelectGpuPlugin(const std::string &cuda_home, const std::set<std::string> &file_names);
83 
84   std::map<std::string, void *> plugin_maps_;
85   bool load_init_;
86   std::string plugin_path_;
87 
88   // The string converted from DeviceContextKey -> DeviceContextPtr.
89   std::map<std::string, DeviceContextPtr> device_contexts_;
90   // The name of device -> vector of DeviceContextPtr.
91   std::map<std::string, DeviceContextPtr> backend_to_device_context_;
92   // The name of device -> DeviceContextCreator.
93   std::map<std::string, DeviceContextCreator> device_context_creators_;
94   // record error message of dlopen, print when create device_context failed.
95   std::stringstream dlopen_error_msg_;
96 
97   // Backend name->register stateless functions callback.
98   std::map<std::string, RegisterStatelessFuncCb> register_func_cbs_;
99 };
100 
101 class BACKEND_EXPORT DeviceContextRegister {
102  public:
DeviceContextRegister(const std::string & device_name,DeviceContextCreator && runtime_creator)103   DeviceContextRegister(const std::string &device_name, DeviceContextCreator &&runtime_creator) {
104     DeviceContextManager::GetInstance().Register(device_name, std::move(runtime_creator));
105   }
106   ~DeviceContextRegister() = default;
107 };
108 
109 #define MS_REGISTER_DEVICE(DEVICE_NAME, DEVICE_CONTEXT_CLASS)            \
110   static const DeviceContextRegister g_device_##DEVICE_NAME##_reg(       \
111     DEVICE_NAME, [](const DeviceContextKey &device_context_key) {        \
112       return std::make_shared<DEVICE_CONTEXT_CLASS>(device_context_key); \
113     });
114 
115 class BACKEND_EXPORT StatelessFuncCbRegister {
116  public:
StatelessFuncCbRegister(const std::string & device_name,const RegisterStatelessFuncCb & func)117   StatelessFuncCbRegister(const std::string &device_name, const RegisterStatelessFuncCb &func) {
118     DeviceContextManager::GetInstance().SetRegisterDeviceStatelessFuncCb(device_name, func);
119   }
120   ~StatelessFuncCbRegister() = default;
121 };
122 
123 #define REGISTER_DEV_STATELESS_FUNC_CB(DEVICE_NAME, FUNC_OBJ) \
124   static const StatelessFuncCbRegister g_##DEVICE_NAME##_stateless_func_cb_reg(DEVICE_NAME, FUNC_OBJ)
125 }  // namespace device
126 }  // namespace mindspore
127 #endif  // MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_MANAGER_H_
128