1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "context_pool.h"
17
18 #include <mutex>
19 #include <set>
20 #include <singleton.h>
21 #include <unordered_map>
22
23 #include <openssl/rand.h>
24
25 #include "iam_logger.h"
26 #include "iam_para2str.h"
27
28 #define LOG_LABEL UserIam::Common::LABEL_USER_AUTH_SA
29
30 namespace OHOS {
31 namespace UserIam {
32 namespace UserAuth {
33 class ContextPoolImpl final : public ContextPool, public Singleton<ContextPoolImpl> {
34 public:
35 bool Insert(const std::shared_ptr<Context> &context) override;
36 bool Delete(uint64_t contextId) override;
37 void CancelAll() const override;
38 std::weak_ptr<Context> Select(uint64_t contextId) const override;
39 std::vector<std::weak_ptr<Context>> Select(ContextType contextType) const override;
40 std::shared_ptr<ScheduleNode> SelectScheduleNodeByScheduleId(uint64_t scheduleId) override;
41 bool RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener) override;
42 bool DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener) override;
43
44 private:
45 mutable std::recursive_mutex poolMutex_;
46 std::unordered_map<uint64_t, std::shared_ptr<Context>> contextMap_;
47 std::set<std::shared_ptr<ContextPoolListener>> listenerSet_;
48 };
49
Insert(const std::shared_ptr<Context> & context)50 bool ContextPoolImpl::Insert(const std::shared_ptr<Context> &context)
51 {
52 if (context == nullptr) {
53 IAM_LOGE("context is nullptr");
54 return false;
55 }
56 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
57 uint64_t contextId = context->GetContextId();
58 auto result = contextMap_.try_emplace(contextId, context);
59 if (!result.second) {
60 return false;
61 }
62 for (const auto &listener : listenerSet_) {
63 if (listener != nullptr) {
64 listener->OnContextPoolInsert(context);
65 }
66 }
67 return true;
68 }
69
Delete(uint64_t contextId)70 bool ContextPoolImpl::Delete(uint64_t contextId)
71 {
72 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
73 auto iter = contextMap_.find(contextId);
74 if (iter == contextMap_.end()) {
75 IAM_LOGE("context not found");
76 return false;
77 }
78 auto tempContext = iter->second;
79 contextMap_.erase(iter);
80 for (const auto &listener : listenerSet_) {
81 if (listener != nullptr) {
82 listener->OnContextPoolDelete(tempContext);
83 }
84 }
85 return true;
86 }
87
CancelAll() const88 void ContextPoolImpl::CancelAll() const
89 {
90 IAM_LOGI("start");
91 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
92 for (const auto &context : contextMap_) {
93 if (context.second == nullptr) {
94 continue;
95 }
96 IAM_LOGI("cancel context %{public}s", GET_MASKED_STRING(context.second->GetContextId()).c_str());
97 if (!context.second->Stop()) {
98 IAM_LOGE("cancel context %{public}s fail", GET_MASKED_STRING(context.second->GetContextId()).c_str());
99 }
100 }
101 }
102
Select(uint64_t contextId) const103 std::weak_ptr<Context> ContextPoolImpl::Select(uint64_t contextId) const
104 {
105 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
106 std::weak_ptr<Context> result;
107 auto iter = contextMap_.find(contextId);
108 if (iter != contextMap_.end()) {
109 result = iter->second;
110 }
111 return result;
112 }
113
Select(ContextType contextType) const114 std::vector<std::weak_ptr<Context>> ContextPoolImpl::Select(ContextType contextType) const
115 {
116 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
117 std::vector<std::weak_ptr<Context>> result;
118 for (const auto &context : contextMap_) {
119 if (context.second == nullptr) {
120 continue;
121 }
122 if (context.second->GetContextType() == contextType) {
123 result.emplace_back(context.second);
124 }
125 }
126 return result;
127 }
128
SelectScheduleNodeByScheduleId(uint64_t scheduleId)129 std::shared_ptr<ScheduleNode> ContextPoolImpl::SelectScheduleNodeByScheduleId(uint64_t scheduleId)
130 {
131 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
132 for (const auto &context : contextMap_) {
133 if (context.second == nullptr) {
134 continue;
135 }
136 auto node = context.second->GetScheduleNode(scheduleId);
137 if (node != nullptr) {
138 return node;
139 }
140 }
141 return nullptr;
142 }
143
RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> & listener)144 bool ContextPoolImpl::RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener)
145 {
146 if (listener == nullptr) {
147 IAM_LOGE("listener is nullptr");
148 return false;
149 }
150 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
151 listenerSet_.insert(listener);
152 return true;
153 }
154
DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> & listener)155 bool ContextPoolImpl::DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener)
156 {
157 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
158 return listenerSet_.erase(listener) == 1;
159 }
160
Instance()161 ContextPool &ContextPool::Instance()
162 {
163 return ContextPoolImpl::GetInstance();
164 }
165
GetNewContextId()166 uint64_t ContextPool::GetNewContextId()
167 {
168 static constexpr uint32_t MAX_TRY_TIMES = 10;
169 static std::mutex mutex;
170 std::lock_guard<std::mutex> lock(mutex);
171 uint64_t contextId = 0;
172 unsigned char *contextIdPtr = static_cast<unsigned char *>(static_cast<void *>(&contextId));
173 for (uint32_t i = 0; i < MAX_TRY_TIMES; i++) {
174 RAND_bytes(contextIdPtr, sizeof(uint64_t));
175 if (contextId == 0 || ContextPool::Instance().Select(contextId).lock() != nullptr) {
176 IAM_LOGE("invalid or duplicate context id");
177 continue;
178 }
179 }
180 return contextId;
181 }
182 } // namespace UserAuth
183 } // namespace UserIam
184 } // namespace OHOS
185