• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "store_session.h"
17 #include <chrono>
18 #include <stack>
19 #include <thread>
20 #include "logger.h"
21 #include "rdb_errno.h"
22 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM)
23 #include "shared_block.h"
24 #endif
25 #include "sqlite_database_utils.h"
26 #include "sqlite_utils.h"
27 #include "base_transaction.h"
28 
29 namespace OHOS::NativeRdb {
StoreSession(SqliteConnectionPool & connectionPool)30 StoreSession::StoreSession(SqliteConnectionPool &connectionPool)
31     : connectionPool(connectionPool), readConnection(nullptr), connection(nullptr),
32       readConnectionUseCount(0), connectionUseCount(0), isInStepQuery(false)
33 {
34 }
35 
~StoreSession()36 StoreSession::~StoreSession()
37 {
38 }
39 
AcquireConnection(bool isReadOnly)40 void StoreSession::AcquireConnection(bool isReadOnly)
41 {
42     if (isReadOnly) {
43         if (readConnection == nullptr) {
44             readConnection = connectionPool.AcquireConnection(true);
45         }
46         readConnectionUseCount += 1;
47         return;
48     }
49     if (connection == nullptr) {
50         connection = connectionPool.AcquireConnection(false);
51     }
52     connectionUseCount += 1;
53     return;
54 }
55 
ReleaseConnection(bool isReadOnly)56 void StoreSession::ReleaseConnection(bool isReadOnly)
57 {
58     if (isReadOnly) {
59         if ((readConnection == nullptr) || (readConnectionUseCount <= 0)) {
60             LOG_ERROR("SQLiteSession ReleaseConnection repeated release");
61             return;
62         }
63         if (--readConnectionUseCount == 0) {
64             connectionPool.ReleaseConnection(readConnection);
65             readConnection = nullptr;
66         }
67         return;
68     }
69     if ((connection == nullptr) || (connectionUseCount <= 0)) {
70         LOG_ERROR("SQLiteSession ReleaseConnection repeated release");
71         return;
72     }
73 
74     if (--connectionUseCount == 0) {
75         connectionPool.ReleaseConnection(connection);
76         connection = nullptr;
77     }
78 }
79 
PrepareAndGetInfo(const std::string & sql,bool & outIsReadOnly,int & numParameters,std::vector<std::string> & columnNames)80 int StoreSession::PrepareAndGetInfo(
81     const std::string &sql, bool &outIsReadOnly, int &numParameters, std::vector<std::string> &columnNames)
82 {
83     // Obtains the type of SQL statement.
84     int type = SqliteUtils::GetSqlStatementType(sql);
85     if (SqliteUtils::IsSpecial(type)) {
86         return E_TRANSACTION_IN_EXECUTE;
87     }
88     bool assumeReadOnly = SqliteUtils::IsSqlReadOnly(type);
89     AcquireConnection(assumeReadOnly);
90     auto con = assumeReadOnly ? readConnection : connection;
91     int errCode = con->PrepareAndGetInfo(sql, outIsReadOnly, numParameters, columnNames);
92     if (errCode != 0) {
93         ReleaseConnection(assumeReadOnly);
94         return errCode;
95     }
96 
97     ReleaseConnection(assumeReadOnly);
98     return E_OK;
99 }
100 
BeginExecuteSql(const std::string & sql,bool & isReadOnly)101 int StoreSession::BeginExecuteSql(const std::string &sql, bool &isReadOnly)
102 {
103     int type = SqliteUtils::GetSqlStatementType(sql);
104     if (SqliteUtils::IsSpecial(type)) {
105         return E_TRANSACTION_IN_EXECUTE;
106     }
107 
108     bool assumeReadOnly = SqliteUtils::IsSqlReadOnly(type);
109     AcquireConnection(assumeReadOnly);
110     SqliteConnection *con = assumeReadOnly ? readConnection : connection;
111     int errCode = con->Prepare(sql, isReadOnly);
112     if (errCode != 0) {
113         ReleaseConnection(assumeReadOnly);
114         return errCode;
115     }
116 
117     if (isReadOnly == con->IsWriteConnection()) {
118         ReleaseConnection(assumeReadOnly);
119         AcquireConnection(isReadOnly);
120         if (!isReadOnly && !con->IsWriteConnection()) {
121             LOG_ERROR("StoreSession BeginExecute: read connection can not execute write operation");
122             ReleaseConnection(isReadOnly);
123             return E_EXECUTE_WRITE_IN_READ_CONNECTION;
124         }
125         return E_OK;
126     }
127     isReadOnly = assumeReadOnly;
128     return E_OK;
129 }
ExecuteSql(const std::string & sql,const std::vector<ValueObject> & bindArgs)130 int StoreSession::ExecuteSql(const std::string &sql, const std::vector<ValueObject> &bindArgs)
131 {
132     bool isReadOnly = false;
133     int errCode = BeginExecuteSql(sql, isReadOnly);
134     if (errCode != 0) {
135         return errCode;
136     }
137     SqliteConnection *con = isReadOnly ? readConnection : connection;
138     errCode = con->ExecuteSql(sql, bindArgs);
139     ReleaseConnection(isReadOnly);
140     return errCode;
141 }
142 
ExecuteForChangedRowCount(int & changedRows,const std::string & sql,const std::vector<ValueObject> & bindArgs)143 int StoreSession::ExecuteForChangedRowCount(
144     int &changedRows, const std::string &sql, const std::vector<ValueObject> &bindArgs)
145 {
146     bool isReadOnly = false;
147     int errCode = BeginExecuteSql(sql, isReadOnly);
148     if (errCode != 0) {
149         return errCode;
150     }
151     auto con = isReadOnly ? readConnection : connection;
152     errCode = con->ExecuteForChangedRowCount(changedRows, sql, bindArgs);
153     ReleaseConnection(isReadOnly);
154     return errCode;
155 }
156 
ExecuteForLastInsertedRowId(int64_t & outRowId,const std::string & sql,const std::vector<ValueObject> & bindArgs)157 int StoreSession::ExecuteForLastInsertedRowId(
158     int64_t &outRowId, const std::string &sql, const std::vector<ValueObject> &bindArgs)
159 {
160     bool isReadOnly = false;
161     int errCode = BeginExecuteSql(sql, isReadOnly);
162     if (errCode != 0) {
163         LOG_ERROR("rdbStore BeginExecuteSql failed");
164         return errCode;
165     }
166     auto con = isReadOnly ? readConnection : connection;
167     errCode = con->ExecuteForLastInsertedRowId(outRowId, sql, bindArgs);
168     if (errCode != E_OK) {
169         LOG_ERROR("rdbStore ExecuteForLastInsertedRowId FAILED");
170     }
171     ReleaseConnection(isReadOnly);
172     return errCode;
173 }
174 
ExecuteGetLong(int64_t & outValue,const std::string & sql,const std::vector<ValueObject> & bindArgs)175 int StoreSession::ExecuteGetLong(int64_t &outValue, const std::string &sql, const std::vector<ValueObject> &bindArgs)
176 {
177     bool isReadOnly = false;
178     int errCode = BeginExecuteSql(sql, isReadOnly);
179     if (errCode != 0) {
180         return errCode;
181     }
182     auto con = isReadOnly ? readConnection : connection;
183     errCode = con->ExecuteGetLong(outValue, sql, bindArgs);
184     ReleaseConnection(isReadOnly);
185     return errCode;
186 }
187 
ExecuteGetString(std::string & outValue,const std::string & sql,const std::vector<ValueObject> & bindArgs)188 int StoreSession::ExecuteGetString(
189     std::string &outValue, const std::string &sql, const std::vector<ValueObject> &bindArgs)
190 {
191     bool isReadOnly = false;
192     int errCode = BeginExecuteSql(sql, isReadOnly);
193     if (errCode != 0) {
194         return errCode;
195     }
196     auto con = isReadOnly ? readConnection : connection;
197     std::string sqlstr = sql;
198     int type = SqliteDatabaseUtils::GetSqlStatementType(sqlstr);
199     if (type == STATEMENT_PRAGMA) {
200         ReleaseConnection(isReadOnly);
201         AcquireConnection(false);
202         con = connection;
203     }
204 
205     errCode = con->ExecuteGetString(outValue, sql, bindArgs);
206     ReleaseConnection(isReadOnly);
207     return errCode;
208 }
209 
Backup(const std::string databasePath,const std::vector<uint8_t> destEncryptKey)210 int StoreSession::Backup(const std::string databasePath, const std::vector<uint8_t> destEncryptKey)
211 {
212     std::vector<ValueObject> bindArgs;
213     bindArgs.push_back(ValueObject(databasePath));
214     if (destEncryptKey.size() != 0) {
215         bindArgs.push_back(ValueObject(destEncryptKey));
216     } else {
217         std::string str = "";
218         bindArgs.push_back(ValueObject(str));
219     }
220 
221     int errCode = ExecuteSql(ATTACH_BACKUP_SQL, bindArgs);
222     if (errCode != E_OK) {
223         LOG_ERROR("ExecuteSql ATTACH_BACKUP_SQL error %{public}d", errCode);
224         return errCode;
225     }
226     int64_t count;
227     errCode = ExecuteGetLong(count, EXPORT_SQL, std::vector<ValueObject>());
228     if (errCode != E_OK) {
229         LOG_ERROR("ExecuteSql EXPORT_SQL error %{public}d", errCode);
230         return errCode;
231     }
232 
233     errCode = ExecuteSql(DETACH_BACKUP_SQL, std::vector<ValueObject>());
234     if (errCode != E_OK) {
235         LOG_ERROR("ExecuteSql DETACH_BACKUP_SQL error %{public}d", errCode);
236         return errCode;
237     }
238     return E_OK;
239 }
240 
241 // Checks whether this thread holds a database connection.
IsHoldingConnection() const242 bool StoreSession::IsHoldingConnection() const
243 {
244     if (connection == nullptr && readConnection == nullptr) {
245         return false;
246     } else {
247         return true;
248     }
249 }
250 
CheckNoTransaction() const251 int StoreSession::CheckNoTransaction() const
252 {
253     int errorCode = 0;
254     if (connectionPool.getTransactionStack().empty()) {
255         errorCode = E_STORE_SESSION_NO_CURRENT_TRANSACTION;
256         return errorCode;
257     }
258     return E_OK;
259 }
260 
GiveConnectionTemporarily(int64_t milliseconds)261 int StoreSession::GiveConnectionTemporarily(int64_t milliseconds)
262 {
263     int errorCode = CheckNoTransaction();
264     if (errorCode != E_OK) {
265         return errorCode;
266     }
267     BaseTransaction transaction = connectionPool.getTransactionStack().top();
268     if (transaction.IsMarkedSuccessful() || connectionPool.getTransactionStack().size() > 1) {
269         errorCode = E_STORE_SESSION_NOT_GIVE_CONNECTION_TEMPORARILY;
270         return errorCode;
271     }
272 
273     MarkAsCommit();
274     EndTransaction();
275     if (milliseconds > 0) {
276         std::this_thread::sleep_for(std::chrono::milliseconds(milliseconds));
277     }
278     BeginTransaction();
279     return E_OK;
280 }
281 
Attach(const std::string & alias,const std::string & pathName,const std::vector<uint8_t> destEncryptKey)282 int StoreSession::Attach(
283     const std::string &alias, const std::string &pathName, const std::vector<uint8_t> destEncryptKey)
284 {
285     std::string journalMode;
286     int errCode = ExecuteGetString(journalMode, "PRAGMA journal_mode", std::vector<ValueObject>());
287     if (errCode != E_OK) {
288         LOG_ERROR("RdbStoreImpl CheckAttach fail to get journal mode : %{public}d", errCode);
289         return errCode;
290     }
291     journalMode = SqliteUtils::StrToUpper(journalMode);
292     if (journalMode == "WAL") {
293         LOG_ERROR("RdbStoreImpl attach is not supported in WAL mode");
294         return E_NOT_SUPPORTED_ATTACH_IN_WAL_MODE;
295     }
296 
297     std::vector<ValueObject> bindArgs;
298     bindArgs.push_back(ValueObject(pathName));
299     bindArgs.push_back(ValueObject(alias));
300     if (destEncryptKey.size() != 0) {
301         bindArgs.push_back(ValueObject(destEncryptKey));
302     } else {
303         std::string str = "";
304         bindArgs.push_back(ValueObject(str));
305     }
306     errCode = ExecuteSql(ATTACH_SQL, bindArgs);
307     if (errCode != E_OK) {
308         LOG_ERROR("ExecuteSql ATTACH_SQL error %{public}d", errCode);
309         return errCode;
310     }
311 
312     return E_OK;
313 }
314 
BeginTransaction(TransactionObserver * transactionObserver)315 int StoreSession::BeginTransaction(TransactionObserver *transactionObserver)
316 {
317     if (connectionPool.getTransactionStack().empty()) {
318         AcquireConnection(false);
319 
320         int errCode = connection->ExecuteSql("BEGIN EXCLUSIVE;");
321         if (errCode != E_OK) {
322             ReleaseConnection(false);
323             return errCode;
324         }
325     }
326 
327     if (transactionObserver != nullptr) {
328         transactionObserver->OnBegin();
329     }
330 
331     BaseTransaction transaction(connectionPool.getTransactionStack().size());
332     connectionPool.getTransactionStack().push(transaction);
333 
334     return E_OK;
335 }
336 
MarkAsCommitWithObserver(TransactionObserver * transactionObserver)337 int StoreSession::MarkAsCommitWithObserver(TransactionObserver *transactionObserver)
338 {
339     if (connectionPool.getTransactionStack().empty()) {
340         return E_NO_TRANSACTION_IN_SESSION;
341     }
342     connectionPool.getTransactionStack().top().SetMarkedSuccessful(true);
343     return E_OK;
344 }
345 
EndTransactionWithObserver(TransactionObserver * transactionObserver)346 int StoreSession::EndTransactionWithObserver(TransactionObserver *transactionObserver)
347 {
348     if (connectionPool.getTransactionStack().empty()) {
349         return E_NO_TRANSACTION_IN_SESSION;
350     }
351 
352     BaseTransaction transaction = connectionPool.getTransactionStack().top();
353     bool isSucceed = transaction.IsAllBeforeSuccessful() && transaction.IsMarkedSuccessful();
354     connectionPool.getTransactionStack().pop();
355 
356     if (transactionObserver != nullptr) {
357         if (isSucceed) {
358             transactionObserver->OnCommit();
359         } else {
360             transactionObserver->OnRollback();
361         }
362     }
363 
364     if (!connectionPool.getTransactionStack().empty()) {
365         if (transactionObserver != nullptr) {
366             transactionObserver->OnRollback();
367         }
368 
369         if (!isSucceed) {
370             connectionPool.getTransactionStack().top().SetAllBeforeSuccessful(false);
371         }
372     } else {
373         int errCode;
374         if (connection == nullptr) {
375             LOG_ERROR("connection is null");
376             return E_ERROR;
377         }
378         if (isSucceed) {
379             errCode = connection->ExecuteSql("COMMIT;");
380         } else {
381             errCode = connection->ExecuteSql("ROLLBACK;");
382         }
383 
384         ReleaseConnection(false);
385         return errCode;
386     }
387 
388     return E_OK;
389 }
390 
MarkAsCommit()391 int StoreSession::MarkAsCommit()
392 {
393     if (connectionPool.getTransactionStack().empty()) {
394         return E_NO_TRANSACTION_IN_SESSION;
395     }
396     connectionPool.getTransactionStack().top().SetMarkedSuccessful(true);
397     return E_OK;
398 }
399 
EndTransaction()400 int StoreSession::EndTransaction()
401 {
402     if (connectionPool.getTransactionStack().empty()) {
403         return E_NO_TRANSACTION_IN_SESSION;
404     }
405 
406     BaseTransaction transaction = connectionPool.getTransactionStack().top();
407     bool isSucceed = transaction.IsAllBeforeSuccessful() && transaction.IsMarkedSuccessful();
408     connectionPool.getTransactionStack().pop();
409     if (!connectionPool.getTransactionStack().empty()) {
410         if (!isSucceed) {
411             connectionPool.getTransactionStack().top().SetAllBeforeSuccessful(false);
412         }
413     } else {
414         if (connection == nullptr) {
415             LOG_ERROR("connection is null");
416             return E_ERROR;
417         }
418         int errCode = connection->ExecuteSql(isSucceed ? "COMMIT;" : "ROLLBACK;");
419         ReleaseConnection(false);
420         return errCode;
421     }
422 
423     return E_OK;
424 }
IsInTransaction() const425 bool StoreSession::IsInTransaction() const
426 {
427     return !connectionPool.getTransactionStack().empty();
428 }
429 
BeginStepQuery(int & errCode,const std::string & sql,const std::vector<std::string> & selectionArgs)430 std::shared_ptr<SqliteStatement> StoreSession::BeginStepQuery(
431     int &errCode, const std::string &sql, const std::vector<std::string> &selectionArgs)
432 {
433     if (isInStepQuery == true) {
434         LOG_ERROR("StoreSession BeginStepQuery fail : begin more step query in one session !");
435         errCode = E_MORE_STEP_QUERY_IN_ONE_SESSION;
436         return nullptr; // fail,already in
437     }
438 
439     if (SqliteUtils::GetSqlStatementType(sql) != SqliteUtils::STATEMENT_SELECT) {
440         LOG_ERROR("StoreSession BeginStepQuery fail : not select sql !");
441         errCode = E_EXECUTE_IN_STEP_QUERY;
442         return nullptr;
443     }
444 
445     AcquireConnection(true);
446     std::shared_ptr<SqliteStatement> statement = readConnection->BeginStepQuery(errCode, sql, selectionArgs);
447     if (statement == nullptr) {
448         ReleaseConnection(true);
449         return nullptr;
450     }
451     isInStepQuery = true;
452     return statement;
453 }
454 
EndStepQuery()455 int StoreSession::EndStepQuery()
456 {
457     if (isInStepQuery == false) {
458         return E_OK;
459     }
460 
461     int errCode = readConnection->EndStepQuery();
462     isInStepQuery = false;
463     ReleaseConnection(true);
464     return errCode;
465 }
466 
467 #if !defined(WINDOWS_PLATFORM) && !defined(MAC_PLATFORM)
ExecuteForSharedBlock(int & rowNum,std::string sql,const std::vector<ValueObject> & bindArgs,AppDataFwk::SharedBlock * sharedBlock,int startPos,int requiredPos,bool isCountAllRows)468 int StoreSession::ExecuteForSharedBlock(int &rowNum, std::string sql, const std::vector<ValueObject> &bindArgs,
469     AppDataFwk::SharedBlock *sharedBlock, int startPos, int requiredPos, bool isCountAllRows)
470 {
471     bool isReadOnly = false;
472     int errCode = BeginExecuteSql(sql, isReadOnly);
473     if (errCode != E_OK) {
474         return errCode;
475     }
476     SqliteConnection *con = isReadOnly ? readConnection : connection;
477     errCode =
478         con->ExecuteForSharedBlock(rowNum, sql, bindArgs, sharedBlock, startPos, requiredPos, isCountAllRows);
479     ReleaseConnection(isReadOnly);
480     return errCode;
481 }
482 #endif
483 
BeginTransaction()484 int StoreSession::BeginTransaction()
485 {
486     AcquireConnection(false);
487 
488     BaseTransaction transaction(connectionPool.getTransactionStack().size());
489     int errCode = connection->ExecuteSql(transaction.getTransactionStr());
490     if (errCode != E_OK) {
491         LOG_DEBUG("storeSession BeginTransaction Failed");
492         ReleaseConnection(false);
493         return errCode;
494     }
495     connectionPool.getTransactionStack().push(transaction);
496     ReleaseConnection(false);
497     return E_OK;
498 }
499 
Commit()500 int StoreSession::Commit()
501 {
502     if (connectionPool.getTransactionStack().empty()) {
503         return E_OK;
504     }
505     BaseTransaction transaction = connectionPool.getTransactionStack().top();
506     std::string sqlStr = transaction.getCommitStr();
507     if (sqlStr.size() <= 1) {
508         connectionPool.getTransactionStack().pop();
509         return E_OK;
510     }
511 
512     AcquireConnection(false);
513     int errCode = connection->ExecuteSql(sqlStr);
514     ReleaseConnection(false);
515     if (errCode != E_OK) {
516         // if error the transaction is leaving for rollback
517         return errCode;
518     }
519     connectionPool.getTransactionStack().pop();
520     return E_OK;
521 }
522 
RollBack()523 int StoreSession::RollBack()
524 {
525     std::stack<BaseTransaction> transactionStack = connectionPool.getTransactionStack();
526     if (transactionStack.empty()) {
527         return E_NO_TRANSACTION_IN_SESSION;
528     }
529     BaseTransaction transaction = transactionStack.top();
530     transactionStack.pop();
531     if (transaction.getType() != TransType::ROLLBACK_SELF && !transactionStack.empty()) {
532         transactionStack.top().setChildFailure(true);
533     }
534     AcquireConnection(false);
535     int errCode = connection->ExecuteSql(transaction.getRollbackStr());
536     ReleaseConnection(false);
537     if (errCode != E_OK) {
538         LOG_ERROR("storeSession RollBack Fail");
539     }
540 
541     return errCode;
542 }
543 
GetConnectionUseCount()544 int StoreSession::GetConnectionUseCount()
545 {
546     return connectionUseCount;
547 }
548 } // namespace OHOS::NativeRdb
549