1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::collections::VecDeque;
15 use std::ffi::c_void;
16 use std::io;
17 use std::marker::PhantomPinned;
18 use std::mem::size_of;
19 use std::os::windows::io::RawSocket;
20 use std::pin::Pin;
21 use std::ptr::null_mut;
22 use std::sync::atomic::{AtomicBool, Ordering};
23 use std::sync::{Arc, Mutex};
24 use std::time::Duration;
25
26 use windows_sys::Win32::Foundation::{
27 ERROR_INVALID_HANDLE, ERROR_IO_PENDING, HANDLE, STATUS_CANCELLED, WAIT_TIMEOUT,
28 };
29 use windows_sys::Win32::Networking::WinSock::{
30 WSAGetLastError, WSAIoctl, SIO_BASE_HANDLE, SIO_BSP_HANDLE, SIO_BSP_HANDLE_POLL,
31 SIO_BSP_HANDLE_SELECT, SOCKET_ERROR,
32 };
33 use windows_sys::Win32::System::IO::OVERLAPPED;
34
35 use crate::sys::windows::afd;
36 use crate::sys::windows::afd::{Afd, AfdGroup, AfdPollInfo};
37 use crate::sys::windows::events::{
38 Events, ERROR_FLAGS, READABLE_FLAGS, READ_CLOSED_FLAGS, WRITABLE_FLAGS, WRITE_CLOSED_FLAGS,
39 };
40 use crate::sys::windows::io_status_block::IoStatusBlock;
41 use crate::sys::windows::iocp::{CompletionPort, CompletionStatus};
42 use crate::sys::NetInner;
43 use crate::{Event, Interest, Token};
44
45 /// An wrapper to block different OS polling system.
46 /// Linux: epoll
47 /// Windows: iocp
48 #[derive(Debug)]
49 pub struct Selector {
50 inner: Arc<SelectorInner>,
51 }
52
53 impl Selector {
new() -> io::Result<Selector>54 pub(crate) fn new() -> io::Result<Selector> {
55 SelectorInner::new().map(|inner| Selector {
56 inner: Arc::new(inner),
57 })
58 }
59
select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()>60 pub(crate) fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
61 self.inner.select(events, timeout)
62 }
63
register( &self, socket: RawSocket, token: Token, interests: Interest, ) -> io::Result<NetInner>64 pub(crate) fn register(
65 &self,
66 socket: RawSocket,
67 token: Token,
68 interests: Interest,
69 ) -> io::Result<NetInner> {
70 SelectorInner::register(&self.inner, socket, token, interests)
71 }
72
reregister( &self, sock_state: Pin<Arc<Mutex<SockState>>>, token: Token, interests: Interest, ) -> io::Result<()>73 pub(crate) fn reregister(
74 &self,
75 sock_state: Pin<Arc<Mutex<SockState>>>,
76 token: Token,
77 interests: Interest,
78 ) -> io::Result<()> {
79 self.inner.reregister(sock_state, token, interests)
80 }
81
clone_cp(&self) -> Arc<CompletionPort>82 pub(crate) fn clone_cp(&self) -> Arc<CompletionPort> {
83 self.inner.completion_port.clone()
84 }
85 }
86
87 #[derive(Debug)]
88 pub(crate) struct SelectorInner {
89 /// IOCP Handle.
90 completion_port: Arc<CompletionPort>,
91 /// Registered/re-registered IO events are placed in this queue.
92 update_queue: Mutex<VecDeque<Pin<Arc<Mutex<SockState>>>>>,
93 /// Afd Group.
94 afd_group: AfdGroup,
95 /// Weather the Selector is polling.
96 polling: AtomicBool,
97 }
98
99 impl SelectorInner {
100 /// Creates a new SelectorInner
new() -> io::Result<SelectorInner>101 fn new() -> io::Result<SelectorInner> {
102 CompletionPort::new().map(|cp| {
103 let arc_cp = Arc::new(cp);
104 let cp_afd = Arc::clone(&arc_cp);
105
106 SelectorInner {
107 completion_port: arc_cp,
108 update_queue: Mutex::new(VecDeque::new()),
109 afd_group: AfdGroup::new(cp_afd),
110 polling: AtomicBool::new(false),
111 }
112 })
113 }
114
115 /// Start poll
select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()>116 fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
117 events.clear();
118
119 match timeout {
120 None => loop {
121 let len = self.select_inner(events, timeout)?;
122 if len != 0 {
123 return Ok(());
124 }
125 },
126 Some(_) => {
127 self.select_inner(events, timeout)?;
128 Ok(())
129 }
130 }
131 }
132
select_inner(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<usize>133 fn select_inner(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<usize> {
134 // We can only poll once at the same time.
135 if self.polling.swap(true, Ordering::AcqRel) {
136 panic!("Can't be polling twice at same time!");
137 }
138
139 unsafe { self.update_sockets_events() }?;
140
141 let results = self
142 .completion_port
143 .get_results(&mut events.status, timeout);
144
145 self.polling.store(false, Ordering::Relaxed);
146
147 match results {
148 Ok(iocp_events) => Ok(unsafe { self.feed_events(&mut events.events, iocp_events) }),
149 Err(ref e) if e.raw_os_error() == Some(WAIT_TIMEOUT as i32) => Ok(0),
150 Err(e) => Err(e),
151 }
152 }
153
154 /// Process completed operation and put them into events; regular AFD events
155 /// are put back into VecDeque
feed_events( &self, events: &mut Vec<Event>, iocp_events: &[CompletionStatus], ) -> usize156 unsafe fn feed_events(
157 &self,
158 events: &mut Vec<Event>,
159 iocp_events: &[CompletionStatus],
160 ) -> usize {
161 let mut epoll_event_count = 0;
162 let mut update_queue = self.update_queue.lock().unwrap();
163 for iocp_event in iocp_events.iter() {
164 if iocp_event.overlapped().is_null() {
165 events.push(Event::from_completion_status(iocp_event));
166 epoll_event_count += 1;
167 continue;
168 } else if iocp_event.token() % 2 == 1 {
169 // Non-AFD event, including pipe.
170 let callback = (*(iocp_event.overlapped() as *mut super::Overlapped)).callback;
171
172 let len = events.len();
173 callback(iocp_event.entry(), Some(events));
174 epoll_event_count += events.len() - len;
175 continue;
176 }
177
178 // General asynchronous IO event.
179 let sock_state = from_overlapped(iocp_event.overlapped());
180 let mut sock_guard = sock_state.lock().unwrap();
181 if let Some(event) = sock_guard.sock_feed_event() {
182 events.push(event);
183 epoll_event_count += 1;
184 }
185
186 // Reregister the socket.
187 if !sock_guard.is_delete_pending() {
188 update_queue.push_back(sock_state.clone());
189 }
190 }
191
192 self.afd_group.release_unused_afd();
193 epoll_event_count
194 }
195
196 /// Updates each SockState in the Deque, started only when Poll::poll() is
197 /// called externally
update_sockets_events(&self) -> io::Result<()>198 unsafe fn update_sockets_events(&self) -> io::Result<()> {
199 let mut update_queue = self.update_queue.lock().unwrap();
200 for sock in update_queue.iter_mut() {
201 let mut sock_internal = sock.lock().unwrap();
202 if !sock_internal.delete_pending {
203 sock_internal.update(sock)?;
204 }
205 }
206 // Deletes events which has been updated successful.
207 update_queue.retain(|sock| sock.lock().unwrap().has_error());
208
209 self.afd_group.release_unused_afd();
210 Ok(())
211 }
212
213 /// No actual system call is made at register, it only starts at
214 /// Poll::poll(). Return Arc<NetInternal> and put it in the asynchronous
215 /// IO structure
register( this: &Arc<Self>, raw_socket: RawSocket, token: Token, interests: Interest, ) -> io::Result<NetInner>216 pub(crate) fn register(
217 this: &Arc<Self>,
218 raw_socket: RawSocket,
219 token: Token,
220 interests: Interest,
221 ) -> io::Result<NetInner> {
222 // Creates Afd
223 let afd = this.afd_group.acquire()?;
224 let mut sock_state = SockState::new(raw_socket, afd)?;
225
226 let flags = interests_to_afd_flags(interests);
227 sock_state.set_event(flags, token.0 as u64);
228
229 let pin_sock_state = Arc::pin(Mutex::new(sock_state));
230
231 let net_internal = NetInner::new(this.clone(), token, interests, pin_sock_state.clone());
232
233 // Adds SockState to VecDeque
234 this.queue_state(pin_sock_state);
235
236 if this.polling.load(Ordering::Acquire) {
237 unsafe { this.update_sockets_events()? }
238 }
239
240 Ok(net_internal)
241 }
242
243 /// Re-register, put SockState back into VecDeque
reregister( &self, state: Pin<Arc<Mutex<SockState>>>, token: Token, interests: Interest, ) -> io::Result<()>244 pub(crate) fn reregister(
245 &self,
246 state: Pin<Arc<Mutex<SockState>>>,
247 token: Token,
248 interests: Interest,
249 ) -> io::Result<()> {
250 let flags = interests_to_afd_flags(interests);
251 state.lock().unwrap().set_event(flags, token.0 as u64);
252
253 // Put back in the update queue VecDeque
254 self.queue_state(state);
255
256 if self.polling.load(Ordering::Acquire) {
257 unsafe { self.update_sockets_events() }
258 } else {
259 Ok(())
260 }
261 }
262
263 /// Adds SockState to VecDeque last.
queue_state(&self, sock_state: Pin<Arc<Mutex<SockState>>>)264 fn queue_state(&self, sock_state: Pin<Arc<Mutex<SockState>>>) {
265 let mut update_queue = self.update_queue.lock().unwrap();
266 update_queue.push_back(sock_state);
267 }
268 }
269
270 impl Drop for SelectorInner {
drop(&mut self)271 fn drop(&mut self) {
272 loop {
273 let complete_num: usize;
274 let mut status: [CompletionStatus; 1024] = [CompletionStatus::zero(); 1024];
275
276 let result = self
277 .completion_port
278 .get_results(&mut status, Some(Duration::from_millis(0)));
279
280 match result {
281 Ok(iocp_events) => {
282 complete_num = iocp_events.iter().len();
283 for iocp_event in iocp_events.iter() {
284 if iocp_event.overlapped().is_null() {
285 // User event
286 } else if iocp_event.token() % 2 == 1 {
287 // For pipe, dispatch the event so it can release resources
288 let callback = unsafe {
289 (*(iocp_event.overlapped() as *mut super::Overlapped)).callback
290 };
291
292 callback(iocp_event.entry(), None);
293 } else {
294 // Release memory of Arc reference
295 let _ = from_overlapped(iocp_event.overlapped());
296 }
297 }
298 }
299
300 Err(_) => {
301 break;
302 }
303 }
304
305 if complete_num == 0 {
306 // continue looping until all completion status have been drained
307 break;
308 }
309 }
310
311 self.afd_group.release_unused_afd();
312 }
313 }
314
315 #[derive(Debug, PartialEq)]
316 enum SockPollStatus {
317 /// Initial Value.
318 Idle,
319 /// System function called when updating sockets_events, set from Idle to
320 /// Pending. Update only when polling. Only the socket of Pending can be
321 /// cancelled.
322 Pending,
323 /// After calling the system api to cancel the sock, set it to Cancelled.
324 Cancelled,
325 }
326
327 /// Saves all information of the socket during polling.
328 #[derive(Debug)]
329 pub struct SockState {
330 iosb: IoStatusBlock,
331 poll_info: AfdPollInfo,
332 /// The file handle to which request is bound.
333 afd: Arc<Afd>,
334 /// SOCKET of the request
335 base_socket: RawSocket,
336 /// User Token
337 user_token: u64,
338 /// user Interest
339 user_interests_flags: u32,
340 /// When this socket is polled, save user_interests_flags in
341 /// polling_interests_flags. Used for comparison during re-registration.
342 polling_interests_flags: u32,
343 /// Current Status. When this is Pending, System API calls must be made.
344 poll_status: SockPollStatus,
345 /// Mark if it is deleted.
346 delete_pending: bool,
347 /// Error during updating
348 error: Option<i32>,
349
350 _pinned: PhantomPinned,
351 }
352
353 impl SockState {
354 /// Creates a new SockState with RawSocket and Afd.
new(socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState>355 fn new(socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState> {
356 Ok(SockState {
357 iosb: IoStatusBlock::zeroed(),
358 poll_info: AfdPollInfo::zeroed(),
359 afd,
360 base_socket: get_base_socket(socket)?,
361 user_interests_flags: 0,
362 polling_interests_flags: 0,
363 user_token: 0,
364 poll_status: SockPollStatus::Idle,
365 delete_pending: false,
366
367 error: None,
368 _pinned: PhantomPinned,
369 })
370 }
371
372 /// Update SockState in Deque, poll for each Afd.
update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()>373 fn update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()> {
374 // delete_pending must false.
375 if self.delete_pending {
376 panic!("SockState update when delete_pending is true, {:#?}", self);
377 }
378
379 // Make sure to reset previous error before a new update
380 self.error = None;
381
382 match self.poll_status {
383 // Starts poll
384 SockPollStatus::Idle => {
385 // Init AfdPollInfo
386 self.poll_info.exclusive = 0;
387 self.poll_info.number_of_handles = 1;
388 self.poll_info.timeout = i64::MAX;
389 self.poll_info.handles[0].handle = self.base_socket as HANDLE;
390 self.poll_info.handles[0].status = 0;
391 self.poll_info.handles[0].events =
392 self.user_interests_flags | afd::POLL_LOCAL_CLOSE;
393
394 let overlapped_ptr = into_overlapped(self_arc.clone());
395
396 // System call to run current event.
397 let result = unsafe {
398 self.afd
399 .poll(&mut self.poll_info, &mut *self.iosb, overlapped_ptr)
400 };
401
402 if let Err(e) = result {
403 let code = e.raw_os_error().unwrap();
404 if code != ERROR_IO_PENDING as i32 {
405 drop(from_overlapped(overlapped_ptr as *mut _));
406
407 return if code == ERROR_INVALID_HANDLE as i32 {
408 // Socket closed; it'll be dropped.
409 self.start_drop();
410 Ok(())
411 } else {
412 self.error = e.raw_os_error();
413 Err(e)
414 };
415 }
416 };
417
418 // The poll request was successfully submitted.
419 self.poll_status = SockPollStatus::Pending;
420 self.polling_interests_flags = self.user_interests_flags;
421 }
422 SockPollStatus::Pending => {
423 if (self.user_interests_flags & afd::ALL_EVENTS & !self.polling_interests_flags)
424 == 0
425 {
426 // All the events the user is interested in are already
427 // being monitored by the pending poll
428 // operation. It might spuriously complete because of an
429 // event that we're no longer interested in; when that
430 // happens we'll submit a new poll
431 // operation with the updated event mask.
432 } else {
433 // A poll operation is already pending, but it's not monitoring for all the
434 // events that the user is interested in. Therefore, cancel the pending
435 // poll operation; when we receive it's completion package, a new poll
436 // operation will be submitted with the correct event mask.
437 if let Err(e) = self.cancel() {
438 self.error = e.raw_os_error();
439 return Err(e);
440 }
441 }
442 }
443 // Do nothing
444 SockPollStatus::Cancelled => {}
445 }
446
447 Ok(())
448 }
449
450 /// Returns true if user_interests_flags is inconsistent with
451 /// polling_interests_flags.
set_event(&mut self, flags: u32, token_data: u64) -> bool452 fn set_event(&mut self, flags: u32, token_data: u64) -> bool {
453 self.user_interests_flags = flags | afd::POLL_CONNECT_FAIL | afd::POLL_ABORT;
454 self.user_token = token_data;
455
456 (self.user_interests_flags & !self.polling_interests_flags) != 0
457 }
458
459 /// Process completed IO operation.
sock_feed_event(&mut self) -> Option<Event>460 fn sock_feed_event(&mut self) -> Option<Event> {
461 self.poll_status = SockPollStatus::Idle;
462 self.polling_interests_flags = 0;
463
464 let mut afd_events = 0;
465 // Uses the status info in IO_STATUS_BLOCK to determine the socket poll status.
466 // It is unsafe to use a pointer of IO_STATUS_BLOCK.
467 unsafe {
468 if self.delete_pending {
469 return None;
470 } else if self.iosb.Anonymous.Status == STATUS_CANCELLED {
471 // The poll request was cancelled by CancelIoEx.
472 } else if self.iosb.Anonymous.Status < 0 {
473 // The overlapped request itself failed in an unexpected way.
474 afd_events = afd::POLL_CONNECT_FAIL;
475 } else if self.poll_info.number_of_handles < 1 {
476 // This poll operation succeeded but didn't report any socket
477 // events.
478 } else if self.poll_info.handles[0].events & afd::POLL_LOCAL_CLOSE != 0 {
479 // The poll operation reported that the socket was closed.
480 self.start_drop();
481 return None;
482 } else {
483 afd_events = self.poll_info.handles[0].events;
484 }
485 }
486 // Filter out events that the user didn't ask for.
487 afd_events &= self.user_interests_flags;
488
489 if afd_events == 0 {
490 return None;
491 }
492
493 // Simulates Edge-triggered behavior to match API usage.
494 // Intercept all read/write from user which may cause WouldBlock usage,
495 // And reregister the socket to reset the interests.
496 self.user_interests_flags &= !afd_events;
497
498 Some(Event {
499 data: self.user_token,
500 flags: afd_events,
501 })
502 }
503
504 /// Starts drop SockState
start_drop(&mut self)505 pub(crate) fn start_drop(&mut self) {
506 if !self.delete_pending {
507 // if it is Pending, it means SockState has been register in IOCP,
508 // must system call to cancel socket.
509 // else set delete_pending=true is enough.
510 if let SockPollStatus::Pending = self.poll_status {
511 drop(self.cancel());
512 }
513 self.delete_pending = true;
514 }
515 }
516
517 /// Only can cancel SockState of SockPollStatus::Pending, Set to
518 /// SockPollStatus::Cancelled.
cancel(&mut self) -> io::Result<()>519 fn cancel(&mut self) -> io::Result<()> {
520 // Checks poll_status again.
521 if self.poll_status != SockPollStatus::Pending {
522 unreachable!("Invalid poll status during cancel, {:#?}", self);
523 }
524
525 unsafe {
526 self.afd.cancel(&mut *self.iosb)?;
527 }
528
529 // Only here set SockPollStatus::Cancelled, SockStates has been system called to
530 // cancel
531 self.poll_status = SockPollStatus::Cancelled;
532 self.polling_interests_flags = 0;
533
534 Ok(())
535 }
536
is_delete_pending(&self) -> bool537 fn is_delete_pending(&self) -> bool {
538 self.delete_pending
539 }
540
has_error(&self) -> bool541 fn has_error(&self) -> bool {
542 self.error.is_some()
543 }
544 }
545
546 impl Drop for SockState {
drop(&mut self)547 fn drop(&mut self) {
548 self.start_drop();
549 }
550 }
551
get_base_socket(raw_socket: RawSocket) -> io::Result<RawSocket>552 fn get_base_socket(raw_socket: RawSocket) -> io::Result<RawSocket> {
553 let res = base_socket_inner(raw_socket, SIO_BASE_HANDLE);
554 if let Ok(base_socket) = res {
555 return Ok(base_socket);
556 }
557
558 for &ioctl in &[SIO_BSP_HANDLE_SELECT, SIO_BSP_HANDLE_POLL, SIO_BSP_HANDLE] {
559 if let Ok(base_socket) = base_socket_inner(raw_socket, ioctl) {
560 if base_socket != raw_socket {
561 return Ok(base_socket);
562 }
563 }
564 }
565
566 Err(io::Error::from_raw_os_error(res.unwrap_err()))
567 }
568
base_socket_inner(raw_socket: RawSocket, control_code: u32) -> Result<RawSocket, i32>569 fn base_socket_inner(raw_socket: RawSocket, control_code: u32) -> Result<RawSocket, i32> {
570 let mut base_socket: RawSocket = 0;
571 let mut bytes_returned: u32 = 0;
572 unsafe {
573 if WSAIoctl(
574 raw_socket as usize,
575 control_code,
576 null_mut(),
577 0,
578 &mut base_socket as *mut _ as *mut c_void,
579 size_of::<RawSocket>() as u32,
580 &mut bytes_returned,
581 null_mut(),
582 None,
583 ) != SOCKET_ERROR
584 {
585 Ok(base_socket)
586 } else {
587 // Returns the error status for the last Windows Sockets operation that failed.
588 Err(WSAGetLastError())
589 }
590 }
591 }
592
593 /// Interests convert to flags.
interests_to_afd_flags(interests: Interest) -> u32594 fn interests_to_afd_flags(interests: Interest) -> u32 {
595 let mut flags = 0;
596
597 // Sets readable flags.
598 if interests.is_readable() {
599 flags |= READABLE_FLAGS | READ_CLOSED_FLAGS | ERROR_FLAGS;
600 }
601
602 // Sets writable flags.
603 if interests.is_writable() {
604 flags |= WRITABLE_FLAGS | WRITE_CLOSED_FLAGS | ERROR_FLAGS;
605 }
606
607 flags
608 }
609
610 /// Converts the pointer to a `SockState` into a raw pointer.
into_overlapped(sock_state: Pin<Arc<Mutex<SockState>>>) -> *mut c_void611 fn into_overlapped(sock_state: Pin<Arc<Mutex<SockState>>>) -> *mut c_void {
612 let overlapped_ptr: *const Mutex<SockState> =
613 unsafe { Arc::into_raw(Pin::into_inner_unchecked(sock_state)) };
614 overlapped_ptr as *mut _
615 }
616
617 /// Convert a raw overlapped pointer into a reference to `SockState`.
from_overlapped(ptr: *mut OVERLAPPED) -> Pin<Arc<Mutex<SockState>>>618 fn from_overlapped(ptr: *mut OVERLAPPED) -> Pin<Arc<Mutex<SockState>>> {
619 let sock_ptr: *const Mutex<SockState> = ptr as *const _;
620 unsafe { Pin::new_unchecked(Arc::from_raw(sock_ptr)) }
621 }
622