1 /*
2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "context_pool.h"
17
18 #include <fcntl.h>
19 #include <mutex>
20 #include <set>
21 #include <singleton.h>
22 #include <unordered_map>
23
24 #include "iam_logger.h"
25 #include "iam_para2str.h"
26 #include "iam_check.h"
27
28 #define LOG_TAG "USER_AUTH_SA"
29
30 namespace OHOS {
31 namespace UserIam {
32 namespace UserAuth {
33 namespace {
34 const uint32_t MAX_CONTEXT_NUM = 100;
GenerateRand(uint8_t * data,size_t len)35 bool GenerateRand(uint8_t *data, size_t len)
36 {
37 int fd = open("/dev/random", O_RDONLY);
38 if (fd < 0) {
39 IAM_LOGE("open read file fail");
40 return false;
41 }
42 ssize_t readLen = read(fd, data, len);
43 close(fd);
44 if (readLen < 0) {
45 IAM_LOGE("read file failed");
46 return false;
47 }
48 return static_cast<size_t>(readLen) == len;
49 }
50 }
51 class ContextPoolImpl final : public ContextPool, public Singleton<ContextPoolImpl> {
52 public:
53 bool Insert(const std::shared_ptr<Context> &context) override;
54 bool Delete(uint64_t contextId) override;
55 void CancelAll() const override;
56 std::weak_ptr<Context> Select(uint64_t contextId) const override;
57 std::vector<std::weak_ptr<Context>> Select(ContextType contextType) const override;
58 std::shared_ptr<ScheduleNode> SelectScheduleNodeByScheduleId(uint64_t scheduleId) override;
59 bool RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener) override;
60 bool DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener) override;
61
62 private:
63 mutable std::recursive_mutex poolMutex_;
64 std::unordered_map<uint64_t, std::shared_ptr<Context>> contextMap_;
65 std::set<std::shared_ptr<ContextPoolListener>> listenerSet_;
66 };
67
Insert(const std::shared_ptr<Context> & context)68 bool ContextPoolImpl::Insert(const std::shared_ptr<Context> &context)
69 {
70 if (context == nullptr) {
71 IAM_LOGE("context is nullptr");
72 return false;
73 }
74 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
75 if (contextMap_.size() >= MAX_CONTEXT_NUM) {
76 IAM_LOGE("context pool is full");
77 return false;
78 }
79 uint64_t contextId = context->GetContextId();
80 auto result = contextMap_.try_emplace(contextId, context);
81 if (!result.second) {
82 return false;
83 }
84 for (const auto &listener : listenerSet_) {
85 if (listener != nullptr) {
86 listener->OnContextPoolInsert(context);
87 }
88 }
89 return true;
90 }
91
Delete(uint64_t contextId)92 bool ContextPoolImpl::Delete(uint64_t contextId)
93 {
94 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
95 auto iter = contextMap_.find(contextId);
96 if (iter == contextMap_.end()) {
97 IAM_LOGE("context not found");
98 return false;
99 }
100 auto tempContext = iter->second;
101 contextMap_.erase(iter);
102 for (const auto &listener : listenerSet_) {
103 if (listener != nullptr) {
104 listener->OnContextPoolDelete(tempContext);
105 }
106 }
107 return true;
108 }
109
CancelAll() const110 void ContextPoolImpl::CancelAll() const
111 {
112 IAM_LOGI("start");
113 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
114 for (const auto &context : contextMap_) {
115 if (context.second == nullptr) {
116 continue;
117 }
118 IAM_LOGI("cancel context %{public}s", GET_MASKED_STRING(context.second->GetContextId()).c_str());
119 if (!context.second->Stop()) {
120 IAM_LOGE("cancel context %{public}s fail", GET_MASKED_STRING(context.second->GetContextId()).c_str());
121 }
122 }
123 }
124
Select(uint64_t contextId) const125 std::weak_ptr<Context> ContextPoolImpl::Select(uint64_t contextId) const
126 {
127 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
128 std::weak_ptr<Context> result;
129 auto iter = contextMap_.find(contextId);
130 if (iter != contextMap_.end()) {
131 result = iter->second;
132 }
133 return result;
134 }
135
Select(ContextType contextType) const136 std::vector<std::weak_ptr<Context>> ContextPoolImpl::Select(ContextType contextType) const
137 {
138 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
139 std::vector<std::weak_ptr<Context>> result;
140 for (const auto &context : contextMap_) {
141 if (context.second == nullptr) {
142 continue;
143 }
144 if (context.second->GetContextType() == contextType) {
145 result.emplace_back(context.second);
146 }
147 }
148 return result;
149 }
150
SelectScheduleNodeByScheduleId(uint64_t scheduleId)151 std::shared_ptr<ScheduleNode> ContextPoolImpl::SelectScheduleNodeByScheduleId(uint64_t scheduleId)
152 {
153 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
154 for (const auto &context : contextMap_) {
155 if (context.second == nullptr) {
156 continue;
157 }
158 auto node = context.second->GetScheduleNode(scheduleId);
159 if (node != nullptr) {
160 return node;
161 }
162 }
163
164 IAM_LOGE("not found");
165 return nullptr;
166 }
167
RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> & listener)168 bool ContextPoolImpl::RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener)
169 {
170 if (listener == nullptr) {
171 IAM_LOGE("listener is nullptr");
172 return false;
173 }
174 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
175 listenerSet_.insert(listener);
176 return true;
177 }
178
DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> & listener)179 bool ContextPoolImpl::DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener)
180 {
181 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
182 return listenerSet_.erase(listener) == 1;
183 }
184
Instance()185 ContextPool &ContextPool::Instance()
186 {
187 return ContextPoolImpl::GetInstance();
188 }
189
GetNewContextId()190 uint64_t ContextPool::GetNewContextId()
191 {
192 static constexpr uint32_t MAX_TRY_TIMES = 10;
193 static std::mutex mutex;
194 std::lock_guard<std::mutex> lock(mutex);
195 uint64_t contextId = 0;
196 unsigned char *contextIdPtr = static_cast<unsigned char *>(static_cast<void *>(&contextId));
197 for (uint32_t i = 0; i < MAX_TRY_TIMES; i++) {
198 bool genRandRet = GenerateRand(contextIdPtr, sizeof(uint64_t));
199 if (!genRandRet) {
200 IAM_LOGE("generate rand fail");
201 return 0;
202 }
203 if (contextId == 0 || contextId == REUSE_AUTH_RESULT_CONTEXT_ID ||
204 ContextPool::Instance().Select(contextId).lock() != nullptr) {
205 IAM_LOGE("invalid or duplicate context id");
206 continue;
207 }
208 break;
209 }
210 return contextId;
211 }
212 } // namespace UserAuth
213 } // namespace UserIam
214 } // namespace OHOS
215