1 /*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #include <mutex>
16 #include <openssl/sha.h>
17 #include <string>
18 #include <sys/time.h>
19 #include <thread>
20 #include <vector>
21
22 // using the "sqlite3sym.h" in OHOS
23 #ifndef USE_SQLITE_SYMBOLS
24 #include "sqlite3.h"
25 #else
26 #include "sqlite3sym.h"
27 #endif
28
29 namespace {
30 constexpr int E_OK = 0;
31 constexpr int E_ERROR = 1;
32 constexpr int BUSY_TIMEOUT = 2000; // 2s.
33 class ValueHashCalc {
34 public:
ValueHashCalc()35 ValueHashCalc() {};
~ValueHashCalc()36 ~ValueHashCalc()
37 {
38 delete context_;
39 context_ = nullptr;
40 }
41
Initialize()42 int Initialize()
43 {
44 context_ = new (std::nothrow) SHA256_CTX;
45 if (context_ == nullptr) {
46 return -E_ERROR;
47 }
48
49 int errCode = SHA256_Init(context_);
50 if (errCode == 0) {
51 return -E_ERROR;
52 }
53 return E_OK;
54 }
55
Update(const std::vector<uint8_t> & value)56 int Update(const std::vector<uint8_t> &value)
57 {
58 if (context_ == nullptr) {
59 return -E_ERROR;
60 }
61 int errCode = SHA256_Update(context_, value.data(), value.size());
62 if (errCode == 0) {
63 return -E_ERROR;
64 }
65 return E_OK;
66 }
67
GetResult(std::vector<uint8_t> & value)68 int GetResult(std::vector<uint8_t> &value)
69 {
70 if (context_ == nullptr) {
71 return -E_ERROR;
72 }
73
74 value.resize(SHA256_DIGEST_LENGTH);
75 int errCode = SHA256_Final(value.data(), context_);
76 if (errCode == 0) {
77 return -E_ERROR;
78 }
79
80 return E_OK;
81 }
82
83 private:
84 SHA256_CTX *context_ = nullptr;
85 };
86
87
88 const uint64_t MULTIPLES_BETWEEN_SECONDS_AND_MICROSECONDS = 1000000;
89
90 using Timestamp = uint64_t;
91 using TimeOffset = int64_t;
92
93 class TimeHelper {
94 public:
95 constexpr static int64_t BASE_OFFSET = 10000LL * 365LL * 24LL * 3600LL * 1000LL * 1000LL * 10L; // 10000 year 100ns
96
97 constexpr static int64_t MAX_VALID_TIME = BASE_OFFSET * 2; // 20000 year 100ns
98
99 constexpr static uint64_t TO_100_NS = 10; // 1us to 100ns
100
101 constexpr static Timestamp INVALID_TIMESTAMP = 0;
102
103 // Get current system time
GetSysCurrentTime()104 static Timestamp GetSysCurrentTime()
105 {
106 uint64_t curTime = 0;
107 int errCode = GetCurrentSysTimeInMicrosecond(curTime);
108 if (errCode != E_OK) {
109 return INVALID_TIMESTAMP;
110 }
111
112 std::lock_guard<std::mutex> lock(systemTimeLock_);
113 // If GetSysCurrentTime in 1us, we need increase the currentIncCount_
114 if (curTime == lastSystemTimeUs_) {
115 // if the currentIncCount_ has been increased MAX_INC_COUNT, keep the currentIncCount_
116 if (currentIncCount_ < MAX_INC_COUNT) {
117 currentIncCount_++;
118 }
119 } else {
120 lastSystemTimeUs_ = curTime;
121 currentIncCount_ = 0;
122 }
123 return (curTime * TO_100_NS) + currentIncCount_; // Currently Timestamp is uint64_t
124 }
125
126 // Init the TimeHelper
Initialize(Timestamp maxTimestamp)127 static void Initialize(Timestamp maxTimestamp)
128 {
129 std::lock_guard<std::mutex> lock(lastLocalTimeLock_);
130 if (lastLocalTime_ < maxTimestamp) {
131 lastLocalTime_ = maxTimestamp;
132 }
133 }
134
GetTime(TimeOffset timeOffset)135 static Timestamp GetTime(TimeOffset timeOffset)
136 {
137 Timestamp currentSysTime = GetSysCurrentTime();
138 Timestamp currentLocalTime = currentSysTime + timeOffset;
139 std::lock_guard<std::mutex> lock(lastLocalTimeLock_);
140 if (currentLocalTime <= lastLocalTime_ || currentLocalTime > MAX_VALID_TIME) {
141 lastLocalTime_++;
142 currentLocalTime = lastLocalTime_;
143 } else {
144 lastLocalTime_ = currentLocalTime;
145 }
146 return currentLocalTime;
147 }
148
149 private:
GetCurrentSysTimeInMicrosecond(uint64_t & outTime)150 static int GetCurrentSysTimeInMicrosecond(uint64_t &outTime)
151 {
152 struct timeval rawTime;
153 int errCode = gettimeofday(&rawTime, nullptr);
154 if (errCode < 0) {
155 return -E_ERROR;
156 }
157 outTime = static_cast<uint64_t>(rawTime.tv_sec) * MULTIPLES_BETWEEN_SECONDS_AND_MICROSECONDS +
158 static_cast<uint64_t>(rawTime.tv_usec);
159 return E_OK;
160 }
161
162 static std::mutex systemTimeLock_;
163 static Timestamp lastSystemTimeUs_;
164 static Timestamp currentIncCount_;
165 static const uint64_t MAX_INC_COUNT = 9; // last bit from 0-9
166
167 static Timestamp lastLocalTime_;
168 static std::mutex lastLocalTimeLock_;
169 };
170
171 std::mutex TimeHelper::systemTimeLock_;
172 Timestamp TimeHelper::lastSystemTimeUs_ = 0;
173 Timestamp TimeHelper::currentIncCount_ = 0;
174 Timestamp TimeHelper::lastLocalTime_ = 0;
175 std::mutex TimeHelper::lastLocalTimeLock_;
176
177 struct TransactFunc {
178 void (*xFunc)(sqlite3_context*, int, sqlite3_value**) = nullptr;
179 void (*xStep)(sqlite3_context*, int, sqlite3_value**) = nullptr;
180 void (*xFinal)(sqlite3_context*) = nullptr;
181 void(*xDestroy)(void*) = nullptr;
182 };
183
RegisterFunction(sqlite3 * db,const std::string & funcName,int nArg,void * uData,TransactFunc & func)184 int RegisterFunction(sqlite3 *db, const std::string &funcName, int nArg, void *uData, TransactFunc &func)
185 {
186 if (db == nullptr) {
187 return -E_ERROR;
188 }
189 return sqlite3_create_function_v2(db, funcName.c_str(), nArg, SQLITE_UTF8 | SQLITE_DETERMINISTIC, uData,
190 func.xFunc, func.xStep, func.xFinal, func.xDestroy);
191 }
192
CalcValueHash(const std::vector<uint8_t> & value,std::vector<uint8_t> & hashValue)193 int CalcValueHash(const std::vector<uint8_t> &value, std::vector<uint8_t> &hashValue)
194 {
195 ValueHashCalc hashCalc;
196 int errCode = hashCalc.Initialize();
197 if (errCode != E_OK) {
198 return -E_ERROR;
199 }
200
201 errCode = hashCalc.Update(value);
202 if (errCode != E_OK) {
203 return -E_ERROR;
204 }
205
206 errCode = hashCalc.GetResult(hashValue);
207 if (errCode != E_OK) {
208 return -E_ERROR;
209 }
210
211 return E_OK;
212 }
213
CalcHashKey(sqlite3_context * ctx,int argc,sqlite3_value ** argv)214 void CalcHashKey(sqlite3_context *ctx, int argc, sqlite3_value **argv)
215 {
216 // 1 means that the function only needs one parameter, namely key
217 if (ctx == nullptr || argc != 1 || argv == nullptr) {
218 return;
219 }
220 auto keyBlob = static_cast<const uint8_t *>(sqlite3_value_blob(argv[0]));
221 if (keyBlob == nullptr) {
222 sqlite3_result_error(ctx, "Parameters is invalid.", -1);
223 return;
224 }
225 int blobLen = sqlite3_value_bytes(argv[0]);
226 std::vector<uint8_t> value(keyBlob, keyBlob + blobLen);
227 std::vector<uint8_t> hashValue;
228 int errCode = CalcValueHash(value, hashValue);
229 if (errCode != E_OK) {
230 sqlite3_result_error(ctx, "Get hash value error.", -1);
231 return;
232 }
233 sqlite3_result_blob(ctx, hashValue.data(), hashValue.size(), SQLITE_TRANSIENT);
234 return;
235 }
236
RegisterCalcHash(sqlite3 * db)237 int RegisterCalcHash(sqlite3 *db)
238 {
239 TransactFunc func;
240 func.xFunc = &CalcHashKey;
241 return RegisterFunction(db, "calc_hash", 1, nullptr, func);
242 }
243
GetSysTime(sqlite3_context * ctx,int argc,sqlite3_value ** argv)244 void GetSysTime(sqlite3_context *ctx, int argc, sqlite3_value **argv)
245 {
246 if (ctx == nullptr || argc != 1 || argv == nullptr) { // 1: function need one parameter
247 return;
248 }
249 int timeOffset = static_cast<int64_t>(sqlite3_value_int64(argv[0]));
250 sqlite3_result_int64(ctx, (sqlite3_int64)TimeHelper::GetTime(timeOffset));
251 }
252
RegisterGetSysTime(sqlite3 * db)253 int RegisterGetSysTime(sqlite3 *db)
254 {
255 TransactFunc func;
256 func.xFunc = &GetSysTime;
257 return RegisterFunction(db, "get_sys_time", 1, nullptr, func);
258 }
259
ResetStatement(sqlite3_stmt * & stmt)260 int ResetStatement(sqlite3_stmt *&stmt)
261 {
262 if (stmt == nullptr || sqlite3_finalize(stmt) != SQLITE_OK) {
263 return -E_ERROR;
264 }
265 stmt = nullptr;
266 return E_OK;
267 }
268
GetStatement(sqlite3 * db,const std::string & sql,sqlite3_stmt * & stmt)269 int GetStatement(sqlite3 *db, const std::string &sql, sqlite3_stmt *&stmt)
270 {
271 int errCode = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
272 if (errCode != SQLITE_OK) {
273 (void)ResetStatement(stmt);
274 return -E_ERROR;
275 }
276 return E_OK;
277 }
278
ExecuteRawSQL(sqlite3 * db,const std::string & sql)279 int ExecuteRawSQL(sqlite3 *db, const std::string &sql)
280 {
281 if (db == nullptr) {
282 return -E_ERROR;
283 }
284 char *errMsg = nullptr;
285 int errCode = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &errMsg);
286 if (errCode != SQLITE_OK) {
287 errCode = -E_ERROR;
288 }
289
290 if (errMsg != nullptr) {
291 sqlite3_free(errMsg);
292 errMsg = nullptr;
293 }
294 return errCode;
295 }
296
StepWithRetry(sqlite3_stmt * stmt)297 int StepWithRetry(sqlite3_stmt *stmt)
298 {
299 if (stmt == nullptr) {
300 return -E_ERROR;
301 }
302 int errCode = sqlite3_step(stmt);
303 if (errCode != SQLITE_DONE && errCode != SQLITE_ROW) {
304 return -E_ERROR;
305 }
306 return errCode;
307 }
308
GetColumnTestValue(sqlite3_stmt * stmt,int index,std::string & value)309 int GetColumnTestValue(sqlite3_stmt *stmt, int index, std::string &value)
310 {
311 if (stmt == nullptr) {
312 return -E_ERROR;
313 }
314 const unsigned char *val = sqlite3_column_text(stmt, index);
315 value = (val != nullptr) ? std::string(reinterpret_cast<const char *>(val)) : std::string();
316 return E_OK;
317 }
318
GetCurrentMaxTimestamp(sqlite3 * db,Timestamp & maxTimestamp)319 int GetCurrentMaxTimestamp(sqlite3 *db, Timestamp &maxTimestamp)
320 {
321 if (db == nullptr) {
322 return -E_ERROR;
323 }
324 std::string checkTableSql = "SELECT name FROM sqlite_master WHERE type = 'table' AND " \
325 "name LIKE 'naturalbase_rdb_aux_%_log';";
326 sqlite3_stmt *checkTableStmt = nullptr;
327 int errCode = GetStatement(db, checkTableSql, checkTableStmt);
328 if (errCode != E_OK) {
329 return -E_ERROR;
330 }
331 while ((errCode = StepWithRetry(checkTableStmt)) != SQLITE_DONE) {
332 if (errCode != SQLITE_ROW) {
333 ResetStatement(checkTableStmt);
334 return -E_ERROR;
335 }
336 std::string logTablename;
337 GetColumnTestValue(checkTableStmt, 0, logTablename);
338 if (logTablename.empty()) {
339 continue;
340 }
341
342 std::string getMaxTimestampSql = "SELECT MAX(timestamp) FROM " + logTablename + ";";
343 sqlite3_stmt *getTimeStmt = nullptr;
344 errCode = GetStatement(db, getMaxTimestampSql, getTimeStmt);
345 if (errCode != E_OK) {
346 continue;
347 }
348 errCode = StepWithRetry(getTimeStmt);
349 if (errCode != SQLITE_ROW) {
350 ResetStatement(getTimeStmt);
351 continue;
352 }
353 auto tableMaxTimestamp = static_cast<Timestamp>(sqlite3_column_int64(getTimeStmt, 0));
354 maxTimestamp = (maxTimestamp > tableMaxTimestamp) ? maxTimestamp : tableMaxTimestamp;
355 ResetStatement(getTimeStmt);
356 }
357 ResetStatement(checkTableStmt);
358 return E_OK;
359 }
360
ClearTheLogAfterDropTable(void * db,int actionCode,const char * tblName,const char * useLessParam,const char * schemaName,const char * triggerName)361 int ClearTheLogAfterDropTable(void *db, int actionCode, const char *tblName,
362 const char *useLessParam, const char *schemaName, const char *triggerName)
363 {
364 (void)useLessParam;
365 (void)triggerName;
366 if (actionCode != SQLITE_DROP_TABLE) {
367 return SQLITE_OK;
368 }
369 if (db == nullptr || tblName == nullptr || schemaName == nullptr) {
370 return SQLITE_DENY;
371 }
372 auto filepath = sqlite3_db_filename(static_cast<sqlite3 *>(db), schemaName);
373 if (filepath == nullptr) {
374 return SQLITE_DENY;
375 }
376 auto filename = std::string(filepath);
377 std::thread th([filename, tableName = std::string(tblName), dropTimeStamp = TimeHelper::GetTime(0)] {
378 sqlite3 *db = nullptr;
379 (void)sqlite3_open(filename.c_str(), &db);
380 if (db == nullptr) {
381 return;
382 }
383
384 if (sqlite3_busy_timeout(db, BUSY_TIMEOUT) != SQLITE_OK) {
385 sqlite3_close(db);
386 return;
387 }
388
389 sqlite3_stmt *stmt = nullptr;
390 std::string logTblName = "naturalbase_rdb_aux_" + std::string(tableName) + "_log";
391 std::string sql = "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='" + logTblName + "';";
392 if (sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr) != SQLITE_OK) {
393 (void)sqlite3_finalize(stmt);
394 sqlite3_close(db);
395 return;
396 }
397
398 bool isLogTblExists = false;
399 if (sqlite3_step(stmt) == SQLITE_ROW && static_cast<bool>(sqlite3_column_int(stmt, 0))) {
400 isLogTblExists = true;
401 }
402 (void)sqlite3_finalize(stmt);
403 stmt = nullptr;
404
405 if (isLogTblExists) {
406 RegisterGetSysTime(db);
407 sql = "UPDATE " + logTblName + " SET flag=0x03, timestamp=get_sys_time(0) "
408 "WHERE flag&0x03=0x02 AND timestamp<" + std::to_string(dropTimeStamp);
409 (void)sqlite3_exec(db, sql.c_str(), nullptr, nullptr, nullptr);
410 }
411 sqlite3_close(db);
412 return;
413 });
414 th.detach();
415 return SQLITE_OK;
416 }
417
PostHandle(sqlite3 * db)418 void PostHandle(sqlite3 *db)
419 {
420 Timestamp currentMaxTimestamp = 0;
421 (void)GetCurrentMaxTimestamp(db, currentMaxTimestamp);
422 TimeHelper::Initialize(currentMaxTimestamp);
423 RegisterCalcHash(db);
424 RegisterGetSysTime(db);
425 (void)sqlite3_set_authorizer(db, &ClearTheLogAfterDropTable, db);
426 (void)sqlite3_busy_timeout(db, BUSY_TIMEOUT);
427 std::string recursiveTrigger = "PRAGMA recursive_triggers = ON;";
428 (void)ExecuteRawSQL(db, recursiveTrigger);
429 }
430 }
431
sqlite3_open_relational(const char * filename,sqlite3 ** ppDb)432 SQLITE_API int sqlite3_open_relational(const char *filename, sqlite3 **ppDb)
433 {
434 int err = sqlite3_open(filename, ppDb);
435 if (err != SQLITE_OK) {
436 return err;
437 }
438 PostHandle(*ppDb);
439 return err;
440 }
441
sqlite3_open16_relational(const void * filename,sqlite3 ** ppDb)442 SQLITE_API int sqlite3_open16_relational(const void *filename, sqlite3 **ppDb)
443 {
444 int err = sqlite3_open16(filename, ppDb);
445 if (err != SQLITE_OK) {
446 return err;
447 }
448 PostHandle(*ppDb);
449 return err;
450 }
451
sqlite3_open_v2_relational(const char * filename,sqlite3 ** ppDb,int flags,const char * zVfs)452 SQLITE_API int sqlite3_open_v2_relational(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs)
453 {
454 int err = sqlite3_open_v2(filename, ppDb, flags, zVfs);
455 if (err != SQLITE_OK) {
456 return err;
457 }
458 PostHandle(*ppDb);
459 return err;
460 }
461
462 // hw export the symbols
463 #ifdef SQLITE_DISTRIBUTE_RELATIONAL
464 #if defined(__GNUC__)
465 # define EXPORT_SYMBOLS __attribute__ ((visibility ("default")))
466 #elif defined(_MSC_VER)
467 # define EXPORT_SYMBOLS __declspec(dllexport)
468 #else
469 # define EXPORT_SYMBOLS
470 #endif
471
472 struct sqlite3_api_routines_relational {
473 int (*open)(const char *, sqlite3 **);
474 int (*open16)(const void *, sqlite3 **);
475 int (*open_v2)(const char *, sqlite3 **, int, const char *);
476 };
477
478 typedef struct sqlite3_api_routines_relational sqlite3_api_routines_relational;
479 static const sqlite3_api_routines_relational sqlite3HwApis = {
480 #ifdef SQLITE_DISTRIBUTE_RELATIONAL
481 sqlite3_open_relational,
482 sqlite3_open16_relational,
483 sqlite3_open_v2_relational
484 #else
485 0,
486 0,
487 0
488 #endif
489 };
490
491 EXPORT_SYMBOLS const sqlite3_api_routines_relational *sqlite3_export_relational_symbols = &sqlite3HwApis;
492 #endif