• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "backend_manager.h"
17 
18 #include <algorithm>
19 #include "cpp_type.h"
20 
21 namespace OHOS {
22 namespace NeuralNetworkRuntime {
~BackendManager()23 BackendManager::~BackendManager()
24 {
25     m_backends.clear();
26     m_backendNames.clear();
27     m_backendIDs.clear();
28     m_backendIDGroup.clear();
29 }
30 
GetAllBackendsID()31 const std::vector<size_t>& BackendManager::GetAllBackendsID()
32 {
33     const std::lock_guard<std::mutex> lock(m_mtx);
34     return m_backendIDs;
35 }
36 
GetBackend(size_t backendID)37 std::shared_ptr<Backend> BackendManager::GetBackend(size_t backendID)
38 {
39     const std::lock_guard<std::mutex> lock(m_mtx);
40     if (m_backends.empty()) {
41         LOGE("[BackendManager] GetBackend failed, there is no registered backend can be used.");
42         return nullptr;
43     }
44 
45     auto iter = m_backends.begin();
46     if (backendID == static_cast<size_t>(0)) {
47         LOGI("[BackendManager] the backendID is 0, default return 1st backend.");
48         return iter->second;
49     }
50 
51     iter = m_backends.find(backendID);
52     if (iter == m_backends.end()) {
53         LOGE("[BackendManager] GetBackend failed, not find backendId=%{public}zu", backendID);
54         return nullptr;
55     }
56 
57     return iter->second;
58 }
59 
GetBackendName(size_t backendID)60 const std::string& BackendManager::GetBackendName(size_t backendID)
61 {
62     const std::lock_guard<std::mutex> lock(m_mtx);
63     if (m_backendNames.empty()) {
64         LOGE("[BackendManager] GetBackendName failed, there is no registered backend can be used.");
65         return m_emptyBackendName;
66     }
67 
68     auto iter = m_backendNames.begin();
69     if (backendID == static_cast<size_t>(0)) {
70         LOGI("[BackendManager] the backendID is 0, default return 1st backend.");
71     } else {
72         iter = m_backendNames.find(backendID);
73     }
74 
75     if (iter == m_backendNames.end()) {
76         LOGE("[BackendManager] GetBackendName failed, backendID %{public}zu is not registered.", backendID);
77         return m_emptyBackendName;
78     }
79 
80     return iter->second;
81 }
82 
RegisterBackend(const std::string & backendName,std::function<std::shared_ptr<Backend> ()> creator)83 OH_NN_ReturnCode BackendManager::RegisterBackend(
84     const std::string& backendName, std::function<std::shared_ptr<Backend>()> creator)
85 {
86     auto regBackend = creator();
87     if (regBackend == nullptr) {
88         LOGE("[BackendManager] RegisterBackend failed, fail to create backend.");
89         return OH_NN_FAILED;
90     }
91 
92     if (!IsValidBackend(regBackend)) {
93         LOGE("[BackendManager] RegisterBackend failed, backend is not available.");
94         return OH_NN_UNAVAILABLE_DEVICE;
95     }
96 
97     size_t backendID = regBackend->GetBackendID();
98 
99     const std::lock_guard<std::mutex> lock(m_mtx);
100     auto iter = std::find(m_backendIDs.begin(), m_backendIDs.end(), backendID);
101     if (iter != m_backendIDs.end()) {
102         LOGE("[BackendManager] RegisterBackend failed, backend already exists, cannot register again. "
103              "backendID=%{public}zu", backendID);
104         return OH_NN_FAILED;
105     }
106 
107     std::string tmpBackendName;
108     auto ret = regBackend->GetBackendName(tmpBackendName);
109     if (ret != OH_NN_SUCCESS) {
110         LOGE("[BackendManager] RegisterBackend failed, fail to get backend name.");
111         return OH_NN_FAILED;
112     }
113     m_backends.emplace(backendID, regBackend);
114     m_backendIDs.emplace_back(backendID);
115     m_backendNames.emplace(backendID, tmpBackendName);
116     if (m_backendIDGroup.find(backendName) == m_backendIDGroup.end()) {
117         std::vector<size_t> backendIDsTmp {backendID};
118         m_backendIDGroup.emplace(backendName, backendIDsTmp);
119     } else {
120         m_backendIDGroup[backendName].emplace_back(backendID);
121     }
122     return OH_NN_SUCCESS;
123 }
124 
RemoveBackend(const std::string & backendName)125 void BackendManager::RemoveBackend(const std::string& backendName)
126 {
127     LOGI("[RemoveBackend] start remove backend for %{public}s.", backendName.c_str());
128     const std::lock_guard<std::mutex> lock(m_mtx);
129     if (m_backendIDGroup.find(backendName) == m_backendIDGroup.end()) {
130         LOGI("[RemoveBackend] No need to remove backend for %{public}s.", backendName.c_str());
131         return;
132     }
133 
134     auto backendIDs = m_backendIDGroup[backendName];
135     for (auto backendID : backendIDs) {
136         if (m_backends.find(backendID) != m_backends.end()) {
137             m_backends.erase(backendID);
138         }
139         auto iter = std::find(m_backendIDs.begin(), m_backendIDs.end(), backendID);
140         if (iter != m_backendIDs.end()) {
141             m_backendIDs.erase(iter);
142         }
143         if (m_backendNames.find(backendID) != m_backendNames.end()) {
144             m_backendNames.erase(backendID);
145         }
146         LOGI("[RemoveBackend] remove backendID[%{public}zu] for %{public}s success.", backendID, backendName.c_str());
147     }
148     m_backendIDGroup.erase(backendName);
149 }
150 
IsValidBackend(std::shared_ptr<Backend> backend) const151 bool BackendManager::IsValidBackend(std::shared_ptr<Backend> backend) const
152 {
153     DeviceStatus status = UNKNOWN;
154 
155     OH_NN_ReturnCode ret = backend->GetBackendStatus(status);
156     if (ret != OH_NN_SUCCESS || status == UNKNOWN || status == OFFLINE) {
157         return false;
158     }
159 
160     return true;
161 }
162 } // NeuralNetworkCore
163 } // OHOS
164