• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 //! This module provides data operation management on database tables.
17 //! The managed data can be user input. Because we will prepare and bind data.
18 
19 use core::ffi::c_void;
20 use std::cmp::Ordering;
21 
22 use asset_definition::{log_throw_error, Conversion, DataType, ErrCode, Result, Value};
23 use asset_log::logi;
24 
25 use crate::{
26     database::Database,
27     statement::Statement,
28     transaction::Transaction,
29     types::{ColumnInfo, DbMap, QueryOptions, UpgradeColumnInfo, DB_UPGRADE_VERSION, SQLITE_ROW},
30 };
31 
32 extern "C" {
SqliteChanges(db: *mut c_void) -> i3233     fn SqliteChanges(db: *mut c_void) -> i32;
34 }
35 
36 #[repr(C)]
37 pub(crate) struct Table<'a> {
38     pub(crate) table_name: String,
39     pub(crate) db: &'a Database,
40 }
41 
42 #[inline(always)]
bind_datas(datas: &DbMap, stmt: &Statement, index: &mut i32) -> Result<()>43 fn bind_datas(datas: &DbMap, stmt: &Statement, index: &mut i32) -> Result<()> {
44     for (_, value) in datas.iter() {
45         stmt.bind_data(*index, value)?;
46         *index += 1;
47     }
48     Ok(())
49 }
50 
bind_where_datas(datas: &DbMap, stmt: &Statement, index: &mut i32) -> Result<()>51 fn bind_where_datas(datas: &DbMap, stmt: &Statement, index: &mut i32) -> Result<()> {
52     for (key, value) in datas.iter() {
53         if *key == "SyncType" {
54             stmt.bind_data(*index, value)?;
55             *index += 1;
56         }
57         stmt.bind_data(*index, value)?;
58         *index += 1;
59     }
60     Ok(())
61 }
62 
bind_where_with_specific_condifion(datas: &[Value], stmt: &Statement, index: &mut i32) -> Result<()>63 fn bind_where_with_specific_condifion(datas: &[Value], stmt: &Statement, index: &mut i32) -> Result<()> {
64     for value in datas.iter() {
65         stmt.bind_data(*index, value)?;
66         *index += 1;
67     }
68     Ok(())
69 }
70 
71 #[inline(always)]
build_sql_columns_not_empty(columns: &Vec<&str>, sql: &mut String)72 fn build_sql_columns_not_empty(columns: &Vec<&str>, sql: &mut String) {
73     for i in 0..columns.len() {
74         let column = &columns[i];
75         sql.push_str(column);
76         if i != columns.len() - 1 {
77             sql.push(',');
78         }
79     }
80 }
81 
82 #[inline(always)]
build_sql_columns(columns: &Vec<&str>, sql: &mut String)83 fn build_sql_columns(columns: &Vec<&str>, sql: &mut String) {
84     if !columns.is_empty() {
85         build_sql_columns_not_empty(columns, sql);
86     } else {
87         sql.push('*');
88     }
89 }
90 
91 #[inline(always)]
build_sql_where(conditions: &DbMap, filter: bool, sql: &mut String)92 fn build_sql_where(conditions: &DbMap, filter: bool, sql: &mut String) {
93     if !conditions.is_empty() || filter {
94         sql.push_str(" where ");
95         if filter {
96             sql.push_str("SyncStatus <> 2");
97             if !conditions.is_empty() {
98                 sql.push_str(" and ");
99             }
100         }
101         if !conditions.is_empty() {
102             for (i, column_name) in conditions.keys().enumerate() {
103                 if *column_name == "SyncType" {
104                     sql.push_str("(SyncType & ?) = ?");
105                 } else {
106                     sql.push_str(column_name);
107                     sql.push_str("=?");
108                 }
109                 if i != conditions.len() - 1 {
110                     sql.push_str(" and ")
111                 }
112             }
113         }
114     }
115 }
116 
117 #[inline(always)]
build_sql_values(len: usize, sql: &mut String)118 fn build_sql_values(len: usize, sql: &mut String) {
119     for i in 0..len {
120         sql.push('?');
121         if i != len - 1 {
122             sql.push(',');
123         }
124     }
125 }
126 
from_data_type_to_str(value: &DataType) -> &'static str127 fn from_data_type_to_str(value: &DataType) -> &'static str {
128     match *value {
129         DataType::Bytes => "BLOB",
130         DataType::Number => "INTEGER",
131         DataType::Bool => "INTEGER",
132     }
133 }
134 
from_data_value_to_str_value(value: &Value) -> String135 fn from_data_value_to_str_value(value: &Value) -> String {
136     match *value {
137         Value::Number(i) => format!("{}", i),
138         Value::Bytes(_) => String::from("NOT SUPPORTED"),
139         Value::Bool(b) => format!("{}", b),
140     }
141 }
142 
build_sql_query_options(query_options: Option<&QueryOptions>, sql: &mut String)143 fn build_sql_query_options(query_options: Option<&QueryOptions>, sql: &mut String) {
144     if let Some(option) = query_options {
145         if let Some(order_by) = &option.order_by {
146             if !order_by.is_empty() {
147                 sql.push_str(" order by ");
148                 build_sql_columns_not_empty(order_by, sql);
149             }
150         }
151         if let Some(order) = option.order {
152             let str = if order == Ordering::Greater {
153                 "ASC"
154             } else if order == Ordering::Less {
155                 "DESC"
156             } else {
157                 ""
158             };
159             sql.push_str(format!(" {}", str).as_str());
160         }
161         if let Some(limit) = option.limit {
162             sql.push_str(format!(" limit {}", limit).as_str());
163             if let Some(offset) = option.offset {
164                 sql.push_str(format!(" offset {}", offset).as_str());
165             }
166         } else if let Some(offset) = option.offset {
167             sql.push_str(format!(" limit -1 offset {}", offset).as_str());
168         }
169     }
170 }
171 
build_sql_reverse_condition(condition: &DbMap, reverse_condition: Option<&DbMap>, sql: &mut String)172 fn build_sql_reverse_condition(condition: &DbMap, reverse_condition: Option<&DbMap>, sql: &mut String) {
173     if let Some(conditions) = reverse_condition {
174         if !conditions.is_empty() {
175             if !condition.is_empty() {
176                 sql.push_str(" and ");
177             } else {
178                 sql.push_str(" where ");
179             }
180             for (i, column_name) in conditions.keys().enumerate() {
181                 if *column_name == "SyncType" {
182                     sql.push_str("(SyncType & ?) == 0");
183                 } else {
184                     sql.push_str(column_name);
185                     sql.push_str("<>?");
186                 }
187                 if i != conditions.len() - 1 {
188                     sql.push_str(" and ")
189                 }
190             }
191         }
192     }
193 }
194 
get_column_info(columns: &'static [ColumnInfo], db_column: &str) -> Result<&'static ColumnInfo>195 fn get_column_info(columns: &'static [ColumnInfo], db_column: &str) -> Result<&'static ColumnInfo> {
196     for column in columns.iter() {
197         if column.name.eq(db_column) {
198             return Ok(column);
199         }
200     }
201     log_throw_error!(ErrCode::DataCorrupted, "Database is corrupted.")
202 }
203 
204 impl<'a> Table<'a> {
new(table_name: &str, db: &'a Database) -> Table<'a>205     pub(crate) fn new(table_name: &str, db: &'a Database) -> Table<'a> {
206         Table { table_name: table_name.to_string(), db }
207     }
208 
exist(&self) -> Result<bool>209     pub(crate) fn exist(&self) -> Result<bool> {
210         let sql = format!("select * from sqlite_master where type ='table' and name = '{}'", self.table_name);
211         let stmt = Statement::prepare(sql.as_str(), self.db)?;
212         let ret = stmt.step()?;
213         if ret == SQLITE_ROW {
214             Ok(true)
215         } else {
216             Ok(false)
217         }
218     }
219 
220     #[allow(dead_code)]
delete(&self) -> Result<()>221     pub(crate) fn delete(&self) -> Result<()> {
222         let sql = format!("DROP TABLE {}", self.table_name);
223         self.db.exec(&sql)
224     }
225 
226     /// Create a table with name 'table_name' at specific version.
227     /// The columns is descriptions for each column.
create_with_version(&self, columns: &[ColumnInfo], version: u32) -> Result<()>228     pub(crate) fn create_with_version(&self, columns: &[ColumnInfo], version: u32) -> Result<()> {
229         let is_exist = self.exist()?;
230         if is_exist {
231             return Ok(());
232         }
233         let mut sql = format!("CREATE TABLE IF NOT EXISTS {}(", self.table_name);
234         for i in 0..columns.len() {
235             let column = &columns[i];
236             sql.push_str(column.name);
237             sql.push(' ');
238             sql.push_str(from_data_type_to_str(&column.data_type));
239             if column.is_primary_key {
240                 sql.push_str(" PRIMARY KEY");
241             }
242             if column.not_null {
243                 sql.push_str(" NOT NULL");
244             }
245             if i != columns.len() - 1 {
246                 sql.push(',')
247             };
248         }
249         sql.push_str(");");
250         let mut trans = Transaction::new(self.db);
251         trans.begin()?;
252         if self.db.exec(sql.as_str()).is_ok() && self.db.set_version(version).is_ok() {
253             trans.commit()
254         } else {
255             trans.rollback()
256         }
257     }
258 
259     /// Create a table with name 'table_name'.
260     /// The columns is descriptions for each column.
create(&self, columns: &[ColumnInfo]) -> Result<()>261     pub(crate) fn create(&self, columns: &[ColumnInfo]) -> Result<()> {
262         self.create_with_version(columns, DB_UPGRADE_VERSION)
263     }
264 
upgrade(&self, ver: u32, columns: &[UpgradeColumnInfo]) -> Result<()>265     pub(crate) fn upgrade(&self, ver: u32, columns: &[UpgradeColumnInfo]) -> Result<()> {
266         let is_exist = self.exist()?;
267         if !is_exist {
268             return Ok(());
269         }
270         logi!("upgrade table!");
271         let mut trans = Transaction::new(self.db);
272         trans.begin()?;
273         for item in columns {
274             if self.add_column(&item.base_info, &item.default_value).is_err() {
275                 return trans.rollback();
276             }
277         }
278         if self.db.set_version(ver).is_err() {
279             trans.rollback()
280         } else {
281             trans.commit()
282         }
283     }
284 
285     /// Insert a row into table, and datas is the value to be insert.
286     ///
287     /// # Examples
288     ///
289     /// ```
290     /// // SQL: insert into table_name(id,alias) values (3,'alias1')
291     /// let datas = &DbMap::from([("id", Value::Number(3), ("alias", Value::Bytes(b"alias1"))]);
292     /// let ret = table.insert_row(datas);
293     /// ```
insert_row(&self, datas: &DbMap) -> Result<i32>294     pub(crate) fn insert_row(&self, datas: &DbMap) -> Result<i32> {
295         let mut sql = format!("insert into {} (", self.table_name);
296         for (i, column_name) in datas.keys().enumerate() {
297             sql.push_str(column_name);
298             if i != datas.len() - 1 {
299                 sql.push(',');
300             }
301         }
302 
303         sql.push_str(") values (");
304         build_sql_values(datas.len(), &mut sql);
305         sql.push(')');
306         let stmt = Statement::prepare(&sql, self.db)?;
307         let mut index = 1;
308         bind_datas(datas, &stmt, &mut index)?;
309         stmt.step()?;
310         let count = unsafe { SqliteChanges(self.db.handle as _) };
311         Ok(count)
312     }
313 
314     /// Delete row from table.
315     ///
316     /// # Examples
317     ///
318     /// ```
319     /// // SQL: delete from table_name where id=2
320     /// let condition = &DbMap::from([("id", Value::Number(2)]);
321     /// let ret = table.delete_row(condition, None, false);
322     /// ```
delete_row( &self, condition: &DbMap, reverse_condition: Option<&DbMap>, is_filter_sync: bool, ) -> Result<i32>323     pub(crate) fn delete_row(
324         &self,
325         condition: &DbMap,
326         reverse_condition: Option<&DbMap>,
327         is_filter_sync: bool,
328     ) -> Result<i32> {
329         let mut sql = format!("delete from {}", self.table_name);
330         build_sql_where(condition, is_filter_sync, &mut sql);
331         build_sql_reverse_condition(condition, reverse_condition, &mut sql);
332         let stmt = Statement::prepare(&sql, self.db)?;
333         let mut index = 1;
334         bind_where_datas(condition, &stmt, &mut index)?;
335         if let Some(datas) = reverse_condition {
336             bind_datas(datas, &stmt, &mut index)?;
337         }
338         stmt.step()?;
339         let count = unsafe { SqliteChanges(self.db.handle as _) };
340         Ok(count)
341     }
342 
343     /// Delete row from table with specific condition.
344     ///
345     /// # Examples
346     ///
347     /// ```
348     /// // SQL: delete from table_name where id=2
349     /// let specific_cond = "id".to_string();
350     /// let condition_value = Value::Number(2);
351     /// let ret = table.delete_with_specific_cond(specific_cond, condition_value);
352     /// ```
delete_with_specific_cond(&self, specific_cond: &str, condition_value: &[Value]) -> Result<i32>353     pub(crate) fn delete_with_specific_cond(&self, specific_cond: &str, condition_value: &[Value]) -> Result<i32> {
354         let sql: String = format!("delete from {} where {}", self.table_name, specific_cond);
355         let stmt = Statement::prepare(&sql, self.db)?;
356         let mut index = 1;
357         bind_where_with_specific_condifion(condition_value, &stmt, &mut index)?;
358         stmt.step()?;
359         let count = unsafe { SqliteChanges(self.db.handle as _) };
360         Ok(count)
361     }
362 
363     /// Update a row in table.
364     ///
365     /// # Examples
366     ///
367     /// ```
368     /// // SQL: update table_name set alias='update_value' where id=2
369     /// let condition = &DbMap::from([("id", Value::Number(2)]);
370     /// let datas = &DbMap::from([("alias", Value::Bytes(b"update_value")]);
371     /// let ret = table.update_row(conditions, false, datas);
372     /// ```
update_row(&self, condition: &DbMap, is_filter_sync: bool, datas: &DbMap) -> Result<i32>373     pub(crate) fn update_row(&self, condition: &DbMap, is_filter_sync: bool, datas: &DbMap) -> Result<i32> {
374         let mut sql = format!("update {} set ", self.table_name);
375         for (i, column_name) in datas.keys().enumerate() {
376             sql.push_str(column_name);
377             sql.push_str("=?");
378             if i != datas.len() - 1 {
379                 sql.push(',');
380             }
381         }
382         build_sql_where(condition, is_filter_sync, &mut sql);
383         let stmt = Statement::prepare(&sql, self.db)?;
384         let mut index = 1;
385         bind_datas(datas, &stmt, &mut index)?;
386         bind_where_datas(condition, &stmt, &mut index)?;
387         stmt.step()?;
388         let count = unsafe { SqliteChanges(self.db.handle as _) };
389         Ok(count)
390     }
391 
392     /// Query row from table.
393     /// If length of columns is 0, all table columns are queried. (eg. select * xxx)
394     /// If length of condition is 0, all data in the table is queried.
395     ///
396     /// # Examples
397     ///
398     /// ```
399     /// // SQL: select alias,blobs from table_name
400     /// let result_set = table.query_datas_with_key_value(&vec!["alias", "blobs"], false, &vec![]);
401     /// ```
query_row( &self, columns: &Vec<&'static str>, condition: &DbMap, query_options: Option<&QueryOptions>, is_filter_sync: bool, column_info: &'static [ColumnInfo], ) -> Result<Vec<DbMap>>402     pub(crate) fn query_row(
403         &self,
404         columns: &Vec<&'static str>,
405         condition: &DbMap,
406         query_options: Option<&QueryOptions>,
407         is_filter_sync: bool,
408         column_info: &'static [ColumnInfo],
409     ) -> Result<Vec<DbMap>> {
410         let mut sql = String::from("select ");
411         if !columns.is_empty() {
412             sql.push_str("distinct ");
413         }
414         build_sql_columns(columns, &mut sql);
415         sql.push_str(" from ");
416         sql.push_str(self.table_name.as_str());
417         build_sql_where(condition, is_filter_sync, &mut sql);
418         build_sql_query_options(query_options, &mut sql);
419         let stmt = Statement::prepare(&sql, self.db)?;
420         let mut index = 1;
421         bind_where_datas(condition, &stmt, &mut index)?;
422         let mut result = vec![];
423         while stmt.step()? == SQLITE_ROW {
424             let mut record = DbMap::new();
425             let n = stmt.data_count();
426             for i in 0..n {
427                 let column_name = stmt.query_column_name(i)?;
428                 let column_info = get_column_info(column_info, column_name)?;
429                 match stmt.query_column_auto_type(i)? {
430                     Some(Value::Number(n)) if column_info.data_type == DataType::Bool => {
431                         record.insert(column_info.name, Value::Bool(n != 0))
432                     },
433                     Some(n) if n.data_type() == column_info.data_type => record.insert(column_info.name, n),
434                     Some(_) => {
435                         return log_throw_error!(ErrCode::DataCorrupted, "The data in DB has been tampered with.")
436                     },
437                     None => continue,
438                 };
439             }
440             result.push(record);
441         }
442         Ok(result)
443     }
444 
445     /// Count the number of datas with query condition(can be empty).
446     ///
447     /// # Examples
448     ///
449     /// ```
450     /// // SQL: select count(*) as count from table_name where id=3
451     /// let count = table.count_datas(&DbMap::from([("id", Value::Number(3))]), false);
452     /// ```
count_datas(&self, condition: &DbMap, is_filter_sync: bool) -> Result<u32>453     pub(crate) fn count_datas(&self, condition: &DbMap, is_filter_sync: bool) -> Result<u32> {
454         let mut sql = format!("select count(*) as count from {}", self.table_name);
455         build_sql_where(condition, is_filter_sync, &mut sql);
456         let stmt = Statement::prepare(&sql, self.db)?;
457         let mut index = 1;
458         bind_where_datas(condition, &stmt, &mut index)?;
459         stmt.step()?;
460         let count = stmt.query_column_int(0);
461         Ok(count)
462     }
463 
464     /// Check whether data exists in the database table.
465     ///
466     /// # Examples
467     ///
468     /// ```
469     /// // SQL: select count(*) as count from table_name where id=3 and alias='alias'
470     /// let exits = table
471     ///     .is_data_exists(&DbMap::from([("id", Value::Number(3)), ("alias", Value::Bytes(b"alias"))]), false);
472     /// ```
is_data_exists(&self, cond: &DbMap, is_filter_sync: bool) -> Result<bool>473     pub(crate) fn is_data_exists(&self, cond: &DbMap, is_filter_sync: bool) -> Result<bool> {
474         let ret = self.count_datas(cond, is_filter_sync);
475         match ret {
476             Ok(count) => Ok(count > 0),
477             Err(e) => Err(e),
478         }
479     }
480 
481     /// Add new column tp table.
482     /// 1. Primary key cannot be added.
483     /// 2. Cannot add a non-null column with no default value
484     /// 3. Only the integer and blob types support the default value, and the default value of the blob type is null.
485     ///
486     /// # Examples
487     ///
488     /// ```
489     /// // SQL: alter table table_name add cloumn id integer not null
490     /// let ret = table.add_column(
491     ///     ColumnInfo {
492     ///         name: "id",
493     ///         data_type: DataType::INTEGER,
494     ///         is_primary_key: false,
495     ///         not_null: true,
496     ///     },
497     ///     Some(Value::Number(0)),
498     /// );
499     /// ```
add_column(&self, column: &ColumnInfo, default_value: &Option<Value>) -> Result<()>500     pub(crate) fn add_column(&self, column: &ColumnInfo, default_value: &Option<Value>) -> Result<()> {
501         if column.is_primary_key {
502             return log_throw_error!(ErrCode::InvalidArgument, "The primary key already exists in the table.");
503         }
504         if column.not_null && default_value.is_none() {
505             return log_throw_error!(ErrCode::InvalidArgument, "A default value is required for a non-null column.");
506         }
507         let data_type = from_data_type_to_str(&column.data_type);
508         let mut sql = format!("ALTER TABLE {} ADD COLUMN {} {}", self.table_name, column.name, data_type);
509         if let Some(data) = default_value {
510             sql.push_str(" DEFAULT ");
511             sql.push_str(&from_data_value_to_str_value(data));
512         }
513         if column.not_null {
514             sql.push_str(" NOT NULL");
515         }
516         self.db.exec(sql.as_str())
517     }
518 
replace_row(&self, condition: &DbMap, is_filter_sync: bool, datas: &DbMap) -> Result<()>519     pub(crate) fn replace_row(&self, condition: &DbMap, is_filter_sync: bool, datas: &DbMap) -> Result<()> {
520         let mut trans = Transaction::new(self.db);
521         trans.begin()?;
522         if self.delete_row(condition, None, is_filter_sync).is_ok() && self.insert_row(datas).is_ok() {
523             trans.commit()
524         } else {
525             trans.rollback()
526         }
527     }
528 }
529