• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "sqlite_relational_utils.h"
17 #include "db_common.h"
18 #include "db_errno.h"
19 #include "cloud/cloud_db_types.h"
20 #include "sqlite_utils.h"
21 #include "cloud/cloud_storage_utils.h"
22 #include "res_finalizer.h"
23 #include "runtime_context.h"
24 #include "cloud/cloud_db_constant.h"
25 
26 namespace DistributedDB {
GetDataValueByType(sqlite3_stmt * statement,int cid,DataValue & value)27 int SQLiteRelationalUtils::GetDataValueByType(sqlite3_stmt *statement, int cid, DataValue &value)
28 {
29     if (statement == nullptr || cid < 0 || cid >= sqlite3_column_count(statement)) {
30         return -E_INVALID_ARGS;
31     }
32 
33     int errCode = E_OK;
34     int storageType = sqlite3_column_type(statement, cid);
35     switch (storageType) {
36         case SQLITE_INTEGER:
37             value = static_cast<int64_t>(sqlite3_column_int64(statement, cid));
38             break;
39         case SQLITE_FLOAT:
40             value = sqlite3_column_double(statement, cid);
41             break;
42         case SQLITE_BLOB: {
43             std::vector<uint8_t> blobValue;
44             errCode = SQLiteUtils::GetColumnBlobValue(statement, cid, blobValue);
45             if (errCode != E_OK) {
46                 return errCode;
47             }
48             auto blob = new (std::nothrow) Blob;
49             if (blob == nullptr) {
50                 return -E_OUT_OF_MEMORY;
51             }
52             blob->WriteBlob(blobValue.data(), static_cast<uint32_t>(blobValue.size()));
53             errCode = value.Set(blob);
54             if (errCode != E_OK) {
55                 delete blob;
56                 blob = nullptr;
57             }
58             break;
59         }
60         case SQLITE_NULL:
61             break;
62         case SQLITE3_TEXT: {
63             std::string str;
64             (void)SQLiteUtils::GetColumnTextValue(statement, cid, str);
65             value = str;
66             if (value.GetType() != StorageType::STORAGE_TYPE_TEXT) {
67                 errCode = -E_OUT_OF_MEMORY;
68             }
69             break;
70         }
71         default:
72             break;
73     }
74     return errCode;
75 }
76 
GetSelectValues(sqlite3_stmt * stmt)77 std::vector<DataValue> SQLiteRelationalUtils::GetSelectValues(sqlite3_stmt *stmt)
78 {
79     std::vector<DataValue> values;
80     for (int cid = 0, colCount = sqlite3_column_count(stmt); cid < colCount; ++cid) {
81         DataValue value;
82         (void)GetDataValueByType(stmt, cid, value);
83         values.emplace_back(std::move(value));
84     }
85     return values;
86 }
87 
GetCloudValueByType(sqlite3_stmt * statement,int type,int cid,Type & cloudValue)88 int SQLiteRelationalUtils::GetCloudValueByType(sqlite3_stmt *statement, int type, int cid, Type &cloudValue)
89 {
90     if (statement == nullptr || cid < 0 || cid >= sqlite3_column_count(statement)) {
91         return -E_INVALID_ARGS;
92     }
93     switch (sqlite3_column_type(statement, cid)) {
94         case SQLITE_INTEGER: {
95             if (type == TYPE_INDEX<bool>) {
96                 cloudValue = static_cast<bool>(sqlite3_column_int(statement, cid));
97                 break;
98             }
99             cloudValue = static_cast<int64_t>(sqlite3_column_int64(statement, cid));
100             break;
101         }
102         case SQLITE_FLOAT: {
103             cloudValue = sqlite3_column_double(statement, cid);
104             break;
105         }
106         case SQLITE_BLOB: {
107             std::vector<uint8_t> blobValue;
108             int errCode = SQLiteUtils::GetColumnBlobValue(statement, cid, blobValue);
109             if (errCode != E_OK) {
110                 return errCode;
111             }
112             cloudValue = blobValue;
113             break;
114         }
115         case SQLITE3_TEXT: {
116             bool isBlob = (type == TYPE_INDEX<Bytes> || type == TYPE_INDEX<Asset> || type == TYPE_INDEX<Assets>);
117             if (isBlob) {
118                 std::vector<uint8_t> blobValue;
119                 int errCode = SQLiteUtils::GetColumnBlobValue(statement, cid, blobValue);
120                 if (errCode != E_OK) {
121                     return errCode;
122                 }
123                 cloudValue = blobValue;
124                 break;
125             }
126             std::string str;
127             (void)SQLiteUtils::GetColumnTextValue(statement, cid, str);
128             cloudValue = str;
129             break;
130         }
131         default: {
132             cloudValue = Nil();
133         }
134     }
135     return E_OK;
136 }
137 
CalCloudValueLen(Type & cloudValue,uint32_t & totalSize)138 void SQLiteRelationalUtils::CalCloudValueLen(Type &cloudValue, uint32_t &totalSize)
139 {
140     switch (cloudValue.index()) {
141         case TYPE_INDEX<int64_t>:
142             totalSize += sizeof(int64_t);
143             break;
144         case TYPE_INDEX<double>:
145             totalSize += sizeof(double);
146             break;
147         case TYPE_INDEX<std::string>:
148             totalSize += std::get<std::string>(cloudValue).size();
149             break;
150         case TYPE_INDEX<bool>:
151             totalSize += sizeof(int32_t);
152             break;
153         case TYPE_INDEX<Bytes>:
154         case TYPE_INDEX<Asset>:
155         case TYPE_INDEX<Assets>:
156             totalSize += std::get<Bytes>(cloudValue).size();
157             break;
158         default: {
159             break;
160         }
161     }
162 }
163 
BindStatementByType(sqlite3_stmt * statement,int cid,Type & typeVal)164 int SQLiteRelationalUtils::BindStatementByType(sqlite3_stmt *statement, int cid, Type &typeVal)
165 {
166     int errCode = E_OK;
167     switch (typeVal.index()) {
168         case TYPE_INDEX<int64_t>: {
169             int64_t value = 0;
170             (void)CloudStorageUtils::GetValueFromType(typeVal, value);
171             errCode = SQLiteUtils::BindInt64ToStatement(statement, cid, value);
172             break;
173         }
174         case TYPE_INDEX<bool>: {
175             bool value = false;
176             (void)CloudStorageUtils::GetValueFromType<bool>(typeVal, value);
177             errCode = SQLiteUtils::BindInt64ToStatement(statement, cid, value);
178             break;
179         }
180         case TYPE_INDEX<double>: {
181             double value = 0.0;
182             (void)CloudStorageUtils::GetValueFromType<double>(typeVal, value);
183             errCode = SQLiteUtils::MapSQLiteErrno(sqlite3_bind_double(statement, cid, value));
184             break;
185         }
186         case TYPE_INDEX<std::string>: {
187             std::string value;
188             (void)CloudStorageUtils::GetValueFromType<std::string>(typeVal, value);
189             errCode = SQLiteUtils::BindTextToStatement(statement, cid, value);
190             break;
191         }
192         default: {
193             errCode = BindExtendStatementByType(statement, cid, typeVal);
194             break;
195         }
196     }
197     return errCode;
198 }
199 
BindExtendStatementByType(sqlite3_stmt * statement,int cid,Type & typeVal)200 int SQLiteRelationalUtils::BindExtendStatementByType(sqlite3_stmt *statement, int cid, Type &typeVal)
201 {
202     int errCode = E_OK;
203     switch (typeVal.index()) {
204         case TYPE_INDEX<Bytes>: {
205             Bytes value;
206             (void)CloudStorageUtils::GetValueFromType<Bytes>(typeVal, value);
207             errCode = SQLiteUtils::BindBlobToStatement(statement, cid, value);
208             break;
209         }
210         case TYPE_INDEX<Asset>: {
211             Asset value;
212             (void)CloudStorageUtils::GetValueFromType<Asset>(typeVal, value);
213             Bytes val;
214             errCode = RuntimeContext::GetInstance()->AssetToBlob(value, val);
215             if (errCode != E_OK) {
216                 break;
217             }
218             errCode = SQLiteUtils::BindBlobToStatement(statement, cid, val);
219             break;
220         }
221         case TYPE_INDEX<Assets>: {
222             Assets value;
223             (void)CloudStorageUtils::GetValueFromType<Assets>(typeVal, value);
224             Bytes val;
225             errCode = RuntimeContext::GetInstance()->AssetsToBlob(value, val);
226             if (errCode != E_OK) {
227                 break;
228             }
229             errCode = SQLiteUtils::BindBlobToStatement(statement, cid, val);
230             break;
231         }
232         default: {
233             errCode = SQLiteUtils::MapSQLiteErrno(sqlite3_bind_null(statement, cid));
234             break;
235         }
236     }
237     return errCode;
238 }
239 
GetSelectVBucket(sqlite3_stmt * stmt,VBucket & bucket)240 int SQLiteRelationalUtils::GetSelectVBucket(sqlite3_stmt *stmt, VBucket &bucket)
241 {
242     if (stmt == nullptr) {
243         return -E_INVALID_ARGS;
244     }
245     for (int cid = 0, colCount = sqlite3_column_count(stmt); cid < colCount; ++cid) {
246         Type typeVal;
247         int errCode = GetTypeValByStatement(stmt, cid, typeVal);
248         if (errCode != E_OK) {
249             LOGE("get typeVal from stmt failed");
250             return errCode;
251         }
252         const char *colName = sqlite3_column_name(stmt, cid);
253         bucket.insert_or_assign(colName, std::move(typeVal));
254     }
255     return E_OK;
256 }
257 
GetDbFileName(sqlite3 * db,std::string & fileName)258 bool SQLiteRelationalUtils::GetDbFileName(sqlite3 *db, std::string &fileName)
259 {
260     if (db == nullptr) {
261         return false;
262     }
263 
264     auto dbFilePath = sqlite3_db_filename(db, nullptr);
265     if (dbFilePath == nullptr) {
266         return false;
267     }
268     fileName = std::string(dbFilePath);
269     return true;
270 }
271 
GetTypeValByStatement(sqlite3_stmt * stmt,int cid,Type & typeVal)272 int SQLiteRelationalUtils::GetTypeValByStatement(sqlite3_stmt *stmt, int cid, Type &typeVal)
273 {
274     if (stmt == nullptr || cid < 0 || cid >= sqlite3_column_count(stmt)) {
275         return -E_INVALID_ARGS;
276     }
277     int errCode = E_OK;
278     switch (sqlite3_column_type(stmt, cid)) {
279         case SQLITE_INTEGER: {
280             const char *declType = sqlite3_column_decltype(stmt, cid);
281             if (declType == nullptr) { // LCOV_EXCL_BR_LINE
282                 typeVal = static_cast<int64_t>(sqlite3_column_int64(stmt, cid));
283                 break;
284             }
285             if (strcasecmp(declType, SchemaConstant::KEYWORD_TYPE_BOOL.c_str()) == 0 ||
286                 strcasecmp(declType, SchemaConstant::KEYWORD_TYPE_BOOLEAN.c_str()) == 0) { // LCOV_EXCL_BR_LINE
287                 typeVal = static_cast<bool>(sqlite3_column_int(stmt, cid));
288                 break;
289             }
290             typeVal = static_cast<int64_t>(sqlite3_column_int64(stmt, cid));
291             break;
292         }
293         case SQLITE_FLOAT: {
294             typeVal = sqlite3_column_double(stmt, cid);
295             break;
296         }
297         case SQLITE_BLOB: {
298             errCode = GetBlobByStatement(stmt, cid, typeVal);
299             break;
300         }
301         case SQLITE3_TEXT: {
302             errCode = GetBlobByStatement(stmt, cid, typeVal);
303             if (errCode != E_OK || typeVal.index() != TYPE_INDEX<Nil>) { // LCOV_EXCL_BR_LINE
304                 break;
305             }
306             std::string str;
307             (void)SQLiteUtils::GetColumnTextValue(stmt, cid, str);
308             typeVal = str;
309             break;
310         }
311         default: {
312             typeVal = Nil();
313         }
314     }
315     return errCode;
316 }
317 
GetBlobByStatement(sqlite3_stmt * stmt,int cid,Type & typeVal)318 int SQLiteRelationalUtils::GetBlobByStatement(sqlite3_stmt *stmt, int cid, Type &typeVal)
319 {
320     const char *declType = sqlite3_column_decltype(stmt, cid);
321     int errCode = E_OK;
322     if (declType != nullptr && strcasecmp(declType, CloudDbConstant::ASSET) == 0) { // LCOV_EXCL_BR_LINE
323         std::vector<uint8_t> blobValue;
324         errCode = SQLiteUtils::GetColumnBlobValue(stmt, cid, blobValue);
325         if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
326             return errCode;
327         }
328         Asset asset;
329         errCode = RuntimeContext::GetInstance()->BlobToAsset(blobValue, asset);
330         if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
331             return errCode;
332         }
333         typeVal = asset;
334     } else if (declType != nullptr && strcasecmp(declType, CloudDbConstant::ASSETS) == 0) {
335         std::vector<uint8_t> blobValue;
336         errCode = SQLiteUtils::GetColumnBlobValue(stmt, cid, blobValue);
337         if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
338             return errCode;
339         }
340         Assets assets;
341         errCode = RuntimeContext::GetInstance()->BlobToAssets(blobValue, assets);
342         if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
343             return errCode;
344         }
345         typeVal = assets;
346     } else if (sqlite3_column_type(stmt, cid) == SQLITE_BLOB) {
347         std::vector<uint8_t> blobValue;
348         errCode = SQLiteUtils::GetColumnBlobValue(stmt, cid, blobValue);
349         if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
350             return errCode;
351         }
352         typeVal = blobValue;
353     }
354     return E_OK;
355 }
356 
SelectServerObserver(sqlite3 * db,const std::string & tableName,bool isChanged)357 int SQLiteRelationalUtils::SelectServerObserver(sqlite3 *db, const std::string &tableName, bool isChanged)
358 {
359     if (db == nullptr || tableName.empty()) {
360         return -E_INVALID_ARGS;
361     }
362     std::string sql;
363     if (isChanged) {
364         sql = "SELECT server_observer('" + tableName + "', 1);";
365     } else {
366         sql = "SELECT server_observer('" + tableName + "', 0);";
367     }
368     sqlite3_stmt *stmt = nullptr;
369     int errCode = SQLiteUtils::GetStatement(db, sql, stmt);
370     if (errCode != E_OK) {
371         LOGE("get select server observer stmt failed. %d", errCode);
372         return errCode;
373     }
374     errCode = SQLiteUtils::StepWithRetry(stmt, false);
375     int ret = E_OK;
376     SQLiteUtils::ResetStatement(stmt, true, ret);
377     if (errCode != SQLiteUtils::MapSQLiteErrno(SQLITE_ROW)) {
378         LOGE("select server observer failed. %d", errCode);
379         return SQLiteUtils::MapSQLiteErrno(errCode);
380     }
381     return ret == E_OK ? E_OK : ret;
382 }
383 
AddUpgradeSqlToList(const TableInfo & tableInfo,const std::vector<std::pair<std::string,std::string>> & fieldList,std::vector<std::string> & sqlList)384 void SQLiteRelationalUtils::AddUpgradeSqlToList(const TableInfo &tableInfo,
385     const std::vector<std::pair<std::string, std::string>> &fieldList, std::vector<std::string> &sqlList)
386 {
387     for (const auto &[colName, colType] : fieldList) {
388         auto it = tableInfo.GetFields().find(colName);
389         if (it != tableInfo.GetFields().end()) {
390             continue;
391         }
392         sqlList.push_back("alter table " + tableInfo.GetTableName() + " add " + colName +
393             " " + colType + ";");
394     }
395 }
396 
AnalysisTrackerTable(sqlite3 * db,const TrackerTable & trackerTable,TableInfo & tableInfo)397 int SQLiteRelationalUtils::AnalysisTrackerTable(sqlite3 *db, const TrackerTable &trackerTable, TableInfo &tableInfo)
398 {
399     int errCode = SQLiteUtils::AnalysisSchema(db, trackerTable.GetTableName(), tableInfo, true);
400     if (errCode != E_OK) {
401         LOGE("analysis table schema failed %d.", errCode);
402         return errCode;
403     }
404     tableInfo.SetTrackerTable(trackerTable);
405     errCode = tableInfo.CheckTrackerTable();
406     if (errCode != E_OK) {
407         LOGE("check tracker table schema failed %d.", errCode);
408     }
409     return errCode;
410 }
411 
QueryCount(sqlite3 * db,const std::string & tableName,int64_t & count)412 int SQLiteRelationalUtils::QueryCount(sqlite3 *db, const std::string &tableName, int64_t &count)
413 {
414     std::string sql = "SELECT COUNT(1) FROM " + tableName ;
415     sqlite3_stmt *stmt = nullptr;
416     int errCode = SQLiteUtils::GetStatement(db, sql, stmt);
417     if (errCode != E_OK) {
418         LOGE("Query count failed. %d", errCode);
419         return errCode;
420     }
421     errCode = SQLiteUtils::StepWithRetry(stmt, false);
422     if (errCode == SQLiteUtils::MapSQLiteErrno(SQLITE_ROW)) {
423         count = static_cast<int64_t>(sqlite3_column_int64(stmt, 0));
424         errCode = E_OK;
425     } else {
426         LOGE("Failed to get the count. %d", errCode);
427     }
428     SQLiteUtils::ResetStatement(stmt, true, errCode);
429     return errCode;
430 }
431 
GetCursor(sqlite3 * db,const std::string & tableName,uint64_t & cursor)432 int SQLiteRelationalUtils::GetCursor(sqlite3 *db, const std::string &tableName, uint64_t &cursor)
433 {
434     cursor = DBConstant::INVALID_CURSOR;
435     std::string sql = "SELECT value FROM " + std::string(DBConstant::RELATIONAL_PREFIX) + "metadata where key = ?;";
436     sqlite3_stmt *stmt = nullptr;
437     int errCode = SQLiteUtils::GetStatement(db, sql, stmt);
438     if (errCode != E_OK) {
439         LOGE("[Storage Executor] Get cursor of table[%s length[%u]] failed=%d",
440             DBCommon::StringMiddleMasking(tableName).c_str(), tableName.length(), errCode);
441         return errCode;
442     }
443     ResFinalizer finalizer([stmt]() {
444         sqlite3_stmt *statement = stmt;
445         int ret = E_OK;
446         SQLiteUtils::ResetStatement(statement, true, ret);
447         if (ret != E_OK) {
448             LOGW("Reset stmt failed %d when get cursor", ret);
449         }
450     });
451     Key key;
452     DBCommon::StringToVector(DBCommon::GetCursorKey(tableName), key);
453     errCode = SQLiteUtils::BindBlobToStatement(stmt, 1, key, false); // first arg.
454     if (errCode != E_OK) {
455         LOGE("[Storage Executor] Bind failed when get cursor of table[%s length[%u]] failed=%d",
456             DBCommon::StringMiddleMasking(tableName).c_str(), tableName.length(), errCode);
457         return errCode;
458     }
459     errCode = SQLiteUtils::StepWithRetry(stmt, false);
460     if (errCode == SQLiteUtils::MapSQLiteErrno(SQLITE_ROW)) {
461         int64_t tmpCursor = static_cast<int64_t>(sqlite3_column_int64(stmt, 0));
462         if (tmpCursor >= 0) {
463             cursor = static_cast<uint64_t>(tmpCursor);
464         }
465     }
466     return cursor == DBConstant::INVALID_CURSOR ? errCode : E_OK;
467 }
468 
GetFieldsNeedContain(const TableInfo & tableInfo,const std::vector<FieldInfo> & syncFields,std::set<std::string> & fieldsNeedContain,std::set<std::string> & fieldsNotDecrease,std::set<std::string> & requiredNotNullFields)469 void GetFieldsNeedContain(const TableInfo &tableInfo, const std::vector<FieldInfo> &syncFields,
470     std::set<std::string> &fieldsNeedContain, std::set<std::string> &fieldsNotDecrease,
471     std::set<std::string> &requiredNotNullFields)
472 {
473     // should not decrease distributed field
474     for (const auto &field : tableInfo.GetSyncField()) {
475         fieldsNotDecrease.insert(field);
476     }
477     const std::vector<CompositeFields> &uniqueDefines = tableInfo.GetUniqueDefine();
478     for (const auto &compositeFields : uniqueDefines) {
479         for (const auto &fieldName : compositeFields) {
480             if (tableInfo.IsPrimaryKey(fieldName)) {
481                 continue;
482             }
483             fieldsNeedContain.insert(fieldName);
484         }
485     }
486     const FieldInfoMap &fieldInfoMap = tableInfo.GetFields();
487     for (const auto &entry : fieldInfoMap) {
488         const FieldInfo &fieldInfo = entry.second;
489         if (fieldInfo.IsNotNull() && fieldInfo.GetDefaultValue().empty()) {
490             requiredNotNullFields.insert(fieldInfo.GetFieldName());
491         }
492     }
493 }
494 
CheckRequireFieldsInMap(const std::set<std::string> & fieldsNeedContain,const std::map<std::string,bool> & fieldsMap)495 bool CheckRequireFieldsInMap(const std::set<std::string> &fieldsNeedContain,
496     const std::map<std::string, bool> &fieldsMap)
497 {
498     for (auto &fieldNeedContain : fieldsNeedContain) {
499         if (fieldsMap.find(fieldNeedContain) == fieldsMap.end()) {
500             LOGE("Required column[%s [%zu]] not found", DBCommon::StringMiddleMasking(fieldNeedContain).c_str(),
501                 fieldNeedContain.size());
502             return false;
503         }
504         if (!fieldsMap.at(fieldNeedContain)) {
505             LOGE("The isP2pSync of required column[%s [%zu]] is false",
506                 DBCommon::StringMiddleMasking(fieldNeedContain).c_str(), fieldNeedContain.size());
507             return false;
508         }
509     }
510     return true;
511 }
512 
IsMarkUniqueColumnInvalid(const TableInfo & tableInfo,const std::vector<DistributedField> & originFields)513 bool IsMarkUniqueColumnInvalid(const TableInfo &tableInfo, const std::vector<DistributedField> &originFields)
514 {
515     int count = 0;
516     for (const auto &field : originFields) {
517         if (field.isSpecified && tableInfo.IsUniqueField(field.colName)) {
518             count++;
519             if (count > 1) {
520                 return true;
521             }
522         }
523     }
524     return false;
525 }
526 
IsDistributedPkInvalid(const TableInfo & tableInfo,const std::set<std::string,CaseInsensitiveComparator> & distributedPk,const std::vector<DistributedField> & originFields,bool isForceUpgrade)527 bool IsDistributedPkInvalid(const TableInfo &tableInfo,
528     const std::set<std::string, CaseInsensitiveComparator> &distributedPk,
529     const std::vector<DistributedField> &originFields, bool isForceUpgrade)
530 {
531     auto lastDistributedPk = DBCommon::TransformToCaseInsensitive(tableInfo.GetSyncDistributedPk());
532     if (!isForceUpgrade && !lastDistributedPk.empty() && distributedPk != lastDistributedPk) {
533         LOGE("distributed pk has change last %zu now %zu", lastDistributedPk.size(), distributedPk.size());
534         return true;
535     }
536     // check pk is same or local pk is auto increase and set pk is unique
537     if (tableInfo.IsNoPkTable()) {
538         return false;
539     }
540     if (!distributedPk.empty() && distributedPk.size() != tableInfo.GetPrimaryKey().size()) {
541         return true;
542     }
543     if (tableInfo.GetAutoIncrement()) {
544         if (distributedPk.empty()) {
545             return false;
546         }
547         auto uniqueAndPkDefine = tableInfo.GetUniqueAndPkDefine();
548         if (IsMarkUniqueColumnInvalid(tableInfo, originFields)) {
549             LOGE("Mark more than one unique column specified in auto increment table: %s, tableName len: %zu",
550                 DBCommon::StringMiddleMasking(tableInfo.GetTableName()).c_str(), tableInfo.GetTableName().length());
551             return true;
552         }
553         auto find = std::any_of(uniqueAndPkDefine.begin(), uniqueAndPkDefine.end(), [&distributedPk](const auto &item) {
554             // unique index field count should be same
555             return item.size() == distributedPk.size() && distributedPk == DBCommon::TransformToCaseInsensitive(item);
556         });
557         bool isMissMatch = !find;
558         if (isMissMatch) {
559             LOGE("Miss match distributed pk size %zu in auto increment table %s", distributedPk.size(),
560                 DBCommon::StringMiddleMasking(tableInfo.GetTableName()).c_str());
561         }
562         return isMissMatch;
563     }
564     for (const auto &field : originFields) {
565         bool isLocalPk = tableInfo.IsPrimaryKey(field.colName);
566         if (field.isSpecified && !isLocalPk) {
567             LOGE("Column[%s [%zu]] is not primary key but mark specified",
568                 DBCommon::StringMiddleMasking(field.colName).c_str(), field.colName.size());
569             return true;
570         }
571         if (isLocalPk && !field.isP2pSync) {
572             LOGE("Column[%s [%zu]] is primary key but set isP2pSync false",
573                 DBCommon::StringMiddleMasking(field.colName).c_str(), field.colName.size());
574             return true;
575         }
576     }
577     return false;
578 }
579 
IsDistributedSchemaSupport(const TableInfo & tableInfo,const std::vector<DistributedField> & fields)580 bool IsDistributedSchemaSupport(const TableInfo &tableInfo, const std::vector<DistributedField> &fields)
581 {
582     if (!tableInfo.GetAutoIncrement()) {
583         return true;
584     }
585     bool isSyncPk = false;
586     bool isSyncOtherSpecified = false;
587     for (const auto &item : fields) {
588         if (tableInfo.IsPrimaryKey(item.colName) && item.isP2pSync) {
589             isSyncPk = true;
590         } else if (item.isSpecified && item.isP2pSync) {
591             isSyncOtherSpecified = true;
592         }
593     }
594     if (isSyncPk && isSyncOtherSpecified) {
595         LOGE("Not support sync with auto increment pk and other specified col");
596         return false;
597     }
598     return true;
599 }
600 
CheckDistributedSchemaFields(const TableInfo & tableInfo,const std::vector<FieldInfo> & syncFields,const std::vector<DistributedField> & fields,bool isForceUpgrade)601 int CheckDistributedSchemaFields(const TableInfo &tableInfo, const std::vector<FieldInfo> &syncFields,
602     const std::vector<DistributedField> &fields, bool isForceUpgrade)
603 {
604     if (fields.empty()) {
605         LOGE("fields cannot be empty");
606         return -E_SCHEMA_MISMATCH;
607     }
608     if (!IsDistributedSchemaSupport(tableInfo, fields)) {
609         return -E_NOT_SUPPORT;
610     }
611     std::set<std::string, CaseInsensitiveComparator> distributedPk;
612     bool isNoPrimaryKeyTable = tableInfo.IsNoPkTable();
613     for (const auto &field : fields) {
614         if (!tableInfo.IsFieldExist(field.colName)) {
615             LOGE("Column[%s [%zu]] not found in table", DBCommon::StringMiddleMasking(field.colName).c_str(),
616                  field.colName.size());
617             return -E_SCHEMA_MISMATCH;
618         }
619         if (isNoPrimaryKeyTable && field.isSpecified) {
620             return -E_SCHEMA_MISMATCH;
621         }
622         if (field.isSpecified && field.isP2pSync) {
623             distributedPk.insert(field.colName);
624         }
625     }
626     if (IsDistributedPkInvalid(tableInfo, distributedPk, fields, isForceUpgrade)) {
627         return -E_SCHEMA_MISMATCH;
628     }
629     std::set<std::string> fieldsNeedContain;
630     std::set<std::string> fieldsNotDecrease;
631     std::set<std::string> requiredNotNullFields;
632     GetFieldsNeedContain(tableInfo, syncFields, fieldsNeedContain, fieldsNotDecrease, requiredNotNullFields);
633     std::map<std::string, bool> fieldsMap;
634     for (auto &field : fields) {
635         fieldsMap.insert({field.colName, field.isP2pSync});
636     }
637     if (!CheckRequireFieldsInMap(fieldsNeedContain, fieldsMap)) {
638         LOGE("The required fields are not found in fieldsMap");
639         return -E_SCHEMA_MISMATCH;
640     }
641     if (!isForceUpgrade && !CheckRequireFieldsInMap(fieldsNotDecrease, fieldsMap)) {
642         LOGE("The fields should not decrease");
643         return -E_DISTRIBUTED_FIELD_DECREASE;
644     }
645     if (!CheckRequireFieldsInMap(requiredNotNullFields, fieldsMap)) {
646         LOGE("The required not-null fields are not found in fieldsMap");
647         return -E_SCHEMA_MISMATCH;
648     }
649     return E_OK;
650 }
651 
CheckDistributedSchemaValid(const RelationalSchemaObject & schemaObj,const DistributedSchema & schema,bool isForceUpgrade,SQLiteSingleVerRelationalStorageExecutor * executor)652 int SQLiteRelationalUtils::CheckDistributedSchemaValid(const RelationalSchemaObject &schemaObj,
653     const DistributedSchema &schema, bool isForceUpgrade, SQLiteSingleVerRelationalStorageExecutor *executor)
654 {
655     if (executor == nullptr) {
656         LOGE("[RDBUtils][CheckDistributedSchemaValid] executor is null");
657         return -E_INVALID_ARGS;
658     }
659     sqlite3 *db;
660     int errCode = executor->GetDbHandle(db);
661     if (errCode != E_OK) {
662         LOGE("[RDBUtils][CheckDistributedSchemaValid] sqlite handle failed %d", errCode);
663         return errCode;
664     }
665     for (const auto &table : schema.tables) {
666         if (table.tableName.empty()) {
667             LOGE("[RDBUtils][CheckDistributedSchemaValid] Table name cannot be empty");
668             return -E_SCHEMA_MISMATCH;
669         }
670         TableInfo tableInfo;
671         errCode = SQLiteUtils::AnalysisSchema(db, table.tableName, tableInfo);
672         if (errCode != E_OK) {
673             LOGE("[RDBUtils][CheckDistributedSchemaValid] analyze table %s failed %d",
674                 DBCommon::StringMiddleMasking(table.tableName).c_str(), errCode);
675             return errCode == -E_NOT_FOUND ? -E_SCHEMA_MISMATCH : errCode;
676         }
677         tableInfo.SetDistributedTable(schemaObj.GetDistributedTable(table.tableName));
678         errCode = CheckDistributedSchemaFields(tableInfo, schemaObj.GetSyncFieldInfo(table.tableName, false),
679             table.fields, isForceUpgrade);
680         if (errCode != E_OK) {
681             LOGE("[CheckDistributedSchema] Check fields of [%s [%zu]] fail",
682                 DBCommon::StringMiddleMasking(table.tableName).c_str(), table.tableName.size());
683             return errCode;
684         }
685     }
686     return E_OK;
687 }
688 
FilterRepeatDefine(const DistributedSchema & schema)689 DistributedSchema SQLiteRelationalUtils::FilterRepeatDefine(const DistributedSchema &schema)
690 {
691     DistributedSchema res;
692     res.version = schema.version;
693     std::set<std::string> tableName;
694     std::list<DistributedTable> tableList;
695     for (auto it = schema.tables.rbegin();it != schema.tables.rend(); it++) {
696         if (tableName.find(it->tableName) != tableName.end()) {
697             continue;
698         }
699         tableName.insert(it->tableName);
700         tableList.push_front(FilterRepeatDefine(*it));
701     }
702     for (auto &item : tableList) {
703         res.tables.push_back(std::move(item));
704     }
705     return res;
706 }
707 
FilterRepeatDefine(const DistributedTable & table)708 DistributedTable SQLiteRelationalUtils::FilterRepeatDefine(const DistributedTable &table)
709 {
710     DistributedTable res;
711     res.tableName = table.tableName;
712     std::set<std::string> fieldName;
713     std::list<DistributedField> fieldList;
714     for (auto it = table.fields.rbegin();it != table.fields.rend(); it++) {
715         if (fieldName.find(it->colName) != fieldName.end()) {
716             continue;
717         }
718         fieldName.insert(it->colName);
719         fieldList.push_front(*it);
720     }
721     for (auto &item : fieldList) {
722         res.fields.push_back(std::move(item));
723     }
724     return res;
725 }
726 
GetLogData(sqlite3_stmt * logStatement,LogInfo & logInfo)727 int SQLiteRelationalUtils::GetLogData(sqlite3_stmt *logStatement, LogInfo &logInfo)
728 {
729     logInfo.dataKey = sqlite3_column_int64(logStatement, 0);  // 0 means dataKey index
730 
731     std::vector<uint8_t> dev;
732     int errCode = SQLiteUtils::GetColumnBlobValue(logStatement, 1, dev);  // 1 means dev index
733     if (errCode != E_OK) {
734         LOGE("[SQLiteRDBUtils] Get dev failed %d", errCode);
735         return errCode;
736     }
737     logInfo.device = std::string(dev.begin(), dev.end());
738 
739     std::vector<uint8_t> oriDev;
740     errCode = SQLiteUtils::GetColumnBlobValue(logStatement, 2, oriDev);  // 2 means ori_dev index
741     if (errCode != E_OK) {
742         LOGE("[SQLiteRDBUtils] Get ori dev failed %d", errCode);
743         return errCode;
744     }
745     logInfo.originDev = std::string(oriDev.begin(), oriDev.end());
746     logInfo.timestamp = static_cast<uint64_t>(sqlite3_column_int64(logStatement, 3));  // 3 means timestamp index
747     logInfo.wTimestamp = static_cast<uint64_t>(sqlite3_column_int64(logStatement, 4));  // 4 means w_timestamp index
748     logInfo.flag = static_cast<uint64_t>(sqlite3_column_int64(logStatement, 5));  // 5 means flag index
749     logInfo.flag &= (~DataItem::LOCAL_FLAG);
750     logInfo.flag &= (~DataItem::UPDATE_FLAG);
751     errCode = SQLiteUtils::GetColumnBlobValue(logStatement, 6, logInfo.hashKey);  // 6 means hashKey index
752     if (errCode != E_OK) {
753         LOGE("[SQLiteRDBUtils] Get hashKey failed %d", errCode);
754     }
755     return errCode;
756 }
757 
GetLogInfoPre(sqlite3_stmt * queryStmt,DistributedTableMode mode,const DataItem & dataItem,LogInfo & logInfoGet)758 int SQLiteRelationalUtils::GetLogInfoPre(sqlite3_stmt *queryStmt, DistributedTableMode mode,
759     const DataItem &dataItem, LogInfo &logInfoGet)
760 {
761     if (queryStmt == nullptr) {
762         return -E_INVALID_ARGS;
763     }
764     int errCode = SQLiteUtils::BindBlobToStatement(queryStmt, 1, dataItem.hashKey);  // 1 means hash key index.
765     if (errCode != E_OK) {
766         LOGE("[SQLiteRDBUtils] Bind hashKey failed %d", errCode);
767         return errCode;
768     }
769     if (mode != DistributedTableMode::COLLABORATION) {
770         errCode = SQLiteUtils::BindTextToStatement(queryStmt, 2, dataItem.dev);  // 2 means device index.
771         if (errCode != E_OK) {
772             LOGE("[SQLiteRDBUtils] Bind dev failed %d", errCode);
773             return errCode;
774         }
775     }
776 
777     errCode = SQLiteUtils::StepWithRetry(queryStmt, false); // rdb not exist mem db
778     if (errCode != SQLiteUtils::MapSQLiteErrno(SQLITE_ROW)) {
779         errCode = -E_NOT_FOUND;
780     } else {
781         errCode = SQLiteRelationalUtils::GetLogData(queryStmt, logInfoGet);
782     }
783     return errCode;
784 }
785 } // namespace DistributedDB