• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::fmt::Debug;
2 use std::os::raw::{c_char, c_int, c_void};
3 use std::panic::catch_unwind;
4 use std::ptr;
5 
6 use super::expect_utf8;
7 use super::free_boxed_hook;
8 use super::Action;
9 use crate::error::check;
10 use crate::ffi;
11 use crate::inner_connection::InnerConnection;
12 use crate::types::ValueRef;
13 use crate::Connection;
14 use crate::Result;
15 
16 /// The possible cases for when a PreUpdateHook gets triggered. Allows access to the relevant
17 /// functions for each case through the contained values.
18 #[derive(Debug)]
19 pub enum PreUpdateCase {
20     /// Pre-update hook was triggered by an insert.
21     Insert(PreUpdateNewValueAccessor),
22     /// Pre-update hook was triggered by a delete.
23     Delete(PreUpdateOldValueAccessor),
24     /// Pre-update hook was triggered by an update.
25     Update {
26         #[allow(missing_docs)]
27         old_value_accessor: PreUpdateOldValueAccessor,
28         #[allow(missing_docs)]
29         new_value_accessor: PreUpdateNewValueAccessor,
30     },
31     /// This variant is not normally produced by SQLite. You may encounter it
32     /// if you're using a different version than what's supported by this library.
33     Unknown,
34 }
35 
36 impl From<PreUpdateCase> for Action {
from(puc: PreUpdateCase) -> Action37     fn from(puc: PreUpdateCase) -> Action {
38         match puc {
39             PreUpdateCase::Insert(_) => Action::SQLITE_INSERT,
40             PreUpdateCase::Delete(_) => Action::SQLITE_DELETE,
41             PreUpdateCase::Update { .. } => Action::SQLITE_UPDATE,
42             PreUpdateCase::Unknown => Action::UNKNOWN,
43         }
44     }
45 }
46 
47 /// An accessor to access the old values of the row being deleted/updated during the preupdate callback.
48 #[derive(Debug)]
49 pub struct PreUpdateOldValueAccessor {
50     db: *mut ffi::sqlite3,
51     old_row_id: i64,
52 }
53 
54 impl PreUpdateOldValueAccessor {
55     /// Get the amount of columns in the row being deleted/updated.
get_column_count(&self) -> i3256     pub fn get_column_count(&self) -> i32 {
57         unsafe { ffi::sqlite3_preupdate_count(self.db) }
58     }
59 
60     /// Get the depth of the query that triggered the preupdate hook.
61     /// Returns 0 if the preupdate callback was invoked as a result of
62     /// a direct insert, update, or delete operation;
63     /// 1 for inserts, updates, or deletes invoked by top-level triggers;
64     /// 2 for changes resulting from triggers called by top-level triggers; and so forth.
get_query_depth(&self) -> i3265     pub fn get_query_depth(&self) -> i32 {
66         unsafe { ffi::sqlite3_preupdate_depth(self.db) }
67     }
68 
69     /// Get the row id of the row being updated/deleted.
get_old_row_id(&self) -> i6470     pub fn get_old_row_id(&self) -> i64 {
71         self.old_row_id
72     }
73 
74     /// Get the value of the row being updated/deleted at the specified index.
get_old_column_value(&self, i: i32) -> Result<ValueRef>75     pub fn get_old_column_value(&self, i: i32) -> Result<ValueRef> {
76         let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut();
77         unsafe {
78             check(ffi::sqlite3_preupdate_old(self.db, i, &mut p_value))?;
79             Ok(ValueRef::from_value(p_value))
80         }
81     }
82 }
83 
84 /// An accessor to access the new values of the row being inserted/updated
85 /// during the preupdate callback.
86 #[derive(Debug)]
87 pub struct PreUpdateNewValueAccessor {
88     db: *mut ffi::sqlite3,
89     new_row_id: i64,
90 }
91 
92 impl PreUpdateNewValueAccessor {
93     /// Get the amount of columns in the row being inserted/updated.
get_column_count(&self) -> i3294     pub fn get_column_count(&self) -> i32 {
95         unsafe { ffi::sqlite3_preupdate_count(self.db) }
96     }
97 
98     /// Get the depth of the query that triggered the preupdate hook.
99     /// Returns 0 if the preupdate callback was invoked as a result of
100     /// a direct insert, update, or delete operation;
101     /// 1 for inserts, updates, or deletes invoked by top-level triggers;
102     /// 2 for changes resulting from triggers called by top-level triggers; and so forth.
get_query_depth(&self) -> i32103     pub fn get_query_depth(&self) -> i32 {
104         unsafe { ffi::sqlite3_preupdate_depth(self.db) }
105     }
106 
107     /// Get the row id of the row being inserted/updated.
get_new_row_id(&self) -> i64108     pub fn get_new_row_id(&self) -> i64 {
109         self.new_row_id
110     }
111 
112     /// Get the value of the row being updated/deleted at the specified index.
get_new_column_value(&self, i: i32) -> Result<ValueRef>113     pub fn get_new_column_value(&self, i: i32) -> Result<ValueRef> {
114         let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut();
115         unsafe {
116             check(ffi::sqlite3_preupdate_new(self.db, i, &mut p_value))?;
117             Ok(ValueRef::from_value(p_value))
118         }
119     }
120 }
121 
122 impl Connection {
123     /// Register a callback function to be invoked before
124     /// a row is updated, inserted or deleted.
125     ///
126     /// The callback parameters are:
127     ///
128     /// - the name of the database ("main", "temp", ...),
129     /// - the name of the table that is updated,
130     /// - a variant of the PreUpdateCase enum which allows access to extra functions depending
131     ///   on whether it's an update, delete or insert.
132     #[inline]
preupdate_hook<F>(&self, hook: Option<F>) where F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static,133     pub fn preupdate_hook<F>(&self, hook: Option<F>)
134     where
135         F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static,
136     {
137         self.db.borrow_mut().preupdate_hook(hook);
138     }
139 }
140 
141 impl InnerConnection {
142     #[inline]
remove_preupdate_hook(&mut self)143     pub fn remove_preupdate_hook(&mut self) {
144         self.preupdate_hook(None::<fn(Action, &str, &str, &PreUpdateCase)>);
145     }
146 
147     /// ```compile_fail
148     /// use rusqlite::{Connection, Result, hooks::PreUpdateCase};
149     /// fn main() -> Result<()> {
150     ///     let db = Connection::open_in_memory()?;
151     ///     {
152     ///         let mut called = std::sync::atomic::AtomicBool::new(false);
153     ///         db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| {
154     ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
155     ///         }));
156     ///     }
157     ///     db.execute_batch("CREATE TABLE foo AS SELECT 1 AS bar;")
158     /// }
159     /// ```
preupdate_hook<F>(&mut self, hook: Option<F>) where F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static,160     fn preupdate_hook<F>(&mut self, hook: Option<F>)
161     where
162         F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static,
163     {
164         unsafe extern "C" fn call_boxed_closure<F>(
165             p_arg: *mut c_void,
166             sqlite: *mut ffi::sqlite3,
167             action_code: c_int,
168             db_name: *const c_char,
169             tbl_name: *const c_char,
170             old_row_id: i64,
171             new_row_id: i64,
172         ) where
173             F: FnMut(Action, &str, &str, &PreUpdateCase),
174         {
175             let action = Action::from(action_code);
176 
177             let preupdate_case = match action {
178                 Action::SQLITE_INSERT => PreUpdateCase::Insert(PreUpdateNewValueAccessor {
179                     db: sqlite,
180                     new_row_id,
181                 }),
182                 Action::SQLITE_DELETE => PreUpdateCase::Delete(PreUpdateOldValueAccessor {
183                     db: sqlite,
184                     old_row_id,
185                 }),
186                 Action::SQLITE_UPDATE => PreUpdateCase::Update {
187                     old_value_accessor: PreUpdateOldValueAccessor {
188                         db: sqlite,
189                         old_row_id,
190                     },
191                     new_value_accessor: PreUpdateNewValueAccessor {
192                         db: sqlite,
193                         new_row_id,
194                     },
195                 },
196                 Action::UNKNOWN => PreUpdateCase::Unknown,
197             };
198 
199             drop(catch_unwind(|| {
200                 let boxed_hook: *mut F = p_arg.cast::<F>();
201                 (*boxed_hook)(
202                     action,
203                     expect_utf8(db_name, "database name"),
204                     expect_utf8(tbl_name, "table name"),
205                     &preupdate_case,
206                 );
207             }));
208         }
209 
210         let free_preupdate_hook = if hook.is_some() {
211             Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
212         } else {
213             None
214         };
215 
216         let previous_hook = match hook {
217             Some(hook) => {
218                 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
219                 unsafe {
220                     ffi::sqlite3_preupdate_hook(
221                         self.db(),
222                         Some(call_boxed_closure::<F>),
223                         boxed_hook.cast(),
224                     )
225                 }
226             }
227             _ => unsafe { ffi::sqlite3_preupdate_hook(self.db(), None, ptr::null_mut()) },
228         };
229         if !previous_hook.is_null() {
230             if let Some(free_boxed_hook) = self.free_preupdate_hook {
231                 unsafe { free_boxed_hook(previous_hook) };
232             }
233         }
234         self.free_preupdate_hook = free_preupdate_hook;
235     }
236 }
237 
238 #[cfg(test)]
239 mod test {
240     use std::sync::atomic::{AtomicBool, Ordering};
241 
242     use super::super::Action;
243     use super::PreUpdateCase;
244     use crate::{Connection, Result};
245 
246     #[test]
test_preupdate_hook_insert() -> Result<()>247     fn test_preupdate_hook_insert() -> Result<()> {
248         let db = Connection::open_in_memory()?;
249 
250         static CALLED: AtomicBool = AtomicBool::new(false);
251 
252         db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| {
253             assert_eq!(Action::SQLITE_INSERT, action);
254             assert_eq!("main", db);
255             assert_eq!("foo", tbl);
256             match case {
257                 PreUpdateCase::Insert(accessor) => {
258                     assert_eq!(1, accessor.get_column_count());
259                     assert_eq!(1, accessor.get_new_row_id());
260                     assert_eq!(0, accessor.get_query_depth());
261                     // out of bounds access should return an error
262                     assert!(accessor.get_new_column_value(1).is_err());
263                     assert_eq!(
264                         "lisa",
265                         accessor.get_new_column_value(0).unwrap().as_str().unwrap()
266                     );
267                     assert_eq!(0, accessor.get_query_depth());
268                 }
269                 _ => panic!("wrong preupdate case"),
270             }
271             CALLED.store(true, Ordering::Relaxed);
272         }));
273         db.execute_batch("CREATE TABLE foo (t TEXT)")?;
274         db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
275         assert!(CALLED.load(Ordering::Relaxed));
276         Ok(())
277     }
278 
279     #[test]
test_preupdate_hook_delete() -> Result<()>280     fn test_preupdate_hook_delete() -> Result<()> {
281         let db = Connection::open_in_memory()?;
282 
283         static CALLED: AtomicBool = AtomicBool::new(false);
284 
285         db.execute_batch("CREATE TABLE foo (t TEXT)")?;
286         db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
287 
288         db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| {
289             assert_eq!(Action::SQLITE_DELETE, action);
290             assert_eq!("main", db);
291             assert_eq!("foo", tbl);
292             match case {
293                 PreUpdateCase::Delete(accessor) => {
294                     assert_eq!(1, accessor.get_column_count());
295                     assert_eq!(1, accessor.get_old_row_id());
296                     assert_eq!(0, accessor.get_query_depth());
297                     // out of bounds access should return an error
298                     assert!(accessor.get_old_column_value(1).is_err());
299                     assert_eq!(
300                         "lisa",
301                         accessor.get_old_column_value(0).unwrap().as_str().unwrap()
302                     );
303                     assert_eq!(0, accessor.get_query_depth());
304                 }
305                 _ => panic!("wrong preupdate case"),
306             }
307             CALLED.store(true, Ordering::Relaxed);
308         }));
309 
310         db.execute_batch("DELETE from foo")?;
311         assert!(CALLED.load(Ordering::Relaxed));
312         Ok(())
313     }
314 
315     #[test]
test_preupdate_hook_update() -> Result<()>316     fn test_preupdate_hook_update() -> Result<()> {
317         let db = Connection::open_in_memory()?;
318 
319         static CALLED: AtomicBool = AtomicBool::new(false);
320 
321         db.execute_batch("CREATE TABLE foo (t TEXT)")?;
322         db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
323 
324         db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| {
325             assert_eq!(Action::SQLITE_UPDATE, action);
326             assert_eq!("main", db);
327             assert_eq!("foo", tbl);
328             match case {
329                 PreUpdateCase::Update {
330                     old_value_accessor,
331                     new_value_accessor,
332                 } => {
333                     assert_eq!(1, old_value_accessor.get_column_count());
334                     assert_eq!(1, old_value_accessor.get_old_row_id());
335                     assert_eq!(0, old_value_accessor.get_query_depth());
336                     // out of bounds access should return an error
337                     assert!(old_value_accessor.get_old_column_value(1).is_err());
338                     assert_eq!(
339                         "lisa",
340                         old_value_accessor
341                             .get_old_column_value(0)
342                             .unwrap()
343                             .as_str()
344                             .unwrap()
345                     );
346                     assert_eq!(0, old_value_accessor.get_query_depth());
347 
348                     assert_eq!(1, new_value_accessor.get_column_count());
349                     assert_eq!(1, new_value_accessor.get_new_row_id());
350                     assert_eq!(0, new_value_accessor.get_query_depth());
351                     // out of bounds access should return an error
352                     assert!(new_value_accessor.get_new_column_value(1).is_err());
353                     assert_eq!(
354                         "janice",
355                         new_value_accessor
356                             .get_new_column_value(0)
357                             .unwrap()
358                             .as_str()
359                             .unwrap()
360                     );
361                     assert_eq!(0, new_value_accessor.get_query_depth());
362                 }
363                 _ => panic!("wrong preupdate case"),
364             }
365             CALLED.store(true, Ordering::Relaxed);
366         }));
367 
368         db.execute_batch("UPDATE foo SET t = 'janice'")?;
369         assert!(CALLED.load(Ordering::Relaxed));
370         Ok(())
371     }
372 }
373