• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "access_token_db.h"
17 
18 #include <algorithm>
19 #include <cinttypes>
20 #include <mutex>
21 
22 #include "accesstoken_common_log.h"
23 #include "access_token_error.h"
24 #include "access_token_open_callback.h"
25 #include "rdb_helper.h"
26 #include "time_util.h"
27 #include "token_field_const.h"
28 
29 namespace OHOS {
30 namespace Security {
31 namespace AccessToken {
32 namespace {
33 constexpr const char* DATABASE_NAME = "access_token.db";
34 constexpr const char* ACCESSTOKEN_SERVICE_NAME = "accesstoken_service";
35 std::recursive_mutex g_instanceMutex;
36 }
37 
GetInstance()38 AccessTokenDb& AccessTokenDb::GetInstance()
39 {
40     static AccessTokenDb* instance = nullptr;
41     if (instance == nullptr) {
42         std::lock_guard<std::recursive_mutex> lock(g_instanceMutex);
43         if (instance == nullptr) {
44             AccessTokenDb* tmp = new AccessTokenDb();
45             instance = std::move(tmp);
46         }
47     }
48     return *instance;
49 }
50 
AccessTokenDb()51 AccessTokenDb::AccessTokenDb()
52 {
53     InitRdb();
54 }
55 
RestoreAndInsertIfCorrupt(const int32_t resultCode,int64_t & outInsertNum,const std::string & tableName,const std::vector<NativeRdb::ValuesBucket> & buckets,const std::shared_ptr<NativeRdb::RdbStore> & db)56 int32_t AccessTokenDb::RestoreAndInsertIfCorrupt(const int32_t resultCode, int64_t& outInsertNum,
57     const std::string& tableName, const std::vector<NativeRdb::ValuesBucket>& buckets,
58     const std::shared_ptr<NativeRdb::RdbStore>& db)
59 {
60     if (resultCode != NativeRdb::E_SQLITE_CORRUPT) {
61         return resultCode;
62     }
63 
64     LOGW(ATM_DOMAIN, ATM_TAG, "Detech database corrupt, restore from backup!");
65     int32_t res = db->Restore("");
66     if (res != NativeRdb::E_OK) {
67         LOGE(ATM_DOMAIN, ATM_TAG, "Db restore failed, res is %{public}d.", res);
68         return res;
69     }
70     LOGI(ATM_DOMAIN, ATM_TAG, "Database restore success, try insert again!");
71 
72     res = db->BatchInsert(outInsertNum, tableName, buckets);
73     if (res != NativeRdb::E_OK) {
74         LOGE(ATM_DOMAIN, ATM_TAG, "Failed to batch insert into table %{public}s again, res is %{public}d.",
75             tableName.c_str(), res);
76         return res;
77     }
78 
79     return 0;
80 }
81 
InitRdb()82 void AccessTokenDb::InitRdb()
83 {
84     std::string dbPath = std::string(DATABASE_PATH) + std::string(DATABASE_NAME);
85     NativeRdb::RdbStoreConfig config(dbPath);
86     config.SetSecurityLevel(NativeRdb::SecurityLevel::S3);
87     config.SetAllowRebuild(true);
88     config.SetHaMode(NativeRdb::HAMode::MAIN_REPLICA); // Real-time dual-write backup database
89     config.SetServiceName(std::string(ACCESSTOKEN_SERVICE_NAME));
90     AccessTokenOpenCallback callback;
91     int32_t res = NativeRdb::E_OK;
92     // pragma user_version will done by rdb, they store path and db_ as pair in RdbStoreManager
93     db_ = NativeRdb::RdbHelper::GetRdbStore(config, DATABASE_VERSION_5, callback, res);
94     if ((res != NativeRdb::E_OK) || (db_ == nullptr)) {
95         LOGE(ATM_DOMAIN, ATM_TAG, "Failed to init rdb, res is %{public}d.", res);
96     }
97 }
98 
GetRdb()99 std::shared_ptr<NativeRdb::RdbStore> AccessTokenDb::GetRdb()
100 {
101     std::lock_guard<std::mutex> lock(dbLock_);
102     if (db_ == nullptr) {
103         InitRdb();
104     }
105     return db_;
106 }
107 
AddValues(const AtmDataType type,const std::vector<GenericValues> & addValues)108 int32_t AccessTokenDb::AddValues(const AtmDataType type, const std::vector<GenericValues>& addValues)
109 {
110     std::string tableName;
111     AccessTokenDbUtil::GetTableNameByType(type, tableName);
112     if (tableName.empty()) {
113         LOGE(ATM_DOMAIN, ATM_TAG, "Table name is empty.");
114         return AccessTokenError::ERR_PARAM_INVALID;
115     }
116 
117     // if nothing to insert, no need to call BatchInsert
118     if (addValues.empty()) {
119         return 0;
120     }
121 
122     std::shared_ptr<NativeRdb::RdbStore> db = GetRdb();
123     if (db == nullptr) {
124         LOGE(ATM_DOMAIN, ATM_TAG, "db is nullptr.");
125         return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
126     }
127 
128     // fill buckets with addValues
129     int64_t outInsertNum = 0;
130     std::vector<NativeRdb::ValuesBucket> buckets;
131     AccessTokenDbUtil::ToRdbValueBuckets(addValues, buckets);
132 
133     int32_t res = db->BatchInsert(outInsertNum, tableName, buckets);
134     if (res != NativeRdb::E_OK) {
135         LOGE(ATM_DOMAIN, ATM_TAG, "Failed to batch insert into table %{public}s, res is %{public}d.",
136             tableName.c_str(), res);
137         int32_t result = RestoreAndInsertIfCorrupt(res, outInsertNum, tableName, buckets, db);
138         if (result != NativeRdb::E_OK) {
139             return result;
140         }
141     }
142     if (outInsertNum <= 0) { // rdb bug, adapt it
143         LOGE(ATM_DOMAIN, ATM_TAG, "Insert count %{public}" PRId64 " abnormal.", outInsertNum);
144         return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
145     }
146 
147     LOGI(ATM_DOMAIN, ATM_TAG, "Batch insert %{public}" PRId64 " records to table %{public}s.", outInsertNum,
148         tableName.c_str());
149 
150     return 0;
151 }
152 
RestoreAndDeleteIfCorrupt(const int32_t resultCode,int32_t & deletedRows,const NativeRdb::RdbPredicates & predicates,const std::shared_ptr<NativeRdb::RdbStore> & db)153 int32_t AccessTokenDb::RestoreAndDeleteIfCorrupt(const int32_t resultCode, int32_t& deletedRows,
154     const NativeRdb::RdbPredicates& predicates, const std::shared_ptr<NativeRdb::RdbStore>& db)
155 {
156     if (resultCode != NativeRdb::E_SQLITE_CORRUPT) {
157         return resultCode;
158     }
159 
160     LOGW(ATM_DOMAIN, ATM_TAG, "Detech database corrupt, restore from backup!");
161     int32_t res = db->Restore("");
162     if (res != NativeRdb::E_OK) {
163         LOGE(ATM_DOMAIN, ATM_TAG, "Db restore failed, res is %{public}d.", res);
164         return res;
165     }
166     LOGI(ATM_DOMAIN, ATM_TAG, "Database restore success, try delete again!");
167 
168     res = db->Delete(deletedRows, predicates);
169     if (res != NativeRdb::E_OK) {
170         LOGE(ATM_DOMAIN, ATM_TAG, "Failed to delete record from table %{public}s again, res is %{public}d.",
171             predicates.GetTableName().c_str(), res);
172         return res;
173     }
174 
175     return 0;
176 }
177 
RemoveValues(const AtmDataType type,const GenericValues & conditionValue)178 int32_t AccessTokenDb::RemoveValues(const AtmDataType type, const GenericValues& conditionValue)
179 {
180     std::string tableName;
181     AccessTokenDbUtil::GetTableNameByType(type, tableName);
182     if (tableName.empty()) {
183         LOGE(ATM_DOMAIN, ATM_TAG, "Table name is empty.");
184         return AccessTokenError::ERR_PARAM_INVALID;
185     }
186 
187     std::shared_ptr<NativeRdb::RdbStore> db = GetRdb();
188     if (db == nullptr) {
189         LOGE(ATM_DOMAIN, ATM_TAG, "db is nullptr.");
190         return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
191     }
192 
193     int32_t deletedRows = 0;
194     NativeRdb::RdbPredicates predicates(tableName);
195     AccessTokenDbUtil::ToRdbPredicates(conditionValue, predicates);
196 
197     int32_t res = db->Delete(deletedRows, predicates);
198     if (res != NativeRdb::E_OK) {
199         LOGE(ATM_DOMAIN, ATM_TAG, "Failed to delete record from table %{public}s, res is %{public}d.",
200             tableName.c_str(), res);
201         int32_t result = RestoreAndDeleteIfCorrupt(res, deletedRows, predicates, db);
202         if (result != NativeRdb::E_OK) {
203             return result;
204         }
205     }
206 
207     LOGI(ATM_DOMAIN, ATM_TAG, "Delete %{public}d records from table %{public}s.", deletedRows, tableName.c_str());
208 
209     return 0;
210 }
211 
RestoreAndUpdateIfCorrupt(const int32_t resultCode,int32_t & changedRows,const NativeRdb::ValuesBucket & bucket,const NativeRdb::RdbPredicates & predicates,const std::shared_ptr<NativeRdb::RdbStore> & db)212 int32_t AccessTokenDb::RestoreAndUpdateIfCorrupt(const int32_t resultCode, int32_t& changedRows,
213     const NativeRdb::ValuesBucket& bucket, const NativeRdb::RdbPredicates& predicates,
214     const std::shared_ptr<NativeRdb::RdbStore>& db)
215 {
216     if (resultCode != NativeRdb::E_SQLITE_CORRUPT) {
217         return resultCode;
218     }
219 
220     LOGW(ATM_DOMAIN, ATM_TAG, "Detech database corrupt, restore from backup!");
221     int32_t res = db->Restore("");
222     if (res != NativeRdb::E_OK) {
223         LOGE(ATM_DOMAIN, ATM_TAG, "Db restore failed, res is %{public}d.", res);
224         return res;
225     }
226     LOGI(ATM_DOMAIN, ATM_TAG, "Database restore success, try update again!");
227 
228     res = db->Update(changedRows, bucket, predicates);
229     if (res != NativeRdb::E_OK) {
230         LOGE(ATM_DOMAIN, ATM_TAG, "Failed to update record from table %{public}s again, res is %{public}d.",
231             predicates.GetTableName().c_str(), res);
232         return res;
233     }
234 
235     return 0;
236 }
237 
Modify(const AtmDataType type,const GenericValues & modifyValue,const GenericValues & conditionValue)238 int32_t AccessTokenDb::Modify(const AtmDataType type, const GenericValues& modifyValue,
239     const GenericValues& conditionValue)
240 {
241     int64_t beginTime = TimeUtil::GetCurrentTimestamp();
242     std::string tableName;
243     AccessTokenDbUtil::GetTableNameByType(type, tableName);
244     if (tableName.empty()) {
245         return AccessTokenError::ERR_PARAM_INVALID;
246     }
247 
248     NativeRdb::ValuesBucket bucket;
249 
250     AccessTokenDbUtil::ToRdbValueBucket(modifyValue, bucket);
251     if (bucket.IsEmpty()) {
252         return AccessTokenError::ERR_PARAM_INVALID;
253     }
254 
255     NativeRdb::RdbPredicates predicates(tableName);
256     AccessTokenDbUtil::ToRdbPredicates(conditionValue, predicates);
257 
258     int32_t changedRows = 0;
259     {
260         OHOS::Utils::UniqueWriteGuard<OHOS::Utils::RWLock> lock(this->rwLock_);
261         auto db = GetRdb();
262         if (db == nullptr) {
263             LOGE(ATM_DOMAIN, ATM_TAG, "db is nullptr.");
264             return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
265         }
266 
267         int32_t res = db->Update(changedRows, bucket, predicates);
268         if (res != NativeRdb::E_OK) {
269             LOGE(ATM_DOMAIN, ATM_TAG, "Failed to update record from table %{public}s, res is %{public}d.",
270                 tableName.c_str(), res);
271             int32_t result = RestoreAndUpdateIfCorrupt(res, changedRows, bucket, predicates, db);
272             if (result != NativeRdb::E_OK) {
273                 return result;
274             }
275         }
276     }
277 
278     int64_t endTime = TimeUtil::GetCurrentTimestamp();
279     LOGI(ATM_DOMAIN, ATM_TAG, "Modify cost %{public}" PRId64
280         ", update %{public}d records from table %{public}s.", endTime - beginTime, changedRows, tableName.c_str());
281 
282     return 0;
283 }
284 
RestoreAndQueryIfCorrupt(const NativeRdb::RdbPredicates & predicates,const std::vector<std::string> & columns,std::shared_ptr<NativeRdb::AbsSharedResultSet> & queryResultSet,const std::shared_ptr<NativeRdb::RdbStore> & db)285 int32_t AccessTokenDb::RestoreAndQueryIfCorrupt(const NativeRdb::RdbPredicates& predicates,
286     const std::vector<std::string>& columns, std::shared_ptr<NativeRdb::AbsSharedResultSet>& queryResultSet,
287     const std::shared_ptr<NativeRdb::RdbStore>& db)
288 {
289     int32_t count = 0;
290     int32_t res = queryResultSet->GetRowCount(count);
291     if (res != NativeRdb::E_OK) {
292         if (res == NativeRdb::E_SQLITE_CORRUPT) {
293             queryResultSet->Close();
294             queryResultSet = nullptr;
295 
296             LOGW(ATM_DOMAIN, ATM_TAG, "Detech database corrupt, restore from backup!");
297             res = db->Restore("");
298             if (res != NativeRdb::E_OK) {
299                 LOGE(ATM_DOMAIN, ATM_TAG, "Db restore failed, res is %{public}d.", res);
300                 return res;
301             }
302             LOGI(ATM_DOMAIN, ATM_TAG, "Database restore success, try query again!");
303 
304             queryResultSet = db->Query(predicates, columns);
305             if (queryResultSet == nullptr) {
306                 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to find records from table %{public}s again.",
307                     predicates.GetTableName().c_str());
308                 return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
309             }
310         } else {
311             LOGE(ATM_DOMAIN, ATM_TAG, "Failed to get result count.");
312             return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
313         }
314     }
315 
316     return 0;
317 }
318 
Find(AtmDataType type,const GenericValues & conditionValue,std::vector<GenericValues> & results)319 int32_t AccessTokenDb::Find(AtmDataType type, const GenericValues& conditionValue,
320     std::vector<GenericValues>& results)
321 {
322     int64_t beginTime = TimeUtil::GetCurrentTimestamp();
323     std::string tableName;
324     AccessTokenDbUtil::GetTableNameByType(type, tableName);
325     if (tableName.empty()) {
326         return AccessTokenError::ERR_PARAM_INVALID;
327     }
328 
329     NativeRdb::RdbPredicates predicates(tableName);
330     AccessTokenDbUtil::ToRdbPredicates(conditionValue, predicates);
331 
332     std::vector<std::string> columns; // empty columns means query all columns
333     int count = 0;
334     {
335         OHOS::Utils::UniqueReadGuard<OHOS::Utils::RWLock> lock(this->rwLock_);
336         auto db = GetRdb();
337         if (db == nullptr) {
338             LOGE(ATM_DOMAIN, ATM_TAG, "db is nullptr.");
339             return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
340         }
341 
342         auto queryResultSet = db->Query(predicates, columns);
343         if (queryResultSet == nullptr) {
344             LOGE(ATM_DOMAIN, ATM_TAG, "Failed to find records from table %{public}s.",
345                 tableName.c_str());
346             return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
347         }
348 
349         int32_t res = RestoreAndQueryIfCorrupt(predicates, columns, queryResultSet, db);
350         if (res != 0) {
351             return res;
352         }
353 
354         while (queryResultSet->GoToNextRow() == NativeRdb::E_OK) {
355             GenericValues value;
356             AccessTokenDbUtil::ResultToGenericValues(queryResultSet, value);
357             if (value.GetAllKeys().empty()) {
358                 continue;
359             }
360 
361             results.emplace_back(value);
362             count++;
363         }
364     }
365 
366     int64_t endTime = TimeUtil::GetCurrentTimestamp();
367     LOGI(ATM_DOMAIN, ATM_TAG, "Find cost %{public}" PRId64
368         ", query %{public}d records from table %{public}s.", endTime - beginTime, count, tableName.c_str());
369 
370     return 0;
371 }
372 
RestoreAndCommitIfCorrupt(const int32_t resultCode,const std::shared_ptr<NativeRdb::RdbStore> & db)373 int32_t AccessTokenDb::RestoreAndCommitIfCorrupt(const int32_t resultCode,
374     const std::shared_ptr<NativeRdb::RdbStore>& db)
375 {
376     if (resultCode != NativeRdb::E_SQLITE_CORRUPT) {
377         return resultCode;
378     }
379 
380     LOGW(ATM_DOMAIN, ATM_TAG, "Detech database corrupt, restore from backup!");
381     int32_t res = db->Restore("");
382     if (res != NativeRdb::E_OK) {
383         LOGE(ATM_DOMAIN, ATM_TAG, "Db restore failed, res is %{public}d.", res);
384         return res;
385     }
386     LOGI(ATM_DOMAIN, ATM_TAG, "Database restore success, try commit again!");
387 
388     res = db->Commit();
389     if (res != NativeRdb::E_OK) {
390         LOGE(ATM_DOMAIN, ATM_TAG, "Failed to Commit again, res is %{public}d.", res);
391         return res;
392     }
393 
394     return NativeRdb::E_OK;
395 }
396 
DeleteAndInsertValues(const std::vector<AtmDataType> & delDataTypes,const std::vector<GenericValues> & delValues,const std::vector<AtmDataType> & addDataTypes,const std::vector<std::vector<GenericValues>> & addValues)397 int32_t AccessTokenDb::DeleteAndInsertValues(
398     const std::vector<AtmDataType>& delDataTypes, const std::vector<GenericValues>& delValues,
399     const std::vector<AtmDataType>& addDataTypes, const std::vector<std::vector<GenericValues>>& addValues)
400 {
401     int64_t beginTime = TimeUtil::GetCurrentTimestamp();
402 
403     {
404         OHOS::Utils::UniqueWriteGuard<OHOS::Utils::RWLock> lock(this->rwLock_);
405         std::shared_ptr<NativeRdb::RdbStore> db = GetRdb();
406         if (db == nullptr) {
407             LOGE(ATM_DOMAIN, ATM_TAG, "db is nullptr.");
408             return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
409         }
410 
411         db->BeginTransaction();
412 
413         int32_t res = 0;
414         size_t count = delDataTypes.size();
415         for (size_t i = 0; i < count; ++i) {
416             res = RemoveValues(delDataTypes[i], delValues[i]);
417             if (res != 0) {
418                 db->RollBack();
419                 return res;
420             }
421         }
422 
423         count = addDataTypes.size();
424         for (size_t i = 0; i < count; ++i) {
425             res = AddValues(addDataTypes[i], addValues[i]);
426             if (res != 0) {
427                 db->RollBack();
428                 return res;
429             }
430         }
431 
432         res = db->Commit();
433         if (res != NativeRdb::E_OK) {
434             LOGE(ATM_DOMAIN, ATM_TAG, "Failed to commit, res is %{public}d.", res);
435             int32_t result = RestoreAndCommitIfCorrupt(res, db);
436             if (result != NativeRdb::E_OK) {
437                 return result;
438             }
439         }
440     }
441 
442     int64_t endTime = TimeUtil::GetCurrentTimestamp();
443     LOGI(ATM_DOMAIN, ATM_TAG, "DeleteAndInsertNative cost %{public}" PRId64 ".", endTime - beginTime);
444 
445     return 0;
446 }
447 } // namespace AccessToken
448 } // namespace Security
449 } // namespace OHOS
450