1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/kernel_runtime_manager.h"
18 #include "utils/log_adapter.h"
19
20 namespace mindspore {
21 namespace device {
ClearRuntimeResource()22 void KernelRuntimeManager::ClearRuntimeResource() {
23 std::lock_guard<std::mutex> guard(lock_);
24 for (auto &iter : runtime_map_) {
25 MS_LOG(INFO) << "Release device " << iter.first;
26 MS_EXCEPTION_IF_NULL(iter.second);
27 iter.second->ReleaseDeviceRes();
28 }
29 runtime_map_.clear();
30 }
31
ClearGraphResource(uint32_t graph_id)32 void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id) {
33 std::lock_guard<std::mutex> guard(lock_);
34 for (auto &iter : runtime_map_) {
35 MS_LOG(INFO) << "Clear device " << iter.first << " graph " << graph_id << " runtime resource";
36 if (!iter.second) {
37 MS_LOG(ERROR) << "Kernel runtime is nullptr";
38 continue;
39 }
40 iter.second->ClearGraphRuntimeResource(graph_id);
41 }
42 }
43
Instance()44 KernelRuntimeManager &KernelRuntimeManager::Instance() {
45 static KernelRuntimeManager instance{};
46 return instance;
47 }
48
Register(const std::string & device_name,KernelRuntimeCreator && runtime_creator) const49 void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) const {
50 if (runtime_creators_.find(device_name) == runtime_creators_.end()) {
51 (void)runtime_creators_.emplace(device_name, runtime_creator);
52 }
53 }
54
Clear()55 void KernelRuntimeManager::Clear() { runtime_creators_.clear(); }
56
GetDeviceKey(const std::string & device_name,uint32_t device_id) const57 std::string KernelRuntimeManager::GetDeviceKey(const std::string &device_name, uint32_t device_id) const {
58 std::string device_key = device_name + "_" + std::to_string(device_id);
59 return device_key;
60 }
61
GetSingleKernelRuntime(const std::string & device_name,uint32_t device_id)62 KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id) {
63 auto runtime_key = GetDeviceKey(device_name, device_id);
64 auto runtime_iter = runtime_map_.find(runtime_key);
65 if (runtime_iter != runtime_map_.end()) {
66 return runtime_iter->second.get();
67 } else if (!runtime_map_.empty()) {
68 auto cur_runtime_key = runtime_map_.begin()->first;
69 auto find_pos = cur_runtime_key.rfind('_');
70 if (find_pos != std::string::npos) {
71 if (cur_runtime_key.size() > find_pos + 1) {
72 auto cur_device_id = cur_runtime_key.substr(find_pos + 1);
73 MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id
74 << ", set device id: " << device_id << " failed";
75 } else {
76 MS_LOG(EXCEPTION) << "Can't change device id in runtime, current runtime_key size error, set device id: "
77 << device_id << " failed";
78 }
79 }
80 }
81 return GetKernelRuntime(device_name, device_id);
82 }
83
GetKernelRuntime(const std::string & device_name,uint32_t device_id)84 KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_name, uint32_t device_id) {
85 std::string runtime_key = GetDeviceKey(device_name, device_id);
86 std::lock_guard<std::mutex> guard(lock_);
87 auto runtime_iter = runtime_map_.find(runtime_key);
88 if (runtime_iter != runtime_map_.end()) {
89 return runtime_iter->second.get();
90 }
91 std::shared_ptr<KernelRuntime> kernel_runtime;
92 auto creator_iter = runtime_creators_.find(device_name);
93 if (creator_iter != runtime_creators_.end()) {
94 MS_EXCEPTION_IF_NULL(creator_iter->second);
95 kernel_runtime = (creator_iter->second)();
96 MS_EXCEPTION_IF_NULL(kernel_runtime);
97 kernel_runtime->set_device_id(device_id);
98 runtime_map_[runtime_key] = kernel_runtime;
99 } else {
100 MS_LOG(EXCEPTION) << "No kernel runtime creator for " << device_name << " with device id " << device_id;
101 }
102
103 return kernel_runtime.get();
104 }
105
GetCurrentKernelRuntime()106 KernelRuntime *KernelRuntimeManager::GetCurrentKernelRuntime() {
107 auto ms_context = MsContext::GetInstance();
108 MS_EXCEPTION_IF_NULL(ms_context);
109 uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
110 std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
111 return GetKernelRuntime(device_name, device_id);
112 }
113
ReleaseKernelRuntime(const std::string & device_name,uint32_t device_id)114 void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id) {
115 std::string runtime_key = GetDeviceKey(device_name, device_id);
116 std::lock_guard<std::mutex> guard(lock_);
117 auto runtime_iter = runtime_map_.find(runtime_key);
118 if (runtime_iter == runtime_map_.end()) {
119 return;
120 }
121 auto runtime = runtime_iter->second.get();
122 if (runtime == nullptr) {
123 return;
124 }
125 runtime->ReleaseDeviceRes();
126 (void)runtime_map_.erase(runtime_iter);
127 }
128
WaitTaskFinishOnDevice() const129 void KernelRuntimeManager::WaitTaskFinishOnDevice() const {
130 for (const auto &iter : runtime_map_) {
131 auto kernel_runtime = iter.second;
132 try {
133 if (kernel_runtime != nullptr && !kernel_runtime->SyncStream()) {
134 MS_LOG(ERROR) << "SyncStream failed";
135 return;
136 }
137 } catch (const std::exception &ex) {
138 MS_LOG(ERROR) << "SyncStream failed, exception:" << ex.what();
139 return;
140 }
141 }
142 }
143 } // namespace device
144 } // namespace mindspore
145