1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::sync::atomic::AtomicU32;
6 use std::sync::atomic::Ordering::SeqCst;
7
8 use base::info;
9 use base::Event;
10 use libc::c_void;
11 use winapi::shared::guiddef::IsEqualGUID;
12 use winapi::shared::guiddef::REFIID;
13 use winapi::shared::minwindef::ULONG;
14 use winapi::shared::winerror::E_INVALIDARG;
15 use winapi::shared::winerror::E_NOINTERFACE;
16 use winapi::shared::winerror::NOERROR;
17 use winapi::shared::winerror::S_OK;
18 use winapi::um::mmdeviceapi::*;
19 use winapi::um::objidlbase::IAgileObject;
20 use winapi::um::unknwnbase::IUnknown;
21 use winapi::um::unknwnbase::IUnknownVtbl;
22 use winapi::um::winnt::HRESULT;
23 use winapi::Interface;
24 use wio::com::ComPtr;
25
26 /// This struct is used to create the completion handler `IActivateAudioInterfaceCompletionHandler`
27 /// that is passed into `ActivateAudioInterfaceAsync`. In other words, the first field in the struct
28 /// must be `IActivateAudioInterfaceCompletionHandlerVtbl`.
29 ///
30 /// This struct matches the `IActivateAudioInterfaceCompletionHandler` struct with the addition of
31 /// the `ref_count` below `lp_vtbl` which is used to keep a reference count to the completion
32 /// handler.
33 #[repr(C)]
34 pub struct WinAudioActivateAudioInterfaceCompletionHandler {
35 pub lp_vtbl: &'static IActivateAudioInterfaceCompletionHandlerVtbl,
36 ref_count: AtomicU32,
37 activate_audio_interface_complete_event: Event,
38 }
39
40 impl WinAudioActivateAudioInterfaceCompletionHandler {
41 /// The ComPtr is a `WinAudioActivateAudioInterfaceCompletionHandler` casted as an
42 /// `IActivateAudioInterfaceCompletionHandler`.
create_com_ptr( activate_audio_interface_complete_event: Event, ) -> ComPtr<IActivateAudioInterfaceCompletionHandler>43 pub fn create_com_ptr(
44 activate_audio_interface_complete_event: Event,
45 ) -> ComPtr<IActivateAudioInterfaceCompletionHandler> {
46 let win_completion_handler = Box::new(WinAudioActivateAudioInterfaceCompletionHandler {
47 lp_vtbl: IWIN_AUDIO_COMPLETION_HANDLER_VTBL,
48 ref_count: AtomicU32::new(1),
49 activate_audio_interface_complete_event,
50 });
51
52 // This is safe if the value passed into `from_raw` is structured in a way where it can
53 // match `IActivateAudioInterfaceCompletionHandler`.
54 // Since `win_completion_handler.cast_to_com_ptr()` does, this is safe.
55 // SAFETY: We are passing in a valid COM object that implements `IUnknown` into
56 // `from_raw`.
57 unsafe {
58 ComPtr::from_raw(Box::into_raw(win_completion_handler)
59 as *mut IActivateAudioInterfaceCompletionHandler)
60 }
61 }
62
63 /// Unsafe if `thing` cannot because casted to
64 /// `WinAudioActivateAudioInterfaceCompletionHandler`. This is safe because `thing` is
65 /// originally a `WinAudioActivateAudioInterfaceCompletionHandler.
increment_counter(&self) -> ULONG66 unsafe fn increment_counter(&self) -> ULONG {
67 self.ref_count.fetch_add(1, SeqCst) + 1
68 }
69
decrement_counter(&mut self) -> ULONG70 fn decrement_counter(&mut self) -> ULONG {
71 let old_val = self.ref_count.fetch_sub(1, SeqCst);
72 if old_val == 0 {
73 panic!("Attempted to decrement WinAudioActivateInterfaceCompletionHandler ref count when it is already 0.");
74 }
75 old_val - 1
76 }
77
activate_completed(&self)78 fn activate_completed(&self) {
79 info!("Activate Completed handler called from ActiviateAudioInterfaceAsync.");
80 self.activate_audio_interface_complete_event
81 .signal()
82 .expect("Failed to notify audioclientevent");
83 }
84 }
85
86 impl Drop for WinAudioActivateAudioInterfaceCompletionHandler {
drop(&mut self)87 fn drop(&mut self) {
88 info!("IActivateAudioInterfaceCompletionHandler is dropped.");
89 }
90 }
91
92 /// This is the callback when `ActivateAudioInterfaceAsync` is completed. When this is callback is
93 /// triggered, the IAudioClient will be available.
94 /// More info: https://docs.microsoft.com/en-us/windows/win32/api/mmdeviceapi/nf-mmdeviceapi-iactivateaudiointerfacecompletionhandler-activatecompleted
95 ///
96 /// Safe because we are certain that `completion_handler` can be casted to
97 /// `WinAudioActivateAudioInterfaceHandler`, since that is its original type during construction.
activate_completed( completion_handler: *mut IActivateAudioInterfaceCompletionHandler, _activate_operation: *mut IActivateAudioInterfaceAsyncOperation, ) -> HRESULT98 unsafe extern "system" fn activate_completed(
99 completion_handler: *mut IActivateAudioInterfaceCompletionHandler,
100 _activate_operation: *mut IActivateAudioInterfaceAsyncOperation,
101 ) -> HRESULT {
102 let win_audio_activate_interface =
103 completion_handler as *mut WinAudioActivateAudioInterfaceCompletionHandler;
104 (*win_audio_activate_interface).activate_completed();
105
106 S_OK
107 }
108
109 const IWIN_AUDIO_COMPLETION_HANDLER_VTBL: &IActivateAudioInterfaceCompletionHandlerVtbl =
110 // Implementation based on
111 // https://docs.microsoft.com/en-us/office/client-developer/outlook/mapi/implementing-iunknown-in-c-plus-plus
112 &IActivateAudioInterfaceCompletionHandlerVtbl {
113 parent: IUnknownVtbl {
114 QueryInterface: {
115 /// Safe because if `this` is not implemented (fails the RIID check) this
116 /// function will just return. If it valid, it should be
117 /// able to safely increment the ref counter and set the
118 /// pointer `ppv_object`.
query_interface( this: *mut IUnknown, riid: REFIID, ppv_object: *mut *mut c_void, ) -> HRESULT119 unsafe extern "system" fn query_interface(
120 this: *mut IUnknown,
121 riid: REFIID,
122 ppv_object: *mut *mut c_void,
123 ) -> HRESULT {
124 if ppv_object.is_null() {
125 return E_INVALIDARG;
126 }
127
128 *ppv_object = std::ptr::null_mut();
129
130 // Check for valid RIID's
131 if IsEqualGUID(&*riid, &IUnknown::uuidof())
132 || IsEqualGUID(
133 &*riid,
134 &IActivateAudioInterfaceCompletionHandler::uuidof(),
135 )
136 || IsEqualGUID(&*riid, &IAgileObject::uuidof())
137 {
138 *ppv_object = this as *mut c_void;
139 (*this).AddRef();
140 return NOERROR;
141 }
142 E_NOINTERFACE
143 }
144 query_interface
145 },
146 AddRef: {
147 /// Unsafe if `this` cannot because casted to
148 /// `WinAudioActivateAudioInterfaceCompletionHandler`.
149 ///
150 /// This is safe because `this` is
151 /// originally a `WinAudioActivateAudioInterfaceCompletionHandler.
add_ref(this: *mut IUnknown) -> ULONG152 unsafe extern "system" fn add_ref(this: *mut IUnknown) -> ULONG {
153 info!("Adding ref in IActivateAudioInterfaceCompletionHandler.");
154 let win_audio_completion_handler =
155 this as *mut WinAudioActivateAudioInterfaceCompletionHandler;
156 (*win_audio_completion_handler).increment_counter()
157 }
158 add_ref
159 },
160 Release: {
161 /// Unsafe if `this` cannot because casted to
162 /// `WinAudioActivateAudioInterfaceCompletionHandler`. Also would be unsafe
163 /// if `release` is called more than `add_ref`.
164 ///
165 /// This is safe because `this` is
166 /// originally a `WinAudioActivateAudioInterfaceCompletionHandler and isn't
167 /// called more than `add_ref`.
release(this: *mut IUnknown) -> ULONG168 unsafe extern "system" fn release(this: *mut IUnknown) -> ULONG {
169 info!("Releasing ref in IActivateAudioInterfaceCompletionHandler.");
170 // Decrementing will free the `this` pointer if it's ref_count becomes 0.
171 let win_audio_completion_handler =
172 this as *mut WinAudioActivateAudioInterfaceCompletionHandler;
173 let ref_count = (*win_audio_completion_handler).decrement_counter();
174 if ref_count == 0 {
175 // Delete the pointer
176 drop(Box::from_raw(
177 this as *mut WinAudioActivateAudioInterfaceCompletionHandler,
178 ));
179 }
180 ref_count
181 }
182 release
183 },
184 },
185 ActivateCompleted: activate_completed,
186 };
187
188 /// SAFETY:
189 /// `ActivateAudioInterfaceAsync` requires that `IActivateAudioCompletionHandler` to implement
190 /// `IAgileObject`, which means it is free threaded and can be called from any apartment. These
191 /// traits should allow it to do that.
192 unsafe impl Send for WinAudioActivateAudioInterfaceCompletionHandler {}
193 // SAFETY: see above
194 unsafe impl Sync for WinAudioActivateAudioInterfaceCompletionHandler {}
195
196 #[cfg(test)]
197 mod test {
198 use base::EventExt;
199
200 use super::*;
201
202 #[test]
test_query_interface_valid()203 fn test_query_interface_valid() {
204 let completion_handler = WinAudioActivateAudioInterfaceCompletionHandler::create_com_ptr(
205 Event::new_auto_reset().unwrap(),
206 );
207 let invalid_ref_iid = IUnknown::uuidof();
208 let mut null_value = std::ptr::null_mut();
209 let ppv_object: *mut *mut c_void = &mut null_value;
210
211 // Calling `QueryInterface`
212 // SAFETY: completion_handler has a valid lpVtbl pointer
213 let res = unsafe {
214 ((*completion_handler.lpVtbl).parent.QueryInterface)(
215 completion_handler.as_raw() as *mut IUnknown,
216 &invalid_ref_iid,
217 ppv_object,
218 )
219 };
220 assert_eq!(res, NOERROR);
221
222 // Release the reference from `QueryInteface` by calling `Release`
223 release(&completion_handler);
224
225 let invalid_ref_iid = IActivateAudioInterfaceCompletionHandler::uuidof();
226 // SAFETY: completion_handler has a valid lpVtbl pointer
227 let res = unsafe {
228 ((*completion_handler.lpVtbl).parent.QueryInterface)(
229 completion_handler.as_raw() as *mut IUnknown,
230 &invalid_ref_iid,
231 ppv_object,
232 )
233 };
234 assert_eq!(res, NOERROR);
235
236 release(&completion_handler);
237
238 let invalid_ref_iid = IAgileObject::uuidof();
239 // SAFETY: completion_handler has a valid lpVtbl pointer
240 let res = unsafe {
241 ((*completion_handler.lpVtbl).parent.QueryInterface)(
242 completion_handler.as_raw() as *mut IUnknown,
243 &invalid_ref_iid,
244 ppv_object,
245 )
246 };
247 release(&completion_handler);
248 assert_eq!(res, NOERROR);
249 }
250
251 #[test]
test_query_interface_invalid()252 fn test_query_interface_invalid() {
253 let completion_handler = WinAudioActivateAudioInterfaceCompletionHandler::create_com_ptr(
254 Event::new_auto_reset().unwrap(),
255 );
256 let invalid_ref_iid = IMMDeviceCollection::uuidof();
257 let mut null_value = std::ptr::null_mut();
258 let ppv_object: *mut *mut c_void = &mut null_value;
259
260 // Call `QueryInterface`
261 // SAFETY: completion_handler has a valid lpVtbl pointer
262 let res = unsafe {
263 ((*completion_handler.lpVtbl).parent.QueryInterface)(
264 completion_handler.as_raw() as *mut IUnknown,
265 &invalid_ref_iid,
266 ppv_object,
267 )
268 };
269 assert_eq!(res, E_NOINTERFACE)
270 }
271
272 #[test]
test_add_ref()273 fn test_add_ref() {
274 // ref_count = 1
275 let completion_handler = WinAudioActivateAudioInterfaceCompletionHandler::create_com_ptr(
276 Event::new_auto_reset().unwrap(),
277 );
278 // ref_count = 2
279 let ref_count = add_ref(&completion_handler);
280 assert_eq!(ref_count, 2);
281 // ref_count = 1
282 release(&completion_handler);
283 // ref_count = 0 since ComPtr drops
284 }
285
286 #[test]
test_release()287 fn test_release() {
288 // ref_count = 1
289 let completion_handler = WinAudioActivateAudioInterfaceCompletionHandler::create_com_ptr(
290 Event::new_auto_reset().unwrap(),
291 );
292 // ref_count = 2
293 let ref_count = add_ref(&completion_handler);
294 assert_eq!(ref_count, 2);
295 // ref_count = 1
296 let ref_count = release(&completion_handler);
297 assert_eq!(ref_count, 1);
298 // ref_count = 0 since ComPtr drops
299 }
300
release(completion_handler: &ComPtr<IActivateAudioInterfaceCompletionHandler>) -> ULONG301 fn release(completion_handler: &ComPtr<IActivateAudioInterfaceCompletionHandler>) -> ULONG {
302 // SAFETY: completion_handler has a valid lpVtbl pointer
303 unsafe {
304 ((*completion_handler.lpVtbl).parent.Release)(
305 completion_handler.as_raw() as *mut IUnknown
306 )
307 }
308 }
309
add_ref(completion_handler: &ComPtr<IActivateAudioInterfaceCompletionHandler>) -> ULONG310 fn add_ref(completion_handler: &ComPtr<IActivateAudioInterfaceCompletionHandler>) -> ULONG {
311 // SAFETY: completion_handler has a valid lpVtbl pointer
312 unsafe {
313 ((*completion_handler.lpVtbl).parent.AddRef)(
314 completion_handler.as_raw() as *mut IUnknown
315 )
316 }
317 }
318 }
319