1 use std::fmt::Debug; 2 use std::os::raw::{c_char, c_int, c_void}; 3 use std::panic::catch_unwind; 4 use std::ptr; 5 6 use super::expect_utf8; 7 use super::free_boxed_hook; 8 use super::Action; 9 use crate::error::check; 10 use crate::ffi; 11 use crate::inner_connection::InnerConnection; 12 use crate::types::ValueRef; 13 use crate::Connection; 14 use crate::Result; 15 16 /// The possible cases for when a PreUpdateHook gets triggered. Allows access to the relevant 17 /// functions for each case through the contained values. 18 #[derive(Debug)] 19 pub enum PreUpdateCase { 20 /// Pre-update hook was triggered by an insert. 21 Insert(PreUpdateNewValueAccessor), 22 /// Pre-update hook was triggered by a delete. 23 Delete(PreUpdateOldValueAccessor), 24 /// Pre-update hook was triggered by an update. 25 Update { 26 #[allow(missing_docs)] 27 old_value_accessor: PreUpdateOldValueAccessor, 28 #[allow(missing_docs)] 29 new_value_accessor: PreUpdateNewValueAccessor, 30 }, 31 /// This variant is not normally produced by SQLite. You may encounter it 32 /// if you're using a different version than what's supported by this library. 33 Unknown, 34 } 35 36 impl From<PreUpdateCase> for Action { from(puc: PreUpdateCase) -> Action37 fn from(puc: PreUpdateCase) -> Action { 38 match puc { 39 PreUpdateCase::Insert(_) => Action::SQLITE_INSERT, 40 PreUpdateCase::Delete(_) => Action::SQLITE_DELETE, 41 PreUpdateCase::Update { .. } => Action::SQLITE_UPDATE, 42 PreUpdateCase::Unknown => Action::UNKNOWN, 43 } 44 } 45 } 46 47 /// An accessor to access the old values of the row being deleted/updated during the preupdate callback. 48 #[derive(Debug)] 49 pub struct PreUpdateOldValueAccessor { 50 db: *mut ffi::sqlite3, 51 old_row_id: i64, 52 } 53 54 impl PreUpdateOldValueAccessor { 55 /// Get the amount of columns in the row being deleted/updated. get_column_count(&self) -> i3256 pub fn get_column_count(&self) -> i32 { 57 unsafe { ffi::sqlite3_preupdate_count(self.db) } 58 } 59 60 /// Get the depth of the query that triggered the preupdate hook. 61 /// Returns 0 if the preupdate callback was invoked as a result of 62 /// a direct insert, update, or delete operation; 63 /// 1 for inserts, updates, or deletes invoked by top-level triggers; 64 /// 2 for changes resulting from triggers called by top-level triggers; and so forth. get_query_depth(&self) -> i3265 pub fn get_query_depth(&self) -> i32 { 66 unsafe { ffi::sqlite3_preupdate_depth(self.db) } 67 } 68 69 /// Get the row id of the row being updated/deleted. get_old_row_id(&self) -> i6470 pub fn get_old_row_id(&self) -> i64 { 71 self.old_row_id 72 } 73 74 /// Get the value of the row being updated/deleted at the specified index. get_old_column_value(&self, i: i32) -> Result<ValueRef>75 pub fn get_old_column_value(&self, i: i32) -> Result<ValueRef> { 76 let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); 77 unsafe { 78 check(ffi::sqlite3_preupdate_old(self.db, i, &mut p_value))?; 79 Ok(ValueRef::from_value(p_value)) 80 } 81 } 82 } 83 84 /// An accessor to access the new values of the row being inserted/updated 85 /// during the preupdate callback. 86 #[derive(Debug)] 87 pub struct PreUpdateNewValueAccessor { 88 db: *mut ffi::sqlite3, 89 new_row_id: i64, 90 } 91 92 impl PreUpdateNewValueAccessor { 93 /// Get the amount of columns in the row being inserted/updated. get_column_count(&self) -> i3294 pub fn get_column_count(&self) -> i32 { 95 unsafe { ffi::sqlite3_preupdate_count(self.db) } 96 } 97 98 /// Get the depth of the query that triggered the preupdate hook. 99 /// Returns 0 if the preupdate callback was invoked as a result of 100 /// a direct insert, update, or delete operation; 101 /// 1 for inserts, updates, or deletes invoked by top-level triggers; 102 /// 2 for changes resulting from triggers called by top-level triggers; and so forth. get_query_depth(&self) -> i32103 pub fn get_query_depth(&self) -> i32 { 104 unsafe { ffi::sqlite3_preupdate_depth(self.db) } 105 } 106 107 /// Get the row id of the row being inserted/updated. get_new_row_id(&self) -> i64108 pub fn get_new_row_id(&self) -> i64 { 109 self.new_row_id 110 } 111 112 /// Get the value of the row being updated/deleted at the specified index. get_new_column_value(&self, i: i32) -> Result<ValueRef>113 pub fn get_new_column_value(&self, i: i32) -> Result<ValueRef> { 114 let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); 115 unsafe { 116 check(ffi::sqlite3_preupdate_new(self.db, i, &mut p_value))?; 117 Ok(ValueRef::from_value(p_value)) 118 } 119 } 120 } 121 122 impl Connection { 123 /// Register a callback function to be invoked before 124 /// a row is updated, inserted or deleted. 125 /// 126 /// The callback parameters are: 127 /// 128 /// - the name of the database ("main", "temp", ...), 129 /// - the name of the table that is updated, 130 /// - a variant of the PreUpdateCase enum which allows access to extra functions depending 131 /// on whether it's an update, delete or insert. 132 #[inline] preupdate_hook<F>(&self, hook: Option<F>) where F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static,133 pub fn preupdate_hook<F>(&self, hook: Option<F>) 134 where 135 F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static, 136 { 137 self.db.borrow_mut().preupdate_hook(hook); 138 } 139 } 140 141 impl InnerConnection { 142 #[inline] remove_preupdate_hook(&mut self)143 pub fn remove_preupdate_hook(&mut self) { 144 self.preupdate_hook(None::<fn(Action, &str, &str, &PreUpdateCase)>); 145 } 146 147 /// ```compile_fail 148 /// use rusqlite::{Connection, Result, hooks::PreUpdateCase}; 149 /// fn main() -> Result<()> { 150 /// let db = Connection::open_in_memory()?; 151 /// { 152 /// let mut called = std::sync::atomic::AtomicBool::new(false); 153 /// db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| { 154 /// called.store(true, std::sync::atomic::Ordering::Relaxed); 155 /// })); 156 /// } 157 /// db.execute_batch("CREATE TABLE foo AS SELECT 1 AS bar;") 158 /// } 159 /// ``` preupdate_hook<F>(&mut self, hook: Option<F>) where F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static,160 fn preupdate_hook<F>(&mut self, hook: Option<F>) 161 where 162 F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static, 163 { 164 unsafe extern "C" fn call_boxed_closure<F>( 165 p_arg: *mut c_void, 166 sqlite: *mut ffi::sqlite3, 167 action_code: c_int, 168 db_name: *const c_char, 169 tbl_name: *const c_char, 170 old_row_id: i64, 171 new_row_id: i64, 172 ) where 173 F: FnMut(Action, &str, &str, &PreUpdateCase), 174 { 175 let action = Action::from(action_code); 176 177 let preupdate_case = match action { 178 Action::SQLITE_INSERT => PreUpdateCase::Insert(PreUpdateNewValueAccessor { 179 db: sqlite, 180 new_row_id, 181 }), 182 Action::SQLITE_DELETE => PreUpdateCase::Delete(PreUpdateOldValueAccessor { 183 db: sqlite, 184 old_row_id, 185 }), 186 Action::SQLITE_UPDATE => PreUpdateCase::Update { 187 old_value_accessor: PreUpdateOldValueAccessor { 188 db: sqlite, 189 old_row_id, 190 }, 191 new_value_accessor: PreUpdateNewValueAccessor { 192 db: sqlite, 193 new_row_id, 194 }, 195 }, 196 Action::UNKNOWN => PreUpdateCase::Unknown, 197 }; 198 199 drop(catch_unwind(|| { 200 let boxed_hook: *mut F = p_arg.cast::<F>(); 201 (*boxed_hook)( 202 action, 203 expect_utf8(db_name, "database name"), 204 expect_utf8(tbl_name, "table name"), 205 &preupdate_case, 206 ); 207 })); 208 } 209 210 let free_preupdate_hook = if hook.is_some() { 211 Some(free_boxed_hook::<F> as unsafe fn(*mut c_void)) 212 } else { 213 None 214 }; 215 216 let previous_hook = match hook { 217 Some(hook) => { 218 let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); 219 unsafe { 220 ffi::sqlite3_preupdate_hook( 221 self.db(), 222 Some(call_boxed_closure::<F>), 223 boxed_hook.cast(), 224 ) 225 } 226 } 227 _ => unsafe { ffi::sqlite3_preupdate_hook(self.db(), None, ptr::null_mut()) }, 228 }; 229 if !previous_hook.is_null() { 230 if let Some(free_boxed_hook) = self.free_preupdate_hook { 231 unsafe { free_boxed_hook(previous_hook) }; 232 } 233 } 234 self.free_preupdate_hook = free_preupdate_hook; 235 } 236 } 237 238 #[cfg(test)] 239 mod test { 240 use std::sync::atomic::{AtomicBool, Ordering}; 241 242 use super::super::Action; 243 use super::PreUpdateCase; 244 use crate::{Connection, Result}; 245 246 #[test] test_preupdate_hook_insert() -> Result<()>247 fn test_preupdate_hook_insert() -> Result<()> { 248 let db = Connection::open_in_memory()?; 249 250 static CALLED: AtomicBool = AtomicBool::new(false); 251 252 db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| { 253 assert_eq!(Action::SQLITE_INSERT, action); 254 assert_eq!("main", db); 255 assert_eq!("foo", tbl); 256 match case { 257 PreUpdateCase::Insert(accessor) => { 258 assert_eq!(1, accessor.get_column_count()); 259 assert_eq!(1, accessor.get_new_row_id()); 260 assert_eq!(0, accessor.get_query_depth()); 261 // out of bounds access should return an error 262 assert!(accessor.get_new_column_value(1).is_err()); 263 assert_eq!( 264 "lisa", 265 accessor.get_new_column_value(0).unwrap().as_str().unwrap() 266 ); 267 assert_eq!(0, accessor.get_query_depth()); 268 } 269 _ => panic!("wrong preupdate case"), 270 } 271 CALLED.store(true, Ordering::Relaxed); 272 })); 273 db.execute_batch("CREATE TABLE foo (t TEXT)")?; 274 db.execute_batch("INSERT INTO foo VALUES ('lisa')")?; 275 assert!(CALLED.load(Ordering::Relaxed)); 276 Ok(()) 277 } 278 279 #[test] test_preupdate_hook_delete() -> Result<()>280 fn test_preupdate_hook_delete() -> Result<()> { 281 let db = Connection::open_in_memory()?; 282 283 static CALLED: AtomicBool = AtomicBool::new(false); 284 285 db.execute_batch("CREATE TABLE foo (t TEXT)")?; 286 db.execute_batch("INSERT INTO foo VALUES ('lisa')")?; 287 288 db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| { 289 assert_eq!(Action::SQLITE_DELETE, action); 290 assert_eq!("main", db); 291 assert_eq!("foo", tbl); 292 match case { 293 PreUpdateCase::Delete(accessor) => { 294 assert_eq!(1, accessor.get_column_count()); 295 assert_eq!(1, accessor.get_old_row_id()); 296 assert_eq!(0, accessor.get_query_depth()); 297 // out of bounds access should return an error 298 assert!(accessor.get_old_column_value(1).is_err()); 299 assert_eq!( 300 "lisa", 301 accessor.get_old_column_value(0).unwrap().as_str().unwrap() 302 ); 303 assert_eq!(0, accessor.get_query_depth()); 304 } 305 _ => panic!("wrong preupdate case"), 306 } 307 CALLED.store(true, Ordering::Relaxed); 308 })); 309 310 db.execute_batch("DELETE from foo")?; 311 assert!(CALLED.load(Ordering::Relaxed)); 312 Ok(()) 313 } 314 315 #[test] test_preupdate_hook_update() -> Result<()>316 fn test_preupdate_hook_update() -> Result<()> { 317 let db = Connection::open_in_memory()?; 318 319 static CALLED: AtomicBool = AtomicBool::new(false); 320 321 db.execute_batch("CREATE TABLE foo (t TEXT)")?; 322 db.execute_batch("INSERT INTO foo VALUES ('lisa')")?; 323 324 db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| { 325 assert_eq!(Action::SQLITE_UPDATE, action); 326 assert_eq!("main", db); 327 assert_eq!("foo", tbl); 328 match case { 329 PreUpdateCase::Update { 330 old_value_accessor, 331 new_value_accessor, 332 } => { 333 assert_eq!(1, old_value_accessor.get_column_count()); 334 assert_eq!(1, old_value_accessor.get_old_row_id()); 335 assert_eq!(0, old_value_accessor.get_query_depth()); 336 // out of bounds access should return an error 337 assert!(old_value_accessor.get_old_column_value(1).is_err()); 338 assert_eq!( 339 "lisa", 340 old_value_accessor 341 .get_old_column_value(0) 342 .unwrap() 343 .as_str() 344 .unwrap() 345 ); 346 assert_eq!(0, old_value_accessor.get_query_depth()); 347 348 assert_eq!(1, new_value_accessor.get_column_count()); 349 assert_eq!(1, new_value_accessor.get_new_row_id()); 350 assert_eq!(0, new_value_accessor.get_query_depth()); 351 // out of bounds access should return an error 352 assert!(new_value_accessor.get_new_column_value(1).is_err()); 353 assert_eq!( 354 "janice", 355 new_value_accessor 356 .get_new_column_value(0) 357 .unwrap() 358 .as_str() 359 .unwrap() 360 ); 361 assert_eq!(0, new_value_accessor.get_query_depth()); 362 } 363 _ => panic!("wrong preupdate case"), 364 } 365 CALLED.store(true, Ordering::Relaxed); 366 })); 367 368 db.execute_batch("UPDATE foo SET t = 'janice'")?; 369 assert!(CALLED.load(Ordering::Relaxed)); 370 Ok(()) 371 } 372 } 373