1 /*
2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "access_token_db.h"
17
18 #include <algorithm>
19 #include <cinttypes>
20 #include <mutex>
21
22 #include "accesstoken_common_log.h"
23 #include "access_token_error.h"
24 #include "access_token_open_callback.h"
25 #include "rdb_helper.h"
26 #include "time_util.h"
27 #include "token_field_const.h"
28
29 namespace OHOS {
30 namespace Security {
31 namespace AccessToken {
32 namespace {
33 constexpr const char* DATABASE_NAME = "access_token.db";
34 constexpr const char* ACCESSTOKEN_SERVICE_NAME = "accesstoken_service";
35 std::recursive_mutex g_instanceMutex;
36 }
37
GetInstance()38 AccessTokenDb& AccessTokenDb::GetInstance()
39 {
40 static AccessTokenDb* instance = nullptr;
41 if (instance == nullptr) {
42 std::lock_guard<std::recursive_mutex> lock(g_instanceMutex);
43 if (instance == nullptr) {
44 AccessTokenDb* tmp = new AccessTokenDb();
45 instance = std::move(tmp);
46 }
47 }
48 return *instance;
49 }
50
AccessTokenDb()51 AccessTokenDb::AccessTokenDb()
52 {
53 InitRdb();
54 }
55
RestoreAndInsertIfCorrupt(const int32_t resultCode,int64_t & outInsertNum,const std::string & tableName,const std::vector<NativeRdb::ValuesBucket> & buckets,const std::shared_ptr<NativeRdb::RdbStore> & db)56 int32_t AccessTokenDb::RestoreAndInsertIfCorrupt(const int32_t resultCode, int64_t& outInsertNum,
57 const std::string& tableName, const std::vector<NativeRdb::ValuesBucket>& buckets,
58 const std::shared_ptr<NativeRdb::RdbStore>& db)
59 {
60 if (resultCode != NativeRdb::E_SQLITE_CORRUPT) {
61 return resultCode;
62 }
63
64 LOGW(ATM_DOMAIN, ATM_TAG, "Detech database corrupt, restore from backup!");
65 int32_t res = db->Restore("");
66 if (res != NativeRdb::E_OK) {
67 LOGE(ATM_DOMAIN, ATM_TAG, "Db restore failed, res is %{public}d.", res);
68 return res;
69 }
70 LOGI(ATM_DOMAIN, ATM_TAG, "Database restore success, try insert again!");
71
72 res = db->BatchInsert(outInsertNum, tableName, buckets);
73 if (res != NativeRdb::E_OK) {
74 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to batch insert into table %{public}s again, res is %{public}d.",
75 tableName.c_str(), res);
76 return res;
77 }
78
79 return 0;
80 }
81
InitRdb()82 void AccessTokenDb::InitRdb()
83 {
84 std::string dbPath = std::string(DATABASE_PATH) + std::string(DATABASE_NAME);
85 NativeRdb::RdbStoreConfig config(dbPath);
86 config.SetSecurityLevel(NativeRdb::SecurityLevel::S3);
87 config.SetAllowRebuild(true);
88 config.SetHaMode(NativeRdb::HAMode::MAIN_REPLICA); // Real-time dual-write backup database
89 config.SetServiceName(std::string(ACCESSTOKEN_SERVICE_NAME));
90 AccessTokenOpenCallback callback;
91 int32_t res = NativeRdb::E_OK;
92 // pragma user_version will done by rdb, they store path and db_ as pair in RdbStoreManager
93 db_ = NativeRdb::RdbHelper::GetRdbStore(config, DATABASE_VERSION_5, callback, res);
94 if ((res != NativeRdb::E_OK) || (db_ == nullptr)) {
95 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to init rdb, res is %{public}d.", res);
96 }
97 }
98
GetRdb()99 std::shared_ptr<NativeRdb::RdbStore> AccessTokenDb::GetRdb()
100 {
101 std::lock_guard<std::mutex> lock(dbLock_);
102 if (db_ == nullptr) {
103 InitRdb();
104 }
105 return db_;
106 }
107
AddValues(const AtmDataType type,const std::vector<GenericValues> & addValues)108 int32_t AccessTokenDb::AddValues(const AtmDataType type, const std::vector<GenericValues>& addValues)
109 {
110 std::string tableName;
111 AccessTokenDbUtil::GetTableNameByType(type, tableName);
112 if (tableName.empty()) {
113 LOGE(ATM_DOMAIN, ATM_TAG, "Table name is empty.");
114 return AccessTokenError::ERR_PARAM_INVALID;
115 }
116
117 // if nothing to insert, no need to call BatchInsert
118 if (addValues.empty()) {
119 return 0;
120 }
121
122 std::shared_ptr<NativeRdb::RdbStore> db = GetRdb();
123 if (db == nullptr) {
124 LOGE(ATM_DOMAIN, ATM_TAG, "db is nullptr.");
125 return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
126 }
127
128 // fill buckets with addValues
129 int64_t outInsertNum = 0;
130 std::vector<NativeRdb::ValuesBucket> buckets;
131 AccessTokenDbUtil::ToRdbValueBuckets(addValues, buckets);
132
133 int32_t res = db->BatchInsert(outInsertNum, tableName, buckets);
134 if (res != NativeRdb::E_OK) {
135 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to batch insert into table %{public}s, res is %{public}d.",
136 tableName.c_str(), res);
137 int32_t result = RestoreAndInsertIfCorrupt(res, outInsertNum, tableName, buckets, db);
138 if (result != NativeRdb::E_OK) {
139 return result;
140 }
141 }
142 if (outInsertNum <= 0) { // rdb bug, adapt it
143 LOGE(ATM_DOMAIN, ATM_TAG, "Insert count %{public}" PRId64 " abnormal.", outInsertNum);
144 return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
145 }
146
147 LOGI(ATM_DOMAIN, ATM_TAG, "Batch insert %{public}" PRId64 " records to table %{public}s.", outInsertNum,
148 tableName.c_str());
149
150 return 0;
151 }
152
RestoreAndDeleteIfCorrupt(const int32_t resultCode,int32_t & deletedRows,const NativeRdb::RdbPredicates & predicates,const std::shared_ptr<NativeRdb::RdbStore> & db)153 int32_t AccessTokenDb::RestoreAndDeleteIfCorrupt(const int32_t resultCode, int32_t& deletedRows,
154 const NativeRdb::RdbPredicates& predicates, const std::shared_ptr<NativeRdb::RdbStore>& db)
155 {
156 if (resultCode != NativeRdb::E_SQLITE_CORRUPT) {
157 return resultCode;
158 }
159
160 LOGW(ATM_DOMAIN, ATM_TAG, "Detech database corrupt, restore from backup!");
161 int32_t res = db->Restore("");
162 if (res != NativeRdb::E_OK) {
163 LOGE(ATM_DOMAIN, ATM_TAG, "Db restore failed, res is %{public}d.", res);
164 return res;
165 }
166 LOGI(ATM_DOMAIN, ATM_TAG, "Database restore success, try delete again!");
167
168 res = db->Delete(deletedRows, predicates);
169 if (res != NativeRdb::E_OK) {
170 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to delete record from table %{public}s again, res is %{public}d.",
171 predicates.GetTableName().c_str(), res);
172 return res;
173 }
174
175 return 0;
176 }
177
RemoveValues(const AtmDataType type,const GenericValues & conditionValue)178 int32_t AccessTokenDb::RemoveValues(const AtmDataType type, const GenericValues& conditionValue)
179 {
180 std::string tableName;
181 AccessTokenDbUtil::GetTableNameByType(type, tableName);
182 if (tableName.empty()) {
183 LOGE(ATM_DOMAIN, ATM_TAG, "Table name is empty.");
184 return AccessTokenError::ERR_PARAM_INVALID;
185 }
186
187 std::shared_ptr<NativeRdb::RdbStore> db = GetRdb();
188 if (db == nullptr) {
189 LOGE(ATM_DOMAIN, ATM_TAG, "db is nullptr.");
190 return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
191 }
192
193 int32_t deletedRows = 0;
194 NativeRdb::RdbPredicates predicates(tableName);
195 AccessTokenDbUtil::ToRdbPredicates(conditionValue, predicates);
196
197 int32_t res = db->Delete(deletedRows, predicates);
198 if (res != NativeRdb::E_OK) {
199 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to delete record from table %{public}s, res is %{public}d.",
200 tableName.c_str(), res);
201 int32_t result = RestoreAndDeleteIfCorrupt(res, deletedRows, predicates, db);
202 if (result != NativeRdb::E_OK) {
203 return result;
204 }
205 }
206
207 LOGI(ATM_DOMAIN, ATM_TAG, "Delete %{public}d records from table %{public}s.", deletedRows, tableName.c_str());
208
209 return 0;
210 }
211
RestoreAndUpdateIfCorrupt(const int32_t resultCode,int32_t & changedRows,const NativeRdb::ValuesBucket & bucket,const NativeRdb::RdbPredicates & predicates,const std::shared_ptr<NativeRdb::RdbStore> & db)212 int32_t AccessTokenDb::RestoreAndUpdateIfCorrupt(const int32_t resultCode, int32_t& changedRows,
213 const NativeRdb::ValuesBucket& bucket, const NativeRdb::RdbPredicates& predicates,
214 const std::shared_ptr<NativeRdb::RdbStore>& db)
215 {
216 if (resultCode != NativeRdb::E_SQLITE_CORRUPT) {
217 return resultCode;
218 }
219
220 LOGW(ATM_DOMAIN, ATM_TAG, "Detech database corrupt, restore from backup!");
221 int32_t res = db->Restore("");
222 if (res != NativeRdb::E_OK) {
223 LOGE(ATM_DOMAIN, ATM_TAG, "Db restore failed, res is %{public}d.", res);
224 return res;
225 }
226 LOGI(ATM_DOMAIN, ATM_TAG, "Database restore success, try update again!");
227
228 res = db->Update(changedRows, bucket, predicates);
229 if (res != NativeRdb::E_OK) {
230 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to update record from table %{public}s again, res is %{public}d.",
231 predicates.GetTableName().c_str(), res);
232 return res;
233 }
234
235 return 0;
236 }
237
Modify(const AtmDataType type,const GenericValues & modifyValue,const GenericValues & conditionValue)238 int32_t AccessTokenDb::Modify(const AtmDataType type, const GenericValues& modifyValue,
239 const GenericValues& conditionValue)
240 {
241 int64_t beginTime = TimeUtil::GetCurrentTimestamp();
242 std::string tableName;
243 AccessTokenDbUtil::GetTableNameByType(type, tableName);
244 if (tableName.empty()) {
245 return AccessTokenError::ERR_PARAM_INVALID;
246 }
247
248 NativeRdb::ValuesBucket bucket;
249
250 AccessTokenDbUtil::ToRdbValueBucket(modifyValue, bucket);
251 if (bucket.IsEmpty()) {
252 return AccessTokenError::ERR_PARAM_INVALID;
253 }
254
255 NativeRdb::RdbPredicates predicates(tableName);
256 AccessTokenDbUtil::ToRdbPredicates(conditionValue, predicates);
257
258 int32_t changedRows = 0;
259 {
260 OHOS::Utils::UniqueWriteGuard<OHOS::Utils::RWLock> lock(this->rwLock_);
261 auto db = GetRdb();
262 if (db == nullptr) {
263 LOGE(ATM_DOMAIN, ATM_TAG, "db is nullptr.");
264 return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
265 }
266
267 int32_t res = db->Update(changedRows, bucket, predicates);
268 if (res != NativeRdb::E_OK) {
269 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to update record from table %{public}s, res is %{public}d.",
270 tableName.c_str(), res);
271 int32_t result = RestoreAndUpdateIfCorrupt(res, changedRows, bucket, predicates, db);
272 if (result != NativeRdb::E_OK) {
273 return result;
274 }
275 }
276 }
277
278 int64_t endTime = TimeUtil::GetCurrentTimestamp();
279 LOGI(ATM_DOMAIN, ATM_TAG, "Modify cost %{public}" PRId64
280 ", update %{public}d records from table %{public}s.", endTime - beginTime, changedRows, tableName.c_str());
281
282 return 0;
283 }
284
RestoreAndQueryIfCorrupt(const NativeRdb::RdbPredicates & predicates,const std::vector<std::string> & columns,std::shared_ptr<NativeRdb::AbsSharedResultSet> & queryResultSet,const std::shared_ptr<NativeRdb::RdbStore> & db)285 int32_t AccessTokenDb::RestoreAndQueryIfCorrupt(const NativeRdb::RdbPredicates& predicates,
286 const std::vector<std::string>& columns, std::shared_ptr<NativeRdb::AbsSharedResultSet>& queryResultSet,
287 const std::shared_ptr<NativeRdb::RdbStore>& db)
288 {
289 int32_t count = 0;
290 int32_t res = queryResultSet->GetRowCount(count);
291 if (res != NativeRdb::E_OK) {
292 if (res == NativeRdb::E_SQLITE_CORRUPT) {
293 queryResultSet->Close();
294 queryResultSet = nullptr;
295
296 LOGW(ATM_DOMAIN, ATM_TAG, "Detech database corrupt, restore from backup!");
297 res = db->Restore("");
298 if (res != NativeRdb::E_OK) {
299 LOGE(ATM_DOMAIN, ATM_TAG, "Db restore failed, res is %{public}d.", res);
300 return res;
301 }
302 LOGI(ATM_DOMAIN, ATM_TAG, "Database restore success, try query again!");
303
304 queryResultSet = db->Query(predicates, columns);
305 if (queryResultSet == nullptr) {
306 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to find records from table %{public}s again.",
307 predicates.GetTableName().c_str());
308 return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
309 }
310 } else {
311 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to get result count.");
312 return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
313 }
314 }
315
316 return 0;
317 }
318
Find(AtmDataType type,const GenericValues & conditionValue,std::vector<GenericValues> & results)319 int32_t AccessTokenDb::Find(AtmDataType type, const GenericValues& conditionValue,
320 std::vector<GenericValues>& results)
321 {
322 int64_t beginTime = TimeUtil::GetCurrentTimestamp();
323 std::string tableName;
324 AccessTokenDbUtil::GetTableNameByType(type, tableName);
325 if (tableName.empty()) {
326 return AccessTokenError::ERR_PARAM_INVALID;
327 }
328
329 NativeRdb::RdbPredicates predicates(tableName);
330 AccessTokenDbUtil::ToRdbPredicates(conditionValue, predicates);
331
332 std::vector<std::string> columns; // empty columns means query all columns
333 int count = 0;
334 {
335 OHOS::Utils::UniqueReadGuard<OHOS::Utils::RWLock> lock(this->rwLock_);
336 auto db = GetRdb();
337 if (db == nullptr) {
338 LOGE(ATM_DOMAIN, ATM_TAG, "db is nullptr.");
339 return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
340 }
341
342 auto queryResultSet = db->Query(predicates, columns);
343 if (queryResultSet == nullptr) {
344 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to find records from table %{public}s.",
345 tableName.c_str());
346 return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
347 }
348
349 int32_t res = RestoreAndQueryIfCorrupt(predicates, columns, queryResultSet, db);
350 if (res != 0) {
351 return res;
352 }
353
354 while (queryResultSet->GoToNextRow() == NativeRdb::E_OK) {
355 GenericValues value;
356 AccessTokenDbUtil::ResultToGenericValues(queryResultSet, value);
357 if (value.GetAllKeys().empty()) {
358 continue;
359 }
360
361 results.emplace_back(value);
362 count++;
363 }
364 }
365
366 int64_t endTime = TimeUtil::GetCurrentTimestamp();
367 LOGI(ATM_DOMAIN, ATM_TAG, "Find cost %{public}" PRId64
368 ", query %{public}d records from table %{public}s.", endTime - beginTime, count, tableName.c_str());
369
370 return 0;
371 }
372
RestoreAndCommitIfCorrupt(const int32_t resultCode,const std::shared_ptr<NativeRdb::RdbStore> & db)373 int32_t AccessTokenDb::RestoreAndCommitIfCorrupt(const int32_t resultCode,
374 const std::shared_ptr<NativeRdb::RdbStore>& db)
375 {
376 if (resultCode != NativeRdb::E_SQLITE_CORRUPT) {
377 return resultCode;
378 }
379
380 LOGW(ATM_DOMAIN, ATM_TAG, "Detech database corrupt, restore from backup!");
381 int32_t res = db->Restore("");
382 if (res != NativeRdb::E_OK) {
383 LOGE(ATM_DOMAIN, ATM_TAG, "Db restore failed, res is %{public}d.", res);
384 return res;
385 }
386 LOGI(ATM_DOMAIN, ATM_TAG, "Database restore success, try commit again!");
387
388 res = db->Commit();
389 if (res != NativeRdb::E_OK) {
390 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to Commit again, res is %{public}d.", res);
391 return res;
392 }
393
394 return NativeRdb::E_OK;
395 }
396
DeleteAndInsertValues(const std::vector<AtmDataType> & delDataTypes,const std::vector<GenericValues> & delValues,const std::vector<AtmDataType> & addDataTypes,const std::vector<std::vector<GenericValues>> & addValues)397 int32_t AccessTokenDb::DeleteAndInsertValues(
398 const std::vector<AtmDataType>& delDataTypes, const std::vector<GenericValues>& delValues,
399 const std::vector<AtmDataType>& addDataTypes, const std::vector<std::vector<GenericValues>>& addValues)
400 {
401 int64_t beginTime = TimeUtil::GetCurrentTimestamp();
402
403 {
404 OHOS::Utils::UniqueWriteGuard<OHOS::Utils::RWLock> lock(this->rwLock_);
405 std::shared_ptr<NativeRdb::RdbStore> db = GetRdb();
406 if (db == nullptr) {
407 LOGE(ATM_DOMAIN, ATM_TAG, "db is nullptr.");
408 return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
409 }
410
411 db->BeginTransaction();
412
413 int32_t res = 0;
414 size_t count = delDataTypes.size();
415 for (size_t i = 0; i < count; ++i) {
416 res = RemoveValues(delDataTypes[i], delValues[i]);
417 if (res != 0) {
418 db->RollBack();
419 return res;
420 }
421 }
422
423 count = addDataTypes.size();
424 for (size_t i = 0; i < count; ++i) {
425 res = AddValues(addDataTypes[i], addValues[i]);
426 if (res != 0) {
427 db->RollBack();
428 return res;
429 }
430 }
431
432 res = db->Commit();
433 if (res != NativeRdb::E_OK) {
434 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to commit, res is %{public}d.", res);
435 int32_t result = RestoreAndCommitIfCorrupt(res, db);
436 if (result != NativeRdb::E_OK) {
437 return result;
438 }
439 }
440 }
441
442 int64_t endTime = TimeUtil::GetCurrentTimestamp();
443 LOGI(ATM_DOMAIN, ATM_TAG, "DeleteAndInsertNative cost %{public}" PRId64 ".", endTime - beginTime);
444
445 return 0;
446 }
447 } // namespace AccessToken
448 } // namespace Security
449 } // namespace OHOS
450