• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/kernel_runtime_manager.h"
18 #include "utils/log_adapter.h"
19 
20 namespace mindspore {
21 namespace device {
ClearRuntimeResource()22 void KernelRuntimeManager::ClearRuntimeResource() {
23   std::lock_guard<std::mutex> guard(lock_);
24   for (auto &iter : runtime_map_) {
25     MS_LOG(INFO) << "Release device " << iter.first;
26     MS_EXCEPTION_IF_NULL(iter.second);
27     iter.second->ReleaseDeviceRes();
28   }
29   runtime_map_.clear();
30 }
31 
ClearGraphResource(uint32_t graph_id)32 void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id) {
33   std::lock_guard<std::mutex> guard(lock_);
34   for (auto &iter : runtime_map_) {
35     MS_LOG(INFO) << "Clear device " << iter.first << " graph " << graph_id << " runtime resource";
36     if (!iter.second) {
37       MS_LOG(ERROR) << "Kernel runtime is nullptr";
38       continue;
39     }
40     iter.second->ClearGraphRuntimeResource(graph_id);
41   }
42 }
43 
Instance()44 KernelRuntimeManager &KernelRuntimeManager::Instance() {
45   static KernelRuntimeManager instance{};
46   return instance;
47 }
48 
Register(const std::string & device_name,KernelRuntimeCreator && runtime_creator) const49 void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) const {
50   if (runtime_creators_.find(device_name) == runtime_creators_.end()) {
51     (void)runtime_creators_.emplace(device_name, runtime_creator);
52   }
53 }
54 
Clear()55 void KernelRuntimeManager::Clear() { runtime_creators_.clear(); }
56 
GetDeviceKey(const std::string & device_name,uint32_t device_id) const57 std::string KernelRuntimeManager::GetDeviceKey(const std::string &device_name, uint32_t device_id) const {
58   std::string device_key = device_name + "_" + std::to_string(device_id);
59   return device_key;
60 }
61 
GetSingleKernelRuntime(const std::string & device_name,uint32_t device_id)62 KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id) {
63   auto runtime_key = GetDeviceKey(device_name, device_id);
64   auto runtime_iter = runtime_map_.find(runtime_key);
65   if (runtime_iter != runtime_map_.end()) {
66     return runtime_iter->second.get();
67   } else if (!runtime_map_.empty()) {
68     auto cur_runtime_key = runtime_map_.begin()->first;
69     auto find_pos = cur_runtime_key.rfind('_');
70     if (find_pos != std::string::npos) {
71       if (cur_runtime_key.size() > find_pos + 1) {
72         auto cur_device_id = cur_runtime_key.substr(find_pos + 1);
73         MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id
74                           << ", set device id: " << device_id << " failed";
75       } else {
76         MS_LOG(EXCEPTION) << "Can't change device id in runtime, current runtime_key size error, set device id: "
77                           << device_id << " failed";
78       }
79     }
80   }
81   return GetKernelRuntime(device_name, device_id);
82 }
83 
GetKernelRuntime(const std::string & device_name,uint32_t device_id)84 KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_name, uint32_t device_id) {
85   std::string runtime_key = GetDeviceKey(device_name, device_id);
86   std::lock_guard<std::mutex> guard(lock_);
87   auto runtime_iter = runtime_map_.find(runtime_key);
88   if (runtime_iter != runtime_map_.end()) {
89     return runtime_iter->second.get();
90   }
91   std::shared_ptr<KernelRuntime> kernel_runtime;
92   auto creator_iter = runtime_creators_.find(device_name);
93   if (creator_iter != runtime_creators_.end()) {
94     MS_EXCEPTION_IF_NULL(creator_iter->second);
95     kernel_runtime = (creator_iter->second)();
96     MS_EXCEPTION_IF_NULL(kernel_runtime);
97     kernel_runtime->set_device_id(device_id);
98     runtime_map_[runtime_key] = kernel_runtime;
99   } else {
100     MS_LOG(EXCEPTION) << "No kernel runtime creator for " << device_name << " with device id " << device_id;
101   }
102 
103   return kernel_runtime.get();
104 }
105 
GetCurrentKernelRuntime()106 KernelRuntime *KernelRuntimeManager::GetCurrentKernelRuntime() {
107   auto ms_context = MsContext::GetInstance();
108   MS_EXCEPTION_IF_NULL(ms_context);
109   uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
110   std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
111   return GetKernelRuntime(device_name, device_id);
112 }
113 
ReleaseKernelRuntime(const std::string & device_name,uint32_t device_id)114 void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id) {
115   std::string runtime_key = GetDeviceKey(device_name, device_id);
116   std::lock_guard<std::mutex> guard(lock_);
117   auto runtime_iter = runtime_map_.find(runtime_key);
118   if (runtime_iter == runtime_map_.end()) {
119     return;
120   }
121   auto runtime = runtime_iter->second.get();
122   if (runtime == nullptr) {
123     return;
124   }
125   runtime->ReleaseDeviceRes();
126   (void)runtime_map_.erase(runtime_iter);
127 }
128 
WaitTaskFinishOnDevice() const129 void KernelRuntimeManager::WaitTaskFinishOnDevice() const {
130   for (const auto &iter : runtime_map_) {
131     auto kernel_runtime = iter.second;
132     try {
133       if (kernel_runtime != nullptr && !kernel_runtime->SyncStream()) {
134         MS_LOG(ERROR) << "SyncStream failed";
135         return;
136       }
137     } catch (const std::exception &ex) {
138       MS_LOG(ERROR) << "SyncStream failed, exception:" << ex.what();
139       return;
140     }
141   }
142 }
143 }  // namespace device
144 }  // namespace mindspore
145