• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #define LOG_TAG "RdStatement"
16 #include "rd_statement.h"
17 
18 #include <chrono>
19 #include <cinttypes>
20 #include <iomanip>
21 #include <sstream>
22 
23 #include "logger.h"
24 #include "raw_data_parser.h"
25 #include "rd_connection.h"
26 #include "rd_utils.h"
27 #include "rdb_errno.h"
28 #include "rdb_fault_hiview_reporter.h"
29 #include "sqlite_global_config.h"
30 #include "sqlite_utils.h"
31 
32 namespace OHOS {
33 namespace NativeRdb {
34 using namespace OHOS::Rdb;
35 using Reportor = RdbFaultHiViewReporter;
RdStatement()36 RdStatement::RdStatement()
37 {
38 }
39 
~RdStatement()40 RdStatement::~RdStatement()
41 {
42     Finalize();
43 }
44 
45 constexpr size_t PRAGMA_VERSION_SQL_LEN = __builtin_strlen(GlobalExpr::PRAGMA_VERSION);
46 
TryEatSymbol(const std::string & str,char symbol,size_t & curIdx)47 static bool TryEatSymbol(const std::string &str, char symbol, size_t &curIdx)
48 {
49     size_t idx = curIdx;
50     while (idx < str.length()) {
51         if (str[idx] == ' ') {
52             idx++;
53             continue;
54         }
55         if (str[idx] == symbol) {
56             curIdx = idx + 1;
57             return true;
58         }
59         break;
60     }
61     return false;
62 }
63 
TryEatNumber(const std::string & str,int & outNumber,size_t & curIdx)64 static int TryEatNumber(const std::string &str, int &outNumber, size_t &curIdx)
65 {
66     size_t idx = curIdx;
67     uint32_t numSpace = 0;
68     bool hasMeetDigit = false;
69     while (idx < str.length()) {
70         if (str[idx] == ' ' && !hasMeetDigit) {
71             idx++;
72             numSpace++;
73             continue;
74         }
75         if (isdigit(str[idx]) != 0) {
76             idx++;
77             hasMeetDigit = true;
78             continue;
79         }
80         // Indicates that meet first not-digit-char
81         break;
82     }
83     if (!hasMeetDigit) {
84         return false;
85     }
86     outNumber = atoi(str.substr(curIdx).c_str());
87     curIdx = idx;
88     return true;
89 }
90 
EndWithNull(const std::string & str,size_t curIdx)91 static int EndWithNull(const std::string &str, size_t curIdx)
92 {
93     size_t idx = curIdx;
94     while (idx < str.length()) {
95         if (str[idx] == ' ') {
96             idx++;
97             continue;
98         }
99         return false;
100     }
101     return true;
102 }
103 
Prepare(GRD_DB * db,const std::string & newSql)104 int RdStatement::Prepare(GRD_DB *db, const std::string &newSql)
105 {
106     if (newSql.find(GlobalExpr::PRAGMA_VERSION) == 0) {
107         // Indicates that sql is start with pragma version
108         if (newSql.length() == PRAGMA_VERSION_SQL_LEN) {
109             // Indicates that sql is to get version
110             sql_ = newSql;
111             readOnly_ = true;
112             return E_OK;
113         }
114         size_t curIdx = PRAGMA_VERSION_SQL_LEN;
115         int version = 0;
116         if ((!TryEatSymbol(newSql, '=', curIdx)) || (!TryEatNumber(newSql, version, curIdx)) ||
117             (!EndWithNull(newSql, curIdx) && !TryEatSymbol(newSql, ';', curIdx))) {
118             return E_INCORRECT_SQL;
119         }
120 
121         readOnly_ = false;
122         sql_ = newSql;
123         return setPragmas_["user_version"](version);
124     }
125     if (sql_.compare(newSql) == 0) {
126         return E_OK;
127     }
128     GRD_SqlStmt *tmpStmt = nullptr;
129     int ret = RdUtils::RdSqlPrepare(db, newSql.c_str(), newSql.length(), &tmpStmt, nullptr);
130     if (ret != E_OK) {
131         if (ret == E_SQLITE_CORRUPT && config_ != nullptr) {
132             Reportor::ReportCorruptedOnce(Reportor::Create(*config_, ret));
133         }
134         if (tmpStmt != nullptr) {
135             (void)RdUtils::RdSqlFinalize(tmpStmt);
136         }
137         LOG_ERROR("Prepare sql for stmt ret is %{public}d", ret);
138         return ret;
139     }
140     Finalize(); // Finalize original stmt
141     sql_ = newSql;
142     stmtHandle_ = tmpStmt;
143     columnCount_ = RdUtils::RdSqlColCnt(tmpStmt);
144     readOnly_ = SqliteUtils::GetSqlStatementType(newSql) == SqliteUtils::STATEMENT_SELECT;
145     return E_OK;
146 }
147 
Finalize()148 int RdStatement::Finalize()
149 {
150     if (stmtHandle_ == nullptr) {
151         return E_OK;
152     }
153     int ret = RdUtils::RdSqlFinalize(stmtHandle_);
154     if (ret != E_OK) {
155         LOG_ERROR("Finalize ret is %{public}d", ret);
156         return ret;
157     }
158     stmtHandle_ = nullptr;
159     sql_ = "";
160     columnCount_ = 0;
161     readOnly_ = false;
162     config_ = nullptr;
163     return E_OK;
164 }
165 
InnerBindBlobTypeArgs(const ValueObject & arg,uint32_t index) const166 int RdStatement::InnerBindBlobTypeArgs(const ValueObject &arg, uint32_t index) const
167 {
168     int ret = E_OK;
169     switch (arg.GetType()) {
170         case ValueObjectType::TYPE_BLOB: {
171             std::vector<uint8_t> blob;
172             arg.GetBlob(blob);
173             ret = RdUtils::RdSqlBindBlob(
174                 stmtHandle_, index, static_cast<const void *>(blob.data()), blob.size(), nullptr);
175             break;
176         }
177         case ValueObjectType::TYPE_BOOL: {
178             bool boolVal = false;
179             arg.GetBool(boolVal);
180             ret = RdUtils::RdSqlBindInt64(stmtHandle_, index, boolVal ? 1 : 0);
181             break;
182         }
183         case ValueObjectType::TYPE_ASSET: {
184             ValueObject::Asset asset;
185             arg.GetAsset(asset);
186             auto rawData = RawDataParser::PackageRawData(asset);
187             ret = RdUtils::RdSqlBindBlob(
188                 stmtHandle_, index, static_cast<const void *>(rawData.data()), rawData.size(), nullptr);
189             break;
190         }
191         case ValueObjectType::TYPE_ASSETS: {
192             ValueObject::Assets assets;
193             arg.GetAssets(assets);
194             auto rawData = RawDataParser::PackageRawData(assets);
195             ret = RdUtils::RdSqlBindBlob(
196                 stmtHandle_, index, static_cast<const void *>(rawData.data()), rawData.size(), nullptr);
197             break;
198         }
199         case ValueObjectType::TYPE_VECS: {
200             ValueObject::FloatVector vectors;
201             arg.GetVecs(vectors);
202             ret = RdUtils::RdSqlBindFloatVector(
203                 stmtHandle_, index, static_cast<float *>(vectors.data()), vectors.size(), nullptr);
204             break;
205         }
206         default: {
207             std::string str;
208             arg.GetString(str);
209             ret = RdUtils::RdSqlBindText(stmtHandle_, index, str.c_str(), str.length(), nullptr);
210             break;
211         }
212     }
213     return ret;
214 }
215 
PreGetColCount()216 int RdStatement::PreGetColCount()
217 {
218     if (!isStepInPrepare_ && readOnly_) {
219         isStepInPrepare_ = true;
220         int ret = Step();
221         if (ret != E_OK && ret != E_NO_MORE_ROWS) {
222             isStepInPrepare_ = false;
223             return ret;
224         }
225         GetProperties();
226         if (ret == E_NO_MORE_ROWS) {
227             Reset();
228         }
229     }
230     return E_OK;
231 }
232 
IsValid(int index) const233 int RdStatement::IsValid(int index) const
234 {
235     if (stmtHandle_ == nullptr) {
236         LOG_ERROR("Statement already close.");
237         return E_ALREADY_CLOSED;
238     }
239     if (index < 0 || index >= columnCount_) {
240         LOG_ERROR("Index (%{public}d) >= columnCount (%{public}d)", index, columnCount_);
241         return E_COLUMN_OUT_RANGE;
242     }
243     return E_OK;
244 }
245 
Prepare(const std::string & sql)246 int32_t RdStatement::Prepare(const std::string &sql)
247 {
248     if (dbHandle_ == nullptr) {
249         return E_ERROR;
250     }
251     return Prepare(dbHandle_, sql);
252 }
253 
Bind(const std::vector<ValueObject> & args)254 int32_t RdStatement::Bind(const std::vector<ValueObject> &args)
255 {
256     std::vector<std::reference_wrapper<ValueObject>> refArgs;
257     for (auto &object : args) {
258         refArgs.emplace_back(std::ref(const_cast<ValueObject &>(object)));
259     }
260     return Bind(refArgs);
261 }
262 
Bind(const std::vector<std::reference_wrapper<ValueObject>> & args)263 int32_t RdStatement::Bind(const std::vector<std::reference_wrapper<ValueObject>> &args)
264 {
265     uint32_t index = 1;
266     int ret = E_OK;
267     for (auto &arg : args) {
268         switch (arg.get().GetType()) {
269             case ValueObjectType::TYPE_NULL: {
270                 ret = RdUtils::RdSqlBindNull(stmtHandle_, index);
271                 break;
272             }
273             case ValueObjectType::TYPE_INT: {
274                 int64_t value = 0;
275                 arg.get().GetLong(value);
276                 ret = RdUtils::RdSqlBindInt64(stmtHandle_, index, value);
277                 break;
278             }
279             case ValueObjectType::TYPE_DOUBLE: {
280                 double doubleVal = 0;
281                 arg.get().GetDouble(doubleVal);
282                 ret = RdUtils::RdSqlBindDouble(stmtHandle_, index, doubleVal);
283                 break;
284             }
285             default: {
286                 ret = InnerBindBlobTypeArgs(arg, index);
287                 break;
288             }
289         }
290         if (ret != E_OK) {
291             LOG_ERROR("Bind ret is %{public}d", ret);
292             return ret;
293         }
294         index++;
295     }
296     return PreGetColCount();
297 }
298 
Count()299 std::pair<int32_t, int32_t> RdStatement::Count()
300 {
301     return { E_NOT_SUPPORT, INVALID_COUNT };
302 }
303 
Step()304 int32_t RdStatement::Step()
305 {
306     if (stmtHandle_ == nullptr) {
307         return E_OK;
308     }
309     if (isStepInPrepare_ && stepCnt_ == 1) {
310         stepCnt_++;
311         return E_OK;
312     }
313     int ret = RdUtils::RdSqlStep(stmtHandle_);
314     if (ret == E_SQLITE_CORRUPT && config_ != nullptr) {
315         Reportor::ReportCorruptedOnce(Reportor::Create(*config_, ret));
316     }
317     stepCnt_++;
318     return ret;
319 }
320 
Reset()321 int32_t RdStatement::Reset()
322 {
323     if (stmtHandle_ == nullptr) {
324         return E_OK;
325     }
326     stepCnt_ = 0;
327     isStepInPrepare_ = false;
328     return RdUtils::RdSqlReset(stmtHandle_);
329 }
330 
Execute(const std::vector<ValueObject> & args)331 int32_t RdStatement::Execute(const std::vector<ValueObject> &args)
332 {
333     std::vector<std::reference_wrapper<ValueObject>> refArgs;
334     for (auto &object : args) {
335         refArgs.emplace_back(std::ref(const_cast<ValueObject &>(object)));
336     }
337     return Execute(refArgs);
338 }
339 
Execute(const std::vector<std::reference_wrapper<ValueObject>> & args)340 int32_t RdStatement::Execute(const std::vector<std::reference_wrapper<ValueObject>> &args)
341 {
342     if (!readOnly_ && strcmp(sql_.c_str(), GlobalExpr::PRAGMA_VERSION) == 0) {
343         // It has already set version in prepare procedure
344         // Current modification is only temporary for unification between rd and sqlite,
345         // rd kernal will support pragma in later version
346         return E_OK;
347     }
348     int ret = Bind(args);
349     if (ret != E_OK) {
350         LOG_ERROR("RdConnection unable to prepare and bind stmt : err %{public}d", ret);
351         return ret;
352     }
353     ret = Step();
354     if (ret != E_OK && ret != E_NO_MORE_ROWS) {
355         LOG_ERROR("RdConnection Execute : err %{public}d", ret);
356     }
357     return ret;
358 }
359 
ExecuteForValue(const std::vector<ValueObject> & args)360 std::pair<int, ValueObject> RdStatement::ExecuteForValue(const std::vector<ValueObject> &args)
361 {
362     int ret = E_OK;
363     if (readOnly_ && strcmp(sql_.c_str(), GlobalExpr::PRAGMA_VERSION) == 0) {
364         int version = 0;
365         ret = getPragmas_["user_version"](version);
366         if (ret != E_OK) {
367             LOG_ERROR("RdConnection unable to GetVersion : err %{public}d", ret);
368             return { ret, ValueObject() };
369         }
370         return { ret, ValueObject(version) };
371     }
372     ret = Bind(args);
373     if (ret != E_OK) {
374         LOG_ERROR("RdConnection unable to prepare and bind stmt : err %{public}d", ret);
375         return { ret, ValueObject() };
376     }
377     ret = Step();
378     if (ret != E_OK && ret != E_NO_MORE_ROWS) {
379         LOG_ERROR("RdConnection Execute : err %{public}d", ret);
380         return { ret, ValueObject() };
381     }
382     return GetColumn(0);
383 }
384 
Changes() const385 int32_t RdStatement::Changes() const
386 {
387     return 0;
388 }
389 
LastInsertRowId() const390 int64_t RdStatement::LastInsertRowId() const
391 {
392     return 0;
393 }
394 
GetColumnCount() const395 int32_t RdStatement::GetColumnCount() const
396 {
397     return columnCount_;
398 }
399 
GetColumnName(int32_t index) const400 std::pair<int32_t, std::string> RdStatement::GetColumnName(int32_t index) const
401 {
402     int ret = IsValid(index);
403     if (ret != E_OK) {
404         return { ret, "" };
405     }
406     const char *name = RdUtils::RdSqlColName(stmtHandle_, index);
407     if (name == nullptr) {
408         LOG_ERROR("column_name is null.");
409         return { E_ERROR, "" };
410     }
411     return { E_OK, name };
412 }
413 
GetColumnType(int32_t index) const414 std::pair<int32_t, int32_t> RdStatement::GetColumnType(int32_t index) const
415 {
416     int ret = IsValid(index);
417     if (ret != E_OK) {
418         return { ret, static_cast<int32_t>(ColumnType::TYPE_NULL) };
419     }
420     ColumnType type = RdUtils::RdSqlColType(stmtHandle_, index);
421     switch (type) {
422         case ColumnType::TYPE_INTEGER:
423         case ColumnType::TYPE_FLOAT:
424         case ColumnType::TYPE_NULL:
425         case ColumnType::TYPE_STRING:
426         case ColumnType::TYPE_BLOB:
427         case ColumnType::TYPE_FLOAT32_ARRAY:
428             break;
429         default:
430             LOG_ERROR("Invalid type %{public}d.", type);
431             return { E_ERROR, static_cast<int32_t>(ColumnType::TYPE_NULL) };
432     }
433     return { ret, static_cast<int32_t>(type) };
434 }
435 
GetSize(int32_t index) const436 std::pair<int32_t, size_t> RdStatement::GetSize(int32_t index) const
437 {
438     int ret = IsValid(index);
439     if (ret != E_OK) {
440         return { ret, 0 };
441     }
442     ColumnType type = RdUtils::RdSqlColType(stmtHandle_, index);
443     if (type == ColumnType::TYPE_BLOB || type == ColumnType::TYPE_STRING || type == ColumnType::TYPE_NULL ||
444         type == ColumnType::TYPE_FLOAT32_ARRAY) {
445         return { E_OK, static_cast<size_t>(RdUtils::RdSqlColBytes(stmtHandle_, index)) };
446     }
447     return { E_INVALID_COLUMN_TYPE, 0 };
448 }
449 
GetColumn(int32_t index) const450 std::pair<int32_t, ValueObject> RdStatement::GetColumn(int32_t index) const
451 {
452     ValueObject object;
453     int ret = IsValid(index);
454     if (ret != E_OK) {
455         return { ret, object };
456     }
457 
458     ColumnType type = RdUtils::RdSqlColType(stmtHandle_, index);
459     switch (type) {
460         case ColumnType::TYPE_FLOAT:
461             object = RdUtils::RdSqlColDouble(stmtHandle_, index);
462             break;
463         case ColumnType::TYPE_INTEGER:
464             object = static_cast<int64_t>(RdUtils::RdSqlColInt64(stmtHandle_, index));
465             break;
466         case ColumnType::TYPE_STRING:
467             object = reinterpret_cast<const char *>(RdUtils::RdSqlColText(stmtHandle_, index));
468             break;
469         case ColumnType::TYPE_NULL:
470             break;
471         case ColumnType::TYPE_FLOAT32_ARRAY: {
472             uint32_t dim = 0;
473             auto vectors = reinterpret_cast<const float *>(RdUtils::RdSqlColumnFloatVector(stmtHandle_, index, &dim));
474             std::vector<float> vecData;
475             if (dim > 0 || vectors != nullptr) {
476                 vecData.resize(dim);
477                 vecData.assign(vectors, vectors + dim);
478             }
479             object = std::move(vecData);
480             break;
481         }
482         case ColumnType::TYPE_BLOB: {
483             int size = RdUtils::RdSqlColBytes(stmtHandle_, index);
484             auto blob = static_cast<const uint8_t *>(RdUtils::RdSqlColBlob(stmtHandle_, index));
485             std::vector<uint8_t> rawData;
486             if (size > 0 || blob != nullptr) {
487                 rawData.resize(size);
488                 rawData.assign(blob, blob + size);
489             }
490             object = std::move(rawData);
491             break;
492         }
493         default:
494             break;
495     }
496     return { ret, std::move(object) };
497 }
498 
ReadOnly() const499 bool RdStatement::ReadOnly() const
500 {
501     return readOnly_;
502 }
503 
SupportBlockInfo() const504 bool RdStatement::SupportBlockInfo() const
505 {
506     return false;
507 }
508 
FillBlockInfo(SharedBlockInfo * info) const509 int32_t RdStatement::FillBlockInfo(SharedBlockInfo *info) const
510 {
511     return E_NOT_SUPPORT;
512 }
513 
GetProperties()514 void RdStatement::GetProperties()
515 {
516     columnCount_ = RdUtils::RdSqlColCnt(stmtHandle_);
517 }
518 } // namespace NativeRdb
519 } // namespace OHOS
520