1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "sqlite_relational_utils.h"
17 #include "db_common.h"
18 #include "db_errno.h"
19 #include "cloud/cloud_db_types.h"
20 #include "sqlite_utils.h"
21 #include "cloud/cloud_storage_utils.h"
22 #include "res_finalizer.h"
23 #include "runtime_context.h"
24 #include "cloud/cloud_db_constant.h"
25
26 namespace DistributedDB {
GetDataValueByType(sqlite3_stmt * statement,int cid,DataValue & value)27 int SQLiteRelationalUtils::GetDataValueByType(sqlite3_stmt *statement, int cid, DataValue &value)
28 {
29 if (statement == nullptr || cid < 0 || cid >= sqlite3_column_count(statement)) {
30 return -E_INVALID_ARGS;
31 }
32
33 int errCode = E_OK;
34 int storageType = sqlite3_column_type(statement, cid);
35 switch (storageType) {
36 case SQLITE_INTEGER:
37 value = static_cast<int64_t>(sqlite3_column_int64(statement, cid));
38 break;
39 case SQLITE_FLOAT:
40 value = sqlite3_column_double(statement, cid);
41 break;
42 case SQLITE_BLOB: {
43 std::vector<uint8_t> blobValue;
44 errCode = SQLiteUtils::GetColumnBlobValue(statement, cid, blobValue);
45 if (errCode != E_OK) {
46 return errCode;
47 }
48 auto blob = new (std::nothrow) Blob;
49 if (blob == nullptr) {
50 return -E_OUT_OF_MEMORY;
51 }
52 blob->WriteBlob(blobValue.data(), static_cast<uint32_t>(blobValue.size()));
53 errCode = value.Set(blob);
54 if (errCode != E_OK) {
55 delete blob;
56 blob = nullptr;
57 }
58 break;
59 }
60 case SQLITE_NULL:
61 break;
62 case SQLITE3_TEXT: {
63 std::string str;
64 (void)SQLiteUtils::GetColumnTextValue(statement, cid, str);
65 value = str;
66 if (value.GetType() != StorageType::STORAGE_TYPE_TEXT) {
67 errCode = -E_OUT_OF_MEMORY;
68 }
69 break;
70 }
71 default:
72 break;
73 }
74 return errCode;
75 }
76
GetSelectValues(sqlite3_stmt * stmt)77 std::vector<DataValue> SQLiteRelationalUtils::GetSelectValues(sqlite3_stmt *stmt)
78 {
79 std::vector<DataValue> values;
80 for (int cid = 0, colCount = sqlite3_column_count(stmt); cid < colCount; ++cid) {
81 DataValue value;
82 (void)GetDataValueByType(stmt, cid, value);
83 values.emplace_back(std::move(value));
84 }
85 return values;
86 }
87
GetCloudValueByType(sqlite3_stmt * statement,int type,int cid,Type & cloudValue)88 int SQLiteRelationalUtils::GetCloudValueByType(sqlite3_stmt *statement, int type, int cid, Type &cloudValue)
89 {
90 if (statement == nullptr || cid < 0 || cid >= sqlite3_column_count(statement)) {
91 return -E_INVALID_ARGS;
92 }
93 switch (sqlite3_column_type(statement, cid)) {
94 case SQLITE_INTEGER: {
95 if (type == TYPE_INDEX<bool>) {
96 cloudValue = static_cast<bool>(sqlite3_column_int(statement, cid));
97 break;
98 }
99 cloudValue = static_cast<int64_t>(sqlite3_column_int64(statement, cid));
100 break;
101 }
102 case SQLITE_FLOAT: {
103 cloudValue = sqlite3_column_double(statement, cid);
104 break;
105 }
106 case SQLITE_BLOB: {
107 std::vector<uint8_t> blobValue;
108 int errCode = SQLiteUtils::GetColumnBlobValue(statement, cid, blobValue);
109 if (errCode != E_OK) {
110 return errCode;
111 }
112 cloudValue = blobValue;
113 break;
114 }
115 case SQLITE3_TEXT: {
116 bool isBlob = (type == TYPE_INDEX<Bytes> || type == TYPE_INDEX<Asset> || type == TYPE_INDEX<Assets>);
117 if (isBlob) {
118 std::vector<uint8_t> blobValue;
119 int errCode = SQLiteUtils::GetColumnBlobValue(statement, cid, blobValue);
120 if (errCode != E_OK) {
121 return errCode;
122 }
123 cloudValue = blobValue;
124 break;
125 }
126 std::string str;
127 (void)SQLiteUtils::GetColumnTextValue(statement, cid, str);
128 cloudValue = str;
129 break;
130 }
131 default: {
132 cloudValue = Nil();
133 }
134 }
135 return E_OK;
136 }
137
CalCloudValueLen(Type & cloudValue,uint32_t & totalSize)138 void SQLiteRelationalUtils::CalCloudValueLen(Type &cloudValue, uint32_t &totalSize)
139 {
140 switch (cloudValue.index()) {
141 case TYPE_INDEX<int64_t>:
142 totalSize += sizeof(int64_t);
143 break;
144 case TYPE_INDEX<double>:
145 totalSize += sizeof(double);
146 break;
147 case TYPE_INDEX<std::string>:
148 totalSize += std::get<std::string>(cloudValue).size();
149 break;
150 case TYPE_INDEX<bool>:
151 totalSize += sizeof(int32_t);
152 break;
153 case TYPE_INDEX<Bytes>:
154 case TYPE_INDEX<Asset>:
155 case TYPE_INDEX<Assets>:
156 totalSize += std::get<Bytes>(cloudValue).size();
157 break;
158 default: {
159 break;
160 }
161 }
162 }
163
BindStatementByType(sqlite3_stmt * statement,int cid,Type & typeVal)164 int SQLiteRelationalUtils::BindStatementByType(sqlite3_stmt *statement, int cid, Type &typeVal)
165 {
166 int errCode = E_OK;
167 switch (typeVal.index()) {
168 case TYPE_INDEX<int64_t>: {
169 int64_t value = 0;
170 (void)CloudStorageUtils::GetValueFromType(typeVal, value);
171 errCode = SQLiteUtils::BindInt64ToStatement(statement, cid, value);
172 break;
173 }
174 case TYPE_INDEX<bool>: {
175 bool value = false;
176 (void)CloudStorageUtils::GetValueFromType<bool>(typeVal, value);
177 errCode = SQLiteUtils::BindInt64ToStatement(statement, cid, value);
178 break;
179 }
180 case TYPE_INDEX<double>: {
181 double value = 0.0;
182 (void)CloudStorageUtils::GetValueFromType<double>(typeVal, value);
183 errCode = SQLiteUtils::MapSQLiteErrno(sqlite3_bind_double(statement, cid, value));
184 break;
185 }
186 case TYPE_INDEX<std::string>: {
187 std::string value;
188 (void)CloudStorageUtils::GetValueFromType<std::string>(typeVal, value);
189 errCode = SQLiteUtils::BindTextToStatement(statement, cid, value);
190 break;
191 }
192 default: {
193 errCode = BindExtendStatementByType(statement, cid, typeVal);
194 break;
195 }
196 }
197 return errCode;
198 }
199
BindExtendStatementByType(sqlite3_stmt * statement,int cid,Type & typeVal)200 int SQLiteRelationalUtils::BindExtendStatementByType(sqlite3_stmt *statement, int cid, Type &typeVal)
201 {
202 int errCode = E_OK;
203 switch (typeVal.index()) {
204 case TYPE_INDEX<Bytes>: {
205 Bytes value;
206 (void)CloudStorageUtils::GetValueFromType<Bytes>(typeVal, value);
207 errCode = SQLiteUtils::BindBlobToStatement(statement, cid, value);
208 break;
209 }
210 case TYPE_INDEX<Asset>: {
211 Asset value;
212 (void)CloudStorageUtils::GetValueFromType<Asset>(typeVal, value);
213 Bytes val;
214 errCode = RuntimeContext::GetInstance()->AssetToBlob(value, val);
215 if (errCode != E_OK) {
216 break;
217 }
218 errCode = SQLiteUtils::BindBlobToStatement(statement, cid, val);
219 break;
220 }
221 case TYPE_INDEX<Assets>: {
222 Assets value;
223 (void)CloudStorageUtils::GetValueFromType<Assets>(typeVal, value);
224 Bytes val;
225 errCode = RuntimeContext::GetInstance()->AssetsToBlob(value, val);
226 if (errCode != E_OK) {
227 break;
228 }
229 errCode = SQLiteUtils::BindBlobToStatement(statement, cid, val);
230 break;
231 }
232 default: {
233 errCode = SQLiteUtils::MapSQLiteErrno(sqlite3_bind_null(statement, cid));
234 break;
235 }
236 }
237 return errCode;
238 }
239
GetSelectVBucket(sqlite3_stmt * stmt,VBucket & bucket)240 int SQLiteRelationalUtils::GetSelectVBucket(sqlite3_stmt *stmt, VBucket &bucket)
241 {
242 if (stmt == nullptr) {
243 return -E_INVALID_ARGS;
244 }
245 for (int cid = 0, colCount = sqlite3_column_count(stmt); cid < colCount; ++cid) {
246 Type typeVal;
247 int errCode = GetTypeValByStatement(stmt, cid, typeVal);
248 if (errCode != E_OK) {
249 LOGE("get typeVal from stmt failed");
250 return errCode;
251 }
252 const char *colName = sqlite3_column_name(stmt, cid);
253 bucket.insert_or_assign(colName, std::move(typeVal));
254 }
255 return E_OK;
256 }
257
GetDbFileName(sqlite3 * db,std::string & fileName)258 bool SQLiteRelationalUtils::GetDbFileName(sqlite3 *db, std::string &fileName)
259 {
260 if (db == nullptr) {
261 return false;
262 }
263
264 auto dbFilePath = sqlite3_db_filename(db, nullptr);
265 if (dbFilePath == nullptr) {
266 return false;
267 }
268 fileName = std::string(dbFilePath);
269 return true;
270 }
271
GetTypeValByStatement(sqlite3_stmt * stmt,int cid,Type & typeVal)272 int SQLiteRelationalUtils::GetTypeValByStatement(sqlite3_stmt *stmt, int cid, Type &typeVal)
273 {
274 if (stmt == nullptr || cid < 0 || cid >= sqlite3_column_count(stmt)) {
275 return -E_INVALID_ARGS;
276 }
277 int errCode = E_OK;
278 switch (sqlite3_column_type(stmt, cid)) {
279 case SQLITE_INTEGER: {
280 const char *declType = sqlite3_column_decltype(stmt, cid);
281 if (declType == nullptr) { // LCOV_EXCL_BR_LINE
282 typeVal = static_cast<int64_t>(sqlite3_column_int64(stmt, cid));
283 break;
284 }
285 if (strcasecmp(declType, SchemaConstant::KEYWORD_TYPE_BOOL.c_str()) == 0 ||
286 strcasecmp(declType, SchemaConstant::KEYWORD_TYPE_BOOLEAN.c_str()) == 0) { // LCOV_EXCL_BR_LINE
287 typeVal = static_cast<bool>(sqlite3_column_int(stmt, cid));
288 break;
289 }
290 typeVal = static_cast<int64_t>(sqlite3_column_int64(stmt, cid));
291 break;
292 }
293 case SQLITE_FLOAT: {
294 typeVal = sqlite3_column_double(stmt, cid);
295 break;
296 }
297 case SQLITE_BLOB: {
298 errCode = GetBlobByStatement(stmt, cid, typeVal);
299 break;
300 }
301 case SQLITE3_TEXT: {
302 errCode = GetBlobByStatement(stmt, cid, typeVal);
303 if (errCode != E_OK || typeVal.index() != TYPE_INDEX<Nil>) { // LCOV_EXCL_BR_LINE
304 break;
305 }
306 std::string str;
307 (void)SQLiteUtils::GetColumnTextValue(stmt, cid, str);
308 typeVal = str;
309 break;
310 }
311 default: {
312 typeVal = Nil();
313 }
314 }
315 return errCode;
316 }
317
GetBlobByStatement(sqlite3_stmt * stmt,int cid,Type & typeVal)318 int SQLiteRelationalUtils::GetBlobByStatement(sqlite3_stmt *stmt, int cid, Type &typeVal)
319 {
320 const char *declType = sqlite3_column_decltype(stmt, cid);
321 int errCode = E_OK;
322 if (declType != nullptr && strcasecmp(declType, CloudDbConstant::ASSET) == 0) { // LCOV_EXCL_BR_LINE
323 std::vector<uint8_t> blobValue;
324 errCode = SQLiteUtils::GetColumnBlobValue(stmt, cid, blobValue);
325 if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
326 return errCode;
327 }
328 Asset asset;
329 errCode = RuntimeContext::GetInstance()->BlobToAsset(blobValue, asset);
330 if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
331 return errCode;
332 }
333 typeVal = asset;
334 } else if (declType != nullptr && strcasecmp(declType, CloudDbConstant::ASSETS) == 0) {
335 std::vector<uint8_t> blobValue;
336 errCode = SQLiteUtils::GetColumnBlobValue(stmt, cid, blobValue);
337 if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
338 return errCode;
339 }
340 Assets assets;
341 errCode = RuntimeContext::GetInstance()->BlobToAssets(blobValue, assets);
342 if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
343 return errCode;
344 }
345 typeVal = assets;
346 } else if (sqlite3_column_type(stmt, cid) == SQLITE_BLOB) {
347 std::vector<uint8_t> blobValue;
348 errCode = SQLiteUtils::GetColumnBlobValue(stmt, cid, blobValue);
349 if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
350 return errCode;
351 }
352 typeVal = blobValue;
353 }
354 return E_OK;
355 }
356
SelectServerObserver(sqlite3 * db,const std::string & tableName,bool isChanged)357 int SQLiteRelationalUtils::SelectServerObserver(sqlite3 *db, const std::string &tableName, bool isChanged)
358 {
359 if (db == nullptr || tableName.empty()) {
360 return -E_INVALID_ARGS;
361 }
362 std::string sql;
363 if (isChanged) {
364 sql = "SELECT server_observer('" + tableName + "', 1);";
365 } else {
366 sql = "SELECT server_observer('" + tableName + "', 0);";
367 }
368 sqlite3_stmt *stmt = nullptr;
369 int errCode = SQLiteUtils::GetStatement(db, sql, stmt);
370 if (errCode != E_OK) {
371 LOGE("get select server observer stmt failed. %d", errCode);
372 return errCode;
373 }
374 errCode = SQLiteUtils::StepWithRetry(stmt, false);
375 int ret = E_OK;
376 SQLiteUtils::ResetStatement(stmt, true, ret);
377 if (errCode != SQLiteUtils::MapSQLiteErrno(SQLITE_ROW)) {
378 LOGE("select server observer failed. %d", errCode);
379 return SQLiteUtils::MapSQLiteErrno(errCode);
380 }
381 return ret == E_OK ? E_OK : ret;
382 }
383
AddUpgradeSqlToList(const TableInfo & tableInfo,const std::vector<std::pair<std::string,std::string>> & fieldList,std::vector<std::string> & sqlList)384 void SQLiteRelationalUtils::AddUpgradeSqlToList(const TableInfo &tableInfo,
385 const std::vector<std::pair<std::string, std::string>> &fieldList, std::vector<std::string> &sqlList)
386 {
387 for (const auto &[colName, colType] : fieldList) {
388 auto it = tableInfo.GetFields().find(colName);
389 if (it != tableInfo.GetFields().end()) {
390 continue;
391 }
392 sqlList.push_back("alter table " + tableInfo.GetTableName() + " add " + colName +
393 " " + colType + ";");
394 }
395 }
396
AnalysisTrackerTable(sqlite3 * db,const TrackerTable & trackerTable,TableInfo & tableInfo)397 int SQLiteRelationalUtils::AnalysisTrackerTable(sqlite3 *db, const TrackerTable &trackerTable, TableInfo &tableInfo)
398 {
399 int errCode = SQLiteUtils::AnalysisSchema(db, trackerTable.GetTableName(), tableInfo, true);
400 if (errCode != E_OK) {
401 LOGE("analysis table schema failed %d.", errCode);
402 return errCode;
403 }
404 tableInfo.SetTrackerTable(trackerTable);
405 errCode = tableInfo.CheckTrackerTable();
406 if (errCode != E_OK) {
407 LOGE("check tracker table schema failed %d.", errCode);
408 }
409 return errCode;
410 }
411
QueryCount(sqlite3 * db,const std::string & tableName,int64_t & count)412 int SQLiteRelationalUtils::QueryCount(sqlite3 *db, const std::string &tableName, int64_t &count)
413 {
414 std::string sql = "SELECT COUNT(1) FROM " + tableName ;
415 sqlite3_stmt *stmt = nullptr;
416 int errCode = SQLiteUtils::GetStatement(db, sql, stmt);
417 if (errCode != E_OK) {
418 LOGE("Query count failed. %d", errCode);
419 return errCode;
420 }
421 errCode = SQLiteUtils::StepWithRetry(stmt, false);
422 if (errCode == SQLiteUtils::MapSQLiteErrno(SQLITE_ROW)) {
423 count = static_cast<int64_t>(sqlite3_column_int64(stmt, 0));
424 errCode = E_OK;
425 } else {
426 LOGE("Failed to get the count. %d", errCode);
427 }
428 SQLiteUtils::ResetStatement(stmt, true, errCode);
429 return errCode;
430 }
431
GetCursor(sqlite3 * db,const std::string & tableName,uint64_t & cursor)432 int SQLiteRelationalUtils::GetCursor(sqlite3 *db, const std::string &tableName, uint64_t &cursor)
433 {
434 cursor = DBConstant::INVALID_CURSOR;
435 std::string sql = "SELECT value FROM " + std::string(DBConstant::RELATIONAL_PREFIX) + "metadata where key = ?;";
436 sqlite3_stmt *stmt = nullptr;
437 int errCode = SQLiteUtils::GetStatement(db, sql, stmt);
438 if (errCode != E_OK) {
439 LOGE("[Storage Executor] Get cursor of table[%s length[%u]] failed=%d",
440 DBCommon::StringMiddleMasking(tableName).c_str(), tableName.length(), errCode);
441 return errCode;
442 }
443 ResFinalizer finalizer([stmt]() {
444 sqlite3_stmt *statement = stmt;
445 int ret = E_OK;
446 SQLiteUtils::ResetStatement(statement, true, ret);
447 if (ret != E_OK) {
448 LOGW("Reset stmt failed %d when get cursor", ret);
449 }
450 });
451 Key key;
452 DBCommon::StringToVector(DBCommon::GetCursorKey(tableName), key);
453 errCode = SQLiteUtils::BindBlobToStatement(stmt, 1, key, false); // first arg.
454 if (errCode != E_OK) {
455 LOGE("[Storage Executor] Bind failed when get cursor of table[%s length[%u]] failed=%d",
456 DBCommon::StringMiddleMasking(tableName).c_str(), tableName.length(), errCode);
457 return errCode;
458 }
459 errCode = SQLiteUtils::StepWithRetry(stmt, false);
460 if (errCode == SQLiteUtils::MapSQLiteErrno(SQLITE_ROW)) {
461 int64_t tmpCursor = static_cast<int64_t>(sqlite3_column_int64(stmt, 0));
462 if (tmpCursor >= 0) {
463 cursor = static_cast<uint64_t>(tmpCursor);
464 }
465 }
466 return cursor == DBConstant::INVALID_CURSOR ? errCode : E_OK;
467 }
468
GetFieldsNeedContain(const TableInfo & tableInfo,const std::vector<FieldInfo> & syncFields,std::set<std::string> & fieldsNeedContain,std::set<std::string> & fieldsNotDecrease,std::set<std::string> & requiredNotNullFields)469 void GetFieldsNeedContain(const TableInfo &tableInfo, const std::vector<FieldInfo> &syncFields,
470 std::set<std::string> &fieldsNeedContain, std::set<std::string> &fieldsNotDecrease,
471 std::set<std::string> &requiredNotNullFields)
472 {
473 // should not decrease distributed field
474 for (const auto &field : tableInfo.GetSyncField()) {
475 fieldsNotDecrease.insert(field);
476 }
477 const std::vector<CompositeFields> &uniqueDefines = tableInfo.GetUniqueDefine();
478 for (const auto &compositeFields : uniqueDefines) {
479 for (const auto &fieldName : compositeFields) {
480 if (tableInfo.IsPrimaryKey(fieldName)) {
481 continue;
482 }
483 fieldsNeedContain.insert(fieldName);
484 }
485 }
486 const FieldInfoMap &fieldInfoMap = tableInfo.GetFields();
487 for (const auto &entry : fieldInfoMap) {
488 const FieldInfo &fieldInfo = entry.second;
489 if (fieldInfo.IsNotNull() && fieldInfo.GetDefaultValue().empty()) {
490 requiredNotNullFields.insert(fieldInfo.GetFieldName());
491 }
492 }
493 }
494
CheckRequireFieldsInMap(const std::set<std::string> & fieldsNeedContain,const std::map<std::string,bool> & fieldsMap)495 bool CheckRequireFieldsInMap(const std::set<std::string> &fieldsNeedContain,
496 const std::map<std::string, bool> &fieldsMap)
497 {
498 for (auto &fieldNeedContain : fieldsNeedContain) {
499 if (fieldsMap.find(fieldNeedContain) == fieldsMap.end()) {
500 LOGE("Required column[%s [%zu]] not found", DBCommon::StringMiddleMasking(fieldNeedContain).c_str(),
501 fieldNeedContain.size());
502 return false;
503 }
504 if (!fieldsMap.at(fieldNeedContain)) {
505 LOGE("The isP2pSync of required column[%s [%zu]] is false",
506 DBCommon::StringMiddleMasking(fieldNeedContain).c_str(), fieldNeedContain.size());
507 return false;
508 }
509 }
510 return true;
511 }
512
IsMarkUniqueColumnInvalid(const TableInfo & tableInfo,const std::vector<DistributedField> & originFields)513 bool IsMarkUniqueColumnInvalid(const TableInfo &tableInfo, const std::vector<DistributedField> &originFields)
514 {
515 int count = 0;
516 for (const auto &field : originFields) {
517 if (field.isSpecified && tableInfo.IsUniqueField(field.colName)) {
518 count++;
519 if (count > 1) {
520 return true;
521 }
522 }
523 }
524 return false;
525 }
526
IsDistributedPkInvalid(const TableInfo & tableInfo,const std::set<std::string,CaseInsensitiveComparator> & distributedPk,const std::vector<DistributedField> & originFields,bool isForceUpgrade)527 bool IsDistributedPkInvalid(const TableInfo &tableInfo,
528 const std::set<std::string, CaseInsensitiveComparator> &distributedPk,
529 const std::vector<DistributedField> &originFields, bool isForceUpgrade)
530 {
531 auto lastDistributedPk = DBCommon::TransformToCaseInsensitive(tableInfo.GetSyncDistributedPk());
532 if (!isForceUpgrade && !lastDistributedPk.empty() && distributedPk != lastDistributedPk) {
533 LOGE("distributed pk has change last %zu now %zu", lastDistributedPk.size(), distributedPk.size());
534 return true;
535 }
536 // check pk is same or local pk is auto increase and set pk is unique
537 if (tableInfo.IsNoPkTable()) {
538 return false;
539 }
540 if (!distributedPk.empty() && distributedPk.size() != tableInfo.GetPrimaryKey().size()) {
541 return true;
542 }
543 if (tableInfo.GetAutoIncrement()) {
544 if (distributedPk.empty()) {
545 return false;
546 }
547 auto uniqueAndPkDefine = tableInfo.GetUniqueAndPkDefine();
548 if (IsMarkUniqueColumnInvalid(tableInfo, originFields)) {
549 LOGE("Mark more than one unique column specified in auto increment table: %s, tableName len: %zu",
550 DBCommon::StringMiddleMasking(tableInfo.GetTableName()).c_str(), tableInfo.GetTableName().length());
551 return true;
552 }
553 auto find = std::any_of(uniqueAndPkDefine.begin(), uniqueAndPkDefine.end(), [&distributedPk](const auto &item) {
554 // unique index field count should be same
555 return item.size() == distributedPk.size() && distributedPk == DBCommon::TransformToCaseInsensitive(item);
556 });
557 bool isMissMatch = !find;
558 if (isMissMatch) {
559 LOGE("Miss match distributed pk size %zu in auto increment table %s", distributedPk.size(),
560 DBCommon::StringMiddleMasking(tableInfo.GetTableName()).c_str());
561 }
562 return isMissMatch;
563 }
564 for (const auto &field : originFields) {
565 bool isLocalPk = tableInfo.IsPrimaryKey(field.colName);
566 if (field.isSpecified && !isLocalPk) {
567 LOGE("Column[%s [%zu]] is not primary key but mark specified",
568 DBCommon::StringMiddleMasking(field.colName).c_str(), field.colName.size());
569 return true;
570 }
571 if (isLocalPk && !field.isP2pSync) {
572 LOGE("Column[%s [%zu]] is primary key but set isP2pSync false",
573 DBCommon::StringMiddleMasking(field.colName).c_str(), field.colName.size());
574 return true;
575 }
576 }
577 return false;
578 }
579
IsDistributedSchemaSupport(const TableInfo & tableInfo,const std::vector<DistributedField> & fields)580 bool IsDistributedSchemaSupport(const TableInfo &tableInfo, const std::vector<DistributedField> &fields)
581 {
582 if (!tableInfo.GetAutoIncrement()) {
583 return true;
584 }
585 bool isSyncPk = false;
586 bool isSyncOtherSpecified = false;
587 for (const auto &item : fields) {
588 if (tableInfo.IsPrimaryKey(item.colName) && item.isP2pSync) {
589 isSyncPk = true;
590 } else if (item.isSpecified && item.isP2pSync) {
591 isSyncOtherSpecified = true;
592 }
593 }
594 if (isSyncPk && isSyncOtherSpecified) {
595 LOGE("Not support sync with auto increment pk and other specified col");
596 return false;
597 }
598 return true;
599 }
600
CheckDistributedSchemaFields(const TableInfo & tableInfo,const std::vector<FieldInfo> & syncFields,const std::vector<DistributedField> & fields,bool isForceUpgrade)601 int CheckDistributedSchemaFields(const TableInfo &tableInfo, const std::vector<FieldInfo> &syncFields,
602 const std::vector<DistributedField> &fields, bool isForceUpgrade)
603 {
604 if (fields.empty()) {
605 LOGE("fields cannot be empty");
606 return -E_SCHEMA_MISMATCH;
607 }
608 if (!IsDistributedSchemaSupport(tableInfo, fields)) {
609 return -E_NOT_SUPPORT;
610 }
611 std::set<std::string, CaseInsensitiveComparator> distributedPk;
612 bool isNoPrimaryKeyTable = tableInfo.IsNoPkTable();
613 for (const auto &field : fields) {
614 if (!tableInfo.IsFieldExist(field.colName)) {
615 LOGE("Column[%s [%zu]] not found in table", DBCommon::StringMiddleMasking(field.colName).c_str(),
616 field.colName.size());
617 return -E_SCHEMA_MISMATCH;
618 }
619 if (isNoPrimaryKeyTable && field.isSpecified) {
620 return -E_SCHEMA_MISMATCH;
621 }
622 if (field.isSpecified && field.isP2pSync) {
623 distributedPk.insert(field.colName);
624 }
625 }
626 if (IsDistributedPkInvalid(tableInfo, distributedPk, fields, isForceUpgrade)) {
627 return -E_SCHEMA_MISMATCH;
628 }
629 std::set<std::string> fieldsNeedContain;
630 std::set<std::string> fieldsNotDecrease;
631 std::set<std::string> requiredNotNullFields;
632 GetFieldsNeedContain(tableInfo, syncFields, fieldsNeedContain, fieldsNotDecrease, requiredNotNullFields);
633 std::map<std::string, bool> fieldsMap;
634 for (auto &field : fields) {
635 fieldsMap.insert({field.colName, field.isP2pSync});
636 }
637 if (!CheckRequireFieldsInMap(fieldsNeedContain, fieldsMap)) {
638 LOGE("The required fields are not found in fieldsMap");
639 return -E_SCHEMA_MISMATCH;
640 }
641 if (!isForceUpgrade && !CheckRequireFieldsInMap(fieldsNotDecrease, fieldsMap)) {
642 LOGE("The fields should not decrease");
643 return -E_DISTRIBUTED_FIELD_DECREASE;
644 }
645 if (!CheckRequireFieldsInMap(requiredNotNullFields, fieldsMap)) {
646 LOGE("The required not-null fields are not found in fieldsMap");
647 return -E_SCHEMA_MISMATCH;
648 }
649 return E_OK;
650 }
651
CheckDistributedSchemaValid(const RelationalSchemaObject & schemaObj,const DistributedSchema & schema,bool isForceUpgrade,SQLiteSingleVerRelationalStorageExecutor * executor)652 int SQLiteRelationalUtils::CheckDistributedSchemaValid(const RelationalSchemaObject &schemaObj,
653 const DistributedSchema &schema, bool isForceUpgrade, SQLiteSingleVerRelationalStorageExecutor *executor)
654 {
655 if (executor == nullptr) {
656 LOGE("[RDBUtils][CheckDistributedSchemaValid] executor is null");
657 return -E_INVALID_ARGS;
658 }
659 sqlite3 *db;
660 int errCode = executor->GetDbHandle(db);
661 if (errCode != E_OK) {
662 LOGE("[RDBUtils][CheckDistributedSchemaValid] sqlite handle failed %d", errCode);
663 return errCode;
664 }
665 for (const auto &table : schema.tables) {
666 if (table.tableName.empty()) {
667 LOGE("[RDBUtils][CheckDistributedSchemaValid] Table name cannot be empty");
668 return -E_SCHEMA_MISMATCH;
669 }
670 TableInfo tableInfo;
671 errCode = SQLiteUtils::AnalysisSchema(db, table.tableName, tableInfo);
672 if (errCode != E_OK) {
673 LOGE("[RDBUtils][CheckDistributedSchemaValid] analyze table %s failed %d",
674 DBCommon::StringMiddleMasking(table.tableName).c_str(), errCode);
675 return errCode == -E_NOT_FOUND ? -E_SCHEMA_MISMATCH : errCode;
676 }
677 tableInfo.SetDistributedTable(schemaObj.GetDistributedTable(table.tableName));
678 errCode = CheckDistributedSchemaFields(tableInfo, schemaObj.GetSyncFieldInfo(table.tableName, false),
679 table.fields, isForceUpgrade);
680 if (errCode != E_OK) {
681 LOGE("[CheckDistributedSchema] Check fields of [%s [%zu]] fail",
682 DBCommon::StringMiddleMasking(table.tableName).c_str(), table.tableName.size());
683 return errCode;
684 }
685 }
686 return E_OK;
687 }
688
FilterRepeatDefine(const DistributedSchema & schema)689 DistributedSchema SQLiteRelationalUtils::FilterRepeatDefine(const DistributedSchema &schema)
690 {
691 DistributedSchema res;
692 res.version = schema.version;
693 std::set<std::string> tableName;
694 std::list<DistributedTable> tableList;
695 for (auto it = schema.tables.rbegin();it != schema.tables.rend(); it++) {
696 if (tableName.find(it->tableName) != tableName.end()) {
697 continue;
698 }
699 tableName.insert(it->tableName);
700 tableList.push_front(FilterRepeatDefine(*it));
701 }
702 for (auto &item : tableList) {
703 res.tables.push_back(std::move(item));
704 }
705 return res;
706 }
707
FilterRepeatDefine(const DistributedTable & table)708 DistributedTable SQLiteRelationalUtils::FilterRepeatDefine(const DistributedTable &table)
709 {
710 DistributedTable res;
711 res.tableName = table.tableName;
712 std::set<std::string> fieldName;
713 std::list<DistributedField> fieldList;
714 for (auto it = table.fields.rbegin();it != table.fields.rend(); it++) {
715 if (fieldName.find(it->colName) != fieldName.end()) {
716 continue;
717 }
718 fieldName.insert(it->colName);
719 fieldList.push_front(*it);
720 }
721 for (auto &item : fieldList) {
722 res.fields.push_back(std::move(item));
723 }
724 return res;
725 }
726
GetLogData(sqlite3_stmt * logStatement,LogInfo & logInfo)727 int SQLiteRelationalUtils::GetLogData(sqlite3_stmt *logStatement, LogInfo &logInfo)
728 {
729 logInfo.dataKey = sqlite3_column_int64(logStatement, 0); // 0 means dataKey index
730
731 std::vector<uint8_t> dev;
732 int errCode = SQLiteUtils::GetColumnBlobValue(logStatement, 1, dev); // 1 means dev index
733 if (errCode != E_OK) {
734 LOGE("[SQLiteRDBUtils] Get dev failed %d", errCode);
735 return errCode;
736 }
737 logInfo.device = std::string(dev.begin(), dev.end());
738
739 std::vector<uint8_t> oriDev;
740 errCode = SQLiteUtils::GetColumnBlobValue(logStatement, 2, oriDev); // 2 means ori_dev index
741 if (errCode != E_OK) {
742 LOGE("[SQLiteRDBUtils] Get ori dev failed %d", errCode);
743 return errCode;
744 }
745 logInfo.originDev = std::string(oriDev.begin(), oriDev.end());
746 logInfo.timestamp = static_cast<uint64_t>(sqlite3_column_int64(logStatement, 3)); // 3 means timestamp index
747 logInfo.wTimestamp = static_cast<uint64_t>(sqlite3_column_int64(logStatement, 4)); // 4 means w_timestamp index
748 logInfo.flag = static_cast<uint64_t>(sqlite3_column_int64(logStatement, 5)); // 5 means flag index
749 logInfo.flag &= (~DataItem::LOCAL_FLAG);
750 logInfo.flag &= (~DataItem::UPDATE_FLAG);
751 errCode = SQLiteUtils::GetColumnBlobValue(logStatement, 6, logInfo.hashKey); // 6 means hashKey index
752 if (errCode != E_OK) {
753 LOGE("[SQLiteRDBUtils] Get hashKey failed %d", errCode);
754 }
755 return errCode;
756 }
757
GetLogInfoPre(sqlite3_stmt * queryStmt,DistributedTableMode mode,const DataItem & dataItem,LogInfo & logInfoGet)758 int SQLiteRelationalUtils::GetLogInfoPre(sqlite3_stmt *queryStmt, DistributedTableMode mode,
759 const DataItem &dataItem, LogInfo &logInfoGet)
760 {
761 if (queryStmt == nullptr) {
762 return -E_INVALID_ARGS;
763 }
764 int errCode = SQLiteUtils::BindBlobToStatement(queryStmt, 1, dataItem.hashKey); // 1 means hash key index.
765 if (errCode != E_OK) {
766 LOGE("[SQLiteRDBUtils] Bind hashKey failed %d", errCode);
767 return errCode;
768 }
769 if (mode != DistributedTableMode::COLLABORATION) {
770 errCode = SQLiteUtils::BindTextToStatement(queryStmt, 2, dataItem.dev); // 2 means device index.
771 if (errCode != E_OK) {
772 LOGE("[SQLiteRDBUtils] Bind dev failed %d", errCode);
773 return errCode;
774 }
775 }
776
777 errCode = SQLiteUtils::StepWithRetry(queryStmt, false); // rdb not exist mem db
778 if (errCode != SQLiteUtils::MapSQLiteErrno(SQLITE_ROW)) {
779 errCode = -E_NOT_FOUND;
780 } else {
781 errCode = SQLiteRelationalUtils::GetLogData(queryStmt, logInfoGet);
782 }
783 return errCode;
784 }
785 } // namespace DistributedDB