• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "sqlite_statement.h"
17 
18 #include <iomanip>
19 #include <sstream>
20 
21 #include "logger.h"
22 #include "raw_data_parser.h"
23 #include "rdb_errno.h"
24 #include "sqlite_errno.h"
25 #include "sqlite_utils.h"
26 
27 namespace OHOS {
28 namespace NativeRdb {
29 using namespace OHOS::Rdb;
30 
31 // Setting Data Precision
32 const int SET_DATA_PRECISION = 15;
SqliteStatement()33 SqliteStatement::SqliteStatement() : sql(""), stmtHandle(nullptr), readOnly(false), columnCount(0), numParameters(0)
34 {
35 }
~SqliteStatement()36 SqliteStatement::~SqliteStatement()
37 {
38     Finalize();
39 }
40 
Prepare(sqlite3 * dbHandle,const std::string & newSql)41 int SqliteStatement::Prepare(sqlite3 *dbHandle, const std::string &newSql)
42 {
43     if (sql.compare(newSql) == 0) {
44         return E_OK;
45     }
46     // prepare the new sqlite3_stmt
47     sqlite3_stmt *stmt = nullptr;
48     int errCode = sqlite3_prepare_v2(dbHandle, newSql.c_str(), newSql.length(), &stmt, nullptr);
49     if (errCode != SQLITE_OK) {
50         LOG_ERROR("prepare_v2 ret is %{public}d", errCode);
51         if (stmt != nullptr) {
52             sqlite3_finalize(stmt);
53         }
54         return SQLiteError::ErrNo(errCode);
55     }
56     Finalize(); // finalize the old
57     sql = newSql;
58     stmtHandle = stmt;
59     readOnly = (sqlite3_stmt_readonly(stmtHandle) != 0) ? true : false;
60     columnCount = sqlite3_column_count(stmtHandle);
61     numParameters = sqlite3_bind_parameter_count(stmtHandle);
62     return E_OK;
63 }
64 
Finalize()65 int SqliteStatement::Finalize()
66 {
67     if (stmtHandle == nullptr) {
68         return E_OK;
69     }
70 
71     int errCode = sqlite3_finalize(stmtHandle);
72     stmtHandle = nullptr;
73     sql = "";
74     readOnly = false;
75     columnCount = 0;
76     numParameters = 0;
77     if (errCode != SQLITE_OK) {
78         LOG_ERROR("finalize ret is %{public}d", errCode);
79         return SQLiteError::ErrNo(errCode);
80     }
81     return E_OK;
82 }
83 
BindArguments(const std::vector<ValueObject> & bindArgs) const84 int SqliteStatement::BindArguments(const std::vector<ValueObject> &bindArgs) const
85 {
86     int count = static_cast<int>(bindArgs.size());
87     std::vector<ValueObject> abindArgs;
88 
89     if (count == 0) {
90         return E_OK;
91     }
92     // Obtains the bound parameter set.
93     if ((numParameters != 0) && (count <= numParameters)) {
94         for (const auto& i : bindArgs) {
95             abindArgs.push_back(i);
96         }
97 
98         for (int i = count; i < numParameters; i++) { // TD: when count <> numParameters
99             ValueObject val;
100             abindArgs.push_back(val);
101         }
102     }
103 
104     if (count > numParameters) {
105         LOG_ERROR("bind args count(%{public}d) > numParameters(%{public}d)", count, numParameters);
106         return E_INVALID_BIND_ARGS_COUNT;
107     }
108 
109     return InnerBindArguments(abindArgs);
110 }
111 
InnerBindArguments(const std::vector<ValueObject> & bindArgs) const112 int SqliteStatement::InnerBindArguments(const std::vector<ValueObject> &bindArgs) const
113 {
114     int index = 1;
115     int errCode;
116     for (auto arg : bindArgs) {
117         switch (arg.GetType()) {
118             case ValueObjectType::TYPE_NULL: {
119                 errCode = sqlite3_bind_null(stmtHandle, index);
120                 break;
121             }
122             case ValueObjectType::TYPE_INT: {
123                 int64_t value = 0;
124                 arg.GetLong(value);
125                 errCode = sqlite3_bind_int64(stmtHandle, index, value);
126                 break;
127             }
128             case ValueObjectType::TYPE_DOUBLE: {
129                 double doubleVal = 0;
130                 arg.GetDouble(doubleVal);
131                 errCode = sqlite3_bind_double(stmtHandle, index, doubleVal);
132                 break;
133             }
134             case ValueObjectType::TYPE_BLOB: {
135                 std::vector<uint8_t> blob;
136                 arg.GetBlob(blob);
137                 errCode = sqlite3_bind_blob(stmtHandle, index, static_cast<const void *>(blob.data()), blob.size(),
138                     SQLITE_TRANSIENT);
139                 break;
140             }
141             case ValueObjectType::TYPE_BOOL: {
142                 bool boolVal = false;
143                 arg.GetBool(boolVal);
144                 errCode = sqlite3_bind_int64(stmtHandle, index, boolVal ? 1 : 0);
145                 break;
146             }
147             case ValueObjectType::TYPE_ASSET: {
148                 Asset asset;
149                 arg.GetAsset(asset);
150                 auto rawData = RawDataParser::PackageRawData(asset);
151                 errCode = sqlite3_bind_blob(stmtHandle, index, static_cast<const void *>(rawData.data()),
152                     rawData.size(), SQLITE_TRANSIENT);
153                 break;
154             }
155             case ValueObjectType::TYPE_ASSETS: {
156                 Assets assets;
157                 arg.GetAssets(assets);
158                 auto rawData = RawDataParser::PackageRawData(assets);
159                 errCode = sqlite3_bind_blob(stmtHandle, index, static_cast<const void *>(rawData.data()),
160                     rawData.size(), SQLITE_TRANSIENT);
161                 break;
162             }
163             default: {
164                 std::string str;
165                 arg.GetString(str);
166                 errCode = sqlite3_bind_text(stmtHandle, index, str.c_str(), str.length(), SQLITE_TRANSIENT);
167                 break;
168             }
169         }
170 
171         if (errCode != SQLITE_OK) {
172             LOG_ERROR("bind ret is %{public}d", errCode);
173             return SQLiteError::ErrNo(errCode);
174         }
175 
176         index++;
177     }
178 
179     return E_OK;
180 }
181 
ResetStatementAndClearBindings() const182 int SqliteStatement::ResetStatementAndClearBindings() const
183 {
184     if (stmtHandle == nullptr) {
185         return E_OK;
186     }
187 
188     int errCode = sqlite3_reset(stmtHandle);
189     if (errCode != SQLITE_OK) {
190         LOG_ERROR("reset ret is %{public}d", errCode);
191         return SQLiteError::ErrNo(errCode);
192     }
193 
194     errCode = sqlite3_clear_bindings(stmtHandle);
195     if (errCode != SQLITE_OK) {
196         LOG_ERROR("clear_bindings ret is %{public}d", errCode);
197         return SQLiteError::ErrNo(errCode);
198     }
199 
200     return E_OK;
201 }
202 
Step() const203 int SqliteStatement::Step() const
204 {
205     int errCode = sqlite3_step(stmtHandle);
206     return errCode;
207 }
208 
GetColumnCount(int & count) const209 int SqliteStatement::GetColumnCount(int &count) const
210 {
211     if (stmtHandle == nullptr) {
212         LOG_ERROR("invalid statement.");
213         return E_INVALID_STATEMENT;
214     }
215     count = columnCount;
216     return E_OK;
217 }
218 
219 /**
220  * Obtains the number that the statement has.
221  */
GetNumParameters(int & numParams) const222 int SqliteStatement::GetNumParameters(int &numParams) const
223 {
224     if (stmtHandle == nullptr) {
225         LOG_ERROR("invalid statement.");
226         return E_INVALID_STATEMENT;
227     }
228     numParams = numParameters;
229     return E_OK;
230 }
231 
GetColumnName(int index,std::string & columnName) const232 int SqliteStatement::GetColumnName(int index, std::string &columnName) const
233 {
234     int ret = IsValid(index);
235     if (ret != E_OK) {
236         return ret;
237     }
238 
239     const char *name = sqlite3_column_name(stmtHandle, index);
240     if (name == nullptr) {
241         LOG_ERROR("column_name is null.");
242         return E_ERROR;
243     }
244     columnName = std::string(name);
245     return E_OK;
246 }
247 
GetColumnType(int index,int & columnType) const248 int SqliteStatement::GetColumnType(int index, int &columnType) const
249 {
250     int ret = IsValid(index);
251     if (ret != E_OK) {
252         return ret;
253     }
254 
255     int type = sqlite3_column_type(stmtHandle, index);
256     switch (type) {
257         case SQLITE_INTEGER:
258         case SQLITE_FLOAT:
259         case SQLITE_BLOB:
260         case SQLITE_NULL:
261         case SQLITE_TEXT:
262             columnType = type;
263             return E_OK;
264         default:
265             LOG_ERROR("invalid type %{public}d.", type);
266             return E_ERROR;
267     }
268 }
269 
GetColumnBlob(int index,std::vector<uint8_t> & value) const270 int SqliteStatement::GetColumnBlob(int index, std::vector<uint8_t> &value) const
271 {
272     int ret = IsValid(index);
273     if (ret != E_OK) {
274         return ret;
275     }
276 
277     int type = sqlite3_column_type(stmtHandle, index);
278     if (type != SQLITE_BLOB && type != SQLITE_TEXT && type != SQLITE_NULL) {
279         LOG_ERROR("invalid type %{public}d.", type);
280         return E_INVALID_COLUMN_TYPE;
281     }
282 
283     int size = sqlite3_column_bytes(stmtHandle, index);
284     auto blob = static_cast<const uint8_t *>(sqlite3_column_blob(stmtHandle, index));
285     if (size == 0 || blob == nullptr) {
286         value.resize(0);
287     } else {
288         value.resize(size);
289         value.assign(blob, blob + size);
290     }
291 
292     return E_OK;
293 }
294 
GetColumnString(int index,std::string & value) const295 int SqliteStatement::GetColumnString(int index, std::string &value) const
296 {
297     int ret = IsValid(index);
298     if (ret != E_OK) {
299         return ret;
300     }
301 
302     int type = sqlite3_column_type(stmtHandle, index);
303     switch (type) {
304         case SQLITE_TEXT: {
305             auto val = reinterpret_cast<const char *>(sqlite3_column_text(stmtHandle, index));
306             value = (val == nullptr) ? "" : std::string(val, sqlite3_column_bytes(stmtHandle, index));
307             break;
308         }
309         case SQLITE_INTEGER: {
310             int64_t val = sqlite3_column_int64(stmtHandle, index);
311             value = std::to_string(val);
312             break;
313         }
314         case SQLITE_FLOAT: {
315             double val = sqlite3_column_double(stmtHandle, index);
316             std::ostringstream os;
317             if (os << std::setprecision(SET_DATA_PRECISION) << val)
318                 value = os.str();
319             break;
320         }
321         case SQLITE_NULL: {
322             value = "";
323             return E_OK;
324         }
325         case SQLITE_BLOB: {
326             return E_INVALID_COLUMN_TYPE;
327         }
328         default:
329             return E_ERROR;
330     }
331     return E_OK;
332 }
333 
GetColumnLong(int index,int64_t & value) const334 int SqliteStatement::GetColumnLong(int index, int64_t &value) const
335 {
336     int ret = IsValid(index);
337     if (ret != E_OK) {
338         return ret;
339     }
340 
341     char *errStr = nullptr;
342     int type = sqlite3_column_type(stmtHandle, index);
343     if (type == SQLITE_INTEGER) {
344         value = sqlite3_column_int64(stmtHandle, index);
345     } else if (type == SQLITE_TEXT) {
346         auto val = reinterpret_cast<const char *>(sqlite3_column_text(stmtHandle, index));
347         value = (val == nullptr) ? 0 : strtoll(val, &errStr, 0);
348     } else if (type == SQLITE_FLOAT) {
349         double val = sqlite3_column_double(stmtHandle, index);
350         value = static_cast<int64_t>(val);
351     } else if (type == SQLITE_NULL) {
352         value = 0;
353     } else if (type == SQLITE_BLOB) {
354         return E_INVALID_COLUMN_TYPE;
355     } else {
356         return E_ERROR;
357     }
358 
359     return E_OK;
360 }
GetColumnDouble(int index,double & value) const361 int SqliteStatement::GetColumnDouble(int index, double &value) const
362 {
363     int ret = IsValid(index);
364     if (ret != E_OK) {
365         return ret;
366     }
367 
368     char *ptr = nullptr;
369     int type = sqlite3_column_type(stmtHandle, index);
370     if (type == SQLITE_FLOAT) {
371         value = sqlite3_column_double(stmtHandle, index);
372     } else if (type == SQLITE_INTEGER) {
373         int64_t val = sqlite3_column_int64(stmtHandle, index);
374         value = static_cast<double>(val);
375     } else if (type == SQLITE_TEXT) {
376         auto val = reinterpret_cast<const char *>(sqlite3_column_text(stmtHandle, index));
377         value = (val == nullptr) ? 0.0 : std::strtod(val, &ptr);
378     } else if (type == SQLITE_NULL) {
379         value = 0.0;
380     } else if (type == SQLITE_BLOB) {
381         return E_INVALID_COLUMN_TYPE;
382     } else {
383         LOG_ERROR("invalid type %{public}d.", type);
384         return E_ERROR;
385     }
386 
387     return E_OK;
388 }
389 
GetColumn(int index,ValueObject & value) const390 int SqliteStatement::GetColumn(int index, ValueObject &value) const
391 {
392     int ret = IsValid(index);
393     if (ret != E_OK) {
394         return ret;
395     }
396 
397     int type = sqlite3_column_type(stmtHandle, index);
398     switch (type) {
399         case SQLITE_FLOAT:
400             value = sqlite3_column_double(stmtHandle, index);
401             return E_OK;
402         case SQLITE_INTEGER:
403             value = static_cast<int64_t>(sqlite3_column_int64(stmtHandle, index));
404             return E_OK;
405         case SQLITE_TEXT:
406             value = reinterpret_cast<const char *>(sqlite3_column_text(stmtHandle, index));
407             return E_OK;
408         case SQLITE_NULL:
409             return E_OK;
410         default:
411             break;
412     }
413     const char *decl = sqlite3_column_decltype(stmtHandle, index);
414     if (type != SQLITE_BLOB || decl == nullptr) {
415         LOG_ERROR("invalid type %{public}d.", type);
416         return E_ERROR;
417     }
418     int size = sqlite3_column_bytes(stmtHandle, index);
419     auto blob = static_cast<const uint8_t *>(sqlite3_column_blob(stmtHandle, index));
420     std::string declType = decl;
421     if (SqliteUtils::StrToUpper(declType) == ValueObject::DeclType<Asset>()) {
422         Asset asset;
423         RawDataParser::ParserRawData(blob, size, asset);
424         value = std::move(asset);
425         return E_OK;
426     }
427     if (SqliteUtils::StrToUpper(declType) == ValueObject::DeclType<Assets>()) {
428         Assets assets;
429         RawDataParser::ParserRawData(blob, size, assets);
430         value = std::move(assets);
431         return E_OK;
432     }
433     std::vector<uint8_t> rawData;
434     if (size > 0 || blob != nullptr) {
435         rawData.resize(size);
436         rawData.assign(blob, blob + size);
437     }
438     value = std::move(rawData);
439     return E_OK;
440 }
441 
GetSize(int index,size_t & size) const442 int SqliteStatement::GetSize(int index, size_t &size) const
443 {
444     size = 0;
445     if (stmtHandle == nullptr) {
446         return E_INVALID_STATEMENT;
447     }
448 
449     if (index >= columnCount) {
450         return E_INVALID_COLUMN_INDEX;
451     }
452 
453     int type = sqlite3_column_type(stmtHandle, index);
454     if (type == SQLITE_BLOB || type == SQLITE_TEXT || type == SQLITE_NULL) {
455         size = static_cast<size_t>(sqlite3_column_bytes(stmtHandle, index));
456         return E_OK;
457     }
458 
459     return E_INVALID_COLUMN_TYPE;
460 }
461 
IsReadOnly() const462 bool SqliteStatement::IsReadOnly() const
463 {
464     return readOnly;
465 }
466 
IsValid(int index) const467 int SqliteStatement::IsValid(int index) const
468 {
469     if (stmtHandle == nullptr) {
470         LOG_ERROR("invalid statement.");
471         return E_INVALID_STATEMENT;
472     }
473 
474     if (index >= columnCount) {
475         LOG_ERROR("index (%{public}d) >= columnCount (%{public}d)", index, columnCount);
476         return E_INVALID_COLUMN_INDEX;
477     }
478 
479     return E_OK;
480 }
481 } // namespace NativeRdb
482 } // namespace OHOS
483