1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "backend_manager.h"
17
18 #include <algorithm>
19 #include "cpp_type.h"
20
21 namespace OHOS {
22 namespace NeuralNetworkRuntime {
~BackendManager()23 BackendManager::~BackendManager()
24 {
25 m_backends.clear();
26 m_backendNames.clear();
27 m_backendIDs.clear();
28 m_backendIDGroup.clear();
29 }
30
GetAllBackendsID()31 const std::vector<size_t>& BackendManager::GetAllBackendsID()
32 {
33 const std::lock_guard<std::mutex> lock(m_mtx);
34 return m_backendIDs;
35 }
36
GetBackend(size_t backendID)37 std::shared_ptr<Backend> BackendManager::GetBackend(size_t backendID)
38 {
39 const std::lock_guard<std::mutex> lock(m_mtx);
40 if (m_backends.empty()) {
41 LOGE("[BackendManager] GetBackend failed, there is no registered backend can be used.");
42 return nullptr;
43 }
44
45 auto iter = m_backends.begin();
46 if (backendID == static_cast<size_t>(0)) {
47 LOGI("[BackendManager] the backendID is 0, default return 1st backend.");
48 return iter->second;
49 }
50
51 iter = m_backends.find(backendID);
52 if (iter == m_backends.end()) {
53 LOGE("[BackendManager] GetBackend failed, not find backendId=%{public}zu", backendID);
54 return nullptr;
55 }
56
57 return iter->second;
58 }
59
GetBackendName(size_t backendID)60 const std::string& BackendManager::GetBackendName(size_t backendID)
61 {
62 const std::lock_guard<std::mutex> lock(m_mtx);
63 if (m_backendNames.empty()) {
64 LOGE("[BackendManager] GetBackendName failed, there is no registered backend can be used.");
65 return m_emptyBackendName;
66 }
67
68 auto iter = m_backendNames.begin();
69 if (backendID == static_cast<size_t>(0)) {
70 LOGI("[BackendManager] the backendID is 0, default return 1st backend.");
71 } else {
72 iter = m_backendNames.find(backendID);
73 }
74
75 if (iter == m_backendNames.end()) {
76 LOGE("[BackendManager] GetBackendName failed, backendID %{public}zu is not registered.", backendID);
77 return m_emptyBackendName;
78 }
79
80 return iter->second;
81 }
82
RegisterBackend(const std::string & backendName,std::function<std::shared_ptr<Backend> ()> creator)83 OH_NN_ReturnCode BackendManager::RegisterBackend(
84 const std::string& backendName, std::function<std::shared_ptr<Backend>()> creator)
85 {
86 auto regBackend = creator();
87 if (regBackend == nullptr) {
88 LOGE("[BackendManager] RegisterBackend failed, fail to create backend.");
89 return OH_NN_FAILED;
90 }
91
92 if (!IsValidBackend(regBackend)) {
93 LOGE("[BackendManager] RegisterBackend failed, backend is not available.");
94 return OH_NN_UNAVAILABLE_DEVICE;
95 }
96
97 size_t backendID = regBackend->GetBackendID();
98
99 const std::lock_guard<std::mutex> lock(m_mtx);
100 auto iter = std::find(m_backendIDs.begin(), m_backendIDs.end(), backendID);
101 if (iter != m_backendIDs.end()) {
102 LOGE("[BackendManager] RegisterBackend failed, backend already exists, cannot register again. "
103 "backendID=%{public}zu", backendID);
104 return OH_NN_FAILED;
105 }
106
107 std::string tmpBackendName;
108 auto ret = regBackend->GetBackendName(tmpBackendName);
109 if (ret != OH_NN_SUCCESS) {
110 LOGE("[BackendManager] RegisterBackend failed, fail to get backend name.");
111 return OH_NN_FAILED;
112 }
113 m_backends.emplace(backendID, regBackend);
114 m_backendIDs.emplace_back(backendID);
115 m_backendNames.emplace(backendID, tmpBackendName);
116 if (m_backendIDGroup.find(backendName) == m_backendIDGroup.end()) {
117 std::vector<size_t> backendIDsTmp {backendID};
118 m_backendIDGroup.emplace(backendName, backendIDsTmp);
119 } else {
120 m_backendIDGroup[backendName].emplace_back(backendID);
121 }
122 return OH_NN_SUCCESS;
123 }
124
RemoveBackend(const std::string & backendName)125 void BackendManager::RemoveBackend(const std::string& backendName)
126 {
127 LOGI("[RemoveBackend] start remove backend for %{public}s.", backendName.c_str());
128 const std::lock_guard<std::mutex> lock(m_mtx);
129 if (m_backendIDGroup.find(backendName) == m_backendIDGroup.end()) {
130 LOGI("[RemoveBackend] No need to remove backend for %{public}s.", backendName.c_str());
131 return;
132 }
133
134 auto backendIDs = m_backendIDGroup[backendName];
135 for (auto backendID : backendIDs) {
136 if (m_backends.find(backendID) != m_backends.end()) {
137 m_backends.erase(backendID);
138 }
139 auto iter = std::find(m_backendIDs.begin(), m_backendIDs.end(), backendID);
140 if (iter != m_backendIDs.end()) {
141 m_backendIDs.erase(iter);
142 }
143 if (m_backendNames.find(backendID) != m_backendNames.end()) {
144 m_backendNames.erase(backendID);
145 }
146 LOGI("[RemoveBackend] remove backendID[%{public}zu] for %{public}s success.", backendID, backendName.c_str());
147 }
148 m_backendIDGroup.erase(backendName);
149 }
150
IsValidBackend(std::shared_ptr<Backend> backend) const151 bool BackendManager::IsValidBackend(std::shared_ptr<Backend> backend) const
152 {
153 DeviceStatus status = UNKNOWN;
154
155 OH_NN_ReturnCode ret = backend->GetBackendStatus(status);
156 if (ret != OH_NN_SUCCESS || status == UNKNOWN || status == OFFLINE) {
157 return false;
158 }
159
160 return true;
161 }
162 } // NeuralNetworkCore
163 } // OHOS
164