• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::cmp::Ordering;
6 use std::convert::TryFrom;
7 use std::ffi::OsString;
8 use std::fs::remove_file;
9 use std::io;
10 use std::mem;
11 use std::mem::size_of;
12 use std::net::Ipv4Addr;
13 use std::net::Ipv6Addr;
14 use std::net::SocketAddr;
15 use std::net::SocketAddrV4;
16 use std::net::SocketAddrV6;
17 use std::net::TcpListener;
18 use std::net::TcpStream;
19 use std::net::ToSocketAddrs;
20 use std::ops::Deref;
21 use std::os::unix::ffi::OsStringExt;
22 use std::path::Path;
23 use std::path::PathBuf;
24 use std::ptr::null_mut;
25 use std::time::Duration;
26 use std::time::Instant;
27 
28 use libc::c_int;
29 use libc::recvfrom;
30 use libc::sa_family_t;
31 use libc::sockaddr;
32 use libc::sockaddr_in;
33 use libc::sockaddr_in6;
34 use libc::socklen_t;
35 use libc::AF_INET;
36 use libc::AF_INET6;
37 use libc::MSG_PEEK;
38 use libc::MSG_TRUNC;
39 use log::warn;
40 use serde::Deserialize;
41 use serde::Serialize;
42 
43 use crate::descriptor::AsRawDescriptor;
44 use crate::descriptor::FromRawDescriptor;
45 use crate::descriptor::IntoRawDescriptor;
46 use crate::handle_eintr_errno;
47 use crate::sys::sockaddr_un;
48 use crate::sys::sockaddrv4_to_lib_c;
49 use crate::sys::sockaddrv6_to_lib_c;
50 use crate::Error;
51 use crate::RawDescriptor;
52 use crate::SafeDescriptor;
53 
54 /// Assist in handling both IP version 4 and IP version 6.
55 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
56 pub enum InetVersion {
57     V4,
58     V6,
59 }
60 
61 impl InetVersion {
from_sockaddr(s: &SocketAddr) -> Self62     pub fn from_sockaddr(s: &SocketAddr) -> Self {
63         match s {
64             SocketAddr::V4(_) => InetVersion::V4,
65             SocketAddr::V6(_) => InetVersion::V6,
66         }
67     }
68 }
69 
70 impl From<InetVersion> for sa_family_t {
from(v: InetVersion) -> sa_family_t71     fn from(v: InetVersion) -> sa_family_t {
72         match v {
73             InetVersion::V4 => AF_INET as sa_family_t,
74             InetVersion::V6 => AF_INET6 as sa_family_t,
75         }
76     }
77 }
78 
socket( domain: c_int, sock_type: c_int, protocol: c_int, ) -> io::Result<SafeDescriptor>79 pub(in crate::sys) fn socket(
80     domain: c_int,
81     sock_type: c_int,
82     protocol: c_int,
83 ) -> io::Result<SafeDescriptor> {
84     // SAFETY:
85     // Safe socket initialization since we handle the returned error.
86     match unsafe { libc::socket(domain, sock_type, protocol) } {
87         -1 => Err(io::Error::last_os_error()),
88         // SAFETY:
89         // Safe because we own the file descriptor.
90         fd => Ok(unsafe { SafeDescriptor::from_raw_descriptor(fd) }),
91     }
92 }
93 
socketpair( domain: c_int, sock_type: c_int, protocol: c_int, ) -> io::Result<(SafeDescriptor, SafeDescriptor)>94 pub(in crate::sys) fn socketpair(
95     domain: c_int,
96     sock_type: c_int,
97     protocol: c_int,
98 ) -> io::Result<(SafeDescriptor, SafeDescriptor)> {
99     let mut fds = [0, 0];
100     // SAFETY:
101     // Safe because we give enough space to store all the fds and we check the return value.
102     match unsafe { libc::socketpair(domain, sock_type, protocol, fds.as_mut_ptr()) } {
103         -1 => Err(io::Error::last_os_error()),
104         _ => Ok(
105             // SAFETY:
106             // Safe because we own the file descriptors.
107             unsafe {
108                 (
109                     SafeDescriptor::from_raw_descriptor(fds[0]),
110                     SafeDescriptor::from_raw_descriptor(fds[1]),
111                 )
112             },
113         ),
114     }
115 }
116 
117 /// A TCP socket.
118 ///
119 /// Do not use this class unless you need to change socket options or query the
120 /// state of the socket prior to calling listen or connect. Instead use either TcpStream or
121 /// TcpListener.
122 #[derive(Debug)]
123 pub struct TcpSocket {
124     pub(in crate::sys) inet_version: InetVersion,
125     pub(in crate::sys) descriptor: SafeDescriptor,
126 }
127 
128 impl TcpSocket {
bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()>129     pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
130         let sockaddr = addr
131             .to_socket_addrs()
132             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
133             .next()
134             .unwrap();
135 
136         let ret = match sockaddr {
137             SocketAddr::V4(a) => {
138                 let sin = sockaddrv4_to_lib_c(&a);
139                 // SAFETY:
140                 // Safe because this doesn't modify any memory and we check the return value.
141                 unsafe {
142                     libc::bind(
143                         self.as_raw_descriptor(),
144                         &sin as *const sockaddr_in as *const sockaddr,
145                         size_of::<sockaddr_in>() as socklen_t,
146                     )
147                 }
148             }
149             SocketAddr::V6(a) => {
150                 let sin6 = sockaddrv6_to_lib_c(&a);
151                 // SAFETY:
152                 // Safe because this doesn't modify any memory and we check the return value.
153                 unsafe {
154                     libc::bind(
155                         self.as_raw_descriptor(),
156                         &sin6 as *const sockaddr_in6 as *const sockaddr,
157                         size_of::<sockaddr_in6>() as socklen_t,
158                     )
159                 }
160             }
161         };
162         if ret < 0 {
163             let bind_err = io::Error::last_os_error();
164             Err(bind_err)
165         } else {
166             Ok(())
167         }
168     }
169 
connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream>170     pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
171         let sockaddr = addr
172             .to_socket_addrs()
173             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
174             .next()
175             .unwrap();
176 
177         let ret = match sockaddr {
178             SocketAddr::V4(a) => {
179                 let sin = sockaddrv4_to_lib_c(&a);
180                 // SAFETY:
181                 // Safe because this doesn't modify any memory and we check the return value.
182                 unsafe {
183                     libc::connect(
184                         self.as_raw_descriptor(),
185                         &sin as *const sockaddr_in as *const sockaddr,
186                         size_of::<sockaddr_in>() as socklen_t,
187                     )
188                 }
189             }
190             SocketAddr::V6(a) => {
191                 let sin6 = sockaddrv6_to_lib_c(&a);
192                 // SAFETY:
193                 // Safe because this doesn't modify any memory and we check the return value.
194                 unsafe {
195                     libc::connect(
196                         self.as_raw_descriptor(),
197                         &sin6 as *const sockaddr_in6 as *const sockaddr,
198                         size_of::<sockaddr_in>() as socklen_t,
199                     )
200                 }
201             }
202         };
203 
204         if ret < 0 {
205             let connect_err = io::Error::last_os_error();
206             Err(connect_err)
207         } else {
208             Ok(TcpStream::from(self.descriptor))
209         }
210     }
211 
listen(self) -> io::Result<TcpListener>212     pub fn listen(self) -> io::Result<TcpListener> {
213         // SAFETY:
214         // Safe because this doesn't modify any memory and we check the return value.
215         let ret = unsafe { libc::listen(self.as_raw_descriptor(), 1) };
216         if ret < 0 {
217             let listen_err = io::Error::last_os_error();
218             Err(listen_err)
219         } else {
220             Ok(TcpListener::from(self.descriptor))
221         }
222     }
223 
224     /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u16>225     pub fn local_port(&self) -> io::Result<u16> {
226         match self.inet_version {
227             InetVersion::V4 => {
228                 let mut sin = sockaddrv4_to_lib_c(&SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0));
229 
230                 let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
231                 // SAFETY:
232                 // Safe because we give a valid pointer for addrlen and check the length.
233                 let ret = unsafe {
234                     // Get the socket address that was actually bound.
235                     libc::getsockname(
236                         self.as_raw_descriptor(),
237                         &mut sin as *mut sockaddr_in as *mut sockaddr,
238                         &mut addrlen as *mut socklen_t,
239                     )
240                 };
241                 if ret < 0 {
242                     let getsockname_err = io::Error::last_os_error();
243                     Err(getsockname_err)
244                 } else {
245                     // If this doesn't match, it's not safe to get the port out of the sockaddr.
246                     assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
247 
248                     Ok(u16::from_be(sin.sin_port))
249                 }
250             }
251             InetVersion::V6 => {
252                 let mut sin6 = sockaddrv6_to_lib_c(&SocketAddrV6::new(
253                     Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
254                     0,
255                     0,
256                     0,
257                 ));
258 
259                 let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
260                 // SAFETY:
261                 // Safe because we give a valid pointer for addrlen and check the length.
262                 let ret = unsafe {
263                     // Get the socket address that was actually bound.
264                     libc::getsockname(
265                         self.as_raw_descriptor(),
266                         &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
267                         &mut addrlen as *mut socklen_t,
268                     )
269                 };
270                 if ret < 0 {
271                     let getsockname_err = io::Error::last_os_error();
272                     Err(getsockname_err)
273                 } else {
274                     // If this doesn't match, it's not safe to get the port out of the sockaddr.
275                     assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
276 
277                     Ok(u16::from_be(sin6.sin6_port))
278                 }
279             }
280         }
281     }
282 }
283 
284 impl AsRawDescriptor for TcpSocket {
as_raw_descriptor(&self) -> RawDescriptor285     fn as_raw_descriptor(&self) -> RawDescriptor {
286         self.descriptor.as_raw_descriptor()
287     }
288 }
289 
290 // Offset of sun_path in structure sockaddr_un.
sun_path_offset() -> usize291 pub(in crate::sys) fn sun_path_offset() -> usize {
292     // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.
293     #[allow(clippy::zero_ptr)]
294     let addr = 0 as *const libc::sockaddr_un;
295     // SAFETY:
296     // Safe because we only use the dereference to create a pointer to the desired field in
297     // calculating the offset.
298     unsafe { &(*addr).sun_path as *const _ as usize }
299 }
300 
301 /// A Unix `SOCK_SEQPACKET` socket point to given `path`
302 #[derive(Debug, Serialize, Deserialize)]
303 pub struct UnixSeqpacket(SafeDescriptor);
304 
305 impl UnixSeqpacket {
306     /// Open a `SOCK_SEQPACKET` connection to socket named by `path`.
307     ///
308     /// # Arguments
309     /// * `path` - Path to `SOCK_SEQPACKET` socket
310     ///
311     /// # Returns
312     /// A `UnixSeqpacket` structure point to the socket
313     ///
314     /// # Errors
315     /// Return `io::Error` when error occurs.
connect<P: AsRef<Path>>(path: P) -> io::Result<Self>316     pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
317         let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
318         let (addr, len) = sockaddr_un(path.as_ref())?;
319         // SAFETY:
320         // Safe connect since we handle the error and use the right length generated from
321         // `sockaddr_un`.
322         unsafe {
323             let ret = libc::connect(
324                 descriptor.as_raw_descriptor(),
325                 &addr as *const _ as *const _,
326                 len,
327             );
328             if ret < 0 {
329                 return Err(io::Error::last_os_error());
330             }
331         }
332         Ok(UnixSeqpacket(descriptor))
333     }
334 
335     /// Clone the underlying FD.
try_clone(&self) -> io::Result<Self>336     pub fn try_clone(&self) -> io::Result<Self> {
337         Ok(Self(self.0.try_clone()?))
338     }
339 
340     /// Gets the number of bytes that can be read from this socket without blocking.
get_readable_bytes(&self) -> io::Result<usize>341     pub fn get_readable_bytes(&self) -> io::Result<usize> {
342         let mut byte_count = 0i32;
343         // SAFETY:
344         // Safe because self has valid raw descriptor and return value are checked.
345         let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONREAD, &mut byte_count) };
346         if ret < 0 {
347             Err(io::Error::last_os_error())
348         } else {
349             Ok(byte_count as usize)
350         }
351     }
352 
353     /// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
354     /// respecting the blocking and timeout settings of the underlying socket.
next_packet_size(&self) -> io::Result<usize>355     pub fn next_packet_size(&self) -> io::Result<usize> {
356         #[cfg(not(debug_assertions))]
357         let buf = null_mut();
358         // Work around for qemu's syscall translation which will reject null pointers in recvfrom.
359         // This only matters for running the unit tests for a non-native architecture. See the
360         // upstream thread for the qemu fix:
361         // https://lists.nongnu.org/archive/html/qemu-devel/2021-03/msg09027.html
362         #[cfg(debug_assertions)]
363         let buf = &mut 0 as *mut _ as *mut _;
364 
365         // SAFETY:
366         // This form of recvfrom doesn't modify any data because all null pointers are used. We only
367         // use the return value and check for errors on an FD owned by this structure.
368         let ret = unsafe {
369             recvfrom(
370                 self.as_raw_descriptor(),
371                 buf,
372                 0,
373                 MSG_TRUNC | MSG_PEEK,
374                 null_mut(),
375                 null_mut(),
376             )
377         };
378         if ret < 0 {
379             Err(io::Error::last_os_error())
380         } else {
381             Ok(ret as usize)
382         }
383     }
384 
385     /// Write data from a given buffer to the socket fd
386     ///
387     /// # Arguments
388     /// * `buf` - A reference to the data buffer.
389     ///
390     /// # Returns
391     /// * `usize` - The size of bytes written to the buffer.
392     ///
393     /// # Errors
394     /// Returns error when `libc::write` failed.
send(&self, buf: &[u8]) -> io::Result<usize>395     pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
396         // SAFETY:
397         // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
398         unsafe {
399             let ret = libc::write(
400                 self.as_raw_descriptor(),
401                 buf.as_ptr() as *const _,
402                 buf.len(),
403             );
404             if ret < 0 {
405                 Err(io::Error::last_os_error())
406             } else {
407                 Ok(ret as usize)
408             }
409         }
410     }
411 
412     /// Read data from the socket fd to a given buffer
413     ///
414     /// # Arguments
415     /// * `buf` - A mut reference to the data buffer.
416     ///
417     /// # Returns
418     /// * `usize` - The size of bytes read to the buffer.
419     ///
420     /// # Errors
421     /// Returns error when `libc::read` failed.
recv(&self, buf: &mut [u8]) -> io::Result<usize>422     pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
423         // SAFETY:
424         // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
425         unsafe {
426             let ret = libc::read(
427                 self.as_raw_descriptor(),
428                 buf.as_mut_ptr() as *mut _,
429                 buf.len(),
430             );
431             if ret < 0 {
432                 Err(io::Error::last_os_error())
433             } else {
434                 Ok(ret as usize)
435             }
436         }
437     }
438 
439     /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size.
440     ///
441     /// # Arguments
442     /// * `buf` - A mut reference to a `Vec` to resize and read into.
443     ///
444     /// # Errors
445     /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()>446     pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
447         let packet_size = self.next_packet_size()?;
448         buf.resize(packet_size, 0);
449         let read_bytes = self.recv(buf)?;
450         buf.resize(read_bytes, 0);
451         Ok(())
452     }
453 
454     /// Read data from the socket fd to a new `Vec`.
455     ///
456     /// # Returns
457     /// * `vec` - A new `Vec` with the entire received packet.
458     ///
459     /// # Errors
460     /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_as_vec(&self) -> io::Result<Vec<u8>>461     pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
462         let mut buf = Vec::new();
463         self.recv_to_vec(&mut buf)?;
464         Ok(buf)
465     }
466 
467     #[allow(clippy::useless_conversion)]
set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()>468     fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
469         let timeval = match timeout {
470             Some(t) => {
471                 if t.as_secs() == 0 && t.subsec_micros() == 0 {
472                     return Err(io::Error::new(
473                         io::ErrorKind::InvalidInput,
474                         "zero timeout duration is invalid",
475                     ));
476                 }
477                 // subsec_micros fits in i32 because it is defined to be less than one million.
478                 let nsec = t.subsec_micros() as i32;
479                 libc::timeval {
480                     tv_sec: t.as_secs() as libc::time_t,
481                     tv_usec: libc::suseconds_t::from(nsec),
482                 }
483             }
484             None => libc::timeval {
485                 tv_sec: 0,
486                 tv_usec: 0,
487             },
488         };
489         // SAFETY:
490         // Safe because we own the fd, and the length of the pointer's data is the same as the
491         // passed in length parameter. The level argument is valid, the kind is assumed to be valid,
492         // and the return value is checked.
493         let ret = unsafe {
494             libc::setsockopt(
495                 self.as_raw_descriptor(),
496                 libc::SOL_SOCKET,
497                 kind,
498                 &timeval as *const libc::timeval as *const libc::c_void,
499                 mem::size_of::<libc::timeval>() as libc::socklen_t,
500             )
501         };
502         if ret < 0 {
503             Err(io::Error::last_os_error())
504         } else {
505             Ok(())
506         }
507     }
508 
509     /// Sets or removes the timeout for read/recv operations on this socket.
set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>510     pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
511         self.set_timeout(timeout, libc::SO_RCVTIMEO)
512     }
513 
514     /// Sets or removes the timeout for write/send operations on this socket.
set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>515     pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
516         self.set_timeout(timeout, libc::SO_SNDTIMEO)
517     }
518 
519     /// Sets the blocking mode for this socket.
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>520     pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
521         let mut nonblocking = nonblocking as libc::c_int;
522         // SAFETY:
523         // Safe because the return value is checked, and this ioctl call sets the nonblocking mode
524         // and does not continue holding the file descriptor after the call.
525         let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
526         if ret < 0 {
527             Err(io::Error::last_os_error())
528         } else {
529             Ok(())
530         }
531     }
532 }
533 
534 impl From<UnixSeqpacket> for SafeDescriptor {
from(s: UnixSeqpacket) -> Self535     fn from(s: UnixSeqpacket) -> Self {
536         s.0
537     }
538 }
539 
540 impl From<SafeDescriptor> for UnixSeqpacket {
from(s: SafeDescriptor) -> Self541     fn from(s: SafeDescriptor) -> Self {
542         Self(s)
543     }
544 }
545 
546 impl FromRawDescriptor for UnixSeqpacket {
from_raw_descriptor(descriptor: RawDescriptor) -> Self547     unsafe fn from_raw_descriptor(descriptor: RawDescriptor) -> Self {
548         Self(SafeDescriptor::from_raw_descriptor(descriptor))
549     }
550 }
551 
552 impl AsRawDescriptor for UnixSeqpacket {
as_raw_descriptor(&self) -> RawDescriptor553     fn as_raw_descriptor(&self) -> RawDescriptor {
554         self.0.as_raw_descriptor()
555     }
556 }
557 
558 impl IntoRawDescriptor for UnixSeqpacket {
into_raw_descriptor(self) -> RawDescriptor559     fn into_raw_descriptor(self) -> RawDescriptor {
560         self.0.into_raw_descriptor()
561     }
562 }
563 
564 impl io::Read for UnixSeqpacket {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>565     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
566         self.recv(buf)
567     }
568 }
569 
570 impl io::Write for UnixSeqpacket {
write(&mut self, buf: &[u8]) -> io::Result<usize>571     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
572         self.send(buf)
573     }
574 
flush(&mut self) -> io::Result<()>575     fn flush(&mut self) -> io::Result<()> {
576         Ok(())
577     }
578 }
579 
580 /// Like a `UnixListener` but for accepting `UnixSeqpacket` type sockets.
581 pub struct UnixSeqpacketListener {
582     descriptor: SafeDescriptor,
583     no_path: bool,
584 }
585 
586 impl UnixSeqpacketListener {
587     /// Creates a new `UnixSeqpacketListener` bound to the given path.
bind<P: AsRef<Path>>(path: P) -> io::Result<Self>588     pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
589         if path.as_ref().starts_with("/proc/self/fd/") {
590             let fd = path
591                 .as_ref()
592                 .file_name()
593                 .expect("Failed to get fd filename")
594                 .to_str()
595                 .expect("fd filename should be unicode")
596                 .parse::<i32>()
597                 .expect("fd should be an integer");
598             let mut result: c_int = 0;
599             let mut result_len = size_of::<c_int>() as libc::socklen_t;
600             // SAFETY: Safe because fd and other args are valid and the return value is checked.
601             let ret = unsafe {
602                 libc::getsockopt(
603                     fd,
604                     libc::SOL_SOCKET,
605                     libc::SO_ACCEPTCONN,
606                     &mut result as *mut _ as *mut libc::c_void,
607                     &mut result_len,
608                 )
609             };
610             if ret < 0 {
611                 return Err(io::Error::last_os_error());
612             }
613             if result != 1 {
614                 return Err(io::Error::new(
615                     io::ErrorKind::InvalidInput,
616                     "specified descriptor is not a listening socket",
617                 ));
618             }
619             // SAFETY:
620             // Safe because we validated the socket file descriptor.
621             let descriptor = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
622             return Ok(UnixSeqpacketListener {
623                 descriptor,
624                 no_path: true,
625             });
626         }
627 
628         let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
629         let (addr, len) = sockaddr_un(path.as_ref())?;
630 
631         // SAFETY:
632         // Safe connect since we handle the error and use the right length generated from
633         // `sockaddr_un`.
634         unsafe {
635             let ret = handle_eintr_errno!(libc::bind(
636                 descriptor.as_raw_descriptor(),
637                 &addr as *const _ as *const _,
638                 len
639             ));
640             if ret < 0 {
641                 return Err(io::Error::last_os_error());
642             }
643             let ret = handle_eintr_errno!(libc::listen(descriptor.as_raw_descriptor(), 128));
644             if ret < 0 {
645                 return Err(io::Error::last_os_error());
646             }
647         }
648         Ok(UnixSeqpacketListener {
649             descriptor,
650             no_path: false,
651         })
652     }
653 
accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket>654     pub fn accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket> {
655         let start = Instant::now();
656 
657         loop {
658             let mut fds = libc::pollfd {
659                 fd: self.as_raw_descriptor(),
660                 events: libc::POLLIN,
661                 revents: 0,
662             };
663             let elapsed = Instant::now().saturating_duration_since(start);
664             let remaining = timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO);
665             let cur_timeout_ms = i32::try_from(remaining.as_millis()).unwrap_or(i32::MAX);
666             // SAFETY:
667             // Safe because we give a valid pointer to a list (of 1) FD and we check
668             // the return value.
669             match unsafe { libc::poll(&mut fds, 1, cur_timeout_ms) }.cmp(&0) {
670                 Ordering::Greater => return self.accept(),
671                 Ordering::Equal => return Err(io::Error::from_raw_os_error(libc::ETIMEDOUT)),
672                 Ordering::Less => {
673                     if Error::last() != Error::new(libc::EINTR) {
674                         return Err(io::Error::last_os_error());
675                     }
676                 }
677             }
678         }
679     }
680 
681     /// Gets the path that this listener is bound to.
path(&self) -> io::Result<PathBuf>682     pub fn path(&self) -> io::Result<PathBuf> {
683         let mut addr = sockaddr_un(Path::new(""))?.0;
684         if self.no_path {
685             return Err(io::Error::new(
686                 io::ErrorKind::InvalidInput,
687                 "socket has no path",
688             ));
689         }
690         let sun_path_offset = (&addr.sun_path as *const _ as usize
691             - &addr.sun_family as *const _ as usize)
692             as libc::socklen_t;
693         let mut len = mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
694         // SAFETY:
695         // Safe because the length given matches the length of the data of the given pointer, and we
696         // check the return value.
697         let ret = unsafe {
698             handle_eintr_errno!(libc::getsockname(
699                 self.as_raw_descriptor(),
700                 &mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr,
701                 &mut len
702             ))
703         };
704         if ret < 0 {
705             return Err(io::Error::last_os_error());
706         }
707         if addr.sun_family != libc::AF_UNIX as libc::sa_family_t
708             || addr.sun_path[0] == 0
709             || len < 1 + sun_path_offset
710         {
711             return Err(io::Error::new(
712                 io::ErrorKind::InvalidInput,
713                 "getsockname on socket returned invalid value",
714             ));
715         }
716 
717         let path_os_str = OsString::from_vec(
718             addr.sun_path[..(len - sun_path_offset - 1) as usize]
719                 .iter()
720                 .map(|&c| c as _)
721                 .collect(),
722         );
723         Ok(path_os_str.into())
724     }
725 
726     /// Sets the blocking mode for this socket.
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>727     pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
728         let mut nonblocking = nonblocking as libc::c_int;
729         // SAFETY:
730         // Safe because the return value is checked, and this ioctl call sets the nonblocking mode
731         // and does not continue holding the file descriptor after the call.
732         let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
733         if ret < 0 {
734             Err(io::Error::last_os_error())
735         } else {
736             Ok(())
737         }
738     }
739 }
740 
741 impl AsRawDescriptor for UnixSeqpacketListener {
as_raw_descriptor(&self) -> RawDescriptor742     fn as_raw_descriptor(&self) -> RawDescriptor {
743         self.descriptor.as_raw_descriptor()
744     }
745 }
746 
747 /// Used to attempt to clean up a `UnixSeqpacketListener` after it is dropped.
748 pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener);
749 
750 impl AsRawDescriptor for UnlinkUnixSeqpacketListener {
as_raw_descriptor(&self) -> RawDescriptor751     fn as_raw_descriptor(&self) -> RawDescriptor {
752         self.0.as_raw_descriptor()
753     }
754 }
755 
756 impl AsRef<UnixSeqpacketListener> for UnlinkUnixSeqpacketListener {
as_ref(&self) -> &UnixSeqpacketListener757     fn as_ref(&self) -> &UnixSeqpacketListener {
758         &self.0
759     }
760 }
761 
762 impl Deref for UnlinkUnixSeqpacketListener {
763     type Target = UnixSeqpacketListener;
deref(&self) -> &Self::Target764     fn deref(&self) -> &Self::Target {
765         &self.0
766     }
767 }
768 
769 impl Drop for UnlinkUnixSeqpacketListener {
drop(&mut self)770     fn drop(&mut self) {
771         if let Ok(path) = self.0.path() {
772             if let Err(e) = remove_file(path) {
773                 warn!("failed to remove control socket file: {:?}", e);
774             }
775         }
776     }
777 }
778 
779 #[cfg(test)]
780 mod tests {
781     use super::*;
782 
783     #[test]
sockaddr_un_zero_length_input()784     fn sockaddr_un_zero_length_input() {
785         let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed");
786     }
787 
788     #[test]
sockaddr_un_long_input_err()789     fn sockaddr_un_long_input_err() {
790         let res = sockaddr_un(Path::new(&"a".repeat(108)));
791         assert!(res.is_err());
792     }
793 
794     #[test]
sockaddr_un_long_input_pass()795     fn sockaddr_un_long_input_pass() {
796         let _res = sockaddr_un(Path::new(&"a".repeat(107))).expect("sockaddr_un failed");
797     }
798 
799     #[test]
sockaddr_un_len_check()800     fn sockaddr_un_len_check() {
801         let (_addr, len) = sockaddr_un(Path::new(&"a".repeat(50))).expect("sockaddr_un failed");
802         assert_eq!(len, (sun_path_offset() + 50 + 1) as u32);
803     }
804 
805     #[test]
806     #[allow(clippy::unnecessary_cast)]
807     // c_char is u8 on aarch64 and i8 on x86, so clippy's suggested fix of changing
808     // `'a' as libc::c_char` below to `b'a'` won't work everywhere.
809     #[allow(clippy::char_lit_as_u8)]
sockaddr_un_pass()810     fn sockaddr_un_pass() {
811         let path_size = 50;
812         let (addr, len) =
813             sockaddr_un(Path::new(&"a".repeat(path_size))).expect("sockaddr_un failed");
814         assert_eq!(len, (sun_path_offset() + path_size + 1) as u32);
815         assert_eq!(addr.sun_family, libc::AF_UNIX as libc::sa_family_t);
816 
817         // Check `sun_path` in returned `sockaddr_un`
818         let mut ref_sun_path = [0 as libc::c_char; 108];
819         for path in ref_sun_path.iter_mut().take(path_size) {
820             *path = 'a' as libc::c_char;
821         }
822 
823         for (addr_char, ref_char) in addr.sun_path.iter().zip(ref_sun_path.iter()) {
824             assert_eq!(addr_char, ref_char);
825         }
826     }
827 }
828