1 // Copyright 2022 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 use std::ffi::c_void; 6 use std::ffi::OsString; 7 use std::io; 8 use std::ptr; 9 10 use winapi::shared::minwindef::ULONG; 11 use winapi::um::winnt::PVOID; 12 13 use super::unicode_string_to_os_string; 14 15 // Required for Windows API FFI bindings, as the names of the FFI structs and 16 // functions get called out by the linter. 17 #[allow(non_upper_case_globals)] 18 #[allow(non_camel_case_types)] 19 #[allow(non_snake_case)] 20 #[allow(dead_code)] 21 mod dll_notification_sys { 22 use std::io; 23 24 use winapi::shared::minwindef::ULONG; 25 use winapi::shared::ntdef::NTSTATUS; 26 use winapi::shared::ntdef::PCUNICODE_STRING; 27 use winapi::shared::ntstatus::STATUS_SUCCESS; 28 use winapi::um::libloaderapi::GetModuleHandleA; 29 use winapi::um::libloaderapi::GetProcAddress; 30 use winapi::um::winnt::CHAR; 31 use winapi::um::winnt::PVOID; 32 33 #[repr(C)] 34 pub union _LDR_DLL_NOTIFICATION_DATA { 35 pub Loaded: LDR_DLL_LOADED_NOTIFICATION_DATA, 36 pub Unloaded: LDR_DLL_UNLOADED_NOTIFICATION_DATA, 37 } 38 pub type LDR_DLL_NOTIFICATION_DATA = _LDR_DLL_NOTIFICATION_DATA; 39 pub type PLDR_DLL_NOTIFICATION_DATA = *mut LDR_DLL_NOTIFICATION_DATA; 40 41 #[repr(C)] 42 #[derive(Debug, Copy, Clone)] 43 pub struct _LDR_DLL_LOADED_NOTIFICATION_DATA { 44 pub Flags: ULONG, // Reserved. 45 pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module. 46 pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module. 47 pub DllBase: PVOID, // A pointer to the base address for the DLL in memory. 48 pub SizeOfImage: ULONG, // The size of the DLL image, in bytes. 49 } 50 pub type LDR_DLL_LOADED_NOTIFICATION_DATA = _LDR_DLL_LOADED_NOTIFICATION_DATA; 51 pub type PLDR_DLL_LOADED_NOTIFICATION_DATA = *mut LDR_DLL_LOADED_NOTIFICATION_DATA; 52 53 #[repr(C)] 54 #[derive(Debug, Copy, Clone)] 55 pub struct _LDR_DLL_UNLOADED_NOTIFICATION_DATA { 56 pub Flags: ULONG, // Reserved. 57 pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module. 58 pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module. 59 pub DllBase: PVOID, // A pointer to the base address for the DLL in memory. 60 pub SizeOfImage: ULONG, // The size of the DLL image, in bytes. 61 } 62 pub type LDR_DLL_UNLOADED_NOTIFICATION_DATA = _LDR_DLL_UNLOADED_NOTIFICATION_DATA; 63 pub type PLDR_DLL_UNLOADED_NOTIFICATION_DATA = *mut LDR_DLL_UNLOADED_NOTIFICATION_DATA; 64 65 pub const LDR_DLL_NOTIFICATION_REASON_LOADED: ULONG = 1; 66 pub const LDR_DLL_NOTIFICATION_REASON_UNLOADED: ULONG = 2; 67 68 const NTDLL: &'static [u8] = b"ntdll\0"; 69 const LDR_REGISTER_DLL_NOTIFICATION: &'static [u8] = b"LdrRegisterDllNotification\0"; 70 const LDR_UNREGISTER_DLL_NOTIFICATION: &'static [u8] = b"LdrUnregisterDllNotification\0"; 71 72 pub type LdrDllNotification = unsafe extern "C" fn( 73 NotificationReason: ULONG, 74 NotificationData: PLDR_DLL_NOTIFICATION_DATA, 75 Context: PVOID, 76 ); 77 78 pub type FnLdrRegisterDllNotification = 79 unsafe extern "C" fn(ULONG, LdrDllNotification, PVOID, *mut PVOID) -> NTSTATUS; 80 pub type FnLdrUnregisterDllNotification = unsafe extern "C" fn(PVOID) -> NTSTATUS; 81 82 extern "C" { RtlNtStatusToDosError(Status: NTSTATUS) -> ULONG83 pub fn RtlNtStatusToDosError(Status: NTSTATUS) -> ULONG; 84 } 85 86 /// Wrapper for the NTDLL `LdrRegisterDllNotification` function. Dynamically 87 /// gets the address of the function and invokes the function with the given 88 /// arguments. 89 /// 90 /// # Safety 91 /// Unsafe as this function does not verify its arguments; the caller is 92 /// expected to verify the safety as if invoking the underlying C function. LdrRegisterDllNotification( Flags: ULONG, NotificationFunction: LdrDllNotification, Context: PVOID, Cookie: *mut PVOID, ) -> io::Result<()>93 pub unsafe fn LdrRegisterDllNotification( 94 Flags: ULONG, 95 NotificationFunction: LdrDllNotification, 96 Context: PVOID, 97 Cookie: *mut PVOID, 98 ) -> io::Result<()> { 99 let proc_addr = GetProcAddress( 100 /* hModule= */ 101 GetModuleHandleA( 102 /* lpModuleName= */ NTDLL.as_ptr() as *const u8 as *const CHAR, 103 ), 104 /* lpProcName= */ 105 LDR_REGISTER_DLL_NOTIFICATION.as_ptr() as *const u8 as *const CHAR, 106 ); 107 if proc_addr.is_null() { 108 return Err(std::io::Error::last_os_error()); 109 } 110 let ldr_register_dll_notification: FnLdrRegisterDllNotification = 111 std::mem::transmute(proc_addr); 112 let ret = ldr_register_dll_notification(Flags, NotificationFunction, Context, Cookie); 113 if ret != STATUS_SUCCESS { 114 return Err(io::Error::from_raw_os_error( 115 RtlNtStatusToDosError(/* Status= */ ret) as i32, 116 )); 117 }; 118 Ok(()) 119 } 120 121 /// Wrapper for the NTDLL `LdrUnregisterDllNotification` function. Dynamically 122 /// gets the address of the function and invokes the function with the given 123 /// arguments. 124 /// 125 /// # Safety 126 /// Unsafe as this function does not verify its arguments; the caller is 127 /// expected to verify the safety as if invoking the underlying C function. LdrUnregisterDllNotification(Cookie: PVOID) -> io::Result<()>128 pub unsafe fn LdrUnregisterDllNotification(Cookie: PVOID) -> io::Result<()> { 129 let proc_addr = GetProcAddress( 130 /* hModule= */ 131 GetModuleHandleA( 132 /* lpModuleName= */ NTDLL.as_ptr() as *const u8 as *const CHAR, 133 ), 134 /* lpProcName= */ 135 LDR_UNREGISTER_DLL_NOTIFICATION.as_ptr() as *const u8 as *const CHAR, 136 ); 137 if proc_addr.is_null() { 138 return Err(std::io::Error::last_os_error()); 139 } 140 let ldr_unregister_dll_notification: FnLdrUnregisterDllNotification = 141 std::mem::transmute(proc_addr); 142 let ret = ldr_unregister_dll_notification(Cookie); 143 if ret != STATUS_SUCCESS { 144 return Err(io::Error::from_raw_os_error( 145 RtlNtStatusToDosError(/* Status= */ ret) as i32, 146 )); 147 }; 148 Ok(()) 149 } 150 } 151 152 use dll_notification_sys::*; 153 154 #[derive(Debug)] 155 pub struct DllNotificationData { 156 pub full_dll_name: OsString, 157 pub base_dll_name: OsString, 158 } 159 160 /// Callback context wrapper for DLL load notification functions. 161 /// 162 /// This struct provides a wrapper for invoking a function-like type any time a 163 /// DLL is loaded in the current process. This is done in a type-safe way, 164 /// provided that users of this struct observe some safety invariants. 165 /// 166 /// # Safety 167 /// The struct instance must not be used once it has been registered as a 168 /// notification target. The callback function assumes that it has a mutable 169 /// reference to the struct instance. Only once the callback is unregistered is 170 /// it safe to re-use the struct instance. 171 struct CallbackContext<F1, F2> 172 where 173 F1: FnMut(DllNotificationData), 174 F2: FnMut(DllNotificationData), 175 { 176 loaded_callback: F1, 177 unloaded_callback: F2, 178 } 179 180 impl<F1, F2> CallbackContext<F1, F2> 181 where 182 F1: FnMut(DllNotificationData), 183 F2: FnMut(DllNotificationData), 184 { 185 /// Create a new `CallbackContext` with the two callback functions. Takes 186 /// two callbacks, a `loaded_callback` which is called when a DLL is 187 /// loaded, and `unloaded_callback` which is called when a DLL is unloaded. new(loaded_callback: F1, unloaded_callback: F2) -> Self188 pub fn new(loaded_callback: F1, unloaded_callback: F2) -> Self { 189 CallbackContext { 190 loaded_callback, 191 unloaded_callback, 192 } 193 } 194 195 /// Provides a notification function that can be passed to the 196 /// `LdrRegisterDllNotification` function. get_notification_function(&self) -> LdrDllNotification197 pub fn get_notification_function(&self) -> LdrDllNotification { 198 Self::notification_function 199 } 200 201 /// A notification function with C linkage. This function assumes that it 202 /// has exclusive access to the instance of the struct passed through the 203 /// `context` parameter. notification_function( notification_reason: ULONG, notification_data: PLDR_DLL_NOTIFICATION_DATA, context: PVOID, )204 extern "C" fn notification_function( 205 notification_reason: ULONG, 206 notification_data: PLDR_DLL_NOTIFICATION_DATA, 207 context: PVOID, 208 ) { 209 // Safe because the DLLWatcher guarantees that the CallbackContext 210 // instance is not null and that we have exclusive access to it. 211 let callback_context = 212 unsafe { (context as *mut Self).as_mut() }.expect("context was null"); 213 214 assert!(!notification_data.is_null()); 215 216 match notification_reason { 217 LDR_DLL_NOTIFICATION_REASON_LOADED => { 218 // Safe because we know that the LDR_DLL_NOTIFICATION_DATA union 219 // contains the LDR_DLL_LOADED_NOTIFICATION_DATA because we got 220 // LDR_DLL_NOTIFICATION_REASON_LOADED as the notification 221 // reason. 222 let loaded = unsafe { &mut (*notification_data).Loaded }; 223 224 assert!(!loaded.BaseDllName.is_null()); 225 226 // Safe because we assert that the pointer is not null and 227 // expect that the OS has provided a valid UNICODE_STRING 228 // struct. 229 let base_dll_name = unsafe { unicode_string_to_os_string(&*loaded.BaseDllName) }; 230 231 assert!(!loaded.FullDllName.is_null()); 232 233 // Safe because we assert that the pointer is not null and 234 // expect that the OS has provided a valid UNICODE_STRING 235 // struct. 236 let full_dll_name = unsafe { unicode_string_to_os_string(&*loaded.FullDllName) }; 237 238 (callback_context.loaded_callback)(DllNotificationData { 239 base_dll_name, 240 full_dll_name, 241 }); 242 } 243 LDR_DLL_NOTIFICATION_REASON_UNLOADED => { 244 // Safe because we know that the LDR_DLL_NOTIFICATION_DATA union 245 // contains the LDR_DLL_UNLOADED_NOTIFICATION_DATA because we got 246 // LDR_DLL_NOTIFICATION_REASON_UNLOADED as the notification 247 // reason. 248 let unloaded = unsafe { &mut (*notification_data).Unloaded }; 249 250 assert!(!unloaded.BaseDllName.is_null()); 251 252 // Safe because we assert that the pointer is not null and 253 // expect that the OS has provided a valid UNICODE_STRING 254 // struct. 255 let base_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.BaseDllName) }; 256 257 assert!(!unloaded.FullDllName.is_null()); 258 259 // Safe because we assert that the pointer is not null and 260 // expect that the OS has provided a valid UNICODE_STRING 261 // struct. 262 let full_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.FullDllName) }; 263 264 (callback_context.unloaded_callback)(DllNotificationData { 265 base_dll_name, 266 full_dll_name, 267 }) 268 } 269 n => panic!("invalid value \"{}\" for dll notification reason", n), 270 } 271 } 272 } 273 274 /// DLL watcher for monitoring DLL loads/unloads. 275 /// 276 /// Provides a method to invoke a function-like type any time a DLL 277 /// is loaded or unloaded in the current process. 278 pub struct DllWatcher<F1, F2> 279 where 280 F1: FnMut(DllNotificationData), 281 F2: FnMut(DllNotificationData), 282 { 283 context: Box<CallbackContext<F1, F2>>, 284 cookie: Option<ptr::NonNull<c_void>>, 285 } 286 287 impl<F1, F2> DllWatcher<F1, F2> 288 where 289 F1: FnMut(DllNotificationData), 290 F2: FnMut(DllNotificationData), 291 { 292 /// Create a new `DllWatcher` with the two callback functions. Takes two 293 /// callbacks, a `loaded_callback` which is called when a DLL is loaded, 294 /// and `unloaded_callback` which is called when a DLL is unloaded. new(loaded_callback: F1, unloaded_callback: F2) -> io::Result<Self>295 pub fn new(loaded_callback: F1, unloaded_callback: F2) -> io::Result<Self> { 296 let mut watcher = Self { 297 context: Box::new(CallbackContext::new(loaded_callback, unloaded_callback)), 298 cookie: None, 299 }; 300 let mut cookie: PVOID = ptr::null_mut(); 301 // Safe because we guarantee that the notification function that we 302 // register will have exclusive access to the context. 303 unsafe { 304 LdrRegisterDllNotification( 305 /* Flags= */ 0, 306 /* NotificationFunction= */ watcher.context.get_notification_function(), 307 /* Context= */ 308 &mut *watcher.context as *mut CallbackContext<F1, F2> as PVOID, 309 /* Cookie= */ &mut cookie as *mut PVOID, 310 )? 311 }; 312 watcher.cookie = ptr::NonNull::new(cookie); 313 Ok(watcher) 314 } 315 unregister_dll_notification(&mut self) -> io::Result<()>316 fn unregister_dll_notification(&mut self) -> io::Result<()> { 317 match self.cookie { 318 Some(c) => { 319 // Safe because we guarantee that `Cookie` was previously initialized. 320 unsafe { 321 LdrUnregisterDllNotification(/* Cookie= */ c.as_ptr() as PVOID)? 322 } 323 self.cookie = None; 324 } 325 None => {} 326 } 327 Ok(()) 328 } 329 } 330 331 impl<F1, F2> Drop for DllWatcher<F1, F2> 332 where 333 F1: FnMut(DllNotificationData), 334 F2: FnMut(DllNotificationData), 335 { drop(&mut self)336 fn drop(&mut self) { 337 self.unregister_dll_notification() 338 .expect("error unregistering dll notification"); 339 } 340 } 341 342 #[cfg(test)] 343 mod tests { 344 use std::collections::HashSet; 345 use std::ffi::CString; 346 use std::io; 347 348 use winapi::shared::minwindef::FALSE; 349 use winapi::shared::minwindef::TRUE; 350 use winapi::um::handleapi::CloseHandle; 351 use winapi::um::libloaderapi::FreeLibrary; 352 use winapi::um::libloaderapi::LoadLibraryA; 353 use winapi::um::synchapi::CreateEventA; 354 use winapi::um::synchapi::SetEvent; 355 use winapi::um::synchapi::WaitForSingleObject; 356 use winapi::um::winbase::WAIT_OBJECT_0; 357 358 use super::*; 359 360 // Arbitrarily chosen DLLs for load/unload test. Chosen because they're 361 // hopefully esoteric enough that they're probably not already loaded in 362 // the process so we can test load/unload notifications. 363 // 364 // Using a single DLL can lead to flakiness; since the tests are run in the 365 // same process, it can be hard to rely on the OS to clean up the DLL loaded 366 // by one test before the other test runs. Using a different DLL makes the 367 // tests more independent. 368 const TEST_DLL_NAME_1: &'static str = "Imagehlp.dll"; 369 const TEST_DLL_NAME_2: &'static str = "dbghelp.dll"; 370 371 #[test] load_dll()372 fn load_dll() { 373 let test_dll_name = CString::new(TEST_DLL_NAME_1).expect("failed to create CString"); 374 let mut loaded_dlls: HashSet<OsString> = HashSet::new(); 375 let h_module = { 376 let _watcher = DllWatcher::new( 377 |data| { 378 loaded_dlls.insert(data.base_dll_name); 379 }, 380 |_data| (), 381 ) 382 .expect("failed to create DllWatcher"); 383 // Safe because we pass a valid C string in to the function. 384 unsafe { LoadLibraryA(test_dll_name.as_ptr()) } 385 }; 386 assert!( 387 !h_module.is_null(), 388 "failed to load {}: {}", 389 TEST_DLL_NAME_1, 390 io::Error::last_os_error() 391 ); 392 assert!( 393 loaded_dlls.len() >= 1, 394 "no DLL loads recorded by DLL watcher" 395 ); 396 assert!( 397 loaded_dlls.contains::<OsString>(&(TEST_DLL_NAME_1.to_owned().into())), 398 "{} load wasn't recorded by DLL watcher", 399 TEST_DLL_NAME_1 400 ); 401 // Safe because we initialized h_module with a LoadLibraryA call. 402 let success = unsafe { FreeLibrary(h_module) } > 0; 403 assert!( 404 success, 405 "failed to free {}: {}", 406 TEST_DLL_NAME_1, 407 io::Error::last_os_error(), 408 ) 409 } 410 411 #[test] unload_dll()412 fn unload_dll() { 413 let mut unloaded_dlls: HashSet<OsString> = HashSet::new(); 414 // Safe as no pointers are passed. The handle may leak if the test fails. 415 let event = 416 unsafe { CreateEventA(std::ptr::null_mut(), TRUE, FALSE, std::ptr::null_mut()) }; 417 assert!( 418 !event.is_null(), 419 "failed to create event; event was NULL: {}", 420 io::Error::last_os_error() 421 ); 422 { 423 let test_dll_name = CString::new(TEST_DLL_NAME_2).expect("failed to create CString"); 424 let _watcher = DllWatcher::new( 425 |_data| (), 426 |data| { 427 unloaded_dlls.insert(data.base_dll_name); 428 // Safe as we assert that the event is valid above. 429 unsafe { SetEvent(event) }; 430 }, 431 ) 432 .expect("failed to create DllWatcher"); 433 // Safe because we pass a valid C string in to the function. 434 let h_module = unsafe { LoadLibraryA(test_dll_name.as_ptr()) }; 435 assert!( 436 !h_module.is_null(), 437 "failed to load {}: {}", 438 TEST_DLL_NAME_2, 439 io::Error::last_os_error() 440 ); 441 // Safe because we initialized h_module with a LoadLibraryA call. 442 let success = unsafe { FreeLibrary(h_module) } > 0; 443 assert!( 444 success, 445 "failed to free {}: {}", 446 TEST_DLL_NAME_2, 447 io::Error::last_os_error(), 448 ) 449 }; 450 // Safe as we assert that the event is valid above. 451 assert_eq!(unsafe { WaitForSingleObject(event, 5000) }, WAIT_OBJECT_0); 452 assert!( 453 unloaded_dlls.len() >= 1, 454 "no DLL unloads recorded by DLL watcher" 455 ); 456 assert!( 457 unloaded_dlls.contains::<OsString>(&(TEST_DLL_NAME_2.to_owned().into())), 458 "{} unload wasn't recorded by DLL watcher", 459 TEST_DLL_NAME_2 460 ); 461 // Safe as we assert that the event is valid above. 462 unsafe { CloseHandle(event) }; 463 } 464 } 465