• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "sqlite_utils.h"
17 
18 #include <climits>
19 #include <cstring>
20 #include <chrono>
21 #include <thread>
22 #include <mutex>
23 #include <map>
24 #include <algorithm>
25 
26 #include "sqlite_import.h"
27 #include "securec.h"
28 #include "db_constant.h"
29 #include "db_common.h"
30 #include "db_errno.h"
31 #include "log_print.h"
32 #include "value_object.h"
33 #include "schema_utils.h"
34 #include "schema_constant.h"
35 #include "time_helper.h"
36 #include "platform_specific.h"
37 #include "sqlite_relational_utils.h"
38 
39 namespace DistributedDB {
40 namespace {
41     const int BUSY_TIMEOUT_MS = 3000; // 3000ms for sqlite busy timeout.
42     const int USING_STR_LEN = -1;
43     const std::string CIPHER_CONFIG_SQL = "PRAGMA codec_cipher=";
44     const std::string KDF_ITER_CONFIG_SQL = "PRAGMA codec_kdf_iter=";
45     const std::string USER_VERSION_SQL = "PRAGMA user_version;";
46     const std::string WAL_MODE_SQL = "PRAGMA journal_mode=WAL;";
47     const std::string SHA1_ALGO_SQL = "PRAGMA codec_hmac_algo=SHA1;";
48     const std::string SHA256_ALGO_REKEY_SQL = "PRAGMA codec_rekey_hmac_algo=SHA256;";
49 }
50 
51 struct ValueParseCache {
52     ValueObject valueParsed;
53     std::vector<uint8_t> valueOriginal;
54 };
55 
56 namespace {
IsDeleteRecord(const uint8_t * valueBlob,int valueBlobLen)57 inline bool IsDeleteRecord(const uint8_t *valueBlob, int valueBlobLen)
58 {
59     return (valueBlob == nullptr) || (valueBlobLen <= 0); // In fact, sqlite guarantee valueBlobLen not negative
60 }
61 
62 // Use the same cache id as sqlite use for json_extract which is substituted by our json_extract_by_path
63 // A negative cache-id enables sharing of cache between different operation during the same statement
64 constexpr int VALUE_CACHE_ID = -429938;
65 
ValueParseCacheFree(ValueParseCache * inCache)66 void ValueParseCacheFree(ValueParseCache *inCache)
67 {
68     if (inCache != nullptr) {
69         delete inCache;
70     }
71 }
72 
73 // We don't use cache array since we only cache value column of sqlite table, see sqlite implementation for compare.
ParseValueThenCacheOrGetFromCache(sqlite3_context * ctx,const uint8_t * valueBlob,uint32_t valueBlobLen,uint32_t offset)74 const ValueObject *ParseValueThenCacheOrGetFromCache(sqlite3_context *ctx, const uint8_t *valueBlob,
75     uint32_t valueBlobLen, uint32_t offset)
76 {
77     // Note: All parameter had already been check inside JsonExtractByPath, only called by JsonExtractByPath
78     auto cached = static_cast<ValueParseCache *>(sqlite3_get_auxdata(ctx, VALUE_CACHE_ID));
79     if (cached != nullptr) { // A previous cache exist
80         if (cached->valueOriginal.size() == valueBlobLen) {
81             if (std::memcmp(cached->valueOriginal.data(), valueBlob, valueBlobLen) == 0) {
82                 // Cache match
83                 return &(cached->valueParsed);
84             }
85         }
86     }
87     // No cache or cache mismatch
88     auto newCache = new (std::nothrow) ValueParseCache;
89     if (newCache == nullptr) {
90         sqlite3_result_error(ctx, "[ParseValueCache] OOM.", USING_STR_LEN);
91         LOGE("[ParseValueCache] OOM.");
92         return nullptr;
93     }
94     int errCode = newCache->valueParsed.Parse(valueBlob, valueBlob + valueBlobLen, offset);
95     if (errCode != E_OK) {
96         sqlite3_result_error(ctx, "[ParseValueCache] Parse fail.", USING_STR_LEN);
97         LOGE("[ParseValueCache] Parse fail, errCode=%d.", errCode);
98         delete newCache;
99         newCache = nullptr;
100         return nullptr;
101     }
102     newCache->valueOriginal.assign(valueBlob, valueBlob + valueBlobLen);
103     sqlite3_set_auxdata(ctx, VALUE_CACHE_ID, newCache, reinterpret_cast<void(*)(void*)>(ValueParseCacheFree));
104     // If sqlite3_set_auxdata fail, it will immediately call ValueParseCacheFree to delete newCache;
105     // Next time sqlite3_set_auxdata will call ValueParseCacheFree to delete newCache of this time;
106     // At the end, newCache will be eventually deleted when call sqlite3_reset or sqlite3_finalize;
107     // Since sqlite3_set_auxdata may fail, we have to call sqlite3_get_auxdata other than return newCache directly.
108     auto cacheInAuxdata = static_cast<ValueParseCache *>(sqlite3_get_auxdata(ctx, VALUE_CACHE_ID));
109     if (cacheInAuxdata == nullptr) {
110         return nullptr;
111     }
112     return &(cacheInAuxdata->valueParsed);
113 }
114 }
115 
JsonExtractByPath(sqlite3_context * ctx,int argc,sqlite3_value ** argv)116 void SQLiteUtils::JsonExtractByPath(sqlite3_context *ctx, int argc, sqlite3_value **argv)
117 {
118     if (ctx == nullptr || argc != 3 || argv == nullptr) { // 3 parameters, which are value, path and offset
119         LOGE("[JsonExtract] Invalid parameter, argc=%d.", argc);
120         return;
121     }
122     auto valueBlob = static_cast<const uint8_t *>(sqlite3_value_blob(argv[0]));
123     int valueBlobLen = sqlite3_value_bytes(argv[0]);
124     if (IsDeleteRecord(valueBlob, valueBlobLen)) {
125         // Currently delete records are filtered out of query and create-index sql, so not allowed here.
126         sqlite3_result_error(ctx, "[JsonExtract] Delete record not allowed.", USING_STR_LEN);
127         LOGE("[JsonExtract] Delete record not allowed.");
128         return;
129     }
130     auto path = reinterpret_cast<const char *>(sqlite3_value_text(argv[1]));
131     int offset = sqlite3_value_int(argv[2]); // index 2 is the third parameter
132     if ((path == nullptr) || (offset < 0)) {
133         sqlite3_result_error(ctx, "[JsonExtract] Path nullptr or offset invalid.", USING_STR_LEN);
134         LOGE("[JsonExtract] Path nullptr or offset=%d invalid.", offset);
135         return;
136     }
137     FieldPath outPath;
138     int errCode = SchemaUtils::ParseAndCheckFieldPath(path, outPath);
139     if (errCode != E_OK) {
140         sqlite3_result_error(ctx, "[JsonExtract] Path illegal.", USING_STR_LEN);
141         LOGE("[JsonExtract] Path illegal.");
142         return;
143     }
144     // Parameter Check Done Here
145     const ValueObject *valueObj = ParseValueThenCacheOrGetFromCache(ctx, valueBlob, static_cast<uint32_t>(valueBlobLen),
146         static_cast<uint32_t>(offset));
147     if (valueObj == nullptr) {
148         return; // Necessary had been printed in ParseValueThenCacheOrGetFromCache
149     }
150     JsonExtractInnerFunc(ctx, *valueObj, outPath);
151 }
152 
153 namespace {
IsExtractableType(FieldType inType)154 inline bool IsExtractableType(FieldType inType)
155 {
156     return (inType != FieldType::LEAF_FIELD_NULL && inType != FieldType::LEAF_FIELD_ARRAY &&
157         inType != FieldType::LEAF_FIELD_OBJECT && inType != FieldType::INTERNAL_FIELD_OBJECT);
158 }
159 }
160 
JsonExtractInnerFunc(sqlite3_context * ctx,const ValueObject & inValue,const FieldPath & inPath)161 void SQLiteUtils::JsonExtractInnerFunc(sqlite3_context *ctx, const ValueObject &inValue, const FieldPath &inPath)
162 {
163     FieldType outType = FieldType::LEAF_FIELD_NULL; // Default type null for invalid-path(path not exist)
164     int errCode = inValue.GetFieldTypeByFieldPath(inPath, outType);
165     if (errCode != E_OK && errCode != -E_INVALID_PATH) {
166         sqlite3_result_error(ctx, "[JsonExtract] GetFieldType fail.", USING_STR_LEN);
167         LOGE("[JsonExtract] GetFieldType fail, errCode=%d.", errCode);
168         return;
169     }
170     FieldValue outValue;
171     if (IsExtractableType(outType)) {
172         errCode = inValue.GetFieldValueByFieldPath(inPath, outValue);
173         if (errCode != E_OK) {
174             sqlite3_result_error(ctx, "[JsonExtract] GetFieldValue fail.", USING_STR_LEN);
175             LOGE("[JsonExtract] GetFieldValue fail, errCode=%d.", errCode);
176             return;
177         }
178     }
179     // FieldType null, array, object do not have value, all these FieldValue will be regarded as null in JsonReturn.
180     ExtractReturn(ctx, outType, outValue);
181 }
182 
183 // NOTE!!! This function is performance sensitive !!! Carefully not to allocate memory often!!!
FlatBufferExtractByPath(sqlite3_context * ctx,int argc,sqlite3_value ** argv)184 void SQLiteUtils::FlatBufferExtractByPath(sqlite3_context *ctx, int argc, sqlite3_value **argv)
185 {
186     if (ctx == nullptr || argc != 3 || argv == nullptr) { // 3 parameters, which are value, path and offset
187         LOGE("[FlatBufferExtract] Invalid parameter, argc=%d.", argc);
188         return;
189     }
190     auto schema = static_cast<SchemaObject *>(sqlite3_user_data(ctx));
191     if (schema == nullptr || !schema->IsSchemaValid() ||
192         (schema->GetSchemaType() != SchemaType::FLATBUFFER)) { // LCOV_EXCL_BR_LINE
193         sqlite3_result_error(ctx, "[FlatBufferExtract] No SchemaObject or invalid.", USING_STR_LEN);
194         LOGE("[FlatBufferExtract] No SchemaObject or invalid.");
195         return;
196     }
197     // Get information from argv
198     auto valueBlob = static_cast<const uint8_t *>(sqlite3_value_blob(argv[0]));
199     int valueBlobLen = sqlite3_value_bytes(argv[0]);
200     if (IsDeleteRecord(valueBlob, valueBlobLen)) { // LCOV_EXCL_BR_LINE
201         // Currently delete records are filtered out of query and create-index sql, so not allowed here.
202         sqlite3_result_error(ctx, "[FlatBufferExtract] Delete record not allowed.", USING_STR_LEN);
203         LOGE("[FlatBufferExtract] Delete record not allowed.");
204         return;
205     }
206     auto path = reinterpret_cast<const char *>(sqlite3_value_text(argv[1]));
207     int offset = sqlite3_value_int(argv[2]); // index 2 is the third parameter
208     if ((path == nullptr) || (offset < 0) ||
209         (static_cast<uint32_t>(offset) != schema->GetSkipSize())) { // LCOV_EXCL_BR_LINE
210         sqlite3_result_error(ctx, "[FlatBufferExtract] Path null or offset invalid.", USING_STR_LEN);
211         LOGE("[FlatBufferExtract] Path null or offset=%d(skipsize=%u) invalid.", offset, schema->GetSkipSize());
212         return;
213     }
214     FlatBufferExtractInnerFunc(ctx, *schema, RawValue { valueBlob, valueBlobLen }, path);
215 }
216 
217 namespace {
218 constexpr uint32_t FLATBUFFER_MAX_CACHE_SIZE = 102400; // 100 KBytes
219 
FlatBufferCacheFree(std::vector<uint8_t> * inCache)220 void FlatBufferCacheFree(std::vector<uint8_t> *inCache)
221 {
222     if (inCache != nullptr) {
223         delete inCache;
224     }
225 }
226 }
227 
FlatBufferExtractInnerFunc(sqlite3_context * ctx,const SchemaObject & schema,const RawValue & inValue,RawString inPath)228 void SQLiteUtils::FlatBufferExtractInnerFunc(sqlite3_context *ctx, const SchemaObject &schema, const RawValue &inValue,
229     RawString inPath)
230 {
231     // All parameter had already been check inside FlatBufferExtractByPath, only called by FlatBufferExtractByPath
232     if (schema.GetSkipSize() % SchemaConstant::SECURE_BYTE_ALIGN == 0) { // LCOV_EXCL_BR_LINE
233         TypeValue outExtract;
234         int errCode = schema.ExtractValue(ValueSource::FROM_DBFILE, inPath, inValue, outExtract, nullptr);
235         if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
236             sqlite3_result_error(ctx, "[FlatBufferExtract] ExtractValue fail.", USING_STR_LEN);
237             LOGE("[FlatBufferExtract] ExtractValue fail, errCode=%d.", errCode);
238             return;
239         }
240         ExtractReturn(ctx, outExtract.first, outExtract.second);
241         return;
242     }
243     // Not byte-align secure, we have to make a cache for copy. Check whether cache had already exist.
244     auto cached = static_cast<std::vector<uint8_t> *>(sqlite3_get_auxdata(ctx, VALUE_CACHE_ID)); // Share the same id
245     if (cached == nullptr) { // LCOV_EXCL_BR_LINE
246         // Make the cache
247         auto newCache = new (std::nothrow) std::vector<uint8_t>;
248         if (newCache == nullptr) {
249             sqlite3_result_error(ctx, "[FlatBufferExtract] OOM.", USING_STR_LEN);
250             LOGE("[FlatBufferExtract] OOM.");
251             return;
252         }
253         newCache->resize(FLATBUFFER_MAX_CACHE_SIZE);
254         sqlite3_set_auxdata(ctx, VALUE_CACHE_ID, newCache, reinterpret_cast<void(*)(void*)>(FlatBufferCacheFree));
255         // If sqlite3_set_auxdata fail, it will immediately call FlatBufferCacheFree to delete newCache;
256         // Next time sqlite3_set_auxdata will call FlatBufferCacheFree to delete newCache of this time;
257         // At the end, newCache will be eventually deleted when call sqlite3_reset or sqlite3_finalize;
258         // Since sqlite3_set_auxdata may fail, we have to call sqlite3_get_auxdata other than return newCache directly.
259         // See sqlite.org for more information.
260         cached = static_cast<std::vector<uint8_t> *>(sqlite3_get_auxdata(ctx, VALUE_CACHE_ID));
261     }
262     if (cached == nullptr) { // LCOV_EXCL_BR_LINE
263         LOGW("[FlatBufferExtract] Something wrong with Auxdata, but it is no matter without cache.");
264     }
265     TypeValue outExtract;
266     int errCode = schema.ExtractValue(ValueSource::FROM_DBFILE, inPath, inValue, outExtract, cached);
267     if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
268         sqlite3_result_error(ctx, "[FlatBufferExtract] ExtractValue fail.", USING_STR_LEN);
269         LOGE("[FlatBufferExtract] ExtractValue fail, errCode=%d.", errCode);
270         return;
271     }
272     ExtractReturn(ctx, outExtract.first, outExtract.second);
273 }
274 
ExtractReturn(sqlite3_context * ctx,FieldType type,const FieldValue & value)275 void SQLiteUtils::ExtractReturn(sqlite3_context *ctx, FieldType type, const FieldValue &value)
276 {
277     if (ctx == nullptr) {
278         return;
279     }
280     switch (type) {
281         case FieldType::LEAF_FIELD_BOOL:
282             sqlite3_result_int(ctx, (value.boolValue ? 1 : 0));
283             break;
284         case FieldType::LEAF_FIELD_INTEGER:
285             sqlite3_result_int(ctx, value.integerValue);
286             break;
287         case FieldType::LEAF_FIELD_LONG:
288             sqlite3_result_int64(ctx, value.longValue);
289             break;
290         case FieldType::LEAF_FIELD_DOUBLE:
291             sqlite3_result_double(ctx, value.doubleValue);
292             break;
293         case FieldType::LEAF_FIELD_STRING:
294             // The SQLITE_TRANSIENT value means that the content will likely change in the near future and
295             // that SQLite should make its own private copy of the content before returning.
296             sqlite3_result_text(ctx, value.stringValue.c_str(), -1, SQLITE_TRANSIENT); // -1 mean use the string length
297             break;
298         default:
299             // All other type regard as null
300             sqlite3_result_null(ctx);
301     }
302     return;
303 }
304 
CalcHashFunc(sqlite3_context * ctx,sqlite3_value ** argv)305 static void CalcHashFunc(sqlite3_context *ctx, sqlite3_value **argv)
306 {
307     auto keyBlob = static_cast<const uint8_t *>(sqlite3_value_blob(argv[0]));
308     if (keyBlob == nullptr) {
309         sqlite3_result_error(ctx, "Parameters is invalid.", USING_STR_LEN);
310         LOGE("Parameters is invalid.");
311         return;
312     }
313     int blobLen = sqlite3_value_bytes(argv[0]);
314     std::vector<uint8_t> value(keyBlob, keyBlob + blobLen);
315     std::vector<uint8_t> hashValue;
316     int errCode = DBCommon::CalcValueHash(value, hashValue);
317     if (errCode != E_OK) {
318         sqlite3_result_error(ctx, "Get hash value error.", USING_STR_LEN);
319         LOGE("Get hash value error.");
320         return;
321     }
322     sqlite3_result_blob(ctx, hashValue.data(), hashValue.size(), SQLITE_TRANSIENT);
323 }
324 
CalcHashKey(sqlite3_context * ctx,int argc,sqlite3_value ** argv)325 void SQLiteUtils::CalcHashKey(sqlite3_context *ctx, int argc, sqlite3_value **argv)
326 {
327     // 1 means that the function only needs one parameter, namely key
328     if (ctx == nullptr || argc != 1 || argv == nullptr) {
329         LOGE("Parameter does not meet restrictions.");
330         return;
331     }
332     CalcHashFunc(ctx, argv);
333 }
334 
CalcHash(sqlite3_context * ctx,int argc,sqlite3_value ** argv)335 void SQLiteUtils::CalcHash(sqlite3_context *ctx, int argc, sqlite3_value **argv)
336 {
337     if (ctx == nullptr || argc != 2 || argv == nullptr) { // 2 is params count
338         LOGE("Parameter does not meet restrictions.");
339         return;
340     }
341     CalcHashFunc(ctx, argv);
342 }
343 
344 
GetDbSize(const std::string & dir,const std::string & dbName,uint64_t & size)345 int SQLiteUtils::GetDbSize(const std::string &dir, const std::string &dbName, uint64_t &size)
346 {
347     std::string dataDir = dir + "/" + dbName + DBConstant::DB_EXTENSION;
348     uint64_t localDbSize = 0;
349     int errCode = OS::CalFileSize(dataDir, localDbSize);
350     if (errCode != E_OK) {
351         LOGD("Failed to get the db file size, errCode:%d", errCode);
352         return errCode;
353     }
354 
355     std::string shmFileName = dataDir + "-shm";
356     uint64_t localshmFileSize = 0;
357     errCode = OS::CalFileSize(shmFileName, localshmFileSize);
358     if (errCode != E_OK) {
359         localshmFileSize = 0;
360     }
361 
362     std::string walFileName = dataDir + "-wal";
363     uint64_t localWalFileSize = 0;
364     errCode = OS::CalFileSize(walFileName, localWalFileSize);
365     if (errCode != E_OK) {
366         localWalFileSize = 0;
367     }
368 
369     // 64-bit system is Suffice. Computer storage is less than uint64_t max
370     size += (localDbSize + localshmFileSize + localWalFileSize);
371     return E_OK;
372 }
373 
SetDataBaseProperty(sqlite3 * db,const OpenDbProperties & properties,bool setWal,const std::vector<std::string> & sqls)374 int SQLiteUtils::SetDataBaseProperty(sqlite3 *db, const OpenDbProperties &properties, bool setWal,
375     const std::vector<std::string> &sqls)
376 {
377     // Set the default busy handler to retry automatically before returning SQLITE_BUSY.
378     int errCode = SetBusyTimeout(db, BUSY_TIMEOUT_MS);
379     if (errCode != E_OK) {
380         return errCode;
381     }
382     if (!properties.isMemDb) {
383         errCode = SQLiteUtils::SetKey(db, properties.cipherType, properties.passwd, setWal,
384             properties.iterTimes);
385         if (errCode != E_OK) {
386             LOGD("SQLiteUtils::SetKey fail!!![%d]", errCode);
387             return errCode;
388         }
389     }
390 
391     for (const auto &sql : sqls) {
392         errCode = SQLiteUtils::ExecuteRawSQL(db, sql);
393         if (errCode != E_OK) {
394             LOGE("[SQLite] execute sql failed: %d", errCode);
395             return errCode;
396         }
397     }
398     // Create table if not exist according the sqls.
399     if (properties.createIfNecessary) {
400         for (const auto &sql : properties.sqls) {
401             errCode = SQLiteUtils::ExecuteRawSQL(db, sql);
402             if (errCode != E_OK) {
403                 LOGE("[SQLite] execute preset sqls failed");
404                 return errCode;
405             }
406         }
407     }
408     return E_OK;
409 }
410 
411 #ifndef OMIT_ENCRYPT
SetCipherSettings(sqlite3 * db,CipherType type,uint32_t iterTimes)412 int SQLiteUtils::SetCipherSettings(sqlite3 *db, CipherType type, uint32_t iterTimes)
413 {
414     if (db == nullptr) {
415         return -E_INVALID_DB;
416     }
417     std::string cipherName = GetCipherName(type);
418     if (cipherName.empty()) {
419         return -E_INVALID_ARGS;
420     }
421     std::string cipherConfig = CIPHER_CONFIG_SQL + cipherName + ";";
422     int errCode = SQLiteUtils::ExecuteRawSQL(db, cipherConfig);
423     if (errCode != E_OK) {
424         LOGE("[SQLiteUtils][SetCipherSettings] config cipher failed:%d", errCode);
425         return errCode;
426     }
427     errCode = SQLiteUtils::ExecuteRawSQL(db, KDF_ITER_CONFIG_SQL + std::to_string(iterTimes));
428     if (errCode != E_OK) {
429         LOGE("[SQLiteUtils][SetCipherSettings] config iter failed:%d", errCode);
430     }
431     return errCode;
432 }
433 
GetCipherName(CipherType type)434 std::string SQLiteUtils::GetCipherName(CipherType type)
435 {
436     if (type == CipherType::AES_256_GCM || type == CipherType::DEFAULT) {
437         return "'aes-256-gcm'";
438     }
439     return "";
440 }
441 #endif
442 
DropTriggerByName(sqlite3 * db,const std::string & name)443 int SQLiteUtils::DropTriggerByName(sqlite3 *db, const std::string &name)
444 {
445     const std::string dropTriggerSql = "DROP TRIGGER " + name + ";";
446     int errCode = SQLiteUtils::ExecuteRawSQL(db, dropTriggerSql);
447     if (errCode != E_OK) {
448         LOGE("Remove trigger failed. %d", errCode);
449     }
450     return errCode;
451 }
452 
ExpandedSql(sqlite3_stmt * stmt,std::string & basicString)453 int SQLiteUtils::ExpandedSql(sqlite3_stmt *stmt, std::string &basicString)
454 {
455     if (stmt == nullptr) {
456         return -E_INVALID_ARGS;
457     }
458     char *eSql = sqlite3_expanded_sql(stmt);
459     if (eSql == nullptr) {
460         LOGE("expand statement to sql failed.");
461         return -E_INVALID_DATA;
462     }
463     basicString = std::string(eSql);
464     sqlite3_free(eSql);
465     return E_OK;
466 }
467 
ExecuteCheckPoint(sqlite3 * db)468 void SQLiteUtils::ExecuteCheckPoint(sqlite3 *db)
469 {
470     if (db == nullptr) {
471         return;
472     }
473 
474     int chkResult = sqlite3_wal_checkpoint_v2(db, nullptr, SQLITE_CHECKPOINT_TRUNCATE, nullptr, nullptr);
475     LOGI("SQLite checkpoint result:%d", chkResult);
476 }
477 
CheckTableEmpty(sqlite3 * db,const std::string & tableName,bool & isEmpty)478 int SQLiteUtils::CheckTableEmpty(sqlite3 *db, const std::string &tableName, bool &isEmpty)
479 {
480     if (db == nullptr) {
481         return -E_INVALID_ARGS;
482     }
483 
484     std::string cntSql = "SELECT min(rowid) FROM '" + tableName + "';";
485     sqlite3_stmt *stmt = nullptr;
486     int errCode = SQLiteUtils::GetStatement(db, cntSql, stmt);
487     if (errCode != E_OK) {
488         return errCode;
489     }
490 
491     errCode = SQLiteUtils::StepWithRetry(stmt, false);
492     if (errCode == SQLiteUtils::MapSQLiteErrno(SQLITE_ROW)) {
493         isEmpty = (sqlite3_column_type(stmt, 0) == SQLITE_NULL);
494         errCode = E_OK;
495     }
496 
497     int ret = E_OK;
498     SQLiteUtils::ResetStatement(stmt, true, ret);
499     return SQLiteUtils::MapSQLiteErrno(errCode != E_OK ? errCode : ret);
500 }
501 
SetPersistWalMode(sqlite3 * db)502 int SQLiteUtils::SetPersistWalMode(sqlite3 *db)
503 {
504     if (db == nullptr) {
505         return -E_INVALID_ARGS;
506     }
507     int opCode = 1;
508     int errCode = sqlite3_file_control(db, "main", SQLITE_FCNTL_PERSIST_WAL, &opCode);
509     if (errCode != SQLITE_OK) {
510         LOGE("Set persist wal mode failed. %d", errCode);
511     }
512     return SQLiteUtils::MapSQLiteErrno(errCode);
513 }
514 
GetLastRowId(sqlite3 * db)515 int64_t SQLiteUtils::GetLastRowId(sqlite3 *db)
516 {
517     if (db == nullptr) {
518         return -1;
519     }
520     return sqlite3_last_insert_rowid(db);
521 }
522 
GetLastErrorMsg()523 std::string SQLiteUtils::GetLastErrorMsg()
524 {
525     std::lock_guard<std::mutex> autoLock(logMutex_);
526     return lastErrorMsg_;
527 }
528 
SetAuthorizer(sqlite3 * db,int (* xAuth)(void *,int,const char *,const char *,const char *,const char *))529 int SQLiteUtils::SetAuthorizer(sqlite3 *db,
530     int (*xAuth)(void*, int, const char*, const char*, const char*, const char*))
531 {
532     return SQLiteUtils::MapSQLiteErrno(sqlite3_set_authorizer(db, xAuth, nullptr));
533 }
534 
GetSelectCols(sqlite3_stmt * stmt,std::vector<std::string> & colNames)535 void SQLiteUtils::GetSelectCols(sqlite3_stmt *stmt, std::vector<std::string> &colNames)
536 {
537     colNames.clear();
538     for (int i = 0; i < sqlite3_column_count(stmt); ++i) {
539         const char *name = sqlite3_column_name(stmt, i);
540         colNames.emplace_back(name == nullptr ? std::string() : std::string(name));
541     }
542 }
543 
SetKeyInner(sqlite3 * db,CipherType type,const CipherPassword & passwd,uint32_t iterTimes)544 int SQLiteUtils::SetKeyInner(sqlite3 *db, CipherType type, const CipherPassword &passwd, uint32_t iterTimes)
545 {
546 #ifndef OMIT_ENCRYPT
547     int errCode = sqlite3_key(db, static_cast<const void *>(passwd.GetData()), static_cast<int>(passwd.GetSize()));
548     if (errCode != SQLITE_OK) {
549         LOGE("[SQLiteUtils][SetKeyInner] config key failed:(%d)", errCode);
550         return SQLiteUtils::MapSQLiteErrno(errCode);
551     }
552 
553     errCode = SQLiteUtils::SetCipherSettings(db, type, iterTimes);
554     if (errCode != E_OK) {
555         LOGE("[SQLiteUtils][SetKeyInner] set cipher settings failed:%d", errCode);
556     }
557     return errCode;
558 #else
559     return -E_NOT_SUPPORT;
560 #endif
561 }
562 
BindDataValueByType(sqlite3_stmt * statement,const std::optional<DataValue> & data,int cid)563 int SQLiteUtils::BindDataValueByType(sqlite3_stmt *statement, const std::optional<DataValue> &data, int cid)
564 {
565     int errCode = E_OK;
566     StorageType type = data.value_or(DataValue()).GetType();
567     switch (type) {
568         case StorageType::STORAGE_TYPE_INTEGER: {
569             int64_t intData = 0;
570             (void)data.value().GetInt64(intData);
571             errCode = SQLiteUtils::MapSQLiteErrno(sqlite3_bind_int64(statement, cid, intData));
572             break;
573         }
574 
575         case StorageType::STORAGE_TYPE_REAL: {
576             double doubleData = 0;
577             (void)data.value().GetDouble(doubleData);
578             errCode = SQLiteUtils::MapSQLiteErrno(sqlite3_bind_double(statement, cid, doubleData));
579             break;
580         }
581 
582         case StorageType::STORAGE_TYPE_TEXT: {
583             std::string strData;
584             (void)data.value().GetText(strData);
585             errCode = SQLiteUtils::BindTextToStatement(statement, cid, strData);
586             break;
587         }
588 
589         case StorageType::STORAGE_TYPE_BLOB: {
590             Blob blob;
591             (void)data.value().GetBlob(blob);
592             std::vector<uint8_t> blobData(blob.GetData(), blob.GetData() + blob.GetSize());
593             errCode = SQLiteUtils::BindBlobToStatement(statement, cid, blobData, true);
594             break;
595         }
596 
597         case StorageType::STORAGE_TYPE_NULL: {
598             errCode = SQLiteUtils::MapSQLiteErrno(sqlite3_bind_null(statement, cid));
599             break;
600         }
601 
602         default:
603             break;
604     }
605     return errCode;
606 }
607 
UpdateCipherShaAlgo(sqlite3 * db,bool setWal,CipherType type,const CipherPassword & passwd,uint32_t iterTimes)608 int SQLiteUtils::UpdateCipherShaAlgo(sqlite3 *db, bool setWal, CipherType type, const CipherPassword &passwd,
609     uint32_t iterTimes)
610 {
611     int errCode = SetKeyInner(db, type, passwd, iterTimes);
612     if (errCode != E_OK) {
613         return errCode;
614     }
615     // set sha1 algo for old version
616     errCode = SQLiteUtils::ExecuteRawSQL(db, SHA1_ALGO_SQL);
617     if (errCode != E_OK) {
618         LOGE("[SQLiteUtils][UpdateCipherShaAlgo] set sha algo failed:%d", errCode);
619         return errCode;
620     }
621     // try to get user version
622     errCode = SQLiteUtils::ExecuteRawSQL(db, USER_VERSION_SQL);
623     if (errCode != E_OK) {
624         LOGE("[SQLiteUtils][UpdateCipherShaAlgo] verify version failed:%d", errCode);
625         if (errno == EKEYREVOKED) {
626             return -E_EKEYREVOKED;
627         }
628         return errCode;
629     }
630     // try to update rekey sha algo by rekey operation
631     errCode = SQLiteUtils::ExecuteRawSQL(db, SHA256_ALGO_REKEY_SQL);
632     if (errCode != E_OK) {
633         LOGE("[SQLiteUtils][UpdateCipherShaAlgo] set rekey sha algo failed:%d", errCode);
634         return errCode;
635     }
636     if (setWal) {
637         errCode = SQLiteUtils::ExecuteRawSQL(db, WAL_MODE_SQL);
638         if (errCode != E_OK) {
639             LOGE("[SQLite][UpdateCipherShaAlgo] execute wal sql failed: %d", errCode);
640             return errCode;
641         }
642     }
643     return Rekey(db, passwd);
644 }
645 
StepNext(sqlite3_stmt * stmt,bool isMemDb)646 int SQLiteUtils::StepNext(sqlite3_stmt *stmt, bool isMemDb)
647 {
648     if (stmt == nullptr) {
649         return -E_INVALID_ARGS;
650     }
651     int errCode = SQLiteUtils::StepWithRetry(stmt, isMemDb);
652     if (errCode == SQLiteUtils::MapSQLiteErrno(SQLITE_DONE)) {
653         errCode = -E_FINISHED;
654     } else if (errCode == SQLiteUtils::MapSQLiteErrno(SQLITE_ROW)) {
655         errCode = E_OK;
656     }
657     return errCode;
658 }
659 
IsStmtReadOnly(sqlite3_stmt * statement)660 bool SQLiteUtils::IsStmtReadOnly(sqlite3_stmt *statement)
661 {
662     if (statement == nullptr) {
663         return false;
664     }
665     int isReadOnly = sqlite3_stmt_readonly(statement);
666     return static_cast<bool>(isReadOnly);
667 }
668 
UpdateLocalDataModifyTime(sqlite3 * db,const std::string & virtualTime,const std::string & modifyTime)669 int SQLiteUtils::UpdateLocalDataModifyTime(sqlite3 *db, const std::string &virtualTime, const std::string &modifyTime)
670 {
671     if (db == nullptr) {
672         return -E_INVALID_DB;
673     }
674     bool isCreate = false;
675     std::string tableName = DBConstant::KV_SYNC_TABLE_NAME;
676     auto errCode = SQLiteUtils::CheckTableExists(db, tableName, isCreate);
677     if (errCode != E_OK) {
678         LOGE("[SQLiteUtils] Check table exist failed %d when update modify time", errCode);
679         return errCode;
680     }
681     if (!isCreate) {
682         LOGW("[SQLiteUtils] non exist table when update time");
683         return E_OK;
684     }
685     std::string updateTimeSql = "UPDATE " + tableName + " SET timestamp = _rowid_ + " + virtualTime +
686                                 ", w_timestamp = _rowid_ + " + virtualTime + ", modify_time = _rowid_ + " + modifyTime +
687                                 " WHERE flag & 0x02 != 0;";
688     errCode = SQLiteUtils::ExecuteRawSQL(db, updateTimeSql);
689     if (errCode != E_OK) {
690         LOGE("[SQLiteUtils] Update modify time failed %d", errCode);
691     }
692     return errCode;
693 }
694 
UpdateLocalDataCloudFlag(sqlite3 * db)695 int SQLiteUtils::UpdateLocalDataCloudFlag(sqlite3 *db)
696 {
697     if (db == nullptr) {
698         return -E_INVALID_DB;
699     }
700     bool isCreate = false;
701     std::string logTableName = "naturalbase_kv_aux_sync_data_log";
702     auto errCode = SQLiteUtils::CheckTableExists(db, logTableName, isCreate);
703     if (errCode != E_OK) {
704         LOGE("[SQLiteUtils] Check log table exist failed %d when update cloud flag", errCode);
705         return errCode;
706     }
707     if (!isCreate) {
708         LOGW("[SQLiteUtils] non exist log table when update cloud flag");
709         return E_OK;
710     }
711     std::string tableName = DBConstant::KV_SYNC_TABLE_NAME;
712     errCode = SQLiteUtils::CheckTableExists(db, tableName, isCreate);
713     if (errCode != E_OK) {
714         LOGE("[SQLiteUtils] Check table exist failed %d when update cloud flag", errCode);
715         return errCode;
716     }
717     if (!isCreate) {
718         LOGW("[SQLiteUtils] non exist table when update cloud flag");
719         return E_OK;
720     }
721     std::string updateCloudFlagSql = "UPDATE " + logTableName + " SET cloud_flag = cloud_flag & ~0x400 "
722                                      "WHERE hash_key IN (SELECT hash_key FROM " + tableName +
723                                      " WHERE flag & 0x02 != 0);";
724     errCode = SQLiteUtils::ExecuteRawSQL(db, updateCloudFlagSql);
725     if (errCode != E_OK) {
726         LOGE("[SQLiteUtils] Update cloud flag failed %d", errCode);
727     }
728     return errCode;
729 }
730 } // namespace DistributedDB
731