1 // Copyright 2015 The Rust Project Developers.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8
9 use std::cmp::min;
10 use std::io::{self, IoSlice};
11 use std::marker::PhantomData;
12 use std::mem::{self, size_of, MaybeUninit};
13 use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
14 use std::os::windows::io::{
15 AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket,
16 };
17 use std::path::Path;
18 use std::sync::Once;
19 use std::time::{Duration, Instant};
20 use std::{process, ptr, slice};
21
22 use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT};
23 use windows_sys::Win32::Networking::WinSock::{
24 self, tcp_keepalive, FIONBIO, IN6_ADDR, IN6_ADDR_0, INVALID_SOCKET, IN_ADDR, IN_ADDR_0,
25 POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM, SD_BOTH, SD_RECEIVE, SD_SEND, SIO_KEEPALIVE_VALS,
26 SOCKET_ERROR, WSABUF, WSAEMSGSIZE, WSAESHUTDOWN, WSAPOLLFD, WSAPROTOCOL_INFOW,
27 WSA_FLAG_NO_HANDLE_INHERIT, WSA_FLAG_OVERLAPPED,
28 };
29 #[cfg(feature = "all")]
30 use windows_sys::Win32::Networking::WinSock::{
31 IP6T_SO_ORIGINAL_DST, SOL_IP, SO_ORIGINAL_DST, SO_PROTOCOL_INFOW,
32 };
33 use windows_sys::Win32::System::Threading::INFINITE;
34
35 use crate::{MsgHdr, RecvFlags, SockAddr, TcpKeepalive, Type};
36
37 #[allow(non_camel_case_types)]
38 pub(crate) type c_int = std::os::raw::c_int;
39
40 /// Fake MSG_TRUNC flag for the [`RecvFlags`] struct.
41 ///
42 /// The flag is enabled when a `WSARecv[From]` call returns `WSAEMSGSIZE`. The
43 /// value of the flag is defined by us.
44 pub(crate) const MSG_TRUNC: c_int = 0x01;
45
46 // Used in `Domain`.
47 pub(crate) const AF_INET: c_int = windows_sys::Win32::Networking::WinSock::AF_INET as c_int;
48 pub(crate) const AF_INET6: c_int = windows_sys::Win32::Networking::WinSock::AF_INET6 as c_int;
49 pub(crate) const AF_UNIX: c_int = windows_sys::Win32::Networking::WinSock::AF_UNIX as c_int;
50 pub(crate) const AF_UNSPEC: c_int = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as c_int;
51 // Used in `Type`.
52 pub(crate) const SOCK_STREAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_STREAM as c_int;
53 pub(crate) const SOCK_DGRAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_DGRAM as c_int;
54 pub(crate) const SOCK_RAW: c_int = windows_sys::Win32::Networking::WinSock::SOCK_RAW as c_int;
55 const SOCK_RDM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_RDM as c_int;
56 pub(crate) const SOCK_SEQPACKET: c_int =
57 windows_sys::Win32::Networking::WinSock::SOCK_SEQPACKET as c_int;
58 // Used in `Protocol`.
59 pub(crate) use windows_sys::Win32::Networking::WinSock::{
60 IPPROTO_ICMP, IPPROTO_ICMPV6, IPPROTO_TCP, IPPROTO_UDP,
61 };
62 // Used in `SockAddr`.
63 pub(crate) use windows_sys::Win32::Networking::WinSock::{
64 SOCKADDR as sockaddr, SOCKADDR_IN as sockaddr_in, SOCKADDR_IN6 as sockaddr_in6,
65 SOCKADDR_STORAGE as sockaddr_storage,
66 };
67 #[allow(non_camel_case_types)]
68 pub(crate) type sa_family_t = windows_sys::Win32::Networking::WinSock::ADDRESS_FAMILY;
69 #[allow(non_camel_case_types)]
70 pub(crate) type socklen_t = windows_sys::Win32::Networking::WinSock::socklen_t;
71 // Used in `Socket`.
72 #[cfg(feature = "all")]
73 pub(crate) use windows_sys::Win32::Networking::WinSock::IP_HDRINCL;
74 pub(crate) use windows_sys::Win32::Networking::WinSock::{
75 IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_MREQ as Ipv6Mreq,
76 IPV6_MULTICAST_HOPS, IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, IPV6_RECVTCLASS,
77 IPV6_UNICAST_HOPS, IPV6_V6ONLY, IP_ADD_MEMBERSHIP, IP_ADD_SOURCE_MEMBERSHIP,
78 IP_DROP_MEMBERSHIP, IP_DROP_SOURCE_MEMBERSHIP, IP_MREQ as IpMreq,
79 IP_MREQ_SOURCE as IpMreqSource, IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL,
80 IP_RECVTOS, IP_TOS, IP_TTL, LINGER as linger, MSG_OOB, MSG_PEEK, SO_BROADCAST, SO_ERROR,
81 SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE, SO_RCVBUF, SO_RCVTIMEO, SO_REUSEADDR, SO_SNDBUF,
82 SO_SNDTIMEO, SO_TYPE, TCP_NODELAY,
83 };
84 pub(crate) const IPPROTO_IP: c_int = windows_sys::Win32::Networking::WinSock::IPPROTO_IP as c_int;
85 pub(crate) const SOL_SOCKET: c_int = windows_sys::Win32::Networking::WinSock::SOL_SOCKET as c_int;
86
87 /// Type used in set/getsockopt to retrieve the `TCP_NODELAY` option.
88 ///
89 /// NOTE: <https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-getsockopt>
90 /// documents that options such as `TCP_NODELAY` and `SO_KEEPALIVE` expect a
91 /// `BOOL` (alias for `c_int`, 4 bytes), however in practice this turns out to
92 /// be false (or misleading) as a `BOOLEAN` (`c_uchar`, 1 byte) is returned by
93 /// `getsockopt`.
94 pub(crate) type Bool = windows_sys::Win32::Foundation::BOOLEAN;
95
96 /// Maximum size of a buffer passed to system call like `recv` and `send`.
97 const MAX_BUF_LEN: usize = c_int::MAX as usize;
98
99 /// Helper macro to execute a system call that returns an `io::Result`.
100 macro_rules! syscall {
101 ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{
102 #[allow(unused_unsafe)]
103 let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) };
104 if $err_test(&res, &$err_value) {
105 Err(io::Error::last_os_error())
106 } else {
107 Ok(res)
108 }
109 }};
110 }
111
112 impl_debug!(
113 crate::Domain,
114 self::AF_INET,
115 self::AF_INET6,
116 self::AF_UNIX,
117 self::AF_UNSPEC,
118 );
119
120 /// Windows only API.
121 impl Type {
122 /// Our custom flag to set `WSA_FLAG_NO_HANDLE_INHERIT` on socket creation.
123 /// Trying to mimic `Type::cloexec` on windows.
124 const NO_INHERIT: c_int = 1 << ((size_of::<c_int>() * 8) - 1); // Last bit.
125
126 /// Set `WSA_FLAG_NO_HANDLE_INHERIT` on the socket.
127 #[cfg(feature = "all")]
128 #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))]
no_inherit(self) -> Type129 pub const fn no_inherit(self) -> Type {
130 self._no_inherit()
131 }
132
_no_inherit(self) -> Type133 pub(crate) const fn _no_inherit(self) -> Type {
134 Type(self.0 | Type::NO_INHERIT)
135 }
136 }
137
138 impl_debug!(
139 crate::Type,
140 self::SOCK_STREAM,
141 self::SOCK_DGRAM,
142 self::SOCK_RAW,
143 self::SOCK_RDM,
144 self::SOCK_SEQPACKET,
145 );
146
147 impl_debug!(
148 crate::Protocol,
149 WinSock::IPPROTO_ICMP,
150 WinSock::IPPROTO_ICMPV6,
151 WinSock::IPPROTO_TCP,
152 WinSock::IPPROTO_UDP,
153 );
154
155 impl std::fmt::Debug for RecvFlags {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 f.debug_struct("RecvFlags")
158 .field("is_truncated", &self.is_truncated())
159 .finish()
160 }
161 }
162
163 #[repr(transparent)]
164 pub struct MaybeUninitSlice<'a> {
165 vec: WSABUF,
166 _lifetime: PhantomData<&'a mut [MaybeUninit<u8>]>,
167 }
168
169 unsafe impl<'a> Send for MaybeUninitSlice<'a> {}
170
171 unsafe impl<'a> Sync for MaybeUninitSlice<'a> {}
172
173 impl<'a> MaybeUninitSlice<'a> {
new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a>174 pub fn new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a> {
175 assert!(buf.len() <= u32::MAX as usize);
176 MaybeUninitSlice {
177 vec: WSABUF {
178 len: buf.len() as u32,
179 buf: buf.as_mut_ptr().cast(),
180 },
181 _lifetime: PhantomData,
182 }
183 }
184
as_slice(&self) -> &[MaybeUninit<u8>]185 pub fn as_slice(&self) -> &[MaybeUninit<u8>] {
186 unsafe { slice::from_raw_parts(self.vec.buf.cast(), self.vec.len as usize) }
187 }
188
as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>]189 pub fn as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>] {
190 unsafe { slice::from_raw_parts_mut(self.vec.buf.cast(), self.vec.len as usize) }
191 }
192 }
193
194 // Used in `MsgHdr`.
195 pub(crate) use windows_sys::Win32::Networking::WinSock::WSAMSG as msghdr;
196
set_msghdr_name(msg: &mut msghdr, name: &SockAddr)197 pub(crate) fn set_msghdr_name(msg: &mut msghdr, name: &SockAddr) {
198 msg.name = name.as_ptr() as *mut _;
199 msg.namelen = name.len();
200 }
201
set_msghdr_iov(msg: &mut msghdr, ptr: *mut WSABUF, len: usize)202 pub(crate) fn set_msghdr_iov(msg: &mut msghdr, ptr: *mut WSABUF, len: usize) {
203 msg.lpBuffers = ptr;
204 msg.dwBufferCount = min(len, u32::MAX as usize) as u32;
205 }
206
set_msghdr_control(msg: &mut msghdr, ptr: *mut u8, len: usize)207 pub(crate) fn set_msghdr_control(msg: &mut msghdr, ptr: *mut u8, len: usize) {
208 msg.Control.buf = ptr;
209 msg.Control.len = len as u32;
210 }
211
set_msghdr_flags(msg: &mut msghdr, flags: c_int)212 pub(crate) fn set_msghdr_flags(msg: &mut msghdr, flags: c_int) {
213 msg.dwFlags = flags as u32;
214 }
215
msghdr_flags(msg: &msghdr) -> RecvFlags216 pub(crate) fn msghdr_flags(msg: &msghdr) -> RecvFlags {
217 RecvFlags(msg.dwFlags as c_int)
218 }
219
msghdr_control_len(msg: &msghdr) -> usize220 pub(crate) fn msghdr_control_len(msg: &msghdr) -> usize {
221 msg.Control.len as _
222 }
223
init()224 fn init() {
225 static INIT: Once = Once::new();
226
227 INIT.call_once(|| {
228 // Initialize winsock through the standard library by just creating a
229 // dummy socket. Whether this is successful or not we drop the result as
230 // libstd will be sure to have initialized winsock.
231 let _ = net::UdpSocket::bind("127.0.0.1:34254");
232 });
233 }
234
235 pub(crate) type Socket = windows_sys::Win32::Networking::WinSock::SOCKET;
236
socket_from_raw(socket: Socket) -> crate::socket::Inner237 pub(crate) unsafe fn socket_from_raw(socket: Socket) -> crate::socket::Inner {
238 crate::socket::Inner::from_raw_socket(socket as RawSocket)
239 }
240
socket_as_raw(socket: &crate::socket::Inner) -> Socket241 pub(crate) fn socket_as_raw(socket: &crate::socket::Inner) -> Socket {
242 socket.as_raw_socket() as Socket
243 }
244
socket_into_raw(socket: crate::socket::Inner) -> Socket245 pub(crate) fn socket_into_raw(socket: crate::socket::Inner) -> Socket {
246 socket.into_raw_socket() as Socket
247 }
248
socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket>249 pub(crate) fn socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket> {
250 init();
251
252 // Check if we set our custom flag.
253 let flags = if ty & Type::NO_INHERIT != 0 {
254 ty = ty & !Type::NO_INHERIT;
255 WSA_FLAG_NO_HANDLE_INHERIT
256 } else {
257 0
258 };
259
260 syscall!(
261 WSASocketW(
262 family,
263 ty,
264 protocol,
265 ptr::null_mut(),
266 0,
267 WSA_FLAG_OVERLAPPED | flags,
268 ),
269 PartialEq::eq,
270 INVALID_SOCKET
271 )
272 }
273
bind(socket: Socket, addr: &SockAddr) -> io::Result<()>274 pub(crate) fn bind(socket: Socket, addr: &SockAddr) -> io::Result<()> {
275 syscall!(bind(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
276 }
277
connect(socket: Socket, addr: &SockAddr) -> io::Result<()>278 pub(crate) fn connect(socket: Socket, addr: &SockAddr) -> io::Result<()> {
279 syscall!(connect(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
280 }
281
poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()>282 pub(crate) fn poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()> {
283 let start = Instant::now();
284
285 let mut fd_array = WSAPOLLFD {
286 fd: socket.as_raw(),
287 events: (POLLRDNORM | POLLWRNORM) as i16,
288 revents: 0,
289 };
290
291 loop {
292 let elapsed = start.elapsed();
293 if elapsed >= timeout {
294 return Err(io::ErrorKind::TimedOut.into());
295 }
296
297 let timeout = (timeout - elapsed).as_millis();
298 let timeout = clamp(timeout, 1, c_int::MAX as u128) as c_int;
299
300 match syscall!(
301 WSAPoll(&mut fd_array, 1, timeout),
302 PartialEq::eq,
303 SOCKET_ERROR
304 ) {
305 Ok(0) => return Err(io::ErrorKind::TimedOut.into()),
306 Ok(_) => {
307 // Error or hang up indicates an error (or failure to connect).
308 if (fd_array.revents & POLLERR as i16) != 0
309 || (fd_array.revents & POLLHUP as i16) != 0
310 {
311 match socket.take_error() {
312 Ok(Some(err)) => return Err(err),
313 Ok(None) => {
314 return Err(io::Error::new(
315 io::ErrorKind::Other,
316 "no error set after POLLHUP",
317 ))
318 }
319 Err(err) => return Err(err),
320 }
321 }
322 return Ok(());
323 }
324 // Got interrupted, try again.
325 Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
326 Err(err) => return Err(err),
327 }
328 }
329 }
330
331 // TODO: use clamp from std lib, stable since 1.50.
clamp<T>(value: T, min: T, max: T) -> T where T: Ord,332 fn clamp<T>(value: T, min: T, max: T) -> T
333 where
334 T: Ord,
335 {
336 if value <= min {
337 min
338 } else if value >= max {
339 max
340 } else {
341 value
342 }
343 }
344
listen(socket: Socket, backlog: c_int) -> io::Result<()>345 pub(crate) fn listen(socket: Socket, backlog: c_int) -> io::Result<()> {
346 syscall!(listen(socket, backlog), PartialEq::ne, 0).map(|_| ())
347 }
348
accept(socket: Socket) -> io::Result<(Socket, SockAddr)>349 pub(crate) fn accept(socket: Socket) -> io::Result<(Socket, SockAddr)> {
350 // Safety: `accept` initialises the `SockAddr` for us.
351 unsafe {
352 SockAddr::try_init(|storage, len| {
353 syscall!(
354 accept(socket, storage.cast(), len),
355 PartialEq::eq,
356 INVALID_SOCKET
357 )
358 })
359 }
360 }
361
getsockname(socket: Socket) -> io::Result<SockAddr>362 pub(crate) fn getsockname(socket: Socket) -> io::Result<SockAddr> {
363 // Safety: `getsockname` initialises the `SockAddr` for us.
364 unsafe {
365 SockAddr::try_init(|storage, len| {
366 syscall!(
367 getsockname(socket, storage.cast(), len),
368 PartialEq::eq,
369 SOCKET_ERROR
370 )
371 })
372 }
373 .map(|(_, addr)| addr)
374 }
375
getpeername(socket: Socket) -> io::Result<SockAddr>376 pub(crate) fn getpeername(socket: Socket) -> io::Result<SockAddr> {
377 // Safety: `getpeername` initialises the `SockAddr` for us.
378 unsafe {
379 SockAddr::try_init(|storage, len| {
380 syscall!(
381 getpeername(socket, storage.cast(), len),
382 PartialEq::eq,
383 SOCKET_ERROR
384 )
385 })
386 }
387 .map(|(_, addr)| addr)
388 }
389
try_clone(socket: Socket) -> io::Result<Socket>390 pub(crate) fn try_clone(socket: Socket) -> io::Result<Socket> {
391 let mut info: MaybeUninit<WSAPROTOCOL_INFOW> = MaybeUninit::uninit();
392 syscall!(
393 // NOTE: `process.id` is the same as `GetCurrentProcessId`.
394 WSADuplicateSocketW(socket, process::id(), info.as_mut_ptr()),
395 PartialEq::eq,
396 SOCKET_ERROR
397 )?;
398 // Safety: `WSADuplicateSocketW` intialised `info` for us.
399 let mut info = unsafe { info.assume_init() };
400
401 syscall!(
402 WSASocketW(
403 info.iAddressFamily,
404 info.iSocketType,
405 info.iProtocol,
406 &mut info,
407 0,
408 WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT,
409 ),
410 PartialEq::eq,
411 INVALID_SOCKET
412 )
413 }
414
set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()>415 pub(crate) fn set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()> {
416 let mut nonblocking = if nonblocking { 1 } else { 0 };
417 ioctlsocket(socket, FIONBIO, &mut nonblocking)
418 }
419
shutdown(socket: Socket, how: Shutdown) -> io::Result<()>420 pub(crate) fn shutdown(socket: Socket, how: Shutdown) -> io::Result<()> {
421 let how = match how {
422 Shutdown::Write => SD_SEND,
423 Shutdown::Read => SD_RECEIVE,
424 Shutdown::Both => SD_BOTH,
425 } as i32;
426 syscall!(shutdown(socket, how), PartialEq::eq, SOCKET_ERROR).map(|_| ())
427 }
428
recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize>429 pub(crate) fn recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
430 let res = syscall!(
431 recv(
432 socket,
433 buf.as_mut_ptr().cast(),
434 min(buf.len(), MAX_BUF_LEN) as c_int,
435 flags,
436 ),
437 PartialEq::eq,
438 SOCKET_ERROR
439 );
440 match res {
441 Ok(n) => Ok(n as usize),
442 Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok(0),
443 Err(err) => Err(err),
444 }
445 }
446
recv_vectored( socket: Socket, bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags)>447 pub(crate) fn recv_vectored(
448 socket: Socket,
449 bufs: &mut [crate::MaybeUninitSlice<'_>],
450 flags: c_int,
451 ) -> io::Result<(usize, RecvFlags)> {
452 let mut nread = 0;
453 let mut flags = flags as u32;
454 let res = syscall!(
455 WSARecv(
456 socket,
457 bufs.as_mut_ptr().cast(),
458 min(bufs.len(), u32::MAX as usize) as u32,
459 &mut nread,
460 &mut flags,
461 ptr::null_mut(),
462 None,
463 ),
464 PartialEq::eq,
465 SOCKET_ERROR
466 );
467 match res {
468 Ok(_) => Ok((nread as usize, RecvFlags(0))),
469 Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok((0, RecvFlags(0))),
470 Err(ref err) if err.raw_os_error() == Some(WSAEMSGSIZE as i32) => {
471 Ok((nread as usize, RecvFlags(MSG_TRUNC)))
472 }
473 Err(err) => Err(err),
474 }
475 }
476
recv_from( socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int, ) -> io::Result<(usize, SockAddr)>477 pub(crate) fn recv_from(
478 socket: Socket,
479 buf: &mut [MaybeUninit<u8>],
480 flags: c_int,
481 ) -> io::Result<(usize, SockAddr)> {
482 // Safety: `recvfrom` initialises the `SockAddr` for us.
483 unsafe {
484 SockAddr::try_init(|storage, addrlen| {
485 let res = syscall!(
486 recvfrom(
487 socket,
488 buf.as_mut_ptr().cast(),
489 min(buf.len(), MAX_BUF_LEN) as c_int,
490 flags,
491 storage.cast(),
492 addrlen,
493 ),
494 PartialEq::eq,
495 SOCKET_ERROR
496 );
497 match res {
498 Ok(n) => Ok(n as usize),
499 Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok(0),
500 Err(err) => Err(err),
501 }
502 })
503 }
504 }
505
peek_sender(socket: Socket) -> io::Result<SockAddr>506 pub(crate) fn peek_sender(socket: Socket) -> io::Result<SockAddr> {
507 // Safety: `recvfrom` initialises the `SockAddr` for us.
508 let ((), sender) = unsafe {
509 SockAddr::try_init(|storage, addrlen| {
510 let res = syscall!(
511 recvfrom(
512 socket,
513 // Windows *appears* not to care if you pass a null pointer.
514 ptr::null_mut(),
515 0,
516 MSG_PEEK,
517 storage.cast(),
518 addrlen,
519 ),
520 PartialEq::eq,
521 SOCKET_ERROR
522 );
523 match res {
524 Ok(_n) => Ok(()),
525 Err(e) => match e.raw_os_error() {
526 Some(code) if code == (WSAESHUTDOWN as i32) || code == (WSAEMSGSIZE as i32) => {
527 Ok(())
528 }
529 _ => Err(e),
530 },
531 }
532 })
533 }?;
534
535 Ok(sender)
536 }
537
recv_from_vectored( socket: Socket, bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags, SockAddr)>538 pub(crate) fn recv_from_vectored(
539 socket: Socket,
540 bufs: &mut [crate::MaybeUninitSlice<'_>],
541 flags: c_int,
542 ) -> io::Result<(usize, RecvFlags, SockAddr)> {
543 // Safety: `recvfrom` initialises the `SockAddr` for us.
544 unsafe {
545 SockAddr::try_init(|storage, addrlen| {
546 let mut nread = 0;
547 let mut flags = flags as u32;
548 let res = syscall!(
549 WSARecvFrom(
550 socket,
551 bufs.as_mut_ptr().cast(),
552 min(bufs.len(), u32::MAX as usize) as u32,
553 &mut nread,
554 &mut flags,
555 storage.cast(),
556 addrlen,
557 ptr::null_mut(),
558 None,
559 ),
560 PartialEq::eq,
561 SOCKET_ERROR
562 );
563 match res {
564 Ok(_) => Ok((nread as usize, RecvFlags(0))),
565 Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => {
566 Ok((nread as usize, RecvFlags(0)))
567 }
568 Err(ref err) if err.raw_os_error() == Some(WSAEMSGSIZE as i32) => {
569 Ok((nread as usize, RecvFlags(MSG_TRUNC)))
570 }
571 Err(err) => Err(err),
572 }
573 })
574 }
575 .map(|((n, recv_flags), addr)| (n, recv_flags, addr))
576 }
577
send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize>578 pub(crate) fn send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize> {
579 syscall!(
580 send(
581 socket,
582 buf.as_ptr().cast(),
583 min(buf.len(), MAX_BUF_LEN) as c_int,
584 flags,
585 ),
586 PartialEq::eq,
587 SOCKET_ERROR
588 )
589 .map(|n| n as usize)
590 }
591
send_vectored( socket: Socket, bufs: &[IoSlice<'_>], flags: c_int, ) -> io::Result<usize>592 pub(crate) fn send_vectored(
593 socket: Socket,
594 bufs: &[IoSlice<'_>],
595 flags: c_int,
596 ) -> io::Result<usize> {
597 let mut nsent = 0;
598 syscall!(
599 WSASend(
600 socket,
601 // FIXME: From the `WSASend` docs [1]:
602 // > For a Winsock application, once the WSASend function is called,
603 // > the system owns these buffers and the application may not
604 // > access them.
605 //
606 // So what we're doing is actually UB as `bufs` needs to be `&mut
607 // [IoSlice<'_>]`.
608 //
609 // Tracking issue: https://github.com/rust-lang/socket2-rs/issues/129.
610 //
611 // NOTE: `send_to_vectored` has the same problem.
612 //
613 // [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend
614 bufs.as_ptr() as *mut _,
615 min(bufs.len(), u32::MAX as usize) as u32,
616 &mut nsent,
617 flags as u32,
618 std::ptr::null_mut(),
619 None,
620 ),
621 PartialEq::eq,
622 SOCKET_ERROR
623 )
624 .map(|_| nsent as usize)
625 }
626
send_to( socket: Socket, buf: &[u8], addr: &SockAddr, flags: c_int, ) -> io::Result<usize>627 pub(crate) fn send_to(
628 socket: Socket,
629 buf: &[u8],
630 addr: &SockAddr,
631 flags: c_int,
632 ) -> io::Result<usize> {
633 syscall!(
634 sendto(
635 socket,
636 buf.as_ptr().cast(),
637 min(buf.len(), MAX_BUF_LEN) as c_int,
638 flags,
639 addr.as_ptr(),
640 addr.len(),
641 ),
642 PartialEq::eq,
643 SOCKET_ERROR
644 )
645 .map(|n| n as usize)
646 }
647
send_to_vectored( socket: Socket, bufs: &[IoSlice<'_>], addr: &SockAddr, flags: c_int, ) -> io::Result<usize>648 pub(crate) fn send_to_vectored(
649 socket: Socket,
650 bufs: &[IoSlice<'_>],
651 addr: &SockAddr,
652 flags: c_int,
653 ) -> io::Result<usize> {
654 let mut nsent = 0;
655 syscall!(
656 WSASendTo(
657 socket,
658 // FIXME: Same problem as in `send_vectored`.
659 bufs.as_ptr() as *mut _,
660 bufs.len().min(u32::MAX as usize) as u32,
661 &mut nsent,
662 flags as u32,
663 addr.as_ptr(),
664 addr.len(),
665 ptr::null_mut(),
666 None,
667 ),
668 PartialEq::eq,
669 SOCKET_ERROR
670 )
671 .map(|_| nsent as usize)
672 }
673
sendmsg(socket: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io::Result<usize>674 pub(crate) fn sendmsg(socket: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io::Result<usize> {
675 let mut nsent = 0;
676 syscall!(
677 WSASendMsg(
678 socket,
679 &msg.inner,
680 flags as u32,
681 &mut nsent,
682 ptr::null_mut(),
683 None,
684 ),
685 PartialEq::eq,
686 SOCKET_ERROR
687 )
688 .map(|_| nsent as usize)
689 }
690
691 /// Wrapper around `getsockopt` to deal with platform specific timeouts.
timeout_opt(fd: Socket, lvl: c_int, name: i32) -> io::Result<Option<Duration>>692 pub(crate) fn timeout_opt(fd: Socket, lvl: c_int, name: i32) -> io::Result<Option<Duration>> {
693 unsafe { getsockopt(fd, lvl, name).map(from_ms) }
694 }
695
from_ms(duration: u32) -> Option<Duration>696 fn from_ms(duration: u32) -> Option<Duration> {
697 if duration == 0 {
698 None
699 } else {
700 let secs = duration / 1000;
701 let nsec = (duration % 1000) * 1000000;
702 Some(Duration::new(secs as u64, nsec as u32))
703 }
704 }
705
706 /// Wrapper around `setsockopt` to deal with platform specific timeouts.
set_timeout_opt( socket: Socket, level: c_int, optname: i32, duration: Option<Duration>, ) -> io::Result<()>707 pub(crate) fn set_timeout_opt(
708 socket: Socket,
709 level: c_int,
710 optname: i32,
711 duration: Option<Duration>,
712 ) -> io::Result<()> {
713 let duration = into_ms(duration);
714 unsafe { setsockopt(socket, level, optname, duration) }
715 }
716
into_ms(duration: Option<Duration>) -> u32717 fn into_ms(duration: Option<Duration>) -> u32 {
718 // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the
719 // timeouts in windows APIs are typically u32 milliseconds. To translate, we
720 // have two pieces to take care of:
721 //
722 // * Nanosecond precision is rounded up
723 // * Greater than u32::MAX milliseconds (50 days) is rounded up to
724 // INFINITE (never time out).
725 duration.map_or(0, |duration| {
726 min(duration.as_millis(), INFINITE as u128) as u32
727 })
728 }
729
set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()>730 pub(crate) fn set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()> {
731 let mut keepalive = tcp_keepalive {
732 onoff: 1,
733 keepalivetime: into_ms(keepalive.time),
734 keepaliveinterval: into_ms(keepalive.interval),
735 };
736 let mut out = 0;
737 syscall!(
738 WSAIoctl(
739 socket,
740 SIO_KEEPALIVE_VALS,
741 &mut keepalive as *mut _ as *mut _,
742 size_of::<tcp_keepalive>() as _,
743 ptr::null_mut(),
744 0,
745 &mut out,
746 ptr::null_mut(),
747 None,
748 ),
749 PartialEq::eq,
750 SOCKET_ERROR
751 )
752 .map(|_| ())
753 }
754
755 /// Caller must ensure `T` is the correct type for `level` and `optname`.
756 // NOTE: `optname` is actually `i32`, but all constants are `u32`.
getsockopt<T>(socket: Socket, level: c_int, optname: i32) -> io::Result<T>757 pub(crate) unsafe fn getsockopt<T>(socket: Socket, level: c_int, optname: i32) -> io::Result<T> {
758 let mut optval: MaybeUninit<T> = MaybeUninit::uninit();
759 let mut optlen = mem::size_of::<T>() as c_int;
760 syscall!(
761 getsockopt(
762 socket,
763 level as i32,
764 optname,
765 optval.as_mut_ptr().cast(),
766 &mut optlen,
767 ),
768 PartialEq::eq,
769 SOCKET_ERROR
770 )
771 .map(|_| {
772 debug_assert_eq!(optlen as usize, mem::size_of::<T>());
773 // Safety: `getsockopt` initialised `optval` for us.
774 optval.assume_init()
775 })
776 }
777
778 /// Caller must ensure `T` is the correct type for `level` and `optname`.
779 // NOTE: `optname` is actually `i32`, but all constants are `u32`.
setsockopt<T>( socket: Socket, level: c_int, optname: i32, optval: T, ) -> io::Result<()>780 pub(crate) unsafe fn setsockopt<T>(
781 socket: Socket,
782 level: c_int,
783 optname: i32,
784 optval: T,
785 ) -> io::Result<()> {
786 syscall!(
787 setsockopt(
788 socket,
789 level as i32,
790 optname,
791 (&optval as *const T).cast(),
792 mem::size_of::<T>() as c_int,
793 ),
794 PartialEq::eq,
795 SOCKET_ERROR
796 )
797 .map(|_| ())
798 }
799
ioctlsocket(socket: Socket, cmd: i32, payload: &mut u32) -> io::Result<()>800 fn ioctlsocket(socket: Socket, cmd: i32, payload: &mut u32) -> io::Result<()> {
801 syscall!(
802 ioctlsocket(socket, cmd, payload),
803 PartialEq::eq,
804 SOCKET_ERROR
805 )
806 .map(|_| ())
807 }
808
to_in_addr(addr: &Ipv4Addr) -> IN_ADDR809 pub(crate) fn to_in_addr(addr: &Ipv4Addr) -> IN_ADDR {
810 IN_ADDR {
811 S_un: IN_ADDR_0 {
812 // `S_un` is stored as BE on all machines, and the array is in BE
813 // order. So the native endian conversion method is used so that
814 // it's never swapped.
815 S_addr: u32::from_ne_bytes(addr.octets()),
816 },
817 }
818 }
819
from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr820 pub(crate) fn from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr {
821 Ipv4Addr::from(unsafe { in_addr.S_un.S_addr }.to_ne_bytes())
822 }
823
to_in6_addr(addr: &Ipv6Addr) -> IN6_ADDR824 pub(crate) fn to_in6_addr(addr: &Ipv6Addr) -> IN6_ADDR {
825 IN6_ADDR {
826 u: IN6_ADDR_0 {
827 Byte: addr.octets(),
828 },
829 }
830 }
831
from_in6_addr(addr: IN6_ADDR) -> Ipv6Addr832 pub(crate) fn from_in6_addr(addr: IN6_ADDR) -> Ipv6Addr {
833 Ipv6Addr::from(unsafe { addr.u.Byte })
834 }
835
to_mreqn( multiaddr: &Ipv4Addr, interface: &crate::socket::InterfaceIndexOrAddress, ) -> IpMreq836 pub(crate) fn to_mreqn(
837 multiaddr: &Ipv4Addr,
838 interface: &crate::socket::InterfaceIndexOrAddress,
839 ) -> IpMreq {
840 IpMreq {
841 imr_multiaddr: to_in_addr(multiaddr),
842 // Per https://docs.microsoft.com/en-us/windows/win32/api/ws2ipdef/ns-ws2ipdef-ip_mreq#members:
843 //
844 // imr_interface
845 //
846 // The local IPv4 address of the interface or the interface index on
847 // which the multicast group should be joined or dropped. This value is
848 // in network byte order. If this member specifies an IPv4 address of
849 // 0.0.0.0, the default IPv4 multicast interface is used.
850 //
851 // To use an interface index of 1 would be the same as an IP address of
852 // 0.0.0.1.
853 imr_interface: match interface {
854 crate::socket::InterfaceIndexOrAddress::Index(interface) => {
855 to_in_addr(&(*interface).into())
856 }
857 crate::socket::InterfaceIndexOrAddress::Address(interface) => to_in_addr(interface),
858 },
859 }
860 }
861
862 #[cfg(feature = "all")]
original_dst(socket: Socket) -> io::Result<SockAddr>863 pub(crate) fn original_dst(socket: Socket) -> io::Result<SockAddr> {
864 unsafe {
865 SockAddr::try_init(|storage, len| {
866 syscall!(
867 getsockopt(
868 socket,
869 SOL_IP as i32,
870 SO_ORIGINAL_DST as i32,
871 storage.cast(),
872 len,
873 ),
874 PartialEq::eq,
875 SOCKET_ERROR
876 )
877 })
878 }
879 .map(|(_, addr)| addr)
880 }
881
882 #[cfg(feature = "all")]
original_dst_ipv6(socket: Socket) -> io::Result<SockAddr>883 pub(crate) fn original_dst_ipv6(socket: Socket) -> io::Result<SockAddr> {
884 unsafe {
885 SockAddr::try_init(|storage, len| {
886 syscall!(
887 getsockopt(
888 socket,
889 SOL_IP as i32,
890 IP6T_SO_ORIGINAL_DST as i32,
891 storage.cast(),
892 len,
893 ),
894 PartialEq::eq,
895 SOCKET_ERROR
896 )
897 })
898 }
899 .map(|(_, addr)| addr)
900 }
901
902 #[allow(unsafe_op_in_unsafe_fn)]
unix_sockaddr(path: &Path) -> io::Result<SockAddr>903 pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
904 // SAFETY: a `sockaddr_storage` of all zeros is valid.
905 let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
906 let len = {
907 let storage: &mut windows_sys::Win32::Networking::WinSock::SOCKADDR_UN =
908 unsafe { &mut *(&mut storage as *mut sockaddr_storage).cast() };
909
910 // Windows expects a UTF-8 path here even though Windows paths are
911 // usually UCS-2 encoded. If Rust exposed OsStr's Wtf8 encoded
912 // buffer, this could be used directly, relying on Windows to
913 // validate the path, but Rust hides this implementation detail.
914 //
915 // See <https://github.com/rust-lang/rust/pull/95290>.
916 let bytes = path
917 .to_str()
918 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "path must be valid UTF-8"))?
919 .as_bytes();
920
921 // Windows appears to allow non-null-terminated paths, but this is
922 // not documented, so do not rely on it yet.
923 //
924 // See <https://github.com/rust-lang/socket2/issues/331>.
925 if bytes.len() >= storage.sun_path.len() {
926 return Err(io::Error::new(
927 io::ErrorKind::InvalidInput,
928 "path must be shorter than SUN_LEN",
929 ));
930 }
931
932 storage.sun_family = crate::sys::AF_UNIX as sa_family_t;
933 // `storage` was initialized to zero above, so the path is
934 // already null terminated.
935 storage.sun_path[..bytes.len()].copy_from_slice(bytes);
936
937 let base = storage as *const _ as usize;
938 let path = &storage.sun_path as *const _ as usize;
939 let sun_path_offset = path - base;
940 sun_path_offset + bytes.len() + 1
941 };
942 Ok(unsafe { SockAddr::new(storage, len as socklen_t) })
943 }
944
945 /// Windows only API.
946 impl crate::Socket {
947 /// Sets `HANDLE_FLAG_INHERIT` using `SetHandleInformation`.
948 #[cfg(feature = "all")]
949 #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))]
set_no_inherit(&self, no_inherit: bool) -> io::Result<()>950 pub fn set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
951 self._set_no_inherit(no_inherit)
952 }
953
_set_no_inherit(&self, no_inherit: bool) -> io::Result<()>954 pub(crate) fn _set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
955 // NOTE: can't use `syscall!` because it expects the function in the
956 // `windows_sys::Win32::Networking::WinSock::` path.
957 let res = unsafe {
958 SetHandleInformation(
959 self.as_raw() as HANDLE,
960 HANDLE_FLAG_INHERIT,
961 !no_inherit as _,
962 )
963 };
964 if res == 0 {
965 // Zero means error.
966 Err(io::Error::last_os_error())
967 } else {
968 Ok(())
969 }
970 }
971
972 /// Returns the [`Protocol`] of this socket by checking the `SO_PROTOCOL_INFOW`
973 /// option on this socket.
974 ///
975 /// [`Protocol`]: crate::Protocol
976 #[cfg(feature = "all")]
protocol(&self) -> io::Result<Option<crate::Protocol>>977 pub fn protocol(&self) -> io::Result<Option<crate::Protocol>> {
978 let info = unsafe {
979 getsockopt::<WSAPROTOCOL_INFOW>(self.as_raw(), SOL_SOCKET, SO_PROTOCOL_INFOW)?
980 };
981 match info.iProtocol {
982 0 => Ok(None),
983 p => Ok(Some(crate::Protocol::from(p))),
984 }
985 }
986 }
987
988 #[cfg_attr(docsrs, doc(cfg(windows)))]
989 impl AsSocket for crate::Socket {
as_socket(&self) -> BorrowedSocket<'_>990 fn as_socket(&self) -> BorrowedSocket<'_> {
991 // SAFETY: lifetime is bound by self.
992 unsafe { BorrowedSocket::borrow_raw(self.as_raw() as RawSocket) }
993 }
994 }
995
996 #[cfg_attr(docsrs, doc(cfg(windows)))]
997 impl AsRawSocket for crate::Socket {
as_raw_socket(&self) -> RawSocket998 fn as_raw_socket(&self) -> RawSocket {
999 self.as_raw() as RawSocket
1000 }
1001 }
1002
1003 #[cfg_attr(docsrs, doc(cfg(windows)))]
1004 impl From<crate::Socket> for OwnedSocket {
from(sock: crate::Socket) -> OwnedSocket1005 fn from(sock: crate::Socket) -> OwnedSocket {
1006 // SAFETY: sock.into_raw() always returns a valid fd.
1007 unsafe { OwnedSocket::from_raw_socket(sock.into_raw() as RawSocket) }
1008 }
1009 }
1010
1011 #[cfg_attr(docsrs, doc(cfg(windows)))]
1012 impl IntoRawSocket for crate::Socket {
into_raw_socket(self) -> RawSocket1013 fn into_raw_socket(self) -> RawSocket {
1014 self.into_raw() as RawSocket
1015 }
1016 }
1017
1018 #[cfg_attr(docsrs, doc(cfg(windows)))]
1019 impl From<OwnedSocket> for crate::Socket {
from(fd: OwnedSocket) -> crate::Socket1020 fn from(fd: OwnedSocket) -> crate::Socket {
1021 // SAFETY: `OwnedFd` ensures the fd is valid.
1022 unsafe { crate::Socket::from_raw_socket(fd.into_raw_socket()) }
1023 }
1024 }
1025
1026 #[cfg_attr(docsrs, doc(cfg(windows)))]
1027 impl FromRawSocket for crate::Socket {
from_raw_socket(socket: RawSocket) -> crate::Socket1028 unsafe fn from_raw_socket(socket: RawSocket) -> crate::Socket {
1029 crate::Socket::from_raw(socket as Socket)
1030 }
1031 }
1032
1033 #[test]
in_addr_convertion()1034 fn in_addr_convertion() {
1035 let ip = Ipv4Addr::new(127, 0, 0, 1);
1036 let raw = to_in_addr(&ip);
1037 assert_eq!(unsafe { raw.S_un.S_addr }, 127 << 0 | 1 << 24);
1038 assert_eq!(from_in_addr(raw), ip);
1039
1040 let ip = Ipv4Addr::new(127, 34, 4, 12);
1041 let raw = to_in_addr(&ip);
1042 assert_eq!(
1043 unsafe { raw.S_un.S_addr },
1044 127 << 0 | 34 << 8 | 4 << 16 | 12 << 24
1045 );
1046 assert_eq!(from_in_addr(raw), ip);
1047 }
1048
1049 #[test]
in6_addr_convertion()1050 fn in6_addr_convertion() {
1051 let ip = Ipv6Addr::new(0x2000, 1, 2, 3, 4, 5, 6, 7);
1052 let raw = to_in6_addr(&ip);
1053 let want = [
1054 0x2000u16.to_be(),
1055 1u16.to_be(),
1056 2u16.to_be(),
1057 3u16.to_be(),
1058 4u16.to_be(),
1059 5u16.to_be(),
1060 6u16.to_be(),
1061 7u16.to_be(),
1062 ];
1063 assert_eq!(unsafe { raw.u.Word }, want);
1064 assert_eq!(from_in6_addr(raw), ip);
1065 }
1066