1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "store_session.h"
17 #include <chrono>
18 #include <stack>
19 #include <thread>
20 #include "logger.h"
21 #include "rdb_errno.h"
22 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM)
23 #include "shared_block.h"
24 #include "rdb_security_manager.h"
25 #endif
26 #include "sqlite_database_utils.h"
27 #include "sqlite_utils.h"
28 #include "base_transaction.h"
29
30 namespace OHOS::NativeRdb {
StoreSession(SqliteConnectionPool & connectionPool)31 StoreSession::StoreSession(SqliteConnectionPool &connectionPool)
32 : connectionPool(connectionPool), readConnection(nullptr), connection(nullptr),
33 readConnectionUseCount(0), connectionUseCount(0), isInStepQuery(false)
34 {
35 }
36
~StoreSession()37 StoreSession::~StoreSession()
38 {
39 }
40
AcquireConnection(bool isReadOnly)41 void StoreSession::AcquireConnection(bool isReadOnly)
42 {
43 if (isReadOnly) {
44 if (readConnection == nullptr) {
45 readConnection = connectionPool.AcquireConnection(true);
46 }
47 readConnectionUseCount += 1;
48 return;
49 }
50 if (connection == nullptr) {
51 connection = connectionPool.AcquireConnection(false);
52 }
53 connectionUseCount += 1;
54 return;
55 }
56
ReleaseConnection(bool isReadOnly)57 void StoreSession::ReleaseConnection(bool isReadOnly)
58 {
59 if (isReadOnly) {
60 if ((readConnection == nullptr) || (readConnectionUseCount <= 0)) {
61 LOG_ERROR("SQLiteSession ReleaseConnection repeated release");
62 return;
63 }
64 if (--readConnectionUseCount == 0) {
65 connectionPool.ReleaseConnection(readConnection);
66 readConnection = nullptr;
67 }
68 return;
69 }
70 if ((connection == nullptr) || (connectionUseCount <= 0)) {
71 LOG_ERROR("SQLiteSession ReleaseConnection repeated release");
72 return;
73 }
74
75 if (--connectionUseCount == 0) {
76 connectionPool.ReleaseConnection(connection);
77 connection = nullptr;
78 }
79 }
80
PrepareAndGetInfo(const std::string & sql,bool & outIsReadOnly,int & numParameters,std::vector<std::string> & columnNames)81 int StoreSession::PrepareAndGetInfo(
82 const std::string &sql, bool &outIsReadOnly, int &numParameters, std::vector<std::string> &columnNames)
83 {
84 // Obtains the type of SQL statement.
85 int type = SqliteUtils::GetSqlStatementType(sql);
86 if (SqliteUtils::IsSpecial(type)) {
87 return E_TRANSACTION_IN_EXECUTE;
88 }
89 bool assumeReadOnly = SqliteUtils::IsSqlReadOnly(type);
90 AcquireConnection(assumeReadOnly);
91 auto con = assumeReadOnly ? readConnection : connection;
92 int errCode = con->PrepareAndGetInfo(sql, outIsReadOnly, numParameters, columnNames);
93 if (errCode != 0) {
94 ReleaseConnection(assumeReadOnly);
95 return errCode;
96 }
97
98 ReleaseConnection(assumeReadOnly);
99 return E_OK;
100 }
101
BeginExecuteSql(const std::string & sql,bool & isReadOnly)102 int StoreSession::BeginExecuteSql(const std::string &sql, bool &isReadOnly)
103 {
104 int type = SqliteUtils::GetSqlStatementType(sql);
105 if (SqliteUtils::IsSpecial(type)) {
106 return E_TRANSACTION_IN_EXECUTE;
107 }
108
109 bool assumeReadOnly = SqliteUtils::IsSqlReadOnly(type);
110 AcquireConnection(assumeReadOnly);
111 SqliteConnection *con = assumeReadOnly ? readConnection : connection;
112 int errCode = con->Prepare(sql, isReadOnly);
113 if (errCode != 0) {
114 ReleaseConnection(assumeReadOnly);
115 return errCode;
116 }
117
118 if (isReadOnly == con->IsWriteConnection()) {
119 ReleaseConnection(assumeReadOnly);
120 AcquireConnection(isReadOnly);
121 if (!isReadOnly && !con->IsWriteConnection()) {
122 LOG_ERROR("StoreSession BeginExecute: read connection can not execute write operation");
123 ReleaseConnection(isReadOnly);
124 return E_EXECUTE_WRITE_IN_READ_CONNECTION;
125 }
126 return E_OK;
127 }
128 isReadOnly = assumeReadOnly;
129 return E_OK;
130 }
ExecuteSql(const std::string & sql,const std::vector<ValueObject> & bindArgs)131 int StoreSession::ExecuteSql(const std::string &sql, const std::vector<ValueObject> &bindArgs)
132 {
133 bool isReadOnly = false;
134 int errCode = BeginExecuteSql(sql, isReadOnly);
135 if (errCode != 0) {
136 return errCode;
137 }
138 SqliteConnection *con = isReadOnly ? readConnection : connection;
139 errCode = con->ExecuteSql(sql, bindArgs);
140 ReleaseConnection(isReadOnly);
141 return errCode;
142 }
143
ExecuteForChangedRowCount(int & changedRows,const std::string & sql,const std::vector<ValueObject> & bindArgs)144 int StoreSession::ExecuteForChangedRowCount(
145 int &changedRows, const std::string &sql, const std::vector<ValueObject> &bindArgs)
146 {
147 bool isReadOnly = false;
148 int errCode = BeginExecuteSql(sql, isReadOnly);
149 if (errCode != 0) {
150 return errCode;
151 }
152 auto con = isReadOnly ? readConnection : connection;
153 errCode = con->ExecuteForChangedRowCount(changedRows, sql, bindArgs);
154 ReleaseConnection(isReadOnly);
155 return errCode;
156 }
157
ExecuteForLastInsertedRowId(int64_t & outRowId,const std::string & sql,const std::vector<ValueObject> & bindArgs)158 int StoreSession::ExecuteForLastInsertedRowId(
159 int64_t &outRowId, const std::string &sql, const std::vector<ValueObject> &bindArgs)
160 {
161 bool isReadOnly = false;
162 int errCode = BeginExecuteSql(sql, isReadOnly);
163 if (errCode != 0) {
164 LOG_ERROR("rdbStore BeginExecuteSql failed");
165 return errCode;
166 }
167 auto con = isReadOnly ? readConnection : connection;
168 errCode = con->ExecuteForLastInsertedRowId(outRowId, sql, bindArgs);
169 if (errCode != E_OK) {
170 LOG_ERROR("rdbStore ExecuteForLastInsertedRowId FAILED");
171 }
172 ReleaseConnection(isReadOnly);
173 return errCode;
174 }
175
ExecuteGetLong(int64_t & outValue,const std::string & sql,const std::vector<ValueObject> & bindArgs)176 int StoreSession::ExecuteGetLong(int64_t &outValue, const std::string &sql, const std::vector<ValueObject> &bindArgs)
177 {
178 bool isReadOnly = false;
179 int errCode = BeginExecuteSql(sql, isReadOnly);
180 if (errCode != 0) {
181 return errCode;
182 }
183 auto con = isReadOnly ? readConnection : connection;
184 errCode = con->ExecuteGetLong(outValue, sql, bindArgs);
185 ReleaseConnection(isReadOnly);
186 return errCode;
187 }
188
ExecuteGetString(std::string & outValue,const std::string & sql,const std::vector<ValueObject> & bindArgs)189 int StoreSession::ExecuteGetString(
190 std::string &outValue, const std::string &sql, const std::vector<ValueObject> &bindArgs)
191 {
192 bool isReadOnly = false;
193 int errCode = BeginExecuteSql(sql, isReadOnly);
194 if (errCode != 0) {
195 return errCode;
196 }
197 auto con = isReadOnly ? readConnection : connection;
198 std::string sqlstr = sql;
199 int type = SqliteDatabaseUtils::GetSqlStatementType(sqlstr);
200 if (type == STATEMENT_PRAGMA) {
201 ReleaseConnection(isReadOnly);
202 AcquireConnection(false);
203 con = connection;
204 }
205
206 errCode = con->ExecuteGetString(outValue, sql, bindArgs);
207 ReleaseConnection(isReadOnly);
208 return errCode;
209 }
210
Backup(const std::string databasePath,const std::vector<uint8_t> destEncryptKey,bool isEncrypt)211 int StoreSession::Backup(const std::string databasePath, const std::vector<uint8_t> destEncryptKey, bool isEncrypt)
212 {
213 std::vector<ValueObject> bindArgs;
214 bindArgs.push_back(ValueObject(databasePath));
215 if (destEncryptKey.size() != 0 && !isEncrypt) {
216 bindArgs.push_back(ValueObject(destEncryptKey));
217 ExecuteSql(CIPHER_DEFAULT_ATTACH_HMAC_ALGO);
218 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM)
219 } else if (isEncrypt) {
220 std::vector<uint8_t> key;
221 RdbPassword rdbPwd;
222 rdbPwd = RdbSecurityManager::GetInstance().GetRdbPassword(RdbSecurityManager::KeyFileType::PUB_KEY_FILE);
223 key = std::vector<uint8_t>(rdbPwd.GetData(), rdbPwd.GetData() + rdbPwd.GetSize());
224 bindArgs.push_back(ValueObject(key));
225 ExecuteSql(CIPHER_DEFAULT_ATTACH_HMAC_ALGO);
226 #endif
227 } else {
228 std::string str = "";
229 bindArgs.push_back(ValueObject(str));
230 }
231
232 int errCode = ExecuteSql(ATTACH_BACKUP_SQL, bindArgs);
233 if (errCode != E_OK) {
234 LOG_ERROR("ExecuteSql ATTACH_BACKUP_SQL error %{public}d", errCode);
235 return errCode;
236 }
237 int64_t count;
238 errCode = ExecuteGetLong(count, EXPORT_SQL, std::vector<ValueObject>());
239 if (errCode != E_OK) {
240 LOG_ERROR("ExecuteSql EXPORT_SQL error %{public}d", errCode);
241 return errCode;
242 }
243
244 errCode = ExecuteSql(DETACH_BACKUP_SQL, std::vector<ValueObject>());
245 if (errCode != E_OK) {
246 LOG_ERROR("ExecuteSql DETACH_BACKUP_SQL error %{public}d", errCode);
247 return errCode;
248 }
249 return E_OK;
250 }
251
252 // Checks whether this thread holds a database connection.
IsHoldingConnection() const253 bool StoreSession::IsHoldingConnection() const
254 {
255 if (connection == nullptr && readConnection == nullptr) {
256 return false;
257 } else {
258 return true;
259 }
260 }
261
CheckNoTransaction() const262 int StoreSession::CheckNoTransaction() const
263 {
264 int errorCode = 0;
265 if (connectionPool.getTransactionStack().empty()) {
266 errorCode = E_STORE_SESSION_NO_CURRENT_TRANSACTION;
267 return errorCode;
268 }
269 return E_OK;
270 }
271
GiveConnectionTemporarily(int64_t milliseconds)272 int StoreSession::GiveConnectionTemporarily(int64_t milliseconds)
273 {
274 int errorCode = CheckNoTransaction();
275 if (errorCode != E_OK) {
276 return errorCode;
277 }
278 BaseTransaction transaction = connectionPool.getTransactionStack().top();
279 if (transaction.IsMarkedSuccessful() || connectionPool.getTransactionStack().size() > 1) {
280 errorCode = E_STORE_SESSION_NOT_GIVE_CONNECTION_TEMPORARILY;
281 return errorCode;
282 }
283
284 MarkAsCommit();
285 EndTransaction();
286 if (milliseconds > 0) {
287 std::this_thread::sleep_for(std::chrono::milliseconds(milliseconds));
288 }
289 BeginTransaction();
290 return E_OK;
291 }
292
Attach(const std::string & alias,const std::string & pathName,const std::vector<uint8_t> destEncryptKey,bool isEncrypt)293 int StoreSession::Attach(
294 const std::string &alias, const std::string &pathName, const std::vector<uint8_t> destEncryptKey, bool isEncrypt)
295 {
296 std::string journalMode;
297 int errCode = ExecuteGetString(journalMode, "PRAGMA journal_mode", std::vector<ValueObject>());
298 if (errCode != E_OK) {
299 LOG_ERROR("RdbStoreImpl CheckAttach fail to get journal mode : %{public}d", errCode);
300 return errCode;
301 }
302 journalMode = SqliteUtils::StrToUpper(journalMode);
303 if (journalMode == "WAL") {
304 LOG_ERROR("RdbStoreImpl attach is not supported in WAL mode");
305 return E_NOT_SUPPORTED_ATTACH_IN_WAL_MODE;
306 }
307
308 std::vector<ValueObject> bindArgs;
309 bindArgs.push_back(ValueObject(pathName));
310 bindArgs.push_back(ValueObject(alias));
311 if (destEncryptKey.size() != 0 && !isEncrypt) {
312 bindArgs.push_back(ValueObject(destEncryptKey));
313 ExecuteSql(CIPHER_DEFAULT_ATTACH_HMAC_ALGO);
314 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM)
315 } else if (isEncrypt) {
316 std::vector<uint8_t> key;
317 RdbPassword rdbPwd;
318 rdbPwd = RdbSecurityManager::GetInstance().GetRdbPassword(RdbSecurityManager::KeyFileType::PUB_KEY_FILE);
319 key = std::vector<uint8_t>(rdbPwd.GetData(), rdbPwd.GetData() + rdbPwd.GetSize());
320 bindArgs.push_back(ValueObject(key));
321 ExecuteSql(CIPHER_DEFAULT_ATTACH_HMAC_ALGO);
322 #endif
323 } else {
324 std::string str = "";
325 bindArgs.push_back(ValueObject(str));
326 }
327 errCode = ExecuteSql(ATTACH_SQL, bindArgs);
328 if (errCode != E_OK) {
329 LOG_ERROR("ExecuteSql ATTACH_SQL error %{public}d", errCode);
330 return errCode;
331 }
332
333 return E_OK;
334 }
335
BeginTransaction(TransactionObserver * transactionObserver)336 int StoreSession::BeginTransaction(TransactionObserver *transactionObserver)
337 {
338 if (connectionPool.getTransactionStack().empty()) {
339 AcquireConnection(false);
340
341 int errCode = connection->ExecuteSql("BEGIN EXCLUSIVE;");
342 if (errCode != E_OK) {
343 ReleaseConnection(false);
344 return errCode;
345 }
346 }
347
348 if (transactionObserver != nullptr) {
349 transactionObserver->OnBegin();
350 }
351
352 BaseTransaction transaction(connectionPool.getTransactionStack().size());
353 connectionPool.getTransactionStack().push(transaction);
354
355 return E_OK;
356 }
357
MarkAsCommitWithObserver(TransactionObserver * transactionObserver)358 int StoreSession::MarkAsCommitWithObserver(TransactionObserver *transactionObserver)
359 {
360 if (connectionPool.getTransactionStack().empty()) {
361 return E_NO_TRANSACTION_IN_SESSION;
362 }
363 connectionPool.getTransactionStack().top().SetMarkedSuccessful(true);
364 return E_OK;
365 }
366
EndTransactionWithObserver(TransactionObserver * transactionObserver)367 int StoreSession::EndTransactionWithObserver(TransactionObserver *transactionObserver)
368 {
369 if (connectionPool.getTransactionStack().empty()) {
370 return E_NO_TRANSACTION_IN_SESSION;
371 }
372
373 BaseTransaction transaction = connectionPool.getTransactionStack().top();
374 bool isSucceed = transaction.IsAllBeforeSuccessful() && transaction.IsMarkedSuccessful();
375 connectionPool.getTransactionStack().pop();
376
377 if (transactionObserver != nullptr) {
378 if (isSucceed) {
379 transactionObserver->OnCommit();
380 } else {
381 transactionObserver->OnRollback();
382 }
383 }
384
385 if (!connectionPool.getTransactionStack().empty()) {
386 if (transactionObserver != nullptr) {
387 transactionObserver->OnRollback();
388 }
389
390 if (!isSucceed) {
391 connectionPool.getTransactionStack().top().SetAllBeforeSuccessful(false);
392 }
393 } else {
394 int errCode;
395 if (connection == nullptr) {
396 LOG_ERROR("connection is null");
397 return E_ERROR;
398 }
399 if (isSucceed) {
400 errCode = connection->ExecuteSql("COMMIT;");
401 } else {
402 errCode = connection->ExecuteSql("ROLLBACK;");
403 }
404
405 ReleaseConnection(false);
406 return errCode;
407 }
408
409 return E_OK;
410 }
411
MarkAsCommit()412 int StoreSession::MarkAsCommit()
413 {
414 if (connectionPool.getTransactionStack().empty()) {
415 return E_NO_TRANSACTION_IN_SESSION;
416 }
417 connectionPool.getTransactionStack().top().SetMarkedSuccessful(true);
418 return E_OK;
419 }
420
EndTransaction()421 int StoreSession::EndTransaction()
422 {
423 if (connectionPool.getTransactionStack().empty()) {
424 return E_NO_TRANSACTION_IN_SESSION;
425 }
426
427 BaseTransaction transaction = connectionPool.getTransactionStack().top();
428 bool isSucceed = transaction.IsAllBeforeSuccessful() && transaction.IsMarkedSuccessful();
429 connectionPool.getTransactionStack().pop();
430 if (!connectionPool.getTransactionStack().empty()) {
431 if (!isSucceed) {
432 connectionPool.getTransactionStack().top().SetAllBeforeSuccessful(false);
433 }
434 } else {
435 if (connection == nullptr) {
436 LOG_ERROR("connection is null");
437 return E_ERROR;
438 }
439 int errCode = connection->ExecuteSql(isSucceed ? "COMMIT;" : "ROLLBACK;");
440 ReleaseConnection(false);
441 return errCode;
442 }
443
444 return E_OK;
445 }
IsInTransaction() const446 bool StoreSession::IsInTransaction() const
447 {
448 return !connectionPool.getTransactionStack().empty();
449 }
450
BeginStepQuery(int & errCode,const std::string & sql,const std::vector<std::string> & selectionArgs)451 std::shared_ptr<SqliteStatement> StoreSession::BeginStepQuery(
452 int &errCode, const std::string &sql, const std::vector<std::string> &selectionArgs)
453 {
454 if (isInStepQuery == true) {
455 LOG_ERROR("StoreSession BeginStepQuery fail : begin more step query in one session !");
456 errCode = E_MORE_STEP_QUERY_IN_ONE_SESSION;
457 return nullptr; // fail,already in
458 }
459
460 if (SqliteUtils::GetSqlStatementType(sql) != SqliteUtils::STATEMENT_SELECT) {
461 LOG_ERROR("StoreSession BeginStepQuery fail : not select sql !");
462 errCode = E_EXECUTE_IN_STEP_QUERY;
463 return nullptr;
464 }
465
466 AcquireConnection(true);
467 std::shared_ptr<SqliteStatement> statement = readConnection->BeginStepQuery(errCode, sql, selectionArgs);
468 if (statement == nullptr) {
469 ReleaseConnection(true);
470 return nullptr;
471 }
472 isInStepQuery = true;
473 return statement;
474 }
475
EndStepQuery()476 int StoreSession::EndStepQuery()
477 {
478 if (isInStepQuery == false) {
479 return E_OK;
480 }
481
482 int errCode = readConnection->EndStepQuery();
483 isInStepQuery = false;
484 ReleaseConnection(true);
485 return errCode;
486 }
487
488 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM)
ExecuteForSharedBlock(int & rowNum,std::string sql,const std::vector<ValueObject> & bindArgs,AppDataFwk::SharedBlock * sharedBlock,int startPos,int requiredPos,bool isCountAllRows)489 int StoreSession::ExecuteForSharedBlock(int &rowNum, std::string sql, const std::vector<ValueObject> &bindArgs,
490 AppDataFwk::SharedBlock *sharedBlock, int startPos, int requiredPos, bool isCountAllRows)
491 {
492 bool isReadOnly = false;
493 int errCode = BeginExecuteSql(sql, isReadOnly);
494 if (errCode != E_OK) {
495 return errCode;
496 }
497 SqliteConnection *con = isReadOnly ? readConnection : connection;
498 errCode =
499 con->ExecuteForSharedBlock(rowNum, sql, bindArgs, sharedBlock, startPos, requiredPos, isCountAllRows);
500 ReleaseConnection(isReadOnly);
501 return errCode;
502 }
503 #endif
504
BeginTransaction()505 int StoreSession::BeginTransaction()
506 {
507 AcquireConnection(false);
508
509 BaseTransaction transaction(connectionPool.getTransactionStack().size());
510 int errCode = connection->ExecuteSql(transaction.getTransactionStr());
511 if (errCode != E_OK) {
512 LOG_DEBUG("storeSession BeginTransaction Failed");
513 ReleaseConnection(false);
514 return errCode;
515 }
516 connectionPool.getTransactionStack().push(transaction);
517 ReleaseConnection(false);
518 return E_OK;
519 }
520
Commit()521 int StoreSession::Commit()
522 {
523 if (connectionPool.getTransactionStack().empty()) {
524 return E_OK;
525 }
526 BaseTransaction transaction = connectionPool.getTransactionStack().top();
527 std::string sqlStr = transaction.getCommitStr();
528 if (sqlStr.size() <= 1) {
529 connectionPool.getTransactionStack().pop();
530 return E_OK;
531 }
532
533 AcquireConnection(false);
534 int errCode = connection->ExecuteSql(sqlStr);
535 ReleaseConnection(false);
536 if (errCode != E_OK) {
537 // if error the transaction is leaving for rollback
538 return errCode;
539 }
540 connectionPool.getTransactionStack().pop();
541 return E_OK;
542 }
543
RollBack()544 int StoreSession::RollBack()
545 {
546 std::stack<BaseTransaction> transactionStack = connectionPool.getTransactionStack();
547 if (transactionStack.empty()) {
548 return E_NO_TRANSACTION_IN_SESSION;
549 }
550 BaseTransaction transaction = transactionStack.top();
551 transactionStack.pop();
552 if (transaction.getType() != TransType::ROLLBACK_SELF && !transactionStack.empty()) {
553 transactionStack.top().setChildFailure(true);
554 }
555 AcquireConnection(false);
556 int errCode = connection->ExecuteSql(transaction.getRollbackStr());
557 ReleaseConnection(false);
558 if (errCode != E_OK) {
559 LOG_ERROR("storeSession RollBack Fail");
560 }
561
562 return errCode;
563 }
564
GetConnectionUseCount()565 int StoreSession::GetConnectionUseCount()
566 {
567 return connectionUseCount;
568 }
569 } // namespace OHOS::NativeRdb
570