• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "store_session.h"
17 #include <chrono>
18 #include <stack>
19 #include <thread>
20 #include "logger.h"
21 #include "rdb_errno.h"
22 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM)
23 #include "shared_block.h"
24 #include "rdb_security_manager.h"
25 #endif
26 #include "sqlite_database_utils.h"
27 #include "sqlite_utils.h"
28 #include "base_transaction.h"
29 
30 namespace OHOS::NativeRdb {
StoreSession(SqliteConnectionPool & connectionPool)31 StoreSession::StoreSession(SqliteConnectionPool &connectionPool)
32     : connectionPool(connectionPool), readConnection(nullptr), connection(nullptr),
33       readConnectionUseCount(0), connectionUseCount(0), isInStepQuery(false)
34 {
35 }
36 
~StoreSession()37 StoreSession::~StoreSession()
38 {
39 }
40 
AcquireConnection(bool isReadOnly)41 void StoreSession::AcquireConnection(bool isReadOnly)
42 {
43     if (isReadOnly) {
44         if (readConnection == nullptr) {
45             readConnection = connectionPool.AcquireConnection(true);
46         }
47         readConnectionUseCount += 1;
48         return;
49     }
50     if (connection == nullptr) {
51         connection = connectionPool.AcquireConnection(false);
52     }
53     connectionUseCount += 1;
54     return;
55 }
56 
ReleaseConnection(bool isReadOnly)57 void StoreSession::ReleaseConnection(bool isReadOnly)
58 {
59     if (isReadOnly) {
60         if ((readConnection == nullptr) || (readConnectionUseCount <= 0)) {
61             LOG_ERROR("SQLiteSession ReleaseConnection repeated release");
62             return;
63         }
64         if (--readConnectionUseCount == 0) {
65             connectionPool.ReleaseConnection(readConnection);
66             readConnection = nullptr;
67         }
68         return;
69     }
70     if ((connection == nullptr) || (connectionUseCount <= 0)) {
71         LOG_ERROR("SQLiteSession ReleaseConnection repeated release");
72         return;
73     }
74 
75     if (--connectionUseCount == 0) {
76         connectionPool.ReleaseConnection(connection);
77         connection = nullptr;
78     }
79 }
80 
PrepareAndGetInfo(const std::string & sql,bool & outIsReadOnly,int & numParameters,std::vector<std::string> & columnNames)81 int StoreSession::PrepareAndGetInfo(
82     const std::string &sql, bool &outIsReadOnly, int &numParameters, std::vector<std::string> &columnNames)
83 {
84     // Obtains the type of SQL statement.
85     int type = SqliteUtils::GetSqlStatementType(sql);
86     if (SqliteUtils::IsSpecial(type)) {
87         return E_TRANSACTION_IN_EXECUTE;
88     }
89     bool assumeReadOnly = SqliteUtils::IsSqlReadOnly(type);
90     AcquireConnection(assumeReadOnly);
91     auto con = assumeReadOnly ? readConnection : connection;
92     int errCode = con->PrepareAndGetInfo(sql, outIsReadOnly, numParameters, columnNames);
93     if (errCode != 0) {
94         ReleaseConnection(assumeReadOnly);
95         return errCode;
96     }
97 
98     ReleaseConnection(assumeReadOnly);
99     return E_OK;
100 }
101 
BeginExecuteSql(const std::string & sql,bool & isReadOnly)102 int StoreSession::BeginExecuteSql(const std::string &sql, bool &isReadOnly)
103 {
104     int type = SqliteUtils::GetSqlStatementType(sql);
105     if (SqliteUtils::IsSpecial(type)) {
106         return E_TRANSACTION_IN_EXECUTE;
107     }
108 
109     bool assumeReadOnly = SqliteUtils::IsSqlReadOnly(type);
110     AcquireConnection(assumeReadOnly);
111     SqliteConnection *con = assumeReadOnly ? readConnection : connection;
112     int errCode = con->Prepare(sql, isReadOnly);
113     if (errCode != 0) {
114         ReleaseConnection(assumeReadOnly);
115         return errCode;
116     }
117 
118     if (isReadOnly == con->IsWriteConnection()) {
119         ReleaseConnection(assumeReadOnly);
120         AcquireConnection(isReadOnly);
121         if (!isReadOnly && !con->IsWriteConnection()) {
122             LOG_ERROR("StoreSession BeginExecute: read connection can not execute write operation");
123             ReleaseConnection(isReadOnly);
124             return E_EXECUTE_WRITE_IN_READ_CONNECTION;
125         }
126         return E_OK;
127     }
128     isReadOnly = assumeReadOnly;
129     return E_OK;
130 }
ExecuteSql(const std::string & sql,const std::vector<ValueObject> & bindArgs)131 int StoreSession::ExecuteSql(const std::string &sql, const std::vector<ValueObject> &bindArgs)
132 {
133     bool isReadOnly = false;
134     int errCode = BeginExecuteSql(sql, isReadOnly);
135     if (errCode != 0) {
136         return errCode;
137     }
138     SqliteConnection *con = isReadOnly ? readConnection : connection;
139     errCode = con->ExecuteSql(sql, bindArgs);
140     ReleaseConnection(isReadOnly);
141     return errCode;
142 }
143 
ExecuteForChangedRowCount(int & changedRows,const std::string & sql,const std::vector<ValueObject> & bindArgs)144 int StoreSession::ExecuteForChangedRowCount(
145     int &changedRows, const std::string &sql, const std::vector<ValueObject> &bindArgs)
146 {
147     bool isReadOnly = false;
148     int errCode = BeginExecuteSql(sql, isReadOnly);
149     if (errCode != 0) {
150         return errCode;
151     }
152     auto con = isReadOnly ? readConnection : connection;
153     errCode = con->ExecuteForChangedRowCount(changedRows, sql, bindArgs);
154     ReleaseConnection(isReadOnly);
155     return errCode;
156 }
157 
ExecuteForLastInsertedRowId(int64_t & outRowId,const std::string & sql,const std::vector<ValueObject> & bindArgs)158 int StoreSession::ExecuteForLastInsertedRowId(
159     int64_t &outRowId, const std::string &sql, const std::vector<ValueObject> &bindArgs)
160 {
161     bool isReadOnly = false;
162     int errCode = BeginExecuteSql(sql, isReadOnly);
163     if (errCode != 0) {
164         LOG_ERROR("rdbStore BeginExecuteSql failed");
165         return errCode;
166     }
167     auto con = isReadOnly ? readConnection : connection;
168     errCode = con->ExecuteForLastInsertedRowId(outRowId, sql, bindArgs);
169     if (errCode != E_OK) {
170         LOG_ERROR("rdbStore ExecuteForLastInsertedRowId FAILED");
171     }
172     ReleaseConnection(isReadOnly);
173     return errCode;
174 }
175 
ExecuteGetLong(int64_t & outValue,const std::string & sql,const std::vector<ValueObject> & bindArgs)176 int StoreSession::ExecuteGetLong(int64_t &outValue, const std::string &sql, const std::vector<ValueObject> &bindArgs)
177 {
178     bool isReadOnly = false;
179     int errCode = BeginExecuteSql(sql, isReadOnly);
180     if (errCode != 0) {
181         return errCode;
182     }
183     auto con = isReadOnly ? readConnection : connection;
184     errCode = con->ExecuteGetLong(outValue, sql, bindArgs);
185     ReleaseConnection(isReadOnly);
186     return errCode;
187 }
188 
ExecuteGetString(std::string & outValue,const std::string & sql,const std::vector<ValueObject> & bindArgs)189 int StoreSession::ExecuteGetString(
190     std::string &outValue, const std::string &sql, const std::vector<ValueObject> &bindArgs)
191 {
192     bool isReadOnly = false;
193     int errCode = BeginExecuteSql(sql, isReadOnly);
194     if (errCode != 0) {
195         return errCode;
196     }
197     auto con = isReadOnly ? readConnection : connection;
198     std::string sqlstr = sql;
199     int type = SqliteDatabaseUtils::GetSqlStatementType(sqlstr);
200     if (type == STATEMENT_PRAGMA) {
201         ReleaseConnection(isReadOnly);
202         AcquireConnection(false);
203         con = connection;
204     }
205 
206     errCode = con->ExecuteGetString(outValue, sql, bindArgs);
207     ReleaseConnection(isReadOnly);
208     return errCode;
209 }
210 
Backup(const std::string databasePath,const std::vector<uint8_t> destEncryptKey,bool isEncrypt)211 int StoreSession::Backup(const std::string databasePath, const std::vector<uint8_t> destEncryptKey, bool isEncrypt)
212 {
213     std::vector<ValueObject> bindArgs;
214     bindArgs.push_back(ValueObject(databasePath));
215     if (destEncryptKey.size() != 0 && !isEncrypt) {
216         bindArgs.push_back(ValueObject(destEncryptKey));
217         ExecuteSql(CIPHER_DEFAULT_ATTACH_HMAC_ALGO);
218 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM)
219     } else if (isEncrypt) {
220         std::vector<uint8_t> key;
221         RdbPassword rdbPwd;
222         rdbPwd = RdbSecurityManager::GetInstance().GetRdbPassword(RdbSecurityManager::KeyFileType::PUB_KEY_FILE);
223         key = std::vector<uint8_t>(rdbPwd.GetData(), rdbPwd.GetData() + rdbPwd.GetSize());
224         bindArgs.push_back(ValueObject(key));
225         ExecuteSql(CIPHER_DEFAULT_ATTACH_HMAC_ALGO);
226 #endif
227     } else {
228         std::string str = "";
229         bindArgs.push_back(ValueObject(str));
230     }
231 
232     int errCode = ExecuteSql(ATTACH_BACKUP_SQL, bindArgs);
233     if (errCode != E_OK) {
234         LOG_ERROR("ExecuteSql ATTACH_BACKUP_SQL error %{public}d", errCode);
235         return errCode;
236     }
237     int64_t count;
238     errCode = ExecuteGetLong(count, EXPORT_SQL, std::vector<ValueObject>());
239     if (errCode != E_OK) {
240         LOG_ERROR("ExecuteSql EXPORT_SQL error %{public}d", errCode);
241         return errCode;
242     }
243 
244     errCode = ExecuteSql(DETACH_BACKUP_SQL, std::vector<ValueObject>());
245     if (errCode != E_OK) {
246         LOG_ERROR("ExecuteSql DETACH_BACKUP_SQL error %{public}d", errCode);
247         return errCode;
248     }
249     return E_OK;
250 }
251 
252 // Checks whether this thread holds a database connection.
IsHoldingConnection() const253 bool StoreSession::IsHoldingConnection() const
254 {
255     if (connection == nullptr && readConnection == nullptr) {
256         return false;
257     } else {
258         return true;
259     }
260 }
261 
CheckNoTransaction() const262 int StoreSession::CheckNoTransaction() const
263 {
264     int errorCode = 0;
265     if (connectionPool.getTransactionStack().empty()) {
266         errorCode = E_STORE_SESSION_NO_CURRENT_TRANSACTION;
267         return errorCode;
268     }
269     return E_OK;
270 }
271 
GiveConnectionTemporarily(int64_t milliseconds)272 int StoreSession::GiveConnectionTemporarily(int64_t milliseconds)
273 {
274     int errorCode = CheckNoTransaction();
275     if (errorCode != E_OK) {
276         return errorCode;
277     }
278     BaseTransaction transaction = connectionPool.getTransactionStack().top();
279     if (transaction.IsMarkedSuccessful() || connectionPool.getTransactionStack().size() > 1) {
280         errorCode = E_STORE_SESSION_NOT_GIVE_CONNECTION_TEMPORARILY;
281         return errorCode;
282     }
283 
284     MarkAsCommit();
285     EndTransaction();
286     if (milliseconds > 0) {
287         std::this_thread::sleep_for(std::chrono::milliseconds(milliseconds));
288     }
289     BeginTransaction();
290     return E_OK;
291 }
292 
Attach(const std::string & alias,const std::string & pathName,const std::vector<uint8_t> destEncryptKey,bool isEncrypt)293 int StoreSession::Attach(
294     const std::string &alias, const std::string &pathName, const std::vector<uint8_t> destEncryptKey, bool isEncrypt)
295 {
296     std::string journalMode;
297     int errCode = ExecuteGetString(journalMode, "PRAGMA journal_mode", std::vector<ValueObject>());
298     if (errCode != E_OK) {
299         LOG_ERROR("RdbStoreImpl CheckAttach fail to get journal mode : %{public}d", errCode);
300         return errCode;
301     }
302     journalMode = SqliteUtils::StrToUpper(journalMode);
303     if (journalMode == "WAL") {
304         LOG_ERROR("RdbStoreImpl attach is not supported in WAL mode");
305         return E_NOT_SUPPORTED_ATTACH_IN_WAL_MODE;
306     }
307 
308     std::vector<ValueObject> bindArgs;
309     bindArgs.push_back(ValueObject(pathName));
310     bindArgs.push_back(ValueObject(alias));
311     if (destEncryptKey.size() != 0 && !isEncrypt) {
312         bindArgs.push_back(ValueObject(destEncryptKey));
313         ExecuteSql(CIPHER_DEFAULT_ATTACH_HMAC_ALGO);
314 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM)
315     } else if (isEncrypt) {
316         std::vector<uint8_t> key;
317         RdbPassword rdbPwd;
318         rdbPwd = RdbSecurityManager::GetInstance().GetRdbPassword(RdbSecurityManager::KeyFileType::PUB_KEY_FILE);
319         key = std::vector<uint8_t>(rdbPwd.GetData(), rdbPwd.GetData() + rdbPwd.GetSize());
320         bindArgs.push_back(ValueObject(key));
321         ExecuteSql(CIPHER_DEFAULT_ATTACH_HMAC_ALGO);
322 #endif
323     } else {
324         std::string str = "";
325         bindArgs.push_back(ValueObject(str));
326     }
327     errCode = ExecuteSql(ATTACH_SQL, bindArgs);
328     if (errCode != E_OK) {
329         LOG_ERROR("ExecuteSql ATTACH_SQL error %{public}d", errCode);
330         return errCode;
331     }
332 
333     return E_OK;
334 }
335 
BeginTransaction(TransactionObserver * transactionObserver)336 int StoreSession::BeginTransaction(TransactionObserver *transactionObserver)
337 {
338     if (connectionPool.getTransactionStack().empty()) {
339         AcquireConnection(false);
340 
341         int errCode = connection->ExecuteSql("BEGIN EXCLUSIVE;");
342         if (errCode != E_OK) {
343             ReleaseConnection(false);
344             return errCode;
345         }
346     }
347 
348     if (transactionObserver != nullptr) {
349         transactionObserver->OnBegin();
350     }
351 
352     BaseTransaction transaction(connectionPool.getTransactionStack().size());
353     connectionPool.getTransactionStack().push(transaction);
354 
355     return E_OK;
356 }
357 
MarkAsCommitWithObserver(TransactionObserver * transactionObserver)358 int StoreSession::MarkAsCommitWithObserver(TransactionObserver *transactionObserver)
359 {
360     if (connectionPool.getTransactionStack().empty()) {
361         return E_NO_TRANSACTION_IN_SESSION;
362     }
363     connectionPool.getTransactionStack().top().SetMarkedSuccessful(true);
364     return E_OK;
365 }
366 
EndTransactionWithObserver(TransactionObserver * transactionObserver)367 int StoreSession::EndTransactionWithObserver(TransactionObserver *transactionObserver)
368 {
369     if (connectionPool.getTransactionStack().empty()) {
370         return E_NO_TRANSACTION_IN_SESSION;
371     }
372 
373     BaseTransaction transaction = connectionPool.getTransactionStack().top();
374     bool isSucceed = transaction.IsAllBeforeSuccessful() && transaction.IsMarkedSuccessful();
375     connectionPool.getTransactionStack().pop();
376 
377     if (transactionObserver != nullptr) {
378         if (isSucceed) {
379             transactionObserver->OnCommit();
380         } else {
381             transactionObserver->OnRollback();
382         }
383     }
384 
385     if (!connectionPool.getTransactionStack().empty()) {
386         if (transactionObserver != nullptr) {
387             transactionObserver->OnRollback();
388         }
389 
390         if (!isSucceed) {
391             connectionPool.getTransactionStack().top().SetAllBeforeSuccessful(false);
392         }
393     } else {
394         int errCode;
395         if (connection == nullptr) {
396             LOG_ERROR("connection is null");
397             return E_ERROR;
398         }
399         if (isSucceed) {
400             errCode = connection->ExecuteSql("COMMIT;");
401         } else {
402             errCode = connection->ExecuteSql("ROLLBACK;");
403         }
404 
405         ReleaseConnection(false);
406         return errCode;
407     }
408 
409     return E_OK;
410 }
411 
MarkAsCommit()412 int StoreSession::MarkAsCommit()
413 {
414     if (connectionPool.getTransactionStack().empty()) {
415         return E_NO_TRANSACTION_IN_SESSION;
416     }
417     connectionPool.getTransactionStack().top().SetMarkedSuccessful(true);
418     return E_OK;
419 }
420 
EndTransaction()421 int StoreSession::EndTransaction()
422 {
423     if (connectionPool.getTransactionStack().empty()) {
424         return E_NO_TRANSACTION_IN_SESSION;
425     }
426 
427     BaseTransaction transaction = connectionPool.getTransactionStack().top();
428     bool isSucceed = transaction.IsAllBeforeSuccessful() && transaction.IsMarkedSuccessful();
429     connectionPool.getTransactionStack().pop();
430     if (!connectionPool.getTransactionStack().empty()) {
431         if (!isSucceed) {
432             connectionPool.getTransactionStack().top().SetAllBeforeSuccessful(false);
433         }
434     } else {
435         if (connection == nullptr) {
436             LOG_ERROR("connection is null");
437             return E_ERROR;
438         }
439         int errCode = connection->ExecuteSql(isSucceed ? "COMMIT;" : "ROLLBACK;");
440         ReleaseConnection(false);
441         return errCode;
442     }
443 
444     return E_OK;
445 }
IsInTransaction() const446 bool StoreSession::IsInTransaction() const
447 {
448     return !connectionPool.getTransactionStack().empty();
449 }
450 
BeginStepQuery(int & errCode,const std::string & sql,const std::vector<std::string> & selectionArgs)451 std::shared_ptr<SqliteStatement> StoreSession::BeginStepQuery(
452     int &errCode, const std::string &sql, const std::vector<std::string> &selectionArgs)
453 {
454     if (isInStepQuery == true) {
455         LOG_ERROR("StoreSession BeginStepQuery fail : begin more step query in one session !");
456         errCode = E_MORE_STEP_QUERY_IN_ONE_SESSION;
457         return nullptr; // fail,already in
458     }
459 
460     if (SqliteUtils::GetSqlStatementType(sql) != SqliteUtils::STATEMENT_SELECT) {
461         LOG_ERROR("StoreSession BeginStepQuery fail : not select sql !");
462         errCode = E_EXECUTE_IN_STEP_QUERY;
463         return nullptr;
464     }
465 
466     AcquireConnection(true);
467     std::shared_ptr<SqliteStatement> statement = readConnection->BeginStepQuery(errCode, sql, selectionArgs);
468     if (statement == nullptr) {
469         ReleaseConnection(true);
470         return nullptr;
471     }
472     isInStepQuery = true;
473     return statement;
474 }
475 
EndStepQuery()476 int StoreSession::EndStepQuery()
477 {
478     if (isInStepQuery == false) {
479         return E_OK;
480     }
481 
482     int errCode = readConnection->EndStepQuery();
483     isInStepQuery = false;
484     ReleaseConnection(true);
485     return errCode;
486 }
487 
488 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM)
ExecuteForSharedBlock(int & rowNum,std::string sql,const std::vector<ValueObject> & bindArgs,AppDataFwk::SharedBlock * sharedBlock,int startPos,int requiredPos,bool isCountAllRows)489 int StoreSession::ExecuteForSharedBlock(int &rowNum, std::string sql, const std::vector<ValueObject> &bindArgs,
490     AppDataFwk::SharedBlock *sharedBlock, int startPos, int requiredPos, bool isCountAllRows)
491 {
492     bool isReadOnly = false;
493     int errCode = BeginExecuteSql(sql, isReadOnly);
494     if (errCode != E_OK) {
495         return errCode;
496     }
497     SqliteConnection *con = isReadOnly ? readConnection : connection;
498     errCode =
499         con->ExecuteForSharedBlock(rowNum, sql, bindArgs, sharedBlock, startPos, requiredPos, isCountAllRows);
500     ReleaseConnection(isReadOnly);
501     return errCode;
502 }
503 #endif
504 
BeginTransaction()505 int StoreSession::BeginTransaction()
506 {
507     AcquireConnection(false);
508 
509     BaseTransaction transaction(connectionPool.getTransactionStack().size());
510     int errCode = connection->ExecuteSql(transaction.getTransactionStr());
511     if (errCode != E_OK) {
512         LOG_DEBUG("storeSession BeginTransaction Failed");
513         ReleaseConnection(false);
514         return errCode;
515     }
516     connectionPool.getTransactionStack().push(transaction);
517     ReleaseConnection(false);
518     return E_OK;
519 }
520 
Commit()521 int StoreSession::Commit()
522 {
523     if (connectionPool.getTransactionStack().empty()) {
524         return E_OK;
525     }
526     BaseTransaction transaction = connectionPool.getTransactionStack().top();
527     std::string sqlStr = transaction.getCommitStr();
528     if (sqlStr.size() <= 1) {
529         connectionPool.getTransactionStack().pop();
530         return E_OK;
531     }
532 
533     AcquireConnection(false);
534     int errCode = connection->ExecuteSql(sqlStr);
535     ReleaseConnection(false);
536     if (errCode != E_OK) {
537         // if error the transaction is leaving for rollback
538         return errCode;
539     }
540     connectionPool.getTransactionStack().pop();
541     return E_OK;
542 }
543 
RollBack()544 int StoreSession::RollBack()
545 {
546     std::stack<BaseTransaction> transactionStack = connectionPool.getTransactionStack();
547     if (transactionStack.empty()) {
548         return E_NO_TRANSACTION_IN_SESSION;
549     }
550     BaseTransaction transaction = transactionStack.top();
551     transactionStack.pop();
552     if (transaction.getType() != TransType::ROLLBACK_SELF && !transactionStack.empty()) {
553         transactionStack.top().setChildFailure(true);
554     }
555     AcquireConnection(false);
556     int errCode = connection->ExecuteSql(transaction.getRollbackStr());
557     ReleaseConnection(false);
558     if (errCode != E_OK) {
559         LOG_ERROR("storeSession RollBack Fail");
560     }
561 
562     return errCode;
563 }
564 
GetConnectionUseCount()565 int StoreSession::GetConnectionUseCount()
566 {
567     return connectionUseCount;
568 }
569 } // namespace OHOS::NativeRdb
570