• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "context_pool.h"
17 
18 #include <fcntl.h>
19 #include <mutex>
20 #include <set>
21 #include <singleton.h>
22 #include <unordered_map>
23 
24 #include "system_ability_definition.h"
25 
26 #include "iam_check.h"
27 #include "iam_logger.h"
28 #include "iam_para2str.h"
29 #include "system_ability_listener.h"
30 
31 #define LOG_TAG "USER_AUTH_SA"
32 
33 namespace OHOS {
34 namespace UserIam {
35 namespace UserAuth {
36 namespace {
37 const uint32_t MAX_CONTEXT_NUM = 100;
GenerateRand(uint8_t * data,size_t len)38 bool GenerateRand(uint8_t *data, size_t len)
39 {
40     FILE *fp = fopen("/dev/random", "rb");
41     if (fp == nullptr) {
42         IAM_LOGE("fopen read file fail");
43         return false;
44     }
45     size_t readLen = fread(data, sizeof(uint8_t), len, fp);
46     (void)fclose(fp);
47     if (readLen != len) {
48         IAM_LOGE("fread file failed");
49         return false;
50     }
51     return true;
52 }
53 }
54 class ContextPoolImpl final : public ContextPool, public Singleton<ContextPoolImpl> {
55 public:
56     bool Insert(const std::shared_ptr<Context> &context) override;
57     bool Delete(uint64_t contextId) override;
58     void CancelAll() const override;
59     std::weak_ptr<Context> Select(uint64_t contextId) const override;
60     std::vector<std::weak_ptr<Context>> Select(ContextType contextType) const override;
61     std::shared_ptr<ScheduleNode> SelectScheduleNodeByScheduleId(uint64_t scheduleId) override;
62     bool RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener) override;
63     bool DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener) override;
64     void StartSubscribeOsAccountSaStatus() override;
65 
66 private:
67     void CheckPreemptContext(const std::shared_ptr<Context> &context);
68     mutable std::recursive_mutex poolMutex_;
69     std::unordered_map<uint64_t, std::shared_ptr<Context>> contextMap_;
70     std::set<std::shared_ptr<ContextPoolListener>> listenerSet_;
71     sptr<SystemAbilityListener> accountSaStatusListener_ {nullptr};
72 };
73 
CheckPreemptContext(const std::shared_ptr<Context> & context)74 void ContextPoolImpl::CheckPreemptContext(const std::shared_ptr<Context> &context)
75 {
76     if (context->GetContextType() != ContextType::CONTEXT_SIMPLE_AUTH) {
77         return;
78     }
79     for (auto iter = contextMap_.begin(); iter != contextMap_.end(); iter++) {
80         if (iter->second == nullptr) {
81             IAM_LOGE("context is nullptr");
82             break;
83         }
84         if (iter->second->GetCallerName() == context->GetCallerName() &&
85             iter->second->GetAuthType() == context->GetAuthType() &&
86             iter->second->GetUserId() == context->GetUserId()) {
87             IAM_LOGE("contextId:%{public}hx is preempted, newContextId:%{public}hx, mapSize:%{public}zu,"
88                 "callerName:%{public}s, userId:%{public}d, authType:%{public}d", static_cast<uint16_t>(iter->first),
89                 static_cast<uint16_t>(context->GetContextId()), contextMap_.size(), context->GetCallerName().c_str(),
90                 context->GetUserId(), context->GetAuthType());
91             iter->second->Stop();
92             break;
93         }
94     }
95 }
96 
Insert(const std::shared_ptr<Context> & context)97 bool ContextPoolImpl::Insert(const std::shared_ptr<Context> &context)
98 {
99     if (context == nullptr) {
100         IAM_LOGE("context is nullptr");
101         return false;
102     }
103     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
104     if (contextMap_.size() >= MAX_CONTEXT_NUM) {
105         IAM_LOGE("context pool is full");
106         return false;
107     }
108     CheckPreemptContext(context);
109     uint64_t contextId = context->GetContextId();
110     auto result = contextMap_.try_emplace(contextId, context);
111     if (!result.second) {
112         return false;
113     }
114     for (const auto &listener : listenerSet_) {
115         if (listener != nullptr) {
116             listener->OnContextPoolInsert(context);
117         }
118     }
119     return true;
120 }
121 
Delete(uint64_t contextId)122 bool ContextPoolImpl::Delete(uint64_t contextId)
123 {
124     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
125     auto iter = contextMap_.find(contextId);
126     if (iter == contextMap_.end()) {
127         IAM_LOGE("context not found");
128         return false;
129     }
130     auto tempContext = iter->second;
131     contextMap_.erase(iter);
132     for (const auto &listener : listenerSet_) {
133         if (listener != nullptr) {
134             listener->OnContextPoolDelete(tempContext);
135         }
136     }
137     return true;
138 }
139 
CancelAll() const140 void ContextPoolImpl::CancelAll() const
141 {
142     IAM_LOGI("start");
143     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
144     for (const auto &context : contextMap_) {
145         if (context.second == nullptr) {
146             continue;
147         }
148         IAM_LOGI("cancel context %{public}s", GET_MASKED_STRING(context.second->GetContextId()).c_str());
149         if (!context.second->Stop()) {
150             IAM_LOGE("cancel context %{public}s fail", GET_MASKED_STRING(context.second->GetContextId()).c_str());
151         }
152     }
153 }
154 
Select(uint64_t contextId) const155 std::weak_ptr<Context> ContextPoolImpl::Select(uint64_t contextId) const
156 {
157     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
158     std::weak_ptr<Context> result;
159     auto iter = contextMap_.find(contextId);
160     if (iter != contextMap_.end()) {
161         result = iter->second;
162     }
163     return result;
164 }
165 
Select(ContextType contextType) const166 std::vector<std::weak_ptr<Context>> ContextPoolImpl::Select(ContextType contextType) const
167 {
168     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
169     std::vector<std::weak_ptr<Context>> result;
170     for (const auto &context : contextMap_) {
171         if (context.second == nullptr) {
172             continue;
173         }
174         if (context.second->GetContextType() == contextType) {
175             result.emplace_back(context.second);
176         }
177     }
178     return result;
179 }
180 
SelectScheduleNodeByScheduleId(uint64_t scheduleId)181 std::shared_ptr<ScheduleNode> ContextPoolImpl::SelectScheduleNodeByScheduleId(uint64_t scheduleId)
182 {
183     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
184     for (const auto &context : contextMap_) {
185         if (context.second == nullptr) {
186             continue;
187         }
188         auto node = context.second->GetScheduleNode(scheduleId);
189         if (node != nullptr) {
190             return node;
191         }
192     }
193 
194     IAM_LOGE("not found");
195     return nullptr;
196 }
197 
RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> & listener)198 bool ContextPoolImpl::RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener)
199 {
200     if (listener == nullptr) {
201         IAM_LOGE("listener is nullptr");
202         return false;
203     }
204     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
205     listenerSet_.insert(listener);
206     return true;
207 }
208 
DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> & listener)209 bool ContextPoolImpl::DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener)
210 {
211     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
212     return listenerSet_.erase(listener) == 1;
213 }
214 
StartSubscribeOsAccountSaStatus()215 void ContextPoolImpl::StartSubscribeOsAccountSaStatus()
216 {
217     IAM_LOGI("start");
218     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
219     if (accountSaStatusListener_ != nullptr) {
220         return;
221     }
222     accountSaStatusListener_ = SystemAbilityListener::Subscribe(
223         "OsAccountService", SUBSYS_ACCOUNT_SYS_ABILITY_ID_BEGIN,
224         []() {},
225         []() { ContextPool::Instance().CancelAll(); });
226     IF_FALSE_LOGE_AND_RETURN(accountSaStatusListener_ != nullptr);
227 }
228 
Instance()229 ContextPool &ContextPool::Instance()
230 {
231     return ContextPoolImpl::GetInstance();
232 }
233 
GetNewContextId()234 uint64_t ContextPool::GetNewContextId()
235 {
236     static constexpr uint32_t MAX_TRY_TIMES = 10;
237     static std::mutex mutex;
238     std::lock_guard<std::mutex> lock(mutex);
239     uint64_t contextId = 0;
240     unsigned char *contextIdPtr = static_cast<unsigned char *>(static_cast<void *>(&contextId));
241     for (uint32_t i = 0; i < MAX_TRY_TIMES; i++) {
242         bool genRandRet = GenerateRand(contextIdPtr, sizeof(uint64_t));
243         if (!genRandRet) {
244             IAM_LOGE("generate rand fail");
245             return 0;
246         }
247         if (contextId == 0 || contextId == REUSE_AUTH_RESULT_CONTEXT_ID ||
248             ContextPool::Instance().Select(contextId).lock() != nullptr) {
249             IAM_LOGE("invalid or duplicate context id");
250             continue;
251         }
252         break;
253     }
254     return contextId;
255 }
256 } // namespace UserAuth
257 } // namespace UserIam
258 } // namespace OHOS
259