• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::ffi::c_void;
6 use std::ffi::OsString;
7 use std::io;
8 use std::ptr;
9 
10 use winapi::shared::minwindef::ULONG;
11 use winapi::um::winnt::PVOID;
12 
13 use super::unicode_string_to_os_string;
14 
15 // Required for Windows API FFI bindings, as the names of the FFI structs and
16 // functions get called out by the linter.
17 #[allow(non_upper_case_globals)]
18 #[allow(non_camel_case_types)]
19 #[allow(non_snake_case)]
20 #[allow(dead_code)]
21 mod dll_notification_sys {
22     use std::io;
23 
24     use winapi::shared::minwindef::ULONG;
25     use winapi::shared::ntdef::NTSTATUS;
26     use winapi::shared::ntdef::PCUNICODE_STRING;
27     use winapi::shared::ntstatus::STATUS_SUCCESS;
28     use winapi::um::libloaderapi::GetModuleHandleA;
29     use winapi::um::libloaderapi::GetProcAddress;
30     use winapi::um::winnt::CHAR;
31     use winapi::um::winnt::PVOID;
32 
33     #[repr(C)]
34     pub union _LDR_DLL_NOTIFICATION_DATA {
35         pub Loaded: LDR_DLL_LOADED_NOTIFICATION_DATA,
36         pub Unloaded: LDR_DLL_UNLOADED_NOTIFICATION_DATA,
37     }
38     pub type LDR_DLL_NOTIFICATION_DATA = _LDR_DLL_NOTIFICATION_DATA;
39     pub type PLDR_DLL_NOTIFICATION_DATA = *mut LDR_DLL_NOTIFICATION_DATA;
40 
41     #[repr(C)]
42     #[derive(Debug, Copy, Clone)]
43     pub struct _LDR_DLL_LOADED_NOTIFICATION_DATA {
44         pub Flags: ULONG,                  // Reserved.
45         pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module.
46         pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module.
47         pub DllBase: PVOID,                // A pointer to the base address for the DLL in memory.
48         pub SizeOfImage: ULONG,            // The size of the DLL image, in bytes.
49     }
50     pub type LDR_DLL_LOADED_NOTIFICATION_DATA = _LDR_DLL_LOADED_NOTIFICATION_DATA;
51     pub type PLDR_DLL_LOADED_NOTIFICATION_DATA = *mut LDR_DLL_LOADED_NOTIFICATION_DATA;
52 
53     #[repr(C)]
54     #[derive(Debug, Copy, Clone)]
55     pub struct _LDR_DLL_UNLOADED_NOTIFICATION_DATA {
56         pub Flags: ULONG,                  // Reserved.
57         pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module.
58         pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module.
59         pub DllBase: PVOID,                // A pointer to the base address for the DLL in memory.
60         pub SizeOfImage: ULONG,            // The size of the DLL image, in bytes.
61     }
62     pub type LDR_DLL_UNLOADED_NOTIFICATION_DATA = _LDR_DLL_UNLOADED_NOTIFICATION_DATA;
63     pub type PLDR_DLL_UNLOADED_NOTIFICATION_DATA = *mut LDR_DLL_UNLOADED_NOTIFICATION_DATA;
64 
65     pub const LDR_DLL_NOTIFICATION_REASON_LOADED: ULONG = 1;
66     pub const LDR_DLL_NOTIFICATION_REASON_UNLOADED: ULONG = 2;
67 
68     const NTDLL: &'static [u8] = b"ntdll\0";
69     const LDR_REGISTER_DLL_NOTIFICATION: &'static [u8] = b"LdrRegisterDllNotification\0";
70     const LDR_UNREGISTER_DLL_NOTIFICATION: &'static [u8] = b"LdrUnregisterDllNotification\0";
71 
72     pub type LdrDllNotification = unsafe extern "C" fn(
73         NotificationReason: ULONG,
74         NotificationData: PLDR_DLL_NOTIFICATION_DATA,
75         Context: PVOID,
76     );
77 
78     pub type FnLdrRegisterDllNotification =
79         unsafe extern "C" fn(ULONG, LdrDllNotification, PVOID, *mut PVOID) -> NTSTATUS;
80     pub type FnLdrUnregisterDllNotification = unsafe extern "C" fn(PVOID) -> NTSTATUS;
81 
82     extern "C" {
RtlNtStatusToDosError(Status: NTSTATUS) -> ULONG83         pub fn RtlNtStatusToDosError(Status: NTSTATUS) -> ULONG;
84     }
85 
86     /// Wrapper for the NTDLL `LdrRegisterDllNotification` function. Dynamically
87     /// gets the address of the function and invokes the function with the given
88     /// arguments.
89     ///
90     /// # Safety
91     /// Unsafe as this function does not verify its arguments; the caller is
92     /// expected to verify the safety as if invoking the underlying C function.
LdrRegisterDllNotification( Flags: ULONG, NotificationFunction: LdrDllNotification, Context: PVOID, Cookie: *mut PVOID, ) -> io::Result<()>93     pub unsafe fn LdrRegisterDllNotification(
94         Flags: ULONG,
95         NotificationFunction: LdrDllNotification,
96         Context: PVOID,
97         Cookie: *mut PVOID,
98     ) -> io::Result<()> {
99         let proc_addr = GetProcAddress(
100             /* hModule= */
101             GetModuleHandleA(
102                 /* lpModuleName= */ NTDLL.as_ptr() as *const u8 as *const CHAR,
103             ),
104             /* lpProcName= */
105             LDR_REGISTER_DLL_NOTIFICATION.as_ptr() as *const u8 as *const CHAR,
106         );
107         if proc_addr.is_null() {
108             return Err(std::io::Error::last_os_error());
109         }
110         let ldr_register_dll_notification: FnLdrRegisterDllNotification =
111             std::mem::transmute(proc_addr);
112         let ret = ldr_register_dll_notification(Flags, NotificationFunction, Context, Cookie);
113         if ret != STATUS_SUCCESS {
114             return Err(io::Error::from_raw_os_error(
115                 RtlNtStatusToDosError(/* Status= */ ret) as i32,
116             ));
117         };
118         Ok(())
119     }
120 
121     /// Wrapper for the NTDLL `LdrUnregisterDllNotification` function. Dynamically
122     /// gets the address of the function and invokes the function with the given
123     /// arguments.
124     ///
125     /// # Safety
126     /// Unsafe as this function does not verify its arguments; the caller is
127     /// expected to verify the safety as if invoking the underlying C function.
LdrUnregisterDllNotification(Cookie: PVOID) -> io::Result<()>128     pub unsafe fn LdrUnregisterDllNotification(Cookie: PVOID) -> io::Result<()> {
129         let proc_addr = GetProcAddress(
130             /* hModule= */
131             GetModuleHandleA(
132                 /* lpModuleName= */ NTDLL.as_ptr() as *const u8 as *const CHAR,
133             ),
134             /* lpProcName= */
135             LDR_UNREGISTER_DLL_NOTIFICATION.as_ptr() as *const u8 as *const CHAR,
136         );
137         if proc_addr.is_null() {
138             return Err(std::io::Error::last_os_error());
139         }
140         let ldr_unregister_dll_notification: FnLdrUnregisterDllNotification =
141             std::mem::transmute(proc_addr);
142         let ret = ldr_unregister_dll_notification(Cookie);
143         if ret != STATUS_SUCCESS {
144             return Err(io::Error::from_raw_os_error(
145                 RtlNtStatusToDosError(/* Status= */ ret) as i32,
146             ));
147         };
148         Ok(())
149     }
150 }
151 
152 use dll_notification_sys::*;
153 
154 #[derive(Debug)]
155 pub struct DllNotificationData {
156     pub full_dll_name: OsString,
157     pub base_dll_name: OsString,
158 }
159 
160 /// Callback context wrapper for DLL load notification functions.
161 ///
162 /// This struct provides a wrapper for invoking a function-like type any time a
163 /// DLL is loaded in the current process. This is done in a type-safe way,
164 /// provided that users of this struct observe some safety invariants.
165 ///
166 /// # Safety
167 /// The struct instance must not be used once it has been registered as a
168 /// notification target. The callback function assumes that it has a mutable
169 /// reference to the struct instance. Only once the callback is unregistered is
170 /// it safe to re-use the struct instance.
171 struct CallbackContext<F1, F2>
172 where
173     F1: FnMut(DllNotificationData),
174     F2: FnMut(DllNotificationData),
175 {
176     loaded_callback: F1,
177     unloaded_callback: F2,
178 }
179 
180 impl<F1, F2> CallbackContext<F1, F2>
181 where
182     F1: FnMut(DllNotificationData),
183     F2: FnMut(DllNotificationData),
184 {
185     /// Create a new `CallbackContext` with the two callback functions. Takes
186     /// two callbacks, a `loaded_callback` which is called when a DLL is
187     /// loaded, and `unloaded_callback` which is called when a DLL is unloaded.
new(loaded_callback: F1, unloaded_callback: F2) -> Self188     pub fn new(loaded_callback: F1, unloaded_callback: F2) -> Self {
189         CallbackContext {
190             loaded_callback,
191             unloaded_callback,
192         }
193     }
194 
195     /// Provides a notification function that can be passed to the
196     /// `LdrRegisterDllNotification` function.
get_notification_function(&self) -> LdrDllNotification197     pub fn get_notification_function(&self) -> LdrDllNotification {
198         Self::notification_function
199     }
200 
201     /// A notification function with C linkage. This function assumes that it
202     /// has exclusive access to the instance of the struct passed through the
203     /// `context` parameter.
notification_function( notification_reason: ULONG, notification_data: PLDR_DLL_NOTIFICATION_DATA, context: PVOID, )204     extern "C" fn notification_function(
205         notification_reason: ULONG,
206         notification_data: PLDR_DLL_NOTIFICATION_DATA,
207         context: PVOID,
208     ) {
209         // Safe because the DLLWatcher guarantees that the CallbackContext
210         // instance is not null and that we have exclusive access to it.
211         let callback_context =
212             unsafe { (context as *mut Self).as_mut() }.expect("context was null");
213 
214         assert!(!notification_data.is_null());
215 
216         match notification_reason {
217             LDR_DLL_NOTIFICATION_REASON_LOADED => {
218                 // Safe because we know that the LDR_DLL_NOTIFICATION_DATA union
219                 // contains the LDR_DLL_LOADED_NOTIFICATION_DATA because we got
220                 // LDR_DLL_NOTIFICATION_REASON_LOADED as the notification
221                 // reason.
222                 let loaded = unsafe { &mut (*notification_data).Loaded };
223 
224                 assert!(!loaded.BaseDllName.is_null());
225 
226                 // Safe because we assert that the pointer is not null and
227                 // expect that the OS has provided a valid UNICODE_STRING
228                 // struct.
229                 let base_dll_name = unsafe { unicode_string_to_os_string(&*loaded.BaseDllName) };
230 
231                 assert!(!loaded.FullDllName.is_null());
232 
233                 // Safe because we assert that the pointer is not null and
234                 // expect that the OS has provided a valid UNICODE_STRING
235                 // struct.
236                 let full_dll_name = unsafe { unicode_string_to_os_string(&*loaded.FullDllName) };
237 
238                 (callback_context.loaded_callback)(DllNotificationData {
239                     base_dll_name,
240                     full_dll_name,
241                 });
242             }
243             LDR_DLL_NOTIFICATION_REASON_UNLOADED => {
244                 // Safe because we know that the LDR_DLL_NOTIFICATION_DATA union
245                 // contains the LDR_DLL_UNLOADED_NOTIFICATION_DATA because we got
246                 // LDR_DLL_NOTIFICATION_REASON_UNLOADED as the notification
247                 // reason.
248                 let unloaded = unsafe { &mut (*notification_data).Unloaded };
249 
250                 assert!(!unloaded.BaseDllName.is_null());
251 
252                 // Safe because we assert that the pointer is not null and
253                 // expect that the OS has provided a valid UNICODE_STRING
254                 // struct.
255                 let base_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.BaseDllName) };
256 
257                 assert!(!unloaded.FullDllName.is_null());
258 
259                 // Safe because we assert that the pointer is not null and
260                 // expect that the OS has provided a valid UNICODE_STRING
261                 // struct.
262                 let full_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.FullDllName) };
263 
264                 (callback_context.unloaded_callback)(DllNotificationData {
265                     base_dll_name,
266                     full_dll_name,
267                 })
268             }
269             n => panic!("invalid value \"{}\" for dll notification reason", n),
270         }
271     }
272 }
273 
274 /// DLL watcher for monitoring DLL loads/unloads.
275 ///
276 /// Provides a method to invoke a function-like type any time a DLL
277 /// is loaded or unloaded in the current process.
278 pub struct DllWatcher<F1, F2>
279 where
280     F1: FnMut(DllNotificationData),
281     F2: FnMut(DllNotificationData),
282 {
283     context: Box<CallbackContext<F1, F2>>,
284     cookie: Option<ptr::NonNull<c_void>>,
285 }
286 
287 impl<F1, F2> DllWatcher<F1, F2>
288 where
289     F1: FnMut(DllNotificationData),
290     F2: FnMut(DllNotificationData),
291 {
292     /// Create a new `DllWatcher` with the two callback functions. Takes two
293     /// callbacks, a `loaded_callback` which is called when a DLL is loaded,
294     /// and `unloaded_callback` which is called when a DLL is unloaded.
new(loaded_callback: F1, unloaded_callback: F2) -> io::Result<Self>295     pub fn new(loaded_callback: F1, unloaded_callback: F2) -> io::Result<Self> {
296         let mut watcher = Self {
297             context: Box::new(CallbackContext::new(loaded_callback, unloaded_callback)),
298             cookie: None,
299         };
300         let mut cookie: PVOID = ptr::null_mut();
301         // Safe because we guarantee that the notification function that we
302         // register will have exclusive access to the context.
303         unsafe {
304             LdrRegisterDllNotification(
305                 /* Flags= */ 0,
306                 /* NotificationFunction= */ watcher.context.get_notification_function(),
307                 /* Context= */
308                 &mut *watcher.context as *mut CallbackContext<F1, F2> as PVOID,
309                 /* Cookie= */ &mut cookie as *mut PVOID,
310             )?
311         };
312         watcher.cookie = ptr::NonNull::new(cookie);
313         Ok(watcher)
314     }
315 
unregister_dll_notification(&mut self) -> io::Result<()>316     fn unregister_dll_notification(&mut self) -> io::Result<()> {
317         match self.cookie {
318             Some(c) => {
319                 // Safe because we guarantee that `Cookie` was previously initialized.
320                 unsafe {
321                     LdrUnregisterDllNotification(/* Cookie= */ c.as_ptr() as PVOID)?
322                 }
323                 self.cookie = None;
324             }
325             None => {}
326         }
327         Ok(())
328     }
329 }
330 
331 impl<F1, F2> Drop for DllWatcher<F1, F2>
332 where
333     F1: FnMut(DllNotificationData),
334     F2: FnMut(DllNotificationData),
335 {
drop(&mut self)336     fn drop(&mut self) {
337         self.unregister_dll_notification()
338             .expect("error unregistering dll notification");
339     }
340 }
341 
342 #[cfg(test)]
343 mod tests {
344     use std::collections::HashSet;
345     use std::ffi::CString;
346     use std::io;
347 
348     use winapi::shared::minwindef::FALSE;
349     use winapi::shared::minwindef::TRUE;
350     use winapi::um::handleapi::CloseHandle;
351     use winapi::um::libloaderapi::FreeLibrary;
352     use winapi::um::libloaderapi::LoadLibraryA;
353     use winapi::um::synchapi::CreateEventA;
354     use winapi::um::synchapi::SetEvent;
355     use winapi::um::synchapi::WaitForSingleObject;
356     use winapi::um::winbase::WAIT_OBJECT_0;
357 
358     use super::*;
359 
360     // Arbitrarily chosen DLLs for load/unload test. Chosen because they're
361     // hopefully esoteric enough that they're probably not already loaded in
362     // the process so we can test load/unload notifications.
363     //
364     // Using a single DLL can lead to flakiness; since the tests are run in the
365     // same process, it can be hard to rely on the OS to clean up the DLL loaded
366     // by one test before the other test runs. Using a different DLL makes the
367     // tests more independent.
368     const TEST_DLL_NAME_1: &'static str = "Imagehlp.dll";
369     const TEST_DLL_NAME_2: &'static str = "dbghelp.dll";
370 
371     #[test]
load_dll()372     fn load_dll() {
373         let test_dll_name = CString::new(TEST_DLL_NAME_1).expect("failed to create CString");
374         let mut loaded_dlls: HashSet<OsString> = HashSet::new();
375         let h_module = {
376             let _watcher = DllWatcher::new(
377                 |data| {
378                     loaded_dlls.insert(data.base_dll_name);
379                 },
380                 |_data| (),
381             )
382             .expect("failed to create DllWatcher");
383             // Safe because we pass a valid C string in to the function.
384             unsafe { LoadLibraryA(test_dll_name.as_ptr()) }
385         };
386         assert!(
387             !h_module.is_null(),
388             "failed to load {}: {}",
389             TEST_DLL_NAME_1,
390             io::Error::last_os_error()
391         );
392         assert!(
393             loaded_dlls.len() >= 1,
394             "no DLL loads recorded by DLL watcher"
395         );
396         assert!(
397             loaded_dlls.contains::<OsString>(&(TEST_DLL_NAME_1.to_owned().into())),
398             "{} load wasn't recorded by DLL watcher",
399             TEST_DLL_NAME_1
400         );
401         // Safe because we initialized h_module with a LoadLibraryA call.
402         let success = unsafe { FreeLibrary(h_module) } > 0;
403         assert!(
404             success,
405             "failed to free {}: {}",
406             TEST_DLL_NAME_1,
407             io::Error::last_os_error(),
408         )
409     }
410 
411     #[test]
unload_dll()412     fn unload_dll() {
413         let mut unloaded_dlls: HashSet<OsString> = HashSet::new();
414         // Safe as no pointers are passed. The handle may leak if the test fails.
415         let event =
416             unsafe { CreateEventA(std::ptr::null_mut(), TRUE, FALSE, std::ptr::null_mut()) };
417         assert!(
418             !event.is_null(),
419             "failed to create event; event was NULL: {}",
420             io::Error::last_os_error()
421         );
422         {
423             let test_dll_name = CString::new(TEST_DLL_NAME_2).expect("failed to create CString");
424             let _watcher = DllWatcher::new(
425                 |_data| (),
426                 |data| {
427                     unloaded_dlls.insert(data.base_dll_name);
428                     // Safe as we assert that the event is valid above.
429                     unsafe { SetEvent(event) };
430                 },
431             )
432             .expect("failed to create DllWatcher");
433             // Safe because we pass a valid C string in to the function.
434             let h_module = unsafe { LoadLibraryA(test_dll_name.as_ptr()) };
435             assert!(
436                 !h_module.is_null(),
437                 "failed to load {}: {}",
438                 TEST_DLL_NAME_2,
439                 io::Error::last_os_error()
440             );
441             // Safe because we initialized h_module with a LoadLibraryA call.
442             let success = unsafe { FreeLibrary(h_module) } > 0;
443             assert!(
444                 success,
445                 "failed to free {}: {}",
446                 TEST_DLL_NAME_2,
447                 io::Error::last_os_error(),
448             )
449         };
450         // Safe as we assert that the event is valid above.
451         assert_eq!(unsafe { WaitForSingleObject(event, 5000) }, WAIT_OBJECT_0);
452         assert!(
453             unloaded_dlls.len() >= 1,
454             "no DLL unloads recorded by DLL watcher"
455         );
456         assert!(
457             unloaded_dlls.contains::<OsString>(&(TEST_DLL_NAME_2.to_owned().into())),
458             "{} unload wasn't recorded by DLL watcher",
459             TEST_DLL_NAME_2
460         );
461         // Safe as we assert that the event is valid above.
462         unsafe { CloseHandle(event) };
463     }
464 }
465