• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/kernel_runtime_manager.h"
18 #include "utils/log_adapter.h"
19 #if ((defined ENABLE_CPU) && (!defined _WIN32))
20 #include "ps/ps_cache/ps_cache_manager.h"
21 #endif
22 #include "backend/session/pynative_task_manager.h"
23 
24 namespace mindspore {
25 namespace device {
ClearRuntimeResource()26 void KernelRuntimeManager::ClearRuntimeResource() {
27   // Just remove PyNative tasks before runtime resource release.
28   session::PynativeTaskManager::GetInstance().Reset();
29 #if ((defined ENABLE_CPU) && (!defined _WIN32))
30   if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
31     ps::ps_cache_instance.SyncEmbeddingTable();
32   }
33 #endif
34   std::lock_guard<std::mutex> guard(lock_);
35   for (auto &iter : runtime_map_) {
36     MS_LOG(INFO) << "Release device " << iter.first;
37     MS_EXCEPTION_IF_NULL(iter.second);
38     iter.second->ReleaseDeviceRes();
39   }
40   runtime_map_.clear();
41 }
42 
ClearGraphResource(uint32_t graph_id)43 void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id) {
44   std::lock_guard<std::mutex> guard(lock_);
45   for (auto &iter : runtime_map_) {
46     MS_LOG(INFO) << "Clear device " << iter.first << " graph " << graph_id << " runtime resource";
47     if (!iter.second) {
48       MS_LOG(ERROR) << "Kernel runtime is nullptr";
49       continue;
50     }
51     iter.second->ClearGraphRuntimeResource(graph_id);
52   }
53 }
54 
Instance()55 KernelRuntimeManager &KernelRuntimeManager::Instance() {
56   static KernelRuntimeManager instance{};
57   return instance;
58 }
59 
Register(const std::string & device_name,KernelRuntimeCreator && runtime_creator)60 void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) {
61   if (runtime_creators_.find(device_name) == runtime_creators_.end()) {
62     (void)runtime_creators_.emplace(device_name, runtime_creator);
63   }
64 }
65 
GetDeviceKey(const std::string & device_name,uint32_t device_id)66 std::string KernelRuntimeManager::GetDeviceKey(const std::string &device_name, uint32_t device_id) {
67   std::string device_key = device_name + "_" + std::to_string(device_id);
68   return device_key;
69 }
70 
GetSingleKernelRuntime(const std::string & device_name,uint32_t device_id)71 KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id) {
72   auto runtime_key = GetDeviceKey(device_name, device_id);
73   auto runtime_iter = runtime_map_.find(runtime_key);
74   if (runtime_iter != runtime_map_.end()) {
75     return runtime_iter->second.get();
76   } else if (!runtime_map_.empty()) {
77     auto cur_runtime_key = runtime_map_.begin()->first;
78     auto find_pos = cur_runtime_key.rfind('_');
79     if (find_pos != std::string::npos) {
80       if (cur_runtime_key.size() > find_pos + 1) {
81         auto cur_device_id = cur_runtime_key.substr(find_pos + 1);
82         MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id
83                           << ", set device id: " << device_id << " failed";
84       } else {
85         MS_LOG(EXCEPTION) << "Can't change device id in runtime, current runtime_key size error, set device id: "
86                           << device_id << " failed";
87       }
88     }
89   }
90   return GetKernelRuntime(device_name, device_id);
91 }
92 
GetKernelRuntime(const std::string & device_name,uint32_t device_id)93 KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_name, uint32_t device_id) {
94   std::string runtime_key = GetDeviceKey(device_name, device_id);
95   std::lock_guard<std::mutex> guard(lock_);
96   auto runtime_iter = runtime_map_.find(runtime_key);
97   if (runtime_iter != runtime_map_.end()) {
98     return runtime_iter->second.get();
99   }
100   std::shared_ptr<KernelRuntime> kernel_runtime;
101   auto creator_iter = runtime_creators_.find(device_name);
102   if (creator_iter != runtime_creators_.end()) {
103     MS_EXCEPTION_IF_NULL(creator_iter->second);
104     kernel_runtime = (creator_iter->second)();
105     MS_EXCEPTION_IF_NULL(kernel_runtime);
106     kernel_runtime->set_device_id(device_id);
107     runtime_map_[runtime_key] = kernel_runtime;
108   } else {
109     MS_LOG(EXCEPTION) << "No kernel runtime creator for " << device_name << " with device id " << device_id;
110   }
111 
112   return kernel_runtime.get();
113 }
114 
GetCurrentKernelRuntime()115 KernelRuntime *KernelRuntimeManager::GetCurrentKernelRuntime() {
116   auto ms_context = MsContext::GetInstance();
117   MS_EXCEPTION_IF_NULL(ms_context);
118   uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
119   std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
120   return GetKernelRuntime(device_name, device_id);
121 }
122 
ReleaseKernelRuntime(const std::string & device_name,uint32_t device_id)123 void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id) {
124   session::PynativeTaskManager::GetInstance().Reset();
125   std::string runtime_key = GetDeviceKey(device_name, device_id);
126   std::lock_guard<std::mutex> guard(lock_);
127   auto runtime_iter = runtime_map_.find(runtime_key);
128   if (runtime_iter == runtime_map_.end()) {
129     return;
130   }
131   auto runtime = runtime_iter->second.get();
132   if (runtime == nullptr) {
133     return;
134   }
135 #if ((defined ENABLE_CPU) && (!defined _WIN32))
136   if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
137     ps::ps_cache_instance.SyncEmbeddingTable();
138   }
139 #endif
140   runtime->ReleaseDeviceRes();
141   runtime_map_.erase(runtime_iter);
142 }
143 }  // namespace device
144 }  // namespace mindspore
145