1 //! Commit, Data Change and Rollback Notification Callbacks
2 #![allow(non_camel_case_types)]
3
4 use std::os::raw::{c_char, c_int, c_void};
5 use std::panic::catch_unwind;
6 use std::ptr;
7
8 use crate::ffi;
9
10 use crate::{Connection, InnerConnection};
11
12 #[cfg(feature = "preupdate_hook")]
13 pub use preupdate_hook::*;
14
15 #[cfg(feature = "preupdate_hook")]
16 mod preupdate_hook;
17
18 /// Action Codes
19 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
20 #[repr(i32)]
21 #[non_exhaustive]
22 #[allow(clippy::upper_case_acronyms)]
23 pub enum Action {
24 /// Unsupported / unexpected action
25 UNKNOWN = -1,
26 /// DELETE command
27 SQLITE_DELETE = ffi::SQLITE_DELETE,
28 /// INSERT command
29 SQLITE_INSERT = ffi::SQLITE_INSERT,
30 /// UPDATE command
31 SQLITE_UPDATE = ffi::SQLITE_UPDATE,
32 }
33
34 impl From<i32> for Action {
35 #[inline]
from(code: i32) -> Action36 fn from(code: i32) -> Action {
37 match code {
38 ffi::SQLITE_DELETE => Action::SQLITE_DELETE,
39 ffi::SQLITE_INSERT => Action::SQLITE_INSERT,
40 ffi::SQLITE_UPDATE => Action::SQLITE_UPDATE,
41 _ => Action::UNKNOWN,
42 }
43 }
44 }
45
46 /// The context received by an authorizer hook.
47 ///
48 /// See <https://sqlite.org/c3ref/set_authorizer.html> for more info.
49 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
50 pub struct AuthContext<'c> {
51 /// The action to be authorized.
52 pub action: AuthAction<'c>,
53
54 /// The database name, if applicable.
55 pub database_name: Option<&'c str>,
56
57 /// The inner-most trigger or view responsible for the access attempt.
58 /// `None` if the access attempt was made by top-level SQL code.
59 pub accessor: Option<&'c str>,
60 }
61
62 /// Actions and arguments found within a statement during
63 /// preparation.
64 ///
65 /// See <https://sqlite.org/c3ref/c_alter_table.html> for more info.
66 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
67 #[non_exhaustive]
68 #[allow(missing_docs)]
69 pub enum AuthAction<'c> {
70 /// This variant is not normally produced by SQLite. You may encounter it
71 // if you're using a different version than what's supported by this library.
72 Unknown {
73 /// The unknown authorization action code.
74 code: i32,
75 /// The third arg to the authorizer callback.
76 arg1: Option<&'c str>,
77 /// The fourth arg to the authorizer callback.
78 arg2: Option<&'c str>,
79 },
80 CreateIndex {
81 index_name: &'c str,
82 table_name: &'c str,
83 },
84 CreateTable {
85 table_name: &'c str,
86 },
87 CreateTempIndex {
88 index_name: &'c str,
89 table_name: &'c str,
90 },
91 CreateTempTable {
92 table_name: &'c str,
93 },
94 CreateTempTrigger {
95 trigger_name: &'c str,
96 table_name: &'c str,
97 },
98 CreateTempView {
99 view_name: &'c str,
100 },
101 CreateTrigger {
102 trigger_name: &'c str,
103 table_name: &'c str,
104 },
105 CreateView {
106 view_name: &'c str,
107 },
108 Delete {
109 table_name: &'c str,
110 },
111 DropIndex {
112 index_name: &'c str,
113 table_name: &'c str,
114 },
115 DropTable {
116 table_name: &'c str,
117 },
118 DropTempIndex {
119 index_name: &'c str,
120 table_name: &'c str,
121 },
122 DropTempTable {
123 table_name: &'c str,
124 },
125 DropTempTrigger {
126 trigger_name: &'c str,
127 table_name: &'c str,
128 },
129 DropTempView {
130 view_name: &'c str,
131 },
132 DropTrigger {
133 trigger_name: &'c str,
134 table_name: &'c str,
135 },
136 DropView {
137 view_name: &'c str,
138 },
139 Insert {
140 table_name: &'c str,
141 },
142 Pragma {
143 pragma_name: &'c str,
144 /// The pragma value, if present (e.g., `PRAGMA name = value;`).
145 pragma_value: Option<&'c str>,
146 },
147 Read {
148 table_name: &'c str,
149 column_name: &'c str,
150 },
151 Select,
152 Transaction {
153 operation: TransactionOperation,
154 },
155 Update {
156 table_name: &'c str,
157 column_name: &'c str,
158 },
159 Attach {
160 filename: &'c str,
161 },
162 Detach {
163 database_name: &'c str,
164 },
165 AlterTable {
166 database_name: &'c str,
167 table_name: &'c str,
168 },
169 Reindex {
170 index_name: &'c str,
171 },
172 Analyze {
173 table_name: &'c str,
174 },
175 CreateVtable {
176 table_name: &'c str,
177 module_name: &'c str,
178 },
179 DropVtable {
180 table_name: &'c str,
181 module_name: &'c str,
182 },
183 Function {
184 function_name: &'c str,
185 },
186 Savepoint {
187 operation: TransactionOperation,
188 savepoint_name: &'c str,
189 },
190 Recursive,
191 }
192
193 impl<'c> AuthAction<'c> {
from_raw(code: i32, arg1: Option<&'c str>, arg2: Option<&'c str>) -> Self194 fn from_raw(code: i32, arg1: Option<&'c str>, arg2: Option<&'c str>) -> Self {
195 match (code, arg1, arg2) {
196 (ffi::SQLITE_CREATE_INDEX, Some(index_name), Some(table_name)) => Self::CreateIndex {
197 index_name,
198 table_name,
199 },
200 (ffi::SQLITE_CREATE_TABLE, Some(table_name), _) => Self::CreateTable { table_name },
201 (ffi::SQLITE_CREATE_TEMP_INDEX, Some(index_name), Some(table_name)) => {
202 Self::CreateTempIndex {
203 index_name,
204 table_name,
205 }
206 }
207 (ffi::SQLITE_CREATE_TEMP_TABLE, Some(table_name), _) => {
208 Self::CreateTempTable { table_name }
209 }
210 (ffi::SQLITE_CREATE_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => {
211 Self::CreateTempTrigger {
212 trigger_name,
213 table_name,
214 }
215 }
216 (ffi::SQLITE_CREATE_TEMP_VIEW, Some(view_name), _) => {
217 Self::CreateTempView { view_name }
218 }
219 (ffi::SQLITE_CREATE_TRIGGER, Some(trigger_name), Some(table_name)) => {
220 Self::CreateTrigger {
221 trigger_name,
222 table_name,
223 }
224 }
225 (ffi::SQLITE_CREATE_VIEW, Some(view_name), _) => Self::CreateView { view_name },
226 (ffi::SQLITE_DELETE, Some(table_name), None) => Self::Delete { table_name },
227 (ffi::SQLITE_DROP_INDEX, Some(index_name), Some(table_name)) => Self::DropIndex {
228 index_name,
229 table_name,
230 },
231 (ffi::SQLITE_DROP_TABLE, Some(table_name), _) => Self::DropTable { table_name },
232 (ffi::SQLITE_DROP_TEMP_INDEX, Some(index_name), Some(table_name)) => {
233 Self::DropTempIndex {
234 index_name,
235 table_name,
236 }
237 }
238 (ffi::SQLITE_DROP_TEMP_TABLE, Some(table_name), _) => {
239 Self::DropTempTable { table_name }
240 }
241 (ffi::SQLITE_DROP_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => {
242 Self::DropTempTrigger {
243 trigger_name,
244 table_name,
245 }
246 }
247 (ffi::SQLITE_DROP_TEMP_VIEW, Some(view_name), _) => Self::DropTempView { view_name },
248 (ffi::SQLITE_DROP_TRIGGER, Some(trigger_name), Some(table_name)) => Self::DropTrigger {
249 trigger_name,
250 table_name,
251 },
252 (ffi::SQLITE_DROP_VIEW, Some(view_name), _) => Self::DropView { view_name },
253 (ffi::SQLITE_INSERT, Some(table_name), _) => Self::Insert { table_name },
254 (ffi::SQLITE_PRAGMA, Some(pragma_name), pragma_value) => Self::Pragma {
255 pragma_name,
256 pragma_value,
257 },
258 (ffi::SQLITE_READ, Some(table_name), Some(column_name)) => Self::Read {
259 table_name,
260 column_name,
261 },
262 (ffi::SQLITE_SELECT, ..) => Self::Select,
263 (ffi::SQLITE_TRANSACTION, Some(operation_str), _) => Self::Transaction {
264 operation: TransactionOperation::from_str(operation_str),
265 },
266 (ffi::SQLITE_UPDATE, Some(table_name), Some(column_name)) => Self::Update {
267 table_name,
268 column_name,
269 },
270 (ffi::SQLITE_ATTACH, Some(filename), _) => Self::Attach { filename },
271 (ffi::SQLITE_DETACH, Some(database_name), _) => Self::Detach { database_name },
272 (ffi::SQLITE_ALTER_TABLE, Some(database_name), Some(table_name)) => Self::AlterTable {
273 database_name,
274 table_name,
275 },
276 (ffi::SQLITE_REINDEX, Some(index_name), _) => Self::Reindex { index_name },
277 (ffi::SQLITE_ANALYZE, Some(table_name), _) => Self::Analyze { table_name },
278 (ffi::SQLITE_CREATE_VTABLE, Some(table_name), Some(module_name)) => {
279 Self::CreateVtable {
280 table_name,
281 module_name,
282 }
283 }
284 (ffi::SQLITE_DROP_VTABLE, Some(table_name), Some(module_name)) => Self::DropVtable {
285 table_name,
286 module_name,
287 },
288 (ffi::SQLITE_FUNCTION, _, Some(function_name)) => Self::Function { function_name },
289 (ffi::SQLITE_SAVEPOINT, Some(operation_str), Some(savepoint_name)) => Self::Savepoint {
290 operation: TransactionOperation::from_str(operation_str),
291 savepoint_name,
292 },
293 (ffi::SQLITE_RECURSIVE, ..) => Self::Recursive,
294 (code, arg1, arg2) => Self::Unknown { code, arg1, arg2 },
295 }
296 }
297 }
298
299 pub(crate) type BoxedAuthorizer =
300 Box<dyn for<'c> FnMut(AuthContext<'c>) -> Authorization + Send + 'static>;
301
302 /// A transaction operation.
303 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
304 #[non_exhaustive]
305 #[allow(missing_docs)]
306 pub enum TransactionOperation {
307 Unknown,
308 Begin,
309 Release,
310 Rollback,
311 }
312
313 impl TransactionOperation {
from_str(op_str: &str) -> Self314 fn from_str(op_str: &str) -> Self {
315 match op_str {
316 "BEGIN" => Self::Begin,
317 "RELEASE" => Self::Release,
318 "ROLLBACK" => Self::Rollback,
319 _ => Self::Unknown,
320 }
321 }
322 }
323
324 /// [`authorizer`](Connection::authorizer) return code
325 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
326 #[non_exhaustive]
327 pub enum Authorization {
328 /// Authorize the action.
329 Allow,
330 /// Don't allow access, but don't trigger an error either.
331 Ignore,
332 /// Trigger an error.
333 Deny,
334 }
335
336 impl Authorization {
into_raw(self) -> c_int337 fn into_raw(self) -> c_int {
338 match self {
339 Self::Allow => ffi::SQLITE_OK,
340 Self::Ignore => ffi::SQLITE_IGNORE,
341 Self::Deny => ffi::SQLITE_DENY,
342 }
343 }
344 }
345
346 impl Connection {
347 /// Register a callback function to be invoked whenever
348 /// a transaction is committed.
349 ///
350 /// The callback returns `true` to rollback.
351 #[inline]
commit_hook<F>(&self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static,352 pub fn commit_hook<F>(&self, hook: Option<F>)
353 where
354 F: FnMut() -> bool + Send + 'static,
355 {
356 self.db.borrow_mut().commit_hook(hook);
357 }
358
359 /// Register a callback function to be invoked whenever
360 /// a transaction is committed.
361 #[inline]
rollback_hook<F>(&self, hook: Option<F>) where F: FnMut() + Send + 'static,362 pub fn rollback_hook<F>(&self, hook: Option<F>)
363 where
364 F: FnMut() + Send + 'static,
365 {
366 self.db.borrow_mut().rollback_hook(hook);
367 }
368
369 /// Register a callback function to be invoked whenever
370 /// a row is updated, inserted or deleted in a rowid table.
371 ///
372 /// The callback parameters are:
373 ///
374 /// - the type of database update (`SQLITE_INSERT`, `SQLITE_UPDATE` or
375 /// `SQLITE_DELETE`),
376 /// - the name of the database ("main", "temp", ...),
377 /// - the name of the table that is updated,
378 /// - the ROWID of the row that is updated.
379 #[inline]
update_hook<F>(&self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static,380 pub fn update_hook<F>(&self, hook: Option<F>)
381 where
382 F: FnMut(Action, &str, &str, i64) + Send + 'static,
383 {
384 self.db.borrow_mut().update_hook(hook);
385 }
386
387 /// Register a query progress callback.
388 ///
389 /// The parameter `num_ops` is the approximate number of virtual machine
390 /// instructions that are evaluated between successive invocations of the
391 /// `handler`. If `num_ops` is less than one then the progress handler
392 /// is disabled.
393 ///
394 /// If the progress callback returns `true`, the operation is interrupted.
progress_handler<F>(&self, num_ops: c_int, handler: Option<F>) where F: FnMut() -> bool + Send + 'static,395 pub fn progress_handler<F>(&self, num_ops: c_int, handler: Option<F>)
396 where
397 F: FnMut() -> bool + Send + 'static,
398 {
399 self.db.borrow_mut().progress_handler(num_ops, handler);
400 }
401
402 /// Register an authorizer callback that's invoked
403 /// as a statement is being prepared.
404 #[inline]
authorizer<'c, F>(&self, hook: Option<F>) where F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,405 pub fn authorizer<'c, F>(&self, hook: Option<F>)
406 where
407 F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,
408 {
409 self.db.borrow_mut().authorizer(hook);
410 }
411 }
412
413 impl InnerConnection {
414 #[inline]
remove_hooks(&mut self)415 pub fn remove_hooks(&mut self) {
416 self.update_hook(None::<fn(Action, &str, &str, i64)>);
417 self.commit_hook(None::<fn() -> bool>);
418 self.rollback_hook(None::<fn()>);
419 self.progress_handler(0, None::<fn() -> bool>);
420 self.authorizer(None::<fn(AuthContext<'_>) -> Authorization>);
421 }
422
423 /// ```compile_fail
424 /// use rusqlite::{Connection, Result};
425 /// fn main() -> Result<()> {
426 /// let db = Connection::open_in_memory()?;
427 /// {
428 /// let mut called = std::sync::atomic::AtomicBool::new(false);
429 /// db.commit_hook(Some(|| {
430 /// called.store(true, std::sync::atomic::Ordering::Relaxed);
431 /// true
432 /// }));
433 /// }
434 /// assert!(db
435 /// .execute_batch(
436 /// "BEGIN;
437 /// CREATE TABLE foo (t TEXT);
438 /// COMMIT;",
439 /// )
440 /// .is_err());
441 /// Ok(())
442 /// }
443 /// ```
commit_hook<F>(&mut self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static,444 fn commit_hook<F>(&mut self, hook: Option<F>)
445 where
446 F: FnMut() -> bool + Send + 'static,
447 {
448 unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
449 where
450 F: FnMut() -> bool,
451 {
452 let r = catch_unwind(|| {
453 let boxed_hook: *mut F = p_arg.cast::<F>();
454 (*boxed_hook)()
455 });
456 c_int::from(r.unwrap_or_default())
457 }
458
459 // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with
460 // `sqlite3_commit_hook`. so we keep the `xDestroy` function in
461 // `InnerConnection.free_boxed_hook`.
462 let free_commit_hook = if hook.is_some() {
463 Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
464 } else {
465 None
466 };
467
468 let previous_hook = match hook {
469 Some(hook) => {
470 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
471 unsafe {
472 ffi::sqlite3_commit_hook(
473 self.db(),
474 Some(call_boxed_closure::<F>),
475 boxed_hook.cast(),
476 )
477 }
478 }
479 _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) },
480 };
481 if !previous_hook.is_null() {
482 if let Some(free_boxed_hook) = self.free_commit_hook {
483 unsafe { free_boxed_hook(previous_hook) };
484 }
485 }
486 self.free_commit_hook = free_commit_hook;
487 }
488
489 /// ```compile_fail
490 /// use rusqlite::{Connection, Result};
491 /// fn main() -> Result<()> {
492 /// let db = Connection::open_in_memory()?;
493 /// {
494 /// let mut called = std::sync::atomic::AtomicBool::new(false);
495 /// db.rollback_hook(Some(|| {
496 /// called.store(true, std::sync::atomic::Ordering::Relaxed);
497 /// }));
498 /// }
499 /// assert!(db
500 /// .execute_batch(
501 /// "BEGIN;
502 /// CREATE TABLE foo (t TEXT);
503 /// ROLLBACK;",
504 /// )
505 /// .is_err());
506 /// Ok(())
507 /// }
508 /// ```
rollback_hook<F>(&mut self, hook: Option<F>) where F: FnMut() + Send + 'static,509 fn rollback_hook<F>(&mut self, hook: Option<F>)
510 where
511 F: FnMut() + Send + 'static,
512 {
513 unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void)
514 where
515 F: FnMut(),
516 {
517 drop(catch_unwind(|| {
518 let boxed_hook: *mut F = p_arg.cast::<F>();
519 (*boxed_hook)();
520 }));
521 }
522
523 let free_rollback_hook = if hook.is_some() {
524 Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
525 } else {
526 None
527 };
528
529 let previous_hook = match hook {
530 Some(hook) => {
531 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
532 unsafe {
533 ffi::sqlite3_rollback_hook(
534 self.db(),
535 Some(call_boxed_closure::<F>),
536 boxed_hook.cast(),
537 )
538 }
539 }
540 _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) },
541 };
542 if !previous_hook.is_null() {
543 if let Some(free_boxed_hook) = self.free_rollback_hook {
544 unsafe { free_boxed_hook(previous_hook) };
545 }
546 }
547 self.free_rollback_hook = free_rollback_hook;
548 }
549
550 /// ```compile_fail
551 /// use rusqlite::{Connection, Result};
552 /// fn main() -> Result<()> {
553 /// let db = Connection::open_in_memory()?;
554 /// {
555 /// let mut called = std::sync::atomic::AtomicBool::new(false);
556 /// db.update_hook(Some(|_, _: &str, _: &str, _| {
557 /// called.store(true, std::sync::atomic::Ordering::Relaxed);
558 /// }));
559 /// }
560 /// db.execute_batch("CREATE TABLE foo AS SELECT 1 AS bar;")
561 /// }
562 /// ```
update_hook<F>(&mut self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static,563 fn update_hook<F>(&mut self, hook: Option<F>)
564 where
565 F: FnMut(Action, &str, &str, i64) + Send + 'static,
566 {
567 unsafe extern "C" fn call_boxed_closure<F>(
568 p_arg: *mut c_void,
569 action_code: c_int,
570 p_db_name: *const c_char,
571 p_table_name: *const c_char,
572 row_id: i64,
573 ) where
574 F: FnMut(Action, &str, &str, i64),
575 {
576 let action = Action::from(action_code);
577 drop(catch_unwind(|| {
578 let boxed_hook: *mut F = p_arg.cast::<F>();
579 (*boxed_hook)(
580 action,
581 expect_utf8(p_db_name, "database name"),
582 expect_utf8(p_table_name, "table name"),
583 row_id,
584 );
585 }));
586 }
587
588 let free_update_hook = if hook.is_some() {
589 Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
590 } else {
591 None
592 };
593
594 let previous_hook = match hook {
595 Some(hook) => {
596 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
597 unsafe {
598 ffi::sqlite3_update_hook(
599 self.db(),
600 Some(call_boxed_closure::<F>),
601 boxed_hook.cast(),
602 )
603 }
604 }
605 _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) },
606 };
607 if !previous_hook.is_null() {
608 if let Some(free_boxed_hook) = self.free_update_hook {
609 unsafe { free_boxed_hook(previous_hook) };
610 }
611 }
612 self.free_update_hook = free_update_hook;
613 }
614
615 /// ```compile_fail
616 /// use rusqlite::{Connection, Result};
617 /// fn main() -> Result<()> {
618 /// let db = Connection::open_in_memory()?;
619 /// {
620 /// let mut called = std::sync::atomic::AtomicBool::new(false);
621 /// db.progress_handler(
622 /// 1,
623 /// Some(|| {
624 /// called.store(true, std::sync::atomic::Ordering::Relaxed);
625 /// true
626 /// }),
627 /// );
628 /// }
629 /// assert!(db
630 /// .execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
631 /// .is_err());
632 /// Ok(())
633 /// }
634 /// ```
progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>) where F: FnMut() -> bool + Send + 'static,635 fn progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>)
636 where
637 F: FnMut() -> bool + Send + 'static,
638 {
639 unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
640 where
641 F: FnMut() -> bool,
642 {
643 let r = catch_unwind(|| {
644 let boxed_handler: *mut F = p_arg.cast::<F>();
645 (*boxed_handler)()
646 });
647 c_int::from(r.unwrap_or_default())
648 }
649
650 if let Some(handler) = handler {
651 let boxed_handler = Box::new(handler);
652 unsafe {
653 ffi::sqlite3_progress_handler(
654 self.db(),
655 num_ops,
656 Some(call_boxed_closure::<F>),
657 &*boxed_handler as *const F as *mut _,
658 );
659 }
660 self.progress_handler = Some(boxed_handler);
661 } else {
662 unsafe { ffi::sqlite3_progress_handler(self.db(), num_ops, None, ptr::null_mut()) }
663 self.progress_handler = None;
664 };
665 }
666
667 /// ```compile_fail
668 /// use rusqlite::{Connection, Result};
669 /// fn main() -> Result<()> {
670 /// let db = Connection::open_in_memory()?;
671 /// {
672 /// let mut called = std::sync::atomic::AtomicBool::new(false);
673 /// db.authorizer(Some(|_: rusqlite::hooks::AuthContext<'_>| {
674 /// called.store(true, std::sync::atomic::Ordering::Relaxed);
675 /// rusqlite::hooks::Authorization::Deny
676 /// }));
677 /// }
678 /// assert!(db
679 /// .execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
680 /// .is_err());
681 /// Ok(())
682 /// }
683 /// ```
authorizer<'c, F>(&'c mut self, authorizer: Option<F>) where F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,684 fn authorizer<'c, F>(&'c mut self, authorizer: Option<F>)
685 where
686 F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,
687 {
688 unsafe extern "C" fn call_boxed_closure<'c, F>(
689 p_arg: *mut c_void,
690 action_code: c_int,
691 param1: *const c_char,
692 param2: *const c_char,
693 db_name: *const c_char,
694 trigger_or_view_name: *const c_char,
695 ) -> c_int
696 where
697 F: FnMut(AuthContext<'c>) -> Authorization + Send + 'static,
698 {
699 catch_unwind(|| {
700 let action = AuthAction::from_raw(
701 action_code,
702 expect_optional_utf8(param1, "authorizer param 1"),
703 expect_optional_utf8(param2, "authorizer param 2"),
704 );
705 let auth_ctx = AuthContext {
706 action,
707 database_name: expect_optional_utf8(db_name, "database name"),
708 accessor: expect_optional_utf8(
709 trigger_or_view_name,
710 "accessor (inner-most trigger or view)",
711 ),
712 };
713 let boxed_hook: *mut F = p_arg.cast::<F>();
714 (*boxed_hook)(auth_ctx)
715 })
716 .map_or_else(|_| ffi::SQLITE_ERROR, Authorization::into_raw)
717 }
718
719 let callback_fn = authorizer
720 .as_ref()
721 .map(|_| call_boxed_closure::<'c, F> as unsafe extern "C" fn(_, _, _, _, _, _) -> _);
722 let boxed_authorizer = authorizer.map(Box::new);
723
724 match unsafe {
725 ffi::sqlite3_set_authorizer(
726 self.db(),
727 callback_fn,
728 boxed_authorizer
729 .as_ref()
730 .map_or_else(ptr::null_mut, |f| &**f as *const F as *mut _),
731 )
732 } {
733 ffi::SQLITE_OK => {
734 self.authorizer = boxed_authorizer.map(|ba| ba as _);
735 }
736 err_code => {
737 // The only error that `sqlite3_set_authorizer` returns is `SQLITE_MISUSE`
738 // when compiled with `ENABLE_API_ARMOR` and the db pointer is invalid.
739 // This library does not allow constructing a null db ptr, so if this branch
740 // is hit, something very bad has happened. Panicking instead of returning
741 // `Result` keeps this hook's API consistent with the others.
742 panic!("unexpectedly failed to set_authorizer: {}", unsafe {
743 crate::error::error_from_handle(self.db(), err_code)
744 });
745 }
746 }
747 }
748 }
749
free_boxed_hook<F>(p: *mut c_void)750 unsafe fn free_boxed_hook<F>(p: *mut c_void) {
751 drop(Box::from_raw(p.cast::<F>()));
752 }
753
expect_utf8<'a>(p_str: *const c_char, description: &'static str) -> &'a str754 unsafe fn expect_utf8<'a>(p_str: *const c_char, description: &'static str) -> &'a str {
755 expect_optional_utf8(p_str, description)
756 .unwrap_or_else(|| panic!("received empty {description}"))
757 }
758
expect_optional_utf8<'a>( p_str: *const c_char, description: &'static str, ) -> Option<&'a str>759 unsafe fn expect_optional_utf8<'a>(
760 p_str: *const c_char,
761 description: &'static str,
762 ) -> Option<&'a str> {
763 if p_str.is_null() {
764 return None;
765 }
766 std::ffi::CStr::from_ptr(p_str)
767 .to_str()
768 .unwrap_or_else(|_| panic!("received non-utf8 string as {description}"))
769 .into()
770 }
771
772 #[cfg(test)]
773 mod test {
774 use super::Action;
775 use crate::{Connection, Result};
776 use std::sync::atomic::{AtomicBool, Ordering};
777
778 #[test]
test_commit_hook() -> Result<()>779 fn test_commit_hook() -> Result<()> {
780 let db = Connection::open_in_memory()?;
781
782 static CALLED: AtomicBool = AtomicBool::new(false);
783 db.commit_hook(Some(|| {
784 CALLED.store(true, Ordering::Relaxed);
785 false
786 }));
787 db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?;
788 assert!(CALLED.load(Ordering::Relaxed));
789 Ok(())
790 }
791
792 #[test]
test_fn_commit_hook() -> Result<()>793 fn test_fn_commit_hook() -> Result<()> {
794 let db = Connection::open_in_memory()?;
795
796 fn hook() -> bool {
797 true
798 }
799
800 db.commit_hook(Some(hook));
801 db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
802 .unwrap_err();
803 Ok(())
804 }
805
806 #[test]
test_rollback_hook() -> Result<()>807 fn test_rollback_hook() -> Result<()> {
808 let db = Connection::open_in_memory()?;
809
810 static CALLED: AtomicBool = AtomicBool::new(false);
811 db.rollback_hook(Some(|| {
812 CALLED.store(true, Ordering::Relaxed);
813 }));
814 db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")?;
815 assert!(CALLED.load(Ordering::Relaxed));
816 Ok(())
817 }
818
819 #[test]
test_update_hook() -> Result<()>820 fn test_update_hook() -> Result<()> {
821 let db = Connection::open_in_memory()?;
822
823 static CALLED: AtomicBool = AtomicBool::new(false);
824 db.update_hook(Some(|action, db: &str, tbl: &str, row_id| {
825 assert_eq!(Action::SQLITE_INSERT, action);
826 assert_eq!("main", db);
827 assert_eq!("foo", tbl);
828 assert_eq!(1, row_id);
829 CALLED.store(true, Ordering::Relaxed);
830 }));
831 db.execute_batch("CREATE TABLE foo (t TEXT)")?;
832 db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
833 assert!(CALLED.load(Ordering::Relaxed));
834 Ok(())
835 }
836
837 #[test]
test_progress_handler() -> Result<()>838 fn test_progress_handler() -> Result<()> {
839 let db = Connection::open_in_memory()?;
840
841 static CALLED: AtomicBool = AtomicBool::new(false);
842 db.progress_handler(
843 1,
844 Some(|| {
845 CALLED.store(true, Ordering::Relaxed);
846 false
847 }),
848 );
849 db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?;
850 assert!(CALLED.load(Ordering::Relaxed));
851 Ok(())
852 }
853
854 #[test]
test_progress_handler_interrupt() -> Result<()>855 fn test_progress_handler_interrupt() -> Result<()> {
856 let db = Connection::open_in_memory()?;
857
858 fn handler() -> bool {
859 true
860 }
861
862 db.progress_handler(1, Some(handler));
863 db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
864 .unwrap_err();
865 Ok(())
866 }
867
868 #[test]
test_authorizer() -> Result<()>869 fn test_authorizer() -> Result<()> {
870 use super::{AuthAction, AuthContext, Authorization};
871
872 let db = Connection::open_in_memory()?;
873 db.execute_batch("CREATE TABLE foo (public TEXT, private TEXT)")
874 .unwrap();
875
876 let authorizer = move |ctx: AuthContext<'_>| match ctx.action {
877 AuthAction::Read {
878 column_name: "private",
879 ..
880 } => Authorization::Ignore,
881 AuthAction::DropTable { .. } => Authorization::Deny,
882 AuthAction::Pragma { .. } => panic!("shouldn't be called"),
883 _ => Authorization::Allow,
884 };
885
886 db.authorizer(Some(authorizer));
887 db.execute_batch(
888 "BEGIN TRANSACTION; INSERT INTO foo VALUES ('pub txt', 'priv txt'); COMMIT;",
889 )
890 .unwrap();
891 db.query_row_and_then("SELECT * FROM foo", [], |row| -> Result<()> {
892 assert_eq!(row.get::<_, String>("public")?, "pub txt");
893 assert!(row.get::<_, Option<String>>("private")?.is_none());
894 Ok(())
895 })
896 .unwrap();
897 db.execute_batch("DROP TABLE foo").unwrap_err();
898
899 db.authorizer(None::<fn(AuthContext<'_>) -> Authorization>);
900 db.execute_batch("PRAGMA user_version=1").unwrap(); // Disallowed by first authorizer, but it's now removed.
901
902 Ok(())
903 }
904 }
905