• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include <mutex>
16 #include <openssl/sha.h>
17 #include <string>
18 #include <sys/time.h>
19 #include <thread>
20 #include <vector>
21 
22 // using the "sqlite3sym.h" in OHOS
23 #ifndef USE_SQLITE_SYMBOLS
24 #include "sqlite3.h"
25 #else
26 #include "sqlite3sym.h"
27 #endif
28 
29 namespace {
30 constexpr int E_OK = 0;
31 constexpr int E_ERROR = 1;
32 constexpr int BUSY_TIMEOUT = 2000;  // 2s.
33 class ValueHashCalc {
34 public:
ValueHashCalc()35     ValueHashCalc() {};
~ValueHashCalc()36     ~ValueHashCalc()
37     {
38         delete context_;
39         context_ = nullptr;
40     }
41 
Initialize()42     int Initialize()
43     {
44         context_ = new (std::nothrow) SHA256_CTX;
45         if (context_ == nullptr) {
46             return -E_ERROR;
47         }
48 
49         int errCode = SHA256_Init(context_);
50         if (errCode == 0) {
51             return -E_ERROR;
52         }
53         return E_OK;
54     }
55 
Update(const std::vector<uint8_t> & value)56     int Update(const std::vector<uint8_t> &value)
57     {
58         if (context_ == nullptr) {
59             return -E_ERROR;
60         }
61         int errCode = SHA256_Update(context_, value.data(), value.size());
62         if (errCode == 0) {
63             return -E_ERROR;
64         }
65         return E_OK;
66     }
67 
GetResult(std::vector<uint8_t> & value)68     int GetResult(std::vector<uint8_t> &value)
69     {
70         if (context_ == nullptr) {
71             return -E_ERROR;
72         }
73 
74         value.resize(SHA256_DIGEST_LENGTH);
75         int errCode = SHA256_Final(value.data(), context_);
76         if (errCode == 0) {
77             return -E_ERROR;
78         }
79 
80         return E_OK;
81     }
82 
83 private:
84     SHA256_CTX *context_ = nullptr;
85 };
86 
87 
88 const uint64_t MULTIPLES_BETWEEN_SECONDS_AND_MICROSECONDS = 1000000;
89 
90 using Timestamp = uint64_t;
91 using TimeOffset = int64_t;
92 
93 class TimeHelper {
94 public:
95     constexpr static int64_t BASE_OFFSET = 10000LL * 365LL * 24LL * 3600LL * 1000LL * 1000LL * 10L; // 10000 year 100ns
96 
97     constexpr static int64_t MAX_VALID_TIME = BASE_OFFSET * 2; // 20000 year 100ns
98 
99     constexpr static uint64_t TO_100_NS = 10; // 1us to 100ns
100 
101     constexpr static Timestamp INVALID_TIMESTAMP = 0;
102 
103     // Get current system time
GetSysCurrentTime()104     static Timestamp GetSysCurrentTime()
105     {
106         uint64_t curTime = 0;
107         int errCode = GetCurrentSysTimeInMicrosecond(curTime);
108         if (errCode != E_OK) {
109             return INVALID_TIMESTAMP;
110         }
111 
112         std::lock_guard<std::mutex> lock(systemTimeLock_);
113         // If GetSysCurrentTime in 1us, we need increase the currentIncCount_
114         if (curTime == lastSystemTimeUs_) {
115             // if the currentIncCount_ has been increased MAX_INC_COUNT, keep the currentIncCount_
116             if (currentIncCount_ < MAX_INC_COUNT) {
117                 currentIncCount_++;
118             }
119         } else {
120             lastSystemTimeUs_ = curTime;
121             currentIncCount_ = 0;
122         }
123         return (curTime * TO_100_NS) + currentIncCount_; // Currently Timestamp is uint64_t
124     }
125 
126     // Init the TimeHelper
Initialize(Timestamp maxTimestamp)127     static void Initialize(Timestamp maxTimestamp)
128     {
129         std::lock_guard<std::mutex> lock(lastLocalTimeLock_);
130         if (lastLocalTime_ < maxTimestamp) {
131             lastLocalTime_ = maxTimestamp;
132         }
133     }
134 
GetTime(TimeOffset timeOffset)135     static Timestamp GetTime(TimeOffset timeOffset)
136     {
137         Timestamp currentSysTime = GetSysCurrentTime();
138         Timestamp currentLocalTime = currentSysTime + timeOffset;
139         std::lock_guard<std::mutex> lock(lastLocalTimeLock_);
140         if (currentLocalTime <= lastLocalTime_ || currentLocalTime > MAX_VALID_TIME) {
141             lastLocalTime_++;
142             currentLocalTime = lastLocalTime_;
143         } else {
144             lastLocalTime_ = currentLocalTime;
145         }
146         return currentLocalTime;
147     }
148 
149 private:
GetCurrentSysTimeInMicrosecond(uint64_t & outTime)150     static int GetCurrentSysTimeInMicrosecond(uint64_t &outTime)
151     {
152         struct timeval rawTime;
153         int errCode = gettimeofday(&rawTime, nullptr);
154         if (errCode < 0) {
155             return -E_ERROR;
156         }
157         outTime = static_cast<uint64_t>(rawTime.tv_sec) * MULTIPLES_BETWEEN_SECONDS_AND_MICROSECONDS +
158             static_cast<uint64_t>(rawTime.tv_usec);
159         return E_OK;
160     }
161 
162     static std::mutex systemTimeLock_;
163     static Timestamp lastSystemTimeUs_;
164     static Timestamp currentIncCount_;
165     static const uint64_t MAX_INC_COUNT = 9; // last bit from 0-9
166 
167     static Timestamp lastLocalTime_;
168     static std::mutex lastLocalTimeLock_;
169 };
170 
171 std::mutex TimeHelper::systemTimeLock_;
172 Timestamp TimeHelper::lastSystemTimeUs_ = 0;
173 Timestamp TimeHelper::currentIncCount_ = 0;
174 Timestamp TimeHelper::lastLocalTime_ = 0;
175 std::mutex TimeHelper::lastLocalTimeLock_;
176 
177 struct TransactFunc {
178     void (*xFunc)(sqlite3_context*, int, sqlite3_value**) = nullptr;
179     void (*xStep)(sqlite3_context*, int, sqlite3_value**) = nullptr;
180     void (*xFinal)(sqlite3_context*) = nullptr;
181     void(*xDestroy)(void*) = nullptr;
182 };
183 
RegisterFunction(sqlite3 * db,const std::string & funcName,int nArg,void * uData,TransactFunc & func)184 int RegisterFunction(sqlite3 *db, const std::string &funcName, int nArg, void *uData, TransactFunc &func)
185 {
186     if (db == nullptr) {
187         return -E_ERROR;
188     }
189     return sqlite3_create_function_v2(db, funcName.c_str(), nArg, SQLITE_UTF8 | SQLITE_DETERMINISTIC, uData,
190         func.xFunc, func.xStep, func.xFinal, func.xDestroy);
191 }
192 
CalcValueHash(const std::vector<uint8_t> & value,std::vector<uint8_t> & hashValue)193 int CalcValueHash(const std::vector<uint8_t> &value, std::vector<uint8_t> &hashValue)
194 {
195     ValueHashCalc hashCalc;
196     int errCode = hashCalc.Initialize();
197     if (errCode != E_OK) {
198         return -E_ERROR;
199     }
200 
201     errCode = hashCalc.Update(value);
202     if (errCode != E_OK) {
203         return -E_ERROR;
204     }
205 
206     errCode = hashCalc.GetResult(hashValue);
207     if (errCode != E_OK) {
208         return -E_ERROR;
209     }
210 
211     return E_OK;
212 }
213 
CalcHashKey(sqlite3_context * ctx,int argc,sqlite3_value ** argv)214 void CalcHashKey(sqlite3_context *ctx, int argc, sqlite3_value **argv)
215 {
216     // 1 means that the function only needs one parameter, namely key
217     if (ctx == nullptr || argc != 1 || argv == nullptr) {
218         return;
219     }
220     auto keyBlob = static_cast<const uint8_t *>(sqlite3_value_blob(argv[0]));
221     if (keyBlob == nullptr) {
222         sqlite3_result_error(ctx, "Parameters is invalid.", -1);
223         return;
224     }
225     int blobLen = sqlite3_value_bytes(argv[0]);
226     std::vector<uint8_t> value(keyBlob, keyBlob + blobLen);
227     std::vector<uint8_t> hashValue;
228     int errCode = CalcValueHash(value, hashValue);
229     if (errCode != E_OK) {
230         sqlite3_result_error(ctx, "Get hash value error.", -1);
231         return;
232     }
233     sqlite3_result_blob(ctx, hashValue.data(), hashValue.size(), SQLITE_TRANSIENT);
234     return;
235 }
236 
RegisterCalcHash(sqlite3 * db)237 int RegisterCalcHash(sqlite3 *db)
238 {
239     TransactFunc func;
240     func.xFunc = &CalcHashKey;
241     return RegisterFunction(db, "calc_hash", 1, nullptr, func);
242 }
243 
GetSysTime(sqlite3_context * ctx,int argc,sqlite3_value ** argv)244 void GetSysTime(sqlite3_context *ctx, int argc, sqlite3_value **argv)
245 {
246     if (ctx == nullptr || argc != 1 || argv == nullptr) { // 1: function need one parameter
247         return;
248     }
249     int timeOffset = static_cast<int64_t>(sqlite3_value_int64(argv[0]));
250     sqlite3_result_int64(ctx, (sqlite3_int64)TimeHelper::GetTime(timeOffset));
251 }
252 
RegisterGetSysTime(sqlite3 * db)253 int RegisterGetSysTime(sqlite3 *db)
254 {
255     TransactFunc func;
256     func.xFunc = &GetSysTime;
257     return RegisterFunction(db, "get_sys_time", 1, nullptr, func);
258 }
259 
ResetStatement(sqlite3_stmt * & stmt)260 int ResetStatement(sqlite3_stmt *&stmt)
261 {
262     if (stmt == nullptr || sqlite3_finalize(stmt) != SQLITE_OK) {
263         return -E_ERROR;
264     }
265     stmt = nullptr;
266     return E_OK;
267 }
268 
GetStatement(sqlite3 * db,const std::string & sql,sqlite3_stmt * & stmt)269 int GetStatement(sqlite3 *db, const std::string &sql, sqlite3_stmt *&stmt)
270 {
271     int errCode = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
272     if (errCode != SQLITE_OK) {
273         (void)ResetStatement(stmt);
274         return -E_ERROR;
275     }
276     return E_OK;
277 }
278 
ExecuteRawSQL(sqlite3 * db,const std::string & sql)279 int ExecuteRawSQL(sqlite3 *db, const std::string &sql)
280 {
281     if (db == nullptr) {
282         return -E_ERROR;
283     }
284     char *errMsg = nullptr;
285     int errCode = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &errMsg);
286     if (errCode != SQLITE_OK) {
287         errCode = -E_ERROR;
288     }
289 
290     if (errMsg != nullptr) {
291         sqlite3_free(errMsg);
292         errMsg = nullptr;
293     }
294     return errCode;
295 }
296 
StepWithRetry(sqlite3_stmt * stmt)297 int StepWithRetry(sqlite3_stmt *stmt)
298 {
299     if (stmt == nullptr) {
300         return -E_ERROR;
301     }
302     int errCode = sqlite3_step(stmt);
303     if (errCode != SQLITE_DONE && errCode != SQLITE_ROW) {
304         return -E_ERROR;
305     }
306     return errCode;
307 }
308 
GetColumnTestValue(sqlite3_stmt * stmt,int index,std::string & value)309 int GetColumnTestValue(sqlite3_stmt *stmt, int index, std::string &value)
310 {
311     if (stmt == nullptr) {
312         return -E_ERROR;
313     }
314     const unsigned char *val = sqlite3_column_text(stmt, index);
315     value = (val != nullptr) ? std::string(reinterpret_cast<const char *>(val)) : std::string();
316     return E_OK;
317 }
318 
GetCurrentMaxTimestamp(sqlite3 * db,Timestamp & maxTimestamp)319 int GetCurrentMaxTimestamp(sqlite3 *db, Timestamp &maxTimestamp)
320 {
321     if (db == nullptr) {
322         return -E_ERROR;
323     }
324     std::string checkTableSql = "SELECT name FROM sqlite_master WHERE type = 'table' AND " \
325         "name LIKE 'naturalbase_rdb_aux_%_log';";
326     sqlite3_stmt *checkTableStmt = nullptr;
327     int errCode = GetStatement(db, checkTableSql, checkTableStmt);
328     if (errCode != E_OK) {
329         return -E_ERROR;
330     }
331     while ((errCode = StepWithRetry(checkTableStmt)) != SQLITE_DONE) {
332         if (errCode != SQLITE_ROW) {
333             ResetStatement(checkTableStmt);
334             return -E_ERROR;
335         }
336         std::string logTablename;
337         GetColumnTestValue(checkTableStmt, 0, logTablename);
338         if (logTablename.empty()) {
339             continue;
340         }
341 
342         std::string getMaxTimestampSql = "SELECT MAX(timestamp) FROM " + logTablename + ";";
343         sqlite3_stmt *getTimeStmt = nullptr;
344         errCode = GetStatement(db, getMaxTimestampSql, getTimeStmt);
345         if (errCode != E_OK) {
346             continue;
347         }
348         errCode = StepWithRetry(getTimeStmt);
349         if (errCode != SQLITE_ROW) {
350             ResetStatement(getTimeStmt);
351             continue;
352         }
353         auto tableMaxTimestamp = static_cast<Timestamp>(sqlite3_column_int64(getTimeStmt, 0));
354         maxTimestamp = (maxTimestamp > tableMaxTimestamp) ? maxTimestamp : tableMaxTimestamp;
355         ResetStatement(getTimeStmt);
356     }
357     ResetStatement(checkTableStmt);
358     return E_OK;
359 }
360 
ClearTheLogAfterDropTable(void * db,int actionCode,const char * tblName,const char * useLessParam,const char * schemaName,const char * triggerName)361 int ClearTheLogAfterDropTable(void *db, int actionCode, const char *tblName,
362     const char *useLessParam, const char *schemaName, const char *triggerName)
363 {
364     (void)useLessParam;
365     (void)triggerName;
366     if (actionCode != SQLITE_DROP_TABLE) {
367         return SQLITE_OK;
368     }
369     if (db == nullptr || tblName == nullptr || schemaName == nullptr) {
370         return SQLITE_DENY;
371     }
372     auto filepath = sqlite3_db_filename(static_cast<sqlite3 *>(db), schemaName);
373     if (filepath == nullptr) {
374         return SQLITE_DENY;
375     }
376     auto filename = std::string(filepath);
377     std::thread th([filename, tableName = std::string(tblName), dropTimeStamp = TimeHelper::GetTime(0)] {
378         sqlite3 *db = nullptr;
379         (void)sqlite3_open(filename.c_str(), &db);
380         if (db == nullptr) {
381             return;
382         }
383 
384         if (sqlite3_busy_timeout(db, BUSY_TIMEOUT) != SQLITE_OK) {
385             sqlite3_close(db);
386             return;
387         }
388 
389         sqlite3_stmt *stmt = nullptr;
390         std::string logTblName = "naturalbase_rdb_aux_" + std::string(tableName) + "_log";
391         std::string sql = "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='" + logTblName + "';";
392         if (sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr) != SQLITE_OK) {
393             (void)sqlite3_finalize(stmt);
394             sqlite3_close(db);
395             return;
396         }
397 
398         bool isLogTblExists = false;
399         if (sqlite3_step(stmt) == SQLITE_ROW && static_cast<bool>(sqlite3_column_int(stmt, 0))) {
400             isLogTblExists = true;
401         }
402         (void)sqlite3_finalize(stmt);
403         stmt = nullptr;
404 
405         if (isLogTblExists) {
406             RegisterGetSysTime(db);
407             sql = "UPDATE " + logTblName + " SET flag=0x03, timestamp=get_sys_time(0) "
408                   "WHERE flag&0x03=0x02 AND timestamp<" + std::to_string(dropTimeStamp);
409             (void)sqlite3_exec(db, sql.c_str(), nullptr, nullptr, nullptr);
410         }
411         sqlite3_close(db);
412         return;
413     });
414     th.detach();
415     return SQLITE_OK;
416 }
417 
PostHandle(sqlite3 * db)418 void PostHandle(sqlite3 *db)
419 {
420     Timestamp currentMaxTimestamp = 0;
421     (void)GetCurrentMaxTimestamp(db, currentMaxTimestamp);
422     TimeHelper::Initialize(currentMaxTimestamp);
423     RegisterCalcHash(db);
424     RegisterGetSysTime(db);
425     (void)sqlite3_set_authorizer(db, &ClearTheLogAfterDropTable, db);
426     (void)sqlite3_busy_timeout(db, BUSY_TIMEOUT);
427     std::string recursiveTrigger = "PRAGMA recursive_triggers = ON;";
428     (void)ExecuteRawSQL(db, recursiveTrigger);
429 }
430 }
431 
sqlite3_open_relational(const char * filename,sqlite3 ** ppDb)432 SQLITE_API int sqlite3_open_relational(const char *filename, sqlite3 **ppDb)
433 {
434     int err = sqlite3_open(filename, ppDb);
435     if (err != SQLITE_OK) {
436         return err;
437     }
438     PostHandle(*ppDb);
439     return err;
440 }
441 
sqlite3_open16_relational(const void * filename,sqlite3 ** ppDb)442 SQLITE_API int sqlite3_open16_relational(const void *filename, sqlite3 **ppDb)
443 {
444     int err = sqlite3_open16(filename, ppDb);
445     if (err != SQLITE_OK) {
446         return err;
447     }
448     PostHandle(*ppDb);
449     return err;
450 }
451 
sqlite3_open_v2_relational(const char * filename,sqlite3 ** ppDb,int flags,const char * zVfs)452 SQLITE_API int sqlite3_open_v2_relational(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs)
453 {
454     int err = sqlite3_open_v2(filename, ppDb, flags, zVfs);
455     if (err != SQLITE_OK) {
456         return err;
457     }
458     PostHandle(*ppDb);
459     return err;
460 }
461 
462 // hw export the symbols
463 #ifdef SQLITE_DISTRIBUTE_RELATIONAL
464 #if defined(__GNUC__)
465 #  define EXPORT_SYMBOLS  __attribute__ ((visibility ("default")))
466 #elif defined(_MSC_VER)
467     #  define EXPORT_SYMBOLS  __declspec(dllexport)
468 #else
469 #  define EXPORT_SYMBOLS
470 #endif
471 
472 struct sqlite3_api_routines_relational {
473     int (*open)(const char *, sqlite3 **);
474     int (*open16)(const void *, sqlite3 **);
475     int (*open_v2)(const char *, sqlite3 **, int, const char *);
476 };
477 
478 typedef struct sqlite3_api_routines_relational sqlite3_api_routines_relational;
479 static const sqlite3_api_routines_relational sqlite3HwApis = {
480 #ifdef SQLITE_DISTRIBUTE_RELATIONAL
481     sqlite3_open_relational,
482     sqlite3_open16_relational,
483     sqlite3_open_v2_relational
484 #else
485     0,
486     0,
487     0
488 #endif
489 };
490 
491 EXPORT_SYMBOLS const sqlite3_api_routines_relational *sqlite3_export_relational_symbols = &sqlite3HwApis;
492 #endif