• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::sync::atomic::AtomicU32;
6 use std::sync::atomic::Ordering::SeqCst;
7 
8 use base::info;
9 use base::Event;
10 use libc::c_void;
11 use winapi::shared::guiddef::IsEqualGUID;
12 use winapi::shared::guiddef::REFIID;
13 use winapi::shared::minwindef::ULONG;
14 use winapi::shared::winerror::E_INVALIDARG;
15 use winapi::shared::winerror::E_NOINTERFACE;
16 use winapi::shared::winerror::NOERROR;
17 use winapi::shared::winerror::S_OK;
18 use winapi::um::mmdeviceapi::*;
19 use winapi::um::objidlbase::IAgileObject;
20 use winapi::um::unknwnbase::IUnknown;
21 use winapi::um::unknwnbase::IUnknownVtbl;
22 use winapi::um::winnt::HRESULT;
23 use winapi::Interface;
24 use wio::com::ComPtr;
25 
26 /// This struct is used to create the completion handler `IActivateAudioInterfaceCompletionHandler`
27 /// that is passed into `ActivateAudioInterfaceAsync`. In other words, the first field in the struct
28 /// must be `IActivateAudioInterfaceCompletionHandlerVtbl`.
29 ///
30 /// This struct matches the `IActivateAudioInterfaceCompletionHandler` struct with the addition of
31 /// the `ref_count` below `lp_vtbl` which is used to keep a reference count to the completion
32 /// handler.
33 #[repr(C)]
34 pub struct WinAudioActivateAudioInterfaceCompletionHandler {
35     pub lp_vtbl: &'static IActivateAudioInterfaceCompletionHandlerVtbl,
36     ref_count: AtomicU32,
37     activate_audio_interface_complete_event: Event,
38 }
39 
40 impl WinAudioActivateAudioInterfaceCompletionHandler {
41     /// The ComPtr is a `WinAudioActivateAudioInterfaceCompletionHandler` casted as an
42     /// `IActivateAudioInterfaceCompletionHandler`.
create_com_ptr( activate_audio_interface_complete_event: Event, ) -> ComPtr<IActivateAudioInterfaceCompletionHandler>43     pub fn create_com_ptr(
44         activate_audio_interface_complete_event: Event,
45     ) -> ComPtr<IActivateAudioInterfaceCompletionHandler> {
46         let win_completion_handler = Box::new(WinAudioActivateAudioInterfaceCompletionHandler {
47             lp_vtbl: IWIN_AUDIO_COMPLETION_HANDLER_VTBL,
48             ref_count: AtomicU32::new(1),
49             activate_audio_interface_complete_event,
50         });
51 
52         // This is safe if the value passed into `from_raw` is structured in a way where it can
53         // match `IActivateAudioInterfaceCompletionHandler`.
54         // Since `win_completion_handler.cast_to_com_ptr()` does, this is safe.
55         // SAFETY: We are passing in a valid COM object that implements `IUnknown` into
56         // `from_raw`.
57         unsafe {
58             ComPtr::from_raw(Box::into_raw(win_completion_handler)
59                 as *mut IActivateAudioInterfaceCompletionHandler)
60         }
61     }
62 
63     /// Unsafe if `thing` cannot because casted to
64     /// `WinAudioActivateAudioInterfaceCompletionHandler`. This is safe because `thing` is
65     /// originally a `WinAudioActivateAudioInterfaceCompletionHandler.
increment_counter(&self) -> ULONG66     unsafe fn increment_counter(&self) -> ULONG {
67         self.ref_count.fetch_add(1, SeqCst) + 1
68     }
69 
decrement_counter(&mut self) -> ULONG70     fn decrement_counter(&mut self) -> ULONG {
71         let old_val = self.ref_count.fetch_sub(1, SeqCst);
72         if old_val == 0 {
73             panic!("Attempted to decrement WinAudioActivateInterfaceCompletionHandler ref count when it is already 0.");
74         }
75         old_val - 1
76     }
77 
activate_completed(&self)78     fn activate_completed(&self) {
79         info!("Activate Completed handler called from ActiviateAudioInterfaceAsync.");
80         self.activate_audio_interface_complete_event
81             .signal()
82             .expect("Failed to notify audioclientevent");
83     }
84 }
85 
86 impl Drop for WinAudioActivateAudioInterfaceCompletionHandler {
drop(&mut self)87     fn drop(&mut self) {
88         info!("IActivateAudioInterfaceCompletionHandler is dropped.");
89     }
90 }
91 
92 /// This is the callback when `ActivateAudioInterfaceAsync` is completed. When this is callback is
93 /// triggered, the IAudioClient will be available.
94 /// More info: https://docs.microsoft.com/en-us/windows/win32/api/mmdeviceapi/nf-mmdeviceapi-iactivateaudiointerfacecompletionhandler-activatecompleted
95 ///
96 /// Safe because we are certain that `completion_handler` can be casted to
97 /// `WinAudioActivateAudioInterfaceHandler`, since that is its original type during construction.
activate_completed( completion_handler: *mut IActivateAudioInterfaceCompletionHandler, _activate_operation: *mut IActivateAudioInterfaceAsyncOperation, ) -> HRESULT98 unsafe extern "system" fn activate_completed(
99     completion_handler: *mut IActivateAudioInterfaceCompletionHandler,
100     _activate_operation: *mut IActivateAudioInterfaceAsyncOperation,
101 ) -> HRESULT {
102     let win_audio_activate_interface =
103         completion_handler as *mut WinAudioActivateAudioInterfaceCompletionHandler;
104     (*win_audio_activate_interface).activate_completed();
105 
106     S_OK
107 }
108 
109 const IWIN_AUDIO_COMPLETION_HANDLER_VTBL: &IActivateAudioInterfaceCompletionHandlerVtbl =
110     // Implementation based on
111     // https://docs.microsoft.com/en-us/office/client-developer/outlook/mapi/implementing-iunknown-in-c-plus-plus
112     &IActivateAudioInterfaceCompletionHandlerVtbl {
113             parent: IUnknownVtbl {
114                 QueryInterface: {
115                     /// Safe because if `this` is not implemented (fails the RIID check) this
116                     /// function will just return. If it valid, it should be
117                     /// able to safely increment the ref counter and set the
118                     /// pointer `ppv_object`.
query_interface( this: *mut IUnknown, riid: REFIID, ppv_object: *mut *mut c_void, ) -> HRESULT119                     unsafe extern "system" fn query_interface(
120                         this: *mut IUnknown,
121                         riid: REFIID,
122                         ppv_object: *mut *mut c_void,
123                     ) -> HRESULT {
124                         if ppv_object.is_null() {
125                             return E_INVALIDARG;
126                         }
127 
128                         *ppv_object = std::ptr::null_mut();
129 
130                         // Check for valid RIID's
131                         if IsEqualGUID(&*riid, &IUnknown::uuidof())
132                             || IsEqualGUID(
133                                 &*riid,
134                                 &IActivateAudioInterfaceCompletionHandler::uuidof(),
135                             )
136                             || IsEqualGUID(&*riid, &IAgileObject::uuidof())
137                         {
138                             *ppv_object = this as *mut c_void;
139                             (*this).AddRef();
140                             return NOERROR;
141                         }
142                         E_NOINTERFACE
143                     }
144                     query_interface
145                 },
146                 AddRef: {
147                     /// Unsafe if `this` cannot because casted to
148                     /// `WinAudioActivateAudioInterfaceCompletionHandler`.
149                     ///
150                     /// This is safe because `this` is
151                     /// originally a `WinAudioActivateAudioInterfaceCompletionHandler.
add_ref(this: *mut IUnknown) -> ULONG152                     unsafe extern "system" fn add_ref(this: *mut IUnknown) -> ULONG {
153                         info!("Adding ref in IActivateAudioInterfaceCompletionHandler.");
154                         let win_audio_completion_handler =
155                             this as *mut WinAudioActivateAudioInterfaceCompletionHandler;
156                         (*win_audio_completion_handler).increment_counter()
157                     }
158                     add_ref
159                 },
160                 Release: {
161                     /// Unsafe if `this` cannot because casted to
162                     /// `WinAudioActivateAudioInterfaceCompletionHandler`. Also would be unsafe
163                     /// if `release` is called more than `add_ref`.
164                     ///
165                     /// This is safe because `this` is
166                     /// originally a `WinAudioActivateAudioInterfaceCompletionHandler and isn't
167                     /// called more than `add_ref`.
release(this: *mut IUnknown) -> ULONG168                     unsafe extern "system" fn release(this: *mut IUnknown) -> ULONG {
169                         info!("Releasing ref in IActivateAudioInterfaceCompletionHandler.");
170                         // Decrementing will free the `this` pointer if it's ref_count becomes 0.
171                         let win_audio_completion_handler =
172                             this as *mut WinAudioActivateAudioInterfaceCompletionHandler;
173                         let ref_count = (*win_audio_completion_handler).decrement_counter();
174                         if ref_count == 0 {
175                             // Delete the pointer
176                             drop(Box::from_raw(
177                                 this as *mut WinAudioActivateAudioInterfaceCompletionHandler,
178                             ));
179                         }
180                         ref_count
181                     }
182                     release
183                 },
184             },
185             ActivateCompleted: activate_completed,
186         };
187 
188 /// SAFETY:
189 /// `ActivateAudioInterfaceAsync` requires that `IActivateAudioCompletionHandler` to implement
190 /// `IAgileObject`, which means it is free threaded and can be called from any apartment. These
191 /// traits should allow it to do that.
192 unsafe impl Send for WinAudioActivateAudioInterfaceCompletionHandler {}
193 // SAFETY: see above
194 unsafe impl Sync for WinAudioActivateAudioInterfaceCompletionHandler {}
195 
196 #[cfg(test)]
197 mod test {
198     use base::EventExt;
199 
200     use super::*;
201 
202     #[test]
test_query_interface_valid()203     fn test_query_interface_valid() {
204         let completion_handler = WinAudioActivateAudioInterfaceCompletionHandler::create_com_ptr(
205             Event::new_auto_reset().unwrap(),
206         );
207         let invalid_ref_iid = IUnknown::uuidof();
208         let mut null_value = std::ptr::null_mut();
209         let ppv_object: *mut *mut c_void = &mut null_value;
210 
211         // Calling `QueryInterface`
212         // SAFETY: completion_handler has a valid lpVtbl pointer
213         let res = unsafe {
214             ((*completion_handler.lpVtbl).parent.QueryInterface)(
215                 completion_handler.as_raw() as *mut IUnknown,
216                 &invalid_ref_iid,
217                 ppv_object,
218             )
219         };
220         assert_eq!(res, NOERROR);
221 
222         // Release the reference from `QueryInteface` by calling `Release`
223         release(&completion_handler);
224 
225         let invalid_ref_iid = IActivateAudioInterfaceCompletionHandler::uuidof();
226         // SAFETY: completion_handler has a valid lpVtbl pointer
227         let res = unsafe {
228             ((*completion_handler.lpVtbl).parent.QueryInterface)(
229                 completion_handler.as_raw() as *mut IUnknown,
230                 &invalid_ref_iid,
231                 ppv_object,
232             )
233         };
234         assert_eq!(res, NOERROR);
235 
236         release(&completion_handler);
237 
238         let invalid_ref_iid = IAgileObject::uuidof();
239         // SAFETY: completion_handler has a valid lpVtbl pointer
240         let res = unsafe {
241             ((*completion_handler.lpVtbl).parent.QueryInterface)(
242                 completion_handler.as_raw() as *mut IUnknown,
243                 &invalid_ref_iid,
244                 ppv_object,
245             )
246         };
247         release(&completion_handler);
248         assert_eq!(res, NOERROR);
249     }
250 
251     #[test]
test_query_interface_invalid()252     fn test_query_interface_invalid() {
253         let completion_handler = WinAudioActivateAudioInterfaceCompletionHandler::create_com_ptr(
254             Event::new_auto_reset().unwrap(),
255         );
256         let invalid_ref_iid = IMMDeviceCollection::uuidof();
257         let mut null_value = std::ptr::null_mut();
258         let ppv_object: *mut *mut c_void = &mut null_value;
259 
260         // Call `QueryInterface`
261         // SAFETY: completion_handler has a valid lpVtbl pointer
262         let res = unsafe {
263             ((*completion_handler.lpVtbl).parent.QueryInterface)(
264                 completion_handler.as_raw() as *mut IUnknown,
265                 &invalid_ref_iid,
266                 ppv_object,
267             )
268         };
269         assert_eq!(res, E_NOINTERFACE)
270     }
271 
272     #[test]
test_add_ref()273     fn test_add_ref() {
274         // ref_count = 1
275         let completion_handler = WinAudioActivateAudioInterfaceCompletionHandler::create_com_ptr(
276             Event::new_auto_reset().unwrap(),
277         );
278         // ref_count = 2
279         let ref_count = add_ref(&completion_handler);
280         assert_eq!(ref_count, 2);
281         // ref_count = 1
282         release(&completion_handler);
283         // ref_count = 0 since ComPtr drops
284     }
285 
286     #[test]
test_release()287     fn test_release() {
288         // ref_count = 1
289         let completion_handler = WinAudioActivateAudioInterfaceCompletionHandler::create_com_ptr(
290             Event::new_auto_reset().unwrap(),
291         );
292         // ref_count = 2
293         let ref_count = add_ref(&completion_handler);
294         assert_eq!(ref_count, 2);
295         // ref_count = 1
296         let ref_count = release(&completion_handler);
297         assert_eq!(ref_count, 1);
298         // ref_count = 0 since ComPtr drops
299     }
300 
release(completion_handler: &ComPtr<IActivateAudioInterfaceCompletionHandler>) -> ULONG301     fn release(completion_handler: &ComPtr<IActivateAudioInterfaceCompletionHandler>) -> ULONG {
302         // SAFETY: completion_handler has a valid lpVtbl pointer
303         unsafe {
304             ((*completion_handler.lpVtbl).parent.Release)(
305                 completion_handler.as_raw() as *mut IUnknown
306             )
307         }
308     }
309 
add_ref(completion_handler: &ComPtr<IActivateAudioInterfaceCompletionHandler>) -> ULONG310     fn add_ref(completion_handler: &ComPtr<IActivateAudioInterfaceCompletionHandler>) -> ULONG {
311         // SAFETY: completion_handler has a valid lpVtbl pointer
312         unsafe {
313             ((*completion_handler.lpVtbl).parent.AddRef)(
314                 completion_handler.as_raw() as *mut IUnknown
315             )
316         }
317     }
318 }
319