1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/kernel_runtime_manager.h"
18 #include "utils/log_adapter.h"
19 #if ((defined ENABLE_CPU) && (!defined _WIN32))
20 #include "ps/ps_cache/ps_cache_manager.h"
21 #endif
22 #include "backend/session/pynative_task_manager.h"
23
24 namespace mindspore {
25 namespace device {
ClearRuntimeResource()26 void KernelRuntimeManager::ClearRuntimeResource() {
27 // Just remove PyNative tasks before runtime resource release.
28 session::PynativeTaskManager::GetInstance().Reset();
29 #if ((defined ENABLE_CPU) && (!defined _WIN32))
30 if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
31 ps::ps_cache_instance.SyncEmbeddingTable();
32 }
33 #endif
34 std::lock_guard<std::mutex> guard(lock_);
35 for (auto &iter : runtime_map_) {
36 MS_LOG(INFO) << "Release device " << iter.first;
37 MS_EXCEPTION_IF_NULL(iter.second);
38 iter.second->ReleaseDeviceRes();
39 }
40 runtime_map_.clear();
41 }
42
ClearGraphResource(uint32_t graph_id)43 void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id) {
44 std::lock_guard<std::mutex> guard(lock_);
45 for (auto &iter : runtime_map_) {
46 MS_LOG(INFO) << "Clear device " << iter.first << " graph " << graph_id << " runtime resource";
47 if (!iter.second) {
48 MS_LOG(ERROR) << "Kernel runtime is nullptr";
49 continue;
50 }
51 iter.second->ClearGraphRuntimeResource(graph_id);
52 }
53 }
54
Instance()55 KernelRuntimeManager &KernelRuntimeManager::Instance() {
56 static KernelRuntimeManager instance{};
57 return instance;
58 }
59
Register(const std::string & device_name,KernelRuntimeCreator && runtime_creator)60 void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) {
61 if (runtime_creators_.find(device_name) == runtime_creators_.end()) {
62 (void)runtime_creators_.emplace(device_name, runtime_creator);
63 }
64 }
65
GetDeviceKey(const std::string & device_name,uint32_t device_id)66 std::string KernelRuntimeManager::GetDeviceKey(const std::string &device_name, uint32_t device_id) {
67 std::string device_key = device_name + "_" + std::to_string(device_id);
68 return device_key;
69 }
70
GetSingleKernelRuntime(const std::string & device_name,uint32_t device_id)71 KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id) {
72 auto runtime_key = GetDeviceKey(device_name, device_id);
73 auto runtime_iter = runtime_map_.find(runtime_key);
74 if (runtime_iter != runtime_map_.end()) {
75 return runtime_iter->second.get();
76 } else if (!runtime_map_.empty()) {
77 auto cur_runtime_key = runtime_map_.begin()->first;
78 auto find_pos = cur_runtime_key.rfind('_');
79 if (find_pos != std::string::npos) {
80 if (cur_runtime_key.size() > find_pos + 1) {
81 auto cur_device_id = cur_runtime_key.substr(find_pos + 1);
82 MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id
83 << ", set device id: " << device_id << " failed";
84 } else {
85 MS_LOG(EXCEPTION) << "Can't change device id in runtime, current runtime_key size error, set device id: "
86 << device_id << " failed";
87 }
88 }
89 }
90 return GetKernelRuntime(device_name, device_id);
91 }
92
GetKernelRuntime(const std::string & device_name,uint32_t device_id)93 KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_name, uint32_t device_id) {
94 std::string runtime_key = GetDeviceKey(device_name, device_id);
95 std::lock_guard<std::mutex> guard(lock_);
96 auto runtime_iter = runtime_map_.find(runtime_key);
97 if (runtime_iter != runtime_map_.end()) {
98 return runtime_iter->second.get();
99 }
100 std::shared_ptr<KernelRuntime> kernel_runtime;
101 auto creator_iter = runtime_creators_.find(device_name);
102 if (creator_iter != runtime_creators_.end()) {
103 MS_EXCEPTION_IF_NULL(creator_iter->second);
104 kernel_runtime = (creator_iter->second)();
105 MS_EXCEPTION_IF_NULL(kernel_runtime);
106 kernel_runtime->set_device_id(device_id);
107 runtime_map_[runtime_key] = kernel_runtime;
108 } else {
109 MS_LOG(EXCEPTION) << "No kernel runtime creator for " << device_name << " with device id " << device_id;
110 }
111
112 return kernel_runtime.get();
113 }
114
GetCurrentKernelRuntime()115 KernelRuntime *KernelRuntimeManager::GetCurrentKernelRuntime() {
116 auto ms_context = MsContext::GetInstance();
117 MS_EXCEPTION_IF_NULL(ms_context);
118 uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
119 std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
120 return GetKernelRuntime(device_name, device_id);
121 }
122
ReleaseKernelRuntime(const std::string & device_name,uint32_t device_id)123 void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id) {
124 session::PynativeTaskManager::GetInstance().Reset();
125 std::string runtime_key = GetDeviceKey(device_name, device_id);
126 std::lock_guard<std::mutex> guard(lock_);
127 auto runtime_iter = runtime_map_.find(runtime_key);
128 if (runtime_iter == runtime_map_.end()) {
129 return;
130 }
131 auto runtime = runtime_iter->second.get();
132 if (runtime == nullptr) {
133 return;
134 }
135 #if ((defined ENABLE_CPU) && (!defined _WIN32))
136 if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
137 ps::ps_cache_instance.SyncEmbeddingTable();
138 }
139 #endif
140 runtime->ReleaseDeviceRes();
141 runtime_map_.erase(runtime_iter);
142 }
143 } // namespace device
144 } // namespace mindspore
145