1 /** 2 * Copyright 2021-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_MANAGER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_MANAGER_H_ 19 20 #include <set> 21 #include <any> 22 #include <map> 23 #include <string> 24 #include <memory> 25 #include <utility> 26 #include <functional> 27 #include <mutex> 28 #include <vector> 29 #include "runtime/hardware/device_context.h" 30 #include "include/backend/visible.h" 31 #include "include/common/pybind_api/api_register.h" 32 33 namespace mindspore { 34 namespace plugin_loader { 35 class PluginLoader { 36 public: 37 static bool LoadDynamicLib(const std::string &plugin_file, std::map<std::string, void *> *all_handles, 38 std::stringstream *err_msg, const bool gpu_env = false); 39 static void CloseDynamicLib(const std::string &dl_name, void *handle); 40 static bool GetPluginPath(std::string *file_path); 41 42 private: 43 static std::string GetDynamicLibName(const std::string &plugin_file); 44 }; 45 } // namespace plugin_loader 46 47 namespace device { 48 using DeviceContextCreator = std::function<std::shared_ptr<DeviceContext>(const DeviceContextKey &)>; 49 50 // This callback registers stateless functions to _c_expression. It is set by different device contexts. 51 using RegisterStatelessFuncCb = std::function<void(py::module *m)>; 52 53 const DeviceContext *FetchRealDeviceContext(const AnfNodePtr &node, const DeviceContext *device_context); 54 55 class BACKEND_EXPORT DeviceContextManager { 56 public: 57 ~DeviceContextManager() = default; 58 static DeviceContextManager &GetInstance(); 59 void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator); 60 DeviceContext *GetOrCreateDeviceContext(const DeviceContextKey &device_context_key); 61 // Return the device context of the specified device target. 62 // The difference between this method and 'GetOrCreateDeviceContext' is this method only query device context by 63 // device target(without device id) since MindSpore only supports 'single process, single device'. 64 DeviceContextPtr GetDeviceContext(const std::string &device_target); 65 void UpdateDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key); 66 void ClearDeviceContexts(); 67 void ChildAfterFork(); 68 void WaitTaskFinishOnDevice() const; 69 void SyncAllStreams() const; 70 void UnloadPlugin(); 71 std::string GetErrorMsg() const; 72 void BindDeviceCtx() const; 73 74 // For different device backends, some methods are stateless. They have to be registered to `DeviceContextManager`. 75 void SetRegisterDeviceStatelessFuncCb(const std::string &backend, const RegisterStatelessFuncCb ®ister_func_cb); 76 void RegisterDeviceStatelessFunc(py::module *m); 77 78 private: 79 DeviceContextManager() = default; 80 DISABLE_COPY_AND_ASSIGN(DeviceContextManager); 81 void LoadPlugin(); 82 bool SelectGpuPlugin(const std::string &cuda_home, const std::set<std::string> &file_names); 83 84 std::map<std::string, void *> plugin_maps_; 85 bool load_init_; 86 std::string plugin_path_; 87 88 // The string converted from DeviceContextKey -> DeviceContextPtr. 89 std::map<std::string, DeviceContextPtr> device_contexts_; 90 // The name of device -> vector of DeviceContextPtr. 91 std::map<std::string, DeviceContextPtr> backend_to_device_context_; 92 // The name of device -> DeviceContextCreator. 93 std::map<std::string, DeviceContextCreator> device_context_creators_; 94 // record error message of dlopen, print when create device_context failed. 95 std::stringstream dlopen_error_msg_; 96 97 // Backend name->register stateless functions callback. 98 std::map<std::string, RegisterStatelessFuncCb> register_func_cbs_; 99 }; 100 101 class BACKEND_EXPORT DeviceContextRegister { 102 public: DeviceContextRegister(const std::string & device_name,DeviceContextCreator && runtime_creator)103 DeviceContextRegister(const std::string &device_name, DeviceContextCreator &&runtime_creator) { 104 DeviceContextManager::GetInstance().Register(device_name, std::move(runtime_creator)); 105 } 106 ~DeviceContextRegister() = default; 107 }; 108 109 #define MS_REGISTER_DEVICE(DEVICE_NAME, DEVICE_CONTEXT_CLASS) \ 110 static const DeviceContextRegister g_device_##DEVICE_NAME##_reg( \ 111 DEVICE_NAME, [](const DeviceContextKey &device_context_key) { \ 112 return std::make_shared<DEVICE_CONTEXT_CLASS>(device_context_key); \ 113 }); 114 115 class BACKEND_EXPORT StatelessFuncCbRegister { 116 public: StatelessFuncCbRegister(const std::string & device_name,const RegisterStatelessFuncCb & func)117 StatelessFuncCbRegister(const std::string &device_name, const RegisterStatelessFuncCb &func) { 118 DeviceContextManager::GetInstance().SetRegisterDeviceStatelessFuncCb(device_name, func); 119 } 120 ~StatelessFuncCbRegister() = default; 121 }; 122 123 #define REGISTER_DEV_STATELESS_FUNC_CB(DEVICE_NAME, FUNC_OBJ) \ 124 static const StatelessFuncCbRegister g_##DEVICE_NAME##_stateless_func_cb_reg(DEVICE_NAME, FUNC_OBJ) 125 } // namespace device 126 } // namespace mindspore 127 #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_MANAGER_H_ 128