1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #define LOG_TAG "RdStatement"
16 #include "rd_statement.h"
17
18 #include <chrono>
19 #include <cinttypes>
20 #include <iomanip>
21 #include <sstream>
22
23 #include "logger.h"
24 #include "raw_data_parser.h"
25 #include "rd_connection.h"
26 #include "rd_utils.h"
27 #include "rdb_errno.h"
28 #include "rdb_fault_hiview_reporter.h"
29 #include "sqlite_global_config.h"
30 #include "sqlite_utils.h"
31
32 namespace OHOS {
33 namespace NativeRdb {
34 using namespace OHOS::Rdb;
35 using Reportor = RdbFaultHiViewReporter;
RdStatement()36 RdStatement::RdStatement()
37 {
38 }
39
~RdStatement()40 RdStatement::~RdStatement()
41 {
42 Finalize();
43 }
44
45 constexpr size_t PRAGMA_VERSION_SQL_LEN = __builtin_strlen(GlobalExpr::PRAGMA_VERSION);
46
TryEatSymbol(const std::string & str,char symbol,size_t & curIdx)47 static bool TryEatSymbol(const std::string &str, char symbol, size_t &curIdx)
48 {
49 size_t idx = curIdx;
50 while (idx < str.length()) {
51 if (str[idx] == ' ') {
52 idx++;
53 continue;
54 }
55 if (str[idx] == symbol) {
56 curIdx = idx + 1;
57 return true;
58 }
59 break;
60 }
61 return false;
62 }
63
TryEatNumber(const std::string & str,int & outNumber,size_t & curIdx)64 static int TryEatNumber(const std::string &str, int &outNumber, size_t &curIdx)
65 {
66 size_t idx = curIdx;
67 uint32_t numSpace = 0;
68 bool hasMeetDigit = false;
69 while (idx < str.length()) {
70 if (str[idx] == ' ' && !hasMeetDigit) {
71 idx++;
72 numSpace++;
73 continue;
74 }
75 if (isdigit(str[idx]) != 0) {
76 idx++;
77 hasMeetDigit = true;
78 continue;
79 }
80 // Indicates that meet first not-digit-char
81 break;
82 }
83 if (!hasMeetDigit) {
84 return false;
85 }
86 outNumber = atoi(str.substr(curIdx).c_str());
87 curIdx = idx;
88 return true;
89 }
90
EndWithNull(const std::string & str,size_t curIdx)91 static int EndWithNull(const std::string &str, size_t curIdx)
92 {
93 size_t idx = curIdx;
94 while (idx < str.length()) {
95 if (str[idx] == ' ') {
96 idx++;
97 continue;
98 }
99 return false;
100 }
101 return true;
102 }
103
Prepare(GRD_DB * db,const std::string & newSql)104 int RdStatement::Prepare(GRD_DB *db, const std::string &newSql)
105 {
106 if (newSql.find(GlobalExpr::PRAGMA_VERSION) == 0) {
107 // Indicates that sql is start with pragma version
108 if (newSql.length() == PRAGMA_VERSION_SQL_LEN) {
109 // Indicates that sql is to get version
110 sql_ = newSql;
111 readOnly_ = true;
112 return E_OK;
113 }
114 size_t curIdx = PRAGMA_VERSION_SQL_LEN;
115 int version = 0;
116 if ((!TryEatSymbol(newSql, '=', curIdx)) || (!TryEatNumber(newSql, version, curIdx)) ||
117 (!EndWithNull(newSql, curIdx) && !TryEatSymbol(newSql, ';', curIdx))) {
118 return E_INCORRECT_SQL;
119 }
120
121 readOnly_ = false;
122 sql_ = newSql;
123 return setPragmas_["user_version"](version);
124 }
125 if (sql_.compare(newSql) == 0) {
126 return E_OK;
127 }
128 GRD_SqlStmt *tmpStmt = nullptr;
129 int ret = RdUtils::RdSqlPrepare(db, newSql.c_str(), newSql.length(), &tmpStmt, nullptr);
130 if (ret != E_OK) {
131 if (ret == E_SQLITE_CORRUPT && config_ != nullptr) {
132 Reportor::ReportCorruptedOnce(Reportor::Create(*config_, ret));
133 }
134 if (tmpStmt != nullptr) {
135 (void)RdUtils::RdSqlFinalize(tmpStmt);
136 }
137 LOG_ERROR("Prepare sql for stmt ret is %{public}d", ret);
138 return ret;
139 }
140 Finalize(); // Finalize original stmt
141 sql_ = newSql;
142 stmtHandle_ = tmpStmt;
143 columnCount_ = RdUtils::RdSqlColCnt(tmpStmt);
144 readOnly_ = SqliteUtils::GetSqlStatementType(newSql) == SqliteUtils::STATEMENT_SELECT;
145 return E_OK;
146 }
147
Finalize()148 int RdStatement::Finalize()
149 {
150 if (stmtHandle_ == nullptr) {
151 return E_OK;
152 }
153 int ret = RdUtils::RdSqlFinalize(stmtHandle_);
154 if (ret != E_OK) {
155 LOG_ERROR("Finalize ret is %{public}d", ret);
156 return ret;
157 }
158 stmtHandle_ = nullptr;
159 sql_ = "";
160 columnCount_ = 0;
161 readOnly_ = false;
162 config_ = nullptr;
163 return E_OK;
164 }
165
InnerBindBlobTypeArgs(const ValueObject & arg,uint32_t index) const166 int RdStatement::InnerBindBlobTypeArgs(const ValueObject &arg, uint32_t index) const
167 {
168 int ret = E_OK;
169 switch (arg.GetType()) {
170 case ValueObjectType::TYPE_BLOB: {
171 std::vector<uint8_t> blob;
172 arg.GetBlob(blob);
173 ret = RdUtils::RdSqlBindBlob(
174 stmtHandle_, index, static_cast<const void *>(blob.data()), blob.size(), nullptr);
175 break;
176 }
177 case ValueObjectType::TYPE_BOOL: {
178 bool boolVal = false;
179 arg.GetBool(boolVal);
180 ret = RdUtils::RdSqlBindInt64(stmtHandle_, index, boolVal ? 1 : 0);
181 break;
182 }
183 case ValueObjectType::TYPE_ASSET: {
184 ValueObject::Asset asset;
185 arg.GetAsset(asset);
186 auto rawData = RawDataParser::PackageRawData(asset);
187 ret = RdUtils::RdSqlBindBlob(
188 stmtHandle_, index, static_cast<const void *>(rawData.data()), rawData.size(), nullptr);
189 break;
190 }
191 case ValueObjectType::TYPE_ASSETS: {
192 ValueObject::Assets assets;
193 arg.GetAssets(assets);
194 auto rawData = RawDataParser::PackageRawData(assets);
195 ret = RdUtils::RdSqlBindBlob(
196 stmtHandle_, index, static_cast<const void *>(rawData.data()), rawData.size(), nullptr);
197 break;
198 }
199 case ValueObjectType::TYPE_VECS: {
200 ValueObject::FloatVector vectors;
201 arg.GetVecs(vectors);
202 ret = RdUtils::RdSqlBindFloatVector(
203 stmtHandle_, index, static_cast<float *>(vectors.data()), vectors.size(), nullptr);
204 break;
205 }
206 default: {
207 std::string str;
208 arg.GetString(str);
209 ret = RdUtils::RdSqlBindText(stmtHandle_, index, str.c_str(), str.length(), nullptr);
210 break;
211 }
212 }
213 return ret;
214 }
215
PreGetColCount()216 int RdStatement::PreGetColCount()
217 {
218 if (!isStepInPrepare_ && readOnly_) {
219 isStepInPrepare_ = true;
220 int ret = Step();
221 if (ret != E_OK && ret != E_NO_MORE_ROWS) {
222 isStepInPrepare_ = false;
223 return ret;
224 }
225 GetProperties();
226 if (ret == E_NO_MORE_ROWS) {
227 Reset();
228 }
229 }
230 return E_OK;
231 }
232
IsValid(int index) const233 int RdStatement::IsValid(int index) const
234 {
235 if (stmtHandle_ == nullptr) {
236 LOG_ERROR("Statement already close.");
237 return E_ALREADY_CLOSED;
238 }
239 if (index < 0 || index >= columnCount_) {
240 LOG_ERROR("Index (%{public}d) >= columnCount (%{public}d)", index, columnCount_);
241 return E_COLUMN_OUT_RANGE;
242 }
243 return E_OK;
244 }
245
Prepare(const std::string & sql)246 int32_t RdStatement::Prepare(const std::string &sql)
247 {
248 if (dbHandle_ == nullptr) {
249 return E_ERROR;
250 }
251 return Prepare(dbHandle_, sql);
252 }
253
Bind(const std::vector<ValueObject> & args)254 int32_t RdStatement::Bind(const std::vector<ValueObject> &args)
255 {
256 std::vector<std::reference_wrapper<ValueObject>> refArgs;
257 for (auto &object : args) {
258 refArgs.emplace_back(std::ref(const_cast<ValueObject &>(object)));
259 }
260 return Bind(refArgs);
261 }
262
Bind(const std::vector<std::reference_wrapper<ValueObject>> & args)263 int32_t RdStatement::Bind(const std::vector<std::reference_wrapper<ValueObject>> &args)
264 {
265 uint32_t index = 1;
266 int ret = E_OK;
267 for (auto &arg : args) {
268 switch (arg.get().GetType()) {
269 case ValueObjectType::TYPE_NULL: {
270 ret = RdUtils::RdSqlBindNull(stmtHandle_, index);
271 break;
272 }
273 case ValueObjectType::TYPE_INT: {
274 int64_t value = 0;
275 arg.get().GetLong(value);
276 ret = RdUtils::RdSqlBindInt64(stmtHandle_, index, value);
277 break;
278 }
279 case ValueObjectType::TYPE_DOUBLE: {
280 double doubleVal = 0;
281 arg.get().GetDouble(doubleVal);
282 ret = RdUtils::RdSqlBindDouble(stmtHandle_, index, doubleVal);
283 break;
284 }
285 default: {
286 ret = InnerBindBlobTypeArgs(arg, index);
287 break;
288 }
289 }
290 if (ret != E_OK) {
291 LOG_ERROR("Bind ret is %{public}d", ret);
292 return ret;
293 }
294 index++;
295 }
296 return PreGetColCount();
297 }
298
Count()299 std::pair<int32_t, int32_t> RdStatement::Count()
300 {
301 return { E_NOT_SUPPORT, INVALID_COUNT };
302 }
303
Step()304 int32_t RdStatement::Step()
305 {
306 if (stmtHandle_ == nullptr) {
307 return E_OK;
308 }
309 if (isStepInPrepare_ && stepCnt_ == 1) {
310 stepCnt_++;
311 return E_OK;
312 }
313 int ret = RdUtils::RdSqlStep(stmtHandle_);
314 if (ret == E_SQLITE_CORRUPT && config_ != nullptr) {
315 Reportor::ReportCorruptedOnce(Reportor::Create(*config_, ret));
316 }
317 stepCnt_++;
318 return ret;
319 }
320
Reset()321 int32_t RdStatement::Reset()
322 {
323 if (stmtHandle_ == nullptr) {
324 return E_OK;
325 }
326 stepCnt_ = 0;
327 isStepInPrepare_ = false;
328 return RdUtils::RdSqlReset(stmtHandle_);
329 }
330
Execute(const std::vector<ValueObject> & args)331 int32_t RdStatement::Execute(const std::vector<ValueObject> &args)
332 {
333 std::vector<std::reference_wrapper<ValueObject>> refArgs;
334 for (auto &object : args) {
335 refArgs.emplace_back(std::ref(const_cast<ValueObject &>(object)));
336 }
337 return Execute(refArgs);
338 }
339
Execute(const std::vector<std::reference_wrapper<ValueObject>> & args)340 int32_t RdStatement::Execute(const std::vector<std::reference_wrapper<ValueObject>> &args)
341 {
342 if (!readOnly_ && strcmp(sql_.c_str(), GlobalExpr::PRAGMA_VERSION) == 0) {
343 // It has already set version in prepare procedure
344 // Current modification is only temporary for unification between rd and sqlite,
345 // rd kernal will support pragma in later version
346 return E_OK;
347 }
348 int ret = Bind(args);
349 if (ret != E_OK) {
350 LOG_ERROR("RdConnection unable to prepare and bind stmt : err %{public}d", ret);
351 return ret;
352 }
353 ret = Step();
354 if (ret != E_OK && ret != E_NO_MORE_ROWS) {
355 LOG_ERROR("RdConnection Execute : err %{public}d", ret);
356 }
357 return ret;
358 }
359
ExecuteForValue(const std::vector<ValueObject> & args)360 std::pair<int, ValueObject> RdStatement::ExecuteForValue(const std::vector<ValueObject> &args)
361 {
362 int ret = E_OK;
363 if (readOnly_ && strcmp(sql_.c_str(), GlobalExpr::PRAGMA_VERSION) == 0) {
364 int version = 0;
365 ret = getPragmas_["user_version"](version);
366 if (ret != E_OK) {
367 LOG_ERROR("RdConnection unable to GetVersion : err %{public}d", ret);
368 return { ret, ValueObject() };
369 }
370 return { ret, ValueObject(version) };
371 }
372 ret = Bind(args);
373 if (ret != E_OK) {
374 LOG_ERROR("RdConnection unable to prepare and bind stmt : err %{public}d", ret);
375 return { ret, ValueObject() };
376 }
377 ret = Step();
378 if (ret != E_OK && ret != E_NO_MORE_ROWS) {
379 LOG_ERROR("RdConnection Execute : err %{public}d", ret);
380 return { ret, ValueObject() };
381 }
382 return GetColumn(0);
383 }
384
Changes() const385 int32_t RdStatement::Changes() const
386 {
387 return 0;
388 }
389
LastInsertRowId() const390 int64_t RdStatement::LastInsertRowId() const
391 {
392 return 0;
393 }
394
GetColumnCount() const395 int32_t RdStatement::GetColumnCount() const
396 {
397 return columnCount_;
398 }
399
GetColumnName(int32_t index) const400 std::pair<int32_t, std::string> RdStatement::GetColumnName(int32_t index) const
401 {
402 int ret = IsValid(index);
403 if (ret != E_OK) {
404 return { ret, "" };
405 }
406 const char *name = RdUtils::RdSqlColName(stmtHandle_, index);
407 if (name == nullptr) {
408 LOG_ERROR("column_name is null.");
409 return { E_ERROR, "" };
410 }
411 return { E_OK, name };
412 }
413
GetColumnType(int32_t index) const414 std::pair<int32_t, int32_t> RdStatement::GetColumnType(int32_t index) const
415 {
416 int ret = IsValid(index);
417 if (ret != E_OK) {
418 return { ret, static_cast<int32_t>(ColumnType::TYPE_NULL) };
419 }
420 ColumnType type = RdUtils::RdSqlColType(stmtHandle_, index);
421 switch (type) {
422 case ColumnType::TYPE_INTEGER:
423 case ColumnType::TYPE_FLOAT:
424 case ColumnType::TYPE_NULL:
425 case ColumnType::TYPE_STRING:
426 case ColumnType::TYPE_BLOB:
427 case ColumnType::TYPE_FLOAT32_ARRAY:
428 break;
429 default:
430 LOG_ERROR("Invalid type %{public}d.", type);
431 return { E_ERROR, static_cast<int32_t>(ColumnType::TYPE_NULL) };
432 }
433 return { ret, static_cast<int32_t>(type) };
434 }
435
GetSize(int32_t index) const436 std::pair<int32_t, size_t> RdStatement::GetSize(int32_t index) const
437 {
438 int ret = IsValid(index);
439 if (ret != E_OK) {
440 return { ret, 0 };
441 }
442 ColumnType type = RdUtils::RdSqlColType(stmtHandle_, index);
443 if (type == ColumnType::TYPE_BLOB || type == ColumnType::TYPE_STRING || type == ColumnType::TYPE_NULL ||
444 type == ColumnType::TYPE_FLOAT32_ARRAY) {
445 return { E_OK, static_cast<size_t>(RdUtils::RdSqlColBytes(stmtHandle_, index)) };
446 }
447 return { E_INVALID_COLUMN_TYPE, 0 };
448 }
449
GetColumn(int32_t index) const450 std::pair<int32_t, ValueObject> RdStatement::GetColumn(int32_t index) const
451 {
452 ValueObject object;
453 int ret = IsValid(index);
454 if (ret != E_OK) {
455 return { ret, object };
456 }
457
458 ColumnType type = RdUtils::RdSqlColType(stmtHandle_, index);
459 switch (type) {
460 case ColumnType::TYPE_FLOAT:
461 object = RdUtils::RdSqlColDouble(stmtHandle_, index);
462 break;
463 case ColumnType::TYPE_INTEGER:
464 object = static_cast<int64_t>(RdUtils::RdSqlColInt64(stmtHandle_, index));
465 break;
466 case ColumnType::TYPE_STRING:
467 object = reinterpret_cast<const char *>(RdUtils::RdSqlColText(stmtHandle_, index));
468 break;
469 case ColumnType::TYPE_NULL:
470 break;
471 case ColumnType::TYPE_FLOAT32_ARRAY: {
472 uint32_t dim = 0;
473 auto vectors = reinterpret_cast<const float *>(RdUtils::RdSqlColumnFloatVector(stmtHandle_, index, &dim));
474 std::vector<float> vecData;
475 if (dim > 0 || vectors != nullptr) {
476 vecData.resize(dim);
477 vecData.assign(vectors, vectors + dim);
478 }
479 object = std::move(vecData);
480 break;
481 }
482 case ColumnType::TYPE_BLOB: {
483 int size = RdUtils::RdSqlColBytes(stmtHandle_, index);
484 auto blob = static_cast<const uint8_t *>(RdUtils::RdSqlColBlob(stmtHandle_, index));
485 std::vector<uint8_t> rawData;
486 if (size > 0 || blob != nullptr) {
487 rawData.resize(size);
488 rawData.assign(blob, blob + size);
489 }
490 object = std::move(rawData);
491 break;
492 }
493 default:
494 break;
495 }
496 return { ret, std::move(object) };
497 }
498
ReadOnly() const499 bool RdStatement::ReadOnly() const
500 {
501 return readOnly_;
502 }
503
SupportBlockInfo() const504 bool RdStatement::SupportBlockInfo() const
505 {
506 return false;
507 }
508
FillBlockInfo(SharedBlockInfo * info) const509 int32_t RdStatement::FillBlockInfo(SharedBlockInfo *info) const
510 {
511 return E_NOT_SUPPORT;
512 }
513
GetProperties()514 void RdStatement::GetProperties()
515 {
516 columnCount_ = RdUtils::RdSqlColCnt(stmtHandle_);
517 }
518 } // namespace NativeRdb
519 } // namespace OHOS
520