• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! IO completion port wrapper.
6 
7 use std::collections::VecDeque;
8 use std::io;
9 use std::ptr::null_mut;
10 use std::sync::Arc;
11 use std::sync::Condvar;
12 use std::time::Duration;
13 
14 use base::error;
15 use base::info;
16 use base::AsRawDescriptor;
17 use base::Error as SysError;
18 use base::Event;
19 use base::EventWaitResult;
20 use base::FromRawDescriptor;
21 use base::RawDescriptor;
22 use base::SafeDescriptor;
23 use base::WorkerThread;
24 use smallvec::smallvec;
25 use smallvec::SmallVec;
26 use sync::Mutex;
27 use winapi::shared::minwindef::BOOL;
28 use winapi::shared::minwindef::DWORD;
29 use winapi::shared::minwindef::ULONG;
30 use winapi::um::handleapi::INVALID_HANDLE_VALUE;
31 use winapi::um::ioapiset::CreateIoCompletionPort;
32 use winapi::um::ioapiset::GetOverlappedResult;
33 use winapi::um::ioapiset::GetQueuedCompletionStatus;
34 use winapi::um::ioapiset::GetQueuedCompletionStatusEx;
35 use winapi::um::ioapiset::PostQueuedCompletionStatus;
36 use winapi::um::minwinbase::LPOVERLAPPED_ENTRY;
37 use winapi::um::minwinbase::OVERLAPPED;
38 use winapi::um::minwinbase::OVERLAPPED_ENTRY;
39 use winapi::um::winbase::INFINITE;
40 
41 use super::handle_executor::Error;
42 use super::handle_executor::Result;
43 
44 /// The number of IOCP packets we accept per poll operation.
45 /// Because this is only used for SmallVec sizes, clippy thinks it is unused.
46 #[allow(dead_code)]
47 const ENTRIES_PER_POLL: usize = 16;
48 
49 /// A minimal version of completion packets from an IoCompletionPort.
50 pub(crate) struct CompletionPacket {
51     pub completion_key: usize,
52     pub overlapped_ptr: usize,
53     pub result: std::result::Result<usize, SysError>,
54 }
55 
56 struct Port {
57     inner: RawDescriptor,
58 }
59 
60 // SAFETY:
61 // Safe because the Port is dropped before IoCompletionPort goes out of scope
62 unsafe impl Send for Port {}
63 
64 /// Wraps an IO Completion Port (iocp). These ports are very similar to an epoll
65 /// context on unix. Handles (equivalent to FDs) we want to wait on for
66 /// readiness are added to the port, and then the port can be waited on using a
67 /// syscall (GetQueuedCompletionStatus). IOCP is a little more flexible than
68 /// epoll because custom messages can be enqueued and received from the port
69 /// just like if a handle became ready (see [IoCompletionPort::post_status]).
70 ///
71 /// Note that completion ports can only be subscribed to a handle, they
72 /// can never be unsubscribed. Handles are removed from the port automatically when they are closed.
73 ///
74 /// Registered handles have their completion key set to their handle number.
75 pub(crate) struct IoCompletionPort {
76     port: SafeDescriptor,
77     threads: Vec<WorkerThread<Result<()>>>,
78     completed: Arc<(Mutex<VecDeque<CompletionPacket>>, Condvar)>,
79     concurrency: u32,
80 }
81 
82 /// Gets a completion packet from the completion port. If the underlying IO operation
83 /// encountered an error, it will be contained inside the completion packet. If this method
84 /// encountered an error getting a completion packet, the error will be returned directly.
85 /// Safety: caller needs to ensure that the `handle` is valid and is for io completion port.
86 #[deny(unsafe_op_in_unsafe_fn)]
get_completion_status( handle: RawDescriptor, timeout: DWORD, ) -> io::Result<CompletionPacket>87 unsafe fn get_completion_status(
88     handle: RawDescriptor,
89     timeout: DWORD,
90 ) -> io::Result<CompletionPacket> {
91     let mut bytes_transferred = 0;
92     let mut completion_key = 0;
93     // SAFETY: trivially safe
94     let mut overlapped: *mut OVERLAPPED = unsafe { std::mem::zeroed() };
95 
96     // SAFETY:
97     // Safe because:
98     //      1. Memory of pointers passed is stack allocated and lives as long as the syscall.
99     //      2. We check the error so we don't use invalid output values (e.g. overlapped).
100     let success = unsafe {
101         GetQueuedCompletionStatus(
102             handle,
103             &mut bytes_transferred,
104             &mut completion_key,
105             &mut overlapped as *mut *mut OVERLAPPED,
106             timeout,
107         )
108     } != 0;
109 
110     if success {
111         return Ok(CompletionPacket {
112             result: Ok(bytes_transferred as usize),
113             completion_key,
114             overlapped_ptr: overlapped as usize,
115         });
116     }
117 
118     // Did the IOCP operation fail, or did the overlapped operation fail?
119     if overlapped.is_null() {
120         // IOCP failed somehow.
121         Err(io::Error::last_os_error())
122     } else {
123         // Overlapped operation failed.
124         Ok(CompletionPacket {
125             result: Err(SysError::last()),
126             completion_key,
127             overlapped_ptr: overlapped as usize,
128         })
129     }
130 }
131 
132 /// Waits for completion events to arrive & returns the completion keys.
133 /// Safety: caller needs to ensure that the `handle` is valid and is for io completion port.
134 #[deny(unsafe_op_in_unsafe_fn)]
poll(port: RawDescriptor) -> Result<Vec<CompletionPacket>>135 unsafe fn poll(port: RawDescriptor) -> Result<Vec<CompletionPacket>> {
136     let mut completion_packets = vec![];
137     completion_packets.push(
138         // SAFETY: caller has ensured that the handle is valid and is for io completion port
139         unsafe {
140             get_completion_status(port, INFINITE)
141                 .map_err(|e| Error::IocpOperationFailed(SysError::from(e)))?
142         },
143     );
144 
145     // Drain any waiting completion packets.
146     //
147     // Wondering why we don't use GetQueuedCompletionStatusEx instead? Well, there's no way to
148     // get detailed error information for each of the returned overlapped IO operations without
149     // calling GetOverlappedResult. If we have to do that, then it's cheaper to just get each
150     // completion packet individually.
151     while completion_packets.len() < ENTRIES_PER_POLL {
152         // SAFETY:
153         // Safety: caller has ensured that the handle is valid and is for io completion port
154         match unsafe { get_completion_status(port, 0) } {
155             Ok(pkt) => {
156                 completion_packets.push(pkt);
157             }
158             Err(e) if e.kind() == io::ErrorKind::TimedOut => break,
159             Err(e) => return Err(Error::IocpOperationFailed(SysError::from(e))),
160         }
161     }
162 
163     Ok(completion_packets)
164 }
165 
166 /// Safety: caller needs to ensure that the `handle` is valid and is for io completion port.
iocp_waiter_thread( port: Arc<Mutex<Port>>, kill_evt: Event, completed: Arc<(Mutex<VecDeque<CompletionPacket>>, Condvar)>, ) -> Result<()>167 fn iocp_waiter_thread(
168     port: Arc<Mutex<Port>>,
169     kill_evt: Event,
170     completed: Arc<(Mutex<VecDeque<CompletionPacket>>, Condvar)>,
171 ) -> Result<()> {
172     let port = port.lock();
173     loop {
174         // SAFETY: caller has ensured that the handle is valid and is for io completion port
175         let packets = unsafe { poll(port.inner)? };
176         if !packets.is_empty() {
177             {
178                 let mut c = completed.0.lock();
179                 for packet in packets {
180                     c.push_back(packet);
181                 }
182                 completed.1.notify_one();
183             }
184         }
185         if kill_evt
186             .wait_timeout(Duration::from_nanos(0))
187             .map_err(Error::IocpOperationFailed)?
188             == EventWaitResult::Signaled
189         {
190             return Ok(());
191         }
192     }
193 }
194 
195 impl Drop for IoCompletionPort {
drop(&mut self)196     fn drop(&mut self) {
197         if !self.threaded() {
198             return;
199         }
200 
201         let mut threads = std::mem::take(&mut self.threads);
202         for thread in &mut threads {
203             // let the thread know that it should exit
204             if let Err(e) = thread.signal() {
205                 error!("faild to signal iocp thread: {}", e);
206             }
207         }
208 
209         // interrupt all poll/get status on ports.
210         // Single thread can consume more ENTRIES_PER_POLL number of completion statuses.
211         // We send enough post_status so that all threads have enough data to be woken up by the
212         // completion ports.
213         // This is slightly unpleasant way to interrupt all the threads.
214         for _ in 0..(threads.len() * ENTRIES_PER_POLL) {
215             if let Err(e) = self.wake() {
216                 error!("post_status failed during thread exit:{}", e);
217             }
218         }
219     }
220 }
221 
222 impl IoCompletionPort {
new(concurrency: u32) -> Result<Self>223     pub fn new(concurrency: u32) -> Result<Self> {
224         let completed = Arc::new((Mutex::new(VecDeque::new()), Condvar::new()));
225         // Unwrap is safe because we're creating a new IOCP and will receive the owned handle
226         // back.
227         let port = create_iocp(None, None, 0, concurrency)?.unwrap();
228         let mut threads = vec![];
229         if concurrency > 1 {
230             info!("creating iocp with concurrency: {}", concurrency);
231             for i in 0..concurrency {
232                 let completed_clone = completed.clone();
233                 let port_desc = Arc::new(Mutex::new(Port {
234                     inner: port.as_raw_descriptor(),
235                 }));
236                 threads.push(WorkerThread::start(
237                     format!("overlapped_io_{}", i),
238                     move |kill_evt| {
239                         iocp_waiter_thread(port_desc, kill_evt, completed_clone).unwrap();
240                         Ok(())
241                     },
242                 ));
243             }
244         }
245         Ok(Self {
246             port,
247             threads,
248             completed,
249             concurrency,
250         })
251     }
252 
threaded(&self) -> bool253     fn threaded(&self) -> bool {
254         self.concurrency > 1
255     }
256 
257     /// Register the provided descriptor with this completion port. Registered descriptors cannot
258     /// be deregistered. To deregister, close the descriptor.
register_descriptor(&self, desc: &dyn AsRawDescriptor) -> Result<()>259     pub fn register_descriptor(&self, desc: &dyn AsRawDescriptor) -> Result<()> {
260         create_iocp(
261             Some(desc),
262             Some(&self.port),
263             desc.as_raw_descriptor() as usize,
264             self.concurrency,
265         )?;
266         Ok(())
267     }
268 
269     /// Posts a completion packet to the IO completion port.
post_status(&self, bytes_transferred: u32, completion_key: usize) -> Result<()>270     pub fn post_status(&self, bytes_transferred: u32, completion_key: usize) -> Result<()> {
271         // SAFETY:
272         // Safe because the IOCP handle is valid.
273         let res = unsafe {
274             PostQueuedCompletionStatus(
275                 self.port.as_raw_descriptor(),
276                 bytes_transferred,
277                 completion_key,
278                 null_mut(),
279             )
280         };
281         if res == 0 {
282             return Err(Error::IocpOperationFailed(SysError::last()));
283         }
284         Ok(())
285     }
286 
287     /// Wake up thread waiting on this iocp.
288     /// If there are more than one thread waiting, then you may need to call this function
289     /// multiple times.
wake(&self) -> Result<()>290     pub fn wake(&self) -> Result<()> {
291         self.post_status(0, INVALID_HANDLE_VALUE as usize)
292     }
293 
294     /// Get up to ENTRIES_PER_POLL completion packets from the IOCP in one shot.
295     #[allow(dead_code)]
get_completion_status_ex( &self, timeout: DWORD, ) -> Result<SmallVec<[OVERLAPPED_ENTRY; ENTRIES_PER_POLL]>>296     fn get_completion_status_ex(
297         &self,
298         timeout: DWORD,
299     ) -> Result<SmallVec<[OVERLAPPED_ENTRY; ENTRIES_PER_POLL]>> {
300         let mut overlapped_entries: SmallVec<[OVERLAPPED_ENTRY; ENTRIES_PER_POLL]> =
301             smallvec!(OVERLAPPED_ENTRY::default(); ENTRIES_PER_POLL);
302 
303         let mut entries_removed: ULONG = 0;
304         // SAFETY:
305         // Safe because:
306         //      1. IOCP is guaranteed to exist by self.
307         //      2. Memory of pointers passed is stack allocated and lives as long as the syscall.
308         //      3. We check the error so we don't use invalid output values (e.g. overlapped).
309         let success = unsafe {
310             GetQueuedCompletionStatusEx(
311                 self.port.as_raw_descriptor(),
312                 overlapped_entries.as_mut_ptr() as LPOVERLAPPED_ENTRY,
313                 ENTRIES_PER_POLL as ULONG,
314                 &mut entries_removed,
315                 timeout,
316                 // We are normally called from a polling loop. It's more efficient (loop latency
317                 // wise) to hold the thread instead of performing an alertable wait.
318                 /* fAlertable= */
319                 false as BOOL,
320             )
321         } != 0;
322 
323         if success {
324             overlapped_entries.truncate(entries_removed as usize);
325             return Ok(overlapped_entries);
326         }
327 
328         // Overlapped operation failed.
329         Err(Error::IocpOperationFailed(SysError::last()))
330     }
331 
332     /// Waits for completion events to arrive & returns the completion keys.
poll_threaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>>333     fn poll_threaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
334         let mut completion_packets = SmallVec::with_capacity(ENTRIES_PER_POLL);
335         let mut packets = self.completed.0.lock();
336         loop {
337             let len = usize::min(ENTRIES_PER_POLL, packets.len());
338             for p in packets.drain(..len) {
339                 completion_packets.push(p)
340             }
341             if !completion_packets.is_empty() {
342                 return Ok(completion_packets);
343             }
344             packets = self.completed.1.wait(packets).unwrap();
345         }
346     }
347 
348     /// Waits for completion events to arrive & returns the completion keys.
poll_unthreaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>>349     fn poll_unthreaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
350         // SAFETY: safe because port is in scope for the duration of the call.
351         let packets = unsafe { poll(self.port.as_raw_descriptor())? };
352         let mut completion_packets = SmallVec::with_capacity(ENTRIES_PER_POLL);
353         for pkt in packets {
354             completion_packets.push(pkt);
355         }
356         Ok(completion_packets)
357     }
358 
poll(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>>359     pub fn poll(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
360         if self.threaded() {
361             self.poll_threaded()
362         } else {
363             self.poll_unthreaded()
364         }
365     }
366 
367     /// Waits for completion events to arrive & returns the completion keys. Internally uses
368     /// GetCompletionStatusEx.
369     ///
370     /// WARNING: do NOT use completion keys that are not IO handles except for INVALID_HANDLE_VALUE
371     /// or undefined behavior will result.
372     #[allow(dead_code)]
poll_ex(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>>373     pub fn poll_ex(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
374         if self.threaded() {
375             self.poll()
376         } else {
377             self.poll_ex_unthreaded()
378         }
379     }
380 
poll_ex_unthreaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>>381     pub fn poll_ex_unthreaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
382         let mut completion_packets = SmallVec::with_capacity(ENTRIES_PER_POLL);
383         let overlapped_entries = self.get_completion_status_ex(INFINITE)?;
384 
385         for entry in &overlapped_entries {
386             if entry.lpCompletionKey as RawDescriptor == INVALID_HANDLE_VALUE {
387                 completion_packets.push(CompletionPacket {
388                     result: Ok(0),
389                     completion_key: entry.lpCompletionKey,
390                     overlapped_ptr: entry.lpOverlapped as usize,
391                 });
392                 continue;
393             }
394 
395             let mut bytes_transferred = 0;
396             // SAFETY: trivially safe with return value checked
397             let success = unsafe {
398                 GetOverlappedResult(
399                     entry.lpCompletionKey as RawDescriptor,
400                     entry.lpOverlapped,
401                     &mut bytes_transferred,
402                     // We don't need to wait because IOCP told us the IO is complete.
403                     /* bWait= */
404                     false as BOOL,
405                 )
406             } != 0;
407             if success {
408                 completion_packets.push(CompletionPacket {
409                     result: Ok(bytes_transferred as usize),
410                     completion_key: entry.lpCompletionKey,
411                     overlapped_ptr: entry.lpOverlapped as usize,
412                 });
413             } else {
414                 completion_packets.push(CompletionPacket {
415                     result: Err(SysError::last()),
416                     completion_key: entry.lpCompletionKey,
417                     overlapped_ptr: entry.lpOverlapped as usize,
418                 });
419             }
420         }
421         Ok(completion_packets)
422     }
423 }
424 
425 /// If existing_iocp is None, will return the created IOCP.
create_iocp( file: Option<&dyn AsRawDescriptor>, existing_iocp: Option<&dyn AsRawDescriptor>, completion_key: usize, concurrency: u32, ) -> Result<Option<SafeDescriptor>>426 fn create_iocp(
427     file: Option<&dyn AsRawDescriptor>,
428     existing_iocp: Option<&dyn AsRawDescriptor>,
429     completion_key: usize,
430     concurrency: u32,
431 ) -> Result<Option<SafeDescriptor>> {
432     let raw_file = match file {
433         Some(file) => file.as_raw_descriptor(),
434         None => INVALID_HANDLE_VALUE,
435     };
436     let raw_existing_iocp = match existing_iocp {
437         Some(iocp) => iocp.as_raw_descriptor(),
438         None => null_mut(),
439     };
440 
441     let port =
442         // SAFETY:
443         // Safe because:
444         //      1. The file handle is open because we have a reference to it.
445         //      2. The existing IOCP (if applicable) is valid.
446         unsafe { CreateIoCompletionPort(raw_file, raw_existing_iocp, completion_key, concurrency) };
447 
448     if port.is_null() {
449         return Err(Error::IocpOperationFailed(SysError::last()));
450     }
451 
452     if existing_iocp.is_some() {
453         Ok(None)
454     } else {
455         // SAFETY:
456         // Safe because:
457         // 1. We are creating a new IOCP.
458         // 2. We exclusively own the handle.
459         // 3. The handle is valid since CreateIoCompletionPort returned without errors.
460         Ok(Some(unsafe { SafeDescriptor::from_raw_descriptor(port) }))
461     }
462 }
463 
464 #[cfg(test)]
465 mod tests {
466     use std::fs::File;
467     use std::fs::OpenOptions;
468     use std::os::windows::fs::OpenOptionsExt;
469     use std::path::PathBuf;
470 
471     use tempfile::TempDir;
472     use winapi::um::winbase::FILE_FLAG_OVERLAPPED;
473 
474     use super::*;
475 
476     static TEST_IO_CONCURRENCY: u32 = 4;
477 
tempfile_path() -> (PathBuf, TempDir)478     fn tempfile_path() -> (PathBuf, TempDir) {
479         let dir = tempfile::TempDir::new().unwrap();
480         let mut file_path = PathBuf::from(dir.path());
481         file_path.push("test");
482         (file_path, dir)
483     }
484 
open_overlapped(path: &PathBuf) -> File485     fn open_overlapped(path: &PathBuf) -> File {
486         OpenOptions::new()
487             .create(true)
488             .read(true)
489             .write(true)
490             .custom_flags(FILE_FLAG_OVERLAPPED)
491             .open(path)
492             .unwrap()
493     }
494 
basic_iocp_test_with(concurrency: u32)495     fn basic_iocp_test_with(concurrency: u32) {
496         let iocp = IoCompletionPort::new(concurrency).unwrap();
497         let (file_path, _tmpdir) = tempfile_path();
498         let mut overlapped = OVERLAPPED::default();
499         let f = open_overlapped(&file_path);
500 
501         iocp.register_descriptor(&f).unwrap();
502         let buf = [0u8; 16];
503         // SAFETY: Safe given file is valid, buffers are allocated and initialized and return value
504         // is checked.
505         unsafe {
506             base::windows::write_file(&f, buf.as_ptr(), buf.len(), Some(&mut overlapped)).unwrap()
507         };
508         assert_eq!(iocp.poll().unwrap().len(), 1);
509     }
510 
511     #[test]
basic_iocp_test_unthreaded()512     fn basic_iocp_test_unthreaded() {
513         basic_iocp_test_with(1)
514     }
515 
516     #[test]
basic_iocp_test_threaded()517     fn basic_iocp_test_threaded() {
518         basic_iocp_test_with(TEST_IO_CONCURRENCY)
519     }
520 
basic_iocp_test_poll_ex(concurrency: u32)521     fn basic_iocp_test_poll_ex(concurrency: u32) {
522         let iocp = IoCompletionPort::new(concurrency).unwrap();
523         let (file_path, _tmpdir) = tempfile_path();
524         let mut overlapped = OVERLAPPED::default();
525         let f = open_overlapped(&file_path);
526 
527         iocp.register_descriptor(&f).unwrap();
528         let buf = [0u8; 16];
529         // SAFETY: Safe given file is valid, buffers are allocated and initialized and return value
530         // is checked.
531         unsafe {
532             base::windows::write_file(&f, buf.as_ptr(), buf.len(), Some(&mut overlapped)).unwrap()
533         };
534         assert_eq!(iocp.poll_ex().unwrap().len(), 1);
535     }
536 
537     #[test]
basic_iocp_test_poll_ex_unthreaded()538     fn basic_iocp_test_poll_ex_unthreaded() {
539         basic_iocp_test_poll_ex(1);
540     }
541 
542     #[test]
basic_iocp_test_poll_ex_threaded()543     fn basic_iocp_test_poll_ex_threaded() {
544         basic_iocp_test_poll_ex(TEST_IO_CONCURRENCY);
545     }
546 }
547