1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #define LOG_TAG "TransDB"
16 #include "trans_db.h"
17
18 #include "cache_result_set.h"
19 #include "logger.h"
20 #include "rdb_sql_statistic.h"
21 #include "rdb_trace.h"
22 #include "sqlite_sql_builder.h"
23 #include "sqlite_utils.h"
24 #include "step_result_set.h"
25 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM) && !defined(ANDROID_PLATFORM) && !defined(IOS_PLATFORM)
26 #include "sqlite_shared_result_set.h"
27 #endif
28 namespace OHOS::NativeRdb {
29 using namespace OHOS::Rdb;
30 using namespace DistributedRdb;
TransDB(std::shared_ptr<Connection> conn,const std::string & path)31 TransDB::TransDB(std::shared_ptr<Connection> conn, const std::string &path) : conn_(conn), path_(path)
32 {
33 maxArgs_ = conn->GetMaxVariable();
34 }
35
Insert(const std::string & table,const Row & row,Resolution resolution)36 std::pair<int, int64_t> TransDB::Insert(const std::string &table, const Row &row, Resolution resolution)
37 {
38 DISTRIBUTED_DATA_HITRACE(std::string(__FUNCTION__));
39 auto conflictClause = SqliteUtils::GetConflictClause(static_cast<int>(resolution));
40 if (table.empty() || row.IsEmpty() || conflictClause == nullptr) {
41 return { E_INVALID_ARGS, -1 };
42 }
43
44 std::string sql("INSERT");
45 sql.append(conflictClause).append(" INTO ").append(table).append("(");
46 std::vector<ValueObject> args;
47 args.reserve(row.values_.size());
48 const char *split = "";
49 for (const auto &[key, val] : row.values_) {
50 sql.append(split).append(key);
51 if (val.GetType() == ValueObject::TYPE_ASSETS && resolution == ConflictResolution::ON_CONFLICT_REPLACE) {
52 return { E_INVALID_ARGS, -1 };
53 }
54 SqliteSqlBuilder::UpdateAssetStatus(val, AssetValue::STATUS_INSERT);
55 args.push_back(val); // columnValue
56 split = ",";
57 }
58
59 sql.append(") VALUES (");
60 if (!args.empty()) {
61 sql.append(SqliteSqlBuilder::GetSqlArgs(args.size()));
62 }
63
64 sql.append(")");
65 int64_t rowid = -1;
66 auto [errCode, statement] = GetStatement(sql);
67 if (statement == nullptr) {
68 return { errCode, rowid };
69 }
70 errCode = statement->Execute(args);
71 if (errCode != E_OK) {
72 return { errCode, rowid };
73 }
74 rowid = statement->Changes() > 0 ? statement->LastInsertRowId() : -1;
75 return { errCode, rowid };
76 }
77
BatchInsert(const std::string & table,const RefRows & rows)78 std::pair<int, int64_t> TransDB::BatchInsert(const std::string &table, const RefRows &rows)
79 {
80 if (rows.RowSize() == 0) {
81 return { E_OK, 0 };
82 }
83
84 auto batchInfo = SqliteSqlBuilder::GenerateSqls(table, rows, maxArgs_);
85 if (table.empty() || batchInfo.empty()) {
86 LOG_ERROR("empty,table=%{public}s,rows:%{public}zu,max:%{public}d.", SqliteUtils::Anonymous(table).c_str(),
87 rows.RowSize(), maxArgs_);
88 return { E_INVALID_ARGS, -1 };
89 }
90
91 for (const auto &[sql, batchArgs] : batchInfo) {
92 auto [errCode, statement] = GetStatement(sql);
93 if (statement == nullptr) {
94 return { errCode, -1 };
95 }
96 for (const auto &args : batchArgs) {
97 errCode = statement->Execute(args);
98 if (errCode == E_OK) {
99 continue;
100 }
101 LOG_ERROR("failed(0x%{public}x) db:%{public}s table:%{public}s args:%{public}zu", errCode,
102 SqliteUtils::Anonymous(path_).c_str(), SqliteUtils::Anonymous(table).c_str(), args.size());
103 return { errCode, -1 };
104 }
105 }
106 return { E_OK, int64_t(rows.RowSize()) };
107 }
108
BatchInsert(const std::string & table,const ValuesBuckets & rows,const std::vector<std::string> & returningFields,Resolution resolution)109 std::pair<int32_t, Results> TransDB::BatchInsert(const std::string &table, const ValuesBuckets &rows,
110 const std::vector<std::string> &returningFields, Resolution resolution)
111 {
112 if (rows.RowSize() == 0) {
113 return { E_OK, 0 };
114 }
115
116 auto sqlArgs = SqliteSqlBuilder::GenerateSqls(table, rows, maxArgs_, resolution);
117 if (sqlArgs.size() != 1 || sqlArgs.front().second.size() != 1) {
118 auto [fields, values] = rows.GetFieldsAndValues();
119 LOG_ERROR("invalid args, table=%{public}s, rows:%{public}zu, fields:%{public}zu, max:%{public}d.",
120 SqliteUtils::Anonymous(table).c_str(), rows.RowSize(), fields != nullptr ? fields->size() : 0, maxArgs_);
121 return { E_INVALID_ARGS, -1 };
122 }
123 auto &[sql, bindArgs] = sqlArgs.front();
124 SqliteSqlBuilder::AppendReturning(sql, returningFields);
125 auto [errCode, statement] = GetStatement(sql);
126 if (statement == nullptr) {
127 LOG_ERROR("statement is nullptr, errCode:0x%{public}x, args:%{public}zu, table:%{public}s.", errCode,
128 bindArgs.size(), SqliteUtils::Anonymous(table).c_str());
129 return { errCode, -1 };
130 }
131 auto args = std::ref(bindArgs.front());
132 errCode = statement->Execute(args);
133 if (errCode != E_OK) {
134 LOG_ERROR("failed,errCode:%{public}d,table:%{public}s,args:%{public}zu,resolution:%{public}d.", errCode,
135 SqliteUtils::Anonymous(table).c_str(), args.get().size(), static_cast<int32_t>(resolution));
136 }
137 return { errCode, GenerateResult(errCode, statement) };
138 }
139
Update(const Row & row,const AbsRdbPredicates & predicates,const std::vector<std::string> & returningFields,Resolution resolution)140 std::pair<int32_t, Results> TransDB::Update(const Row &row, const AbsRdbPredicates &predicates,
141 const std::vector<std::string> &returningFields, Resolution resolution)
142 {
143 DISTRIBUTED_DATA_HITRACE(std::string(__FUNCTION__));
144 auto clause = SqliteUtils::GetConflictClause(static_cast<int>(resolution));
145 auto table = predicates.GetTableName();
146 if (table.empty() || row.IsEmpty() || clause == nullptr) {
147 return { E_INVALID_ARGS, 0 };
148 }
149
150 std::string sql("UPDATE");
151 sql.append(clause).append(" ").append(table).append(" SET ");
152 std::vector<ValueObject> totalArgs;
153 auto args = predicates.GetBindArgs();
154 totalArgs.reserve(row.values_.size() + args.size());
155 const char *split = "";
156 for (auto &[key, val] : row.values_) {
157 sql.append(split);
158 if (val.GetType() == ValueObject::TYPE_ASSETS) {
159 sql.append(key).append("=merge_assets(").append(key).append(", ?)");
160 } else if (val.GetType() == ValueObject::TYPE_ASSET) {
161 sql.append(key).append("=merge_asset(").append(key).append(", ?)");
162 } else {
163 sql.append(key).append("=?");
164 }
165 totalArgs.push_back(val);
166 split = ",";
167 }
168 auto where = predicates.GetWhereClause();
169 if (!where.empty()) {
170 sql.append(" WHERE ").append(where);
171 }
172 SqliteSqlBuilder::AppendReturning(sql, returningFields);
173 totalArgs.insert(totalArgs.end(), args.begin(), args.end());
174 auto [errCode, statement] = GetStatement(sql);
175 if (errCode != E_OK || statement == nullptr) {
176 return { errCode != E_OK ? errCode : E_ERROR, -1 };
177 }
178
179 errCode = statement->Execute(totalArgs);
180 if (errCode != E_OK) {
181 LOG_ERROR("failed,errCode:%{public}d,table:%{public}s,returningFields:%{public}zu,resolution:%{public}d.",
182 errCode, SqliteUtils::Anonymous(table).c_str(), returningFields.size(), static_cast<int32_t>(resolution));
183 }
184 return { errCode, GenerateResult(errCode, statement) };
185 }
186
Delete(const AbsRdbPredicates & predicates,const std::vector<std::string> & returningFields)187 std::pair<int32_t, Results> TransDB::Delete(
188 const AbsRdbPredicates &predicates, const std::vector<std::string> &returningFields)
189 {
190 DISTRIBUTED_DATA_HITRACE(std::string(__FUNCTION__));
191 auto table = predicates.GetTableName();
192 if (table.empty()) {
193 return { E_INVALID_ARGS, -1 };
194 }
195
196 std::string sql;
197 sql.append("DELETE FROM ").append(table);
198 auto whereClause = predicates.GetWhereClause();
199 if (!whereClause.empty()) {
200 sql.append(" WHERE ").append(whereClause);
201 }
202 SqliteSqlBuilder::AppendReturning(sql, returningFields);
203 auto [errCode, statement] = GetStatement(sql);
204 if (errCode != E_OK || statement == nullptr) {
205 return { errCode != E_OK ? errCode : E_ERROR, -1 };
206 }
207 errCode = statement->Execute(predicates.GetBindArgs());
208 if (errCode != E_OK) {
209 LOG_ERROR("failed,errCode:%{public}d,table:%{public}s,returningFields:%{public}zu.", errCode,
210 SqliteUtils::Anonymous(table).c_str(), returningFields.size());
211 }
212 return { errCode, GenerateResult(errCode, statement) };
213 }
214
QuerySql(const std::string & sql,const Values & args)215 std::shared_ptr<AbsSharedResultSet> TransDB::QuerySql(const std::string &sql, const Values &args)
216 {
217 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM) && !defined(ANDROID_PLATFORM) && !defined(IOS_PLATFORM)
218 auto start = std::chrono::steady_clock::now();
219 return std::make_shared<SqliteSharedResultSet>(start, conn_.lock(), sql, args, path_);
220 #else
221 (void)sql;
222 (void)args;
223 return nullptr;
224 #endif
225 }
226
QueryByStep(const std::string & sql,const Values & args,bool preCount)227 std::shared_ptr<ResultSet> TransDB::QueryByStep(const std::string &sql, const Values &args, bool preCount)
228 {
229 auto start = std::chrono::steady_clock::now();
230 return std::make_shared<StepResultSet>(start, conn_.lock(), sql, args, true, true);
231 }
232
Execute(const std::string & sql,const Values & args,int64_t trxId)233 std::pair<int32_t, ValueObject> TransDB::Execute(const std::string &sql, const Values &args, int64_t trxId)
234 {
235 (void)trxId;
236 ValueObject object;
237 int sqlType = SqliteUtils::GetSqlStatementType(sql);
238 if (!SqliteUtils::IsSupportSqlForExecute(sqlType) && !SqliteUtils::IsSpecial(sqlType)) {
239 LOG_ERROR("Not support the sql:app self can check the SQL, sqlType:%{public}d", sqlType);
240 return { E_INVALID_ARGS, object };
241 }
242
243 auto [errCode, statement] = GetStatement(sql);
244 if (errCode != E_OK) {
245 return { errCode, object };
246 }
247
248 errCode = statement->Execute(args);
249 if (errCode != E_OK) {
250 LOG_ERROR("failed,app self can check the SQL, error:0x%{public}x.", errCode);
251 return { errCode, object };
252 }
253
254 if (sqlType == SqliteUtils::STATEMENT_INSERT) {
255 int64_t outValue = statement->Changes() > 0 ? statement->LastInsertRowId() : -1;
256 return { errCode, ValueObject(outValue) };
257 }
258
259 if (sqlType == SqliteUtils::STATEMENT_UPDATE) {
260 int outValue = statement->Changes();
261 return { errCode, ValueObject(outValue) };
262 }
263
264 if (sqlType == SqliteUtils::STATEMENT_PRAGMA) {
265 if (statement->GetColumnCount() == 1) {
266 return statement->GetColumn(0);
267 }
268 }
269
270 if (sqlType == SqliteUtils::STATEMENT_DDL) {
271 HandleSchemaDDL(statement);
272 }
273 return { errCode, object };
274 }
275
ExecuteExt(const std::string & sql,const Values & args)276 std::pair<int32_t, Results> TransDB::ExecuteExt(const std::string &sql, const Values &args)
277 {
278 ValueObject object;
279 int sqlType = SqliteUtils::GetSqlStatementType(sql);
280 if (!SqliteUtils::IsSupportSqlForExecute(sqlType) && !SqliteUtils::IsSpecial(sqlType)) {
281 LOG_ERROR("Not support the sql:app self can check the SQL");
282 return { E_INVALID_ARGS, -1 };
283 }
284 auto [errCode, statement] = GetStatement(sql);
285 if (errCode != E_OK || statement == nullptr) {
286 return { errCode != E_OK ? errCode : E_ERROR, -1 };
287 }
288
289 errCode = statement->Execute(args);
290 auto result = GenerateResult(
291 errCode, statement, sqlType == SqliteUtils::STATEMENT_INSERT || sqlType == SqliteUtils::STATEMENT_UPDATE);
292 if (errCode != E_OK) {
293 LOG_ERROR("failed,app self can check the SQL, error:0x%{public}x.", errCode);
294 return { errCode, result };
295 }
296
297 if (sqlType == SqliteUtils::STATEMENT_DDL) {
298 HandleSchemaDDL(statement);
299 }
300 return { errCode, result };
301 }
302
HandleSchemaDDL(std::shared_ptr<Statement> statement)303 void TransDB::HandleSchemaDDL(std::shared_ptr<Statement> statement)
304 {
305 if (statement == nullptr) {
306 return;
307 }
308 statement->Reset();
309 statement->Prepare("PRAGMA schema_version");
310 auto [err, version] = statement->ExecuteForValue();
311 if (vSchema_ < static_cast<int64_t>(version)) {
312 LOG_INFO("db:%{public}s exe DDL schema<%{public}" PRIi64 "->%{public}" PRIi64 ">",
313 SqliteUtils::Anonymous(path_).c_str(), vSchema_, static_cast<int64_t>(version));
314 vSchema_ = version;
315 }
316 }
317
GetVersion(int & version)318 int TransDB::GetVersion(int &version)
319 {
320 return E_NOT_SUPPORT;
321 }
322
SetVersion(int version)323 int TransDB::SetVersion(int version)
324 {
325 return E_NOT_SUPPORT;
326 }
327
Sync(const SyncOption & option,const std::vector<std::string> & tables,const AsyncDetail & async)328 int TransDB::Sync(const SyncOption &option, const std::vector<std::string> &tables, const AsyncDetail &async)
329 {
330 if (option.mode != TIME_FIRST || tables.empty()) {
331 return E_INVALID_ARGS;
332 }
333 return RdbStore::Sync(option, tables, async);
334 }
335
GetStatement(const std::string & sql) const336 std::pair<int32_t, std::shared_ptr<Statement>> TransDB::GetStatement(const std::string &sql) const
337 {
338 auto connection = conn_.lock();
339 if (connection == nullptr) {
340 return { E_ALREADY_CLOSED, nullptr };
341 }
342 return connection->CreateStatement(sql, connection);
343 }
344
GenerateResult(int32_t code,std::shared_ptr<Statement> statement,bool isDML)345 Results TransDB::GenerateResult(int32_t code, std::shared_ptr<Statement> statement, bool isDML)
346 {
347 Results result{ -1 };
348 if (statement == nullptr) {
349 return result;
350 }
351 // There are no data changes in other scenarios
352 if (code == E_OK) {
353 result.results = GetValues(statement);
354 result.changed = isDML ? statement->Changes() : 0;
355 }
356 if (code == E_SQLITE_CONSTRAINT) {
357 result.changed = statement->Changes();
358 }
359 if (isDML && result.changed <= 0) {
360 result.results = std::make_shared<CacheResultSet>();
361 }
362 return result;
363 }
364
GetValues(std::shared_ptr<Statement> statement)365 std::shared_ptr<ResultSet> TransDB::GetValues(std::shared_ptr<Statement> statement)
366 {
367 if (statement == nullptr) {
368 return nullptr;
369 }
370 auto [code, rows] = statement->GetRows(MAX_RETURNING_ROWS);
371 auto size = rows.size();
372 std::shared_ptr<ResultSet> result = std::make_shared<CacheResultSet>(std::move(rows));
373 // The correct number of changed rows can only be obtained after completing the step
374 while (code == E_OK && size == MAX_RETURNING_ROWS) {
375 std::tie(code, rows) = statement->GetRows(MAX_RETURNING_ROWS);
376 size = rows.size();
377 }
378 return result;
379 }
380 } // namespace OHOS::NativeRdb