• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #define LOG_TAG "TransDB"
16 #include "trans_db.h"
17 
18 #include "cache_result_set.h"
19 #include "logger.h"
20 #include "rdb_sql_statistic.h"
21 #include "rdb_trace.h"
22 #include "sqlite_sql_builder.h"
23 #include "sqlite_utils.h"
24 #include "step_result_set.h"
25 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM) && !defined(ANDROID_PLATFORM) && !defined(IOS_PLATFORM)
26 #include "sqlite_shared_result_set.h"
27 #endif
28 namespace OHOS::NativeRdb {
29 using namespace OHOS::Rdb;
30 using namespace DistributedRdb;
TransDB(std::shared_ptr<Connection> conn,const std::string & path)31 TransDB::TransDB(std::shared_ptr<Connection> conn, const std::string &path) : conn_(conn), path_(path)
32 {
33     maxArgs_ = conn->GetMaxVariable();
34 }
35 
Insert(const std::string & table,const Row & row,Resolution resolution)36 std::pair<int, int64_t> TransDB::Insert(const std::string &table, const Row &row, Resolution resolution)
37 {
38     DISTRIBUTED_DATA_HITRACE(std::string(__FUNCTION__));
39     auto conflictClause = SqliteUtils::GetConflictClause(static_cast<int>(resolution));
40     if (table.empty() || row.IsEmpty() || conflictClause == nullptr) {
41         return { E_INVALID_ARGS, -1 };
42     }
43 
44     std::string sql("INSERT");
45     sql.append(conflictClause).append(" INTO ").append(table).append("(");
46     std::vector<ValueObject> args;
47     args.reserve(row.values_.size());
48     const char *split = "";
49     for (const auto &[key, val] : row.values_) {
50         sql.append(split).append(key);
51         if (val.GetType() == ValueObject::TYPE_ASSETS && resolution == ConflictResolution::ON_CONFLICT_REPLACE) {
52             return { E_INVALID_ARGS, -1 };
53         }
54         SqliteSqlBuilder::UpdateAssetStatus(val, AssetValue::STATUS_INSERT);
55         args.push_back(val); // columnValue
56         split = ",";
57     }
58 
59     sql.append(") VALUES (");
60     if (!args.empty()) {
61         sql.append(SqliteSqlBuilder::GetSqlArgs(args.size()));
62     }
63 
64     sql.append(")");
65     int64_t rowid = -1;
66     auto [errCode, statement] = GetStatement(sql);
67     if (statement == nullptr) {
68         return { errCode, rowid };
69     }
70     errCode = statement->Execute(args);
71     if (errCode != E_OK) {
72         return { errCode, rowid };
73     }
74     rowid = statement->Changes() > 0 ? statement->LastInsertRowId() : -1;
75     return { errCode, rowid };
76 }
77 
BatchInsert(const std::string & table,const RefRows & rows)78 std::pair<int, int64_t> TransDB::BatchInsert(const std::string &table, const RefRows &rows)
79 {
80     if (rows.RowSize() == 0) {
81         return { E_OK, 0 };
82     }
83 
84     auto batchInfo = SqliteSqlBuilder::GenerateSqls(table, rows, maxArgs_);
85     if (table.empty() || batchInfo.empty()) {
86         LOG_ERROR("empty,table=%{public}s,rows:%{public}zu,max:%{public}d.", SqliteUtils::Anonymous(table).c_str(),
87             rows.RowSize(), maxArgs_);
88         return { E_INVALID_ARGS, -1 };
89     }
90 
91     for (const auto &[sql, batchArgs] : batchInfo) {
92         auto [errCode, statement] = GetStatement(sql);
93         if (statement == nullptr) {
94             return { errCode, -1 };
95         }
96         for (const auto &args : batchArgs) {
97             errCode = statement->Execute(args);
98             if (errCode == E_OK) {
99                 continue;
100             }
101             LOG_ERROR("failed(0x%{public}x) db:%{public}s table:%{public}s args:%{public}zu", errCode,
102                 SqliteUtils::Anonymous(path_).c_str(), SqliteUtils::Anonymous(table).c_str(), args.size());
103             return { errCode, -1 };
104         }
105     }
106     return { E_OK, int64_t(rows.RowSize()) };
107 }
108 
BatchInsert(const std::string & table,const ValuesBuckets & rows,const std::vector<std::string> & returningFields,Resolution resolution)109 std::pair<int32_t, Results> TransDB::BatchInsert(const std::string &table, const ValuesBuckets &rows,
110     const std::vector<std::string> &returningFields, Resolution resolution)
111 {
112     if (rows.RowSize() == 0) {
113         return { E_OK, 0 };
114     }
115 
116     auto sqlArgs = SqliteSqlBuilder::GenerateSqls(table, rows, maxArgs_, resolution);
117     if (sqlArgs.size() != 1 || sqlArgs.front().second.size() != 1) {
118         auto [fields, values] = rows.GetFieldsAndValues();
119         LOG_ERROR("invalid args, table=%{public}s, rows:%{public}zu, fields:%{public}zu, max:%{public}d.",
120             SqliteUtils::Anonymous(table).c_str(), rows.RowSize(), fields != nullptr ? fields->size() : 0, maxArgs_);
121         return { E_INVALID_ARGS, -1 };
122     }
123     auto &[sql, bindArgs] = sqlArgs.front();
124     SqliteSqlBuilder::AppendReturning(sql, returningFields);
125     auto [errCode, statement] = GetStatement(sql);
126     if (statement == nullptr) {
127         LOG_ERROR("statement is nullptr, errCode:0x%{public}x, args:%{public}zu, table:%{public}s.", errCode,
128             bindArgs.size(), SqliteUtils::Anonymous(table).c_str());
129         return { errCode, -1 };
130     }
131     auto args = std::ref(bindArgs.front());
132     errCode = statement->Execute(args);
133     if (errCode != E_OK) {
134         LOG_ERROR("failed,errCode:%{public}d,table:%{public}s,args:%{public}zu,resolution:%{public}d.", errCode,
135             SqliteUtils::Anonymous(table).c_str(), args.get().size(), static_cast<int32_t>(resolution));
136     }
137     return { errCode, GenerateResult(errCode, statement) };
138 }
139 
Update(const Row & row,const AbsRdbPredicates & predicates,const std::vector<std::string> & returningFields,Resolution resolution)140 std::pair<int32_t, Results> TransDB::Update(const Row &row, const AbsRdbPredicates &predicates,
141     const std::vector<std::string> &returningFields, Resolution resolution)
142 {
143     DISTRIBUTED_DATA_HITRACE(std::string(__FUNCTION__));
144     auto clause = SqliteUtils::GetConflictClause(static_cast<int>(resolution));
145     auto table = predicates.GetTableName();
146     if (table.empty() || row.IsEmpty() || clause == nullptr) {
147         return { E_INVALID_ARGS, 0 };
148     }
149 
150     std::string sql("UPDATE");
151     sql.append(clause).append(" ").append(table).append(" SET ");
152     std::vector<ValueObject> totalArgs;
153     auto args = predicates.GetBindArgs();
154     totalArgs.reserve(row.values_.size() + args.size());
155     const char *split = "";
156     for (auto &[key, val] : row.values_) {
157         sql.append(split);
158         if (val.GetType() == ValueObject::TYPE_ASSETS) {
159             sql.append(key).append("=merge_assets(").append(key).append(", ?)");
160         } else if (val.GetType() == ValueObject::TYPE_ASSET) {
161             sql.append(key).append("=merge_asset(").append(key).append(", ?)");
162         } else {
163             sql.append(key).append("=?");
164         }
165         totalArgs.push_back(val);
166         split = ",";
167     }
168     auto where = predicates.GetWhereClause();
169     if (!where.empty()) {
170         sql.append(" WHERE ").append(where);
171     }
172     SqliteSqlBuilder::AppendReturning(sql, returningFields);
173     totalArgs.insert(totalArgs.end(), args.begin(), args.end());
174     auto [errCode, statement] = GetStatement(sql);
175     if (errCode != E_OK || statement == nullptr) {
176         return { errCode != E_OK ? errCode : E_ERROR, -1 };
177     }
178 
179     errCode = statement->Execute(totalArgs);
180     if (errCode != E_OK) {
181         LOG_ERROR("failed,errCode:%{public}d,table:%{public}s,returningFields:%{public}zu,resolution:%{public}d.",
182             errCode, SqliteUtils::Anonymous(table).c_str(), returningFields.size(), static_cast<int32_t>(resolution));
183     }
184     return { errCode, GenerateResult(errCode, statement) };
185 }
186 
Delete(const AbsRdbPredicates & predicates,const std::vector<std::string> & returningFields)187 std::pair<int32_t, Results> TransDB::Delete(
188     const AbsRdbPredicates &predicates, const std::vector<std::string> &returningFields)
189 {
190     DISTRIBUTED_DATA_HITRACE(std::string(__FUNCTION__));
191     auto table = predicates.GetTableName();
192     if (table.empty()) {
193         return { E_INVALID_ARGS, -1 };
194     }
195 
196     std::string sql;
197     sql.append("DELETE FROM ").append(table);
198     auto whereClause = predicates.GetWhereClause();
199     if (!whereClause.empty()) {
200         sql.append(" WHERE ").append(whereClause);
201     }
202     SqliteSqlBuilder::AppendReturning(sql, returningFields);
203     auto [errCode, statement] = GetStatement(sql);
204     if (errCode != E_OK || statement == nullptr) {
205         return { errCode != E_OK ? errCode : E_ERROR, -1 };
206     }
207     errCode = statement->Execute(predicates.GetBindArgs());
208     if (errCode != E_OK) {
209         LOG_ERROR("failed,errCode:%{public}d,table:%{public}s,returningFields:%{public}zu.", errCode,
210             SqliteUtils::Anonymous(table).c_str(), returningFields.size());
211     }
212     return { errCode, GenerateResult(errCode, statement) };
213 }
214 
QuerySql(const std::string & sql,const Values & args)215 std::shared_ptr<AbsSharedResultSet> TransDB::QuerySql(const std::string &sql, const Values &args)
216 {
217 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM) && !defined(ANDROID_PLATFORM) && !defined(IOS_PLATFORM)
218     auto start = std::chrono::steady_clock::now();
219     return std::make_shared<SqliteSharedResultSet>(start, conn_.lock(), sql, args, path_);
220 #else
221     (void)sql;
222     (void)args;
223     return nullptr;
224 #endif
225 }
226 
QueryByStep(const std::string & sql,const Values & args,bool preCount)227 std::shared_ptr<ResultSet> TransDB::QueryByStep(const std::string &sql, const Values &args, bool preCount)
228 {
229     auto start = std::chrono::steady_clock::now();
230     return std::make_shared<StepResultSet>(start, conn_.lock(), sql, args, true, true);
231 }
232 
Execute(const std::string & sql,const Values & args,int64_t trxId)233 std::pair<int32_t, ValueObject> TransDB::Execute(const std::string &sql, const Values &args, int64_t trxId)
234 {
235     (void)trxId;
236     ValueObject object;
237     int sqlType = SqliteUtils::GetSqlStatementType(sql);
238     if (!SqliteUtils::IsSupportSqlForExecute(sqlType) && !SqliteUtils::IsSpecial(sqlType)) {
239         LOG_ERROR("Not support the sql:app self can check the SQL, sqlType:%{public}d", sqlType);
240         return { E_INVALID_ARGS, object };
241     }
242 
243     auto [errCode, statement] = GetStatement(sql);
244     if (errCode != E_OK) {
245         return { errCode, object };
246     }
247 
248     errCode = statement->Execute(args);
249     if (errCode != E_OK) {
250         LOG_ERROR("failed,app self can check the SQL, error:0x%{public}x.", errCode);
251         return { errCode, object };
252     }
253 
254     if (sqlType == SqliteUtils::STATEMENT_INSERT) {
255         int64_t outValue = statement->Changes() > 0 ? statement->LastInsertRowId() : -1;
256         return { errCode, ValueObject(outValue) };
257     }
258 
259     if (sqlType == SqliteUtils::STATEMENT_UPDATE) {
260         int outValue = statement->Changes();
261         return { errCode, ValueObject(outValue) };
262     }
263 
264     if (sqlType == SqliteUtils::STATEMENT_PRAGMA) {
265         if (statement->GetColumnCount() == 1) {
266             return statement->GetColumn(0);
267         }
268     }
269 
270     if (sqlType == SqliteUtils::STATEMENT_DDL) {
271         HandleSchemaDDL(statement);
272     }
273     return { errCode, object };
274 }
275 
ExecuteExt(const std::string & sql,const Values & args)276 std::pair<int32_t, Results> TransDB::ExecuteExt(const std::string &sql, const Values &args)
277 {
278     ValueObject object;
279     int sqlType = SqliteUtils::GetSqlStatementType(sql);
280     if (!SqliteUtils::IsSupportSqlForExecute(sqlType) && !SqliteUtils::IsSpecial(sqlType)) {
281         LOG_ERROR("Not support the sql:app self can check the SQL");
282         return { E_INVALID_ARGS, -1 };
283     }
284     auto [errCode, statement] = GetStatement(sql);
285     if (errCode != E_OK || statement == nullptr) {
286         return { errCode != E_OK ? errCode : E_ERROR, -1 };
287     }
288 
289     errCode = statement->Execute(args);
290     auto result = GenerateResult(
291         errCode, statement, sqlType == SqliteUtils::STATEMENT_INSERT || sqlType == SqliteUtils::STATEMENT_UPDATE);
292     if (errCode != E_OK) {
293         LOG_ERROR("failed,app self can check the SQL, error:0x%{public}x.", errCode);
294         return { errCode, result };
295     }
296 
297     if (sqlType == SqliteUtils::STATEMENT_DDL) {
298         HandleSchemaDDL(statement);
299     }
300     return { errCode, result };
301 }
302 
HandleSchemaDDL(std::shared_ptr<Statement> statement)303 void TransDB::HandleSchemaDDL(std::shared_ptr<Statement> statement)
304 {
305     if (statement == nullptr) {
306         return;
307     }
308     statement->Reset();
309     statement->Prepare("PRAGMA schema_version");
310     auto [err, version] = statement->ExecuteForValue();
311     if (vSchema_ < static_cast<int64_t>(version)) {
312         LOG_INFO("db:%{public}s exe DDL schema<%{public}" PRIi64 "->%{public}" PRIi64 ">",
313             SqliteUtils::Anonymous(path_).c_str(), vSchema_, static_cast<int64_t>(version));
314         vSchema_ = version;
315     }
316 }
317 
GetVersion(int & version)318 int TransDB::GetVersion(int &version)
319 {
320     return E_NOT_SUPPORT;
321 }
322 
SetVersion(int version)323 int TransDB::SetVersion(int version)
324 {
325     return E_NOT_SUPPORT;
326 }
327 
Sync(const SyncOption & option,const std::vector<std::string> & tables,const AsyncDetail & async)328 int TransDB::Sync(const SyncOption &option, const std::vector<std::string> &tables, const AsyncDetail &async)
329 {
330     if (option.mode != TIME_FIRST || tables.empty()) {
331         return E_INVALID_ARGS;
332     }
333     return RdbStore::Sync(option, tables, async);
334 }
335 
GetStatement(const std::string & sql) const336 std::pair<int32_t, std::shared_ptr<Statement>> TransDB::GetStatement(const std::string &sql) const
337 {
338     auto connection = conn_.lock();
339     if (connection == nullptr) {
340         return { E_ALREADY_CLOSED, nullptr };
341     }
342     return connection->CreateStatement(sql, connection);
343 }
344 
GenerateResult(int32_t code,std::shared_ptr<Statement> statement,bool isDML)345 Results TransDB::GenerateResult(int32_t code, std::shared_ptr<Statement> statement, bool isDML)
346 {
347     Results result{ -1 };
348     if (statement == nullptr) {
349         return result;
350     }
351     // There are no data changes in other scenarios
352     if (code == E_OK) {
353         result.results = GetValues(statement);
354         result.changed = isDML ? statement->Changes() : 0;
355     }
356     if (code == E_SQLITE_CONSTRAINT) {
357         result.changed = statement->Changes();
358     }
359     if (isDML && result.changed <= 0) {
360         result.results = std::make_shared<CacheResultSet>();
361     }
362     return result;
363 }
364 
GetValues(std::shared_ptr<Statement> statement)365 std::shared_ptr<ResultSet> TransDB::GetValues(std::shared_ptr<Statement> statement)
366 {
367     if (statement == nullptr) {
368         return nullptr;
369     }
370     auto [code, rows] = statement->GetRows(MAX_RETURNING_ROWS);
371     auto size = rows.size();
372     std::shared_ptr<ResultSet> result = std::make_shared<CacheResultSet>(std::move(rows));
373     // The correct number of changed rows can only be obtained after completing the step
374     while (code == E_OK && size == MAX_RETURNING_ROWS) {
375         std::tie(code, rows) = statement->GetRows(MAX_RETURNING_ROWS);
376         size = rows.size();
377     }
378     return result;
379 }
380 } // namespace OHOS::NativeRdb