• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::collections::VecDeque;
15 use std::ffi::c_void;
16 use std::io;
17 use std::marker::PhantomPinned;
18 use std::mem::size_of;
19 use std::os::windows::io::RawSocket;
20 use std::pin::Pin;
21 use std::ptr::null_mut;
22 use std::sync::atomic::{AtomicBool, Ordering};
23 use std::sync::{Arc, Mutex};
24 use std::time::Duration;
25 
26 use windows_sys::Win32::Foundation::{
27     ERROR_INVALID_HANDLE, ERROR_IO_PENDING, HANDLE, STATUS_CANCELLED, WAIT_TIMEOUT,
28 };
29 use windows_sys::Win32::Networking::WinSock::{
30     WSAGetLastError, WSAIoctl, SIO_BASE_HANDLE, SIO_BSP_HANDLE, SIO_BSP_HANDLE_POLL,
31     SIO_BSP_HANDLE_SELECT, SOCKET_ERROR,
32 };
33 use windows_sys::Win32::System::IO::OVERLAPPED;
34 
35 use crate::sys::windows::afd;
36 use crate::sys::windows::afd::{Afd, AfdGroup, AfdPollInfo};
37 use crate::sys::windows::events::{
38     Events, ERROR_FLAGS, READABLE_FLAGS, READ_CLOSED_FLAGS, WRITABLE_FLAGS, WRITE_CLOSED_FLAGS,
39 };
40 use crate::sys::windows::io_status_block::IoStatusBlock;
41 use crate::sys::windows::iocp::{CompletionPort, CompletionStatus};
42 use crate::sys::NetInner;
43 use crate::{Event, Interest, Token};
44 
45 /// An wrapper to block different OS polling system.
46 /// Linux: epoll
47 /// Windows: iocp
48 #[derive(Debug)]
49 pub struct Selector {
50     inner: Arc<SelectorInner>,
51 }
52 
53 impl Selector {
new() -> io::Result<Selector>54     pub(crate) fn new() -> io::Result<Selector> {
55         SelectorInner::new().map(|inner| Selector {
56             inner: Arc::new(inner),
57         })
58     }
59 
select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()>60     pub(crate) fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
61         self.inner.select(events, timeout)
62     }
63 
register( &self, socket: RawSocket, token: Token, interests: Interest, ) -> io::Result<NetInner>64     pub(crate) fn register(
65         &self,
66         socket: RawSocket,
67         token: Token,
68         interests: Interest,
69     ) -> io::Result<NetInner> {
70         SelectorInner::register(&self.inner, socket, token, interests)
71     }
72 
reregister( &self, sock_state: Pin<Arc<Mutex<SockState>>>, token: Token, interests: Interest, ) -> io::Result<()>73     pub(crate) fn reregister(
74         &self,
75         sock_state: Pin<Arc<Mutex<SockState>>>,
76         token: Token,
77         interests: Interest,
78     ) -> io::Result<()> {
79         self.inner.reregister(sock_state, token, interests)
80     }
81 
clone_cp(&self) -> Arc<CompletionPort>82     pub(crate) fn clone_cp(&self) -> Arc<CompletionPort> {
83         self.inner.completion_port.clone()
84     }
85 }
86 
87 #[derive(Debug)]
88 pub(crate) struct SelectorInner {
89     /// IOCP Handle.
90     completion_port: Arc<CompletionPort>,
91     /// Registered/re-registered IO events are placed in this queue.
92     update_queue: Mutex<VecDeque<Pin<Arc<Mutex<SockState>>>>>,
93     /// Afd Group.
94     afd_group: AfdGroup,
95     /// Weather the Selector is polling.
96     polling: AtomicBool,
97 }
98 
99 impl SelectorInner {
100     /// Creates a new SelectorInner
new() -> io::Result<SelectorInner>101     fn new() -> io::Result<SelectorInner> {
102         CompletionPort::new().map(|cp| {
103             let arc_cp = Arc::new(cp);
104             let cp_afd = Arc::clone(&arc_cp);
105 
106             SelectorInner {
107                 completion_port: arc_cp,
108                 update_queue: Mutex::new(VecDeque::new()),
109                 afd_group: AfdGroup::new(cp_afd),
110                 polling: AtomicBool::new(false),
111             }
112         })
113     }
114 
115     /// Start poll
select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()>116     fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
117         events.clear();
118 
119         match timeout {
120             None => loop {
121                 let len = self.select_inner(events, timeout)?;
122                 if len != 0 {
123                     return Ok(());
124                 }
125             },
126             Some(_) => {
127                 self.select_inner(events, timeout)?;
128                 Ok(())
129             }
130         }
131     }
132 
select_inner(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<usize>133     fn select_inner(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<usize> {
134         // We can only poll once at the same time.
135         if self.polling.swap(true, Ordering::AcqRel) {
136             panic!("Can't be polling twice at same time!");
137         }
138 
139         unsafe { self.update_sockets_events() }?;
140 
141         let results = self
142             .completion_port
143             .get_results(&mut events.status, timeout);
144 
145         self.polling.store(false, Ordering::Relaxed);
146 
147         match results {
148             Ok(iocp_events) => Ok(unsafe { self.feed_events(&mut events.events, iocp_events) }),
149             Err(ref e) if e.raw_os_error() == Some(WAIT_TIMEOUT as i32) => Ok(0),
150             Err(e) => Err(e),
151         }
152     }
153 
154     /// Process completed operation and put them into events; regular AFD events
155     /// are put back into VecDeque
feed_events( &self, events: &mut Vec<Event>, iocp_events: &[CompletionStatus], ) -> usize156     unsafe fn feed_events(
157         &self,
158         events: &mut Vec<Event>,
159         iocp_events: &[CompletionStatus],
160     ) -> usize {
161         let mut epoll_event_count = 0;
162         let mut update_queue = self.update_queue.lock().unwrap();
163         for iocp_event in iocp_events.iter() {
164             if iocp_event.overlapped().is_null() {
165                 events.push(Event::from_completion_status(iocp_event));
166                 epoll_event_count += 1;
167                 continue;
168             } else if iocp_event.token() % 2 == 1 {
169                 // Non-AFD event, including pipe.
170                 let callback = (*(iocp_event.overlapped() as *mut super::Overlapped)).callback;
171 
172                 let len = events.len();
173                 callback(iocp_event.entry(), Some(events));
174                 epoll_event_count += events.len() - len;
175                 continue;
176             }
177 
178             // General asynchronous IO event.
179             let sock_state = from_overlapped(iocp_event.overlapped());
180             let mut sock_guard = sock_state.lock().unwrap();
181             if let Some(event) = sock_guard.sock_feed_event() {
182                 events.push(event);
183                 epoll_event_count += 1;
184             }
185 
186             // Reregister the socket.
187             if !sock_guard.is_delete_pending() {
188                 update_queue.push_back(sock_state.clone());
189             }
190         }
191 
192         self.afd_group.release_unused_afd();
193         epoll_event_count
194     }
195 
196     /// Updates each SockState in the Deque, started only when Poll::poll() is
197     /// called externally
update_sockets_events(&self) -> io::Result<()>198     unsafe fn update_sockets_events(&self) -> io::Result<()> {
199         let mut update_queue = self.update_queue.lock().unwrap();
200         for sock in update_queue.iter_mut() {
201             let mut sock_internal = sock.lock().unwrap();
202             if !sock_internal.delete_pending {
203                 sock_internal.update(sock)?;
204             }
205         }
206         // Deletes events which has been updated successful.
207         update_queue.retain(|sock| sock.lock().unwrap().has_error());
208 
209         self.afd_group.release_unused_afd();
210         Ok(())
211     }
212 
213     /// No actual system call is made at register, it only starts at
214     /// Poll::poll(). Return Arc<NetInternal> and put it in the asynchronous
215     /// IO structure
register( this: &Arc<Self>, raw_socket: RawSocket, token: Token, interests: Interest, ) -> io::Result<NetInner>216     pub(crate) fn register(
217         this: &Arc<Self>,
218         raw_socket: RawSocket,
219         token: Token,
220         interests: Interest,
221     ) -> io::Result<NetInner> {
222         // Creates Afd
223         let afd = this.afd_group.acquire()?;
224         let mut sock_state = SockState::new(raw_socket, afd)?;
225 
226         let flags = interests_to_afd_flags(interests);
227         sock_state.set_event(flags, token.0 as u64);
228 
229         let pin_sock_state = Arc::pin(Mutex::new(sock_state));
230 
231         let net_internal = NetInner::new(this.clone(), token, interests, pin_sock_state.clone());
232 
233         // Adds SockState to VecDeque
234         this.queue_state(pin_sock_state);
235 
236         if this.polling.load(Ordering::Acquire) {
237             unsafe { this.update_sockets_events()? }
238         }
239 
240         Ok(net_internal)
241     }
242 
243     /// Re-register, put SockState back into VecDeque
reregister( &self, state: Pin<Arc<Mutex<SockState>>>, token: Token, interests: Interest, ) -> io::Result<()>244     pub(crate) fn reregister(
245         &self,
246         state: Pin<Arc<Mutex<SockState>>>,
247         token: Token,
248         interests: Interest,
249     ) -> io::Result<()> {
250         let flags = interests_to_afd_flags(interests);
251         state.lock().unwrap().set_event(flags, token.0 as u64);
252 
253         // Put back in the update queue VecDeque
254         self.queue_state(state);
255 
256         if self.polling.load(Ordering::Acquire) {
257             unsafe { self.update_sockets_events() }
258         } else {
259             Ok(())
260         }
261     }
262 
263     /// Adds SockState to VecDeque last.
queue_state(&self, sock_state: Pin<Arc<Mutex<SockState>>>)264     fn queue_state(&self, sock_state: Pin<Arc<Mutex<SockState>>>) {
265         let mut update_queue = self.update_queue.lock().unwrap();
266         update_queue.push_back(sock_state);
267     }
268 }
269 
270 impl Drop for SelectorInner {
drop(&mut self)271     fn drop(&mut self) {
272         loop {
273             let complete_num: usize;
274             let mut status: [CompletionStatus; 1024] = [CompletionStatus::zero(); 1024];
275 
276             let result = self
277                 .completion_port
278                 .get_results(&mut status, Some(Duration::from_millis(0)));
279 
280             match result {
281                 Ok(iocp_events) => {
282                     complete_num = iocp_events.iter().len();
283                     for iocp_event in iocp_events.iter() {
284                         if iocp_event.overlapped().is_null() {
285                             // User event
286                         } else if iocp_event.token() % 2 == 1 {
287                             // For pipe, dispatch the event so it can release resources
288                             let callback = unsafe {
289                                 (*(iocp_event.overlapped() as *mut super::Overlapped)).callback
290                             };
291 
292                             callback(iocp_event.entry(), None);
293                         } else {
294                             // Release memory of Arc reference
295                             let _ = from_overlapped(iocp_event.overlapped());
296                         }
297                     }
298                 }
299 
300                 Err(_) => {
301                     break;
302                 }
303             }
304 
305             if complete_num == 0 {
306                 // continue looping until all completion status have been drained
307                 break;
308             }
309         }
310 
311         self.afd_group.release_unused_afd();
312     }
313 }
314 
315 #[derive(Debug, PartialEq)]
316 enum SockPollStatus {
317     /// Initial Value.
318     Idle,
319     /// System function called when updating sockets_events, set from Idle to
320     /// Pending. Update only when polling. Only the socket of Pending can be
321     /// cancelled.
322     Pending,
323     /// After calling the system api to cancel the sock, set it to Cancelled.
324     Cancelled,
325 }
326 
327 /// Saves all information of the socket during polling.
328 #[derive(Debug)]
329 pub struct SockState {
330     iosb: IoStatusBlock,
331     poll_info: AfdPollInfo,
332     /// The file handle to which request is bound.
333     afd: Arc<Afd>,
334     /// SOCKET of the request
335     base_socket: RawSocket,
336     /// User Token
337     user_token: u64,
338     /// user Interest
339     user_interests_flags: u32,
340     /// When this socket is polled, save user_interests_flags in
341     /// polling_interests_flags. Used for comparison during re-registration.
342     polling_interests_flags: u32,
343     /// Current Status. When this is Pending, System API calls must be made.
344     poll_status: SockPollStatus,
345     /// Mark if it is deleted.
346     delete_pending: bool,
347     /// Error during updating
348     error: Option<i32>,
349 
350     _pinned: PhantomPinned,
351 }
352 
353 impl SockState {
354     /// Creates a new SockState with RawSocket and Afd.
new(socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState>355     fn new(socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState> {
356         Ok(SockState {
357             iosb: IoStatusBlock::zeroed(),
358             poll_info: AfdPollInfo::zeroed(),
359             afd,
360             base_socket: get_base_socket(socket)?,
361             user_interests_flags: 0,
362             polling_interests_flags: 0,
363             user_token: 0,
364             poll_status: SockPollStatus::Idle,
365             delete_pending: false,
366 
367             error: None,
368             _pinned: PhantomPinned,
369         })
370     }
371 
372     /// Update SockState in Deque, poll for each Afd.
update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()>373     fn update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()> {
374         // delete_pending must false.
375         if self.delete_pending {
376             panic!("SockState update when delete_pending is true, {:#?}", self);
377         }
378 
379         // Make sure to reset previous error before a new update
380         self.error = None;
381 
382         match self.poll_status {
383             // Starts poll
384             SockPollStatus::Idle => {
385                 // Init AfdPollInfo
386                 self.poll_info.exclusive = 0;
387                 self.poll_info.number_of_handles = 1;
388                 self.poll_info.timeout = i64::MAX;
389                 self.poll_info.handles[0].handle = self.base_socket as HANDLE;
390                 self.poll_info.handles[0].status = 0;
391                 self.poll_info.handles[0].events =
392                     self.user_interests_flags | afd::POLL_LOCAL_CLOSE;
393 
394                 let overlapped_ptr = into_overlapped(self_arc.clone());
395 
396                 // System call to run current event.
397                 let result = unsafe {
398                     self.afd
399                         .poll(&mut self.poll_info, &mut *self.iosb, overlapped_ptr)
400                 };
401 
402                 if let Err(e) = result {
403                     let code = e.raw_os_error().unwrap();
404                     if code != ERROR_IO_PENDING as i32 {
405                         drop(from_overlapped(overlapped_ptr as *mut _));
406 
407                         return if code == ERROR_INVALID_HANDLE as i32 {
408                             // Socket closed; it'll be dropped.
409                             self.start_drop();
410                             Ok(())
411                         } else {
412                             self.error = e.raw_os_error();
413                             Err(e)
414                         };
415                     }
416                 };
417 
418                 // The poll request was successfully submitted.
419                 self.poll_status = SockPollStatus::Pending;
420                 self.polling_interests_flags = self.user_interests_flags;
421             }
422             SockPollStatus::Pending => {
423                 if (self.user_interests_flags & afd::ALL_EVENTS & !self.polling_interests_flags)
424                     == 0
425                 {
426                     // All the events the user is interested in are already
427                     // being monitored by the pending poll
428                     // operation. It might spuriously complete because of an
429                     // event that we're no longer interested in; when that
430                     // happens we'll submit a new poll
431                     // operation with the updated event mask.
432                 } else {
433                     // A poll operation is already pending, but it's not monitoring for all the
434                     // events that the user is interested in. Therefore, cancel the pending
435                     // poll operation; when we receive it's completion package, a new poll
436                     // operation will be submitted with the correct event mask.
437                     if let Err(e) = self.cancel() {
438                         self.error = e.raw_os_error();
439                         return Err(e);
440                     }
441                 }
442             }
443             // Do nothing
444             SockPollStatus::Cancelled => {}
445         }
446 
447         Ok(())
448     }
449 
450     /// Returns true if user_interests_flags is inconsistent with
451     /// polling_interests_flags.
set_event(&mut self, flags: u32, token_data: u64) -> bool452     fn set_event(&mut self, flags: u32, token_data: u64) -> bool {
453         self.user_interests_flags = flags | afd::POLL_CONNECT_FAIL | afd::POLL_ABORT;
454         self.user_token = token_data;
455 
456         (self.user_interests_flags & !self.polling_interests_flags) != 0
457     }
458 
459     /// Process completed IO operation.
sock_feed_event(&mut self) -> Option<Event>460     fn sock_feed_event(&mut self) -> Option<Event> {
461         self.poll_status = SockPollStatus::Idle;
462         self.polling_interests_flags = 0;
463 
464         let mut afd_events = 0;
465         // Uses the status info in IO_STATUS_BLOCK to determine the socket poll status.
466         // It is unsafe to use a pointer of IO_STATUS_BLOCK.
467         unsafe {
468             if self.delete_pending {
469                 return None;
470             } else if self.iosb.Anonymous.Status == STATUS_CANCELLED {
471                 // The poll request was cancelled by CancelIoEx.
472             } else if self.iosb.Anonymous.Status < 0 {
473                 // The overlapped request itself failed in an unexpected way.
474                 afd_events = afd::POLL_CONNECT_FAIL;
475             } else if self.poll_info.number_of_handles < 1 {
476                 // This poll operation succeeded but didn't report any socket
477                 // events.
478             } else if self.poll_info.handles[0].events & afd::POLL_LOCAL_CLOSE != 0 {
479                 // The poll operation reported that the socket was closed.
480                 self.start_drop();
481                 return None;
482             } else {
483                 afd_events = self.poll_info.handles[0].events;
484             }
485         }
486         // Filter out events that the user didn't ask for.
487         afd_events &= self.user_interests_flags;
488 
489         if afd_events == 0 {
490             return None;
491         }
492 
493         // Simulates Edge-triggered behavior to match API usage.
494         // Intercept all read/write from user which may cause WouldBlock usage,
495         // And reregister the socket to reset the interests.
496         self.user_interests_flags &= !afd_events;
497 
498         Some(Event {
499             data: self.user_token,
500             flags: afd_events,
501         })
502     }
503 
504     /// Starts drop SockState
start_drop(&mut self)505     pub(crate) fn start_drop(&mut self) {
506         if !self.delete_pending {
507             // if it is Pending, it means SockState has been register in IOCP,
508             // must system call to cancel socket.
509             // else set delete_pending=true is enough.
510             if let SockPollStatus::Pending = self.poll_status {
511                 drop(self.cancel());
512             }
513             self.delete_pending = true;
514         }
515     }
516 
517     /// Only can cancel SockState of SockPollStatus::Pending, Set to
518     /// SockPollStatus::Cancelled.
cancel(&mut self) -> io::Result<()>519     fn cancel(&mut self) -> io::Result<()> {
520         // Checks poll_status again.
521         if self.poll_status != SockPollStatus::Pending {
522             unreachable!("Invalid poll status during cancel, {:#?}", self);
523         }
524 
525         unsafe {
526             self.afd.cancel(&mut *self.iosb)?;
527         }
528 
529         // Only here set SockPollStatus::Cancelled, SockStates has been system called to
530         // cancel
531         self.poll_status = SockPollStatus::Cancelled;
532         self.polling_interests_flags = 0;
533 
534         Ok(())
535     }
536 
is_delete_pending(&self) -> bool537     fn is_delete_pending(&self) -> bool {
538         self.delete_pending
539     }
540 
has_error(&self) -> bool541     fn has_error(&self) -> bool {
542         self.error.is_some()
543     }
544 }
545 
546 impl Drop for SockState {
drop(&mut self)547     fn drop(&mut self) {
548         self.start_drop();
549     }
550 }
551 
get_base_socket(raw_socket: RawSocket) -> io::Result<RawSocket>552 fn get_base_socket(raw_socket: RawSocket) -> io::Result<RawSocket> {
553     let res = base_socket_inner(raw_socket, SIO_BASE_HANDLE);
554     if let Ok(base_socket) = res {
555         return Ok(base_socket);
556     }
557 
558     for &ioctl in &[SIO_BSP_HANDLE_SELECT, SIO_BSP_HANDLE_POLL, SIO_BSP_HANDLE] {
559         if let Ok(base_socket) = base_socket_inner(raw_socket, ioctl) {
560             if base_socket != raw_socket {
561                 return Ok(base_socket);
562             }
563         }
564     }
565 
566     Err(io::Error::from_raw_os_error(res.unwrap_err()))
567 }
568 
base_socket_inner(raw_socket: RawSocket, control_code: u32) -> Result<RawSocket, i32>569 fn base_socket_inner(raw_socket: RawSocket, control_code: u32) -> Result<RawSocket, i32> {
570     let mut base_socket: RawSocket = 0;
571     let mut bytes_returned: u32 = 0;
572     unsafe {
573         if WSAIoctl(
574             raw_socket as usize,
575             control_code,
576             null_mut(),
577             0,
578             &mut base_socket as *mut _ as *mut c_void,
579             size_of::<RawSocket>() as u32,
580             &mut bytes_returned,
581             null_mut(),
582             None,
583         ) != SOCKET_ERROR
584         {
585             Ok(base_socket)
586         } else {
587             // Returns the error status for the last Windows Sockets operation that failed.
588             Err(WSAGetLastError())
589         }
590     }
591 }
592 
593 /// Interests convert to flags.
interests_to_afd_flags(interests: Interest) -> u32594 fn interests_to_afd_flags(interests: Interest) -> u32 {
595     let mut flags = 0;
596 
597     // Sets readable flags.
598     if interests.is_readable() {
599         flags |= READABLE_FLAGS | READ_CLOSED_FLAGS | ERROR_FLAGS;
600     }
601 
602     // Sets writable flags.
603     if interests.is_writable() {
604         flags |= WRITABLE_FLAGS | WRITE_CLOSED_FLAGS | ERROR_FLAGS;
605     }
606 
607     flags
608 }
609 
610 /// Converts the pointer to a `SockState` into a raw pointer.
into_overlapped(sock_state: Pin<Arc<Mutex<SockState>>>) -> *mut c_void611 fn into_overlapped(sock_state: Pin<Arc<Mutex<SockState>>>) -> *mut c_void {
612     let overlapped_ptr: *const Mutex<SockState> =
613         unsafe { Arc::into_raw(Pin::into_inner_unchecked(sock_state)) };
614     overlapped_ptr as *mut _
615 }
616 
617 /// Convert a raw overlapped pointer into a reference to `SockState`.
from_overlapped(ptr: *mut OVERLAPPED) -> Pin<Arc<Mutex<SockState>>>618 fn from_overlapped(ptr: *mut OVERLAPPED) -> Pin<Arc<Mutex<SockState>>> {
619     let sock_ptr: *const Mutex<SockState> = ptr as *const _;
620     unsafe { Pin::new_unchecked(Arc::from_raw(sock_ptr)) }
621 }
622