1 /*
2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "context_pool.h"
17
18 #include <fcntl.h>
19 #include <mutex>
20 #include <set>
21 #include <singleton.h>
22 #include <unordered_map>
23
24 #include "system_ability_definition.h"
25
26 #include "iam_check.h"
27 #include "iam_logger.h"
28 #include "iam_para2str.h"
29 #include "system_ability_listener.h"
30
31 #define LOG_TAG "USER_AUTH_SA"
32
33 namespace OHOS {
34 namespace UserIam {
35 namespace UserAuth {
36 namespace {
37 const uint32_t MAX_CONTEXT_NUM = 100;
GenerateRand(uint8_t * data,size_t len)38 bool GenerateRand(uint8_t *data, size_t len)
39 {
40 FILE *fp = fopen("/dev/random", "rb");
41 if (fp == nullptr) {
42 IAM_LOGE("fopen read file fail");
43 return false;
44 }
45 size_t readLen = fread(data, sizeof(uint8_t), len, fp);
46 (void)fclose(fp);
47 if (readLen != len) {
48 IAM_LOGE("fread file failed");
49 return false;
50 }
51 return true;
52 }
53 }
54 class ContextPoolImpl final : public ContextPool, public Singleton<ContextPoolImpl> {
55 public:
56 bool Insert(const std::shared_ptr<Context> &context) override;
57 bool Delete(uint64_t contextId) override;
58 void CancelAll() const override;
59 std::weak_ptr<Context> Select(uint64_t contextId) const override;
60 std::vector<std::weak_ptr<Context>> Select(ContextType contextType) const override;
61 std::shared_ptr<ScheduleNode> SelectScheduleNodeByScheduleId(uint64_t scheduleId) override;
62 bool RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener) override;
63 bool DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener) override;
64 void StartSubscribeOsAccountSaStatus() override;
65
66 private:
67 void CheckPreemptContext(const std::shared_ptr<Context> &context);
68 mutable std::recursive_mutex poolMutex_;
69 std::unordered_map<uint64_t, std::shared_ptr<Context>> contextMap_;
70 std::set<std::shared_ptr<ContextPoolListener>> listenerSet_;
71 sptr<SystemAbilityListener> accountSaStatusListener_ {nullptr};
72 };
73
CheckPreemptContext(const std::shared_ptr<Context> & context)74 void ContextPoolImpl::CheckPreemptContext(const std::shared_ptr<Context> &context)
75 {
76 if (context->GetContextType() != ContextType::CONTEXT_SIMPLE_AUTH) {
77 return;
78 }
79 for (auto iter = contextMap_.begin(); iter != contextMap_.end(); iter++) {
80 if (iter->second == nullptr) {
81 IAM_LOGE("context is nullptr");
82 break;
83 }
84 if (iter->second->GetCallerName() == context->GetCallerName() &&
85 iter->second->GetAuthType() == context->GetAuthType() &&
86 iter->second->GetUserId() == context->GetUserId()) {
87 IAM_LOGE("contextId:%{public}hx is preempted, newContextId:%{public}hx, mapSize:%{public}zu,"
88 "callerName:%{public}s, userId:%{public}d, authType:%{public}d", static_cast<uint16_t>(iter->first),
89 static_cast<uint16_t>(context->GetContextId()), contextMap_.size(), context->GetCallerName().c_str(),
90 context->GetUserId(), context->GetAuthType());
91 iter->second->Stop();
92 break;
93 }
94 }
95 }
96
Insert(const std::shared_ptr<Context> & context)97 bool ContextPoolImpl::Insert(const std::shared_ptr<Context> &context)
98 {
99 if (context == nullptr) {
100 IAM_LOGE("context is nullptr");
101 return false;
102 }
103 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
104 if (contextMap_.size() >= MAX_CONTEXT_NUM) {
105 IAM_LOGE("context pool is full");
106 return false;
107 }
108 CheckPreemptContext(context);
109 uint64_t contextId = context->GetContextId();
110 auto result = contextMap_.try_emplace(contextId, context);
111 if (!result.second) {
112 return false;
113 }
114 for (const auto &listener : listenerSet_) {
115 if (listener != nullptr) {
116 listener->OnContextPoolInsert(context);
117 }
118 }
119 return true;
120 }
121
Delete(uint64_t contextId)122 bool ContextPoolImpl::Delete(uint64_t contextId)
123 {
124 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
125 auto iter = contextMap_.find(contextId);
126 if (iter == contextMap_.end()) {
127 IAM_LOGE("context not found");
128 return false;
129 }
130 auto tempContext = iter->second;
131 contextMap_.erase(iter);
132 for (const auto &listener : listenerSet_) {
133 if (listener != nullptr) {
134 listener->OnContextPoolDelete(tempContext);
135 }
136 }
137 return true;
138 }
139
CancelAll() const140 void ContextPoolImpl::CancelAll() const
141 {
142 IAM_LOGI("start");
143 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
144 for (const auto &context : contextMap_) {
145 if (context.second == nullptr) {
146 continue;
147 }
148 IAM_LOGI("cancel context %{public}s", GET_MASKED_STRING(context.second->GetContextId()).c_str());
149 if (!context.second->Stop()) {
150 IAM_LOGE("cancel context %{public}s fail", GET_MASKED_STRING(context.second->GetContextId()).c_str());
151 }
152 }
153 }
154
Select(uint64_t contextId) const155 std::weak_ptr<Context> ContextPoolImpl::Select(uint64_t contextId) const
156 {
157 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
158 std::weak_ptr<Context> result;
159 auto iter = contextMap_.find(contextId);
160 if (iter != contextMap_.end()) {
161 result = iter->second;
162 }
163 return result;
164 }
165
Select(ContextType contextType) const166 std::vector<std::weak_ptr<Context>> ContextPoolImpl::Select(ContextType contextType) const
167 {
168 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
169 std::vector<std::weak_ptr<Context>> result;
170 for (const auto &context : contextMap_) {
171 if (context.second == nullptr) {
172 continue;
173 }
174 if (context.second->GetContextType() == contextType) {
175 result.emplace_back(context.second);
176 }
177 }
178 return result;
179 }
180
SelectScheduleNodeByScheduleId(uint64_t scheduleId)181 std::shared_ptr<ScheduleNode> ContextPoolImpl::SelectScheduleNodeByScheduleId(uint64_t scheduleId)
182 {
183 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
184 for (const auto &context : contextMap_) {
185 if (context.second == nullptr) {
186 continue;
187 }
188 auto node = context.second->GetScheduleNode(scheduleId);
189 if (node != nullptr) {
190 return node;
191 }
192 }
193
194 IAM_LOGE("not found");
195 return nullptr;
196 }
197
RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> & listener)198 bool ContextPoolImpl::RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener)
199 {
200 if (listener == nullptr) {
201 IAM_LOGE("listener is nullptr");
202 return false;
203 }
204 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
205 listenerSet_.insert(listener);
206 return true;
207 }
208
DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> & listener)209 bool ContextPoolImpl::DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener)
210 {
211 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
212 return listenerSet_.erase(listener) == 1;
213 }
214
StartSubscribeOsAccountSaStatus()215 void ContextPoolImpl::StartSubscribeOsAccountSaStatus()
216 {
217 IAM_LOGI("start");
218 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
219 if (accountSaStatusListener_ != nullptr) {
220 return;
221 }
222 accountSaStatusListener_ = SystemAbilityListener::Subscribe(
223 "OsAccountService", SUBSYS_ACCOUNT_SYS_ABILITY_ID_BEGIN,
224 []() {},
225 []() { ContextPool::Instance().CancelAll(); });
226 IF_FALSE_LOGE_AND_RETURN(accountSaStatusListener_ != nullptr);
227 }
228
Instance()229 ContextPool &ContextPool::Instance()
230 {
231 return ContextPoolImpl::GetInstance();
232 }
233
GetNewContextId()234 uint64_t ContextPool::GetNewContextId()
235 {
236 static constexpr uint32_t MAX_TRY_TIMES = 10;
237 static std::mutex mutex;
238 std::lock_guard<std::mutex> lock(mutex);
239 uint64_t contextId = 0;
240 unsigned char *contextIdPtr = static_cast<unsigned char *>(static_cast<void *>(&contextId));
241 for (uint32_t i = 0; i < MAX_TRY_TIMES; i++) {
242 bool genRandRet = GenerateRand(contextIdPtr, sizeof(uint64_t));
243 if (!genRandRet) {
244 IAM_LOGE("generate rand fail");
245 return 0;
246 }
247 if (contextId == 0 || contextId == REUSE_AUTH_RESULT_CONTEXT_ID ||
248 ContextPool::Instance().Select(contextId).lock() != nullptr) {
249 IAM_LOGE("invalid or duplicate context id");
250 continue;
251 }
252 break;
253 }
254 return contextId;
255 }
256 } // namespace UserAuth
257 } // namespace UserIam
258 } // namespace OHOS
259