• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Commit, Data Change and Rollback Notification Callbacks
2 #![allow(non_camel_case_types)]
3 
4 use std::os::raw::{c_char, c_int, c_void};
5 use std::panic::catch_unwind;
6 use std::ptr;
7 
8 use crate::ffi;
9 
10 use crate::{Connection, InnerConnection};
11 
12 #[cfg(feature = "preupdate_hook")]
13 pub use preupdate_hook::*;
14 
15 #[cfg(feature = "preupdate_hook")]
16 mod preupdate_hook;
17 
18 /// Action Codes
19 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
20 #[repr(i32)]
21 #[non_exhaustive]
22 #[allow(clippy::upper_case_acronyms)]
23 pub enum Action {
24     /// Unsupported / unexpected action
25     UNKNOWN = -1,
26     /// DELETE command
27     SQLITE_DELETE = ffi::SQLITE_DELETE,
28     /// INSERT command
29     SQLITE_INSERT = ffi::SQLITE_INSERT,
30     /// UPDATE command
31     SQLITE_UPDATE = ffi::SQLITE_UPDATE,
32 }
33 
34 impl From<i32> for Action {
35     #[inline]
from(code: i32) -> Action36     fn from(code: i32) -> Action {
37         match code {
38             ffi::SQLITE_DELETE => Action::SQLITE_DELETE,
39             ffi::SQLITE_INSERT => Action::SQLITE_INSERT,
40             ffi::SQLITE_UPDATE => Action::SQLITE_UPDATE,
41             _ => Action::UNKNOWN,
42         }
43     }
44 }
45 
46 /// The context received by an authorizer hook.
47 ///
48 /// See <https://sqlite.org/c3ref/set_authorizer.html> for more info.
49 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
50 pub struct AuthContext<'c> {
51     /// The action to be authorized.
52     pub action: AuthAction<'c>,
53 
54     /// The database name, if applicable.
55     pub database_name: Option<&'c str>,
56 
57     /// The inner-most trigger or view responsible for the access attempt.
58     /// `None` if the access attempt was made by top-level SQL code.
59     pub accessor: Option<&'c str>,
60 }
61 
62 /// Actions and arguments found within a statement during
63 /// preparation.
64 ///
65 /// See <https://sqlite.org/c3ref/c_alter_table.html> for more info.
66 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
67 #[non_exhaustive]
68 #[allow(missing_docs)]
69 pub enum AuthAction<'c> {
70     /// This variant is not normally produced by SQLite. You may encounter it
71     // if you're using a different version than what's supported by this library.
72     Unknown {
73         /// The unknown authorization action code.
74         code: i32,
75         /// The third arg to the authorizer callback.
76         arg1: Option<&'c str>,
77         /// The fourth arg to the authorizer callback.
78         arg2: Option<&'c str>,
79     },
80     CreateIndex {
81         index_name: &'c str,
82         table_name: &'c str,
83     },
84     CreateTable {
85         table_name: &'c str,
86     },
87     CreateTempIndex {
88         index_name: &'c str,
89         table_name: &'c str,
90     },
91     CreateTempTable {
92         table_name: &'c str,
93     },
94     CreateTempTrigger {
95         trigger_name: &'c str,
96         table_name: &'c str,
97     },
98     CreateTempView {
99         view_name: &'c str,
100     },
101     CreateTrigger {
102         trigger_name: &'c str,
103         table_name: &'c str,
104     },
105     CreateView {
106         view_name: &'c str,
107     },
108     Delete {
109         table_name: &'c str,
110     },
111     DropIndex {
112         index_name: &'c str,
113         table_name: &'c str,
114     },
115     DropTable {
116         table_name: &'c str,
117     },
118     DropTempIndex {
119         index_name: &'c str,
120         table_name: &'c str,
121     },
122     DropTempTable {
123         table_name: &'c str,
124     },
125     DropTempTrigger {
126         trigger_name: &'c str,
127         table_name: &'c str,
128     },
129     DropTempView {
130         view_name: &'c str,
131     },
132     DropTrigger {
133         trigger_name: &'c str,
134         table_name: &'c str,
135     },
136     DropView {
137         view_name: &'c str,
138     },
139     Insert {
140         table_name: &'c str,
141     },
142     Pragma {
143         pragma_name: &'c str,
144         /// The pragma value, if present (e.g., `PRAGMA name = value;`).
145         pragma_value: Option<&'c str>,
146     },
147     Read {
148         table_name: &'c str,
149         column_name: &'c str,
150     },
151     Select,
152     Transaction {
153         operation: TransactionOperation,
154     },
155     Update {
156         table_name: &'c str,
157         column_name: &'c str,
158     },
159     Attach {
160         filename: &'c str,
161     },
162     Detach {
163         database_name: &'c str,
164     },
165     AlterTable {
166         database_name: &'c str,
167         table_name: &'c str,
168     },
169     Reindex {
170         index_name: &'c str,
171     },
172     Analyze {
173         table_name: &'c str,
174     },
175     CreateVtable {
176         table_name: &'c str,
177         module_name: &'c str,
178     },
179     DropVtable {
180         table_name: &'c str,
181         module_name: &'c str,
182     },
183     Function {
184         function_name: &'c str,
185     },
186     Savepoint {
187         operation: TransactionOperation,
188         savepoint_name: &'c str,
189     },
190     Recursive,
191 }
192 
193 impl<'c> AuthAction<'c> {
from_raw(code: i32, arg1: Option<&'c str>, arg2: Option<&'c str>) -> Self194     fn from_raw(code: i32, arg1: Option<&'c str>, arg2: Option<&'c str>) -> Self {
195         match (code, arg1, arg2) {
196             (ffi::SQLITE_CREATE_INDEX, Some(index_name), Some(table_name)) => Self::CreateIndex {
197                 index_name,
198                 table_name,
199             },
200             (ffi::SQLITE_CREATE_TABLE, Some(table_name), _) => Self::CreateTable { table_name },
201             (ffi::SQLITE_CREATE_TEMP_INDEX, Some(index_name), Some(table_name)) => {
202                 Self::CreateTempIndex {
203                     index_name,
204                     table_name,
205                 }
206             }
207             (ffi::SQLITE_CREATE_TEMP_TABLE, Some(table_name), _) => {
208                 Self::CreateTempTable { table_name }
209             }
210             (ffi::SQLITE_CREATE_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => {
211                 Self::CreateTempTrigger {
212                     trigger_name,
213                     table_name,
214                 }
215             }
216             (ffi::SQLITE_CREATE_TEMP_VIEW, Some(view_name), _) => {
217                 Self::CreateTempView { view_name }
218             }
219             (ffi::SQLITE_CREATE_TRIGGER, Some(trigger_name), Some(table_name)) => {
220                 Self::CreateTrigger {
221                     trigger_name,
222                     table_name,
223                 }
224             }
225             (ffi::SQLITE_CREATE_VIEW, Some(view_name), _) => Self::CreateView { view_name },
226             (ffi::SQLITE_DELETE, Some(table_name), None) => Self::Delete { table_name },
227             (ffi::SQLITE_DROP_INDEX, Some(index_name), Some(table_name)) => Self::DropIndex {
228                 index_name,
229                 table_name,
230             },
231             (ffi::SQLITE_DROP_TABLE, Some(table_name), _) => Self::DropTable { table_name },
232             (ffi::SQLITE_DROP_TEMP_INDEX, Some(index_name), Some(table_name)) => {
233                 Self::DropTempIndex {
234                     index_name,
235                     table_name,
236                 }
237             }
238             (ffi::SQLITE_DROP_TEMP_TABLE, Some(table_name), _) => {
239                 Self::DropTempTable { table_name }
240             }
241             (ffi::SQLITE_DROP_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => {
242                 Self::DropTempTrigger {
243                     trigger_name,
244                     table_name,
245                 }
246             }
247             (ffi::SQLITE_DROP_TEMP_VIEW, Some(view_name), _) => Self::DropTempView { view_name },
248             (ffi::SQLITE_DROP_TRIGGER, Some(trigger_name), Some(table_name)) => Self::DropTrigger {
249                 trigger_name,
250                 table_name,
251             },
252             (ffi::SQLITE_DROP_VIEW, Some(view_name), _) => Self::DropView { view_name },
253             (ffi::SQLITE_INSERT, Some(table_name), _) => Self::Insert { table_name },
254             (ffi::SQLITE_PRAGMA, Some(pragma_name), pragma_value) => Self::Pragma {
255                 pragma_name,
256                 pragma_value,
257             },
258             (ffi::SQLITE_READ, Some(table_name), Some(column_name)) => Self::Read {
259                 table_name,
260                 column_name,
261             },
262             (ffi::SQLITE_SELECT, ..) => Self::Select,
263             (ffi::SQLITE_TRANSACTION, Some(operation_str), _) => Self::Transaction {
264                 operation: TransactionOperation::from_str(operation_str),
265             },
266             (ffi::SQLITE_UPDATE, Some(table_name), Some(column_name)) => Self::Update {
267                 table_name,
268                 column_name,
269             },
270             (ffi::SQLITE_ATTACH, Some(filename), _) => Self::Attach { filename },
271             (ffi::SQLITE_DETACH, Some(database_name), _) => Self::Detach { database_name },
272             (ffi::SQLITE_ALTER_TABLE, Some(database_name), Some(table_name)) => Self::AlterTable {
273                 database_name,
274                 table_name,
275             },
276             (ffi::SQLITE_REINDEX, Some(index_name), _) => Self::Reindex { index_name },
277             (ffi::SQLITE_ANALYZE, Some(table_name), _) => Self::Analyze { table_name },
278             (ffi::SQLITE_CREATE_VTABLE, Some(table_name), Some(module_name)) => {
279                 Self::CreateVtable {
280                     table_name,
281                     module_name,
282                 }
283             }
284             (ffi::SQLITE_DROP_VTABLE, Some(table_name), Some(module_name)) => Self::DropVtable {
285                 table_name,
286                 module_name,
287             },
288             (ffi::SQLITE_FUNCTION, _, Some(function_name)) => Self::Function { function_name },
289             (ffi::SQLITE_SAVEPOINT, Some(operation_str), Some(savepoint_name)) => Self::Savepoint {
290                 operation: TransactionOperation::from_str(operation_str),
291                 savepoint_name,
292             },
293             (ffi::SQLITE_RECURSIVE, ..) => Self::Recursive,
294             (code, arg1, arg2) => Self::Unknown { code, arg1, arg2 },
295         }
296     }
297 }
298 
299 pub(crate) type BoxedAuthorizer =
300     Box<dyn for<'c> FnMut(AuthContext<'c>) -> Authorization + Send + 'static>;
301 
302 /// A transaction operation.
303 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
304 #[non_exhaustive]
305 #[allow(missing_docs)]
306 pub enum TransactionOperation {
307     Unknown,
308     Begin,
309     Release,
310     Rollback,
311 }
312 
313 impl TransactionOperation {
from_str(op_str: &str) -> Self314     fn from_str(op_str: &str) -> Self {
315         match op_str {
316             "BEGIN" => Self::Begin,
317             "RELEASE" => Self::Release,
318             "ROLLBACK" => Self::Rollback,
319             _ => Self::Unknown,
320         }
321     }
322 }
323 
324 /// [`authorizer`](Connection::authorizer) return code
325 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
326 #[non_exhaustive]
327 pub enum Authorization {
328     /// Authorize the action.
329     Allow,
330     /// Don't allow access, but don't trigger an error either.
331     Ignore,
332     /// Trigger an error.
333     Deny,
334 }
335 
336 impl Authorization {
into_raw(self) -> c_int337     fn into_raw(self) -> c_int {
338         match self {
339             Self::Allow => ffi::SQLITE_OK,
340             Self::Ignore => ffi::SQLITE_IGNORE,
341             Self::Deny => ffi::SQLITE_DENY,
342         }
343     }
344 }
345 
346 impl Connection {
347     /// Register a callback function to be invoked whenever
348     /// a transaction is committed.
349     ///
350     /// The callback returns `true` to rollback.
351     #[inline]
commit_hook<F>(&self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static,352     pub fn commit_hook<F>(&self, hook: Option<F>)
353     where
354         F: FnMut() -> bool + Send + 'static,
355     {
356         self.db.borrow_mut().commit_hook(hook);
357     }
358 
359     /// Register a callback function to be invoked whenever
360     /// a transaction is committed.
361     #[inline]
rollback_hook<F>(&self, hook: Option<F>) where F: FnMut() + Send + 'static,362     pub fn rollback_hook<F>(&self, hook: Option<F>)
363     where
364         F: FnMut() + Send + 'static,
365     {
366         self.db.borrow_mut().rollback_hook(hook);
367     }
368 
369     /// Register a callback function to be invoked whenever
370     /// a row is updated, inserted or deleted in a rowid table.
371     ///
372     /// The callback parameters are:
373     ///
374     /// - the type of database update (`SQLITE_INSERT`, `SQLITE_UPDATE` or
375     ///   `SQLITE_DELETE`),
376     /// - the name of the database ("main", "temp", ...),
377     /// - the name of the table that is updated,
378     /// - the ROWID of the row that is updated.
379     #[inline]
update_hook<F>(&self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static,380     pub fn update_hook<F>(&self, hook: Option<F>)
381     where
382         F: FnMut(Action, &str, &str, i64) + Send + 'static,
383     {
384         self.db.borrow_mut().update_hook(hook);
385     }
386 
387     /// Register a query progress callback.
388     ///
389     /// The parameter `num_ops` is the approximate number of virtual machine
390     /// instructions that are evaluated between successive invocations of the
391     /// `handler`. If `num_ops` is less than one then the progress handler
392     /// is disabled.
393     ///
394     /// If the progress callback returns `true`, the operation is interrupted.
progress_handler<F>(&self, num_ops: c_int, handler: Option<F>) where F: FnMut() -> bool + Send + 'static,395     pub fn progress_handler<F>(&self, num_ops: c_int, handler: Option<F>)
396     where
397         F: FnMut() -> bool + Send + 'static,
398     {
399         self.db.borrow_mut().progress_handler(num_ops, handler);
400     }
401 
402     /// Register an authorizer callback that's invoked
403     /// as a statement is being prepared.
404     #[inline]
authorizer<'c, F>(&self, hook: Option<F>) where F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,405     pub fn authorizer<'c, F>(&self, hook: Option<F>)
406     where
407         F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,
408     {
409         self.db.borrow_mut().authorizer(hook);
410     }
411 }
412 
413 impl InnerConnection {
414     #[inline]
remove_hooks(&mut self)415     pub fn remove_hooks(&mut self) {
416         self.update_hook(None::<fn(Action, &str, &str, i64)>);
417         self.commit_hook(None::<fn() -> bool>);
418         self.rollback_hook(None::<fn()>);
419         self.progress_handler(0, None::<fn() -> bool>);
420         self.authorizer(None::<fn(AuthContext<'_>) -> Authorization>);
421     }
422 
423     /// ```compile_fail
424     /// use rusqlite::{Connection, Result};
425     /// fn main() -> Result<()> {
426     ///     let db = Connection::open_in_memory()?;
427     ///     {
428     ///         let mut called = std::sync::atomic::AtomicBool::new(false);
429     ///         db.commit_hook(Some(|| {
430     ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
431     ///             true
432     ///         }));
433     ///     }
434     ///     assert!(db
435     ///         .execute_batch(
436     ///             "BEGIN;
437     ///         CREATE TABLE foo (t TEXT);
438     ///         COMMIT;",
439     ///         )
440     ///         .is_err());
441     ///     Ok(())
442     /// }
443     /// ```
commit_hook<F>(&mut self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static,444     fn commit_hook<F>(&mut self, hook: Option<F>)
445     where
446         F: FnMut() -> bool + Send + 'static,
447     {
448         unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
449         where
450             F: FnMut() -> bool,
451         {
452             let r = catch_unwind(|| {
453                 let boxed_hook: *mut F = p_arg.cast::<F>();
454                 (*boxed_hook)()
455             });
456             c_int::from(r.unwrap_or_default())
457         }
458 
459         // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with
460         // `sqlite3_commit_hook`. so we keep the `xDestroy` function in
461         // `InnerConnection.free_boxed_hook`.
462         let free_commit_hook = if hook.is_some() {
463             Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
464         } else {
465             None
466         };
467 
468         let previous_hook = match hook {
469             Some(hook) => {
470                 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
471                 unsafe {
472                     ffi::sqlite3_commit_hook(
473                         self.db(),
474                         Some(call_boxed_closure::<F>),
475                         boxed_hook.cast(),
476                     )
477                 }
478             }
479             _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) },
480         };
481         if !previous_hook.is_null() {
482             if let Some(free_boxed_hook) = self.free_commit_hook {
483                 unsafe { free_boxed_hook(previous_hook) };
484             }
485         }
486         self.free_commit_hook = free_commit_hook;
487     }
488 
489     /// ```compile_fail
490     /// use rusqlite::{Connection, Result};
491     /// fn main() -> Result<()> {
492     ///     let db = Connection::open_in_memory()?;
493     ///     {
494     ///         let mut called = std::sync::atomic::AtomicBool::new(false);
495     ///         db.rollback_hook(Some(|| {
496     ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
497     ///         }));
498     ///     }
499     ///     assert!(db
500     ///         .execute_batch(
501     ///             "BEGIN;
502     ///         CREATE TABLE foo (t TEXT);
503     ///         ROLLBACK;",
504     ///         )
505     ///         .is_err());
506     ///     Ok(())
507     /// }
508     /// ```
rollback_hook<F>(&mut self, hook: Option<F>) where F: FnMut() + Send + 'static,509     fn rollback_hook<F>(&mut self, hook: Option<F>)
510     where
511         F: FnMut() + Send + 'static,
512     {
513         unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void)
514         where
515             F: FnMut(),
516         {
517             drop(catch_unwind(|| {
518                 let boxed_hook: *mut F = p_arg.cast::<F>();
519                 (*boxed_hook)();
520             }));
521         }
522 
523         let free_rollback_hook = if hook.is_some() {
524             Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
525         } else {
526             None
527         };
528 
529         let previous_hook = match hook {
530             Some(hook) => {
531                 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
532                 unsafe {
533                     ffi::sqlite3_rollback_hook(
534                         self.db(),
535                         Some(call_boxed_closure::<F>),
536                         boxed_hook.cast(),
537                     )
538                 }
539             }
540             _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) },
541         };
542         if !previous_hook.is_null() {
543             if let Some(free_boxed_hook) = self.free_rollback_hook {
544                 unsafe { free_boxed_hook(previous_hook) };
545             }
546         }
547         self.free_rollback_hook = free_rollback_hook;
548     }
549 
550     /// ```compile_fail
551     /// use rusqlite::{Connection, Result};
552     /// fn main() -> Result<()> {
553     ///     let db = Connection::open_in_memory()?;
554     ///     {
555     ///         let mut called = std::sync::atomic::AtomicBool::new(false);
556     ///         db.update_hook(Some(|_, _: &str, _: &str, _| {
557     ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
558     ///         }));
559     ///     }
560     ///     db.execute_batch("CREATE TABLE foo AS SELECT 1 AS bar;")
561     /// }
562     /// ```
update_hook<F>(&mut self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static,563     fn update_hook<F>(&mut self, hook: Option<F>)
564     where
565         F: FnMut(Action, &str, &str, i64) + Send + 'static,
566     {
567         unsafe extern "C" fn call_boxed_closure<F>(
568             p_arg: *mut c_void,
569             action_code: c_int,
570             p_db_name: *const c_char,
571             p_table_name: *const c_char,
572             row_id: i64,
573         ) where
574             F: FnMut(Action, &str, &str, i64),
575         {
576             let action = Action::from(action_code);
577             drop(catch_unwind(|| {
578                 let boxed_hook: *mut F = p_arg.cast::<F>();
579                 (*boxed_hook)(
580                     action,
581                     expect_utf8(p_db_name, "database name"),
582                     expect_utf8(p_table_name, "table name"),
583                     row_id,
584                 );
585             }));
586         }
587 
588         let free_update_hook = if hook.is_some() {
589             Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
590         } else {
591             None
592         };
593 
594         let previous_hook = match hook {
595             Some(hook) => {
596                 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
597                 unsafe {
598                     ffi::sqlite3_update_hook(
599                         self.db(),
600                         Some(call_boxed_closure::<F>),
601                         boxed_hook.cast(),
602                     )
603                 }
604             }
605             _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) },
606         };
607         if !previous_hook.is_null() {
608             if let Some(free_boxed_hook) = self.free_update_hook {
609                 unsafe { free_boxed_hook(previous_hook) };
610             }
611         }
612         self.free_update_hook = free_update_hook;
613     }
614 
615     /// ```compile_fail
616     /// use rusqlite::{Connection, Result};
617     /// fn main() -> Result<()> {
618     ///     let db = Connection::open_in_memory()?;
619     ///     {
620     ///         let mut called = std::sync::atomic::AtomicBool::new(false);
621     ///         db.progress_handler(
622     ///             1,
623     ///             Some(|| {
624     ///                 called.store(true, std::sync::atomic::Ordering::Relaxed);
625     ///                 true
626     ///             }),
627     ///         );
628     ///     }
629     ///     assert!(db
630     ///         .execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
631     ///         .is_err());
632     ///     Ok(())
633     /// }
634     /// ```
progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>) where F: FnMut() -> bool + Send + 'static,635     fn progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>)
636     where
637         F: FnMut() -> bool + Send + 'static,
638     {
639         unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
640         where
641             F: FnMut() -> bool,
642         {
643             let r = catch_unwind(|| {
644                 let boxed_handler: *mut F = p_arg.cast::<F>();
645                 (*boxed_handler)()
646             });
647             c_int::from(r.unwrap_or_default())
648         }
649 
650         if let Some(handler) = handler {
651             let boxed_handler = Box::new(handler);
652             unsafe {
653                 ffi::sqlite3_progress_handler(
654                     self.db(),
655                     num_ops,
656                     Some(call_boxed_closure::<F>),
657                     &*boxed_handler as *const F as *mut _,
658                 );
659             }
660             self.progress_handler = Some(boxed_handler);
661         } else {
662             unsafe { ffi::sqlite3_progress_handler(self.db(), num_ops, None, ptr::null_mut()) }
663             self.progress_handler = None;
664         };
665     }
666 
667     /// ```compile_fail
668     /// use rusqlite::{Connection, Result};
669     /// fn main() -> Result<()> {
670     ///     let db = Connection::open_in_memory()?;
671     ///     {
672     ///         let mut called = std::sync::atomic::AtomicBool::new(false);
673     ///         db.authorizer(Some(|_: rusqlite::hooks::AuthContext<'_>| {
674     ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
675     ///             rusqlite::hooks::Authorization::Deny
676     ///         }));
677     ///     }
678     ///     assert!(db
679     ///         .execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
680     ///         .is_err());
681     ///     Ok(())
682     /// }
683     /// ```
authorizer<'c, F>(&'c mut self, authorizer: Option<F>) where F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,684     fn authorizer<'c, F>(&'c mut self, authorizer: Option<F>)
685     where
686         F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,
687     {
688         unsafe extern "C" fn call_boxed_closure<'c, F>(
689             p_arg: *mut c_void,
690             action_code: c_int,
691             param1: *const c_char,
692             param2: *const c_char,
693             db_name: *const c_char,
694             trigger_or_view_name: *const c_char,
695         ) -> c_int
696         where
697             F: FnMut(AuthContext<'c>) -> Authorization + Send + 'static,
698         {
699             catch_unwind(|| {
700                 let action = AuthAction::from_raw(
701                     action_code,
702                     expect_optional_utf8(param1, "authorizer param 1"),
703                     expect_optional_utf8(param2, "authorizer param 2"),
704                 );
705                 let auth_ctx = AuthContext {
706                     action,
707                     database_name: expect_optional_utf8(db_name, "database name"),
708                     accessor: expect_optional_utf8(
709                         trigger_or_view_name,
710                         "accessor (inner-most trigger or view)",
711                     ),
712                 };
713                 let boxed_hook: *mut F = p_arg.cast::<F>();
714                 (*boxed_hook)(auth_ctx)
715             })
716             .map_or_else(|_| ffi::SQLITE_ERROR, Authorization::into_raw)
717         }
718 
719         let callback_fn = authorizer
720             .as_ref()
721             .map(|_| call_boxed_closure::<'c, F> as unsafe extern "C" fn(_, _, _, _, _, _) -> _);
722         let boxed_authorizer = authorizer.map(Box::new);
723 
724         match unsafe {
725             ffi::sqlite3_set_authorizer(
726                 self.db(),
727                 callback_fn,
728                 boxed_authorizer
729                     .as_ref()
730                     .map_or_else(ptr::null_mut, |f| &**f as *const F as *mut _),
731             )
732         } {
733             ffi::SQLITE_OK => {
734                 self.authorizer = boxed_authorizer.map(|ba| ba as _);
735             }
736             err_code => {
737                 // The only error that `sqlite3_set_authorizer` returns is `SQLITE_MISUSE`
738                 // when compiled with `ENABLE_API_ARMOR` and the db pointer is invalid.
739                 // This library does not allow constructing a null db ptr, so if this branch
740                 // is hit, something very bad has happened. Panicking instead of returning
741                 // `Result` keeps this hook's API consistent with the others.
742                 panic!("unexpectedly failed to set_authorizer: {}", unsafe {
743                     crate::error::error_from_handle(self.db(), err_code)
744                 });
745             }
746         }
747     }
748 }
749 
free_boxed_hook<F>(p: *mut c_void)750 unsafe fn free_boxed_hook<F>(p: *mut c_void) {
751     drop(Box::from_raw(p.cast::<F>()));
752 }
753 
expect_utf8<'a>(p_str: *const c_char, description: &'static str) -> &'a str754 unsafe fn expect_utf8<'a>(p_str: *const c_char, description: &'static str) -> &'a str {
755     expect_optional_utf8(p_str, description)
756         .unwrap_or_else(|| panic!("received empty {description}"))
757 }
758 
expect_optional_utf8<'a>( p_str: *const c_char, description: &'static str, ) -> Option<&'a str>759 unsafe fn expect_optional_utf8<'a>(
760     p_str: *const c_char,
761     description: &'static str,
762 ) -> Option<&'a str> {
763     if p_str.is_null() {
764         return None;
765     }
766     std::ffi::CStr::from_ptr(p_str)
767         .to_str()
768         .unwrap_or_else(|_| panic!("received non-utf8 string as {description}"))
769         .into()
770 }
771 
772 #[cfg(test)]
773 mod test {
774     use super::Action;
775     use crate::{Connection, Result};
776     use std::sync::atomic::{AtomicBool, Ordering};
777 
778     #[test]
test_commit_hook() -> Result<()>779     fn test_commit_hook() -> Result<()> {
780         let db = Connection::open_in_memory()?;
781 
782         static CALLED: AtomicBool = AtomicBool::new(false);
783         db.commit_hook(Some(|| {
784             CALLED.store(true, Ordering::Relaxed);
785             false
786         }));
787         db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?;
788         assert!(CALLED.load(Ordering::Relaxed));
789         Ok(())
790     }
791 
792     #[test]
test_fn_commit_hook() -> Result<()>793     fn test_fn_commit_hook() -> Result<()> {
794         let db = Connection::open_in_memory()?;
795 
796         fn hook() -> bool {
797             true
798         }
799 
800         db.commit_hook(Some(hook));
801         db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
802             .unwrap_err();
803         Ok(())
804     }
805 
806     #[test]
test_rollback_hook() -> Result<()>807     fn test_rollback_hook() -> Result<()> {
808         let db = Connection::open_in_memory()?;
809 
810         static CALLED: AtomicBool = AtomicBool::new(false);
811         db.rollback_hook(Some(|| {
812             CALLED.store(true, Ordering::Relaxed);
813         }));
814         db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")?;
815         assert!(CALLED.load(Ordering::Relaxed));
816         Ok(())
817     }
818 
819     #[test]
test_update_hook() -> Result<()>820     fn test_update_hook() -> Result<()> {
821         let db = Connection::open_in_memory()?;
822 
823         static CALLED: AtomicBool = AtomicBool::new(false);
824         db.update_hook(Some(|action, db: &str, tbl: &str, row_id| {
825             assert_eq!(Action::SQLITE_INSERT, action);
826             assert_eq!("main", db);
827             assert_eq!("foo", tbl);
828             assert_eq!(1, row_id);
829             CALLED.store(true, Ordering::Relaxed);
830         }));
831         db.execute_batch("CREATE TABLE foo (t TEXT)")?;
832         db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
833         assert!(CALLED.load(Ordering::Relaxed));
834         Ok(())
835     }
836 
837     #[test]
test_progress_handler() -> Result<()>838     fn test_progress_handler() -> Result<()> {
839         let db = Connection::open_in_memory()?;
840 
841         static CALLED: AtomicBool = AtomicBool::new(false);
842         db.progress_handler(
843             1,
844             Some(|| {
845                 CALLED.store(true, Ordering::Relaxed);
846                 false
847             }),
848         );
849         db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?;
850         assert!(CALLED.load(Ordering::Relaxed));
851         Ok(())
852     }
853 
854     #[test]
test_progress_handler_interrupt() -> Result<()>855     fn test_progress_handler_interrupt() -> Result<()> {
856         let db = Connection::open_in_memory()?;
857 
858         fn handler() -> bool {
859             true
860         }
861 
862         db.progress_handler(1, Some(handler));
863         db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
864             .unwrap_err();
865         Ok(())
866     }
867 
868     #[test]
test_authorizer() -> Result<()>869     fn test_authorizer() -> Result<()> {
870         use super::{AuthAction, AuthContext, Authorization};
871 
872         let db = Connection::open_in_memory()?;
873         db.execute_batch("CREATE TABLE foo (public TEXT, private TEXT)")
874             .unwrap();
875 
876         let authorizer = move |ctx: AuthContext<'_>| match ctx.action {
877             AuthAction::Read {
878                 column_name: "private",
879                 ..
880             } => Authorization::Ignore,
881             AuthAction::DropTable { .. } => Authorization::Deny,
882             AuthAction::Pragma { .. } => panic!("shouldn't be called"),
883             _ => Authorization::Allow,
884         };
885 
886         db.authorizer(Some(authorizer));
887         db.execute_batch(
888             "BEGIN TRANSACTION; INSERT INTO foo VALUES ('pub txt', 'priv txt'); COMMIT;",
889         )
890         .unwrap();
891         db.query_row_and_then("SELECT * FROM foo", [], |row| -> Result<()> {
892             assert_eq!(row.get::<_, String>("public")?, "pub txt");
893             assert!(row.get::<_, Option<String>>("private")?.is_none());
894             Ok(())
895         })
896         .unwrap();
897         db.execute_batch("DROP TABLE foo").unwrap_err();
898 
899         db.authorizer(None::<fn(AuthContext<'_>) -> Authorization>);
900         db.execute_batch("PRAGMA user_version=1").unwrap(); // Disallowed by first authorizer, but it's now removed.
901 
902         Ok(())
903     }
904 }
905