1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 //! IO completion port wrapper.
6
7 use std::collections::VecDeque;
8 use std::io;
9 use std::ptr::null_mut;
10 use std::sync::Arc;
11 use std::sync::Condvar;
12 use std::time::Duration;
13
14 use base::error;
15 use base::info;
16 use base::AsRawDescriptor;
17 use base::Error as SysError;
18 use base::Event;
19 use base::EventWaitResult;
20 use base::FromRawDescriptor;
21 use base::RawDescriptor;
22 use base::SafeDescriptor;
23 use base::WorkerThread;
24 use smallvec::smallvec;
25 use smallvec::SmallVec;
26 use sync::Mutex;
27 use winapi::shared::minwindef::BOOL;
28 use winapi::shared::minwindef::DWORD;
29 use winapi::shared::minwindef::ULONG;
30 use winapi::um::handleapi::INVALID_HANDLE_VALUE;
31 use winapi::um::ioapiset::CreateIoCompletionPort;
32 use winapi::um::ioapiset::GetOverlappedResult;
33 use winapi::um::ioapiset::GetQueuedCompletionStatus;
34 use winapi::um::ioapiset::GetQueuedCompletionStatusEx;
35 use winapi::um::ioapiset::PostQueuedCompletionStatus;
36 use winapi::um::minwinbase::LPOVERLAPPED_ENTRY;
37 use winapi::um::minwinbase::OVERLAPPED;
38 use winapi::um::minwinbase::OVERLAPPED_ENTRY;
39 use winapi::um::winbase::INFINITE;
40
41 use super::handle_executor::Error;
42 use super::handle_executor::Result;
43
44 /// The number of IOCP packets we accept per poll operation.
45 /// Because this is only used for SmallVec sizes, clippy thinks it is unused.
46 #[allow(dead_code)]
47 const ENTRIES_PER_POLL: usize = 16;
48
49 /// A minimal version of completion packets from an IoCompletionPort.
50 pub(crate) struct CompletionPacket {
51 pub completion_key: usize,
52 pub overlapped_ptr: usize,
53 pub result: std::result::Result<usize, SysError>,
54 }
55
56 struct Port {
57 inner: RawDescriptor,
58 }
59
60 // SAFETY:
61 // Safe because the Port is dropped before IoCompletionPort goes out of scope
62 unsafe impl Send for Port {}
63
64 /// Wraps an IO Completion Port (iocp). These ports are very similar to an epoll
65 /// context on unix. Handles (equivalent to FDs) we want to wait on for
66 /// readiness are added to the port, and then the port can be waited on using a
67 /// syscall (GetQueuedCompletionStatus). IOCP is a little more flexible than
68 /// epoll because custom messages can be enqueued and received from the port
69 /// just like if a handle became ready (see [IoCompletionPort::post_status]).
70 ///
71 /// Note that completion ports can only be subscribed to a handle, they
72 /// can never be unsubscribed. Handles are removed from the port automatically when they are closed.
73 ///
74 /// Registered handles have their completion key set to their handle number.
75 pub(crate) struct IoCompletionPort {
76 port: SafeDescriptor,
77 threads: Vec<WorkerThread<Result<()>>>,
78 completed: Arc<(Mutex<VecDeque<CompletionPacket>>, Condvar)>,
79 concurrency: u32,
80 }
81
82 /// Gets a completion packet from the completion port. If the underlying IO operation
83 /// encountered an error, it will be contained inside the completion packet. If this method
84 /// encountered an error getting a completion packet, the error will be returned directly.
85 /// Safety: caller needs to ensure that the `handle` is valid and is for io completion port.
86 #[deny(unsafe_op_in_unsafe_fn)]
get_completion_status( handle: RawDescriptor, timeout: DWORD, ) -> io::Result<CompletionPacket>87 unsafe fn get_completion_status(
88 handle: RawDescriptor,
89 timeout: DWORD,
90 ) -> io::Result<CompletionPacket> {
91 let mut bytes_transferred = 0;
92 let mut completion_key = 0;
93 // SAFETY: trivially safe
94 let mut overlapped: *mut OVERLAPPED = unsafe { std::mem::zeroed() };
95
96 // SAFETY:
97 // Safe because:
98 // 1. Memory of pointers passed is stack allocated and lives as long as the syscall.
99 // 2. We check the error so we don't use invalid output values (e.g. overlapped).
100 let success = unsafe {
101 GetQueuedCompletionStatus(
102 handle,
103 &mut bytes_transferred,
104 &mut completion_key,
105 &mut overlapped as *mut *mut OVERLAPPED,
106 timeout,
107 )
108 } != 0;
109
110 if success {
111 return Ok(CompletionPacket {
112 result: Ok(bytes_transferred as usize),
113 completion_key,
114 overlapped_ptr: overlapped as usize,
115 });
116 }
117
118 // Did the IOCP operation fail, or did the overlapped operation fail?
119 if overlapped.is_null() {
120 // IOCP failed somehow.
121 Err(io::Error::last_os_error())
122 } else {
123 // Overlapped operation failed.
124 Ok(CompletionPacket {
125 result: Err(SysError::last()),
126 completion_key,
127 overlapped_ptr: overlapped as usize,
128 })
129 }
130 }
131
132 /// Waits for completion events to arrive & returns the completion keys.
133 /// Safety: caller needs to ensure that the `handle` is valid and is for io completion port.
134 #[deny(unsafe_op_in_unsafe_fn)]
poll(port: RawDescriptor) -> Result<Vec<CompletionPacket>>135 unsafe fn poll(port: RawDescriptor) -> Result<Vec<CompletionPacket>> {
136 let mut completion_packets = vec![];
137 completion_packets.push(
138 // SAFETY: caller has ensured that the handle is valid and is for io completion port
139 unsafe {
140 get_completion_status(port, INFINITE)
141 .map_err(|e| Error::IocpOperationFailed(SysError::from(e)))?
142 },
143 );
144
145 // Drain any waiting completion packets.
146 //
147 // Wondering why we don't use GetQueuedCompletionStatusEx instead? Well, there's no way to
148 // get detailed error information for each of the returned overlapped IO operations without
149 // calling GetOverlappedResult. If we have to do that, then it's cheaper to just get each
150 // completion packet individually.
151 while completion_packets.len() < ENTRIES_PER_POLL {
152 // SAFETY:
153 // Safety: caller has ensured that the handle is valid and is for io completion port
154 match unsafe { get_completion_status(port, 0) } {
155 Ok(pkt) => {
156 completion_packets.push(pkt);
157 }
158 Err(e) if e.kind() == io::ErrorKind::TimedOut => break,
159 Err(e) => return Err(Error::IocpOperationFailed(SysError::from(e))),
160 }
161 }
162
163 Ok(completion_packets)
164 }
165
166 /// Safety: caller needs to ensure that the `handle` is valid and is for io completion port.
iocp_waiter_thread( port: Arc<Mutex<Port>>, kill_evt: Event, completed: Arc<(Mutex<VecDeque<CompletionPacket>>, Condvar)>, ) -> Result<()>167 fn iocp_waiter_thread(
168 port: Arc<Mutex<Port>>,
169 kill_evt: Event,
170 completed: Arc<(Mutex<VecDeque<CompletionPacket>>, Condvar)>,
171 ) -> Result<()> {
172 let port = port.lock();
173 loop {
174 // SAFETY: caller has ensured that the handle is valid and is for io completion port
175 let packets = unsafe { poll(port.inner)? };
176 if !packets.is_empty() {
177 {
178 let mut c = completed.0.lock();
179 for packet in packets {
180 c.push_back(packet);
181 }
182 completed.1.notify_one();
183 }
184 }
185 if kill_evt
186 .wait_timeout(Duration::from_nanos(0))
187 .map_err(Error::IocpOperationFailed)?
188 == EventWaitResult::Signaled
189 {
190 return Ok(());
191 }
192 }
193 }
194
195 impl Drop for IoCompletionPort {
drop(&mut self)196 fn drop(&mut self) {
197 if !self.threaded() {
198 return;
199 }
200
201 let mut threads = std::mem::take(&mut self.threads);
202 for thread in &mut threads {
203 // let the thread know that it should exit
204 if let Err(e) = thread.signal() {
205 error!("faild to signal iocp thread: {}", e);
206 }
207 }
208
209 // interrupt all poll/get status on ports.
210 // Single thread can consume more ENTRIES_PER_POLL number of completion statuses.
211 // We send enough post_status so that all threads have enough data to be woken up by the
212 // completion ports.
213 // This is slightly unpleasant way to interrupt all the threads.
214 for _ in 0..(threads.len() * ENTRIES_PER_POLL) {
215 if let Err(e) = self.wake() {
216 error!("post_status failed during thread exit:{}", e);
217 }
218 }
219 }
220 }
221
222 impl IoCompletionPort {
new(concurrency: u32) -> Result<Self>223 pub fn new(concurrency: u32) -> Result<Self> {
224 let completed = Arc::new((Mutex::new(VecDeque::new()), Condvar::new()));
225 // Unwrap is safe because we're creating a new IOCP and will receive the owned handle
226 // back.
227 let port = create_iocp(None, None, 0, concurrency)?.unwrap();
228 let mut threads = vec![];
229 if concurrency > 1 {
230 info!("creating iocp with concurrency: {}", concurrency);
231 for i in 0..concurrency {
232 let completed_clone = completed.clone();
233 let port_desc = Arc::new(Mutex::new(Port {
234 inner: port.as_raw_descriptor(),
235 }));
236 threads.push(WorkerThread::start(
237 format!("overlapped_io_{}", i),
238 move |kill_evt| {
239 iocp_waiter_thread(port_desc, kill_evt, completed_clone).unwrap();
240 Ok(())
241 },
242 ));
243 }
244 }
245 Ok(Self {
246 port,
247 threads,
248 completed,
249 concurrency,
250 })
251 }
252
threaded(&self) -> bool253 fn threaded(&self) -> bool {
254 self.concurrency > 1
255 }
256
257 /// Register the provided descriptor with this completion port. Registered descriptors cannot
258 /// be deregistered. To deregister, close the descriptor.
register_descriptor(&self, desc: &dyn AsRawDescriptor) -> Result<()>259 pub fn register_descriptor(&self, desc: &dyn AsRawDescriptor) -> Result<()> {
260 create_iocp(
261 Some(desc),
262 Some(&self.port),
263 desc.as_raw_descriptor() as usize,
264 self.concurrency,
265 )?;
266 Ok(())
267 }
268
269 /// Posts a completion packet to the IO completion port.
post_status(&self, bytes_transferred: u32, completion_key: usize) -> Result<()>270 pub fn post_status(&self, bytes_transferred: u32, completion_key: usize) -> Result<()> {
271 // SAFETY:
272 // Safe because the IOCP handle is valid.
273 let res = unsafe {
274 PostQueuedCompletionStatus(
275 self.port.as_raw_descriptor(),
276 bytes_transferred,
277 completion_key,
278 null_mut(),
279 )
280 };
281 if res == 0 {
282 return Err(Error::IocpOperationFailed(SysError::last()));
283 }
284 Ok(())
285 }
286
287 /// Wake up thread waiting on this iocp.
288 /// If there are more than one thread waiting, then you may need to call this function
289 /// multiple times.
wake(&self) -> Result<()>290 pub fn wake(&self) -> Result<()> {
291 self.post_status(0, INVALID_HANDLE_VALUE as usize)
292 }
293
294 /// Get up to ENTRIES_PER_POLL completion packets from the IOCP in one shot.
295 #[allow(dead_code)]
get_completion_status_ex( &self, timeout: DWORD, ) -> Result<SmallVec<[OVERLAPPED_ENTRY; ENTRIES_PER_POLL]>>296 fn get_completion_status_ex(
297 &self,
298 timeout: DWORD,
299 ) -> Result<SmallVec<[OVERLAPPED_ENTRY; ENTRIES_PER_POLL]>> {
300 let mut overlapped_entries: SmallVec<[OVERLAPPED_ENTRY; ENTRIES_PER_POLL]> =
301 smallvec!(OVERLAPPED_ENTRY::default(); ENTRIES_PER_POLL);
302
303 let mut entries_removed: ULONG = 0;
304 // SAFETY:
305 // Safe because:
306 // 1. IOCP is guaranteed to exist by self.
307 // 2. Memory of pointers passed is stack allocated and lives as long as the syscall.
308 // 3. We check the error so we don't use invalid output values (e.g. overlapped).
309 let success = unsafe {
310 GetQueuedCompletionStatusEx(
311 self.port.as_raw_descriptor(),
312 overlapped_entries.as_mut_ptr() as LPOVERLAPPED_ENTRY,
313 ENTRIES_PER_POLL as ULONG,
314 &mut entries_removed,
315 timeout,
316 // We are normally called from a polling loop. It's more efficient (loop latency
317 // wise) to hold the thread instead of performing an alertable wait.
318 /* fAlertable= */
319 false as BOOL,
320 )
321 } != 0;
322
323 if success {
324 overlapped_entries.truncate(entries_removed as usize);
325 return Ok(overlapped_entries);
326 }
327
328 // Overlapped operation failed.
329 Err(Error::IocpOperationFailed(SysError::last()))
330 }
331
332 /// Waits for completion events to arrive & returns the completion keys.
poll_threaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>>333 fn poll_threaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
334 let mut completion_packets = SmallVec::with_capacity(ENTRIES_PER_POLL);
335 let mut packets = self.completed.0.lock();
336 loop {
337 let len = usize::min(ENTRIES_PER_POLL, packets.len());
338 for p in packets.drain(..len) {
339 completion_packets.push(p)
340 }
341 if !completion_packets.is_empty() {
342 return Ok(completion_packets);
343 }
344 packets = self.completed.1.wait(packets).unwrap();
345 }
346 }
347
348 /// Waits for completion events to arrive & returns the completion keys.
poll_unthreaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>>349 fn poll_unthreaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
350 // SAFETY: safe because port is in scope for the duration of the call.
351 let packets = unsafe { poll(self.port.as_raw_descriptor())? };
352 let mut completion_packets = SmallVec::with_capacity(ENTRIES_PER_POLL);
353 for pkt in packets {
354 completion_packets.push(pkt);
355 }
356 Ok(completion_packets)
357 }
358
poll(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>>359 pub fn poll(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
360 if self.threaded() {
361 self.poll_threaded()
362 } else {
363 self.poll_unthreaded()
364 }
365 }
366
367 /// Waits for completion events to arrive & returns the completion keys. Internally uses
368 /// GetCompletionStatusEx.
369 ///
370 /// WARNING: do NOT use completion keys that are not IO handles except for INVALID_HANDLE_VALUE
371 /// or undefined behavior will result.
372 #[allow(dead_code)]
poll_ex(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>>373 pub fn poll_ex(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
374 if self.threaded() {
375 self.poll()
376 } else {
377 self.poll_ex_unthreaded()
378 }
379 }
380
poll_ex_unthreaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>>381 pub fn poll_ex_unthreaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
382 let mut completion_packets = SmallVec::with_capacity(ENTRIES_PER_POLL);
383 let overlapped_entries = self.get_completion_status_ex(INFINITE)?;
384
385 for entry in &overlapped_entries {
386 if entry.lpCompletionKey as RawDescriptor == INVALID_HANDLE_VALUE {
387 completion_packets.push(CompletionPacket {
388 result: Ok(0),
389 completion_key: entry.lpCompletionKey,
390 overlapped_ptr: entry.lpOverlapped as usize,
391 });
392 continue;
393 }
394
395 let mut bytes_transferred = 0;
396 // SAFETY: trivially safe with return value checked
397 let success = unsafe {
398 GetOverlappedResult(
399 entry.lpCompletionKey as RawDescriptor,
400 entry.lpOverlapped,
401 &mut bytes_transferred,
402 // We don't need to wait because IOCP told us the IO is complete.
403 /* bWait= */
404 false as BOOL,
405 )
406 } != 0;
407 if success {
408 completion_packets.push(CompletionPacket {
409 result: Ok(bytes_transferred as usize),
410 completion_key: entry.lpCompletionKey,
411 overlapped_ptr: entry.lpOverlapped as usize,
412 });
413 } else {
414 completion_packets.push(CompletionPacket {
415 result: Err(SysError::last()),
416 completion_key: entry.lpCompletionKey,
417 overlapped_ptr: entry.lpOverlapped as usize,
418 });
419 }
420 }
421 Ok(completion_packets)
422 }
423 }
424
425 /// If existing_iocp is None, will return the created IOCP.
create_iocp( file: Option<&dyn AsRawDescriptor>, existing_iocp: Option<&dyn AsRawDescriptor>, completion_key: usize, concurrency: u32, ) -> Result<Option<SafeDescriptor>>426 fn create_iocp(
427 file: Option<&dyn AsRawDescriptor>,
428 existing_iocp: Option<&dyn AsRawDescriptor>,
429 completion_key: usize,
430 concurrency: u32,
431 ) -> Result<Option<SafeDescriptor>> {
432 let raw_file = match file {
433 Some(file) => file.as_raw_descriptor(),
434 None => INVALID_HANDLE_VALUE,
435 };
436 let raw_existing_iocp = match existing_iocp {
437 Some(iocp) => iocp.as_raw_descriptor(),
438 None => null_mut(),
439 };
440
441 let port =
442 // SAFETY:
443 // Safe because:
444 // 1. The file handle is open because we have a reference to it.
445 // 2. The existing IOCP (if applicable) is valid.
446 unsafe { CreateIoCompletionPort(raw_file, raw_existing_iocp, completion_key, concurrency) };
447
448 if port.is_null() {
449 return Err(Error::IocpOperationFailed(SysError::last()));
450 }
451
452 if existing_iocp.is_some() {
453 Ok(None)
454 } else {
455 // SAFETY:
456 // Safe because:
457 // 1. We are creating a new IOCP.
458 // 2. We exclusively own the handle.
459 // 3. The handle is valid since CreateIoCompletionPort returned without errors.
460 Ok(Some(unsafe { SafeDescriptor::from_raw_descriptor(port) }))
461 }
462 }
463
464 #[cfg(test)]
465 mod tests {
466 use std::fs::File;
467 use std::fs::OpenOptions;
468 use std::os::windows::fs::OpenOptionsExt;
469 use std::path::PathBuf;
470
471 use tempfile::TempDir;
472 use winapi::um::winbase::FILE_FLAG_OVERLAPPED;
473
474 use super::*;
475
476 static TEST_IO_CONCURRENCY: u32 = 4;
477
tempfile_path() -> (PathBuf, TempDir)478 fn tempfile_path() -> (PathBuf, TempDir) {
479 let dir = tempfile::TempDir::new().unwrap();
480 let mut file_path = PathBuf::from(dir.path());
481 file_path.push("test");
482 (file_path, dir)
483 }
484
open_overlapped(path: &PathBuf) -> File485 fn open_overlapped(path: &PathBuf) -> File {
486 OpenOptions::new()
487 .create(true)
488 .read(true)
489 .write(true)
490 .custom_flags(FILE_FLAG_OVERLAPPED)
491 .open(path)
492 .unwrap()
493 }
494
basic_iocp_test_with(concurrency: u32)495 fn basic_iocp_test_with(concurrency: u32) {
496 let iocp = IoCompletionPort::new(concurrency).unwrap();
497 let (file_path, _tmpdir) = tempfile_path();
498 let mut overlapped = OVERLAPPED::default();
499 let f = open_overlapped(&file_path);
500
501 iocp.register_descriptor(&f).unwrap();
502 let buf = [0u8; 16];
503 // SAFETY: Safe given file is valid, buffers are allocated and initialized and return value
504 // is checked.
505 unsafe {
506 base::windows::write_file(&f, buf.as_ptr(), buf.len(), Some(&mut overlapped)).unwrap()
507 };
508 assert_eq!(iocp.poll().unwrap().len(), 1);
509 }
510
511 #[test]
basic_iocp_test_unthreaded()512 fn basic_iocp_test_unthreaded() {
513 basic_iocp_test_with(1)
514 }
515
516 #[test]
basic_iocp_test_threaded()517 fn basic_iocp_test_threaded() {
518 basic_iocp_test_with(TEST_IO_CONCURRENCY)
519 }
520
basic_iocp_test_poll_ex(concurrency: u32)521 fn basic_iocp_test_poll_ex(concurrency: u32) {
522 let iocp = IoCompletionPort::new(concurrency).unwrap();
523 let (file_path, _tmpdir) = tempfile_path();
524 let mut overlapped = OVERLAPPED::default();
525 let f = open_overlapped(&file_path);
526
527 iocp.register_descriptor(&f).unwrap();
528 let buf = [0u8; 16];
529 // SAFETY: Safe given file is valid, buffers are allocated and initialized and return value
530 // is checked.
531 unsafe {
532 base::windows::write_file(&f, buf.as_ptr(), buf.len(), Some(&mut overlapped)).unwrap()
533 };
534 assert_eq!(iocp.poll_ex().unwrap().len(), 1);
535 }
536
537 #[test]
basic_iocp_test_poll_ex_unthreaded()538 fn basic_iocp_test_poll_ex_unthreaded() {
539 basic_iocp_test_poll_ex(1);
540 }
541
542 #[test]
basic_iocp_test_poll_ex_threaded()543 fn basic_iocp_test_poll_ex_threaded() {
544 basic_iocp_test_poll_ex(TEST_IO_CONCURRENCY);
545 }
546 }
547